add the support of parallel beam search decoding in deployment

pull/2/head
Yibing Liu 8 years ago
parent d1189a7950
commit dad406a49b

@ -18,7 +18,7 @@ import time
parser = argparse.ArgumentParser(description=__doc__) parser = argparse.ArgumentParser(description=__doc__)
parser.add_argument( parser.add_argument(
"--num_samples", "--num_samples",
default=10, default=32,
type=int, type=int,
help="Number of samples for inference. (default: %(default)s)") help="Number of samples for inference. (default: %(default)s)")
parser.add_argument( parser.add_argument(
@ -46,6 +46,11 @@ parser.add_argument(
default=multiprocessing.cpu_count(), default=multiprocessing.cpu_count(),
type=int, type=int,
help="Number of cpu threads for preprocessing data. (default: %(default)s)") help="Number of cpu threads for preprocessing data. (default: %(default)s)")
parser.add_argument(
"--num_processes_beam_search",
default=multiprocessing.cpu_count(),
type=int,
help="Number of cpu processes for beam search. (default: %(default)s)")
parser.add_argument( parser.add_argument(
"--mean_std_filepath", "--mean_std_filepath",
default='mean_std.npz', default='mean_std.npz',
@ -70,8 +75,8 @@ parser.add_argument(
"--decode_method", "--decode_method",
default='beam_search', default='beam_search',
type=str, type=str,
help="Method for ctc decoding: best_path or beam_search. (default: %(default)s)" help="Method for ctc decoding: beam_search or beam_search_batch. "
) "(default: %(default)s)")
parser.add_argument( parser.add_argument(
"--beam_size", "--beam_size",
default=200, default=200,
@ -169,6 +174,8 @@ def infer():
## decode and print ## decode and print
time_begin = time.time() time_begin = time.time()
wer_sum, wer_counter = 0, 0 wer_sum, wer_counter = 0, 0
batch_beam_results = []
if args.decode_method == 'beam_search':
for i, probs in enumerate(probs_split): for i, probs in enumerate(probs_split):
beam_result = ctc_beam_search_decoder( beam_result = ctc_beam_search_decoder(
probs_seq=probs, probs_seq=probs,
@ -177,7 +184,18 @@ def infer():
blank_id=len(data_generator.vocab_list), blank_id=len(data_generator.vocab_list),
cutoff_prob=args.cutoff_prob, cutoff_prob=args.cutoff_prob,
ext_scoring_func=ext_scorer, ) ext_scoring_func=ext_scorer, )
batch_beam_results += [beam_result]
else:
batch_beam_results = ctc_beam_search_decoder_batch(
probs_split=probs_split,
beam_size=args.beam_size,
vocabulary=data_generator.vocab_list,
blank_id=len(data_generator.vocab_list),
num_processes=args.num_processes_beam_search,
cutoff_prob=args.cutoff_prob,
ext_scoring_func=ext_scorer, )
for i, beam_result in enumerate(batch_beam_results):
print("\nTarget Transcription:\t%s" % target_transcription[i]) print("\nTarget Transcription:\t%s" % target_transcription[i])
print("Beam %d: %f \t%s" % (0, beam_result[0][0], beam_result[0][1])) print("Beam %d: %f \t%s" % (0, beam_result[0][0], beam_result[0][1]))
wer_cur = wer(target_transcription[i], beam_result[0][1]) wer_cur = wer(target_transcription[i], beam_result[0][1])
@ -185,6 +203,7 @@ def infer():
wer_counter += 1 wer_counter += 1
print("cur wer = %f , average wer = %f" % print("cur wer = %f , average wer = %f" %
(wer_cur, wer_sum / wer_counter)) (wer_cur, wer_sum / wer_counter))
time_end = time.time() time_end = time.time()
print("total time = %f" % (time_end - time_begin)) print("total time = %f" % (time_end - time_begin))

@ -1,12 +1,25 @@
### Installation ### Installation
The setup of the decoder for deployment depends on the source code of [kenlm](https://github.com/kpu/kenlm/) and [openfst](http://www.openfst.org/twiki/bin/view/FST/WebHome), first clone kenlm and download openfst to current directory (i.e., `deep_speech_2/deploy`) The build of the decoder for deployment depends on several open-sourced projects, first clone or download them to current directory (i.e., `deep_speech_2/deploy`)
- [**KenLM**](https://github.com/kpu/kenlm/): Faster and Smaller Language Model Queries
```shell ```shell
git clone https://github.com/kpu/kenlm.git git clone https://github.com/kpu/kenlm.git
```
- [**OpenFst**](http://www.openfst.org/twiki/bin/view/FST/WebHome): A library for finite-state transducers
```shell
wget http://www.openfst.org/twiki/pub/FST/FstDownload/openfst-1.6.3.tar.gz wget http://www.openfst.org/twiki/pub/FST/FstDownload/openfst-1.6.3.tar.gz
tar -xzvf openfst-1.6.3.tar.gz tar -xzvf openfst-1.6.3.tar.gz
``` ```
- [**ThreadPool**](http://progsch.net/wordpress/): A library for C++ thread pool
```shell
git clone https://github.com/progschj/ThreadPool.git
```
Then run the setup Then run the setup
```shell ```shell

@ -6,6 +6,7 @@
#include <limits> #include <limits>
#include "ctc_decoders.h" #include "ctc_decoders.h"
#include "decoder_utils.h" #include "decoder_utils.h"
#include "ThreadPool.h"
typedef double log_prob_type; typedef double log_prob_type;
@ -33,7 +34,8 @@ T log_sum_exp(T x, T y)
} }
std::string ctc_best_path_decoder(std::vector<std::vector<double> > probs_seq, std::string ctc_best_path_decoder(std::vector<std::vector<double> > probs_seq,
std::vector<std::string> vocabulary) { std::vector<std::string> vocabulary)
{
// dimension check // dimension check
int num_time_steps = probs_seq.size(); int num_time_steps = probs_seq.size();
for (int i=0; i<num_time_steps; i++) { for (int i=0; i<num_time_steps; i++) {
@ -83,8 +85,8 @@ std::vector<std::pair<double, std::string> >
std::vector<std::string> vocabulary, std::vector<std::string> vocabulary,
int blank_id, int blank_id,
double cutoff_prob, double cutoff_prob,
Scorer *ext_scorer, Scorer *ext_scorer)
bool nproc) { {
// dimension check // dimension check
int num_time_steps = probs_seq.size(); int num_time_steps = probs_seq.size();
for (int i=0; i<num_time_steps; i++) { for (int i=0; i<num_time_steps; i++) {
@ -260,3 +262,39 @@ std::vector<std::pair<double, std::string> >
pair_comp_first_rev<double, std::string>); pair_comp_first_rev<double, std::string>);
return beam_result; return beam_result;
} }
std::vector<std::vector<std::pair<double, std::string>>>
ctc_beam_search_decoder_batch(
std::vector<std::vector<std::vector<double>>> probs_split,
int beam_size,
std::vector<std::string> vocabulary,
int blank_id,
int num_processes,
double cutoff_prob,
Scorer *ext_scorer
)
{
if (num_processes <= 0) {
std::cout << "num_processes must be nonnegative!" << std::endl;
exit(1);
}
// thread pool
ThreadPool pool(num_processes);
// number of samples
int batch_size = probs_split.size();
// enqueue the tasks of decoding
std::vector<std::future<std::vector<std::pair<double, std::string>>>> res;
for (int i = 0; i < batch_size; i++) {
res.emplace_back(
pool.enqueue(ctc_beam_search_decoder, probs_split[i],
beam_size, vocabulary, blank_id, cutoff_prob, ext_scorer)
);
}
// get decoding results
std::vector<std::vector<std::pair<double, std::string>>> batch_results;
for (int i = 0; i < batch_size; i++) {
batch_results.emplace_back(res[i].get());
}
return batch_results;
}

@ -6,8 +6,20 @@
#include <utility> #include <utility>
#include "scorer.h" #include "scorer.h"
/* CTC Beam Search Decoder, the interface is consistent with the /* CTC Best Path Decoder
* original decoder in Python version. *
* Parameters:
* probs_seq: 2-D vector that each element is a vector of probabilities
* over vocabulary of one time step.
* vocabulary: A vector of vocabulary.
* Return:
* A vector that each element is a pair of score and decoding result,
* in desending order.
*/
std::string ctc_best_path_decoder(std::vector<std::vector<double> > probs_seq,
std::vector<std::string> vocabulary);
/* CTC Beam Search Decoder
* Parameters: * Parameters:
* probs_seq: 2-D vector that each element is a vector of probabilities * probs_seq: 2-D vector that each element is a vector of probabilities
@ -17,7 +29,6 @@
* blank_id: ID of blank. * blank_id: ID of blank.
* cutoff_prob: Cutoff probability of pruning * cutoff_prob: Cutoff probability of pruning
* ext_scorer: External scorer to evaluate a prefix. * ext_scorer: External scorer to evaluate a prefix.
* nproc: Whether this function used in multiprocessing.
* Return: * Return:
* A vector that each element is a pair of score and decoding result, * A vector that each element is a pair of score and decoding result,
* in desending order. * in desending order.
@ -28,21 +39,35 @@ std::vector<std::pair<double, std::string> >
std::vector<std::string> vocabulary, std::vector<std::string> vocabulary,
int blank_id, int blank_id,
double cutoff_prob=1.0, double cutoff_prob=1.0,
Scorer *ext_scorer=NULL, Scorer *ext_scorer=NULL
bool nproc=false
); );
/* CTC Best Path Decoder /* CTC Beam Search Decoder for batch data, the interface is consistent with the
* * original decoder in Python version.
* Parameters: * Parameters:
* probs_seq: 2-D vector that each element is a vector of probabilities * probs_seq: 3-D vector that each element is a 2-D vector that can be used
* over vocabulary of one time step. * by ctc_beam_search_decoder().
* .
* beam_size: The width of beam search.
* vocabulary: A vector of vocabulary. * vocabulary: A vector of vocabulary.
* blank_id: ID of blank.
* num_processes: Number of threads for beam search.
* cutoff_prob: Cutoff probability of pruning
* ext_scorer: External scorer to evaluate a prefix.
* Return: * Return:
* A vector that each element is a pair of score and decoding result, * A 2-D vector that each element is a vector of decoding result for one
* in desending order. * sample.
*/ */
std::string ctc_best_path_decoder(std::vector<std::vector<double> > probs_seq, std::vector<std::vector<std::pair<double, std::string>>>
std::vector<std::string> vocabulary); ctc_beam_search_decoder_batch(std::vector<std::vector<std::vector<double>>> probs_split,
int beam_size,
std::vector<std::string> vocabulary,
int blank_id,
int num_processes,
double cutoff_prob=1.0,
Scorer *ext_scorer=NULL
);
#endif // CTC_BEAM_SEARCH_DECODER_H_ #endif // CTC_BEAM_SEARCH_DECODER_H_

@ -17,6 +17,8 @@ namespace std{
%template(Pair) std::pair<float, std::string>; %template(Pair) std::pair<float, std::string>;
%template(PairFloatStringVector) std::vector<std::pair<float, std::string> >; %template(PairFloatStringVector) std::vector<std::pair<float, std::string> >;
%template(PairDoubleStringVector) std::vector<std::pair<double, std::string> >; %template(PairDoubleStringVector) std::vector<std::pair<double, std::string> >;
%template(PairDoubleStringVector2) std::vector<std::vector<std::pair<double, std::string> > >;
%template(DoubleVector3) std::vector<std::vector<std::vector<double> > >;
} }
%import decoder_utils.h %import decoder_utils.h

@ -36,12 +36,12 @@ if compile_test('lzma.h', 'lzma'):
os.system('swig -python -c++ ./decoders.i') os.system('swig -python -c++ ./decoders.i')
ctc_beam_search_decoder_module = [ decoders_module = [
Extension( Extension(
name='_swig_decoders', name='_swig_decoders',
sources=FILES + glob.glob('*.cxx') + glob.glob('*.cpp'), sources=FILES + glob.glob('*.cxx') + glob.glob('*.cpp'),
language='C++', language='C++',
include_dirs=['.', './kenlm', './openfst-1.6.3/src/include'], include_dirs=['.', 'kenlm', 'openfst-1.6.3/src/include', 'ThreadPool'],
libraries=LIBS, libraries=LIBS,
extra_compile_args=ARGS) extra_compile_args=ARGS)
] ]
@ -50,5 +50,5 @@ setup(
name='swig_decoders', name='swig_decoders',
version='0.1', version='0.1',
description="""CTC decoders""", description="""CTC decoders""",
ext_modules=ctc_beam_search_decoder_module, ext_modules=decoders_module,
py_modules=['swig_decoders'], ) py_modules=['swig_decoders'], )

@ -4,7 +4,6 @@ from __future__ import division
from __future__ import print_function from __future__ import print_function
import swig_decoders import swig_decoders
import multiprocessing
class Scorer(swig_decoders.Scorer): class Scorer(swig_decoders.Scorer):
@ -39,14 +38,13 @@ def ctc_best_path_decoder(probs_seq, vocabulary):
return swig_decoders.ctc_best_path_decoder(probs_seq.tolist(), vocabulary) return swig_decoders.ctc_best_path_decoder(probs_seq.tolist(), vocabulary)
def ctc_beam_search_decoder( def ctc_beam_search_decoder(probs_seq,
probs_seq,
beam_size, beam_size,
vocabulary, vocabulary,
blank_id, blank_id,
cutoff_prob=1.0, cutoff_prob=1.0,
ext_scoring_func=None, ): ext_scoring_func=None):
"""Wrapper for CTC Beam Search Decoder. """Wrapper for the CTC Beam Search Decoder.
:param probs_seq: 2-D list of probability distributions over each time :param probs_seq: 2-D list of probability distributions over each time
step, with each element being a list of normalized step, with each element being a list of normalized
@ -81,24 +79,34 @@ def ctc_beam_search_decoder_batch(probs_split,
num_processes, num_processes,
cutoff_prob=1.0, cutoff_prob=1.0,
ext_scoring_func=None): ext_scoring_func=None):
"""Wrapper for CTC beam search decoder in batch """Wrapper for the batched CTC beam search decoder.
"""
# TODO: to resolve PicklingError
if not num_processes > 0: :param probs_seq: 3-D list with each element as an instance of 2-D list
raise ValueError("Number of processes must be positive!") of probabilities used by ctc_beam_search_decoder().
:type probs_seq: 3-D list
pool = Pool(processes=num_processes) :param beam_size: Width for beam search.
results = [] :type beam_size: int
args_list = [] :param vocabulary: Vocabulary list.
for i, probs_list in enumerate(probs_split): :type vocabulary: list
args = (probs_list, beam_size, vocabulary, blank_id, cutoff_prob, :param blank_id: ID of blank.
ext_scoring_func) :type blank_id: int
args_list.append(args) :param num_processes: Number of parallel processes.
results.append(pool.apply_async(ctc_beam_search_decoder, args)) :type num_processes: int
:param cutoff_prob: Cutoff probability in pruning,
default 1.0, no pruning.
:param num_processes: Number of parallel processes.
:type num_processes: int
:type cutoff_prob: float
:param ext_scoring_func: External scoring function for
partially decoded sentence, e.g. word count
or language model.
:type external_scoring_function: callable
:return: List of tuples of log probability and sentence as decoding
results, in descending order of the probability.
:rtype: list
"""
probs_split = [probs_seq.tolist() for probs_seq in probs_split]
pool.close() return swig_decoders.ctc_beam_search_decoder_batch(
pool.join() probs_split, beam_size, vocabulary, blank_id, num_processes,
beam_search_results = [result.get() for result in results] cutoff_prob, ext_scoring_func)
return beam_search_results

Loading…
Cancel
Save