diff --git a/paddlespeech/s2t/exps/u2/model.py b/paddlespeech/s2t/exps/u2/model.py index 45fbcb404..dae618db6 100644 --- a/paddlespeech/s2t/exps/u2/model.py +++ b/paddlespeech/s2t/exps/u2/model.py @@ -513,9 +513,9 @@ class U2Tester(U2Trainer): infer_model.forward_attention_decoder, input_spec=input_spec) ######################### infer_model.ctc_activation ######################## - # TODO: 512(encoder_output) be configable input_spec = [ - paddle.static.InputSpec(shape=[1, None, 512], dtype='float32') + # encoder_out, (B,T,D) + paddle.static.InputSpec(shape=[batch_size, None, model_size], dtype='float32') ] infer_model.ctc_activation = paddle.jit.to_static( infer_model.ctc_activation, input_spec=input_spec) diff --git a/paddlespeech/s2t/models/u2/u2.py b/paddlespeech/s2t/models/u2/u2.py index a1daccf18..149170ed6 100644 --- a/paddlespeech/s2t/models/u2/u2.py +++ b/paddlespeech/s2t/models/u2/u2.py @@ -599,12 +599,7 @@ class U2BaseModel(ASRInterface, nn.Layer): """ return self.eos - # @jit.to_static(input_spec=[ - # paddle.static.InputSpec(shape=[1, None, 80], dtype='float32'), - # paddle.static.InputSpec(shape=[1], dtype='int32'), - # -1, - # paddle.static.InputSpec(shape=[None, None, None, None], dtype='float32'), - # paddle.static.InputSpec(shape=[None, None, None, None], dtype='float32')]) + # @jit.to_static def forward_encoder_chunk( self, xs: paddle.Tensor, @@ -658,10 +653,7 @@ class U2BaseModel(ASRInterface, nn.Layer): """ return self.ctc.log_softmax(xs) - # @jit.to_static(input_spec=[ - # paddle.static.InputSpec(shape=[None, None], dtype='int64'), - # paddle.static.InputSpec(shape=[None], dtype='int64'), - # paddle.static.InputSpec(shape=[1, None, 512], dtype='float32')]) + # @jit.to_static def forward_attention_decoder( self, hyps: paddle.Tensor,