sequence_mask make all inputs zeros, which cause grad be zero, this is a bug of LessThanOp

add grad clip by global norm
add model train test notebook
pull/522/head
Hui Zhang 5 years ago
parent 54b13722f5
commit f121f851d9

@ -470,3 +470,108 @@ class SpeechCollator():
texts = np.array(texts).astype('int32') texts = np.array(texts).astype('int32')
text_lens = np.array(text_lens).astype('int64') text_lens = np.array(text_lens).astype('int64')
return padded_audios, texts, audio_lens, text_lens return padded_audios, texts, audio_lens, text_lens
def create_dataloader(manifest_path,
vocab_filepath,
mean_std_filepath,
augmentation_config='{}',
max_duration=float('inf'),
min_duration=0.0,
stride_ms=10.0,
window_ms=20.0,
max_freq=None,
specgram_type='linear',
use_dB_normalization=True,
random_seed=0,
keep_transcription_text=False,
is_training=False,
batch_size=1,
num_workers=0,
sortagrad=False,
shuffle_method=None,
dist=False):
dataset = DeepSpeech2Dataset(
manifest_path,
vocab_filepath,
mean_std_filepath,
augmentation_config=augmentation_config,
max_duration=max_duration,
min_duration=min_duration,
stride_ms=stride_ms,
window_ms=window_ms,
max_freq=max_freq,
specgram_type=specgram_type,
use_dB_normalization=use_dB_normalization,
random_seed=random_seed,
keep_transcription_text=keep_transcription_text)
if dist:
batch_sampler = DeepSpeech2DistributedBatchSampler(
dataset,
batch_size,
num_replicas=None,
rank=None,
shuffle=is_training,
drop_last=is_training,
sortagrad=is_training,
shuffle_method=shuffle_method)
else:
batch_sampler = DeepSpeech2BatchSampler(
dataset,
shuffle=is_training,
batch_size=batch_size,
drop_last=is_training,
sortagrad=is_training,
shuffle_method=shuffle_method)
def padding_batch(batch, padding_to=-1, flatten=False, is_training=True):
"""
Padding audio features with zeros to make them have the same shape (or
a user-defined shape) within one bach.
If ``padding_to`` is -1, the maximun shape in the batch will be used
as the target shape for padding. Otherwise, `padding_to` will be the
target shape (only refers to the second axis).
If `flatten` is True, features will be flatten to 1darray.
"""
new_batch = []
# get target shape
max_length = max([audio.shape[1] for audio, text in batch])
if padding_to != -1:
if padding_to < max_length:
raise ValueError("If padding_to is not -1, it should be larger "
"than any instance's shape in the batch")
max_length = padding_to
max_text_length = max([len(text) for audio, text in batch])
# padding
padded_audios = []
audio_lens = []
texts, text_lens = [], []
for audio, text in batch:
padded_audio = np.zeros([audio.shape[0], max_length])
padded_audio[:, :audio.shape[1]] = audio
if flatten:
padded_audio = padded_audio.flatten()
padded_audios.append(padded_audio)
audio_lens.append(audio.shape[1])
padded_text = np.zeros([max_text_length])
padded_text[:len(text)] = text
texts.append(padded_text)
text_lens.append(len(text))
padded_audios = np.array(padded_audios).astype('float32')
audio_lens = np.array(audio_lens).astype('int64')
texts = np.array(texts).astype('int32')
text_lens = np.array(text_lens).astype('int64')
return padded_audios, texts, audio_lens, text_lens
loader = DataLoader(
dataset,
batch_sampler=batch_sampler,
collate_fn=partial(padding_batch, is_training=is_training),
num_workers=num_workers)
return loader

@ -32,7 +32,6 @@ export FLAGS_sync_nccl_allreduce=0
#--specgram_type="linear" \ #--specgram_type="linear" \
#--shuffle_method="batch_shuffle_clipped" \ #--shuffle_method="batch_shuffle_clipped" \
CUDA_VISIBLE_DEVICES=2,3,5,7 \
python3 -u ${MAIN_ROOT}/train.py \ python3 -u ${MAIN_ROOT}/train.py \
--device 'gpu' \ --device 'gpu' \
--nproc 4 \ --nproc 4 \

@ -17,6 +17,7 @@ import io
import sys import sys
import os import os
import time import time
import logging
import numpy as np import numpy as np
from collections import defaultdict from collections import defaultdict
@ -24,6 +25,13 @@ import paddle
from paddle import distributed as dist from paddle import distributed as dist
from paddle.io import DataLoader from paddle.io import DataLoader
from paddle.fluid.dygraph import base as imperative_base
from paddle.fluid import layers
from paddle.fluid import framework
from paddle.fluid import core
from paddle.fluid import name_scope
from utils import mp_tools from utils import mp_tools
from training import Trainer from training import Trainer
@ -41,6 +49,68 @@ from decoders.swig_wrapper import ctc_beam_search_decoder_batch
from utils.error_rate import char_errors, word_errors, cer, wer from utils.error_rate import char_errors, word_errors, cer, wer
logger = logging.getLogger(__name__)
class MyClipGradByGlobalNorm(paddle.nn.ClipGradByGlobalNorm):
def __init__(self, clip_norm):
super().__init__(clip_norm)
@imperative_base.no_grad
def _dygraph_clip(self, params_grads):
params_and_grads = []
sum_square_list = []
for p, g in params_grads:
if g is None:
continue
if getattr(p, 'need_clip', True) is False:
continue
merge_grad = g
if g.type == core.VarDesc.VarType.SELECTED_ROWS:
merge_grad = layers.merge_selected_rows(g)
merge_grad = layers.get_tensor_from_selected_rows(merge_grad)
square = layers.square(merge_grad)
sum_square = layers.reduce_sum(square)
logger.info(f"Grad Before Clip: {p.name}: {float(layers.sqrt(layers.reduce_sum(layers.square(merge_grad))) ) }")
sum_square_list.append(sum_square)
# all parameters have been filterd out
if len(sum_square_list) == 0:
return params_grads
global_norm_var = layers.concat(sum_square_list)
global_norm_var = layers.reduce_sum(global_norm_var)
global_norm_var = layers.sqrt(global_norm_var)
logger.info(f"Grad Global Norm: {float(global_norm_var)}!!!!")
max_global_norm = layers.fill_constant(
shape=[1], dtype=global_norm_var.dtype, value=self.clip_norm)
clip_var = layers.elementwise_div(
x=max_global_norm,
y=layers.elementwise_max(
x=global_norm_var, y=max_global_norm))
for p, g in params_grads:
if g is None:
continue
if getattr(p, 'need_clip', True) is False:
params_and_grads.append((p, g))
continue
new_grad = layers.elementwise_mul(x=g, y=clip_var)
logger.info(f"Grad After Clip: {p.name}: {float(layers.sqrt(layers.reduce_sum(layers.square(merge_grad))) ) }")
params_and_grads.append((p, new_grad))
return params_and_grads
def print_grads(model, logger=None):
for n, p in model.named_parameters():
msg = f"param grad: {n}: shape: {p.shape} grad: {p.grad}"
if logger:
logger.info(msg)
def print_params(model, logger=None):
for n, p in model.named_parameters():
msg = f"param: {n}: shape: {p.shape} stop_grad: {p.stop_gradient}"
if logger:
logger.info(msg)
class DeepSpeech2Trainer(Trainer): class DeepSpeech2Trainer(Trainer):
def __init__(self, config, args): def __init__(self, config, args):
@ -61,6 +131,7 @@ class DeepSpeech2Trainer(Trainer):
loss = self.compute_losses(batch_data, outputs) loss = self.compute_losses(batch_data, outputs)
loss.backward() loss.backward()
print_grads(self.model, logger=None)
self.optimizer.step() self.optimizer.step()
self.optimizer.clear_grad() self.optimizer.clear_grad()
@ -102,15 +173,19 @@ class DeepSpeech2Trainer(Trainer):
f"Train Total Examples: {len(self.train_loader.dataset)}") f"Train Total Examples: {len(self.train_loader.dataset)}")
self.new_epoch() self.new_epoch()
while self.epoch <= self.config.training.n_epoch: while self.epoch <= self.config.training.n_epoch:
for batch in self.train_loader: try:
self.iteration += 1 for batch in self.train_loader:
self.train_batch(batch) self.iteration += 1
self.train_batch(batch)
# if self.iteration % self.config.training.valid_interval == 0: # if self.iteration % self.config.training.valid_interval == 0:
# self.valid() # self.valid()
# if self.iteration % self.config.training.save_interval == 0: # if self.iteration % self.config.training.save_interval == 0:
# self.save() # self.save()
except Exception as e:
self.logger.error(e)
pass
self.valid() self.valid()
self.save() self.save()
@ -166,11 +241,9 @@ class DeepSpeech2Trainer(Trainer):
if self.parallel: if self.parallel:
model = paddle.DataParallel(model) model = paddle.DataParallel(model)
for n, p in model.named_parameters(): print_params(model, self.logger)
self.logger.info(
f"param: {n}: shape: {p.shape} stop_grad: {p.stop_gradient}")
grad_clip = paddle.nn.ClipGradByGlobalNorm( grad_clip = MyClipGradByGlobalNorm(
config.training.global_grad_clip) config.training.global_grad_clip)
# optimizer = paddle.optimizer.Adam( # optimizer = paddle.optimizer.Adam(
@ -299,6 +372,9 @@ class DeepSpeech2Tester(DeepSpeech2Trainer):
error_rate_func = cer if cfg.error_rate_type == 'cer' else wer error_rate_func = cer if cfg.error_rate_type == 'cer' else wer
vocab_list = self.test_loader.dataset.vocab_list vocab_list = self.test_loader.dataset.vocab_list
for t in vocab_list:
self.logger.info(f"vocab: {t}")
target_transcripts = self.id2token(texts, texts_len, vocab_list) target_transcripts = self.id2token(texts, texts_len, vocab_list)
result_transcripts = self.model.decode_probs( result_transcripts = self.model.decode_probs(
probs.numpy(), probs.numpy(),

@ -31,27 +31,30 @@ logger = logging.getLogger(__name__)
__all__ = ['DeepSpeech2', 'DeepSpeech2Loss'] __all__ = ['DeepSpeech2', 'DeepSpeech2Loss']
def ctc_loss(log_probs, def ctc_loss(logits,
labels, labels,
input_lengths, input_lengths,
label_lengths, label_lengths,
blank=0, blank=0,
reduction='mean', reduction='mean',
norm_by_times=True): norm_by_times=False):
#logger.info("my ctc loss with norm by times") #logger.info("my ctc loss with norm by times")
## https://github.com/PaddlePaddle/Paddle/blob/f5ca2db2cc/paddle/fluid/operators/warpctc_op.h#L403
loss_out = paddle.fluid.layers.warpctc( loss_out = paddle.fluid.layers.warpctc(
log_probs, labels, blank, norm_by_times, input_lengths, label_lengths) logits, labels, blank, norm_by_times, input_lengths, label_lengths)
loss_out = paddle.fluid.layers.squeeze(loss_out, [-1]) loss_out = paddle.fluid.layers.squeeze(loss_out, [-1])
logger.info(f"warpctc loss: {loss_out}/{loss_out.shape} ")
assert reduction in ['mean', 'sum', 'none'] assert reduction in ['mean', 'sum', 'none']
if reduction == 'mean': if reduction == 'mean':
loss_out = paddle.mean(loss_out / label_lengths) loss_out = paddle.mean(loss_out / label_lengths)
elif reduction == 'sum': elif reduction == 'sum':
loss_out = paddle.sum(loss_out) loss_out = paddle.sum(loss_out)
logger.info(f"ctc loss: {loss_out}")
return loss_out return loss_out
F.ctc_loss = ctc_loss #F.ctc_loss = ctc_loss
def brelu(x, t_min=0.0, t_max=24.0, name=None): def brelu(x, t_min=0.0, t_max=24.0, name=None):
@ -64,7 +67,8 @@ def sequence_mask(x_len, max_len=None, dtype='float32'):
max_len = max_len or x_len.max() max_len = max_len or x_len.max()
x_len = paddle.unsqueeze(x_len, -1) x_len = paddle.unsqueeze(x_len, -1)
row_vector = paddle.arange(max_len) row_vector = paddle.arange(max_len)
mask = row_vector < x_len #mask = row_vector < x_len
mask = row_vector > x_len # a bug, broadcast 的时候出错了
mask = paddle.cast(mask, dtype) mask = paddle.cast(mask, dtype)
return mask return mask
@ -119,7 +123,7 @@ class ConvBn(nn.Layer):
weight_attr=None, weight_attr=None,
bias_attr=None, bias_attr=None,
data_format='NCHW') data_format='NCHW')
self.act = paddle.relu if act == 'relu' else brelu self.act = F.relu if act == 'relu' else brelu
def forward(self, x, x_len): def forward(self, x, x_len):
""" """
@ -154,16 +158,13 @@ class ConvStack(nn.Layer):
self.feat_size = feat_size # D self.feat_size = feat_size # D
self.num_stacks = num_stacks self.num_stacks = num_stacks
self.filter_size = (41, 11) # [D, T]
self.stride = (2, 3)
self.padding = (20, 5)
self.conv_in = ConvBn( self.conv_in = ConvBn(
num_channels_in=1, num_channels_in=1,
num_channels_out=32, num_channels_out=32,
kernel_size=self.filter_size, kernel_size=(41, 11), #[D, T]
stride=self.stride, stride=(2, 3),
padding=self.padding, padding=(20, 5),
act='brelu', ) act='brelu')
out_channel = 32 out_channel = 32
self.conv_stack = nn.LayerList([ self.conv_stack = nn.LayerList([
@ -307,7 +308,7 @@ class GRUCellShare(nn.RNNCellBase):
self.input_size = input_size self.input_size = input_size
self._gate_activation = F.sigmoid self._gate_activation = F.sigmoid
#self._activation = paddle.tanh #self._activation = paddle.tanh
self._activation = paddle.relu self._activation = F.relu
def forward(self, inputs, states=None): def forward(self, inputs, states=None):
if states is None: if states is None:
@ -479,6 +480,9 @@ class RNNStack(nn.Layer):
""" """
for i, rnn in enumerate(self.rnn_stacks): for i, rnn in enumerate(self.rnn_stacks):
x, x_len = rnn(x, x_len) x, x_len = rnn(x, x_len)
masks = sequence_mask(x_len) #[B, T]
masks = masks.unsqueeze(-1) # [B, T, 1]
x = x.multiply(masks)
return x, x_len return x, x_len
@ -544,14 +548,17 @@ class DeepSpeech2(nn.Layer):
# convolution group # convolution group
x, audio_len = self.conv(audio, audio_len) x, audio_len = self.conv(audio, audio_len)
#print('conv out', x.shape)
# convert data from convolution feature map to sequence of vectors # convert data from convolution feature map to sequence of vectors
B, C, D, T = paddle.shape(x) B, C, D, T = paddle.shape(x)
x = x.transpose([0, 3, 1, 2]) #[B, T, C, D] x = x.transpose([0, 3, 1, 2]) #[B, T, C, D]
x = x.reshape([B, T, C * D]) #[B, T, C*D] x = x.reshape([B, T, C * D]) #[B, T, C*D]
#print('rnn input', x.shape)
# remove padding part # remove padding part
x, audio_len = self.rnn(x, audio_len) #[B, T, D] x, audio_len = self.rnn(x, audio_len) #[B, T, D]
#print('rnn output', x.shape)
logits = self.fc(x) #[B, T, V + 1] logits = self.fc(x) #[B, T, V + 1]
@ -713,7 +720,7 @@ class DeepSpeech2Loss(nn.Layer):
def __init__(self, vocab_size): def __init__(self, vocab_size):
super().__init__() super().__init__()
# last token id as blank id # last token id as blank id
self.loss = nn.CTCLoss(blank=vocab_size, reduction='none') self.loss = nn.CTCLoss(blank=vocab_size, reduction='sum')
def forward(self, logits, text, logits_len, text_len): def forward(self, logits, text, logits_len, text_len):
# warp-ctc do softmax on activations # warp-ctc do softmax on activations
@ -721,7 +728,4 @@ class DeepSpeech2Loss(nn.Layer):
logits = logits.transpose([1, 0, 2]) logits = logits.transpose([1, 0, 2])
ctc_loss = self.loss(logits, text, logits_len, text_len) ctc_loss = self.loss(logits, text, logits_len, text_len)
## https://github.com/PaddlePaddle/Paddle/blob/f5ca2db2cc/paddle/fluid/operators/warpctc_op.h#L403
#ctc_loss /= logits_len # norm_by_times
ctc_loss = ctc_loss.sum()
return ctc_loss return ctc_loss

@ -12,7 +12,7 @@
# See the License for the specific language governing permissions and # See the License for the specific language governing permissions and
# limitations under the License. # limitations under the License.
from network2 import DeepSpeech2 from model_utils.network import DeepSpeech2
import paddle import paddle
import numpy as np import numpy as np
@ -51,7 +51,7 @@ if __name__ == '__main__':
rnn_size=1024, rnn_size=1024,
use_gru=False, use_gru=False,
share_rnn_weights=False, ) share_rnn_weights=False, )
probs = model(audio, text, audio_len, text_len) logits, probs, logits_len = model(audio, text, audio_len, text_len)
print('probs.shape', probs.shape) print('probs.shape', probs.shape)
print("-----------------") print("-----------------")

File diff suppressed because it is too large Load Diff
Loading…
Cancel
Save