diff --git a/deepspeech/exps/deepspeech2/model.py b/deepspeech/exps/deepspeech2/model.py index e3d6369bb..ebf479172 100644 --- a/deepspeech/exps/deepspeech2/model.py +++ b/deepspeech/exps/deepspeech2/model.py @@ -45,9 +45,10 @@ class DeepSpeech2Trainer(Trainer): def __init__(self, config, args): super().__init__(config, args) - def train_batch(self, batch_data): - start = time.time() + def train_batch(self, batch_data, msg): self.model.train() + start = time.time() + loss = self.model(*batch_data) loss.backward() layer_tools.print_grads(self.model, print_func=None) @@ -59,10 +60,8 @@ class DeepSpeech2Trainer(Trainer): losses_np = { 'train_loss': float(loss), } - msg = "Train: Rank: {}, ".format(dist.get_rank()) - msg += "epoch: {}, ".format(self.epoch) - msg += "step: {}, ".format(self.iteration) msg += "time: {:>.3f}s, ".format(iteration_time) + msg += "batch size: {}, ".format(self.config.data.batch_size) msg += ', '.join('{}: {:>.6f}'.format(k, v) for k, v in losses_np.items()) self.logger.info(msg) @@ -71,6 +70,7 @@ class DeepSpeech2Trainer(Trainer): for k, v in losses_np.items(): self.visualizer.add_scalar("train/{}".format(k), v, self.iteration) + self.iteration += 1 @mp_tools.rank_zero_only @paddle.no_grad() diff --git a/deepspeech/exps/u2/bin/train.py b/deepspeech/exps/u2/bin/train.py index 5afabbb81..02845e9cc 100644 --- a/deepspeech/exps/u2/bin/train.py +++ b/deepspeech/exps/u2/bin/train.py @@ -13,6 +13,8 @@ # limitations under the License. """Trainer for U2 model.""" +import os +import cProfile from paddle import distributed as dist from deepspeech.utils.utility import print_arguments @@ -52,4 +54,7 @@ if __name__ == "__main__": with open(args.dump_config, 'w') as f: print(config, file=f) - main(config, args) + # Setting for profiling + pr = cProfile.Profile() + pr.runcall(main, config, args) + pr.dump_stats(os.path.join('.', 'train.profile')) diff --git a/deepspeech/exps/u2/model.py b/deepspeech/exps/u2/model.py index 38e24cef5..87d0c94db 100644 --- a/deepspeech/exps/u2/model.py +++ b/deepspeech/exps/u2/model.py @@ -80,54 +80,60 @@ class U2Trainer(Trainer): self.model.train() start = time.time() - + loss, attention_loss, ctc_loss = self.model(*batch_data) + # loss div by `batch_size * accum_grad` + loss /= train_conf.accum_grad loss.backward() layer_tools.print_grads(self.model, print_func=None) - - if self.iteration % train_conf.accum_grad == 0: + + losses_np = { + 'train_loss': float(loss) * train_conf.accum_grad, + 'train_att_loss': float(attention_loss), + 'train_ctc_loss': float(ctc_loss), + } + + if (self.iteration + 1) % train_conf.accum_grad == 0: + if dist.get_rank() == 0 and self.visualizer: + for k, v in losses_np.items(): + self.visualizer.add_scalar("train/{}".format(k), v, + self.iteration) self.optimizer.step() self.optimizer.clear_grad() self.lr_scheduler.step() + self.iteration += 1 iteration_time = time.time() - start - losses_np = { - 'train_loss': float(loss), - 'train_att_loss': float(attention_loss), - 'train_ctc_loss': float(ctc_loss), - } - msg += "time: {:>.3f}s, ".format(iteration_time) - msg += "batch size: {}, ".format(self.config.data.batch_size) - msg += "accum: {}, ".format(train_conf.accum_grad) - msg += ', '.join('{}: {:>.6f}'.format(k, v) - for k, v in losses_np.items()) - if self.iteration % train_conf.log_interval == 0: + if (self.iteration + 1) % train_conf.log_interval == 0: + msg += "time: {:>.3f}s, ".format(iteration_time) + msg += "batch size: {}, ".format(self.config.data.batch_size) + msg += "accum: {}, ".format(train_conf.accum_grad) + msg += ', '.join('{}: {:>.6f}'.format(k, v) + for k, v in losses_np.items()) self.logger.info(msg) - # display - if dist.get_rank() == 0 and self.visualizer: - for k, v in losses_np.items(): - self.visualizer.add_scalar("train/{}".format(k), v, - self.iteration) - def train(self): - """The training process. - It includes forward/backward/update and periodical validation and - saving. - """ + """The training process control by step.""" # !!!IMPORTANT!!! # Try to export the model by script, if fails, we should refine # the code to satisfy the script export requirements # script_model = paddle.jit.to_static(self.model) # script_model_path = str(self.checkpoint_dir / 'init') # paddle.jit.save(script_model, script_model_path) - + from_scratch = self.resume_or_scratch() + if from_scratch: + # save init model, i.e. 0 epoch + self.save(tag='init') + + self.lr_scheduler.step(self.iteration) + if self.parallel: + self.train_loader.batch_sampler.set_epoch(self.epoch) + self.logger.info( f"Train Total Examples: {len(self.train_loader.dataset)}") - - while self.epoch <= self.config.training.n_epoch: + while self.epoch < self.config.training.n_epoch: try: data_start_time = time.time() for batch in self.train_loader: @@ -135,19 +141,18 @@ class U2Trainer(Trainer): msg = "Train: Rank: {}, ".format(dist.get_rank()) msg += "epoch: {}, ".format(self.epoch) msg += "step: {}, ".format(self.iteration) - msg += "lr: {}, ".foramt(self.lr_scheduler()) + msg += "lr: {:>.8f}, ".format(self.lr_scheduler()) msg += "dataloader time: {:>.3f}s, ".format(dataload_time) - self.iteration += 1 self.train_batch(batch, msg) data_start_time = time.time() except Exception as e: self.logger.error(e) raise e - + self.valid() self.save() self.new_epoch() - + @mp_tools.rank_zero_only @paddle.no_grad() def valid(self): @@ -263,12 +268,12 @@ class U2Trainer(Trainer): lr_scheduler = paddle.optimizer.lr.ExponentialDecay( learning_rate=optim_conf.lr, gamma=scheduler_conf.lr_decay, - verbose=True) + verbose=False) elif scheduler_type == 'warmuplr': lr_scheduler = WarmupLR( learning_rate=optim_conf.lr, warmup_steps=scheduler_conf.warmup_steps, - verbose=True) + verbose=False) else: raise ValueError(f"Not support scheduler: {scheduler_type}") diff --git a/deepspeech/training/trainer.py b/deepspeech/training/trainer.py index 474f8d728..0f18e3ba0 100644 --- a/deepspeech/training/trainer.py +++ b/deepspeech/training/trainer.py @@ -127,7 +127,7 @@ class Trainer(): dist.init_parallel_env() @mp_tools.rank_zero_only - def save(self, infos=None): + def save(self, tag=None, infos=None): """Save checkpoint (model parameters and optimizer states). """ if infos is None: @@ -136,8 +136,9 @@ class Trainer(): "epoch": self.epoch, "lr": self.optimizer.get_lr(), } - checkpoint.save_parameters(self.checkpoint_dir, self.iteration, - self.model, self.optimizer, infos) + checkpoint.save_parameters(self.checkpoint_dir, self.iteration + if tag is None else tag, self.model, + self.optimizer, infos) def resume_or_scratch(self): """Resume from latest checkpoint at checkpoints in the output @@ -146,6 +147,7 @@ class Trainer(): If ``args.checkpoint_path`` is not None, load the checkpoint, else resume training. """ + scratch = None infos = checkpoint.load_parameters( self.model, self.optimizer, @@ -155,44 +157,41 @@ class Trainer(): # restore from ckpt self.iteration = infos["step"] self.epoch = infos["epoch"] - self.lr_scheduler.step(self.iteration) - if self.parallel: - self.train_loader.batch_sampler.set_epoch(self.epoch) - return False + scratch = False else: - # from scratch, epoch and iteration init with zero - # save init model, i.e. 0 epoch - self.save() - # self.epoch start from 1. - self.new_epoch() - return True + scratch = True + + return scratch def new_epoch(self): - """Reset the train loader and increment ``epoch``. + """Reset the train loader seed and increment `epoch`. """ + self.epoch += 1 if self.parallel: - # batch sampler epoch start from 0 self.train_loader.batch_sampler.set_epoch(self.epoch) - self.epoch += 1 def train(self): - """The training process. - - """ + """The training process control by epoch.""" from_scratch = self.resume_or_scratch() - + if from_scratch: + # save init model, i.e. 0 epoch + self.save(tag='init') + + self.lr_scheduler.step(self.iteration) + if self.parallel: + self.train_loader.batch_sampler.set_epoch(self.epoch) + self.logger.info( f"Train Total Examples: {len(self.train_loader.dataset)}") - while self.epoch <= self.config.training.n_epoch: + while self.epoch < self.config.training.n_epoch: try: data_start_time = time.time() for batch in self.train_loader: dataload_time = time.time() - data_start_time - # iteration start from 1. - self.iteration += 1 msg = "Train: Rank: {}, ".format(dist.get_rank()) msg += "epoch: {}, ".format(self.epoch) msg += "step: {}, ".format(self.iteration) + msg += "lr: {:>.8f}, ".format(self.lr_scheduler()) msg += "dataloader time: {:>.3f}s, ".format(dataload_time) self.train_batch(batch, msg) data_start_time = time.time() @@ -202,7 +201,6 @@ class Trainer(): self.valid() self.save() - # lr control by epoch self.lr_scheduler.step() self.new_epoch() diff --git a/deepspeech/utils/checkpoint.py b/deepspeech/utils/checkpoint.py index 622811d0d..c674b1117 100644 --- a/deepspeech/utils/checkpoint.py +++ b/deepspeech/utils/checkpoint.py @@ -16,6 +16,7 @@ import os import logging import re import json +from typing import Union import paddle from paddle import distributed as dist @@ -79,7 +80,7 @@ def load_parameters(model, configs = {} if checkpoint_path is not None: - iteration = int(os.path.basename(checkpoint_path).split(":")[-1]) + tag = os.path.basename(checkpoint_path).split(":")[-1] elif checkpoint_dir is not None: iteration = _load_latest_checkpoint(checkpoint_dir) if iteration == -1: @@ -113,14 +114,14 @@ def load_parameters(model, @mp_tools.rank_zero_only def save_parameters(checkpoint_dir: str, - iteration: int, + tag_or_iteration: Union[int, str], model: paddle.nn.Layer, optimizer: Optimizer=None, infos: dict=None): """Checkpoint the latest trained model parameters. Args: checkpoint_dir (str): the directory where checkpoint is saved. - iteration (int): the latest iteration(step or epoch) number. + tag_or_iteration (int or str): the latest iteration(step or epoch) number. model (Layer): model to be checkpointed. optimizer (Optimizer, optional): optimizer to be checkpointed. Defaults to None. @@ -128,7 +129,8 @@ def save_parameters(checkpoint_dir: str, Returns: None """ - checkpoint_path = os.path.join(checkpoint_dir, "{}".format(iteration)) + checkpoint_path = os.path.join(checkpoint_dir, + "{}".format(tag_or_iteration)) model_dict = model.state_dict() params_path = checkpoint_path + ".pdparams" @@ -142,10 +144,10 @@ def save_parameters(checkpoint_dir: str, logger.info("Saved optimzier state to {}".format(optimizer_path)) info_path = re.sub('.pdparams$', '.json', params_path) - if infos is None: - infos = {} + infos = {} if infos is None else infos with open(info_path, 'w') as fout: data = json.dumps(infos) fout.write(data) - _save_checkpoint(checkpoint_dir, iteration) + if isinstance(tag_or_iteration, int): + _save_checkpoint(checkpoint_dir, tag_or_iteration) diff --git a/examples/aishell/s1/conf/conformer.yaml b/examples/aishell/s1/conf/conformer.yaml index 1b374507a..40c40629f 100644 --- a/examples/aishell/s1/conf/conformer.yaml +++ b/examples/aishell/s1/conf/conformer.yaml @@ -6,7 +6,6 @@ data: vocab_filepath: data/vocab.txt unit_type: 'char' spm_model_prefix: '' - mean_std_filepath: "" augmentation_config: conf/augmentation.json batch_size: 64 min_input_len: 0.5 diff --git a/examples/tiny/s1/conf/conformer.yaml b/examples/tiny/s1/conf/conformer.yaml index b1101736d..0e58002af 100644 --- a/examples/tiny/s1/conf/conformer.yaml +++ b/examples/tiny/s1/conf/conformer.yaml @@ -12,7 +12,7 @@ data: min_input_len: 0.5 max_input_len: 20.0 min_output_len: 0.0 - max_output_len: 400 + max_output_len: 400.0 min_output_input_ratio: 0.05 max_output_input_ratio: 10.0 raw_wav: True # use raw_wav or kaldi feature