From 93c3e03bc846053889c823a349d7921686a329c3 Mon Sep 17 00:00:00 2001 From: Hui Zhang Date: Fri, 1 Apr 2022 07:49:53 +0000 Subject: [PATCH] more comment --- speechx/examples/decoder/decoder_test_main.cc | 7 ++- .../examples/decoder/offline_decoder_main.cc | 2 +- .../offline_decoder_sliding_chunk_main.cc | 48 ++++++++++++++----- speechx/speechx/nnet/decodable.cc | 2 +- 4 files changed, 43 insertions(+), 16 deletions(-) diff --git a/speechx/examples/decoder/decoder_test_main.cc b/speechx/examples/decoder/decoder_test_main.cc index 79fe63fcd..0e249cc6b 100644 --- a/speechx/examples/decoder/decoder_test_main.cc +++ b/speechx/examples/decoder/decoder_test_main.cc @@ -24,11 +24,11 @@ DEFINE_string(nnet_prob_respecifier, "", "test nnet prob rspecifier"); DEFINE_string(dict_file, "vocab.txt", "vocabulary of lm"); DEFINE_string(lm_path, "lm.klm", "language model"); - using kaldi::BaseFloat; using kaldi::Matrix; using std::vector; +// test decoder by feeding nnet posterior probability int main(int argc, char* argv[]) { gflags::ParseCommandLineFlags(&argc, &argv, false); google::InitGoogleLogging(argv[0]); @@ -37,6 +37,8 @@ int main(int argc, char* argv[]) { FLAGS_nnet_prob_respecifier); std::string dict_file = FLAGS_dict_file; std::string lm_path = FLAGS_lm_path; + LOG(INFO) << "dict path: " << dict_file; + LOG(INFO) << "lm path: " << lm_path; int32 num_done = 0, num_err = 0; @@ -53,6 +55,9 @@ int main(int argc, char* argv[]) { for (; !likelihood_reader.Done(); likelihood_reader.Next()) { string utt = likelihood_reader.Key(); const kaldi::Matrix likelihood = likelihood_reader.Value(); + LOG(INFO) << "process utt: " << utt; + LOG(INFO) << "rows: " << likelihood.NumRows(); + LOG(INFO) << "cols: " << likelihood.NumCols(); decodable->Acceptlikelihood(likelihood); decoder.AdvanceDecode(decodable); std::string result; diff --git a/speechx/examples/decoder/offline_decoder_main.cc b/speechx/examples/decoder/offline_decoder_main.cc index c73d59682..6bd83b9b1 100644 --- a/speechx/examples/decoder/offline_decoder_main.cc +++ b/speechx/examples/decoder/offline_decoder_main.cc @@ -34,6 +34,7 @@ using kaldi::BaseFloat; using kaldi::Matrix; using std::vector; +// test decoder by feeding speech feature, deprecated. int main(int argc, char* argv[]) { gflags::ParseCommandLineFlags(&argc, &argv, false); google::InitGoogleLogging(argv[0]); @@ -55,7 +56,6 @@ int main(int argc, char* argv[]) { // frontend + nnet is decodable ppspeech::ModelOptions model_opts; - model_opts.cache_shape = "5-1-1024,5-1-1024"; model_opts.model_path = model_graph; model_opts.params_path = model_params; std::shared_ptr nnet( diff --git a/speechx/examples/decoder/offline_decoder_sliding_chunk_main.cc b/speechx/examples/decoder/offline_decoder_sliding_chunk_main.cc index 27bd7b1bc..4d5ffe145 100644 --- a/speechx/examples/decoder/offline_decoder_sliding_chunk_main.cc +++ b/speechx/examples/decoder/offline_decoder_sliding_chunk_main.cc @@ -27,12 +27,19 @@ DEFINE_string(model_path, "avg_1.jit.pdmodel", "paddle nnet model"); DEFINE_string(param_path, "avg_1.jit.pdiparams", "paddle nnet model param"); DEFINE_string(dict_file, "vocab.txt", "vocabulary of lm"); DEFINE_string(lm_path, "lm.klm", "language model"); - +DEFINE_int32(receptive_field_length, + 7, + "receptive field of two CNN(kernel=5) downsampling module."); +DEFINE_int32(downsampling_rate, + 4, + "two CNN(kernel=5) module downsampling rate."); using kaldi::BaseFloat; using kaldi::Matrix; using std::vector; + +// test ds2 online decoder by feeding speech feature int main(int argc, char* argv[]) { gflags::ParseCommandLineFlags(&argc, &argv, false); google::InitGoogleLogging(argv[0]); @@ -43,6 +50,11 @@ int main(int argc, char* argv[]) { std::string model_params = FLAGS_param_path; std::string dict_file = FLAGS_dict_file; std::string lm_path = FLAGS_lm_path; + LOG(INFO) << "model path: " << model_graph; + LOG(INFO) << "model param: " << model_params; + LOG(INFO) << "dict path: " << dict_file; + LOG(INFO) << "lm path: " << lm_path; + int32 num_done = 0, num_err = 0; @@ -57,34 +69,44 @@ int main(int argc, char* argv[]) { model_opts.cache_shape = "5-1-1024,5-1-1024"; std::shared_ptr nnet( new ppspeech::PaddleNnet(model_opts)); - std::shared_ptr raw_data( - new ppspeech::DataCache()); + std::shared_ptr raw_data(new ppspeech::DataCache()); std::shared_ptr decodable( new ppspeech::Decodable(nnet, raw_data)); - int32 chunk_size = 7; - int32 chunk_stride = 4; - int32 receptive_field_length = 7; + int32 chunk_size = FLAGS_receptive_field_length; + int32 chunk_stride = FLAGS_downsampling_rate; + int32 receptive_field_length = FLAGS_receptive_field_length; + LOG(INFO) << "chunk size (frame): " << chunk_size; + LOG(INFO) << "chunk stride (frame): " << chunk_stride; + LOG(INFO) << "receptive field (frame): " << receptive_field_length; decoder.InitDecoder(); for (; !feature_reader.Done(); feature_reader.Next()) { string utt = feature_reader.Key(); kaldi::Matrix feature = feature_reader.Value(); raw_data->SetDim(feature.NumCols()); + LOG(INFO) << "process utt: " << utt; + LOG(INFO) << "rows: " << feature.NumRows(); + LOG(INFO) << "cols: " << feature.NumCols(); + int32 row_idx = 0; int32 padding_len = 0; - int32 ori_feature_len = feature.NumRows(); - if ( (feature.NumRows() - chunk_size) % chunk_stride != 0) { - padding_len = chunk_stride - (feature.NumRows() - chunk_size) % chunk_stride; - feature.Resize(feature.NumRows() + padding_len, feature.NumCols(), kaldi::kCopyData); + int32 ori_feature_len = feature.NumRows(); + if ((feature.NumRows() - chunk_size) % chunk_stride != 0) { + padding_len = + chunk_stride - (feature.NumRows() - chunk_size) % chunk_stride; + feature.Resize(feature.NumRows() + padding_len, + feature.NumCols(), + kaldi::kCopyData); } int32 num_chunks = (feature.NumRows() - chunk_size) / chunk_stride + 1; for (int chunk_idx = 0; chunk_idx < num_chunks; ++chunk_idx) { kaldi::Vector feature_chunk(chunk_size * feature.NumCols()); - int32 feature_chunk_size = 0; - if ( ori_feature_len > chunk_idx * chunk_stride) { - feature_chunk_size = std::min(ori_feature_len - chunk_idx * chunk_stride, chunk_size); + int32 feature_chunk_size = 0; + if (ori_feature_len > chunk_idx * chunk_stride) { + feature_chunk_size = std::min( + ori_feature_len - chunk_idx * chunk_stride, chunk_size); } if (feature_chunk_size < receptive_field_length) break; diff --git a/speechx/speechx/nnet/decodable.cc b/speechx/speechx/nnet/decodable.cc index cd72bf767..e6315d07a 100644 --- a/speechx/speechx/nnet/decodable.cc +++ b/speechx/speechx/nnet/decodable.cc @@ -82,7 +82,7 @@ void Decodable::Reset() { if (nnet_ != nullptr) nnet_->Reset(); frame_offset_ = 0; frames_ready_ = 0; - nnet_cache_.Resize(0,0); + nnet_cache_.Resize(0, 0); } } // namespace ppspeech \ No newline at end of file