pull/3950/head
megemini 9 months ago
parent 73beb187da
commit f3eb9508de

@ -404,6 +404,12 @@ class DataLoaderFactory():
config['subsampling_factor'] = 1 config['subsampling_factor'] = 1
config['num_encs'] = 1 config['num_encs'] = 1
config['shortest_first'] = False config['shortest_first'] = False
config['minibatches'] = 0
config['batch_count'] = 'auto'
config['batch_bins'] = 0
config['batch_frames_in'] = 0
config['batch_frames_out'] = 0
config['batch_frames_inout'] = 0
elif mode == 'valid': elif mode == 'valid':
config['manifest'] = config.dev_manifest config['manifest'] = config.dev_manifest
config['train_mode'] = False config['train_mode'] = False

@ -170,7 +170,7 @@ class U2STBaseModel(nn.Layer):
ys_in_lens = ys_pad_lens + 1 ys_in_lens = ys_pad_lens + 1
# 1. Forward decoder # 1. Forward decoder
decoder_out, _ = self.st_decoder(encoder_out, encoder_mask, ys_in_pad, decoder_out, *_ = self.st_decoder(encoder_out, encoder_mask, ys_in_pad,
ys_in_lens) ys_in_lens)
# 2. Compute attention loss # 2. Compute attention loss
@ -203,7 +203,7 @@ class U2STBaseModel(nn.Layer):
ys_in_lens = ys_pad_lens + 1 ys_in_lens = ys_pad_lens + 1
# 1. Forward decoder # 1. Forward decoder
decoder_out, _ = self.decoder(encoder_out, encoder_mask, ys_in_pad, decoder_out, *_ = self.decoder(encoder_out, encoder_mask, ys_in_pad,
ys_in_lens) ys_in_lens)
# 2. Compute attention loss # 2. Compute attention loss

@ -110,14 +110,14 @@ class TransformerDecoder(BatchScorerInterface, nn.Layer):
concat_after=concat_after, ) for _ in range(num_blocks) concat_after=concat_after, ) for _ in range(num_blocks)
]) ])
def forward( def forward(self,
self,
memory: paddle.Tensor, memory: paddle.Tensor,
memory_mask: paddle.Tensor, memory_mask: paddle.Tensor,
ys_in_pad: paddle.Tensor, ys_in_pad: paddle.Tensor,
ys_in_lens: paddle.Tensor, ys_in_lens: paddle.Tensor,
r_ys_in_pad: paddle.Tensor=paddle.empty([0]), r_ys_in_pad: paddle.Tensor=paddle.empty([0]),
reverse_weight: float=0.0) -> Tuple[paddle.Tensor, paddle.Tensor]: reverse_weight: float=0.0
) -> Tuple[paddle.Tensor, paddle.Tensor, paddle.Tensor]:
"""Forward decoder. """Forward decoder.
Args: Args:
memory: encoded memory, float32 (batch, maxlen_in, feat) memory: encoded memory, float32 (batch, maxlen_in, feat)

Loading…
Cancel
Save