From 309c8d70d9e7168eac597a5ffb030fc6703d7e87 Mon Sep 17 00:00:00 2001 From: Hui Zhang Date: Tue, 20 Sep 2022 12:56:07 +0000 Subject: [PATCH] add reverse weight --- paddlespeech/s2t/exps/u2/model.py | 4 +++- paddlespeech/s2t/models/u2/u2.py | 2 +- 2 files changed, 4 insertions(+), 2 deletions(-) diff --git a/paddlespeech/s2t/exps/u2/model.py b/paddlespeech/s2t/exps/u2/model.py index 54810f22f..64b6c8df6 100644 --- a/paddlespeech/s2t/exps/u2/model.py +++ b/paddlespeech/s2t/exps/u2/model.py @@ -520,6 +520,7 @@ class U2Tester(U2Trainer): infer_model.ctc_activation, input_spec=input_spec) ######################### infer_model.forward_attention_decoder ######################## + reverse_weight = 0.3 input_spec = [ # hyps, (B, U) paddle.static.InputSpec(shape=[None, None], dtype='int64'), @@ -527,7 +528,8 @@ class U2Tester(U2Trainer): paddle.static.InputSpec(shape=[None], dtype='int64'), # encoder_out, (B,T,D) paddle.static.InputSpec( - shape=[batch_size, None, model_size], dtype='float32') + shape=[batch_size, None, model_size], dtype='float32'), + reverse_weight ] infer_model.forward_attention_decoder = paddle.jit.to_static( infer_model.forward_attention_decoder, input_spec=input_spec) diff --git a/paddlespeech/s2t/models/u2/u2.py b/paddlespeech/s2t/models/u2/u2.py index d699b684b..1681bf1d9 100644 --- a/paddlespeech/s2t/models/u2/u2.py +++ b/paddlespeech/s2t/models/u2/u2.py @@ -706,7 +706,7 @@ class U2BaseModel(ASRInterface, nn.Layer): hyps: paddle.Tensor, hyps_lens: paddle.Tensor, encoder_out: paddle.Tensor, - reverse_weight: float=0.0, ) -> paddle.Tensor: + reverse_weight: float=0.0) -> paddle.Tensor: """ Export interface for c++ call, forward decoder with multiple hypothesis from ctc prefix beam search and one encoder output Args: