From 5186319f48d0cd631a48f26ff9fc94f5fc4ff3f0 Mon Sep 17 00:00:00 2001 From: zxcd <228587199@qq.com> Date: Thu, 9 Mar 2023 15:04:29 +0800 Subject: [PATCH] fix load model schedule error, config optional. (#3008) --- paddlespeech/s2t/exps/wav2vec2/model.py | 2 +- paddlespeech/s2t/training/scheduler.py | 1 - 2 files changed, 1 insertion(+), 2 deletions(-) diff --git a/paddlespeech/s2t/exps/wav2vec2/model.py b/paddlespeech/s2t/exps/wav2vec2/model.py index 86b56b876..6c90f99e1 100644 --- a/paddlespeech/s2t/exps/wav2vec2/model.py +++ b/paddlespeech/s2t/exps/wav2vec2/model.py @@ -591,7 +591,7 @@ class Wav2Vec2ASRTrainer(Trainer): def setup_dataloader(self): config = self.config.clone() self.use_streamdata = config.get("use_stream_data", False) - self.use_sb = config.use_sb_pipeline + self.use_sb = config.get("use_sb_pipeline", False) if self.use_sb: hparams_file = config.sb_pipeline_conf with open(hparams_file, 'r', encoding='utf8') as fin: diff --git a/paddlespeech/s2t/training/scheduler.py b/paddlespeech/s2t/training/scheduler.py index 53c756ce3..a5e7a08f1 100644 --- a/paddlespeech/s2t/training/scheduler.py +++ b/paddlespeech/s2t/training/scheduler.py @@ -220,7 +220,6 @@ class NewBobScheduler(LRScheduler): def load(self, data): """Loads the needed information.""" - data = paddle.load(data) self.last_epoch = data["current_epoch_index"] self.hyperparam_value = data["hyperparam_value"] self.metric_values = data["metric_values"]