diff --git a/speechx/examples/u2pp_ol/wenetspeech/README.md b/speechx/examples/u2pp_ol/wenetspeech/README.md index a9a4578ff..9a8f8af51 100644 --- a/speechx/examples/u2pp_ol/wenetspeech/README.md +++ b/speechx/examples/u2pp_ol/wenetspeech/README.md @@ -25,4 +25,4 @@ run.sh --stop_stage 0 ``` ./run.sh --stage 3 --stop_stage 3 -``` \ No newline at end of file +``` diff --git a/speechx/speechx/codelab/nnet/ds2_model_test_main.cc b/speechx/speechx/codelab/nnet/ds2_model_test_main.cc index 7d99e8571..09f9e2fbc 100644 --- a/speechx/speechx/codelab/nnet/ds2_model_test_main.cc +++ b/speechx/speechx/codelab/nnet/ds2_model_test_main.cc @@ -21,6 +21,7 @@ #include #include #include + #include "base/flags.h" #include "base/log.h" #include "paddle_inference_api.h" diff --git a/speechx/speechx/decoder/ctc_beam_search_decoder.cc b/speechx/speechx/decoder/ctc_beam_search_decoder.cc index 3f00ee35b..c4b35ff0f 100644 --- a/speechx/speechx/decoder/ctc_beam_search_decoder.cc +++ b/speechx/speechx/decoder/ctc_beam_search_decoder.cc @@ -13,9 +13,10 @@ // limitations under the License. +#include "decoder/ctc_beam_search_decoder.h" + #include "base/common.h" #include "decoder/ctc_decoders/decoder_utils.h" -#include "decoder/ctc_beam_search_decoder.h" #include "utils/file_utils.h" namespace ppspeech { @@ -24,10 +25,7 @@ using std::vector; using FSTMATCH = fst::SortedMatcher; CTCBeamSearch::CTCBeamSearch(const CTCBeamSearchOptions& opts) - : opts_(opts), - init_ext_scorer_(nullptr), - space_id_(-1), - root_(nullptr) { + : opts_(opts), init_ext_scorer_(nullptr), space_id_(-1), root_(nullptr) { LOG(INFO) << "dict path: " << opts_.dict_file; if (!ReadFileToVector(opts_.dict_file, &vocabulary_)) { LOG(INFO) << "load the dict failed"; @@ -41,7 +39,7 @@ CTCBeamSearch::CTCBeamSearch(const CTCBeamSearchOptions& opts) opts_.alpha, opts_.beta, opts_.lm_path, vocabulary_); } - CHECK(opts_.blank==0); + CHECK(opts_.blank == 0); auto it = std::find(vocabulary_.begin(), vocabulary_.end(), " "); space_id_ = it - vocabulary_.begin(); @@ -115,7 +113,7 @@ int CTCBeamSearch::DecodeLikelihoods(const vector>& probs, } vector> CTCBeamSearch::GetNBestPath(int n) { - int beam_size = n == -1 ? opts_.beam_size: std::min(n, opts_.beam_size); + int beam_size = n == -1 ? opts_.beam_size : std::min(n, opts_.beam_size); return get_beam_search_result(prefixes_, vocabulary_, beam_size); } diff --git a/speechx/speechx/decoder/ctc_prefix_beam_search_decoder.cc b/speechx/speechx/decoder/ctc_prefix_beam_search_decoder.cc index ce2d4dc2f..a0fe5b2ac 100644 --- a/speechx/speechx/decoder/ctc_prefix_beam_search_decoder.cc +++ b/speechx/speechx/decoder/ctc_prefix_beam_search_decoder.cc @@ -16,11 +16,12 @@ #include "decoder/ctc_prefix_beam_search_decoder.h" + +#include "absl/strings/str_join.h" #include "base/common.h" #include "decoder/ctc_beam_search_opt.h" #include "decoder/ctc_prefix_beam_search_score.h" #include "utils/math.h" -#include "absl/strings/str_join.h" #ifdef USE_PROFILING #include "paddle/fluid/platform/profiler.h" @@ -30,18 +31,17 @@ using paddle::platform::TracerEventType; namespace ppspeech { -CTCPrefixBeamSearch::CTCPrefixBeamSearch( - const std::string vocab_path, - const CTCBeamSearchOptions& opts) +CTCPrefixBeamSearch::CTCPrefixBeamSearch(const std::string vocab_path, + const CTCBeamSearchOptions& opts) : opts_(opts) { - - unit_table_ = std::shared_ptr(fst::SymbolTable::ReadText(vocab_path)); + unit_table_ = std::shared_ptr( + fst::SymbolTable::ReadText(vocab_path)); CHECK(unit_table_ != nullptr); Reset(); } -void CTCPrefixBeamSearch::Reset() { +void CTCPrefixBeamSearch::Reset() { num_frame_decoded_ = 0; cur_hyps_.clear(); @@ -65,10 +65,9 @@ void CTCPrefixBeamSearch::Reset() { hypotheses_.emplace_back(empty); likelihood_.emplace_back(prefix_score.TotalScore()); times_.emplace_back(empty); - } - -void CTCPrefixBeamSearch::InitDecoder() { Reset(); } +} +void CTCPrefixBeamSearch::InitDecoder() { Reset(); } void CTCPrefixBeamSearch::AdvanceDecode( @@ -296,9 +295,7 @@ void CTCPrefixBeamSearch::UpdateOutputs( outputs_.emplace_back(output); } -void CTCPrefixBeamSearch::FinalizeSearch() { - UpdateFinalContext(); -} +void CTCPrefixBeamSearch::FinalizeSearch() { UpdateFinalContext(); } void CTCPrefixBeamSearch::UpdateFinalContext() { if (context_graph_ == nullptr) return; @@ -311,8 +308,8 @@ void CTCPrefixBeamSearch::UpdateFinalContext() { for (const auto& prefix : hypotheses_) { PrefixScore& prefix_score = cur_hyps_[prefix]; if (prefix_score.context_score != 0) { - prefix_score.UpdateContext(context_graph_, prefix_score, 0, - prefix.size()); + prefix_score.UpdateContext( + context_graph_, prefix_score, 0, prefix.size()); } } std::vector, PrefixScore>> arr(cur_hyps_.begin(), @@ -323,48 +320,44 @@ void CTCPrefixBeamSearch::UpdateFinalContext() { UpdateHypotheses(arr); } - std::string CTCPrefixBeamSearch::GetBestPath(int index) { +std::string CTCPrefixBeamSearch::GetBestPath(int index) { int n_hyps = Outputs().size(); CHECK(n_hyps > 0); CHECK(index < n_hyps); std::vector one = Outputs()[index]; std::string sentence; - for (int i = 0; i < one.size(); i++){ + for (int i = 0; i < one.size(); i++) { sentence += unit_table_->Find(one[i]); } return sentence; - } +} - std::string CTCPrefixBeamSearch::GetBestPath() { - return GetBestPath(0); - } +std::string CTCPrefixBeamSearch::GetBestPath() { return GetBestPath(0); } - std::vector> CTCPrefixBeamSearch::GetNBestPath(int n) { - int hyps_size = hypotheses_.size(); - CHECK(hyps_size > 0); +std::vector> CTCPrefixBeamSearch::GetNBestPath( + int n) { + int hyps_size = hypotheses_.size(); + CHECK(hyps_size > 0); - int min_n = n == -1 ? hypotheses_.size() : std::min(n, hyps_size); + int min_n = n == -1 ? hypotheses_.size() : std::min(n, hyps_size); - std::vector> n_best; - n_best.reserve(min_n); + std::vector> n_best; + n_best.reserve(min_n); - for (int i = 0; i < min_n; i++){ - n_best.emplace_back(Likelihood()[i], GetBestPath(i) ); - } - return n_best; - } + for (int i = 0; i < min_n; i++) { + n_best.emplace_back(Likelihood()[i], GetBestPath(i)); + } + return n_best; +} - std::vector> CTCPrefixBeamSearch::GetNBestPath() { +std::vector> +CTCPrefixBeamSearch::GetNBestPath() { return GetNBestPath(-1); - } - -std::string CTCPrefixBeamSearch::GetFinalBestPath() { - return GetBestPath(); } -std::string CTCPrefixBeamSearch::GetPartialResult() { - return GetBestPath(); -} +std::string CTCPrefixBeamSearch::GetFinalBestPath() { return GetBestPath(); } + +std::string CTCPrefixBeamSearch::GetPartialResult() { return GetBestPath(); } } // namespace ppspeech \ No newline at end of file diff --git a/speechx/speechx/decoder/ctc_prefix_beam_search_decoder_main.cc b/speechx/speechx/decoder/ctc_prefix_beam_search_decoder_main.cc index 7a488bb0d..d9cca1471 100644 --- a/speechx/speechx/decoder/ctc_prefix_beam_search_decoder_main.cc +++ b/speechx/speechx/decoder/ctc_prefix_beam_search_decoder_main.cc @@ -12,14 +12,14 @@ // See the License for the specific language governing permissions and // limitations under the License. +#include "absl/strings/str_split.h" #include "base/common.h" #include "decoder/ctc_prefix_beam_search_decoder.h" #include "frontend/audio/data_cache.h" +#include "fst/symbol-table.h" #include "kaldi/util/table-types.h" #include "nnet/decodable.h" #include "nnet/u2_nnet.h" -#include "absl/strings/str_split.h" -#include "fst/symbol-table.h" DEFINE_string(feature_rspecifier, "", "test feature rspecifier"); DEFINE_string(result_wspecifier, "", "test result wspecifier"); @@ -64,8 +64,7 @@ int main(int argc, char* argv[]) { // nnet ppspeech::ModelOptions model_opts; model_opts.model_path = FLAGS_model_path; - std::shared_ptr nnet( - new ppspeech::U2Nnet(model_opts)); + std::shared_ptr nnet(new ppspeech::U2Nnet(model_opts)); // decodeable std::shared_ptr raw_data(new ppspeech::DataCache()); @@ -114,9 +113,9 @@ int main(int argc, char* argv[]) { ori_feature_len - chunk_idx * chunk_stride, chunk_size); } if (this_chunk_size < receptive_field_length) { - LOG(WARNING) << "utt: " << utt << " skip last " - << this_chunk_size << " frames, expect is " - << receptive_field_length; + LOG(WARNING) + << "utt: " << utt << " skip last " << this_chunk_size + << " frames, expect is " << receptive_field_length; break; } @@ -127,7 +126,7 @@ int main(int argc, char* argv[]) { for (int row_id = 0; row_id < this_chunk_size; ++row_id) { kaldi::SubVector feat_row(feature, start); kaldi::SubVector feature_chunk_row( - feature_chunk.Data() + row_id * feat_dim, feat_dim); + feature_chunk.Data() + row_id * feat_dim, feat_dim); feature_chunk_row.CopyFromVec(feat_row); ++start; @@ -151,7 +150,7 @@ int main(int argc, char* argv[]) { // get 1-best result std::string result = decoder.GetFinalBestPath(); - + // after process one utt, then reset state. decodable->Reset(); decoder.Reset(); diff --git a/speechx/speechx/decoder/ctc_tlg_decoder.cc b/speechx/speechx/decoder/ctc_tlg_decoder.cc index 4d0a21d58..2c2b6d3c9 100644 --- a/speechx/speechx/decoder/ctc_tlg_decoder.cc +++ b/speechx/speechx/decoder/ctc_tlg_decoder.cc @@ -33,9 +33,7 @@ void TLGDecoder::Reset() { return; } -void TLGDecoder::InitDecoder() { - Reset(); -} +void TLGDecoder::InitDecoder() { Reset(); } void TLGDecoder::AdvanceDecode( const std::shared_ptr& decodable) { @@ -50,7 +48,6 @@ void TLGDecoder::AdvanceDecoding(kaldi::DecodableInterface* decodable) { } - std::string TLGDecoder::GetPartialResult() { if (num_frame_decoded_ == 0) { // Assertion failed: (this->NumFramesDecoded() > 0 && "You cannot call @@ -93,4 +90,4 @@ std::string TLGDecoder::GetFinalBestPath() { return words; } -} +} // namespace ppspeech diff --git a/speechx/speechx/decoder/ctc_tlg_decoder_main.cc b/speechx/speechx/decoder/ctc_tlg_decoder_main.cc index f262101ac..e9bd8a3f4 100644 --- a/speechx/speechx/decoder/ctc_tlg_decoder_main.cc +++ b/speechx/speechx/decoder/ctc_tlg_decoder_main.cc @@ -15,14 +15,12 @@ // todo refactor, repalce with gtest #include "base/common.h" - +#include "decoder/ctc_tlg_decoder.h" +#include "decoder/param.h" #include "frontend/audio/data_cache.h" +#include "kaldi/util/table-types.h" #include "nnet/decodable.h" #include "nnet/ds2_nnet.h" -#include "decoder/param.h" -#include "decoder/ctc_tlg_decoder.h" - -#include "kaldi/util/table-types.h" DEFINE_string(feature_rspecifier, "", "test feature rspecifier"); @@ -47,12 +45,13 @@ int main(int argc, char* argv[]) { int32 num_done = 0, num_err = 0; - ppspeech::TLGDecoderOptions opts = ppspeech::TLGDecoderOptions::InitFromFlags(); + ppspeech::TLGDecoderOptions opts = + ppspeech::TLGDecoderOptions::InitFromFlags(); opts.opts.beam = 15.0; opts.opts.lattice_beam = 7.5; ppspeech::TLGDecoder decoder(opts); - ppspeech::ModelOptions model_opts = ppspeech::ModelOptions::InitFromFlags(); + ppspeech::ModelOptions model_opts = ppspeech::ModelOptions::InitFromFlags(); std::shared_ptr nnet( new ppspeech::PaddleNnet(model_opts)); @@ -67,7 +66,7 @@ int main(int argc, char* argv[]) { LOG(INFO) << "chunk size (frame): " << chunk_size; LOG(INFO) << "chunk stride (frame): " << chunk_stride; LOG(INFO) << "receptive field (frame): " << receptive_field_length; - + decoder.InitDecoder(); kaldi::Timer timer; for (; !feature_reader.Done(); feature_reader.Next()) { diff --git a/speechx/speechx/decoder/recognizer.cc b/speechx/speechx/decoder/recognizer.cc index bb9ea1872..870aa40ac 100644 --- a/speechx/speechx/decoder/recognizer.cc +++ b/speechx/speechx/decoder/recognizer.cc @@ -17,12 +17,12 @@ namespace ppspeech { -using kaldi::Vector; -using kaldi::VectorBase; using kaldi::BaseFloat; -using std::vector; using kaldi::SubVector; +using kaldi::Vector; +using kaldi::VectorBase; using std::unique_ptr; +using std::vector; Recognizer::Recognizer(const RecognizerResource& resource) { diff --git a/speechx/speechx/decoder/recognizer_main.cc b/speechx/speechx/decoder/recognizer_main.cc index 662943b57..8e83b1888 100644 --- a/speechx/speechx/decoder/recognizer_main.cc +++ b/speechx/speechx/decoder/recognizer_main.cc @@ -12,8 +12,8 @@ // See the License for the specific language governing permissions and // limitations under the License. -#include "decoder/recognizer.h" #include "decoder/param.h" +#include "decoder/recognizer.h" #include "kaldi/feat/wave-reader.h" #include "kaldi/util/table-types.h" @@ -25,8 +25,9 @@ DEFINE_int32(sample_rate, 16000, "sample rate"); ppspeech::RecognizerResource InitRecognizerResoure() { ppspeech::RecognizerResource resource; resource.acoustic_scale = FLAGS_acoustic_scale; - resource.feature_pipeline_opts = ppspeech::FeaturePipelineOptions::InitFromFlags(); - resource.model_opts = ppspeech::ModelOptions::InitFromFlags(); + resource.feature_pipeline_opts = + ppspeech::FeaturePipelineOptions::InitFromFlags(); + resource.model_opts = ppspeech::ModelOptions::InitFromFlags(); resource.tlg_opts = ppspeech::TLGDecoderOptions::InitFromFlags(); return resource; } diff --git a/speechx/speechx/decoder/u2_recognizer.cc b/speechx/speechx/decoder/u2_recognizer.cc index 8fcc5d79b..04712e7b5 100644 --- a/speechx/speechx/decoder/u2_recognizer.cc +++ b/speechx/speechx/decoder/u2_recognizer.cc @@ -13,18 +13,20 @@ // limitations under the License. #include "decoder/u2_recognizer.h" + #include "nnet/u2_nnet.h" namespace ppspeech { -using kaldi::Vector; -using kaldi::VectorBase; using kaldi::BaseFloat; -using std::vector; using kaldi::SubVector; +using kaldi::Vector; +using kaldi::VectorBase; using std::unique_ptr; +using std::vector; -U2Recognizer::U2Recognizer(const U2RecognizerResource& resource): opts_(resource) { +U2Recognizer::U2Recognizer(const U2RecognizerResource& resource) + : opts_(resource) { const FeaturePipelineOptions& feature_opts = resource.feature_pipeline_opts; feature_pipeline_.reset(new FeaturePipeline(feature_opts)); @@ -34,7 +36,8 @@ U2Recognizer::U2Recognizer(const U2RecognizerResource& resource): opts_(resource decodable_.reset(new Decodable(nnet, feature_pipeline_, am_scale)); CHECK(resource.vocab_path != ""); - decoder_.reset(new CTCPrefixBeamSearch(resource.vocab_path, resource.decoder_opts.ctc_prefix_search_opts)); + decoder_.reset(new CTCPrefixBeamSearch( + resource.vocab_path, resource.decoder_opts.ctc_prefix_search_opts)); unit_table_ = decoder_->VocabTable(); symbol_table_ = unit_table_; @@ -70,140 +73,140 @@ void U2Recognizer::Accept(const VectorBase& waves) { } -void U2Recognizer::Decode() { - decoder_->AdvanceDecode(decodable_); +void U2Recognizer::Decode() { + decoder_->AdvanceDecode(decodable_); UpdateResult(false); } void U2Recognizer::Rescoring() { - // Do attention Rescoring - kaldi::Timer timer; - AttentionRescoring(); - VLOG(1) << "Rescoring cost latency: " << timer.Elapsed() << " sec."; + // Do attention Rescoring + kaldi::Timer timer; + AttentionRescoring(); + VLOG(1) << "Rescoring cost latency: " << timer.Elapsed() << " sec."; } void U2Recognizer::UpdateResult(bool finish) { - const auto& hypotheses = decoder_->Outputs(); - const auto& inputs = decoder_->Inputs(); - const auto& likelihood = decoder_->Likelihood(); - const auto& times = decoder_->Times(); - result_.clear(); - - CHECK_EQ(hypotheses.size(), likelihood.size()); - for (size_t i = 0; i < hypotheses.size(); i++) { - const std::vector& hypothesis = hypotheses[i]; - - DecodeResult path; - path.score = likelihood[i]; - for (size_t j = 0; j < hypothesis.size(); j++) { - std::string word = symbol_table_->Find(hypothesis[j]); - // A detailed explanation of this if-else branch can be found in - // https://github.com/wenet-e2e/wenet/issues/583#issuecomment-907994058 - if (decoder_->Type() == kWfstBeamSearch) { - path.sentence += (" " + word); - } else { - path.sentence += (word); - } - } + const auto& hypotheses = decoder_->Outputs(); + const auto& inputs = decoder_->Inputs(); + const auto& likelihood = decoder_->Likelihood(); + const auto& times = decoder_->Times(); + result_.clear(); - // TimeStamp is only supported in final result - // TimeStamp of the output of CtcWfstBeamSearch may be inaccurate due to - // various FST operations when building the decoding graph. So here we use - // time stamp of the input(e2e model unit), which is more accurate, and it - // requires the symbol table of the e2e model used in training. - if (unit_table_ != nullptr && finish) { - int offset = global_frame_offset_ * FrameShiftInMs(); + CHECK_EQ(hypotheses.size(), likelihood.size()); + for (size_t i = 0; i < hypotheses.size(); i++) { + const std::vector& hypothesis = hypotheses[i]; + + DecodeResult path; + path.score = likelihood[i]; + for (size_t j = 0; j < hypothesis.size(); j++) { + std::string word = symbol_table_->Find(hypothesis[j]); + // A detailed explanation of this if-else branch can be found in + // https://github.com/wenet-e2e/wenet/issues/583#issuecomment-907994058 + if (decoder_->Type() == kWfstBeamSearch) { + path.sentence += (" " + word); + } else { + path.sentence += (word); + } + } + + // TimeStamp is only supported in final result + // TimeStamp of the output of CtcWfstBeamSearch may be inaccurate due to + // various FST operations when building the decoding graph. So here we + // use time stamp of the input(e2e model unit), which is more accurate, + // and it requires the symbol table of the e2e model used in training. + if (unit_table_ != nullptr && finish) { + int offset = global_frame_offset_ * FrameShiftInMs(); - const std::vector& input = inputs[i]; - const std::vector time_stamp = times[i]; - CHECK_EQ(input.size(), time_stamp.size()); + const std::vector& input = inputs[i]; + const std::vector time_stamp = times[i]; + CHECK_EQ(input.size(), time_stamp.size()); - for (size_t j = 0; j < input.size(); j++) { - std::string word = unit_table_->Find(input[j]); + for (size_t j = 0; j < input.size(); j++) { + std::string word = unit_table_->Find(input[j]); - int start = time_stamp[j] * FrameShiftInMs() - time_stamp_gap_ > 0 + int start = + time_stamp[j] * FrameShiftInMs() - time_stamp_gap_ > 0 ? time_stamp[j] * FrameShiftInMs() - time_stamp_gap_ : 0; - if (j > 0) { - start = (time_stamp[j] - time_stamp[j - 1]) * FrameShiftInMs() < - time_stamp_gap_ - ? (time_stamp[j - 1] + time_stamp[j]) / 2 * - FrameShiftInMs() - : start; + if (j > 0) { + start = + (time_stamp[j] - time_stamp[j - 1]) * FrameShiftInMs() < + time_stamp_gap_ + ? (time_stamp[j - 1] + time_stamp[j]) / 2 * + FrameShiftInMs() + : start; + } + + int end = time_stamp[j] * FrameShiftInMs(); + if (j < input.size() - 1) { + end = + (time_stamp[j + 1] - time_stamp[j]) * FrameShiftInMs() < + time_stamp_gap_ + ? (time_stamp[j + 1] + time_stamp[j]) / 2 * + FrameShiftInMs() + : end; + } + + WordPiece word_piece(word, offset + start, offset + end); + path.word_pieces.emplace_back(word_piece); + } } - int end = time_stamp[j] * FrameShiftInMs(); - if (j < input.size() - 1) { - end = (time_stamp[j + 1] - time_stamp[j]) * FrameShiftInMs() < - time_stamp_gap_ - ? (time_stamp[j + 1] + time_stamp[j]) / 2 * - FrameShiftInMs() - : end; - } + // if (post_processor_ != nullptr) { + // path.sentence = post_processor_->Process(path.sentence, finish); + // } - WordPiece word_piece(word, offset + start, offset + end); - path.word_pieces.emplace_back(word_piece); - } + result_.emplace_back(path); } - // if (post_processor_ != nullptr) { - // path.sentence = post_processor_->Process(path.sentence, finish); - // } - - result_.emplace_back(path); - } - - if (DecodedSomething()) { - VLOG(1) << "Partial CTC result " << result_[0].sentence; - } + if (DecodedSomething()) { + VLOG(1) << "Partial CTC result " << result_[0].sentence; + } } void U2Recognizer::AttentionRescoring() { - decoder_->FinalizeSearch(); - UpdateResult(true); - - // No need to do rescoring - if (0.0 == opts_.decoder_opts.rescoring_weight) { - LOG_EVERY_N(WARNING, 3) << "Not do AttentionRescoring!"; - return; - } - LOG_EVERY_N(WARNING, 3) << "Do AttentionRescoring!"; - - // Inputs() returns N-best input ids, which is the basic unit for rescoring - // In CtcPrefixBeamSearch, inputs are the same to outputs - const auto& hypotheses = decoder_->Inputs(); - int num_hyps = hypotheses.size(); - if (num_hyps <= 0) { - return; - } - - kaldi::Timer timer; - std::vector rescoring_score; - decodable_->AttentionRescoring( - hypotheses, opts_.decoder_opts.reverse_weight, &rescoring_score); - VLOG(1) << "Attention Rescoring takes " << timer.Elapsed() << " sec."; - - // combine ctc score and rescoring score - for (size_t i = 0; i < num_hyps; i++) { - VLOG(1) << "hyp " << i << " rescoring_score: " << rescoring_score[i] - << " ctc_score: " << result_[i].score; - result_[i].score = opts_.decoder_opts.rescoring_weight * rescoring_score[i] + - opts_.decoder_opts.ctc_weight * result_[i].score; - } - - std::sort(result_.begin(), result_.end(), DecodeResult::CompareFunc); - VLOG(1) << "result: " << result_[0].sentence - << " score: " << result_[0].score; -} + decoder_->FinalizeSearch(); + UpdateResult(true); -std::string U2Recognizer::GetFinalResult() { - return result_[0].sentence; -} + // No need to do rescoring + if (0.0 == opts_.decoder_opts.rescoring_weight) { + LOG_EVERY_N(WARNING, 3) << "Not do AttentionRescoring!"; + return; + } + LOG_EVERY_N(WARNING, 3) << "Do AttentionRescoring!"; + + // Inputs() returns N-best input ids, which is the basic unit for rescoring + // In CtcPrefixBeamSearch, inputs are the same to outputs + const auto& hypotheses = decoder_->Inputs(); + int num_hyps = hypotheses.size(); + if (num_hyps <= 0) { + return; + } + + kaldi::Timer timer; + std::vector rescoring_score; + decodable_->AttentionRescoring( + hypotheses, opts_.decoder_opts.reverse_weight, &rescoring_score); + VLOG(1) << "Attention Rescoring takes " << timer.Elapsed() << " sec."; + + // combine ctc score and rescoring score + for (size_t i = 0; i < num_hyps; i++) { + VLOG(1) << "hyp " << i << " rescoring_score: " << rescoring_score[i] + << " ctc_score: " << result_[i].score; + result_[i].score = + opts_.decoder_opts.rescoring_weight * rescoring_score[i] + + opts_.decoder_opts.ctc_weight * result_[i].score; + } -std::string U2Recognizer::GetPartialResult() { - return result_[0].sentence; + std::sort(result_.begin(), result_.end(), DecodeResult::CompareFunc); + VLOG(1) << "result: " << result_[0].sentence + << " score: " << result_[0].score; } +std::string U2Recognizer::GetFinalResult() { return result_[0].sentence; } + +std::string U2Recognizer::GetPartialResult() { return result_[0].sentence; } + void U2Recognizer::SetFinished() { feature_pipeline_->SetFinished(); input_finished_ = true; diff --git a/speechx/speechx/decoder/u2_recognizer_main.cc b/speechx/speechx/decoder/u2_recognizer_main.cc index b1a7b2e8e..9eb0441b1 100644 --- a/speechx/speechx/decoder/u2_recognizer_main.cc +++ b/speechx/speechx/decoder/u2_recognizer_main.cc @@ -12,8 +12,8 @@ // See the License for the specific language governing permissions and // limitations under the License. -#include "decoder/u2_recognizer.h" #include "decoder/param.h" +#include "decoder/u2_recognizer.h" #include "kaldi/feat/wave-reader.h" #include "kaldi/util/table-types.h" @@ -43,7 +43,8 @@ int main(int argc, char* argv[]) { LOG(INFO) << "chunk size (s): " << streaming_chunk; LOG(INFO) << "chunk size (sample): " << chunk_sample_size; - ppspeech::U2RecognizerResource resource = ppspeech::U2RecognizerResource::InitFromFlags(); + ppspeech::U2RecognizerResource resource = + ppspeech::U2RecognizerResource::InitFromFlags(); ppspeech::U2Recognizer recognizer(resource); kaldi::Timer timer; @@ -103,7 +104,7 @@ int main(int argc, char* argv[]) { } double elapsed = timer.Elapsed(); - + LOG(INFO) << "Done " << num_done << " out of " << (num_err + num_done); LOG(INFO) << "cost:" << elapsed << " sec"; LOG(INFO) << "total wav duration is: " << tot_wav_duration << " sec"; diff --git a/speechx/speechx/frontend/audio/cmvn.cc b/speechx/speechx/frontend/audio/cmvn.cc index 5e84a1a12..7997e8a7b 100644 --- a/speechx/speechx/frontend/audio/cmvn.cc +++ b/speechx/speechx/frontend/audio/cmvn.cc @@ -14,17 +14,18 @@ #include "frontend/audio/cmvn.h" + #include "kaldi/feat/cmvn.h" #include "kaldi/util/kaldi-io.h" namespace ppspeech { -using kaldi::Vector; -using kaldi::VectorBase; using kaldi::BaseFloat; -using std::vector; using kaldi::SubVector; +using kaldi::Vector; +using kaldi::VectorBase; using std::unique_ptr; +using std::vector; CMVN::CMVN(std::string cmvn_file, unique_ptr base_extractor) @@ -57,7 +58,7 @@ bool CMVN::Read(kaldi::Vector* feats) { // feats contain num_frames feature. void CMVN::Compute(VectorBase* feats) const { KALDI_ASSERT(feats != NULL); - + if (stats_.NumRows() > 2 || stats_.NumRows() < 1 || feats->Dim() % dim_ != 0) { KALDI_ERR << "Dim mismatch: cmvn " << stats_.NumRows() << ',' diff --git a/speechx/speechx/frontend/audio/compute_fbank_main.cc b/speechx/speechx/frontend/audio/compute_fbank_main.cc index 93a6d4072..bb7e449fe 100644 --- a/speechx/speechx/frontend/audio/compute_fbank_main.cc +++ b/speechx/speechx/frontend/audio/compute_fbank_main.cc @@ -16,16 +16,15 @@ #include "base/flags.h" #include "base/log.h" -#include "kaldi/feat/wave-reader.h" -#include "kaldi/util/kaldi-io.h" -#include "kaldi/util/table-types.h" - #include "frontend/audio/audio_cache.h" #include "frontend/audio/data_cache.h" #include "frontend/audio/fbank.h" #include "frontend/audio/feature_cache.h" #include "frontend/audio/frontend_itf.h" #include "frontend/audio/normalizer.h" +#include "kaldi/feat/wave-reader.h" +#include "kaldi/util/kaldi-io.h" +#include "kaldi/util/table-types.h" DEFINE_string(wav_rspecifier, "", "test wav scp path"); DEFINE_string(feature_wspecifier, "", "output feats wspecifier"); @@ -86,24 +85,27 @@ int main(int argc, char* argv[]) { LOG(INFO) << "chunk size (sec): " << streaming_chunk; LOG(INFO) << "chunk size (sample): " << chunk_sample_size; - for (; !wav_reader.Done() && !wav_info_reader.Done(); wav_reader.Next(), wav_info_reader.Next()) { + for (; !wav_reader.Done() && !wav_info_reader.Done(); + wav_reader.Next(), wav_info_reader.Next()) { const std::string& utt = wav_reader.Key(); const kaldi::WaveData& wave_data = wav_reader.Value(); const std::string& utt2 = wav_info_reader.Key(); const kaldi::WaveInfo& wave_info = wav_info_reader.Value(); - CHECK(utt == utt2) << "wav reader and wav info reader using diff rspecifier!!!"; + CHECK(utt == utt2) + << "wav reader and wav info reader using diff rspecifier!!!"; LOG(INFO) << "utt: " << utt; LOG(INFO) << "samples: " << wave_info.SampleCount(); LOG(INFO) << "dur: " << wave_info.Duration() << " sec"; - CHECK(wave_info.SampFreq() == FLAGS_sample_rate) << "need " << FLAGS_sample_rate << " get " << wave_info.SampFreq(); + CHECK(wave_info.SampFreq() == FLAGS_sample_rate) + << "need " << FLAGS_sample_rate << " get " << wave_info.SampFreq(); // load first channel wav int32 this_channel = 0; kaldi::SubVector waveform(wave_data.Data(), this_channel); - + // compute feat chunk by chunk int tot_samples = waveform.Dim(); int sample_offset = 0; @@ -157,7 +159,8 @@ int main(int argc, char* argv[]) { ++cur_idx; } } - LOG(INFO) << "feat shape: " << features.NumRows() << " , " << features.NumCols(); + LOG(INFO) << "feat shape: " << features.NumRows() << " , " + << features.NumCols(); feat_writer.Write(utt, features); // reset frontend pipeline state diff --git a/speechx/speechx/frontend/audio/compute_linear_spectrogram_main.cc b/speechx/speechx/frontend/audio/compute_linear_spectrogram_main.cc index 889f5663d..42693c0c6 100644 --- a/speechx/speechx/frontend/audio/compute_linear_spectrogram_main.cc +++ b/speechx/speechx/frontend/audio/compute_linear_spectrogram_main.cc @@ -14,16 +14,15 @@ #include "base/flags.h" #include "base/log.h" -#include "kaldi/feat/wave-reader.h" -#include "kaldi/util/kaldi-io.h" -#include "kaldi/util/table-types.h" - #include "frontend/audio/audio_cache.h" #include "frontend/audio/data_cache.h" #include "frontend/audio/feature_cache.h" #include "frontend/audio/frontend_itf.h" #include "frontend/audio/linear_spectrogram.h" #include "frontend/audio/normalizer.h" +#include "kaldi/feat/wave-reader.h" +#include "kaldi/util/kaldi-io.h" +#include "kaldi/util/table-types.h" DEFINE_string(wav_rspecifier, "", "test wav scp path"); DEFINE_string(feature_wspecifier, "", "output feats wspecifier"); diff --git a/speechx/speechx/frontend/audio/feature_pipeline.cc b/speechx/speechx/frontend/audio/feature_pipeline.cc index 7232efc44..65493e422 100644 --- a/speechx/speechx/frontend/audio/feature_pipeline.cc +++ b/speechx/speechx/frontend/audio/feature_pipeline.cc @@ -18,7 +18,8 @@ namespace ppspeech { using std::unique_ptr; -FeaturePipeline::FeaturePipeline(const FeaturePipelineOptions& opts) : opts_(opts) { +FeaturePipeline::FeaturePipeline(const FeaturePipelineOptions& opts) + : opts_(opts) { unique_ptr data_source( new ppspeech::AudioCache(1000 * kint16max, opts.to_float32)); @@ -43,4 +44,4 @@ FeaturePipeline::FeaturePipeline(const FeaturePipelineOptions& opts) : opts_(opt new ppspeech::Assembler(opts.assembler_opts, std::move(cache))); } -} // ppspeech +} // namespace ppspeech diff --git a/speechx/speechx/nnet/decodable.cc b/speechx/speechx/nnet/decodable.cc index b76c6280a..dc971e0f6 100644 --- a/speechx/speechx/nnet/decodable.cc +++ b/speechx/speechx/nnet/decodable.cc @@ -18,8 +18,8 @@ namespace ppspeech { using kaldi::BaseFloat; using kaldi::Matrix; -using std::vector; using kaldi::Vector; +using std::vector; Decodable::Decodable(const std::shared_ptr& nnet, const std::shared_ptr& frontend, @@ -56,7 +56,6 @@ int32 Decodable::NumIndices() const { return 0; } int32 Decodable::TokenId2NnetId(int32 token_id) { return token_id - 1; } - bool Decodable::EnsureFrameHaveComputed(int32 frame) { // decoding frame if (frame >= frames_ready_) { @@ -92,14 +91,15 @@ bool Decodable::AdvanceChunk() { return true; } -bool Decodable::AdvanceChunk(kaldi::Vector* logprobs, int* vocab_dim) { +bool Decodable::AdvanceChunk(kaldi::Vector* logprobs, + int* vocab_dim) { if (AdvanceChunk() == false) { return false; } int nrows = nnet_out_cache_.NumRows(); - CHECK(nrows == (frames_ready_ - frame_offset_)); - if (nrows <= 0){ + CHECK(nrows == (frames_ready_ - frame_offset_)); + if (nrows <= 0) { LOG(WARNING) << "No new nnet out in cache."; return false; } @@ -107,7 +107,7 @@ bool Decodable::AdvanceChunk(kaldi::Vector* logprobs, int* voc logprobs->Resize(nnet_out_cache_.NumRows() * nnet_out_cache_.NumCols()); logprobs->CopyRowsFromMat(nnet_out_cache_); - *vocab_dim = nnet_out_cache_.NumCols(); + *vocab_dim = nnet_out_cache_.NumCols(); return true; } @@ -140,7 +140,7 @@ BaseFloat Decodable::LogLikelihood(int32 frame, int32 index) { BaseFloat logprob = 0.0; int32 frame_idx = frame - frame_offset_; BaseFloat nnet_out = nnet_out_cache_(frame_idx, TokenId2NnetId(index)); - if (nnet_->IsLogProb()){ + if (nnet_->IsLogProb()) { logprob = nnet_out; } else { logprob = std::log(nnet_out + std::numeric_limits::epsilon()); @@ -158,8 +158,8 @@ void Decodable::Reset() { } void Decodable::AttentionRescoring(const std::vector>& hyps, - float reverse_weight, - std::vector* rescoring_score){ + float reverse_weight, + std::vector* rescoring_score) { nnet_->AttentionRescoring(hyps, reverse_weight, rescoring_score); } diff --git a/speechx/speechx/nnet/u2_nnet.cc b/speechx/speechx/nnet/u2_nnet.cc index 71252477e..4bafdf831 100644 --- a/speechx/speechx/nnet/u2_nnet.cc +++ b/speechx/speechx/nnet/u2_nnet.cc @@ -242,7 +242,6 @@ void U2Nnet::ForwardEncoderChunkImpl( const int32& feat_dim, std::vector* out_prob, int32* vocab_dim) { - #ifdef USE_PROFILING RecordEvent event( "ForwardEncoderChunkImpl", TracerEventType::UserDefined, 1); @@ -349,8 +348,9 @@ void U2Nnet::ForwardEncoderChunkImpl( // current offset in decoder frame // not used in nnet offset_ += chunk_out.shape()[1]; - VLOG(2) << "encoder out chunk size: " << chunk_out.shape()[1] << " total: " << offset_ ; - + VLOG(2) << "encoder out chunk size: " << chunk_out.shape()[1] + << " total: " << offset_; + // collects encoder outs. encoder_outs_.push_back(chunk_out); @@ -706,12 +706,13 @@ void U2Nnet::AttentionRescoring(const std::vector>& hyps, } -void U2Nnet::EncoderOuts(std::vector>* encoder_out) const { +void U2Nnet::EncoderOuts( + std::vector>* encoder_out) const { // list of (B=1,T,D) int size = encoder_outs_.size(); VLOG(1) << "encoder_outs_ size: " << size; - for (int i = 0; i < size; i++){ + for (int i = 0; i < size; i++) { const paddle::Tensor& item = encoder_outs_[i]; const std::vector shape = item.shape(); CHECK(shape.size() == 3); @@ -719,16 +720,17 @@ void U2Nnet::EncoderOuts(std::vector>* encoder_o const int& T = shape[1]; const int& D = shape[2]; CHECK(B == 1) << "Only support batch one."; - VLOG(1) << "encoder out " << i << " shape: (" << B << "," << T << "," << D << ")"; + VLOG(1) << "encoder out " << i << " shape: (" << B << "," << T << "," + << D << ")"; - const float *this_tensor_ptr = item.data(); - for (int j = 0; j < T; j++){ - const float* cur = this_tensor_ptr + j * D; + const float* this_tensor_ptr = item.data(); + for (int j = 0; j < T; j++) { + const float* cur = this_tensor_ptr + j * D; kaldi::Vector out(D); std::memcpy(out.Data(), cur, D * sizeof(kaldi::BaseFloat)); encoder_out->emplace_back(out); } } - } +} } // namespace ppspeech \ No newline at end of file diff --git a/speechx/speechx/nnet/u2_nnet_main.cc b/speechx/speechx/nnet/u2_nnet_main.cc index adbbf0e80..5039a59ad 100644 --- a/speechx/speechx/nnet/u2_nnet_main.cc +++ b/speechx/speechx/nnet/u2_nnet_main.cc @@ -14,11 +14,11 @@ #include "base/common.h" +#include "decoder/param.h" #include "frontend/audio/assembler.h" #include "frontend/audio/data_cache.h" #include "kaldi/util/table-types.h" #include "nnet/decodable.h" -#include "decoder/param.h" #include "nnet/u2_nnet.h" @@ -46,15 +46,16 @@ int main(int argc, char* argv[]) { LOG(INFO) << "output wspecifier: " << FLAGS_nnet_prob_wspecifier; LOG(INFO) << "model path: " << FLAGS_model_path; - kaldi::SequentialBaseFloatMatrixReader feature_reader(FLAGS_feature_rspecifier); + kaldi::SequentialBaseFloatMatrixReader feature_reader( + FLAGS_feature_rspecifier); kaldi::BaseFloatMatrixWriter nnet_out_writer(FLAGS_nnet_prob_wspecifier); - kaldi::BaseFloatMatrixWriter nnet_encoder_outs_writer(FLAGS_nnet_encoder_outs_wspecifier); + kaldi::BaseFloatMatrixWriter nnet_encoder_outs_writer( + FLAGS_nnet_encoder_outs_wspecifier); ppspeech::ModelOptions model_opts = ppspeech::ModelOptions::InitFromFlags(); - int32 chunk_size = - (FLAGS_nnet_decoder_chunk - 1) * FLAGS_subsampling_rate + - FLAGS_receptive_field_length; + int32 chunk_size = (FLAGS_nnet_decoder_chunk - 1) * FLAGS_subsampling_rate + + FLAGS_receptive_field_length; int32 chunk_stride = FLAGS_subsampling_rate * FLAGS_nnet_decoder_chunk; int32 receptive_field_length = FLAGS_receptive_field_length; LOG(INFO) << "chunk size (frame): " << chunk_size; @@ -92,9 +93,9 @@ int main(int argc, char* argv[]) { ori_feature_len - chunk_idx * chunk_stride, chunk_size); } if (this_chunk_size < receptive_field_length) { - LOG(WARNING) << "utt: " << utt << " skip last " - << this_chunk_size << " frames, expect is " - << receptive_field_length; + LOG(WARNING) + << "utt: " << utt << " skip last " << this_chunk_size + << " frames, expect is " << receptive_field_length; break; } @@ -123,13 +124,17 @@ int main(int argc, char* argv[]) { kaldi::Vector logprobs; bool isok = decodable->AdvanceChunk(&logprobs, &vocab_dim); CHECK(isok == true); - for (int row_idx = 0; row_idx < logprobs.Dim() / vocab_dim; row_idx ++) { + for (int row_idx = 0; row_idx < logprobs.Dim() / vocab_dim; + row_idx++) { kaldi::Vector vec_tmp(vocab_dim); - std::memcpy(vec_tmp.Data(), logprobs.Data() + row_idx*vocab_dim, sizeof(kaldi::BaseFloat) * vocab_dim); + std::memcpy(vec_tmp.Data(), + logprobs.Data() + row_idx * vocab_dim, + sizeof(kaldi::BaseFloat) * vocab_dim); prob_vec.push_back(vec_tmp); } - VLOG(2) << "frame_idx: " << frame_idx << " elapsed: " << timer.Elapsed() << " sec."; + VLOG(2) << "frame_idx: " << frame_idx + << " elapsed: " << timer.Elapsed() << " sec."; } // get encoder out @@ -141,7 +146,8 @@ int main(int argc, char* argv[]) { if (prob_vec.size() == 0 || encoder_out_vec.size() == 0) { // the TokenWriter can not write empty string. ++num_err; - LOG(WARNING) << " the nnet prob/encoder_out of " << utt << " is empty"; + LOG(WARNING) << " the nnet prob/encoder_out of " << utt + << " is empty"; continue; } @@ -168,7 +174,8 @@ int main(int argc, char* argv[]) { kaldi::Matrix encoder_outs(nrow, ncol); for (int32 row_idx = 0; row_idx < nrow; ++row_idx) { for (int32 col_idx = 0; col_idx < ncol; ++col_idx) { - encoder_outs(row_idx, col_idx) = encoder_out_vec[row_idx](col_idx); + encoder_outs(row_idx, col_idx) = + encoder_out_vec[row_idx](col_idx); } } nnet_encoder_outs_writer.Write(utt, encoder_outs); diff --git a/speechx/speechx/protocol/websocket/websocket_server_main.cc b/speechx/speechx/protocol/websocket/websocket_server_main.cc index 827b164f3..5c32caf27 100644 --- a/speechx/speechx/protocol/websocket/websocket_server_main.cc +++ b/speechx/speechx/protocol/websocket/websocket_server_main.cc @@ -12,17 +12,18 @@ // See the License for the specific language governing permissions and // limitations under the License. -#include "websocket/websocket_server.h" #include "decoder/param.h" +#include "websocket/websocket_server.h" DEFINE_int32(port, 8082, "websocket listening port"); ppspeech::RecognizerResource InitRecognizerResoure() { ppspeech::RecognizerResource resource; resource.acoustic_scale = FLAGS_acoustic_scale; - resource.feature_pipeline_opts = ppspeech::FeaturePipelineOptions::InitFromFlags(); + resource.feature_pipeline_opts = + ppspeech::FeaturePipelineOptions::InitFromFlags(); resource.model_opts = ppspeech::ModelOptions::InitFromFlags(); - resource.tlg_opts = ppspeech::TLGDecoderOptions::InitFromFlags(); + resource.tlg_opts = ppspeech::TLGDecoderOptions::InitFromFlags(); return resource; } diff --git a/speechx/speechx/utils/math.cc b/speechx/speechx/utils/math.cc index c218990a0..289470f6c 100644 --- a/speechx/speechx/utils/math.cc +++ b/speechx/speechx/utils/math.cc @@ -16,13 +16,13 @@ #include "utils/math.h" -#include "base/common.h" - #include #include #include #include +#include "base/common.h" + namespace ppspeech { @@ -89,8 +89,8 @@ void TopK(const std::vector& data, } template void TopK(const std::vector& data, - int32_t k, - std::vector* values, - std::vector* indices) ; + int32_t k, + std::vector* values, + std::vector* indices); } // namespace ppspeech \ No newline at end of file