diff --git a/.gitignore b/.gitignore index 0bd3d362c..2ec11b5ee 100644 --- a/.gitignore +++ b/.gitignore @@ -2,3 +2,4 @@ *.pyc tools/venv dataset +models/* diff --git a/data_utils/dataset.py b/data_utils/dataset.py index 2de8f87a2..658c5ba54 100644 --- a/data_utils/dataset.py +++ b/data_utils/dataset.py @@ -15,6 +15,7 @@ import math import random import tarfile +import logging import numpy as np import paddle from paddle.io import Dataset @@ -30,6 +31,8 @@ from data_utils.featurizer.speech_featurizer import SpeechFeaturizer from data_utils.speech import SpeechSegment from data_utils.normalizer import FeatureNormalizer +logger = logging.getLogger(__name__) + __all__ = [ "DeepSpeech2Dataset", "DeepSpeech2DistributedBatchSampler", @@ -234,9 +237,13 @@ class DeepSpeech2DistributedBatchSampler(DistributedBatchSampler): assert (clipped == False) if not clipped: res_len = len(indices) - shift_len - len(batch_indices) + assert res_len != 0, f"_batch_shuffle clipped {len(indices)} , {shift_len}, {len(batch_indices)}" # when res_len is 0, will return whole list, len(List[-0:]) = len(List[:]) batch_indices.extend(indices[-res_len:]) batch_indices.extend(indices[0:shift_len]) + assert len(indices) == len( + batch_indices + ), f"_batch_shuffle: {len(indices)} : {len(batch_indices)} : {res_len} - {shift_len}" return batch_indices def __iter__(self): @@ -247,8 +254,8 @@ class DeepSpeech2DistributedBatchSampler(DistributedBatchSampler): # sort (by duration) or batch-wise shuffle the manifest if self.shuffle: - if self.epoch == 0 and self.sortagrad: - pass + if self.epoch == 0 and self._sortagrad: + logger.info(f'dataset sortagrad! epoch {self.epoch}') else: if self._shuffle_method == "batch_shuffle": indices = self._batch_shuffle( @@ -374,9 +381,13 @@ class DeepSpeech2BatchSampler(BatchSampler): assert (clipped == False) if not clipped: res_len = len(indices) - shift_len - len(batch_indices) + assert res_len != 0, f"_batch_shuffle clipped {len(indices)} , {shift_len}, {len(batch_indices)}" # when res_len is 0, will return whole list, len(List[-0:]) = len(List[:]) batch_indices.extend(indices[-res_len:]) batch_indices.extend(indices[0:shift_len]) + assert len(indices) == len( + batch_indices + ), f"_batch_shuffle: {len(indices)} : {len(batch_indices)} : {res_len} - {shift_len}" return batch_indices def __iter__(self): @@ -388,7 +399,7 @@ class DeepSpeech2BatchSampler(BatchSampler): # sort (by duration) or batch-wise shuffle the manifest if self.shuffle: if self.epoch == 0 and self._sortagrad: - pass + logger.info(f'dataset sortagrad! epoch {self.epoch}') else: if self._shuffle_method == "batch_shuffle": indices = self._batch_shuffle( diff --git a/examples/aishell/local/run_train.sh b/examples/aishell/local/run_train.sh index 5bde13721..e3e8c745e 100644 --- a/examples/aishell/local/run_train.sh +++ b/examples/aishell/local/run_train.sh @@ -3,33 +3,42 @@ # train model # if you wish to resume from an exists model, uncomment --init_from_pretrained_model export FLAGS_sync_nccl_allreduce=0 -CUDA_VISIBLE_DEVICES=0,1,2,3,4,5,6,7 \ + +#CUDA_VISIBLE_DEVICES=0,1,2,3,4,5,6,7 \ +#python3 -u ${MAIN_ROOT}/train.py \ +#--batch_size=64 \ +#--num_epoch=50 \ +#--num_conv_layers=2 \ +#--num_rnn_layers=3 \ +#--rnn_layer_size=1024 \ +#--num_iter_print=100 \ +#--save_epoch=1 \ +#--num_samples=120000 \ +#--learning_rate=5e-4 \ +#--max_duration=27.0 \ +#--min_duration=0.0 \ +#--test_off=False \ +#--use_sortagrad=True \ +#--use_gru=True \ +#--use_gpu=True \ +#--is_local=True \ +#--share_rnn_weights=False \ +#--train_manifest="data/manifest.train" \ +#--dev_manifest="data/manifest.dev" \ +#--mean_std_path="data/mean_std.npz" \ +#--vocab_path="data/vocab.txt" \ +#--output_model_dir="./checkpoints" \ +#--augment_conf_path="${MAIN_ROOT}/conf/augmentation.config" \ +#--specgram_type="linear" \ +#--shuffle_method="batch_shuffle_clipped" \ + +CUDA_VISIBLE_DEVICES=1,2,6,7 \ python3 -u ${MAIN_ROOT}/train.py \ ---batch_size=64 \ ---num_epoch=50 \ ---num_conv_layers=2 \ ---num_rnn_layers=3 \ ---rnn_layer_size=1024 \ ---num_iter_print=100 \ ---save_epoch=1 \ ---num_samples=120000 \ ---learning_rate=5e-4 \ ---max_duration=27.0 \ ---min_duration=0.0 \ ---test_off=False \ ---use_sortagrad=True \ ---use_gru=True \ ---use_gpu=True \ ---is_local=True \ ---share_rnn_weights=False \ ---train_manifest="data/manifest.train" \ ---dev_manifest="data/manifest.dev" \ ---mean_std_path="data/mean_std.npz" \ ---vocab_path="data/vocab.txt" \ ---output_model_dir="./checkpoints" \ ---augment_conf_path="${MAIN_ROOT}/conf/augmentation.config" \ ---specgram_type="linear" \ ---shuffle_method="batch_shuffle_clipped" \ +--device 'gpu' \ +--nproc 4 \ +--config conf/deepspeech2.yaml \ +--output ckpt + if [ $? -ne 0 ]; then echo "Failed in training!" diff --git a/examples/tiny/local/run_train.sh b/examples/tiny/local/run_train.sh index 56abf7a34..7037c07e3 100644 --- a/examples/tiny/local/run_train.sh +++ b/examples/tiny/local/run_train.sh @@ -2,10 +2,12 @@ export FLAGS_sync_nccl_allreduce=0 -CUDA_VISIBLE_DEVICES=0,1,2,3 \ +#CUDA_VISIBLE_DEVICES=0,1,2,3 \ +#CUDA_VISIBLE_DEVICES=0,4,5,6 \ +CUDA_VISIBLE_DEVICES=0 \ python3 -u ${MAIN_ROOT}/train.py \ --device 'gpu' \ ---nproc 4 \ +--nproc 1 \ --config conf/deepspeech2.yaml \ --output ckpt diff --git a/model_utils/model.py b/model_utils/model.py index 36c19d7ca..a48307863 100644 --- a/model_utils/model.py +++ b/model_utils/model.py @@ -41,6 +41,22 @@ from decoders.swig_wrapper import ctc_beam_search_decoder_batch from utils.error_rate import char_errors, word_errors, cer, wer +# class ExponentialDecayOld(LRScheduler): +# def __init__(self, learning_rate, gamma, last_epoch=-1, +# decay_steps, decay_rate, staircase=False, verbose=False): +# self.learning_rate = learning_rate +# self.decay_steps = decay_steps +# self.decay_rate = decay_rate +# self.staircase = staircase +# super(ExponentialDecay, self).__init__(learning_rate, last_epoch, +# verbose) + +# def get_lr(self): +# div_res = self.step_num / self.decay_steps +# if self.staircase: +# div_res = paddle.floor(div_res) +# return self.base_lr * (self.decay_rate**div_res) + class DeepSpeech2Trainer(Trainer): def __init__(self, config, args): @@ -73,21 +89,28 @@ class DeepSpeech2Trainer(Trainer): self.optimizer.clear_grad() self.model.train() audio, text, audio_len, text_len = batch + batch_size = audio.shape[0] outputs = self.model(audio, text, audio_len, text_len) loss = self.compute_losses(batch, outputs) loss.backward() self.optimizer.step() iteration_time = time.time() - start - losses_np = {'train_loss': float(loss)} + losses_np = { + 'train_loss': float(loss), + 'train_loss_div_batchsize': + float(loss) / self.config.data.batch_size + } msg = "Train: Rank: {}, ".format(dist.get_rank()) msg += "epoch: {}, ".format(self.epoch) msg += "step: {}, ".format(self.iteration) - msg += "time: {:>.3f}s/{:>.3f}s, ".format(data_loader_time, iteration_time) + msg += f"batch size: {batch_size}, " msg += ', '.join('{}: {:>.6f}'.format(k, v) for k, v in losses_np.items()) + + #if self.iteration % 100 == 0: self.logger.info(msg) if dist.get_rank() == 0 and self.visualizer: @@ -107,15 +130,16 @@ class DeepSpeech2Trainer(Trainer): self.iteration += 1 self.train_batch() - if self.iteration % self.config.training.valid_interval == 0: - self.valid() + # if self.iteration % self.config.training.valid_interval == 0: + # self.valid() - if self.iteration % self.config.training.save_interval == 0: - self.save() + # if self.iteration % self.config.training.save_interval == 0: + # self.save() except StopIteration: self.iteration -= 1 #epoch end, iteration ahead 1 self.valid() self.save() + self.lr_scheduler.step() self.new_epoch() def compute_metrics(self, inputs, outputs): @@ -128,11 +152,14 @@ class DeepSpeech2Trainer(Trainer): valid_losses = defaultdict(list) for i, batch in enumerate(self.valid_loader): audio, text, audio_len, text_len = batch + batch_size = audio.shape[0] outputs = self.model(audio, text, audio_len, text_len) loss = self.compute_losses(batch, outputs) metrics = self.compute_metrics(batch, outputs) valid_losses['val_loss'].append(float(loss)) + valid_losses['val_loss_div_batchsize'].append( + float(loss) / batch_size) # write visual log valid_losses = {k: np.mean(v) for k, v in valid_losses.items()} @@ -166,19 +193,33 @@ class DeepSpeech2Trainer(Trainer): grad_clip = paddle.nn.ClipGradByGlobalNorm( config.training.global_grad_clip) + # optimizer = paddle.optimizer.Adam( + # learning_rate=config.training.lr, + # parameters=model.parameters(), + # weight_decay=paddle.regularizer.L2Decay( + # config.training.weight_decay), + # grad_clip=grad_clip) + + #learning_rate=fluid.layers.exponential_decay( + # learning_rate=learning_rate, + # decay_steps=num_samples / batch_size / dev_count, + # decay_rate=0.83, + # staircase=True), + + lr_scheduler = paddle.optimizer.lr.ExponentialDecay( + learning_rate=config.training.lr, gamma=0.83, verbose=True) optimizer = paddle.optimizer.Adam( - learning_rate=config.training.lr, + learning_rate=lr_scheduler, parameters=model.parameters(), - weight_decay=paddle.regularizer.L2Decay( - config.training.weight_decay), grad_clip=grad_clip) criterion = DeepSpeech2Loss(self.train_loader.dataset.vocab_size) self.model = model self.optimizer = optimizer + self.lr_scheduler = lr_scheduler self.criterion = criterion - self.logger.info("Setup model/optimizer/criterion!") + self.logger.info("Setup model/optimizer/lr_scheduler/criterion!") def setup_dataloader(self): config = self.config diff --git a/model_utils/network.py b/model_utils/network.py index 2c310b855..2b0f6765b 100644 --- a/model_utils/network.py +++ b/model_utils/network.py @@ -29,6 +29,28 @@ from decoders.swig_wrapper import ctc_beam_search_decoder_batch __all__ = ['DeepSpeech2', 'DeepSpeech2Loss'] +def ctc_loss(log_probs, + labels, + input_lengths, + label_lengths, + blank=0, + reduction='mean', + norm_by_times=True): + #print("my ctc loss with norm by times") + loss_out = paddle.fluid.layers.warpctc(log_probs, labels, blank, norm_by_times, + input_lengths, label_lengths) + + loss_out = paddle.fluid.layers.squeeze(loss_out, [-1]) + assert reduction in ['mean', 'sum', 'none'] + if reduction == 'mean': + loss_out = paddle.mean(loss_out / label_lengths) + elif reduction == 'sum': + loss_out = paddle.sum(loss_out) + return loss_out + +F.ctc_loss = ctc_loss + + def brelu(x, t_min=0.0, t_max=24.0, name=None): t_min = paddle.to_tensor(t_min) t_max = paddle.to_tensor(t_max) @@ -683,12 +705,13 @@ class DeepSpeech2Loss(nn.Layer): # last token id as blank id self.loss = nn.CTCLoss(blank=vocab_size, reduction='none') - def forward(self, logits, text, audio_len, text_len): + def forward(self, logits, text, logits_len, text_len): # warp-ctc do softmax on activations # warp-ctc need activation with shape [T, B, V + 1] logits = logits.transpose([1, 0, 2]) - ctc_loss = self.loss(logits, text, audio_len, text_len) - ctc_loss /= text_len # norm_by_times + ctc_loss = self.loss(logits, text, logits_len, text_len) + ## https://github.com/PaddlePaddle/Paddle/blob/f5ca2db2cc/paddle/fluid/operators/warpctc_op.h#L403 + #ctc_loss /= logits_len # norm_by_times ctc_loss = ctc_loss.sum() return ctc_loss diff --git a/training/trainer.py b/training/trainer.py index f10505e3d..88ef847be 100644 --- a/training/trainer.py +++ b/training/trainer.py @@ -161,9 +161,10 @@ class Trainer(): def new_epoch(self): """Reset the train loader and increment ``epoch``. """ - self.epoch += 1 if self.parallel: + # batch sampler epoch start from 0 self.train_loader.batch_sampler.set_epoch(self.epoch) + self.epoch += 1 self.iterator = iter(self.train_loader) def train(self): @@ -246,22 +247,30 @@ class Trainer(): the standard output and a text file named ``worker_n.log`` in the output directory, where ``n`` means the rank of the process. """ + format = '[%(levelname)s %(asctime)s %(filename)s:%(lineno)d] %(message)s' + logger = logging.getLogger(__name__) logger.setLevel("INFO") - formatter = logging.Formatter( - fmt='[%(levelname)s %(asctime)s %(filename)s:%(lineno)d] %(message)s', - datefmt='%Y/%m/%d %H:%M:%S') + formatter = logging.Formatter(fmt=format, datefmt='%Y/%m/%d %H:%M:%S') - stream_handler = logging.StreamHandler() - stream_handler.setFormatter(formatter) - logger.addHandler(stream_handler) + #stream_handler = logging.StreamHandler() + #stream_handler.setFormatter(formatter) + #logger.addHandler(stream_handler) log_file = self.output_dir / 'worker_{}.log'.format(dist.get_rank()) file_handler = logging.FileHandler(str(log_file)) file_handler.setFormatter(formatter) logger.addHandler(file_handler) + # global logger + stdout = True + save_path = '/dev/null' + logging.basicConfig( + level=logging.DEBUG if stdout else logging.INFO, + format=format, + datefmt='%Y/%m/%d %H:%M:%S', + filename=save_path if not stdout else None) self.logger = logger @mp_tools.rank_zero_only