fix LogLikelihood and add AdvanceChunk

pull/2524/head
Hui Zhang 3 years ago
parent 5cc874e1c3
commit 6987751ff8

@ -15,6 +15,7 @@
#pragma once #pragma once
#include <algorithm> #include <algorithm>
#include <cmath>
#include <condition_variable> #include <condition_variable>
#include <cstring> #include <cstring>
#include <deque> #include <deque>

@ -47,13 +47,13 @@ int main(int argc, char* argv[]) {
for (auto obj : value.as_object()) { for (auto obj : value.as_object()) {
if (obj.key() == "mean_stat") { if (obj.key() == "mean_stat") {
LOG(INFO) << "mean_stat:" << obj.value(); VLOG(2) << "mean_stat:" << obj.value();
} }
if (obj.key() == "var_stat") { if (obj.key() == "var_stat") {
LOG(INFO) << "var_stat: " << obj.value(); VLOG(2) << "var_stat: " << obj.value();
} }
if (obj.key() == "frame_num") { if (obj.key() == "frame_num") {
LOG(INFO) << "frame_num: " << obj.value(); VLOG(2) << "frame_num: " << obj.value();
} }
} }
@ -79,7 +79,7 @@ int main(int argc, char* argv[]) {
cmvn_stats(1, idx) = var_stat_vec[idx]; cmvn_stats(1, idx) = var_stat_vec[idx];
} }
cmvn_stats(0, mean_size) = frame_num; cmvn_stats(0, mean_size) = frame_num;
LOG(INFO) << cmvn_stats; VLOG(2) << cmvn_stats;
kaldi::WriteKaldiObject(cmvn_stats, FLAGS_cmvn_write_path, FLAGS_binary); kaldi::WriteKaldiObject(cmvn_stats, FLAGS_cmvn_write_path, FLAGS_binary);
LOG(INFO) << "cmvn stats have write into: " << FLAGS_cmvn_write_path; LOG(INFO) << "cmvn stats have write into: " << FLAGS_cmvn_write_path;

@ -101,7 +101,9 @@ namespace kaldi {
*/ */
class DecodableInterface { class DecodableInterface {
public: public:
/// Returns the log likelihood, which will be negated in the decoder. virtual ~DecodableInterface() {}
/// Returns the log likelihood(logprob), which will be negated in the decoder.
/// The "frame" starts from zero. You should verify that NumFramesReady() > /// The "frame" starts from zero. You should verify that NumFramesReady() >
/// frame /// frame
/// before calling this. /// before calling this.
@ -143,11 +145,12 @@ class DecodableInterface {
/// this is for compatibility with OpenFst). /// this is for compatibility with OpenFst).
virtual int32 NumIndices() const = 0; virtual int32 NumIndices() const = 0;
/// Returns the likelihood(prob), which will be postive in the decoder.
/// The "frame" starts from zero. You should verify that NumFramesReady() >
/// frame
/// before calling this.
virtual bool FrameLikelihood( virtual bool FrameLikelihood(
int32 frame, std::vector<kaldi::BaseFloat>* likelihood) = 0; int32 frame, std::vector<kaldi::BaseFloat>* likelihood) = 0;
virtual ~DecodableInterface() {}
}; };
/// @} /// @}
} // namespace Kaldi } // namespace Kaldi

@ -55,18 +55,10 @@ int32 Decodable::NumIndices() const { return 0; }
// id. // id.
int32 Decodable::TokenId2NnetId(int32 token_id) { return token_id - 1; } int32 Decodable::TokenId2NnetId(int32 token_id) { return token_id - 1; }
BaseFloat Decodable::LogLikelihood(int32 frame, int32 index) {
CHECK_LE(index, nnet_out_cache_.NumCols());
CHECK_LE(frame, frames_ready_);
int32 frame_idx = frame - frame_offset_;
// the nnet output is prob ranther than log prob
// the index - 1, because the ilabel
return acoustic_scale_ *
std::log(nnet_out_cache_(frame_idx, TokenId2NnetId(index)) +
std::numeric_limits<float>::min());
}
bool Decodable::EnsureFrameHaveComputed(int32 frame) { bool Decodable::EnsureFrameHaveComputed(int32 frame) {
// decoding frame
if (frame >= frames_ready_) { if (frame >= frames_ready_) {
return AdvanceChunk(); return AdvanceChunk();
} }
@ -74,26 +66,48 @@ bool Decodable::EnsureFrameHaveComputed(int32 frame) {
} }
bool Decodable::AdvanceChunk() { bool Decodable::AdvanceChunk() {
kaldi::Timer timer;
// read feats // read feats
Vector<BaseFloat> features; Vector<BaseFloat> features;
if (frontend_ == NULL || frontend_->Read(&features) == false) { if (frontend_ == NULL || frontend_->Read(&features) == false) {
// no feat or frontend_ not init. // no feat or frontend_ not init.
return false; return false;
} }
VLOG(2) << "Forward with " << features.Dim() << " frames.";
// forward feats // forward feats
NnetOut out; NnetOut out;
nnet_->FeedForward(features, frontend_->Dim(), &out); nnet_->FeedForward(features, frontend_->Dim(), &out);
int32& vocab_dim = out.vocab_dim; int32& vocab_dim = out.vocab_dim;
Vector<BaseFloat>& probs = out.logprobs; Vector<BaseFloat>& logprobs = out.logprobs;
// cache nnet outupts // cache nnet outupts
nnet_out_cache_.Resize(probs.Dim() / vocab_dim, vocab_dim); nnet_out_cache_.Resize(logprobs.Dim() / vocab_dim, vocab_dim);
nnet_out_cache_.CopyRowsFromVec(probs); nnet_out_cache_.CopyRowsFromVec(logprobs);
// update state // update state, decoding frame.
frame_offset_ = frames_ready_; frame_offset_ = frames_ready_;
frames_ready_ += nnet_out_cache_.NumRows(); frames_ready_ += nnet_out_cache_.NumRows();
VLOG(2) << "Forward feat chunk cost: " << timer.Elapsed() << " sec.";
return true;
}
bool Decodable::AdvanceChunk(kaldi::Vector<kaldi::BaseFloat>* logprobs, int* vocab_dim) {
if (AdvanceChunk() == false) {
return false;
}
int nrows = nnet_out_cache_.NumRows();
CHECK(nrows == (frames_ready_ - frame_offset_));
if (nrows <= 0){
LOG(WARNING) << "No new nnet out in cache.";
return false;
}
logprobs->Resize(nnet_out_cache_.NumRows() * nnet_out_cache_.NumCols());
logprobs->CopyRowsFromMat(nnet_out_cache_);
*vocab_dim = nnet_out_cache_.NumCols();
return true; return true;
} }
@ -113,6 +127,28 @@ bool Decodable::FrameLikelihood(int32 frame, vector<BaseFloat>* likelihood) {
return true; return true;
} }
BaseFloat Decodable::LogLikelihood(int32 frame, int32 index) {
if (EnsureFrameHaveComputed(frame) == false) {
return false;
}
CHECK_LE(index, nnet_out_cache_.NumCols());
CHECK_LE(frame, frames_ready_);
// the nnet output is prob ranther than log prob
// the index - 1, because the ilabel
BaseFloat logprob = 0.0;
int32 frame_idx = frame - frame_offset_;
BaseFloat nnet_out = nnet_out_cache_(frame_idx, TokenId2NnetId(index));
if (nnet_->IsLogProb()){
logprob = nnet_out;
} else {
logprob = std::log(nnet_out + std::numeric_limits<float>::epsilon());
}
CHECK(!std::isnan(logprob) && !std::isinf(logprob));
return acoustic_scale_ * logprob;
}
void Decodable::Reset() { void Decodable::Reset() {
if (frontend_ != nullptr) frontend_->Reset(); if (frontend_ != nullptr) frontend_->Reset();
if (nnet_ != nullptr) nnet_->Reset(); if (nnet_ != nullptr) nnet_->Reset();

@ -57,9 +57,13 @@ class Decodable : public kaldi::DecodableInterface {
std::shared_ptr<NnetInterface> Nnet() { return nnet_; } std::shared_ptr<NnetInterface> Nnet() { return nnet_; }
private: // forward nnet with feats
bool AdvanceChunk(); bool AdvanceChunk();
// forward nnet with feats, and get nnet output
bool AdvanceChunk(kaldi::Vector<kaldi::BaseFloat>* logprobs,
int* vocab_dim);
private:
std::shared_ptr<FrontendInterface> frontend_; std::shared_ptr<FrontendInterface> frontend_;
std::shared_ptr<NnetInterface> nnet_; std::shared_ptr<NnetInterface> nnet_;

@ -104,6 +104,8 @@ class PaddleNnet : public NnetInterface {
void Reset() override; void Reset() override;
bool IsLogProb() override { return false; }
std::shared_ptr<Tensor<kaldi::BaseFloat>> GetCacheEncoder( std::shared_ptr<Tensor<kaldi::BaseFloat>> GetCacheEncoder(
const std::string& name); const std::string& name);

@ -39,7 +39,8 @@ class NnetInterface {
// forward feat with nnet. // forward feat with nnet.
// nnet do not cache feats, feats cached by frontend. // nnet do not cache feats, feats cached by frontend.
// nnet cache model outputs, i.e. logprobs/encoder_outs. // nnet cache model state, i.e. encoder_outs, att_cache, cnn_cache,
// frame_offset.
virtual void FeedForward(const kaldi::Vector<kaldi::BaseFloat>& features, virtual void FeedForward(const kaldi::Vector<kaldi::BaseFloat>& features,
const int32& feature_dim, const int32& feature_dim,
NnetOut* out) = 0; NnetOut* out) = 0;
@ -47,6 +48,9 @@ class NnetInterface {
// reset nnet state, e.g. nnet_logprob_cache_, offset_, encoder_outs_. // reset nnet state, e.g. nnet_logprob_cache_, offset_, encoder_outs_.
virtual void Reset() = 0; virtual void Reset() = 0;
// true, nnet output is logprob; otherwise is prob,
virtual bool IsLogProb() = 0;
// using to get encoder outs. e.g. seq2seq with Attention model. // using to get encoder outs. e.g. seq2seq with Attention model.
virtual void EncoderOuts( virtual void EncoderOuts(
std::vector<kaldi::Vector<kaldi::BaseFloat>>* encoder_out) const = 0; std::vector<kaldi::Vector<kaldi::BaseFloat>>* encoder_out) const = 0;

@ -111,6 +111,8 @@ class U2Nnet : public U2NnetBase {
void Reset() override; void Reset() override;
bool IsLogProb() override { return true; }
void Dim(); void Dim();
void LoadModel(const std::string& model_path_w_prefix); void LoadModel(const std::string& model_path_w_prefix);

@ -98,6 +98,7 @@ int main(int argc, char* argv[]) {
// } // }
int32 frame_idx = 0; int32 frame_idx = 0;
int vocab_dim = 0;
std::vector<kaldi::Vector<kaldi::BaseFloat>> prob_vec; std::vector<kaldi::Vector<kaldi::BaseFloat>> prob_vec;
std::vector<kaldi::Vector<kaldi::BaseFloat>> encoder_out_vec; std::vector<kaldi::Vector<kaldi::BaseFloat>> encoder_out_vec;
int32 ori_feature_len = feature.NumRows(); int32 ori_feature_len = feature.NumRows();
@ -138,17 +139,17 @@ int main(int argc, char* argv[]) {
} }
// get nnet outputs // get nnet outputs
vector<kaldi::BaseFloat> prob; kaldi::Timer timer;
while (decodable->FrameLikelihood(frame_idx, &prob)) { kaldi::Vector<kaldi::BaseFloat> logprobs;
kaldi::Vector<kaldi::BaseFloat> vec_tmp(prob.size()); bool isok = decodable->AdvanceChunk(&logprobs, &vocab_dim);
std::memcpy(vec_tmp.Data(), CHECK(isok == true);
prob.data(), for (int row_idx = 0; row_idx < logprobs.Dim() / vocab_dim; row_idx ++) {
sizeof(kaldi::BaseFloat) * prob.size()); kaldi::Vector<kaldi::BaseFloat> vec_tmp(vocab_dim);
std::memcpy(vec_tmp.Data(), logprobs.Data() + row_idx*vocab_dim, sizeof(kaldi::BaseFloat) * vocab_dim);
prob_vec.push_back(vec_tmp); prob_vec.push_back(vec_tmp);
frame_idx++;
} }
VLOG(2) << "frame_idx: " << frame_idx << " elapsed: " << timer.Elapsed() << " sec.";
} }
// get encoder out // get encoder out
@ -196,8 +197,9 @@ int main(int argc, char* argv[]) {
++num_done; ++num_done;
} }
double elapsed = timer.Elapsed(); double elapsed = timer.Elapsed();
LOG(INFO) << " cost:" << elapsed << " sec"; LOG(INFO) << "Program cost:" << elapsed << " sec";
LOG(INFO) << "Done " << num_done << " utterances, " << num_err LOG(INFO) << "Done " << num_done << " utterances, " << num_err
<< " with errors."; << " with errors.";

Loading…
Cancel
Save