diff --git a/paddlespeech/s2t/exps/u2/model.py b/paddlespeech/s2t/exps/u2/model.py index 26d83e738..11dd0b065 100644 --- a/paddlespeech/s2t/exps/u2/model.py +++ b/paddlespeech/s2t/exps/u2/model.py @@ -82,7 +82,6 @@ class U2Trainer(Trainer): with context(): if scaler: scaler.scale(loss).backward() - scaler.unscale_(self.optimizer) else: loss.backward() layer_tools.print_grads(self.model, print_func=None) @@ -91,6 +90,8 @@ class U2Trainer(Trainer): if (batch_index + 1) % train_conf.accum_grad == 0: # do global grad clip if train_conf.global_grad_clip != 0: + if scaler: + scaler.unscale_(self.optimizer) # need paddlepaddle==develop or paddlepaddle>=2.5 clip_grad_norm_(self.model.parameters(), train_conf.global_grad_clip)