[engine] replace onnx with fastdeploy ()

* onnxruntime change to fastdeploy
pull/3156/head
masimeng1994 2 years ago committed by GitHub
parent d03ebe872a
commit 11ce08b260
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23

@ -88,7 +88,7 @@ include(pybind)
#onnx
if(WITH_ONNX)
include(onnx)
add_definitions(-DUSE_ONNX)
endif()
# gtest

@ -26,13 +26,10 @@ if(NOT FASTDEPLOY_INSTALL_DIR)
else() # Linux
FetchContent_Declare(
fastdeploy
URL https://bj.bcebos.com/fastdeploy/release/cpp/fastdeploy-linux-x64-1.0.4.tgz
URL_HASH MD5=125df3bfce603521960cc5c8b47faab0
URL https://paddlespeech.bj.bcebos.com/speechx/fastdeploy/fastdeploy-1.0.5-x86_64-onnx.tar.gz
URL_HASH MD5=33900d986ea71aa78635e52f0733227c
${EXTERNAL_PROJECT_LOG_ARGS}
)
add_definitions("-DUSE_PADDLE_INFERENCE_BACKEND")
# add_definitions("-DUSE_ORT_BACKEND")
set(CMAKE_CXX_FLAGS_DEBUG "${CMAKE_CXX_FLAGS_DEBUG} -msse -msse2")
set(CMAKE_CXX_FLAGS_RELEASE "${CMAKE_CXX_FLAGS_RELEASE} -msse -msse2 -mavx -O3")
endif()

@ -1,52 +0,0 @@
# Copyright (c) 2020 Mobvoi Inc (Binbin Zhang, Di Wu)
# 2022 ZeXuan Li (lizexuan@huya.com)
# Xingchen Song(sxc19@mails.tsinghua.edu.cn)
# hamddct@gmail.com (Mddct)
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
if(WITH_ONNX)
set(ONNX_VERSION "1.12.0")
if(${CMAKE_SYSTEM_NAME} STREQUAL "Windows")
set(ONNX_URL "https://github.com/microsoft/onnxruntime/releases/download/v${ONNX_VERSION}/onnxruntime-win-x64-${ONNX_VERSION}.zip")
set(URL_HASH "SHA256=8b5d61204989350b7904ac277f5fbccd3e6736ddbb6ec001e412723d71c9c176")
elseif(${CMAKE_SYSTEM_NAME} STREQUAL "Linux")
if(CMAKE_SYSTEM_PROCESSOR MATCHES "aarch64")
set(ONNX_URL "https://github.com/microsoft/onnxruntime/releases/download/v${ONNX_VERSION}/onnxruntime-linux-aarch64-${ONNX_VERSION}.tgz")
set(URL_HASH "SHA256=5820d9f343df73c63b6b2b174a1ff62575032e171c9564bcf92060f46827d0ac")
else()
set(ONNX_URL "https://github.com/microsoft/onnxruntime/releases/download/v${ONNX_VERSION}/onnxruntime-linux-x64-${ONNX_VERSION}.tgz")
set(URL_HASH "SHA256=5d503ce8540358b59be26c675e42081be14a3e833a5301926f555451046929c5")
endif()
elseif(${CMAKE_SYSTEM_NAME} STREQUAL "Darwin")
set(ONNX_URL "https://github.com/microsoft/onnxruntime/releases/download/v${ONNX_VERSION}/onnxruntime-osx-x86_64-${ONNX_VERSION}.tgz")
set(URL_HASH "SHA256=09b17f712f8c6f19bb63da35d508815b443cbb473e16c6192abfaa297c02f600")
else()
message(FATAL_ERROR "Unsupported CMake System Name '${CMAKE_SYSTEM_NAME}' (expected 'Windows', 'Linux' or 'Darwin')")
endif()
FetchContent_Declare(onnxruntime
URL ${ONNX_URL}
URL_HASH ${URL_HASH}
)
FetchContent_MakeAvailable(onnxruntime)
include_directories(${onnxruntime_SOURCE_DIR}/include)
link_directories(${onnxruntime_SOURCE_DIR}/lib)
if(MSVC)
file(GLOB ONNX_DLLS "${onnxruntime_SOURCE_DIR}/lib/*.dll")
file(COPY ${ONNX_DLLS} DESTINATION ${CMAKE_BINARY_DIR}/bin/${CMAKE_BUILD_TYPE})
endif()
add_definitions(-DUSE_ONNX)
endif()

@ -7,7 +7,7 @@ endif()
add_library(nnet STATIC ${srcs})
target_link_libraries(nnet utils)
if(WITH_ONNX)
target_link_libraries(nnet onnxruntime)
target_link_libraries(nnet ${FASTDEPLOY_LIBS})
endif()
target_compile_options(nnet PUBLIC ${PADDLE_COMPILE_FLAGS})
@ -18,4 +18,4 @@ target_include_directories(nnet PUBLIC ${pybind11_INCLUDE_DIRS} ${PROJECT_SOURC
#add_executable(${bin_name} ${CMAKE_CURRENT_SOURCE_DIR}/${bin_name}.cc)
#target_compile_options(${bin_name} PRIVATE ${PADDLE_COMPILE_FLAGS})
#target_include_directories(${bin_name} PRIVATE ${pybind11_INCLUDE_DIRS} ${PROJECT_SOURCE_DIR})
#target_link_libraries(${bin_name} ${PYTHON_LIBRARIES} ${PADDLE_LINK_FLAGS})
#target_link_libraries(${bin_name} ${PYTHON_LIBRARIES} ${PADDLE_LINK_FLAGS})

@ -59,9 +59,6 @@ int main(int argc, char* argv[]) {
kaldi::BaseFloatMatrixWriter nnet_out_writer(FLAGS_nnet_prob_wspecifier);
ppspeech::ModelOptions model_opts = ppspeech::ModelOptions::InitFromFlags();
#ifdef USE_ONNX
ppspeech::U2OnnxNnet::InitEngineThreads(1);
#endif
ppspeech::FeaturePipelineOptions feature_opts =
ppspeech::FeaturePipelineOptions::InitFromFlags();
feature_opts.assembler_opts.fill_zero = false;

@ -1,7 +1,5 @@
// Copyright (c) 2020 Mobvoi Inc (Binbin Zhang, Di Wu)
// 2022 ZeXuan Li (lizexuan@huya.com)
// Xingchen Song(sxc19@mails.tsinghua.edu.cn)
// hamddct@gmail.com (Mddct)
// Copyright 2022 Horizon Robotics. All Rights Reserved.
// Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
@ -15,67 +13,52 @@
// See the License for the specific language governing permissions and
// limitations under the License.
// modified from
// https://github.com/wenet-e2e/wenet/blob/main/runtime/core/decoder/onnx_asr_model.cc
#include "nnet/u2_onnx_nnet.h"
#include "common/base/config.h"
namespace ppspeech {
Ort::Env U2OnnxNnet::env_ = Ort::Env(ORT_LOGGING_LEVEL_WARNING, "");
Ort::SessionOptions U2OnnxNnet::session_options_ = Ort::SessionOptions();
void U2OnnxNnet::InitEngineThreads(int num_threads) {
session_options_.SetIntraOpNumThreads(num_threads);
}
void U2OnnxNnet::LoadModel(const std::string& model_dir) {
std::string encoder_onnx_path = model_dir + "/encoder.onnx";
std::string rescore_onnx_path = model_dir + "/decoder.onnx";
std::string ctc_onnx_path = model_dir + "/ctc.onnx";
std::string param_path = model_dir + "/param.onnx";
// 1. Load sessions
try {
#ifdef _MSC_VER
encoder_session_ = std::make_shared<Ort::Session>(
env_, ToWString(encoder_onnx_path).c_str(), session_options_);
rescore_session_ = std::make_shared<Ort::Session>(
env_, ToWString(rescore_onnx_path).c_str(), session_options_);
ctc_session_ = std::make_shared<Ort::Session>(
env_, ToWString(ctc_onnx_path).c_str(), session_options_);
#else
encoder_session_ = std::make_shared<Ort::Session>(
env_, encoder_onnx_path.c_str(), session_options_);
rescore_session_ = std::make_shared<Ort::Session>(
env_, rescore_onnx_path.c_str(), session_options_);
ctc_session_ = std::make_shared<Ort::Session>(env_, ctc_onnx_path.c_str(),
session_options_);
#endif
encoder_ = std::make_shared<fastdeploy::Runtime>();
ctc_ = std::make_shared<fastdeploy::Runtime>();
rescore_ = std::make_shared<fastdeploy::Runtime>();
fastdeploy::RuntimeOption runtime_option;
runtime_option.UseOrtBackend();
runtime_option.UseCpu();
runtime_option.SetCpuThreadNum(1);
runtime_option.SetModelPath(encoder_onnx_path.c_str(), "", fastdeploy::ModelFormat::ONNX);
assert(encoder_->Init(runtime_option));
runtime_option.SetModelPath(rescore_onnx_path.c_str(), "", fastdeploy::ModelFormat::ONNX);
assert(rescore_->Init(runtime_option));
runtime_option.SetModelPath(ctc_onnx_path.c_str(), "", fastdeploy::ModelFormat::ONNX);
assert(ctc_->Init(runtime_option));
} catch (std::exception const& e) {
LOG(ERROR) << "error when load onnx model: " << e.what();
exit(0);
}
// 2. Read metadata
auto model_metadata = encoder_session_->GetModelMetadata();
Ort::AllocatorWithDefaultOptions allocator;
encoder_output_size_ =
atoi(model_metadata.LookupCustomMetadataMap("output_size", allocator));
num_blocks_ =
atoi(model_metadata.LookupCustomMetadataMap("num_blocks", allocator));
head_ = atoi(model_metadata.LookupCustomMetadataMap("head", allocator));
cnn_module_kernel_ = atoi(
model_metadata.LookupCustomMetadataMap("cnn_module_kernel", allocator));
subsampling_rate_ = atoi(
model_metadata.LookupCustomMetadataMap("subsampling_rate", allocator));
right_context_ =
atoi(model_metadata.LookupCustomMetadataMap("right_context", allocator));
sos_ = atoi(model_metadata.LookupCustomMetadataMap("sos_symbol", allocator));
eos_ = atoi(model_metadata.LookupCustomMetadataMap("eos_symbol", allocator));
is_bidecoder_ = atoi(model_metadata.LookupCustomMetadataMap(
"is_bidirectional_decoder", allocator));
chunk_size_ =
atoi(model_metadata.LookupCustomMetadataMap("chunk_size", allocator));
num_left_chunks_ =
atoi(model_metadata.LookupCustomMetadataMap("left_chunks", allocator));
Config conf(param_path);
encoder_output_size_ = conf.Read("output_size", encoder_output_size_);
num_blocks_ = conf.Read("num_blocks", num_blocks_);
head_ = conf.Read("head", head_);
cnn_module_kernel_ = conf.Read("cnn_module_kernel", cnn_module_kernel_);
subsampling_rate_ = conf.Read("subsampling_rate", subsampling_rate_);
right_context_ = conf.Read("right_context", right_context_);
sos_= conf.Read("sos_symbol", sos_);
eos_= conf.Read("eos_symbol", eos_);
is_bidecoder_= conf.Read("is_bidirectional_decoder", is_bidecoder_);
chunk_size_= conf.Read("chunk_size", chunk_size_);
num_left_chunks_ = conf.Read("left_chunks", num_left_chunks_);
LOG(INFO) << "Onnx Model Info:";
LOG(INFO) << "\tencoder_output_size " << encoder_output_size_;
LOG(INFO) << "\tnum_blocks " << num_blocks_;
@ -91,11 +74,11 @@ void U2OnnxNnet::LoadModel(const std::string& model_dir) {
// 3. Read model nodes
LOG(INFO) << "Onnx Encoder:";
GetInputOutputInfo(encoder_session_, &encoder_in_names_, &encoder_out_names_);
GetInputOutputInfo(encoder_, &encoder_in_names_, &encoder_out_names_);
LOG(INFO) << "Onnx CTC:";
GetInputOutputInfo(ctc_session_, &ctc_in_names_, &ctc_out_names_);
GetInputOutputInfo(ctc_, &ctc_in_names_, &ctc_out_names_);
LOG(INFO) << "Onnx Rescore:";
GetInputOutputInfo(rescore_session_, &rescore_in_names_, &rescore_out_names_);
GetInputOutputInfo(rescore_, &rescore_in_names_, &rescore_out_names_);
}
U2OnnxNnet::U2OnnxNnet(const ModelOptions& opts) : opts_(opts) {
@ -117,11 +100,11 @@ U2OnnxNnet::U2OnnxNnet(const U2OnnxNnet& other) {
chunk_size_ = other.chunk_size_;
num_left_chunks_ = other.num_left_chunks_;
offset_ = other.offset_;
// sessions
encoder_session_ = other.encoder_session_;
ctc_session_ = other.ctc_session_;
rescore_session_ = other.rescore_session_;
// session
encoder_ = other.encoder_;
ctc_ = other.ctc_;
rescore_ = other.rescore_;
// node names
encoder_in_names_ = other.encoder_in_names_;
@ -132,46 +115,36 @@ U2OnnxNnet::U2OnnxNnet(const U2OnnxNnet& other) {
rescore_out_names_ = other.rescore_out_names_;
}
void U2OnnxNnet::GetInputOutputInfo(
const std::shared_ptr<Ort::Session>& session,
std::vector<const char*>* in_names, std::vector<const char*>* out_names) {
Ort::AllocatorWithDefaultOptions allocator;
// Input info
int num_nodes = session->GetInputCount();
in_names->resize(num_nodes);
for (int i = 0; i < num_nodes; ++i) {
char* name = session->GetInputName(i, allocator);
Ort::TypeInfo type_info = session->GetInputTypeInfo(i);
auto tensor_info = type_info.GetTensorTypeAndShapeInfo();
ONNXTensorElementDataType type = tensor_info.GetElementType();
std::vector<int64_t> node_dims = tensor_info.GetShape();
std::stringstream shape;
for (auto j : node_dims) {
shape << j;
shape << " ";
}
LOG(INFO) << "\tInput " << i << " : name=" << name << " type=" << type
void U2OnnxNnet::GetInputOutputInfo(const std::shared_ptr<fastdeploy::Runtime>& runtime,
std::vector<std::string>* in_names, std::vector<std::string>* out_names) {
std::vector<fastdeploy::TensorInfo> inputs_info = runtime->GetInputInfos();
(*in_names).resize(inputs_info.size());
for (int i = 0; i < inputs_info.size(); ++i){
fastdeploy::TensorInfo info = inputs_info[i];
std::stringstream shape;
for(int j = 0; j < info.shape.size(); ++j){
shape << info.shape[j];
shape << " ";
}
LOG(INFO) << "\tInput " << i << " : name=" << info.name << " type=" << info.dtype
<< " dims=" << shape.str();
(*in_names)[i] = name;
}
// Output info
num_nodes = session->GetOutputCount();
out_names->resize(num_nodes);
for (int i = 0; i < num_nodes; ++i) {
char* name = session->GetOutputName(i, allocator);
Ort::TypeInfo type_info = session->GetOutputTypeInfo(i);
auto tensor_info = type_info.GetTensorTypeAndShapeInfo();
ONNXTensorElementDataType type = tensor_info.GetElementType();
std::vector<int64_t> node_dims = tensor_info.GetShape();
std::stringstream shape;
for (auto j : node_dims) {
shape << j;
shape << " ";
(*in_names)[i] = info.name;
}
LOG(INFO) << "\tOutput " << i << " : name=" << name << " type=" << type
std::vector<fastdeploy::TensorInfo> outputs_info = runtime->GetOutputInfos();
(*out_names).resize(outputs_info.size());
for (int i = 0; i < outputs_info.size(); ++i){
fastdeploy::TensorInfo info = outputs_info[i];
std::stringstream shape;
for(int j = 0; j < info.shape.size(); ++j){
shape << info.shape[j];
shape << " ";
}
LOG(INFO) << "\tOutput " << i << " : name=" << info.name << " type=" << info.dtype
<< " dims=" << shape.str();
(*out_names)[i] = name;
}
(*out_names)[i] = info.name;
}
}
std::shared_ptr<NnetBase> U2OnnxNnet::Clone() const {
@ -186,33 +159,28 @@ void U2OnnxNnet::Reset() {
encoder_outs_.clear();
cached_feats_.clear();
// Reset att_cache
Ort::MemoryInfo memory_info =
Ort::MemoryInfo::CreateCpu(OrtDeviceAllocator, OrtMemTypeCPU);
if (num_left_chunks_ > 0) {
int required_cache_size = chunk_size_ * num_left_chunks_;
offset_ = required_cache_size;
att_cache_.resize(num_blocks_ * head_ * required_cache_size *
encoder_output_size_ / head_ * 2,
0.0);
const int64_t att_cache_shape[] = {num_blocks_, head_, required_cache_size,
const std::vector<int64_t> att_cache_shape = {num_blocks_, head_, required_cache_size,
encoder_output_size_ / head_ * 2};
att_cache_ort_ = Ort::Value::CreateTensor<float>(
memory_info, att_cache_.data(), att_cache_.size(), att_cache_shape, 4);
att_cache_ort_.SetExternalData(att_cache_shape, fastdeploy::FDDataType::FP32, att_cache_.data());
} else {
att_cache_.resize(0, 0.0);
const int64_t att_cache_shape[] = {num_blocks_, head_, 0,
const std::vector<int64_t> att_cache_shape = {num_blocks_, head_, 0,
encoder_output_size_ / head_ * 2};
att_cache_ort_ = Ort::Value::CreateTensor<float>(
memory_info, att_cache_.data(), att_cache_.size(), att_cache_shape, 4);
att_cache_ort_.SetExternalData(att_cache_shape, fastdeploy::FDDataType::FP32, att_cache_.data());
}
// Reset cnn_cache
cnn_cache_.resize(
num_blocks_ * encoder_output_size_ * (cnn_module_kernel_ - 1), 0.0);
const int64_t cnn_cache_shape[] = {num_blocks_, 1, encoder_output_size_,
const std::vector<int64_t> cnn_cache_shape = {num_blocks_, 1, encoder_output_size_,
cnn_module_kernel_ - 1};
cnn_cache_ort_ = Ort::Value::CreateTensor<float>(
memory_info, cnn_cache_.data(), cnn_cache_.size(), cnn_cache_shape, 4);
cnn_cache_ort_.SetExternalData(cnn_cache_shape, fastdeploy::FDDataType::FP32, cnn_cache_.data());
}
void U2OnnxNnet::FeedForward(const std::vector<BaseFloat>& features,
@ -233,8 +201,6 @@ void U2OnnxNnet::ForwardEncoderChunkImpl(
std::vector<kaldi::BaseFloat>* out_prob,
int32* vocab_dim) {
Ort::MemoryInfo memory_info =
Ort::MemoryInfo::CreateCpu(OrtDeviceAllocator, OrtMemTypeCPU);
// 1. Prepare onnx required data, splice cached_feature_ and chunk_feats
// chunk
int num_frames = chunk_feats.size() / feat_dim;
@ -243,73 +209,79 @@ void U2OnnxNnet::ForwardEncoderChunkImpl(
const int feature_dim = feat_dim;
std::vector<float> feats;
feats.insert(feats.end(), chunk_feats.begin(), chunk_feats.end());
const int64_t feats_shape[3] = {1, num_frames, feature_dim};
Ort::Value feats_ort = Ort::Value::CreateTensor<float>(
memory_info, feats.data(), feats.size(), feats_shape, 3);
fastdeploy::FDTensor feats_ort;
const std::vector<int64_t> feats_shape = {1, num_frames, feature_dim};
feats_ort.SetExternalData(feats_shape, fastdeploy::FDDataType::FP32, feats.data());
// offset
int64_t offset_int64 = static_cast<int64_t>(offset_);
Ort::Value offset_ort = Ort::Value::CreateTensor<int64_t>(
memory_info, &offset_int64, 1, std::vector<int64_t>{}.data(), 0);
fastdeploy::FDTensor offset_ort;
offset_ort.SetExternalData({}, fastdeploy::FDDataType::INT64, &offset_int64);
// required_cache_size
int64_t required_cache_size = chunk_size_ * num_left_chunks_;
Ort::Value required_cache_size_ort = Ort::Value::CreateTensor<int64_t>(
memory_info, &required_cache_size, 1, std::vector<int64_t>{}.data(), 0);
fastdeploy::FDTensor required_cache_size_ort("");
required_cache_size_ort.SetExternalData({}, fastdeploy::FDDataType::INT64, &required_cache_size);
// att_mask
Ort::Value att_mask_ort{nullptr};
fastdeploy::FDTensor att_mask_ort;
std::vector<uint8_t> att_mask(required_cache_size + chunk_size_, 1);
if (num_left_chunks_ > 0) {
int chunk_idx = offset_ / chunk_size_ - num_left_chunks_;
if (chunk_idx < num_left_chunks_) {
for (int i = 0; i < (num_left_chunks_ - chunk_idx) * chunk_size_; ++i) {
att_mask[i] = 0;
}
for (int i = 0; i < (num_left_chunks_ - chunk_idx) * chunk_size_; ++i) {
att_mask[i] = 0;
}
}
const int64_t att_mask_shape[] = {1, 1, required_cache_size + chunk_size_};
att_mask_ort = Ort::Value::CreateTensor<bool>(
memory_info, reinterpret_cast<bool*>(att_mask.data()), att_mask.size(),
att_mask_shape, 3);
const std::vector<int64_t> att_mask_shape = {1, 1, required_cache_size + chunk_size_};
att_mask_ort.SetExternalData(att_mask_shape, fastdeploy::FDDataType::BOOL, reinterpret_cast<bool*>(att_mask.data()));
}
// 2. Encoder chunk forward
std::vector<Ort::Value> inputs;
for (auto name : encoder_in_names_) {
if (!strcmp(name, "chunk")) {
inputs.emplace_back(std::move(feats_ort));
} else if (!strcmp(name, "offset")) {
inputs.emplace_back(std::move(offset_ort));
} else if (!strcmp(name, "required_cache_size")) {
inputs.emplace_back(std::move(required_cache_size_ort));
} else if (!strcmp(name, "att_cache")) {
inputs.emplace_back(std::move(att_cache_ort_));
} else if (!strcmp(name, "cnn_cache")) {
inputs.emplace_back(std::move(cnn_cache_ort_));
} else if (!strcmp(name, "att_mask")) {
inputs.emplace_back(std::move(att_mask_ort));
std::vector<fastdeploy::FDTensor> inputs(encoder_in_names_.size());
for (int i = 0; i < encoder_in_names_.size(); ++i) {
std::string name = encoder_in_names_[i];
if (!strcmp(name.data(), "chunk")) {
inputs[i] = std::move(feats_ort);
inputs[i].name = "chunk";
} else if (!strcmp(name.data(), "offset")) {
inputs[i] = std::move(offset_ort);
inputs[i].name = "offset";
} else if (!strcmp(name.data(), "required_cache_size")) {
inputs[i] = std::move(required_cache_size_ort);
inputs[i].name = "required_cache_size";
} else if (!strcmp(name.data(), "att_cache")) {
inputs[i] = std::move(att_cache_ort_);
inputs[i].name = "att_cache";
} else if (!strcmp(name.data(), "cnn_cache")) {
inputs[i] = std::move(cnn_cache_ort_);
inputs[i].name = "cnn_cache";
} else if (!strcmp(name.data(), "att_mask")) {
inputs[i] = std::move(att_mask_ort);
inputs[i].name = "att_mask";
}
}
std::vector<fastdeploy::FDTensor> ort_outputs;
assert(encoder_->Infer(inputs, &ort_outputs));
std::vector<Ort::Value> ort_outputs = encoder_session_->Run(
Ort::RunOptions{nullptr}, encoder_in_names_.data(), inputs.data(),
inputs.size(), encoder_out_names_.data(), encoder_out_names_.size());
offset_ += static_cast<int>(
ort_outputs[0].GetTensorTypeAndShapeInfo().GetShape()[1]);
offset_ += static_cast<int>(ort_outputs[0].shape[1]);
att_cache_ort_ = std::move(ort_outputs[1]);
cnn_cache_ort_ = std::move(ort_outputs[2]);
std::vector<Ort::Value> ctc_inputs;
std::vector<fastdeploy::FDTensor> ctc_inputs;
ctc_inputs.emplace_back(std::move(ort_outputs[0]));
// ctc_inputs[0] = std::move(ort_outputs[0]);
ctc_inputs[0].name = ctc_in_names_[0];
std::vector<Ort::Value> ctc_ort_outputs = ctc_session_->Run(
Ort::RunOptions{nullptr}, ctc_in_names_.data(), ctc_inputs.data(),
ctc_inputs.size(), ctc_out_names_.data(), ctc_out_names_.size());
encoder_outs_.push_back(std::move(ctc_inputs[0]));
std::vector<fastdeploy::FDTensor> ctc_ort_outputs;
assert(ctc_->Infer(ctc_inputs, &ctc_ort_outputs));
encoder_outs_.emplace_back(std::move(ctc_inputs[0])); // *****
float* logp_data = ctc_ort_outputs[0].GetTensorMutableData<float>();
auto type_info = ctc_ort_outputs[0].GetTensorTypeAndShapeInfo();
float* logp_data = reinterpret_cast<float*>(ctc_ort_outputs[0].Data());
// Copy to output, (B=1,T,D)
std::vector<int64_t> ctc_log_probs_shape = type_info.GetShape();
std::vector<int64_t> ctc_log_probs_shape = ctc_ort_outputs[0].shape;
CHECK_EQ(ctc_log_probs_shape.size(), 3);
int B = ctc_log_probs_shape[0];
CHECK_EQ(B, 1);
@ -337,8 +309,6 @@ float U2OnnxNnet::ComputeAttentionScore(const float* prob,
void U2OnnxNnet::AttentionRescoring(const std::vector<std::vector<int>>& hyps,
float reverse_weight,
std::vector<float>* rescoring_score) {
Ort::MemoryInfo memory_info =
Ort::MemoryInfo::CreateCpu(OrtDeviceAllocator, OrtMemTypeCPU);
CHECK(rescoring_score != nullptr);
int num_hyps = hyps.size();
rescoring_score->resize(num_hyps, 0.0f);
@ -362,16 +332,13 @@ void U2OnnxNnet::AttentionRescoring(const std::vector<std::vector<int>>& hyps,
std::vector<float> rescore_input;
int encoder_len = 0;
for (int i = 0; i < encoder_outs_.size(); i++) {
float* encoder_outs_data = encoder_outs_[i].GetTensorMutableData<float>();
auto type_info = encoder_outs_[i].GetTensorTypeAndShapeInfo();
for (int j = 0; j < type_info.GetElementCount(); j++) {
rescore_input.emplace_back(encoder_outs_data[j]);
float* encoder_outs_data = reinterpret_cast<float*>(encoder_outs_[i].Data());
for (int j = 0; j < encoder_outs_[i].Numel(); j++) {
rescore_input.emplace_back(encoder_outs_data[j]);
}
encoder_len += type_info.GetShape()[1];
encoder_len += encoder_outs_[i].shape[1];
}
const int64_t decode_input_shape[] = {1, encoder_len, encoder_output_size_};
std::vector<int64_t> hyps_pad;
for (size_t i = 0; i < num_hyps; ++i) {
@ -379,44 +346,43 @@ void U2OnnxNnet::AttentionRescoring(const std::vector<std::vector<int>>& hyps,
hyps_pad.emplace_back(sos_);
size_t j = 0;
for (; j < hyp.size(); ++j) {
hyps_pad.emplace_back(hyp[j]);
hyps_pad.emplace_back(hyp[j]);
}
if (j == max_hyps_len - 1) {
continue;
continue;
}
for (; j < max_hyps_len - 1; ++j) {
hyps_pad.emplace_back(0);
hyps_pad.emplace_back(0);
}
}
const int64_t hyps_pad_shape[] = {num_hyps, max_hyps_len};
const int64_t hyps_lens_shape[] = {num_hyps};
const std::vector<int64_t> hyps_pad_shape = {num_hyps, max_hyps_len};
const std::vector<int64_t> hyps_lens_shape = {num_hyps};
const std::vector<int64_t> decode_input_shape = {1, encoder_len, encoder_output_size_};
Ort::Value decode_input_tensor_ = Ort::Value::CreateTensor<float>(
memory_info, rescore_input.data(), rescore_input.size(),
decode_input_shape, 3);
Ort::Value hyps_pad_tensor_ = Ort::Value::CreateTensor<int64_t>(
memory_info, hyps_pad.data(), hyps_pad.size(), hyps_pad_shape, 2);
Ort::Value hyps_lens_tensor_ = Ort::Value::CreateTensor<int64_t>(
memory_info, hyps_lens.data(), hyps_lens.size(), hyps_lens_shape, 1);
fastdeploy::FDTensor hyps_pad_tensor_;
hyps_pad_tensor_.SetExternalData(hyps_pad_shape, fastdeploy::FDDataType::INT64, hyps_pad.data());
fastdeploy::FDTensor hyps_lens_tensor_;
hyps_lens_tensor_.SetExternalData(hyps_lens_shape, fastdeploy::FDDataType::INT64, hyps_lens.data());
fastdeploy::FDTensor decode_input_tensor_;
decode_input_tensor_.SetExternalData(decode_input_shape, fastdeploy::FDDataType::FP32, rescore_input.data());
std::vector<Ort::Value> rescore_inputs;
std::vector<fastdeploy::FDTensor> rescore_inputs(3);
rescore_inputs.emplace_back(std::move(hyps_pad_tensor_));
rescore_inputs.emplace_back(std::move(hyps_lens_tensor_));
rescore_inputs.emplace_back(std::move(decode_input_tensor_));
rescore_inputs[0] = std::move(hyps_pad_tensor_);
rescore_inputs[0].name = rescore_in_names_[0];
rescore_inputs[1] = std::move(hyps_lens_tensor_);
rescore_inputs[1].name = rescore_in_names_[1];
rescore_inputs[2] = std::move(decode_input_tensor_);
rescore_inputs[2].name = rescore_in_names_[2];
std::vector<Ort::Value> rescore_outputs = rescore_session_->Run(
Ort::RunOptions{nullptr}, rescore_in_names_.data(), rescore_inputs.data(),
rescore_inputs.size(), rescore_out_names_.data(),
rescore_out_names_.size());
std::vector<fastdeploy::FDTensor> rescore_outputs;
assert(rescore_->Infer(rescore_inputs, &rescore_outputs));
float* decoder_outs_data = rescore_outputs[0].GetTensorMutableData<float>();
float* r_decoder_outs_data = rescore_outputs[1].GetTensorMutableData<float>();
float* decoder_outs_data = reinterpret_cast<float*>(rescore_outputs[0].Data());
float* r_decoder_outs_data = reinterpret_cast<float*>(rescore_outputs[1].Data());
auto type_info = rescore_outputs[0].GetTensorTypeAndShapeInfo();
int decode_out_len = type_info.GetShape()[2];
int decode_out_len = rescore_outputs[0].shape[2];
for (size_t i = 0; i < num_hyps; ++i) {
const std::vector<int>& hyp = hyps[i];

@ -1,7 +1,5 @@
// Copyright (c) 2020 Mobvoi Inc (Binbin Zhang, Di Wu)
// 2022 ZeXuan Li (lizexuan@huya.com)
// Xingchen Song(sxc19@mails.tsinghua.edu.cn)
// hamddct@gmail.com (Mddct)
// Copyright 2022 Horizon Robotics. All Rights Reserved.
// Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
@ -15,6 +13,9 @@
// See the License for the specific language governing permissions and
// limitations under the License.
// modified from
// https://github.com/wenet-e2e/wenet/blob/main/runtime/core/decoder/onnx_asr_model.h
#pragma once
#include "base/common.h"
@ -22,14 +23,11 @@
#include "nnet/nnet_itf.h"
#include "nnet/u2_nnet.h"
#include "onnxruntime_cxx_api.h" // NOLINT
#include "fastdeploy/runtime.h"
namespace ppspeech {
class U2OnnxNnet : public U2NnetBase {
public:
static void InitEngineThreads(int num_threads = 1);
public:
explicit U2OnnxNnet(const ModelOptions& opts);
@ -46,9 +44,6 @@ class U2OnnxNnet : public U2NnetBase {
void Dim();
void LoadModel(const std::string& model_dir);
// void Warmup();
// std::shared_ptr<paddle::jit::Layer> model() const { return model_; }
std::shared_ptr<NnetBase> Clone() const override;
@ -58,9 +53,6 @@ class U2OnnxNnet : public U2NnetBase {
std::vector<kaldi::BaseFloat>* ctc_probs,
int32* vocab_dim) override;
// float ComputePathScore(const paddle::Tensor& prob,
// const std::vector<int>& hyp,
// int eos);
float ComputeAttentionScore(const float* prob, const std::vector<int>& hyp,
int eos, int decode_out_len);
@ -68,16 +60,12 @@ class U2OnnxNnet : public U2NnetBase {
float reverse_weight,
std::vector<float>* rescoring_score) override;
// debug
// void FeedEncoderOuts(const paddle::Tensor& encoder_out);
void EncoderOuts(
std::vector<std::vector<kaldi::BaseFloat>>* encoder_out) const;
// copy from wenet
void GetInputOutputInfo(const std::shared_ptr<Ort::Session>& session,
std::vector<const char*>* in_names,
std::vector<const char*>* out_names);
void GetInputOutputInfo(const std::shared_ptr<fastdeploy::Runtime>& runtime,
std::vector<std::string>* in_names,
std::vector<std::string>* out_names);
private:
ModelOptions opts_;
@ -87,28 +75,21 @@ class U2OnnxNnet : public U2NnetBase {
int head_ = 0;
// sessions
// NOTE(Mddct): The Env holds the logging state used by all other objects.
// One Env must be created before using any other Onnxruntime functionality.
static Ort::Env env_; // shared environment across threads.
static Ort::SessionOptions session_options_;
std::shared_ptr<Ort::Session> encoder_session_ = nullptr;
std::shared_ptr<Ort::Session> rescore_session_ = nullptr;
std::shared_ptr<Ort::Session> ctc_session_ = nullptr;
std::shared_ptr<fastdeploy::Runtime> encoder_ = nullptr;
std::shared_ptr<fastdeploy::Runtime> rescore_ = nullptr;
std::shared_ptr<fastdeploy::Runtime> ctc_ = nullptr;
// node names
std::vector<const char*> encoder_in_names_, encoder_out_names_;
std::vector<const char*> ctc_in_names_, ctc_out_names_;
std::vector<const char*> rescore_in_names_, rescore_out_names_;
std::vector<std::string> encoder_in_names_, encoder_out_names_;
std::vector<std::string> ctc_in_names_, ctc_out_names_;
std::vector<std::string> rescore_in_names_, rescore_out_names_;
// caches
Ort::Value att_cache_ort_{nullptr};
Ort::Value cnn_cache_ort_{nullptr};
std::vector<Ort::Value> encoder_outs_;
// NOTE: Instead of making a copy of the xx_cache, ONNX only maintains
// its data pointer when initializing xx_cache_ort (see https://github.com/
// microsoft/onnxruntime/blob/master/onnxruntime/core/framework
// /tensor.cc#L102-L129), so we need the following variables to keep
// our data "alive" during the lifetime of decoder.
fastdeploy::FDTensor att_cache_ort_;
fastdeploy::FDTensor cnn_cache_ort_;
std::vector<fastdeploy::FDTensor> encoder_outs_;
std::vector<float> att_cache_;
std::vector<float> cnn_cache_;
};

@ -1 +1,3 @@
# add_definitions("-DUSE_PADDLE_INFERENCE_BACKEND")
add_definitions("-DUSE_ORT_BACKEND")
add_subdirectory(nnet)
Loading…
Cancel
Save