|
|
@ -65,7 +65,8 @@ std::vector<std::pair<double, std::string>> ctc_beam_search_decoder(
|
|
|
|
prefixes.push_back(&root);
|
|
|
|
prefixes.push_back(&root);
|
|
|
|
|
|
|
|
|
|
|
|
if (ext_scorer != nullptr && !ext_scorer->is_character_based()) {
|
|
|
|
if (ext_scorer != nullptr && !ext_scorer->is_character_based()) {
|
|
|
|
auto fst_dict = static_cast<fst::StdVectorFst *>(ext_scorer->dictionary);
|
|
|
|
auto fst_dict =
|
|
|
|
|
|
|
|
static_cast<fst::StdVectorFst *>(ext_scorer->dictionary);
|
|
|
|
fst::StdVectorFst *dict_ptr = fst_dict->Copy(true);
|
|
|
|
fst::StdVectorFst *dict_ptr = fst_dict->Copy(true);
|
|
|
|
root.set_dictionary(dict_ptr);
|
|
|
|
root.set_dictionary(dict_ptr);
|
|
|
|
auto matcher = std::make_shared<FSTMATCH>(*dict_ptr, fst::MATCH_INPUT);
|
|
|
|
auto matcher = std::make_shared<FSTMATCH>(*dict_ptr, fst::MATCH_INPUT);
|
|
|
@ -80,10 +81,12 @@ std::vector<std::pair<double, std::string>> ctc_beam_search_decoder(
|
|
|
|
bool full_beam = false;
|
|
|
|
bool full_beam = false;
|
|
|
|
if (ext_scorer != nullptr) {
|
|
|
|
if (ext_scorer != nullptr) {
|
|
|
|
size_t num_prefixes = std::min(prefixes.size(), beam_size);
|
|
|
|
size_t num_prefixes = std::min(prefixes.size(), beam_size);
|
|
|
|
std::sort(
|
|
|
|
std::sort(prefixes.begin(),
|
|
|
|
prefixes.begin(), prefixes.begin() + num_prefixes, prefix_compare);
|
|
|
|
prefixes.begin() + num_prefixes,
|
|
|
|
|
|
|
|
prefix_compare);
|
|
|
|
min_cutoff = prefixes[num_prefixes - 1]->score +
|
|
|
|
min_cutoff = prefixes[num_prefixes - 1]->score +
|
|
|
|
std::log(prob[blank_id]) - std::max(0.0, ext_scorer->beta);
|
|
|
|
std::log(prob[blank_id]) -
|
|
|
|
|
|
|
|
std::max(0.0, ext_scorer->beta);
|
|
|
|
full_beam = (num_prefixes == beam_size);
|
|
|
|
full_beam = (num_prefixes == beam_size);
|
|
|
|
}
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
|
@ -101,14 +104,15 @@ std::vector<std::pair<double, std::string>> ctc_beam_search_decoder(
|
|
|
|
}
|
|
|
|
}
|
|
|
|
// blank
|
|
|
|
// blank
|
|
|
|
if (c == blank_id) {
|
|
|
|
if (c == blank_id) {
|
|
|
|
prefix->log_prob_b_cur =
|
|
|
|
prefix->log_prob_b_cur = log_sum_exp(
|
|
|
|
log_sum_exp(prefix->log_prob_b_cur, log_prob_c + prefix->score);
|
|
|
|
prefix->log_prob_b_cur, log_prob_c + prefix->score);
|
|
|
|
continue;
|
|
|
|
continue;
|
|
|
|
}
|
|
|
|
}
|
|
|
|
// repeated character
|
|
|
|
// repeated character
|
|
|
|
if (c == prefix->character) {
|
|
|
|
if (c == prefix->character) {
|
|
|
|
prefix->log_prob_nb_cur = log_sum_exp(
|
|
|
|
prefix->log_prob_nb_cur =
|
|
|
|
prefix->log_prob_nb_cur, log_prob_c + prefix->log_prob_nb_prev);
|
|
|
|
log_sum_exp(prefix->log_prob_nb_cur,
|
|
|
|
|
|
|
|
log_prob_c + prefix->log_prob_nb_prev);
|
|
|
|
}
|
|
|
|
}
|
|
|
|
// get new prefix
|
|
|
|
// get new prefix
|
|
|
|
auto prefix_new = prefix->get_path_trie(c);
|
|
|
|
auto prefix_new = prefix->get_path_trie(c);
|
|
|
@ -137,7 +141,8 @@ std::vector<std::pair<double, std::string>> ctc_beam_search_decoder(
|
|
|
|
float score = 0.0;
|
|
|
|
float score = 0.0;
|
|
|
|
std::vector<std::string> ngram;
|
|
|
|
std::vector<std::string> ngram;
|
|
|
|
ngram = ext_scorer->make_ngram(prefix_to_score);
|
|
|
|
ngram = ext_scorer->make_ngram(prefix_to_score);
|
|
|
|
score = ext_scorer->get_log_cond_prob(ngram) * ext_scorer->alpha;
|
|
|
|
score = ext_scorer->get_log_cond_prob(ngram) *
|
|
|
|
|
|
|
|
ext_scorer->alpha;
|
|
|
|
log_p += score;
|
|
|
|
log_p += score;
|
|
|
|
log_p += ext_scorer->beta;
|
|
|
|
log_p += ext_scorer->beta;
|
|
|
|
}
|
|
|
|
}
|
|
|
@ -171,7 +176,8 @@ std::vector<std::pair<double, std::string>> ctc_beam_search_decoder(
|
|
|
|
if (!prefix->is_empty() && prefix->character != space_id) {
|
|
|
|
if (!prefix->is_empty() && prefix->character != space_id) {
|
|
|
|
float score = 0.0;
|
|
|
|
float score = 0.0;
|
|
|
|
std::vector<std::string> ngram = ext_scorer->make_ngram(prefix);
|
|
|
|
std::vector<std::string> ngram = ext_scorer->make_ngram(prefix);
|
|
|
|
score = ext_scorer->get_log_cond_prob(ngram) * ext_scorer->alpha;
|
|
|
|
score =
|
|
|
|
|
|
|
|
ext_scorer->get_log_cond_prob(ngram) * ext_scorer->alpha;
|
|
|
|
score += ext_scorer->beta;
|
|
|
|
score += ext_scorer->beta;
|
|
|
|
prefix->score += score;
|
|
|
|
prefix->score += score;
|
|
|
|
}
|
|
|
|
}
|
|
|
@ -179,7 +185,8 @@ std::vector<std::pair<double, std::string>> ctc_beam_search_decoder(
|
|
|
|
}
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
|
|
size_t num_prefixes = std::min(prefixes.size(), beam_size);
|
|
|
|
size_t num_prefixes = std::min(prefixes.size(), beam_size);
|
|
|
|
std::sort(prefixes.begin(), prefixes.begin() + num_prefixes, prefix_compare);
|
|
|
|
std::sort(
|
|
|
|
|
|
|
|
prefixes.begin(), prefixes.begin() + num_prefixes, prefix_compare);
|
|
|
|
|
|
|
|
|
|
|
|
// compute aproximate ctc score as the return score, without affecting the
|
|
|
|
// compute aproximate ctc score as the return score, without affecting the
|
|
|
|
// return order of decoding result. To delete when decoder gets stable.
|
|
|
|
// return order of decoding result. To delete when decoder gets stable.
|
|
|
@ -193,7 +200,8 @@ std::vector<std::pair<double, std::string>> ctc_beam_search_decoder(
|
|
|
|
// remove word insert
|
|
|
|
// remove word insert
|
|
|
|
approx_ctc = approx_ctc - prefix_length * ext_scorer->beta;
|
|
|
|
approx_ctc = approx_ctc - prefix_length * ext_scorer->beta;
|
|
|
|
// remove language model weight:
|
|
|
|
// remove language model weight:
|
|
|
|
approx_ctc -= (ext_scorer->get_sent_log_prob(words)) * ext_scorer->alpha;
|
|
|
|
approx_ctc -=
|
|
|
|
|
|
|
|
(ext_scorer->get_sent_log_prob(words)) * ext_scorer->alpha;
|
|
|
|
}
|
|
|
|
}
|
|
|
|
prefixes[i]->approx_ctc = approx_ctc;
|
|
|
|
prefixes[i]->approx_ctc = approx_ctc;
|
|
|
|
}
|
|
|
|
}
|
|
|
|