add cmakelist of decoder, nnet

pull/1400/head
SmileGoat 3 years ago
parent e90438289d
commit d2d53cce1a

@ -39,16 +39,40 @@ FetchContent_Declare(
GIT_TAG "20210324.1" GIT_TAG "20210324.1"
) )
FetchContent_MakeAvailable(absl) FetchContent_MakeAvailable(absl)
include_directories(${absl_SOURCE_DIR}/absl) include_directories(${absl_SOURCE_DIR})
# libsndfile # libsndfile
#include(FetchContent)
#FetchContent_Declare(
# libsndfile
# GIT_REPOSITORY "https://github.com/libsndfile/libsndfile.git"
# GIT_TAG "1.0.31"
#)
#FetchContent_MakeAvailable(libsndfile)
# todo boost build
#include(FetchContent)
#FetchContent_Declare(
# Boost
# URL https://boostorg.jfrog.io/artifactory/main/release/1.75.0/source/boost_1_75_0.zip
# URL_HASH SHA256=aeb26f80e80945e82ee93e5939baebdca47b9dee80a07d3144be1e1a6a66dd6a
#)
#FetchContent_MakeAvailable(Boost)
#include_directories(${Boost_SOURCE_DIR})
set(BOOST_ROOT ${fc_patch}/boost-subbuild/boost-populate-prefix/src/boost_1_75_0)
include_directories(${fc_patch}/boost-subbuild/boost-populate-prefix/src/boost_1_75_0)
link_directories(${fc_patch}/boost-subbuild/boost-populate-prefix/src/boost_1_75_0/stage/lib)
include(FetchContent) include(FetchContent)
FetchContent_Declare( FetchContent_Declare(
libsndfile kenlm
GIT_REPOSITORY "https://github.com/libsndfile/libsndfile.git" GIT_REPOSITORY "https://github.com/kpu/kenlm.git"
GIT_TAG "1.0.31" GIT_TAG "df2d717e95183f79a90b2fa6e4307083a351ca6a"
) )
FetchContent_MakeAvailable(libsndfile) FetchContent_MakeAvailable(kenlm)
add_dependencies(kenlm Boost)
include_directories(${kenlm_SOURCE_DIR})
# gflags # gflags
FetchContent_Declare( FetchContent_Declare(
@ -94,6 +118,22 @@ add_dependencies(openfst gflags glog)
link_directories(${openfst_PREFIX_DIR}/lib) link_directories(${openfst_PREFIX_DIR}/lib)
include_directories(${openfst_PREFIX_DIR}/include) include_directories(${openfst_PREFIX_DIR}/include)
set(PADDLE_LIB ${fc_patch}/paddle-lib/paddle_inference)
include_directories("${PADDLE_LIB}/paddle/include")
set(PADDLE_LIB_THIRD_PARTY_PATH "${PADDLE_LIB}/third_party/install/")
include_directories("${PADDLE_LIB_THIRD_PARTY_PATH}protobuf/include")
#include_directories("${PADDLE_LIB_THIRD_PARTY_PATH}glog/include")
#include_directories("${PADDLE_LIB_THIRD_PARTY_PATH}gflags/include")
include_directories("${PADDLE_LIB_THIRD_PARTY_PATH}xxhash/include")
include_directories("${PADDLE_LIB_THIRD_PARTY_PATH}cryptopp/include")
link_directories("${PADDLE_LIB_THIRD_PARTY_PATH}protobuf/lib")
#link_directories("${PADDLE_LIB_THIRD_PARTY_PATH}glog/lib")
#link_directories("${PADDLE_LIB_THIRD_PARTY_PATH}gflags/lib")
link_directories("${PADDLE_LIB_THIRD_PARTY_PATH}xxhash/lib")
link_directories("${PADDLE_LIB_THIRD_PARTY_PATH}cryptopp/lib")
link_directories("${PADDLE_LIB}/paddle/lib")
add_subdirectory(speechx) add_subdirectory(speechx)
#openblas #openblas
@ -122,4 +162,4 @@ add_subdirectory(speechx)
# if dir do not have CmakeLists.txt # if dir do not have CmakeLists.txt
#add_library(lib_name STATIC file.cc) #add_library(lib_name STATIC file.cc)
#target_link_libraries(lib_name item0 item1) #target_link_libraries(lib_name item0 item1)
#add_dependencies(lib_name depend-target) #add_dependencies(lib_name depend-target)

@ -12,14 +12,35 @@ ${CMAKE_CURRENT_SOURCE_DIR}/kaldi
) )
add_subdirectory(kaldi) add_subdirectory(kaldi)
include_directories(
${CMAKE_CURRENT_SOURCE_DIR}
${CMAKE_CURRENT_SOURCE_DIR}/utils
)
add_subdirectory(utils)
include_directories( include_directories(
${CMAKE_CURRENT_SOURCE_DIR} ${CMAKE_CURRENT_SOURCE_DIR}
${CMAKE_CURRENT_SOURCE_DIR}/frontend ${CMAKE_CURRENT_SOURCE_DIR}/frontend
) )
add_subdirectory(frontend) add_subdirectory(frontend)
include_directories(
${CMAKE_CURRENT_SOURCE_DIR}
${CMAKE_CURRENT_SOURCE_DIR}/nnet
)
add_subdirectory(nnet)
include_directories(
${CMAKE_CURRENT_SOURCE_DIR}
${CMAKE_CURRENT_SOURCE_DIR}/decoder
)
add_subdirectory(decoder)
add_executable(mfcc-test codelab/feat_test/feature-mfcc-test.cc) add_executable(mfcc-test codelab/feat_test/feature-mfcc-test.cc)
target_link_libraries(mfcc-test kaldi-mfcc) target_link_libraries(mfcc-test kaldi-mfcc)
add_executable(linear_spectrogram_main codelab/feat_test/linear_spectrogram_main.cc) add_executable(linear_spectrogram_main codelab/feat_test/linear_spectrogram_main.cc)
target_link_libraries(linear_spectrogram_main frontend kaldi-util kaldi-feat-common gflags glog) target_link_libraries(linear_spectrogram_main frontend kaldi-util kaldi-feat-common gflags glog)
#add_executable(offline_decoder_main codelab/decoder_test/offline_decoder_main.cc)
#target_link_libraries(offline_decoder_main nnet decoder gflags glog)

@ -17,6 +17,7 @@
#include <deque> #include <deque>
#include <iostream> #include <iostream>
#include <istream> #include <istream>
#include <fstream>
#include <map> #include <map>
#include <memory> #include <memory>
#include <ostream> #include <ostream>
@ -27,7 +28,9 @@
#include <vector> #include <vector>
#include <unordered_map> #include <unordered_map>
#include <unordered_set> #include <unordered_set>
#include <mutex>
#include "base/log.h" #include "base/log.h"
#include "base/flags.h"
#include "base/basic_types.h" #include "base/basic_types.h"
#include "base/macros.h" #include "base/macros.h"

@ -16,8 +16,10 @@
namespace ppspeech { namespace ppspeech {
#ifndef DISALLOW_COPY_AND_ASSIGN
#define DISALLOW_COPY_AND_ASSIGN(TypeName) \ #define DISALLOW_COPY_AND_ASSIGN(TypeName) \
TypeName(const TypeName&) = delete; \ TypeName(const TypeName&) = delete; \
void operator=(const TypeName&) = delete void operator=(const TypeName&) = delete
#endif
} // namespace pp_speech } // namespace pp_speech

@ -1,2 +1,10 @@
aux_source_directory(. DIR_LIB_SRCS) project(decoder)
add_library(decoder STATIC ${DIR_LIB_SRCS})
include_directories(${CMAKE_CURRENT_SOURCE_DIR/ctc_decoders})
add_library(decoder
ctc_beam_search_decoder.cc
ctc_decoders/decoder_utils.cpp
ctc_decoders/path_trie.cpp
ctc_decoders/scorer.cpp
)
target_link_libraries(decoder kenlm)

@ -2,33 +2,32 @@
#include "base/basic_types.h" #include "base/basic_types.h"
#include "decoder/ctc_decoders/decoder_utils.h" #include "decoder/ctc_decoders/decoder_utils.h"
#include "utils/file_utils.h"
namespace ppspeech { namespace ppspeech {
using std::vector; using std::vector;
using FSTMATCH = fst::SortedMatcher<fst::StdVectorFst>; using FSTMATCH = fst::SortedMatcher<fst::StdVectorFst>;
CTCBeamSearch::CTCBeamSearch(std::shared_ptr<CTCBeamSearchOptions> opts) : CTCBeamSearch::CTCBeamSearch(const CTCBeamSearchOptions& opts) :
opts_(opts), opts_(opts),
vocabulary_(nullptr),
init_ext_scorer_(nullptr), init_ext_scorer_(nullptr),
blank_id(-1), blank_id(-1),
space_id(-1), space_id(-1),
num_frame_decoded(0), num_frame_decoded_(0),
root(nullptr) { root(nullptr) {
LOG(INFO) << "dict path: " << opts_.dict_file; LOG(INFO) << "dict path: " << opts_.dict_file;
vocabulary_ = std::make_shared<vector<string>>(); if (!ReadFileToVector(opts_.dict_file, &vocabulary_)) {
if (!basr::ReadDictToVector(opts_.dict_file, *vocabulary_)) {
LOG(INFO) << "load the dict failed"; LOG(INFO) << "load the dict failed";
} }
LOG(INFO) << "read the vocabulary success, dict size: " << vocabulary_->size(); LOG(INFO) << "read the vocabulary success, dict size: " << vocabulary_.size();
LOG(INFO) << "language model path: " << opts_.lm_path; LOG(INFO) << "language model path: " << opts_.lm_path;
init_ext_scorer_ = std::make_shared<Scorer>(opts_.alpha, init_ext_scorer_ = std::make_shared<Scorer>(opts_.alpha,
opts_.beta, opts_.beta,
opts_.lm_path, opts_.lm_path,
*vocabulary_); vocabulary_);
} }
void CTCBeamSearch::Reset() { void CTCBeamSearch::Reset() {
@ -39,11 +38,11 @@ void CTCBeamSearch::Reset() {
void CTCBeamSearch::InitDecoder() { void CTCBeamSearch::InitDecoder() {
blank_id = 0; blank_id = 0;
auto it = std::find(vocabulary_->begin(), vocabulary_->end(), " "); auto it = std::find(vocabulary_.begin(), vocabulary_.end(), " ");
space_id = it - vocabulary_->begin(); space_id = it - vocabulary_.begin();
// if no space in vocabulary // if no space in vocabulary
if ((size_t)space_id >= vocabulary_->size()) { if ((size_t)space_id >= vocabulary_.size()) {
space_id = -2; space_id = -2;
} }
@ -63,19 +62,24 @@ void CTCBeamSearch::InitDecoder() {
} }
} }
void CTCBeamSearch::Decode(std::shared_ptr<kaldi::DecodableInterface> decodable) {
return;
}
int32 CTCBeamSearch::NumFrameDecoded() { int32 CTCBeamSearch::NumFrameDecoded() {
return num_frame_decoded_; return num_frame_decoded_;
} }
// todo rename, refactor // todo rename, refactor
void CTCBeamSearch::AdvanceDecode(const std::shared_ptr<kaldi::DecodableInterface>& decodable, int max_frames) { void CTCBeamSearch::AdvanceDecode(const std::shared_ptr<kaldi::DecodableInterface>& decodable,
int max_frames) {
while (max_frames > 0) { while (max_frames > 0) {
vector<vector<BaseFloat>> likelihood; vector<vector<BaseFloat>> likelihood;
if (decodable->IsLastFrame(NumFrameDecoded() + 1)) { if (decodable->IsLastFrame(NumFrameDecoded() + 1)) {
break; break;
} }
likelihood.push_back(decodable->FrameLogLikelihood(NumFrameDecoded() + 1)); likelihood.push_back(decodable->FrameLogLikelihood(NumFrameDecoded() + 1));
AdvanceDecoding(result); AdvanceDecoding(likelihood);
max_frames--; max_frames--;
} }
} }
@ -91,32 +95,21 @@ void CTCBeamSearch::ResetPrefixes() {
int CTCBeamSearch::DecodeLikelihoods(const vector<vector<float>>&probs, int CTCBeamSearch::DecodeLikelihoods(const vector<vector<float>>&probs,
vector<string>& nbest_words) { vector<string>& nbest_words) {
std::thread::id this_id = std::this_thread::get_id(); kaldi::Timer timer;
Timer timer;
vector<vector<double>> double_probs(probs.size(), vector<double>(probs[0].size(), 0));
int row = probs.size();
int col = probs[0].size();
for(int i = 0; i < row; i++) {
for (int j = 0; j < col; j++){
double_probs[i][j] = static_cast<double>(probs[i][j]);
}
}
timer.Reset(); timer.Reset();
AdvanceDecoding(double_probs); AdvanceDecoding(probs);
LOG(INFO) <<"ctc decoding elapsed time(s) " << static_cast<float>(timer.Elapsed()) / 1000.0f; LOG(INFO) <<"ctc decoding elapsed time(s) " << static_cast<float>(timer.Elapsed()) / 1000.0f;
return 0; return 0;
} }
vector<std::pair<double, string>> CTCBeamSearch::GetNBestPath() { vector<std::pair<double, string>> CTCBeamSearch::GetNBestPath() {
return get_beam_search_result(prefixes, *vocabulary_, opts_.beam_size); return get_beam_search_result(prefixes, vocabulary_, opts_.beam_size);
} }
string CTCBeamSearch::GetBestPath() { string CTCBeamSearch::GetBestPath() {
std::vector<std::pair<double, std::string>> result; std::vector<std::pair<double, std::string>> result;
result = get_beam_search_result(prefixes, *vocabulary_, opts_.beam_size); result = get_beam_search_result(prefixes, vocabulary_, opts_.beam_size);
return result[0]->second; return result[0].second;
} }
string CTCBeamSearch::GetFinalBestPath() { string CTCBeamSearch::GetFinalBestPath() {
@ -125,12 +118,22 @@ string CTCBeamSearch::GetFinalBestPath() {
return GetBestPath(); return GetBestPath();
} }
void CTCBeamSearch::AdvanceDecoding(const vector<vector<double>>& probs_seq) { void CTCBeamSearch::AdvanceDecoding(const vector<vector<BaseFloat>>& probs) {
size_t num_time_steps = probs_seq.size(); size_t num_time_steps = probs.size();
size_t beam_size = opts_.beam_size; size_t beam_size = opts_.beam_size;
double cutoff_prob = opts_.cutoff_prob; double cutoff_prob = opts_.cutoff_prob;
size_t cutoff_top_n = opts_.cutoff_top_n; size_t cutoff_top_n = opts_.cutoff_top_n;
vector<vector<double>> probs_seq(probs.size(), vector<double>(probs[0].size(), 0));
int row = probs.size();
int col = probs[0].size();
for(int i = 0; i < row; i++) {
for (int j = 0; j < col; j++){
probs_seq[i][j] = static_cast<double>(probs[i][j]);
}
}
for (size_t time_step = 0; time_step < num_time_steps; time_step++) { for (size_t time_step = 0; time_step < num_time_steps; time_step++) {
const auto& prob = probs_seq[time_step]; const auto& prob = probs_seq[time_step];
@ -158,7 +161,8 @@ void CTCBeamSearch::AdvanceDecoding(const vector<vector<double>>& probs_seq) {
size_t log_prob_idx_len = log_prob_idx.size(); size_t log_prob_idx_len = log_prob_idx.size();
for (size_t index = 0; index < log_prob_idx_len; index++) { for (size_t index = 0; index < log_prob_idx_len; index++) {
SearchOneChar(full_beam, log_prob_idx[index], min_cutoff); SearchOneChar(full_beam, log_prob_idx[index], min_cutoff);
}
prefixes.clear(); prefixes.clear();
// update log probs // update log probs
@ -177,9 +181,9 @@ void CTCBeamSearch::AdvanceDecoding(const vector<vector<double>>& probs_seq) {
} // for probs_seq } // for probs_seq
} }
int CTCBeamSearch::SearchOneChar(const bool& full_beam, int32 CTCBeamSearch::SearchOneChar(const bool& full_beam,
const std::pair<size_t, float>& log_prob_idx, const std::pair<size_t, BaseFloat>& log_prob_idx,
const float& min_cutoff) { const BaseFloat& min_cutoff) {
size_t beam_size = opts_.beam_size; size_t beam_size = opts_.beam_size;
const auto& c = log_prob_idx.first; const auto& c = log_prob_idx.first;
const auto& log_prob_c = log_prob_idx.second; const auto& log_prob_c = log_prob_idx.second;

@ -1,5 +1,8 @@
#include "base/basic_types.h" #include "base/common.h"
#include "nnet/decodable-itf.h" #include "nnet/decodable-itf.h"
#include "util/parse-options.h"
#include "decoder/ctc_decoders/scorer.h"
#include "decoder/ctc_decoders/path_trie.h"
#pragma once #pragma once
@ -38,41 +41,39 @@ struct CTCBeamSearchOptions {
}; };
class CTCBeamSearch { class CTCBeamSearch {
public: public:
explicit CTCBeamSearch(const CTCBeamSearchOptions& opts);
CTCBeamSearch(std::shared_ptr<CTCBeamSearchOptions> opts); ~CTCBeamSearch() {}
void InitDecoder();
~CTCBeamSearch() {
}
bool InitDecoder();
void Decode(std::shared_ptr<kaldi::DecodableInterface> decodable); void Decode(std::shared_ptr<kaldi::DecodableInterface> decodable);
std::string GetBestPath(); std::string GetBestPath();
std::vector<std::pair<double, std::string>> GetNBestPath(); std::vector<std::pair<double, std::string>> GetNBestPath();
std::string GetFinalBestPath(); std::string GetFinalBestPath();
int NumFrameDecoded(); int NumFrameDecoded();
int DecodeLikelihoods(const std::vector<std::vector<BaseFloat>>&probs, int DecodeLikelihoods(const std::vector<std::vector<BaseFloat>>&probs,
std::vector<std::string>& nbest_words); std::vector<std::string>& nbest_words);
void AdvanceDecode(const std::shared_ptr<kaldi::DecodableInterface>& decodable,
int max_frames);
void Reset(); void Reset();
private:
private:
void ResetPrefixes(); void ResetPrefixes();
int32 SearchOneChar(const bool& full_beam, int32 SearchOneChar(const bool& full_beam,
const std::pair<size_t, BaseFloat>& log_prob_idx, const std::pair<size_t, BaseFloat>& log_prob_idx,
const BaseFloat& min_cutoff); const BaseFloat& min_cutoff);
void CalculateApproxScore(); void CalculateApproxScore();
void LMRescore(); void LMRescore();
void AdvanceDecoding(const std::vector<std::vector<double>>& probs_seq); void AdvanceDecoding(const std::vector<std::vector<BaseFloat>>& probs);
CTCBeamSearchOptions opts_; CTCBeamSearchOptions opts_;
std::shared_ptr<Scorer> init_ext_scorer_; // todo separate later std::shared_ptr<Scorer> init_ext_scorer_; // todo separate later
//std::vector<DecodeResult> decoder_results_; //std::vector<DecodeResult> decoder_results_;
std::vector<std::vector<std::string>> vocabulary_; // todo remove later std::vector<std::string> vocabulary_; // todo remove later
size_t blank_id; size_t blank_id;
int space_id; int space_id;
std::shared_ptr<PathTrie> root; std::shared_ptr<PathTrie> root;
std::vector<PathTrie*> prefixes; std::vector<PathTrie*> prefixes;
int num_frame_decoded_; int num_frame_decoded_;
DISALLOW_COPY_AND_ASSIGN(CTCBeamSearch);
}; };
} // namespace basr } // namespace basr

@ -0,0 +1,2 @@
aux_source_directory(. DIR_LIB_SRCS)
add_library(nnet STATIC ${DIR_LIB_SRCS})

@ -114,6 +114,8 @@ class DecodableInterface {
/// this is for compatibility with OpenFst). /// this is for compatibility with OpenFst).
virtual int32 NumIndices() const = 0; virtual int32 NumIndices() const = 0;
virtual std::vector<BaseFloat> FrameLogLikelihood(int32 frame);
virtual ~DecodableInterface() {} virtual ~DecodableInterface() {}
}; };
/// @} /// @}

@ -2,15 +2,24 @@
namespace ppspeech { namespace ppspeech {
Decodable::Acceptlikelihood(const kaldi::Matrix<BaseFloat>& likelihood) { using kaldi::BaseFloat;
frames_ready_ += likelihood.NumRows(); using kaldi::Matrix;
Decodable::Decodable(const std::shared_ptr<NnetInterface>& nnet):
frontend_(NULL),
nnet_(nnet),
finished_(false),
frames_ready_(0) {
} }
Decodable::Init(DecodableConfig config) { void Decodable::Acceptlikelihood(const Matrix<BaseFloat>& likelihood) {
frames_ready_ += likelihood.NumRows();
} }
Decodable::IsLastFrame(int32 frame) const { //Decodable::Init(DecodableConfig config) {
//}
bool Decodable::IsLastFrame(int32 frame) const {
CHECK_LE(frame, frames_ready_); CHECK_LE(frame, frames_ready_);
return finished_ && (frame == frames_ready_ - 1); return finished_ && (frame == frames_ready_ - 1);
} }
@ -19,12 +28,11 @@ int32 Decodable::NumIndices() const {
return 0; return 0;
} }
void Decodable::LogLikelihood(int32 frame, int32 index) { BaseFloat Decodable::LogLikelihood(int32 frame, int32 index) {
return ; return 0;
} }
void Decodable::FeedFeatures(const kaldi::Matrix<kaldi::BaseFloat>& features) { void Decodable::FeedFeatures(const Matrix<kaldi::BaseFloat>& features) {
// skip frame ???
nnet_->FeedForward(features, &nnet_cache_); nnet_->FeedForward(features, &nnet_cache_);
frames_ready_ += nnet_cache_.NumRows(); frames_ready_ += nnet_cache_.NumRows();
return ; return ;
@ -35,4 +43,4 @@ void Decodable::Reset() {
nnet_->Reset(); nnet_->Reset();
} }
} // namespace ppspeech } // namespace ppspeech

@ -1,14 +1,17 @@
#include "nnet/decodable-itf.h" #include "nnet/decodable-itf.h"
#include "base/common.h" #include "base/common.h"
#include "kaldi/matrix/kaldi-matrix.h"
#include "frontend/feature_extractor_interface.h"
#include "nnet/nnet_interface.h"
namespace ppspeech { namespace ppspeech {
struct DecodableConfig; struct DecodableOpts;
class Decodable : public kaldi::DecodableInterface { class Decodable : public kaldi::DecodableInterface {
public: public:
virtual void Init(DecodableOpts config); explicit Decodable(const std::shared_ptr<NnetInterface>& nnet);
//void Init(DecodableOpts config);
virtual kaldi::BaseFloat LogLikelihood(int32 frame, int32 index); virtual kaldi::BaseFloat LogLikelihood(int32 frame, int32 index);
virtual bool IsLastFrame(int32 frame) const; virtual bool IsLastFrame(int32 frame) const;
virtual int32 NumIndices() const; virtual int32 NumIndices() const;
@ -25,4 +28,4 @@ class Decodable : public kaldi::DecodableInterface {
int32 frames_ready_; int32 frames_ready_;
}; };
} // namespace ppspeech } // namespace ppspeech

@ -3,13 +3,14 @@
#include "base/basic_types.h" #include "base/basic_types.h"
#include "kaldi/base/kaldi-types.h" #include "kaldi/base/kaldi-types.h"
#include "kaldi/matrix/kaldi-matrix.h"
namespace ppspeech { namespace ppspeech {
class NnetInterface { class NnetInterface {
public: public:
virtual ~NnetInterface() {} virtual ~NnetInterface() {}
virtual void FeedForward(const kaldi::Matrix<BaseFloat>& features, virtual void FeedForward(const kaldi::Matrix<kaldi::BaseFloat>& features,
kaldi::Matrix<kaldi::BaseFloat>* inferences); kaldi::Matrix<kaldi::BaseFloat>* inferences);
virtual void Reset(); virtual void Reset();

@ -3,6 +3,11 @@
namespace ppspeech { namespace ppspeech {
using std::vector;
using std::string;
using std::shared_ptr;
using kaldi::Matrix;
void PaddleNnet::InitCacheEncouts(const ModelOptions& opts) { void PaddleNnet::InitCacheEncouts(const ModelOptions& opts) {
std::vector<std::string> cache_names; std::vector<std::string> cache_names;
cache_names = absl::StrSplit(opts.cache_names, ", "); cache_names = absl::StrSplit(opts.cache_names, ", ");
@ -25,14 +30,14 @@ void PaddleNnet::InitCacheEncouts(const ModelOptions& opts) {
} }
} }
PaddleNet::PaddleNnet(const ModelOptions& opts) { PaddleNnet::PaddleNnet(const ModelOptions& opts) {
paddle_infer::Config config; paddle_infer::Config config;
config.SetModel(opts.model_path, opts.params_path); config.SetModel(opts.model_path, opts.params_path);
if (opts.use_gpu) { if (opts.use_gpu) {
config.EnableUseGpu(500, 0); config.EnableUseGpu(500, 0);
} }
config.SwitchIrOptim(opts.switch_ir_optim); config.SwitchIrOptim(opts.switch_ir_optim);
if (opts.enbale_fc_padding) { if (opts.enable_fc_padding) {
config.DisableFCPadding(); config.DisableFCPadding();
} }
if (opts.enable_profile) { if (opts.enable_profile) {
@ -42,7 +47,7 @@ PaddleNet::PaddleNnet(const ModelOptions& opts) {
if (pool == nullptr) { if (pool == nullptr) {
LOG(ERROR) << "create the predictor pool failed"; LOG(ERROR) << "create the predictor pool failed";
} }
pool_usages.resize(num_thread); pool_usages.resize(opts.thread_num);
std::fill(pool_usages.begin(), pool_usages.end(), false); std::fill(pool_usages.begin(), pool_usages.end(), false);
LOG(INFO) << "load paddle model success"; LOG(INFO) << "load paddle model success";
@ -51,7 +56,7 @@ PaddleNet::PaddleNnet(const ModelOptions& opts) {
LOG(INFO) << "output names: " << opts.output_names; LOG(INFO) << "output names: " << opts.output_names;
vector<string> input_names_vec = absl::StrSplit(opts.input_names, ", "); vector<string> input_names_vec = absl::StrSplit(opts.input_names, ", ");
vector<string> output_names_vec = absl::StrSplit(opts.output_names, ", "); vector<string> output_names_vec = absl::StrSplit(opts.output_names, ", ");
paddle_infer::Predictor* predictor = get_predictor(); paddle_infer::Predictor* predictor = GetPredictor();
std::vector<std::string> model_input_names = predictor->GetInputNames(); std::vector<std::string> model_input_names = predictor->GetInputNames();
assert(input_names_vec.size() == model_input_names.size()); assert(input_names_vec.size() == model_input_names.size());
@ -64,12 +69,12 @@ PaddleNet::PaddleNnet(const ModelOptions& opts) {
for (size_t i = 0;i < output_names_vec.size(); i++) { for (size_t i = 0;i < output_names_vec.size(); i++) {
assert(output_names_vec[i] == model_output_names[i]); assert(output_names_vec[i] == model_output_names[i]);
} }
release_predictor(predictor); ReleasePredictor(predictor);
InitCacheEncouts(opts); InitCacheEncouts(opts);
} }
paddle_infer::Predictor* PaddleNnet::get_predictor() { paddle_infer::Predictor* PaddleNnet::GetPredictor() {
LOG(INFO) << "attempt to get a new predictor instance " << std::endl; LOG(INFO) << "attempt to get a new predictor instance " << std::endl;
paddle_infer::Predictor* predictor = nullptr; paddle_infer::Predictor* predictor = nullptr;
std::lock_guard<std::mutex> guard(pool_mutex); std::lock_guard<std::mutex> guard(pool_mutex);
@ -111,19 +116,18 @@ int PaddleNnet::ReleasePredictor(paddle_infer::Predictor* predictor) {
return 0; return 0;
} }
shared_ptr<Tensor<BaseFloat>> PaddleNnet::GetCacheEncoder(const string& name) { shared_ptr<Tensor<BaseFloat>> PaddleNnet::GetCacheEncoder(const string& name) {
auto iter = cache_names_idx_.find(name); auto iter = cache_names_idx_.find(name);
if (iter == cache_names_idx_.end()) { if (iter == cache_names_idx_.end()) {
return nullptr; return nullptr;
} }
assert(iter->second < cache_encouts_.size()); assert(iter->second < cache_encouts_.size());
return cache_encouts_[iter->second].get(); return cache_encouts_[iter->second];
} }
void PaddleNet::FeedForward(const Matrix<BaseFloat>& features, Matrix<BaseFloat>* inferences) const { void PaddleNnet::FeedForward(const Matrix<BaseFloat>& features, Matrix<BaseFloat>* inferences) {
paddle_infer::Predictor* predictor = GetPredictor();
// 1. 得到所有的 input tensor 的名称 // 1. 得到所有的 input tensor 的名称
int row = features.NumRows(); int row = features.NumRows();
int col = features.NumCols(); int col = features.NumCols();
@ -144,15 +148,13 @@ void PaddleNet::FeedForward(const Matrix<BaseFloat>& features, Matrix<BaseFloat>
input_len->CopyFromCpu(audio_len.data()); input_len->CopyFromCpu(audio_len.data());
// 输入流式的缓存数据 // 输入流式的缓存数据
std::unique_ptr<paddle_infer::Tensor> h_box = predictor->GetInputHandle(input_names[2]); std::unique_ptr<paddle_infer::Tensor> h_box = predictor->GetInputHandle(input_names[2]);
share_ptr<Tensor<BaseFloat>> h_cache = GetCacheEncoder(input_names[2])); shared_ptr<Tensor<BaseFloat>> h_cache = GetCacheEncoder(input_names[2]);
h_box->Reshape(h_cache->get_shape()); h_box->Reshape(h_cache->get_shape());
h_box->CopyFromCpu(h_cache->get_data().data()); h_box->CopyFromCpu(h_cache->get_data().data());
std::unique_ptr<paddle_infer::Tensor> c_box = predictor->GetInputHandle(input_names[3]); std::unique_ptr<paddle_infer::Tensor> c_box = predictor->GetInputHandle(input_names[3]);
share_ptr<Tensor<float>> c_cache = GetCacheEncoder(input_names[3]); shared_ptr<Tensor<float>> c_cache = GetCacheEncoder(input_names[3]);
c_box->Reshape(c_cache->get_shape()); c_box->Reshape(c_cache->get_shape());
c_box->CopyFromCpu(c_cache->get_data().data()); c_box->CopyFromCpu(c_cache->get_data().data());
std::thread::id this_id = std::this_thread::get_id();
LOG(INFO) << this_id << " start to compute the probability";
bool success = predictor->Run(); bool success = predictor->Run();
if (success == false) { if (success == false) {
@ -172,8 +174,9 @@ void PaddleNet::FeedForward(const Matrix<BaseFloat>& features, Matrix<BaseFloat>
std::vector<int> output_shape = output_tensor->shape(); std::vector<int> output_shape = output_tensor->shape();
row = output_shape[1]; row = output_shape[1];
col = output_shape[2]; col = output_shape[2];
inference.Resize(row, col); inferences->Resize(row, col);
output_tensor->CopyToCpu(inference.Data()); output_tensor->CopyToCpu(inferences->Data());
ReleasePredictor(predictor);
} }
} // namespace ppspeech } // namespace ppspeech

@ -3,8 +3,12 @@
#include "nnet/nnet_interface.h" #include "nnet/nnet_interface.h"
#include "base/common.h" #include "base/common.h"
#include "paddle/paddle_inference_api.h" #include "paddle_inference_api.h"
#include "kaldi/matrix/kaldi-matrix.h"
#include "kaldi/util/options-itf.h"
#include <numeric>
namespace ppspeech { namespace ppspeech {
@ -20,7 +24,7 @@ struct ModelOptions {
std::string cache_shape; std::string cache_shape;
bool enable_fc_padding; bool enable_fc_padding;
bool enable_profile; bool enable_profile;
ModelDecoderOptions() : ModelOptions() :
model_path("model/final.zip"), model_path("model/final.zip"),
params_path("model/avg_1.jit.pdmodel"), params_path("model/avg_1.jit.pdmodel"),
thread_num(2), thread_num(2),
@ -49,16 +53,6 @@ struct ModelOptions {
} }
}; };
void Register(kaldi::OptionsItf* opts) {
_model_opts.Register(opts);
opts->Register("subsampling-rate", &subsampling_rate,
"subsampling rate for deepspeech model");
opts->Register("receptive-field-length", &receptive_field_length,
"receptive field length for deepspeech model");
}
};
template<typename T> template<typename T>
class Tensor { class Tensor {
public: public:
@ -91,15 +85,19 @@ private:
class PaddleNnet : public NnetInterface { class PaddleNnet : public NnetInterface {
public: public:
PaddleNnet(const ModelOptions& opts); PaddleNnet(const ModelOptions& opts);
virtual void FeedForward(const kaldi::Matrix<BaseFloat>& features, virtual void FeedForward(const kaldi::Matrix<kaldi::BaseFloat>& features,
kaldi::Matrix<kaldi::BaseFloat>* inferences) const; kaldi::Matrix<kaldi::BaseFloat>* inferences);
std::shared_ptr<Tensor<kaldi::BaseFloat>> GetCacheEncoder(const std::string& name); std::shared_ptr<Tensor<kaldi::BaseFloat>> GetCacheEncoder(const std::string& name);
void InitCacheEncouts(const ModelOptions& opts); void InitCacheEncouts(const ModelOptions& opts);
private: private:
paddle_infer::Predictor* GetPredictor();
int ReleasePredictor(paddle_infer::Predictor* predictor);
std::unique_ptr<paddle_infer::services::PredictorPool> pool; std::unique_ptr<paddle_infer::services::PredictorPool> pool;
std::vector<bool> pool_usages; std::vector<bool> pool_usages;
std::mutex pool_mutex; std::mutex pool_mutex;
std::map<paddle_infer::Predictor*, int> predictor_to_thread_id;
std::map<std::string, int> cache_names_idx_; std::map<std::string, int> cache_names_idx_;
std::vector<std::shared_ptr<Tensor<kaldi::BaseFloat>>> cache_encouts_; std::vector<std::shared_ptr<Tensor<kaldi::BaseFloat>>> cache_encouts_;

@ -0,0 +1,4 @@
add_library(utils
file_utils.cc
)

@ -0,0 +1,17 @@
#include "utils/file_utils.h"
bool ReadFileToVector(const std::string& filename,
std::vector<std::string>* vocabulary) {
std::ifstream file_in(filename);
if (!file_in) {
std::cerr << "please input a valid file" << std::endl;
return false;
}
std::string line;
while (std::getline(file_in, line)) {
vocabulary->emplace_back(line);
}
return true;
}

@ -0,0 +1,8 @@
#include "base/common.h"
namespace ppspeech {
bool ReadFileToVector(const std::string& filename,
std::vector<std::string>* data);
}
Loading…
Cancel
Save