pull/578/head
Hui Zhang 4 years ago
parent 156ccb947b
commit 281d46dad2

@ -410,13 +410,11 @@ def ctc_loss(logits,
input_lengths, label_lengths) input_lengths, label_lengths)
loss_out = paddle.fluid.layers.squeeze(loss_out, [-1]) loss_out = paddle.fluid.layers.squeeze(loss_out, [-1])
logger.debug(f"warpctc loss: {loss_out}/{loss_out.shape} ")
assert reduction in ['mean', 'sum', 'none'] assert reduction in ['mean', 'sum', 'none']
if reduction == 'mean': if reduction == 'mean':
loss_out = paddle.mean(loss_out / label_lengths) loss_out = paddle.mean(loss_out / label_lengths)
elif reduction == 'sum': elif reduction == 'sum':
loss_out = paddle.sum(loss_out) loss_out = paddle.sum(loss_out)
logger.debug(f"ctc loss: {loss_out}")
return loss_out return loss_out

@ -89,8 +89,9 @@ class U2Trainer(Trainer):
if (batch_index + 1) % train_conf.accum_grad == 0: if (batch_index + 1) % train_conf.accum_grad == 0:
if dist.get_rank() == 0 and self.visualizer: if dist.get_rank() == 0 and self.visualizer:
losses_np.update({"lr": self.lr_scheduler()}) losses_np_v = losses_np.copy()
self.visualizer.add_scalars("step", losses_np, self.iteration) losses_np_v.update({"lr": self.lr_scheduler()})
self.visualizer.add_scalars("step", losses_np_v, self.iteration)
self.optimizer.step() self.optimizer.step()
self.optimizer.clear_grad() self.optimizer.clear_grad()
self.lr_scheduler.step() self.lr_scheduler.step()
@ -171,8 +172,9 @@ class U2Trainer(Trainer):
logger.info(msg) logger.info(msg)
if self.visualizer: if self.visualizer:
valid_losses.update({"lr": self.lr_scheduler()}) valid_losses_v = valid_losses.copy()
self.visualizer.add_scalars('epoch', valid_losses, self.epoch) valid_losses_v.update({"lr": self.lr_scheduler()})
self.visualizer.add_scalars('epoch', valid_losses_v, self.epoch)
return valid_losses return valid_losses
def setup_dataloader(self): def setup_dataloader(self):

@ -297,13 +297,13 @@ class ManifestDataset(Dataset):
else: else:
speech_segment = SpeechSegment.from_file(audio_file, transcript) speech_segment = SpeechSegment.from_file(audio_file, transcript)
load_wav_time = time.time() - start_time load_wav_time = time.time() - start_time
logger.debug(f"load wav time: {load_wav_time}") #logger.debug(f"load wav time: {load_wav_time}")
# audio augment # audio augment
start_time = time.time() start_time = time.time()
self._augmentation_pipeline.transform_audio(speech_segment) self._augmentation_pipeline.transform_audio(speech_segment)
audio_aug_time = time.time() - start_time audio_aug_time = time.time() - start_time
logger.debug(f"audio augmentation time: {audio_aug_time}") #logger.debug(f"audio augmentation time: {audio_aug_time}")
start_time = time.time() start_time = time.time()
specgram, transcript_part = self._speech_featurizer.featurize( specgram, transcript_part = self._speech_featurizer.featurize(
@ -311,13 +311,13 @@ class ManifestDataset(Dataset):
if self._normalizer: if self._normalizer:
specgram = self._normalizer.apply(specgram) specgram = self._normalizer.apply(specgram)
feature_time = time.time() - start_time feature_time = time.time() - start_time
logger.debug(f"audio & test feature time: {feature_time}") #logger.debug(f"audio & test feature time: {feature_time}")
# specgram augment # specgram augment
start_time = time.time() start_time = time.time()
specgram = self._augmentation_pipeline.transform_feature(specgram) specgram = self._augmentation_pipeline.transform_feature(specgram)
feature_aug_time = time.time() - start_time feature_aug_time = time.time() - start_time
logger.debug(f"audio feature augmentation time: {feature_aug_time}") #logger.debug(f"audio feature augmentation time: {feature_aug_time}")
return specgram, transcript_part return specgram, transcript_part
def _instance_reader_creator(self, manifest): def _instance_reader_creator(self, manifest):

@ -159,7 +159,7 @@ class U2BaseModel(nn.Module):
start = time.time() start = time.time()
encoder_out, encoder_mask = self.encoder(speech, speech_lengths) encoder_out, encoder_mask = self.encoder(speech, speech_lengths)
encoder_time = time.time() - start encoder_time = time.time() - start
logger.debug(f"encoder time: {encoder_time}") #logger.debug(f"encoder time: {encoder_time}")
#TODO(Hui Zhang): sum not support bool type #TODO(Hui Zhang): sum not support bool type
#encoder_out_lens = encoder_mask.squeeze(1).sum(1) #[B, 1, T] -> [B] #encoder_out_lens = encoder_mask.squeeze(1).sum(1) #[B, 1, T] -> [B]
encoder_out_lens = encoder_mask.squeeze(1).cast(paddle.int64).sum( encoder_out_lens = encoder_mask.squeeze(1).cast(paddle.int64).sum(
@ -172,7 +172,7 @@ class U2BaseModel(nn.Module):
loss_att, acc_att = self._calc_att_loss(encoder_out, encoder_mask, loss_att, acc_att = self._calc_att_loss(encoder_out, encoder_mask,
text, text_lengths) text, text_lengths)
decoder_time = time.time() - start decoder_time = time.time() - start
logger.debug(f"decoder time: {decoder_time}") #logger.debug(f"decoder time: {decoder_time}")
# 2b. CTC branch # 2b. CTC branch
loss_ctc = None loss_ctc = None
@ -181,7 +181,7 @@ class U2BaseModel(nn.Module):
loss_ctc = self.ctc(encoder_out, encoder_out_lens, text, loss_ctc = self.ctc(encoder_out, encoder_out_lens, text,
text_lengths) text_lengths)
ctc_time = time.time() - start ctc_time = time.time() - start
logger.debug(f"ctc time: {ctc_time}") #logger.debug(f"ctc time: {ctc_time}")
if loss_ctc is None: if loss_ctc is None:
loss = loss_att loss = loss_att

Loading…
Cancel
Save