|
|
@ -170,8 +170,8 @@ class U2STBaseModel(nn.Layer):
|
|
|
|
ys_in_lens = ys_pad_lens + 1
|
|
|
|
ys_in_lens = ys_pad_lens + 1
|
|
|
|
|
|
|
|
|
|
|
|
# 1. Forward decoder
|
|
|
|
# 1. Forward decoder
|
|
|
|
decoder_out, _ = self.st_decoder(encoder_out, encoder_mask, ys_in_pad,
|
|
|
|
decoder_out, *_ = self.st_decoder(encoder_out, encoder_mask, ys_in_pad,
|
|
|
|
ys_in_lens)
|
|
|
|
ys_in_lens)
|
|
|
|
|
|
|
|
|
|
|
|
# 2. Compute attention loss
|
|
|
|
# 2. Compute attention loss
|
|
|
|
loss_att = self.criterion_att(decoder_out, ys_out_pad)
|
|
|
|
loss_att = self.criterion_att(decoder_out, ys_out_pad)
|
|
|
@ -203,8 +203,8 @@ class U2STBaseModel(nn.Layer):
|
|
|
|
ys_in_lens = ys_pad_lens + 1
|
|
|
|
ys_in_lens = ys_pad_lens + 1
|
|
|
|
|
|
|
|
|
|
|
|
# 1. Forward decoder
|
|
|
|
# 1. Forward decoder
|
|
|
|
decoder_out, _ = self.decoder(encoder_out, encoder_mask, ys_in_pad,
|
|
|
|
decoder_out, *_ = self.decoder(encoder_out, encoder_mask, ys_in_pad,
|
|
|
|
ys_in_lens)
|
|
|
|
ys_in_lens)
|
|
|
|
|
|
|
|
|
|
|
|
# 2. Compute attention loss
|
|
|
|
# 2. Compute attention loss
|
|
|
|
loss_att = self.criterion_att(decoder_out, ys_out_pad)
|
|
|
|
loss_att = self.criterion_att(decoder_out, ys_out_pad)
|
|
|
|