|
|
@ -55,18 +55,10 @@ int32 Decodable::NumIndices() const { return 0; }
|
|
|
|
// id.
|
|
|
|
// id.
|
|
|
|
int32 Decodable::TokenId2NnetId(int32 token_id) { return token_id - 1; }
|
|
|
|
int32 Decodable::TokenId2NnetId(int32 token_id) { return token_id - 1; }
|
|
|
|
|
|
|
|
|
|
|
|
BaseFloat Decodable::LogLikelihood(int32 frame, int32 index) {
|
|
|
|
|
|
|
|
CHECK_LE(index, nnet_out_cache_.NumCols());
|
|
|
|
|
|
|
|
CHECK_LE(frame, frames_ready_);
|
|
|
|
|
|
|
|
int32 frame_idx = frame - frame_offset_;
|
|
|
|
|
|
|
|
// the nnet output is prob ranther than log prob
|
|
|
|
|
|
|
|
// the index - 1, because the ilabel
|
|
|
|
|
|
|
|
return acoustic_scale_ *
|
|
|
|
|
|
|
|
std::log(nnet_out_cache_(frame_idx, TokenId2NnetId(index)) +
|
|
|
|
|
|
|
|
std::numeric_limits<float>::min());
|
|
|
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
bool Decodable::EnsureFrameHaveComputed(int32 frame) {
|
|
|
|
bool Decodable::EnsureFrameHaveComputed(int32 frame) {
|
|
|
|
|
|
|
|
// decoding frame
|
|
|
|
if (frame >= frames_ready_) {
|
|
|
|
if (frame >= frames_ready_) {
|
|
|
|
return AdvanceChunk();
|
|
|
|
return AdvanceChunk();
|
|
|
|
}
|
|
|
|
}
|
|
|
@ -74,26 +66,48 @@ bool Decodable::EnsureFrameHaveComputed(int32 frame) {
|
|
|
|
}
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
|
|
bool Decodable::AdvanceChunk() {
|
|
|
|
bool Decodable::AdvanceChunk() {
|
|
|
|
|
|
|
|
kaldi::Timer timer;
|
|
|
|
// read feats
|
|
|
|
// read feats
|
|
|
|
Vector<BaseFloat> features;
|
|
|
|
Vector<BaseFloat> features;
|
|
|
|
if (frontend_ == NULL || frontend_->Read(&features) == false) {
|
|
|
|
if (frontend_ == NULL || frontend_->Read(&features) == false) {
|
|
|
|
// no feat or frontend_ not init.
|
|
|
|
// no feat or frontend_ not init.
|
|
|
|
return false;
|
|
|
|
return false;
|
|
|
|
}
|
|
|
|
}
|
|
|
|
|
|
|
|
VLOG(2) << "Forward with " << features.Dim() << " frames.";
|
|
|
|
|
|
|
|
|
|
|
|
// forward feats
|
|
|
|
// forward feats
|
|
|
|
NnetOut out;
|
|
|
|
NnetOut out;
|
|
|
|
nnet_->FeedForward(features, frontend_->Dim(), &out);
|
|
|
|
nnet_->FeedForward(features, frontend_->Dim(), &out);
|
|
|
|
int32& vocab_dim = out.vocab_dim;
|
|
|
|
int32& vocab_dim = out.vocab_dim;
|
|
|
|
Vector<BaseFloat>& probs = out.logprobs;
|
|
|
|
Vector<BaseFloat>& logprobs = out.logprobs;
|
|
|
|
|
|
|
|
|
|
|
|
// cache nnet outupts
|
|
|
|
// cache nnet outupts
|
|
|
|
nnet_out_cache_.Resize(probs.Dim() / vocab_dim, vocab_dim);
|
|
|
|
nnet_out_cache_.Resize(logprobs.Dim() / vocab_dim, vocab_dim);
|
|
|
|
nnet_out_cache_.CopyRowsFromVec(probs);
|
|
|
|
nnet_out_cache_.CopyRowsFromVec(logprobs);
|
|
|
|
|
|
|
|
|
|
|
|
// update state
|
|
|
|
// update state, decoding frame.
|
|
|
|
frame_offset_ = frames_ready_;
|
|
|
|
frame_offset_ = frames_ready_;
|
|
|
|
frames_ready_ += nnet_out_cache_.NumRows();
|
|
|
|
frames_ready_ += nnet_out_cache_.NumRows();
|
|
|
|
|
|
|
|
VLOG(2) << "Forward feat chunk cost: " << timer.Elapsed() << " sec.";
|
|
|
|
|
|
|
|
return true;
|
|
|
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
bool Decodable::AdvanceChunk(kaldi::Vector<kaldi::BaseFloat>* logprobs, int* vocab_dim) {
|
|
|
|
|
|
|
|
if (AdvanceChunk() == false) {
|
|
|
|
|
|
|
|
return false;
|
|
|
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
int nrows = nnet_out_cache_.NumRows();
|
|
|
|
|
|
|
|
CHECK(nrows == (frames_ready_ - frame_offset_));
|
|
|
|
|
|
|
|
if (nrows <= 0){
|
|
|
|
|
|
|
|
LOG(WARNING) << "No new nnet out in cache.";
|
|
|
|
|
|
|
|
return false;
|
|
|
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
logprobs->Resize(nnet_out_cache_.NumRows() * nnet_out_cache_.NumCols());
|
|
|
|
|
|
|
|
logprobs->CopyRowsFromMat(nnet_out_cache_);
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
*vocab_dim = nnet_out_cache_.NumCols();
|
|
|
|
return true;
|
|
|
|
return true;
|
|
|
|
}
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
|
@ -113,6 +127,28 @@ bool Decodable::FrameLikelihood(int32 frame, vector<BaseFloat>* likelihood) {
|
|
|
|
return true;
|
|
|
|
return true;
|
|
|
|
}
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
BaseFloat Decodable::LogLikelihood(int32 frame, int32 index) {
|
|
|
|
|
|
|
|
if (EnsureFrameHaveComputed(frame) == false) {
|
|
|
|
|
|
|
|
return false;
|
|
|
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
CHECK_LE(index, nnet_out_cache_.NumCols());
|
|
|
|
|
|
|
|
CHECK_LE(frame, frames_ready_);
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
// the nnet output is prob ranther than log prob
|
|
|
|
|
|
|
|
// the index - 1, because the ilabel
|
|
|
|
|
|
|
|
BaseFloat logprob = 0.0;
|
|
|
|
|
|
|
|
int32 frame_idx = frame - frame_offset_;
|
|
|
|
|
|
|
|
BaseFloat nnet_out = nnet_out_cache_(frame_idx, TokenId2NnetId(index));
|
|
|
|
|
|
|
|
if (nnet_->IsLogProb()){
|
|
|
|
|
|
|
|
logprob = nnet_out;
|
|
|
|
|
|
|
|
} else {
|
|
|
|
|
|
|
|
logprob = std::log(nnet_out + std::numeric_limits<float>::epsilon());
|
|
|
|
|
|
|
|
}
|
|
|
|
|
|
|
|
CHECK(!std::isnan(logprob) && !std::isinf(logprob));
|
|
|
|
|
|
|
|
return acoustic_scale_ * logprob;
|
|
|
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
|
|
void Decodable::Reset() {
|
|
|
|
void Decodable::Reset() {
|
|
|
|
if (frontend_ != nullptr) frontend_->Reset();
|
|
|
|
if (frontend_ != nullptr) frontend_->Reset();
|
|
|
|
if (nnet_ != nullptr) nnet_->Reset();
|
|
|
|
if (nnet_ != nullptr) nnet_->Reset();
|
|
|
|