diff --git a/speechx/speechx/base/common.h b/speechx/speechx/base/common.h index 90fc96a18..70b11b691 100644 --- a/speechx/speechx/base/common.h +++ b/speechx/speechx/base/common.h @@ -15,6 +15,7 @@ #pragma once #include +#include #include #include #include diff --git a/speechx/speechx/frontend/audio/cmvn_json2kaldi_main.cc b/speechx/speechx/frontend/audio/cmvn_json2kaldi_main.cc index 93bad6886..713c9ef1e 100644 --- a/speechx/speechx/frontend/audio/cmvn_json2kaldi_main.cc +++ b/speechx/speechx/frontend/audio/cmvn_json2kaldi_main.cc @@ -47,13 +47,13 @@ int main(int argc, char* argv[]) { for (auto obj : value.as_object()) { if (obj.key() == "mean_stat") { - LOG(INFO) << "mean_stat:" << obj.value(); + VLOG(2) << "mean_stat:" << obj.value(); } if (obj.key() == "var_stat") { - LOG(INFO) << "var_stat: " << obj.value(); + VLOG(2) << "var_stat: " << obj.value(); } if (obj.key() == "frame_num") { - LOG(INFO) << "frame_num: " << obj.value(); + VLOG(2) << "frame_num: " << obj.value(); } } @@ -79,7 +79,7 @@ int main(int argc, char* argv[]) { cmvn_stats(1, idx) = var_stat_vec[idx]; } cmvn_stats(0, mean_size) = frame_num; - LOG(INFO) << cmvn_stats; + VLOG(2) << cmvn_stats; kaldi::WriteKaldiObject(cmvn_stats, FLAGS_cmvn_write_path, FLAGS_binary); LOG(INFO) << "cmvn stats have write into: " << FLAGS_cmvn_write_path; diff --git a/speechx/speechx/kaldi/decoder/decodable-itf.h b/speechx/speechx/kaldi/decoder/decodable-itf.h index b8ce9143e..a7c12588b 100644 --- a/speechx/speechx/kaldi/decoder/decodable-itf.h +++ b/speechx/speechx/kaldi/decoder/decodable-itf.h @@ -101,7 +101,9 @@ namespace kaldi { */ class DecodableInterface { public: - /// Returns the log likelihood, which will be negated in the decoder. + virtual ~DecodableInterface() {} + + /// Returns the log likelihood(logprob), which will be negated in the decoder. /// The "frame" starts from zero. You should verify that NumFramesReady() > /// frame /// before calling this. @@ -143,11 +145,12 @@ class DecodableInterface { /// this is for compatibility with OpenFst). virtual int32 NumIndices() const = 0; + /// Returns the likelihood(prob), which will be postive in the decoder. + /// The "frame" starts from zero. You should verify that NumFramesReady() > + /// frame + /// before calling this. virtual bool FrameLikelihood( int32 frame, std::vector* likelihood) = 0; - - - virtual ~DecodableInterface() {} }; /// @} } // namespace Kaldi diff --git a/speechx/speechx/nnet/decodable.cc b/speechx/speechx/nnet/decodable.cc index 40fac182f..1483949b9 100644 --- a/speechx/speechx/nnet/decodable.cc +++ b/speechx/speechx/nnet/decodable.cc @@ -55,18 +55,10 @@ int32 Decodable::NumIndices() const { return 0; } // id. int32 Decodable::TokenId2NnetId(int32 token_id) { return token_id - 1; } -BaseFloat Decodable::LogLikelihood(int32 frame, int32 index) { - CHECK_LE(index, nnet_out_cache_.NumCols()); - CHECK_LE(frame, frames_ready_); - int32 frame_idx = frame - frame_offset_; - // the nnet output is prob ranther than log prob - // the index - 1, because the ilabel - return acoustic_scale_ * - std::log(nnet_out_cache_(frame_idx, TokenId2NnetId(index)) + - std::numeric_limits::min()); -} + bool Decodable::EnsureFrameHaveComputed(int32 frame) { + // decoding frame if (frame >= frames_ready_) { return AdvanceChunk(); } @@ -74,26 +66,48 @@ bool Decodable::EnsureFrameHaveComputed(int32 frame) { } bool Decodable::AdvanceChunk() { + kaldi::Timer timer; // read feats Vector features; if (frontend_ == NULL || frontend_->Read(&features) == false) { // no feat or frontend_ not init. return false; } + VLOG(2) << "Forward with " << features.Dim() << " frames."; // forward feats NnetOut out; nnet_->FeedForward(features, frontend_->Dim(), &out); int32& vocab_dim = out.vocab_dim; - Vector& probs = out.logprobs; + Vector& logprobs = out.logprobs; // cache nnet outupts - nnet_out_cache_.Resize(probs.Dim() / vocab_dim, vocab_dim); - nnet_out_cache_.CopyRowsFromVec(probs); + nnet_out_cache_.Resize(logprobs.Dim() / vocab_dim, vocab_dim); + nnet_out_cache_.CopyRowsFromVec(logprobs); - // update state + // update state, decoding frame. frame_offset_ = frames_ready_; frames_ready_ += nnet_out_cache_.NumRows(); + VLOG(2) << "Forward feat chunk cost: " << timer.Elapsed() << " sec."; + return true; +} + +bool Decodable::AdvanceChunk(kaldi::Vector* logprobs, int* vocab_dim) { + if (AdvanceChunk() == false) { + return false; + } + + int nrows = nnet_out_cache_.NumRows(); + CHECK(nrows == (frames_ready_ - frame_offset_)); + if (nrows <= 0){ + LOG(WARNING) << "No new nnet out in cache."; + return false; + } + + logprobs->Resize(nnet_out_cache_.NumRows() * nnet_out_cache_.NumCols()); + logprobs->CopyRowsFromMat(nnet_out_cache_); + + *vocab_dim = nnet_out_cache_.NumCols(); return true; } @@ -113,6 +127,28 @@ bool Decodable::FrameLikelihood(int32 frame, vector* likelihood) { return true; } +BaseFloat Decodable::LogLikelihood(int32 frame, int32 index) { + if (EnsureFrameHaveComputed(frame) == false) { + return false; + } + + CHECK_LE(index, nnet_out_cache_.NumCols()); + CHECK_LE(frame, frames_ready_); + + // the nnet output is prob ranther than log prob + // the index - 1, because the ilabel + BaseFloat logprob = 0.0; + int32 frame_idx = frame - frame_offset_; + BaseFloat nnet_out = nnet_out_cache_(frame_idx, TokenId2NnetId(index)); + if (nnet_->IsLogProb()){ + logprob = nnet_out; + } else { + logprob = std::log(nnet_out + std::numeric_limits::epsilon()); + } + CHECK(!std::isnan(logprob) && !std::isinf(logprob)); + return acoustic_scale_ * logprob; +} + void Decodable::Reset() { if (frontend_ != nullptr) frontend_->Reset(); if (nnet_ != nullptr) nnet_->Reset(); diff --git a/speechx/speechx/nnet/decodable.h b/speechx/speechx/nnet/decodable.h index 39b38dc11..1ee6afbf8 100644 --- a/speechx/speechx/nnet/decodable.h +++ b/speechx/speechx/nnet/decodable.h @@ -57,9 +57,13 @@ class Decodable : public kaldi::DecodableInterface { std::shared_ptr Nnet() { return nnet_; } - private: + // forward nnet with feats bool AdvanceChunk(); + // forward nnet with feats, and get nnet output + bool AdvanceChunk(kaldi::Vector* logprobs, + int* vocab_dim); + private: std::shared_ptr frontend_; std::shared_ptr nnet_; diff --git a/speechx/speechx/nnet/ds2_nnet.h b/speechx/speechx/nnet/ds2_nnet.h index 80be69271..9e2cb77b7 100644 --- a/speechx/speechx/nnet/ds2_nnet.h +++ b/speechx/speechx/nnet/ds2_nnet.h @@ -104,6 +104,8 @@ class PaddleNnet : public NnetInterface { void Reset() override; + bool IsLogProb() override { return false; } + std::shared_ptr> GetCacheEncoder( const std::string& name); diff --git a/speechx/speechx/nnet/nnet_itf.h b/speechx/speechx/nnet/nnet_itf.h index 5dde72a81..d05aabea4 100644 --- a/speechx/speechx/nnet/nnet_itf.h +++ b/speechx/speechx/nnet/nnet_itf.h @@ -39,7 +39,8 @@ class NnetInterface { // forward feat with nnet. // nnet do not cache feats, feats cached by frontend. - // nnet cache model outputs, i.e. logprobs/encoder_outs. + // nnet cache model state, i.e. encoder_outs, att_cache, cnn_cache, + // frame_offset. virtual void FeedForward(const kaldi::Vector& features, const int32& feature_dim, NnetOut* out) = 0; @@ -47,6 +48,9 @@ class NnetInterface { // reset nnet state, e.g. nnet_logprob_cache_, offset_, encoder_outs_. virtual void Reset() = 0; + // true, nnet output is logprob; otherwise is prob, + virtual bool IsLogProb() = 0; + // using to get encoder outs. e.g. seq2seq with Attention model. virtual void EncoderOuts( std::vector>* encoder_out) const = 0; diff --git a/speechx/speechx/nnet/u2_nnet.h b/speechx/speechx/nnet/u2_nnet.h index 8ce45f43a..4ecbac26f 100644 --- a/speechx/speechx/nnet/u2_nnet.h +++ b/speechx/speechx/nnet/u2_nnet.h @@ -111,6 +111,8 @@ class U2Nnet : public U2NnetBase { void Reset() override; + bool IsLogProb() override { return true; } + void Dim(); void LoadModel(const std::string& model_path_w_prefix); diff --git a/speechx/speechx/nnet/u2_nnet_main.cc b/speechx/speechx/nnet/u2_nnet_main.cc index fb9fec230..0c5aed54e 100644 --- a/speechx/speechx/nnet/u2_nnet_main.cc +++ b/speechx/speechx/nnet/u2_nnet_main.cc @@ -98,6 +98,7 @@ int main(int argc, char* argv[]) { // } int32 frame_idx = 0; + int vocab_dim = 0; std::vector> prob_vec; std::vector> encoder_out_vec; int32 ori_feature_len = feature.NumRows(); @@ -138,17 +139,17 @@ int main(int argc, char* argv[]) { } // get nnet outputs - vector prob; - while (decodable->FrameLikelihood(frame_idx, &prob)) { - kaldi::Vector vec_tmp(prob.size()); - std::memcpy(vec_tmp.Data(), - prob.data(), - sizeof(kaldi::BaseFloat) * prob.size()); + kaldi::Timer timer; + kaldi::Vector logprobs; + bool isok = decodable->AdvanceChunk(&logprobs, &vocab_dim); + CHECK(isok == true); + for (int row_idx = 0; row_idx < logprobs.Dim() / vocab_dim; row_idx ++) { + kaldi::Vector vec_tmp(vocab_dim); + std::memcpy(vec_tmp.Data(), logprobs.Data() + row_idx*vocab_dim, sizeof(kaldi::BaseFloat) * vocab_dim); prob_vec.push_back(vec_tmp); - frame_idx++; } - + VLOG(2) << "frame_idx: " << frame_idx << " elapsed: " << timer.Elapsed() << " sec."; } // get encoder out @@ -196,8 +197,9 @@ int main(int argc, char* argv[]) { ++num_done; } + double elapsed = timer.Elapsed(); - LOG(INFO) << " cost:" << elapsed << " sec"; + LOG(INFO) << "Program cost:" << elapsed << " sec"; LOG(INFO) << "Done " << num_done << " utterances, " << num_err << " with errors.";