From 584a2c0e39ab73b4a5826077528eccb4edf7afbd Mon Sep 17 00:00:00 2001 From: xiongxinlei Date: Wed, 9 Mar 2022 20:46:57 +0800 Subject: [PATCH] add ecapa-tdnn config yaml file --- examples/voxceleb/sv0/conf/ecapa_tdnn.yaml | 35 ++ examples/voxceleb/sv0/run.sh | 6 +- .../ecapa-tdnn/extract_speaker_embedding.py | 112 +++++++ .../ecapa-tdnn/speaker_verification_cosine.py | 207 ++++++++++++ paddlespeech/vector/exps/ecapa-tdnn/train.py | 298 ++++++++++++++++++ 5 files changed, 656 insertions(+), 2 deletions(-) create mode 100644 examples/voxceleb/sv0/conf/ecapa_tdnn.yaml create mode 100644 paddlespeech/vector/exps/ecapa-tdnn/extract_speaker_embedding.py create mode 100644 paddlespeech/vector/exps/ecapa-tdnn/speaker_verification_cosine.py create mode 100644 paddlespeech/vector/exps/ecapa-tdnn/train.py diff --git a/examples/voxceleb/sv0/conf/ecapa_tdnn.yaml b/examples/voxceleb/sv0/conf/ecapa_tdnn.yaml new file mode 100644 index 00000000..33304054 --- /dev/null +++ b/examples/voxceleb/sv0/conf/ecapa_tdnn.yaml @@ -0,0 +1,35 @@ +########################################################### +# FEATURE EXTRACTION SETTING # +########################################################### +# currently, we only support fbank +feature: + n_mels: 80 + window_size: 400 #25ms, sample rate 16000, 25 * 16000 / 1000 = 400 + hop_length: 160 #10ms, sample rate 16000, 10 * 16000 / 1000 = 160 + + +########################################################### +# MODEL SETTING # +########################################################### +# currently, we only support ecapa-tdnn in the ecapa_tdnn.yaml +# if we want use another model, please choose another configuration yaml file +model: + input_size: 80 + ##"channels": [1024, 1024, 1024, 1024, 3072], + # "channels": [512, 512, 512, 512, 1536], + channels: [512, 512, 512, 512, 1536] + kernel_sizes: [5, 3, 3, 3, 1] + dilations: [1, 2, 3, 4, 1] + attention_channels: 128 + lin_neurons: 192 + +########################################### +# Training # +########################################### +seed: 0 +epochs: 10 +batch_size: 32 +num_workers: 2 +save_freq: 10 +log_freq: 10 +learning_rate: 1e-8 diff --git a/examples/voxceleb/sv0/run.sh b/examples/voxceleb/sv0/run.sh index a6346cd5..2c0e55a6 100755 --- a/examples/voxceleb/sv0/run.sh +++ b/examples/voxceleb/sv0/run.sh @@ -31,20 +31,22 @@ if [ $stage -le 1 ]; then python3 \ -m paddle.distributed.launch --gpus=0,1,2,3 \ ${BIN_DIR}/train.py --device "gpu" --checkpoint-dir ${exp_dir} --augment \ - --save-freq 10 --data-dir ${dir} --batch-size 64 --epochs 100 + --data-dir ${dir} --config conf/ecapa_tdnn.yaml fi if [ $stage -le 2 ]; then # stage 1: get the speaker verification scores with cosine function python3 \ ${BIN_DIR}/speaker_verification_cosine.py\ - --batch-size 4 --data-dir ${dir} --load-checkpoint ${exp_dir}/epoch_10/ + --config conf/ecapa_tdnn.yaml \ + --data-dir ${dir} --load-checkpoint ${exp_dir}/epoch_10/ fi if [ $stage -le 3 ]; then # stage 3: extract the audio embedding python3 \ ${BIN_DIR}/extract_speaker_embedding.py\ + --config conf/ecapa_tdnn.yaml \ --audio-path "demo/csv/00001.wav" --load-checkpoint ${exp_dir}/epoch_60/ fi diff --git a/paddlespeech/vector/exps/ecapa-tdnn/extract_speaker_embedding.py b/paddlespeech/vector/exps/ecapa-tdnn/extract_speaker_embedding.py new file mode 100644 index 00000000..78498c61 --- /dev/null +++ b/paddlespeech/vector/exps/ecapa-tdnn/extract_speaker_embedding.py @@ -0,0 +1,112 @@ +# Copyright (c) 2021 PaddlePaddle Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +import argparse +import os + +import numpy as np +import paddle +from yacs.config import CfgNode + +from paddleaudio.paddleaudio.backends import load as load_audio +from paddleaudio.paddleaudio.compliance.librosa import melspectrogram +from paddlespeech.s2t.utils.log import Log +from paddlespeech.vector.io.batch import feature_normalize +from paddlespeech.vector.models.ecapa_tdnn import EcapaTdnn +from paddlespeech.vector.modules.sid_model import SpeakerIdetification +from paddlespeech.vector.training.seeding import seed_everything + +logger = Log(__name__).getlog() + +def extract_audio_embedding(args, config): + # stage 0: set the training device, cpu or gpu + paddle.set_device(args.device) + # set the random seed, it is a must for multiprocess training + seed_everything(config.seed) + + # stage 1: build the dnn backbone model network + ecapa_tdnn = EcapaTdnn(**config.model) + + # stage4: build the speaker verification train instance with backbone model + model = SpeakerIdetification(backbone=ecapa_tdnn, num_class=1211) + # stage 2: load the pre-trained model + args.load_checkpoint = os.path.abspath( + os.path.expanduser(args.load_checkpoint)) + + # load model checkpoint to sid model + state_dict = paddle.load( + os.path.join(args.load_checkpoint, 'model.pdparams')) + model.set_state_dict(state_dict) + logger.info(f'Checkpoint loaded from {args.load_checkpoint}') + + # stage 3: we must set the model to eval mode + model.eval() + + # stage 4: read the audio data and extract the embedding + # wavform is one dimension numpy array + waveform, sr = load_audio(args.audio_path) + + # feat type is numpy array, whose shape is [dim, time] + # we need convert the audio feat to one-batch shape [batch, dim, time], where the batch is one + # so the final shape is [1, dim, time] + feat = melspectrogram(x=waveform, **config.feature) + feat = paddle.to_tensor(feat).unsqueeze(0) + + # in inference period, the lengths is all one without padding + lengths = paddle.ones([1]) + feat = feature_normalize( + feat, mean_norm=True, std_norm=False, convert_to_numpy=True) + + # model backbone network forward the feats and get the embedding + embedding = model.backbone( + feat, lengths).squeeze().numpy() # (1, emb_size, 1) -> (emb_size) + + # stage 5: do global norm with external mean and std + # todo + return embedding + + +if __name__ == "__main__": + # yapf: disable + parser = argparse.ArgumentParser(__doc__) + parser.add_argument('--device', + choices=['cpu', 'gpu'], + default="gpu", + help="Select which device to train model, defaults to gpu.") + parser.add_argument("--config", + default=None, + type=str, + help="configuration file") + parser.add_argument("--load-checkpoint", + type=str, + default='', + help="Directory to load model checkpoint to contiune trainning.") + parser.add_argument("--global-embedding-norm", + type=str, + default=None, + help="Apply global normalization on speaker embeddings.") + parser.add_argument("--audio-path", + default="./data/demo.wav", + type=str, + help="Single audio file path") + args = parser.parse_args() + # yapf: enable + # https://yaml.org/type/float.html + config = CfgNode(new_allowed=True) + if args.config: + config.merge_from_file(args.config) + + config.freeze() + print(config) + + extract_audio_embedding(args, config) diff --git a/paddlespeech/vector/exps/ecapa-tdnn/speaker_verification_cosine.py b/paddlespeech/vector/exps/ecapa-tdnn/speaker_verification_cosine.py new file mode 100644 index 00000000..4d85bd62 --- /dev/null +++ b/paddlespeech/vector/exps/ecapa-tdnn/speaker_verification_cosine.py @@ -0,0 +1,207 @@ +# Copyright (c) 2021 PaddlePaddle Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +import argparse +import ast +import os + +import numpy as np +import paddle +from yacs.config import CfgNode +import paddle.nn.functional as F +from paddle.io import BatchSampler +from paddle.io import DataLoader +from tqdm import tqdm + +from paddleaudio.paddleaudio.datasets import VoxCeleb1 +from paddlespeech.s2t.utils.log import Log +from paddleaudio.paddleaudio.metric import compute_eer +from paddlespeech.vector.io.batch import batch_feature_normalize +from paddlespeech.vector.models.ecapa_tdnn import EcapaTdnn +from paddlespeech.vector.modules.sid_model import SpeakerIdetification +from paddlespeech.vector.training.seeding import seed_everything + +logger = Log(__name__).getlog() + +def main(args, config): + # stage0: set the training device, cpu or gpu + paddle.set_device(args.device) + # set the random seed, it is a must for multiprocess training + seed_everything(config.seed) + + # stage1: build the dnn backbone model network + ecapa_tdnn = EcapaTdnn(**config.model) + + # stage2: build the speaker verification eval instance with backbone model + model = SpeakerIdetification( + backbone=ecapa_tdnn, num_class=VoxCeleb1.num_speakers) + + # stage3: load the pre-trained model + args.load_checkpoint = os.path.abspath( + os.path.expanduser(args.load_checkpoint)) + + # load model checkpoint to sid model + state_dict = paddle.load( + os.path.join(args.load_checkpoint, 'model.pdparams')) + model.set_state_dict(state_dict) + logger.info(f'Checkpoint loaded from {args.load_checkpoint}') + + # stage4: construct the enroll and test dataloader + enroll_dataset = VoxCeleb1( + subset='enroll', + target_dir=args.data_dir, + feat_type='melspectrogram', + random_chunk=False, + **config.feature) + enroll_sampler = BatchSampler( + enroll_dataset, batch_size=config.batch_size, + shuffle=True) # Shuffle to make embedding normalization more robust. + enrol_loader = DataLoader(enroll_dataset, + batch_sampler=enroll_sampler, + collate_fn=lambda x: batch_feature_normalize( + x, mean_norm=True, std_norm=False), + num_workers=config.num_workers, + return_list=True,) + + test_dataset = VoxCeleb1( + subset='test', + target_dir=args.data_dir, + feat_type='melspectrogram', + random_chunk=False, + **config.feature) + + test_sampler = BatchSampler( + test_dataset, batch_size=config.batch_size, shuffle=True) + test_loader = DataLoader(test_dataset, + batch_sampler=test_sampler, + collate_fn=lambda x: batch_feature_normalize( + x, mean_norm=True, std_norm=False), + num_workers=config.num_workers, + return_list=True,) + # stage6: we must set the model to eval mode + model.eval() + + # stage7: global embedding norm to imporve the performance + if args.global_embedding_norm: + global_embedding_mean = None + global_embedding_std = None + mean_norm_flag = args.embedding_mean_norm + std_norm_flag = args.embedding_std_norm + batch_count = 0 + + # stage8: Compute embeddings of audios in enrol and test dataset from model. + id2embedding = {} + # Run multi times to make embedding normalization more stable. + for i in range(2): + for dl in [enrol_loader, test_loader]: + logger.info( + f'Loop {[i+1]}: Computing embeddings on {dl.dataset.subset} dataset' + ) + with paddle.no_grad(): + for batch_idx, batch in enumerate(tqdm(dl)): + + # stage 8-1: extrac the audio embedding + ids, feats, lengths = batch['ids'], batch['feats'], batch[ + 'lengths'] + embeddings = model.backbone(feats, lengths).squeeze( + -1).numpy() # (N, emb_size, 1) -> (N, emb_size) + + # Global embedding normalization. + if args.global_embedding_norm: + batch_count += 1 + current_mean = embeddings.mean( + axis=0) if mean_norm_flag else 0 + current_std = embeddings.std( + axis=0) if std_norm_flag else 1 + # Update global mean and std. + if global_embedding_mean is None and global_embedding_std is None: + global_embedding_mean, global_embedding_std = current_mean, current_std + else: + weight = 1 / batch_count # Weight decay by batches. + global_embedding_mean = ( + 1 - weight + ) * global_embedding_mean + weight * current_mean + global_embedding_std = ( + 1 - weight + ) * global_embedding_std + weight * current_std + # Apply global embedding normalization. + embeddings = (embeddings - global_embedding_mean + ) / global_embedding_std + + # Update embedding dict. + id2embedding.update(dict(zip(ids, embeddings))) + + # stage 9: Compute cosine scores. + labels = [] + enrol_ids = [] + test_ids = [] + with open(VoxCeleb1.veri_test_file, 'r') as f: + for line in f.readlines(): + label, enrol_id, test_id = line.strip().split(' ') + labels.append(int(label)) + enrol_ids.append(enrol_id.split('.')[0].replace('/', '-')) + test_ids.append(test_id.split('.')[0].replace('/', '-')) + + cos_sim_func = paddle.nn.CosineSimilarity(axis=1) + enrol_embeddings, test_embeddings = map(lambda ids: paddle.to_tensor( + np.asarray([id2embedding[id] for id in ids], dtype='float32')), + [enrol_ids, test_ids + ]) # (N, emb_size) + scores = cos_sim_func(enrol_embeddings, test_embeddings) + EER, threshold = compute_eer(np.asarray(labels), scores.numpy()) + logger.info( + f'EER of verification test: {EER*100:.4f}%, score threshold: {threshold:.5f}' + ) + + +if __name__ == "__main__": + # yapf: disable + parser = argparse.ArgumentParser(__doc__) + parser.add_argument('--device', + choices=['cpu', 'gpu'], + default="gpu", + help="Select which device to train model, defaults to gpu.") + parser.add_argument("--config", + default=None, + type=str, + help="configuration file") + parser.add_argument("--data-dir", + default="./data/", + type=str, + help="data directory") + parser.add_argument("--load-checkpoint", + type=str, + default='', + help="Directory to load model checkpoint to contiune trainning.") + parser.add_argument("--global-embedding-norm", + type=bool, + default=True, + help="Apply global normalization on speaker embeddings.") + parser.add_argument("--embedding-mean-norm", + type=bool, + default=True, + help="Apply mean normalization on speaker embeddings.") + parser.add_argument("--embedding-std-norm", + type=bool, + default=False, + help="Apply std normalization on speaker embeddings.") + args = parser.parse_args() + # yapf: enable + # https://yaml.org/type/float.html + config = CfgNode(new_allowed=True) + if args.config: + config.merge_from_file(args.config) + + config.freeze() + print(config) + main(args, config) diff --git a/paddlespeech/vector/exps/ecapa-tdnn/train.py b/paddlespeech/vector/exps/ecapa-tdnn/train.py new file mode 100644 index 00000000..08a4ac1c --- /dev/null +++ b/paddlespeech/vector/exps/ecapa-tdnn/train.py @@ -0,0 +1,298 @@ +# Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +import argparse +import os + +import numpy as np +import paddle +from paddle.io import BatchSampler +from paddle.io import DataLoader +from paddle.io import DistributedBatchSampler +from yacs.config import CfgNode +from paddleaudio.paddleaudio.compliance.librosa import melspectrogram +from paddleaudio.paddleaudio.datasets.voxceleb import VoxCeleb1 +from paddlespeech.s2t.utils.log import Log +from paddlespeech.vector.io.augment import build_augment_pipeline +from paddlespeech.vector.io.augment import waveform_augment +from paddlespeech.vector.io.batch import feature_normalize +from paddlespeech.vector.io.batch import waveform_collate_fn +from paddlespeech.vector.models.ecapa_tdnn import EcapaTdnn +from paddlespeech.vector.modules.loss import AdditiveAngularMargin +from paddlespeech.vector.modules.loss import LogSoftmaxWrapper +from paddlespeech.vector.training.scheduler import CyclicLRScheduler +from paddlespeech.vector.modules.sid_model import SpeakerIdetification +from paddlespeech.vector.training.seeding import seed_everything +from paddlespeech.vector.utils.time import Timer + +logger = Log(__name__).getlog() + +def main(args, config): + # stage0: set the training device, cpu or gpu + paddle.set_device(args.device) + + # stage1: we must call the paddle.distributed.init_parallel_env() api at the begining + paddle.distributed.init_parallel_env() + nranks = paddle.distributed.get_world_size() + local_rank = paddle.distributed.get_rank() + # set the random seed, it is a must for multiprocess training + seed_everything(config.seed) + + # stage2: data prepare, such vox1 and vox2 data, and augment noise data and pipline + # note: some cmd must do in rank==0, so wo will refactor the data prepare code + train_dataset = VoxCeleb1('train', target_dir=args.data_dir) + dev_dataset = VoxCeleb1('dev', target_dir=args.data_dir) + + if args.augment: + augment_pipeline = build_augment_pipeline(target_dir=args.data_dir) + else: + augment_pipeline = [] + + # stage3: build the dnn backbone model network + ecapa_tdnn = EcapaTdnn(**config.model) + + # stage4: build the speaker verification train instance with backbone model + model = SpeakerIdetification( + backbone=ecapa_tdnn, num_class=VoxCeleb1.num_speakers) + + # stage5: build the optimizer, we now only construct the AdamW optimizer + lr_schedule = CyclicLRScheduler( + base_lr=config.learning_rate, max_lr=1e-3, step_size=140000 // nranks) + optimizer = paddle.optimizer.AdamW( + learning_rate=lr_schedule, parameters=model.parameters()) + + # stage6: build the loss function, we now only support LogSoftmaxWrapper + criterion = LogSoftmaxWrapper( + loss_fn=AdditiveAngularMargin(margin=0.2, scale=30)) + + # stage7: confirm training start epoch + # if pre-trained model exists, start epoch confirmed by the pre-trained model + start_epoch = 0 + if args.load_checkpoint: + logger.info("load the check point") + args.load_checkpoint = os.path.abspath( + os.path.expanduser(args.load_checkpoint)) + try: + # load model checkpoint + state_dict = paddle.load( + os.path.join(args.load_checkpoint, 'model.pdparams')) + model.set_state_dict(state_dict) + + # load optimizer checkpoint + state_dict = paddle.load( + os.path.join(args.load_checkpoint, 'model.pdopt')) + optimizer.set_state_dict(state_dict) + if local_rank == 0: + logger.info(f'Checkpoint loaded from {args.load_checkpoint}') + except FileExistsError: + if local_rank == 0: + logger.info('Train from scratch.') + + try: + start_epoch = int(args.load_checkpoint[-1]) + logger.info(f'Restore training from epoch {start_epoch}.') + except ValueError: + pass + + # stage8: we build the batch sampler for paddle.DataLoader + train_sampler = DistributedBatchSampler( + train_dataset, + batch_size=config.batch_size, + shuffle=True, + drop_last=False) + train_loader = DataLoader( + train_dataset, + batch_sampler=train_sampler, + num_workers=config.num_workers, + collate_fn=waveform_collate_fn, + return_list=True, + use_buffer_reader=True, ) + + # stage9: start to train + # we will comment the training process + steps_per_epoch = len(train_sampler) + timer = Timer(steps_per_epoch * config.epochs) + timer.start() + + for epoch in range(start_epoch + 1, config.epochs + 1): + # at the begining, model must set to train mode + model.train() + + avg_loss = 0 + num_corrects = 0 + num_samples = 0 + for batch_idx, batch in enumerate(train_loader): + # stage 9-1: batch data is audio sample points and speaker id label + waveforms, labels = batch['waveforms'], batch['labels'] + + # stage 9-2: audio sample augment method, which is done on the audio sample point + if len(augment_pipeline) != 0: + waveforms = waveform_augment(waveforms, augment_pipeline) + labels = paddle.concat( + [labels for i in range(len(augment_pipeline) + 1)]) + + # stage 9-3: extract the audio feats,such fbank, mfcc, spectrogram + feats = [] + for waveform in waveforms.numpy(): + feat = melspectrogram(x=waveform, **config.feature) + feats.append(feat) + feats = paddle.to_tensor(np.asarray(feats)) + + # stage 9-4: feature normalize, which help converge and imporve the performance + feats = feature_normalize( + feats, mean_norm=True, std_norm=False) # Features normalization + + # stage 9-5: model forward, such ecapa-tdnn, x-vector + logits = model(feats) + + # stage 9-6: loss function criterion, such AngularMargin, AdditiveAngularMargin + loss = criterion(logits, labels) + + # stage 9-7: update the gradient and clear the gradient cache + loss.backward() + optimizer.step() + if isinstance(optimizer._learning_rate, + paddle.optimizer.lr.LRScheduler): + optimizer._learning_rate.step() + optimizer.clear_grad() + + # stage 9-8: Calculate average loss per batch + avg_loss += loss.numpy()[0] + + # stage 9-9: Calculate metrics, which is one-best accuracy + preds = paddle.argmax(logits, axis=1) + num_corrects += (preds == labels).numpy().sum() + num_samples += feats.shape[0] + timer.count() # step plus one in timer + + # stage 9-10: print the log information only on 0-rank per log-freq batchs + if (batch_idx + 1) % config.log_freq == 0 and local_rank == 0: + lr = optimizer.get_lr() + avg_loss /= config.log_freq + avg_acc = num_corrects / num_samples + + print_msg = 'Train Epoch={}/{}, Step={}/{}'.format( + epoch, config.epochs, batch_idx + 1, steps_per_epoch) + print_msg += ' loss={:.4f}'.format(avg_loss) + print_msg += ' acc={:.4f}'.format(avg_acc) + print_msg += ' lr={:.4E} step/sec={:.2f} | ETA {}'.format( + lr, timer.timing, timer.eta) + logger.info(print_msg) + + avg_loss = 0 + num_corrects = 0 + num_samples = 0 + + # stage 9-11: save the model parameters only on 0-rank per save-freq batchs + if epoch % config.save_freq == 0 and batch_idx + 1 == steps_per_epoch: + if local_rank != 0: + paddle.distributed.barrier( + ) # Wait for valid step in main process + continue # Resume trainning on other process + + # stage 9-12: construct the valid dataset dataloader + dev_sampler = BatchSampler( + dev_dataset, + batch_size=config.batch_size // 4, + shuffle=False, + drop_last=False) + dev_loader = DataLoader( + dev_dataset, + batch_sampler=dev_sampler, + collate_fn=waveform_collate_fn, + num_workers=config.num_workers, + return_list=True, ) + + # set the model to eval mode + model.eval() + num_corrects = 0 + num_samples = 0 + + # stage 9-13: evaluation the valid dataset batch data + logger.info('Evaluate on validation dataset') + with paddle.no_grad(): + for batch_idx, batch in enumerate(dev_loader): + waveforms, labels = batch['waveforms'], batch['labels'] + + feats = [] + for waveform in waveforms.numpy(): + # feat = melspectrogram(x=waveform, **cpu_feat_conf) + feat = melspectrogram(x=waveform, **config.feature) + feats.append(feat) + + feats = paddle.to_tensor(np.asarray(feats)) + feats = feature_normalize( + feats, mean_norm=True, std_norm=False) + logits = model(feats) + + preds = paddle.argmax(logits, axis=1) + num_corrects += (preds == labels).numpy().sum() + num_samples += feats.shape[0] + + print_msg = '[Evaluation result]' + print_msg += ' dev_acc={:.4f}'.format(num_corrects / num_samples) + logger.info(print_msg) + + # stage 9-14: Save model parameters + save_dir = os.path.join(args.checkpoint_dir, + 'epoch_{}'.format(epoch)) + logger.info('Saving model checkpoint to {}'.format(save_dir)) + paddle.save(model.state_dict(), + os.path.join(save_dir, 'model.pdparams')) + paddle.save(optimizer.state_dict(), + os.path.join(save_dir, 'model.pdopt')) + + if nranks > 1: + paddle.distributed.barrier() # Main process + + +if __name__ == "__main__": + # yapf: disable + parser = argparse.ArgumentParser(__doc__) + parser.add_argument('--device', + choices=['cpu', 'gpu'], + default="cpu", + help="Select which device to train model, defaults to gpu.") + parser.add_argument("--config", + default=None, + type=str, + help="configuration file") + parser.add_argument("--data-dir", + default="./data/", + type=str, + help="data directory") + parser.add_argument("--load-checkpoint", + type=str, + default=None, + help="Directory to load model checkpoint to contiune trainning.") + parser.add_argument("--checkpoint-dir", + type=str, + default='./checkpoint', + help="Directory to save model checkpoints.") + parser.add_argument("--augment", + action="store_true", + default=False, + help="Apply audio augments.") + + args = parser.parse_args() + # yapf: enable + + # https://yaml.org/type/float.html + config = CfgNode(new_allowed=True) + if args.config: + config.merge_from_file(args.config) + + config.freeze() + print(config) + + main(args, config)