diff --git a/speechx/examples/ds2_ol/aishell/run.sh b/speechx/examples/ds2_ol/aishell/run.sh index e5fccc03..794b533f 100755 --- a/speechx/examples/ds2_ol/aishell/run.sh +++ b/speechx/examples/ds2_ol/aishell/run.sh @@ -135,7 +135,7 @@ fi if [ ${stage} -le 4 ] && [ ${stop_stage} -ge 4 ]; then # TLG decoder utils/run.pl JOB=1:$nj $data/split${nj}/JOB/recog.wfst.log \ - tlg_decoder_main \ + ctc_tlg_decoder_main \ --feature_rspecifier=scp:$data/split${nj}/JOB/feat.scp \ --model_path=$model_dir/avg_1.jit.pdmodel \ --param_path=$model_dir/avg_1.jit.pdiparams \ diff --git a/speechx/examples/ds2_ol/aishell/run_fbank.sh b/speechx/examples/ds2_ol/aishell/run_fbank.sh index 88ed6287..1c3c3e01 100755 --- a/speechx/examples/ds2_ol/aishell/run_fbank.sh +++ b/speechx/examples/ds2_ol/aishell/run_fbank.sh @@ -133,7 +133,7 @@ fi if [ ${stage} -le 4 ] && [ ${stop_stage} -ge 4 ]; then # TLG decoder utils/run.pl JOB=1:$nj $data/split${nj}/JOB/recog.fbank.wfst.log \ - tlg_decoder_main \ + ctc_tlg_decoder_main \ --feature_rspecifier=scp:$data/split${nj}/JOB/fbank_feat.scp \ --model_path=$model_dir/avg_5.jit.pdmodel \ --param_path=$model_dir/avg_5.jit.pdiparams \ diff --git a/speechx/speechx/decoder/CMakeLists.txt b/speechx/speechx/decoder/CMakeLists.txt index 8d04a997..20e93523 100644 --- a/speechx/speechx/decoder/CMakeLists.txt +++ b/speechx/speechx/decoder/CMakeLists.txt @@ -15,7 +15,7 @@ set(BINS ctc_beam_search_decoder_main nnet_logprob_decoder_main recognizer_main - tlg_decoder_main + ctc_tlg_decoder_main ) foreach(bin_name IN LISTS BINS) diff --git a/speechx/speechx/decoder/ctc_beam_search_decoder.cc b/speechx/speechx/decoder/ctc_beam_search_decoder.cc index 5a12c0b5..ff3298b2 100644 --- a/speechx/speechx/decoder/ctc_beam_search_decoder.cc +++ b/speechx/speechx/decoder/ctc_beam_search_decoder.cc @@ -12,10 +12,10 @@ // See the License for the specific language governing permissions and // limitations under the License. -#include "decoder/ctc_beam_search_decoder.h" -#include "base/basic_types.h" +#include "base/common.h" #include "decoder/ctc_decoders/decoder_utils.h" +#include "decoder/ctc_beam_search_decoder.h" #include "utils/file_utils.h" namespace ppspeech { @@ -26,7 +26,7 @@ using FSTMATCH = fst::SortedMatcher; CTCBeamSearch::CTCBeamSearch(const CTCBeamSearchOptions& opts) : opts_(opts), init_ext_scorer_(nullptr), - blank_id_(-1), + blank_id_(opts.blank), space_id_(-1), num_frame_decoded_(0), root_(nullptr) { @@ -43,9 +43,9 @@ CTCBeamSearch::CTCBeamSearch(const CTCBeamSearchOptions& opts) opts_.alpha, opts_.beta, opts_.lm_path, vocabulary_); } - blank_id_ = 0; - auto it = std::find(vocabulary_.begin(), vocabulary_.end(), " "); + CHECK(blank_id_==0); + auto it = std::find(vocabulary_.begin(), vocabulary_.end(), " "); space_id_ = it - vocabulary_.begin(); // if no space in vocabulary if ((size_t)space_id_ >= vocabulary_.size()) { diff --git a/speechx/speechx/decoder/ctc_beam_search_decoder.h b/speechx/speechx/decoder/ctc_beam_search_decoder.h index 19dbf2f6..e36eb4a0 100644 --- a/speechx/speechx/decoder/ctc_beam_search_decoder.h +++ b/speechx/speechx/decoder/ctc_beam_search_decoder.h @@ -14,67 +14,48 @@ // used by deepspeech2 -#include "base/common.h" +#pragma once + +#include "decoder/ctc_beam_search_opt.h" #include "decoder/ctc_decoders/path_trie.h" #include "decoder/ctc_decoders/scorer.h" -#include "kaldi/decoder/decodable-itf.h" -#include "util/parse-options.h" - -#pragma once +#include "decoder/decoder_itf.h" namespace ppspeech { -struct CTCBeamSearchOptions { - std::string dict_file; - std::string lm_path; - BaseFloat alpha; - BaseFloat beta; - BaseFloat cutoff_prob; - int beam_size; - int cutoff_top_n; - int num_proc_bsearch; - CTCBeamSearchOptions() - : dict_file("vocab.txt"), - lm_path(""), - alpha(1.9f), - beta(5.0), - beam_size(300), - cutoff_prob(0.99f), - cutoff_top_n(40), - num_proc_bsearch(10) {} - - void Register(kaldi::OptionsItf* opts) { - opts->Register("dict", &dict_file, "dict file "); - opts->Register("lm-path", &lm_path, "language model file"); - opts->Register("alpha", &alpha, "alpha"); - opts->Register("beta", &beta, "beta"); - opts->Register( - "beam-size", &beam_size, "beam size for beam search method"); - opts->Register("cutoff-prob", &cutoff_prob, "cutoff probs"); - opts->Register("cutoff-top-n", &cutoff_top_n, "cutoff top n"); - opts->Register( - "num-proc-bsearch", &num_proc_bsearch, "num proc bsearch"); - } -}; - -class CTCBeamSearch { +class CTCBeamSearch : public DecoderInterface { public: explicit CTCBeamSearch(const CTCBeamSearchOptions& opts); ~CTCBeamSearch() {} + void InitDecoder(); + + void Reset(); + + void AdvanceDecode( + const std::shared_ptr& decodable); + + std::string GetFinalBestPath(); + + std::string GetPartialResult() { + CHECK(false) << "Not implement."; + return {}; + } + void Decode(std::shared_ptr decodable); + std::string GetBestPath(); std::vector> GetNBestPath(); - std::string GetFinalBestPath(); + + int NumFrameDecoded(); + int DecodeLikelihoods(const std::vector>& probs, std::vector& nbest_words); - void AdvanceDecode( - const std::shared_ptr& decodable); - void Reset(); private: void ResetPrefixes(); + int32 SearchOneChar(const bool& full_beam, const std::pair& log_prob_idx, const BaseFloat& min_cutoff); @@ -93,4 +74,4 @@ class CTCBeamSearch { DISALLOW_COPY_AND_ASSIGN(CTCBeamSearch); }; -} // namespace basr \ No newline at end of file +} // namespace ppspeech \ No newline at end of file diff --git a/speechx/speechx/decoder/ctc_beam_search_opt.h b/speechx/speechx/decoder/ctc_beam_search_opt.h new file mode 100644 index 00000000..dcb62258 --- /dev/null +++ b/speechx/speechx/decoder/ctc_beam_search_opt.h @@ -0,0 +1,78 @@ +// Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#include "base/common.h" +#include "util/parse-options.h" + +#pragma once + +namespace ppspeech { + +struct CTCBeamSearchOptions { + // common + int blank; + + // ds2 + std::string dict_file; + std::string lm_path; + int beam_size; + BaseFloat alpha; + BaseFloat beta; + BaseFloat cutoff_prob; + int cutoff_top_n; + int num_proc_bsearch; + + // u2 + int first_beam_size; + int second_beam_size; + CTCBeamSearchOptions() + : blank(0), + dict_file("vocab.txt"), + lm_path(""), + alpha(1.9f), + beta(5.0), + beam_size(300), + cutoff_prob(0.99f), + cutoff_top_n(40), + num_proc_bsearch(10), + first_beam_size(10), + second_beam_size(10) {} + + void Register(kaldi::OptionsItf* opts) { + std::string module = "Ds2BeamSearchConfig: "; + opts->Register("dict", &dict_file, module + "vocab file path."); + opts->Register( + "lm-path", &lm_path, module + "ngram language model path."); + opts->Register("alpha", &alpha, module + "alpha"); + opts->Register("beta", &beta, module + "beta"); + opts->Register("beam-size", + &beam_size, + module + "beam size for beam search method"); + opts->Register("cutoff-prob", &cutoff_prob, module + "cutoff probs"); + opts->Register("cutoff-top-n", &cutoff_top_n, module + "cutoff top n"); + opts->Register( + "num-proc-bsearch", &num_proc_bsearch, module + "num proc bsearch"); + + opts->Register("blank", &blank, "blank id, default is 0."); + + module = "U2BeamSearchConfig: "; + opts->Register( + "first-beam-size", &first_beam_size, module + "first beam size."); + opts->Register("second-beam-size", + &second_beam_size, + module + "second beam size."); + } +}; + +} // namespace ppspeech \ No newline at end of file diff --git a/speechx/speechx/decoder/ctc_prefix_beam_search.cc b/speechx/speechx/decoder/ctc_prefix_beam_search.cc deleted file mode 100644 index e69de29b..00000000 diff --git a/speechx/speechx/decoder/ctc_prefix_beam_search_decoder.cc b/speechx/speechx/decoder/ctc_prefix_beam_search_decoder.cc new file mode 100644 index 00000000..0544a1e2 --- /dev/null +++ b/speechx/speechx/decoder/ctc_prefix_beam_search_decoder.cc @@ -0,0 +1,13 @@ +// Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. diff --git a/speechx/speechx/decoder/ctc_prefix_beam_search_decoder.h b/speechx/speechx/decoder/ctc_prefix_beam_search_decoder.h new file mode 100644 index 00000000..745c4a83 --- /dev/null +++ b/speechx/speechx/decoder/ctc_prefix_beam_search_decoder.h @@ -0,0 +1,64 @@ +// Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#pragma once + +#include "decoder/ctc_beam_search_opt.h" +#include "decoder/ctc_prefix_beam_search_score.h" +#include "decoder/decoder_itf.h" + +#include "kaldi/decoder/decodable-itf.h" + +namespace ppspeech { + +class CTCPrefixBeamSearch : public DecoderInterface { + public: + explicit CTCPrefixBeamSearch(const CTCBeamSearchOptions& opts); + ~CTCPrefixBeamSearch() {} + + void InitDecoder(); + + void Decode(std::shared_ptr decodable); + + std::string GetBestPath(); + + std::vector> GetNBestPath(); + + std::string GetFinalBestPath(); + + int NumFrameDecoded(); + + int DecodeLikelihoods(const std::vector>& probs, + std::vector& nbest_words); + + void AdvanceDecode( + const std::shared_ptr& decodable); + void Reset(); + + private: + void ResetPrefixes(); + int32 SearchOneChar(const bool& full_beam, + const std::pair& log_prob_idx, + const BaseFloat& min_cutoff); + void CalculateApproxScore(); + void LMRescore(); + void AdvanceDecoding(const std::vector>& probs); + + CTCBeamSearchOptions opts_; + size_t blank_id_; + int num_frame_decoded_; + DISALLOW_COPY_AND_ASSIGN(CTCPrefixBeamSearch); +}; + +} // namespace basr \ No newline at end of file diff --git a/speechx/speechx/decoder/ctc_prefix_beam_search_score.h b/speechx/speechx/decoder/ctc_prefix_beam_search_score.h new file mode 100644 index 00000000..19423b5e --- /dev/null +++ b/speechx/speechx/decoder/ctc_prefix_beam_search_score.h @@ -0,0 +1,68 @@ +// Copyright (c) 2020 Mobvoi Inc (Binbin Zhang) +// Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#pragma once + +#include "base/common.h" +#include "utils/math.h" + +namespace ppspeech { + +struct PrefxiScore { + // decoding, unit in log scale + float b = -kFloatMax; // blank ending score + float nb = -kFloatMax; // none-blank ending score + + // timestamp, unit in log sclae + float v_b = -kFloatMax; // viterbi blank ending score + float v_nb = -kFloatMax; // niterbi none-blank ending score + float cur_token_prob = -kFloatMax; // prob of current token + std::vector times_b; // times of viterbi blank path + std::vector times_nb; // times of viterbi non-blank path + + // context state + bool has_context = false; + int context_state = 0; + float context_score = 0; + + // decoding score, sum + float Score() const { return LogSumExp(b, nb); } + + // decodign score with context bias + float TotalScore() const { return Score() + context_score; } + + // timestamp score, max + float ViterbiScore() const { return std::max(v_b, v_nb); } + + // get timestamp + const std::vector& Times() const { + return v_b > v_nb ? times_b : times_nb; + } +}; + +struct PrefixScoreHash { + // https://stackoverflow.com/questions/20511347/a-good-hash-function-for-a-vector + std::size_t operator()(const std::vector& prefix) const { + std::size_t seed = prefix.size(); + for (auto& i : prefix) { + seed ^= i + 0x9e3779b9 + (seed << 6) + (seed >> 2); + } + return seed; + } +}; + +using PrefixWithScoreType = std::pair, PrefixScoreHash>; + +} // namespace ppspeech diff --git a/speechx/speechx/decoder/ctc_tlg_decoder.cc b/speechx/speechx/decoder/ctc_tlg_decoder.cc index 712d27dd..de97f6ad 100644 --- a/speechx/speechx/decoder/ctc_tlg_decoder.cc +++ b/speechx/speechx/decoder/ctc_tlg_decoder.cc @@ -22,24 +22,24 @@ TLGDecoder::TLGDecoder(TLGDecoderOptions opts) { fst::SymbolTable::ReadText(opts.word_symbol_table)); decoder_.reset(new kaldi::LatticeFasterOnlineDecoder(*fst_, opts.opts)); decoder_->InitDecoding(); - frame_decoded_size_ = 0; + num_frame_decoded_ = 0; } void TLGDecoder::InitDecoder() { decoder_->InitDecoding(); - frame_decoded_size_ = 0; + num_frame_decoded_ = 0; } void TLGDecoder::AdvanceDecode( const std::shared_ptr& decodable) { - while (!decodable->IsLastFrame(frame_decoded_size_)) { + while (!decodable->IsLastFrame(num_frame_decoded_)) { AdvanceDecoding(decodable.get()); } } void TLGDecoder::AdvanceDecoding(kaldi::DecodableInterface* decodable) { decoder_->AdvanceDecoding(decodable, 1); - frame_decoded_size_++; + num_frame_decoded_++; } void TLGDecoder::Reset() { @@ -48,7 +48,7 @@ void TLGDecoder::Reset() { } std::string TLGDecoder::GetPartialResult() { - if (frame_decoded_size_ == 0) { + if (num_frame_decoded_ == 0) { // Assertion failed: (this->NumFramesDecoded() > 0 && "You cannot call // BestPathEnd if no frames were decoded.") return std::string(""); @@ -68,7 +68,7 @@ std::string TLGDecoder::GetPartialResult() { } std::string TLGDecoder::GetFinalBestPath() { - if (frame_decoded_size_ == 0) { + if (num_frame_decoded_ == 0) { // Assertion failed: (this->NumFramesDecoded() > 0 && "You cannot call // BestPathEnd if no frames were decoded.") return std::string(""); diff --git a/speechx/speechx/decoder/ctc_tlg_decoder.h b/speechx/speechx/decoder/ctc_tlg_decoder.h index 1ac46ac6..f2282cb8 100644 --- a/speechx/speechx/decoder/ctc_tlg_decoder.h +++ b/speechx/speechx/decoder/ctc_tlg_decoder.h @@ -14,8 +14,9 @@ #pragma once -#include "base/basic_types.h" -#include "kaldi/decoder/decodable-itf.h" +#include "base/common.h" +#include "decoder/decoder_itf.h" + #include "kaldi/decoder/lattice-faster-online-decoder.h" #include "util/parse-options.h" @@ -30,21 +31,31 @@ struct TLGDecoderOptions { TLGDecoderOptions() : word_symbol_table(""), fst_path("") {} }; -class TLGDecoder { +class TLGDecoder : public DecoderInterface { public: explicit TLGDecoder(TLGDecoderOptions opts); + ~TLGDecoder() = default; + void InitDecoder(); + void Reset(); + + void AdvanceDecode( + const std::shared_ptr& decodable); + + + std::string GetFinalBestPath(); + std::string GetPartialResult(); + + void Decode(); + std::string GetBestPath(); std::vector> GetNBestPath(); - std::string GetFinalBestPath(); - std::string GetPartialResult(); + int NumFrameDecoded(); int DecodeLikelihoods(const std::vector>& probs, std::vector& nbest_words); - void AdvanceDecode( - const std::shared_ptr& decodable); - void Reset(); + private: void AdvanceDecoding(kaldi::DecodableInterface* decodable); @@ -53,7 +64,7 @@ class TLGDecoder { std::shared_ptr> fst_; std::shared_ptr word_symbol_table_; // the frame size which have decoded starts from 0. - int32 frame_decoded_size_; + int32 num_frame_decoded_; }; diff --git a/speechx/speechx/decoder/tlg_decoder_main.cc b/speechx/speechx/decoder/ctc_tlg_decoder_main.cc similarity index 99% rename from speechx/speechx/decoder/tlg_decoder_main.cc rename to speechx/speechx/decoder/ctc_tlg_decoder_main.cc index b633022a..cd1249d8 100644 --- a/speechx/speechx/decoder/tlg_decoder_main.cc +++ b/speechx/speechx/decoder/ctc_tlg_decoder_main.cc @@ -14,13 +14,15 @@ // todo refactor, repalce with gtest -#include "base/flags.h" -#include "base/log.h" -#include "decoder/ctc_tlg_decoder.h" +#include "base/common.h" + #include "frontend/audio/data_cache.h" -#include "kaldi/util/table-types.h" #include "nnet/decodable.h" #include "nnet/ds2_nnet.h" +#include "decoder/ctc_tlg_decoder.h" + +#include "kaldi/util/table-types.h" + DEFINE_string(feature_rspecifier, "", "test feature rspecifier"); DEFINE_string(result_wspecifier, "", "test result wspecifier"); diff --git a/speechx/speechx/decoder/decoder_itf.h b/speechx/speechx/decoder/decoder_itf.h new file mode 100644 index 00000000..01061939 --- /dev/null +++ b/speechx/speechx/decoder/decoder_itf.h @@ -0,0 +1,56 @@ + +// Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#pragma once + +#include "base/common.h" +#include "kaldi/decoder/decodable-itf.h" + +namespace ppspeech { + +class DecoderInterface { + public: + virtual ~DecoderInterface() {} + + virtual void InitDecoder() = 0; + + virtual void Reset() = 0; + + virtual void AdvanceDecode( + const std::shared_ptr& decodable) = 0; + + + virtual std::string GetFinalBestPath() = 0; + + virtual std::string GetPartialResult() = 0; + + // void Decode(); + + // std::string GetBestPath(); + // std::vector> GetNBestPath(); + + // int NumFrameDecoded(); + // int DecodeLikelihoods(const std::vector>& probs, + // std::vector& nbest_words); + + + private: + // void AdvanceDecoding(kaldi::DecodableInterface* decodable); + + // current decoding frame number + int32 num_frame_decoded_; +}; + +} // namespace ppspeech \ No newline at end of file diff --git a/speechx/speechx/nnet/decodable.h b/speechx/speechx/nnet/decodable.h index bfb75067..70a16e2c 100644 --- a/speechx/speechx/nnet/decodable.h +++ b/speechx/speechx/nnet/decodable.h @@ -39,14 +39,14 @@ class Decodable : public kaldi::DecodableInterface { // forward nnet with feats bool AdvanceChunk(); - + // forward nnet with feats, and get nnet output bool AdvanceChunk(kaldi::Vector* logprobs, int* vocab_dim); - + void AttentionRescoring(const std::vector>& hyps, - float reverse_weight, - std::vector* rescoring_score); + float reverse_weight, + std::vector* rescoring_score); virtual bool IsLastFrame(int32 frame); diff --git a/speechx/speechx/nnet/ds2_nnet.h b/speechx/speechx/nnet/ds2_nnet.h index cd1648b4..e8a49c7d 100644 --- a/speechx/speechx/nnet/ds2_nnet.h +++ b/speechx/speechx/nnet/ds2_nnet.h @@ -56,9 +56,9 @@ class PaddleNnet : public NnetInterface { NnetOut* out) override; void AttentionRescoring(const std::vector>& hyps, - float reverse_weight, - std::vector* rescoring_score) override { - VLOG(2) << "deepspeech2 not has AttentionRescoring."; + float reverse_weight, + std::vector* rescoring_score) override { + VLOG(2) << "deepspeech2 not has AttentionRescoring."; } void Dim();