|
|
|
@ -18,12 +18,14 @@ import logging
|
|
|
|
|
import numpy as np
|
|
|
|
|
from collections import defaultdict
|
|
|
|
|
from pathlib import Path
|
|
|
|
|
from typing import Optional
|
|
|
|
|
from yacs.config import CfgNode
|
|
|
|
|
|
|
|
|
|
import paddle
|
|
|
|
|
from paddle import distributed as dist
|
|
|
|
|
from paddle.io import DataLoader
|
|
|
|
|
|
|
|
|
|
from deepspeech.training import Trainer
|
|
|
|
|
from deepspeech.training.trainer import Trainer
|
|
|
|
|
from deepspeech.training.gradclip import ClipGradByGlobalNormWithLog
|
|
|
|
|
from deepspeech.training.scheduler import WarmupLR
|
|
|
|
|
|
|
|
|
@ -77,7 +79,7 @@ class U2Trainer(Trainer):
|
|
|
|
|
self.model.train()
|
|
|
|
|
|
|
|
|
|
start = time.time()
|
|
|
|
|
loss = self.model(*batch_data)
|
|
|
|
|
loss, attention_loss, ctc_loss = self.model(*batch_data)
|
|
|
|
|
loss.backward()
|
|
|
|
|
layer_tools.print_grads(self.model, print_func=None)
|
|
|
|
|
if self.iteration % train_conf.accum_grad == 0:
|
|
|
|
@ -88,13 +90,15 @@ class U2Trainer(Trainer):
|
|
|
|
|
|
|
|
|
|
losses_np = {
|
|
|
|
|
'train_loss': float(loss),
|
|
|
|
|
'train_loss_div_batchsize':
|
|
|
|
|
float(loss) / self.config.data.batch_size
|
|
|
|
|
'train_att_loss': float(attention_loss),
|
|
|
|
|
'train_ctc_loss': float(ctc_loss),
|
|
|
|
|
}
|
|
|
|
|
msg = "Train: Rank: {}, ".format(dist.get_rank())
|
|
|
|
|
msg += "epoch: {}, ".format(self.epoch)
|
|
|
|
|
msg += "step: {}, ".format(self.iteration)
|
|
|
|
|
msg += "time: {:>.3f}s, ".format(iteration_time)
|
|
|
|
|
msg += f"batch size: {self.config.data.batch_size}, "
|
|
|
|
|
msg += f"accum: {train_config.accum_grad}, "
|
|
|
|
|
msg += ', '.join('{}: {:>.6f}'.format(k, v)
|
|
|
|
|
for k, v in losses_np.items())
|
|
|
|
|
if self.iteration % train_conf.log_interval == 0:
|
|
|
|
@ -113,11 +117,11 @@ class U2Trainer(Trainer):
|
|
|
|
|
f"Valid Total Examples: {len(self.valid_loader.dataset)}")
|
|
|
|
|
valid_losses = defaultdict(list)
|
|
|
|
|
for i, batch in enumerate(self.valid_loader):
|
|
|
|
|
loss = self.model(*batch)
|
|
|
|
|
total_loss, attention_loss, ctc_loss = self.model(*batch)
|
|
|
|
|
|
|
|
|
|
valid_losses['val_loss'].append(float(loss))
|
|
|
|
|
valid_losses['val_loss_div_batchsize'].append(
|
|
|
|
|
float(loss) / self.config.data.batch_size)
|
|
|
|
|
valid_losses['val_loss'].append(float(total_loss))
|
|
|
|
|
valid_losses['val_att_loss'].append(float(attention_loss))
|
|
|
|
|
valid_losses['val_ctc_loss'].append(float(ctc_loss))
|
|
|
|
|
|
|
|
|
|
# write visual log
|
|
|
|
|
valid_losses = {k: np.mean(v) for k, v in valid_losses.items()}
|
|
|
|
@ -137,13 +141,14 @@ class U2Trainer(Trainer):
|
|
|
|
|
|
|
|
|
|
def setup_dataloader(self):
|
|
|
|
|
config = self.config.clone()
|
|
|
|
|
config.defrost()
|
|
|
|
|
config.data.keep_transcription_text = False
|
|
|
|
|
|
|
|
|
|
# train/valid dataset, return token ids
|
|
|
|
|
config.data.manfiest = config.data.train_manifest
|
|
|
|
|
config.data.manifest = config.data.train_manifest
|
|
|
|
|
train_dataset = ManifestDataset.from_config(config)
|
|
|
|
|
|
|
|
|
|
config.data.manfiest = config.data.dev_manifest
|
|
|
|
|
config.data.manifest = config.data.dev_manifest
|
|
|
|
|
config.data.augmentation_config = ""
|
|
|
|
|
dev_dataset = ManifestDataset.from_config(config)
|
|
|
|
|
|
|
|
|
@ -181,7 +186,7 @@ class U2Trainer(Trainer):
|
|
|
|
|
# test dataset, return raw text
|
|
|
|
|
config.data.keep_transcription_text = True
|
|
|
|
|
config.data.augmentation_config = ""
|
|
|
|
|
config.data.manfiest = config.data.test_manifest
|
|
|
|
|
config.data.manifest = config.data.test_manifest
|
|
|
|
|
test_dataset = ManifestDataset.from_config(config)
|
|
|
|
|
# return text ord id
|
|
|
|
|
self.test_loader = DataLoader(
|
|
|
|
@ -193,10 +198,12 @@ class U2Trainer(Trainer):
|
|
|
|
|
self.logger.info("Setup train/valid/test Dataloader!")
|
|
|
|
|
|
|
|
|
|
def setup_model(self):
|
|
|
|
|
config = self.config.clone()
|
|
|
|
|
config = self.config
|
|
|
|
|
model_conf = config.model
|
|
|
|
|
model_conf.defrost()
|
|
|
|
|
model_conf.input_dim = self.train_loader.dataset.feature_size
|
|
|
|
|
model_conf.output_dim = self.train_loader.dataset.vocab_size
|
|
|
|
|
model_conf.freeze()
|
|
|
|
|
model = U2Model.from_config(model_conf)
|
|
|
|
|
|
|
|
|
|
if self.parallel:
|
|
|
|
@ -206,12 +213,12 @@ class U2Trainer(Trainer):
|
|
|
|
|
|
|
|
|
|
train_config = config.training
|
|
|
|
|
optim_type = train_config.optim
|
|
|
|
|
optim_conf = train_config.train_config
|
|
|
|
|
optim_conf = train_config.optim_conf
|
|
|
|
|
scheduler_type = train_config.scheduler
|
|
|
|
|
scheduler_conf = train_config.scheduler_conf
|
|
|
|
|
|
|
|
|
|
grad_clip = ClipGradByGlobalNormWithLog(train_config.global_grad_clip)
|
|
|
|
|
weight_decay = paddle.regularizer.L2Decay(train_config.weight_decay)
|
|
|
|
|
weight_decay = paddle.regularizer.L2Decay(optim_conf.weight_decay)
|
|
|
|
|
|
|
|
|
|
if scheduler_type == 'expdecaylr':
|
|
|
|
|
lr_scheduler = paddle.optimizer.lr.ExponentialDecay(
|
|
|
|
|