|
|
|
@ -153,8 +153,12 @@ class DeepSpeech2Trainer(Trainer):
|
|
|
|
|
def setup_model(self):
|
|
|
|
|
config = self.config.clone()
|
|
|
|
|
with UpdateConfig(config):
|
|
|
|
|
if self.train:
|
|
|
|
|
config.model.feat_size = self.train_loader.collate_fn.feature_size
|
|
|
|
|
config.model.dict_size = self.train_loader.collate_fn.vocab_size
|
|
|
|
|
else:
|
|
|
|
|
config.model.feat_size = self.test_loader.collate_fn.feature_size
|
|
|
|
|
config.model.dict_size = self.test_loader.collate_fn.vocab_size
|
|
|
|
|
|
|
|
|
|
if self.args.model_type == 'offline':
|
|
|
|
|
model = DeepSpeech2Model.from_config(config.model)
|
|
|
|
@ -189,7 +193,6 @@ class DeepSpeech2Trainer(Trainer):
|
|
|
|
|
self.lr_scheduler = lr_scheduler
|
|
|
|
|
logger.info("Setup optimizer/lr_scheduler!")
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
def setup_dataloader(self):
|
|
|
|
|
config = self.config.clone()
|
|
|
|
|
config.defrost()
|
|
|
|
|