Merge pull request #814 from PaddlePaddle/no_sync

support no_sync for backward; ds support accum grad
pull/815/head
Jackwaterveg 4 years ago committed by GitHub
commit 75cd366ddd
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23

@ -351,20 +351,3 @@ if not hasattr(paddle.Tensor, 'tolist'):
logger.warn(
"register user tolist to paddle.Tensor, remove this when fixed!")
setattr(paddle.Tensor, 'tolist', tolist)
########### hcak paddle.nn #############
class GLU(nn.Layer):
"""Gated Linear Units (GLU) Layer"""
def __init__(self, dim: int=-1):
super().__init__()
self.dim = dim
def forward(self, xs):
return F.glu(xs, axis=self.dim)
if not hasattr(paddle.nn, 'GLU'):
logger.warn("register user GLU to paddle.nn, remove this when fixed!")
setattr(paddle.nn, 'GLU', GLU)

@ -15,6 +15,7 @@
import os
import time
from collections import defaultdict
from contextlib import nullcontext
from pathlib import Path
from typing import Optional
@ -65,29 +66,51 @@ class DeepSpeech2Trainer(Trainer):
super().__init__(config, args)
def train_batch(self, batch_index, batch_data, msg):
train_conf = self.config.training
start = time.time()
# forward
utt, audio, audio_len, text, text_len = batch_data
loss = self.model(audio, audio_len, text, text_len)
loss.backward()
layer_tools.print_grads(self.model, print_func=None)
self.optimizer.step()
self.optimizer.clear_grad()
iteration_time = time.time() - start
losses_np = {
'train_loss': float(loss),
}
# loss backward
if (batch_index + 1) % train_conf.accum_grad != 0:
# Disable gradient synchronizations across DDP processes.
# Within this context, gradients will be accumulated on module
# variables, which will later be synchronized.
context = self.model.no_sync
else:
# Used for single gpu training and DDP gradient synchronization
# processes.
context = nullcontext
with context():
loss.backward()
layer_tools.print_grads(self.model, print_func=None)
# optimizer step
if (batch_index + 1) % train_conf.accum_grad == 0:
self.optimizer.step()
self.optimizer.clear_grad()
self.iteration += 1
iteration_time = time.time() - start
msg += "train time: {:>.3f}s, ".format(iteration_time)
msg += "batch size: {}, ".format(self.config.collator.batch_size)
msg += "accum: {}, ".format(train_conf.accum_grad)
msg += ', '.join('{}: {:>.6f}'.format(k, v)
for k, v in losses_np.items())
logger.info(msg)
if dist.get_rank() == 0 and self.visualizer:
for k, v in losses_np.items():
# `step -1` since we update `step` after optimizer.step().
self.visualizer.add_scalar("train/{}".format(k), v,
self.iteration)
self.iteration += 1
self.iteration - 1)
@paddle.no_grad()
def valid(self):

@ -17,6 +17,7 @@ import os
import sys
import time
from collections import defaultdict
from contextlib import nullcontext
from pathlib import Path
from typing import Optional
@ -79,21 +80,35 @@ class U2Trainer(Trainer):
def train_batch(self, batch_index, batch_data, msg):
train_conf = self.config.training
start = time.time()
utt, audio, audio_len, text, text_len = batch_data
# forward
utt, audio, audio_len, text, text_len = batch_data
loss, attention_loss, ctc_loss = self.model(audio, audio_len, text,
text_len)
# loss div by `batch_size * accum_grad`
loss /= train_conf.accum_grad
loss.backward()
layer_tools.print_grads(self.model, print_func=None)
losses_np = {'loss': float(loss) * train_conf.accum_grad}
if attention_loss:
losses_np['att_loss'] = float(attention_loss)
if ctc_loss:
losses_np['ctc_loss'] = float(ctc_loss)
# loss backward
if (batch_index + 1) % train_conf.accum_grad != 0:
# Disable gradient synchronizations across DDP processes.
# Within this context, gradients will be accumulated on module
# variables, which will later be synchronized.
context = self.model.no_sync
else:
# Used for single gpu training and DDP gradient synchronization
# processes.
context = nullcontext
with context():
loss.backward()
layer_tools.print_grads(self.model, print_func=None)
# optimizer step
if (batch_index + 1) % train_conf.accum_grad == 0:
self.optimizer.step()
self.optimizer.clear_grad()

@ -17,6 +17,7 @@ import os
import sys
import time
from collections import defaultdict
from contextlib import nullcontext
from pathlib import Path
from typing import Optional
@ -83,20 +84,34 @@ class U2Trainer(Trainer):
train_conf = self.config.training
start = time.time()
# forward
utt, audio, audio_len, text, text_len = batch_data
loss, attention_loss, ctc_loss = self.model(audio, audio_len, text,
text_len)
# loss div by `batch_size * accum_grad`
loss /= train_conf.accum_grad
loss.backward()
layer_tools.print_grads(self.model, print_func=None)
losses_np = {'loss': float(loss) * train_conf.accum_grad}
if attention_loss:
losses_np['att_loss'] = float(attention_loss)
if ctc_loss:
losses_np['ctc_loss'] = float(ctc_loss)
# loss backward
if (batch_index + 1) % train_conf.accum_grad != 0:
# Disable gradient synchronizations across DDP processes.
# Within this context, gradients will be accumulated on module
# variables, which will later be synchronized.
context = self.model.no_sync
else:
# Used for single gpu training and DDP gradient synchronization
# processes.
context = nullcontext
with context():
loss.backward()
layer_tools.print_grads(self.model, print_func=None)
# optimizer step
if (batch_index + 1) % train_conf.accum_grad == 0:
self.optimizer.step()
self.optimizer.clear_grad()

@ -17,6 +17,7 @@ import os
import sys
import time
from collections import defaultdict
from contextlib import nullcontext
from pathlib import Path
from typing import Optional
@ -83,6 +84,7 @@ class U2STTrainer(Trainer):
def train_batch(self, batch_index, batch_data, msg):
train_conf = self.config.training
start = time.time()
# forward
utt, audio, audio_len, text, text_len = batch_data
if isinstance(text, list) and isinstance(text_len, list):
# joint training with ASR. Two decoding texts [translation, transcription]
@ -94,18 +96,30 @@ class U2STTrainer(Trainer):
else:
loss, st_loss, attention_loss, ctc_loss = self.model(
audio, audio_len, text, text_len)
# loss div by `batch_size * accum_grad`
loss /= train_conf.accum_grad
loss.backward()
layer_tools.print_grads(self.model, print_func=None)
losses_np = {'loss': float(loss) * train_conf.accum_grad}
losses_np['st_loss'] = float(st_loss)
if attention_loss:
losses_np['att_loss'] = float(attention_loss)
if ctc_loss:
losses_np['ctc_loss'] = float(ctc_loss)
# loss backward
if (batch_index + 1) % train_conf.accum_grad != 0:
# Disable gradient synchronizations across DDP processes.
# Within this context, gradients will be accumulated on module
# variables, which will later be synchronized.
context = self.model.no_sync
else:
# Used for single gpu training and DDP gradient synchronization
# processes.
context = nullcontext
with context():
loss.backward()
layer_tools.print_grads(self.model, print_func=None)
# optimizer step
if (batch_index + 1) % train_conf.accum_grad == 0:
self.optimizer.step()
self.optimizer.clear_grad()

@ -15,12 +15,13 @@ from collections import OrderedDict
import paddle
from paddle import nn
from paddle.nn import functional as F
from deepspeech.utils.log import Log
logger = Log(__name__).getlog()
__all__ = ["get_activation", "brelu", "LinearGLUBlock", "ConvGLUBlock"]
__all__ = ["get_activation", "brelu", "LinearGLUBlock", "ConvGLUBlock", "GLU"]
def brelu(x, t_min=0.0, t_max=24.0, name=None):
@ -30,6 +31,17 @@ def brelu(x, t_min=0.0, t_max=24.0, name=None):
return x.maximum(t_min).minimum(t_max)
class GLU(nn.Layer):
"""Gated Linear Units (GLU) Layer"""
def __init__(self, dim: int=-1):
super().__init__()
self.dim = dim
def forward(self, xs):
return F.glu(xs, axis=self.dim)
class LinearGLUBlock(nn.Layer):
"""A linear Gated Linear Units (GLU) block."""
@ -133,13 +145,18 @@ def get_activation(act):
"""Return activation function."""
# Lazy load to avoid unused import
activation_funcs = {
"hardshrink": paddle.nn.Hardshrink,
"hardswish": paddle.nn.Hardswish,
"hardtanh": paddle.nn.Hardtanh,
"tanh": paddle.nn.Tanh,
"relu": paddle.nn.ReLU,
"relu6": paddle.nn.ReLU6,
"leakyrelu": paddle.nn.LeakyReLU,
"selu": paddle.nn.SELU,
"swish": paddle.nn.Swish,
"gelu": paddle.nn.GELU,
"brelu": brelu,
"glu": GLU,
"elu": paddle.nn.ELU,
}
return activation_funcs[act]()

@ -44,6 +44,7 @@ model:
training:
n_epoch: 80
accum_grad: 1
lr: 2e-3
lr_decay: 0.83
weight_decay: 1e-06

@ -46,6 +46,7 @@ model:
training:
n_epoch: 50
accum_grad: 1
lr: 2e-3
lr_decay: 0.9 # 0.83
weight_decay: 1e-06

@ -11,7 +11,7 @@ data:
max_output_input_ratio: .inf
collator:
batch_size: 20
batch_size: 15
mean_std_filepath: data/mean_std.json
unit_type: char
vocab_filepath: data/vocab.txt
@ -44,6 +44,7 @@ model:
training:
n_epoch: 50
accum_grad: 4
lr: 1e-3
lr_decay: 0.83
weight_decay: 1e-06

@ -11,7 +11,7 @@ data:
max_output_input_ratio: .inf
collator:
batch_size: 20
batch_size: 15
mean_std_filepath: data/mean_std.json
unit_type: char
vocab_filepath: data/vocab.txt
@ -46,6 +46,7 @@ model:
training:
n_epoch: 50
accum_grad: 4
lr: 1e-3
lr_decay: 0.83
weight_decay: 1e-06

@ -45,6 +45,7 @@ model:
training:
n_epoch: 10
accum_grad: 1
lr: 1e-5
lr_decay: 1.0
weight_decay: 1e-06

@ -4,7 +4,7 @@ data:
dev_manifest: data/manifest.tiny
test_manifest: data/manifest.tiny
min_input_len: 0.0
max_input_len: 27.0
max_input_len: 30.0
min_output_len: 0.0
max_output_len: 400.0
min_output_input_ratio: 0.05
@ -47,6 +47,7 @@ model:
training:
n_epoch: 10
accum_grad: 1
lr: 1e-5
lr_decay: 1.0
weight_decay: 1e-06

Loading…
Cancel
Save