You can not select more than 25 topics Topics must start with a letter or number, can include dashes ('-') and can be up to 35 characters long.
PaddleSpeech/runtime/engine/asr/nnet/u2_onnx_nnet.h

97 lines
3.1 KiB

// Copyright 2022 Horizon Robotics. All Rights Reserved.
// Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
// modified from
// https://github.com/wenet-e2e/wenet/blob/main/runtime/core/decoder/onnx_asr_model.h
#pragma once
#include "base/common.h"
#include "matrix/kaldi-matrix.h"
#include "nnet/nnet_itf.h"
#include "nnet/u2_nnet.h"
#include "fastdeploy/runtime.h"
namespace ppspeech {
class U2OnnxNnet : public U2NnetBase {
public:
explicit U2OnnxNnet(const ModelOptions& opts);
U2OnnxNnet(const U2OnnxNnet& other);
void FeedForward(const std::vector<kaldi::BaseFloat>& features,
const int32& feature_dim,
NnetOut* out) override;
void Reset() override;
bool IsLogProb() override { return true; }
void Dim();
void LoadModel(const std::string& model_dir);
std::shared_ptr<NnetBase> Clone() const override;
void ForwardEncoderChunkImpl(
const std::vector<kaldi::BaseFloat>& chunk_feats,
const int32& feat_dim,
std::vector<kaldi::BaseFloat>* ctc_probs,
int32* vocab_dim) override;
float ComputeAttentionScore(const float* prob, const std::vector<int>& hyp,
int eos, int decode_out_len);
void AttentionRescoring(const std::vector<std::vector<int>>& hyps,
float reverse_weight,
std::vector<float>* rescoring_score) override;
void EncoderOuts(
std::vector<std::vector<kaldi::BaseFloat>>* encoder_out) const;
void GetInputOutputInfo(const std::shared_ptr<fastdeploy::Runtime>& runtime,
std::vector<std::string>* in_names,
std::vector<std::string>* out_names);
private:
ModelOptions opts_;
int encoder_output_size_ = 0;
int num_blocks_ = 0;
int cnn_module_kernel_ = 0;
int head_ = 0;
// sessions
std::shared_ptr<fastdeploy::Runtime> encoder_ = nullptr;
std::shared_ptr<fastdeploy::Runtime> rescore_ = nullptr;
std::shared_ptr<fastdeploy::Runtime> ctc_ = nullptr;
// node names
std::vector<std::string> encoder_in_names_, encoder_out_names_;
std::vector<std::string> ctc_in_names_, ctc_out_names_;
std::vector<std::string> rescore_in_names_, rescore_out_names_;
// caches
fastdeploy::FDTensor att_cache_ort_;
fastdeploy::FDTensor cnn_cache_ort_;
std::vector<fastdeploy::FDTensor> encoder_outs_;
std::vector<float> att_cache_;
std::vector<float> cnn_cache_;
};
} // namespace ppspeech