rm to_static decarator; configure jit save for ctc_activation

pull/2212/head
Hui Zhang 3 years ago
parent 1c9f238ba0
commit 3a8869fba4

@ -513,9 +513,9 @@ class U2Tester(U2Trainer):
infer_model.forward_attention_decoder, input_spec=input_spec) infer_model.forward_attention_decoder, input_spec=input_spec)
######################### infer_model.ctc_activation ######################## ######################### infer_model.ctc_activation ########################
# TODO: 512(encoder_output) be configable
input_spec = [ input_spec = [
paddle.static.InputSpec(shape=[1, None, 512], dtype='float32') # encoder_out, (B,T,D)
paddle.static.InputSpec(shape=[batch_size, None, model_size], dtype='float32')
] ]
infer_model.ctc_activation = paddle.jit.to_static( infer_model.ctc_activation = paddle.jit.to_static(
infer_model.ctc_activation, input_spec=input_spec) infer_model.ctc_activation, input_spec=input_spec)

@ -599,12 +599,7 @@ class U2BaseModel(ASRInterface, nn.Layer):
""" """
return self.eos return self.eos
# @jit.to_static(input_spec=[ # @jit.to_static
# paddle.static.InputSpec(shape=[1, None, 80], dtype='float32'),
# paddle.static.InputSpec(shape=[1], dtype='int32'),
# -1,
# paddle.static.InputSpec(shape=[None, None, None, None], dtype='float32'),
# paddle.static.InputSpec(shape=[None, None, None, None], dtype='float32')])
def forward_encoder_chunk( def forward_encoder_chunk(
self, self,
xs: paddle.Tensor, xs: paddle.Tensor,
@ -658,10 +653,7 @@ class U2BaseModel(ASRInterface, nn.Layer):
""" """
return self.ctc.log_softmax(xs) return self.ctc.log_softmax(xs)
# @jit.to_static(input_spec=[ # @jit.to_static
# paddle.static.InputSpec(shape=[None, None], dtype='int64'),
# paddle.static.InputSpec(shape=[None], dtype='int64'),
# paddle.static.InputSpec(shape=[1, None, 512], dtype='float32')])
def forward_attention_decoder( def forward_attention_decoder(
self, self,
hyps: paddle.Tensor, hyps: paddle.Tensor,

Loading…
Cancel
Save