fix ctc loss

lr schedule
sortagrad
logger
pull/522/head
Hui Zhang 5 years ago
parent 45ae60b198
commit aeb09d812b

1
.gitignore vendored

@ -2,3 +2,4 @@
*.pyc *.pyc
tools/venv tools/venv
dataset dataset
models/*

@ -15,6 +15,7 @@
import math import math
import random import random
import tarfile import tarfile
import logging
import numpy as np import numpy as np
import paddle import paddle
from paddle.io import Dataset from paddle.io import Dataset
@ -30,6 +31,8 @@ from data_utils.featurizer.speech_featurizer import SpeechFeaturizer
from data_utils.speech import SpeechSegment from data_utils.speech import SpeechSegment
from data_utils.normalizer import FeatureNormalizer from data_utils.normalizer import FeatureNormalizer
logger = logging.getLogger(__name__)
__all__ = [ __all__ = [
"DeepSpeech2Dataset", "DeepSpeech2Dataset",
"DeepSpeech2DistributedBatchSampler", "DeepSpeech2DistributedBatchSampler",
@ -234,9 +237,13 @@ class DeepSpeech2DistributedBatchSampler(DistributedBatchSampler):
assert (clipped == False) assert (clipped == False)
if not clipped: if not clipped:
res_len = len(indices) - shift_len - len(batch_indices) res_len = len(indices) - shift_len - len(batch_indices)
assert res_len != 0, f"_batch_shuffle clipped {len(indices)} , {shift_len}, {len(batch_indices)}"
# when res_len is 0, will return whole list, len(List[-0:]) = len(List[:]) # when res_len is 0, will return whole list, len(List[-0:]) = len(List[:])
batch_indices.extend(indices[-res_len:]) batch_indices.extend(indices[-res_len:])
batch_indices.extend(indices[0:shift_len]) batch_indices.extend(indices[0:shift_len])
assert len(indices) == len(
batch_indices
), f"_batch_shuffle: {len(indices)} : {len(batch_indices)} : {res_len} - {shift_len}"
return batch_indices return batch_indices
def __iter__(self): def __iter__(self):
@ -247,8 +254,8 @@ class DeepSpeech2DistributedBatchSampler(DistributedBatchSampler):
# sort (by duration) or batch-wise shuffle the manifest # sort (by duration) or batch-wise shuffle the manifest
if self.shuffle: if self.shuffle:
if self.epoch == 0 and self.sortagrad: if self.epoch == 0 and self._sortagrad:
pass logger.info(f'dataset sortagrad! epoch {self.epoch}')
else: else:
if self._shuffle_method == "batch_shuffle": if self._shuffle_method == "batch_shuffle":
indices = self._batch_shuffle( indices = self._batch_shuffle(
@ -374,9 +381,13 @@ class DeepSpeech2BatchSampler(BatchSampler):
assert (clipped == False) assert (clipped == False)
if not clipped: if not clipped:
res_len = len(indices) - shift_len - len(batch_indices) res_len = len(indices) - shift_len - len(batch_indices)
assert res_len != 0, f"_batch_shuffle clipped {len(indices)} , {shift_len}, {len(batch_indices)}"
# when res_len is 0, will return whole list, len(List[-0:]) = len(List[:]) # when res_len is 0, will return whole list, len(List[-0:]) = len(List[:])
batch_indices.extend(indices[-res_len:]) batch_indices.extend(indices[-res_len:])
batch_indices.extend(indices[0:shift_len]) batch_indices.extend(indices[0:shift_len])
assert len(indices) == len(
batch_indices
), f"_batch_shuffle: {len(indices)} : {len(batch_indices)} : {res_len} - {shift_len}"
return batch_indices return batch_indices
def __iter__(self): def __iter__(self):
@ -388,7 +399,7 @@ class DeepSpeech2BatchSampler(BatchSampler):
# sort (by duration) or batch-wise shuffle the manifest # sort (by duration) or batch-wise shuffle the manifest
if self.shuffle: if self.shuffle:
if self.epoch == 0 and self._sortagrad: if self.epoch == 0 and self._sortagrad:
pass logger.info(f'dataset sortagrad! epoch {self.epoch}')
else: else:
if self._shuffle_method == "batch_shuffle": if self._shuffle_method == "batch_shuffle":
indices = self._batch_shuffle( indices = self._batch_shuffle(

@ -3,33 +3,42 @@
# train model # train model
# if you wish to resume from an exists model, uncomment --init_from_pretrained_model # if you wish to resume from an exists model, uncomment --init_from_pretrained_model
export FLAGS_sync_nccl_allreduce=0 export FLAGS_sync_nccl_allreduce=0
CUDA_VISIBLE_DEVICES=0,1,2,3,4,5,6,7 \
#CUDA_VISIBLE_DEVICES=0,1,2,3,4,5,6,7 \
#python3 -u ${MAIN_ROOT}/train.py \
#--batch_size=64 \
#--num_epoch=50 \
#--num_conv_layers=2 \
#--num_rnn_layers=3 \
#--rnn_layer_size=1024 \
#--num_iter_print=100 \
#--save_epoch=1 \
#--num_samples=120000 \
#--learning_rate=5e-4 \
#--max_duration=27.0 \
#--min_duration=0.0 \
#--test_off=False \
#--use_sortagrad=True \
#--use_gru=True \
#--use_gpu=True \
#--is_local=True \
#--share_rnn_weights=False \
#--train_manifest="data/manifest.train" \
#--dev_manifest="data/manifest.dev" \
#--mean_std_path="data/mean_std.npz" \
#--vocab_path="data/vocab.txt" \
#--output_model_dir="./checkpoints" \
#--augment_conf_path="${MAIN_ROOT}/conf/augmentation.config" \
#--specgram_type="linear" \
#--shuffle_method="batch_shuffle_clipped" \
CUDA_VISIBLE_DEVICES=1,2,6,7 \
python3 -u ${MAIN_ROOT}/train.py \ python3 -u ${MAIN_ROOT}/train.py \
--batch_size=64 \ --device 'gpu' \
--num_epoch=50 \ --nproc 4 \
--num_conv_layers=2 \ --config conf/deepspeech2.yaml \
--num_rnn_layers=3 \ --output ckpt
--rnn_layer_size=1024 \
--num_iter_print=100 \
--save_epoch=1 \
--num_samples=120000 \
--learning_rate=5e-4 \
--max_duration=27.0 \
--min_duration=0.0 \
--test_off=False \
--use_sortagrad=True \
--use_gru=True \
--use_gpu=True \
--is_local=True \
--share_rnn_weights=False \
--train_manifest="data/manifest.train" \
--dev_manifest="data/manifest.dev" \
--mean_std_path="data/mean_std.npz" \
--vocab_path="data/vocab.txt" \
--output_model_dir="./checkpoints" \
--augment_conf_path="${MAIN_ROOT}/conf/augmentation.config" \
--specgram_type="linear" \
--shuffle_method="batch_shuffle_clipped" \
if [ $? -ne 0 ]; then if [ $? -ne 0 ]; then
echo "Failed in training!" echo "Failed in training!"

@ -2,10 +2,12 @@
export FLAGS_sync_nccl_allreduce=0 export FLAGS_sync_nccl_allreduce=0
CUDA_VISIBLE_DEVICES=0,1,2,3 \ #CUDA_VISIBLE_DEVICES=0,1,2,3 \
#CUDA_VISIBLE_DEVICES=0,4,5,6 \
CUDA_VISIBLE_DEVICES=0 \
python3 -u ${MAIN_ROOT}/train.py \ python3 -u ${MAIN_ROOT}/train.py \
--device 'gpu' \ --device 'gpu' \
--nproc 4 \ --nproc 1 \
--config conf/deepspeech2.yaml \ --config conf/deepspeech2.yaml \
--output ckpt --output ckpt

@ -41,6 +41,22 @@ from decoders.swig_wrapper import ctc_beam_search_decoder_batch
from utils.error_rate import char_errors, word_errors, cer, wer from utils.error_rate import char_errors, word_errors, cer, wer
# class ExponentialDecayOld(LRScheduler):
# def __init__(self, learning_rate, gamma, last_epoch=-1,
# decay_steps, decay_rate, staircase=False, verbose=False):
# self.learning_rate = learning_rate
# self.decay_steps = decay_steps
# self.decay_rate = decay_rate
# self.staircase = staircase
# super(ExponentialDecay, self).__init__(learning_rate, last_epoch,
# verbose)
# def get_lr(self):
# div_res = self.step_num / self.decay_steps
# if self.staircase:
# div_res = paddle.floor(div_res)
# return self.base_lr * (self.decay_rate**div_res)
class DeepSpeech2Trainer(Trainer): class DeepSpeech2Trainer(Trainer):
def __init__(self, config, args): def __init__(self, config, args):
@ -73,21 +89,28 @@ class DeepSpeech2Trainer(Trainer):
self.optimizer.clear_grad() self.optimizer.clear_grad()
self.model.train() self.model.train()
audio, text, audio_len, text_len = batch audio, text, audio_len, text_len = batch
batch_size = audio.shape[0]
outputs = self.model(audio, text, audio_len, text_len) outputs = self.model(audio, text, audio_len, text_len)
loss = self.compute_losses(batch, outputs) loss = self.compute_losses(batch, outputs)
loss.backward() loss.backward()
self.optimizer.step() self.optimizer.step()
iteration_time = time.time() - start iteration_time = time.time() - start
losses_np = {'train_loss': float(loss)} losses_np = {
'train_loss': float(loss),
'train_loss_div_batchsize':
float(loss) / self.config.data.batch_size
}
msg = "Train: Rank: {}, ".format(dist.get_rank()) msg = "Train: Rank: {}, ".format(dist.get_rank())
msg += "epoch: {}, ".format(self.epoch) msg += "epoch: {}, ".format(self.epoch)
msg += "step: {}, ".format(self.iteration) msg += "step: {}, ".format(self.iteration)
msg += "time: {:>.3f}s/{:>.3f}s, ".format(data_loader_time, msg += "time: {:>.3f}s/{:>.3f}s, ".format(data_loader_time,
iteration_time) iteration_time)
msg += f"batch size: {batch_size}, "
msg += ', '.join('{}: {:>.6f}'.format(k, v) msg += ', '.join('{}: {:>.6f}'.format(k, v)
for k, v in losses_np.items()) for k, v in losses_np.items())
#if self.iteration % 100 == 0:
self.logger.info(msg) self.logger.info(msg)
if dist.get_rank() == 0 and self.visualizer: if dist.get_rank() == 0 and self.visualizer:
@ -107,15 +130,16 @@ class DeepSpeech2Trainer(Trainer):
self.iteration += 1 self.iteration += 1
self.train_batch() self.train_batch()
if self.iteration % self.config.training.valid_interval == 0: # if self.iteration % self.config.training.valid_interval == 0:
self.valid() # self.valid()
if self.iteration % self.config.training.save_interval == 0: # if self.iteration % self.config.training.save_interval == 0:
self.save() # self.save()
except StopIteration: except StopIteration:
self.iteration -= 1 #epoch end, iteration ahead 1 self.iteration -= 1 #epoch end, iteration ahead 1
self.valid() self.valid()
self.save() self.save()
self.lr_scheduler.step()
self.new_epoch() self.new_epoch()
def compute_metrics(self, inputs, outputs): def compute_metrics(self, inputs, outputs):
@ -128,11 +152,14 @@ class DeepSpeech2Trainer(Trainer):
valid_losses = defaultdict(list) valid_losses = defaultdict(list)
for i, batch in enumerate(self.valid_loader): for i, batch in enumerate(self.valid_loader):
audio, text, audio_len, text_len = batch audio, text, audio_len, text_len = batch
batch_size = audio.shape[0]
outputs = self.model(audio, text, audio_len, text_len) outputs = self.model(audio, text, audio_len, text_len)
loss = self.compute_losses(batch, outputs) loss = self.compute_losses(batch, outputs)
metrics = self.compute_metrics(batch, outputs) metrics = self.compute_metrics(batch, outputs)
valid_losses['val_loss'].append(float(loss)) valid_losses['val_loss'].append(float(loss))
valid_losses['val_loss_div_batchsize'].append(
float(loss) / batch_size)
# write visual log # write visual log
valid_losses = {k: np.mean(v) for k, v in valid_losses.items()} valid_losses = {k: np.mean(v) for k, v in valid_losses.items()}
@ -166,19 +193,33 @@ class DeepSpeech2Trainer(Trainer):
grad_clip = paddle.nn.ClipGradByGlobalNorm( grad_clip = paddle.nn.ClipGradByGlobalNorm(
config.training.global_grad_clip) config.training.global_grad_clip)
# optimizer = paddle.optimizer.Adam(
# learning_rate=config.training.lr,
# parameters=model.parameters(),
# weight_decay=paddle.regularizer.L2Decay(
# config.training.weight_decay),
# grad_clip=grad_clip)
#learning_rate=fluid.layers.exponential_decay(
# learning_rate=learning_rate,
# decay_steps=num_samples / batch_size / dev_count,
# decay_rate=0.83,
# staircase=True),
lr_scheduler = paddle.optimizer.lr.ExponentialDecay(
learning_rate=config.training.lr, gamma=0.83, verbose=True)
optimizer = paddle.optimizer.Adam( optimizer = paddle.optimizer.Adam(
learning_rate=config.training.lr, learning_rate=lr_scheduler,
parameters=model.parameters(), parameters=model.parameters(),
weight_decay=paddle.regularizer.L2Decay(
config.training.weight_decay),
grad_clip=grad_clip) grad_clip=grad_clip)
criterion = DeepSpeech2Loss(self.train_loader.dataset.vocab_size) criterion = DeepSpeech2Loss(self.train_loader.dataset.vocab_size)
self.model = model self.model = model
self.optimizer = optimizer self.optimizer = optimizer
self.lr_scheduler = lr_scheduler
self.criterion = criterion self.criterion = criterion
self.logger.info("Setup model/optimizer/criterion!") self.logger.info("Setup model/optimizer/lr_scheduler/criterion!")
def setup_dataloader(self): def setup_dataloader(self):
config = self.config config = self.config

@ -29,6 +29,28 @@ from decoders.swig_wrapper import ctc_beam_search_decoder_batch
__all__ = ['DeepSpeech2', 'DeepSpeech2Loss'] __all__ = ['DeepSpeech2', 'DeepSpeech2Loss']
def ctc_loss(log_probs,
labels,
input_lengths,
label_lengths,
blank=0,
reduction='mean',
norm_by_times=True):
#print("my ctc loss with norm by times")
loss_out = paddle.fluid.layers.warpctc(log_probs, labels, blank, norm_by_times,
input_lengths, label_lengths)
loss_out = paddle.fluid.layers.squeeze(loss_out, [-1])
assert reduction in ['mean', 'sum', 'none']
if reduction == 'mean':
loss_out = paddle.mean(loss_out / label_lengths)
elif reduction == 'sum':
loss_out = paddle.sum(loss_out)
return loss_out
F.ctc_loss = ctc_loss
def brelu(x, t_min=0.0, t_max=24.0, name=None): def brelu(x, t_min=0.0, t_max=24.0, name=None):
t_min = paddle.to_tensor(t_min) t_min = paddle.to_tensor(t_min)
t_max = paddle.to_tensor(t_max) t_max = paddle.to_tensor(t_max)
@ -683,12 +705,13 @@ class DeepSpeech2Loss(nn.Layer):
# last token id as blank id # last token id as blank id
self.loss = nn.CTCLoss(blank=vocab_size, reduction='none') self.loss = nn.CTCLoss(blank=vocab_size, reduction='none')
def forward(self, logits, text, audio_len, text_len): def forward(self, logits, text, logits_len, text_len):
# warp-ctc do softmax on activations # warp-ctc do softmax on activations
# warp-ctc need activation with shape [T, B, V + 1] # warp-ctc need activation with shape [T, B, V + 1]
logits = logits.transpose([1, 0, 2]) logits = logits.transpose([1, 0, 2])
ctc_loss = self.loss(logits, text, audio_len, text_len) ctc_loss = self.loss(logits, text, logits_len, text_len)
ctc_loss /= text_len # norm_by_times ## https://github.com/PaddlePaddle/Paddle/blob/f5ca2db2cc/paddle/fluid/operators/warpctc_op.h#L403
#ctc_loss /= logits_len # norm_by_times
ctc_loss = ctc_loss.sum() ctc_loss = ctc_loss.sum()
return ctc_loss return ctc_loss

@ -161,9 +161,10 @@ class Trainer():
def new_epoch(self): def new_epoch(self):
"""Reset the train loader and increment ``epoch``. """Reset the train loader and increment ``epoch``.
""" """
self.epoch += 1
if self.parallel: if self.parallel:
# batch sampler epoch start from 0
self.train_loader.batch_sampler.set_epoch(self.epoch) self.train_loader.batch_sampler.set_epoch(self.epoch)
self.epoch += 1
self.iterator = iter(self.train_loader) self.iterator = iter(self.train_loader)
def train(self): def train(self):
@ -246,22 +247,30 @@ class Trainer():
the standard output and a text file named ``worker_n.log`` in the the standard output and a text file named ``worker_n.log`` in the
output directory, where ``n`` means the rank of the process. output directory, where ``n`` means the rank of the process.
""" """
format = '[%(levelname)s %(asctime)s %(filename)s:%(lineno)d] %(message)s'
logger = logging.getLogger(__name__) logger = logging.getLogger(__name__)
logger.setLevel("INFO") logger.setLevel("INFO")
formatter = logging.Formatter( formatter = logging.Formatter(fmt=format, datefmt='%Y/%m/%d %H:%M:%S')
fmt='[%(levelname)s %(asctime)s %(filename)s:%(lineno)d] %(message)s',
datefmt='%Y/%m/%d %H:%M:%S')
stream_handler = logging.StreamHandler() #stream_handler = logging.StreamHandler()
stream_handler.setFormatter(formatter) #stream_handler.setFormatter(formatter)
logger.addHandler(stream_handler) #logger.addHandler(stream_handler)
log_file = self.output_dir / 'worker_{}.log'.format(dist.get_rank()) log_file = self.output_dir / 'worker_{}.log'.format(dist.get_rank())
file_handler = logging.FileHandler(str(log_file)) file_handler = logging.FileHandler(str(log_file))
file_handler.setFormatter(formatter) file_handler.setFormatter(formatter)
logger.addHandler(file_handler) logger.addHandler(file_handler)
# global logger
stdout = True
save_path = '/dev/null'
logging.basicConfig(
level=logging.DEBUG if stdout else logging.INFO,
format=format,
datefmt='%Y/%m/%d %H:%M:%S',
filename=save_path if not stdout else None)
self.logger = logger self.logger = logger
@mp_tools.rank_zero_only @mp_tools.rank_zero_only

Loading…
Cancel
Save