You can not select more than 25 topics Topics must start with a letter or number, can include dashes ('-') and can be up to 35 characters long.
PaddleSpeech/runtime/engine/asr/nnet/decodable.cc

137 lines
4.0 KiB

// Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
#include "nnet/decodable.h"
namespace ppspeech {
using kaldi::BaseFloat;
using kaldi::Matrix;
using kaldi::Vector;
using std::vector;
Decodable::Decodable(const std::shared_ptr<NnetProducer>& nnet_producer,
kaldi::BaseFloat acoustic_scale)
: nnet_producer_(nnet_producer),
frame_offset_(0),
frames_ready_(0),
acoustic_scale_(acoustic_scale) {}
// for debug
void Decodable::Acceptlikelihood(const Matrix<BaseFloat>& likelihood) {
nnet_producer_->Acceptlikelihood(likelihood);
}
// return the size of frame have computed.
int32 Decodable::NumFramesReady() const { return frames_ready_; }
// frame idx is from 0 to frame_ready_ -1;
bool Decodable::IsLastFrame(int32 frame) {
EnsureFrameHaveComputed(frame);
return frame >= frames_ready_;
}
int32 Decodable::NumIndices() const { return 0; }
// the ilable(TokenId) of wfst(TLG) insert <eps>(id = 0) in front of Nnet prob
// id.
int32 Decodable::TokenId2NnetId(int32 token_id) { return token_id - 1; }
bool Decodable::EnsureFrameHaveComputed(int32 frame) {
// decoding frame
if (frame >= frames_ready_) {
return AdvanceChunk();
}
return true;
}
bool Decodable::AdvanceChunk() {
kaldi::Timer timer;
bool flag = nnet_producer_->Read(&framelikelihood_);
if (flag == false) return false;
frame_offset_ = frames_ready_;
frames_ready_ += 1;
VLOG(1) << "AdvanceChunk feat + forward cost: " << timer.Elapsed()
<< " sec.";
return true;
}
bool Decodable::AdvanceChunk(kaldi::Vector<kaldi::BaseFloat>* logprobs,
int* vocab_dim) {
if (AdvanceChunk() == false) {
return false;
}
if (framelikelihood_.empty()) {
LOG(WARNING) << "No new nnet out in cache.";
return false;
}
size_t dim = framelikelihood_.size();
logprobs->Resize(framelikelihood_.size());
std::memcpy(logprobs->Data(),
framelikelihood_.data(),
dim * sizeof(kaldi::BaseFloat));
*vocab_dim = framelikelihood_.size();
return true;
}
// read one frame likelihood
bool Decodable::FrameLikelihood(int32 frame, vector<BaseFloat>* likelihood) {
if (EnsureFrameHaveComputed(frame) == false) {
VLOG(3) << "framelikehood exit.";
return false;
}
CHECK_EQ(1, (frames_ready_ - frame_offset_));
*likelihood = framelikelihood_;
return true;
}
BaseFloat Decodable::LogLikelihood(int32 frame, int32 index) {
if (EnsureFrameHaveComputed(frame) == false) {
return false;
}
CHECK_LE(index, framelikelihood_.size());
CHECK_LE(frame, frames_ready_);
// the nnet output is prob ranther than log prob
// the index - 1, because the ilabel
BaseFloat logprob = 0.0;
int32 frame_idx = frame - frame_offset_;
CHECK_EQ(frame_idx, 0);
logprob = framelikelihood_[TokenId2NnetId(index)];
return acoustic_scale_ * logprob;
}
void Decodable::Reset() {
if (nnet_producer_ != nullptr) nnet_producer_->Reset();
frame_offset_ = 0;
frames_ready_ = 0;
framelikelihood_.clear();
}
void Decodable::AttentionRescoring(const std::vector<std::vector<int>>& hyps,
float reverse_weight,
std::vector<float>* rescoring_score) {
kaldi::Timer timer;
nnet_producer_->AttentionRescoring(hyps, reverse_weight, rescoring_score);
VLOG(1) << "Attention Rescoring cost: " << timer.Elapsed() << " sec.";
}
} // namespace ppspeech