do not jit save forward; using slice for zeros([0,0,0,0]) tensor

pull/2212/head
Hui Zhang 3 years ago
parent c1fbfe928e
commit d638325c46

@ -482,10 +482,12 @@ class U2Tester(U2Trainer):
# TODO: 80(feature dim) be configable
input_spec = [
paddle.static.InputSpec(shape=[1, None, 80], dtype='float32'),
paddle.static.InputSpec(shape=[1], dtype='int32'), -1,
paddle.static.InputSpec(shape=[1], dtype='int32'),
-1,
paddle.static.InputSpec(
shape=[None, None, None, None],
dtype='float32'), paddle.static.InputSpec(
dtype='float32'),
paddle.static.InputSpec(
shape=[None, None, None, None], dtype='float32')
]
infer_model.forward_encoder_chunk = paddle.jit.to_static(
@ -511,7 +513,7 @@ class U2Tester(U2Trainer):
infer_model.ctc_activation = paddle.jit.to_static(
infer_model.ctc_activation, input_spec=input_spec)
paddle.jit.save(infer_model, './export.jit', combine_params=True)
paddle.jit.save(infer_model, './export.jit', combine_params=True, skip_forward=True)
def flatten(out):
if isinstance(out, paddle.Tensor):
@ -531,33 +533,20 @@ class U2Tester(U2Trainer):
att_cache = paddle.zeros([0, 0, 0, 0])
cnn_cache = paddle.zeros([0, 0, 0, 0])
# xs, att_cache, cnn_cache = infer_model.forward_encoder_chunk(xs1, offset, required_cache_size, att_cache, cnn_cache)
# xs2 = paddle.rand(shape=[1, 67, 80], dtype='float32')
# offset = paddle.to_tensor([16], dtype='int32')
# out1 = infer_model.forward_encoder_chunk(xs2, offset, required_cache_size, att_cache, cnn_cache)
# print(out1)
xs, att_cache, cnn_cache = infer_model.forward_encoder_chunk(
xs1, offset, att_cache, cnn_cache)
xs, att_cache, cnn_cache = infer_model.forward_encoder_chunk(xs1, offset, required_cache_size, att_cache, cnn_cache)
xs2 = paddle.rand(shape=[1, 67, 80], dtype='float32')
offset = paddle.to_tensor([16], dtype='int32')
out1 = infer_model.forward_encoder_chunk(xs2, offset, att_cache,
cnn_cache)
print(out1)
# from paddle.jit.layer import Layer
# layer = Layer()
# layer.load('./export.jit', paddle.CPUPlace())
# offset = paddle.to_tensor([0], dtype='int32')
# att_cache = paddle.zeros([0, 0, 0, 0])
# cnn_cache=paddle.zeros([0, 0, 0, 0])
# xs, att_cache, cnn_cache = layer.forward_encoder_chunk(xs1, offset, att_cache, cnn_cache)
# offset = paddle.to_tensor([16], dtype='int32')
# out2 = layer.forward_encoder_chunk(xs2, offset, att_cache, cnn_cache)
# # print(out2)
# out1 = flatten(out1)
# out2 = flatten(out2)
# for i in range(len(out1)):
# print(np.equal(out1[i].numpy(), out2[i].numpy()).all())
out1 = infer_model.forward_encoder_chunk(xs2, offset, required_cache_size, att_cache, cnn_cache)
print('py encoder', out1)
from paddle.jit.layer import Layer
layer = Layer()
layer.load('./export.jit', paddle.CPUPlace())
xs1 = paddle.full([1, 7, 80], 0.1, dtype='float32')
offset = paddle.to_tensor([0], dtype='int32')
att_cache = paddle.zeros([0, 0, 0, 0])
cnn_cache=paddle.zeros([0, 0, 0, 0])
func = getattr(layer, 'forward_encoder_chunk')
xs, att_cache, cnn_cache = func(xs1, offset, att_cache, cnn_cache)
print('py static encoder', xs)

@ -924,10 +924,6 @@ class U2InferModel(U2Model):
def __init__(self, configs: dict):
super().__init__(configs)
@jit.to_static(input_spec=[
paddle.static.InputSpec(shape=[1, 1, 1], dtype='int64'),
paddle.static.InputSpec(shape=[1], dtype='int64')
])
def forward(self,
feats,
feats_lengths,

@ -251,10 +251,11 @@ class BaseEncoder(nn.Layer):
for i, layer in enumerate(self.encoders):
# att_cache[i:i+1] = (1, head, cache_t1, d_k*2)
# cnn_cache[i:i+1] = (1, B=1, hidden-dim, cache_t2)
# zeros([0,0,0,0]) support [i:i+1] slice
xs, _, new_att_cache, new_cnn_cache = layer(
xs, att_mask, pos_emb,
att_cache=att_cache[i:i+1] if elayers > 0 else att_cache,
cnn_cache=cnn_cache[i:i+1] if paddle.shape(cnn_cache)[0] > 0 else cnn_cache,
att_cache=att_cache[i:i+1],
cnn_cache=cnn_cache[i:i+1],
)
# new_att_cache = (1, head, attention_key_size, d_k*2)
# new_cnn_cache = (B=1, hidden-dim, cache_t2)

Loading…
Cancel
Save