diff --git a/paddlespeech/s2t/exps/u2/model.py b/paddlespeech/s2t/exps/u2/model.py index 5b7654d4..95cb47f5 100644 --- a/paddlespeech/s2t/exps/u2/model.py +++ b/paddlespeech/s2t/exps/u2/model.py @@ -23,6 +23,7 @@ import jsonlines import numpy as np import paddle from paddle import distributed as dist +from paddle.nn.utils import clip_grad_norm_ from paddlespeech.s2t.frontend.featurizer import TextFeaturizer from paddlespeech.s2t.io.dataloader import DataLoaderFactory @@ -47,14 +48,16 @@ class U2Trainer(Trainer): def __init__(self, config, args): super().__init__(config, args) - def train_batch(self, batch_index, batch_data, msg): + def train_batch(self, batch_index, batch_data, scaler, msg): train_conf = self.config start = time.time() # forward utt, audio, audio_len, text, text_len = batch_data - loss, attention_loss, ctc_loss = self.model(audio, audio_len, text, - text_len) + with paddle.amp.auto_cast( + level=self.amp_level, enable=True if scaler else False): + loss, attention_loss, ctc_loss = self.model(audio, audio_len, text, + text_len) # loss div by `batch_size * accum_grad` loss /= train_conf.accum_grad @@ -77,12 +80,24 @@ class U2Trainer(Trainer): # processes. context = nullcontext with context(): - loss.backward() + if scaler: + scaler.scale(loss).backward() + else: + loss.backward() layer_tools.print_grads(self.model, print_func=None) # optimizer step if (batch_index + 1) % train_conf.accum_grad == 0: - self.optimizer.step() + # do global grad clip + if train_conf.global_grad_clip != 0: + # need paddlepaddle==develop or paddlepaddle>=2.5 + clip_grad_norm_(self.model.parameters(), + train_conf.global_grad_clip) + if scaler: + scaler.step(self.optimizer) + scaler.update() + else: + self.optimizer.step() self.optimizer.clear_grad() self.lr_scheduler.step() self.iteration += 1 @@ -173,7 +188,8 @@ class U2Trainer(Trainer): report("epoch", self.epoch) report('step', self.iteration) report("lr", self.lr_scheduler()) - self.train_batch(batch_index, batch, msg) + self.train_batch(batch_index, batch, self.scaler, + msg) self.after_train_batch() report('iter', batch_index + 1) if not self.use_streamdata: @@ -253,6 +269,19 @@ class U2Trainer(Trainer): model_conf.output_dim = self.test_loader.vocab_size model = U2Model.from_config(model_conf) + + # For Mixed Precision Training + self.use_amp = self.config.get("use_amp", True) + self.amp_level = self.config.get("amp_level", "O1") + if self.train and self.use_amp: + self.scaler = paddle.amp.GradScaler( + init_loss_scaling=self.config.get( + "scale_loss", 32768.0)) #amp default num 32768.0 + #Set amp_level + if self.amp_level == 'O2': + model = paddle.amp.decorate(models=model, level=self.amp_level) + else: + self.scaler = None if self.parallel: model = paddle.DataParallel(model) @@ -290,7 +319,6 @@ class U2Trainer(Trainer): scheduler_type = train_config.scheduler scheduler_conf = train_config.scheduler_conf return { - "grad_clip": train_config.global_grad_clip, "weight_decay": optim_conf.weight_decay, "learning_rate": lr_scheduler if lr_scheduler else optim_conf.lr, diff --git a/paddlespeech/s2t/training/trainer.py b/paddlespeech/s2t/training/trainer.py index 4a69d78a..53a5d03f 100644 --- a/paddlespeech/s2t/training/trainer.py +++ b/paddlespeech/s2t/training/trainer.py @@ -110,6 +110,7 @@ class Trainer(): self.rank = dist.get_rank() self.world_size = dist.get_world_size() self._train = True + self.scaler = None # print deps version all_version() @@ -187,7 +188,8 @@ class Trainer(): infos.update({ "step": self.iteration, "epoch": self.epoch, - "lr": self.optimizer.get_lr() + "lr": self.optimizer.get_lr(), + "scaler": self.scaler }) self.checkpoint.save_parameters(self.checkpoint_dir, self.iteration if tag is None else tag, self.model, @@ -211,6 +213,8 @@ class Trainer(): # lr will resotre from optimizer ckpt self.iteration = infos["step"] self.epoch = infos["epoch"] + self.scaler = paddle.amp.GradScaler() + self.scaler.load_state_dict(infos["scaler"]) scratch = False logger.info( f"Restore ckpt: epoch {self.epoch }, step {self.iteration}!")