You can not select more than 25 topics
Topics must start with a letter or number, can include dashes ('-') and can be up to 35 characters long.
137 lines
4.0 KiB
137 lines
4.0 KiB
// Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved.
|
|
//
|
|
// Licensed under the Apache License, Version 2.0 (the "License");
|
|
// you may not use this file except in compliance with the License.
|
|
// You may obtain a copy of the License at
|
|
//
|
|
// http://www.apache.org/licenses/LICENSE-2.0
|
|
//
|
|
// Unless required by applicable law or agreed to in writing, software
|
|
// distributed under the License is distributed on an "AS IS" BASIS,
|
|
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
|
// See the License for the specific language governing permissions and
|
|
// limitations under the License.
|
|
|
|
#include "nnet/decodable.h"
|
|
|
|
namespace ppspeech {
|
|
|
|
using kaldi::BaseFloat;
|
|
using kaldi::Matrix;
|
|
using kaldi::Vector;
|
|
using std::vector;
|
|
|
|
Decodable::Decodable(const std::shared_ptr<NnetProducer>& nnet_producer,
|
|
kaldi::BaseFloat acoustic_scale)
|
|
: nnet_producer_(nnet_producer),
|
|
frame_offset_(0),
|
|
frames_ready_(0),
|
|
acoustic_scale_(acoustic_scale) {}
|
|
|
|
// for debug
|
|
void Decodable::Acceptlikelihood(const Matrix<BaseFloat>& likelihood) {
|
|
nnet_producer_->Acceptlikelihood(likelihood);
|
|
}
|
|
|
|
// return the size of frame have computed.
|
|
int32 Decodable::NumFramesReady() const { return frames_ready_; }
|
|
|
|
|
|
// frame idx is from 0 to frame_ready_ -1;
|
|
bool Decodable::IsLastFrame(int32 frame) {
|
|
EnsureFrameHaveComputed(frame);
|
|
return frame >= frames_ready_;
|
|
}
|
|
|
|
int32 Decodable::NumIndices() const { return 0; }
|
|
|
|
// the ilable(TokenId) of wfst(TLG) insert <eps>(id = 0) in front of Nnet prob
|
|
// id.
|
|
int32 Decodable::TokenId2NnetId(int32 token_id) { return token_id - 1; }
|
|
|
|
|
|
bool Decodable::EnsureFrameHaveComputed(int32 frame) {
|
|
// decoding frame
|
|
if (frame >= frames_ready_) {
|
|
return AdvanceChunk();
|
|
}
|
|
return true;
|
|
}
|
|
|
|
bool Decodable::AdvanceChunk() {
|
|
kaldi::Timer timer;
|
|
bool flag = nnet_producer_->Read(&framelikelihood_);
|
|
if (flag == false) return false;
|
|
frame_offset_ = frames_ready_;
|
|
frames_ready_ += 1;
|
|
VLOG(1) << "AdvanceChunk feat + forward cost: " << timer.Elapsed()
|
|
<< " sec.";
|
|
return true;
|
|
}
|
|
|
|
bool Decodable::AdvanceChunk(kaldi::Vector<kaldi::BaseFloat>* logprobs,
|
|
int* vocab_dim) {
|
|
if (AdvanceChunk() == false) {
|
|
return false;
|
|
}
|
|
|
|
if (framelikelihood_.empty()) {
|
|
LOG(WARNING) << "No new nnet out in cache.";
|
|
return false;
|
|
}
|
|
|
|
size_t dim = framelikelihood_.size();
|
|
logprobs->Resize(framelikelihood_.size());
|
|
std::memcpy(logprobs->Data(),
|
|
framelikelihood_.data(),
|
|
dim * sizeof(kaldi::BaseFloat));
|
|
*vocab_dim = framelikelihood_.size();
|
|
return true;
|
|
}
|
|
|
|
// read one frame likelihood
|
|
bool Decodable::FrameLikelihood(int32 frame, vector<BaseFloat>* likelihood) {
|
|
if (EnsureFrameHaveComputed(frame) == false) {
|
|
VLOG(3) << "framelikehood exit.";
|
|
return false;
|
|
}
|
|
|
|
CHECK_EQ(1, (frames_ready_ - frame_offset_));
|
|
*likelihood = framelikelihood_;
|
|
return true;
|
|
}
|
|
|
|
BaseFloat Decodable::LogLikelihood(int32 frame, int32 index) {
|
|
if (EnsureFrameHaveComputed(frame) == false) {
|
|
return false;
|
|
}
|
|
|
|
CHECK_LE(index, framelikelihood_.size());
|
|
CHECK_LE(frame, frames_ready_);
|
|
|
|
// the nnet output is prob ranther than log prob
|
|
// the index - 1, because the ilabel
|
|
BaseFloat logprob = 0.0;
|
|
int32 frame_idx = frame - frame_offset_;
|
|
CHECK_EQ(frame_idx, 0);
|
|
logprob = framelikelihood_[TokenId2NnetId(index)];
|
|
return acoustic_scale_ * logprob;
|
|
}
|
|
|
|
void Decodable::Reset() {
|
|
if (nnet_producer_ != nullptr) nnet_producer_->Reset();
|
|
frame_offset_ = 0;
|
|
frames_ready_ = 0;
|
|
framelikelihood_.clear();
|
|
}
|
|
|
|
void Decodable::AttentionRescoring(const std::vector<std::vector<int>>& hyps,
|
|
float reverse_weight,
|
|
std::vector<float>* rescoring_score) {
|
|
kaldi::Timer timer;
|
|
nnet_producer_->AttentionRescoring(hyps, reverse_weight, rescoring_score);
|
|
VLOG(1) << "Attention Rescoring cost: " << timer.Elapsed() << " sec.";
|
|
}
|
|
|
|
} // namespace ppspeech
|