From d2d53cce1a9e4a89963ae16cd4bf983036fd5e7a Mon Sep 17 00:00:00 2001 From: SmileGoat Date: Tue, 22 Feb 2022 20:09:57 +0800 Subject: [PATCH] add cmakelist of decoder, nnet --- speechx/CMakeLists.txt | 52 +++++++++++-- speechx/speechx/CMakeLists.txt | 21 ++++++ speechx/speechx/base/common.h | 3 + speechx/speechx/base/macros.h | 2 + speechx/speechx/decoder/CMakeLists.txt | 12 ++- .../decoder/ctc_beam_search_decoder.cc | 74 ++++++++++--------- .../speechx/decoder/ctc_beam_search_decoder.h | 29 ++++---- speechx/speechx/nnet/CMakeLists.txt | 2 + speechx/speechx/nnet/decodable-itf.h | 2 + speechx/speechx/nnet/decodable.cc | 28 ++++--- speechx/speechx/nnet/decodable.h | 11 ++- speechx/speechx/nnet/nnet_interface.h | 3 +- speechx/speechx/nnet/paddle_nnet.cc | 35 +++++---- speechx/speechx/nnet/paddle_nnet.h | 28 ++++--- speechx/speechx/utils/CMakeLists.txt | 4 + speechx/speechx/utils/file_utils.cc | 17 +++++ speechx/speechx/utils/file_utils.h | 8 ++ 17 files changed, 228 insertions(+), 103 deletions(-) create mode 100644 speechx/speechx/nnet/CMakeLists.txt create mode 100644 speechx/speechx/utils/file_utils.cc create mode 100644 speechx/speechx/utils/file_utils.h diff --git a/speechx/CMakeLists.txt b/speechx/CMakeLists.txt index ac3c683d..3b8c8788 100644 --- a/speechx/CMakeLists.txt +++ b/speechx/CMakeLists.txt @@ -39,16 +39,40 @@ FetchContent_Declare( GIT_TAG "20210324.1" ) FetchContent_MakeAvailable(absl) -include_directories(${absl_SOURCE_DIR}/absl) +include_directories(${absl_SOURCE_DIR}) # libsndfile +#include(FetchContent) +#FetchContent_Declare( +# libsndfile +# GIT_REPOSITORY "https://github.com/libsndfile/libsndfile.git" +# GIT_TAG "1.0.31" +#) +#FetchContent_MakeAvailable(libsndfile) + +# todo boost build +#include(FetchContent) +#FetchContent_Declare( +# Boost +# URL https://boostorg.jfrog.io/artifactory/main/release/1.75.0/source/boost_1_75_0.zip +# URL_HASH SHA256=aeb26f80e80945e82ee93e5939baebdca47b9dee80a07d3144be1e1a6a66dd6a +#) +#FetchContent_MakeAvailable(Boost) +#include_directories(${Boost_SOURCE_DIR}) + + +set(BOOST_ROOT ${fc_patch}/boost-subbuild/boost-populate-prefix/src/boost_1_75_0) +include_directories(${fc_patch}/boost-subbuild/boost-populate-prefix/src/boost_1_75_0) +link_directories(${fc_patch}/boost-subbuild/boost-populate-prefix/src/boost_1_75_0/stage/lib) include(FetchContent) FetchContent_Declare( - libsndfile - GIT_REPOSITORY "https://github.com/libsndfile/libsndfile.git" - GIT_TAG "1.0.31" + kenlm + GIT_REPOSITORY "https://github.com/kpu/kenlm.git" + GIT_TAG "df2d717e95183f79a90b2fa6e4307083a351ca6a" ) -FetchContent_MakeAvailable(libsndfile) +FetchContent_MakeAvailable(kenlm) +add_dependencies(kenlm Boost) +include_directories(${kenlm_SOURCE_DIR}) # gflags FetchContent_Declare( @@ -94,6 +118,22 @@ add_dependencies(openfst gflags glog) link_directories(${openfst_PREFIX_DIR}/lib) include_directories(${openfst_PREFIX_DIR}/include) +set(PADDLE_LIB ${fc_patch}/paddle-lib/paddle_inference) +include_directories("${PADDLE_LIB}/paddle/include") +set(PADDLE_LIB_THIRD_PARTY_PATH "${PADDLE_LIB}/third_party/install/") +include_directories("${PADDLE_LIB_THIRD_PARTY_PATH}protobuf/include") +#include_directories("${PADDLE_LIB_THIRD_PARTY_PATH}glog/include") +#include_directories("${PADDLE_LIB_THIRD_PARTY_PATH}gflags/include") +include_directories("${PADDLE_LIB_THIRD_PARTY_PATH}xxhash/include") +include_directories("${PADDLE_LIB_THIRD_PARTY_PATH}cryptopp/include") + +link_directories("${PADDLE_LIB_THIRD_PARTY_PATH}protobuf/lib") +#link_directories("${PADDLE_LIB_THIRD_PARTY_PATH}glog/lib") +#link_directories("${PADDLE_LIB_THIRD_PARTY_PATH}gflags/lib") +link_directories("${PADDLE_LIB_THIRD_PARTY_PATH}xxhash/lib") +link_directories("${PADDLE_LIB_THIRD_PARTY_PATH}cryptopp/lib") +link_directories("${PADDLE_LIB}/paddle/lib") + add_subdirectory(speechx) #openblas @@ -122,4 +162,4 @@ add_subdirectory(speechx) # if dir do not have CmakeLists.txt #add_library(lib_name STATIC file.cc) #target_link_libraries(lib_name item0 item1) -#add_dependencies(lib_name depend-target) \ No newline at end of file +#add_dependencies(lib_name depend-target) diff --git a/speechx/speechx/CMakeLists.txt b/speechx/speechx/CMakeLists.txt index 25e7b1e3..bdf82146 100644 --- a/speechx/speechx/CMakeLists.txt +++ b/speechx/speechx/CMakeLists.txt @@ -12,14 +12,35 @@ ${CMAKE_CURRENT_SOURCE_DIR}/kaldi ) add_subdirectory(kaldi) +include_directories( +${CMAKE_CURRENT_SOURCE_DIR} +${CMAKE_CURRENT_SOURCE_DIR}/utils +) +add_subdirectory(utils) + include_directories( ${CMAKE_CURRENT_SOURCE_DIR} ${CMAKE_CURRENT_SOURCE_DIR}/frontend ) add_subdirectory(frontend) +include_directories( +${CMAKE_CURRENT_SOURCE_DIR} +${CMAKE_CURRENT_SOURCE_DIR}/nnet +) +add_subdirectory(nnet) + +include_directories( +${CMAKE_CURRENT_SOURCE_DIR} +${CMAKE_CURRENT_SOURCE_DIR}/decoder +) +add_subdirectory(decoder) + add_executable(mfcc-test codelab/feat_test/feature-mfcc-test.cc) target_link_libraries(mfcc-test kaldi-mfcc) add_executable(linear_spectrogram_main codelab/feat_test/linear_spectrogram_main.cc) target_link_libraries(linear_spectrogram_main frontend kaldi-util kaldi-feat-common gflags glog) + +#add_executable(offline_decoder_main codelab/decoder_test/offline_decoder_main.cc) +#target_link_libraries(offline_decoder_main nnet decoder gflags glog) diff --git a/speechx/speechx/base/common.h b/speechx/speechx/base/common.h index a16fc55b..f4261e55 100644 --- a/speechx/speechx/base/common.h +++ b/speechx/speechx/base/common.h @@ -17,6 +17,7 @@ #include #include #include +#include #include #include #include @@ -27,7 +28,9 @@ #include #include #include +#include #include "base/log.h" +#include "base/flags.h" #include "base/basic_types.h" #include "base/macros.h" diff --git a/speechx/speechx/base/macros.h b/speechx/speechx/base/macros.h index c8d254d6..17254887 100644 --- a/speechx/speechx/base/macros.h +++ b/speechx/speechx/base/macros.h @@ -16,8 +16,10 @@ namespace ppspeech { +#ifndef DISALLOW_COPY_AND_ASSIGN #define DISALLOW_COPY_AND_ASSIGN(TypeName) \ TypeName(const TypeName&) = delete; \ void operator=(const TypeName&) = delete +#endif } // namespace pp_speech \ No newline at end of file diff --git a/speechx/speechx/decoder/CMakeLists.txt b/speechx/speechx/decoder/CMakeLists.txt index 259261bd..8885dca9 100644 --- a/speechx/speechx/decoder/CMakeLists.txt +++ b/speechx/speechx/decoder/CMakeLists.txt @@ -1,2 +1,10 @@ -aux_source_directory(. DIR_LIB_SRCS) -add_library(decoder STATIC ${DIR_LIB_SRCS}) +project(decoder) + +include_directories(${CMAKE_CURRENT_SOURCE_DIR/ctc_decoders}) +add_library(decoder + ctc_beam_search_decoder.cc + ctc_decoders/decoder_utils.cpp + ctc_decoders/path_trie.cpp + ctc_decoders/scorer.cpp +) +target_link_libraries(decoder kenlm) \ No newline at end of file diff --git a/speechx/speechx/decoder/ctc_beam_search_decoder.cc b/speechx/speechx/decoder/ctc_beam_search_decoder.cc index d4407b53..62abf377 100644 --- a/speechx/speechx/decoder/ctc_beam_search_decoder.cc +++ b/speechx/speechx/decoder/ctc_beam_search_decoder.cc @@ -2,33 +2,32 @@ #include "base/basic_types.h" #include "decoder/ctc_decoders/decoder_utils.h" +#include "utils/file_utils.h" namespace ppspeech { using std::vector; using FSTMATCH = fst::SortedMatcher; -CTCBeamSearch::CTCBeamSearch(std::shared_ptr opts) : +CTCBeamSearch::CTCBeamSearch(const CTCBeamSearchOptions& opts) : opts_(opts), - vocabulary_(nullptr), init_ext_scorer_(nullptr), blank_id(-1), space_id(-1), - num_frame_decoded(0), + num_frame_decoded_(0), root(nullptr) { LOG(INFO) << "dict path: " << opts_.dict_file; - vocabulary_ = std::make_shared>(); - if (!basr::ReadDictToVector(opts_.dict_file, *vocabulary_)) { + if (!ReadFileToVector(opts_.dict_file, &vocabulary_)) { LOG(INFO) << "load the dict failed"; } - LOG(INFO) << "read the vocabulary success, dict size: " << vocabulary_->size(); + LOG(INFO) << "read the vocabulary success, dict size: " << vocabulary_.size(); LOG(INFO) << "language model path: " << opts_.lm_path; init_ext_scorer_ = std::make_shared(opts_.alpha, opts_.beta, opts_.lm_path, - *vocabulary_); + vocabulary_); } void CTCBeamSearch::Reset() { @@ -39,11 +38,11 @@ void CTCBeamSearch::Reset() { void CTCBeamSearch::InitDecoder() { blank_id = 0; - auto it = std::find(vocabulary_->begin(), vocabulary_->end(), " "); + auto it = std::find(vocabulary_.begin(), vocabulary_.end(), " "); - space_id = it - vocabulary_->begin(); + space_id = it - vocabulary_.begin(); // if no space in vocabulary - if ((size_t)space_id >= vocabulary_->size()) { + if ((size_t)space_id >= vocabulary_.size()) { space_id = -2; } @@ -63,19 +62,24 @@ void CTCBeamSearch::InitDecoder() { } } +void CTCBeamSearch::Decode(std::shared_ptr decodable) { + return; +} + int32 CTCBeamSearch::NumFrameDecoded() { return num_frame_decoded_; } // todo rename, refactor -void CTCBeamSearch::AdvanceDecode(const std::shared_ptr& decodable, int max_frames) { +void CTCBeamSearch::AdvanceDecode(const std::shared_ptr& decodable, + int max_frames) { while (max_frames > 0) { vector> likelihood; if (decodable->IsLastFrame(NumFrameDecoded() + 1)) { break; } likelihood.push_back(decodable->FrameLogLikelihood(NumFrameDecoded() + 1)); - AdvanceDecoding(result); + AdvanceDecoding(likelihood); max_frames--; } } @@ -91,32 +95,21 @@ void CTCBeamSearch::ResetPrefixes() { int CTCBeamSearch::DecodeLikelihoods(const vector>&probs, vector& nbest_words) { - std::thread::id this_id = std::this_thread::get_id(); - Timer timer; - vector> double_probs(probs.size(), vector(probs[0].size(), 0)); - - int row = probs.size(); - int col = probs[0].size(); - for(int i = 0; i < row; i++) { - for (int j = 0; j < col; j++){ - double_probs[i][j] = static_cast(probs[i][j]); - } - } - + kaldi::Timer timer; timer.Reset(); - AdvanceDecoding(double_probs); + AdvanceDecoding(probs); LOG(INFO) <<"ctc decoding elapsed time(s) " << static_cast(timer.Elapsed()) / 1000.0f; return 0; } vector> CTCBeamSearch::GetNBestPath() { - return get_beam_search_result(prefixes, *vocabulary_, opts_.beam_size); + return get_beam_search_result(prefixes, vocabulary_, opts_.beam_size); } string CTCBeamSearch::GetBestPath() { std::vector> result; - result = get_beam_search_result(prefixes, *vocabulary_, opts_.beam_size); - return result[0]->second; + result = get_beam_search_result(prefixes, vocabulary_, opts_.beam_size); + return result[0].second; } string CTCBeamSearch::GetFinalBestPath() { @@ -125,12 +118,22 @@ string CTCBeamSearch::GetFinalBestPath() { return GetBestPath(); } -void CTCBeamSearch::AdvanceDecoding(const vector>& probs_seq) { - size_t num_time_steps = probs_seq.size(); +void CTCBeamSearch::AdvanceDecoding(const vector>& probs) { + size_t num_time_steps = probs.size(); size_t beam_size = opts_.beam_size; double cutoff_prob = opts_.cutoff_prob; size_t cutoff_top_n = opts_.cutoff_top_n; - + + vector> probs_seq(probs.size(), vector(probs[0].size(), 0)); + + int row = probs.size(); + int col = probs[0].size(); + for(int i = 0; i < row; i++) { + for (int j = 0; j < col; j++){ + probs_seq[i][j] = static_cast(probs[i][j]); + } + } + for (size_t time_step = 0; time_step < num_time_steps; time_step++) { const auto& prob = probs_seq[time_step]; @@ -158,7 +161,8 @@ void CTCBeamSearch::AdvanceDecoding(const vector>& probs_seq) { size_t log_prob_idx_len = log_prob_idx.size(); for (size_t index = 0; index < log_prob_idx_len; index++) { SearchOneChar(full_beam, log_prob_idx[index], min_cutoff); - + } + prefixes.clear(); // update log probs @@ -177,9 +181,9 @@ void CTCBeamSearch::AdvanceDecoding(const vector>& probs_seq) { } // for probs_seq } -int CTCBeamSearch::SearchOneChar(const bool& full_beam, - const std::pair& log_prob_idx, - const float& min_cutoff) { +int32 CTCBeamSearch::SearchOneChar(const bool& full_beam, + const std::pair& log_prob_idx, + const BaseFloat& min_cutoff) { size_t beam_size = opts_.beam_size; const auto& c = log_prob_idx.first; const auto& log_prob_c = log_prob_idx.second; diff --git a/speechx/speechx/decoder/ctc_beam_search_decoder.h b/speechx/speechx/decoder/ctc_beam_search_decoder.h index b461db88..53af449e 100644 --- a/speechx/speechx/decoder/ctc_beam_search_decoder.h +++ b/speechx/speechx/decoder/ctc_beam_search_decoder.h @@ -1,5 +1,8 @@ -#include "base/basic_types.h" +#include "base/common.h" #include "nnet/decodable-itf.h" +#include "util/parse-options.h" +#include "decoder/ctc_decoders/scorer.h" +#include "decoder/ctc_decoders/path_trie.h" #pragma once @@ -38,41 +41,39 @@ struct CTCBeamSearchOptions { }; class CTCBeamSearch { -public: - - CTCBeamSearch(std::shared_ptr opts); - - ~CTCBeamSearch() { - } - bool InitDecoder(); + public: + explicit CTCBeamSearch(const CTCBeamSearchOptions& opts); + ~CTCBeamSearch() {} + void InitDecoder(); void Decode(std::shared_ptr decodable); std::string GetBestPath(); std::vector> GetNBestPath(); - std::string GetFinalBestPath(); + std::string GetFinalBestPath(); int NumFrameDecoded(); int DecodeLikelihoods(const std::vector>&probs, std::vector& nbest_words); + void AdvanceDecode(const std::shared_ptr& decodable, + int max_frames); void Reset(); - -private: + private: void ResetPrefixes(); int32 SearchOneChar(const bool& full_beam, const std::pair& log_prob_idx, const BaseFloat& min_cutoff); void CalculateApproxScore(); void LMRescore(); - void AdvanceDecoding(const std::vector>& probs_seq); + void AdvanceDecoding(const std::vector>& probs); CTCBeamSearchOptions opts_; std::shared_ptr init_ext_scorer_; // todo separate later //std::vector decoder_results_; - std::vector> vocabulary_; // todo remove later - + std::vector vocabulary_; // todo remove later size_t blank_id; int space_id; std::shared_ptr root; std::vector prefixes; int num_frame_decoded_; + DISALLOW_COPY_AND_ASSIGN(CTCBeamSearch); }; } // namespace basr \ No newline at end of file diff --git a/speechx/speechx/nnet/CMakeLists.txt b/speechx/speechx/nnet/CMakeLists.txt new file mode 100644 index 00000000..4d336b86 --- /dev/null +++ b/speechx/speechx/nnet/CMakeLists.txt @@ -0,0 +1,2 @@ +aux_source_directory(. DIR_LIB_SRCS) +add_library(nnet STATIC ${DIR_LIB_SRCS}) diff --git a/speechx/speechx/nnet/decodable-itf.h b/speechx/speechx/nnet/decodable-itf.h index 20934dde..93f7db76 100644 --- a/speechx/speechx/nnet/decodable-itf.h +++ b/speechx/speechx/nnet/decodable-itf.h @@ -114,6 +114,8 @@ class DecodableInterface { /// this is for compatibility with OpenFst). virtual int32 NumIndices() const = 0; + virtual std::vector FrameLogLikelihood(int32 frame); + virtual ~DecodableInterface() {} }; /// @} diff --git a/speechx/speechx/nnet/decodable.cc b/speechx/speechx/nnet/decodable.cc index 6c03b4a4..984f3ad3 100644 --- a/speechx/speechx/nnet/decodable.cc +++ b/speechx/speechx/nnet/decodable.cc @@ -2,15 +2,24 @@ namespace ppspeech { -Decodable::Acceptlikelihood(const kaldi::Matrix& likelihood) { - frames_ready_ += likelihood.NumRows(); +using kaldi::BaseFloat; +using kaldi::Matrix; + +Decodable::Decodable(const std::shared_ptr& nnet): + frontend_(NULL), + nnet_(nnet), + finished_(false), + frames_ready_(0) { } -Decodable::Init(DecodableConfig config) { - +void Decodable::Acceptlikelihood(const Matrix& likelihood) { + frames_ready_ += likelihood.NumRows(); } -Decodable::IsLastFrame(int32 frame) const { +//Decodable::Init(DecodableConfig config) { +//} + +bool Decodable::IsLastFrame(int32 frame) const { CHECK_LE(frame, frames_ready_); return finished_ && (frame == frames_ready_ - 1); } @@ -19,12 +28,11 @@ int32 Decodable::NumIndices() const { return 0; } -void Decodable::LogLikelihood(int32 frame, int32 index) { - return ; +BaseFloat Decodable::LogLikelihood(int32 frame, int32 index) { + return 0; } -void Decodable::FeedFeatures(const kaldi::Matrix& features) { - // skip frame ??? +void Decodable::FeedFeatures(const Matrix& features) { nnet_->FeedForward(features, &nnet_cache_); frames_ready_ += nnet_cache_.NumRows(); return ; @@ -35,4 +43,4 @@ void Decodable::Reset() { nnet_->Reset(); } -} // namespace ppspeech +} // namespace ppspeech \ No newline at end of file diff --git a/speechx/speechx/nnet/decodable.h b/speechx/speechx/nnet/decodable.h index 0bf28d94..6f06d69a 100644 --- a/speechx/speechx/nnet/decodable.h +++ b/speechx/speechx/nnet/decodable.h @@ -1,14 +1,17 @@ #include "nnet/decodable-itf.h" - #include "base/common.h" +#include "kaldi/matrix/kaldi-matrix.h" +#include "frontend/feature_extractor_interface.h" +#include "nnet/nnet_interface.h" namespace ppspeech { -struct DecodableConfig; +struct DecodableOpts; class Decodable : public kaldi::DecodableInterface { public: - virtual void Init(DecodableOpts config); + explicit Decodable(const std::shared_ptr& nnet); + //void Init(DecodableOpts config); virtual kaldi::BaseFloat LogLikelihood(int32 frame, int32 index); virtual bool IsLastFrame(int32 frame) const; virtual int32 NumIndices() const; @@ -25,4 +28,4 @@ class Decodable : public kaldi::DecodableInterface { int32 frames_ready_; }; -} // namespace ppspeech \ No newline at end of file +} // namespace ppspeech diff --git a/speechx/speechx/nnet/nnet_interface.h b/speechx/speechx/nnet/nnet_interface.h index 5965f7e8..577662f3 100644 --- a/speechx/speechx/nnet/nnet_interface.h +++ b/speechx/speechx/nnet/nnet_interface.h @@ -3,13 +3,14 @@ #include "base/basic_types.h" #include "kaldi/base/kaldi-types.h" +#include "kaldi/matrix/kaldi-matrix.h" namespace ppspeech { class NnetInterface { public: virtual ~NnetInterface() {} - virtual void FeedForward(const kaldi::Matrix& features, + virtual void FeedForward(const kaldi::Matrix& features, kaldi::Matrix* inferences); virtual void Reset(); diff --git a/speechx/speechx/nnet/paddle_nnet.cc b/speechx/speechx/nnet/paddle_nnet.cc index e64850cb..61690872 100644 --- a/speechx/speechx/nnet/paddle_nnet.cc +++ b/speechx/speechx/nnet/paddle_nnet.cc @@ -3,6 +3,11 @@ namespace ppspeech { +using std::vector; +using std::string; +using std::shared_ptr; +using kaldi::Matrix; + void PaddleNnet::InitCacheEncouts(const ModelOptions& opts) { std::vector cache_names; cache_names = absl::StrSplit(opts.cache_names, ", "); @@ -25,14 +30,14 @@ void PaddleNnet::InitCacheEncouts(const ModelOptions& opts) { } } -PaddleNet::PaddleNnet(const ModelOptions& opts) { +PaddleNnet::PaddleNnet(const ModelOptions& opts) { paddle_infer::Config config; config.SetModel(opts.model_path, opts.params_path); if (opts.use_gpu) { config.EnableUseGpu(500, 0); } config.SwitchIrOptim(opts.switch_ir_optim); - if (opts.enbale_fc_padding) { + if (opts.enable_fc_padding) { config.DisableFCPadding(); } if (opts.enable_profile) { @@ -42,7 +47,7 @@ PaddleNet::PaddleNnet(const ModelOptions& opts) { if (pool == nullptr) { LOG(ERROR) << "create the predictor pool failed"; } - pool_usages.resize(num_thread); + pool_usages.resize(opts.thread_num); std::fill(pool_usages.begin(), pool_usages.end(), false); LOG(INFO) << "load paddle model success"; @@ -51,7 +56,7 @@ PaddleNet::PaddleNnet(const ModelOptions& opts) { LOG(INFO) << "output names: " << opts.output_names; vector input_names_vec = absl::StrSplit(opts.input_names, ", "); vector output_names_vec = absl::StrSplit(opts.output_names, ", "); - paddle_infer::Predictor* predictor = get_predictor(); + paddle_infer::Predictor* predictor = GetPredictor(); std::vector model_input_names = predictor->GetInputNames(); assert(input_names_vec.size() == model_input_names.size()); @@ -64,12 +69,12 @@ PaddleNet::PaddleNnet(const ModelOptions& opts) { for (size_t i = 0;i < output_names_vec.size(); i++) { assert(output_names_vec[i] == model_output_names[i]); } - release_predictor(predictor); + ReleasePredictor(predictor); InitCacheEncouts(opts); } -paddle_infer::Predictor* PaddleNnet::get_predictor() { +paddle_infer::Predictor* PaddleNnet::GetPredictor() { LOG(INFO) << "attempt to get a new predictor instance " << std::endl; paddle_infer::Predictor* predictor = nullptr; std::lock_guard guard(pool_mutex); @@ -111,19 +116,18 @@ int PaddleNnet::ReleasePredictor(paddle_infer::Predictor* predictor) { return 0; } - - shared_ptr> PaddleNnet::GetCacheEncoder(const string& name) { auto iter = cache_names_idx_.find(name); if (iter == cache_names_idx_.end()) { return nullptr; } assert(iter->second < cache_encouts_.size()); - return cache_encouts_[iter->second].get(); + return cache_encouts_[iter->second]; } -void PaddleNet::FeedForward(const Matrix& features, Matrix* inferences) const { +void PaddleNnet::FeedForward(const Matrix& features, Matrix* inferences) { + paddle_infer::Predictor* predictor = GetPredictor(); // 1. 得到所有的 input tensor 的名称 int row = features.NumRows(); int col = features.NumCols(); @@ -144,15 +148,13 @@ void PaddleNet::FeedForward(const Matrix& features, Matrix input_len->CopyFromCpu(audio_len.data()); // 输入流式的缓存数据 std::unique_ptr h_box = predictor->GetInputHandle(input_names[2]); - share_ptr> h_cache = GetCacheEncoder(input_names[2])); + shared_ptr> h_cache = GetCacheEncoder(input_names[2]); h_box->Reshape(h_cache->get_shape()); h_box->CopyFromCpu(h_cache->get_data().data()); std::unique_ptr c_box = predictor->GetInputHandle(input_names[3]); - share_ptr> c_cache = GetCacheEncoder(input_names[3]); + shared_ptr> c_cache = GetCacheEncoder(input_names[3]); c_box->Reshape(c_cache->get_shape()); c_box->CopyFromCpu(c_cache->get_data().data()); - std::thread::id this_id = std::this_thread::get_id(); - LOG(INFO) << this_id << " start to compute the probability"; bool success = predictor->Run(); if (success == false) { @@ -172,8 +174,9 @@ void PaddleNet::FeedForward(const Matrix& features, Matrix std::vector output_shape = output_tensor->shape(); row = output_shape[1]; col = output_shape[2]; - inference.Resize(row, col); - output_tensor->CopyToCpu(inference.Data()); + inferences->Resize(row, col); + output_tensor->CopyToCpu(inferences->Data()); + ReleasePredictor(predictor); } } // namespace ppspeech \ No newline at end of file diff --git a/speechx/speechx/nnet/paddle_nnet.h b/speechx/speechx/nnet/paddle_nnet.h index 7f34eeaf..b659a12b 100644 --- a/speechx/speechx/nnet/paddle_nnet.h +++ b/speechx/speechx/nnet/paddle_nnet.h @@ -3,8 +3,12 @@ #include "nnet/nnet_interface.h" #include "base/common.h" -#include "paddle/paddle_inference_api.h" +#include "paddle_inference_api.h" +#include "kaldi/matrix/kaldi-matrix.h" +#include "kaldi/util/options-itf.h" + +#include namespace ppspeech { @@ -20,7 +24,7 @@ struct ModelOptions { std::string cache_shape; bool enable_fc_padding; bool enable_profile; - ModelDecoderOptions() : + ModelOptions() : model_path("model/final.zip"), params_path("model/avg_1.jit.pdmodel"), thread_num(2), @@ -49,16 +53,6 @@ struct ModelOptions { } }; - void Register(kaldi::OptionsItf* opts) { - _model_opts.Register(opts); - opts->Register("subsampling-rate", &subsampling_rate, - "subsampling rate for deepspeech model"); - opts->Register("receptive-field-length", &receptive_field_length, - "receptive field length for deepspeech model"); - } -}; - - template class Tensor { public: @@ -91,15 +85,19 @@ private: class PaddleNnet : public NnetInterface { public: PaddleNnet(const ModelOptions& opts); - virtual void FeedForward(const kaldi::Matrix& features, - kaldi::Matrix* inferences) const; + virtual void FeedForward(const kaldi::Matrix& features, + kaldi::Matrix* inferences); std::shared_ptr> GetCacheEncoder(const std::string& name); - void InitCacheEncouts(const ModelOptions& opts); + void InitCacheEncouts(const ModelOptions& opts); private: + paddle_infer::Predictor* GetPredictor(); + int ReleasePredictor(paddle_infer::Predictor* predictor); + std::unique_ptr pool; std::vector pool_usages; std::mutex pool_mutex; + std::map predictor_to_thread_id; std::map cache_names_idx_; std::vector>> cache_encouts_; diff --git a/speechx/speechx/utils/CMakeLists.txt b/speechx/speechx/utils/CMakeLists.txt index e69de29b..b5e2495e 100644 --- a/speechx/speechx/utils/CMakeLists.txt +++ b/speechx/speechx/utils/CMakeLists.txt @@ -0,0 +1,4 @@ + +add_library(utils + file_utils.cc +) diff --git a/speechx/speechx/utils/file_utils.cc b/speechx/speechx/utils/file_utils.cc new file mode 100644 index 00000000..8b2758ba --- /dev/null +++ b/speechx/speechx/utils/file_utils.cc @@ -0,0 +1,17 @@ +#include "utils/file_utils.h" + +bool ReadFileToVector(const std::string& filename, + std::vector* vocabulary) { + std::ifstream file_in(filename); + if (!file_in) { + std::cerr << "please input a valid file" << std::endl; + return false; + } + + std::string line; + while (std::getline(file_in, line)) { + vocabulary->emplace_back(line); + } + + return true; +} diff --git a/speechx/speechx/utils/file_utils.h b/speechx/speechx/utils/file_utils.h new file mode 100644 index 00000000..0011b6c5 --- /dev/null +++ b/speechx/speechx/utils/file_utils.h @@ -0,0 +1,8 @@ +#include "base/common.h" + +namespace ppspeech { + +bool ReadFileToVector(const std::string& filename, + std::vector* data); + +}