From 3456ae4a516aa192e86bf1391e18f734170c6cf2 Mon Sep 17 00:00:00 2001 From: Yang Zhou Date: Tue, 12 Apr 2022 12:07:57 +0800 Subject: [PATCH] add log & rename LogFrameLikelihood --- speechx/speechx/decoder/ctc_beam_search_decoder.cc | 2 +- speechx/speechx/kaldi/decoder/decodable-itf.h | 2 +- speechx/speechx/nnet/decodable.cc | 11 +++++++++-- speechx/speechx/nnet/decodable.h | 10 +++++++--- 4 files changed, 18 insertions(+), 7 deletions(-) diff --git a/speechx/speechx/decoder/ctc_beam_search_decoder.cc b/speechx/speechx/decoder/ctc_beam_search_decoder.cc index 5d7a4f77..b4caa8e7 100644 --- a/speechx/speechx/decoder/ctc_beam_search_decoder.cc +++ b/speechx/speechx/decoder/ctc_beam_search_decoder.cc @@ -93,7 +93,7 @@ void CTCBeamSearch::AdvanceDecode( vector> likelihood; vector frame_prob; bool flag = - decodable->FrameLogLikelihood(num_frame_decoded_, &frame_prob); + decodable->FrameLikelihood(num_frame_decoded_, &frame_prob); if (flag == false) break; likelihood.push_back(frame_prob); AdvanceDecoding(likelihood); diff --git a/speechx/speechx/kaldi/decoder/decodable-itf.h b/speechx/speechx/kaldi/decoder/decodable-itf.h index 19e07498..b8ce9143 100644 --- a/speechx/speechx/kaldi/decoder/decodable-itf.h +++ b/speechx/speechx/kaldi/decoder/decodable-itf.h @@ -143,7 +143,7 @@ class DecodableInterface { /// this is for compatibility with OpenFst). virtual int32 NumIndices() const = 0; - virtual bool FrameLogLikelihood( + virtual bool FrameLikelihood( int32 frame, std::vector* likelihood) = 0; diff --git a/speechx/speechx/nnet/decodable.cc b/speechx/speechx/nnet/decodable.cc index ce269650..d52b249f 100644 --- a/speechx/speechx/nnet/decodable.cc +++ b/speechx/speechx/nnet/decodable.cc @@ -49,11 +49,18 @@ bool Decodable::IsLastFrame(int32 frame) { int32 Decodable::NumIndices() const { return 0; } +// the ilable(TokenId) of wfst(TLG) insert (id = 0) in front of Nnet prob id. +int32 Decodable::TokenId2NnetId(int32 token_id) { + return token_id - 1; +} + BaseFloat Decodable::LogLikelihood(int32 frame, int32 index) { CHECK_LE(index, nnet_cache_.NumCols()); CHECK_LE(frame, frames_ready_); int32 frame_idx = frame - frame_offset_; - return acoustic_scale_ * std::log(nnet_cache_(frame_idx, index - 1) + + // the nnet output is prob ranther than log prob + // the index - 1, because the ilabel + return acoustic_scale_ * std::log(nnet_cache_(frame_idx, TokenId2NnetId(index)) + std::numeric_limits::min()); } @@ -81,7 +88,7 @@ bool Decodable::AdvanceChunk() { return true; } -bool Decodable::FrameLogLikelihood(int32 frame, vector* likelihood) { +bool Decodable::FrameLikelihood(int32 frame, vector* likelihood) { std::vector result; if (EnsureFrameHaveComputed(frame) == false) return false; likelihood->resize(nnet_cache_.NumCols()); diff --git a/speechx/speechx/nnet/decodable.h b/speechx/speechx/nnet/decodable.h index b18ef07c..9555fea7 100644 --- a/speechx/speechx/nnet/decodable.h +++ b/speechx/speechx/nnet/decodable.h @@ -31,24 +31,28 @@ class Decodable : public kaldi::DecodableInterface { virtual kaldi::BaseFloat LogLikelihood(int32 frame, int32 index); virtual bool IsLastFrame(int32 frame); virtual int32 NumIndices() const; - virtual bool FrameLogLikelihood(int32 frame, - std::vector* likelihood); + // not logprob + virtual bool FrameLikelihood(int32 frame, + std::vector* likelihood); virtual int32 NumFramesReady() const; // for offline test void Acceptlikelihood(const kaldi::Matrix& likelihood); void Reset(); bool IsInputFinished() const { return frontend_->IsFinished(); } bool EnsureFrameHaveComputed(int32 frame); + int32 TokenId2NnetId(int32 token_id); private: bool AdvanceChunk(); std::shared_ptr frontend_; std::shared_ptr nnet_; kaldi::Matrix nnet_cache_; + // the frame is nnet prob frame rather than audio feature frame + // nnet frame subsample the feature frame + // eg: 35 frame features output 8 frame inferences int32 frame_offset_; int32 frames_ready_; // todo: feature frame mismatch with nnet inference frame - // eg: 35 frame features output 8 frame inferences // so use subsampled_frame int32 current_log_post_subsampled_offset_; int32 num_chunk_computed_;