// Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved. // // Licensed under the Apache License, Version 2.0 (the "License"); // you may not use this file except in compliance with the License. // You may obtain a copy of the License at // // http://www.apache.org/licenses/LICENSE-2.0 // // Unless required by applicable law or agreed to in writing, software // distributed under the License is distributed on an "AS IS" BASIS, // WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. // See the License for the specific language governing permissions and // limitations under the License. #include "nnet/decodable.h" namespace ppspeech { using kaldi::BaseFloat; using kaldi::Matrix; using kaldi::Vector; using std::vector; Decodable::Decodable(const std::shared_ptr& nnet_producer, kaldi::BaseFloat acoustic_scale) : nnet_producer_(nnet_producer), frame_offset_(0), frames_ready_(0), acoustic_scale_(acoustic_scale) {} // for debug void Decodable::Acceptlikelihood(const Matrix& likelihood) { nnet_producer_->Acceptlikelihood(likelihood); } // return the size of frame have computed. int32 Decodable::NumFramesReady() const { return frames_ready_; } // frame idx is from 0 to frame_ready_ -1; bool Decodable::IsLastFrame(int32 frame) { EnsureFrameHaveComputed(frame); return frame >= frames_ready_; } int32 Decodable::NumIndices() const { return 0; } // the ilable(TokenId) of wfst(TLG) insert (id = 0) in front of Nnet prob // id. int32 Decodable::TokenId2NnetId(int32 token_id) { return token_id - 1; } bool Decodable::EnsureFrameHaveComputed(int32 frame) { // decoding frame if (frame >= frames_ready_) { return AdvanceChunk(); } return true; } bool Decodable::AdvanceChunk() { kaldi::Timer timer; bool flag = nnet_producer_->Read(&framelikelihood_); if (flag == false) return false; frame_offset_ = frames_ready_; frames_ready_ += 1; VLOG(1) << "AdvanceChunk feat + forward cost: " << timer.Elapsed() << " sec."; return true; } bool Decodable::AdvanceChunk(kaldi::Vector* logprobs, int* vocab_dim) { if (AdvanceChunk() == false) { return false; } if (framelikelihood_.empty()) { LOG(WARNING) << "No new nnet out in cache."; return false; } size_t dim = framelikelihood_.size(); logprobs->Resize(framelikelihood_.size()); std::memcpy(logprobs->Data(), framelikelihood_.data(), dim * sizeof(kaldi::BaseFloat)); *vocab_dim = framelikelihood_.size(); return true; } // read one frame likelihood bool Decodable::FrameLikelihood(int32 frame, vector* likelihood) { if (EnsureFrameHaveComputed(frame) == false) { VLOG(3) << "framelikehood exit."; return false; } CHECK_EQ(1, (frames_ready_ - frame_offset_)); *likelihood = framelikelihood_; return true; } BaseFloat Decodable::LogLikelihood(int32 frame, int32 index) { if (EnsureFrameHaveComputed(frame) == false) { return false; } CHECK_LE(index, framelikelihood_.size()); CHECK_LE(frame, frames_ready_); // the nnet output is prob ranther than log prob // the index - 1, because the ilabel BaseFloat logprob = 0.0; int32 frame_idx = frame - frame_offset_; CHECK_EQ(frame_idx, 0); logprob = framelikelihood_[TokenId2NnetId(index)]; return acoustic_scale_ * logprob; } void Decodable::Reset() { if (nnet_producer_ != nullptr) nnet_producer_->Reset(); frame_offset_ = 0; frames_ready_ = 0; framelikelihood_.clear(); } void Decodable::AttentionRescoring(const std::vector>& hyps, float reverse_weight, std::vector* rescoring_score) { kaldi::Timer timer; nnet_producer_->AttentionRescoring(hyps, reverse_weight, rescoring_score); VLOG(1) << "Attention Rescoring cost: " << timer.Elapsed() << " sec."; } } // namespace ppspeech