pull/2034/head
Hui Zhang 3 years ago
parent 5a4e35b543
commit b472a148dc

@ -34,4 +34,4 @@ For more details please see `run.sh`.
## Outputs
The optimized onnx model is `exp/model.opt.onnx`.
To show the graph, please using `local/netron.sh`.
To show the graph, please using `local/netron.sh`.

@ -12,13 +12,14 @@
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import argparse
import os
import pickle
import numpy as np
import onnxruntime
import paddle
import os
import pickle
def parse_args():
parser = argparse.ArgumentParser(description=__doc__)
@ -26,26 +27,19 @@ def parse_args():
'--input_file',
type=str,
default="static_ds2online_inputs.pickle",
help="ds2 input pickle file.",
)
help="ds2 input pickle file.", )
parser.add_argument(
'--model_dir',
type=str,
default=".",
help="paddle model dir."
)
'--model_dir', type=str, default=".", help="paddle model dir.")
parser.add_argument(
'--model_prefix',
type=str,
default="avg_1.jit",
help="paddle model prefix."
)
help="paddle model prefix.")
parser.add_argument(
'--onnx_model',
type=str,
default='./model.old.onnx',
help="onnx model."
)
help="onnx model.")
return parser.parse_args()
@ -69,19 +63,19 @@ if __name__ == '__main__':
paddle.to_tensor(audio_chunk),
paddle.to_tensor(audio_chunk_lens),
paddle.to_tensor(chunk_state_h_box),
paddle.to_tensor(chunk_state_c_box),
)
paddle.to_tensor(chunk_state_c_box), )
# onnxruntime
options = onnxruntime.SessionOptions()
options.enable_profiling=True
options.enable_profiling = True
sess = onnxruntime.InferenceSession(FLAGS.onnx_model, sess_options=options)
ort_res_chunk, ort_res_lens, ort_chunk_state_h, ort_chunk_state_c = sess.run(
['softmax_0.tmp_0', 'tmp_5', 'concat_0.tmp_0', 'concat_1.tmp_0'],
{"audio_chunk": audio_chunk,
"audio_chunk_lens":audio_chunk_lens,
"chunk_state_h_box": chunk_state_h_box,
"chunk_state_c_box":chunk_state_c_box})
['softmax_0.tmp_0', 'tmp_5', 'concat_0.tmp_0', 'concat_1.tmp_0'], {
"audio_chunk": audio_chunk,
"audio_chunk_lens": audio_chunk_lens,
"chunk_state_h_box": chunk_state_h_box,
"chunk_state_c_box": chunk_state_c_box
})
print(sess.end_profiling())
@ -89,4 +83,4 @@ if __name__ == '__main__':
print(np.allclose(ort_res_chunk, res_chunk, atol=1e-6))
print(np.allclose(ort_res_lens, res_lens, atol=1e-6))
print(np.allclose(ort_chunk_state_h, chunk_state_h, atol=1e-6))
print(np.allclose(ort_chunk_state_c, chunk_state_c, atol=1e-6))
print(np.allclose(ort_chunk_state_c, chunk_state_c, atol=1e-6))

@ -1,5 +1,6 @@
# Copyright (c) Microsoft Corporation. All rights reserved.
# Licensed under the MIT License.
# flake8: noqa
import argparse
import logging
@ -491,9 +492,6 @@ class SymbolicShapeInference:
skip_infer = node.op_type in [
'If', 'Loop', 'Scan', 'SplitToSequence', 'ZipMap', \
# contrib ops
'Attention', 'BiasGelu', \
'EmbedLayerNormalization', \
'FastGelu', 'Gelu', 'LayerNormalization', \
@ -1605,8 +1603,8 @@ class SymbolicShapeInference:
def _infer_Scan(self, node):
subgraph = get_attribute(node, 'body')
num_scan_inputs = get_attribute(node, 'num_scan_inputs')
scan_input_axes = get_attribute(node, 'scan_input_axes', [0] *
num_scan_inputs)
scan_input_axes = get_attribute(node, 'scan_input_axes',
[0] * num_scan_inputs)
num_scan_states = len(node.input) - num_scan_inputs
scan_input_axes = [
handle_negative_axis(
@ -1627,8 +1625,8 @@ class SymbolicShapeInference:
si.name = subgraph_name
self._onnx_infer_subgraph(node, subgraph)
num_scan_outputs = len(node.output) - num_scan_states
scan_output_axes = get_attribute(node, 'scan_output_axes', [0] *
num_scan_outputs)
scan_output_axes = get_attribute(node, 'scan_output_axes',
[0] * num_scan_outputs)
scan_input_dim = get_shape_from_type_proto(
self.known_vi_[node.input[-1]].type)[scan_input_axes[-1]]
for i, o in enumerate(node.output):
@ -1821,8 +1819,8 @@ class SymbolicShapeInference:
split = get_attribute(node, 'split')
if not split:
num_outputs = len(node.output)
split = [input_sympy_shape[axis] / sympy.Integer(num_outputs)
] * num_outputs
split = [input_sympy_shape[axis] /
sympy.Integer(num_outputs)] * num_outputs
self._update_computed_dims(split)
else:
split = [sympy.Integer(s) for s in split]
@ -2174,8 +2172,8 @@ class SymbolicShapeInference:
subgraphs = []
if 'If' == node.op_type:
subgraphs = [
get_attribute(node, 'then_branch'), get_attribute(
node, 'else_branch')
get_attribute(node, 'then_branch'),
get_attribute(node, 'else_branch')
]
elif node.op_type in ['Loop', 'Scan']:
subgraphs = [get_attribute(node, 'body')]
@ -2330,8 +2328,8 @@ class SymbolicShapeInference:
'LessOrEqual', 'GreaterOrEqual'
]:
shapes = [
self._get_shape(node, i) for i in range(
len(node.input))
self._get_shape(node, i)
for i in range(len(node.input))
]
if node.op_type in [
'MatMul', 'MatMulInteger', 'MatMulInteger16'

@ -1,13 +1,12 @@
#!/usr/bin/env python3 -W ignore::DeprecationWarning
# prune model by output names
import argparse
import copy
import sys
import onnx
def parse_arguments():
parser = argparse.ArgumentParser()
parser.add_argument(

@ -1,5 +1,4 @@
#!/usr/bin/env python3 -W ignore::DeprecationWarning
# rename node to new names
import argparse
import sys

@ -4,6 +4,7 @@ import argparse
# paddle inference shape
def process_old_ops_desc(program):
"""set matmul op head_number attr to 1 is not exist.

@ -6,6 +6,7 @@ from typing import List
# paddle prune model.
def prepend_feed_ops(program,
feed_target_names: List[str],
feed_holder_name='feed'):

@ -747,7 +747,7 @@ def num2chn(number_string,
previous_symbol, (CNU, type(None))):
if next_symbol.power != 1 and (
(previous_symbol is None) or
(previous_symbol.power != 1)):
(previous_symbol.power != 1)): # noqa: E129
result_symbols[i] = liang
# if big is True, '两' will not be used and `alt_two` has no impact on output

Loading…
Cancel
Save