|
|
|
@ -154,11 +154,11 @@ class DeepSpeech2Trainer(Trainer):
|
|
|
|
|
config = self.config.clone()
|
|
|
|
|
with UpdateConfig(config):
|
|
|
|
|
if self.train:
|
|
|
|
|
config.model.feat_size = self.train_loader.collate_fn.feature_size
|
|
|
|
|
config.model.dict_size = self.train_loader.collate_fn.vocab_size
|
|
|
|
|
config.model.input_dim = self.train_loader.collate_fn.feature_size
|
|
|
|
|
config.model.output_dim = self.train_loader.collate_fn.vocab_size
|
|
|
|
|
else:
|
|
|
|
|
config.model.feat_size = self.test_loader.collate_fn.feature_size
|
|
|
|
|
config.model.dict_size = self.test_loader.collate_fn.vocab_size
|
|
|
|
|
config.model.input_dim = self.test_loader.collate_fn.feature_size
|
|
|
|
|
config.model.output_dim = self.test_loader.collate_fn.vocab_size
|
|
|
|
|
|
|
|
|
|
if self.args.model_type == 'offline':
|
|
|
|
|
model = DeepSpeech2Model.from_config(config.model)
|
|
|
|
|