From 143ab13679f6221b0a3757899a4c6b715ba06037 Mon Sep 17 00:00:00 2001 From: Yang Zhou Date: Thu, 31 Mar 2022 20:13:05 +0800 Subject: [PATCH] add decoder_test_main --- speechx/examples/decoder/CMakeLists.txt | 9 ++ speechx/examples/decoder/decoder_test_main.cc | 69 ++++++++++ .../examples/decoder/offline_decoder_main.cc | 1 + .../offline_decoder_sliding_chunk_main.cc | 119 ++++++++++++++++++ .../decoder/ctc_beam_search_decoder.cc | 6 +- .../speechx/decoder/ctc_beam_search_decoder.h | 4 +- speechx/speechx/nnet/decodable.cc | 8 +- 7 files changed, 209 insertions(+), 7 deletions(-) create mode 100644 speechx/examples/decoder/decoder_test_main.cc create mode 100644 speechx/examples/decoder/offline_decoder_sliding_chunk_main.cc diff --git a/speechx/examples/decoder/CMakeLists.txt b/speechx/examples/decoder/CMakeLists.txt index 4bd5c6cf..ded423e9 100644 --- a/speechx/examples/decoder/CMakeLists.txt +++ b/speechx/examples/decoder/CMakeLists.txt @@ -1,5 +1,14 @@ cmake_minimum_required(VERSION 3.14 FATAL_ERROR) +add_executable(offline_decoder_sliding_chunk_main ${CMAKE_CURRENT_SOURCE_DIR}/offline_decoder_sliding_chunk_main.cc) +target_include_directories(offline_decoder_sliding_chunk_main PRIVATE ${SPEECHX_ROOT} ${SPEECHX_ROOT}/kaldi) +target_link_libraries(offline_decoder_sliding_chunk_main PUBLIC nnet decoder fst utils gflags glog kaldi-base kaldi-matrix kaldi-util ${DEPS}) + add_executable(offline_decoder_main ${CMAKE_CURRENT_SOURCE_DIR}/offline_decoder_main.cc) target_include_directories(offline_decoder_main PRIVATE ${SPEECHX_ROOT} ${SPEECHX_ROOT}/kaldi) target_link_libraries(offline_decoder_main PUBLIC nnet decoder fst utils gflags glog kaldi-base kaldi-matrix kaldi-util ${DEPS}) + +add_executable(decoder_test_main ${CMAKE_CURRENT_SOURCE_DIR}/decoder_test_main.cc) +target_include_directories(decoder_test_main PRIVATE ${SPEECHX_ROOT} ${SPEECHX_ROOT}/kaldi) +target_link_libraries(decoder_test_main PUBLIC nnet decoder fst utils gflags glog kaldi-base kaldi-matrix kaldi-util ${DEPS}) + diff --git a/speechx/examples/decoder/decoder_test_main.cc b/speechx/examples/decoder/decoder_test_main.cc new file mode 100644 index 00000000..79fe63fc --- /dev/null +++ b/speechx/examples/decoder/decoder_test_main.cc @@ -0,0 +1,69 @@ +// Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +// todo refactor, repalce with gtest + +#include "base/flags.h" +#include "base/log.h" +#include "decoder/ctc_beam_search_decoder.h" +#include "kaldi/util/table-types.h" +#include "nnet/decodable.h" + +DEFINE_string(nnet_prob_respecifier, "", "test nnet prob rspecifier"); +DEFINE_string(dict_file, "vocab.txt", "vocabulary of lm"); +DEFINE_string(lm_path, "lm.klm", "language model"); + + +using kaldi::BaseFloat; +using kaldi::Matrix; +using std::vector; + +int main(int argc, char* argv[]) { + gflags::ParseCommandLineFlags(&argc, &argv, false); + google::InitGoogleLogging(argv[0]); + + kaldi::SequentialBaseFloatMatrixReader likelihood_reader( + FLAGS_nnet_prob_respecifier); + std::string dict_file = FLAGS_dict_file; + std::string lm_path = FLAGS_lm_path; + + int32 num_done = 0, num_err = 0; + + ppspeech::CTCBeamSearchOptions opts; + opts.dict_file = dict_file; + opts.lm_path = lm_path; + ppspeech::CTCBeamSearch decoder(opts); + + std::shared_ptr decodable( + new ppspeech::Decodable(nullptr, nullptr)); + + decoder.InitDecoder(); + + for (; !likelihood_reader.Done(); likelihood_reader.Next()) { + string utt = likelihood_reader.Key(); + const kaldi::Matrix likelihood = likelihood_reader.Value(); + decodable->Acceptlikelihood(likelihood); + decoder.AdvanceDecode(decodable); + std::string result; + result = decoder.GetFinalBestPath(); + KALDI_LOG << " the result of " << utt << " is " << result; + decodable->Reset(); + decoder.Reset(); + ++num_done; + } + + KALDI_LOG << "Done " << num_done << " utterances, " << num_err + << " with errors."; + return (num_done != 0 ? 0 : 1); +} diff --git a/speechx/examples/decoder/offline_decoder_main.cc b/speechx/examples/decoder/offline_decoder_main.cc index 44127c73..63ca868b 100644 --- a/speechx/examples/decoder/offline_decoder_main.cc +++ b/speechx/examples/decoder/offline_decoder_main.cc @@ -52,6 +52,7 @@ int main(int argc, char* argv[]) { ppspeech::CTCBeamSearch decoder(opts); ppspeech::ModelOptions model_opts; + model_opts.cache_shape = "5-1-1024,5-1-1024"; model_opts.model_path = model_graph; model_opts.params_path = model_params; std::shared_ptr nnet( diff --git a/speechx/examples/decoder/offline_decoder_sliding_chunk_main.cc b/speechx/examples/decoder/offline_decoder_sliding_chunk_main.cc new file mode 100644 index 00000000..ad72b772 --- /dev/null +++ b/speechx/examples/decoder/offline_decoder_sliding_chunk_main.cc @@ -0,0 +1,119 @@ +// Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +// todo refactor, repalce with gtest + +#include "base/flags.h" +#include "base/log.h" +#include "decoder/ctc_beam_search_decoder.h" +#include "frontend/raw_audio.h" +#include "kaldi/util/table-types.h" +#include "nnet/decodable.h" +#include "nnet/paddle_nnet.h" + +DEFINE_string(feature_respecifier, "", "test feature rspecifier"); +DEFINE_string(model_path, "avg_1.jit.pdmodel", "paddle nnet model"); +DEFINE_string(param_path, "avg_1.jit.pdiparams", "paddle nnet model param"); +DEFINE_string(dict_file, "vocab.txt", "vocabulary of lm"); +DEFINE_string(lm_path, "lm.klm", "language model"); + + +using kaldi::BaseFloat; +using kaldi::Matrix; +using std::vector; + +int main(int argc, char* argv[]) { + gflags::ParseCommandLineFlags(&argc, &argv, false); + google::InitGoogleLogging(argv[0]); + + kaldi::SequentialBaseFloatMatrixReader feature_reader( + FLAGS_feature_respecifier); + std::string model_graph = FLAGS_model_path; + std::string model_params = FLAGS_param_path; + std::string dict_file = FLAGS_dict_file; + std::string lm_path = FLAGS_lm_path; + + int32 num_done = 0, num_err = 0; + + ppspeech::CTCBeamSearchOptions opts; + opts.dict_file = dict_file; + opts.lm_path = lm_path; + ppspeech::CTCBeamSearch decoder(opts); + + ppspeech::ModelOptions model_opts; + model_opts.model_path = model_graph; + model_opts.params_path = model_params; + model_opts.cache_shape = "5-1-1024,5-1-1024"; + std::shared_ptr nnet( + new ppspeech::PaddleNnet(model_opts)); + std::shared_ptr raw_data( + new ppspeech::RawDataCache()); + std::shared_ptr decodable( + new ppspeech::Decodable(nnet, raw_data)); + + int32 chunk_size = 7; + int32 chunk_stride = 4; + int32 receptive_field_length = 7; + decoder.InitDecoder(); + + for (; !feature_reader.Done(); feature_reader.Next()) { + string utt = feature_reader.Key(); + kaldi::Matrix feature = feature_reader.Value(); + raw_data->SetDim(feature.NumCols()); + int32 row_idx = 0; + int32 padding_len = 0; + int32 ori_feature_len = feature.NumRows(); + if ( (feature.NumRows() - chunk_size) % chunk_stride != 0) { + padding_len = chunk_stride - (feature.NumRows() - chunk_size) % chunk_stride; + feature.Resize(feature.NumRows() + padding_len, feature.NumCols(), kaldi::kCopyData); + } + int32 num_chunks = (feature.NumRows() - chunk_size) / chunk_stride + 1; + for (int chunk_idx = 0; chunk_idx < num_chunks; ++chunk_idx) { + kaldi::Vector feature_chunk(chunk_size * + feature.NumCols()); + int32 feature_chunk_size = 0; + if ( ori_feature_len > chunk_idx * chunk_stride) { + feature_chunk_size = std::min(ori_feature_len - chunk_idx * chunk_stride, chunk_size); + } + if (feature_chunk_size < receptive_field_length) break; + + int32 start = chunk_idx * chunk_stride; + int32 end = start + chunk_size; + + for (int row_id = 0; row_id < chunk_size; ++row_id) { + kaldi::SubVector tmp(feature, start); + kaldi::SubVector f_chunk_tmp( + feature_chunk.Data() + row_id * feature.NumCols(), + feature.NumCols()); + f_chunk_tmp.CopyFromVec(tmp); + ++start; + } + raw_data->Accept(feature_chunk); + if (chunk_idx == num_chunks - 1) { + raw_data->SetFinished(); + } + decoder.AdvanceDecode(decodable); + } + std::string result; + result = decoder.GetFinalBestPath(); + KALDI_LOG << " the result of " << utt << " is " << result; + decodable->Reset(); + decoder.Reset(); + ++num_done; + } + + KALDI_LOG << "Done " << num_done << " utterances, " << num_err + << " with errors."; + return (num_done != 0 ? 0 : 1); +} diff --git a/speechx/speechx/decoder/ctc_beam_search_decoder.cc b/speechx/speechx/decoder/ctc_beam_search_decoder.cc index 84f1453c..5d7a4f77 100644 --- a/speechx/speechx/decoder/ctc_beam_search_decoder.cc +++ b/speechx/speechx/decoder/ctc_beam_search_decoder.cc @@ -38,8 +38,10 @@ CTCBeamSearch::CTCBeamSearch(const CTCBeamSearchOptions& opts) << vocabulary_.size(); LOG(INFO) << "language model path: " << opts_.lm_path; - init_ext_scorer_ = std::make_shared( - opts_.alpha, opts_.beta, opts_.lm_path, vocabulary_); + if (opts_.lm_path != "") { + init_ext_scorer_ = std::make_shared( + opts_.alpha, opts_.beta, opts_.lm_path, vocabulary_); + } blank_id_ = 0; auto it = std::find(vocabulary_.begin(), vocabulary_.end(), " "); diff --git a/speechx/speechx/decoder/ctc_beam_search_decoder.h b/speechx/speechx/decoder/ctc_beam_search_decoder.h index 451f33c0..1387eee7 100644 --- a/speechx/speechx/decoder/ctc_beam_search_decoder.h +++ b/speechx/speechx/decoder/ctc_beam_search_decoder.h @@ -33,13 +33,13 @@ struct CTCBeamSearchOptions { int num_proc_bsearch; CTCBeamSearchOptions() : dict_file("vocab.txt"), - lm_path("lm.klm"), + lm_path(""), alpha(1.9f), beta(5.0), beam_size(300), cutoff_prob(0.99f), cutoff_top_n(40), - num_proc_bsearch(0) {} + num_proc_bsearch(10) {} void Register(kaldi::OptionsItf* opts) { opts->Register("dict", &dict_file, "dict file "); diff --git a/speechx/speechx/nnet/decodable.cc b/speechx/speechx/nnet/decodable.cc index 6c0909ca..cd72bf76 100644 --- a/speechx/speechx/nnet/decodable.cc +++ b/speechx/speechx/nnet/decodable.cc @@ -26,6 +26,7 @@ Decodable::Decodable(const std::shared_ptr& nnet, : frontend_(frontend), nnet_(nnet), frame_offset_(0), frames_ready_(0) {} void Decodable::Acceptlikelihood(const Matrix& likelihood) { + nnet_cache_ = likelihood; frames_ready_ += likelihood.NumRows(); } @@ -53,7 +54,7 @@ bool Decodable::EnsureFrameHaveComputed(int32 frame) { bool Decodable::AdvanceChunk() { Vector features; - if (frontend_->Read(&features) == false) { + if (frontend_ == NULL || frontend_->Read(&features) == false) { return false; } int32 nnet_dim = 0; @@ -77,10 +78,11 @@ bool Decodable::FrameLogLikelihood(int32 frame, vector* likelihood) { } void Decodable::Reset() { - frontend_->Reset(); - nnet_->Reset(); + if (frontend_ != nullptr) frontend_->Reset(); + if (nnet_ != nullptr) nnet_->Reset(); frame_offset_ = 0; frames_ready_ = 0; + nnet_cache_.Resize(0,0); } } // namespace ppspeech \ No newline at end of file