|
|
|
@ -170,7 +170,7 @@ class U2STBaseModel(nn.Layer):
|
|
|
|
|
ys_in_lens = ys_pad_lens + 1
|
|
|
|
|
|
|
|
|
|
# 1. Forward decoder
|
|
|
|
|
decoder_out, _ = self.st_decoder(encoder_out, encoder_mask, ys_in_pad,
|
|
|
|
|
decoder_out, *_ = self.st_decoder(encoder_out, encoder_mask, ys_in_pad,
|
|
|
|
|
ys_in_lens)
|
|
|
|
|
|
|
|
|
|
# 2. Compute attention loss
|
|
|
|
@ -203,7 +203,7 @@ class U2STBaseModel(nn.Layer):
|
|
|
|
|
ys_in_lens = ys_pad_lens + 1
|
|
|
|
|
|
|
|
|
|
# 1. Forward decoder
|
|
|
|
|
decoder_out, _ = self.decoder(encoder_out, encoder_mask, ys_in_pad,
|
|
|
|
|
decoder_out, *_ = self.decoder(encoder_out, encoder_mask, ys_in_pad,
|
|
|
|
|
ys_in_lens)
|
|
|
|
|
|
|
|
|
|
# 2. Compute attention loss
|
|
|
|
|