add amp for U2 conformer.

pull/3167/head
zxcd 1 year ago
parent d3d86f59aa
commit fbd27aab41

@ -23,6 +23,7 @@ import jsonlines
import numpy as np
import paddle
from paddle import distributed as dist
from paddle.nn.utils import clip_grad_norm_
from paddlespeech.s2t.frontend.featurizer import TextFeaturizer
from paddlespeech.s2t.io.dataloader import DataLoaderFactory
@ -47,14 +48,16 @@ class U2Trainer(Trainer):
def __init__(self, config, args):
super().__init__(config, args)
def train_batch(self, batch_index, batch_data, msg):
def train_batch(self, batch_index, batch_data, scaler, msg):
train_conf = self.config
start = time.time()
# forward
utt, audio, audio_len, text, text_len = batch_data
loss, attention_loss, ctc_loss = self.model(audio, audio_len, text,
text_len)
with paddle.amp.auto_cast(
level=self.amp_level, enable=True if scaler else False):
loss, attention_loss, ctc_loss = self.model(audio, audio_len, text,
text_len)
# loss div by `batch_size * accum_grad`
loss /= train_conf.accum_grad
@ -77,12 +80,24 @@ class U2Trainer(Trainer):
# processes.
context = nullcontext
with context():
loss.backward()
if scaler:
scaler.scale(loss).backward()
else:
loss.backward()
layer_tools.print_grads(self.model, print_func=None)
# optimizer step
if (batch_index + 1) % train_conf.accum_grad == 0:
self.optimizer.step()
# do global grad clip
if train_conf.global_grad_clip != 0:
# need paddlepaddle==develop or paddlepaddle>=2.5
clip_grad_norm_(self.model.parameters(),
train_conf.global_grad_clip)
if scaler:
scaler.step(self.optimizer)
scaler.update()
else:
self.optimizer.step()
self.optimizer.clear_grad()
self.lr_scheduler.step()
self.iteration += 1
@ -173,7 +188,8 @@ class U2Trainer(Trainer):
report("epoch", self.epoch)
report('step', self.iteration)
report("lr", self.lr_scheduler())
self.train_batch(batch_index, batch, msg)
self.train_batch(batch_index, batch, self.scaler,
msg)
self.after_train_batch()
report('iter', batch_index + 1)
if not self.use_streamdata:
@ -253,6 +269,19 @@ class U2Trainer(Trainer):
model_conf.output_dim = self.test_loader.vocab_size
model = U2Model.from_config(model_conf)
# For Mixed Precision Training
self.use_amp = self.config.get("use_amp", True)
self.amp_level = self.config.get("amp_level", "O1")
if self.train and self.use_amp:
self.scaler = paddle.amp.GradScaler(
init_loss_scaling=self.config.get(
"scale_loss", 32768.0)) #amp default num 32768.0
#Set amp_level
if self.amp_level == 'O2':
model = paddle.amp.decorate(models=model, level=self.amp_level)
else:
self.scaler = None
if self.parallel:
model = paddle.DataParallel(model)
@ -290,7 +319,6 @@ class U2Trainer(Trainer):
scheduler_type = train_config.scheduler
scheduler_conf = train_config.scheduler_conf
return {
"grad_clip": train_config.global_grad_clip,
"weight_decay": optim_conf.weight_decay,
"learning_rate": lr_scheduler
if lr_scheduler else optim_conf.lr,

@ -110,6 +110,7 @@ class Trainer():
self.rank = dist.get_rank()
self.world_size = dist.get_world_size()
self._train = True
self.scaler = None
# print deps version
all_version()
@ -187,7 +188,8 @@ class Trainer():
infos.update({
"step": self.iteration,
"epoch": self.epoch,
"lr": self.optimizer.get_lr()
"lr": self.optimizer.get_lr(),
"scaler": self.scaler
})
self.checkpoint.save_parameters(self.checkpoint_dir, self.iteration
if tag is None else tag, self.model,
@ -211,6 +213,8 @@ class Trainer():
# lr will resotre from optimizer ckpt
self.iteration = infos["step"]
self.epoch = infos["epoch"]
self.scaler = paddle.amp.GradScaler()
self.scaler.load_state_dict(infos["scaler"])
scratch = False
logger.info(
f"Restore ckpt: epoch {self.epoch }, step {self.iteration}!")

Loading…
Cancel
Save