|
|
|
@ -1,5 +1,6 @@
|
|
|
|
|
# Copyright (c) Microsoft Corporation. All rights reserved.
|
|
|
|
|
# Licensed under the MIT License.
|
|
|
|
|
# flake8: noqa
|
|
|
|
|
import argparse
|
|
|
|
|
import logging
|
|
|
|
|
|
|
|
|
@ -491,9 +492,6 @@ class SymbolicShapeInference:
|
|
|
|
|
skip_infer = node.op_type in [
|
|
|
|
|
'If', 'Loop', 'Scan', 'SplitToSequence', 'ZipMap', \
|
|
|
|
|
# contrib ops
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
'Attention', 'BiasGelu', \
|
|
|
|
|
'EmbedLayerNormalization', \
|
|
|
|
|
'FastGelu', 'Gelu', 'LayerNormalization', \
|
|
|
|
@ -1605,8 +1603,8 @@ class SymbolicShapeInference:
|
|
|
|
|
def _infer_Scan(self, node):
|
|
|
|
|
subgraph = get_attribute(node, 'body')
|
|
|
|
|
num_scan_inputs = get_attribute(node, 'num_scan_inputs')
|
|
|
|
|
scan_input_axes = get_attribute(node, 'scan_input_axes', [0] *
|
|
|
|
|
num_scan_inputs)
|
|
|
|
|
scan_input_axes = get_attribute(node, 'scan_input_axes',
|
|
|
|
|
[0] * num_scan_inputs)
|
|
|
|
|
num_scan_states = len(node.input) - num_scan_inputs
|
|
|
|
|
scan_input_axes = [
|
|
|
|
|
handle_negative_axis(
|
|
|
|
@ -1627,8 +1625,8 @@ class SymbolicShapeInference:
|
|
|
|
|
si.name = subgraph_name
|
|
|
|
|
self._onnx_infer_subgraph(node, subgraph)
|
|
|
|
|
num_scan_outputs = len(node.output) - num_scan_states
|
|
|
|
|
scan_output_axes = get_attribute(node, 'scan_output_axes', [0] *
|
|
|
|
|
num_scan_outputs)
|
|
|
|
|
scan_output_axes = get_attribute(node, 'scan_output_axes',
|
|
|
|
|
[0] * num_scan_outputs)
|
|
|
|
|
scan_input_dim = get_shape_from_type_proto(
|
|
|
|
|
self.known_vi_[node.input[-1]].type)[scan_input_axes[-1]]
|
|
|
|
|
for i, o in enumerate(node.output):
|
|
|
|
@ -1821,8 +1819,8 @@ class SymbolicShapeInference:
|
|
|
|
|
split = get_attribute(node, 'split')
|
|
|
|
|
if not split:
|
|
|
|
|
num_outputs = len(node.output)
|
|
|
|
|
split = [input_sympy_shape[axis] / sympy.Integer(num_outputs)
|
|
|
|
|
] * num_outputs
|
|
|
|
|
split = [input_sympy_shape[axis] /
|
|
|
|
|
sympy.Integer(num_outputs)] * num_outputs
|
|
|
|
|
self._update_computed_dims(split)
|
|
|
|
|
else:
|
|
|
|
|
split = [sympy.Integer(s) for s in split]
|
|
|
|
@ -2174,8 +2172,8 @@ class SymbolicShapeInference:
|
|
|
|
|
subgraphs = []
|
|
|
|
|
if 'If' == node.op_type:
|
|
|
|
|
subgraphs = [
|
|
|
|
|
get_attribute(node, 'then_branch'), get_attribute(
|
|
|
|
|
node, 'else_branch')
|
|
|
|
|
get_attribute(node, 'then_branch'),
|
|
|
|
|
get_attribute(node, 'else_branch')
|
|
|
|
|
]
|
|
|
|
|
elif node.op_type in ['Loop', 'Scan']:
|
|
|
|
|
subgraphs = [get_attribute(node, 'body')]
|
|
|
|
@ -2330,8 +2328,8 @@ class SymbolicShapeInference:
|
|
|
|
|
'LessOrEqual', 'GreaterOrEqual'
|
|
|
|
|
]:
|
|
|
|
|
shapes = [
|
|
|
|
|
self._get_shape(node, i) for i in range(
|
|
|
|
|
len(node.input))
|
|
|
|
|
self._get_shape(node, i)
|
|
|
|
|
for i in range(len(node.input))
|
|
|
|
|
]
|
|
|
|
|
if node.op_type in [
|
|
|
|
|
'MatMul', 'MatMulInteger', 'MatMulInteger16'
|
|
|
|
|