Merge pull request #2034 from zh794390558/onnx

[speechx] deepseech2 to onnx
pull/2052/head
Hui Zhang 2 years ago committed by GitHub
commit b4c6a52beb
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23

3
.gitignore vendored

@ -39,6 +39,9 @@ tools/env.sh
tools/openfst-1.8.1/
tools/libsndfile/
tools/python-soundfile/
tools/onnx
tools/onnxruntime
tools/Paddle2ONNX
speechx/fc_patch/

@ -795,6 +795,7 @@ class ASRServerExecutor(ASRExecutor):
if num_decoding_left_chunks:
assert num_decoding_left_chunks == -1 or num_decoding_left_chunks >= 0, f"num_decoding_left_chunks should be -1 or >=0"
self.config.decode.num_decoding_left_chunks = num_decoding_left_chunks
# we only support ctc_prefix_beam_search and attention_rescoring dedoding method
# Generally we set the decoding_method to attention_rescoring
if self.config.decode.decoding_method not in [

@ -0,0 +1,3 @@
data
log
exp

@ -0,0 +1,37 @@
# DeepSpeech2 ONNX model
1. convert deepspeech2 model to ONNX, using Paddle2ONNX.
2. check paddleinference and onnxruntime output equal.
3. optimize onnx model
4. check paddleinference and optimized onnxruntime output equal.
Please make sure [Paddle2ONNX](https://github.com/PaddlePaddle/Paddle2ONNX) and [onnx-simplifier](https://github.com/zh794390558/onnx-simplifier/tree/dyn_time_shape) version is correct.
The example test with these packages installed:
```
paddle2onnx 0.9.8rc0 # develop af4354b4e9a61a93be6490640059a02a4499bc7a
paddleaudio 0.2.1
paddlefsl 1.1.0
paddlenlp 2.2.6
paddlepaddle-gpu 2.2.2
paddlespeech 0.0.0 # develop
paddlespeech-ctcdecoders 0.2.0
paddlespeech-feat 0.1.0
onnx 1.11.0
onnx-simplifier 0.0.0 # https://github.com/zh794390558/onnx-simplifier/tree/dyn_time_shape
onnxoptimizer 0.2.7
onnxruntime 1.11.0
```
## Using
```
bash run.sh
```
For more details please see `run.sh`.
## Outputs
The optimized onnx model is `exp/model.opt.onnx`.
To show the graph, please using `local/netron.sh`.

@ -0,0 +1,86 @@
#!/usr/bin/env python3
# Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import argparse
import os
import pickle
import numpy as np
import onnxruntime
import paddle
def parse_args():
parser = argparse.ArgumentParser(description=__doc__)
parser.add_argument(
'--input_file',
type=str,
default="static_ds2online_inputs.pickle",
help="ds2 input pickle file.", )
parser.add_argument(
'--model_dir', type=str, default=".", help="paddle model dir.")
parser.add_argument(
'--model_prefix',
type=str,
default="avg_1.jit",
help="paddle model prefix.")
parser.add_argument(
'--onnx_model',
type=str,
default='./model.old.onnx',
help="onnx model.")
return parser.parse_args()
if __name__ == '__main__':
FLAGS = parse_args()
# input and output
with open(FLAGS.input_file, 'rb') as f:
iodict = pickle.load(f)
print(iodict.keys())
audio_chunk = iodict['audio_chunk']
audio_chunk_lens = iodict['audio_chunk_lens']
chunk_state_h_box = iodict['chunk_state_h_box']
chunk_state_c_box = iodict['chunk_state_c_bos']
# paddle
model = paddle.jit.load(os.path.join(FLAGS.model_dir, FLAGS.model_prefix))
res_chunk, res_lens, chunk_state_h, chunk_state_c = model(
paddle.to_tensor(audio_chunk),
paddle.to_tensor(audio_chunk_lens),
paddle.to_tensor(chunk_state_h_box),
paddle.to_tensor(chunk_state_c_box), )
# onnxruntime
options = onnxruntime.SessionOptions()
options.enable_profiling = True
sess = onnxruntime.InferenceSession(FLAGS.onnx_model, sess_options=options)
ort_res_chunk, ort_res_lens, ort_chunk_state_h, ort_chunk_state_c = sess.run(
['softmax_0.tmp_0', 'tmp_5', 'concat_0.tmp_0', 'concat_1.tmp_0'], {
"audio_chunk": audio_chunk,
"audio_chunk_lens": audio_chunk_lens,
"chunk_state_h_box": chunk_state_h_box,
"chunk_state_c_box": chunk_state_c_box
})
print(sess.end_profiling())
# assert paddle equal ort
print(np.allclose(ort_res_chunk, res_chunk, atol=1e-6))
print(np.allclose(ort_res_lens, res_lens, atol=1e-6))
print(np.allclose(ort_chunk_state_h, chunk_state_h, atol=1e-6))
print(np.allclose(ort_chunk_state_c, chunk_state_c, atol=1e-6))

@ -0,0 +1,14 @@
#!/bin/bash
# show model
if [ $# != 1 ];then
echo "usage: $0 model_path"
exit 1
fi
file=$1
pip install netron
netron -p 8082 --host $(hostname -i) $file

@ -0,0 +1,7 @@
#!/bin/bash
# clone onnx repos
git clone https://github.com/onnx/onnx.git
git clone https://github.com/microsoft/onnxruntime.git
git clone https://github.com/PaddlePaddle/Paddle2ONNX.git

File diff suppressed because it is too large Load Diff

@ -0,0 +1,20 @@
#!/bin/bash
set -e
if [ $# != 3 ];then
# ./local/onnx_opt.sh model.old.onnx model.opt.onnx "audio_chunk:1,-1,161 audio_chunk_lens:1 chunk_state_c_box:5,1,1024 chunk_state_h_box:5,1,1024"
echo "usage: $0 onnx.model.in onnx.model.out input_shape "
exit 1
fi
# onnx optimizer
pip install onnx-simplifier
in=$1
out=$2
input_shape=$3
check_n=3
onnxsim $in $out $check_n --dynamic-input-shape --input-shape $input_shape

@ -0,0 +1,128 @@
#!/usr/bin/env python3 -W ignore::DeprecationWarning
# prune model by output names
import argparse
import copy
import sys
import onnx
def parse_arguments():
parser = argparse.ArgumentParser()
parser.add_argument(
'--model',
required=True,
help='Path of directory saved the input model.')
parser.add_argument(
'--output_names',
required=True,
nargs='+',
help='The outputs of pruned model.')
parser.add_argument(
'--save_file', required=True, help='Path to save the new onnx model.')
return parser.parse_args()
if __name__ == '__main__':
args = parse_arguments()
if len(set(args.output_names)) < len(args.output_names):
print(
"[ERROR] There's dumplicate name in --output_names, which is not allowed."
)
sys.exit(-1)
model = onnx.load(args.model)
# collect all node outputs and graph output
output_tensor_names = set()
for node in model.graph.node:
for out in node.output:
# may contain model output
output_tensor_names.add(out)
# for out in model.graph.output:
# output_tensor_names.add(out.name)
for output_name in args.output_names:
if output_name not in output_tensor_names:
print(
"[ERROR] Cannot find output tensor name '{}' in onnx model graph.".
format(output_name))
sys.exit(-1)
output_node_indices = set() # has output names
output_to_node = dict() # all node outputs
for i, node in enumerate(model.graph.node):
for out in node.output:
output_to_node[out] = i
if out in args.output_names:
output_node_indices.add(i)
# from outputs find all the ancestors
reserved_node_indices = copy.deepcopy(
output_node_indices) # nodes need to keep
reserved_inputs = set() # model input to keep
new_output_node_indices = copy.deepcopy(output_node_indices)
while True and len(new_output_node_indices) > 0:
output_node_indices = copy.deepcopy(new_output_node_indices)
new_output_node_indices = set()
for out_node_idx in output_node_indices:
# backtrace to parenet
for ipt in model.graph.node[out_node_idx].input:
if ipt in output_to_node:
reserved_node_indices.add(output_to_node[ipt])
new_output_node_indices.add(output_to_node[ipt])
else:
reserved_inputs.add(ipt)
num_inputs = len(model.graph.input)
num_outputs = len(model.graph.output)
num_nodes = len(model.graph.node)
print(
f"old graph has {num_inputs} inputs, {num_outputs} outpus, {num_nodes} nodes"
)
print(f"{len(reserved_node_indices)} node to keep.")
# del node not to keep
for idx in range(num_nodes - 1, -1, -1):
if idx not in reserved_node_indices:
del model.graph.node[idx]
# del graph input not to keep
for idx in range(num_inputs - 1, -1, -1):
if model.graph.input[idx].name not in reserved_inputs:
del model.graph.input[idx]
# del old graph outputs
for i in range(num_outputs):
del model.graph.output[0]
# new graph output as user input
for out in args.output_names:
model.graph.output.extend([onnx.ValueInfoProto(name=out)])
# infer shape
try:
from onnx_infer_shape import SymbolicShapeInference
model = SymbolicShapeInference.infer_shapes(
model,
int_max=2**31 - 1,
auto_merge=True,
guess_output_rank=False,
verbose=1)
except Exception as e:
print(f"skip infer shape step: {e}")
# check onnx model
onnx.checker.check_model(model)
# save onnx model
onnx.save(model, args.save_file)
print("[Finished] The new model saved in {}.".format(args.save_file))
print("[DEBUG INFO] The inputs of new model: {}".format(
[x.name for x in model.graph.input]))
print("[DEBUG INFO] The outputs of new model: {}".format(
[x.name for x in model.graph.output]))

@ -0,0 +1,111 @@
#!/usr/bin/env python3 -W ignore::DeprecationWarning
# rename node to new names
import argparse
import sys
import onnx
def parse_arguments():
parser = argparse.ArgumentParser()
parser.add_argument(
'--model',
required=True,
help='Path of directory saved the input model.')
parser.add_argument(
'--origin_names',
required=True,
nargs='+',
help='The original name you want to modify.')
parser.add_argument(
'--new_names',
required=True,
nargs='+',
help='The new name you want change to, the number of new_names should be same with the number of origin_names'
)
parser.add_argument(
'--save_file', required=True, help='Path to save the new onnx model.')
return parser.parse_args()
if __name__ == '__main__':
args = parse_arguments()
if len(set(args.origin_names)) < len(args.origin_names):
print(
"[ERROR] There's dumplicate name in --origin_names, which is not allowed."
)
sys.exit(-1)
if len(set(args.new_names)) < len(args.new_names):
print(
"[ERROR] There's dumplicate name in --new_names, which is not allowed."
)
sys.exit(-1)
if len(args.new_names) != len(args.origin_names):
print(
"[ERROR] Number of --new_names must be same with the number of --origin_names."
)
sys.exit(-1)
model = onnx.load(args.model)
# collect input and all node output
output_tensor_names = set()
for ipt in model.graph.input:
output_tensor_names.add(ipt.name)
for node in model.graph.node:
for out in node.output:
output_tensor_names.add(out)
for origin_name in args.origin_names:
if origin_name not in output_tensor_names:
print(
f"[ERROR] Cannot find tensor name '{origin_name}' in onnx model graph."
)
sys.exit(-1)
for new_name in args.new_names:
if new_name in output_tensor_names:
print(
"[ERROR] The defined new_name '{}' is already exist in the onnx model, which is not allowed."
)
sys.exit(-1)
# rename graph input
for i, ipt in enumerate(model.graph.input):
if ipt.name in args.origin_names:
idx = args.origin_names.index(ipt.name)
model.graph.input[i].name = args.new_names[idx]
# rename node input and output
for i, node in enumerate(model.graph.node):
for j, ipt in enumerate(node.input):
if ipt in args.origin_names:
idx = args.origin_names.index(ipt)
model.graph.node[i].input[j] = args.new_names[idx]
for j, out in enumerate(node.output):
if out in args.origin_names:
idx = args.origin_names.index(out)
model.graph.node[i].output[j] = args.new_names[idx]
# rename graph output
for i, out in enumerate(model.graph.output):
if out.name in args.origin_names:
idx = args.origin_names.index(out.name)
model.graph.output[i].name = args.new_names[idx]
# check onnx model
onnx.checker.check_model(model)
# save model
onnx.save(model, args.save_file)
print("[Finished] The new model saved in {}.".format(args.save_file))
print("[DEBUG INFO] The inputs of new model: {}".format(
[x.name for x in model.graph.input]))
print("[DEBUG INFO] The outputs of new model: {}".format(
[x.name for x in model.graph.output]))

@ -0,0 +1,111 @@
#!/usr/bin/env python3 -W ignore::DeprecationWarning
# https://github.com/jiangjiajun/PaddleUtils/blob/main/paddle/README.md#2-%E4%BF%AE%E6%94%B9paddle%E6%A8%A1%E5%9E%8B%E8%BE%93%E5%85%A5shape
import argparse
# paddle inference shape
def process_old_ops_desc(program):
"""set matmul op head_number attr to 1 is not exist.
Args:
program (_type_): _description_
"""
for i in range(len(program.blocks[0].ops)):
if program.blocks[0].ops[i].type == "matmul":
if not program.blocks[0].ops[i].has_attr("head_number"):
program.blocks[0].ops[i]._set_attr("head_number", 1)
def infer_shape(program, input_shape_dict):
# 2002002
model_version = program.desc._version()
# 2.2.2
paddle_version = paddle.__version__
major_ver = model_version // 1000000
minor_ver = (model_version - major_ver * 1000000) // 1000
patch_ver = model_version - major_ver * 1000000 - minor_ver * 1000
model_version = "{}.{}.{}".format(major_ver, minor_ver, patch_ver)
if model_version != paddle_version:
print(
f"[WARNING] The model is saved by paddlepaddle v{model_version}, but now your paddlepaddle is version of {paddle_version}, this difference may cause error, it is recommend you reinstall a same version of paddlepaddle for this model"
)
OP_WITHOUT_KERNEL_SET = {
'feed', 'fetch', 'recurrent', 'go', 'rnn_memory_helper_grad',
'conditional_block', 'while', 'send', 'recv', 'listen_and_serv',
'fl_listen_and_serv', 'ncclInit', 'select', 'checkpoint_notify',
'gen_bkcl_id', 'c_gen_bkcl_id', 'gen_nccl_id', 'c_gen_nccl_id',
'c_comm_init', 'c_sync_calc_stream', 'c_sync_comm_stream',
'queue_generator', 'dequeue', 'enqueue', 'heter_listen_and_serv',
'c_wait_comm', 'c_wait_compute', 'c_gen_hccl_id', 'c_comm_init_hccl',
'copy_cross_scope'
}
for k, v in input_shape_dict.items():
program.blocks[0].var(k).desc.set_shape(v)
for i in range(len(program.blocks)):
for j in range(len(program.blocks[0].ops)):
# for ops
if program.blocks[i].ops[j].type in OP_WITHOUT_KERNEL_SET:
print(f"not infer: {program.blocks[i].ops[j].type} op")
continue
print(f"infer: {program.blocks[i].ops[j].type} op")
program.blocks[i].ops[j].desc.infer_shape(program.blocks[i].desc)
def parse_arguments():
# python pd_infer_shape.py --model_dir data/exp/deepspeech2_online/checkpoints \
# --model_filename avg_1.jit.pdmodel\
# --params_filename avg_1.jit.pdiparams \
# --save_dir . \
# --input_shape_dict="{'audio_chunk':[1,-1,161], 'audio_chunk_lens':[1], 'chunk_state_c_box':[5, 1, 1024], 'chunk_state_h_box':[5,1,1024]}"
parser = argparse.ArgumentParser()
parser.add_argument(
'--model_dir',
required=True,
help='Path of directory saved the input model.')
parser.add_argument(
'--model_filename', required=True, help='model.pdmodel.')
parser.add_argument(
'--params_filename', required=True, help='model.pdiparams.')
parser.add_argument(
'--save_dir',
required=True,
help='directory to save the exported model.')
parser.add_argument(
'--input_shape_dict', required=True, help="The new shape information.")
return parser.parse_args()
if __name__ == '__main__':
args = parse_arguments()
import paddle
paddle.enable_static()
import paddle.fluid as fluid
input_shape_dict_str = args.input_shape_dict
input_shape_dict = eval(input_shape_dict_str)
print("Start to load paddle model...")
exe = fluid.Executor(fluid.CPUPlace())
prog, ipts, outs = fluid.io.load_inference_model(
args.model_dir,
exe,
model_filename=args.model_filename,
params_filename=args.params_filename)
process_old_ops_desc(prog)
infer_shape(prog, input_shape_dict)
fluid.io.save_inference_model(
args.save_dir,
ipts,
outs,
exe,
prog,
model_filename=args.model_filename,
params_filename=args.params_filename)

@ -0,0 +1,158 @@
#!/usr/bin/env python3 -W ignore::DeprecationWarning
# https://github.com/jiangjiajun/PaddleUtils/blob/main/paddle/README.md#1-%E8%A3%81%E5%89%AApaddle%E6%A8%A1%E5%9E%8B
import argparse
import sys
from typing import List
# paddle prune model.
def prepend_feed_ops(program,
feed_target_names: List[str],
feed_holder_name='feed'):
import paddle.fluid.core as core
if len(feed_target_names) == 0:
return
global_block = program.global_block()
feed_var = global_block.create_var(
name=feed_holder_name,
type=core.VarDesc.VarType.FEED_MINIBATCH,
persistable=True, )
for i, name in enumerate(feed_target_names, 0):
if not global_block.has_var(name):
print(
f"The input[{i}]: '{name}' doesn't exist in pruned inference program, which will be ignored in new saved model."
)
continue
out = global_block.var(name)
global_block._prepend_op(
type='feed',
inputs={'X': [feed_var]},
outputs={'Out': [out]},
attrs={'col': i}, )
def append_fetch_ops(program,
fetch_target_names: List[str],
fetch_holder_name='fetch'):
"""in the place, we will add the fetch op
Args:
program (_type_): inference program
fetch_target_names (List[str]): target names
fetch_holder_name (str, optional): fetch op name. Defaults to 'fetch'.
"""
import paddle.fluid.core as core
global_block = program.global_block()
fetch_var = global_block.create_var(
name=fetch_holder_name,
type=core.VarDesc.VarType.FETCH_LIST,
persistable=True, )
print(f"the len of fetch_target_names: {len(fetch_target_names)}")
for i, name in enumerate(fetch_target_names):
global_block.append_op(
type='fetch',
inputs={'X': [name]},
outputs={'Out': [fetch_var]},
attrs={'col': i}, )
def insert_fetch(program,
fetch_target_names: List[str],
fetch_holder_name='fetch'):
"""in the place, we will add the fetch op
Args:
program (_type_): inference program
fetch_target_names (List[str]): target names
fetch_holder_name (str, optional): fetch op name. Defaults to 'fetch'.
"""
global_block = program.global_block()
# remove fetch
need_to_remove_op_index = []
for i, op in enumerate(global_block.ops):
if op.type == 'fetch':
need_to_remove_op_index.append(i)
for index in reversed(need_to_remove_op_index):
global_block._remove_op(index)
program.desc.flush()
# append new fetch
append_fetch_ops(program, fetch_target_names, fetch_holder_name)
def parse_arguments():
parser = argparse.ArgumentParser()
parser.add_argument(
'--model_dir',
required=True,
help='Path of directory saved the input model.')
parser.add_argument(
'--model_filename', required=True, help='model.pdmodel.')
parser.add_argument(
'--params_filename', required=True, help='model.pdiparams.')
parser.add_argument(
'--output_names',
required=True,
help='The outputs of model. sep by comma')
parser.add_argument(
'--save_dir',
required=True,
help='directory to save the exported model.')
parser.add_argument('--debug', default=False, help='output debug info.')
return parser.parse_args()
if __name__ == '__main__':
args = parse_arguments()
args.output_names = args.output_names.split(",")
if len(set(args.output_names)) < len(args.output_names):
print(
f"[ERROR] There's dumplicate name in --output_names {args.output_names}, which is not allowed."
)
sys.exit(-1)
import paddle
paddle.enable_static()
# hack prepend_feed_ops
paddle.fluid.io.prepend_feed_ops = prepend_feed_ops
import paddle.fluid as fluid
print("start to load paddle model")
exe = fluid.Executor(fluid.CPUPlace())
prog, ipts, outs = fluid.io.load_inference_model(
args.model_dir,
exe,
model_filename=args.model_filename,
params_filename=args.params_filename)
print("start to load insert fetch op")
new_outputs = []
insert_fetch(prog, args.output_names)
for out_name in args.output_names:
new_outputs.append(prog.global_block().var(out_name))
# not equal to paddle.static.save_inference_model
fluid.io.save_inference_model(
args.save_dir,
ipts,
new_outputs,
exe,
prog,
model_filename=args.model_filename,
params_filename=args.params_filename)
if args.debug:
for op in prog.global_block().ops:
print(op)

@ -0,0 +1,23 @@
#!/bin/bash
set -e
if [ $# != 5 ]; then
# local/prune.sh data/exp/deepspeech2_online/checkpoints avg_1.jit.pdmodel avg_1.jit.pdiparams softmax_0.tmp_0,tmp_5,concat_0.tmp_0,concat_1.tmp_0 $PWD
echo "usage: $0 model_dir model_filename param_filename outputs_names save_dir"
exit 1
fi
dir=$1
model=$2
param=$3
outputs=$4
save_dir=$5
python local/pd_prune_model.py \
--model_dir $dir \
--model_filename $model \
--params_filename $param \
--output_names $outputs \
--save_dir $save_dir

@ -0,0 +1,25 @@
#!/bin/bash
if [ $# != 4 ];then
# local/tonnx.sh data/exp/deepspeech2_online/checkpoints avg_1.jit.pdmodel avg_1.jit.pdiparams exp/model.onnx
echo "usage: $0 model_dir model_name param_name onnx_output_name"
exit 1
fi
dir=$1
model=$2
param=$3
output=$4
pip install paddle2onnx
pip install onnx
# https://github.com/PaddlePaddle/Paddle2ONNX#%E5%91%BD%E4%BB%A4%E8%A1%8C%E8%BD%AC%E6%8D%A2
paddle2onnx --model_dir $dir \
--model_filename $model \
--params_filename $param \
--save_file $output \
--enable_dev_version True \
--opset_version 9 \
--enable_onnx_checker True

@ -0,0 +1,14 @@
# This contains the locations of binarys build required for running the examples.
MAIN_ROOT=`realpath $PWD/../../../../`
SPEECHX_ROOT=$PWD/../../../
SPEECHX_BUILD=$SPEECHX_ROOT/build/speechx
SPEECHX_TOOLS=$SPEECHX_ROOT/tools
TOOLS_BIN=$SPEECHX_TOOLS/valgrind/install/bin
[ -d $SPEECHX_BUILD ] || { echo "Error: 'build/speechx' directory not found. please ensure that the project build successfully"; }
export LC_AL=C
export PATH=$PATH:$TOOLS_BIN

@ -0,0 +1,76 @@
#!/bin/bash
set -e
. path.sh
stage=0
stop_stage=100
. utils/parse_options.sh
data=data
exp=exp
mkdir -p $data $exp
if [ ${stage} -le 0 ] && [ ${stop_stage} -ge 0 ];then
test -f $data/asr0_deepspeech2_online_wenetspeech_ckpt_1.0.0a.model.tar.gz || wget -c https://paddlespeech.bj.bcebos.com/s2t/wenetspeech/asr0/asr0_deepspeech2_online_wenetspeech_ckpt_1.0.0a.model.tar.gz -P $data
# wenetspeech ds2 model
pushd $data
tar zxvf asr0_deepspeech2_online_wenetspeech_ckpt_1.0.0a.model.tar.gz
popd
# ds2 model demo inputs
pushd $exp
wget -c http://paddlespeech.bj.bcebos.com/speechx/examples/ds2_ol/onnx/static_ds2online_inputs.pickle
popd
fi
dir=$data/exp/deepspeech2_online/checkpoints
model=avg_1.jit.pdmodel
param=avg_1.jit.pdiparams
output_names=softmax_0.tmp_0,tmp_5,concat_0.tmp_0,concat_1.tmp_0
if [ ${stage} -le 1 ] && [ ${stop_stage} -ge 1 ];then
# prune model by outputs
mkdir -p $exp/prune
# prune model deps on output_names.
./local/prune.sh $dir $model $param $output_names $exp/prune
fi
input_shape_dict="{'audio_chunk':[1,-1,161], 'audio_chunk_lens':[1], 'chunk_state_c_box':[5, 1, 1024], 'chunk_state_h_box':[5,1,1024]}"
if [ ${stage} -le 2 ] && [ ${stop_stage} -ge 2 ];then
# infer shape by new shape
mkdir -p $exp/shape
python3 local/pd_infer_shape.py \
--model_dir $dir \
--model_filename $model \
--params_filename $param \
--save_dir $exp/shape \
--input_shape_dict="${input_shape_dict}"
fi
input_file=$exp/static_ds2online_inputs.pickle
test -e $input_file
if [ ${stage} -le 3 ] && [ ${stop_stage} -ge 3 ];then
# to onnx
./local/tonnx.sh $dir $model $param $exp/model.onnx
./local/infer_check.py --input_file $input_file --model_dir $dir --onnx_model $exp/model.onnx
fi
if [ ${stage} -le 4 ] && [ ${stop_stage} -ge 4 ] ;then
input_shape="audio_chunk:1,-1,161 audio_chunk_lens:1 chunk_state_c_box:5,1,1024 chunk_state_h_box:5,1,1024"
# simplifying onnx model
./local/onnx_opt.sh $exp/model.onnx $exp/model.opt.onnx "$input_shape"
./local/infer_check.py --input_file $input_file --model_dir $dir --onnx_model $exp/model.opt.onnx
fi

@ -747,7 +747,7 @@ def num2chn(number_string,
previous_symbol, (CNU, type(None))):
if next_symbol.power != 1 and (
(previous_symbol is None) or
(previous_symbol.power != 1)):
(previous_symbol.power != 1)): # noqa: E129
result_symbols[i] = liang
# if big is True, '两' will not be used and `alt_two` has no impact on output

Loading…
Cancel
Save