From 7e8da9c04da6667b1ec97366360a23ca525c9960 Mon Sep 17 00:00:00 2001 From: Hui Zhang Date: Fri, 5 Nov 2021 06:07:26 +0000 Subject: [PATCH] fix bug for batch dataloader using --- paddlespeech/s2t/exps/u2/model.py | 2 +- paddlespeech/s2t/models/u2/u2.py | 4 ++-- 2 files changed, 3 insertions(+), 3 deletions(-) diff --git a/paddlespeech/s2t/exps/u2/model.py b/paddlespeech/s2t/exps/u2/model.py index 2f0e752f8..8dad50748 100644 --- a/paddlespeech/s2t/exps/u2/model.py +++ b/paddlespeech/s2t/exps/u2/model.py @@ -591,7 +591,7 @@ class U2Tester(U2Trainer): infer_model = U2InferModel.from_pretrained(self.test_loader, self.config.model.clone(), self.args.checkpoint_path) - feat_dim = self.test_loader.collate_fn.feature_size + feat_dim = self.test_loader.feat_dim input_spec = [ paddle.static.InputSpec(shape=[1, None, feat_dim], dtype='float32'), # audio, [B,T,D] diff --git a/paddlespeech/s2t/models/u2/u2.py b/paddlespeech/s2t/models/u2/u2.py index fd9982716..916a6a059 100644 --- a/paddlespeech/s2t/models/u2/u2.py +++ b/paddlespeech/s2t/models/u2/u2.py @@ -934,8 +934,8 @@ class U2Model(U2DecodeModel): DeepSpeech2Model: The model built from pretrained result. """ with UpdateConfig(config): - config.input_dim = dataloader.collate_fn.feature_size - config.output_dim = dataloader.collate_fn.vocab_size + config.input_dim = dataloader.feat_dim + config.output_dim = dataloader.vocab_size model = cls.from_config(config)