rm to_static decarator; configure jit save for ctc_activation

pull/2212/head
Hui Zhang 3 years ago
parent 1c9f238ba0
commit 3a8869fba4

@ -513,9 +513,9 @@ class U2Tester(U2Trainer):
infer_model.forward_attention_decoder, input_spec=input_spec)
######################### infer_model.ctc_activation ########################
# TODO: 512(encoder_output) be configable
input_spec = [
paddle.static.InputSpec(shape=[1, None, 512], dtype='float32')
# encoder_out, (B,T,D)
paddle.static.InputSpec(shape=[batch_size, None, model_size], dtype='float32')
]
infer_model.ctc_activation = paddle.jit.to_static(
infer_model.ctc_activation, input_spec=input_spec)

@ -599,12 +599,7 @@ class U2BaseModel(ASRInterface, nn.Layer):
"""
return self.eos
# @jit.to_static(input_spec=[
# paddle.static.InputSpec(shape=[1, None, 80], dtype='float32'),
# paddle.static.InputSpec(shape=[1], dtype='int32'),
# -1,
# paddle.static.InputSpec(shape=[None, None, None, None], dtype='float32'),
# paddle.static.InputSpec(shape=[None, None, None, None], dtype='float32')])
# @jit.to_static
def forward_encoder_chunk(
self,
xs: paddle.Tensor,
@ -658,10 +653,7 @@ class U2BaseModel(ASRInterface, nn.Layer):
"""
return self.ctc.log_softmax(xs)
# @jit.to_static(input_spec=[
# paddle.static.InputSpec(shape=[None, None], dtype='int64'),
# paddle.static.InputSpec(shape=[None], dtype='int64'),
# paddle.static.InputSpec(shape=[1, None, 512], dtype='float32')])
# @jit.to_static
def forward_attention_decoder(
self,
hyps: paddle.Tensor,

Loading…
Cancel
Save