|
|
|
@ -15,6 +15,7 @@
|
|
|
|
|
import os
|
|
|
|
|
import time
|
|
|
|
|
from collections import defaultdict
|
|
|
|
|
from contextlib import nullcontext
|
|
|
|
|
from pathlib import Path
|
|
|
|
|
from typing import Optional
|
|
|
|
|
|
|
|
|
@ -65,29 +66,51 @@ class DeepSpeech2Trainer(Trainer):
|
|
|
|
|
super().__init__(config, args)
|
|
|
|
|
|
|
|
|
|
def train_batch(self, batch_index, batch_data, msg):
|
|
|
|
|
train_conf = self.config.training
|
|
|
|
|
start = time.time()
|
|
|
|
|
|
|
|
|
|
# forward
|
|
|
|
|
utt, audio, audio_len, text, text_len = batch_data
|
|
|
|
|
loss = self.model(audio, audio_len, text, text_len)
|
|
|
|
|
loss.backward()
|
|
|
|
|
layer_tools.print_grads(self.model, print_func=None)
|
|
|
|
|
self.optimizer.step()
|
|
|
|
|
self.optimizer.clear_grad()
|
|
|
|
|
iteration_time = time.time() - start
|
|
|
|
|
|
|
|
|
|
losses_np = {
|
|
|
|
|
'train_loss': float(loss),
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
# loss backward
|
|
|
|
|
if (batch_index + 1) % train_conf.accum_grad != 0:
|
|
|
|
|
# Disable gradient synchronizations across DDP processes.
|
|
|
|
|
# Within this context, gradients will be accumulated on module
|
|
|
|
|
# variables, which will later be synchronized.
|
|
|
|
|
context = self.model.no_sync
|
|
|
|
|
else:
|
|
|
|
|
# Used for single gpu training and DDP gradient synchronization
|
|
|
|
|
# processes.
|
|
|
|
|
context = nullcontext
|
|
|
|
|
|
|
|
|
|
with context():
|
|
|
|
|
loss.backward()
|
|
|
|
|
layer_tools.print_grads(self.model, print_func=None)
|
|
|
|
|
|
|
|
|
|
# optimizer step
|
|
|
|
|
if (batch_index + 1) % train_conf.accum_grad == 0:
|
|
|
|
|
self.optimizer.step()
|
|
|
|
|
self.optimizer.clear_grad()
|
|
|
|
|
self.iteration += 1
|
|
|
|
|
|
|
|
|
|
iteration_time = time.time() - start
|
|
|
|
|
|
|
|
|
|
msg += "train time: {:>.3f}s, ".format(iteration_time)
|
|
|
|
|
msg += "batch size: {}, ".format(self.config.collator.batch_size)
|
|
|
|
|
msg += "accum: {}, ".format(train_conf.accum_grad)
|
|
|
|
|
msg += ', '.join('{}: {:>.6f}'.format(k, v)
|
|
|
|
|
for k, v in losses_np.items())
|
|
|
|
|
logger.info(msg)
|
|
|
|
|
|
|
|
|
|
if dist.get_rank() == 0 and self.visualizer:
|
|
|
|
|
for k, v in losses_np.items():
|
|
|
|
|
# `step -1` since we update `step` after optimizer.step().
|
|
|
|
|
self.visualizer.add_scalar("train/{}".format(k), v,
|
|
|
|
|
self.iteration)
|
|
|
|
|
self.iteration += 1
|
|
|
|
|
self.iteration - 1)
|
|
|
|
|
|
|
|
|
|
@paddle.no_grad()
|
|
|
|
|
def valid(self):
|
|
|
|
|