// Copyright 2022 Horizon Robotics. All Rights Reserved. // Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved. // // Licensed under the Apache License, Version 2.0 (the "License"); // you may not use this file except in compliance with the License. // You may obtain a copy of the License at // // http://www.apache.org/licenses/LICENSE-2.0 // // Unless required by applicable law or agreed to in writing, software // distributed under the License is distributed on an "AS IS" BASIS, // WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. // See the License for the specific language governing permissions and // limitations under the License. // modified from // https://github.com/wenet-e2e/wenet/blob/main/runtime/core/decoder/onnx_asr_model.h #pragma once #include "base/common.h" #include "matrix/kaldi-matrix.h" #include "nnet/nnet_itf.h" #include "nnet/u2_nnet.h" #include "fastdeploy/runtime.h" namespace ppspeech { class U2OnnxNnet : public U2NnetBase { public: explicit U2OnnxNnet(const ModelOptions& opts); U2OnnxNnet(const U2OnnxNnet& other); void FeedForward(const std::vector& features, const int32& feature_dim, NnetOut* out) override; void Reset() override; bool IsLogProb() override { return true; } void Dim(); void LoadModel(const std::string& model_dir); std::shared_ptr Clone() const override; void ForwardEncoderChunkImpl( const std::vector& chunk_feats, const int32& feat_dim, std::vector* ctc_probs, int32* vocab_dim) override; float ComputeAttentionScore(const float* prob, const std::vector& hyp, int eos, int decode_out_len); void AttentionRescoring(const std::vector>& hyps, float reverse_weight, std::vector* rescoring_score) override; void EncoderOuts( std::vector>* encoder_out) const; void GetInputOutputInfo(const std::shared_ptr& runtime, std::vector* in_names, std::vector* out_names); private: ModelOptions opts_; int encoder_output_size_ = 0; int num_blocks_ = 0; int cnn_module_kernel_ = 0; int head_ = 0; // sessions std::shared_ptr encoder_ = nullptr; std::shared_ptr rescore_ = nullptr; std::shared_ptr ctc_ = nullptr; // node names std::vector encoder_in_names_, encoder_out_names_; std::vector ctc_in_names_, ctc_out_names_; std::vector rescore_in_names_, rescore_out_names_; // caches fastdeploy::FDTensor att_cache_ort_; fastdeploy::FDTensor cnn_cache_ort_; std::vector encoder_outs_; std::vector att_cache_; std::vector cnn_cache_; }; } // namespace ppspeech