diff --git a/paddlespeech/s2t/exps/u2/model.py b/paddlespeech/s2t/exps/u2/model.py index 95cb47f5a..26d83e738 100644 --- a/paddlespeech/s2t/exps/u2/model.py +++ b/paddlespeech/s2t/exps/u2/model.py @@ -82,6 +82,7 @@ class U2Trainer(Trainer): with context(): if scaler: scaler.scale(loss).backward() + scaler.unscale_(self.optimizer) else: loss.backward() layer_tools.print_grads(self.model, print_func=None) diff --git a/paddlespeech/s2t/training/trainer.py b/paddlespeech/s2t/training/trainer.py index d9ac6f8b0..a8f36f91b 100644 --- a/paddlespeech/s2t/training/trainer.py +++ b/paddlespeech/s2t/training/trainer.py @@ -11,6 +11,7 @@ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. +import os import sys import time from collections import OrderedDict @@ -189,8 +190,12 @@ class Trainer(): "step": self.iteration, "epoch": self.epoch, "lr": self.optimizer.get_lr(), - "scaler": self.scaler.state_dict() }) + if self.scaler: + scaler_path = os.path.join(self.checkpoint_dir, + "{}".format(self.epoch)) + '.scaler' + paddle.save(self.scaler.state_dict(), scaler_path) + self.checkpoint.save_parameters(self.checkpoint_dir, self.iteration if tag is None else tag, self.model, self.optimizer, infos) @@ -213,8 +218,13 @@ class Trainer(): # lr will resotre from optimizer ckpt self.iteration = infos["step"] self.epoch = infos["epoch"] - self.scaler = paddle.amp.GradScaler() - self.scaler.load_state_dict(infos["scaler"]) + + scaler_path = os.path.join(self.checkpoint_dir, + "{}".format(self.epoch)) + '.scaler' + if os.path.exists(scaler_path): + scaler_state_dict = paddle.load(scaler_path) + self.scaler.load_state_dict(scaler_state_dict) + scratch = False logger.info( f"Restore ckpt: epoch {self.epoch }, step {self.iteration}!")