diff --git a/deepspeech/models/ds2/deepspeech2.py b/deepspeech/models/ds2/deepspeech2.py index a2aa31f7..63327a8c 100644 --- a/deepspeech/models/ds2/deepspeech2.py +++ b/deepspeech/models/ds2/deepspeech2.py @@ -219,10 +219,10 @@ class DeepSpeech2Model(nn.Layer): The model built from pretrained result. """ model = cls( - #feat_size=dataloader.collate_fn.feature_size, - feat_size=dataloader.dataset.feature_size, - #dict_size=dataloader.collate_fn.vocab_size, - dict_size=dataloader.dataset.vocab_size, + feat_size=dataloader.collate_fn.feature_size, + #feat_size=dataloader.dataset.feature_size, + dict_size=dataloader.collate_fn.vocab_size, + #dict_size=dataloader.dataset.vocab_size, num_conv_layers=config.model.num_conv_layers, num_rnn_layers=config.model.num_rnn_layers, rnn_size=config.model.rnn_layer_size,