From 3c3aa6b59421f8f911247cd667426095f2298d58 Mon Sep 17 00:00:00 2001 From: Hui Zhang Date: Wed, 12 Oct 2022 12:31:20 +0000 Subject: [PATCH] simple ctc prefix beam search compile ok --- speechx/speechx/base/common.h | 2 + speechx/speechx/base/macros.h | 3 +- speechx/speechx/decoder/CMakeLists.txt | 3 +- .../decoder/ctc_beam_search_decoder.cc | 12 +- .../speechx/decoder/ctc_beam_search_decoder.h | 3 +- .../decoder/ctc_prefix_beam_search_decoder.cc | 304 ++++++++++++++++++ .../decoder/ctc_prefix_beam_search_decoder.h | 70 +++- .../decoder/ctc_prefix_beam_search_score.h | 50 ++- speechx/speechx/decoder/ctc_tlg_decoder.h | 2 - speechx/speechx/decoder/decoder_itf.h | 3 +- speechx/speechx/utils/math.cc | 4 +- 11 files changed, 406 insertions(+), 50 deletions(-) diff --git a/speechx/speechx/base/common.h b/speechx/speechx/base/common.h index 70b11b691..b470b9de5 100644 --- a/speechx/speechx/base/common.h +++ b/speechx/speechx/base/common.h @@ -15,6 +15,7 @@ #pragma once #include +#include #include #include #include @@ -35,6 +36,7 @@ #include #include #include +#include #include #include #include diff --git a/speechx/speechx/base/macros.h b/speechx/speechx/base/macros.h index 14332a806..faf39373d 100644 --- a/speechx/speechx/base/macros.h +++ b/speechx/speechx/base/macros.h @@ -25,8 +25,7 @@ namespace ppspeech { void operator=(const TypeName&) = delete #endif -constexpr float kFloatMax = std::numeric_limits::max(); - +// kSpaceSymbol in UTF-8 is: ▁ const std::string kSpaceSymbol = "\xe2\x96\x81"; } // namespace ppspeech diff --git a/speechx/speechx/decoder/CMakeLists.txt b/speechx/speechx/decoder/CMakeLists.txt index 20e935237..b08aaba59 100644 --- a/speechx/speechx/decoder/CMakeLists.txt +++ b/speechx/speechx/decoder/CMakeLists.txt @@ -2,10 +2,11 @@ project(decoder) include_directories(${CMAKE_CURRENT_SOURCE_DIR/ctc_decoders}) add_library(decoder STATIC - ctc_beam_search_decoder.cc ctc_decoders/decoder_utils.cpp ctc_decoders/path_trie.cpp ctc_decoders/scorer.cpp + ctc_beam_search_decoder.cc + ctc_prefix_beam_search_decoder.cc ctc_tlg_decoder.cc recognizer.cc ) diff --git a/speechx/speechx/decoder/ctc_beam_search_decoder.cc b/speechx/speechx/decoder/ctc_beam_search_decoder.cc index ff3298b2a..76342b870 100644 --- a/speechx/speechx/decoder/ctc_beam_search_decoder.cc +++ b/speechx/speechx/decoder/ctc_beam_search_decoder.cc @@ -26,9 +26,7 @@ using FSTMATCH = fst::SortedMatcher; CTCBeamSearch::CTCBeamSearch(const CTCBeamSearchOptions& opts) : opts_(opts), init_ext_scorer_(nullptr), - blank_id_(opts.blank), space_id_(-1), - num_frame_decoded_(0), root_(nullptr) { LOG(INFO) << "dict path: " << opts_.dict_file; if (!ReadFileToVector(opts_.dict_file, &vocabulary_)) { @@ -43,7 +41,7 @@ CTCBeamSearch::CTCBeamSearch(const CTCBeamSearchOptions& opts) opts_.alpha, opts_.beta, opts_.lm_path, vocabulary_); } - CHECK(blank_id_==0); + CHECK(opts_.blank==0); auto it = std::find(vocabulary_.begin(), vocabulary_.end(), " "); space_id_ = it - vocabulary_.begin(); @@ -167,7 +165,7 @@ void CTCBeamSearch::AdvanceDecoding(const vector>& probs) { continue; } min_cutoff = prefixes_[num_prefixes_ - 1]->score + - std::log(prob[blank_id_]) - + std::log(prob[opts_.blank]) - std::max(0.0, init_ext_scorer_->beta); full_beam = (num_prefixes_ == beam_size); @@ -195,9 +193,9 @@ void CTCBeamSearch::AdvanceDecoding(const vector>& probs) { for (size_t i = beam_size; i < prefixes_.size(); ++i) { prefixes_[i]->remove(); } - } // if + } // end if num_frame_decoded_++; - } // for probs_seq + } // end for probs_seq } int32 CTCBeamSearch::SearchOneChar( @@ -215,7 +213,7 @@ int32 CTCBeamSearch::SearchOneChar( break; } - if (c == blank_id_) { + if (c == opts_.blank) { prefix->log_prob_b_cur = log_sum_exp(prefix->log_prob_b_cur, log_prob_c + prefix->score); continue; diff --git a/speechx/speechx/decoder/ctc_beam_search_decoder.h b/speechx/speechx/decoder/ctc_beam_search_decoder.h index e36eb4a01..516f8b2ca 100644 --- a/speechx/speechx/decoder/ctc_beam_search_decoder.h +++ b/speechx/speechx/decoder/ctc_beam_search_decoder.h @@ -66,11 +66,10 @@ class CTCBeamSearch : public DecoderInterface { CTCBeamSearchOptions opts_; std::shared_ptr init_ext_scorer_; // todo separate later std::vector vocabulary_; // todo remove later - size_t blank_id_; int space_id_; std::shared_ptr root_; std::vector prefixes_; - int num_frame_decoded_; + DISALLOW_COPY_AND_ASSIGN(CTCBeamSearch); }; diff --git a/speechx/speechx/decoder/ctc_prefix_beam_search_decoder.cc b/speechx/speechx/decoder/ctc_prefix_beam_search_decoder.cc index 0544a1e29..fd6890235 100644 --- a/speechx/speechx/decoder/ctc_prefix_beam_search_decoder.cc +++ b/speechx/speechx/decoder/ctc_prefix_beam_search_decoder.cc @@ -11,3 +11,307 @@ // WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. // See the License for the specific language governing permissions and // limitations under the License. + + +#include "base/common.h" +#include "decoder/ctc_beam_search_opt.h" +#include "decoder/ctc_prefix_beam_search_score.h" +#include "decoder/ctc_prefix_beam_search_decoder.h" +#include "utils/math.h" + +#ifdef USE_PROFILING +#include "paddle/fluid/platform/profiler.h" +using paddle::platform::RecordEvent; +using paddle::platform::TracerEventType; +#endif + +namespace ppspeech { + +CTCPrefixBeamSearch::CTCPrefixBeamSearch(const CTCBeamSearchOptions& opts) + : opts_(opts) { + InitDecoder(); +} + +void CTCPrefixBeamSearch::InitDecoder() { + num_frame_decoded_ = 0; + + cur_hyps_.clear(); + + hypotheses_.clear(); + likelihood_.clear(); + viterbi_likelihood_.clear(); + times_.clear(); + outputs_.clear(); + + abs_time_step_ = 0; + + // empty hyp with Score + std::vector empty; + PrefixScore prefix_score; + prefix_score.b = 0.0f; // log(1) + prefix_score.nb = -kBaseFloatMax; // log(0) + prefix_score.v_b = 0.0f; // log(1) + prefix_score.v_nb = 0.0f; // log(1) + cur_hyps_[empty] = prefix_score; + + outputs_.emplace_back(empty); + hypotheses_.emplace_back(empty); + likelihood_.emplace_back(prefix_score.TotalScore()); + times_.emplace_back(empty); + +} + +void CTCPrefixBeamSearch::Reset() { + InitDecoder(); +} + +void CTCPrefixBeamSearch::Decode( + std::shared_ptr decodable) { + return; +} + +int32 CTCPrefixBeamSearch::NumFrameDecoded() { return num_frame_decoded_ + 1; } + + +void CTCPrefixBeamSearch::UpdateOutputs( + const std::pair, PrefixScore>& prefix) { + const std::vector& input = prefix.first; + // const std::vector& start_boundaries = prefix.second.start_boundaries; + // const std::vector& end_boundaries = prefix.second.end_boundaries; + + std::vector output; + int s = 0; + int e = 0; + for (int i = 0; i < input.size(); ++i) { + // if (s < start_boundaries.size() && i == start_boundaries[s]){ + // // + // output.emplace_back(context_graph_->start_tag_id()); + // ++s; + // } + + output.emplace_back(input[i]); + + // if (e < end_boundaries.size() && i == end_boundaries[e]){ + // // + // output.emplace_back(context_graph_->end_tag_id()); + // ++e; + // } + } + + outputs_.emplace_back(output); +} + + +void CTCPrefixBeamSearch::AdvanceDecode( + const std::shared_ptr& decodable) { + while (1) { + std::vector frame_prob; + bool flag = decodable->FrameLikelihood(num_frame_decoded_, &frame_prob); + if (flag == false) break; + std::vector> likelihood; + likelihood.push_back(frame_prob); + AdvanceDecoding(likelihood); + } +} + +static bool PrefixScoreCompare( + const std::pair, PrefixScore>& a, + const std::pair, PrefixScore>& b) { + // log domain + return a.second.TotalScore() > b.second.TotalScore(); +} + + +void CTCPrefixBeamSearch::AdvanceDecoding(const std::vector>& logp) { +#ifdef USE_PROFILING + RecordEvent event( + "CtcPrefixBeamSearch::AdvanceDecoding", TracerEventType::UserDefined, 1); +#endif + + if (logp.size() == 0) return; + + int first_beam_size = + std::min(static_cast(logp[0].size()), opts_.first_beam_size); + + for (int t = 0; t < logp.size(); ++t, ++abs_time_step_) { + const std::vector& logp_t = logp[t]; + std::unordered_map, PrefixScore, PrefixScoreHash> next_hyps; + + // 1. first beam prune, only select topk candidates + std::vector topk_score; + std::vector topk_index; + TopK(logp_t, first_beam_size, &topk_score, &topk_index); + + // 2. token passing + for (int i = 0; i < topk_index.size(); ++i) { + int id = topk_index[i]; + auto prob = topk_score[i]; + + for (const auto& it : cur_hyps_) { + const std::vector& prefix = it.first; + const PrefixScore& prefix_score = it.second; + + // If prefix doesn't exist in next_hyps, next_hyps[prefix] will insert + // PrefixScore(-inf, -inf) by default, since the default constructor + // of PrefixScore will set fields b(blank ending Score) and + // nb(none blank ending Score) to -inf, respectively. + + if (id == opts_.blank) { + // case 0: *a + => *a, *a + => *a, prefix not + // change + PrefixScore& next_score = next_hyps[prefix]; + next_score.b = LogSumExp(next_score.b, prefix_score.Score() + prob); + + // timestamp, blank is slince, not effact timestamp + next_score.v_b = prefix_score.ViterbiScore() + prob; + next_score.times_b = prefix_score.Times(); + + // Prefix not changed, copy the context from pefix + if (context_graph_ && !next_score.has_context) { + next_score.CopyContext(prefix_score); + next_score.has_context = true; + } + + } else if (!prefix.empty() && id == prefix.back()) { + // case 1: *a + a => *a, prefix not changed + PrefixScore& next_score1 = next_hyps[prefix]; + next_score1.nb = LogSumExp(next_score1.nb, prefix_score.nb + prob); + + // timestamp, non-blank symbol effact timestamp + if (next_score1.v_nb < prefix_score.v_nb + prob) { + // compute viterbi Score + next_score1.v_nb = prefix_score.v_nb + prob; + if (next_score1.cur_token_prob < prob) { + // store max token prob + next_score1.cur_token_prob = prob; + // update this timestamp as token appeared here. + next_score1.times_nb = prefix_score.times_nb; + assert(next_score1.times_nb.size() > 0); + next_score1.times_nb.back() = abs_time_step_; + } + } + + // Prefix not changed, copy the context from pefix + if (context_graph_ && !next_score1.has_context) { + next_score1.CopyContext(prefix_score); + next_score1.has_context = true; + } + + // case 2: *a + a => *aa, prefix changed. + std::vector new_prefix(prefix); + new_prefix.emplace_back(id); + PrefixScore& next_score2 = next_hyps[new_prefix]; + next_score2.nb = LogSumExp(next_score2.nb, prefix_score.b + prob); + + // timestamp, non-blank symbol effact timestamp + if (next_score2.v_nb < prefix_score.v_b + prob) { + // compute viterbi Score + next_score2.v_nb = prefix_score.v_b + prob; + // new token added + next_score2.cur_token_prob = prob; + next_score2.times_nb = prefix_score.times_b; + next_score2.times_nb.emplace_back(abs_time_step_); + } + + // Prefix changed, calculate the context Score. + if (context_graph_ && !next_score2.has_context) { + next_score2.UpdateContext( + context_graph_, prefix_score, id, prefix.size()); + next_score2.has_context = true; + } + + } else { + // id != prefix.back() + // case 3: *a + b => *ab, *a +b => *ab + std::vector new_prefix(prefix); + new_prefix.emplace_back(id); + PrefixScore& next_score = next_hyps[new_prefix]; + next_score.nb = LogSumExp(next_score.nb, prefix_score.Score() + prob); + + // timetamp, non-blank symbol effact timestamp + if (next_score.v_nb < prefix_score.ViterbiScore() + prob) { + next_score.v_nb = prefix_score.ViterbiScore() + prob; + + next_score.cur_token_prob = prob; + next_score.times_nb = prefix_score.Times(); + next_score.times_nb.emplace_back(abs_time_step_); + } + + // Prefix changed, calculate the context Score. + if (context_graph_ && !next_score.has_context) { + next_score.UpdateContext( + context_graph_, prefix_score, id, prefix.size()); + next_score.has_context = true; + } + } + } // end for (const auto& it : cur_hyps_) + } // end for (int i = 0; i < topk_index.size(); ++i) + + // 3. second beam prune, only keep top n best paths + std::vector, PrefixScore>> arr(next_hyps.begin(), + next_hyps.end()); + int second_beam_size = + std::min(static_cast(arr.size()), opts_.second_beam_size); + std::nth_element(arr.begin(), + arr.begin() + second_beam_size, + arr.end(), + PrefixScoreCompare); + arr.resize(second_beam_size); + std::sort(arr.begin(), arr.end(), PrefixScoreCompare); + + // 4. update cur_hyps by next_hyps, and get new result + UpdateHypotheses(arr); + + num_frame_decoded_++; + } // end for (int t = 0; t < logp.size(); ++t, ++abs_time_step_) +} + + +void CTCPrefixBeamSearch::UpdateHypotheses( + const std::vector, PrefixScore>>& hyps) { + cur_hyps_.clear(); + + outputs_.clear(); + hypotheses_.clear(); + likelihood_.clear(); + viterbi_likelihood_.clear(); + times_.clear(); + + for (auto& item : hyps) { + cur_hyps_[item.first] = item.second; + + UpdateOutputs(item); + hypotheses_.emplace_back(std::move(item.first)); + likelihood_.emplace_back(item.second.TotalScore()); + viterbi_likelihood_.emplace_back(item.second.ViterbiScore()); + times_.emplace_back(item.second.Times()); + } +} + +void CTCPrefixBeamSearch::FinalizeSearch() { UpdateFinalContext(); } + + +void CTCPrefixBeamSearch::UpdateFinalContext() { + if (context_graph_ == nullptr) return; + assert(hypotheses_.size() == cur_hyps_.size()); + assert(hypotheses_.size() == likelihood_.size()); + + // We should backoff the context Score/state when the context is + // not fully matched at the last time. + for (const auto& prefix : hypotheses_) { + PrefixScore& prefix_score = cur_hyps_[prefix]; + if (prefix_score.context_score != 0) { + // prefix_score.UpdateContext(context_graph_, prefix_score, 0, + // prefix.size()); + } + } + std::vector, PrefixScore>> arr(cur_hyps_.begin(), + cur_hyps_.end()); + std::sort(arr.begin(), arr.end(), PrefixScoreCompare); + + // Update cur_hyps_ and get new result + UpdateHypotheses(arr); +} + + +} // namespace ppspeech \ No newline at end of file diff --git a/speechx/speechx/decoder/ctc_prefix_beam_search_decoder.h b/speechx/speechx/decoder/ctc_prefix_beam_search_decoder.h index 745c4a835..b67733e81 100644 --- a/speechx/speechx/decoder/ctc_prefix_beam_search_decoder.h +++ b/speechx/speechx/decoder/ctc_prefix_beam_search_decoder.h @@ -18,10 +18,8 @@ #include "decoder/ctc_prefix_beam_search_score.h" #include "decoder/decoder_itf.h" -#include "kaldi/decoder/decodable-itf.h" - namespace ppspeech { - +class ContextGraph; class CTCPrefixBeamSearch : public DecoderInterface { public: explicit CTCPrefixBeamSearch(const CTCBeamSearchOptions& opts); @@ -29,36 +27,74 @@ class CTCPrefixBeamSearch : public DecoderInterface { void InitDecoder(); + void Reset(); + + void AdvanceDecode( + const std::shared_ptr& decodable); + + std::string GetFinalBestPath(); + + std::string GetPartialResult() { + CHECK(false) << "Not implement."; + return {}; + } + void Decode(std::shared_ptr decodable); std::string GetBestPath(); std::vector> GetNBestPath(); - std::string GetFinalBestPath(); int NumFrameDecoded(); int DecodeLikelihoods(const std::vector>& probs, std::vector& nbest_words); - void AdvanceDecode( - const std::shared_ptr& decodable); - void Reset(); + const std::vector& ViterbiLikelihood() const { + return viterbi_likelihood_; + } + + const std::vector>& Inputs() const { return hypotheses_; } + + const std::vector>& Outputs() const { return outputs_; } + + const std::vector& Likelihood() const { return likelihood_; } + const std::vector>& Times() const { return times_; } + private: - void ResetPrefixes(); - int32 SearchOneChar(const bool& full_beam, - const std::pair& log_prob_idx, - const BaseFloat& min_cutoff); - void CalculateApproxScore(); - void LMRescore(); - void AdvanceDecoding(const std::vector>& probs); + void AdvanceDecoding(const std::vector>& logp); + + void FinalizeSearch(); + void UpdateOutputs(const std::pair, PrefixScore>& prefix); + void UpdateHypotheses( + const std::vector, PrefixScore>>& prefix); + void UpdateFinalContext(); + + + private: CTCBeamSearchOptions opts_; - size_t blank_id_; - int num_frame_decoded_; + + int abs_time_step_ = 0; + + std::unordered_map, PrefixScore, PrefixScoreHash> + cur_hyps_; + + // n-best list and corresponding likelihood, in sorted order + std::vector> hypotheses_; + std::vector likelihood_; + + std::vector> times_; + std::vector viterbi_likelihood_; + + // Outputs contain the hypotheses_ and tags lik: and + std::vector> outputs_; + + std::shared_ptr context_graph_ = nullptr; + DISALLOW_COPY_AND_ASSIGN(CTCPrefixBeamSearch); }; -} // namespace basr \ No newline at end of file +} // namespace ppspeech \ No newline at end of file diff --git a/speechx/speechx/decoder/ctc_prefix_beam_search_score.h b/speechx/speechx/decoder/ctc_prefix_beam_search_score.h index 19423b5e0..da2fb80a9 100644 --- a/speechx/speechx/decoder/ctc_prefix_beam_search_score.h +++ b/speechx/speechx/decoder/ctc_prefix_beam_search_score.h @@ -20,35 +20,55 @@ namespace ppspeech { -struct PrefxiScore { +class ContextGraph; + +struct PrefixScore { // decoding, unit in log scale - float b = -kFloatMax; // blank ending score - float nb = -kFloatMax; // none-blank ending score + float b = -kBaseFloatMax; // blank ending score + float nb = -kBaseFloatMax; // none-blank ending score + + // decoding score, sum + float Score() const { return LogSumExp(b, nb); } // timestamp, unit in log sclae - float v_b = -kFloatMax; // viterbi blank ending score - float v_nb = -kFloatMax; // niterbi none-blank ending score - float cur_token_prob = -kFloatMax; // prob of current token - std::vector times_b; // times of viterbi blank path - std::vector times_nb; // times of viterbi non-blank path + float v_b = -kBaseFloatMax; // viterbi blank ending score + float v_nb = -kBaseFloatMax; // niterbi none-blank ending score + float cur_token_prob = -kBaseFloatMax; // prob of current token + std::vector times_b; // times of viterbi blank path + std::vector times_nb; // times of viterbi non-blank path + + + // timestamp score, max + float ViterbiScore() const { return std::max(v_b, v_nb); } + + // get timestamp + const std::vector& Times() const { + return v_b > v_nb ? times_b : times_nb; + } // context state bool has_context = false; int context_state = 0; float context_score = 0; + std::vector start_boundaries; + std::vector end_boundaries; - // decoding score, sum - float Score() const { return LogSumExp(b, nb); } // decodign score with context bias float TotalScore() const { return Score() + context_score; } - // timestamp score, max - float ViterbiScore() const { return std::max(v_b, v_nb); } + void CopyContext(const PrefixScore& prefix_score) { + context_state = prefix_score.context_state; + context_score = prefix_score.context_score; + start_boundaries = prefix_score.start_boundaries; + end_boundaries = prefix_score.end_boundaries; + } - // get timestamp - const std::vector& Times() const { - return v_b > v_nb ? times_b : times_nb; + void UpdateContext(const std::shared_ptr& constext_graph, + const PrefixScore& prefix_score, + int word_id, + int prefix_len) { + CHECK(false); } }; diff --git a/speechx/speechx/decoder/ctc_tlg_decoder.h b/speechx/speechx/decoder/ctc_tlg_decoder.h index f2282cb87..f3ecde73b 100644 --- a/speechx/speechx/decoder/ctc_tlg_decoder.h +++ b/speechx/speechx/decoder/ctc_tlg_decoder.h @@ -63,8 +63,6 @@ class TLGDecoder : public DecoderInterface { std::shared_ptr decoder_; std::shared_ptr> fst_; std::shared_ptr word_symbol_table_; - // the frame size which have decoded starts from 0. - int32 num_frame_decoded_; }; diff --git a/speechx/speechx/decoder/decoder_itf.h b/speechx/speechx/decoder/decoder_itf.h index 010619397..1bbc6b114 100644 --- a/speechx/speechx/decoder/decoder_itf.h +++ b/speechx/speechx/decoder/decoder_itf.h @@ -31,7 +31,6 @@ class DecoderInterface { virtual void AdvanceDecode( const std::shared_ptr& decodable) = 0; - virtual std::string GetFinalBestPath() = 0; virtual std::string GetPartialResult() = 0; @@ -46,7 +45,7 @@ class DecoderInterface { // std::vector& nbest_words); - private: + protected: // void AdvanceDecoding(kaldi::DecodableInterface* decodable); // current decoding frame number diff --git a/speechx/speechx/utils/math.cc b/speechx/speechx/utils/math.cc index 5087ac60b..6a13f69ba 100644 --- a/speechx/speechx/utils/math.cc +++ b/speechx/speechx/utils/math.cc @@ -28,8 +28,8 @@ namespace ppspeech { // Sum in log scale float LogSumExp(float x, float y) { - if (x <= -kFloatMax) return y; - if (y <= -kFloatMax) return x; + if (x <= -kBaseFloatMax) return y; + if (y <= -kBaseFloatMax) return x; float max = std::max(x, y); return max + std::log(std::exp(x - max) + std::exp(y - max)); }