|
|
|
@ -23,6 +23,7 @@ import jsonlines
|
|
|
|
|
import numpy as np
|
|
|
|
|
import paddle
|
|
|
|
|
from paddle import distributed as dist
|
|
|
|
|
from paddle.nn.utils import clip_grad_norm_
|
|
|
|
|
|
|
|
|
|
from paddlespeech.s2t.frontend.featurizer import TextFeaturizer
|
|
|
|
|
from paddlespeech.s2t.io.dataloader import DataLoaderFactory
|
|
|
|
@ -47,14 +48,16 @@ class U2Trainer(Trainer):
|
|
|
|
|
def __init__(self, config, args):
|
|
|
|
|
super().__init__(config, args)
|
|
|
|
|
|
|
|
|
|
def train_batch(self, batch_index, batch_data, msg):
|
|
|
|
|
def train_batch(self, batch_index, batch_data, scaler, msg):
|
|
|
|
|
train_conf = self.config
|
|
|
|
|
start = time.time()
|
|
|
|
|
|
|
|
|
|
# forward
|
|
|
|
|
utt, audio, audio_len, text, text_len = batch_data
|
|
|
|
|
loss, attention_loss, ctc_loss = self.model(audio, audio_len, text,
|
|
|
|
|
text_len)
|
|
|
|
|
with paddle.amp.auto_cast(
|
|
|
|
|
level=self.amp_level, enable=True if scaler else False):
|
|
|
|
|
loss, attention_loss, ctc_loss = self.model(audio, audio_len, text,
|
|
|
|
|
text_len)
|
|
|
|
|
|
|
|
|
|
# loss div by `batch_size * accum_grad`
|
|
|
|
|
loss /= train_conf.accum_grad
|
|
|
|
@ -77,12 +80,24 @@ class U2Trainer(Trainer):
|
|
|
|
|
# processes.
|
|
|
|
|
context = nullcontext
|
|
|
|
|
with context():
|
|
|
|
|
loss.backward()
|
|
|
|
|
if scaler:
|
|
|
|
|
scaler.scale(loss).backward()
|
|
|
|
|
else:
|
|
|
|
|
loss.backward()
|
|
|
|
|
layer_tools.print_grads(self.model, print_func=None)
|
|
|
|
|
|
|
|
|
|
# optimizer step
|
|
|
|
|
if (batch_index + 1) % train_conf.accum_grad == 0:
|
|
|
|
|
self.optimizer.step()
|
|
|
|
|
# do global grad clip
|
|
|
|
|
if train_conf.global_grad_clip != 0:
|
|
|
|
|
# need paddlepaddle==develop or paddlepaddle>=2.5
|
|
|
|
|
clip_grad_norm_(self.model.parameters(),
|
|
|
|
|
train_conf.global_grad_clip)
|
|
|
|
|
if scaler:
|
|
|
|
|
scaler.step(self.optimizer)
|
|
|
|
|
scaler.update()
|
|
|
|
|
else:
|
|
|
|
|
self.optimizer.step()
|
|
|
|
|
self.optimizer.clear_grad()
|
|
|
|
|
self.lr_scheduler.step()
|
|
|
|
|
self.iteration += 1
|
|
|
|
@ -173,7 +188,8 @@ class U2Trainer(Trainer):
|
|
|
|
|
report("epoch", self.epoch)
|
|
|
|
|
report('step', self.iteration)
|
|
|
|
|
report("lr", self.lr_scheduler())
|
|
|
|
|
self.train_batch(batch_index, batch, msg)
|
|
|
|
|
self.train_batch(batch_index, batch, self.scaler,
|
|
|
|
|
msg)
|
|
|
|
|
self.after_train_batch()
|
|
|
|
|
report('iter', batch_index + 1)
|
|
|
|
|
if not self.use_streamdata:
|
|
|
|
@ -253,6 +269,19 @@ class U2Trainer(Trainer):
|
|
|
|
|
model_conf.output_dim = self.test_loader.vocab_size
|
|
|
|
|
|
|
|
|
|
model = U2Model.from_config(model_conf)
|
|
|
|
|
|
|
|
|
|
# For Mixed Precision Training
|
|
|
|
|
self.use_amp = self.config.get("use_amp", True)
|
|
|
|
|
self.amp_level = self.config.get("amp_level", "O1")
|
|
|
|
|
if self.train and self.use_amp:
|
|
|
|
|
self.scaler = paddle.amp.GradScaler(
|
|
|
|
|
init_loss_scaling=self.config.get(
|
|
|
|
|
"scale_loss", 32768.0)) #amp default num 32768.0
|
|
|
|
|
#Set amp_level
|
|
|
|
|
if self.amp_level == 'O2':
|
|
|
|
|
model = paddle.amp.decorate(models=model, level=self.amp_level)
|
|
|
|
|
else:
|
|
|
|
|
self.scaler = None
|
|
|
|
|
if self.parallel:
|
|
|
|
|
model = paddle.DataParallel(model)
|
|
|
|
|
|
|
|
|
@ -290,7 +319,6 @@ class U2Trainer(Trainer):
|
|
|
|
|
scheduler_type = train_config.scheduler
|
|
|
|
|
scheduler_conf = train_config.scheduler_conf
|
|
|
|
|
return {
|
|
|
|
|
"grad_clip": train_config.global_grad_clip,
|
|
|
|
|
"weight_decay": optim_conf.weight_decay,
|
|
|
|
|
"learning_rate": lr_scheduler
|
|
|
|
|
if lr_scheduler else optim_conf.lr,
|
|
|
|
|