diff --git a/deploy.py b/deploy.py index 2d29973fb..76b616052 100644 --- a/deploy.py +++ b/deploy.py @@ -18,7 +18,7 @@ import time parser = argparse.ArgumentParser(description=__doc__) parser.add_argument( "--num_samples", - default=10, + default=32, type=int, help="Number of samples for inference. (default: %(default)s)") parser.add_argument( @@ -46,6 +46,11 @@ parser.add_argument( default=multiprocessing.cpu_count(), type=int, help="Number of cpu threads for preprocessing data. (default: %(default)s)") +parser.add_argument( + "--num_processes_beam_search", + default=multiprocessing.cpu_count(), + type=int, + help="Number of cpu processes for beam search. (default: %(default)s)") parser.add_argument( "--mean_std_filepath", default='mean_std.npz', @@ -70,8 +75,8 @@ parser.add_argument( "--decode_method", default='beam_search', type=str, - help="Method for ctc decoding: best_path or beam_search. (default: %(default)s)" -) + help="Method for ctc decoding: beam_search or beam_search_batch. " + "(default: %(default)s)") parser.add_argument( "--beam_size", default=200, @@ -169,15 +174,28 @@ def infer(): ## decode and print time_begin = time.time() wer_sum, wer_counter = 0, 0 - for i, probs in enumerate(probs_split): - beam_result = ctc_beam_search_decoder( - probs_seq=probs, + batch_beam_results = [] + if args.decode_method == 'beam_search': + for i, probs in enumerate(probs_split): + beam_result = ctc_beam_search_decoder( + probs_seq=probs, + beam_size=args.beam_size, + vocabulary=data_generator.vocab_list, + blank_id=len(data_generator.vocab_list), + cutoff_prob=args.cutoff_prob, + ext_scoring_func=ext_scorer, ) + batch_beam_results += [beam_result] + else: + batch_beam_results = ctc_beam_search_decoder_batch( + probs_split=probs_split, beam_size=args.beam_size, vocabulary=data_generator.vocab_list, blank_id=len(data_generator.vocab_list), + num_processes=args.num_processes_beam_search, cutoff_prob=args.cutoff_prob, ext_scoring_func=ext_scorer, ) + for i, beam_result in enumerate(batch_beam_results): print("\nTarget Transcription:\t%s" % target_transcription[i]) print("Beam %d: %f \t%s" % (0, beam_result[0][0], beam_result[0][1])) wer_cur = wer(target_transcription[i], beam_result[0][1]) @@ -185,6 +203,7 @@ def infer(): wer_counter += 1 print("cur wer = %f , average wer = %f" % (wer_cur, wer_sum / wer_counter)) + time_end = time.time() print("total time = %f" % (time_end - time_begin)) diff --git a/deploy/README.md b/deploy/README.md index 162a396a4..9bd55dd9a 100644 --- a/deploy/README.md +++ b/deploy/README.md @@ -1,13 +1,26 @@ ### Installation -The setup of the decoder for deployment depends on the source code of [kenlm](https://github.com/kpu/kenlm/) and [openfst](http://www.openfst.org/twiki/bin/view/FST/WebHome), first clone kenlm and download openfst to current directory (i.e., `deep_speech_2/deploy`) +The build of the decoder for deployment depends on several open-sourced projects, first clone or download them to current directory (i.e., `deep_speech_2/deploy`) + +- [**KenLM**](https://github.com/kpu/kenlm/): Faster and Smaller Language Model Queries ```shell git clone https://github.com/kpu/kenlm.git +``` + +- [**OpenFst**](http://www.openfst.org/twiki/bin/view/FST/WebHome): A library for finite-state transducers + +```shell wget http://www.openfst.org/twiki/pub/FST/FstDownload/openfst-1.6.3.tar.gz tar -xzvf openfst-1.6.3.tar.gz ``` -Compiling for python interface requires swig, please make sure swig being installed. +- [**swig**]: Compiling for python interface requires swig, please make sure swig being installed. + +- [**ThreadPool**](http://progsch.net/wordpress/): A library for C++ thread pool + +```shell +git clone https://github.com/progschj/ThreadPool.git +``` Then run the setup diff --git a/deploy/ctc_decoders.cpp b/deploy/ctc_decoders.cpp index 836fb435d..fd553be61 100644 --- a/deploy/ctc_decoders.cpp +++ b/deploy/ctc_decoders.cpp @@ -6,11 +6,13 @@ #include #include "ctc_decoders.h" #include "decoder_utils.h" +#include "ThreadPool.h" typedef double log_prob_type; std::string ctc_best_path_decoder(std::vector > probs_seq, - std::vector vocabulary) { + std::vector vocabulary) +{ // dimension check int num_time_steps = probs_seq.size(); for (int i=0; i > std::vector vocabulary, int blank_id, double cutoff_prob, - Scorer *ext_scorer, - bool nproc) { + Scorer *ext_scorer) +{ // dimension check int num_time_steps = probs_seq.size(); for (int i=0; i > pair_comp_first_rev); return beam_result; } + + +std::vector>> + ctc_beam_search_decoder_batch( + std::vector>> probs_split, + int beam_size, + std::vector vocabulary, + int blank_id, + int num_processes, + double cutoff_prob, + Scorer *ext_scorer + ) +{ + if (num_processes <= 0) { + std::cout << "num_processes must be nonnegative!" << std::endl; + exit(1); + } + // thread pool + ThreadPool pool(num_processes); + // number of samples + int batch_size = probs_split.size(); + // enqueue the tasks of decoding + std::vector>>> res; + for (int i = 0; i < batch_size; i++) { + res.emplace_back( + pool.enqueue(ctc_beam_search_decoder, probs_split[i], + beam_size, vocabulary, blank_id, cutoff_prob, ext_scorer) + ); + } + // get decoding results + std::vector>> batch_results; + for (int i = 0; i < batch_size; i++) { + batch_results.emplace_back(res[i].get()); + } + return batch_results; +} diff --git a/deploy/ctc_decoders.h b/deploy/ctc_decoders.h index 50a6014f0..238903820 100644 --- a/deploy/ctc_decoders.h +++ b/deploy/ctc_decoders.h @@ -6,8 +6,20 @@ #include #include "scorer.h" -/* CTC Beam Search Decoder, the interface is consistent with the - * original decoder in Python version. +/* CTC Best Path Decoder + * + * Parameters: + * probs_seq: 2-D vector that each element is a vector of probabilities + * over vocabulary of one time step. + * vocabulary: A vector of vocabulary. + * Return: + * A vector that each element is a pair of score and decoding result, + * in desending order. + */ +std::string ctc_best_path_decoder(std::vector > probs_seq, + std::vector vocabulary); + +/* CTC Beam Search Decoder * Parameters: * probs_seq: 2-D vector that each element is a vector of probabilities @@ -17,7 +29,6 @@ * blank_id: ID of blank. * cutoff_prob: Cutoff probability of pruning * ext_scorer: External scorer to evaluate a prefix. - * nproc: Whether this function used in multiprocessing. * Return: * A vector that each element is a pair of score and decoding result, * in desending order. @@ -28,21 +39,35 @@ std::vector > std::vector vocabulary, int blank_id, double cutoff_prob=1.0, - Scorer *ext_scorer=NULL, - bool nproc=false + Scorer *ext_scorer=NULL ); -/* CTC Best Path Decoder - * +/* CTC Beam Search Decoder for batch data, the interface is consistent with the + * original decoder in Python version. + * Parameters: - * probs_seq: 2-D vector that each element is a vector of probabilities - * over vocabulary of one time step. + * probs_seq: 3-D vector that each element is a 2-D vector that can be used + * by ctc_beam_search_decoder(). + * . + * beam_size: The width of beam search. * vocabulary: A vector of vocabulary. + * blank_id: ID of blank. + * num_processes: Number of threads for beam search. + * cutoff_prob: Cutoff probability of pruning + * ext_scorer: External scorer to evaluate a prefix. * Return: - * A vector that each element is a pair of score and decoding result, - * in desending order. - */ -std::string ctc_best_path_decoder(std::vector > probs_seq, - std::vector vocabulary); + * A 2-D vector that each element is a vector of decoding result for one + * sample. +*/ +std::vector>> + ctc_beam_search_decoder_batch(std::vector>> probs_split, + int beam_size, + std::vector vocabulary, + int blank_id, + int num_processes, + double cutoff_prob=1.0, + Scorer *ext_scorer=NULL + ); + #endif // CTC_BEAM_SEARCH_DECODER_H_ diff --git a/deploy/decoders.i b/deploy/decoders.i index ed7c85e67..8059199d1 100644 --- a/deploy/decoders.i +++ b/deploy/decoders.i @@ -19,6 +19,8 @@ namespace std { %template(Pair) std::pair; %template(PairFloatStringVector) std::vector >; %template(PairDoubleStringVector) std::vector >; + %template(PairDoubleStringVector2) std::vector > >; + %template(DoubleVector3) std::vector > >; } %template(IntDoublePairCompSecondRev) pair_comp_second_rev; diff --git a/deploy/setup.py b/deploy/setup.py index 077cabd08..1342478b2 100644 --- a/deploy/setup.py +++ b/deploy/setup.py @@ -36,12 +36,12 @@ if compile_test('lzma.h', 'lzma'): os.system('swig -python -c++ ./decoders.i') -ctc_beam_search_decoder_module = [ +decoders_module = [ Extension( name='_swig_decoders', sources=FILES + glob.glob('*.cxx') + glob.glob('*.cpp'), language='C++', - include_dirs=['.', './kenlm', './openfst-1.6.3/src/include'], + include_dirs=['.', 'kenlm', 'openfst-1.6.3/src/include', 'ThreadPool'], libraries=LIBS, extra_compile_args=ARGS) ] @@ -50,5 +50,5 @@ setup( name='swig_decoders', version='0.1', description="""CTC decoders""", - ext_modules=ctc_beam_search_decoder_module, + ext_modules=decoders_module, py_modules=['swig_decoders'], ) diff --git a/deploy/swig_decoders_wrapper.py b/deploy/swig_decoders_wrapper.py index 54c430147..51f3173b2 100644 --- a/deploy/swig_decoders_wrapper.py +++ b/deploy/swig_decoders_wrapper.py @@ -4,7 +4,6 @@ from __future__ import division from __future__ import print_function import swig_decoders -import multiprocessing class Scorer(swig_decoders.Scorer): @@ -39,14 +38,13 @@ def ctc_best_path_decoder(probs_seq, vocabulary): return swig_decoders.ctc_best_path_decoder(probs_seq.tolist(), vocabulary) -def ctc_beam_search_decoder( - probs_seq, - beam_size, - vocabulary, - blank_id, - cutoff_prob=1.0, - ext_scoring_func=None, ): - """Wrapper for CTC Beam Search Decoder. +def ctc_beam_search_decoder(probs_seq, + beam_size, + vocabulary, + blank_id, + cutoff_prob=1.0, + ext_scoring_func=None): + """Wrapper for the CTC Beam Search Decoder. :param probs_seq: 2-D list of probability distributions over each time step, with each element being a list of normalized @@ -81,24 +79,34 @@ def ctc_beam_search_decoder_batch(probs_split, num_processes, cutoff_prob=1.0, ext_scoring_func=None): - """Wrapper for CTC beam search decoder in batch - """ - - # TODO: to resolve PicklingError - - if not num_processes > 0: - raise ValueError("Number of processes must be positive!") + """Wrapper for the batched CTC beam search decoder. - pool = Pool(processes=num_processes) - results = [] - args_list = [] - for i, probs_list in enumerate(probs_split): - args = (probs_list, beam_size, vocabulary, blank_id, cutoff_prob, - ext_scoring_func) - args_list.append(args) - results.append(pool.apply_async(ctc_beam_search_decoder, args)) + :param probs_seq: 3-D list with each element as an instance of 2-D list + of probabilities used by ctc_beam_search_decoder(). + :type probs_seq: 3-D list + :param beam_size: Width for beam search. + :type beam_size: int + :param vocabulary: Vocabulary list. + :type vocabulary: list + :param blank_id: ID of blank. + :type blank_id: int + :param num_processes: Number of parallel processes. + :type num_processes: int + :param cutoff_prob: Cutoff probability in pruning, + default 1.0, no pruning. + :param num_processes: Number of parallel processes. + :type num_processes: int + :type cutoff_prob: float + :param ext_scoring_func: External scoring function for + partially decoded sentence, e.g. word count + or language model. + :type external_scoring_function: callable + :return: List of tuples of log probability and sentence as decoding + results, in descending order of the probability. + :rtype: list + """ + probs_split = [probs_seq.tolist() for probs_seq in probs_split] - pool.close() - pool.join() - beam_search_results = [result.get() for result in results] - return beam_search_results + return swig_decoders.ctc_beam_search_decoder_batch( + probs_split, beam_size, vocabulary, blank_id, num_processes, + cutoff_prob, ext_scoring_func)