You can not select more than 25 topics Topics must start with a letter or number, can include dashes ('-') and can be up to 35 characters long.
PaddleSpeech/runtime/engine/asr/nnet/nnet_itf.h

105 lines
3.2 KiB

3 years ago
// Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
#pragma once
3 years ago
#include "base/basic_types.h"
#include "kaldi/base/kaldi-types.h"
#include "kaldi/util/options-itf.h"
2 years ago
DECLARE_int32(subsampling_rate);
DECLARE_string(model_path);
DECLARE_string(param_path);
DECLARE_string(model_input_names);
DECLARE_string(model_output_names);
DECLARE_string(model_cache_names);
DECLARE_string(model_cache_shapes);
#ifdef USE_ONNX
DECLARE_bool(with_onnx_model);
#endif
2 years ago
namespace ppspeech {
struct ModelOptions {
2 years ago
// common
int subsample_rate{1};
bool use_gpu{false};
std::string model_path;
#ifdef USE_ONNX
bool with_onnx_model{false};
#endif
2 years ago
2 years ago
static ModelOptions InitFromFlags() {
2 years ago
ModelOptions opts;
opts.subsample_rate = FLAGS_subsampling_rate;
2 years ago
LOG(INFO) << "subsampling rate: " << opts.subsample_rate;
2 years ago
opts.model_path = FLAGS_model_path;
2 years ago
LOG(INFO) << "model path: " << opts.model_path;
#ifdef USE_ONNX
opts.with_onnx_model = FLAGS_with_onnx_model;
LOG(INFO) << "with onnx model: " << opts.with_onnx_model;
#endif
2 years ago
return opts;
}
};
struct NnetOut {
// nnet out. maybe logprob or prob. Almost time this is logprob.
std::vector<kaldi::BaseFloat> logprobs;
int32 vocab_dim;
// nnet state. Only using in Attention model.
std::vector<std::vector<kaldi::BaseFloat>> encoder_outs;
NnetOut() : logprobs({}), vocab_dim(-1), encoder_outs({}) {}
};
3 years ago
class NnetInterface {
public:
virtual ~NnetInterface() {}
// forward feat with nnet.
// nnet do not cache feats, feats cached by frontend.
// nnet cache model state, i.e. encoder_outs, att_cache, cnn_cache,
// frame_offset.
virtual void FeedForward(const std::vector<kaldi::BaseFloat>& features,
const int32& feature_dim,
NnetOut* out) = 0;
virtual void AttentionRescoring(const std::vector<std::vector<int>>& hyps,
float reverse_weight,
std::vector<float>* rescoring_score) = 0;
// reset nnet state, e.g. nnet_logprob_cache_, offset_, encoder_outs_.
virtual void Reset() = 0;
// true, nnet output is logprob; otherwise is prob,
virtual bool IsLogProb() = 0;
// using to get encoder outs. e.g. seq2seq with Attention model.
virtual void EncoderOuts(
std::vector<std::vector<kaldi::BaseFloat>>* encoder_out) const = 0;
};
class NnetBase : public NnetInterface {
public:
int SubsamplingRate() const { return subsampling_rate_; }
virtual std::shared_ptr<NnetBase> Clone() const = 0;
2 years ago
protected:
int subsampling_rate_{1};
};
} // namespace ppspeech