From 921c6d0cc1d9f2595edd710db3a7770e7d392988 Mon Sep 17 00:00:00 2001 From: Yibing Liu Date: Mon, 6 Nov 2017 23:20:20 +0800 Subject: [PATCH 1/2] Add the scoring of last word/char of prefixes in CTC beam search decoder --- decoders/swig/ctc_beam_search_decoder.cpp | 28 +++++++++++++++++++---- 1 file changed, 23 insertions(+), 5 deletions(-) diff --git a/decoders/swig/ctc_beam_search_decoder.cpp b/decoders/swig/ctc_beam_search_decoder.cpp index 624784b05..0f33e8573 100644 --- a/decoders/swig/ctc_beam_search_decoder.cpp +++ b/decoders/swig/ctc_beam_search_decoder.cpp @@ -110,17 +110,17 @@ std::vector> ctc_beam_search_decoder( // language model scoring if (ext_scorer != nullptr && (c == space_id || ext_scorer->is_character_based())) { - PathTrie *prefix_toscore = nullptr; + PathTrie *prefix_to_score = nullptr; // skip scoring the space if (ext_scorer->is_character_based()) { - prefix_toscore = prefix_new; + prefix_to_score = prefix_new; } else { - prefix_toscore = prefix; + prefix_to_score = prefix; } - double score = 0.0; + float score = 0.0; std::vector ngram; - ngram = ext_scorer->make_ngram(prefix_toscore); + ngram = ext_scorer->make_ngram(prefix_to_score); score = ext_scorer->get_log_cond_prob(ngram) * ext_scorer->alpha; log_p += score; log_p += ext_scorer->beta; @@ -131,6 +131,7 @@ std::vector> ctc_beam_search_decoder( } // end of loop over prefix } // end of loop over vocabulary + prefixes.clear(); // update log probs root.iterate_to_vec(prefixes); @@ -147,6 +148,23 @@ std::vector> ctc_beam_search_decoder( } } // end of loop over time + // score the last word/character of each prefix + if (ext_scorer != nullptr) { + for (size_t i = 0; i < beam_size && i < prefixes.size(); ++i) { + auto prefix = prefixes[i]; + if (prefix->character != space_id && !prefix->is_empty()) { + float score = 0.0; + std::vector ngram = ext_scorer->make_ngram(prefix); + score = ext_scorer->get_log_cond_prob(ngram) * ext_scorer->alpha; + score += ext_scorer->beta; + prefix->score += score; + } + } + } + + size_t num_prefixes = std::min(prefixes.size(), beam_size); + std::sort(prefixes.begin(), prefixes.begin() + num_prefixes, prefix_compare); + // compute aproximate ctc score as the return score, without affecting the // return order of decoding result. To delete when decoder gets stable. for (size_t i = 0; i < beam_size && i < prefixes.size(); ++i) { From 4b20a7029c24a9645dea56682d527b8137150cd2 Mon Sep 17 00:00:00 2001 From: Yibing Liu Date: Tue, 7 Nov 2017 17:10:15 +0800 Subject: [PATCH 2/2] skip scoring the end when using character-based scorer --- decoders/swig/ctc_beam_search_decoder.cpp | 6 +++--- 1 file changed, 3 insertions(+), 3 deletions(-) diff --git a/decoders/swig/ctc_beam_search_decoder.cpp b/decoders/swig/ctc_beam_search_decoder.cpp index 0f33e8573..4a63af26a 100644 --- a/decoders/swig/ctc_beam_search_decoder.cpp +++ b/decoders/swig/ctc_beam_search_decoder.cpp @@ -148,11 +148,11 @@ std::vector> ctc_beam_search_decoder( } } // end of loop over time - // score the last word/character of each prefix - if (ext_scorer != nullptr) { + // score the last word of each prefix that doesn't end with space + if (ext_scorer != nullptr && !ext_scorer->is_character_based()) { for (size_t i = 0; i < beam_size && i < prefixes.size(); ++i) { auto prefix = prefixes[i]; - if (prefix->character != space_id && !prefix->is_empty()) { + if (!prefix->is_empty() && prefix->character != space_id) { float score = 0.0; std::vector ngram = ext_scorer->make_ngram(prefix); score = ext_scorer->get_log_cond_prob(ngram) * ext_scorer->alpha;