|
|
|
@ -520,6 +520,7 @@ class U2Tester(U2Trainer):
|
|
|
|
|
infer_model.ctc_activation, input_spec=input_spec)
|
|
|
|
|
|
|
|
|
|
######################### infer_model.forward_attention_decoder ########################
|
|
|
|
|
reverse_weight = 0.3
|
|
|
|
|
input_spec = [
|
|
|
|
|
# hyps, (B, U)
|
|
|
|
|
paddle.static.InputSpec(shape=[None, None], dtype='int64'),
|
|
|
|
@ -527,7 +528,8 @@ class U2Tester(U2Trainer):
|
|
|
|
|
paddle.static.InputSpec(shape=[None], dtype='int64'),
|
|
|
|
|
# encoder_out, (B,T,D)
|
|
|
|
|
paddle.static.InputSpec(
|
|
|
|
|
shape=[batch_size, None, model_size], dtype='float32')
|
|
|
|
|
shape=[batch_size, None, model_size], dtype='float32'),
|
|
|
|
|
reverse_weight
|
|
|
|
|
]
|
|
|
|
|
infer_model.forward_attention_decoder = paddle.jit.to_static(
|
|
|
|
|
infer_model.forward_attention_decoder, input_spec=input_spec)
|
|
|
|
|