diff --git a/decoders/swig/ctc_beam_search_decoder.cpp b/decoders/swig/ctc_beam_search_decoder.cpp index 624784b0..0f33e857 100644 --- a/decoders/swig/ctc_beam_search_decoder.cpp +++ b/decoders/swig/ctc_beam_search_decoder.cpp @@ -110,17 +110,17 @@ std::vector> ctc_beam_search_decoder( // language model scoring if (ext_scorer != nullptr && (c == space_id || ext_scorer->is_character_based())) { - PathTrie *prefix_toscore = nullptr; + PathTrie *prefix_to_score = nullptr; // skip scoring the space if (ext_scorer->is_character_based()) { - prefix_toscore = prefix_new; + prefix_to_score = prefix_new; } else { - prefix_toscore = prefix; + prefix_to_score = prefix; } - double score = 0.0; + float score = 0.0; std::vector ngram; - ngram = ext_scorer->make_ngram(prefix_toscore); + ngram = ext_scorer->make_ngram(prefix_to_score); score = ext_scorer->get_log_cond_prob(ngram) * ext_scorer->alpha; log_p += score; log_p += ext_scorer->beta; @@ -131,6 +131,7 @@ std::vector> ctc_beam_search_decoder( } // end of loop over prefix } // end of loop over vocabulary + prefixes.clear(); // update log probs root.iterate_to_vec(prefixes); @@ -147,6 +148,23 @@ std::vector> ctc_beam_search_decoder( } } // end of loop over time + // score the last word/character of each prefix + if (ext_scorer != nullptr) { + for (size_t i = 0; i < beam_size && i < prefixes.size(); ++i) { + auto prefix = prefixes[i]; + if (prefix->character != space_id && !prefix->is_empty()) { + float score = 0.0; + std::vector ngram = ext_scorer->make_ngram(prefix); + score = ext_scorer->get_log_cond_prob(ngram) * ext_scorer->alpha; + score += ext_scorer->beta; + prefix->score += score; + } + } + } + + size_t num_prefixes = std::min(prefixes.size(), beam_size); + std::sort(prefixes.begin(), prefixes.begin() + num_prefixes, prefix_compare); + // compute aproximate ctc score as the return score, without affecting the // return order of decoding result. To delete when decoder gets stable. for (size_t i = 0; i < beam_size && i < prefixes.size(); ++i) {