pull/2/head
Yibing Liu 7 years ago
parent fd102c2110
commit d9d9514269

@ -58,7 +58,7 @@ parser.add_argument(
help="Manifest path for decoding. (default: %(default)s)") help="Manifest path for decoding. (default: %(default)s)")
parser.add_argument( parser.add_argument(
"--model_filepath", "--model_filepath",
default='ds2_new_models_0628/params.pass-51.tar.gz', default='checkpoints/params.latest.tar.gz',
type=str, type=str,
help="Model filepath. (default: %(default)s)") help="Model filepath. (default: %(default)s)")
parser.add_argument( parser.add_argument(
@ -162,9 +162,10 @@ def infer():
for i, probs in enumerate(probs_split) for i, probs in enumerate(probs_split)
] ]
# external scorer
ext_scorer = Scorer(args.alpha, args.beta, args.language_model_path) ext_scorer = Scorer(args.alpha, args.beta, args.language_model_path)
## decode and print
## decode and print
wer_sum, wer_counter = 0, 0 wer_sum, wer_counter = 0, 0
for i, probs in enumerate(probs_split): for i, probs in enumerate(probs_split):
beam_result = ctc_beam_search_decoder( beam_result = ctc_beam_search_decoder(

@ -15,10 +15,10 @@ bool pair_comp_second_rev(const std::pair<T1, T2> a, const std::pair<T1, T2> b)
return a.second > b.second; return a.second > b.second;
} }
/* CTC beam search decoder in C++, the interface is consistent with the original /* CTC beam search decoder in C++, the interface is consistent with the original
decoder in Python version. decoder in Python version.
*/ */
std::vector<std::pair<double, std::string> > std::vector<std::pair<double, std::string> >
ctc_beam_search_decoder(std::vector<std::vector<double> > probs_seq, ctc_beam_search_decoder(std::vector<std::vector<double> > probs_seq,
int beam_size, int beam_size,
std::vector<std::string> vocabulary, std::vector<std::string> vocabulary,
@ -29,15 +29,15 @@ std::vector<std::pair<double, std::string> >
) )
{ {
int num_time_steps = probs_seq.size(); int num_time_steps = probs_seq.size();
// assign space ID // assign space ID
std::vector<std::string>::iterator it = std::find(vocabulary.begin(), vocabulary.end(), " "); std::vector<std::string>::iterator it = std::find(vocabulary.begin(), vocabulary.end(), " ");
int space_id = it-vocabulary.begin(); int space_id = it-vocabulary.begin();
if(space_id >= vocabulary.size()) { if(space_id >= vocabulary.size()) {
std::cout<<"The character space is not in the vocabulary!"; std::cout<<"The character space is not in the vocabulary!";
exit(1); exit(1);
} }
// initialize // initialize
// two sets containing selected and candidate prefixes respectively // two sets containing selected and candidate prefixes respectively
std::map<std::string, double> prefix_set_prev, prefix_set_next; std::map<std::string, double> prefix_set_prev, prefix_set_next;
@ -47,7 +47,7 @@ std::vector<std::pair<double, std::string> >
prefix_set_prev["\t"] = 1.0; prefix_set_prev["\t"] = 1.0;
probs_b_prev["\t"] = 1.0; probs_b_prev["\t"] = 1.0;
probs_nb_prev["\t"] = 0.0; probs_nb_prev["\t"] = 0.0;
for (int time_step=0; time_step<num_time_steps; time_step++) { for (int time_step=0; time_step<num_time_steps; time_step++) {
prefix_set_next.clear(); prefix_set_next.clear();
probs_b_cur.clear(); probs_b_cur.clear();
@ -70,8 +70,8 @@ std::vector<std::pair<double, std::string> >
} }
prob_idx = std::vector<std::pair<int, double> >(prob_idx.begin(), prob_idx.begin()+cutoff_len); prob_idx = std::vector<std::pair<int, double> >(prob_idx.begin(), prob_idx.begin()+cutoff_len);
} }
// extend prefix // extend prefix
for (std::map<std::string, double>::iterator it = prefix_set_prev.begin(); for (std::map<std::string, double>::iterator it = prefix_set_prev.begin();
it != prefix_set_prev.end(); it++) { it != prefix_set_prev.end(); it++) {
std::string l = it->first; std::string l = it->first;
if( prefix_set_next.find(l) == prefix_set_next.end()) { if( prefix_set_next.find(l) == prefix_set_next.end()) {
@ -109,12 +109,12 @@ std::vector<std::pair<double, std::string> >
} }
} }
prefix_set_next[l] = probs_b_cur[l]+probs_nb_cur[l]; prefix_set_next[l] = probs_b_cur[l]+probs_nb_cur[l];
} }
probs_b_prev = probs_b_cur; probs_b_prev = probs_b_cur;
probs_nb_prev = probs_nb_cur; probs_nb_prev = probs_nb_cur;
std::vector<std::pair<std::string, double> > std::vector<std::pair<std::string, double> >
prefix_vec_next(prefix_set_next.begin(), prefix_set_next.end()); prefix_vec_next(prefix_set_next.begin(), prefix_set_next.end());
std::sort(prefix_vec_next.begin(), prefix_vec_next.end(), pair_comp_second_rev<std::string, double>); std::sort(prefix_vec_next.begin(), prefix_vec_next.end(), pair_comp_second_rev<std::string, double>);
int k = beam_size<prefix_vec_next.size() ? beam_size : prefix_vec_next.size(); int k = beam_size<prefix_vec_next.size() ? beam_size : prefix_vec_next.size();
@ -124,7 +124,7 @@ std::vector<std::pair<double, std::string> >
// post processing // post processing
std::vector<std::pair<double, std::string> > beam_result; std::vector<std::pair<double, std::string> > beam_result;
for (std::map<std::string, double>::iterator it = prefix_set_prev.begin(); for (std::map<std::string, double>::iterator it = prefix_set_prev.begin();
it != prefix_set_prev.end(); it++) { it != prefix_set_prev.end(); it++) {
if (it->second > 0.0 && it->first.size() > 1) { if (it->second > 0.0 && it->first.size() > 1) {
double prob = it->second; double prob = it->second;
@ -133,8 +133,8 @@ std::vector<std::pair<double, std::string> >
if (ext_scorer != NULL && sentence[sentence.size()-1] != ' ') { if (ext_scorer != NULL && sentence[sentence.size()-1] != ' ') {
prob = prob * ext_scorer->get_score(sentence); prob = prob * ext_scorer->get_score(sentence);
} }
double log_prob = log(it->second); double log_prob = log(prob);
beam_result.push_back(std::pair<double, std::string>(log_prob, it->first)); beam_result.push_back(std::pair<double, std::string>(log_prob, sentence));
} }
} }
// sort the result and return // sort the result and return

@ -35,7 +35,7 @@ inline void strip(std::string &str, char ch=' ') {
break; break;
} }
} }
if (start == 0 && end == str.size()-1) return; if (start == 0 && end == str.size()-1) return;
if (start > end) { if (start > end) {
std::string emp_str; std::string emp_str;
@ -47,13 +47,12 @@ inline void strip(std::string &str, char ch=' ') {
int Scorer::word_count(std::string sentence) { int Scorer::word_count(std::string sentence) {
strip(sentence); strip(sentence);
int cnt = 0; int cnt = 1;
for (int i=0; i<sentence.size(); i++) { for (int i=0; i<sentence.size(); i++) {
if (sentence[i] == ' ' && sentence[i-1] != ' ') { if (sentence[i] == ' ' && sentence[i-1] != ' ') {
cnt ++; cnt ++;
} }
} }
if (cnt > 0) cnt ++;
return cnt; return cnt;
} }
@ -68,15 +67,16 @@ double Scorer::language_model_score(std::string sentence) {
ret = model->FullScore(state, vocab, out_state); ret = model->FullScore(state, vocab, out_state);
state = out_state; state = out_state;
} }
double score = ret.prob; //log10 prob
double log_prob = ret.prob;
return pow(10, score);
return log_prob;
} }
double Scorer::get_score(std::string sentence) { double Scorer::get_score(std::string sentence) {
double lm_score = language_model_score(sentence); double lm_score = language_model_score(sentence);
int word_cnt = word_count(sentence); int word_cnt = word_count(sentence);
double final_score = pow(lm_score, _alpha) * pow(word_cnt, _beta); double final_score = pow(10, _alpha*lm_score) * pow(word_cnt, _beta);
return final_score; return final_score;
} }

Loading…
Cancel
Save