From 2455d8899710ad5322fa0ebc2281e0e650cd80fb Mon Sep 17 00:00:00 2001 From: Yang Zhou Date: Mon, 11 Apr 2022 15:43:21 +0800 Subject: [PATCH] make wfst work & align frame --- .../decoder/offline_wfst_decoder_main.cc | 56 ++++++++++++---- speechx/speechx/decoder/ctc_tlg_decoder.cc | 64 ++++++++++++------- speechx/speechx/decoder/ctc_tlg_decoder.h | 42 +++++++----- speechx/speechx/nnet/decodable.cc | 41 ++++++------ speechx/speechx/nnet/decodable.h | 8 +-- 5 files changed, 135 insertions(+), 76 deletions(-) diff --git a/speechx/examples/decoder/offline_wfst_decoder_main.cc b/speechx/examples/decoder/offline_wfst_decoder_main.cc index 758942b5..90dc8840 100644 --- a/speechx/examples/decoder/offline_wfst_decoder_main.cc +++ b/speechx/examples/decoder/offline_wfst_decoder_main.cc @@ -27,14 +27,20 @@ DEFINE_string(model_path, "avg_1.jit.pdmodel", "paddle nnet model"); DEFINE_string(param_path, "avg_1.jit.pdiparams", "paddle nnet model param"); DEFINE_string(word_symbol_table, "vocab.txt", "word symbol table"); DEFINE_string(graph_path, "TLG", "decoder graph"); -DEFINE_double(acoustic_scale, 10.0, "acoustic scale"); -DEFINE_int32(max_active, 5000, "decoder graph"); - +DEFINE_double(acoustic_scale, 1.0, "acoustic scale"); +DEFINE_int32(max_active, 7500, "decoder graph"); +DEFINE_int32(receptive_field_length, + 7, + "receptive field of two CNN(kernel=5) downsampling module."); +DEFINE_int32(downsampling_rate, + 4, + "two CNN(kernel=5) module downsampling rate."); using kaldi::BaseFloat; using kaldi::Matrix; using std::vector; +// test clg decoder by feeding speech feature. int main(int argc, char* argv[]) { gflags::ParseCommandLineFlags(&argc, &argv, false); google::InitGoogleLogging(argv[0]); @@ -52,7 +58,8 @@ int main(int argc, char* argv[]) { opts.word_symbol_table = word_symbol_table; opts.fst_path = graph_path; opts.opts.max_active = FLAGS_max_active; - opts.opts.beam = + opts.opts.beam = 15.0; + opts.opts.lattice_beam = 7.5; ppspeech::TLGDecoder decoder(opts); ppspeech::ModelOptions model_opts; @@ -61,30 +68,55 @@ int main(int argc, char* argv[]) { model_opts.cache_shape = "5-1-1024,5-1-1024"; std::shared_ptr nnet( new ppspeech::PaddleNnet(model_opts)); - std::shared_ptr raw_data( - new ppspeech::DataCache()); + std::shared_ptr raw_data(new ppspeech::DataCache()); std::shared_ptr decodable( - new ppspeech::Decodable(nnet, raw_data)); + new ppspeech::Decodable(nnet, raw_data, FLAGS_acoustic_scale)); - int32 chunk_size = 35; + int32 chunk_size = FLAGS_receptive_field_length; + int32 chunk_stride = FLAGS_downsampling_rate; + int32 receptive_field_length = FLAGS_receptive_field_length; + LOG(INFO) << "chunk size (frame): " << chunk_size; + LOG(INFO) << "chunk stride (frame): " << chunk_stride; + LOG(INFO) << "receptive field (frame): " << receptive_field_length; decoder.InitDecoder(); for (; !feature_reader.Done(); feature_reader.Next()) { string utt = feature_reader.Key(); - const kaldi::Matrix feature = feature_reader.Value(); + kaldi::Matrix feature = feature_reader.Value(); raw_data->SetDim(feature.NumCols()); + LOG(INFO) << "process utt: " << utt; + LOG(INFO) << "rows: " << feature.NumRows(); + LOG(INFO) << "cols: " << feature.NumCols(); + int32 row_idx = 0; - int32 num_chunks = feature.NumRows() / chunk_size; + int32 padding_len = 0; + int32 ori_feature_len = feature.NumRows(); + if ((feature.NumRows() - chunk_size) % chunk_stride != 0) { + padding_len = + chunk_stride - (feature.NumRows() - chunk_size) % chunk_stride; + feature.Resize(feature.NumRows() + padding_len, + feature.NumCols(), + kaldi::kCopyData); + } + int32 num_chunks = (feature.NumRows() - chunk_size) / chunk_stride + 1; for (int chunk_idx = 0; chunk_idx < num_chunks; ++chunk_idx) { kaldi::Vector feature_chunk(chunk_size * feature.NumCols()); + int32 feature_chunk_size = 0; + if (ori_feature_len > chunk_idx * chunk_stride) { + feature_chunk_size = std::min( + ori_feature_len - chunk_idx * chunk_stride, chunk_size); + } + if (feature_chunk_size < receptive_field_length) break; + + int32 start = chunk_idx * chunk_stride; for (int row_id = 0; row_id < chunk_size; ++row_id) { - kaldi::SubVector tmp(feature, row_idx); + kaldi::SubVector tmp(feature, start); kaldi::SubVector f_chunk_tmp( feature_chunk.Data() + row_id * feature.NumCols(), feature.NumCols()); f_chunk_tmp.CopyFromVec(tmp); - row_idx++; + ++start; } raw_data->Accept(feature_chunk); if (chunk_idx == num_chunks - 1) { diff --git a/speechx/speechx/decoder/ctc_tlg_decoder.cc b/speechx/speechx/decoder/ctc_tlg_decoder.cc index c08a7d5b..5365e709 100644 --- a/speechx/speechx/decoder/ctc_tlg_decoder.cc +++ b/speechx/speechx/decoder/ctc_tlg_decoder.cc @@ -1,50 +1,66 @@ +// Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + #include "decoder/ctc_tlg_decoder.h" namespace ppspeech { TLGDecoder::TLGDecoder(TLGDecoderOptions opts) { fst_.reset(fst::Fst::Read(opts.fst_path)); CHECK(fst_ != nullptr); - word_symbol_table_.reset(fst::SymbolTable::ReadText(opts.word_symbol_table)); + word_symbol_table_.reset( + fst::SymbolTable::ReadText(opts.word_symbol_table)); decoder_.reset(new kaldi::LatticeFasterOnlineDecoder(*fst_, opts.opts)); decoder_->InitDecoding(); + frame_decoded_size_ = 0; } void TLGDecoder::InitDecoder() { decoder_->InitDecoding(); + frame_decoded_size_ = 0; } -void TLGDecoder::AdvanceDecode(const std::shared_ptr& decodable) { - while (1) { - AdvanceDecoding(decodable.get()); - if (decodable->IsLastFrame(num_frame_decoded_)) break; +void TLGDecoder::AdvanceDecode( + const std::shared_ptr& decodable) { + while (!decodable->IsLastFrame(frame_decoded_size_)) { + LOG(INFO) << "num frame decode: " << frame_decoded_size_; + AdvanceDecoding(decodable.get()); } } void TLGDecoder::AdvanceDecoding(kaldi::DecodableInterface* decodable) { - // skip blank frame? - decoder_->AdvanceDecoding(decodable, 1); - num_frame_decoded_++; + decoder_->AdvanceDecoding(decodable, 1); + frame_decoded_size_++; } void TLGDecoder::Reset() { - decoder_->InitDecoding(); - return; + InitDecoder(); + return; } std::string TLGDecoder::GetFinalBestPath() { - decoder_->FinalizeDecoding(); - kaldi::Lattice lat; - kaldi::LatticeWeight weight; - std::vector alignment; - std::vector words_id; - decoder_->GetBestPath(&lat, true); - fst::GetLinearSymbolSequence(lat, &alignment, &words_id, &weight); - std::string words; - for (int32 idx = 0; idx < words_id.size(); ++idx) { - std::string word = word_symbol_table_->Find(words_id[idx]); - words += word; - } - return words; + decoder_->FinalizeDecoding(); + kaldi::Lattice lat; + kaldi::LatticeWeight weight; + std::vector alignment; + std::vector words_id; + decoder_->GetBestPath(&lat, true); + fst::GetLinearSymbolSequence(lat, &alignment, &words_id, &weight); + std::string words; + for (int32 idx = 0; idx < words_id.size(); ++idx) { + std::string word = word_symbol_table_->Find(words_id[idx]); + words += word; + } + return words; } - } \ No newline at end of file diff --git a/speechx/speechx/decoder/ctc_tlg_decoder.h b/speechx/speechx/decoder/ctc_tlg_decoder.h index b4cd8c34..361c44af 100644 --- a/speechx/speechx/decoder/ctc_tlg_decoder.h +++ b/speechx/speechx/decoder/ctc_tlg_decoder.h @@ -1,21 +1,33 @@ +// Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + #pragma once -#include "kaldi/decoder/lattice-faster-online-decoder.h" +#include "base/basic_types.h" #include "kaldi/decoder/decodable-itf.h" +#include "kaldi/decoder/lattice-faster-online-decoder.h" #include "util/parse-options.h" -#include "base/basic_types.h" namespace ppspeech { struct TLGDecoderOptions { - kaldi::LatticeFasterDecoderConfig opts; - // todo remove later, add into decode resource - std::string word_symbol_table; - std::string fst_path; - - TLGDecoderOptions() - : word_symbol_table(""), - fst_path("") {} + kaldi::LatticeFasterDecoderConfig opts; + // todo remove later, add into decode resource + std::string word_symbol_table; + std::string fst_path; + + TLGDecoderOptions() : word_symbol_table(""), fst_path("") {} }; class TLGDecoder { @@ -34,14 +46,14 @@ class TLGDecoder { void Reset(); private: - void AdvanceDecoding(kaldi::DecodableInterface* decodable); + void AdvanceDecoding(kaldi::DecodableInterface* decodable); std::shared_ptr decoder_; - std::shared_ptr> fst_; + std::shared_ptr> fst_; std::shared_ptr word_symbol_table_; - int32 num_frame_decoded_; - }; + // the frame size which have decoded starts from 0. + int32 frame_decoded_size_; +}; - } // namespace ppspeech \ No newline at end of file diff --git a/speechx/speechx/nnet/decodable.cc b/speechx/speechx/nnet/decodable.cc index 805c0dca..1bf870d8 100644 --- a/speechx/speechx/nnet/decodable.cc +++ b/speechx/speechx/nnet/decodable.cc @@ -22,8 +22,13 @@ using std::vector; using kaldi::Vector; Decodable::Decodable(const std::shared_ptr& nnet, - const std::shared_ptr& frontend) - : frontend_(frontend), nnet_(nnet), frame_offset_(0), frames_ready_(0) {} + const std::shared_ptr& frontend, + kaldi::BaseFloat acoustic_scale) + : frontend_(frontend), + nnet_(nnet), + frame_offset_(0), + frames_ready_(0), + acoustic_scale_(acoustic_scale) {} void Decodable::Acceptlikelihood(const Matrix& likelihood) { nnet_cache_ = likelihood; @@ -32,14 +37,14 @@ void Decodable::Acceptlikelihood(const Matrix& likelihood) { // Decodable::Init(DecodableConfig config) { //} -int32 Decodable::NumFramesReady() const { - return frames_ready_; -} +// return the size of frame have computed. +int32 Decodable::NumFramesReady() const { return frames_ready_; } + +// frame idx is from 0 to frame_ready_ -1; bool Decodable::IsLastFrame(int32 frame) { bool flag = EnsureFrameHaveComputed(frame); - //CHECK_LE(frame, frames_ready_); - return (flag == false) && (frame == frames_ready_); + return frame >= frames_ready_; } int32 Decodable::NumIndices() const { return 0; } @@ -48,7 +53,8 @@ BaseFloat Decodable::LogLikelihood(int32 frame, int32 index) { CHECK_LE(index, nnet_cache_.NumCols()); CHECK_LE(frame, frames_ready_); int32 frame_idx = frame - frame_offset_; - return std::log(nnet_cache_(frame_idx, index) + std::numeric_limits::min()); + return acoustic_scale_ * std::log(nnet_cache_(frame_idx, index - 1) + + std::numeric_limits::min()); } bool Decodable::EnsureFrameHaveComputed(int32 frame) { @@ -67,20 +73,12 @@ bool Decodable::AdvanceChunk() { Vector inferences; Matrix nnet_cache_tmp; nnet_->FeedForward(features, frontend_->Dim(), &inferences, &nnet_dim); - nnet_cache_tmp.Resize(inferences.Dim() / nnet_dim, nnet_dim); - nnet_cache_tmp.CopyRowsFromVec(inferences); - // skip blank - vector no_blank_record; - BaseFloat blank_threshold = 0.98; - for (int32 idx = 0; idx < nnet_cache_.NumRows(); ++idx) { - if (nnet_cache_(idx, 0) > blank_threshold) { - - } - } - - + nnet_cache_.Resize(inferences.Dim() / nnet_dim, nnet_dim); + nnet_cache_.CopyRowsFromVec(inferences); + frame_offset_ = frames_ready_; frames_ready_ += nnet_cache_.NumRows(); + LOG(INFO) << "nnet size: " << nnet_cache_.NumRows(); return true; } @@ -89,7 +87,8 @@ bool Decodable::FrameLogLikelihood(int32 frame, vector* likelihood) { if (EnsureFrameHaveComputed(frame) == false) return false; likelihood->resize(nnet_cache_.NumCols()); for (int32 idx = 0; idx < nnet_cache_.NumCols(); ++idx) { - (*likelihood)[idx] = nnet_cache_(frame - frame_offset_, idx); + (*likelihood)[idx] = + nnet_cache_(frame - frame_offset_, idx) * acoustic_scale_; } return true; } diff --git a/speechx/speechx/nnet/decodable.h b/speechx/speechx/nnet/decodable.h index 5b687f3d..b18ef07c 100644 --- a/speechx/speechx/nnet/decodable.h +++ b/speechx/speechx/nnet/decodable.h @@ -14,8 +14,8 @@ #include "base/common.h" #include "frontend/audio/frontend_itf.h" -#include "kaldi/matrix/kaldi-matrix.h" #include "kaldi/decoder/decodable-itf.h" +#include "kaldi/matrix/kaldi-matrix.h" #include "nnet/nnet_itf.h" namespace ppspeech { @@ -25,7 +25,8 @@ struct DecodableOpts; class Decodable : public kaldi::DecodableInterface { public: explicit Decodable(const std::shared_ptr& nnet, - const std::shared_ptr& frontend); + const std::shared_ptr& frontend, + kaldi::BaseFloat acoustic_scale = 1.0); // void Init(DecodableOpts config); virtual kaldi::BaseFloat LogLikelihood(int32 frame, int32 index); virtual bool IsLastFrame(int32 frame); @@ -38,14 +39,12 @@ class Decodable : public kaldi::DecodableInterface { void Reset(); bool IsInputFinished() const { return frontend_->IsFinished(); } bool EnsureFrameHaveComputed(int32 frame); - private: bool AdvanceChunk(); std::shared_ptr frontend_; std::shared_ptr nnet_; kaldi::Matrix nnet_cache_; - // std::vector> nnet_cache_; int32 frame_offset_; int32 frames_ready_; // todo: feature frame mismatch with nnet inference frame @@ -53,6 +52,7 @@ class Decodable : public kaldi::DecodableInterface { // so use subsampled_frame int32 current_log_post_subsampled_offset_; int32 num_chunk_computed_; + kaldi::BaseFloat acoustic_scale_; }; } // namespace ppspeech