sequence_mask make all inputs zeros, which cause grad be zero, this is a bug of LessThanOp

add grad clip by global norm
add model train test notebook
pull/522/head
Hui Zhang 5 years ago
parent 54b13722f5
commit f121f851d9

@ -470,3 +470,108 @@ class SpeechCollator():
texts = np.array(texts).astype('int32')
text_lens = np.array(text_lens).astype('int64')
return padded_audios, texts, audio_lens, text_lens
def create_dataloader(manifest_path,
vocab_filepath,
mean_std_filepath,
augmentation_config='{}',
max_duration=float('inf'),
min_duration=0.0,
stride_ms=10.0,
window_ms=20.0,
max_freq=None,
specgram_type='linear',
use_dB_normalization=True,
random_seed=0,
keep_transcription_text=False,
is_training=False,
batch_size=1,
num_workers=0,
sortagrad=False,
shuffle_method=None,
dist=False):
dataset = DeepSpeech2Dataset(
manifest_path,
vocab_filepath,
mean_std_filepath,
augmentation_config=augmentation_config,
max_duration=max_duration,
min_duration=min_duration,
stride_ms=stride_ms,
window_ms=window_ms,
max_freq=max_freq,
specgram_type=specgram_type,
use_dB_normalization=use_dB_normalization,
random_seed=random_seed,
keep_transcription_text=keep_transcription_text)
if dist:
batch_sampler = DeepSpeech2DistributedBatchSampler(
dataset,
batch_size,
num_replicas=None,
rank=None,
shuffle=is_training,
drop_last=is_training,
sortagrad=is_training,
shuffle_method=shuffle_method)
else:
batch_sampler = DeepSpeech2BatchSampler(
dataset,
shuffle=is_training,
batch_size=batch_size,
drop_last=is_training,
sortagrad=is_training,
shuffle_method=shuffle_method)
def padding_batch(batch, padding_to=-1, flatten=False, is_training=True):
"""
Padding audio features with zeros to make them have the same shape (or
a user-defined shape) within one bach.
If ``padding_to`` is -1, the maximun shape in the batch will be used
as the target shape for padding. Otherwise, `padding_to` will be the
target shape (only refers to the second axis).
If `flatten` is True, features will be flatten to 1darray.
"""
new_batch = []
# get target shape
max_length = max([audio.shape[1] for audio, text in batch])
if padding_to != -1:
if padding_to < max_length:
raise ValueError("If padding_to is not -1, it should be larger "
"than any instance's shape in the batch")
max_length = padding_to
max_text_length = max([len(text) for audio, text in batch])
# padding
padded_audios = []
audio_lens = []
texts, text_lens = [], []
for audio, text in batch:
padded_audio = np.zeros([audio.shape[0], max_length])
padded_audio[:, :audio.shape[1]] = audio
if flatten:
padded_audio = padded_audio.flatten()
padded_audios.append(padded_audio)
audio_lens.append(audio.shape[1])
padded_text = np.zeros([max_text_length])
padded_text[:len(text)] = text
texts.append(padded_text)
text_lens.append(len(text))
padded_audios = np.array(padded_audios).astype('float32')
audio_lens = np.array(audio_lens).astype('int64')
texts = np.array(texts).astype('int32')
text_lens = np.array(text_lens).astype('int64')
return padded_audios, texts, audio_lens, text_lens
loader = DataLoader(
dataset,
batch_sampler=batch_sampler,
collate_fn=partial(padding_batch, is_training=is_training),
num_workers=num_workers)
return loader

@ -32,7 +32,6 @@ export FLAGS_sync_nccl_allreduce=0
#--specgram_type="linear" \
#--shuffle_method="batch_shuffle_clipped" \
CUDA_VISIBLE_DEVICES=2,3,5,7 \
python3 -u ${MAIN_ROOT}/train.py \
--device 'gpu' \
--nproc 4 \

@ -17,6 +17,7 @@ import io
import sys
import os
import time
import logging
import numpy as np
from collections import defaultdict
@ -24,6 +25,13 @@ import paddle
from paddle import distributed as dist
from paddle.io import DataLoader
from paddle.fluid.dygraph import base as imperative_base
from paddle.fluid import layers
from paddle.fluid import framework
from paddle.fluid import core
from paddle.fluid import name_scope
from utils import mp_tools
from training import Trainer
@ -41,6 +49,68 @@ from decoders.swig_wrapper import ctc_beam_search_decoder_batch
from utils.error_rate import char_errors, word_errors, cer, wer
logger = logging.getLogger(__name__)
class MyClipGradByGlobalNorm(paddle.nn.ClipGradByGlobalNorm):
def __init__(self, clip_norm):
super().__init__(clip_norm)
@imperative_base.no_grad
def _dygraph_clip(self, params_grads):
params_and_grads = []
sum_square_list = []
for p, g in params_grads:
if g is None:
continue
if getattr(p, 'need_clip', True) is False:
continue
merge_grad = g
if g.type == core.VarDesc.VarType.SELECTED_ROWS:
merge_grad = layers.merge_selected_rows(g)
merge_grad = layers.get_tensor_from_selected_rows(merge_grad)
square = layers.square(merge_grad)
sum_square = layers.reduce_sum(square)
logger.info(f"Grad Before Clip: {p.name}: {float(layers.sqrt(layers.reduce_sum(layers.square(merge_grad))) ) }")
sum_square_list.append(sum_square)
# all parameters have been filterd out
if len(sum_square_list) == 0:
return params_grads
global_norm_var = layers.concat(sum_square_list)
global_norm_var = layers.reduce_sum(global_norm_var)
global_norm_var = layers.sqrt(global_norm_var)
logger.info(f"Grad Global Norm: {float(global_norm_var)}!!!!")
max_global_norm = layers.fill_constant(
shape=[1], dtype=global_norm_var.dtype, value=self.clip_norm)
clip_var = layers.elementwise_div(
x=max_global_norm,
y=layers.elementwise_max(
x=global_norm_var, y=max_global_norm))
for p, g in params_grads:
if g is None:
continue
if getattr(p, 'need_clip', True) is False:
params_and_grads.append((p, g))
continue
new_grad = layers.elementwise_mul(x=g, y=clip_var)
logger.info(f"Grad After Clip: {p.name}: {float(layers.sqrt(layers.reduce_sum(layers.square(merge_grad))) ) }")
params_and_grads.append((p, new_grad))
return params_and_grads
def print_grads(model, logger=None):
for n, p in model.named_parameters():
msg = f"param grad: {n}: shape: {p.shape} grad: {p.grad}"
if logger:
logger.info(msg)
def print_params(model, logger=None):
for n, p in model.named_parameters():
msg = f"param: {n}: shape: {p.shape} stop_grad: {p.stop_gradient}"
if logger:
logger.info(msg)
class DeepSpeech2Trainer(Trainer):
def __init__(self, config, args):
@ -61,6 +131,7 @@ class DeepSpeech2Trainer(Trainer):
loss = self.compute_losses(batch_data, outputs)
loss.backward()
print_grads(self.model, logger=None)
self.optimizer.step()
self.optimizer.clear_grad()
@ -102,15 +173,19 @@ class DeepSpeech2Trainer(Trainer):
f"Train Total Examples: {len(self.train_loader.dataset)}")
self.new_epoch()
while self.epoch <= self.config.training.n_epoch:
for batch in self.train_loader:
self.iteration += 1
self.train_batch(batch)
try:
for batch in self.train_loader:
self.iteration += 1
self.train_batch(batch)
# if self.iteration % self.config.training.valid_interval == 0:
# self.valid()
# if self.iteration % self.config.training.valid_interval == 0:
# self.valid()
# if self.iteration % self.config.training.save_interval == 0:
# self.save()
# if self.iteration % self.config.training.save_interval == 0:
# self.save()
except Exception as e:
self.logger.error(e)
pass
self.valid()
self.save()
@ -166,11 +241,9 @@ class DeepSpeech2Trainer(Trainer):
if self.parallel:
model = paddle.DataParallel(model)
for n, p in model.named_parameters():
self.logger.info(
f"param: {n}: shape: {p.shape} stop_grad: {p.stop_gradient}")
print_params(model, self.logger)
grad_clip = paddle.nn.ClipGradByGlobalNorm(
grad_clip = MyClipGradByGlobalNorm(
config.training.global_grad_clip)
# optimizer = paddle.optimizer.Adam(
@ -299,6 +372,9 @@ class DeepSpeech2Tester(DeepSpeech2Trainer):
error_rate_func = cer if cfg.error_rate_type == 'cer' else wer
vocab_list = self.test_loader.dataset.vocab_list
for t in vocab_list:
self.logger.info(f"vocab: {t}")
target_transcripts = self.id2token(texts, texts_len, vocab_list)
result_transcripts = self.model.decode_probs(
probs.numpy(),

@ -31,27 +31,30 @@ logger = logging.getLogger(__name__)
__all__ = ['DeepSpeech2', 'DeepSpeech2Loss']
def ctc_loss(log_probs,
def ctc_loss(logits,
labels,
input_lengths,
label_lengths,
blank=0,
reduction='mean',
norm_by_times=True):
norm_by_times=False):
#logger.info("my ctc loss with norm by times")
## https://github.com/PaddlePaddle/Paddle/blob/f5ca2db2cc/paddle/fluid/operators/warpctc_op.h#L403
loss_out = paddle.fluid.layers.warpctc(
log_probs, labels, blank, norm_by_times, input_lengths, label_lengths)
logits, labels, blank, norm_by_times, input_lengths, label_lengths)
loss_out = paddle.fluid.layers.squeeze(loss_out, [-1])
logger.info(f"warpctc loss: {loss_out}/{loss_out.shape} ")
assert reduction in ['mean', 'sum', 'none']
if reduction == 'mean':
loss_out = paddle.mean(loss_out / label_lengths)
elif reduction == 'sum':
loss_out = paddle.sum(loss_out)
logger.info(f"ctc loss: {loss_out}")
return loss_out
F.ctc_loss = ctc_loss
#F.ctc_loss = ctc_loss
def brelu(x, t_min=0.0, t_max=24.0, name=None):
@ -64,7 +67,8 @@ def sequence_mask(x_len, max_len=None, dtype='float32'):
max_len = max_len or x_len.max()
x_len = paddle.unsqueeze(x_len, -1)
row_vector = paddle.arange(max_len)
mask = row_vector < x_len
#mask = row_vector < x_len
mask = row_vector > x_len # a bug, broadcast 的时候出错了
mask = paddle.cast(mask, dtype)
return mask
@ -119,7 +123,7 @@ class ConvBn(nn.Layer):
weight_attr=None,
bias_attr=None,
data_format='NCHW')
self.act = paddle.relu if act == 'relu' else brelu
self.act = F.relu if act == 'relu' else brelu
def forward(self, x, x_len):
"""
@ -154,16 +158,13 @@ class ConvStack(nn.Layer):
self.feat_size = feat_size # D
self.num_stacks = num_stacks
self.filter_size = (41, 11) # [D, T]
self.stride = (2, 3)
self.padding = (20, 5)
self.conv_in = ConvBn(
num_channels_in=1,
num_channels_out=32,
kernel_size=self.filter_size,
stride=self.stride,
padding=self.padding,
act='brelu', )
kernel_size=(41, 11), #[D, T]
stride=(2, 3),
padding=(20, 5),
act='brelu')
out_channel = 32
self.conv_stack = nn.LayerList([
@ -307,7 +308,7 @@ class GRUCellShare(nn.RNNCellBase):
self.input_size = input_size
self._gate_activation = F.sigmoid
#self._activation = paddle.tanh
self._activation = paddle.relu
self._activation = F.relu
def forward(self, inputs, states=None):
if states is None:
@ -479,6 +480,9 @@ class RNNStack(nn.Layer):
"""
for i, rnn in enumerate(self.rnn_stacks):
x, x_len = rnn(x, x_len)
masks = sequence_mask(x_len) #[B, T]
masks = masks.unsqueeze(-1) # [B, T, 1]
x = x.multiply(masks)
return x, x_len
@ -544,14 +548,17 @@ class DeepSpeech2(nn.Layer):
# convolution group
x, audio_len = self.conv(audio, audio_len)
#print('conv out', x.shape)
# convert data from convolution feature map to sequence of vectors
B, C, D, T = paddle.shape(x)
x = x.transpose([0, 3, 1, 2]) #[B, T, C, D]
x = x.reshape([B, T, C * D]) #[B, T, C*D]
#print('rnn input', x.shape)
# remove padding part
x, audio_len = self.rnn(x, audio_len) #[B, T, D]
#print('rnn output', x.shape)
logits = self.fc(x) #[B, T, V + 1]
@ -713,7 +720,7 @@ class DeepSpeech2Loss(nn.Layer):
def __init__(self, vocab_size):
super().__init__()
# last token id as blank id
self.loss = nn.CTCLoss(blank=vocab_size, reduction='none')
self.loss = nn.CTCLoss(blank=vocab_size, reduction='sum')
def forward(self, logits, text, logits_len, text_len):
# warp-ctc do softmax on activations
@ -721,7 +728,4 @@ class DeepSpeech2Loss(nn.Layer):
logits = logits.transpose([1, 0, 2])
ctc_loss = self.loss(logits, text, logits_len, text_len)
## https://github.com/PaddlePaddle/Paddle/blob/f5ca2db2cc/paddle/fluid/operators/warpctc_op.h#L403
#ctc_loss /= logits_len # norm_by_times
ctc_loss = ctc_loss.sum()
return ctc_loss

@ -12,7 +12,7 @@
# See the License for the specific language governing permissions and
# limitations under the License.
from network2 import DeepSpeech2
from model_utils.network import DeepSpeech2
import paddle
import numpy as np
@ -51,7 +51,7 @@ if __name__ == '__main__':
rnn_size=1024,
use_gru=False,
share_rnn_weights=False, )
probs = model(audio, text, audio_len, text_len)
logits, probs, logits_len = model(audio, text, audio_len, text_len)
print('probs.shape', probs.shape)
print("-----------------")

File diff suppressed because it is too large Load Diff
Loading…
Cancel
Save