diff --git a/deepspeech/__init__.py b/deepspeech/__init__.py index d85a3dde7..3ac68a3b5 100644 --- a/deepspeech/__init__.py +++ b/deepspeech/__init__.py @@ -351,20 +351,3 @@ if not hasattr(paddle.Tensor, 'tolist'): logger.warn( "register user tolist to paddle.Tensor, remove this when fixed!") setattr(paddle.Tensor, 'tolist', tolist) - - -########### hcak paddle.nn ############# -class GLU(nn.Layer): - """Gated Linear Units (GLU) Layer""" - - def __init__(self, dim: int=-1): - super().__init__() - self.dim = dim - - def forward(self, xs): - return F.glu(xs, axis=self.dim) - - -if not hasattr(paddle.nn, 'GLU'): - logger.warn("register user GLU to paddle.nn, remove this when fixed!") - setattr(paddle.nn, 'GLU', GLU) diff --git a/deepspeech/exps/deepspeech2/model.py b/deepspeech/exps/deepspeech2/model.py index f3e3fcadf..fbc357ca0 100644 --- a/deepspeech/exps/deepspeech2/model.py +++ b/deepspeech/exps/deepspeech2/model.py @@ -15,6 +15,7 @@ import os import time from collections import defaultdict +from contextlib import nullcontext from pathlib import Path from typing import Optional @@ -65,29 +66,51 @@ class DeepSpeech2Trainer(Trainer): super().__init__(config, args) def train_batch(self, batch_index, batch_data, msg): + train_conf = self.config.training start = time.time() + + # forward utt, audio, audio_len, text, text_len = batch_data loss = self.model(audio, audio_len, text, text_len) - loss.backward() - layer_tools.print_grads(self.model, print_func=None) - self.optimizer.step() - self.optimizer.clear_grad() - iteration_time = time.time() - start - losses_np = { 'train_loss': float(loss), } + + # loss backward + if (batch_index + 1) % train_conf.accum_grad != 0: + # Disable gradient synchronizations across DDP processes. + # Within this context, gradients will be accumulated on module + # variables, which will later be synchronized. + context = self.model.no_sync + else: + # Used for single gpu training and DDP gradient synchronization + # processes. + context = nullcontext + + with context(): + loss.backward() + layer_tools.print_grads(self.model, print_func=None) + + # optimizer step + if (batch_index + 1) % train_conf.accum_grad == 0: + self.optimizer.step() + self.optimizer.clear_grad() + self.iteration += 1 + + iteration_time = time.time() - start + msg += "train time: {:>.3f}s, ".format(iteration_time) msg += "batch size: {}, ".format(self.config.collator.batch_size) + msg += "accum: {}, ".format(train_conf.accum_grad) msg += ', '.join('{}: {:>.6f}'.format(k, v) for k, v in losses_np.items()) logger.info(msg) if dist.get_rank() == 0 and self.visualizer: for k, v in losses_np.items(): + # `step -1` since we update `step` after optimizer.step(). self.visualizer.add_scalar("train/{}".format(k), v, - self.iteration) - self.iteration += 1 + self.iteration - 1) @paddle.no_grad() def valid(self): diff --git a/deepspeech/exps/u2/model.py b/deepspeech/exps/u2/model.py index 0662e38d9..8ab9a26e8 100644 --- a/deepspeech/exps/u2/model.py +++ b/deepspeech/exps/u2/model.py @@ -17,6 +17,7 @@ import os import sys import time from collections import defaultdict +from contextlib import nullcontext from pathlib import Path from typing import Optional @@ -79,21 +80,35 @@ class U2Trainer(Trainer): def train_batch(self, batch_index, batch_data, msg): train_conf = self.config.training start = time.time() - utt, audio, audio_len, text, text_len = batch_data + # forward + utt, audio, audio_len, text, text_len = batch_data loss, attention_loss, ctc_loss = self.model(audio, audio_len, text, text_len) + # loss div by `batch_size * accum_grad` loss /= train_conf.accum_grad - loss.backward() - layer_tools.print_grads(self.model, print_func=None) - losses_np = {'loss': float(loss) * train_conf.accum_grad} if attention_loss: losses_np['att_loss'] = float(attention_loss) if ctc_loss: losses_np['ctc_loss'] = float(ctc_loss) + # loss backward + if (batch_index + 1) % train_conf.accum_grad != 0: + # Disable gradient synchronizations across DDP processes. + # Within this context, gradients will be accumulated on module + # variables, which will later be synchronized. + context = self.model.no_sync + else: + # Used for single gpu training and DDP gradient synchronization + # processes. + context = nullcontext + with context(): + loss.backward() + layer_tools.print_grads(self.model, print_func=None) + + # optimizer step if (batch_index + 1) % train_conf.accum_grad == 0: self.optimizer.step() self.optimizer.clear_grad() diff --git a/deepspeech/exps/u2_kaldi/model.py b/deepspeech/exps/u2_kaldi/model.py index 6a932d751..140ee947f 100644 --- a/deepspeech/exps/u2_kaldi/model.py +++ b/deepspeech/exps/u2_kaldi/model.py @@ -17,6 +17,7 @@ import os import sys import time from collections import defaultdict +from contextlib import nullcontext from pathlib import Path from typing import Optional @@ -83,20 +84,34 @@ class U2Trainer(Trainer): train_conf = self.config.training start = time.time() + # forward utt, audio, audio_len, text, text_len = batch_data loss, attention_loss, ctc_loss = self.model(audio, audio_len, text, text_len) + # loss div by `batch_size * accum_grad` loss /= train_conf.accum_grad - loss.backward() - layer_tools.print_grads(self.model, print_func=None) - losses_np = {'loss': float(loss) * train_conf.accum_grad} if attention_loss: losses_np['att_loss'] = float(attention_loss) if ctc_loss: losses_np['ctc_loss'] = float(ctc_loss) + # loss backward + if (batch_index + 1) % train_conf.accum_grad != 0: + # Disable gradient synchronizations across DDP processes. + # Within this context, gradients will be accumulated on module + # variables, which will later be synchronized. + context = self.model.no_sync + else: + # Used for single gpu training and DDP gradient synchronization + # processes. + context = nullcontext + with context(): + loss.backward() + layer_tools.print_grads(self.model, print_func=None) + + # optimizer step if (batch_index + 1) % train_conf.accum_grad == 0: self.optimizer.step() self.optimizer.clear_grad() diff --git a/deepspeech/exps/u2_st/model.py b/deepspeech/exps/u2_st/model.py index 5734e15f5..ef5938b77 100644 --- a/deepspeech/exps/u2_st/model.py +++ b/deepspeech/exps/u2_st/model.py @@ -17,6 +17,7 @@ import os import sys import time from collections import defaultdict +from contextlib import nullcontext from pathlib import Path from typing import Optional @@ -83,6 +84,7 @@ class U2STTrainer(Trainer): def train_batch(self, batch_index, batch_data, msg): train_conf = self.config.training start = time.time() + # forward utt, audio, audio_len, text, text_len = batch_data if isinstance(text, list) and isinstance(text_len, list): # joint training with ASR. Two decoding texts [translation, transcription] @@ -94,18 +96,30 @@ class U2STTrainer(Trainer): else: loss, st_loss, attention_loss, ctc_loss = self.model( audio, audio_len, text, text_len) + # loss div by `batch_size * accum_grad` loss /= train_conf.accum_grad - loss.backward() - layer_tools.print_grads(self.model, print_func=None) - losses_np = {'loss': float(loss) * train_conf.accum_grad} - losses_np['st_loss'] = float(st_loss) if attention_loss: losses_np['att_loss'] = float(attention_loss) if ctc_loss: losses_np['ctc_loss'] = float(ctc_loss) + # loss backward + if (batch_index + 1) % train_conf.accum_grad != 0: + # Disable gradient synchronizations across DDP processes. + # Within this context, gradients will be accumulated on module + # variables, which will later be synchronized. + context = self.model.no_sync + else: + # Used for single gpu training and DDP gradient synchronization + # processes. + context = nullcontext + with context(): + loss.backward() + layer_tools.print_grads(self.model, print_func=None) + + # optimizer step if (batch_index + 1) % train_conf.accum_grad == 0: self.optimizer.step() self.optimizer.clear_grad() diff --git a/deepspeech/modules/activation.py b/deepspeech/modules/activation.py index 30132775e..3cb8729e1 100644 --- a/deepspeech/modules/activation.py +++ b/deepspeech/modules/activation.py @@ -15,12 +15,13 @@ from collections import OrderedDict import paddle from paddle import nn +from paddle.nn import functional as F from deepspeech.utils.log import Log logger = Log(__name__).getlog() -__all__ = ["get_activation", "brelu", "LinearGLUBlock", "ConvGLUBlock"] +__all__ = ["get_activation", "brelu", "LinearGLUBlock", "ConvGLUBlock", "GLU"] def brelu(x, t_min=0.0, t_max=24.0, name=None): @@ -30,6 +31,17 @@ def brelu(x, t_min=0.0, t_max=24.0, name=None): return x.maximum(t_min).minimum(t_max) +class GLU(nn.Layer): + """Gated Linear Units (GLU) Layer""" + + def __init__(self, dim: int=-1): + super().__init__() + self.dim = dim + + def forward(self, xs): + return F.glu(xs, axis=self.dim) + + class LinearGLUBlock(nn.Layer): """A linear Gated Linear Units (GLU) block.""" @@ -133,13 +145,18 @@ def get_activation(act): """Return activation function.""" # Lazy load to avoid unused import activation_funcs = { + "hardshrink": paddle.nn.Hardshrink, + "hardswish": paddle.nn.Hardswish, "hardtanh": paddle.nn.Hardtanh, "tanh": paddle.nn.Tanh, "relu": paddle.nn.ReLU, + "relu6": paddle.nn.ReLU6, + "leakyrelu": paddle.nn.LeakyReLU, "selu": paddle.nn.SELU, "swish": paddle.nn.Swish, "gelu": paddle.nn.GELU, - "brelu": brelu, + "glu": GLU, + "elu": paddle.nn.ELU, } return activation_funcs[act]() diff --git a/examples/aishell/s0/conf/deepspeech2.yaml b/examples/aishell/s0/conf/deepspeech2.yaml index 0f465a8f7..4bf03ec63 100644 --- a/examples/aishell/s0/conf/deepspeech2.yaml +++ b/examples/aishell/s0/conf/deepspeech2.yaml @@ -44,6 +44,7 @@ model: training: n_epoch: 80 + accum_grad: 1 lr: 2e-3 lr_decay: 0.83 weight_decay: 1e-06 diff --git a/examples/aishell/s0/conf/deepspeech2_online.yaml b/examples/aishell/s0/conf/deepspeech2_online.yaml index 9f05d8dd8..9946852d0 100644 --- a/examples/aishell/s0/conf/deepspeech2_online.yaml +++ b/examples/aishell/s0/conf/deepspeech2_online.yaml @@ -46,6 +46,7 @@ model: training: n_epoch: 50 + accum_grad: 1 lr: 2e-3 lr_decay: 0.9 # 0.83 weight_decay: 1e-06 diff --git a/examples/librispeech/s0/conf/deepspeech2.yaml b/examples/librispeech/s0/conf/deepspeech2.yaml index 2c31e66e1..0e6ed5bab 100644 --- a/examples/librispeech/s0/conf/deepspeech2.yaml +++ b/examples/librispeech/s0/conf/deepspeech2.yaml @@ -11,7 +11,7 @@ data: max_output_input_ratio: .inf collator: - batch_size: 20 + batch_size: 15 mean_std_filepath: data/mean_std.json unit_type: char vocab_filepath: data/vocab.txt @@ -44,6 +44,7 @@ model: training: n_epoch: 50 + accum_grad: 4 lr: 1e-3 lr_decay: 0.83 weight_decay: 1e-06 diff --git a/examples/librispeech/s0/conf/deepspeech2_online.yaml b/examples/librispeech/s0/conf/deepspeech2_online.yaml index 87445c0b4..6e74f7042 100644 --- a/examples/librispeech/s0/conf/deepspeech2_online.yaml +++ b/examples/librispeech/s0/conf/deepspeech2_online.yaml @@ -11,7 +11,7 @@ data: max_output_input_ratio: .inf collator: - batch_size: 20 + batch_size: 15 mean_std_filepath: data/mean_std.json unit_type: char vocab_filepath: data/vocab.txt @@ -46,6 +46,7 @@ model: training: n_epoch: 50 + accum_grad: 4 lr: 1e-3 lr_decay: 0.83 weight_decay: 1e-06 diff --git a/examples/tiny/s0/conf/deepspeech2.yaml b/examples/tiny/s0/conf/deepspeech2.yaml index c93217d32..5c9436e39 100644 --- a/examples/tiny/s0/conf/deepspeech2.yaml +++ b/examples/tiny/s0/conf/deepspeech2.yaml @@ -45,6 +45,7 @@ model: training: n_epoch: 10 + accum_grad: 1 lr: 1e-5 lr_decay: 1.0 weight_decay: 1e-06 diff --git a/examples/tiny/s0/conf/deepspeech2_online.yaml b/examples/tiny/s0/conf/deepspeech2_online.yaml index 4205a04ac..e435ff969 100644 --- a/examples/tiny/s0/conf/deepspeech2_online.yaml +++ b/examples/tiny/s0/conf/deepspeech2_online.yaml @@ -4,7 +4,7 @@ data: dev_manifest: data/manifest.tiny test_manifest: data/manifest.tiny min_input_len: 0.0 - max_input_len: 27.0 + max_input_len: 30.0 min_output_len: 0.0 max_output_len: 400.0 min_output_input_ratio: 0.05 @@ -47,6 +47,7 @@ model: training: n_epoch: 10 + accum_grad: 1 lr: 1e-5 lr_decay: 1.0 weight_decay: 1e-06