diff --git a/speechx/examples/decoder/offline_wfst_decoder_main.cc b/speechx/examples/decoder/offline_wfst_decoder_main.cc index f0b9cc4f..758942b5 100644 --- a/speechx/examples/decoder/offline_wfst_decoder_main.cc +++ b/speechx/examples/decoder/offline_wfst_decoder_main.cc @@ -17,7 +17,7 @@ #include "base/flags.h" #include "base/log.h" #include "decoder/ctc_tlg_decoder.h" -#include "frontend/raw_audio.h" +#include "frontend/audio/data_cache.h" #include "kaldi/util/table-types.h" #include "nnet/decodable.h" #include "nnet/paddle_nnet.h" @@ -27,7 +27,7 @@ DEFINE_string(model_path, "avg_1.jit.pdmodel", "paddle nnet model"); DEFINE_string(param_path, "avg_1.jit.pdiparams", "paddle nnet model param"); DEFINE_string(word_symbol_table, "vocab.txt", "word symbol table"); DEFINE_string(graph_path, "TLG", "decoder graph"); -DEFINE_double(acoustic_scale, 1.0, "acoustic scale"); +DEFINE_double(acoustic_scale, 10.0, "acoustic scale"); DEFINE_int32(max_active, 5000, "decoder graph"); @@ -52,6 +52,7 @@ int main(int argc, char* argv[]) { opts.word_symbol_table = word_symbol_table; opts.fst_path = graph_path; opts.opts.max_active = FLAGS_max_active; + opts.opts.beam = ppspeech::TLGDecoder decoder(opts); ppspeech::ModelOptions model_opts; @@ -60,8 +61,8 @@ int main(int argc, char* argv[]) { model_opts.cache_shape = "5-1-1024,5-1-1024"; std::shared_ptr nnet( new ppspeech::PaddleNnet(model_opts)); - std::shared_ptr raw_data( - new ppspeech::RawDataCache()); + std::shared_ptr raw_data( + new ppspeech::DataCache()); std::shared_ptr decodable( new ppspeech::Decodable(nnet, raw_data)); diff --git a/speechx/speechx/nnet/decodable.cc b/speechx/speechx/nnet/decodable.cc index de15c4f8..805c0dca 100644 --- a/speechx/speechx/nnet/decodable.cc +++ b/speechx/speechx/nnet/decodable.cc @@ -48,7 +48,7 @@ BaseFloat Decodable::LogLikelihood(int32 frame, int32 index) { CHECK_LE(index, nnet_cache_.NumCols()); CHECK_LE(frame, frames_ready_); int32 frame_idx = frame - frame_offset_; - return nnet_cache_(frame_idx, index); + return std::log(nnet_cache_(frame_idx, index) + std::numeric_limits::min()); } bool Decodable::EnsureFrameHaveComputed(int32 frame) { @@ -65,9 +65,20 @@ bool Decodable::AdvanceChunk() { } int32 nnet_dim = 0; Vector inferences; + Matrix nnet_cache_tmp; nnet_->FeedForward(features, frontend_->Dim(), &inferences, &nnet_dim); - nnet_cache_.Resize(inferences.Dim() / nnet_dim, nnet_dim); - nnet_cache_.CopyRowsFromVec(inferences); + nnet_cache_tmp.Resize(inferences.Dim() / nnet_dim, nnet_dim); + nnet_cache_tmp.CopyRowsFromVec(inferences); + // skip blank + vector no_blank_record; + BaseFloat blank_threshold = 0.98; + for (int32 idx = 0; idx < nnet_cache_.NumRows(); ++idx) { + if (nnet_cache_(idx, 0) > blank_threshold) { + + } + } + + frame_offset_ = frames_ready_; frames_ready_ += nnet_cache_.NumRows(); return true;