diff --git a/deepspeech/exps/deepspeech2/model.py b/deepspeech/exps/deepspeech2/model.py index dfd81241..65c905a1 100644 --- a/deepspeech/exps/deepspeech2/model.py +++ b/deepspeech/exps/deepspeech2/model.py @@ -123,10 +123,6 @@ class DeepSpeech2Trainer(Trainer): def setup_model(self): config = self.config.clone() config.defrost() - assert (self.train_loader.collate_fn.feature_size == - self.test_loader.collate_fn.feature_size) - assert (self.train_loader.collate_fn.vocab_size == - self.test_loader.collate_fn.vocab_size) config.model.feat_size = self.train_loader.collate_fn.feature_size config.model.dict_size = self.train_loader.collate_fn.vocab_size config.freeze() diff --git a/deepspeech/training/trainer.py b/deepspeech/training/trainer.py index 5ebba1a9..209e2240 100644 --- a/deepspeech/training/trainer.py +++ b/deepspeech/training/trainer.py @@ -29,37 +29,37 @@ logger = Log(__name__).getlog() class Trainer(): """ - An experiment template in order to structure the training code and take - care of saving, loading, logging, visualization stuffs. It's intended to - be flexible and simple. - - So it only handles output directory (create directory for the output, - create a checkpoint directory, dump the config in use and create + An experiment template in order to structure the training code and take + care of saving, loading, logging, visualization stuffs. It's intended to + be flexible and simple. + + So it only handles output directory (create directory for the output, + create a checkpoint directory, dump the config in use and create visualizer and logger) in a standard way without enforcing any - input-output protocols to the model and dataloader. It leaves the main - part for the user to implement their own (setup the model, criterion, - optimizer, define a training step, define a validation function and + input-output protocols to the model and dataloader. It leaves the main + part for the user to implement their own (setup the model, criterion, + optimizer, define a training step, define a validation function and customize all the text and visual logs). - It does not save too much boilerplate code. The users still have to write - the forward/backward/update mannually, but they are free to add + It does not save too much boilerplate code. The users still have to write + the forward/backward/update mannually, but they are free to add non-standard behaviors if needed. We have some conventions to follow. - 1. Experiment should have ``model``, ``optimizer``, ``train_loader`` and + 1. Experiment should have ``model``, ``optimizer``, ``train_loader`` and ``valid_loader``, ``config`` and ``args`` attributes. - 2. The config should have a ``training`` field, which has - ``valid_interval``, ``save_interval`` and ``max_iteration`` keys. It is - used as the trigger to invoke validation, checkpointing and stop of the + 2. The config should have a ``training`` field, which has + ``valid_interval``, ``save_interval`` and ``max_iteration`` keys. It is + used as the trigger to invoke validation, checkpointing and stop of the experiment. - 3. There are four methods, namely ``train_batch``, ``valid``, + 3. There are four methods, namely ``train_batch``, ``valid``, ``setup_model`` and ``setup_dataloader`` that should be implemented. - Feel free to add/overwrite other methods and standalone functions if you + Feel free to add/overwrite other methods and standalone functions if you need. - + Parameters ---------- config: yacs.config.CfgNode The configuration used for the experiment. - + args: argparse.Namespace The parsed command line arguments. Examples @@ -68,16 +68,16 @@ class Trainer(): >>> exp = Trainer(config, args) >>> exp.setup() >>> exp.run() - >>> + >>> >>> config = get_cfg_defaults() >>> parser = default_argument_parser() >>> args = parser.parse_args() - >>> if args.config: + >>> if args.config: >>> config.merge_from_file(args.config) >>> if args.opts: >>> config.merge_from_list(args.opts) >>> config.freeze() - >>> + >>> >>> if args.nprocs > 1 and args.device == "gpu": >>> dist.spawn(main_sp, args=(config, args), nprocs=args.nprocs) >>> else: @@ -114,7 +114,7 @@ class Trainer(): @property def parallel(self): - """A flag indicating whether the experiment should run with + """A flag indicating whether the experiment should run with multiprocessing. """ return self.args.device == "gpu" and self.args.nprocs > 1 @@ -144,9 +144,9 @@ class Trainer(): self.optimizer, infos) def resume_or_scratch(self): - """Resume from latest checkpoint at checkpoints in the output + """Resume from latest checkpoint at checkpoints in the output directory or load a specified checkpoint. - + If ``args.checkpoint_path`` is not None, load the checkpoint, else resume training. """ @@ -181,8 +181,7 @@ class Trainer(): if from_scratch: # save init model, i.e. 0 epoch self.save(tag='init', infos=None) - - self.lr_scheduler.step(self.iteration) + self.lr_scheduler.step(self.epoch) if self.parallel: self.train_loader.batch_sampler.set_epoch(self.epoch) @@ -254,7 +253,7 @@ class Trainer(): def setup_checkpointer(self): """Create a directory used to save checkpoints into. - + It is "checkpoints" inside the output directory. """ # checkpoint dir @@ -277,13 +276,13 @@ class Trainer(): @mp_tools.rank_zero_only def setup_visualizer(self): """Initialize a visualizer to log the experiment. - + The visual log is saved in the output directory. - + Notes ------ - Only the main process has a visualizer with it. Use multiple - visualizers in multiprocess to write to a same log file may cause + Only the main process has a visualizer with it. Use multiple + visualizers in multiprocess to write to a same log file may cause unexpected behaviors. """ # visualizer @@ -292,9 +291,9 @@ class Trainer(): @mp_tools.rank_zero_only def dump_config(self): - """Save the configuration used for this experiment. - - It is saved in to ``config.yaml`` in the output directory at the + """Save the configuration used for this experiment. + + It is saved in to ``config.yaml`` in the output directory at the beginning of the experiment. """ with open(self.output_dir / "config.yaml", 'wt') as f: @@ -312,13 +311,13 @@ class Trainer(): raise NotImplementedError("valid should be implemented.") def setup_model(self): - """Setup model, criterion and optimizer, etc. A subclass should + """Setup model, criterion and optimizer, etc. A subclass should implement this method. """ raise NotImplementedError("setup_model should be implemented.") def setup_dataloader(self): - """Setup training dataloader and validation dataloader. A subclass + """Setup training dataloader and validation dataloader. A subclass should implement this method. """ raise NotImplementedError("setup_dataloader should be implemented.")