diff --git a/runtime/engine/asr/decoder/ctc_beam_search_opt.h b/runtime/engine/asr/decoder/ctc_beam_search_opt.h index f4a81b3a6..4c145370c 100644 --- a/runtime/engine/asr/decoder/ctc_beam_search_opt.h +++ b/runtime/engine/asr/decoder/ctc_beam_search_opt.h @@ -22,51 +22,22 @@ namespace ppspeech { struct CTCBeamSearchOptions { // common int blank; - - // ds2 - std::string dict_file; - std::string lm_path; - int beam_size; - BaseFloat alpha; - BaseFloat beta; - BaseFloat cutoff_prob; - int cutoff_top_n; - int num_proc_bsearch; + std::string word_symbol_table; // u2 int first_beam_size; int second_beam_size; + CTCBeamSearchOptions() : blank(0), - dict_file("vocab.txt"), - lm_path(""), - beam_size(300), - alpha(1.9f), - beta(5.0), - cutoff_prob(0.99f), - cutoff_top_n(40), - num_proc_bsearch(10), + word_symbol_table("vocab.txt"), first_beam_size(10), second_beam_size(10) {} void Register(kaldi::OptionsItf* opts) { - std::string module = "Ds2BeamSearchConfig: "; - opts->Register("dict", &dict_file, module + "vocab file path."); - opts->Register( - "lm-path", &lm_path, module + "ngram language model path."); - opts->Register("alpha", &alpha, module + "alpha"); - opts->Register("beta", &beta, module + "beta"); - opts->Register("beam-size", - &beam_size, - module + "beam size for beam search method"); - opts->Register("cutoff-prob", &cutoff_prob, module + "cutoff probs"); - opts->Register("cutoff-top-n", &cutoff_top_n, module + "cutoff top n"); - opts->Register( - "num-proc-bsearch", &num_proc_bsearch, module + "num proc bsearch"); - + std::string module = "CTCBeamSearchOptions: "; + opts->Register("word_symbol_table", &word_symbol_table, module + "vocab file path."); opts->Register("blank", &blank, "blank id, default is 0."); - - module = "U2BeamSearchConfig: "; opts->Register( "first-beam-size", &first_beam_size, module + "first beam size."); opts->Register("second-beam-size", diff --git a/runtime/engine/asr/decoder/ctc_prefix_beam_search_decoder.cc b/runtime/engine/asr/decoder/ctc_prefix_beam_search_decoder.cc index fda8aab0e..f54f21fa2 100644 --- a/runtime/engine/asr/decoder/ctc_prefix_beam_search_decoder.cc +++ b/runtime/engine/asr/decoder/ctc_prefix_beam_search_decoder.cc @@ -30,11 +30,10 @@ using paddle::platform::TracerEventType; namespace ppspeech { -CTCPrefixBeamSearch::CTCPrefixBeamSearch(const std::string& vocab_path, - const CTCBeamSearchOptions& opts) +CTCPrefixBeamSearch::CTCPrefixBeamSearch(const CTCBeamSearchOptions& opts) : opts_(opts) { unit_table_ = std::shared_ptr( - fst::SymbolTable::ReadText(vocab_path)); + fst::SymbolTable::ReadText(opts.word_symbol_table)); CHECK(unit_table_ != nullptr); Reset(); diff --git a/runtime/engine/asr/decoder/ctc_prefix_beam_search_decoder.h b/runtime/engine/asr/decoder/ctc_prefix_beam_search_decoder.h index 3fe1944c7..391b40733 100644 --- a/runtime/engine/asr/decoder/ctc_prefix_beam_search_decoder.h +++ b/runtime/engine/asr/decoder/ctc_prefix_beam_search_decoder.h @@ -27,8 +27,7 @@ namespace ppspeech { class ContextGraph; class CTCPrefixBeamSearch : public DecoderBase { public: - CTCPrefixBeamSearch(const std::string& vocab_path, - const CTCBeamSearchOptions& opts); + CTCPrefixBeamSearch(const CTCBeamSearchOptions& opts); ~CTCPrefixBeamSearch() {} SearchType Type() const { return SearchType::kPrefixBeamSearch; } diff --git a/runtime/engine/asr/decoder/ctc_prefix_beam_search_decoder_main.cc b/runtime/engine/asr/decoder/ctc_prefix_beam_search_decoder_main.cc index 1673bdad1..1fa56cffd 100644 --- a/runtime/engine/asr/decoder/ctc_prefix_beam_search_decoder_main.cc +++ b/runtime/engine/asr/decoder/ctc_prefix_beam_search_decoder_main.cc @@ -23,7 +23,7 @@ DEFINE_string(feature_rspecifier, "", "test feature rspecifier"); DEFINE_string(result_wspecifier, "", "test result wspecifier"); -DEFINE_string(vocab_path, "", "vocab path"); +DEFINE_string(word_symbol_table, "", "vocab path"); DEFINE_string(model_path, "", "paddle nnet model"); @@ -52,10 +52,10 @@ int main(int argc, char* argv[]) { CHECK_NE(FLAGS_result_wspecifier, ""); CHECK_NE(FLAGS_feature_rspecifier, ""); - CHECK_NE(FLAGS_vocab_path, ""); + CHECK_NE(FLAGS_word_symbol_table, ""); CHECK_NE(FLAGS_model_path, ""); LOG(INFO) << "model path: " << FLAGS_model_path; - LOG(INFO) << "Reading vocab table " << FLAGS_vocab_path; + LOG(INFO) << "Reading vocab table " << FLAGS_word_symbol_table; kaldi::SequentialBaseFloatMatrixReader feature_reader( FLAGS_feature_rspecifier); @@ -80,7 +80,8 @@ int main(int argc, char* argv[]) { opts.blank = 0; opts.first_beam_size = 10; opts.second_beam_size = 10; - ppspeech::CTCPrefixBeamSearch decoder(FLAGS_vocab_path, opts); + opts.word_symbol_table = FLAGS_word_symbol_table; + ppspeech::CTCPrefixBeamSearch decoder(opts); int32 chunk_size = FLAGS_receptive_field_length + diff --git a/runtime/engine/asr/decoder/ctc_tlg_decoder.cc b/runtime/engine/asr/decoder/ctc_tlg_decoder.cc index ac30da922..51ded499d 100644 --- a/runtime/engine/asr/decoder/ctc_tlg_decoder.cc +++ b/runtime/engine/asr/decoder/ctc_tlg_decoder.cc @@ -13,12 +13,14 @@ // limitations under the License. #include "decoder/ctc_tlg_decoder.h" + namespace ppspeech { TLGDecoder::TLGDecoder(TLGDecoderOptions opts) : opts_(opts) { - fst_.reset(fst::Fst::Read(opts.fst_path)); + fst_ = opts.fst_ptr; CHECK(fst_ != nullptr); + CHECK(!opts.word_symbol_table.empty()); word_symbol_table_.reset( fst::SymbolTable::ReadText(opts.word_symbol_table)); diff --git a/runtime/engine/asr/decoder/ctc_tlg_decoder.h b/runtime/engine/asr/decoder/ctc_tlg_decoder.h index 4540bc465..2d40f0b91 100644 --- a/runtime/engine/asr/decoder/ctc_tlg_decoder.h +++ b/runtime/engine/asr/decoder/ctc_tlg_decoder.h @@ -18,6 +18,7 @@ #include "decoder/decoder_itf.h" #include "kaldi/decoder/lattice-faster-online-decoder.h" #include "util/parse-options.h" +#include "utils/file_utils.h" DECLARE_string(word_symbol_table); DECLARE_string(graph_path); @@ -33,9 +34,10 @@ struct TLGDecoderOptions { // todo remove later, add into decode resource std::string word_symbol_table; std::string fst_path; + std::shared_ptr> fst_ptr; int nbest; - TLGDecoderOptions() : word_symbol_table(""), fst_path(""), nbest(10) {} + TLGDecoderOptions() : word_symbol_table(""), fst_path(""), fst_ptr(nullptr), nbest(10) {} static TLGDecoderOptions InitFromFlags() { TLGDecoderOptions decoder_opts; @@ -44,6 +46,11 @@ struct TLGDecoderOptions { LOG(INFO) << "fst path: " << decoder_opts.fst_path; LOG(INFO) << "fst symbole table: " << decoder_opts.word_symbol_table; + if (!decoder_opts.fst_path.empty()) { + CHECK(FileExists(decoder_opts.fst_path)); + decoder_opts.fst_ptr.reset(fst::Fst::Read(FLAGS_graph_path)); + } + decoder_opts.opts.max_active = FLAGS_max_active; decoder_opts.opts.beam = FLAGS_beam; decoder_opts.opts.lattice_beam = FLAGS_lattice_beam; diff --git a/runtime/engine/asr/decoder/param.h b/runtime/engine/asr/decoder/param.h index 0d67d77e0..bef5514fb 100644 --- a/runtime/engine/asr/decoder/param.h +++ b/runtime/engine/asr/decoder/param.h @@ -37,28 +37,14 @@ DEFINE_int32(nnet_decoder_chunk, 1, "paddle nnet forward chunk"); // nnet -DEFINE_string(vocab_path, "", "nnet vocab path."); DEFINE_string(model_path, "avg_1.jit.pdmodel", "paddle nnet model"); #ifdef USE_ONNX DEFINE_bool(with_onnx_model, false, "True mean the model path is onnx model path"); #endif -DEFINE_string(param_path, "avg_1.jit.pdiparams", "paddle nnet model param"); -DEFINE_string( - model_input_names, - "audio_chunk,audio_chunk_lens,chunk_state_h_box,chunk_state_c_box", - "model input names"); -DEFINE_string(model_output_names, - "softmax_0.tmp_0,tmp_5,concat_0.tmp_0,concat_1.tmp_0", - "model output names"); -DEFINE_string(model_cache_names, - "chunk_state_h_box,chunk_state_c_box", - "model cache names"); -DEFINE_string(model_cache_shapes, "5-1-1024,5-1-1024", "model cache shapes"); - +//DEFINE_string(param_path, "avg_1.jit.pdiparams", "paddle nnet model param"); // decoder DEFINE_double(acoustic_scale, 1.0, "acoustic scale"); - DEFINE_string(graph_path, "", "decoder graph"); DEFINE_string(word_symbol_table, "", "word symbol table"); DEFINE_int32(max_active, 7500, "max active"); diff --git a/runtime/engine/asr/nnet/nnet_itf.h b/runtime/engine/asr/nnet/nnet_itf.h index 5106894fd..ac105d119 100644 --- a/runtime/engine/asr/nnet/nnet_itf.h +++ b/runtime/engine/asr/nnet/nnet_itf.h @@ -33,24 +33,12 @@ namespace ppspeech { struct ModelOptions { // common int subsample_rate{1}; - int thread_num{1}; // predictor thread pool size for ds2; bool use_gpu{false}; std::string model_path; #ifdef USE_ONNX bool with_onnx_model{false}; #endif - std::string param_path; - - // ds2 for inference - std::string input_names{}; - std::string output_names{}; - std::string cache_names{}; - std::string cache_shape{}; - bool switch_ir_optim{false}; - bool enable_fc_padding{false}; - bool enable_profile{false}; - static ModelOptions InitFromFlags() { ModelOptions opts; opts.subsample_rate = FLAGS_subsampling_rate; @@ -61,19 +49,6 @@ struct ModelOptions { opts.with_onnx_model = FLAGS_with_onnx_model; LOG(INFO) << "with onnx model: " << opts.with_onnx_model; #endif - - opts.param_path = FLAGS_param_path; - LOG(INFO) << "param path: " << opts.param_path; - - LOG(INFO) << "DS2 param: "; - opts.cache_names = FLAGS_model_cache_names; - LOG(INFO) << " cache names: " << opts.cache_names; - opts.cache_shape = FLAGS_model_cache_shapes; - LOG(INFO) << " cache shape: " << opts.cache_shape; - opts.input_names = FLAGS_model_input_names; - LOG(INFO) << " input names: " << opts.input_names; - opts.output_names = FLAGS_model_output_names; - LOG(INFO) << " output names: " << opts.output_names; return opts; } }; @@ -121,7 +96,7 @@ class NnetInterface { class NnetBase : public NnetInterface { public: int SubsamplingRate() const { return subsampling_rate_; } - + virtual std::shared_ptr Clone() const = 0; protected: int subsampling_rate_{1}; }; diff --git a/runtime/engine/asr/nnet/nnet_producer.cc b/runtime/engine/asr/nnet/nnet_producer.cc index 7368b7c4b..b7bc8a33c 100644 --- a/runtime/engine/asr/nnet/nnet_producer.cc +++ b/runtime/engine/asr/nnet/nnet_producer.cc @@ -45,7 +45,7 @@ void NnetProducer::Acceptlikelihood( bool NnetProducer::Read(std::vector* nnet_prob) { bool flag = cache_.pop(nnet_prob); - LOG(INFO) << "nnet cache_ size: " << cache_.size(); + VLOG(1) << "nnet cache_ size: " << cache_.size(); return flag; } @@ -53,7 +53,6 @@ bool NnetProducer::Compute() { vector features; if (frontend_ == NULL || frontend_->Read(&features) == false) { // no feat or frontend_ not init. - LOG(INFO) << "no feat avalible"; if (frontend_->IsFinished() == true) { finished_ = true; } diff --git a/runtime/engine/asr/recognizer/CMakeLists.txt b/runtime/engine/asr/recognizer/CMakeLists.txt index 0e9c8fb2b..e8c865059 100644 --- a/runtime/engine/asr/recognizer/CMakeLists.txt +++ b/runtime/engine/asr/recognizer/CMakeLists.txt @@ -3,6 +3,8 @@ set(srcs) list(APPEND srcs recognizer_controller.cc recognizer_controller_impl.cc + recognizer_instance.cc + recognizer.cc ) add_library(recognizer STATIC ${srcs}) @@ -10,6 +12,7 @@ target_link_libraries(recognizer PUBLIC decoder) set(TEST_BINS recognizer_batch_main + recognizer_batch_main2 recognizer_main ) diff --git a/runtime/engine/asr/recognizer/recognizer.cc b/runtime/engine/asr/recognizer/recognizer.cc index dcd21a4ca..3a95bcc8d 100644 --- a/runtime/engine/asr/recognizer/recognizer.cc +++ b/runtime/engine/asr/recognizer/recognizer.cc @@ -10,4 +10,37 @@ // distributed under the License is distributed on an "AS IS" BASIS, // WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. // See the License for the specific language governing permissions and -// limitations under the License. \ No newline at end of file +// limitations under the License. + +#include "recognizer/recognizer.h" +#include "recognizer/recognizer_instance.h" + +bool InitRecognizer(const std::string& model_file, + const std::string& word_symbol_table_file, + const std::string& fst_file, + int num_instance) { + return ppspeech::RecognizerInstance::GetInstance().Init(model_file, + word_symbol_table_file, + fst_file, + num_instance); +} + +int GetRecognizerInstanceId() { + return ppspeech::RecognizerInstance::GetInstance().GetRecognizerInstanceId(); +} + +void InitDecoder(int instance_id) { + return ppspeech::RecognizerInstance::GetInstance().InitDecoder(instance_id); +} + +void AcceptData(const std::vector& waves, int instance_id) { + return ppspeech::RecognizerInstance::GetInstance().Accept(waves, instance_id); +} + +void SetInputFinished(int instance_id) { + return ppspeech::RecognizerInstance::GetInstance().SetInputFinished(instance_id); +} + +std::string GetFinalResult(int instance_id) { + return ppspeech::RecognizerInstance::GetInstance().GetResult(instance_id); +} \ No newline at end of file diff --git a/runtime/engine/asr/recognizer/recognizer.h b/runtime/engine/asr/recognizer/recognizer.h index dcd21a4ca..bd7fb1294 100644 --- a/runtime/engine/asr/recognizer/recognizer.h +++ b/runtime/engine/asr/recognizer/recognizer.h @@ -10,4 +10,19 @@ // distributed under the License is distributed on an "AS IS" BASIS, // WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. // See the License for the specific language governing permissions and -// limitations under the License. \ No newline at end of file +// limitations under the License. + +#pragma once + +#include +#include + +bool InitRecognizer(const std::string& model_file, + const std::string& word_symbol_table_file, + const std::string& fst_file, + int num_instance); +int GetRecognizerInstanceId(); +void InitDecoder(int instance_id); +void AcceptData(const std::vector& waves, int instance_id); +void SetInputFinished(int instance_id); +std::string GetFinalResult(int instance_id); \ No newline at end of file diff --git a/runtime/engine/asr/recognizer/recognizer_batch_main2.cc b/runtime/engine/asr/recognizer/recognizer_batch_main2.cc new file mode 100644 index 000000000..fc99bf0bd --- /dev/null +++ b/runtime/engine/asr/recognizer/recognizer_batch_main2.cc @@ -0,0 +1,168 @@ +// Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#include "common/base/thread_pool.h" +#include "common/utils/file_utils.h" +#include "common/utils/strings.h" +#include "decoder/param.h" +#include "frontend/wave-reader.h" +#include "kaldi/util/table-types.h" +#include "nnet/u2_nnet.h" +#include "recognizer/recognizer.h" + +DEFINE_string(wav_rspecifier, "", "test feature rspecifier"); +DEFINE_string(result_wspecifier, "", "test result wspecifier"); +DEFINE_double(streaming_chunk, 0.36, "streaming feature chunk size"); +DEFINE_int32(sample_rate, 16000, "sample rate"); +DEFINE_int32(njob, 3, "njob"); + +using std::string; +using std::vector; + +void SplitUtt(string wavlist_file, + vector>* uttlists, + vector>* wavlists, + int njob) { + vector wavlist; + wavlists->resize(njob); + uttlists->resize(njob); + ppspeech::ReadFileToVector(wavlist_file, &wavlist); + for (size_t idx = 0; idx < wavlist.size(); ++idx) { + string utt_str = wavlist[idx]; + vector utt_wav = ppspeech::StrSplit(utt_str, " \t"); + LOG(INFO) << utt_wav[0]; + CHECK_EQ(utt_wav.size(), size_t(2)); + uttlists->at(idx % njob).push_back(utt_wav[0]); + wavlists->at(idx % njob).push_back(utt_wav[1]); + } +} + +void recognizer_func(std::vector wavlist, + std::vector uttlist, + std::vector* results) { + int32 num_done = 0, num_err = 0; + double tot_wav_duration = 0.0; + double tot_attention_rescore_time = 0.0; + double tot_decode_time = 0.0; + int chunk_sample_size = FLAGS_streaming_chunk * FLAGS_sample_rate; + if (wavlist.empty()) return; + + results->reserve(wavlist.size()); + for (size_t idx = 0; idx < wavlist.size(); ++idx) { + std::string utt = uttlist[idx]; + std::string wav_file = wavlist[idx]; + std::ifstream infile; + infile.open(wav_file, std::ifstream::in); + kaldi::WaveData wave_data; + wave_data.Read(infile); + int32 recog_id = -1; + while (recog_id == -1) { + recog_id = GetRecognizerInstanceId(); + } + InitDecoder(recog_id); + LOG(INFO) << "utt: " << utt; + LOG(INFO) << "wav dur: " << wave_data.Duration() << " sec."; + double dur = wave_data.Duration(); + tot_wav_duration += dur; + + int32 this_channel = 0; + kaldi::SubVector waveform(wave_data.Data(), + this_channel); + int tot_samples = waveform.Dim(); + LOG(INFO) << "wav len (sample): " << tot_samples; + + int sample_offset = 0; + kaldi::Timer local_timer; + + while (sample_offset < tot_samples) { + int cur_chunk_size = + std::min(chunk_sample_size, tot_samples - sample_offset); + + std::vector wav_chunk(cur_chunk_size); + for (int i = 0; i < cur_chunk_size; ++i) { + wav_chunk[i] = waveform(sample_offset + i); + } + + AcceptData(wav_chunk, recog_id); + // no overlap + sample_offset += cur_chunk_size; + } + SetInputFinished(recog_id); + CHECK(sample_offset == tot_samples); + std::string result = GetFinalResult(recog_id); + if (result.empty()) { + // the TokenWriter can not write empty string. + ++num_err; + LOG(INFO) << " the result of " << utt << " is empty"; + result = " "; + } + + tot_decode_time += local_timer.Elapsed(); + LOG(INFO) << utt << " " << result; + LOG(INFO) << " RTF: " << local_timer.Elapsed() / dur << " dur: " << dur + << " cost: " << local_timer.Elapsed(); + + results->push_back(result); + ++num_done; + } + LOG(INFO) << "Done " << num_done << " out of " << (num_err + num_done); + LOG(INFO) << "total wav duration is: " << tot_wav_duration << " sec"; + LOG(INFO) << "total decode cost:" << tot_decode_time << " sec"; + LOG(INFO) << "RTF is: " << tot_decode_time / tot_wav_duration; +} + +int main(int argc, char* argv[]) { + gflags::SetUsageMessage("Usage:"); + gflags::ParseCommandLineFlags(&argc, &argv, false); + google::InitGoogleLogging(argv[0]); + google::InstallFailureSignalHandler(); + FLAGS_logtostderr = 1; + + int sample_rate = FLAGS_sample_rate; + float streaming_chunk = FLAGS_streaming_chunk; + int chunk_sample_size = streaming_chunk * sample_rate; + kaldi::TokenWriter result_writer(FLAGS_result_wspecifier); + int njob = FLAGS_njob; + LOG(INFO) << "sr: " << sample_rate; + LOG(INFO) << "chunk size (s): " << streaming_chunk; + LOG(INFO) << "chunk size (sample): " << chunk_sample_size; + + InitRecognizer(FLAGS_model_path, FLAGS_word_symbol_table, FLAGS_graph_path, njob); + ThreadPool threadpool(njob); + vector> wavlist; + vector> uttlist; + vector> resultlist(njob); + vector> futurelist; + SplitUtt(FLAGS_wav_rspecifier, &uttlist, &wavlist, njob); + for (size_t i = 0; i < njob; ++i) { + std::future f = threadpool.enqueue(recognizer_func, + wavlist[i], + uttlist[i], + &resultlist[i]); + futurelist.push_back(std::move(f)); + } + + for (size_t i = 0; i < njob; ++i) { + futurelist[i].get(); + } + + for (size_t idx = 0; idx < njob; ++idx) { + for (size_t utt_idx = 0; utt_idx < uttlist[idx].size(); ++utt_idx) { + string utt = uttlist[idx][utt_idx]; + string result = resultlist[idx][utt_idx]; + result_writer.Write(utt, result); + } + } + return 0; +} diff --git a/runtime/engine/asr/recognizer/recognizer_controller.cc b/runtime/engine/asr/recognizer/recognizer_controller.cc index d85fff093..ef549263f 100644 --- a/runtime/engine/asr/recognizer/recognizer_controller.cc +++ b/runtime/engine/asr/recognizer/recognizer_controller.cc @@ -18,10 +18,9 @@ namespace ppspeech { RecognizerController::RecognizerController(int num_worker, RecognizerResource resource) { - nnet_ = std::make_shared(resource.model_opts); recognizer_workers.resize(num_worker); for (size_t i = 0; i < num_worker; ++i) { - recognizer_workers[i].reset(new ppspeech::RecognizerControllerImpl(resource, nnet_->Clone())); + recognizer_workers[i].reset(new ppspeech::RecognizerControllerImpl(resource)); waiting_workers.push(i); } } diff --git a/runtime/engine/asr/recognizer/recognizer_controller.h b/runtime/engine/asr/recognizer/recognizer_controller.h index ee92e1931..16a8dd137 100644 --- a/runtime/engine/asr/recognizer/recognizer_controller.h +++ b/runtime/engine/asr/recognizer/recognizer_controller.h @@ -18,7 +18,6 @@ #include #include "recognizer/recognizer_controller_impl.h" -#include "nnet/u2_nnet.h" namespace ppspeech { @@ -34,7 +33,6 @@ class RecognizerController { private: std::queue waiting_workers; - std::shared_ptr nnet_; std::mutex mutex_; std::vector> recognizer_workers; diff --git a/runtime/engine/asr/recognizer/recognizer_controller_impl.cc b/runtime/engine/asr/recognizer/recognizer_controller_impl.cc index 5168e43fb..3d141752d 100644 --- a/runtime/engine/asr/recognizer/recognizer_controller_impl.cc +++ b/runtime/engine/asr/recognizer/recognizer_controller_impl.cc @@ -26,24 +26,24 @@ RecognizerControllerImpl::RecognizerControllerImpl(const RecognizerResource& res new FeaturePipeline(feature_opts)); std::shared_ptr nnet; #ifndef USE_ONNX - nnet.reset(new U2Nnet(resource.model_opts)); + nnet = resource.nnet->Clone(); #else if (resource.model_opts.with_onnx_model){ nnet.reset(new U2OnnxNnet(resource.model_opts)); } else { - nnet.reset(new U2Nnet(resource.model_opts)); + nnet = resource.nnet->Clone(); } #endif nnet_producer_.reset(new NnetProducer(nnet, feature_pipeline)); nnet_thread_ = std::thread(RunNnetEvaluation, this); decodable_.reset(new Decodable(nnet_producer_, am_scale)); - CHECK_NE(resource.vocab_path, ""); if (resource.decoder_opts.tlg_decoder_opts.fst_path.empty()) { - LOG(INFO) << resource.decoder_opts.tlg_decoder_opts.fst_path; + LOG(INFO) << "Init PrefixBeamSearch Decoder"; decoder_ = std::make_unique( - resource.vocab_path, resource.decoder_opts.ctc_prefix_search_opts); + resource.decoder_opts.ctc_prefix_search_opts); } else { + LOG(INFO) << "Init TLGDecoder"; decoder_ = std::make_unique( resource.decoder_opts.tlg_decoder_opts); } @@ -55,33 +55,6 @@ RecognizerControllerImpl::RecognizerControllerImpl(const RecognizerResource& res result_.clear(); } -RecognizerControllerImpl::RecognizerControllerImpl(const RecognizerResource& resource, - std::shared_ptr nnet) - :opts_(resource) { - BaseFloat am_scale = resource.acoustic_scale; - const FeaturePipelineOptions& feature_opts = resource.feature_pipeline_opts; - std::shared_ptr feature_pipeline = - std::make_shared(feature_opts); - nnet_producer_ = std::make_shared(nnet, feature_pipeline); - nnet_thread_ = std::thread(RunNnetEvaluation, this); - decodable_.reset(new Decodable(nnet_producer_, am_scale)); - - CHECK_NE(resource.vocab_path, ""); - if (resource.decoder_opts.tlg_decoder_opts.fst_path == "") { - decoder_.reset(new CTCPrefixBeamSearch( - resource.vocab_path, resource.decoder_opts.ctc_prefix_search_opts)); - } else { - decoder_.reset(new TLGDecoder(resource.decoder_opts.tlg_decoder_opts)); - } - - symbol_table_ = decoder_->WordSymbolTable(); - - global_frame_offset_ = 0; - input_finished_ = false; - num_frames_ = 0; - result_.clear(); -} - RecognizerControllerImpl::~RecognizerControllerImpl() { WaitFinished(); } diff --git a/runtime/engine/asr/recognizer/recognizer_controller_impl.h b/runtime/engine/asr/recognizer/recognizer_controller_impl.h index 006de74c8..fe1f8e112 100644 --- a/runtime/engine/asr/recognizer/recognizer_controller_impl.h +++ b/runtime/engine/asr/recognizer/recognizer_controller_impl.h @@ -32,8 +32,8 @@ namespace ppspeech { class RecognizerControllerImpl { public: explicit RecognizerControllerImpl(const RecognizerResource& resource); - explicit RecognizerControllerImpl(const RecognizerResource& resource, - std::shared_ptr nnet); + //explicit RecognizerControllerImpl(const RecognizerResource& resource, + // std::shared_ptr nnet); ~RecognizerControllerImpl(); void Accept(std::vector data); void InitDecoder(); diff --git a/runtime/engine/asr/recognizer/recognizer_impl.cc b/runtime/engine/asr/recognizer/recognizer_impl.cc deleted file mode 100644 index dcd21a4ca..000000000 --- a/runtime/engine/asr/recognizer/recognizer_impl.cc +++ /dev/null @@ -1,13 +0,0 @@ -// Copyright (c) 2023 PaddlePaddle Authors. All Rights Reserved. -// -// Licensed under the Apache License, Version 2.0 (the "License"); -// you may not use this file except in compliance with the License. -// You may obtain a copy of the License at -// -// http://www.apache.org/licenses/LICENSE-2.0 -// -// Unless required by applicable law or agreed to in writing, software -// distributed under the License is distributed on an "AS IS" BASIS, -// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -// See the License for the specific language governing permissions and -// limitations under the License. \ No newline at end of file diff --git a/runtime/engine/asr/recognizer/recognizer_impl.h b/runtime/engine/asr/recognizer/recognizer_impl.h deleted file mode 100644 index dcd21a4ca..000000000 --- a/runtime/engine/asr/recognizer/recognizer_impl.h +++ /dev/null @@ -1,13 +0,0 @@ -// Copyright (c) 2023 PaddlePaddle Authors. All Rights Reserved. -// -// Licensed under the Apache License, Version 2.0 (the "License"); -// you may not use this file except in compliance with the License. -// You may obtain a copy of the License at -// -// http://www.apache.org/licenses/LICENSE-2.0 -// -// Unless required by applicable law or agreed to in writing, software -// distributed under the License is distributed on an "AS IS" BASIS, -// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -// See the License for the specific language governing permissions and -// limitations under the License. \ No newline at end of file diff --git a/runtime/engine/asr/recognizer/recognizer_instance.cc b/runtime/engine/asr/recognizer/recognizer_instance.cc new file mode 100644 index 000000000..b9019ec4e --- /dev/null +++ b/runtime/engine/asr/recognizer/recognizer_instance.cc @@ -0,0 +1,66 @@ +// Copyright (c) 2023 PaddlePaddle Authors. All Rights Reserved. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#include "recognizer/recognizer_instance.h" + + +namespace ppspeech { + +RecognizerInstance& RecognizerInstance::GetInstance() { + static RecognizerInstance instance; + return instance; +} + +bool RecognizerInstance::Init(const std::string& model_file, + const std::string& word_symbol_table_file, + const std::string& fst_file, + int num_instance) { + RecognizerResource resource = RecognizerResource::InitFromFlags(); + resource.model_opts.model_path = model_file; + //resource.vocab_path = word_symbol_table_file; + if (!fst_file.empty()) { + resource.decoder_opts.tlg_decoder_opts.fst_path = fst_file; + resource.decoder_opts.tlg_decoder_opts.fst_path = word_symbol_table_file; + } else { + resource.decoder_opts.ctc_prefix_search_opts.word_symbol_table = + word_symbol_table_file; + } + recognizer_controller_ = std::make_unique(num_instance, resource); + return true; +} + +void RecognizerInstance::InitDecoder(int idx) { + recognizer_controller_->InitDecoder(idx); + return; +} + +int RecognizerInstance::GetRecognizerInstanceId() { + return recognizer_controller_->GetRecognizerInstanceId(); +} + +void RecognizerInstance::Accept(const std::vector& waves, int idx) const { + recognizer_controller_->Accept(waves, idx); + return; +} + +void RecognizerInstance::SetInputFinished(int idx) const { + recognizer_controller_->SetInputFinished(idx); + return; +} + +std::string RecognizerInstance::GetResult(int idx) const { + return recognizer_controller_->GetFinalResult(idx); +} + +} \ No newline at end of file diff --git a/runtime/engine/asr/recognizer/recognizer_instance.h b/runtime/engine/asr/recognizer/recognizer_instance.h new file mode 100644 index 000000000..ef8f524d6 --- /dev/null +++ b/runtime/engine/asr/recognizer/recognizer_instance.h @@ -0,0 +1,42 @@ +// Copyright (c) 2023 PaddlePaddle Authors. All Rights Reserved. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#pragma once + +#include "base/common.h" +#include "recognizer/recognizer_controller.h" + +namespace ppspeech { + +class RecognizerInstance { + public: + static RecognizerInstance& GetInstance(); + RecognizerInstance() {} + ~RecognizerInstance() {} + bool Init(const std::string& model_file, + const std::string& word_symbol_table_file, + const std::string& fst_file, + int num_instance); + int GetRecognizerInstanceId(); + void InitDecoder(int idx); + void Accept(const std::vector& waves, int idx) const; + void SetInputFinished(int idx) const; + std::string GetResult(int idx) const; + + private: + std::unique_ptr recognizer_controller_; +}; + + +} // namespace ppspeech diff --git a/runtime/engine/asr/recognizer/recognizer_resource.h b/runtime/engine/asr/recognizer/recognizer_resource.h index 2a83d960c..963149dfd 100644 --- a/runtime/engine/asr/recognizer/recognizer_resource.h +++ b/runtime/engine/asr/recognizer/recognizer_resource.h @@ -12,7 +12,6 @@ DECLARE_double(reverse_weight); DECLARE_int32(nbest); DECLARE_int32(blank); DECLARE_double(acoustic_scale); -DECLARE_string(vocab_path); DECLARE_string(word_symbol_table); namespace ppspeech { @@ -52,6 +51,8 @@ struct DecodeOptions { decoder_opts.ctc_prefix_search_opts.blank = FLAGS_blank; decoder_opts.ctc_prefix_search_opts.first_beam_size = FLAGS_nbest; decoder_opts.ctc_prefix_search_opts.second_beam_size = FLAGS_nbest; + decoder_opts.ctc_prefix_search_opts.word_symbol_table = + FLAGS_word_symbol_table; decoder_opts.tlg_decoder_opts = ppspeech::TLGDecoderOptions::InitFromFlags(); @@ -68,18 +69,17 @@ struct DecodeOptions { }; struct RecognizerResource { + // decodable opt kaldi::BaseFloat acoustic_scale{1.0}; - std::string vocab_path{}; FeaturePipelineOptions feature_pipeline_opts{}; ModelOptions model_opts{}; DecodeOptions decoder_opts{}; + std::shared_ptr nnet; static RecognizerResource InitFromFlags() { RecognizerResource resource; - resource.vocab_path = FLAGS_vocab_path; resource.acoustic_scale = FLAGS_acoustic_scale; - LOG(INFO) << "vocab path: " << resource.vocab_path; LOG(INFO) << "acoustic_scale: " << resource.acoustic_scale; resource.feature_pipeline_opts = @@ -89,6 +89,15 @@ struct RecognizerResource { << resource.feature_pipeline_opts.assembler_opts.fill_zero; resource.model_opts = ppspeech::ModelOptions::InitFromFlags(); resource.decoder_opts = ppspeech::DecodeOptions::InitFromFlags(); + #ifndef USE_ONNX + resource.nnet.reset(new U2Nnet(resource.model_opts)); + #else + if (resource.model_opts.with_onnx_model){ + resource.nnet.reset(new U2OnnxNnet(resource.model_opts)); + } else { + resource.nnet.reset(new U2Nnet(resource.model_opts)); + } + #endif return resource; } }; diff --git a/runtime/engine/common/utils/file_utils.cc b/runtime/engine/common/utils/file_utils.cc index c42a642c7..385f2b656 100644 --- a/runtime/engine/common/utils/file_utils.cc +++ b/runtime/engine/common/utils/file_utils.cc @@ -14,6 +14,8 @@ #include "utils/file_utils.h" +#include + namespace ppspeech { bool ReadFileToVector(const std::string& filename, @@ -40,4 +42,31 @@ std::string ReadFile2String(const std::string& path) { return std::string((std::istreambuf_iterator(input_file)), std::istreambuf_iterator()); } + +bool FileExists(const std::string& strFilename) { + // this funciton if from: + // https://github.com/kaldi-asr/kaldi/blob/master/src/fstext/deterministic-fst-test.cc + struct stat stFileInfo; + bool blnReturn; + int intStat; + + // Attempt to get the file attributes + intStat = stat(strFilename.c_str(), &stFileInfo); + if (intStat == 0) { + // We were able to get the file attributes + // so the file obviously exists. + blnReturn = true; + } else { + // We were not able to get the file attributes. + // This may mean that we don't have permission to + // access the folder which contains this file. If you + // need to do that level of checking, lookup the + // return values of stat which will give you + // more details on why stat failed. + blnReturn = false; + } + + return blnReturn; +} + } // namespace ppspeech diff --git a/runtime/engine/common/utils/file_utils.h b/runtime/engine/common/utils/file_utils.h index a471e024e..420740dbb 100644 --- a/runtime/engine/common/utils/file_utils.h +++ b/runtime/engine/common/utils/file_utils.h @@ -20,4 +20,7 @@ bool ReadFileToVector(const std::string& filename, std::vector* data); std::string ReadFile2String(const std::string& path); + +bool FileExists(const std::string& filename); + } // namespace ppspeech diff --git a/runtime/examples/u2pp_ol/wenetspeech/local/decode.sh b/runtime/examples/u2pp_ol/wenetspeech/local/decode.sh index 059ed1b36..ab48596f4 100755 --- a/runtime/examples/u2pp_ol/wenetspeech/local/decode.sh +++ b/runtime/examples/u2pp_ol/wenetspeech/local/decode.sh @@ -14,7 +14,7 @@ text=$data/test/text utils/run.pl JOB=1:$nj $data/split${nj}/JOB/decoder.log \ ctc_prefix_beam_search_decoder_main \ --model_path=$model_dir/export.jit \ - --vocab_path=$model_dir/unit.txt \ + --word_symbol_table=$model_dir/unit.txt \ --nnet_decoder_chunk=16 \ --receptive_field_length=7 \ --subsampling_rate=4 \ @@ -23,4 +23,4 @@ ctc_prefix_beam_search_decoder_main \ cat $data/split${nj}/*/result_decode.ark > $exp/aishell.decode.rsl utils/compute-wer.py --char=1 --v=1 $text $exp/aishell.decode.rsl > $exp/aishell.decode.err -tail -n 7 $exp/aishell.decode.err \ No newline at end of file +tail -n 7 $exp/aishell.decode.err diff --git a/runtime/examples/u2pp_ol/wenetspeech/local/recognizer.sh b/runtime/examples/u2pp_ol/wenetspeech/local/recognizer.sh index 0e76b183e..87f0c612d 100755 --- a/runtime/examples/u2pp_ol/wenetspeech/local/recognizer.sh +++ b/runtime/examples/u2pp_ol/wenetspeech/local/recognizer.sh @@ -21,7 +21,7 @@ recognizer_main \ --num_bins=80 \ --cmvn_file=$model_dir/mean_std.json \ --model_path=$model_dir/export.jit \ - --vocab_path=$model_dir/unit.txt \ + --word_symbol_table=$model_dir/unit.txt \ --nnet_decoder_chunk=16 \ --receptive_field_length=7 \ --subsampling_rate=4 \ diff --git a/runtime/examples/u2pp_ol/wenetspeech/local/recognizer_quant.sh b/runtime/examples/u2pp_ol/wenetspeech/local/recognizer_quant.sh index 555feb83f..fe919facb 100755 --- a/runtime/examples/u2pp_ol/wenetspeech/local/recognizer_quant.sh +++ b/runtime/examples/u2pp_ol/wenetspeech/local/recognizer_quant.sh @@ -21,7 +21,7 @@ u2_recognizer_main \ --num_bins=80 \ --cmvn_file=$model_dir/mean_std.json \ --model_path=$model_dir/export \ - --vocab_path=$model_dir/unit.txt \ + --word_symbol_table=$model_dir/unit.txt \ --nnet_decoder_chunk=16 \ --receptive_field_length=7 \ --subsampling_rate=4 \ diff --git a/runtime/examples/u2pp_ol/wenetspeech/local/recognizer_wfst.sh b/runtime/examples/u2pp_ol/wenetspeech/local/recognizer_wfst.sh new file mode 100755 index 000000000..ed4ebdad6 --- /dev/null +++ b/runtime/examples/u2pp_ol/wenetspeech/local/recognizer_wfst.sh @@ -0,0 +1,41 @@ +#!/bin/bash +set -e + +data=data +exp=exp +nj=40 + +. utils/parse_options.sh + +mkdir -p $exp +ckpt_dir=./data/model +model_dir=$ckpt_dir/asr1_chunk_conformer_u2pp_wenetspeech_static_1.3.0.model/ +aishell_wav_scp=aishell_test.scp +text=$data/test/text + +./local/split_data.sh $data $data/$aishell_wav_scp $aishell_wav_scp $nj + +lang_dir=./data/lang_test/ +graph=$lang_dir/TLG.fst +word_table=$lang_dir/words.txt + +utils/run.pl JOB=1:$nj $data/split${nj}/JOB/recognizer_wfst.log \ +recognizer_main \ + --use_fbank=true \ + --num_bins=80 \ + --cmvn_file=$model_dir/mean_std.json \ + --model_path=$model_dir/export.jit \ + --graph_path=$lang_dir/TLG.fst \ + --word_symbol_table=$word_table \ + --nnet_decoder_chunk=16 \ + --receptive_field_length=7 \ + --subsampling_rate=4 \ + --wav_rspecifier=scp:$data/split${nj}/JOB/${aishell_wav_scp} \ + --result_wspecifier=ark,t:$data/split${nj}/JOB/result_recognizer_wfst.ark + + +cat $data/split${nj}/*/result_recognizer_wfst.ark > $exp/aishell_recognizer_wfst +utils/compute-wer.py --char=1 --v=1 $text $exp/aishell_recognizer_wfst > $exp/aishell.recognizer_wfst.err +echo "recognizer test have finished!!!" +echo "please checkout in $exp/aishell.recognizer_wfst.err" +tail -n 7 $exp/aishell.recognizer_wfst.err