From 1f74af110b54127bd7b2b76a3c0664c909dbea98 Mon Sep 17 00:00:00 2001 From: xiongxinlei Date: Thu, 3 Mar 2022 22:17:14 +0800 Subject: [PATCH] add training log info and comment, test=doc --- examples/voxceleb/sv0/local/train.py | 52 +++++++++++++++------ paddlespeech/vector/training/time.py | 67 ++++++++++++++++++++++++++++ 2 files changed, 106 insertions(+), 13 deletions(-) create mode 100644 paddlespeech/vector/training/time.py diff --git a/examples/voxceleb/sv0/local/train.py b/examples/voxceleb/sv0/local/train.py index f68f7373..f86b0a86 100644 --- a/examples/voxceleb/sv0/local/train.py +++ b/examples/voxceleb/sv0/local/train.py @@ -16,12 +16,13 @@ import os import numpy as np import paddle +from paddle.io import BatchSampler from paddle.io import DataLoader from paddle.io import DistributedBatchSampler from paddleaudio.datasets.voxceleb import VoxCeleb1 from paddleaudio.features.core import melspectrogram -from paddleaudio.utils.time import Timer +from paddlespeech.vector.training.time import Timer from paddlespeech.vector.datasets.batch import feature_normalize from paddlespeech.vector.datasets.batch import waveform_collate_fn from paddlespeech.vector.layers.loss import AdditiveAngularMargin @@ -37,7 +38,6 @@ cpu_feat_conf = { 'hop_length': 160, } - def main(args): # stage0: set the training device, cpu or gpu paddle.set_device(args.device) @@ -82,6 +82,7 @@ def main(args): # if pre-trained model exists, start epoch confirmed by the pre-trained model start_epoch = 0 if args.load_checkpoint: + print("load the check point") args.load_checkpoint = os.path.abspath( os.path.expanduser(args.load_checkpoint)) try: @@ -131,18 +132,30 @@ def main(args): num_corrects = 0 num_samples = 0 for batch_idx, batch in enumerate(train_loader): + # stage 9-1: batch data is audio sample points and speaker id label waveforms, labels = batch['waveforms'], batch['labels'] + # stage 9-2: audio sample augment method, which is done on the audio sample point + # todo + + # stage 9-3: extract the audio feats,such fbank, mfcc, spectrogram feats = [] for waveform in waveforms.numpy(): feat = melspectrogram(x=waveform, **cpu_feat_conf) feats.append(feat) feats = paddle.to_tensor(np.asarray(feats)) + + # stage 9-4: feature normalize, which help converge and imporve the performance feats = feature_normalize( feats, mean_norm=True, std_norm=False) # Features normalization + + # stage 9-5: model forward, such ecapa-tdnn, x-vector logits = model(feats) + # stage 9-6: loss function criterion, such AngularMargin, AdditiveAngularMargin loss = criterion(logits, labels) + + # stage 9-7: update the gradient and clear the gradient cache loss.backward() optimizer.step() if isinstance(optimizer._learning_rate, @@ -150,22 +163,22 @@ def main(args): optimizer._learning_rate.step() optimizer.clear_grad() - # Calculate loss + # stage 9-8: Calculate average loss per batch avg_loss += loss.numpy()[0] - # Calculate metrics + # stage 9-9: Calculate metrics, which is one-best accuracy preds = paddle.argmax(logits, axis=1) num_corrects += (preds == labels).numpy().sum() num_samples += feats.shape[0] + timer.count() # step plus one in timer - timer.count() - + # stage 9-10: print the log information only on 0-rank per log-freq batchs if (batch_idx + 1) % args.log_freq == 0 and local_rank == 0: lr = optimizer.get_lr() avg_loss /= args.log_freq avg_acc = num_corrects / num_samples - print_msg = 'Epoch={}/{}, Step={}/{}'.format( + print_msg = 'Train Epoch={}/{}, Step={}/{}'.format( epoch, args.epochs, batch_idx + 1, steps_per_epoch) print_msg += ' loss={:.4f}'.format(avg_loss) print_msg += ' acc={:.4f}'.format(avg_acc) @@ -177,36 +190,42 @@ def main(args): num_corrects = 0 num_samples = 0 + # stage 9-11: save the model parameters only on 0-rank per save-freq batchs if epoch % args.save_freq == 0 and batch_idx + 1 == steps_per_epoch: if local_rank != 0: paddle.distributed.barrier( ) # Wait for valid step in main process continue # Resume trainning on other process - dev_sampler = paddle.io.BatchSampler( + # stage 9-12: construct the valid dataset dataloader + dev_sampler = BatchSampler( dev_ds, batch_size=args.batch_size // 4, shuffle=False, drop_last=False) - dev_loader = paddle.io.DataLoader( + dev_loader = DataLoader( dev_ds, batch_sampler=dev_sampler, collate_fn=waveform_collate_fn, num_workers=args.num_workers, return_list=True, ) + # set the model to eval mode model.eval() num_corrects = 0 num_samples = 0 + + # stage 9-13: evaluation the valid dataset batch data print('Evaluate on validation dataset') with paddle.no_grad(): for batch_idx, batch in enumerate(dev_loader): waveforms, labels = batch['waveforms'], batch['labels'] - # feats = feature_extractor(waveforms) + feats = [] for waveform in waveforms.numpy(): feat = melspectrogram(x=waveform, **cpu_feat_conf) feats.append(feat) + feats = paddle.to_tensor(np.asarray(feats)) feats = feature_normalize( feats, mean_norm=True, std_norm=False) @@ -218,10 +237,9 @@ def main(args): print_msg = '[Evaluation result]' print_msg += ' dev_acc={:.4f}'.format(num_corrects / num_samples) - print(print_msg) - # Save model + # stage 9-14: Save model parameters save_dir = os.path.join(args.checkpoint_dir, 'epoch_{}'.format(epoch)) print('Saving model checkpoint to {}'.format(save_dir)) @@ -264,10 +282,18 @@ if __name__ == "__main__": type=int, default=50, help="Number of epoches for fine-tuning.") - parser.add_argument("--log_freq", + parser.add_argument("--log-freq", type=int, default=10, help="Log the training infomation every n steps.") + parser.add_argument("--save-freq", + type=int, + default=1, + help="Save checkpoint every n epoch.") + parser.add_argument("--checkpoint-dir", + type=str, + default='./checkpoint', + help="Directory to save model checkpoints.") args = parser.parse_args() # yapf: enable diff --git a/paddlespeech/vector/training/time.py b/paddlespeech/vector/training/time.py new file mode 100644 index 00000000..3a4e183d --- /dev/null +++ b/paddlespeech/vector/training/time.py @@ -0,0 +1,67 @@ +# Copyright (c) 2021 PaddlePaddle Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License" +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +import math +import time + + +class Timer(object): + '''Calculate runing speed and estimated time of arrival(ETA)''' + + def __init__(self, total_step: int): + self.total_step = total_step + self.last_start_step = 0 + self.current_step = 0 + self._is_running = True + + def start(self): + self.last_time = time.time() + self.start_time = time.time() + + def stop(self): + self._is_running = False + self.end_time = time.time() + + def count(self) -> int: + if not self.current_step >= self.total_step: + self.current_step += 1 + return self.current_step + + @property + def timing(self) -> float: + run_steps = self.current_step - self.last_start_step + self.last_start_step = self.current_step + time_used = time.time() - self.last_time + self.last_time = time.time() + return time_used / run_steps + + @property + def is_running(self) -> bool: + return self._is_running + + @property + def eta(self) -> str: + if not self.is_running: + return '00:00:00' + scale = self.total_step / self.current_step + remaining_time = (time.time() - self.start_time) * scale + return seconds_to_hms(remaining_time) + + +def seconds_to_hms(seconds: int) -> str: + '''Convert the number of seconds to hh:mm:ss''' + h = math.floor(seconds / 3600) + m = math.floor((seconds - h * 3600) / 60) + s = int(seconds - h * 3600 - m * 60) + hms_str = '{:0>2}:{:0>2}:{:0>2}'.format(h, m, s) + return hms_str