pull/1599/head
Yang Zhou 3 years ago
parent c2ee6bc67d
commit 642e0840b4

@ -17,7 +17,7 @@
#include "base/flags.h" #include "base/flags.h"
#include "base/log.h" #include "base/log.h"
#include "decoder/ctc_tlg_decoder.h" #include "decoder/ctc_tlg_decoder.h"
#include "frontend/raw_audio.h" #include "frontend/audio/data_cache.h"
#include "kaldi/util/table-types.h" #include "kaldi/util/table-types.h"
#include "nnet/decodable.h" #include "nnet/decodable.h"
#include "nnet/paddle_nnet.h" #include "nnet/paddle_nnet.h"
@ -27,7 +27,7 @@ DEFINE_string(model_path, "avg_1.jit.pdmodel", "paddle nnet model");
DEFINE_string(param_path, "avg_1.jit.pdiparams", "paddle nnet model param"); DEFINE_string(param_path, "avg_1.jit.pdiparams", "paddle nnet model param");
DEFINE_string(word_symbol_table, "vocab.txt", "word symbol table"); DEFINE_string(word_symbol_table, "vocab.txt", "word symbol table");
DEFINE_string(graph_path, "TLG", "decoder graph"); DEFINE_string(graph_path, "TLG", "decoder graph");
DEFINE_double(acoustic_scale, 1.0, "acoustic scale"); DEFINE_double(acoustic_scale, 10.0, "acoustic scale");
DEFINE_int32(max_active, 5000, "decoder graph"); DEFINE_int32(max_active, 5000, "decoder graph");
@ -52,6 +52,7 @@ int main(int argc, char* argv[]) {
opts.word_symbol_table = word_symbol_table; opts.word_symbol_table = word_symbol_table;
opts.fst_path = graph_path; opts.fst_path = graph_path;
opts.opts.max_active = FLAGS_max_active; opts.opts.max_active = FLAGS_max_active;
opts.opts.beam =
ppspeech::TLGDecoder decoder(opts); ppspeech::TLGDecoder decoder(opts);
ppspeech::ModelOptions model_opts; ppspeech::ModelOptions model_opts;
@ -60,8 +61,8 @@ int main(int argc, char* argv[]) {
model_opts.cache_shape = "5-1-1024,5-1-1024"; model_opts.cache_shape = "5-1-1024,5-1-1024";
std::shared_ptr<ppspeech::PaddleNnet> nnet( std::shared_ptr<ppspeech::PaddleNnet> nnet(
new ppspeech::PaddleNnet(model_opts)); new ppspeech::PaddleNnet(model_opts));
std::shared_ptr<ppspeech::RawDataCache> raw_data( std::shared_ptr<ppspeech::DataCache> raw_data(
new ppspeech::RawDataCache()); new ppspeech::DataCache());
std::shared_ptr<ppspeech::Decodable> decodable( std::shared_ptr<ppspeech::Decodable> decodable(
new ppspeech::Decodable(nnet, raw_data)); new ppspeech::Decodable(nnet, raw_data));

@ -48,7 +48,7 @@ BaseFloat Decodable::LogLikelihood(int32 frame, int32 index) {
CHECK_LE(index, nnet_cache_.NumCols()); CHECK_LE(index, nnet_cache_.NumCols());
CHECK_LE(frame, frames_ready_); CHECK_LE(frame, frames_ready_);
int32 frame_idx = frame - frame_offset_; int32 frame_idx = frame - frame_offset_;
return nnet_cache_(frame_idx, index); return std::log(nnet_cache_(frame_idx, index) + std::numeric_limits<float>::min());
} }
bool Decodable::EnsureFrameHaveComputed(int32 frame) { bool Decodable::EnsureFrameHaveComputed(int32 frame) {
@ -65,9 +65,20 @@ bool Decodable::AdvanceChunk() {
} }
int32 nnet_dim = 0; int32 nnet_dim = 0;
Vector<BaseFloat> inferences; Vector<BaseFloat> inferences;
Matrix<BaseFloat> nnet_cache_tmp;
nnet_->FeedForward(features, frontend_->Dim(), &inferences, &nnet_dim); nnet_->FeedForward(features, frontend_->Dim(), &inferences, &nnet_dim);
nnet_cache_.Resize(inferences.Dim() / nnet_dim, nnet_dim); nnet_cache_tmp.Resize(inferences.Dim() / nnet_dim, nnet_dim);
nnet_cache_.CopyRowsFromVec(inferences); nnet_cache_tmp.CopyRowsFromVec(inferences);
// skip blank
vector<int> no_blank_record;
BaseFloat blank_threshold = 0.98;
for (int32 idx = 0; idx < nnet_cache_.NumRows(); ++idx) {
if (nnet_cache_(idx, 0) > blank_threshold) {
}
}
frame_offset_ = frames_ready_; frame_offset_ = frames_ready_;
frames_ready_ += nnet_cache_.NumRows(); frames_ready_ += nnet_cache_.NumRows();
return true; return true;

Loading…
Cancel
Save