|
|
|
@ -48,6 +48,24 @@ class Wav2Vec2ASRTrainer(Trainer):
|
|
|
|
|
super().__init__(config, args)
|
|
|
|
|
self.avg_train_loss = 0
|
|
|
|
|
|
|
|
|
|
def update_average(self, batch_index, loss, avg_loss):
|
|
|
|
|
"""Update running average of the loss.
|
|
|
|
|
Arguments
|
|
|
|
|
---------
|
|
|
|
|
loss : paddle.tensor
|
|
|
|
|
detached loss, a single float value.
|
|
|
|
|
avg_loss : float
|
|
|
|
|
current running average.
|
|
|
|
|
Returns
|
|
|
|
|
-------
|
|
|
|
|
avg_loss : float
|
|
|
|
|
The average loss.
|
|
|
|
|
"""
|
|
|
|
|
if paddle.isfinite(loss):
|
|
|
|
|
avg_loss -= avg_loss / (batch_index + 1)
|
|
|
|
|
avg_loss += float(loss) / (batch_index + 1)
|
|
|
|
|
return avg_loss
|
|
|
|
|
|
|
|
|
|
def train_batch(self, batch_index, batch, msg):
|
|
|
|
|
train_conf = self.config
|
|
|
|
|
start = time.time()
|
|
|
|
@ -59,11 +77,11 @@ class Wav2Vec2ASRTrainer(Trainer):
|
|
|
|
|
wav = wav[:, :, 0]
|
|
|
|
|
wav = self.speech_augmentation(wav, wavs_lens_rate)
|
|
|
|
|
loss = self.model(wav, wavs_lens_rate, target, target_lens_rate)
|
|
|
|
|
# pring(wav, wavs_lens_rate, target, target_lens_rate)
|
|
|
|
|
# loss div by `batch_size * accum_grad`
|
|
|
|
|
loss /= train_conf.accum_grad
|
|
|
|
|
|
|
|
|
|
losses_np = {'loss': float(loss) * train_conf.accum_grad}
|
|
|
|
|
self.avg_train_loss = self.update_average(batch_index, loss,
|
|
|
|
|
self.avg_train_loss)
|
|
|
|
|
|
|
|
|
|
# loss backward
|
|
|
|
|
if (batch_index + 1) % train_conf.accum_grad != 0:
|
|
|
|
@ -87,6 +105,8 @@ class Wav2Vec2ASRTrainer(Trainer):
|
|
|
|
|
self.optimizer.clear_grad()
|
|
|
|
|
self.lr_scheduler.step()
|
|
|
|
|
self.iteration += 1
|
|
|
|
|
|
|
|
|
|
losses_np = {'loss': float(self.avg_train_loss) * train_conf.accum_grad}
|
|
|
|
|
iteration_time = time.time() - start
|
|
|
|
|
for k, v in losses_np.items():
|
|
|
|
|
report(k, v)
|
|
|
|
|