add cmakelist of decoder, nnet

pull/1400/head
SmileGoat 2 years ago
parent e90438289d
commit d2d53cce1a

@ -39,16 +39,40 @@ FetchContent_Declare(
GIT_TAG "20210324.1"
)
FetchContent_MakeAvailable(absl)
include_directories(${absl_SOURCE_DIR}/absl)
include_directories(${absl_SOURCE_DIR})
# libsndfile
#include(FetchContent)
#FetchContent_Declare(
# libsndfile
# GIT_REPOSITORY "https://github.com/libsndfile/libsndfile.git"
# GIT_TAG "1.0.31"
#)
#FetchContent_MakeAvailable(libsndfile)
# todo boost build
#include(FetchContent)
#FetchContent_Declare(
# Boost
# URL https://boostorg.jfrog.io/artifactory/main/release/1.75.0/source/boost_1_75_0.zip
# URL_HASH SHA256=aeb26f80e80945e82ee93e5939baebdca47b9dee80a07d3144be1e1a6a66dd6a
#)
#FetchContent_MakeAvailable(Boost)
#include_directories(${Boost_SOURCE_DIR})
set(BOOST_ROOT ${fc_patch}/boost-subbuild/boost-populate-prefix/src/boost_1_75_0)
include_directories(${fc_patch}/boost-subbuild/boost-populate-prefix/src/boost_1_75_0)
link_directories(${fc_patch}/boost-subbuild/boost-populate-prefix/src/boost_1_75_0/stage/lib)
include(FetchContent)
FetchContent_Declare(
libsndfile
GIT_REPOSITORY "https://github.com/libsndfile/libsndfile.git"
GIT_TAG "1.0.31"
kenlm
GIT_REPOSITORY "https://github.com/kpu/kenlm.git"
GIT_TAG "df2d717e95183f79a90b2fa6e4307083a351ca6a"
)
FetchContent_MakeAvailable(libsndfile)
FetchContent_MakeAvailable(kenlm)
add_dependencies(kenlm Boost)
include_directories(${kenlm_SOURCE_DIR})
# gflags
FetchContent_Declare(
@ -94,6 +118,22 @@ add_dependencies(openfst gflags glog)
link_directories(${openfst_PREFIX_DIR}/lib)
include_directories(${openfst_PREFIX_DIR}/include)
set(PADDLE_LIB ${fc_patch}/paddle-lib/paddle_inference)
include_directories("${PADDLE_LIB}/paddle/include")
set(PADDLE_LIB_THIRD_PARTY_PATH "${PADDLE_LIB}/third_party/install/")
include_directories("${PADDLE_LIB_THIRD_PARTY_PATH}protobuf/include")
#include_directories("${PADDLE_LIB_THIRD_PARTY_PATH}glog/include")
#include_directories("${PADDLE_LIB_THIRD_PARTY_PATH}gflags/include")
include_directories("${PADDLE_LIB_THIRD_PARTY_PATH}xxhash/include")
include_directories("${PADDLE_LIB_THIRD_PARTY_PATH}cryptopp/include")
link_directories("${PADDLE_LIB_THIRD_PARTY_PATH}protobuf/lib")
#link_directories("${PADDLE_LIB_THIRD_PARTY_PATH}glog/lib")
#link_directories("${PADDLE_LIB_THIRD_PARTY_PATH}gflags/lib")
link_directories("${PADDLE_LIB_THIRD_PARTY_PATH}xxhash/lib")
link_directories("${PADDLE_LIB_THIRD_PARTY_PATH}cryptopp/lib")
link_directories("${PADDLE_LIB}/paddle/lib")
add_subdirectory(speechx)
#openblas
@ -122,4 +162,4 @@ add_subdirectory(speechx)
# if dir do not have CmakeLists.txt
#add_library(lib_name STATIC file.cc)
#target_link_libraries(lib_name item0 item1)
#add_dependencies(lib_name depend-target)
#add_dependencies(lib_name depend-target)

@ -12,14 +12,35 @@ ${CMAKE_CURRENT_SOURCE_DIR}/kaldi
)
add_subdirectory(kaldi)
include_directories(
${CMAKE_CURRENT_SOURCE_DIR}
${CMAKE_CURRENT_SOURCE_DIR}/utils
)
add_subdirectory(utils)
include_directories(
${CMAKE_CURRENT_SOURCE_DIR}
${CMAKE_CURRENT_SOURCE_DIR}/frontend
)
add_subdirectory(frontend)
include_directories(
${CMAKE_CURRENT_SOURCE_DIR}
${CMAKE_CURRENT_SOURCE_DIR}/nnet
)
add_subdirectory(nnet)
include_directories(
${CMAKE_CURRENT_SOURCE_DIR}
${CMAKE_CURRENT_SOURCE_DIR}/decoder
)
add_subdirectory(decoder)
add_executable(mfcc-test codelab/feat_test/feature-mfcc-test.cc)
target_link_libraries(mfcc-test kaldi-mfcc)
add_executable(linear_spectrogram_main codelab/feat_test/linear_spectrogram_main.cc)
target_link_libraries(linear_spectrogram_main frontend kaldi-util kaldi-feat-common gflags glog)
#add_executable(offline_decoder_main codelab/decoder_test/offline_decoder_main.cc)
#target_link_libraries(offline_decoder_main nnet decoder gflags glog)

@ -17,6 +17,7 @@
#include <deque>
#include <iostream>
#include <istream>
#include <fstream>
#include <map>
#include <memory>
#include <ostream>
@ -27,7 +28,9 @@
#include <vector>
#include <unordered_map>
#include <unordered_set>
#include <mutex>
#include "base/log.h"
#include "base/flags.h"
#include "base/basic_types.h"
#include "base/macros.h"

@ -16,8 +16,10 @@
namespace ppspeech {
#ifndef DISALLOW_COPY_AND_ASSIGN
#define DISALLOW_COPY_AND_ASSIGN(TypeName) \
TypeName(const TypeName&) = delete; \
void operator=(const TypeName&) = delete
#endif
} // namespace pp_speech

@ -1,2 +1,10 @@
aux_source_directory(. DIR_LIB_SRCS)
add_library(decoder STATIC ${DIR_LIB_SRCS})
project(decoder)
include_directories(${CMAKE_CURRENT_SOURCE_DIR/ctc_decoders})
add_library(decoder
ctc_beam_search_decoder.cc
ctc_decoders/decoder_utils.cpp
ctc_decoders/path_trie.cpp
ctc_decoders/scorer.cpp
)
target_link_libraries(decoder kenlm)

@ -2,33 +2,32 @@
#include "base/basic_types.h"
#include "decoder/ctc_decoders/decoder_utils.h"
#include "utils/file_utils.h"
namespace ppspeech {
using std::vector;
using FSTMATCH = fst::SortedMatcher<fst::StdVectorFst>;
CTCBeamSearch::CTCBeamSearch(std::shared_ptr<CTCBeamSearchOptions> opts) :
CTCBeamSearch::CTCBeamSearch(const CTCBeamSearchOptions& opts) :
opts_(opts),
vocabulary_(nullptr),
init_ext_scorer_(nullptr),
blank_id(-1),
space_id(-1),
num_frame_decoded(0),
num_frame_decoded_(0),
root(nullptr) {
LOG(INFO) << "dict path: " << opts_.dict_file;
vocabulary_ = std::make_shared<vector<string>>();
if (!basr::ReadDictToVector(opts_.dict_file, *vocabulary_)) {
if (!ReadFileToVector(opts_.dict_file, &vocabulary_)) {
LOG(INFO) << "load the dict failed";
}
LOG(INFO) << "read the vocabulary success, dict size: " << vocabulary_->size();
LOG(INFO) << "read the vocabulary success, dict size: " << vocabulary_.size();
LOG(INFO) << "language model path: " << opts_.lm_path;
init_ext_scorer_ = std::make_shared<Scorer>(opts_.alpha,
opts_.beta,
opts_.lm_path,
*vocabulary_);
vocabulary_);
}
void CTCBeamSearch::Reset() {
@ -39,11 +38,11 @@ void CTCBeamSearch::Reset() {
void CTCBeamSearch::InitDecoder() {
blank_id = 0;
auto it = std::find(vocabulary_->begin(), vocabulary_->end(), " ");
auto it = std::find(vocabulary_.begin(), vocabulary_.end(), " ");
space_id = it - vocabulary_->begin();
space_id = it - vocabulary_.begin();
// if no space in vocabulary
if ((size_t)space_id >= vocabulary_->size()) {
if ((size_t)space_id >= vocabulary_.size()) {
space_id = -2;
}
@ -63,19 +62,24 @@ void CTCBeamSearch::InitDecoder() {
}
}
void CTCBeamSearch::Decode(std::shared_ptr<kaldi::DecodableInterface> decodable) {
return;
}
int32 CTCBeamSearch::NumFrameDecoded() {
return num_frame_decoded_;
}
// todo rename, refactor
void CTCBeamSearch::AdvanceDecode(const std::shared_ptr<kaldi::DecodableInterface>& decodable, int max_frames) {
void CTCBeamSearch::AdvanceDecode(const std::shared_ptr<kaldi::DecodableInterface>& decodable,
int max_frames) {
while (max_frames > 0) {
vector<vector<BaseFloat>> likelihood;
if (decodable->IsLastFrame(NumFrameDecoded() + 1)) {
break;
}
likelihood.push_back(decodable->FrameLogLikelihood(NumFrameDecoded() + 1));
AdvanceDecoding(result);
AdvanceDecoding(likelihood);
max_frames--;
}
}
@ -91,32 +95,21 @@ void CTCBeamSearch::ResetPrefixes() {
int CTCBeamSearch::DecodeLikelihoods(const vector<vector<float>>&probs,
vector<string>& nbest_words) {
std::thread::id this_id = std::this_thread::get_id();
Timer timer;
vector<vector<double>> double_probs(probs.size(), vector<double>(probs[0].size(), 0));
int row = probs.size();
int col = probs[0].size();
for(int i = 0; i < row; i++) {
for (int j = 0; j < col; j++){
double_probs[i][j] = static_cast<double>(probs[i][j]);
}
}
kaldi::Timer timer;
timer.Reset();
AdvanceDecoding(double_probs);
AdvanceDecoding(probs);
LOG(INFO) <<"ctc decoding elapsed time(s) " << static_cast<float>(timer.Elapsed()) / 1000.0f;
return 0;
}
vector<std::pair<double, string>> CTCBeamSearch::GetNBestPath() {
return get_beam_search_result(prefixes, *vocabulary_, opts_.beam_size);
return get_beam_search_result(prefixes, vocabulary_, opts_.beam_size);
}
string CTCBeamSearch::GetBestPath() {
std::vector<std::pair<double, std::string>> result;
result = get_beam_search_result(prefixes, *vocabulary_, opts_.beam_size);
return result[0]->second;
result = get_beam_search_result(prefixes, vocabulary_, opts_.beam_size);
return result[0].second;
}
string CTCBeamSearch::GetFinalBestPath() {
@ -125,12 +118,22 @@ string CTCBeamSearch::GetFinalBestPath() {
return GetBestPath();
}
void CTCBeamSearch::AdvanceDecoding(const vector<vector<double>>& probs_seq) {
size_t num_time_steps = probs_seq.size();
void CTCBeamSearch::AdvanceDecoding(const vector<vector<BaseFloat>>& probs) {
size_t num_time_steps = probs.size();
size_t beam_size = opts_.beam_size;
double cutoff_prob = opts_.cutoff_prob;
size_t cutoff_top_n = opts_.cutoff_top_n;
vector<vector<double>> probs_seq(probs.size(), vector<double>(probs[0].size(), 0));
int row = probs.size();
int col = probs[0].size();
for(int i = 0; i < row; i++) {
for (int j = 0; j < col; j++){
probs_seq[i][j] = static_cast<double>(probs[i][j]);
}
}
for (size_t time_step = 0; time_step < num_time_steps; time_step++) {
const auto& prob = probs_seq[time_step];
@ -158,7 +161,8 @@ void CTCBeamSearch::AdvanceDecoding(const vector<vector<double>>& probs_seq) {
size_t log_prob_idx_len = log_prob_idx.size();
for (size_t index = 0; index < log_prob_idx_len; index++) {
SearchOneChar(full_beam, log_prob_idx[index], min_cutoff);
}
prefixes.clear();
// update log probs
@ -177,9 +181,9 @@ void CTCBeamSearch::AdvanceDecoding(const vector<vector<double>>& probs_seq) {
} // for probs_seq
}
int CTCBeamSearch::SearchOneChar(const bool& full_beam,
const std::pair<size_t, float>& log_prob_idx,
const float& min_cutoff) {
int32 CTCBeamSearch::SearchOneChar(const bool& full_beam,
const std::pair<size_t, BaseFloat>& log_prob_idx,
const BaseFloat& min_cutoff) {
size_t beam_size = opts_.beam_size;
const auto& c = log_prob_idx.first;
const auto& log_prob_c = log_prob_idx.second;

@ -1,5 +1,8 @@
#include "base/basic_types.h"
#include "base/common.h"
#include "nnet/decodable-itf.h"
#include "util/parse-options.h"
#include "decoder/ctc_decoders/scorer.h"
#include "decoder/ctc_decoders/path_trie.h"
#pragma once
@ -38,41 +41,39 @@ struct CTCBeamSearchOptions {
};
class CTCBeamSearch {
public:
CTCBeamSearch(std::shared_ptr<CTCBeamSearchOptions> opts);
~CTCBeamSearch() {
}
bool InitDecoder();
public:
explicit CTCBeamSearch(const CTCBeamSearchOptions& opts);
~CTCBeamSearch() {}
void InitDecoder();
void Decode(std::shared_ptr<kaldi::DecodableInterface> decodable);
std::string GetBestPath();
std::vector<std::pair<double, std::string>> GetNBestPath();
std::string GetFinalBestPath();
std::string GetFinalBestPath();
int NumFrameDecoded();
int DecodeLikelihoods(const std::vector<std::vector<BaseFloat>>&probs,
std::vector<std::string>& nbest_words);
void AdvanceDecode(const std::shared_ptr<kaldi::DecodableInterface>& decodable,
int max_frames);
void Reset();
private:
private:
void ResetPrefixes();
int32 SearchOneChar(const bool& full_beam,
const std::pair<size_t, BaseFloat>& log_prob_idx,
const BaseFloat& min_cutoff);
void CalculateApproxScore();
void LMRescore();
void AdvanceDecoding(const std::vector<std::vector<double>>& probs_seq);
void AdvanceDecoding(const std::vector<std::vector<BaseFloat>>& probs);
CTCBeamSearchOptions opts_;
std::shared_ptr<Scorer> init_ext_scorer_; // todo separate later
//std::vector<DecodeResult> decoder_results_;
std::vector<std::vector<std::string>> vocabulary_; // todo remove later
std::vector<std::string> vocabulary_; // todo remove later
size_t blank_id;
int space_id;
std::shared_ptr<PathTrie> root;
std::vector<PathTrie*> prefixes;
int num_frame_decoded_;
DISALLOW_COPY_AND_ASSIGN(CTCBeamSearch);
};
} // namespace basr

@ -0,0 +1,2 @@
aux_source_directory(. DIR_LIB_SRCS)
add_library(nnet STATIC ${DIR_LIB_SRCS})

@ -114,6 +114,8 @@ class DecodableInterface {
/// this is for compatibility with OpenFst).
virtual int32 NumIndices() const = 0;
virtual std::vector<BaseFloat> FrameLogLikelihood(int32 frame);
virtual ~DecodableInterface() {}
};
/// @}

@ -2,15 +2,24 @@
namespace ppspeech {
Decodable::Acceptlikelihood(const kaldi::Matrix<BaseFloat>& likelihood) {
frames_ready_ += likelihood.NumRows();
using kaldi::BaseFloat;
using kaldi::Matrix;
Decodable::Decodable(const std::shared_ptr<NnetInterface>& nnet):
frontend_(NULL),
nnet_(nnet),
finished_(false),
frames_ready_(0) {
}
Decodable::Init(DecodableConfig config) {
void Decodable::Acceptlikelihood(const Matrix<BaseFloat>& likelihood) {
frames_ready_ += likelihood.NumRows();
}
Decodable::IsLastFrame(int32 frame) const {
//Decodable::Init(DecodableConfig config) {
//}
bool Decodable::IsLastFrame(int32 frame) const {
CHECK_LE(frame, frames_ready_);
return finished_ && (frame == frames_ready_ - 1);
}
@ -19,12 +28,11 @@ int32 Decodable::NumIndices() const {
return 0;
}
void Decodable::LogLikelihood(int32 frame, int32 index) {
return ;
BaseFloat Decodable::LogLikelihood(int32 frame, int32 index) {
return 0;
}
void Decodable::FeedFeatures(const kaldi::Matrix<kaldi::BaseFloat>& features) {
// skip frame ???
void Decodable::FeedFeatures(const Matrix<kaldi::BaseFloat>& features) {
nnet_->FeedForward(features, &nnet_cache_);
frames_ready_ += nnet_cache_.NumRows();
return ;
@ -35,4 +43,4 @@ void Decodable::Reset() {
nnet_->Reset();
}
} // namespace ppspeech
} // namespace ppspeech

@ -1,14 +1,17 @@
#include "nnet/decodable-itf.h"
#include "base/common.h"
#include "kaldi/matrix/kaldi-matrix.h"
#include "frontend/feature_extractor_interface.h"
#include "nnet/nnet_interface.h"
namespace ppspeech {
struct DecodableConfig;
struct DecodableOpts;
class Decodable : public kaldi::DecodableInterface {
public:
virtual void Init(DecodableOpts config);
explicit Decodable(const std::shared_ptr<NnetInterface>& nnet);
//void Init(DecodableOpts config);
virtual kaldi::BaseFloat LogLikelihood(int32 frame, int32 index);
virtual bool IsLastFrame(int32 frame) const;
virtual int32 NumIndices() const;
@ -25,4 +28,4 @@ class Decodable : public kaldi::DecodableInterface {
int32 frames_ready_;
};
} // namespace ppspeech
} // namespace ppspeech

@ -3,13 +3,14 @@
#include "base/basic_types.h"
#include "kaldi/base/kaldi-types.h"
#include "kaldi/matrix/kaldi-matrix.h"
namespace ppspeech {
class NnetInterface {
public:
virtual ~NnetInterface() {}
virtual void FeedForward(const kaldi::Matrix<BaseFloat>& features,
virtual void FeedForward(const kaldi::Matrix<kaldi::BaseFloat>& features,
kaldi::Matrix<kaldi::BaseFloat>* inferences);
virtual void Reset();

@ -3,6 +3,11 @@
namespace ppspeech {
using std::vector;
using std::string;
using std::shared_ptr;
using kaldi::Matrix;
void PaddleNnet::InitCacheEncouts(const ModelOptions& opts) {
std::vector<std::string> cache_names;
cache_names = absl::StrSplit(opts.cache_names, ", ");
@ -25,14 +30,14 @@ void PaddleNnet::InitCacheEncouts(const ModelOptions& opts) {
}
}
PaddleNet::PaddleNnet(const ModelOptions& opts) {
PaddleNnet::PaddleNnet(const ModelOptions& opts) {
paddle_infer::Config config;
config.SetModel(opts.model_path, opts.params_path);
if (opts.use_gpu) {
config.EnableUseGpu(500, 0);
}
config.SwitchIrOptim(opts.switch_ir_optim);
if (opts.enbale_fc_padding) {
if (opts.enable_fc_padding) {
config.DisableFCPadding();
}
if (opts.enable_profile) {
@ -42,7 +47,7 @@ PaddleNet::PaddleNnet(const ModelOptions& opts) {
if (pool == nullptr) {
LOG(ERROR) << "create the predictor pool failed";
}
pool_usages.resize(num_thread);
pool_usages.resize(opts.thread_num);
std::fill(pool_usages.begin(), pool_usages.end(), false);
LOG(INFO) << "load paddle model success";
@ -51,7 +56,7 @@ PaddleNet::PaddleNnet(const ModelOptions& opts) {
LOG(INFO) << "output names: " << opts.output_names;
vector<string> input_names_vec = absl::StrSplit(opts.input_names, ", ");
vector<string> output_names_vec = absl::StrSplit(opts.output_names, ", ");
paddle_infer::Predictor* predictor = get_predictor();
paddle_infer::Predictor* predictor = GetPredictor();
std::vector<std::string> model_input_names = predictor->GetInputNames();
assert(input_names_vec.size() == model_input_names.size());
@ -64,12 +69,12 @@ PaddleNet::PaddleNnet(const ModelOptions& opts) {
for (size_t i = 0;i < output_names_vec.size(); i++) {
assert(output_names_vec[i] == model_output_names[i]);
}
release_predictor(predictor);
ReleasePredictor(predictor);
InitCacheEncouts(opts);
}
paddle_infer::Predictor* PaddleNnet::get_predictor() {
paddle_infer::Predictor* PaddleNnet::GetPredictor() {
LOG(INFO) << "attempt to get a new predictor instance " << std::endl;
paddle_infer::Predictor* predictor = nullptr;
std::lock_guard<std::mutex> guard(pool_mutex);
@ -111,19 +116,18 @@ int PaddleNnet::ReleasePredictor(paddle_infer::Predictor* predictor) {
return 0;
}
shared_ptr<Tensor<BaseFloat>> PaddleNnet::GetCacheEncoder(const string& name) {
auto iter = cache_names_idx_.find(name);
if (iter == cache_names_idx_.end()) {
return nullptr;
}
assert(iter->second < cache_encouts_.size());
return cache_encouts_[iter->second].get();
return cache_encouts_[iter->second];
}
void PaddleNet::FeedForward(const Matrix<BaseFloat>& features, Matrix<BaseFloat>* inferences) const {
void PaddleNnet::FeedForward(const Matrix<BaseFloat>& features, Matrix<BaseFloat>* inferences) {
paddle_infer::Predictor* predictor = GetPredictor();
// 1. 得到所有的 input tensor 的名称
int row = features.NumRows();
int col = features.NumCols();
@ -144,15 +148,13 @@ void PaddleNet::FeedForward(const Matrix<BaseFloat>& features, Matrix<BaseFloat>
input_len->CopyFromCpu(audio_len.data());
// 输入流式的缓存数据
std::unique_ptr<paddle_infer::Tensor> h_box = predictor->GetInputHandle(input_names[2]);
share_ptr<Tensor<BaseFloat>> h_cache = GetCacheEncoder(input_names[2]));
shared_ptr<Tensor<BaseFloat>> h_cache = GetCacheEncoder(input_names[2]);
h_box->Reshape(h_cache->get_shape());
h_box->CopyFromCpu(h_cache->get_data().data());
std::unique_ptr<paddle_infer::Tensor> c_box = predictor->GetInputHandle(input_names[3]);
share_ptr<Tensor<float>> c_cache = GetCacheEncoder(input_names[3]);
shared_ptr<Tensor<float>> c_cache = GetCacheEncoder(input_names[3]);
c_box->Reshape(c_cache->get_shape());
c_box->CopyFromCpu(c_cache->get_data().data());
std::thread::id this_id = std::this_thread::get_id();
LOG(INFO) << this_id << " start to compute the probability";
bool success = predictor->Run();
if (success == false) {
@ -172,8 +174,9 @@ void PaddleNet::FeedForward(const Matrix<BaseFloat>& features, Matrix<BaseFloat>
std::vector<int> output_shape = output_tensor->shape();
row = output_shape[1];
col = output_shape[2];
inference.Resize(row, col);
output_tensor->CopyToCpu(inference.Data());
inferences->Resize(row, col);
output_tensor->CopyToCpu(inferences->Data());
ReleasePredictor(predictor);
}
} // namespace ppspeech

@ -3,8 +3,12 @@
#include "nnet/nnet_interface.h"
#include "base/common.h"
#include "paddle/paddle_inference_api.h"
#include "paddle_inference_api.h"
#include "kaldi/matrix/kaldi-matrix.h"
#include "kaldi/util/options-itf.h"
#include <numeric>
namespace ppspeech {
@ -20,7 +24,7 @@ struct ModelOptions {
std::string cache_shape;
bool enable_fc_padding;
bool enable_profile;
ModelDecoderOptions() :
ModelOptions() :
model_path("model/final.zip"),
params_path("model/avg_1.jit.pdmodel"),
thread_num(2),
@ -49,16 +53,6 @@ struct ModelOptions {
}
};
void Register(kaldi::OptionsItf* opts) {
_model_opts.Register(opts);
opts->Register("subsampling-rate", &subsampling_rate,
"subsampling rate for deepspeech model");
opts->Register("receptive-field-length", &receptive_field_length,
"receptive field length for deepspeech model");
}
};
template<typename T>
class Tensor {
public:
@ -91,15 +85,19 @@ private:
class PaddleNnet : public NnetInterface {
public:
PaddleNnet(const ModelOptions& opts);
virtual void FeedForward(const kaldi::Matrix<BaseFloat>& features,
kaldi::Matrix<kaldi::BaseFloat>* inferences) const;
virtual void FeedForward(const kaldi::Matrix<kaldi::BaseFloat>& features,
kaldi::Matrix<kaldi::BaseFloat>* inferences);
std::shared_ptr<Tensor<kaldi::BaseFloat>> GetCacheEncoder(const std::string& name);
void InitCacheEncouts(const ModelOptions& opts);
void InitCacheEncouts(const ModelOptions& opts);
private:
paddle_infer::Predictor* GetPredictor();
int ReleasePredictor(paddle_infer::Predictor* predictor);
std::unique_ptr<paddle_infer::services::PredictorPool> pool;
std::vector<bool> pool_usages;
std::mutex pool_mutex;
std::map<paddle_infer::Predictor*, int> predictor_to_thread_id;
std::map<std::string, int> cache_names_idx_;
std::vector<std::shared_ptr<Tensor<kaldi::BaseFloat>>> cache_encouts_;

@ -0,0 +1,4 @@
add_library(utils
file_utils.cc
)

@ -0,0 +1,17 @@
#include "utils/file_utils.h"
bool ReadFileToVector(const std::string& filename,
std::vector<std::string>* vocabulary) {
std::ifstream file_in(filename);
if (!file_in) {
std::cerr << "please input a valid file" << std::endl;
return false;
}
std::string line;
while (std::getline(file_in, line)) {
vocabulary->emplace_back(line);
}
return true;
}

@ -0,0 +1,8 @@
#include "base/common.h"
namespace ppspeech {
bool ReadFileToVector(const std::string& filename,
std::vector<std::string>* data);
}
Loading…
Cancel
Save