From 616fc4594b2484f12400fb937c4b0ff0e9de4a15 Mon Sep 17 00:00:00 2001 From: Hui Zhang Date: Mon, 17 Oct 2022 08:19:43 +0000 Subject: [PATCH] refactor options --- speechx/examples/codelab/u2/local/decode.sh | 2 +- speechx/examples/codelab/u2/local/nnet.sh | 2 +- .../examples/codelab/u2/local/recognizer.sh | 2 +- speechx/speechx/decoder/CMakeLists.txt | 83 +++++++++++------- .../decoder/ctc_beam_search_decoder_main.cc | 15 ++-- .../ctc_prefix_beam_search_decoder_main.cc | 6 +- speechx/speechx/decoder/ctc_tlg_decoder.h | 32 +++++-- .../speechx/decoder/ctc_tlg_decoder_main.cc | 53 ++---------- speechx/speechx/decoder/param.h | 60 +++++-------- speechx/speechx/decoder/recognizer.h | 14 ++- speechx/speechx/decoder/recognizer_main.cc | 24 +----- speechx/speechx/decoder/u2_recognizer.h | 84 +++++++++--------- speechx/speechx/decoder/u2_recognizer_main.cc | 31 +------ .../speechx/frontend/audio/feature_pipeline.h | 77 +++++++++++++---- speechx/speechx/nnet/ds2_nnet_main.cc | 35 ++------ speechx/speechx/nnet/nnet_itf.h | 85 ++++++++++--------- speechx/speechx/nnet/u2_nnet.h | 1 - speechx/speechx/nnet/u2_nnet_main.cc | 23 ++--- .../websocket/websocket_server_main.cc | 24 +----- 19 files changed, 293 insertions(+), 360 deletions(-) diff --git a/speechx/examples/codelab/u2/local/decode.sh b/speechx/examples/codelab/u2/local/decode.sh index 24e9fca5b..c22ad7f07 100755 --- a/speechx/examples/codelab/u2/local/decode.sh +++ b/speechx/examples/codelab/u2/local/decode.sh @@ -14,7 +14,7 @@ ctc_prefix_beam_search_decoder_main \ --model_path=$model_dir/export.jit \ --nnet_decoder_chunk=16 \ --receptive_field_length=7 \ - --downsampling_rate=4 \ + --subsampling_rate=4 \ --vocab_path=$model_dir/unit.txt \ --feature_rspecifier=ark,t:$exp/fbank.ark \ --result_wspecifier=ark,t:$exp/result.ark diff --git a/speechx/examples/codelab/u2/local/nnet.sh b/speechx/examples/codelab/u2/local/nnet.sh index 78663e9c7..4419201cf 100755 --- a/speechx/examples/codelab/u2/local/nnet.sh +++ b/speechx/examples/codelab/u2/local/nnet.sh @@ -15,7 +15,7 @@ u2_nnet_main \ --feature_rspecifier=ark,t:$exp/fbank.ark \ --nnet_decoder_chunk=16 \ --receptive_field_length=7 \ - --downsampling_rate=4 \ + --subsampling_rate=4 \ --acoustic_scale=1.0 \ --nnet_encoder_outs_wspecifier=ark,t:$exp/encoder_outs.ark \ --nnet_prob_wspecifier=ark,t:$exp/logprobs.ark diff --git a/speechx/examples/codelab/u2/local/recognizer.sh b/speechx/examples/codelab/u2/local/recognizer.sh index a73597538..9f697b459 100755 --- a/speechx/examples/codelab/u2/local/recognizer.sh +++ b/speechx/examples/codelab/u2/local/recognizer.sh @@ -16,7 +16,7 @@ u2_recognizer_main \ --model_path=$model_dir/export.jit \ --nnet_decoder_chunk=16 \ --receptive_field_length=7 \ - --downsampling_rate=4 \ + --subsampling_rate=4 \ --vocab_path=$model_dir/unit.txt \ --wav_rspecifier=scp:$data/wav.scp \ --result_wspecifier=ark,t:$exp/result.ark diff --git a/speechx/speechx/decoder/CMakeLists.txt b/speechx/speechx/decoder/CMakeLists.txt index 472d93324..d06c3529b 100644 --- a/speechx/speechx/decoder/CMakeLists.txt +++ b/speechx/speechx/decoder/CMakeLists.txt @@ -1,44 +1,61 @@ project(decoder) include_directories(${CMAKE_CURRENT_SOURCE_DIR/ctc_decoders}) -add_library(decoder STATIC - ctc_decoders/decoder_utils.cpp - ctc_decoders/path_trie.cpp - ctc_decoders/scorer.cpp - ctc_beam_search_decoder.cc - ctc_prefix_beam_search_decoder.cc - ctc_tlg_decoder.cc - recognizer.cc - u2_recognizer.cc + +set(decoder_src ) + +if (USING_DS2) +list(APPEND decoder_src +ctc_decoders/decoder_utils.cpp +ctc_decoders/path_trie.cpp +ctc_decoders/scorer.cpp +ctc_beam_search_decoder.cc +ctc_tlg_decoder.cc +recognizer.cc ) +endif() + +if (USING_U2) + list(APPEND decoder_src + ctc_prefix_beam_search_decoder.cc + u2_recognizer.cc + ) +endif() + +add_library(decoder STATIC ${decoder_src}) target_link_libraries(decoder PUBLIC kenlm utils fst frontend nnet kaldi-decoder absl::strings) # test -set(BINS - ctc_beam_search_decoder_main - nnet_logprob_decoder_main - recognizer_main - ctc_tlg_decoder_main -) +if (USING_DS2) + set(BINS + ctc_beam_search_decoder_main + nnet_logprob_decoder_main + recognizer_main + ctc_tlg_decoder_main + ) -foreach(bin_name IN LISTS BINS) - add_executable(${bin_name} ${CMAKE_CURRENT_SOURCE_DIR}/${bin_name}.cc) - target_include_directories(${bin_name} PRIVATE ${SPEECHX_ROOT} ${SPEECHX_ROOT}/kaldi) - target_link_libraries(${bin_name} PUBLIC nnet decoder fst utils gflags glog kaldi-base kaldi-matrix kaldi-util ${DEPS}) -endforeach() + foreach(bin_name IN LISTS BINS) + add_executable(${bin_name} ${CMAKE_CURRENT_SOURCE_DIR}/${bin_name}.cc) + target_include_directories(${bin_name} PRIVATE ${SPEECHX_ROOT} ${SPEECHX_ROOT}/kaldi) + target_link_libraries(${bin_name} PUBLIC nnet decoder fst utils gflags glog kaldi-base kaldi-matrix kaldi-util ${DEPS}) + endforeach() +endif() -# u2 -set(TEST_BINS - u2_recognizer_main - ctc_prefix_beam_search_decoder_main -) +if (USING_U2) + set(TEST_BINS + ctc_prefix_beam_search_decoder_main + u2_recognizer_main + ) + + foreach(bin_name IN LISTS TEST_BINS) + add_executable(${bin_name} ${CMAKE_CURRENT_SOURCE_DIR}/${bin_name}.cc) + target_include_directories(${bin_name} PRIVATE ${SPEECHX_ROOT} ${SPEECHX_ROOT}/kaldi) + target_link_libraries(${bin_name} nnet decoder fst utils gflags glog kaldi-base kaldi-matrix kaldi-util) + target_compile_options(${bin_name} PRIVATE ${PADDLE_COMPILE_FLAGS}) + target_include_directories(${bin_name} PRIVATE ${pybind11_INCLUDE_DIRS} ${PROJECT_SOURCE_DIR}) + target_link_libraries(${bin_name} ${PYTHON_LIBRARIES} ${PADDLE_LINK_FLAGS}) + endforeach() + +endif() -foreach(bin_name IN LISTS TEST_BINS) - add_executable(${bin_name} ${CMAKE_CURRENT_SOURCE_DIR}/${bin_name}.cc) - target_include_directories(${bin_name} PRIVATE ${SPEECHX_ROOT} ${SPEECHX_ROOT}/kaldi) - target_link_libraries(${bin_name} nnet decoder fst utils gflags glog kaldi-base kaldi-matrix kaldi-util) - target_compile_options(${bin_name} PRIVATE ${PADDLE_COMPILE_FLAGS}) - target_include_directories(${bin_name} PRIVATE ${pybind11_INCLUDE_DIRS} ${PROJECT_SOURCE_DIR}) - target_link_libraries(${bin_name} ${PYTHON_LIBRARIES} ${PADDLE_LINK_FLAGS}) -endforeach() \ No newline at end of file diff --git a/speechx/speechx/decoder/ctc_beam_search_decoder_main.cc b/speechx/speechx/decoder/ctc_beam_search_decoder_main.cc index 7e245e9b8..edf9215ad 100644 --- a/speechx/speechx/decoder/ctc_beam_search_decoder_main.cc +++ b/speechx/speechx/decoder/ctc_beam_search_decoder_main.cc @@ -31,7 +31,7 @@ DEFINE_string(lm_path, "", "language model"); DEFINE_int32(receptive_field_length, 7, "receptive field of two CNN(kernel=3) downsampling module."); -DEFINE_int32(downsampling_rate, +DEFINE_int32(subsampling_rate, 4, "two CNN(kernel=3) module downsampling rate."); DEFINE_string( @@ -81,13 +81,8 @@ int main(int argc, char* argv[]) { opts.lm_path = lm_path; ppspeech::CTCBeamSearch decoder(opts); - ppspeech::ModelOptions model_opts; - model_opts.model_path = model_path; - model_opts.param_path = model_params; - model_opts.cache_names = FLAGS_model_cache_names; - model_opts.cache_shape = FLAGS_model_cache_shapes; - model_opts.input_names = FLAGS_model_input_names; - model_opts.output_names = FLAGS_model_output_names; + ppspeech::ModelOptions model_opts = ppspeech::ModelOptions::InitFromFlags(); + std::shared_ptr nnet( new ppspeech::PaddleNnet(model_opts)); std::shared_ptr raw_data(new ppspeech::DataCache()); @@ -95,8 +90,8 @@ int main(int argc, char* argv[]) { new ppspeech::Decodable(nnet, raw_data)); int32 chunk_size = FLAGS_receptive_field_length + - (FLAGS_nnet_decoder_chunk - 1) * FLAGS_downsampling_rate; - int32 chunk_stride = FLAGS_downsampling_rate * FLAGS_nnet_decoder_chunk; + (FLAGS_nnet_decoder_chunk - 1) * FLAGS_subsampling_rate; + int32 chunk_stride = FLAGS_subsampling_rate * FLAGS_nnet_decoder_chunk; int32 receptive_field_length = FLAGS_receptive_field_length; LOG(INFO) << "chunk size (frame): " << chunk_size; LOG(INFO) << "chunk stride (frame): " << chunk_stride; diff --git a/speechx/speechx/decoder/ctc_prefix_beam_search_decoder_main.cc b/speechx/speechx/decoder/ctc_prefix_beam_search_decoder_main.cc index dd3523786..7a488bb0d 100644 --- a/speechx/speechx/decoder/ctc_prefix_beam_search_decoder_main.cc +++ b/speechx/speechx/decoder/ctc_prefix_beam_search_decoder_main.cc @@ -30,7 +30,7 @@ DEFINE_string(model_path, "", "paddle nnet model"); DEFINE_int32(receptive_field_length, 7, "receptive field of two CNN(kernel=3) downsampling module."); -DEFINE_int32(downsampling_rate, +DEFINE_int32(subsampling_rate, 4, "two CNN(kernel=3) module downsampling rate."); @@ -81,8 +81,8 @@ int main(int argc, char* argv[]) { int32 chunk_size = FLAGS_receptive_field_length + - (FLAGS_nnet_decoder_chunk - 1) * FLAGS_downsampling_rate; - int32 chunk_stride = FLAGS_downsampling_rate * FLAGS_nnet_decoder_chunk; + (FLAGS_nnet_decoder_chunk - 1) * FLAGS_subsampling_rate; + int32 chunk_stride = FLAGS_subsampling_rate * FLAGS_nnet_decoder_chunk; int32 receptive_field_length = FLAGS_receptive_field_length; LOG(INFO) << "chunk size (frame): " << chunk_size; LOG(INFO) << "chunk stride (frame): " << chunk_stride; diff --git a/speechx/speechx/decoder/ctc_tlg_decoder.h b/speechx/speechx/decoder/ctc_tlg_decoder.h index 2f1d6c10a..76bbcf42d 100644 --- a/speechx/speechx/decoder/ctc_tlg_decoder.h +++ b/speechx/speechx/decoder/ctc_tlg_decoder.h @@ -20,15 +20,37 @@ #include "kaldi/decoder/lattice-faster-online-decoder.h" #include "util/parse-options.h" + +DECLARE_string(graph_path); +DECLARE_string(word_symbol_table); +DECLARE_int32(max_active); +DECLARE_double(beam); +DECLARE_double(lattice_beam); + namespace ppspeech { struct TLGDecoderOptions { - kaldi::LatticeFasterDecoderConfig opts; + kaldi::LatticeFasterDecoderConfig opts{}; // todo remove later, add into decode resource - std::string word_symbol_table; - std::string fst_path; - - TLGDecoderOptions() : word_symbol_table(""), fst_path("") {} + std::string word_symbol_table{}; + std::string fst_path{}; + + static TLGDecoderOptions InitFromFlags(){ + TLGDecoderOptions decoder_opts; + decoder_opts.word_symbol_table = FLAGS_word_symbol_table; + decoder_opts.fst_path = FLAGS_graph_path; + LOG(INFO) << "fst path: " << decoder_opts.fst_path; + LOG(INFO) << "fst symbole table: " << decoder_opts.word_symbol_table; + + decoder_opts.opts.max_active = FLAGS_max_active; + decoder_opts.opts.beam = FLAGS_beam; + decoder_opts.opts.lattice_beam = FLAGS_lattice_beam; + LOG(INFO) << "LatticeFasterDecoder max active: " << decoder_opts.opts.max_active ; + LOG(INFO) << "LatticeFasterDecoder beam: " << decoder_opts.opts.beam ; + LOG(INFO) << "LatticeFasterDecoder lattice_beam: " << decoder_opts.opts.lattice_beam ; + + return decoder_opts; + } }; class TLGDecoder : public DecoderInterface { diff --git a/speechx/speechx/decoder/ctc_tlg_decoder_main.cc b/speechx/speechx/decoder/ctc_tlg_decoder_main.cc index cd1249d84..f262101ac 100644 --- a/speechx/speechx/decoder/ctc_tlg_decoder_main.cc +++ b/speechx/speechx/decoder/ctc_tlg_decoder_main.cc @@ -19,6 +19,7 @@ #include "frontend/audio/data_cache.h" #include "nnet/decodable.h" #include "nnet/ds2_nnet.h" +#include "decoder/param.h" #include "decoder/ctc_tlg_decoder.h" #include "kaldi/util/table-types.h" @@ -26,30 +27,7 @@ DEFINE_string(feature_rspecifier, "", "test feature rspecifier"); DEFINE_string(result_wspecifier, "", "test result wspecifier"); -DEFINE_string(model_path, "avg_1.jit.pdmodel", "paddle nnet model"); -DEFINE_string(param_path, "avg_1.jit.pdiparams", "paddle nnet model param"); -DEFINE_string(word_symbol_table, "words.txt", "word symbol table"); -DEFINE_string(graph_path, "TLG", "decoder graph"); -DEFINE_double(acoustic_scale, 1.0, "acoustic scale"); -DEFINE_int32(max_active, 7500, "decoder graph"); -DEFINE_int32(nnet_decoder_chunk, 1, "paddle nnet forward chunk"); -DEFINE_int32(receptive_field_length, - 7, - "receptive field of two CNN(kernel=3) downsampling module."); -DEFINE_int32(downsampling_rate, - 4, - "two CNN(kernel=3) module downsampling rate."); -DEFINE_string( - model_input_names, - "audio_chunk,audio_chunk_lens,chunk_state_h_box,chunk_state_c_box", - "model input names"); -DEFINE_string(model_output_names, - "softmax_0.tmp_0,tmp_5,concat_0.tmp_0,concat_1.tmp_0", - "model output names"); -DEFINE_string(model_cache_names, - "chunk_state_h_box,chunk_state_c_box", - "model cache names"); -DEFINE_string(model_cache_shapes, "5-1-1024,5-1-1024", "model cache shapes"); + using kaldi::BaseFloat; using kaldi::Matrix; @@ -66,32 +44,16 @@ int main(int argc, char* argv[]) { kaldi::SequentialBaseFloatMatrixReader feature_reader( FLAGS_feature_rspecifier); kaldi::TokenWriter result_writer(FLAGS_result_wspecifier); - std::string model_graph = FLAGS_model_path; - std::string model_params = FLAGS_param_path; - std::string word_symbol_table = FLAGS_word_symbol_table; - std::string graph_path = FLAGS_graph_path; - LOG(INFO) << "model path: " << model_graph; - LOG(INFO) << "model param: " << model_params; - LOG(INFO) << "word symbol path: " << word_symbol_table; - LOG(INFO) << "graph path: " << graph_path; int32 num_done = 0, num_err = 0; - ppspeech::TLGDecoderOptions opts; - opts.word_symbol_table = word_symbol_table; - opts.fst_path = graph_path; - opts.opts.max_active = FLAGS_max_active; + ppspeech::TLGDecoderOptions opts = ppspeech::TLGDecoderOptions::InitFromFlags(); opts.opts.beam = 15.0; opts.opts.lattice_beam = 7.5; ppspeech::TLGDecoder decoder(opts); - ppspeech::ModelOptions model_opts; - model_opts.model_path = model_graph; - model_opts.param_path = model_params; - model_opts.cache_names = FLAGS_model_cache_names; - model_opts.cache_shape = FLAGS_model_cache_shapes; - model_opts.input_names = FLAGS_model_input_names; - model_opts.output_names = FLAGS_model_output_names; + ppspeech::ModelOptions model_opts = ppspeech::ModelOptions::InitFromFlags(); + std::shared_ptr nnet( new ppspeech::PaddleNnet(model_opts)); std::shared_ptr raw_data(new ppspeech::DataCache()); @@ -99,12 +61,13 @@ int main(int argc, char* argv[]) { new ppspeech::Decodable(nnet, raw_data, FLAGS_acoustic_scale)); int32 chunk_size = FLAGS_receptive_field_length + - (FLAGS_nnet_decoder_chunk - 1) * FLAGS_downsampling_rate; - int32 chunk_stride = FLAGS_downsampling_rate * FLAGS_nnet_decoder_chunk; + (FLAGS_nnet_decoder_chunk - 1) * FLAGS_subsampling_rate; + int32 chunk_stride = FLAGS_subsampling_rate * FLAGS_nnet_decoder_chunk; int32 receptive_field_length = FLAGS_receptive_field_length; LOG(INFO) << "chunk size (frame): " << chunk_size; LOG(INFO) << "chunk stride (frame): " << chunk_stride; LOG(INFO) << "receptive field (frame): " << receptive_field_length; + decoder.InitDecoder(); kaldi::Timer timer; for (; !feature_reader.Done(); feature_reader.Next()) { diff --git a/speechx/speechx/decoder/param.h b/speechx/speechx/decoder/param.h index 1827e82d6..5e1120ad8 100644 --- a/speechx/speechx/decoder/param.h +++ b/speechx/speechx/decoder/param.h @@ -17,8 +17,6 @@ #include "base/common.h" #include "decoder/ctc_beam_search_decoder.h" #include "decoder/ctc_tlg_decoder.h" -#include "frontend/audio/feature_pipeline.h" - // feature DEFINE_bool(use_fbank, false, "False for fbank; or linear feature"); @@ -27,18 +25,18 @@ DEFINE_bool(use_fbank, false, "False for fbank; or linear feature"); DEFINE_int32(num_bins, 161, "num bins of mel"); DEFINE_string(cmvn_file, "", "read cmvn"); - // feature sliding window DEFINE_int32(receptive_field_length, 7, "receptive field of two CNN(kernel=3) downsampling module."); -DEFINE_int32(downsampling_rate, - 4, +DEFINE_int32(subsampling_rate, + 4, "two CNN(kernel=3) module downsampling rate."); DEFINE_int32(nnet_decoder_chunk, 1, "paddle nnet forward chunk"); // nnet +DEFINE_string(vocab_path, "", "nnet vocab path."); DEFINE_string(model_path, "avg_1.jit.pdmodel", "paddle nnet model"); DEFINE_string(param_path, "avg_1.jit.pdiparams", "paddle nnet model param"); DEFINE_string( @@ -52,10 +50,11 @@ DEFINE_string(model_cache_names, "chunk_state_h_box,chunk_state_c_box", "model cache names"); DEFINE_string(model_cache_shapes, "5-1-1024,5-1-1024", "model cache shapes"); -DEFINE_string(vocab_path, "", "nnet vocab path."); + // decoder DEFINE_double(acoustic_scale, 1.0, "acoustic scale"); + DEFINE_string(graph_path, "TLG", "decoder graph"); DEFINE_string(word_symbol_table, "words.txt", "word symbol table"); DEFINE_int32(max_active, 7500, "max active"); @@ -63,37 +62,20 @@ DEFINE_double(beam, 15.0, "decoder beam"); DEFINE_double(lattice_beam, 7.5, "decoder beam"); -namespace ppspeech { - -// todo refactor later -FeaturePipelineOptions InitFeaturePipelineOptions() { - FeaturePipelineOptions opts; - opts.cmvn_file = FLAGS_cmvn_file; - kaldi::FrameExtractionOptions frame_opts; - frame_opts.dither = 0.0; - frame_opts.frame_shift_ms = 10; - opts.use_fbank = FLAGS_use_fbank; - LOG(INFO) << "feature type: " << (opts.use_fbank ? "fbank" : "linear"); - if (opts.use_fbank) { - opts.to_float32 = false; - frame_opts.window_type = "povey"; - frame_opts.frame_length_ms = 25; - opts.fbank_opts.mel_opts.num_bins = FLAGS_num_bins; - opts.fbank_opts.frame_opts = frame_opts; - LOG(INFO) << "num bins: " << opts.fbank_opts.mel_opts.num_bins; - } else { - opts.to_float32 = true; - frame_opts.remove_dc_offset = false; - frame_opts.frame_length_ms = 20; - frame_opts.window_type = "hanning"; - frame_opts.preemph_coeff = 0.0; - opts.linear_spectrogram_opts.frame_opts = frame_opts; - } - opts.assembler_opts.subsampling_rate = FLAGS_downsampling_rate; - opts.assembler_opts.receptive_filed_length = FLAGS_receptive_field_length; - opts.assembler_opts.nnet_decoder_chunk = FLAGS_nnet_decoder_chunk; - - return opts; -} -} // namespace ppspeech +// DecodeOptions flags +// DEFINE_int32(chunk_size, -1, "decoding chunk size"); +DEFINE_int32(num_left_chunks, -1, "left chunks in decoding"); +DEFINE_double(ctc_weight, + 0.5, + "ctc weight when combining ctc score and rescoring score"); +DEFINE_double(rescoring_weight, + 1.0, + "rescoring weight when combining ctc score and rescoring score"); +DEFINE_double(reverse_weight, + 0.3, + "used for bitransformer rescoring. it must be 0.0 if decoder is" + "conventional transformer decoder, and only reverse_weight > 0.0" + "dose the right to left decoder will be calculated and used"); +DEFINE_int32(nbest, 10, "nbest for ctc wfst or prefix search"); +DEFINE_int32(blank, 0, "blank id in vocab"); diff --git a/speechx/speechx/decoder/recognizer.h b/speechx/speechx/decoder/recognizer.h index 4965e7a3d..51b666739 100644 --- a/speechx/speechx/decoder/recognizer.h +++ b/speechx/speechx/decoder/recognizer.h @@ -22,14 +22,26 @@ #include "nnet/decodable.h" #include "nnet/ds2_nnet.h" +DECLARE_double(acoustic_scale); + namespace ppspeech { struct RecognizerResource { + kaldi::BaseFloat acoustic_scale{1.0}; FeaturePipelineOptions feature_pipeline_opts{}; ModelOptions model_opts{}; TLGDecoderOptions tlg_opts{}; // CTCBeamSearchOptions beam_search_opts; - kaldi::BaseFloat acoustic_scale{1.0}; + + static RecognizerResource InitFromFlags(){ + RecognizerResource resource; + resource.acoustic_scale = FLAGS_acoustic_scale; + resource.feature_pipeline_opts = FeaturePipelineOptions::InitFromFlags(); + resource.model_opts = ModelOptions::InitFromFlags(); + resource.tlg_opts = TLGDecoderOptions::InitFromFlags(); + return resource; + + } }; class Recognizer { diff --git a/speechx/speechx/decoder/recognizer_main.cc b/speechx/speechx/decoder/recognizer_main.cc index 2b497d6ea..662943b57 100644 --- a/speechx/speechx/decoder/recognizer_main.cc +++ b/speechx/speechx/decoder/recognizer_main.cc @@ -25,27 +25,9 @@ DEFINE_int32(sample_rate, 16000, "sample rate"); ppspeech::RecognizerResource InitRecognizerResoure() { ppspeech::RecognizerResource resource; resource.acoustic_scale = FLAGS_acoustic_scale; - resource.feature_pipeline_opts = ppspeech::InitFeaturePipelineOptions(); - - ppspeech::ModelOptions model_opts; - model_opts.model_path = FLAGS_model_path; - model_opts.param_path = FLAGS_param_path; - model_opts.cache_names = FLAGS_model_cache_names; - model_opts.cache_shape = FLAGS_model_cache_shapes; - model_opts.input_names = FLAGS_model_input_names; - model_opts.output_names = FLAGS_model_output_names; - model_opts.subsample_rate = FLAGS_downsampling_rate; - resource.model_opts = model_opts; - - ppspeech::TLGDecoderOptions decoder_opts; - decoder_opts.word_symbol_table = FLAGS_word_symbol_table; - decoder_opts.fst_path = FLAGS_graph_path; - decoder_opts.opts.max_active = FLAGS_max_active; - decoder_opts.opts.beam = FLAGS_beam; - decoder_opts.opts.lattice_beam = FLAGS_lattice_beam; - - resource.tlg_opts = decoder_opts; - + resource.feature_pipeline_opts = ppspeech::FeaturePipelineOptions::InitFromFlags(); + resource.model_opts = ppspeech::ModelOptions::InitFromFlags(); + resource.tlg_opts = ppspeech::TLGDecoderOptions::InitFromFlags(); return resource; } diff --git a/speechx/speechx/decoder/u2_recognizer.h b/speechx/speechx/decoder/u2_recognizer.h index a65cae3b3..86bd48216 100644 --- a/speechx/speechx/decoder/u2_recognizer.h +++ b/speechx/speechx/decoder/u2_recognizer.h @@ -26,15 +26,25 @@ #include "fst/fstlib.h" #include "fst/symbol-table.h" -namespace ppspeech { +DECLARE_int32(nnet_decoder_chunk); +DECLARE_int32(num_left_chunks); +DECLARE_double(ctc_weight); +DECLARE_double(rescoring_weight); +DECLARE_double(reverse_weight); +DECLARE_int32(nbest); +DECLARE_int32(blank); + +DECLARE_double(acoustic_scale); +DECLARE_string(vocab_path); +namespace ppspeech { struct DecodeOptions { // chunk_size is the frame number of one chunk after subsampling. // e.g. if subsample rate is 4 and chunk_size = 16, the frames in // one chunk are 67=16*4 + 3, stride is 64=16*4 - int chunk_size; - int num_left_chunks; + int chunk_size{16}; + int num_left_chunks{-1}; // final_score = rescoring_weight * rescoring_score + ctc_weight * // ctc_score; @@ -46,51 +56,27 @@ struct DecodeOptions { // it's a sum(prefix) score + context score For CtcWfstBeamSerch, it's a // max(viterbi) path score + context score So we should carefully set // ctc_weight accroding to the search methods. - float ctc_weight; - float rescoring_weight; - float reverse_weight; + float ctc_weight{0.0}; + float rescoring_weight{1.0}; + float reverse_weight{0.0}; // CtcEndpointConfig ctc_endpoint_opts; - CTCBeamSearchOptions ctc_prefix_search_opts; - - DecodeOptions() - : chunk_size(16), - num_left_chunks(-1), - ctc_weight(0.5), - rescoring_weight(1.0), - reverse_weight(0.0) {} - - void Register(kaldi::OptionsItf* opts) { - std::string module = "DecoderConfig: "; - opts->Register( - "chunk-size", - &chunk_size, - module + "the frame number of one chunk after subsampling."); - opts->Register("num-left-chunks", - &num_left_chunks, - module + "the left history chunks number."); - opts->Register("ctc-weight", - &ctc_weight, - module + - "ctc weight for rescore. final_score = " - "rescoring_weight * rescoring_score + ctc_weight * " - "ctc_score."); - opts->Register("rescoring-weight", - &rescoring_weight, - module + - "attention score weight for rescore. final_score = " - "rescoring_weight * rescoring_score + ctc_weight * " - "ctc_score."); - opts->Register("reverse-weight", - &reverse_weight, - module + - "reverse decoder weight. rescoring_score = " - "left_to_right_score * (1 - reverse_weight) + " - "right_to_left_score * reverse_weight."); + CTCBeamSearchOptions ctc_prefix_search_opts{}; + + static DecodeOptions InitFromFlags(){ + DecodeOptions decoder_opts; + decoder_opts.chunk_size=FLAGS_nnet_decoder_chunk; + decoder_opts.num_left_chunks = FLAGS_num_left_chunks; + decoder_opts.ctc_weight = FLAGS_ctc_weight; + decoder_opts.rescoring_weight = FLAGS_rescoring_weight; + decoder_opts.reverse_weight = FLAGS_reverse_weight; + decoder_opts.ctc_prefix_search_opts.blank = FLAGS_blank; + decoder_opts.ctc_prefix_search_opts.first_beam_size = FLAGS_nbest; + decoder_opts.ctc_prefix_search_opts.second_beam_size = FLAGS_nbest; + return decoder_opts; } }; - struct U2RecognizerResource { kaldi::BaseFloat acoustic_scale{1.0}; std::string vocab_path{}; @@ -98,7 +84,17 @@ struct U2RecognizerResource { FeaturePipelineOptions feature_pipeline_opts{}; ModelOptions model_opts{}; DecodeOptions decoder_opts{}; - // CTCBeamSearchOptions beam_search_opts; + + static U2RecognizerResource InitFromFlags() { + U2RecognizerResource resource; + resource.vocab_path = FLAGS_vocab_path; + resource.acoustic_scale = FLAGS_acoustic_scale; + + resource.feature_pipeline_opts = ppspeech::FeaturePipelineOptions::InitFromFlags(); + resource.model_opts = ppspeech::ModelOptions::InitFromFlags(); + resource.decoder_opts = ppspeech::DecodeOptions::InitFromFlags(); + return resource; +} }; diff --git a/speechx/speechx/decoder/u2_recognizer_main.cc b/speechx/speechx/decoder/u2_recognizer_main.cc index ab2c66950..b1a7b2e8e 100644 --- a/speechx/speechx/decoder/u2_recognizer_main.cc +++ b/speechx/speechx/decoder/u2_recognizer_main.cc @@ -22,35 +22,6 @@ DEFINE_string(result_wspecifier, "", "test result wspecifier"); DEFINE_double(streaming_chunk, 0.36, "streaming feature chunk size"); DEFINE_int32(sample_rate, 16000, "sample rate"); - -ppspeech::U2RecognizerResource InitOpts() { - ppspeech::U2RecognizerResource resource; - resource.vocab_path = FLAGS_vocab_path; - resource.acoustic_scale = FLAGS_acoustic_scale; - - resource.feature_pipeline_opts = ppspeech::InitFeaturePipelineOptions(); - LOG(INFO) << "feature!"; - ppspeech::ModelOptions model_opts; - model_opts.model_path = FLAGS_model_path; - - resource.model_opts = model_opts; - LOG(INFO) << "model!"; - - ppspeech::DecodeOptions decoder_opts; - decoder_opts.chunk_size=16; - decoder_opts.num_left_chunks = -1; - decoder_opts.ctc_weight = 0.5; - decoder_opts.rescoring_weight = 1.0; - decoder_opts.reverse_weight = 0.3; - decoder_opts.ctc_prefix_search_opts.blank = 0; - decoder_opts.ctc_prefix_search_opts.first_beam_size = 10; - decoder_opts.ctc_prefix_search_opts.second_beam_size = 10; - - resource.decoder_opts = decoder_opts; - LOG(INFO) << "decoder!"; - return resource; -} - int main(int argc, char* argv[]) { gflags::SetUsageMessage("Usage:"); gflags::ParseCommandLineFlags(&argc, &argv, false); @@ -72,7 +43,7 @@ int main(int argc, char* argv[]) { LOG(INFO) << "chunk size (s): " << streaming_chunk; LOG(INFO) << "chunk size (sample): " << chunk_sample_size; - ppspeech::U2RecognizerResource resource = InitOpts(); + ppspeech::U2RecognizerResource resource = ppspeech::U2RecognizerResource::InitFromFlags(); ppspeech::U2Recognizer recognizer(resource); kaldi::Timer timer; diff --git a/speechx/speechx/frontend/audio/feature_pipeline.h b/speechx/speechx/frontend/audio/feature_pipeline.h index 613f69c6a..38a47433f 100644 --- a/speechx/speechx/frontend/audio/feature_pipeline.h +++ b/speechx/speechx/frontend/audio/feature_pipeline.h @@ -25,26 +25,71 @@ #include "frontend/audio/linear_spectrogram.h" #include "frontend/audio/normalizer.h" +// feature +DECLARE_bool(use_fbank); +DECLARE_int32(num_bins); +DECLARE_string(cmvn_file); + +// feature sliding window +DECLARE_int32(receptive_field_length); +DECLARE_int32(subsampling_rate); +DECLARE_int32(nnet_decoder_chunk); + namespace ppspeech { + struct FeaturePipelineOptions { - std::string cmvn_file; - bool to_float32; // true, only for linear feature - bool use_fbank; - LinearSpectrogramOptions linear_spectrogram_opts; - kaldi::FbankOptions fbank_opts; - FeatureCacheOptions feature_cache_opts; - AssemblerOptions assembler_opts; - - FeaturePipelineOptions() - : cmvn_file(""), - to_float32(false), // true, only for linear feature - use_fbank(true), - linear_spectrogram_opts(), - fbank_opts(), - feature_cache_opts(), - assembler_opts() {} + std::string cmvn_file{}; + bool to_float32{false}; // true, only for linear feature + bool use_fbank{true}; + LinearSpectrogramOptions linear_spectrogram_opts{}; + kaldi::FbankOptions fbank_opts{}; + FeatureCacheOptions feature_cache_opts{}; + AssemblerOptions assembler_opts{}; + + static FeaturePipelineOptions InitFromFlags(){ + FeaturePipelineOptions opts; + opts.cmvn_file = FLAGS_cmvn_file; + LOG(INFO) << "cmvn file: " << opts.cmvn_file; + + // frame options + kaldi::FrameExtractionOptions frame_opts; + frame_opts.dither = 0.0; + LOG(INFO) << "dither: " << frame_opts.dither; + frame_opts.frame_shift_ms = 10; + LOG(INFO) << "frame shift ms: " << frame_opts.frame_shift_ms; + opts.use_fbank = FLAGS_use_fbank; + LOG(INFO) << "feature type: " << (opts.use_fbank ? "fbank" : "linear"); + if (opts.use_fbank) { + opts.to_float32 = false; + frame_opts.window_type = "povey"; + frame_opts.frame_length_ms = 25; + opts.fbank_opts.mel_opts.num_bins = FLAGS_num_bins; + LOG(INFO) << "num bins: " << opts.fbank_opts.mel_opts.num_bins; + + opts.fbank_opts.frame_opts = frame_opts; + } else { + opts.to_float32 = true; + frame_opts.remove_dc_offset = false; + frame_opts.frame_length_ms = 20; + frame_opts.window_type = "hanning"; + frame_opts.preemph_coeff = 0.0; + + opts.linear_spectrogram_opts.frame_opts = frame_opts; + } + LOG(INFO) << "frame length ms: " << frame_opts.frame_length_ms; + + // assembler opts + opts.assembler_opts.subsampling_rate = FLAGS_subsampling_rate; + LOG(INFO) << "subsampling rate: " << opts.assembler_opts.subsampling_rate; + opts.assembler_opts.receptive_filed_length = FLAGS_receptive_field_length; + LOG(INFO) << "nnet receptive filed length: " << opts.assembler_opts.receptive_filed_length; + opts.assembler_opts.nnet_decoder_chunk = FLAGS_nnet_decoder_chunk; + LOG(INFO) << "nnet chunk size: " << opts.assembler_opts.nnet_decoder_chunk; + return opts; + } }; + class FeaturePipeline : public FrontendInterface { public: explicit FeaturePipeline(const FeaturePipelineOptions& opts); diff --git a/speechx/speechx/nnet/ds2_nnet_main.cc b/speechx/speechx/nnet/ds2_nnet_main.cc index 943d7e5f2..d8d33e982 100644 --- a/speechx/speechx/nnet/ds2_nnet_main.cc +++ b/speechx/speechx/nnet/ds2_nnet_main.cc @@ -14,6 +14,7 @@ #include "nnet/ds2_nnet.h" #include "base/common.h" +#include "decoder/param.h" #include "frontend/audio/assembler.h" #include "frontend/audio/data_cache.h" #include "kaldi/util/table-types.h" @@ -21,27 +22,6 @@ DEFINE_string(feature_rspecifier, "", "test feature rspecifier"); DEFINE_string(nnet_prob_wspecifier, "", "nnet porb wspecifier"); -DEFINE_string(model_path, "avg_1.jit.pdmodel", "paddle nnet model"); -DEFINE_string(param_path, "avg_1.jit.pdiparams", "paddle nnet model param"); -DEFINE_int32(nnet_decoder_chunk, 1, "paddle nnet forward chunk"); -DEFINE_int32(receptive_field_length, - 7, - "receptive field of two CNN(kernel=3) downsampling module."); -DEFINE_int32(downsampling_rate, - 4, - "two CNN(kernel=3) module downsampling rate."); -DEFINE_string( - model_input_names, - "audio_chunk,audio_chunk_lens,chunk_state_h_box,chunk_state_c_box", - "model input names"); -DEFINE_string(model_output_names, - "softmax_0.tmp_0,tmp_5,concat_0.tmp_0,concat_1.tmp_0", - "model output names"); -DEFINE_string(model_cache_names, - "chunk_state_h_box,chunk_state_c_box", - "model cache names"); -DEFINE_string(model_cache_shapes, "5-1-1024,5-1-1024", "model cache shapes"); -DEFINE_double(acoustic_scale, 1.0, "acoustic scale"); using kaldi::BaseFloat; using kaldi::Matrix; @@ -64,13 +44,8 @@ int main(int argc, char* argv[]) { int32 num_done = 0, num_err = 0; - ppspeech::ModelOptions model_opts; - model_opts.model_path = model_graph; - model_opts.param_path = model_params; - model_opts.cache_names = FLAGS_model_cache_names; - model_opts.cache_shape = FLAGS_model_cache_shapes; - model_opts.input_names = FLAGS_model_input_names; - model_opts.output_names = FLAGS_model_output_names; + ppspeech::ModelOptions model_opts = ppspeech::ModelOptions::InitFromFlags(); + std::shared_ptr nnet( new ppspeech::PaddleNnet(model_opts)); std::shared_ptr raw_data(new ppspeech::DataCache()); @@ -78,8 +53,8 @@ int main(int argc, char* argv[]) { new ppspeech::Decodable(nnet, raw_data, FLAGS_acoustic_scale)); int32 chunk_size = FLAGS_receptive_field_length + - (FLAGS_nnet_decoder_chunk - 1) * FLAGS_downsampling_rate; - int32 chunk_stride = FLAGS_downsampling_rate * FLAGS_nnet_decoder_chunk; + (FLAGS_nnet_decoder_chunk - 1) * FLAGS_subsampling_rate; + int32 chunk_stride = FLAGS_subsampling_rate * FLAGS_nnet_decoder_chunk; int32 receptive_field_length = FLAGS_receptive_field_length; LOG(INFO) << "chunk size (frame): " << chunk_size; LOG(INFO) << "chunk stride (frame): " << chunk_stride; diff --git a/speechx/speechx/nnet/nnet_itf.h b/speechx/speechx/nnet/nnet_itf.h index 109f54e0f..f8105b7f0 100644 --- a/speechx/speechx/nnet/nnet_itf.h +++ b/speechx/speechx/nnet/nnet_itf.h @@ -20,53 +20,54 @@ #include "kaldi/matrix/kaldi-matrix.h" #include "kaldi/util/options-itf.h" -namespace ppspeech { +DECLARE_int32(subsampling_rate); +DECLARE_string(model_path); +DECLARE_string(param_path); +DECLARE_string(model_input_names); +DECLARE_string(model_output_names); +DECLARE_string(model_cache_names); +DECLARE_string(model_cache_shapes); +namespace ppspeech { struct ModelOptions { + // common + int subsample_rate{1}; + int thread_num{1}; // predictor thread pool size for ds2; + bool use_gpu{false}; std::string model_path; + std::string param_path; - int thread_num; // predictor thread pool size for ds2; - bool use_gpu; - bool switch_ir_optim; - std::string input_names; - std::string output_names; - std::string cache_names; - std::string cache_shape; - bool enable_fc_padding; - bool enable_profile; - int subsample_rate; - ModelOptions() - : model_path(""), - param_path(""), - thread_num(1), - use_gpu(false), - input_names(""), - output_names(""), - cache_names(""), - cache_shape(""), - switch_ir_optim(false), - enable_fc_padding(false), - enable_profile(false), - subsample_rate(0) {} - - void Register(kaldi::OptionsItf* opts) { - opts->Register("model-path", &model_path, "model file path"); - opts->Register("model-param", ¶m_path, "params model file path"); - opts->Register("thread-num", &thread_num, "thread num"); - opts->Register("use-gpu", &use_gpu, "if use gpu"); - opts->Register("input-names", &input_names, "paddle input names"); - opts->Register("output-names", &output_names, "paddle output names"); - opts->Register("cache-names", &cache_names, "cache names"); - opts->Register("cache-shape", &cache_shape, "cache shape"); - opts->Register("switch-ir-optiom", - &switch_ir_optim, - "paddle SwitchIrOptim option"); - opts->Register("enable-fc-padding", - &enable_fc_padding, - "paddle EnableFCPadding option"); - opts->Register( - "enable-profile", &enable_profile, "paddle EnableProfile option"); + + // ds2 for inference + std::string input_names{}; + std::string output_names{}; + std::string cache_names{}; + std::string cache_shape{}; + bool switch_ir_optim{false}; + bool enable_fc_padding{false}; + bool enable_profile{false}; + + static ModelOptions InitFromFlags(){ + ModelOptions opts; + opts.subsample_rate = FLAGS_subsampling_rate; + LOG(INFO) << "subsampling rate: " << opts.subsample_rate; + opts.model_path = FLAGS_model_path; + LOG(INFO) << "model path: " << opts.model_path ; + + opts.param_path = FLAGS_param_path; + LOG(INFO) << "param path: " << opts.param_path ; + + LOG(INFO) << "DS2 param: "; + opts.cache_names = FLAGS_model_cache_names; + LOG(INFO) << " cache names: " << opts.cache_names; + opts.cache_shape = FLAGS_model_cache_shapes; + LOG(INFO) << " cache shape: " << opts.cache_shape; + opts.input_names = FLAGS_model_input_names; + LOG(INFO) << " input names: " << opts.input_names; + opts.output_names = FLAGS_model_output_names; + LOG(INFO) << " output names: " << opts.output_names; + return opts; } }; diff --git a/speechx/speechx/nnet/u2_nnet.h b/speechx/speechx/nnet/u2_nnet.h index 7058ea949..697ac20c8 100644 --- a/speechx/speechx/nnet/u2_nnet.h +++ b/speechx/speechx/nnet/u2_nnet.h @@ -17,7 +17,6 @@ #include "base/common.h" #include "kaldi/matrix/kaldi-matrix.h" - #include "nnet/nnet_itf.h" #include "paddle/extension.h" #include "paddle/jit/all.h" diff --git a/speechx/speechx/nnet/u2_nnet_main.cc b/speechx/speechx/nnet/u2_nnet_main.cc index 4b30f6b4a..adbbf0e80 100644 --- a/speechx/speechx/nnet/u2_nnet_main.cc +++ b/speechx/speechx/nnet/u2_nnet_main.cc @@ -12,28 +12,20 @@ // See the License for the specific language governing permissions and // limitations under the License. -#include "nnet/u2_nnet.h" + #include "base/common.h" #include "frontend/audio/assembler.h" #include "frontend/audio/data_cache.h" #include "kaldi/util/table-types.h" #include "nnet/decodable.h" +#include "decoder/param.h" +#include "nnet/u2_nnet.h" + DEFINE_string(feature_rspecifier, "", "test feature rspecifier"); DEFINE_string(nnet_prob_wspecifier, "", "nnet porb wspecifier"); DEFINE_string(nnet_encoder_outs_wspecifier, "", "nnet encoder outs wspecifier"); -DEFINE_string(model_path, "", "paddle nnet model"); - -DEFINE_int32(nnet_decoder_chunk, 16, "nnet forward chunk"); -DEFINE_int32(receptive_field_length, - 7, - "receptive field of two CNN(kernel=3) downsampling module."); -DEFINE_int32(downsampling_rate, - 4, - "two CNN(kernel=3) module downsampling rate."); -DEFINE_double(acoustic_scale, 1.0, "acoustic scale"); - using kaldi::BaseFloat; using kaldi::Matrix; using std::vector; @@ -58,13 +50,12 @@ int main(int argc, char* argv[]) { kaldi::BaseFloatMatrixWriter nnet_out_writer(FLAGS_nnet_prob_wspecifier); kaldi::BaseFloatMatrixWriter nnet_encoder_outs_writer(FLAGS_nnet_encoder_outs_wspecifier); - ppspeech::ModelOptions model_opts; - model_opts.model_path = FLAGS_model_path; + ppspeech::ModelOptions model_opts = ppspeech::ModelOptions::InitFromFlags(); int32 chunk_size = - (FLAGS_nnet_decoder_chunk - 1) * FLAGS_downsampling_rate + + (FLAGS_nnet_decoder_chunk - 1) * FLAGS_subsampling_rate + FLAGS_receptive_field_length; - int32 chunk_stride = FLAGS_downsampling_rate * FLAGS_nnet_decoder_chunk; + int32 chunk_stride = FLAGS_subsampling_rate * FLAGS_nnet_decoder_chunk; int32 receptive_field_length = FLAGS_receptive_field_length; LOG(INFO) << "chunk size (frame): " << chunk_size; LOG(INFO) << "chunk stride (frame): " << chunk_stride; diff --git a/speechx/speechx/protocol/websocket/websocket_server_main.cc b/speechx/speechx/protocol/websocket/websocket_server_main.cc index 9c01a0a1b..827b164f3 100644 --- a/speechx/speechx/protocol/websocket/websocket_server_main.cc +++ b/speechx/speechx/protocol/websocket/websocket_server_main.cc @@ -20,27 +20,9 @@ DEFINE_int32(port, 8082, "websocket listening port"); ppspeech::RecognizerResource InitRecognizerResoure() { ppspeech::RecognizerResource resource; resource.acoustic_scale = FLAGS_acoustic_scale; - resource.feature_pipeline_opts = ppspeech::InitFeaturePipelineOptions(); - - ppspeech::ModelOptions model_opts; - model_opts.model_path = FLAGS_model_path; - model_opts.param_path = FLAGS_param_path; - model_opts.cache_names = FLAGS_model_cache_names; - model_opts.cache_shape = FLAGS_model_cache_shapes; - model_opts.input_names = FLAGS_model_input_names; - model_opts.output_names = FLAGS_model_output_names; - model_opts.subsample_rate = FLAGS_downsampling_rate; - resource.model_opts = model_opts; - - ppspeech::TLGDecoderOptions decoder_opts; - decoder_opts.word_symbol_table = FLAGS_word_symbol_table; - decoder_opts.fst_path = FLAGS_graph_path; - decoder_opts.opts.max_active = FLAGS_max_active; - decoder_opts.opts.beam = FLAGS_beam; - decoder_opts.opts.lattice_beam = FLAGS_lattice_beam; - - resource.tlg_opts = decoder_opts; - + resource.feature_pipeline_opts = ppspeech::FeaturePipelineOptions::InitFromFlags(); + resource.model_opts = ppspeech::ModelOptions::InitFromFlags(); + resource.tlg_opts = ppspeech::TLGDecoderOptions::InitFromFlags(); return resource; }