commit
4851d1d3a2
@ -0,0 +1,3 @@
|
||||
data
|
||||
log
|
||||
exp
|
@ -0,0 +1,37 @@
|
||||
# DeepSpeech2 ONNX model
|
||||
|
||||
1. convert deepspeech2 model to ONNX, using Paddle2ONNX.
|
||||
2. check paddleinference and onnxruntime output equal.
|
||||
3. optimize onnx model
|
||||
4. check paddleinference and optimized onnxruntime output equal.
|
||||
|
||||
Please make sure [Paddle2ONNX](https://github.com/PaddlePaddle/Paddle2ONNX) and [onnx-simplifier](https://github.com/zh794390558/onnx-simplifier/tree/dyn_time_shape) version is correct.
|
||||
|
||||
The example test with these packages installed:
|
||||
```
|
||||
paddle2onnx 0.9.8rc0 # develop af4354b4e9a61a93be6490640059a02a4499bc7a
|
||||
paddleaudio 0.2.1
|
||||
paddlefsl 1.1.0
|
||||
paddlenlp 2.2.6
|
||||
paddlepaddle-gpu 2.2.2
|
||||
paddlespeech 0.0.0 # develop
|
||||
paddlespeech-ctcdecoders 0.2.0
|
||||
paddlespeech-feat 0.1.0
|
||||
onnx 1.11.0
|
||||
onnx-simplifier 0.0.0 # https://github.com/zh794390558/onnx-simplifier/tree/dyn_time_shape
|
||||
onnxoptimizer 0.2.7
|
||||
onnxruntime 1.11.0
|
||||
```
|
||||
|
||||
## Using
|
||||
|
||||
```
|
||||
bash run.sh
|
||||
```
|
||||
|
||||
For more details please see `run.sh`.
|
||||
|
||||
## Outputs
|
||||
The optimized onnx model is `exp/model.opt.onnx`.
|
||||
|
||||
To show the graph, please using `local/netron.sh`.
|
@ -0,0 +1,100 @@
|
||||
#!/usr/bin/env python3
|
||||
# Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
import argparse
|
||||
import os
|
||||
import pickle
|
||||
|
||||
import numpy as np
|
||||
import onnxruntime
|
||||
import paddle
|
||||
|
||||
|
||||
def parse_args():
|
||||
parser = argparse.ArgumentParser(description=__doc__)
|
||||
parser.add_argument(
|
||||
'--input_file',
|
||||
type=str,
|
||||
default="static_ds2online_inputs.pickle",
|
||||
help="aishell ds2 input data file. For wenetspeech, we only feed for infer model", )
|
||||
parser.add_argument(
|
||||
'--model_type',
|
||||
type=str,
|
||||
default="aishell",
|
||||
help="aishell(1024) or wenetspeech(2048)", )
|
||||
parser.add_argument(
|
||||
'--model_dir', type=str, default=".", help="paddle model dir.")
|
||||
parser.add_argument(
|
||||
'--model_prefix',
|
||||
type=str,
|
||||
default="avg_1.jit",
|
||||
help="paddle model prefix.")
|
||||
parser.add_argument(
|
||||
'--onnx_model',
|
||||
type=str,
|
||||
default='./model.old.onnx',
|
||||
help="onnx model.")
|
||||
|
||||
return parser.parse_args()
|
||||
|
||||
|
||||
if __name__ == '__main__':
|
||||
FLAGS = parse_args()
|
||||
|
||||
# input and output
|
||||
with open(FLAGS.input_file, 'rb') as f:
|
||||
iodict = pickle.load(f)
|
||||
print(iodict.keys())
|
||||
|
||||
|
||||
audio_chunk = iodict['audio_chunk']
|
||||
audio_chunk_lens = iodict['audio_chunk_lens']
|
||||
chunk_state_h_box = iodict['chunk_state_h_box']
|
||||
chunk_state_c_box = iodict['chunk_state_c_bos']
|
||||
print("raw state shape: ", chunk_state_c_box.shape)
|
||||
|
||||
if FLAGS.model_type == 'wenetspeech':
|
||||
chunk_state_h_box = np.repeat(chunk_state_h_box, 2, axis=-1)
|
||||
chunk_state_c_box = np.repeat(chunk_state_c_box, 2, axis=-1)
|
||||
print("state shape: ", chunk_state_c_box.shape)
|
||||
|
||||
# paddle
|
||||
model = paddle.jit.load(os.path.join(FLAGS.model_dir, FLAGS.model_prefix))
|
||||
res_chunk, res_lens, chunk_state_h, chunk_state_c = model(
|
||||
paddle.to_tensor(audio_chunk),
|
||||
paddle.to_tensor(audio_chunk_lens),
|
||||
paddle.to_tensor(chunk_state_h_box),
|
||||
paddle.to_tensor(chunk_state_c_box), )
|
||||
|
||||
# onnxruntime
|
||||
options = onnxruntime.SessionOptions()
|
||||
options.enable_profiling = True
|
||||
sess = onnxruntime.InferenceSession(FLAGS.onnx_model, sess_options=options)
|
||||
ort_res_chunk, ort_res_lens, ort_chunk_state_h, ort_chunk_state_c = sess.run(
|
||||
['softmax_0.tmp_0', 'tmp_5', 'concat_0.tmp_0', 'concat_1.tmp_0'], {
|
||||
"audio_chunk": audio_chunk,
|
||||
"audio_chunk_lens": audio_chunk_lens,
|
||||
"chunk_state_h_box": chunk_state_h_box,
|
||||
"chunk_state_c_box": chunk_state_c_box
|
||||
})
|
||||
|
||||
print(sess.end_profiling())
|
||||
|
||||
# assert paddle equal ort
|
||||
print(np.allclose(ort_res_chunk, res_chunk, atol=1e-6))
|
||||
print(np.allclose(ort_res_lens, res_lens, atol=1e-6))
|
||||
|
||||
if FLAGS.model_type == 'aishell':
|
||||
print(np.allclose(ort_chunk_state_h, chunk_state_h, atol=1e-6))
|
||||
print(np.allclose(ort_chunk_state_c, chunk_state_c, atol=1e-6))
|
@ -0,0 +1,14 @@
|
||||
#!/bin/bash
|
||||
|
||||
# show model
|
||||
|
||||
if [ $# != 1 ];then
|
||||
echo "usage: $0 model_path"
|
||||
exit 1
|
||||
fi
|
||||
|
||||
|
||||
file=$1
|
||||
|
||||
pip install netron
|
||||
netron -p 8082 --host $(hostname -i) $file
|
@ -0,0 +1,7 @@
|
||||
|
||||
#!/bin/bash
|
||||
|
||||
# clone onnx repos
|
||||
git clone https://github.com/onnx/onnx.git
|
||||
git clone https://github.com/microsoft/onnxruntime.git
|
||||
git clone https://github.com/PaddlePaddle/Paddle2ONNX.git
|
File diff suppressed because it is too large
Load Diff
@ -0,0 +1,20 @@
|
||||
#!/bin/bash
|
||||
|
||||
set -e
|
||||
|
||||
if [ $# != 3 ];then
|
||||
# ./local/onnx_opt.sh model.old.onnx model.opt.onnx "audio_chunk:1,-1,161 audio_chunk_lens:1 chunk_state_c_box:5,1,1024 chunk_state_h_box:5,1,1024"
|
||||
echo "usage: $0 onnx.model.in onnx.model.out input_shape "
|
||||
exit 1
|
||||
fi
|
||||
|
||||
# onnx optimizer
|
||||
pip install onnx-simplifier
|
||||
|
||||
in=$1
|
||||
out=$2
|
||||
input_shape=$3
|
||||
|
||||
check_n=3
|
||||
|
||||
onnxsim $in $out $check_n --dynamic-input-shape --input-shape $input_shape
|
@ -0,0 +1,128 @@
|
||||
#!/usr/bin/env python3 -W ignore::DeprecationWarning
|
||||
# prune model by output names
|
||||
import argparse
|
||||
import copy
|
||||
import sys
|
||||
|
||||
import onnx
|
||||
|
||||
|
||||
def parse_arguments():
|
||||
parser = argparse.ArgumentParser()
|
||||
parser.add_argument(
|
||||
'--model',
|
||||
required=True,
|
||||
help='Path of directory saved the input model.')
|
||||
parser.add_argument(
|
||||
'--output_names',
|
||||
required=True,
|
||||
nargs='+',
|
||||
help='The outputs of pruned model.')
|
||||
parser.add_argument(
|
||||
'--save_file', required=True, help='Path to save the new onnx model.')
|
||||
return parser.parse_args()
|
||||
|
||||
|
||||
if __name__ == '__main__':
|
||||
args = parse_arguments()
|
||||
|
||||
if len(set(args.output_names)) < len(args.output_names):
|
||||
print(
|
||||
"[ERROR] There's dumplicate name in --output_names, which is not allowed."
|
||||
)
|
||||
sys.exit(-1)
|
||||
|
||||
model = onnx.load(args.model)
|
||||
|
||||
# collect all node outputs and graph output
|
||||
output_tensor_names = set()
|
||||
for node in model.graph.node:
|
||||
for out in node.output:
|
||||
# may contain model output
|
||||
output_tensor_names.add(out)
|
||||
|
||||
# for out in model.graph.output:
|
||||
# output_tensor_names.add(out.name)
|
||||
|
||||
for output_name in args.output_names:
|
||||
if output_name not in output_tensor_names:
|
||||
print(
|
||||
"[ERROR] Cannot find output tensor name '{}' in onnx model graph.".
|
||||
format(output_name))
|
||||
sys.exit(-1)
|
||||
|
||||
output_node_indices = set() # has output names
|
||||
output_to_node = dict() # all node outputs
|
||||
for i, node in enumerate(model.graph.node):
|
||||
for out in node.output:
|
||||
output_to_node[out] = i
|
||||
if out in args.output_names:
|
||||
output_node_indices.add(i)
|
||||
|
||||
# from outputs find all the ancestors
|
||||
reserved_node_indices = copy.deepcopy(
|
||||
output_node_indices) # nodes need to keep
|
||||
reserved_inputs = set() # model input to keep
|
||||
new_output_node_indices = copy.deepcopy(output_node_indices)
|
||||
|
||||
while True and len(new_output_node_indices) > 0:
|
||||
output_node_indices = copy.deepcopy(new_output_node_indices)
|
||||
|
||||
new_output_node_indices = set()
|
||||
|
||||
for out_node_idx in output_node_indices:
|
||||
# backtrace to parenet
|
||||
for ipt in model.graph.node[out_node_idx].input:
|
||||
if ipt in output_to_node:
|
||||
reserved_node_indices.add(output_to_node[ipt])
|
||||
new_output_node_indices.add(output_to_node[ipt])
|
||||
else:
|
||||
reserved_inputs.add(ipt)
|
||||
|
||||
num_inputs = len(model.graph.input)
|
||||
num_outputs = len(model.graph.output)
|
||||
num_nodes = len(model.graph.node)
|
||||
print(
|
||||
f"old graph has {num_inputs} inputs, {num_outputs} outpus, {num_nodes} nodes"
|
||||
)
|
||||
print(f"{len(reserved_node_indices)} node to keep.")
|
||||
|
||||
# del node not to keep
|
||||
for idx in range(num_nodes - 1, -1, -1):
|
||||
if idx not in reserved_node_indices:
|
||||
del model.graph.node[idx]
|
||||
|
||||
# del graph input not to keep
|
||||
for idx in range(num_inputs - 1, -1, -1):
|
||||
if model.graph.input[idx].name not in reserved_inputs:
|
||||
del model.graph.input[idx]
|
||||
|
||||
# del old graph outputs
|
||||
for i in range(num_outputs):
|
||||
del model.graph.output[0]
|
||||
|
||||
# new graph output as user input
|
||||
for out in args.output_names:
|
||||
model.graph.output.extend([onnx.ValueInfoProto(name=out)])
|
||||
|
||||
# infer shape
|
||||
try:
|
||||
from onnx_infer_shape import SymbolicShapeInference
|
||||
model = SymbolicShapeInference.infer_shapes(
|
||||
model,
|
||||
int_max=2**31 - 1,
|
||||
auto_merge=True,
|
||||
guess_output_rank=False,
|
||||
verbose=1)
|
||||
except Exception as e:
|
||||
print(f"skip infer shape step: {e}")
|
||||
|
||||
# check onnx model
|
||||
onnx.checker.check_model(model)
|
||||
# save onnx model
|
||||
onnx.save(model, args.save_file)
|
||||
print("[Finished] The new model saved in {}.".format(args.save_file))
|
||||
print("[DEBUG INFO] The inputs of new model: {}".format(
|
||||
[x.name for x in model.graph.input]))
|
||||
print("[DEBUG INFO] The outputs of new model: {}".format(
|
||||
[x.name for x in model.graph.output]))
|
@ -0,0 +1,111 @@
|
||||
#!/usr/bin/env python3 -W ignore::DeprecationWarning
|
||||
# rename node to new names
|
||||
import argparse
|
||||
import sys
|
||||
|
||||
import onnx
|
||||
|
||||
|
||||
def parse_arguments():
|
||||
parser = argparse.ArgumentParser()
|
||||
parser.add_argument(
|
||||
'--model',
|
||||
required=True,
|
||||
help='Path of directory saved the input model.')
|
||||
parser.add_argument(
|
||||
'--origin_names',
|
||||
required=True,
|
||||
nargs='+',
|
||||
help='The original name you want to modify.')
|
||||
parser.add_argument(
|
||||
'--new_names',
|
||||
required=True,
|
||||
nargs='+',
|
||||
help='The new name you want change to, the number of new_names should be same with the number of origin_names'
|
||||
)
|
||||
parser.add_argument(
|
||||
'--save_file', required=True, help='Path to save the new onnx model.')
|
||||
return parser.parse_args()
|
||||
|
||||
|
||||
if __name__ == '__main__':
|
||||
args = parse_arguments()
|
||||
|
||||
if len(set(args.origin_names)) < len(args.origin_names):
|
||||
print(
|
||||
"[ERROR] There's dumplicate name in --origin_names, which is not allowed."
|
||||
)
|
||||
sys.exit(-1)
|
||||
|
||||
if len(set(args.new_names)) < len(args.new_names):
|
||||
print(
|
||||
"[ERROR] There's dumplicate name in --new_names, which is not allowed."
|
||||
)
|
||||
sys.exit(-1)
|
||||
|
||||
if len(args.new_names) != len(args.origin_names):
|
||||
print(
|
||||
"[ERROR] Number of --new_names must be same with the number of --origin_names."
|
||||
)
|
||||
sys.exit(-1)
|
||||
|
||||
model = onnx.load(args.model)
|
||||
|
||||
# collect input and all node output
|
||||
output_tensor_names = set()
|
||||
for ipt in model.graph.input:
|
||||
output_tensor_names.add(ipt.name)
|
||||
|
||||
for node in model.graph.node:
|
||||
for out in node.output:
|
||||
output_tensor_names.add(out)
|
||||
|
||||
for origin_name in args.origin_names:
|
||||
if origin_name not in output_tensor_names:
|
||||
print(
|
||||
f"[ERROR] Cannot find tensor name '{origin_name}' in onnx model graph."
|
||||
)
|
||||
sys.exit(-1)
|
||||
|
||||
for new_name in args.new_names:
|
||||
if new_name in output_tensor_names:
|
||||
print(
|
||||
"[ERROR] The defined new_name '{}' is already exist in the onnx model, which is not allowed."
|
||||
)
|
||||
sys.exit(-1)
|
||||
|
||||
# rename graph input
|
||||
for i, ipt in enumerate(model.graph.input):
|
||||
if ipt.name in args.origin_names:
|
||||
idx = args.origin_names.index(ipt.name)
|
||||
model.graph.input[i].name = args.new_names[idx]
|
||||
|
||||
# rename node input and output
|
||||
for i, node in enumerate(model.graph.node):
|
||||
for j, ipt in enumerate(node.input):
|
||||
if ipt in args.origin_names:
|
||||
idx = args.origin_names.index(ipt)
|
||||
model.graph.node[i].input[j] = args.new_names[idx]
|
||||
|
||||
for j, out in enumerate(node.output):
|
||||
if out in args.origin_names:
|
||||
idx = args.origin_names.index(out)
|
||||
model.graph.node[i].output[j] = args.new_names[idx]
|
||||
|
||||
# rename graph output
|
||||
for i, out in enumerate(model.graph.output):
|
||||
if out.name in args.origin_names:
|
||||
idx = args.origin_names.index(out.name)
|
||||
model.graph.output[i].name = args.new_names[idx]
|
||||
|
||||
# check onnx model
|
||||
onnx.checker.check_model(model)
|
||||
|
||||
# save model
|
||||
onnx.save(model, args.save_file)
|
||||
|
||||
print("[Finished] The new model saved in {}.".format(args.save_file))
|
||||
print("[DEBUG INFO] The inputs of new model: {}".format(
|
||||
[x.name for x in model.graph.input]))
|
||||
print("[DEBUG INFO] The outputs of new model: {}".format(
|
||||
[x.name for x in model.graph.output]))
|
@ -0,0 +1,111 @@
|
||||
#!/usr/bin/env python3 -W ignore::DeprecationWarning
|
||||
# https://github.com/jiangjiajun/PaddleUtils/blob/main/paddle/README.md#2-%E4%BF%AE%E6%94%B9paddle%E6%A8%A1%E5%9E%8B%E8%BE%93%E5%85%A5shape
|
||||
import argparse
|
||||
|
||||
# paddle inference shape
|
||||
|
||||
|
||||
def process_old_ops_desc(program):
|
||||
"""set matmul op head_number attr to 1 is not exist.
|
||||
|
||||
Args:
|
||||
program (_type_): _description_
|
||||
"""
|
||||
for i in range(len(program.blocks[0].ops)):
|
||||
if program.blocks[0].ops[i].type == "matmul":
|
||||
if not program.blocks[0].ops[i].has_attr("head_number"):
|
||||
program.blocks[0].ops[i]._set_attr("head_number", 1)
|
||||
|
||||
|
||||
def infer_shape(program, input_shape_dict):
|
||||
# 2002002
|
||||
model_version = program.desc._version()
|
||||
# 2.2.2
|
||||
paddle_version = paddle.__version__
|
||||
major_ver = model_version // 1000000
|
||||
minor_ver = (model_version - major_ver * 1000000) // 1000
|
||||
patch_ver = model_version - major_ver * 1000000 - minor_ver * 1000
|
||||
model_version = "{}.{}.{}".format(major_ver, minor_ver, patch_ver)
|
||||
if model_version != paddle_version:
|
||||
print(
|
||||
f"[WARNING] The model is saved by paddlepaddle v{model_version}, but now your paddlepaddle is version of {paddle_version}, this difference may cause error, it is recommend you reinstall a same version of paddlepaddle for this model"
|
||||
)
|
||||
|
||||
OP_WITHOUT_KERNEL_SET = {
|
||||
'feed', 'fetch', 'recurrent', 'go', 'rnn_memory_helper_grad',
|
||||
'conditional_block', 'while', 'send', 'recv', 'listen_and_serv',
|
||||
'fl_listen_and_serv', 'ncclInit', 'select', 'checkpoint_notify',
|
||||
'gen_bkcl_id', 'c_gen_bkcl_id', 'gen_nccl_id', 'c_gen_nccl_id',
|
||||
'c_comm_init', 'c_sync_calc_stream', 'c_sync_comm_stream',
|
||||
'queue_generator', 'dequeue', 'enqueue', 'heter_listen_and_serv',
|
||||
'c_wait_comm', 'c_wait_compute', 'c_gen_hccl_id', 'c_comm_init_hccl',
|
||||
'copy_cross_scope'
|
||||
}
|
||||
|
||||
for k, v in input_shape_dict.items():
|
||||
program.blocks[0].var(k).desc.set_shape(v)
|
||||
|
||||
for i in range(len(program.blocks)):
|
||||
for j in range(len(program.blocks[0].ops)):
|
||||
# for ops
|
||||
if program.blocks[i].ops[j].type in OP_WITHOUT_KERNEL_SET:
|
||||
print(f"not infer: {program.blocks[i].ops[j].type} op")
|
||||
continue
|
||||
print(f"infer: {program.blocks[i].ops[j].type} op")
|
||||
program.blocks[i].ops[j].desc.infer_shape(program.blocks[i].desc)
|
||||
|
||||
|
||||
def parse_arguments():
|
||||
# python pd_infer_shape.py --model_dir data/exp/deepspeech2_online/checkpoints \
|
||||
# --model_filename avg_1.jit.pdmodel\
|
||||
# --params_filename avg_1.jit.pdiparams \
|
||||
# --save_dir . \
|
||||
# --input_shape_dict="{'audio_chunk':[1,-1,161], 'audio_chunk_lens':[1], 'chunk_state_c_box':[5, 1, 1024], 'chunk_state_h_box':[5,1,1024]}"
|
||||
parser = argparse.ArgumentParser()
|
||||
parser.add_argument(
|
||||
'--model_dir',
|
||||
required=True,
|
||||
help='Path of directory saved the input model.')
|
||||
parser.add_argument(
|
||||
'--model_filename', required=True, help='model.pdmodel.')
|
||||
parser.add_argument(
|
||||
'--params_filename', required=True, help='model.pdiparams.')
|
||||
parser.add_argument(
|
||||
'--save_dir',
|
||||
required=True,
|
||||
help='directory to save the exported model.')
|
||||
parser.add_argument(
|
||||
'--input_shape_dict', required=True, help="The new shape information.")
|
||||
return parser.parse_args()
|
||||
|
||||
|
||||
if __name__ == '__main__':
|
||||
args = parse_arguments()
|
||||
|
||||
import paddle
|
||||
paddle.enable_static()
|
||||
import paddle.fluid as fluid
|
||||
|
||||
input_shape_dict_str = args.input_shape_dict
|
||||
input_shape_dict = eval(input_shape_dict_str)
|
||||
|
||||
print("Start to load paddle model...")
|
||||
exe = fluid.Executor(fluid.CPUPlace())
|
||||
|
||||
prog, ipts, outs = fluid.io.load_inference_model(
|
||||
args.model_dir,
|
||||
exe,
|
||||
model_filename=args.model_filename,
|
||||
params_filename=args.params_filename)
|
||||
|
||||
process_old_ops_desc(prog)
|
||||
infer_shape(prog, input_shape_dict)
|
||||
|
||||
fluid.io.save_inference_model(
|
||||
args.save_dir,
|
||||
ipts,
|
||||
outs,
|
||||
exe,
|
||||
prog,
|
||||
model_filename=args.model_filename,
|
||||
params_filename=args.params_filename)
|
@ -0,0 +1,158 @@
|
||||
#!/usr/bin/env python3 -W ignore::DeprecationWarning
|
||||
# https://github.com/jiangjiajun/PaddleUtils/blob/main/paddle/README.md#1-%E8%A3%81%E5%89%AApaddle%E6%A8%A1%E5%9E%8B
|
||||
import argparse
|
||||
import sys
|
||||
from typing import List
|
||||
|
||||
# paddle prune model.
|
||||
|
||||
|
||||
def prepend_feed_ops(program,
|
||||
feed_target_names: List[str],
|
||||
feed_holder_name='feed'):
|
||||
import paddle.fluid.core as core
|
||||
if len(feed_target_names) == 0:
|
||||
return
|
||||
|
||||
global_block = program.global_block()
|
||||
feed_var = global_block.create_var(
|
||||
name=feed_holder_name,
|
||||
type=core.VarDesc.VarType.FEED_MINIBATCH,
|
||||
persistable=True, )
|
||||
|
||||
for i, name in enumerate(feed_target_names, 0):
|
||||
if not global_block.has_var(name):
|
||||
print(
|
||||
f"The input[{i}]: '{name}' doesn't exist in pruned inference program, which will be ignored in new saved model."
|
||||
)
|
||||
continue
|
||||
|
||||
out = global_block.var(name)
|
||||
global_block._prepend_op(
|
||||
type='feed',
|
||||
inputs={'X': [feed_var]},
|
||||
outputs={'Out': [out]},
|
||||
attrs={'col': i}, )
|
||||
|
||||
|
||||
def append_fetch_ops(program,
|
||||
fetch_target_names: List[str],
|
||||
fetch_holder_name='fetch'):
|
||||
"""in the place, we will add the fetch op
|
||||
|
||||
Args:
|
||||
program (_type_): inference program
|
||||
fetch_target_names (List[str]): target names
|
||||
fetch_holder_name (str, optional): fetch op name. Defaults to 'fetch'.
|
||||
"""
|
||||
import paddle.fluid.core as core
|
||||
global_block = program.global_block()
|
||||
fetch_var = global_block.create_var(
|
||||
name=fetch_holder_name,
|
||||
type=core.VarDesc.VarType.FETCH_LIST,
|
||||
persistable=True, )
|
||||
|
||||
print(f"the len of fetch_target_names: {len(fetch_target_names)}")
|
||||
|
||||
for i, name in enumerate(fetch_target_names):
|
||||
global_block.append_op(
|
||||
type='fetch',
|
||||
inputs={'X': [name]},
|
||||
outputs={'Out': [fetch_var]},
|
||||
attrs={'col': i}, )
|
||||
|
||||
|
||||
def insert_fetch(program,
|
||||
fetch_target_names: List[str],
|
||||
fetch_holder_name='fetch'):
|
||||
"""in the place, we will add the fetch op
|
||||
|
||||
Args:
|
||||
program (_type_): inference program
|
||||
fetch_target_names (List[str]): target names
|
||||
fetch_holder_name (str, optional): fetch op name. Defaults to 'fetch'.
|
||||
"""
|
||||
global_block = program.global_block()
|
||||
|
||||
# remove fetch
|
||||
need_to_remove_op_index = []
|
||||
for i, op in enumerate(global_block.ops):
|
||||
if op.type == 'fetch':
|
||||
need_to_remove_op_index.append(i)
|
||||
|
||||
for index in reversed(need_to_remove_op_index):
|
||||
global_block._remove_op(index)
|
||||
|
||||
program.desc.flush()
|
||||
|
||||
# append new fetch
|
||||
append_fetch_ops(program, fetch_target_names, fetch_holder_name)
|
||||
|
||||
|
||||
def parse_arguments():
|
||||
parser = argparse.ArgumentParser()
|
||||
parser.add_argument(
|
||||
'--model_dir',
|
||||
required=True,
|
||||
help='Path of directory saved the input model.')
|
||||
parser.add_argument(
|
||||
'--model_filename', required=True, help='model.pdmodel.')
|
||||
parser.add_argument(
|
||||
'--params_filename', required=True, help='model.pdiparams.')
|
||||
parser.add_argument(
|
||||
'--output_names',
|
||||
required=True,
|
||||
help='The outputs of model. sep by comma')
|
||||
parser.add_argument(
|
||||
'--save_dir',
|
||||
required=True,
|
||||
help='directory to save the exported model.')
|
||||
parser.add_argument('--debug', default=False, help='output debug info.')
|
||||
return parser.parse_args()
|
||||
|
||||
|
||||
if __name__ == '__main__':
|
||||
args = parse_arguments()
|
||||
|
||||
args.output_names = args.output_names.split(",")
|
||||
|
||||
if len(set(args.output_names)) < len(args.output_names):
|
||||
print(
|
||||
f"[ERROR] There's dumplicate name in --output_names {args.output_names}, which is not allowed."
|
||||
)
|
||||
sys.exit(-1)
|
||||
|
||||
import paddle
|
||||
paddle.enable_static()
|
||||
# hack prepend_feed_ops
|
||||
paddle.fluid.io.prepend_feed_ops = prepend_feed_ops
|
||||
|
||||
import paddle.fluid as fluid
|
||||
|
||||
print("start to load paddle model")
|
||||
exe = fluid.Executor(fluid.CPUPlace())
|
||||
prog, ipts, outs = fluid.io.load_inference_model(
|
||||
args.model_dir,
|
||||
exe,
|
||||
model_filename=args.model_filename,
|
||||
params_filename=args.params_filename)
|
||||
|
||||
print("start to load insert fetch op")
|
||||
new_outputs = []
|
||||
insert_fetch(prog, args.output_names)
|
||||
for out_name in args.output_names:
|
||||
new_outputs.append(prog.global_block().var(out_name))
|
||||
|
||||
# not equal to paddle.static.save_inference_model
|
||||
fluid.io.save_inference_model(
|
||||
args.save_dir,
|
||||
ipts,
|
||||
new_outputs,
|
||||
exe,
|
||||
prog,
|
||||
model_filename=args.model_filename,
|
||||
params_filename=args.params_filename)
|
||||
|
||||
if args.debug:
|
||||
for op in prog.global_block().ops:
|
||||
print(op)
|
@ -0,0 +1,23 @@
|
||||
#!/bin/bash
|
||||
|
||||
set -e
|
||||
|
||||
if [ $# != 5 ]; then
|
||||
# local/prune.sh data/exp/deepspeech2_online/checkpoints avg_1.jit.pdmodel avg_1.jit.pdiparams softmax_0.tmp_0,tmp_5,concat_0.tmp_0,concat_1.tmp_0 $PWD
|
||||
echo "usage: $0 model_dir model_filename param_filename outputs_names save_dir"
|
||||
exit 1
|
||||
fi
|
||||
|
||||
dir=$1
|
||||
model=$2
|
||||
param=$3
|
||||
outputs=$4
|
||||
save_dir=$5
|
||||
|
||||
|
||||
python local/pd_prune_model.py \
|
||||
--model_dir $dir \
|
||||
--model_filename $model \
|
||||
--params_filename $param \
|
||||
--output_names $outputs \
|
||||
--save_dir $save_dir
|
@ -0,0 +1,25 @@
|
||||
#!/bin/bash
|
||||
|
||||
if [ $# != 4 ];then
|
||||
# local/tonnx.sh data/exp/deepspeech2_online/checkpoints avg_1.jit.pdmodel avg_1.jit.pdiparams exp/model.onnx
|
||||
echo "usage: $0 model_dir model_name param_name onnx_output_name"
|
||||
exit 1
|
||||
fi
|
||||
|
||||
dir=$1
|
||||
model=$2
|
||||
param=$3
|
||||
output=$4
|
||||
|
||||
pip install paddle2onnx
|
||||
pip install onnx
|
||||
|
||||
# https://github.com/PaddlePaddle/Paddle2ONNX#%E5%91%BD%E4%BB%A4%E8%A1%8C%E8%BD%AC%E6%8D%A2
|
||||
paddle2onnx --model_dir $dir \
|
||||
--model_filename $model \
|
||||
--params_filename $param \
|
||||
--save_file $output \
|
||||
--enable_dev_version True \
|
||||
--opset_version 9 \
|
||||
--enable_onnx_checker True
|
||||
|
@ -0,0 +1,14 @@
|
||||
# This contains the locations of binarys build required for running the examples.
|
||||
|
||||
MAIN_ROOT=`realpath $PWD/../../../../`
|
||||
SPEECHX_ROOT=$PWD/../../../
|
||||
SPEECHX_BUILD=$SPEECHX_ROOT/build/speechx
|
||||
|
||||
SPEECHX_TOOLS=$SPEECHX_ROOT/tools
|
||||
TOOLS_BIN=$SPEECHX_TOOLS/valgrind/install/bin
|
||||
|
||||
[ -d $SPEECHX_BUILD ] || { echo "Error: 'build/speechx' directory not found. please ensure that the project build successfully"; }
|
||||
|
||||
export LC_AL=C
|
||||
|
||||
export PATH=$PATH:$TOOLS_BIN
|
@ -0,0 +1,101 @@
|
||||
#!/bin/bash
|
||||
|
||||
set -e
|
||||
|
||||
. path.sh
|
||||
|
||||
stage=0
|
||||
stop_stage=100
|
||||
#tarfile=asr0_deepspeech2_online_wenetspeech_ckpt_1.0.2.model.tar.gz
|
||||
tarfile=asr0_deepspeech2_online_aishell_fbank161_ckpt_1.0.1.model.tar.gz
|
||||
model_prefix=avg_1.jit
|
||||
model=${model_prefix}.pdmodel
|
||||
param=${model_prefix}.pdiparams
|
||||
|
||||
. utils/parse_options.sh
|
||||
|
||||
data=data
|
||||
exp=exp
|
||||
|
||||
mkdir -p $data $exp
|
||||
|
||||
dir=$data/exp/deepspeech2_online/checkpoints
|
||||
|
||||
# wenetspeech or aishell
|
||||
model_type=$(echo $tarfile | cut -d '_' -f 4)
|
||||
|
||||
if [ ${stage} -le 0 ] && [ ${stop_stage} -ge 0 ];then
|
||||
test -f $data/$tarfile || wget -P $data -c https://paddlespeech.bj.bcebos.com/s2t/$model_type/asr0/$tarfile
|
||||
|
||||
# wenetspeech ds2 model
|
||||
pushd $data
|
||||
tar zxvf $tarfile
|
||||
popd
|
||||
|
||||
# ds2 model demo inputs
|
||||
pushd $exp
|
||||
wget -c http://paddlespeech.bj.bcebos.com/speechx/examples/ds2_ol/onnx/static_ds2online_inputs.pickle
|
||||
popd
|
||||
fi
|
||||
|
||||
output_names=softmax_0.tmp_0,tmp_5,concat_0.tmp_0,concat_1.tmp_0
|
||||
if [ ${stage} -le 1 ] && [ ${stop_stage} -ge 1 ];then
|
||||
# prune model by outputs
|
||||
mkdir -p $exp/prune
|
||||
|
||||
# prune model deps on output_names.
|
||||
./local/prune.sh $dir $model $param $output_names $exp/prune
|
||||
fi
|
||||
|
||||
# aishell rnn hidden is 1024
|
||||
# wenetspeech rnn hiddn is 2048
|
||||
if [ $model_type == 'aishell' ];then
|
||||
input_shape_dict="{'audio_chunk':[1,-1,161], 'audio_chunk_lens':[1], 'chunk_state_c_box':[5, 1, 1024], 'chunk_state_h_box':[5,1,1024]}"
|
||||
elif [ $model_type == 'wenetspeech' ];then
|
||||
input_shape_dict="{'audio_chunk':[1,-1,161], 'audio_chunk_lens':[1], 'chunk_state_c_box':[5, 1, 2048], 'chunk_state_h_box':[5,1,2048]}"
|
||||
else
|
||||
echo "not support: $model_type"
|
||||
exit -1
|
||||
fi
|
||||
if [ ${stage} -le 2 ] && [ ${stop_stage} -ge 2 ];then
|
||||
# infer shape by new shape
|
||||
mkdir -p $exp/shape
|
||||
echo $input_shape_dict
|
||||
python3 local/pd_infer_shape.py \
|
||||
--model_dir $dir \
|
||||
--model_filename $model \
|
||||
--params_filename $param \
|
||||
--save_dir $exp/shape \
|
||||
--input_shape_dict="${input_shape_dict}"
|
||||
fi
|
||||
|
||||
input_file=$exp/static_ds2online_inputs.pickle
|
||||
test -e $input_file
|
||||
|
||||
if [ ${stage} -le 3 ] && [ ${stop_stage} -ge 3 ];then
|
||||
# to onnx
|
||||
./local/tonnx.sh $dir $model $param $exp/model.onnx
|
||||
|
||||
./local/infer_check.py --input_file $input_file --model_type $model_type --model_dir $dir --model_prefix $model_prefix --onnx_model $exp/model.onnx
|
||||
fi
|
||||
|
||||
|
||||
# aishell rnn hidden is 1024
|
||||
# wenetspeech rnn hiddn is 2048
|
||||
if [ $model_type == 'aishell' ];then
|
||||
input_shape="audio_chunk:1,-1,161 audio_chunk_lens:1 chunk_state_c_box:5,1,1024 chunk_state_h_box:5,1,1024"
|
||||
elif [ $model_type == 'wenetspeech' ];then
|
||||
input_shape="audio_chunk:1,-1,161 audio_chunk_lens:1 chunk_state_c_box:5,1,2048 chunk_state_h_box:5,1,2048"
|
||||
else
|
||||
echo "not support: $model_type"
|
||||
exit -1
|
||||
fi
|
||||
|
||||
|
||||
if [ ${stage} -le 4 ] && [ ${stop_stage} -ge 4 ] ;then
|
||||
# wenetspeech ds2 model execed 2GB limit, will error.
|
||||
# simplifying onnx model
|
||||
./local/onnx_opt.sh $exp/model.onnx $exp/model.opt.onnx "$input_shape"
|
||||
|
||||
./local/infer_check.py --input_file $input_file --model_type $model_type --model_dir $dir --model_prefix $model_prefix --onnx_model $exp/model.opt.onnx
|
||||
fi
|
@ -0,0 +1 @@
|
||||
../../../../utils/
|
Loading…
Reference in new issue