diff --git a/paddlespeech/s2t/exps/u2/model.py b/paddlespeech/s2t/exps/u2/model.py index 66b95f63..1d813761 100644 --- a/paddlespeech/s2t/exps/u2/model.py +++ b/paddlespeech/s2t/exps/u2/model.py @@ -492,10 +492,9 @@ class U2Tester(U2Trainer): ] infer_model.forward_encoder_chunk = paddle.jit.to_static( infer_model.forward_encoder_chunk, input_spec=input_spec) - # paddle.jit.save(static_model, self.args.export_path, combine_params=True) ######################### infer_model.forward_attention_decoder ######################## - # TODO: 512(encoder_output) be configable. 1 for B + # TODO: 512(encoder_output) be configable. 1 for BatchSize input_spec = [ paddle.static.InputSpec(shape=[None, None], dtype='int64'), paddle.static.InputSpec(shape=[None], dtype='int64'), @@ -503,7 +502,6 @@ class U2Tester(U2Trainer): ] infer_model.forward_attention_decoder = paddle.jit.to_static( infer_model.forward_attention_decoder, input_spec=input_spec) - # paddle.jit.save(static_model, self.args.export_path, combine_params=True) ######################### infer_model.ctc_activation ######################## # TODO: 512(encoder_output) be configable @@ -513,8 +511,10 @@ class U2Tester(U2Trainer): infer_model.ctc_activation = paddle.jit.to_static( infer_model.ctc_activation, input_spec=input_spec) - paddle.jit.save(infer_model, './export.jit', combine_params=True, skip_forward=True) + # jit save + paddle.jit.save(infer_model, self.args.export_path, combine_params=True, skip_forward=True) + # test dy2static def flatten(out): if isinstance(out, paddle.Tensor): return [out] @@ -541,7 +541,7 @@ class U2Tester(U2Trainer): from paddle.jit.layer import Layer layer = Layer() - layer.load('./export.jit', paddle.CPUPlace()) + layer.load(self.args.export_path, paddle.CPUPlace()) xs1 = paddle.full([1, 7, 80], 0.1, dtype='float32') offset = paddle.to_tensor([0], dtype='int32') diff --git a/paddlespeech/s2t/modules/encoder.py b/paddlespeech/s2t/modules/encoder.py index a7919bca..230894d5 100644 --- a/paddlespeech/s2t/modules/encoder.py +++ b/paddlespeech/s2t/modules/encoder.py @@ -251,7 +251,12 @@ class BaseEncoder(nn.Layer): for i, layer in enumerate(self.encoders): # att_cache[i:i+1] = (1, head, cache_t1, d_k*2) # cnn_cache[i:i+1] = (1, B=1, hidden-dim, cache_t2) - # zeros([0,0,0,0]) support [i:i+1] slice + + # WARNING: eliminate if-else cond op in graph + # tensor zeros([0,0,0,0]) support [i:i+1] slice, will return zeros([0,0,0,0]) tensor + # raw code as below: + # att_cache=att_cache[i:i+1] if elayers > 0 else att_cache, + # cnn_cache=cnn_cache[i:i+1] if paddle.shape(cnn_cache)[0] > 0 else cnn_cache, xs, _, new_att_cache, new_cnn_cache = layer( xs, att_mask, pos_emb, att_cache=att_cache[i:i+1],