|
|
@ -18,9 +18,56 @@
|
|
|
|
#include "base/basic_types.h"
|
|
|
|
#include "base/basic_types.h"
|
|
|
|
#include "kaldi/base/kaldi-types.h"
|
|
|
|
#include "kaldi/base/kaldi-types.h"
|
|
|
|
#include "kaldi/matrix/kaldi-matrix.h"
|
|
|
|
#include "kaldi/matrix/kaldi-matrix.h"
|
|
|
|
|
|
|
|
#include "kaldi/util/options-itf.h"
|
|
|
|
|
|
|
|
|
|
|
|
namespace ppspeech {
|
|
|
|
namespace ppspeech {
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
struct ModelOptions {
|
|
|
|
|
|
|
|
std::string model_path;
|
|
|
|
|
|
|
|
std::string param_path;
|
|
|
|
|
|
|
|
int thread_num; // predictor thread pool size for ds2;
|
|
|
|
|
|
|
|
bool use_gpu;
|
|
|
|
|
|
|
|
bool switch_ir_optim;
|
|
|
|
|
|
|
|
std::string input_names;
|
|
|
|
|
|
|
|
std::string output_names;
|
|
|
|
|
|
|
|
std::string cache_names;
|
|
|
|
|
|
|
|
std::string cache_shape;
|
|
|
|
|
|
|
|
bool enable_fc_padding;
|
|
|
|
|
|
|
|
bool enable_profile;
|
|
|
|
|
|
|
|
ModelOptions()
|
|
|
|
|
|
|
|
: model_path(""),
|
|
|
|
|
|
|
|
param_path(""),
|
|
|
|
|
|
|
|
thread_num(1),
|
|
|
|
|
|
|
|
use_gpu(false),
|
|
|
|
|
|
|
|
input_names(""),
|
|
|
|
|
|
|
|
output_names(""),
|
|
|
|
|
|
|
|
cache_names(""),
|
|
|
|
|
|
|
|
cache_shape(""),
|
|
|
|
|
|
|
|
switch_ir_optim(false),
|
|
|
|
|
|
|
|
enable_fc_padding(false),
|
|
|
|
|
|
|
|
enable_profile(false) {}
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
void Register(kaldi::OptionsItf* opts) {
|
|
|
|
|
|
|
|
opts->Register("model-path", &model_path, "model file path");
|
|
|
|
|
|
|
|
opts->Register("model-param", ¶m_path, "params model file path");
|
|
|
|
|
|
|
|
opts->Register("thread-num", &thread_num, "thread num");
|
|
|
|
|
|
|
|
opts->Register("use-gpu", &use_gpu, "if use gpu");
|
|
|
|
|
|
|
|
opts->Register("input-names", &input_names, "paddle input names");
|
|
|
|
|
|
|
|
opts->Register("output-names", &output_names, "paddle output names");
|
|
|
|
|
|
|
|
opts->Register("cache-names", &cache_names, "cache names");
|
|
|
|
|
|
|
|
opts->Register("cache-shape", &cache_shape, "cache shape");
|
|
|
|
|
|
|
|
opts->Register("switch-ir-optiom",
|
|
|
|
|
|
|
|
&switch_ir_optim,
|
|
|
|
|
|
|
|
"paddle SwitchIrOptim option");
|
|
|
|
|
|
|
|
opts->Register("enable-fc-padding",
|
|
|
|
|
|
|
|
&enable_fc_padding,
|
|
|
|
|
|
|
|
"paddle EnableFCPadding option");
|
|
|
|
|
|
|
|
opts->Register(
|
|
|
|
|
|
|
|
"enable-profile", &enable_profile, "paddle EnableProfile option");
|
|
|
|
|
|
|
|
}
|
|
|
|
|
|
|
|
};
|
|
|
|
|
|
|
|
|
|
|
|
struct NnetOut {
|
|
|
|
struct NnetOut {
|
|
|
|
// nnet out. maybe logprob or prob. Almost time this is logprob.
|
|
|
|
// nnet out. maybe logprob or prob. Almost time this is logprob.
|
|
|
|
kaldi::Vector<kaldi::BaseFloat> logprobs;
|
|
|
|
kaldi::Vector<kaldi::BaseFloat> logprobs;
|
|
|
@ -45,6 +92,10 @@ class NnetInterface {
|
|
|
|
const int32& feature_dim,
|
|
|
|
const int32& feature_dim,
|
|
|
|
NnetOut* out) = 0;
|
|
|
|
NnetOut* out) = 0;
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
virtual void AttentionRescoring(const std::vector<std::vector<int>>& hyps,
|
|
|
|
|
|
|
|
float reverse_weight,
|
|
|
|
|
|
|
|
std::vector<float>* rescoring_score) = 0;
|
|
|
|
|
|
|
|
|
|
|
|
// reset nnet state, e.g. nnet_logprob_cache_, offset_, encoder_outs_.
|
|
|
|
// reset nnet state, e.g. nnet_logprob_cache_, offset_, encoder_outs_.
|
|
|
|
virtual void Reset() = 0;
|
|
|
|
virtual void Reset() = 0;
|
|
|
|
|
|
|
|
|
|
|
|