pull/2/head
Yibing Liu 7 years ago
parent fd102c2110
commit d9d9514269

@ -58,7 +58,7 @@ parser.add_argument(
help="Manifest path for decoding. (default: %(default)s)")
parser.add_argument(
"--model_filepath",
default='ds2_new_models_0628/params.pass-51.tar.gz',
default='checkpoints/params.latest.tar.gz',
type=str,
help="Model filepath. (default: %(default)s)")
parser.add_argument(
@ -162,9 +162,10 @@ def infer():
for i, probs in enumerate(probs_split)
]
# external scorer
ext_scorer = Scorer(args.alpha, args.beta, args.language_model_path)
## decode and print
## decode and print
wer_sum, wer_counter = 0, 0
for i, probs in enumerate(probs_split):
beam_result = ctc_beam_search_decoder(

@ -133,8 +133,8 @@ std::vector<std::pair<double, std::string> >
if (ext_scorer != NULL && sentence[sentence.size()-1] != ' ') {
prob = prob * ext_scorer->get_score(sentence);
}
double log_prob = log(it->second);
beam_result.push_back(std::pair<double, std::string>(log_prob, it->first));
double log_prob = log(prob);
beam_result.push_back(std::pair<double, std::string>(log_prob, sentence));
}
}
// sort the result and return

@ -47,13 +47,12 @@ inline void strip(std::string &str, char ch=' ') {
int Scorer::word_count(std::string sentence) {
strip(sentence);
int cnt = 0;
int cnt = 1;
for (int i=0; i<sentence.size(); i++) {
if (sentence[i] == ' ' && sentence[i-1] != ' ') {
cnt ++;
}
}
if (cnt > 0) cnt ++;
return cnt;
}
@ -68,15 +67,16 @@ double Scorer::language_model_score(std::string sentence) {
ret = model->FullScore(state, vocab, out_state);
state = out_state;
}
double score = ret.prob;
//log10 prob
double log_prob = ret.prob;
return pow(10, score);
return log_prob;
}
double Scorer::get_score(std::string sentence) {
double lm_score = language_model_score(sentence);
int word_cnt = word_count(sentence);
double final_score = pow(lm_score, _alpha) * pow(word_cnt, _beta);
double final_score = pow(10, _alpha*lm_score) * pow(word_cnt, _beta);
return final_score;
}

Loading…
Cancel
Save