diff --git a/paddlespeech/s2t/exps/u2/model.py b/paddlespeech/s2t/exps/u2/model.py index 992be5cd..85bb877b 100644 --- a/paddlespeech/s2t/exps/u2/model.py +++ b/paddlespeech/s2t/exps/u2/model.py @@ -524,10 +524,10 @@ class U2Tester(U2Trainer): List[paddle.static.InputSpec]: input spec. """ from paddlespeech.s2t.models.u2 import U2InferModel - infer_model = U2InferModel.from_pretrained(self.train_loader, + infer_model = U2InferModel.from_pretrained(self.test_loader, self.config.clone(), self.args.checkpoint_path) - feat_dim = self.train_loader.feat_dim + feat_dim = self.test_loader.feat_dim input_spec = [ paddle.static.InputSpec(shape=[1, None, feat_dim], dtype='float32'), # audio, [B,T,D]