fix bug for batch dataloader using

pull/1012/head
Hui Zhang 3 years ago
parent 69bccb4f02
commit 9cdd2643b1

@ -591,7 +591,7 @@ class U2Tester(U2Trainer):
infer_model = U2InferModel.from_pretrained(self.test_loader, infer_model = U2InferModel.from_pretrained(self.test_loader,
self.config.model.clone(), self.config.model.clone(),
self.args.checkpoint_path) self.args.checkpoint_path)
feat_dim = self.test_loader.collate_fn.feature_size feat_dim = self.test_loader.feat_dim
input_spec = [ input_spec = [
paddle.static.InputSpec(shape=[1, None, feat_dim], paddle.static.InputSpec(shape=[1, None, feat_dim],
dtype='float32'), # audio, [B,T,D] dtype='float32'), # audio, [B,T,D]

@ -934,8 +934,8 @@ class U2Model(U2DecodeModel):
DeepSpeech2Model: The model built from pretrained result. DeepSpeech2Model: The model built from pretrained result.
""" """
with UpdateConfig(config): with UpdateConfig(config):
config.input_dim = dataloader.collate_fn.feature_size config.input_dim = dataloader.feat_dim
config.output_dim = dataloader.collate_fn.vocab_size config.output_dim = dataloader.vocab_size
model = cls.from_config(config) model = cls.from_config(config)

Loading…
Cancel
Save