From 85a1744ecc4509cafa48e4888d60704764836b79 Mon Sep 17 00:00:00 2001 From: YangZhou Date: Sun, 23 Apr 2023 10:31:13 +0800 Subject: [PATCH] add skip blank --- .../ctc_prefix_beam_search_decoder_main.cc | 2 +- runtime/engine/asr/decoder/ctc_tlg_decoder.h | 2 +- .../asr/decoder/ctc_tlg_decoder_main.cc | 2 +- runtime/engine/asr/decoder/param.h | 5 +---- runtime/engine/asr/nnet/nnet_producer.cc | 22 ++++++++++++++++--- runtime/engine/asr/nnet/nnet_producer.h | 7 +++++- .../recognizer/recognizer_controller_impl.cc | 3 ++- .../asr/recognizer/recognizer_resource.h | 3 +++ runtime/engine/kaldi/fstbin/CMakeLists.txt | 2 +- .../wenetspeech/local/recognizer_wfst.sh | 2 +- 10 files changed, 36 insertions(+), 14 deletions(-) diff --git a/runtime/engine/asr/decoder/ctc_prefix_beam_search_decoder_main.cc b/runtime/engine/asr/decoder/ctc_prefix_beam_search_decoder_main.cc index 1fa56cffd..0935c6e6f 100644 --- a/runtime/engine/asr/decoder/ctc_prefix_beam_search_decoder_main.cc +++ b/runtime/engine/asr/decoder/ctc_prefix_beam_search_decoder_main.cc @@ -71,7 +71,7 @@ int main(int argc, char* argv[]) { std::shared_ptr raw_data = std::make_shared(); std::shared_ptr nnet_producer = - std::make_shared(nnet, raw_data); + std::make_shared(nnet, raw_data, 1.0); std::shared_ptr decodable = std::make_shared(nnet_producer); diff --git a/runtime/engine/asr/decoder/ctc_tlg_decoder.h b/runtime/engine/asr/decoder/ctc_tlg_decoder.h index 2d40f0b91..80896361c 100644 --- a/runtime/engine/asr/decoder/ctc_tlg_decoder.h +++ b/runtime/engine/asr/decoder/ctc_tlg_decoder.h @@ -44,7 +44,7 @@ struct TLGDecoderOptions { decoder_opts.word_symbol_table = FLAGS_word_symbol_table; decoder_opts.fst_path = FLAGS_graph_path; LOG(INFO) << "fst path: " << decoder_opts.fst_path; - LOG(INFO) << "fst symbole table: " << decoder_opts.word_symbol_table; + LOG(INFO) << "symbole table: " << decoder_opts.word_symbol_table; if (!decoder_opts.fst_path.empty()) { CHECK(FileExists(decoder_opts.fst_path)); diff --git a/runtime/engine/asr/decoder/ctc_tlg_decoder_main.cc b/runtime/engine/asr/decoder/ctc_tlg_decoder_main.cc index 410574dcb..dcd18b810 100644 --- a/runtime/engine/asr/decoder/ctc_tlg_decoder_main.cc +++ b/runtime/engine/asr/decoder/ctc_tlg_decoder_main.cc @@ -54,7 +54,7 @@ int main(int argc, char* argv[]) { ppspeech::ModelOptions model_opts = ppspeech::ModelOptions::InitFromFlags(); std::shared_ptr nnet_producer = - std::make_shared(nullptr); + std::make_shared(nullptr, nullptr, 1.0); std::shared_ptr decodable( new ppspeech::Decodable(nnet_producer, FLAGS_acoustic_scale)); diff --git a/runtime/engine/asr/decoder/param.h b/runtime/engine/asr/decoder/param.h index bef5514fb..0cad75bfc 100644 --- a/runtime/engine/asr/decoder/param.h +++ b/runtime/engine/asr/decoder/param.h @@ -35,13 +35,11 @@ DEFINE_int32(subsampling_rate, "two CNN(kernel=3) module downsampling rate."); DEFINE_int32(nnet_decoder_chunk, 1, "paddle nnet forward chunk"); - // nnet DEFINE_string(model_path, "avg_1.jit.pdmodel", "paddle nnet model"); #ifdef USE_ONNX DEFINE_bool(with_onnx_model, false, "True mean the model path is onnx model path"); #endif -//DEFINE_string(param_path, "avg_1.jit.pdiparams", "paddle nnet model param"); // decoder DEFINE_double(acoustic_scale, 1.0, "acoustic scale"); @@ -50,10 +48,9 @@ DEFINE_string(word_symbol_table, "", "word symbol table"); DEFINE_int32(max_active, 7500, "max active"); DEFINE_double(beam, 15.0, "decoder beam"); DEFINE_double(lattice_beam, 7.5, "decoder beam"); - +DEFINE_double(blank_threshold, 0.98, "blank skip threshold"); // DecodeOptions flags -// DEFINE_int32(chunk_size, -1, "decoding chunk size"); DEFINE_int32(num_left_chunks, -1, "left chunks in decoding"); DEFINE_double(ctc_weight, 0.5, diff --git a/runtime/engine/asr/nnet/nnet_producer.cc b/runtime/engine/asr/nnet/nnet_producer.cc index b7bc8a33c..431b70251 100644 --- a/runtime/engine/asr/nnet/nnet_producer.cc +++ b/runtime/engine/asr/nnet/nnet_producer.cc @@ -22,8 +22,9 @@ using kaldi::BaseFloat; using std::vector; NnetProducer::NnetProducer(std::shared_ptr nnet, - std::shared_ptr frontend) - : nnet_(nnet), frontend_(frontend) { + std::shared_ptr frontend, + float blank_threshold) + : nnet_(nnet), frontend_(frontend), blank_threshold_(blank_threshold) { Reset(); } @@ -70,7 +71,22 @@ bool NnetProducer::Compute() { std::vector logprob( out.logprobs.data() + idx * vocab_dim, out.logprobs.data() + (idx + 1) * vocab_dim); - cache_.push_back(logprob); + // process blank prob + float blank_prob = std::exp(logprob[0]); + if (blank_prob > blank_threshold_) { + last_frame_logprob_ = logprob; + is_last_frame_skip_ = true; + continue; + } else { + int cur_max = std::max(logprob.begin(), logprob.end()) - logprob.begin(); + if (cur_max == last_max_elem_ && cur_max != 0 && is_last_frame_skip_) { + cache_.push_back(last_frame_logprob_); + last_max_elem_ = cur_max; + } + last_max_elem_ = cur_max; + is_last_frame_skip_ = false; + cache_.push_back(logprob); + } } return true; } diff --git a/runtime/engine/asr/nnet/nnet_producer.h b/runtime/engine/asr/nnet/nnet_producer.h index 83521ea76..21aee067e 100644 --- a/runtime/engine/asr/nnet/nnet_producer.h +++ b/runtime/engine/asr/nnet/nnet_producer.h @@ -24,7 +24,8 @@ namespace ppspeech { class NnetProducer { public: explicit NnetProducer(std::shared_ptr nnet, - std::shared_ptr frontend = NULL); + std::shared_ptr frontend, + float blank_threshold); // Feed feats or waves void Accept(const std::vector& inputs); @@ -64,6 +65,10 @@ class NnetProducer { std::shared_ptr frontend_; std::shared_ptr nnet_; SafeQueue> cache_; + std::vector last_frame_logprob_; + bool is_last_frame_skip_ = false; + int last_max_elem_ = -1; + float blank_threshold_ = 0.0; bool finished_; DISALLOW_COPY_AND_ASSIGN(NnetProducer); diff --git a/runtime/engine/asr/recognizer/recognizer_controller_impl.cc b/runtime/engine/asr/recognizer/recognizer_controller_impl.cc index 3d141752d..cc4d3c78a 100644 --- a/runtime/engine/asr/recognizer/recognizer_controller_impl.cc +++ b/runtime/engine/asr/recognizer/recognizer_controller_impl.cc @@ -21,6 +21,7 @@ namespace ppspeech { RecognizerControllerImpl::RecognizerControllerImpl(const RecognizerResource& resource) : opts_(resource) { BaseFloat am_scale = resource.acoustic_scale; + BaseFloat blank_threshold = resource.blank_threshold; const FeaturePipelineOptions& feature_opts = resource.feature_pipeline_opts; std::shared_ptr feature_pipeline( new FeaturePipeline(feature_opts)); @@ -34,7 +35,7 @@ RecognizerControllerImpl::RecognizerControllerImpl(const RecognizerResource& res nnet = resource.nnet->Clone(); } #endif - nnet_producer_.reset(new NnetProducer(nnet, feature_pipeline)); + nnet_producer_.reset(new NnetProducer(nnet, feature_pipeline, blank_threshold)); nnet_thread_ = std::thread(RunNnetEvaluation, this); decodable_.reset(new Decodable(nnet_producer_, am_scale)); diff --git a/runtime/engine/asr/recognizer/recognizer_resource.h b/runtime/engine/asr/recognizer/recognizer_resource.h index 963149dfd..064a5b5ba 100644 --- a/runtime/engine/asr/recognizer/recognizer_resource.h +++ b/runtime/engine/asr/recognizer/recognizer_resource.h @@ -12,6 +12,7 @@ DECLARE_double(reverse_weight); DECLARE_int32(nbest); DECLARE_int32(blank); DECLARE_double(acoustic_scale); +DECLARE_double(blank_threshold); DECLARE_string(word_symbol_table); namespace ppspeech { @@ -71,6 +72,7 @@ struct DecodeOptions { struct RecognizerResource { // decodable opt kaldi::BaseFloat acoustic_scale{1.0}; + kaldi::BaseFloat blank_threshold{0.98}; FeaturePipelineOptions feature_pipeline_opts{}; ModelOptions model_opts{}; @@ -80,6 +82,7 @@ struct RecognizerResource { static RecognizerResource InitFromFlags() { RecognizerResource resource; resource.acoustic_scale = FLAGS_acoustic_scale; + resource.blank_threshold = FLAGS_blank_threshold; LOG(INFO) << "acoustic_scale: " << resource.acoustic_scale; resource.feature_pipeline_opts = diff --git a/runtime/engine/kaldi/fstbin/CMakeLists.txt b/runtime/engine/kaldi/fstbin/CMakeLists.txt index f53be578d..05d0501f3 100644 --- a/runtime/engine/kaldi/fstbin/CMakeLists.txt +++ b/runtime/engine/kaldi/fstbin/CMakeLists.txt @@ -11,5 +11,5 @@ fsttablecompose foreach(binary IN LISTS BINS) add_executable(${binary} ${CMAKE_CURRENT_SOURCE_DIR}/${binary}.cc) target_include_directories(${binary} PRIVATE ${SPEECHX_ROOT} ${SPEECHX_ROOT}/kaldi) - target_link_libraries(${binary} PUBLIC kaldi-fstext glog libgflags_nothreads.so fst dl) + target_link_libraries(${binary} PUBLIC kaldi-fstext glog gflags fst dl) endforeach() diff --git a/runtime/examples/u2pp_ol/wenetspeech/local/recognizer_wfst.sh b/runtime/examples/u2pp_ol/wenetspeech/local/recognizer_wfst.sh index 7b8a81f77..57d69a4c0 100755 --- a/runtime/examples/u2pp_ol/wenetspeech/local/recognizer_wfst.sh +++ b/runtime/examples/u2pp_ol/wenetspeech/local/recognizer_wfst.sh @@ -3,7 +3,7 @@ set -e data=data exp=exp -nj=40 +nj=20 . utils/parse_options.sh