diff --git a/speechx/CMakeLists.txt b/speechx/CMakeLists.txt index e24744d6d..d056ebbc1 100644 --- a/speechx/CMakeLists.txt +++ b/speechx/CMakeLists.txt @@ -33,7 +33,7 @@ set(FETCHCONTENT_BASE_DIR ${fc_patch}) # compiler option # Keep the same with openfst, -fPIC or -fpic -set(CMAKE_CXX_FLAGS "${CMAKE_CXX_FLAGS} --std=c++14 -pthread -fPIC -O0 -Wall -g") +set(CMAKE_CXX_FLAGS "${CMAKE_CXX_FLAGS} --std=c++14 -pthread -fPIC -O0 -Wall -g -ldl") SET(CMAKE_CXX_FLAGS_DEBUG "$ENV{CXXFLAGS} --std=c++14 -pthread -fPIC -O0 -Wall -g -ggdb") SET(CMAKE_CXX_FLAGS_RELEASE "$ENV{CXXFLAGS} --std=c++14 -pthread -fPIC -O3 -Wall") diff --git a/speechx/speechx/asr/decoder/CMakeLists.txt b/speechx/speechx/asr/decoder/CMakeLists.txt index b2f507080..07adda956 100644 --- a/speechx/speechx/asr/decoder/CMakeLists.txt +++ b/speechx/speechx/asr/decoder/CMakeLists.txt @@ -1,6 +1,7 @@ set(srcs) list(APPEND srcs ctc_prefix_beam_search_decoder.cc + ctc_tlg_decoder.cc ) add_library(decoder STATIC ${srcs}) @@ -9,6 +10,7 @@ target_link_libraries(decoder PUBLIC utils fst frontend nnet kaldi-decoder) # test set(TEST_BINS ctc_prefix_beam_search_decoder_main + ctc_tlg_decoder_main ) foreach(bin_name IN LISTS TEST_BINS) diff --git a/speechx/speechx/asr/decoder/ctc_prefix_beam_search_decoder.h b/speechx/speechx/asr/decoder/ctc_prefix_beam_search_decoder.h index 5013246a4..3fe1944c7 100644 --- a/speechx/speechx/asr/decoder/ctc_prefix_beam_search_decoder.h +++ b/speechx/speechx/asr/decoder/ctc_prefix_beam_search_decoder.h @@ -45,7 +45,7 @@ class CTCPrefixBeamSearch : public DecoderBase { void FinalizeSearch(); - const std::shared_ptr VocabTable() const { + const std::shared_ptr WordSymbolTable() const override { return unit_table_; } @@ -57,7 +57,6 @@ class CTCPrefixBeamSearch : public DecoderBase { } const std::vector>& Times() const { return times_; } - protected: std::string GetBestPath() override; std::vector> GetNBestPath() override; diff --git a/speechx/speechx/asr/decoder/ctc_tlg_decoder.cc b/speechx/speechx/asr/decoder/ctc_tlg_decoder.cc index 2c2b6d3c9..ca7d65c8f 100644 --- a/speechx/speechx/asr/decoder/ctc_tlg_decoder.cc +++ b/speechx/speechx/asr/decoder/ctc_tlg_decoder.cc @@ -15,7 +15,7 @@ #include "decoder/ctc_tlg_decoder.h" namespace ppspeech { -TLGDecoder::TLGDecoder(TLGDecoderOptions opts) { +TLGDecoder::TLGDecoder(TLGDecoderOptions opts) : opts_(opts) { fst_.reset(fst::Fst::Read(opts.fst_path)); CHECK(fst_ != nullptr); @@ -68,14 +68,52 @@ std::string TLGDecoder::GetPartialResult() { return words; } +void TLGDecoder::FinalizeSearch() { + decoder_->FinalizeDecoding(); + kaldi::CompactLattice clat; + decoder_->GetLattice(&clat, true); + kaldi::Lattice lat, nbest_lat; + fst::ConvertLattice(clat, &lat); + fst::ShortestPath(lat, &nbest_lat, opts_.nbest); + std::vector nbest_lats; + fst::ConvertNbestToVector(nbest_lat, &nbest_lats); + + hypotheses_.clear(); + hypotheses_.reserve(nbest_lats.size()); + likelihood_.clear(); + likelihood_.reserve(nbest_lats.size()); + times_.clear(); + times_.reserve(nbest_lats.size()); + for (auto lat : nbest_lats) { + kaldi::LatticeWeight weight; + std::vector hypothese; + std::vector time; + std::vector alignment; + std::vector words_id; + fst::GetLinearSymbolSequence(lat, &alignment, &words_id, &weight); + int idx = 0; + for (; idx < alignment.size() - 1; ++idx) { + if (alignment[idx] == 0) continue; + if (alignment[idx] != alignment[idx + 1]) { + hypothese.push_back(alignment[idx] - 1); + time.push_back(idx); // fake time, todo later + } + } + hypothese.push_back(alignment[idx] - 1); + time.push_back(idx); // fake time, todo later + hypotheses_.push_back(hypothese); + times_.push_back(time); + olabels.push_back(words_id); + likelihood_.push_back(-(weight.Value2() + weight.Value1())); + } +} + std::string TLGDecoder::GetFinalBestPath() { if (num_frame_decoded_ == 0) { // Assertion failed: (this->NumFramesDecoded() > 0 && "You cannot call // BestPathEnd if no frames were decoded.") return std::string(""); } - - decoder_->FinalizeDecoding(); kaldi::Lattice lat; kaldi::LatticeWeight weight; std::vector alignment; diff --git a/speechx/speechx/asr/decoder/ctc_tlg_decoder.h b/speechx/speechx/asr/decoder/ctc_tlg_decoder.h index 8be69dadd..1ea6d634a 100644 --- a/speechx/speechx/asr/decoder/ctc_tlg_decoder.h +++ b/speechx/speechx/asr/decoder/ctc_tlg_decoder.h @@ -19,9 +19,8 @@ #include "kaldi/decoder/lattice-faster-online-decoder.h" #include "util/parse-options.h" - -DECLARE_string(graph_path); DECLARE_string(word_symbol_table); +DECLARE_string(graph_path); DECLARE_int32(max_active); DECLARE_double(beam); DECLARE_double(lattice_beam); @@ -33,6 +32,9 @@ struct TLGDecoderOptions { // todo remove later, add into decode resource std::string word_symbol_table; std::string fst_path; + int nbest; + + TLGDecoderOptions() : word_symbol_table(""), fst_path(""), nbest(10) {} static TLGDecoderOptions InitFromFlags() { TLGDecoderOptions decoder_opts; @@ -44,6 +46,7 @@ struct TLGDecoderOptions { decoder_opts.opts.max_active = FLAGS_max_active; decoder_opts.opts.beam = FLAGS_beam; decoder_opts.opts.lattice_beam = FLAGS_lattice_beam; + // decoder_opts.nbest = FLAGS_lattice_nbest; LOG(INFO) << "LatticeFasterDecoder max active: " << decoder_opts.opts.max_active; LOG(INFO) << "LatticeFasterDecoder beam: " << decoder_opts.opts.beam; @@ -59,20 +62,38 @@ class TLGDecoder : public DecoderBase { explicit TLGDecoder(TLGDecoderOptions opts); ~TLGDecoder() = default; - void InitDecoder(); - void Reset(); + void InitDecoder() override; + void Reset() override; void AdvanceDecode( - const std::shared_ptr& decodable); + const std::shared_ptr& decodable) override; void Decode(); std::string GetFinalBestPath() override; std::string GetPartialResult() override; + const std::shared_ptr WordSymbolTable() const override { + return word_symbol_table_; + } + int DecodeLikelihoods(const std::vector>& probs, const std::vector& nbest_words); + void FinalizeSearch() override; + const std::vector>& Inputs() const override { + return hypotheses_; + } + const std::vector>& Outputs() const override { + return olabels; + } // outputs_; } + const std::vector& Likelihood() const override { + return likelihood_; + } + const std::vector>& Times() const override { + return times_; + } + protected: std::string GetBestPath() override { CHECK(false); @@ -90,9 +111,15 @@ class TLGDecoder : public DecoderBase { private: void AdvanceDecoding(kaldi::DecodableInterface* decodable); + std::vector> hypotheses_; + std::vector> olabels; + std::vector likelihood_; + std::vector> times_; + std::shared_ptr decoder_; std::shared_ptr> fst_; std::shared_ptr word_symbol_table_; + TLGDecoderOptions opts_; }; diff --git a/speechx/speechx/asr/decoder/ctc_tlg_decoder_main.cc b/speechx/speechx/asr/decoder/ctc_tlg_decoder_main.cc index e9bd8a3f4..148ee15e3 100644 --- a/speechx/speechx/asr/decoder/ctc_tlg_decoder_main.cc +++ b/speechx/speechx/asr/decoder/ctc_tlg_decoder_main.cc @@ -14,16 +14,16 @@ // todo refactor, repalce with gtest -#include "base/common.h" #include "decoder/ctc_tlg_decoder.h" +#include "base/common.h" #include "decoder/param.h" -#include "frontend/audio/data_cache.h" +#include "frontend/data_cache.h" #include "kaldi/util/table-types.h" #include "nnet/decodable.h" -#include "nnet/ds2_nnet.h" +#include "nnet/nnet_producer.h" -DEFINE_string(feature_rspecifier, "", "test feature rspecifier"); +DEFINE_string(nnet_prob_rspecifier, "", "test feature rspecifier"); DEFINE_string(result_wspecifier, "", "test result wspecifier"); @@ -39,8 +39,8 @@ int main(int argc, char* argv[]) { google::InstallFailureSignalHandler(); FLAGS_logtostderr = 1; - kaldi::SequentialBaseFloatMatrixReader feature_reader( - FLAGS_feature_rspecifier); + kaldi::SequentialBaseFloatMatrixReader nnet_prob_reader( + FLAGS_nnet_prob_rspecifier); kaldi::TokenWriter result_writer(FLAGS_result_wspecifier); int32 num_done = 0, num_err = 0; @@ -53,66 +53,19 @@ int main(int argc, char* argv[]) { ppspeech::ModelOptions model_opts = ppspeech::ModelOptions::InitFromFlags(); - std::shared_ptr nnet( - new ppspeech::PaddleNnet(model_opts)); - std::shared_ptr raw_data(new ppspeech::DataCache()); + std::shared_ptr nnet_producer = + std::make_shared(nullptr); std::shared_ptr decodable( - new ppspeech::Decodable(nnet, raw_data, FLAGS_acoustic_scale)); - - int32 chunk_size = FLAGS_receptive_field_length + - (FLAGS_nnet_decoder_chunk - 1) * FLAGS_subsampling_rate; - int32 chunk_stride = FLAGS_subsampling_rate * FLAGS_nnet_decoder_chunk; - int32 receptive_field_length = FLAGS_receptive_field_length; - LOG(INFO) << "chunk size (frame): " << chunk_size; - LOG(INFO) << "chunk stride (frame): " << chunk_stride; - LOG(INFO) << "receptive field (frame): " << receptive_field_length; + new ppspeech::Decodable(nnet_producer, FLAGS_acoustic_scale)); decoder.InitDecoder(); kaldi::Timer timer; - for (; !feature_reader.Done(); feature_reader.Next()) { - string utt = feature_reader.Key(); - kaldi::Matrix feature = feature_reader.Value(); - raw_data->SetDim(feature.NumCols()); - LOG(INFO) << "process utt: " << utt; - LOG(INFO) << "rows: " << feature.NumRows(); - LOG(INFO) << "cols: " << feature.NumCols(); - - int32 row_idx = 0; - int32 padding_len = 0; - int32 ori_feature_len = feature.NumRows(); - if ((feature.NumRows() - chunk_size) % chunk_stride != 0) { - padding_len = - chunk_stride - (feature.NumRows() - chunk_size) % chunk_stride; - feature.Resize(feature.NumRows() + padding_len, - feature.NumCols(), - kaldi::kCopyData); - } - int32 num_chunks = (feature.NumRows() - chunk_size) / chunk_stride + 1; - for (int chunk_idx = 0; chunk_idx < num_chunks; ++chunk_idx) { - kaldi::Vector feature_chunk(chunk_size * - feature.NumCols()); - int32 feature_chunk_size = 0; - if (ori_feature_len > chunk_idx * chunk_stride) { - feature_chunk_size = std::min( - ori_feature_len - chunk_idx * chunk_stride, chunk_size); - } - if (feature_chunk_size < receptive_field_length) break; - - int32 start = chunk_idx * chunk_stride; - for (int row_id = 0; row_id < chunk_size; ++row_id) { - kaldi::SubVector tmp(feature, start); - kaldi::SubVector f_chunk_tmp( - feature_chunk.Data() + row_id * feature.NumCols(), - feature.NumCols()); - f_chunk_tmp.CopyFromVec(tmp); - ++start; - } - raw_data->Accept(feature_chunk); - if (chunk_idx == num_chunks - 1) { - raw_data->SetFinished(); - } - decoder.AdvanceDecode(decodable); - } + + for (; !nnet_prob_reader.Done(); nnet_prob_reader.Next()) { + string utt = nnet_prob_reader.Key(); + kaldi::Matrix prob = nnet_prob_reader.Value(); + decodable->Acceptlikelihood(prob); + decoder.AdvanceDecode(decodable); std::string result; result = decoder.GetFinalBestPath(); decodable->Reset(); diff --git a/speechx/speechx/asr/decoder/decoder_itf.h b/speechx/speechx/asr/decoder/decoder_itf.h index 2289b3173..cb7717e8e 100644 --- a/speechx/speechx/asr/decoder/decoder_itf.h +++ b/speechx/speechx/asr/decoder/decoder_itf.h @@ -1,4 +1,3 @@ - // Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved. // // Licensed under the Apache License, Version 2.0 (the "License"); @@ -16,6 +15,7 @@ #pragma once #include "base/common.h" +#include "fst/symbol-table.h" #include "kaldi/decoder/decodable-itf.h" namespace ppspeech { @@ -41,6 +41,14 @@ class DecoderInterface { virtual std::string GetPartialResult() = 0; + virtual const std::shared_ptr WordSymbolTable() const = 0; + virtual void FinalizeSearch() = 0; + + virtual const std::vector>& Inputs() const = 0; + virtual const std::vector>& Outputs() const = 0; + virtual const std::vector& Likelihood() const = 0; + virtual const std::vector>& Times() const = 0; + protected: // virtual void AdvanceDecoding(kaldi::DecodableInterface* decodable) = 0; diff --git a/speechx/speechx/asr/decoder/param.h b/speechx/speechx/asr/decoder/param.h index cad6dbd8d..83e2c7fb4 100644 --- a/speechx/speechx/asr/decoder/param.h +++ b/speechx/speechx/asr/decoder/param.h @@ -57,8 +57,8 @@ DEFINE_string(model_cache_shapes, "5-1-1024,5-1-1024", "model cache shapes"); // decoder DEFINE_double(acoustic_scale, 1.0, "acoustic scale"); -DEFINE_string(graph_path, "TLG", "decoder graph"); -DEFINE_string(word_symbol_table, "words.txt", "word symbol table"); +DEFINE_string(graph_path, "", "decoder graph"); +DEFINE_string(word_symbol_table, "", "word symbol table"); DEFINE_int32(max_active, 7500, "max active"); DEFINE_double(beam, 15.0, "decoder beam"); DEFINE_double(lattice_beam, 7.5, "decoder beam"); diff --git a/speechx/speechx/asr/nnet/decodable.h b/speechx/speechx/asr/nnet/decodable.h index 44c7a0c33..c1dbb4b89 100644 --- a/speechx/speechx/asr/nnet/decodable.h +++ b/speechx/speechx/asr/nnet/decodable.h @@ -27,8 +27,6 @@ class Decodable : public kaldi::DecodableInterface { explicit Decodable(const std::shared_ptr& nnet_producer, kaldi::BaseFloat acoustic_scale = 1.0); - // void Init(DecodableOpts config); - // nnet logprob output, used by wfst virtual kaldi::BaseFloat LogLikelihood(int32 frame, int32 index); diff --git a/speechx/speechx/asr/nnet/nnet_producer.cc b/speechx/speechx/asr/nnet/nnet_producer.cc index b83b59767..29daa709d 100644 --- a/speechx/speechx/asr/nnet/nnet_producer.cc +++ b/speechx/speechx/asr/nnet/nnet_producer.cc @@ -23,25 +23,25 @@ using kaldi::BaseFloat; NnetProducer::NnetProducer(std::shared_ptr nnet, std::shared_ptr frontend) : nnet_(nnet), frontend_(frontend) { - abort_ = false; - Reset(); - thread_ = std::thread(RunNnetEvaluation, this); - } + abort_ = false; + Reset(); + if (nnet_ != nullptr) thread_ = std::thread(RunNnetEvaluation, this); +} void NnetProducer::Accept(const std::vector& inputs) { frontend_->Accept(inputs); condition_variable_.notify_one(); } -void NnetProducer::UnLock() { +void NnetProducer::WaitProduce() { std::unique_lock lock(read_mutex_); while (frontend_->IsFinished() == false && cache_.empty()) { - condition_read_ready_.wait(lock); + condition_read_ready_.wait(lock); } return; } -void NnetProducer::RunNnetEvaluation(NnetProducer *me) { +void NnetProducer::RunNnetEvaluation(NnetProducer* me) { me->RunNnetEvaluationInteral(); } @@ -55,7 +55,7 @@ void NnetProducer::RunNnetEvaluationInteral() { result = Compute(); } while (result); if (frontend_->IsFinished() == true) { - if (cache_.empty()) finished_ = true; + if (cache_.empty()) finished_ = true; } } LOG(INFO) << "NnetEvaluationInteral exit"; diff --git a/speechx/speechx/asr/nnet/nnet_producer.h b/speechx/speechx/asr/nnet/nnet_producer.h index 14c74d043..9eb3a4f78 100644 --- a/speechx/speechx/asr/nnet/nnet_producer.h +++ b/speechx/speechx/asr/nnet/nnet_producer.h @@ -34,9 +34,9 @@ class NnetProducer { // nnet bool Read(std::vector* nnet_prob); bool ReadandCompute(std::vector* nnet_prob); - static void RunNnetEvaluation(NnetProducer *me); + static void RunNnetEvaluation(NnetProducer* me); void RunNnetEvaluationInteral(); - void UnLock(); + void WaitProduce(); void Wait() { abort_ = true; @@ -56,12 +56,12 @@ class NnetProducer { bool IsFinished() const { return finished_; } ~NnetProducer() { - if (thread_.joinable()) thread_.join(); + if (thread_.joinable()) thread_.join(); } void Reset() { - frontend_->Reset(); - nnet_->Reset(); + if (frontend_ != NULL) frontend_->Reset(); + if (nnet_ != NULL) nnet_->Reset(); VLOG(3) << "feature cache reset: cache size: " << cache_.size(); cache_.clear(); finished_ = false; diff --git a/speechx/speechx/asr/recognizer/u2_recognizer.cc b/speechx/speechx/asr/recognizer/u2_recognizer.cc index 30595d79f..f31ceb3b6 100644 --- a/speechx/speechx/asr/recognizer/u2_recognizer.cc +++ b/speechx/speechx/asr/recognizer/u2_recognizer.cc @@ -33,11 +33,15 @@ U2Recognizer::U2Recognizer(const U2RecognizerResource& resource) decodable_.reset(new Decodable(nnet_producer_, am_scale)); CHECK_NE(resource.vocab_path, ""); - decoder_.reset(new CTCPrefixBeamSearch( - resource.vocab_path, resource.decoder_opts.ctc_prefix_search_opts)); + if (resource.decoder_opts.tlg_decoder_opts.fst_path == "") { + LOG(INFO) << resource.decoder_opts.tlg_decoder_opts.fst_path; + decoder_.reset(new CTCPrefixBeamSearch( + resource.vocab_path, resource.decoder_opts.ctc_prefix_search_opts)); + } else { + decoder_.reset(new TLGDecoder(resource.decoder_opts.tlg_decoder_opts)); + } - unit_table_ = decoder_->VocabTable(); - symbol_table_ = unit_table_; + symbol_table_ = decoder_->WordSymbolTable(); global_frame_offset_ = 0; input_finished_ = false; @@ -56,11 +60,14 @@ U2Recognizer::U2Recognizer(const U2RecognizerResource& resource, decodable_.reset(new Decodable(nnet_producer_, am_scale)); CHECK_NE(resource.vocab_path, ""); - decoder_.reset(new CTCPrefixBeamSearch( - resource.vocab_path, resource.decoder_opts.ctc_prefix_search_opts)); + if (resource.decoder_opts.tlg_decoder_opts.fst_path == "") { + decoder_.reset(new CTCPrefixBeamSearch( + resource.vocab_path, resource.decoder_opts.ctc_prefix_search_opts)); + } else { + decoder_.reset(new TLGDecoder(resource.decoder_opts.tlg_decoder_opts)); + } - unit_table_ = decoder_->VocabTable(); - symbol_table_ = unit_table_; + symbol_table_ = decoder_->WordSymbolTable(); global_frame_offset_ = 0; input_finished_ = false; @@ -109,10 +116,11 @@ void U2Recognizer::RunDecoderSearch(U2Recognizer* me) { void U2Recognizer::RunDecoderSearchInternal() { LOG(INFO) << "DecoderSearchInteral begin"; while (!nnet_producer_->IsFinished()) { - nnet_producer_->UnLock(); + nnet_producer_->WaitProduce(); decoder_->AdvanceDecode(decodable_); } - Decode(); + decoder_->AdvanceDecode(decodable_); + UpdateResult(false); LOG(INFO) << "DecoderSearchInteral exit"; } @@ -140,7 +148,7 @@ void U2Recognizer::UpdateResult(bool finish) { const auto& times = decoder_->Times(); result_.clear(); - CHECK_EQ(hypotheses.size(), likelihood.size()); + CHECK_EQ(inputs.size(), likelihood.size()); for (size_t i = 0; i < hypotheses.size(); i++) { const std::vector& hypothesis = hypotheses[i]; @@ -148,13 +156,9 @@ void U2Recognizer::UpdateResult(bool finish) { path.score = likelihood[i]; for (size_t j = 0; j < hypothesis.size(); j++) { std::string word = symbol_table_->Find(hypothesis[j]); - // A detailed explanation of this if-else branch can be found in - // https://github.com/wenet-e2e/wenet/issues/583#issuecomment-907994058 - if (decoder_->Type() == kWfstBeamSearch) { - path.sentence += (" " + word); - } else { - path.sentence += (word); - } + // path.sentence += (" " + word); // todo SmileGoat: add blank + // processor + path.sentence += word; // todo SmileGoat: add blank processor } // TimeStamp is only supported in final result @@ -162,7 +166,7 @@ void U2Recognizer::UpdateResult(bool finish) { // various FST operations when building the decoding graph. So here we // use time stamp of the input(e2e model unit), which is more accurate, // and it requires the symbol table of the e2e model used in training. - if (unit_table_ != nullptr && finish) { + if (symbol_table_ != nullptr && finish) { int offset = global_frame_offset_ * FrameShiftInMs(); const std::vector& input = inputs[i]; @@ -170,7 +174,7 @@ void U2Recognizer::UpdateResult(bool finish) { CHECK_EQ(input.size(), time_stamp.size()); for (size_t j = 0; j < input.size(); j++) { - std::string word = unit_table_->Find(input[j]); + std::string word = symbol_table_->Find(input[j]); int start = time_stamp[j] * FrameShiftInMs() - time_stamp_gap_ > 0 @@ -214,7 +218,7 @@ void U2Recognizer::UpdateResult(bool finish) { void U2Recognizer::AttentionRescoring() { decoder_->FinalizeSearch(); - UpdateResult(true); + UpdateResult(false); // No need to do rescoring if (0.0 == opts_.decoder_opts.rescoring_weight) { diff --git a/speechx/speechx/asr/recognizer/u2_recognizer.h b/speechx/speechx/asr/recognizer/u2_recognizer.h index 5d628e3a3..889da85bf 100644 --- a/speechx/speechx/asr/recognizer/u2_recognizer.h +++ b/speechx/speechx/asr/recognizer/u2_recognizer.h @@ -17,6 +17,7 @@ #include "decoder/common.h" #include "decoder/ctc_beam_search_opt.h" #include "decoder/ctc_prefix_beam_search_decoder.h" +#include "decoder/ctc_tlg_decoder.h" #include "decoder/decoder_itf.h" #include "frontend/feature_pipeline.h" #include "fst/fstlib.h" @@ -33,6 +34,8 @@ DECLARE_int32(blank); DECLARE_double(acoustic_scale); DECLARE_string(vocab_path); +DECLARE_string(word_symbol_table); +// DECLARE_string(fst_path); namespace ppspeech { @@ -59,6 +62,7 @@ struct DecodeOptions { // CtcEndpointConfig ctc_endpoint_opts; CTCBeamSearchOptions ctc_prefix_search_opts{}; + TLGDecoderOptions tlg_decoder_opts{}; static DecodeOptions InitFromFlags() { DecodeOptions decoder_opts; @@ -70,6 +74,13 @@ struct DecodeOptions { decoder_opts.ctc_prefix_search_opts.blank = FLAGS_blank; decoder_opts.ctc_prefix_search_opts.first_beam_size = FLAGS_nbest; decoder_opts.ctc_prefix_search_opts.second_beam_size = FLAGS_nbest; + // decoder_opts.tlg_decoder_opts.fst_path = "";//FLAGS_fst_path; + // decoder_opts.tlg_decoder_opts.word_symbol_table = + // FLAGS_word_symbol_table; + // decoder_opts.tlg_decoder_opts.nbest = FLAGS_nbest; + decoder_opts.tlg_decoder_opts = + ppspeech::TLGDecoderOptions::InitFromFlags(); + LOG(INFO) << "chunk_size: " << decoder_opts.chunk_size; LOG(INFO) << "num_left_chunks: " << decoder_opts.num_left_chunks; LOG(INFO) << "ctc_weight: " << decoder_opts.ctc_weight; @@ -113,7 +124,7 @@ class U2Recognizer { public: explicit U2Recognizer(const U2RecognizerResource& resouce); explicit U2Recognizer(const U2RecognizerResource& resource, - std::shared_ptr nnet); + std::shared_ptr nnet); ~U2Recognizer(); void InitDecoder(); void ResetContinuousDecoding(); @@ -154,10 +165,9 @@ class U2Recognizer { std::shared_ptr nnet_producer_; std::shared_ptr decodable_; - std::unique_ptr decoder_; + std::unique_ptr decoder_; // e2e unit symbol table - std::shared_ptr unit_table_ = nullptr; std::shared_ptr symbol_table_ = nullptr; std::vector result_;