You can not select more than 25 topics Topics must start with a letter or number, can include dashes ('-') and can be up to 35 characters long.
PaddleSpeech/runtime/engine/asr/nnet/u2_onnx_nnet.h

116 lines
4.1 KiB

// Copyright (c) 2020 Mobvoi Inc (Binbin Zhang, Di Wu)
// 2022 ZeXuan Li (lizexuan@huya.com)
// Xingchen Song(sxc19@mails.tsinghua.edu.cn)
// hamddct@gmail.com (Mddct)
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
#pragma once
#include "base/common.h"
#include "matrix/kaldi-matrix.h"
#include "nnet/nnet_itf.h"
#include "nnet/u2_nnet.h"
#include "onnxruntime_cxx_api.h" // NOLINT
namespace ppspeech {
class U2OnnxNnet : public U2NnetBase {
public:
static void InitEngineThreads(int num_threads = 1);
public:
explicit U2OnnxNnet(const ModelOptions& opts);
U2OnnxNnet(const U2OnnxNnet& other);
void FeedForward(const std::vector<kaldi::BaseFloat>& features,
const int32& feature_dim,
NnetOut* out) override;
void Reset() override;
bool IsLogProb() override { return true; }
void Dim();
void LoadModel(const std::string& model_dir);
// void Warmup();
// std::shared_ptr<paddle::jit::Layer> model() const { return model_; }
std::shared_ptr<NnetBase> Clone() const override;
void ForwardEncoderChunkImpl(
const std::vector<kaldi::BaseFloat>& chunk_feats,
const int32& feat_dim,
std::vector<kaldi::BaseFloat>* ctc_probs,
int32* vocab_dim) override;
// float ComputePathScore(const paddle::Tensor& prob,
// const std::vector<int>& hyp,
// int eos);
float ComputeAttentionScore(const float* prob, const std::vector<int>& hyp,
int eos, int decode_out_len);
void AttentionRescoring(const std::vector<std::vector<int>>& hyps,
float reverse_weight,
std::vector<float>* rescoring_score) override;
// debug
// void FeedEncoderOuts(const paddle::Tensor& encoder_out);
void EncoderOuts(
std::vector<std::vector<kaldi::BaseFloat>>* encoder_out) const;
// copy from wenet
void GetInputOutputInfo(const std::shared_ptr<Ort::Session>& session,
std::vector<const char*>* in_names,
std::vector<const char*>* out_names);
private:
ModelOptions opts_;
int encoder_output_size_ = 0;
int num_blocks_ = 0;
int cnn_module_kernel_ = 0;
int head_ = 0;
// sessions
// NOTE(Mddct): The Env holds the logging state used by all other objects.
// One Env must be created before using any other Onnxruntime functionality.
static Ort::Env env_; // shared environment across threads.
static Ort::SessionOptions session_options_;
std::shared_ptr<Ort::Session> encoder_session_ = nullptr;
std::shared_ptr<Ort::Session> rescore_session_ = nullptr;
std::shared_ptr<Ort::Session> ctc_session_ = nullptr;
// node names
std::vector<const char*> encoder_in_names_, encoder_out_names_;
std::vector<const char*> ctc_in_names_, ctc_out_names_;
std::vector<const char*> rescore_in_names_, rescore_out_names_;
// caches
Ort::Value att_cache_ort_{nullptr};
Ort::Value cnn_cache_ort_{nullptr};
std::vector<Ort::Value> encoder_outs_;
// NOTE: Instead of making a copy of the xx_cache, ONNX only maintains
// its data pointer when initializing xx_cache_ort (see https://github.com/
// microsoft/onnxruntime/blob/master/onnxruntime/core/framework
// /tensor.cc#L102-L129), so we need the following variables to keep
// our data "alive" during the lifetime of decoder.
std::vector<float> att_cache_;
std::vector<float> cnn_cache_;
};
} // namespace ppspeech