|
|
|
@ -110,17 +110,17 @@ std::vector<std::pair<double, std::string>> ctc_beam_search_decoder(
|
|
|
|
|
// language model scoring
|
|
|
|
|
if (ext_scorer != nullptr &&
|
|
|
|
|
(c == space_id || ext_scorer->is_character_based())) {
|
|
|
|
|
PathTrie *prefix_toscore = nullptr;
|
|
|
|
|
PathTrie *prefix_to_score = nullptr;
|
|
|
|
|
// skip scoring the space
|
|
|
|
|
if (ext_scorer->is_character_based()) {
|
|
|
|
|
prefix_toscore = prefix_new;
|
|
|
|
|
prefix_to_score = prefix_new;
|
|
|
|
|
} else {
|
|
|
|
|
prefix_toscore = prefix;
|
|
|
|
|
prefix_to_score = prefix;
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
double score = 0.0;
|
|
|
|
|
float score = 0.0;
|
|
|
|
|
std::vector<std::string> ngram;
|
|
|
|
|
ngram = ext_scorer->make_ngram(prefix_toscore);
|
|
|
|
|
ngram = ext_scorer->make_ngram(prefix_to_score);
|
|
|
|
|
score = ext_scorer->get_log_cond_prob(ngram) * ext_scorer->alpha;
|
|
|
|
|
log_p += score;
|
|
|
|
|
log_p += ext_scorer->beta;
|
|
|
|
@ -131,6 +131,7 @@ std::vector<std::pair<double, std::string>> ctc_beam_search_decoder(
|
|
|
|
|
} // end of loop over prefix
|
|
|
|
|
} // end of loop over vocabulary
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
prefixes.clear();
|
|
|
|
|
// update log probs
|
|
|
|
|
root.iterate_to_vec(prefixes);
|
|
|
|
@ -147,6 +148,23 @@ std::vector<std::pair<double, std::string>> ctc_beam_search_decoder(
|
|
|
|
|
}
|
|
|
|
|
} // end of loop over time
|
|
|
|
|
|
|
|
|
|
// score the last word/character of each prefix
|
|
|
|
|
if (ext_scorer != nullptr) {
|
|
|
|
|
for (size_t i = 0; i < beam_size && i < prefixes.size(); ++i) {
|
|
|
|
|
auto prefix = prefixes[i];
|
|
|
|
|
if (prefix->character != space_id && !prefix->is_empty()) {
|
|
|
|
|
float score = 0.0;
|
|
|
|
|
std::vector<std::string> ngram = ext_scorer->make_ngram(prefix);
|
|
|
|
|
score = ext_scorer->get_log_cond_prob(ngram) * ext_scorer->alpha;
|
|
|
|
|
score += ext_scorer->beta;
|
|
|
|
|
prefix->score += score;
|
|
|
|
|
}
|
|
|
|
|
}
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
size_t num_prefixes = std::min(prefixes.size(), beam_size);
|
|
|
|
|
std::sort(prefixes.begin(), prefixes.begin() + num_prefixes, prefix_compare);
|
|
|
|
|
|
|
|
|
|
// compute aproximate ctc score as the return score, without affecting the
|
|
|
|
|
// return order of decoding result. To delete when decoder gets stable.
|
|
|
|
|
for (size_t i = 0; i < beam_size && i < prefixes.size(); ++i) {
|
|
|
|
|