From 38433729584469c71382459c72dbba610badf7ec Mon Sep 17 00:00:00 2001 From: Hui Zhang Date: Mon, 13 Sep 2021 09:44:29 +0000 Subject: [PATCH] u2 with chianer updater --- deepspeech/__init__.py | 58 ++--- deepspeech/exps/u2/bin/train.py | 4 +- deepspeech/exps/u2/trainer.py | 219 ++++++++++++++++++ deepspeech/models/u2/__init__.py | 19 ++ deepspeech/models/{ => u2}/u2.py | 0 deepspeech/models/u2/updater.py | 149 ++++++++++++ deepspeech/modules/loss.py | 2 +- deepspeech/training/extensions/evaluator.py | 36 ++- deepspeech/training/extensions/snapshot.py | 35 ++- deepspeech/training/extensions/visualizer.py | 10 +- deepspeech/training/trainer.py | 2 +- .../training/updaters/standard_updater.py | 53 +++-- deepspeech/training/updaters/trainer.py | 4 +- deepspeech/training/updaters/updater.py | 1 + deepspeech/utils/checkpoint.py | 4 +- deepspeech/utils/log.py | 73 ++---- examples/tiny/s1/conf/transformer.yaml | 2 +- requirements.txt | 4 + 18 files changed, 544 insertions(+), 131 deletions(-) create mode 100644 deepspeech/exps/u2/trainer.py create mode 100644 deepspeech/models/u2/__init__.py rename deepspeech/models/{ => u2}/u2.py (100%) create mode 100644 deepspeech/models/u2/updater.py diff --git a/deepspeech/__init__.py b/deepspeech/__init__.py index 3ac68a3b..5505ecbf 100644 --- a/deepspeech/__init__.py +++ b/deepspeech/__init__.py @@ -80,23 +80,23 @@ def convert_dtype_to_string(tensor_dtype): if not hasattr(paddle, 'softmax'): - logger.warn("register user softmax to paddle, remove this when fixed!") + logger.debug("register user softmax to paddle, remove this when fixed!") setattr(paddle, 'softmax', paddle.nn.functional.softmax) if not hasattr(paddle, 'log_softmax'): - logger.warn("register user log_softmax to paddle, remove this when fixed!") + logger.debug("register user log_softmax to paddle, remove this when fixed!") setattr(paddle, 'log_softmax', paddle.nn.functional.log_softmax) if not hasattr(paddle, 'sigmoid'): - logger.warn("register user sigmoid to paddle, remove this when fixed!") + logger.debug("register user sigmoid to paddle, remove this when fixed!") setattr(paddle, 'sigmoid', paddle.nn.functional.sigmoid) if not hasattr(paddle, 'log_sigmoid'): - logger.warn("register user log_sigmoid to paddle, remove this when fixed!") + logger.debug("register user log_sigmoid to paddle, remove this when fixed!") setattr(paddle, 'log_sigmoid', paddle.nn.functional.log_sigmoid) if not hasattr(paddle, 'relu'): - logger.warn("register user relu to paddle, remove this when fixed!") + logger.debug("register user relu to paddle, remove this when fixed!") setattr(paddle, 'relu', paddle.nn.functional.relu) @@ -105,7 +105,7 @@ def cat(xs, dim=0): if not hasattr(paddle, 'cat'): - logger.warn( + logger.debug( "override cat of paddle if exists or register, remove this when fixed!") paddle.cat = cat @@ -116,7 +116,7 @@ def item(x: paddle.Tensor): if not hasattr(paddle.Tensor, 'item'): - logger.warn( + logger.debug( "override item of paddle.Tensor if exists or register, remove this when fixed!" ) paddle.Tensor.item = item @@ -127,13 +127,13 @@ def func_long(x: paddle.Tensor): if not hasattr(paddle.Tensor, 'long'): - logger.warn( + logger.debug( "override long of paddle.Tensor if exists or register, remove this when fixed!" ) paddle.Tensor.long = func_long if not hasattr(paddle.Tensor, 'numel'): - logger.warn( + logger.debug( "override numel of paddle.Tensor if exists or register, remove this when fixed!" ) paddle.Tensor.numel = paddle.numel @@ -147,7 +147,7 @@ def new_full(x: paddle.Tensor, if not hasattr(paddle.Tensor, 'new_full'): - logger.warn( + logger.debug( "override new_full of paddle.Tensor if exists or register, remove this when fixed!" ) paddle.Tensor.new_full = new_full @@ -162,13 +162,13 @@ def eq(xs: paddle.Tensor, ys: Union[paddle.Tensor, float]) -> paddle.Tensor: if not hasattr(paddle.Tensor, 'eq'): - logger.warn( + logger.debug( "override eq of paddle.Tensor if exists or register, remove this when fixed!" ) paddle.Tensor.eq = eq if not hasattr(paddle, 'eq'): - logger.warn( + logger.debug( "override eq of paddle if exists or register, remove this when fixed!") paddle.eq = eq @@ -178,7 +178,7 @@ def contiguous(xs: paddle.Tensor) -> paddle.Tensor: if not hasattr(paddle.Tensor, 'contiguous'): - logger.warn( + logger.debug( "override contiguous of paddle.Tensor if exists or register, remove this when fixed!" ) paddle.Tensor.contiguous = contiguous @@ -195,7 +195,7 @@ def size(xs: paddle.Tensor, *args: int) -> paddle.Tensor: #`to_static` do not process `size` property, maybe some `paddle` api dependent on it. -logger.warn( +logger.debug( "override size of paddle.Tensor " "(`to_static` do not process `size` property, maybe some `paddle` api dependent on it), remove this when fixed!" ) @@ -207,7 +207,7 @@ def view(xs: paddle.Tensor, *args: int) -> paddle.Tensor: if not hasattr(paddle.Tensor, 'view'): - logger.warn("register user view to paddle.Tensor, remove this when fixed!") + logger.debug("register user view to paddle.Tensor, remove this when fixed!") paddle.Tensor.view = view @@ -216,7 +216,7 @@ def view_as(xs: paddle.Tensor, ys: paddle.Tensor) -> paddle.Tensor: if not hasattr(paddle.Tensor, 'view_as'): - logger.warn( + logger.debug( "register user view_as to paddle.Tensor, remove this when fixed!") paddle.Tensor.view_as = view_as @@ -242,7 +242,7 @@ def masked_fill(xs: paddle.Tensor, if not hasattr(paddle.Tensor, 'masked_fill'): - logger.warn( + logger.debug( "register user masked_fill to paddle.Tensor, remove this when fixed!") paddle.Tensor.masked_fill = masked_fill @@ -260,7 +260,7 @@ def masked_fill_(xs: paddle.Tensor, if not hasattr(paddle.Tensor, 'masked_fill_'): - logger.warn( + logger.debug( "register user masked_fill_ to paddle.Tensor, remove this when fixed!") paddle.Tensor.masked_fill_ = masked_fill_ @@ -272,7 +272,8 @@ def fill_(xs: paddle.Tensor, value: Union[float, int]) -> paddle.Tensor: if not hasattr(paddle.Tensor, 'fill_'): - logger.warn("register user fill_ to paddle.Tensor, remove this when fixed!") + logger.debug( + "register user fill_ to paddle.Tensor, remove this when fixed!") paddle.Tensor.fill_ = fill_ @@ -281,22 +282,22 @@ def repeat(xs: paddle.Tensor, *size: Any) -> paddle.Tensor: if not hasattr(paddle.Tensor, 'repeat'): - logger.warn( + logger.debug( "register user repeat to paddle.Tensor, remove this when fixed!") paddle.Tensor.repeat = repeat if not hasattr(paddle.Tensor, 'softmax'): - logger.warn( + logger.debug( "register user softmax to paddle.Tensor, remove this when fixed!") setattr(paddle.Tensor, 'softmax', paddle.nn.functional.softmax) if not hasattr(paddle.Tensor, 'sigmoid'): - logger.warn( + logger.debug( "register user sigmoid to paddle.Tensor, remove this when fixed!") setattr(paddle.Tensor, 'sigmoid', paddle.nn.functional.sigmoid) if not hasattr(paddle.Tensor, 'relu'): - logger.warn("register user relu to paddle.Tensor, remove this when fixed!") + logger.debug("register user relu to paddle.Tensor, remove this when fixed!") setattr(paddle.Tensor, 'relu', paddle.nn.functional.relu) @@ -305,7 +306,7 @@ def type_as(x: paddle.Tensor, other: paddle.Tensor) -> paddle.Tensor: if not hasattr(paddle.Tensor, 'type_as'): - logger.warn( + logger.debug( "register user type_as to paddle.Tensor, remove this when fixed!") setattr(paddle.Tensor, 'type_as', type_as) @@ -321,7 +322,7 @@ def to(x: paddle.Tensor, *args, **kwargs) -> paddle.Tensor: if not hasattr(paddle.Tensor, 'to'): - logger.warn("register user to to paddle.Tensor, remove this when fixed!") + logger.debug("register user to to paddle.Tensor, remove this when fixed!") setattr(paddle.Tensor, 'to', to) @@ -330,7 +331,8 @@ def func_float(x: paddle.Tensor) -> paddle.Tensor: if not hasattr(paddle.Tensor, 'float'): - logger.warn("register user float to paddle.Tensor, remove this when fixed!") + logger.debug( + "register user float to paddle.Tensor, remove this when fixed!") setattr(paddle.Tensor, 'float', func_float) @@ -339,7 +341,7 @@ def func_int(x: paddle.Tensor) -> paddle.Tensor: if not hasattr(paddle.Tensor, 'int'): - logger.warn("register user int to paddle.Tensor, remove this when fixed!") + logger.debug("register user int to paddle.Tensor, remove this when fixed!") setattr(paddle.Tensor, 'int', func_int) @@ -348,6 +350,6 @@ def tolist(x: paddle.Tensor) -> List[Any]: if not hasattr(paddle.Tensor, 'tolist'): - logger.warn( + logger.debug( "register user tolist to paddle.Tensor, remove this when fixed!") setattr(paddle.Tensor, 'tolist', tolist) diff --git a/deepspeech/exps/u2/bin/train.py b/deepspeech/exps/u2/bin/train.py index 9dd0041d..bd1aeb9e 100644 --- a/deepspeech/exps/u2/bin/train.py +++ b/deepspeech/exps/u2/bin/train.py @@ -18,10 +18,12 @@ import os from paddle import distributed as dist from deepspeech.exps.u2.config import get_cfg_defaults -from deepspeech.exps.u2.model import U2Trainer as Trainer +# from deepspeech.exps.u2.trainer import U2Trainer as Trainer from deepspeech.training.cli import default_argument_parser from deepspeech.utils.utility import print_arguments +from deepspeech.exps.u2.model import U2Trainer as Trainer + def main_sp(config, args): exp = Trainer(config, args) diff --git a/deepspeech/exps/u2/trainer.py b/deepspeech/exps/u2/trainer.py new file mode 100644 index 00000000..fa3e6d9d --- /dev/null +++ b/deepspeech/exps/u2/trainer.py @@ -0,0 +1,219 @@ +# Copyright (c) 2021 PaddlePaddle Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +"""Contains U2 model.""" +import paddle +from paddle import distributed as dist +from paddle.io import DataLoader + +from deepspeech.io.collator import SpeechCollator +from deepspeech.io.dataset import ManifestDataset +from deepspeech.io.sampler import SortagradBatchSampler +from deepspeech.io.sampler import SortagradDistributedBatchSampler +from deepspeech.models.u2 import U2Evaluator +from deepspeech.models.u2 import U2Model +from deepspeech.models.u2 import U2Updater +from deepspeech.training.extensions.snapshot import Snapshot +from deepspeech.training.extensions.visualizer import VisualDL +from deepspeech.training.optimizer import OptimizerFactory +from deepspeech.training.scheduler import LRSchedulerFactory +from deepspeech.training.timer import Timer +from deepspeech.training.trainer import Trainer +from deepspeech.training.updaters.trainer import Trainer as NewTrainer +from deepspeech.utils import layer_tools +from deepspeech.utils.log import Log + +logger = Log(__name__).getlog() + + +class U2Trainer(Trainer): + def __init__(self, config, args): + super().__init__(config, args) + + def setup_dataloader(self): + config = self.config.clone() + config.defrost() + config.collator.keep_transcription_text = False + + # train/valid dataset, return token ids + config.data.manifest = config.data.train_manifest + train_dataset = ManifestDataset.from_config(config) + + config.data.manifest = config.data.dev_manifest + dev_dataset = ManifestDataset.from_config(config) + + collate_fn_train = SpeechCollator.from_config(config) + + config.collator.augmentation_config = "" + collate_fn_dev = SpeechCollator.from_config(config) + + if self.parallel: + batch_sampler = SortagradDistributedBatchSampler( + train_dataset, + batch_size=config.collator.batch_size, + num_replicas=None, + rank=None, + shuffle=True, + drop_last=True, + sortagrad=config.collator.sortagrad, + shuffle_method=config.collator.shuffle_method) + else: + batch_sampler = SortagradBatchSampler( + train_dataset, + shuffle=True, + batch_size=config.collator.batch_size, + drop_last=True, + sortagrad=config.collator.sortagrad, + shuffle_method=config.collator.shuffle_method) + self.train_loader = DataLoader( + train_dataset, + batch_sampler=batch_sampler, + collate_fn=collate_fn_train, + num_workers=config.collator.num_workers, ) + self.valid_loader = DataLoader( + dev_dataset, + batch_size=config.collator.batch_size, + shuffle=False, + drop_last=False, + collate_fn=collate_fn_dev) + + # test dataset, return raw text + config.data.manifest = config.data.test_manifest + # filter test examples, will cause less examples, but no mismatch with training + # and can use large batch size , save training time, so filter test egs now. + config.data.min_input_len = 0.0 # second + config.data.max_input_len = float('inf') # second + config.data.min_output_len = 0.0 # tokens + config.data.max_output_len = float('inf') # tokens + config.data.min_output_input_ratio = 0.00 + config.data.max_output_input_ratio = float('inf') + + test_dataset = ManifestDataset.from_config(config) + # return text ord id + config.collator.keep_transcription_text = True + config.collator.augmentation_config = "" + self.test_loader = DataLoader( + test_dataset, + batch_size=config.decoding.batch_size, + shuffle=False, + drop_last=False, + collate_fn=SpeechCollator.from_config(config)) + # return text token id + config.collator.keep_transcription_text = False + self.align_loader = DataLoader( + test_dataset, + batch_size=config.decoding.batch_size, + shuffle=False, + drop_last=False, + collate_fn=SpeechCollator.from_config(config)) + logger.info("Setup train/valid/test/align Dataloader!") + + def setup_model(self): + config = self.config + model_conf = config.model + model_conf.defrost() + model_conf.input_dim = self.train_loader.collate_fn.feature_size + model_conf.output_dim = self.train_loader.collate_fn.vocab_size + model_conf.freeze() + model = U2Model.from_config(model_conf) + + if self.parallel: + model = paddle.DataParallel(model) + + model.train() + logger.info(f"{model}") + layer_tools.print_params(model, logger.info) + + train_config = config.training + optim_type = train_config.optim + optim_conf = train_config.optim_conf + scheduler_type = train_config.scheduler + scheduler_conf = train_config.scheduler_conf + + scheduler_args = { + "learning_rate": optim_conf.lr, + "verbose": False, + "warmup_steps": scheduler_conf.warmup_steps, + "gamma": scheduler_conf.lr_decay, + "d_model": model_conf.encoder_conf.output_size, + } + lr_scheduler = LRSchedulerFactory.from_args(scheduler_type, + scheduler_args) + + def optimizer_args( + config, + parameters, + lr_scheduler=None, ): + train_config = config.training + optim_type = train_config.optim + optim_conf = train_config.optim_conf + scheduler_type = train_config.scheduler + scheduler_conf = train_config.scheduler_conf + return { + "grad_clip": train_config.global_grad_clip, + "weight_decay": optim_conf.weight_decay, + "learning_rate": lr_scheduler + if lr_scheduler else optim_conf.lr, + "parameters": parameters, + "epsilon": 1e-9 if optim_type == 'noam' else None, + "beta1": 0.9 if optim_type == 'noam' else None, + "beat2": 0.98 if optim_type == 'noam' else None, + } + + optimzer_args = optimizer_args(config, model.parameters(), lr_scheduler) + optimizer = OptimizerFactory.from_args(optim_type, optimzer_args) + + self.model = model + self.optimizer = optimizer + self.lr_scheduler = lr_scheduler + logger.info("Setup model/optimizer/lr_scheduler!") + + def setup_updater(self): + output_dir = self.output_dir + config = self.config.training + + updater = U2Updater( + model=self.model, + optimizer=self.optimizer, + scheduler=self.lr_scheduler, + dataloader=self.train_loader, + output_dir=output_dir, + accum_grad=config.accum_grad) + + trainer = NewTrainer(updater, (config.n_epoch, 'epoch'), output_dir) + + evaluator = U2Evaluator(self.model, self.valid_loader) + + trainer.extend(evaluator, trigger=(1, "epoch")) + + if dist.get_rank() == 0: + trainer.extend(VisualDL(output_dir), trigger=(1, "iteration")) + num_snapshots = config.checkpoint.kbest_n + trainer.extend( + Snapshot( + mode='kbest', + max_size=num_snapshots, + indicator='VALID/LOSS', + less_better=True), + trigger=(1, 'epoch')) + # print(trainer.extensions) + # trainer.run() + self.trainer = trainer + + def run(self): + """The routine of the experiment after setup. This method is intended + to be used by the user. + """ + self.setup_updater() + with Timer("Training Done: {}"): + self.trainer.run() diff --git a/deepspeech/models/u2/__init__.py b/deepspeech/models/u2/__init__.py new file mode 100644 index 00000000..a9010f1d --- /dev/null +++ b/deepspeech/models/u2/__init__.py @@ -0,0 +1,19 @@ +# Copyright (c) 2021 PaddlePaddle Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +from .u2 import U2InferModel +from .u2 import U2Model +from .updater import U2Evaluator +from .updater import U2Updater + +__all__ = ["U2Model", "U2InferModel", "U2Evaluator", "U2Updater"] diff --git a/deepspeech/models/u2.py b/deepspeech/models/u2/u2.py similarity index 100% rename from deepspeech/models/u2.py rename to deepspeech/models/u2/u2.py diff --git a/deepspeech/models/u2/updater.py b/deepspeech/models/u2/updater.py new file mode 100644 index 00000000..7b70ca04 --- /dev/null +++ b/deepspeech/models/u2/updater.py @@ -0,0 +1,149 @@ +# Copyright (c) 2021 PaddlePaddle Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +from contextlib import nullcontext + +import paddle +from paddle import distributed as dist + +from deepspeech.training.extensions.evaluator import StandardEvaluator +from deepspeech.training.reporter import report +from deepspeech.training.timer import Timer +from deepspeech.training.updaters.standard_updater import StandardUpdater +from deepspeech.utils import layer_tools +from deepspeech.utils.log import Log + +logger = Log(__name__).getlog() + + +class U2Evaluator(StandardEvaluator): + def __init__(self, model, dataloader): + super().__init__(model, dataloader) + self.msg = "" + self.num_seen_utts = 0 + self.total_loss = 0.0 + + def evaluate_core(self, batch): + self.msg = "Valid: Rank: {}, ".format(dist.get_rank()) + losses_dict = {} + + loss, attention_loss, ctc_loss = self.model(*batch[1:]) + if paddle.isfinite(loss): + num_utts = batch[1].shape[0] + self.num_seen_utts += num_utts + self.total_loss += float(loss) * num_utts + + losses_dict['loss'] = float(loss) + if attention_loss: + losses_dict['att_loss'] = float(attention_loss) + if ctc_loss: + losses_dict['ctc_loss'] = float(ctc_loss) + + for k, v in losses_dict.items(): + report("eval/" + k, v) + + self.msg += ', '.join('{}: {:>.6f}'.format(k, v) + for k, v in losses_dict.items()) + logger.info(self.msg) + return self.total_loss, self.num_seen_utts + + +class U2Updater(StandardUpdater): + def __init__(self, + model, + optimizer, + scheduler, + dataloader, + init_state=None, + accum_grad=1, + **kwargs): + super().__init__( + model, optimizer, scheduler, dataloader, init_state=init_state) + self.accum_grad = accum_grad + self.forward_count = 0 + self.msg = "" + + def update_core(self, batch): + """One Step + + Args: + batch (List[Object]): utts, xs, xlens, ys, ylens + """ + losses_dict = {} + self.msg = "Rank: {}, ".format(dist.get_rank()) + + # forward + batch_size = batch[1].shape[0] + loss, attention_loss, ctc_loss = self.model(*batch[1:]) + # loss div by `batch_size * accum_grad` + loss /= self.accum_grad + + # loss backward + if (self.forward_count + 1) != self.accum_grad: + # Disable gradient synchronizations across DDP processes. + # Within this context, gradients will be accumulated on module + # variables, which will later be synchronized. + context = self.model.no_sync + else: + # Used for single gpu training and DDP gradient synchronization + # processes. + context = nullcontext + + with context(): + loss.backward() + layer_tools.print_grads(self.model, print_func=None) + + # loss info + losses_dict['loss'] = float(loss) * self.accum_grad + if attention_loss: + losses_dict['att_loss'] = float(attention_loss) + if ctc_loss: + losses_dict['ctc_loss'] = float(ctc_loss) + # report loss + for k, v in losses_dict.items(): + report("train/" + k, v) + # loss msg + self.msg += "batch size: {}, ".format(batch_size) + self.msg += "accum: {}, ".format(self.accum_grad) + self.msg += ', '.join('{}: {:>.6f}'.format(k, v) + for k, v in losses_dict.items()) + + # Truncate the graph + loss.detach() + + # update parameters + self.forward_count += 1 + if self.forward_count != self.accum_grad: + return + self.forward_count = 0 + + self.optimizer.step() + self.optimizer.clear_grad() + self.scheduler.step() + + def update(self): + # model is default in train mode + + # training for a step is implemented here + with Timer("data time cost:{}"): + batch = self.read_batch() + with Timer("step time cost:{}"): + self.update_core(batch) + + # #iterations with accum_grad > 1 + # Ref.: https://github.com/espnet/espnet/issues/777 + if self.forward_count == 0: + self.state.iteration += 1 + if self.updates_per_epoch is not None: + if self.state.iteration % self.updates_per_epoch == 0: + self.state.epoch += 1 diff --git a/deepspeech/modules/loss.py b/deepspeech/modules/loss.py index 023a1923..2c58be7e 100644 --- a/deepspeech/modules/loss.py +++ b/deepspeech/modules/loss.py @@ -46,7 +46,7 @@ class CTCLoss(nn.Layer): if grad_norm_type == 'instance': self.norm_by_times = True if grad_norm_type == 'batch': - self.norm_by_times = True + self.norm_by_batchsize = True if grad_norm_type == 'frame': self.norm_by_total_logits_len = True diff --git a/deepspeech/training/extensions/evaluator.py b/deepspeech/training/extensions/evaluator.py index 96ff967f..d5b35982 100644 --- a/deepspeech/training/extensions/evaluator.py +++ b/deepspeech/training/extensions/evaluator.py @@ -13,14 +13,18 @@ # limitations under the License. from typing import Dict -import extension import paddle +from paddle import distributed as dist from paddle.io import DataLoader from paddle.nn import Layer +from . import extension from ..reporter import DictSummary from ..reporter import report from ..reporter import scope +from ..timer import Timer +from deepspeech.utils.log import Log +logger = Log(__name__).getlog() class StandardEvaluator(extension.Extension): @@ -43,6 +47,27 @@ class StandardEvaluator(extension.Extension): def evaluate_core(self, batch): # compute self.model(batch) # you may report here + return + + def evaluate_sync(self, data): + # dist sync `evaluate_core` outputs + if data is None: + return + + numerator, denominator = data + if dist.get_world_size() > 1: + numerator = paddle.to_tensor(numerator) + denominator = paddle.to_tensor(denominator) + # the default operator in all_reduce function is sum. + dist.all_reduce(numerator) + dist.all_reduce(denominator) + value = numerator / denominator + value = float(value) + else: + value = numerator / denominator + # used for `snapshort` to do kbest save. + report("VALID/LOSS", value) + logger.info(f"Valid: all-reduce loss {value}") def evaluate(self): # switch to eval mode @@ -56,9 +81,13 @@ class StandardEvaluator(extension.Extension): with scope(observation): # main evaluation computation here. with paddle.no_grad(): - self.evaluate_core(batch) + self.evaluate_sync(self.evaluate_core(batch)) summary.add(observation) summary = summary.compute_mean() + + # switch to train mode + for model in self.models.values(): + model.train() return summary def __call__(self, trainer=None): @@ -66,6 +95,7 @@ class StandardEvaluator(extension.Extension): # if it is used to extend a trainer, the metrics is reported to # to observation of the trainer # or otherwise, you can use your own observation - summary = self.evaluate() + with Timer("Eval Time Cost: {}"): + summary = self.evaluate() for k, v in summary.items(): report(k, v) diff --git a/deepspeech/training/extensions/snapshot.py b/deepspeech/training/extensions/snapshot.py index cb4e6dfb..1d3fe70c 100644 --- a/deepspeech/training/extensions/snapshot.py +++ b/deepspeech/training/extensions/snapshot.py @@ -20,8 +20,9 @@ from typing import List import jsonlines -from deepspeech.training.extensions import extension -from deepspeech.training.updaters.trainer import Trainer +from . import extension +from ..reporter import get_observations +from ..updaters.trainer import Trainer from deepspeech.utils.log import Log from deepspeech.utils.mp_tools import rank_zero_only @@ -52,8 +53,19 @@ class Snapshot(extension.Extension): priority = -100 default_name = "snapshot" - def __init__(self, max_size: int=5, snapshot_on_error: bool=False): + def __init__(self, + mode='latest', + max_size: int=5, + indicator=None, + less_better=True, + snapshot_on_error: bool=False): self.records: List[Dict[str, Any]] = [] + assert mode in ('latest', 'kbest'), mode + if mode == 'kbest': + assert indicator is not None + self.mode = mode + self.indicator = indicator + self.less_is_better = less_better self.max_size = max_size self._snapshot_on_error = snapshot_on_error self._save_all = (max_size == -1) @@ -66,16 +78,17 @@ class Snapshot(extension.Extension): # load existing records record_path: Path = self.checkpoint_dir / "records.jsonl" if record_path.exists(): - logger.debug("Loading from an existing checkpoint dir") self.records = load_records(record_path) - trainer.updater.load(self.records[-1]['path']) + ckpt_path = self.records[-1]['path'] + logger.info(f"Loading from an existing checkpoint {ckpt_path}") + trainer.updater.load(ckpt_path) def on_error(self, trainer, exc, tb): if self._snapshot_on_error: - self.save_checkpoint_and_update(trainer) + self.save_checkpoint_and_update(trainer, 'latest') def __call__(self, trainer: Trainer): - self.save_checkpoint_and_update(trainer) + self.save_checkpoint_and_update(trainer, self.mode) def full(self): """Whether the number of snapshots it keeps track of is greater @@ -83,7 +96,7 @@ class Snapshot(extension.Extension): return (not self._save_all) and len(self.records) > self.max_size @rank_zero_only - def save_checkpoint_and_update(self, trainer: Trainer): + def save_checkpoint_and_update(self, trainer: Trainer, mode: str): """Saving new snapshot and remove the oldest snapshot if needed.""" iteration = trainer.updater.state.iteration epoch = trainer.updater.state.epoch @@ -97,11 +110,17 @@ class Snapshot(extension.Extension): 'path': str(path.resolve()), # use absolute path 'iteration': iteration, 'epoch': epoch, + 'indicator': get_observations()[self.indicator] } self.records.append(record) # remove the earist if self.full(): + if mode == 'kbest': + self.records = sorted( + self.records, + key=lambda record: record['indicator'], + reverse=not self.less_is_better) eariest_record = self.records[0] os.remove(eariest_record["path"]) self.records.pop(0) diff --git a/deepspeech/training/extensions/visualizer.py b/deepspeech/training/extensions/visualizer.py index b69e94aa..e5f456ca 100644 --- a/deepspeech/training/extensions/visualizer.py +++ b/deepspeech/training/extensions/visualizer.py @@ -11,8 +11,10 @@ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. -from deepspeech.training.extensions import extension -from deepspeech.training.updaters.trainer import Trainer +from visualdl import LogWriter + +from . import extension +from ..updaters.trainer import Trainer class VisualDL(extension.Extension): @@ -26,8 +28,8 @@ class VisualDL(extension.Extension): default_name = 'visualdl' priority = extension.PRIORITY_READER - def __init__(self, writer): - self.writer = writer + def __init__(self, output_dir): + self.writer = LogWriter(str(output_dir)) def __call__(self, trainer: Trainer): for k, v in trainer.observation.items(): diff --git a/deepspeech/training/trainer.py b/deepspeech/training/trainer.py index 25c002df..7959b41b 100644 --- a/deepspeech/training/trainer.py +++ b/deepspeech/training/trainer.py @@ -171,7 +171,7 @@ class Trainer(): self.iteration = 0 self.epoch = 0 scratch = True - + logger.info("Restore/Init checkpoint!") return scratch def new_epoch(self): diff --git a/deepspeech/training/updaters/standard_updater.py b/deepspeech/training/updaters/standard_updater.py index fc758e93..10c99e7f 100644 --- a/deepspeech/training/updaters/standard_updater.py +++ b/deepspeech/training/updaters/standard_updater.py @@ -14,12 +14,12 @@ from typing import Dict from typing import Optional -from paddle import Tensor +import paddle from paddle.io import DataLoader from paddle.io import DistributedBatchSampler from paddle.nn import Layer from paddle.optimizer import Optimizer -from timer import timer +from paddle.optimizer.lr import LRScheduler from deepspeech.training.reporter import report from deepspeech.training.updaters.updater import UpdaterBase @@ -39,8 +39,10 @@ class StandardUpdater(UpdaterBase): def __init__(self, model: Layer, optimizer: Optimizer, + scheduler: LRScheduler, dataloader: DataLoader, init_state: Optional[UpdaterState]=None): + super().__init__(init_state) # it is designed to hold multiple models models = {"main": model} self.models: Dict[str, Layer] = models @@ -51,15 +53,14 @@ class StandardUpdater(UpdaterBase): self.optimizer = optimizer self.optimizers: Dict[str, Optimizer] = optimizers + # it is designed to hold multiple scheduler + schedulers = {"main": scheduler} + self.scheduler = scheduler + self.schedulers: Dict[str, LRScheduler] = schedulers + # dataloaders self.dataloader = dataloader - # init state - if init_state is None: - self.state = UpdaterState() - else: - self.state = init_state - self.train_iterator = iter(dataloader) def update(self): @@ -103,8 +104,10 @@ class StandardUpdater(UpdaterBase): model.train() # training for a step is implemented here - batch = self.read_batch() - self.update_core(batch) + with Timier("data time cost:{}"): + batch = self.read_batch() + with Timier("step time cost:{}"): + self.update_core(batch) self.state.iteration += 1 if self.updates_per_epoch is not None: @@ -115,13 +118,14 @@ class StandardUpdater(UpdaterBase): """A simple case for a training step. Basic assumptions are: Single model; Single optimizer; + Single scheduler, and update learning rate each step; A batch from the dataloader is just the input of the model; The model return a single loss, or a dict containing serval losses. Parameters updates at every batch, no gradient accumulation. """ loss = self.model(*batch) - if isinstance(loss, Tensor): + if isinstance(loss, paddle.Tensor): loss_dict = {"main": loss} else: # Dict[str, Tensor] @@ -135,14 +139,15 @@ class StandardUpdater(UpdaterBase): for name, loss_item in loss_dict.items(): report(name, float(loss_item)) - self.optimizer.clear_gradient() + self.optimizer.clear_grad() loss_dict["main"].backward() - self.optimizer.update() + self.optimizer.step() + self.scheduler.step() @property def updates_per_epoch(self): - """Number of updater per epoch, determined by the length of the - dataloader.""" + """Number of steps per epoch, + determined by the length of the dataloader.""" length_of_dataloader = None try: length_of_dataloader = len(self.dataloader) @@ -163,18 +168,16 @@ class StandardUpdater(UpdaterBase): def read_batch(self): """Read a batch from the data loader, auto renew when data is exhausted.""" - with timer() as t: - try: - batch = next(self.train_iterator) - except StopIteration: - self.new_epoch() - batch = next(self.train_iterator) - logger.debug( - f"Read a batch takes {t.elapse}s.") # replace it with logger + try: + batch = next(self.train_iterator) + except StopIteration: + self.new_epoch() + batch = next(self.train_iterator) return batch def state_dict(self): - """State dict of a Updater, model, optimizer and updater state are included.""" + """State dict of a Updater, model, optimizers/schedulers + and updater state are included.""" state_dict = super().state_dict() for name, model in self.models.items(): state_dict[f"{name}_params"] = model.state_dict() @@ -184,7 +187,7 @@ class StandardUpdater(UpdaterBase): def set_state_dict(self, state_dict): """Set state dict for a Updater. Parameters of models, states for - optimizers and UpdaterState are restored.""" + optimizers/schedulers and UpdaterState are restored.""" for name, model in self.models.items(): model.set_state_dict(state_dict[f"{name}_params"]) for name, optim in self.optimizers.items(): diff --git a/deepspeech/training/updaters/trainer.py b/deepspeech/training/updaters/trainer.py index 954ce260..a52fb9eb 100644 --- a/deepspeech/training/updaters/trainer.py +++ b/deepspeech/training/updaters/trainer.py @@ -140,8 +140,8 @@ class Trainer(): try: while not stop_trigger(self): self.observation = {} - # set observation as the report target - # you can use report freely in Updater.update() + # set observation as the `report` target + # you can use `report` freely in Updater.update() # updating parameters and state with scope(self.observation): diff --git a/deepspeech/training/updaters/updater.py b/deepspeech/training/updaters/updater.py index 66fdc2bb..e5dd6556 100644 --- a/deepspeech/training/updaters/updater.py +++ b/deepspeech/training/updaters/updater.py @@ -52,6 +52,7 @@ class UpdaterBase(): """ def __init__(self, init_state=None): + # init state if init_state is None: self.state = UpdaterState() else: diff --git a/deepspeech/utils/checkpoint.py b/deepspeech/utils/checkpoint.py index a59f8be7..8e31edfa 100644 --- a/deepspeech/utils/checkpoint.py +++ b/deepspeech/utils/checkpoint.py @@ -114,13 +114,13 @@ class Checkpoint(): params_path = checkpoint_path + ".pdparams" model_dict = paddle.load(params_path) model.set_state_dict(model_dict) - logger.info("Rank {}: loaded model from {}".format(rank, params_path)) + logger.info("Rank {}: Restore model from {}".format(rank, params_path)) optimizer_path = checkpoint_path + ".pdopt" if optimizer and os.path.isfile(optimizer_path): optimizer_dict = paddle.load(optimizer_path) optimizer.set_state_dict(optimizer_dict) - logger.info("Rank {}: loaded optimizer state from {}".format( + logger.info("Rank {}: Restore optimizer state from {}".format( rank, optimizer_path)) info_path = re.sub('.pdparams$', '.json', params_path) diff --git a/deepspeech/utils/log.py b/deepspeech/utils/log.py index 3fd7d248..7e8de600 100644 --- a/deepspeech/utils/log.py +++ b/deepspeech/utils/log.py @@ -12,19 +12,13 @@ # See the License for the specific language governing permissions and # limitations under the License. import getpass -import logging import os import socket import sys +from loguru import logger from paddle import inference -FORMAT_STR = '[%(levelname)s %(asctime)s %(filename)s:%(lineno)d] %(message)s' -DATE_FMT_STR = '%Y/%m/%d %H:%M:%S' - -logging.basicConfig( - level=logging.DEBUG, format=FORMAT_STR, datefmt=DATE_FMT_STR) - def find_log_dir(log_dir=None): """Returns the most suitable directory to put log files into. @@ -98,59 +92,28 @@ def find_log_dir_and_names(program_name=None, log_dir=None): class Log(): - - log_name = None - - def __init__(self, logger=None): - self.logger = logging.getLogger(logger) - self.logger.setLevel(logging.DEBUG) - - file_dir = os.getcwd() + '/log' - if not os.path.exists(file_dir): - os.mkdir(file_dir) - self.log_dir = file_dir - - actual_log_dir, file_prefix, symlink_prefix = find_log_dir_and_names( - program_name=None, log_dir=self.log_dir) - - basename = '%s.DEBUG.%d' % (file_prefix, os.getpid()) - filename = os.path.join(actual_log_dir, basename) - if Log.log_name is None: - Log.log_name = filename - - # Create a symlink to the log file with a canonical name. - symlink = os.path.join(actual_log_dir, symlink_prefix + '.DEBUG') - try: - if os.path.islink(symlink): - os.unlink(symlink) - os.symlink(os.path.basename(Log.log_name), symlink) - except EnvironmentError: - # If it fails, we're sad but it's no error. Commonly, this - # fails because the symlink was created by another user and so - # we can't modify it - pass - - if not self.logger.hasHandlers(): - formatter = logging.Formatter(fmt=FORMAT_STR, datefmt=DATE_FMT_STR) - fh = logging.FileHandler(Log.log_name) - fh.setLevel(logging.DEBUG) - fh.setFormatter(formatter) - self.logger.addHandler(fh) - - ch = logging.StreamHandler() - ch.setLevel(logging.INFO) - ch.setFormatter(formatter) - self.logger.addHandler(ch) - - # stop propagate for propagating may print - # log multiple times - self.logger.propagate = False + """Default Logger for all.""" + logger.remove() + logger.add( + sys.stdout, + level='INFO', + enqueue=True, + filter=lambda record: record['level'].no >= 20) + _, file_prefix, _ = find_log_dir_and_names() + sink_prefix = os.path.join("exp/log", file_prefix) + sink_path = sink_prefix[:-3] + "{time}.log" + logger.add(sink_path, level='DEBUG', enqueue=True, rotation="500 MB") + + def __init__(self, name=None): + pass def getlog(self): - return self.logger + return logger class Autolog: + """Just used by fullchain project""" + def __init__(self, batch_size, model_name="DeepSpeech", diff --git a/examples/tiny/s1/conf/transformer.yaml b/examples/tiny/s1/conf/transformer.yaml index 5127e8c6..fcbe1da4 100644 --- a/examples/tiny/s1/conf/transformer.yaml +++ b/examples/tiny/s1/conf/transformer.yaml @@ -86,7 +86,7 @@ training: lr_decay: 1.0 log_interval: 1 checkpoint: - kbest_n: 10 + kbest_n: 2 latest_n: 1 diff --git a/requirements.txt b/requirements.txt index 7c3da37e..ebf879b5 100644 --- a/requirements.txt +++ b/requirements.txt @@ -1,7 +1,9 @@ coverage gpustat jsonlines +jsonlines kaldiio +loguru Pillow pre-commit pybind11 @@ -14,5 +16,7 @@ SoundFile==0.9.0.post1 sox tensorboardX textgrid +tqdm typeguard +visualdl==2.2.0 yacs