diff --git a/deploy.py b/deploy.py index d43ab1e0..60bdcb0c 100644 --- a/deploy.py +++ b/deploy.py @@ -18,7 +18,7 @@ import time parser = argparse.ArgumentParser(description=__doc__) parser.add_argument( "--num_samples", - default=4, + default=10, type=int, help="Number of samples for inference. (default: %(default)s)") parser.add_argument( @@ -95,12 +95,12 @@ parser.add_argument( help="Path for language model. (default: %(default)s)") parser.add_argument( "--alpha", - default=0.26, + default=1.5, type=float, help="Parameter associated with language model. (default: %(default)f)") parser.add_argument( "--beta", - default=0.1, + default=0.3, type=float, help="Parameter associated with word count. (default: %(default)f)") parser.add_argument( @@ -109,6 +109,12 @@ parser.add_argument( type=float, help="The cutoff probability of pruning" "in beam search. (default: %(default)f)") +parser.add_argument( + "--cutoff_top_n", + default=40, + type=int, + help="The cutoff number of pruning" + "in beam search. (default: %(default)f)") args = parser.parse_args() @@ -184,6 +190,7 @@ def infer(): vocabulary=data_generator.vocab_list, blank_id=len(data_generator.vocab_list), cutoff_prob=args.cutoff_prob, + cutoff_top_n=args.cutoff_top_n, ext_scoring_func=ext_scorer, ) batch_beam_results += [beam_result] else: @@ -194,6 +201,7 @@ def infer(): blank_id=len(data_generator.vocab_list), num_processes=args.num_processes_beam_search, cutoff_prob=args.cutoff_prob, + cutoff_top_n=args.cutoff_top_n, ext_scoring_func=ext_scorer, ) for i, beam_result in enumerate(batch_beam_results): diff --git a/deploy/ctc_decoders.cpp b/deploy/ctc_decoders.cpp index 9304c780..7933b01d 100644 --- a/deploy/ctc_decoders.cpp +++ b/deploy/ctc_decoders.cpp @@ -62,6 +62,7 @@ std::vector > std::vector vocabulary, int blank_id, double cutoff_prob, + int cutoff_top_n, Scorer *ext_scorer) { // dimension check @@ -116,19 +117,33 @@ std::vector > prob_idx.push_back(std::pair(i, prob[i])); } + float min_cutoff = -NUM_FLT_INF; + bool full_beam = false; + if (ext_scorer != nullptr) { + int num_prefixes = std::min((int)prefixes.size(), beam_size); + std::sort(prefixes.begin(), prefixes.begin() + num_prefixes, + prefix_compare); + min_cutoff = prefixes[num_prefixes-1]->_score + log(prob[blank_id]) + - std::max(0.0, ext_scorer->beta); + full_beam = (num_prefixes == beam_size); + } + // pruning of vacobulary int cutoff_len = prob.size(); - if (cutoff_prob < 1.0) { + if (cutoff_prob < 1.0 || cutoff_top_n < prob.size()) { std::sort(prob_idx.begin(), prob_idx.end(), pair_comp_second_rev); - double cum_prob = 0.0; - cutoff_len = 0; - for (int i=0; i= cutoff_prob) break; + if (cutoff_prob < 1.0) { + double cum_prob = 0.0; + cutoff_len = 0; + for (int i=0; i= cutoff_prob) break; + } } + cutoff_len = std::min(cutoff_len, cutoff_top_n); prob_idx = std::vector >( prob_idx.begin(), prob_idx.begin() + cutoff_len); } @@ -138,15 +153,17 @@ std::vector > log_prob_idx.push_back(std::pair (prob_idx[i].first, log(prob_idx[i].second + NUM_FLT_MIN))); } - // loop over chars for (int index = 0; index < log_prob_idx.size(); index++) { auto c = log_prob_idx[index].first; float log_prob_c = log_prob_idx[index].second; - //float log_probs_prev; for (int i = 0; i < prefixes.size() && i_score < min_cutoff) { + break; + } // blank if (c == blank_id) { prefix->_log_prob_b_cur = log_sum_exp( @@ -178,7 +195,7 @@ std::vector > (c == space_id || ext_scorer->is_character_based()) ) { PathTrie *prefix_to_score = nullptr; - // don't score the space + // skip scoring the space if (ext_scorer->is_character_based()) { prefix_to_score = prefix_new; } else { @@ -202,10 +219,10 @@ std::vector > } // end of loop over chars prefixes.clear(); - // update log probabilities + // update log probs root.iterate_to_vec(prefixes); - // sort prefixes by score + // preserve top beam_size prefixes if (prefixes.size() >= beam_size) { std::nth_element(prefixes.begin(), prefixes.begin() + beam_size, @@ -218,18 +235,20 @@ std::vector > } } + // compute aproximate ctc score as the return score for (size_t i = 0; i < beam_size && i < prefixes.size(); i++) { double approx_ctc = prefixes[i]->_score; - // remove word insert: - std::vector output; - prefixes[i]->get_path_vec(output); - size_t prefix_length = output.size(); - // remove language model weight: if (ext_scorer != nullptr) { - // auto words = split_labels(output); - // approx_ctc = approx_ctc - path_length * ext_scorer->beta; - // approx_ctc -= (_lm->get_sent_log_prob(words)) * ext_scorer->alpha; + std::vector output; + prefixes[i]->get_path_vec(output); + size_t prefix_length = output.size(); + auto words = ext_scorer->split_labels(output); + // remove word insert + approx_ctc = approx_ctc - prefix_length * ext_scorer->beta; + // remove language model weight: + approx_ctc -= (ext_scorer->get_sent_log_prob(words)) + * ext_scorer->alpha; } prefixes[i]->_approx_ctc = approx_ctc; @@ -253,11 +272,9 @@ std::vector > for (int j = 0; j < output.size(); j++) { output_str += vocabulary[output[j]]; } - std::pair output_pair(space_prefixes[i]->_score, - output_str); - output_vecs.emplace_back( - output_pair - ); + std::pair + output_pair(-space_prefixes[i]->_approx_ctc, output_str); + output_vecs.emplace_back(output_pair); } return output_vecs; @@ -272,6 +289,7 @@ std::vector > > int blank_id, int num_processes, double cutoff_prob, + int cutoff_top_n, Scorer *ext_scorer ) { if (num_processes <= 0) { @@ -295,7 +313,8 @@ std::vector > > for (int i = 0; i < batch_size; i++) { res.emplace_back( pool.enqueue(ctc_beam_search_decoder, probs_split[i], - beam_size, vocabulary, blank_id, cutoff_prob, ext_scorer) + beam_size, vocabulary, blank_id, cutoff_prob, + cutoff_top_n, ext_scorer) ); } // get decoding results diff --git a/deploy/ctc_decoders.h b/deploy/ctc_decoders.h index 23890382..f339cbd0 100644 --- a/deploy/ctc_decoders.h +++ b/deploy/ctc_decoders.h @@ -39,6 +39,7 @@ std::vector > std::vector vocabulary, int blank_id, double cutoff_prob=1.0, + int cutoff_top_n=40, Scorer *ext_scorer=NULL ); @@ -66,6 +67,7 @@ std::vector>> int blank_id, int num_processes, double cutoff_prob=1.0, + int cutoff_top_n=40, Scorer *ext_scorer=NULL ); diff --git a/deploy/scorer.h b/deploy/scorer.h index e5bfecaf..7d7ce430 100644 --- a/deploy/scorer.h +++ b/deploy/scorer.h @@ -50,6 +50,7 @@ public: void fill_dictionary(bool add_space); // set char map void set_char_map(std::vector char_list); + std::vector split_labels(const std::vector &labels); // expose to decoder double alpha; double beta; @@ -60,7 +61,6 @@ protected: void load_LM(const char* filename); double get_log_prob(const std::vector& words); std::string vec2str(const std::vector &input); - std::vector split_labels(const std::vector &labels); private: void* _language_model; diff --git a/deploy/swig_decoders_wrapper.py b/deploy/swig_decoders_wrapper.py index 51f3173b..b44fae0a 100644 --- a/deploy/swig_decoders_wrapper.py +++ b/deploy/swig_decoders_wrapper.py @@ -43,6 +43,7 @@ def ctc_beam_search_decoder(probs_seq, vocabulary, blank_id, cutoff_prob=1.0, + cutoff_top_n=40, ext_scoring_func=None): """Wrapper for the CTC Beam Search Decoder. @@ -59,6 +60,10 @@ def ctc_beam_search_decoder(probs_seq, :param cutoff_prob: Cutoff probability in pruning, default 1.0, no pruning. :type cutoff_prob: float + :param cutoff_top_n: Cutoff number in pruning, only top cutoff_top_n + characters with highest probs in vocabulary will be + used in beam search, default 40. + :type cutoff_top_n: int :param ext_scoring_func: External scoring function for partially decoded sentence, e.g. word count or language model. @@ -67,9 +72,9 @@ def ctc_beam_search_decoder(probs_seq, results, in descending order of the probability. :rtype: list """ - return swig_decoders.ctc_beam_search_decoder(probs_seq.tolist(), beam_size, - vocabulary, blank_id, - cutoff_prob, ext_scoring_func) + return swig_decoders.ctc_beam_search_decoder( + probs_seq.tolist(), beam_size, vocabulary, blank_id, cutoff_prob, + cutoff_top_n, ext_scoring_func) def ctc_beam_search_decoder_batch(probs_split, @@ -78,6 +83,7 @@ def ctc_beam_search_decoder_batch(probs_split, blank_id, num_processes, cutoff_prob=1.0, + cutoff_top_n=40, ext_scoring_func=None): """Wrapper for the batched CTC beam search decoder. @@ -92,11 +98,15 @@ def ctc_beam_search_decoder_batch(probs_split, :type blank_id: int :param num_processes: Number of parallel processes. :type num_processes: int - :param cutoff_prob: Cutoff probability in pruning, + :param cutoff_prob: Cutoff probability in vocabulary pruning, default 1.0, no pruning. + :type cutoff_prob: float + :param cutoff_top_n: Cutoff number in pruning, only top cutoff_top_n + characters with highest probs in vocabulary will be + used in beam search, default 40. + :type cutoff_top_n: int :param num_processes: Number of parallel processes. :type num_processes: int - :type cutoff_prob: float :param ext_scoring_func: External scoring function for partially decoded sentence, e.g. word count or language model. @@ -109,4 +119,4 @@ def ctc_beam_search_decoder_batch(probs_split, return swig_decoders.ctc_beam_search_decoder_batch( probs_split, beam_size, vocabulary, blank_id, num_processes, - cutoff_prob, ext_scoring_func) + cutoff_prob, cutoff_top_n, ext_scoring_func)