diff --git a/paddlespeech/s2t/exps/wav2vec2/model.py b/paddlespeech/s2t/exps/wav2vec2/model.py index 86b56b876..6c90f99e1 100644 --- a/paddlespeech/s2t/exps/wav2vec2/model.py +++ b/paddlespeech/s2t/exps/wav2vec2/model.py @@ -591,7 +591,7 @@ class Wav2Vec2ASRTrainer(Trainer): def setup_dataloader(self): config = self.config.clone() self.use_streamdata = config.get("use_stream_data", False) - self.use_sb = config.use_sb_pipeline + self.use_sb = config.get("use_sb_pipeline", False) if self.use_sb: hparams_file = config.sb_pipeline_conf with open(hparams_file, 'r', encoding='utf8') as fin: diff --git a/paddlespeech/s2t/training/scheduler.py b/paddlespeech/s2t/training/scheduler.py index 53c756ce3..a5e7a08f1 100644 --- a/paddlespeech/s2t/training/scheduler.py +++ b/paddlespeech/s2t/training/scheduler.py @@ -220,7 +220,6 @@ class NewBobScheduler(LRScheduler): def load(self, data): """Loads the needed information.""" - data = paddle.load(data) self.last_epoch = data["current_epoch_index"] self.hyperparam_value = data["hyperparam_value"] self.metric_values = data["metric_values"]