diff --git a/speechx/examples/codelab/u2/local/decode.sh b/speechx/examples/codelab/u2/local/decode.sh index 12297661d..24e9fca5b 100755 --- a/speechx/examples/codelab/u2/local/decode.sh +++ b/speechx/examples/codelab/u2/local/decode.sh @@ -1,5 +1,5 @@ #!/bin/bash -set -x +set +x set -e . path.sh diff --git a/speechx/speechx/decoder/CMakeLists.txt b/speechx/speechx/decoder/CMakeLists.txt index 8cf94a100..472d93324 100644 --- a/speechx/speechx/decoder/CMakeLists.txt +++ b/speechx/speechx/decoder/CMakeLists.txt @@ -9,6 +9,7 @@ add_library(decoder STATIC ctc_prefix_beam_search_decoder.cc ctc_tlg_decoder.cc recognizer.cc + u2_recognizer.cc ) target_link_libraries(decoder PUBLIC kenlm utils fst frontend nnet kaldi-decoder absl::strings) @@ -28,10 +29,16 @@ endforeach() # u2 -set(bin_name ctc_prefix_beam_search_decoder_main) -add_executable(${bin_name} ${CMAKE_CURRENT_SOURCE_DIR}/${bin_name}.cc) -target_include_directories(${bin_name} PRIVATE ${SPEECHX_ROOT} ${SPEECHX_ROOT}/kaldi) -target_link_libraries(${bin_name} nnet decoder fst utils gflags glog kaldi-base kaldi-matrix kaldi-util) -target_compile_options(${bin_name} PRIVATE ${PADDLE_COMPILE_FLAGS}) -target_include_directories(${bin_name} PRIVATE ${pybind11_INCLUDE_DIRS} ${PROJECT_SOURCE_DIR}) -target_link_libraries(${bin_name} ${PYTHON_LIBRARIES} ${PADDLE_LINK_FLAGS}) \ No newline at end of file +set(TEST_BINS + u2_recognizer_main + ctc_prefix_beam_search_decoder_main +) + +foreach(bin_name IN LISTS TEST_BINS) + add_executable(${bin_name} ${CMAKE_CURRENT_SOURCE_DIR}/${bin_name}.cc) + target_include_directories(${bin_name} PRIVATE ${SPEECHX_ROOT} ${SPEECHX_ROOT}/kaldi) + target_link_libraries(${bin_name} nnet decoder fst utils gflags glog kaldi-base kaldi-matrix kaldi-util) + target_compile_options(${bin_name} PRIVATE ${PADDLE_COMPILE_FLAGS}) + target_include_directories(${bin_name} PRIVATE ${pybind11_INCLUDE_DIRS} ${PROJECT_SOURCE_DIR}) + target_link_libraries(${bin_name} ${PYTHON_LIBRARIES} ${PADDLE_LINK_FLAGS}) +endforeach() \ No newline at end of file diff --git a/speechx/speechx/decoder/common.h b/speechx/speechx/decoder/common.h index 52deffac9..0ae732771 100644 --- a/speechx/speechx/decoder/common.h +++ b/speechx/speechx/decoder/common.h @@ -1,3 +1,4 @@ +// Copyright (c) 2020 Mobvoi Inc (Binbin Zhang) // Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved. // // Licensed under the Apache License, Version 2.0 (the "License"); @@ -12,10 +13,36 @@ // See the License for the specific language governing permissions and // limitations under the License. -#include "base/basic_types.h" +#pragma once + +#include "base/common.h" struct DecoderResult { BaseFloat acoustic_score; std::vector words_idx; - std::vector> time_stamp; + std::vector> time_stamp; +}; + + +namespace ppspeech { + +struct WordPiece { + std::string word; + int start = -1; + int end = -1; + + WordPiece(std::string word, int start, int end) + : word(std::move(word)), start(start), end(end) {} }; + +struct DecodeResult { + float score = -kBaseFloatMax; + std::string sentence; + std::vector word_pieces; + + static bool CompareFunc(const DecodeResult& a, const DecodeResult& b) { + return a.score > b.score; + } +}; + +} // namespace ppspeech diff --git a/speechx/speechx/decoder/ctc_beam_search_opt.h b/speechx/speechx/decoder/ctc_beam_search_opt.h index af92fad05..d21b3abd8 100644 --- a/speechx/speechx/decoder/ctc_beam_search_opt.h +++ b/speechx/speechx/decoder/ctc_beam_search_opt.h @@ -76,68 +76,4 @@ struct CTCBeamSearchOptions { } }; - -// used by u2 model -struct CTCBeamSearchDecoderOptions { - // chunk_size is the frame number of one chunk after subsampling. - // e.g. if subsample rate is 4 and chunk_size = 16, the frames in - // one chunk are 67=16*4 + 3, stride is 64=16*4 - int chunk_size; - int num_left_chunks; - - // final_score = rescoring_weight * rescoring_score + ctc_weight * - // ctc_score; - // rescoring_score = left_to_right_score * (1 - reverse_weight) + - // right_to_left_score * reverse_weight - // Please note the concept of ctc_scores - // in the following two search methods are different. For - // CtcPrefixBeamSerch, - // it's a sum(prefix) score + context score For CtcWfstBeamSerch, it's a - // max(viterbi) path score + context score So we should carefully set - // ctc_weight accroding to the search methods. - float ctc_weight; - float rescoring_weight; - float reverse_weight; - - // CtcEndpointConfig ctc_endpoint_opts; - - CTCBeamSearchOptions ctc_prefix_search_opts; - - CTCBeamSearchDecoderOptions() - : chunk_size(16), - num_left_chunks(-1), - ctc_weight(0.5), - rescoring_weight(1.0), - reverse_weight(0.0) {} - - void Register(kaldi::OptionsItf* opts) { - std::string module = "DecoderConfig: "; - opts->Register( - "chunk-size", - &chunk_size, - module + "the frame number of one chunk after subsampling."); - opts->Register("num-left-chunks", - &num_left_chunks, - module + "the left history chunks number."); - opts->Register("ctc-weight", - &ctc_weight, - module + - "ctc weight for rescore. final_score = " - "rescoring_weight * rescoring_score + ctc_weight * " - "ctc_score."); - opts->Register("rescoring-weight", - &rescoring_weight, - module + - "attention score weight for rescore. final_score = " - "rescoring_weight * rescoring_score + ctc_weight * " - "ctc_score."); - opts->Register("reverse-weight", - &reverse_weight, - module + - "reverse decoder weight. rescoring_score = " - "left_to_right_score * (1 - reverse_weight) + " - "right_to_left_score * reverse_weight."); - } -}; - } // namespace ppspeech \ No newline at end of file diff --git a/speechx/speechx/decoder/ctc_prefix_beam_search_decoder.cc b/speechx/speechx/decoder/ctc_prefix_beam_search_decoder.cc index f22bfea27..ce2d4dc2f 100644 --- a/speechx/speechx/decoder/ctc_prefix_beam_search_decoder.cc +++ b/speechx/speechx/decoder/ctc_prefix_beam_search_decoder.cc @@ -30,8 +30,14 @@ using paddle::platform::TracerEventType; namespace ppspeech { -CTCPrefixBeamSearch::CTCPrefixBeamSearch(const CTCBeamSearchOptions& opts) +CTCPrefixBeamSearch::CTCPrefixBeamSearch( + const std::string vocab_path, + const CTCBeamSearchOptions& opts) : opts_(opts) { + + unit_table_ = std::shared_ptr(fst::SymbolTable::ReadText(vocab_path)); + CHECK(unit_table_ != nullptr); + Reset(); } @@ -322,7 +328,11 @@ void CTCPrefixBeamSearch::UpdateFinalContext() { CHECK(n_hyps > 0); CHECK(index < n_hyps); std::vector one = Outputs()[index]; - return std::string(absl::StrJoin(one, kSpaceSymbol)); + std::string sentence; + for (int i = 0; i < one.size(); i++){ + sentence += unit_table_->Find(one[i]); + } + return sentence; } std::string CTCPrefixBeamSearch::GetBestPath() { diff --git a/speechx/speechx/decoder/ctc_prefix_beam_search_decoder.h b/speechx/speechx/decoder/ctc_prefix_beam_search_decoder.h index ba44b0a20..2c28bee1b 100644 --- a/speechx/speechx/decoder/ctc_prefix_beam_search_decoder.h +++ b/speechx/speechx/decoder/ctc_prefix_beam_search_decoder.h @@ -15,17 +15,21 @@ #pragma once #include "decoder/ctc_beam_search_opt.h" -#include "decoder/ctc_prefix_beam_search_result.h" #include "decoder/ctc_prefix_beam_search_score.h" #include "decoder/decoder_itf.h" +#include "fst/symbol-table.h" + namespace ppspeech { class ContextGraph; class CTCPrefixBeamSearch : public DecoderInterface { public: - explicit CTCPrefixBeamSearch(const CTCBeamSearchOptions& opts); + explicit CTCPrefixBeamSearch(const std::string vocab_path, + const CTCBeamSearchOptions& opts); ~CTCPrefixBeamSearch() {} + SearchType Type() const { return SearchType::kPrefixBeamSearch; } + void InitDecoder() override; void Reset() override; @@ -38,10 +42,9 @@ class CTCPrefixBeamSearch : public DecoderInterface { void FinalizeSearch(); - protected: - std::string GetBestPath() override; - std::vector> GetNBestPath() override; - std::vector> GetNBestPath(int n) override; + const std::shared_ptr VocabTable() const { + return unit_table_; + } const std::vector>& Inputs() const { return hypotheses_; } const std::vector>& Outputs() const { return outputs_; } @@ -52,6 +55,11 @@ class CTCPrefixBeamSearch : public DecoderInterface { const std::vector>& Times() const { return times_; } + protected: + std::string GetBestPath() override; + std::vector> GetNBestPath() override; + std::vector> GetNBestPath(int n) override; + private: std::string GetBestPath(int index); @@ -66,6 +74,7 @@ class CTCPrefixBeamSearch : public DecoderInterface { private: CTCBeamSearchOptions opts_; + std::shared_ptr unit_table_; std::unordered_map, PrefixScore, PrefixScoreHash> cur_hyps_; @@ -86,28 +95,4 @@ class CTCPrefixBeamSearch : public DecoderInterface { }; -class CTCPrefixBeamSearchDecoder : public CTCPrefixBeamSearch { - public: - explicit CTCPrefixBeamSearchDecoder(const CTCBeamSearchDecoderOptions& opts) - : CTCPrefixBeamSearch(opts.ctc_prefix_search_opts), opts_(opts) {} - - ~CTCPrefixBeamSearchDecoder() {} - - private: - CTCBeamSearchDecoderOptions opts_; - - // cache feature - bool start_ = false; // false, this is first frame. - // for continues decoding - int num_frames_ = 0; - int global_frame_offset_ = 0; - const int time_stamp_gap_ = - 100; // timestamp gap between words in a sentence - - // std::unique_ptr ctc_endpointer_; - - int num_frames_in_current_chunk_ = 0; - std::vector result_; -}; - } // namespace ppspeech \ No newline at end of file diff --git a/speechx/speechx/decoder/ctc_prefix_beam_search_decoder_main.cc b/speechx/speechx/decoder/ctc_prefix_beam_search_decoder_main.cc index 8927a5f45..dd3523786 100644 --- a/speechx/speechx/decoder/ctc_prefix_beam_search_decoder_main.cc +++ b/speechx/speechx/decoder/ctc_prefix_beam_search_decoder_main.cc @@ -55,14 +55,12 @@ int main(int argc, char* argv[]) { CHECK(FLAGS_vocab_path != ""); CHECK(FLAGS_model_path != ""); LOG(INFO) << "model path: " << FLAGS_model_path; + LOG(INFO) << "Reading vocab table " << FLAGS_vocab_path; kaldi::SequentialBaseFloatMatrixReader feature_reader( FLAGS_feature_rspecifier); kaldi::TokenWriter result_writer(FLAGS_result_wspecifier); - LOG(INFO) << "Reading vocab table " << FLAGS_vocab_path; - fst::SymbolTable* unit_table = fst::SymbolTable::ReadText(FLAGS_vocab_path); - // nnet ppspeech::ModelOptions model_opts; model_opts.model_path = FLAGS_model_path; @@ -75,16 +73,11 @@ int main(int argc, char* argv[]) { new ppspeech::Decodable(nnet, raw_data)); // decoder - ppspeech::CTCBeamSearchDecoderOptions opts; - opts.chunk_size = 16; - opts.num_left_chunks = -1; - opts.ctc_weight = 0.5; - opts.rescoring_weight = 1.0; - opts.reverse_weight = 0.3; - opts.ctc_prefix_search_opts.blank = 0; - opts.ctc_prefix_search_opts.first_beam_size = 10; - opts.ctc_prefix_search_opts.second_beam_size = 10; - ppspeech::CTCPrefixBeamSearchDecoder decoder(opts); + ppspeech::CTCBeamSearchOptions opts; + opts.blank = 0; + opts.first_beam_size = 10; + opts.second_beam_size = 10; + ppspeech::CTCPrefixBeamSearch decoder(FLAGS_vocab_path, opts); int32 chunk_size = FLAGS_receptive_field_length + @@ -150,17 +143,14 @@ int main(int argc, char* argv[]) { // forward nnet decoder.AdvanceDecode(decodable); + + LOG(INFO) << "Partial result: " << decoder.GetPartialResult(); } decoder.FinalizeSearch(); // get 1-best result - std::string result_ints = decoder.GetFinalBestPath(); - std::vector tokenids = absl::StrSplit(result_ints, ppspeech::kSpaceSymbol); - std::string result; - for (int i = 0; i < tokenids.size(); i++){ - result += unit_table->Find(std::stoi(tokenids[i])); - } + std::string result = decoder.GetFinalBestPath(); // after process one utt, then reset state. decodable->Reset(); diff --git a/speechx/speechx/decoder/ctc_prefix_beam_search_result.h b/speechx/speechx/decoder/ctc_prefix_beam_search_result.h deleted file mode 100644 index caa3e37e6..000000000 --- a/speechx/speechx/decoder/ctc_prefix_beam_search_result.h +++ /dev/null @@ -1,41 +0,0 @@ -// Copyright (c) 2020 Mobvoi Inc (Binbin Zhang) -// Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved. -// -// Licensed under the Apache License, Version 2.0 (the "License"); -// you may not use this file except in compliance with the License. -// You may obtain a copy of the License at -// -// http://www.apache.org/licenses/LICENSE-2.0 -// -// Unless required by applicable law or agreed to in writing, software -// distributed under the License is distributed on an "AS IS" BASIS, -// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -// See the License for the specific language governing permissions and -// limitations under the License. - -#pragma once - -#include "base/common.h" - -namespace ppspeech { - -struct WordPiece { - std::string word; - int start = -1; - int end = -1; - - WordPiece(std::string word, int start, int end) - : word(std::move(word)), start(start), end(end) {} -}; - -struct DecodeResult { - float score = -kBaseFloatMax; - std::string sentence; - std::vector word_pieces; - - static bool CompareFunc(const DecodeResult& a, const DecodeResult& b) { - return a.score > b.score; - } -}; - -} // namespace ppspeech diff --git a/speechx/speechx/decoder/decoder_itf.h b/speechx/speechx/decoder/decoder_itf.h index fe4e7408d..eec9bc3d4 100644 --- a/speechx/speechx/decoder/decoder_itf.h +++ b/speechx/speechx/decoder/decoder_itf.h @@ -20,6 +20,10 @@ namespace ppspeech { +enum SearchType { + kPrefixBeamSearch = 0, + kWfstBeamSearch = 1, +}; class DecoderInterface { public: virtual ~DecoderInterface() {} diff --git a/speechx/speechx/decoder/param.h b/speechx/speechx/decoder/param.h index 8a5990dc8..e0f22d8c6 100644 --- a/speechx/speechx/decoder/param.h +++ b/speechx/speechx/decoder/param.h @@ -19,12 +19,15 @@ #include "decoder/ctc_tlg_decoder.h" #include "frontend/audio/feature_pipeline.h" + // feature DEFINE_bool(use_fbank, false, "False for fbank; or linear feature"); // DEFINE_bool(to_float32, true, "audio convert to pcm32. True for linear // feature, or fbank"); DEFINE_int32(num_bins, 161, "num bins of mel"); DEFINE_string(cmvn_file, "", "read cmvn"); + + // feature sliding window DEFINE_int32(receptive_field_length, 7, @@ -33,6 +36,8 @@ DEFINE_int32(downsampling_rate, 4, "two CNN(kernel=3) module downsampling rate."); DEFINE_int32(nnet_decoder_chunk, 1, "paddle nnet forward chunk"); + + // nnet DEFINE_string(model_path, "avg_1.jit.pdmodel", "paddle nnet model"); DEFINE_string(param_path, "avg_1.jit.pdiparams", "paddle nnet model param"); @@ -89,34 +94,4 @@ FeaturePipelineOptions InitFeaturePipelineOptions() { return opts; } -ModelOptions InitModelOptions() { - ModelOptions model_opts; - model_opts.model_path = FLAGS_model_path; - model_opts.param_path = FLAGS_param_path; - model_opts.cache_names = FLAGS_model_cache_names; - model_opts.cache_shape = FLAGS_model_cache_shapes; - model_opts.input_names = FLAGS_model_input_names; - model_opts.output_names = FLAGS_model_output_names; - return model_opts; -} - -TLGDecoderOptions InitDecoderOptions() { - TLGDecoderOptions decoder_opts; - decoder_opts.word_symbol_table = FLAGS_word_symbol_table; - decoder_opts.fst_path = FLAGS_graph_path; - decoder_opts.opts.max_active = FLAGS_max_active; - decoder_opts.opts.beam = FLAGS_beam; - decoder_opts.opts.lattice_beam = FLAGS_lattice_beam; - return decoder_opts; -} - -RecognizerResource InitRecognizerResoure() { - RecognizerResource resource; - resource.acoustic_scale = FLAGS_acoustic_scale; - resource.feature_pipeline_opts = InitFeaturePipelineOptions(); - resource.model_opts = InitModelOptions(); - resource.tlg_opts = InitDecoderOptions(); - return resource; -} - } // namespace ppspeech diff --git a/speechx/speechx/decoder/recognizer.cc b/speechx/speechx/decoder/recognizer.cc index 44c3911c9..bb9ea1872 100644 --- a/speechx/speechx/decoder/recognizer.cc +++ b/speechx/speechx/decoder/recognizer.cc @@ -14,6 +14,7 @@ #include "decoder/recognizer.h" + namespace ppspeech { using kaldi::Vector; @@ -23,14 +24,19 @@ using std::vector; using kaldi::SubVector; using std::unique_ptr; + Recognizer::Recognizer(const RecognizerResource& resource) { // resource_ = resource; const FeaturePipelineOptions& feature_opts = resource.feature_pipeline_opts; feature_pipeline_.reset(new FeaturePipeline(feature_opts)); + std::shared_ptr nnet(new PaddleNnet(resource.model_opts)); + BaseFloat ac_scale = resource.acoustic_scale; decodable_.reset(new Decodable(nnet, feature_pipeline_, ac_scale)); + decoder_.reset(new TLGDecoder(resource.tlg_opts)); + input_finished_ = false; } diff --git a/speechx/speechx/decoder/recognizer.h b/speechx/speechx/decoder/recognizer.h index e47ca433d..4965e7a3d 100644 --- a/speechx/speechx/decoder/recognizer.h +++ b/speechx/speechx/decoder/recognizer.h @@ -25,16 +25,11 @@ namespace ppspeech { struct RecognizerResource { - FeaturePipelineOptions feature_pipeline_opts; - ModelOptions model_opts; - TLGDecoderOptions tlg_opts; + FeaturePipelineOptions feature_pipeline_opts{}; + ModelOptions model_opts{}; + TLGDecoderOptions tlg_opts{}; // CTCBeamSearchOptions beam_search_opts; - kaldi::BaseFloat acoustic_scale; - RecognizerResource() - : acoustic_scale(1.0), - feature_pipeline_opts(), - model_opts(), - tlg_opts() {} + kaldi::BaseFloat acoustic_scale{1.0}; }; class Recognizer { diff --git a/speechx/speechx/decoder/recognizer_main.cc b/speechx/speechx/decoder/recognizer_main.cc index 050266462..2b497d6ea 100644 --- a/speechx/speechx/decoder/recognizer_main.cc +++ b/speechx/speechx/decoder/recognizer_main.cc @@ -22,6 +22,33 @@ DEFINE_string(result_wspecifier, "", "test result wspecifier"); DEFINE_double(streaming_chunk, 0.36, "streaming feature chunk size"); DEFINE_int32(sample_rate, 16000, "sample rate"); +ppspeech::RecognizerResource InitRecognizerResoure() { + ppspeech::RecognizerResource resource; + resource.acoustic_scale = FLAGS_acoustic_scale; + resource.feature_pipeline_opts = ppspeech::InitFeaturePipelineOptions(); + + ppspeech::ModelOptions model_opts; + model_opts.model_path = FLAGS_model_path; + model_opts.param_path = FLAGS_param_path; + model_opts.cache_names = FLAGS_model_cache_names; + model_opts.cache_shape = FLAGS_model_cache_shapes; + model_opts.input_names = FLAGS_model_input_names; + model_opts.output_names = FLAGS_model_output_names; + model_opts.subsample_rate = FLAGS_downsampling_rate; + resource.model_opts = model_opts; + + ppspeech::TLGDecoderOptions decoder_opts; + decoder_opts.word_symbol_table = FLAGS_word_symbol_table; + decoder_opts.fst_path = FLAGS_graph_path; + decoder_opts.opts.max_active = FLAGS_max_active; + decoder_opts.opts.beam = FLAGS_beam; + decoder_opts.opts.lattice_beam = FLAGS_lattice_beam; + + resource.tlg_opts = decoder_opts; + + return resource; +} + int main(int argc, char* argv[]) { gflags::SetUsageMessage("Usage:"); gflags::ParseCommandLineFlags(&argc, &argv, false); @@ -29,7 +56,7 @@ int main(int argc, char* argv[]) { google::InstallFailureSignalHandler(); FLAGS_logtostderr = 1; - ppspeech::RecognizerResource resource = ppspeech::InitRecognizerResoure(); + ppspeech::RecognizerResource resource = InitRecognizerResoure(); ppspeech::Recognizer recognizer(resource); kaldi::SequentialTableReader wav_reader( diff --git a/speechx/speechx/decoder/u2_recognizer.cc b/speechx/speechx/decoder/u2_recognizer.cc new file mode 100644 index 000000000..0ace086c4 --- /dev/null +++ b/speechx/speechx/decoder/u2_recognizer.cc @@ -0,0 +1,209 @@ +// Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#include "decoder/u2_recognizer.h" +#include "nnet/u2_nnet.h" + +namespace ppspeech { + +using kaldi::Vector; +using kaldi::VectorBase; +using kaldi::BaseFloat; +using std::vector; +using kaldi::SubVector; +using std::unique_ptr; + +U2Recognizer::U2Recognizer(const U2RecognizerResource& resource): opts_(resource) { + const FeaturePipelineOptions& feature_opts = resource.feature_pipeline_opts; + feature_pipeline_.reset(new FeaturePipeline(feature_opts)); + + std::shared_ptr nnet(new U2Nnet(resource.model_opts)); + + BaseFloat am_scale = resource.acoustic_scale; + decodable_.reset(new Decodable(nnet, feature_pipeline_, am_scale)); + + decoder_.reset(new CTCPrefixBeamSearch(resource.vocab_path, resource.decoder_opts.ctc_prefix_search_opts)); + + unit_table_ = decoder_->VocabTable(); + symbol_table_ = unit_table_; + + input_finished_ = false; +} + +void U2Recognizer::Reset() { + global_frame_offset_ = 0; + num_frames_ = 0; + result_.clear(); + + feature_pipeline_->Reset(); + decodable_->Reset(); + decoder_->Reset(); +} + +void U2Recognizer::ResetContinuousDecoding() { + global_frame_offset_ = num_frames_; + num_frames_ = 0; + result_.clear(); + + feature_pipeline_->Reset(); + decodable_->Reset(); + decoder_->Reset(); +} + + +void U2Recognizer::Accept(const VectorBase& waves) { + feature_pipeline_->Accept(waves); +} + + +void U2Recognizer::Decode() { + decoder_->AdvanceDecode(decodable_); +} + +void U2Recognizer::Rescoring() { + // Do attention Rescoring + kaldi::Timer timer; + AttentionRescoring(); + VLOG(1) << "Rescoring cost latency: " << timer.Elapsed() << " sec."; +} + +void U2Recognizer::UpdateResult(bool finish) { + const auto& hypotheses = decoder_->Outputs(); + const auto& inputs = decoder_->Inputs(); + const auto& likelihood = decoder_->Likelihood(); + const auto& times = decoder_->Times(); + result_.clear(); + + CHECK_EQ(hypotheses.size(), likelihood.size()); + for (size_t i = 0; i < hypotheses.size(); i++) { + const std::vector& hypothesis = hypotheses[i]; + + DecodeResult path; + path.score = likelihood[i]; + for (size_t j = 0; j < hypothesis.size(); j++) { + std::string word = symbol_table_->Find(hypothesis[j]); + // A detailed explanation of this if-else branch can be found in + // https://github.com/wenet-e2e/wenet/issues/583#issuecomment-907994058 + if (decoder_->Type() == kWfstBeamSearch) { + path.sentence += (" " + word); + } else { + path.sentence += (word); + } + } + + // TimeStamp is only supported in final result + // TimeStamp of the output of CtcWfstBeamSearch may be inaccurate due to + // various FST operations when building the decoding graph. So here we use + // time stamp of the input(e2e model unit), which is more accurate, and it + // requires the symbol table of the e2e model used in training. + if (unit_table_ != nullptr && finish) { + int offset = global_frame_offset_ * FrameShiftInMs(); + + const std::vector& input = inputs[i]; + const std::vector time_stamp = times[i]; + CHECK_EQ(input.size(), time_stamp.size()); + + for (size_t j = 0; j < input.size(); j++) { + std::string word = unit_table_->Find(input[j]); + + int start = time_stamp[j] * FrameShiftInMs() - time_stamp_gap_ > 0 + ? time_stamp[j] * FrameShiftInMs() - time_stamp_gap_ + : 0; + if (j > 0) { + start = (time_stamp[j] - time_stamp[j - 1]) * FrameShiftInMs() < + time_stamp_gap_ + ? (time_stamp[j - 1] + time_stamp[j]) / 2 * + FrameShiftInMs() + : start; + } + + int end = time_stamp[j] * FrameShiftInMs(); + if (j < input.size() - 1) { + end = (time_stamp[j + 1] - time_stamp[j]) * FrameShiftInMs() < + time_stamp_gap_ + ? (time_stamp[j + 1] + time_stamp[j]) / 2 * + FrameShiftInMs() + : end; + } + + WordPiece word_piece(word, offset + start, offset + end); + path.word_pieces.emplace_back(word_piece); + } + } + + // if (post_processor_ != nullptr) { + // path.sentence = post_processor_->Process(path.sentence, finish); + // } + + result_.emplace_back(path); + } + + if (DecodedSomething()) { + VLOG(1) << "Partial CTC result " << result_[0].sentence; + } +} + +void U2Recognizer::AttentionRescoring() { + decoder_->FinalizeSearch(); + UpdateResult(true); + + // No need to do rescoring + if (0.0 == opts_.decoder_opts.rescoring_weight) { + LOG_EVERY_N(WARNING, 3) << "Not do AttentionRescoring!"; + return; + } + LOG_EVERY_N(WARNING, 3) << "Do AttentionRescoring!"; + + // Inputs() returns N-best input ids, which is the basic unit for rescoring + // In CtcPrefixBeamSearch, inputs are the same to outputs + const auto& hypotheses = decoder_->Inputs(); + int num_hyps = hypotheses.size(); + if (num_hyps <= 0) { + return; + } + + kaldi::Timer timer; + std::vector rescoring_score; + decodable_->AttentionRescoring( + hypotheses, opts_.decoder_opts.reverse_weight, &rescoring_score); + VLOG(1) << "Attention Rescoring takes " << timer.Elapsed() << " sec."; + + // combine ctc score and rescoring score + for (size_t i = 0; i < num_hyps; i++) { + VLOG(1) << "hyp " << i << " rescoring_score: " << rescoring_score[i] + << " ctc_score: " << result_[i].score; + result_[i].score = opts_.decoder_opts.rescoring_weight * rescoring_score[i] + + opts_.decoder_opts.ctc_weight * result_[i].score; + } + + std::sort(result_.begin(), result_.end(), DecodeResult::CompareFunc); + VLOG(1) << "result: " << result_[0].sentence + << " score: " << result_[0].score; +} + +std::string U2Recognizer::GetFinalResult() { + return result_[0].sentence; +} + +std::string U2Recognizer::GetPartialResult() { + return result_[0].sentence; +} + +void U2Recognizer::SetFinished() { + feature_pipeline_->SetFinished(); + input_finished_ = true; +} + + +} // namespace ppspeech \ No newline at end of file diff --git a/speechx/speechx/decoder/u2_recognizer.h b/speechx/speechx/decoder/u2_recognizer.h new file mode 100644 index 000000000..0947e5933 --- /dev/null +++ b/speechx/speechx/decoder/u2_recognizer.h @@ -0,0 +1,164 @@ + + +// Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#pragma once + +#include "decoder/common.h" +#include "decoder/ctc_beam_search_opt.h" +#include "decoder/ctc_prefix_beam_search_decoder.h" +#include "decoder/decoder_itf.h" +#include "frontend/audio/feature_pipeline.h" +#include "nnet/decodable.h" + +#include "fst/fstlib.h" +#include "fst/symbol-table.h" + +namespace ppspeech { + + +struct DecodeOptions { + // chunk_size is the frame number of one chunk after subsampling. + // e.g. if subsample rate is 4 and chunk_size = 16, the frames in + // one chunk are 67=16*4 + 3, stride is 64=16*4 + int chunk_size; + int num_left_chunks; + + // final_score = rescoring_weight * rescoring_score + ctc_weight * + // ctc_score; + // rescoring_score = left_to_right_score * (1 - reverse_weight) + + // right_to_left_score * reverse_weight + // Please note the concept of ctc_scores + // in the following two search methods are different. For + // CtcPrefixBeamSerch, + // it's a sum(prefix) score + context score For CtcWfstBeamSerch, it's a + // max(viterbi) path score + context score So we should carefully set + // ctc_weight accroding to the search methods. + float ctc_weight; + float rescoring_weight; + float reverse_weight; + + // CtcEndpointConfig ctc_endpoint_opts; + CTCBeamSearchOptions ctc_prefix_search_opts; + + DecodeOptions() + : chunk_size(16), + num_left_chunks(-1), + ctc_weight(0.5), + rescoring_weight(1.0), + reverse_weight(0.0) {} + + void Register(kaldi::OptionsItf* opts) { + std::string module = "DecoderConfig: "; + opts->Register( + "chunk-size", + &chunk_size, + module + "the frame number of one chunk after subsampling."); + opts->Register("num-left-chunks", + &num_left_chunks, + module + "the left history chunks number."); + opts->Register("ctc-weight", + &ctc_weight, + module + + "ctc weight for rescore. final_score = " + "rescoring_weight * rescoring_score + ctc_weight * " + "ctc_score."); + opts->Register("rescoring-weight", + &rescoring_weight, + module + + "attention score weight for rescore. final_score = " + "rescoring_weight * rescoring_score + ctc_weight * " + "ctc_score."); + opts->Register("reverse-weight", + &reverse_weight, + module + + "reverse decoder weight. rescoring_score = " + "left_to_right_score * (1 - reverse_weight) + " + "right_to_left_score * reverse_weight."); + } +}; + + +struct U2RecognizerResource { + FeaturePipelineOptions feature_pipeline_opts{}; + ModelOptions model_opts{}; + DecodeOptions decoder_opts{}; + // CTCBeamSearchOptions beam_search_opts; + kaldi::BaseFloat acoustic_scale{1.0}; + std::string vocab_path{}; +}; + + +class U2Recognizer { + public: + explicit U2Recognizer(const U2RecognizerResource& resouce); + void Reset(); + void ResetContinuousDecoding(); + + void Accept(const kaldi::VectorBase& waves); + void Decode(); + void Rescoring(); + + + std::string GetFinalResult(); + std::string GetPartialResult(); + + void SetFinished(); + bool IsFinished() { return input_finished_; } + + bool DecodedSomething() const { + return !result_.empty() && !result_[0].sentence.empty(); + } + + + int FrameShiftInMs() const { + // one decoder frame length in ms + return decodable_->Nnet()->SubsamplingRate() * + feature_pipeline_->FrameShift(); + } + + + const std::vector& Result() const { return result_; } + + private: + void AttentionRescoring(); + void UpdateResult(bool finish = false); + + private: + U2RecognizerResource opts_; + + // std::shared_ptr resource_; + // U2RecognizerResource resource_; + std::shared_ptr feature_pipeline_; + std::shared_ptr decodable_; + std::unique_ptr decoder_; + + // e2e unit symbol table + std::shared_ptr unit_table_ = nullptr; + std::shared_ptr symbol_table_ = nullptr; + + std::vector result_; + + // global decoded frame offset + int global_frame_offset_; + // cur decoded frame num + int num_frames_; + // timestamp gap between words in a sentence + const int time_stamp_gap_ = 100; + + bool input_finished_; +}; + +} // namespace ppspeech \ No newline at end of file diff --git a/speechx/speechx/decoder/u2_recognizer_main.cc b/speechx/speechx/decoder/u2_recognizer_main.cc new file mode 100644 index 000000000..70bc7d675 --- /dev/null +++ b/speechx/speechx/decoder/u2_recognizer_main.cc @@ -0,0 +1,137 @@ +// Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#include "decoder/u2_recognizer.h" +#include "decoder/param.h" +#include "kaldi/feat/wave-reader.h" +#include "kaldi/util/table-types.h" + +DEFINE_string(wav_rspecifier, "", "test feature rspecifier"); +DEFINE_string(result_wspecifier, "", "test result wspecifier"); +DEFINE_double(streaming_chunk, 0.36, "streaming feature chunk size"); +DEFINE_int32(sample_rate, 16000, "sample rate"); + + +ppspeech::U2RecognizerResource InitOpts() { + ppspeech::U2RecognizerResource resource; + resource.acoustic_scale = FLAGS_acoustic_scale; + resource.feature_pipeline_opts = ppspeech::InitFeaturePipelineOptions(); + + ppspeech::ModelOptions model_opts; + model_opts.model_path = FLAGS_model_path; + + resource.model_opts = model_opts; + + ppspeech::DecodeOptions decoder_opts; + decoder_opts.chunk_size=16; + decoder_opts.num_left_chunks = -1; + decoder_opts.ctc_weight = 0.5; + decoder_opts.rescoring_weight = 1.0; + decoder_opts.reverse_weight = 0.3; + decoder_opts.ctc_prefix_search_opts.blank = 0; + decoder_opts.ctc_prefix_search_opts.first_beam_size = 10; + decoder_opts.ctc_prefix_search_opts.second_beam_size = 10; + + resource.decoder_opts = decoder_opts; + return resource; +} + +int main(int argc, char* argv[]) { + gflags::SetUsageMessage("Usage:"); + gflags::ParseCommandLineFlags(&argc, &argv, false); + google::InitGoogleLogging(argv[0]); + google::InstallFailureSignalHandler(); + FLAGS_logtostderr = 1; + + int32 num_done = 0, num_err = 0; + double tot_wav_duration = 0.0; + + ppspeech::U2RecognizerResource resource = InitOpts(); + ppspeech::U2Recognizer recognizer(resource); + + kaldi::SequentialTableReader wav_reader( + FLAGS_wav_rspecifier); + kaldi::TokenWriter result_writer(FLAGS_result_wspecifier); + + int sample_rate = FLAGS_sample_rate; + float streaming_chunk = FLAGS_streaming_chunk; + int chunk_sample_size = streaming_chunk * sample_rate; + LOG(INFO) << "sr: " << sample_rate; + LOG(INFO) << "chunk size (s): " << streaming_chunk; + LOG(INFO) << "chunk size (sample): " << chunk_sample_size; + + kaldi::Timer timer; + + for (; !wav_reader.Done(); wav_reader.Next()) { + std::string utt = wav_reader.Key(); + const kaldi::WaveData& wave_data = wav_reader.Value(); + LOG(INFO) << "utt: " << utt; + LOG(INFO) << "wav dur: " << wave_data.Duration() << " sec."; + tot_wav_duration += wave_data.Duration(); + + int32 this_channel = 0; + kaldi::SubVector waveform(wave_data.Data(), + this_channel); + int tot_samples = waveform.Dim(); + LOG(INFO) << "wav len (sample): " << tot_samples; + + int sample_offset = 0; + while (sample_offset < tot_samples) { + int cur_chunk_size = + std::min(chunk_sample_size, tot_samples - sample_offset); + + kaldi::Vector wav_chunk(cur_chunk_size); + for (int i = 0; i < cur_chunk_size; ++i) { + wav_chunk(i) = waveform(sample_offset + i); + } + // wav_chunk = waveform.Range(sample_offset + i, cur_chunk_size); + + recognizer.Accept(wav_chunk); + if (cur_chunk_size < chunk_sample_size) { + recognizer.SetFinished(); + } + recognizer.Decode(); + LOG(INFO) << "Pratial result: " << recognizer.GetPartialResult(); + + // no overlap + sample_offset += cur_chunk_size; + } + // second pass decoding + recognizer.Rescoring(); + + std::string result = recognizer.GetFinalResult(); + + recognizer.Reset(); + + if (result.empty()) { + // the TokenWriter can not write empty string. + ++num_err; + LOG(INFO) << " the result of " << utt << " is empty"; + continue; + } + + LOG(INFO) << " the result of " << utt << " is " << result; + + result_writer.Write(utt, result); + + ++num_done; + } + + double elapsed = timer.Elapsed(); + + LOG(INFO) << "Done " << num_done << " out of " << (num_err + num_done); + LOG(INFO) << "cost:" << elapsed << " sec"; + LOG(INFO) << "total wav duration is: " << tot_wav_duration << " sec"; + LOG(INFO) << "the RTF is: " << elapsed / tot_wav_duration; +} diff --git a/speechx/speechx/frontend/audio/feature_pipeline.cc b/speechx/speechx/frontend/audio/feature_pipeline.cc index 9cacff9f7..9fc35c958 100644 --- a/speechx/speechx/frontend/audio/feature_pipeline.cc +++ b/speechx/speechx/frontend/audio/feature_pipeline.cc @@ -18,7 +18,7 @@ namespace ppspeech { using std::unique_ptr; -FeaturePipeline::FeaturePipeline(const FeaturePipelineOptions& opts) { +FeaturePipeline::FeaturePipeline(const FeaturePipelineOptions& opts) : opts_(opts) { unique_ptr data_source( new ppspeech::AudioCache(1000 * kint16max, opts.to_float32)); diff --git a/speechx/speechx/frontend/audio/feature_pipeline.h b/speechx/speechx/frontend/audio/feature_pipeline.h index 48f95e3f3..613f69c6a 100644 --- a/speechx/speechx/frontend/audio/feature_pipeline.h +++ b/speechx/speechx/frontend/audio/feature_pipeline.h @@ -26,7 +26,6 @@ #include "frontend/audio/normalizer.h" namespace ppspeech { - struct FeaturePipelineOptions { std::string cmvn_file; bool to_float32; // true, only for linear feature @@ -60,7 +59,21 @@ class FeaturePipeline : public FrontendInterface { virtual bool IsFinished() const { return base_extractor_->IsFinished(); } virtual void Reset() { base_extractor_->Reset(); } + const FeaturePipelineOptions& Config() { return opts_; } + + const BaseFloat FrameShift() const { + return opts_.fbank_opts.frame_opts.frame_shift_ms; + } + const BaseFloat FrameLength() const { + return opts_.fbank_opts.frame_opts.frame_length_ms; + } + const BaseFloat SampleRate() const { + return opts_.fbank_opts.frame_opts.samp_freq; + } + private: + FeaturePipelineOptions opts_; std::unique_ptr base_extractor_; }; -} + +} // namespace ppspeech diff --git a/speechx/speechx/nnet/ds2_nnet.cc b/speechx/speechx/nnet/ds2_nnet.cc index c6add03c3..8c83f8324 100644 --- a/speechx/speechx/nnet/ds2_nnet.cc +++ b/speechx/speechx/nnet/ds2_nnet.cc @@ -48,6 +48,7 @@ void PaddleNnet::InitCacheEncouts(const ModelOptions& opts) { } PaddleNnet::PaddleNnet(const ModelOptions& opts) : opts_(opts) { + subsampling_rate_ = opts.subsample_rate; paddle_infer::Config config; config.SetModel(opts.model_path, opts.param_path); if (opts.use_gpu) { diff --git a/speechx/speechx/nnet/ds2_nnet.h b/speechx/speechx/nnet/ds2_nnet.h index e8a49c7d3..2a53e5f7e 100644 --- a/speechx/speechx/nnet/ds2_nnet.h +++ b/speechx/speechx/nnet/ds2_nnet.h @@ -67,6 +67,7 @@ class PaddleNnet : public NnetInterface { bool IsLogProb() override { return false; } + std::shared_ptr> GetCacheEncoder( const std::string& name); @@ -85,6 +86,7 @@ class PaddleNnet : public NnetInterface { std::map predictor_to_thread_id; std::map cache_names_idx_; std::vector>> cache_encouts_; + ModelOptions opts_; public: diff --git a/speechx/speechx/nnet/nnet_itf.h b/speechx/speechx/nnet/nnet_itf.h index 2e21ff9bf..109f54e0f 100644 --- a/speechx/speechx/nnet/nnet_itf.h +++ b/speechx/speechx/nnet/nnet_itf.h @@ -35,6 +35,7 @@ struct ModelOptions { std::string cache_shape; bool enable_fc_padding; bool enable_profile; + int subsample_rate; ModelOptions() : model_path(""), param_path(""), @@ -46,7 +47,8 @@ struct ModelOptions { cache_shape(""), switch_ir_optim(false), enable_fc_padding(false), - enable_profile(false) {} + enable_profile(false), + subsample_rate(0) {} void Register(kaldi::OptionsItf* opts) { opts->Register("model-path", &model_path, "model file path"); @@ -102,9 +104,14 @@ class NnetInterface { // true, nnet output is logprob; otherwise is prob, virtual bool IsLogProb() = 0; + int SubsamplingRate() const { return subsampling_rate_; } + // using to get encoder outs. e.g. seq2seq with Attention model. virtual void EncoderOuts( std::vector>* encoder_out) const = 0; + + protected: + int subsampling_rate_{1}; }; } // namespace ppspeech diff --git a/speechx/speechx/nnet/u2_nnet.h b/speechx/speechx/nnet/u2_nnet.h index 1bac652e8..7058ea949 100644 --- a/speechx/speechx/nnet/u2_nnet.h +++ b/speechx/speechx/nnet/u2_nnet.h @@ -30,7 +30,7 @@ class U2NnetBase : public NnetInterface { public: virtual int context() const { return right_context_ + 1; } virtual int right_context() const { return right_context_; } - virtual int subsampling_rate() const { return subsampling_rate_; } + virtual int eos() const { return eos_; } virtual int sos() const { return sos_; } virtual int is_bidecoder() const { return is_bidecoder_; } @@ -64,7 +64,6 @@ class U2NnetBase : public NnetInterface { protected: // model specification int right_context_{0}; - int subsampling_rate_{1}; int sos_{0}; int eos_{0}; diff --git a/speechx/speechx/protocol/websocket/CMakeLists.txt b/speechx/speechx/protocol/websocket/CMakeLists.txt index 0f73fd24c..a171d84d0 100644 --- a/speechx/speechx/protocol/websocket/CMakeLists.txt +++ b/speechx/speechx/protocol/websocket/CMakeLists.txt @@ -1,5 +1,3 @@ -# project(websocket) - add_library(websocket STATIC websocket_server.cc websocket_client.cc diff --git a/speechx/speechx/protocol/websocket/websocket_server_main.cc b/speechx/speechx/protocol/websocket/websocket_server_main.cc index 109da96b6..9c01a0a1b 100644 --- a/speechx/speechx/protocol/websocket/websocket_server_main.cc +++ b/speechx/speechx/protocol/websocket/websocket_server_main.cc @@ -17,11 +17,38 @@ DEFINE_int32(port, 8082, "websocket listening port"); +ppspeech::RecognizerResource InitRecognizerResoure() { + ppspeech::RecognizerResource resource; + resource.acoustic_scale = FLAGS_acoustic_scale; + resource.feature_pipeline_opts = ppspeech::InitFeaturePipelineOptions(); + + ppspeech::ModelOptions model_opts; + model_opts.model_path = FLAGS_model_path; + model_opts.param_path = FLAGS_param_path; + model_opts.cache_names = FLAGS_model_cache_names; + model_opts.cache_shape = FLAGS_model_cache_shapes; + model_opts.input_names = FLAGS_model_input_names; + model_opts.output_names = FLAGS_model_output_names; + model_opts.subsample_rate = FLAGS_downsampling_rate; + resource.model_opts = model_opts; + + ppspeech::TLGDecoderOptions decoder_opts; + decoder_opts.word_symbol_table = FLAGS_word_symbol_table; + decoder_opts.fst_path = FLAGS_graph_path; + decoder_opts.opts.max_active = FLAGS_max_active; + decoder_opts.opts.beam = FLAGS_beam; + decoder_opts.opts.lattice_beam = FLAGS_lattice_beam; + + resource.tlg_opts = decoder_opts; + + return resource; +} + int main(int argc, char *argv[]) { gflags::ParseCommandLineFlags(&argc, &argv, false); google::InitGoogleLogging(argv[0]); - ppspeech::RecognizerResource resource = ppspeech::InitRecognizerResoure(); + ppspeech::RecognizerResource resource = InitRecognizerResoure(); ppspeech::WebSocketServer server(FLAGS_port, resource); LOG(INFO) << "Listening at port " << FLAGS_port;