more comment

pull/2212/head
Hui Zhang 2 years ago
parent a7c6c54e75
commit 63aeb747b0

@ -492,10 +492,9 @@ class U2Tester(U2Trainer):
] ]
infer_model.forward_encoder_chunk = paddle.jit.to_static( infer_model.forward_encoder_chunk = paddle.jit.to_static(
infer_model.forward_encoder_chunk, input_spec=input_spec) infer_model.forward_encoder_chunk, input_spec=input_spec)
# paddle.jit.save(static_model, self.args.export_path, combine_params=True)
######################### infer_model.forward_attention_decoder ######################## ######################### infer_model.forward_attention_decoder ########################
# TODO: 512(encoder_output) be configable. 1 for B # TODO: 512(encoder_output) be configable. 1 for BatchSize
input_spec = [ input_spec = [
paddle.static.InputSpec(shape=[None, None], dtype='int64'), paddle.static.InputSpec(shape=[None, None], dtype='int64'),
paddle.static.InputSpec(shape=[None], dtype='int64'), paddle.static.InputSpec(shape=[None], dtype='int64'),
@ -503,7 +502,6 @@ class U2Tester(U2Trainer):
] ]
infer_model.forward_attention_decoder = paddle.jit.to_static( infer_model.forward_attention_decoder = paddle.jit.to_static(
infer_model.forward_attention_decoder, input_spec=input_spec) infer_model.forward_attention_decoder, input_spec=input_spec)
# paddle.jit.save(static_model, self.args.export_path, combine_params=True)
######################### infer_model.ctc_activation ######################## ######################### infer_model.ctc_activation ########################
# TODO: 512(encoder_output) be configable # TODO: 512(encoder_output) be configable
@ -513,8 +511,10 @@ class U2Tester(U2Trainer):
infer_model.ctc_activation = paddle.jit.to_static( infer_model.ctc_activation = paddle.jit.to_static(
infer_model.ctc_activation, input_spec=input_spec) infer_model.ctc_activation, input_spec=input_spec)
paddle.jit.save(infer_model, './export.jit', combine_params=True, skip_forward=True) # jit save
paddle.jit.save(infer_model, self.args.export_path, combine_params=True, skip_forward=True)
# test dy2static
def flatten(out): def flatten(out):
if isinstance(out, paddle.Tensor): if isinstance(out, paddle.Tensor):
return [out] return [out]
@ -541,7 +541,7 @@ class U2Tester(U2Trainer):
from paddle.jit.layer import Layer from paddle.jit.layer import Layer
layer = Layer() layer = Layer()
layer.load('./export.jit', paddle.CPUPlace()) layer.load(self.args.export_path, paddle.CPUPlace())
xs1 = paddle.full([1, 7, 80], 0.1, dtype='float32') xs1 = paddle.full([1, 7, 80], 0.1, dtype='float32')
offset = paddle.to_tensor([0], dtype='int32') offset = paddle.to_tensor([0], dtype='int32')

@ -251,7 +251,12 @@ class BaseEncoder(nn.Layer):
for i, layer in enumerate(self.encoders): for i, layer in enumerate(self.encoders):
# att_cache[i:i+1] = (1, head, cache_t1, d_k*2) # att_cache[i:i+1] = (1, head, cache_t1, d_k*2)
# cnn_cache[i:i+1] = (1, B=1, hidden-dim, cache_t2) # cnn_cache[i:i+1] = (1, B=1, hidden-dim, cache_t2)
# zeros([0,0,0,0]) support [i:i+1] slice
# WARNING: eliminate if-else cond op in graph
# tensor zeros([0,0,0,0]) support [i:i+1] slice, will return zeros([0,0,0,0]) tensor
# raw code as below:
# att_cache=att_cache[i:i+1] if elayers > 0 else att_cache,
# cnn_cache=cnn_cache[i:i+1] if paddle.shape(cnn_cache)[0] > 0 else cnn_cache,
xs, _, new_att_cache, new_cnn_cache = layer( xs, _, new_att_cache, new_cnn_cache = layer(
xs, att_mask, pos_emb, xs, att_mask, pos_emb,
att_cache=att_cache[i:i+1], att_cache=att_cache[i:i+1],

Loading…
Cancel
Save