Merge pull request #753 from Jackwaterveg/ds2_online

模型Resume的学习率问题
pull/754/head v2.1.1
Hui Zhang 4 years ago committed by GitHub
commit 0309c36a3d
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23

@ -123,10 +123,6 @@ class DeepSpeech2Trainer(Trainer):
def setup_model(self): def setup_model(self):
config = self.config.clone() config = self.config.clone()
config.defrost() config.defrost()
assert (self.train_loader.collate_fn.feature_size ==
self.test_loader.collate_fn.feature_size)
assert (self.train_loader.collate_fn.vocab_size ==
self.test_loader.collate_fn.vocab_size)
config.model.feat_size = self.train_loader.collate_fn.feature_size config.model.feat_size = self.train_loader.collate_fn.feature_size
config.model.dict_size = self.train_loader.collate_fn.vocab_size config.model.dict_size = self.train_loader.collate_fn.vocab_size
config.freeze() config.freeze()

@ -181,8 +181,7 @@ class Trainer():
if from_scratch: if from_scratch:
# save init model, i.e. 0 epoch # save init model, i.e. 0 epoch
self.save(tag='init', infos=None) self.save(tag='init', infos=None)
self.lr_scheduler.step(self.epoch)
self.lr_scheduler.step(self.iteration)
if self.parallel: if self.parallel:
self.train_loader.batch_sampler.set_epoch(self.epoch) self.train_loader.batch_sampler.set_epoch(self.epoch)

Loading…
Cancel
Save