pull/2034/head
Hui Zhang 3 years ago
parent 5a4e35b543
commit b472a148dc

@ -34,4 +34,4 @@ For more details please see `run.sh`.
## Outputs ## Outputs
The optimized onnx model is `exp/model.opt.onnx`. The optimized onnx model is `exp/model.opt.onnx`.
To show the graph, please using `local/netron.sh`. To show the graph, please using `local/netron.sh`.

@ -12,13 +12,14 @@
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and # See the License for the specific language governing permissions and
# limitations under the License. # limitations under the License.
import argparse import argparse
import os
import pickle
import numpy as np import numpy as np
import onnxruntime import onnxruntime
import paddle import paddle
import os
import pickle
def parse_args(): def parse_args():
parser = argparse.ArgumentParser(description=__doc__) parser = argparse.ArgumentParser(description=__doc__)
@ -26,26 +27,19 @@ def parse_args():
'--input_file', '--input_file',
type=str, type=str,
default="static_ds2online_inputs.pickle", default="static_ds2online_inputs.pickle",
help="ds2 input pickle file.", help="ds2 input pickle file.", )
)
parser.add_argument( parser.add_argument(
'--model_dir', '--model_dir', type=str, default=".", help="paddle model dir.")
type=str,
default=".",
help="paddle model dir."
)
parser.add_argument( parser.add_argument(
'--model_prefix', '--model_prefix',
type=str, type=str,
default="avg_1.jit", default="avg_1.jit",
help="paddle model prefix." help="paddle model prefix.")
)
parser.add_argument( parser.add_argument(
'--onnx_model', '--onnx_model',
type=str, type=str,
default='./model.old.onnx', default='./model.old.onnx',
help="onnx model." help="onnx model.")
)
return parser.parse_args() return parser.parse_args()
@ -69,19 +63,19 @@ if __name__ == '__main__':
paddle.to_tensor(audio_chunk), paddle.to_tensor(audio_chunk),
paddle.to_tensor(audio_chunk_lens), paddle.to_tensor(audio_chunk_lens),
paddle.to_tensor(chunk_state_h_box), paddle.to_tensor(chunk_state_h_box),
paddle.to_tensor(chunk_state_c_box), paddle.to_tensor(chunk_state_c_box), )
)
# onnxruntime # onnxruntime
options = onnxruntime.SessionOptions() options = onnxruntime.SessionOptions()
options.enable_profiling=True options.enable_profiling = True
sess = onnxruntime.InferenceSession(FLAGS.onnx_model, sess_options=options) sess = onnxruntime.InferenceSession(FLAGS.onnx_model, sess_options=options)
ort_res_chunk, ort_res_lens, ort_chunk_state_h, ort_chunk_state_c = sess.run( ort_res_chunk, ort_res_lens, ort_chunk_state_h, ort_chunk_state_c = sess.run(
['softmax_0.tmp_0', 'tmp_5', 'concat_0.tmp_0', 'concat_1.tmp_0'], ['softmax_0.tmp_0', 'tmp_5', 'concat_0.tmp_0', 'concat_1.tmp_0'], {
{"audio_chunk": audio_chunk, "audio_chunk": audio_chunk,
"audio_chunk_lens":audio_chunk_lens, "audio_chunk_lens": audio_chunk_lens,
"chunk_state_h_box": chunk_state_h_box, "chunk_state_h_box": chunk_state_h_box,
"chunk_state_c_box":chunk_state_c_box}) "chunk_state_c_box": chunk_state_c_box
})
print(sess.end_profiling()) print(sess.end_profiling())
@ -89,4 +83,4 @@ if __name__ == '__main__':
print(np.allclose(ort_res_chunk, res_chunk, atol=1e-6)) print(np.allclose(ort_res_chunk, res_chunk, atol=1e-6))
print(np.allclose(ort_res_lens, res_lens, atol=1e-6)) print(np.allclose(ort_res_lens, res_lens, atol=1e-6))
print(np.allclose(ort_chunk_state_h, chunk_state_h, atol=1e-6)) print(np.allclose(ort_chunk_state_h, chunk_state_h, atol=1e-6))
print(np.allclose(ort_chunk_state_c, chunk_state_c, atol=1e-6)) print(np.allclose(ort_chunk_state_c, chunk_state_c, atol=1e-6))

@ -1,5 +1,6 @@
# Copyright (c) Microsoft Corporation. All rights reserved. # Copyright (c) Microsoft Corporation. All rights reserved.
# Licensed under the MIT License. # Licensed under the MIT License.
# flake8: noqa
import argparse import argparse
import logging import logging
@ -491,9 +492,6 @@ class SymbolicShapeInference:
skip_infer = node.op_type in [ skip_infer = node.op_type in [
'If', 'Loop', 'Scan', 'SplitToSequence', 'ZipMap', \ 'If', 'Loop', 'Scan', 'SplitToSequence', 'ZipMap', \
# contrib ops # contrib ops
'Attention', 'BiasGelu', \ 'Attention', 'BiasGelu', \
'EmbedLayerNormalization', \ 'EmbedLayerNormalization', \
'FastGelu', 'Gelu', 'LayerNormalization', \ 'FastGelu', 'Gelu', 'LayerNormalization', \
@ -1605,8 +1603,8 @@ class SymbolicShapeInference:
def _infer_Scan(self, node): def _infer_Scan(self, node):
subgraph = get_attribute(node, 'body') subgraph = get_attribute(node, 'body')
num_scan_inputs = get_attribute(node, 'num_scan_inputs') num_scan_inputs = get_attribute(node, 'num_scan_inputs')
scan_input_axes = get_attribute(node, 'scan_input_axes', [0] * scan_input_axes = get_attribute(node, 'scan_input_axes',
num_scan_inputs) [0] * num_scan_inputs)
num_scan_states = len(node.input) - num_scan_inputs num_scan_states = len(node.input) - num_scan_inputs
scan_input_axes = [ scan_input_axes = [
handle_negative_axis( handle_negative_axis(
@ -1627,8 +1625,8 @@ class SymbolicShapeInference:
si.name = subgraph_name si.name = subgraph_name
self._onnx_infer_subgraph(node, subgraph) self._onnx_infer_subgraph(node, subgraph)
num_scan_outputs = len(node.output) - num_scan_states num_scan_outputs = len(node.output) - num_scan_states
scan_output_axes = get_attribute(node, 'scan_output_axes', [0] * scan_output_axes = get_attribute(node, 'scan_output_axes',
num_scan_outputs) [0] * num_scan_outputs)
scan_input_dim = get_shape_from_type_proto( scan_input_dim = get_shape_from_type_proto(
self.known_vi_[node.input[-1]].type)[scan_input_axes[-1]] self.known_vi_[node.input[-1]].type)[scan_input_axes[-1]]
for i, o in enumerate(node.output): for i, o in enumerate(node.output):
@ -1821,8 +1819,8 @@ class SymbolicShapeInference:
split = get_attribute(node, 'split') split = get_attribute(node, 'split')
if not split: if not split:
num_outputs = len(node.output) num_outputs = len(node.output)
split = [input_sympy_shape[axis] / sympy.Integer(num_outputs) split = [input_sympy_shape[axis] /
] * num_outputs sympy.Integer(num_outputs)] * num_outputs
self._update_computed_dims(split) self._update_computed_dims(split)
else: else:
split = [sympy.Integer(s) for s in split] split = [sympy.Integer(s) for s in split]
@ -2174,8 +2172,8 @@ class SymbolicShapeInference:
subgraphs = [] subgraphs = []
if 'If' == node.op_type: if 'If' == node.op_type:
subgraphs = [ subgraphs = [
get_attribute(node, 'then_branch'), get_attribute( get_attribute(node, 'then_branch'),
node, 'else_branch') get_attribute(node, 'else_branch')
] ]
elif node.op_type in ['Loop', 'Scan']: elif node.op_type in ['Loop', 'Scan']:
subgraphs = [get_attribute(node, 'body')] subgraphs = [get_attribute(node, 'body')]
@ -2330,8 +2328,8 @@ class SymbolicShapeInference:
'LessOrEqual', 'GreaterOrEqual' 'LessOrEqual', 'GreaterOrEqual'
]: ]:
shapes = [ shapes = [
self._get_shape(node, i) for i in range( self._get_shape(node, i)
len(node.input)) for i in range(len(node.input))
] ]
if node.op_type in [ if node.op_type in [
'MatMul', 'MatMulInteger', 'MatMulInteger16' 'MatMul', 'MatMulInteger', 'MatMulInteger16'

@ -1,13 +1,12 @@
#!/usr/bin/env python3 -W ignore::DeprecationWarning #!/usr/bin/env python3 -W ignore::DeprecationWarning
# prune model by output names # prune model by output names
import argparse import argparse
import copy import copy
import sys import sys
import onnx import onnx
def parse_arguments(): def parse_arguments():
parser = argparse.ArgumentParser() parser = argparse.ArgumentParser()
parser.add_argument( parser.add_argument(

@ -1,5 +1,4 @@
#!/usr/bin/env python3 -W ignore::DeprecationWarning #!/usr/bin/env python3 -W ignore::DeprecationWarning
# rename node to new names # rename node to new names
import argparse import argparse
import sys import sys

@ -4,6 +4,7 @@ import argparse
# paddle inference shape # paddle inference shape
def process_old_ops_desc(program): def process_old_ops_desc(program):
"""set matmul op head_number attr to 1 is not exist. """set matmul op head_number attr to 1 is not exist.

@ -6,6 +6,7 @@ from typing import List
# paddle prune model. # paddle prune model.
def prepend_feed_ops(program, def prepend_feed_ops(program,
feed_target_names: List[str], feed_target_names: List[str],
feed_holder_name='feed'): feed_holder_name='feed'):

@ -747,7 +747,7 @@ def num2chn(number_string,
previous_symbol, (CNU, type(None))): previous_symbol, (CNU, type(None))):
if next_symbol.power != 1 and ( if next_symbol.power != 1 and (
(previous_symbol is None) or (previous_symbol is None) or
(previous_symbol.power != 1)): (previous_symbol.power != 1)): # noqa: E129
result_symbols[i] = liang result_symbols[i] = liang
# if big is True, '两' will not be used and `alt_two` has no impact on output # if big is True, '两' will not be used and `alt_two` has no impact on output

Loading…
Cancel
Save