|
|
@ -159,7 +159,7 @@ class U2BaseModel(nn.Module):
|
|
|
|
start = time.time()
|
|
|
|
start = time.time()
|
|
|
|
encoder_out, encoder_mask = self.encoder(speech, speech_lengths)
|
|
|
|
encoder_out, encoder_mask = self.encoder(speech, speech_lengths)
|
|
|
|
encoder_time = time.time() - start
|
|
|
|
encoder_time = time.time() - start
|
|
|
|
logger.debug(f"encoder time: {encoder_time}")
|
|
|
|
#logger.debug(f"encoder time: {encoder_time}")
|
|
|
|
#TODO(Hui Zhang): sum not support bool type
|
|
|
|
#TODO(Hui Zhang): sum not support bool type
|
|
|
|
#encoder_out_lens = encoder_mask.squeeze(1).sum(1) #[B, 1, T] -> [B]
|
|
|
|
#encoder_out_lens = encoder_mask.squeeze(1).sum(1) #[B, 1, T] -> [B]
|
|
|
|
encoder_out_lens = encoder_mask.squeeze(1).cast(paddle.int64).sum(
|
|
|
|
encoder_out_lens = encoder_mask.squeeze(1).cast(paddle.int64).sum(
|
|
|
@ -172,7 +172,7 @@ class U2BaseModel(nn.Module):
|
|
|
|
loss_att, acc_att = self._calc_att_loss(encoder_out, encoder_mask,
|
|
|
|
loss_att, acc_att = self._calc_att_loss(encoder_out, encoder_mask,
|
|
|
|
text, text_lengths)
|
|
|
|
text, text_lengths)
|
|
|
|
decoder_time = time.time() - start
|
|
|
|
decoder_time = time.time() - start
|
|
|
|
logger.debug(f"decoder time: {decoder_time}")
|
|
|
|
#logger.debug(f"decoder time: {decoder_time}")
|
|
|
|
|
|
|
|
|
|
|
|
# 2b. CTC branch
|
|
|
|
# 2b. CTC branch
|
|
|
|
loss_ctc = None
|
|
|
|
loss_ctc = None
|
|
|
@ -181,7 +181,7 @@ class U2BaseModel(nn.Module):
|
|
|
|
loss_ctc = self.ctc(encoder_out, encoder_out_lens, text,
|
|
|
|
loss_ctc = self.ctc(encoder_out, encoder_out_lens, text,
|
|
|
|
text_lengths)
|
|
|
|
text_lengths)
|
|
|
|
ctc_time = time.time() - start
|
|
|
|
ctc_time = time.time() - start
|
|
|
|
logger.debug(f"ctc time: {ctc_time}")
|
|
|
|
#logger.debug(f"ctc time: {ctc_time}")
|
|
|
|
|
|
|
|
|
|
|
|
if loss_ctc is None:
|
|
|
|
if loss_ctc is None:
|
|
|
|
loss = loss_att
|
|
|
|
loss = loss_att
|
|
|
|