Merge pull request #753 from Jackwaterveg/ds2_online

模型Resume的学习率问题
pull/754/head v2.1.1
Hui Zhang 3 years ago committed by GitHub
commit 0309c36a3d
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23

@ -123,10 +123,6 @@ class DeepSpeech2Trainer(Trainer):
def setup_model(self):
config = self.config.clone()
config.defrost()
assert (self.train_loader.collate_fn.feature_size ==
self.test_loader.collate_fn.feature_size)
assert (self.train_loader.collate_fn.vocab_size ==
self.test_loader.collate_fn.vocab_size)
config.model.feat_size = self.train_loader.collate_fn.feature_size
config.model.dict_size = self.train_loader.collate_fn.vocab_size
config.freeze()

@ -181,8 +181,7 @@ class Trainer():
if from_scratch:
# save init model, i.e. 0 epoch
self.save(tag='init', infos=None)
self.lr_scheduler.step(self.iteration)
self.lr_scheduler.step(self.epoch)
if self.parallel:
self.train_loader.batch_sampler.set_epoch(self.epoch)

Loading…
Cancel
Save