add reverse weight

pull/2425/head
Hui Zhang 2 years ago
parent 9b66680ea4
commit 309c8d70d9

@ -520,6 +520,7 @@ class U2Tester(U2Trainer):
infer_model.ctc_activation, input_spec=input_spec) infer_model.ctc_activation, input_spec=input_spec)
######################### infer_model.forward_attention_decoder ######################## ######################### infer_model.forward_attention_decoder ########################
reverse_weight = 0.3
input_spec = [ input_spec = [
# hyps, (B, U) # hyps, (B, U)
paddle.static.InputSpec(shape=[None, None], dtype='int64'), paddle.static.InputSpec(shape=[None, None], dtype='int64'),
@ -527,7 +528,8 @@ class U2Tester(U2Trainer):
paddle.static.InputSpec(shape=[None], dtype='int64'), paddle.static.InputSpec(shape=[None], dtype='int64'),
# encoder_out, (B,T,D) # encoder_out, (B,T,D)
paddle.static.InputSpec( paddle.static.InputSpec(
shape=[batch_size, None, model_size], dtype='float32') shape=[batch_size, None, model_size], dtype='float32'),
reverse_weight
] ]
infer_model.forward_attention_decoder = paddle.jit.to_static( infer_model.forward_attention_decoder = paddle.jit.to_static(
infer_model.forward_attention_decoder, input_spec=input_spec) infer_model.forward_attention_decoder, input_spec=input_spec)

@ -706,7 +706,7 @@ class U2BaseModel(ASRInterface, nn.Layer):
hyps: paddle.Tensor, hyps: paddle.Tensor,
hyps_lens: paddle.Tensor, hyps_lens: paddle.Tensor,
encoder_out: paddle.Tensor, encoder_out: paddle.Tensor,
reverse_weight: float=0.0, ) -> paddle.Tensor: reverse_weight: float=0.0) -> paddle.Tensor:
""" Export interface for c++ call, forward decoder with multiple """ Export interface for c++ call, forward decoder with multiple
hypothesis from ctc prefix beam search and one encoder output hypothesis from ctc prefix beam search and one encoder output
Args: Args:

Loading…
Cancel
Save