diff --git a/deploy.py b/deploy.py index 3272371b..d8a7e5b2 100644 --- a/deploy.py +++ b/deploy.py @@ -58,7 +58,7 @@ parser.add_argument( help="Manifest path for decoding. (default: %(default)s)") parser.add_argument( "--model_filepath", - default='ds2_new_models_0628/params.pass-51.tar.gz', + default='checkpoints/params.latest.tar.gz', type=str, help="Model filepath. (default: %(default)s)") parser.add_argument( @@ -162,9 +162,10 @@ def infer(): for i, probs in enumerate(probs_split) ] + # external scorer ext_scorer = Scorer(args.alpha, args.beta, args.language_model_path) - ## decode and print + ## decode and print wer_sum, wer_counter = 0, 0 for i, probs in enumerate(probs_split): beam_result = ctc_beam_search_decoder( diff --git a/deploy/ctc_beam_search_decoder.cpp b/deploy/ctc_beam_search_decoder.cpp index 297c7c24..68d1a845 100644 --- a/deploy/ctc_beam_search_decoder.cpp +++ b/deploy/ctc_beam_search_decoder.cpp @@ -15,10 +15,10 @@ bool pair_comp_second_rev(const std::pair a, const std::pair b) return a.second > b.second; } -/* CTC beam search decoder in C++, the interface is consistent with the original +/* CTC beam search decoder in C++, the interface is consistent with the original decoder in Python version. */ -std::vector > +std::vector > ctc_beam_search_decoder(std::vector > probs_seq, int beam_size, std::vector vocabulary, @@ -29,15 +29,15 @@ std::vector > ) { int num_time_steps = probs_seq.size(); - - // assign space ID + + // assign space ID std::vector::iterator it = std::find(vocabulary.begin(), vocabulary.end(), " "); int space_id = it-vocabulary.begin(); if(space_id >= vocabulary.size()) { std::cout<<"The character space is not in the vocabulary!"; - exit(1); + exit(1); } - + // initialize // two sets containing selected and candidate prefixes respectively std::map prefix_set_prev, prefix_set_next; @@ -47,7 +47,7 @@ std::vector > prefix_set_prev["\t"] = 1.0; probs_b_prev["\t"] = 1.0; probs_nb_prev["\t"] = 0.0; - + for (int time_step=0; time_step > } prob_idx = std::vector >(prob_idx.begin(), prob_idx.begin()+cutoff_len); } - // extend prefix - for (std::map::iterator it = prefix_set_prev.begin(); + // extend prefix + for (std::map::iterator it = prefix_set_prev.begin(); it != prefix_set_prev.end(); it++) { std::string l = it->first; if( prefix_set_next.find(l) == prefix_set_next.end()) { @@ -109,12 +109,12 @@ std::vector > } } - prefix_set_next[l] = probs_b_cur[l]+probs_nb_cur[l]; + prefix_set_next[l] = probs_b_cur[l]+probs_nb_cur[l]; } probs_b_prev = probs_b_cur; probs_nb_prev = probs_nb_cur; - std::vector > + std::vector > prefix_vec_next(prefix_set_next.begin(), prefix_set_next.end()); std::sort(prefix_vec_next.begin(), prefix_vec_next.end(), pair_comp_second_rev); int k = beam_size > // post processing std::vector > beam_result; - for (std::map::iterator it = prefix_set_prev.begin(); + for (std::map::iterator it = prefix_set_prev.begin(); it != prefix_set_prev.end(); it++) { if (it->second > 0.0 && it->first.size() > 1) { double prob = it->second; @@ -133,8 +133,8 @@ std::vector > if (ext_scorer != NULL && sentence[sentence.size()-1] != ' ') { prob = prob * ext_scorer->get_score(sentence); } - double log_prob = log(it->second); - beam_result.push_back(std::pair(log_prob, it->first)); + double log_prob = log(prob); + beam_result.push_back(std::pair(log_prob, sentence)); } } // sort the result and return diff --git a/deploy/scorer.cpp b/deploy/scorer.cpp index 9cb68055..d7f68d71 100644 --- a/deploy/scorer.cpp +++ b/deploy/scorer.cpp @@ -35,7 +35,7 @@ inline void strip(std::string &str, char ch=' ') { break; } } - + if (start == 0 && end == str.size()-1) return; if (start > end) { std::string emp_str; @@ -47,13 +47,12 @@ inline void strip(std::string &str, char ch=' ') { int Scorer::word_count(std::string sentence) { strip(sentence); - int cnt = 0; + int cnt = 1; for (int i=0; i 0) cnt ++; return cnt; } @@ -68,15 +67,16 @@ double Scorer::language_model_score(std::string sentence) { ret = model->FullScore(state, vocab, out_state); state = out_state; } - double score = ret.prob; - - return pow(10, score); + //log10 prob + double log_prob = ret.prob; + + return log_prob; } double Scorer::get_score(std::string sentence) { double lm_score = language_model_score(sentence); int word_cnt = word_count(sentence); - double final_score = pow(lm_score, _alpha) * pow(word_cnt, _beta); + double final_score = pow(10, _alpha*lm_score) * pow(word_cnt, _beta); return final_score; }