diff --git a/examples/voxceleb/sv0/local/data_prepare.py b/examples/voxceleb/sv0/local/data_prepare.py index ca707fc2..1a0a6392 100644 --- a/examples/voxceleb/sv0/local/data_prepare.py +++ b/examples/voxceleb/sv0/local/data_prepare.py @@ -3,24 +3,11 @@ import os import numpy as np import paddle -from paddle.io import BatchSampler -from paddle.io import DataLoader -from paddle.io import DistributedBatchSampler -from paddleaudio.datasets.voxceleb import VoxCeleb1 -from paddleaudio.features.core import melspectrogram +from paddleaudio.paddleaudio.datasets.voxceleb import VoxCeleb1 from paddlespeech.s2t.utils.log import Log from paddlespeech.vector.io.augment import build_augment_pipeline -from paddlespeech.vector.io.augment import waveform_augment -from paddlespeech.vector.io.batch import feature_normalize -from paddlespeech.vector.io.batch import waveform_collate_fn -from paddlespeech.vector.models.ecapa_tdnn import EcapaTdnn -from paddlespeech.vector.modules.loss import AdditiveAngularMargin -from paddlespeech.vector.modules.loss import LogSoftmaxWrapper -from paddlespeech.vector.modules.lr import CyclicLRScheduler -from paddlespeech.vector.modules.sid_model import SpeakerIdetification from paddlespeech.vector.training.seeding import seed_everything -from paddlespeech.vector.utils.time import Timer logger = Log(__name__).getlog() diff --git a/examples/voxceleb/sv0/local/extract_speaker_embedding.py b/examples/voxceleb/sv0/local/extract_speaker_embedding.py deleted file mode 100644 index e7dad140..00000000 --- a/examples/voxceleb/sv0/local/extract_speaker_embedding.py +++ /dev/null @@ -1,129 +0,0 @@ -# Copyright (c) 2021 PaddlePaddle Authors. All Rights Reserved. -# -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. -import argparse -import ast -import os - -import numpy as np -import paddle -import paddle.nn.functional as F -from paddle.io import BatchSampler -from paddle.io import DataLoader -from tqdm import tqdm - -from paddleaudio.backends import load as load_audio -from paddleaudio.datasets.voxceleb import VoxCeleb1 -from paddleaudio.features.core import melspectrogram -from paddlespeech.s2t.utils.log import Log -from paddlespeech.vector.io.batch import feature_normalize -from paddlespeech.vector.models.ecapa_tdnn import EcapaTdnn -from paddlespeech.vector.modules.sid_model import SpeakerIdetification -from paddlespeech.vector.training.metrics import compute_eer -from paddlespeech.vector.training.seeding import seed_everything - -logger = Log(__name__).getlog() - -# feat configuration -cpu_feat_conf = { - 'n_mels': 80, - 'window_size': 400, #ms - 'hop_length': 160, #ms -} - - -def extract_audio_embedding(args): - # stage 0: set the training device, cpu or gpu - paddle.set_device(args.device) - # set the random seed, it is a must for multiprocess training - seed_everything(args.seed) - - # stage 1: build the dnn backbone model network - ##"channels": [1024, 1024, 1024, 1024, 3072], - model_conf = { - "input_size": 80, - "channels": [512, 512, 512, 512, 1536], - "kernel_sizes": [5, 3, 3, 3, 1], - "dilations": [1, 2, 3, 4, 1], - "attention_channels": 128, - "lin_neurons": 192, - } - ecapa_tdnn = EcapaTdnn(**model_conf) - - # stage4: build the speaker verification train instance with backbone model - model = SpeakerIdetification(backbone=ecapa_tdnn, num_class=1211) - # stage 2: load the pre-trained model - args.load_checkpoint = os.path.abspath( - os.path.expanduser(args.load_checkpoint)) - - # load model checkpoint to sid model - state_dict = paddle.load( - os.path.join(args.load_checkpoint, 'model.pdparams')) - model.set_state_dict(state_dict) - logger.info(f'Checkpoint loaded from {args.load_checkpoint}') - - # stage 3: we must set the model to eval mode - model.eval() - - # stage 4: read the audio data and extract the embedding - # wavform is one dimension numpy array - waveform, sr = load_audio(args.audio_path) - - # feat type is numpy array, whose shape is [dim, time] - # we need convert the audio feat to one-batch shape [batch, dim, time], where the batch is one - # so the final shape is [1, dim, time] - feat = melspectrogram(x=waveform, **cpu_feat_conf) - feat = paddle.to_tensor(feat).unsqueeze(0) - - # in inference period, the lengths is all one without padding - lengths = paddle.ones([1]) - feat = feature_normalize( - feat, mean_norm=True, std_norm=False, convert_to_numpy=True) - - # model backbone network forward the feats and get the embedding - embedding = model.backbone( - feat, lengths).squeeze().numpy() # (1, emb_size, 1) -> (emb_size) - - # stage 5: do global norm with external mean and std - # todo - # np.save("audio-embedding", embedding) - return embedding - - -if __name__ == "__main__": - # yapf: disable - parser = argparse.ArgumentParser(__doc__) - parser.add_argument('--device', - choices=['cpu', 'gpu'], - default="gpu", - help="Select which device to train model, defaults to gpu.") - parser.add_argument("--seed", - default=0, - type=int, - help="random seed for paddle, numpy and python random package") - parser.add_argument("--load-checkpoint", - type=str, - default='', - help="Directory to load model checkpoint to contiune trainning.") - parser.add_argument("--global-embedding-norm", - type=str, - default=None, - help="Apply global normalization on speaker embeddings.") - parser.add_argument("--audio-path", - default="./data/demo.wav", - type=str, - help="Single audio file path") - args = parser.parse_args() - # yapf: enable - - extract_audio_embedding(args) diff --git a/examples/voxceleb/sv0/local/speaker_verification_cosine.py b/examples/voxceleb/sv0/local/speaker_verification_cosine.py deleted file mode 100644 index 417e8aa3..00000000 --- a/examples/voxceleb/sv0/local/speaker_verification_cosine.py +++ /dev/null @@ -1,264 +0,0 @@ -# Copyright (c) 2021 PaddlePaddle Authors. All Rights Reserved. -# -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. -import argparse -import ast -import os - -import numpy as np -import paddle -import paddle.nn.functional as F -from paddle.io import BatchSampler -from paddle.io import DataLoader -from tqdm import tqdm - -from paddleaudio.datasets.voxceleb import VoxCeleb1 -from paddlespeech.s2t.utils.log import Log -from paddlespeech.vector.models.ecapa_tdnn import EcapaTdnn -from paddlespeech.vector.modules.sid_model import SpeakerIdetification -from paddlespeech.vector.training.metrics import compute_eer -from paddlespeech.vector.training.seeding import seed_everything - -logger = Log(__name__).getlog() - - -def pad_right_2d(x, target_length, axis=-1, mode='constant', **kwargs): - x = np.asarray(x) - assert len( - x.shape) == 2, f'Only 2D arrays supported, but got shape: {x.shape}' - - w = target_length - x.shape[axis] - assert w >= 0, f'Target length {target_length} is less than origin length {x.shape[axis]}' - - if axis == 0: - pad_width = [[0, w], [0, 0]] - else: - pad_width = [[0, 0], [0, w]] - - return np.pad(x, pad_width, mode=mode, **kwargs) - - -def feature_normalize(batch, mean_norm: bool=True, std_norm: bool=True): - ids = [item['id'] for item in batch] - lengths = np.asarray([item['feat'].shape[1] for item in batch]) - feats = list( - map(lambda x: pad_right_2d(x, lengths.max()), - [item['feat'] for item in batch])) - feats = np.stack(feats) - - # Features normalization if needed - for i in range(len(feats)): - feat = feats[i][:, :lengths[i]] # Excluding pad values. - mean = feat.mean(axis=-1, keepdims=True) if mean_norm else 0 - std = feat.std(axis=-1, keepdims=True) if std_norm else 1 - feats[i][:, :lengths[i]] = (feat - mean) / std - assert feats[i][:, lengths[ - i]:].sum() == 0 # Padding valus should all be 0. - - # Converts into ratios. - lengths = (lengths / lengths.max()).astype(np.float32) - - return {'ids': ids, 'feats': feats, 'lengths': lengths} - - -# feat configuration -cpu_feat_conf = { - 'n_mels': 80, - 'window_size': 400, #ms - 'hop_length': 160, #ms -} - - -def main(args): - # stage0: set the training device, cpu or gpu - paddle.set_device(args.device) - # set the random seed, it is a must for multiprocess training - seed_everything(args.seed) - - # stage1: build the dnn backbone model network - ##"channels": [1024, 1024, 1024, 1024, 3072], - model_conf = { - "input_size": 80, - "channels": [512, 512, 512, 512, 1536], - "kernel_sizes": [5, 3, 3, 3, 1], - "dilations": [1, 2, 3, 4, 1], - "attention_channels": 128, - "lin_neurons": 192, - } - ecapa_tdnn = EcapaTdnn(**model_conf) - - # stage2: build the speaker verification eval instance with backbone model - model = SpeakerIdetification( - backbone=ecapa_tdnn, num_class=VoxCeleb1.num_speakers) - - # stage3: load the pre-trained model - args.load_checkpoint = os.path.abspath( - os.path.expanduser(args.load_checkpoint)) - - # load model checkpoint to sid model - state_dict = paddle.load( - os.path.join(args.load_checkpoint, 'model.pdparams')) - model.set_state_dict(state_dict) - logger.info(f'Checkpoint loaded from {args.load_checkpoint}') - - # stage4: construct the enroll and test dataloader - enrol_ds = VoxCeleb1( - subset='enrol', - target_dir=args.data_dir, - feat_type='melspectrogram', - random_chunk=False, - **cpu_feat_conf) - enrol_sampler = BatchSampler( - enrol_ds, batch_size=args.batch_size, - shuffle=True) # Shuffle to make embedding normalization more robust. - enrol_loader = DataLoader(enrol_ds, - batch_sampler=enrol_sampler, - collate_fn=lambda x: feature_normalize( - x, mean_norm=True, std_norm=False), - num_workers=args.num_workers, - return_list=True,) - - test_ds = VoxCeleb1( - subset='test', - target_dir=args.data_dir, - feat_type='melspectrogram', - random_chunk=False, - **cpu_feat_conf) - - test_sampler = BatchSampler( - test_ds, batch_size=args.batch_size, shuffle=True) - test_loader = DataLoader(test_ds, - batch_sampler=test_sampler, - collate_fn=lambda x: feature_normalize( - x, mean_norm=True, std_norm=False), - num_workers=args.num_workers, - return_list=True,) - # stage6: we must set the model to eval mode - model.eval() - - # stage7: global embedding norm to imporve the performance - if args.global_embedding_norm: - global_embedding_mean = None - global_embedding_std = None - mean_norm_flag = args.embedding_mean_norm - std_norm_flag = args.embedding_std_norm - batch_count = 0 - - # stage8: Compute embeddings of audios in enrol and test dataset from model. - id2embedding = {} - # Run multi times to make embedding normalization more stable. - for i in range(2): - for dl in [enrol_loader, test_loader]: - logger.info( - f'Loop {[i+1]}: Computing embeddings on {dl.dataset.subset} dataset' - ) - with paddle.no_grad(): - for batch_idx, batch in enumerate(tqdm(dl)): - - # stage 8-1: extrac the audio embedding - ids, feats, lengths = batch['ids'], batch['feats'], batch[ - 'lengths'] - embeddings = model.backbone(feats, lengths).squeeze( - -1).numpy() # (N, emb_size, 1) -> (N, emb_size) - - # Global embedding normalization. - if args.global_embedding_norm: - batch_count += 1 - current_mean = embeddings.mean( - axis=0) if mean_norm_flag else 0 - current_std = embeddings.std( - axis=0) if std_norm_flag else 1 - # Update global mean and std. - if global_embedding_mean is None and global_embedding_std is None: - global_embedding_mean, global_embedding_std = current_mean, current_std - else: - weight = 1 / batch_count # Weight decay by batches. - global_embedding_mean = ( - 1 - weight - ) * global_embedding_mean + weight * current_mean - global_embedding_std = ( - 1 - weight - ) * global_embedding_std + weight * current_std - # Apply global embedding normalization. - embeddings = (embeddings - global_embedding_mean - ) / global_embedding_std - - # Update embedding dict. - id2embedding.update(dict(zip(ids, embeddings))) - - # stage 9: Compute cosine scores. - labels = [] - enrol_ids = [] - test_ids = [] - with open(VoxCeleb1.veri_test_file, 'r') as f: - for line in f.readlines(): - label, enrol_id, test_id = line.strip().split(' ') - labels.append(int(label)) - enrol_ids.append(enrol_id.split('.')[0].replace('/', '-')) - test_ids.append(test_id.split('.')[0].replace('/', '-')) - - cos_sim_func = paddle.nn.CosineSimilarity(axis=1) - enrol_embeddings, test_embeddings = map(lambda ids: paddle.to_tensor( - np.asarray([id2embedding[id] for id in ids], dtype='float32')), - [enrol_ids, test_ids - ]) # (N, emb_size) - scores = cos_sim_func(enrol_embeddings, test_embeddings) - EER, threshold = compute_eer(np.asarray(labels), scores.numpy()) - logger.info( - f'EER of verification test: {EER*100:.4f}%, score threshold: {threshold:.5f}' - ) - - -if __name__ == "__main__": - # yapf: disable - parser = argparse.ArgumentParser(__doc__) - parser.add_argument('--device', - choices=['cpu', 'gpu'], - default="gpu", - help="Select which device to train model, defaults to gpu.") - parser.add_argument("--seed", - default=0, - type=int, - help="random seed for paddle, numpy and python random package") - parser.add_argument("--data-dir", - default="./data/", - type=str, - help="data directory") - parser.add_argument("--batch-size", - type=int, - default=16, - help="Total examples' number in batch for extract the embedding.") - parser.add_argument("--num-workers", - type=int, - default=0, - help="Number of workers in dataloader.") - parser.add_argument("--load-checkpoint", - type=str, - default='', - help="Directory to load model checkpoint to contiune trainning.") - parser.add_argument("--global-embedding-norm", - type=bool, - default=True, - help="Apply global normalization on speaker embeddings.") - parser.add_argument("--embedding-mean-norm", - type=bool, - default=True, - help="Apply mean normalization on speaker embeddings.") - parser.add_argument("--embedding-std-norm", - type=bool, - default=False, - help="Apply std normalization on speaker embeddings.") - args = parser.parse_args() - # yapf: enable - - main(args) diff --git a/examples/voxceleb/sv0/local/train.py b/examples/voxceleb/sv0/local/train.py deleted file mode 100644 index 3fe67c8e..00000000 --- a/examples/voxceleb/sv0/local/train.py +++ /dev/null @@ -1,326 +0,0 @@ -# Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved. -# -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. -import argparse -import os - -import numpy as np -import paddle -from paddle.io import BatchSampler -from paddle.io import DataLoader -from paddle.io import DistributedBatchSampler - -from paddleaudio.datasets.voxceleb import VoxCeleb1 -from paddleaudio.features.core import melspectrogram -from paddlespeech.s2t.utils.log import Log -from paddlespeech.vector.io.augment import build_augment_pipeline -from paddlespeech.vector.io.augment import waveform_augment -from paddlespeech.vector.io.batch import feature_normalize -from paddlespeech.vector.io.batch import waveform_collate_fn -from paddlespeech.vector.models.ecapa_tdnn import EcapaTdnn -from paddlespeech.vector.modules.loss import AdditiveAngularMargin -from paddlespeech.vector.modules.loss import LogSoftmaxWrapper -from paddlespeech.vector.modules.lr import CyclicLRScheduler -from paddlespeech.vector.modules.sid_model import SpeakerIdetification -from paddlespeech.vector.training.seeding import seed_everything -from paddlespeech.vector.utils.time import Timer - -logger = Log(__name__).getlog() - -# feat configuration -cpu_feat_conf = { - 'n_mels': 80, - 'window_size': 400, #ms - 'hop_length': 160, #ms -} - - -def main(args): - # stage0: set the training device, cpu or gpu - paddle.set_device(args.device) - - # stage1: we must call the paddle.distributed.init_parallel_env() api at the begining - paddle.distributed.init_parallel_env() - nranks = paddle.distributed.get_world_size() - local_rank = paddle.distributed.get_rank() - # set the random seed, it is a must for multiprocess training - seed_everything(args.seed) - - # stage2: data prepare, such vox1 and vox2 data, and augment noise data and pipline - # note: some cmd must do in rank==0, so wo will refactor the data prepare code - train_dataset = VoxCeleb1('train', target_dir=args.data_dir) - dev_dataset = VoxCeleb1('dev', target_dir=args.data_dir) - - if args.augment: - augment_pipeline = build_augment_pipeline(target_dir=args.data_dir) - else: - augment_pipeline = [] - - # stage3: build the dnn backbone model network - #"channels": [1024, 1024, 1024, 1024, 3072], - model_conf = { - "input_size": 80, - "channels": [512, 512, 512, 512, 1536], - "kernel_sizes": [5, 3, 3, 3, 1], - "dilations": [1, 2, 3, 4, 1], - "attention_channels": 128, - "lin_neurons": 192, - } - ecapa_tdnn = EcapaTdnn(**model_conf) - - # stage4: build the speaker verification train instance with backbone model - model = SpeakerIdetification( - backbone=ecapa_tdnn, num_class=VoxCeleb1.num_speakers) - - # stage5: build the optimizer, we now only construct the AdamW optimizer - lr_schedule = CyclicLRScheduler( - base_lr=args.learning_rate, max_lr=1e-3, step_size=140000 // nranks) - optimizer = paddle.optimizer.AdamW( - learning_rate=lr_schedule, parameters=model.parameters()) - - # stage6: build the loss function, we now only support LogSoftmaxWrapper - criterion = LogSoftmaxWrapper( - loss_fn=AdditiveAngularMargin(margin=0.2, scale=30)) - - # stage7: confirm training start epoch - # if pre-trained model exists, start epoch confirmed by the pre-trained model - start_epoch = 0 - if args.load_checkpoint: - logger.info("load the check point") - args.load_checkpoint = os.path.abspath( - os.path.expanduser(args.load_checkpoint)) - try: - # load model checkpoint - state_dict = paddle.load( - os.path.join(args.load_checkpoint, 'model.pdparams')) - model.set_state_dict(state_dict) - - # load optimizer checkpoint - state_dict = paddle.load( - os.path.join(args.load_checkpoint, 'model.pdopt')) - optimizer.set_state_dict(state_dict) - if local_rank == 0: - logger.info(f'Checkpoint loaded from {args.load_checkpoint}') - except FileExistsError: - if local_rank == 0: - logger.info('Train from scratch.') - - try: - start_epoch = int(args.load_checkpoint[-1]) - logger.info(f'Restore training from epoch {start_epoch}.') - except ValueError: - pass - - # stage8: we build the batch sampler for paddle.DataLoader - train_sampler = DistributedBatchSampler( - train_dataset, batch_size=args.batch_size, shuffle=True, drop_last=False) - train_loader = DataLoader( - train_dataset, - batch_sampler=train_sampler, - num_workers=args.num_workers, - collate_fn=waveform_collate_fn, - return_list=True, - use_buffer_reader=True, ) - - # stage9: start to train - # we will comment the training process - steps_per_epoch = len(train_sampler) - timer = Timer(steps_per_epoch * args.epochs) - timer.start() - - for epoch in range(start_epoch + 1, args.epochs + 1): - # at the begining, model must set to train mode - model.train() - - avg_loss = 0 - num_corrects = 0 - num_samples = 0 - for batch_idx, batch in enumerate(train_loader): - # stage 9-1: batch data is audio sample points and speaker id label - waveforms, labels = batch['waveforms'], batch['labels'] - - # stage 9-2: audio sample augment method, which is done on the audio sample point - if len(augment_pipeline) != 0: - waveforms = waveform_augment(waveforms, augment_pipeline) - labels = paddle.concat( - [labels for i in range(len(augment_pipeline) + 1)]) - - # stage 9-3: extract the audio feats,such fbank, mfcc, spectrogram - feats = [] - for waveform in waveforms.numpy(): - feat = melspectrogram(x=waveform, **cpu_feat_conf) - feats.append(feat) - feats = paddle.to_tensor(np.asarray(feats)) - - # stage 9-4: feature normalize, which help converge and imporve the performance - feats = feature_normalize( - feats, mean_norm=True, std_norm=False) # Features normalization - - # stage 9-5: model forward, such ecapa-tdnn, x-vector - logits = model(feats) - - # stage 9-6: loss function criterion, such AngularMargin, AdditiveAngularMargin - loss = criterion(logits, labels) - - # stage 9-7: update the gradient and clear the gradient cache - loss.backward() - optimizer.step() - if isinstance(optimizer._learning_rate, - paddle.optimizer.lr.LRScheduler): - optimizer._learning_rate.step() - optimizer.clear_grad() - - # stage 9-8: Calculate average loss per batch - avg_loss += loss.numpy()[0] - - # stage 9-9: Calculate metrics, which is one-best accuracy - preds = paddle.argmax(logits, axis=1) - num_corrects += (preds == labels).numpy().sum() - num_samples += feats.shape[0] - timer.count() # step plus one in timer - - # stage 9-10: print the log information only on 0-rank per log-freq batchs - if (batch_idx + 1) % args.log_freq == 0 and local_rank == 0: - lr = optimizer.get_lr() - avg_loss /= args.log_freq - avg_acc = num_corrects / num_samples - - print_msg = 'Train Epoch={}/{}, Step={}/{}'.format( - epoch, args.epochs, batch_idx + 1, steps_per_epoch) - print_msg += ' loss={:.4f}'.format(avg_loss) - print_msg += ' acc={:.4f}'.format(avg_acc) - print_msg += ' lr={:.4E} step/sec={:.2f} | ETA {}'.format( - lr, timer.timing, timer.eta) - logger.info(print_msg) - - avg_loss = 0 - num_corrects = 0 - num_samples = 0 - - # stage 9-11: save the model parameters only on 0-rank per save-freq batchs - if epoch % args.save_freq == 0 and batch_idx + 1 == steps_per_epoch: - if local_rank != 0: - paddle.distributed.barrier( - ) # Wait for valid step in main process - continue # Resume trainning on other process - - # stage 9-12: construct the valid dataset dataloader - dev_sampler = BatchSampler( - dev_dataset, - batch_size=args.batch_size // 4, - shuffle=False, - drop_last=False) - dev_loader = DataLoader( - dev_dataset, - batch_sampler=dev_sampler, - collate_fn=waveform_collate_fn, - num_workers=args.num_workers, - return_list=True, ) - - # set the model to eval mode - model.eval() - num_corrects = 0 - num_samples = 0 - - # stage 9-13: evaluation the valid dataset batch data - logger.info('Evaluate on validation dataset') - with paddle.no_grad(): - for batch_idx, batch in enumerate(dev_loader): - waveforms, labels = batch['waveforms'], batch['labels'] - - feats = [] - for waveform in waveforms.numpy(): - feat = melspectrogram(x=waveform, **cpu_feat_conf) - feats.append(feat) - - feats = paddle.to_tensor(np.asarray(feats)) - feats = feature_normalize( - feats, mean_norm=True, std_norm=False) - logits = model(feats) - - preds = paddle.argmax(logits, axis=1) - num_corrects += (preds == labels).numpy().sum() - num_samples += feats.shape[0] - - print_msg = '[Evaluation result]' - print_msg += ' dev_acc={:.4f}'.format(num_corrects / num_samples) - logger.info(print_msg) - - # stage 9-14: Save model parameters - save_dir = os.path.join(args.checkpoint_dir, - 'epoch_{}'.format(epoch)) - logger.info('Saving model checkpoint to {}'.format(save_dir)) - paddle.save(model.state_dict(), - os.path.join(save_dir, 'model.pdparams')) - paddle.save(optimizer.state_dict(), - os.path.join(save_dir, 'model.pdopt')) - - if nranks > 1: - paddle.distributed.barrier() # Main process - - -if __name__ == "__main__": - # yapf: disable - parser = argparse.ArgumentParser(__doc__) - parser.add_argument('--device', - choices=['cpu', 'gpu'], - default="cpu", - help="Select which device to train model, defaults to gpu.") - parser.add_argument("--seed", - default=0, - type=int, - help="random seed for paddle, numpy and python random package") - parser.add_argument("--data-dir", - default="./data/", - type=str, - help="data directory") - parser.add_argument("--learning-rate", - type=float, - default=1e-8, - help="Learning rate used to train with warmup.") - parser.add_argument("--load-checkpoint", - type=str, - default=None, - help="Directory to load model checkpoint to contiune trainning.") - parser.add_argument("--batch-size", - type=int, default=64, - help="Total examples' number in batch for training.") - parser.add_argument("--num-workers", - type=int, - default=0, - help="Number of workers in dataloader.") - parser.add_argument("--epochs", - type=int, - default=50, - help="Number of epoches for fine-tuning.") - parser.add_argument("--log-freq", - type=int, - default=10, - help="Log the training infomation every n steps.") - parser.add_argument("--save-freq", - type=int, - default=1, - help="Save checkpoint every n epoch.") - parser.add_argument("--checkpoint-dir", - type=str, - default='./checkpoint', - help="Directory to save model checkpoints.") - parser.add_argument("--augment", - action="store_true", - default=False, - help="Apply audio augments.") - - args = parser.parse_args() - # yapf: enable - - main(args) diff --git a/examples/voxceleb/sv0/path.sh b/examples/voxceleb/sv0/path.sh index 38a242a4..6d19f994 100755 --- a/examples/voxceleb/sv0/path.sh +++ b/examples/voxceleb/sv0/path.sh @@ -9,3 +9,6 @@ export PYTHONIOENCODING=UTF-8 export PYTHONPATH=${MAIN_ROOT}:${PYTHONPATH} export LD_LIBRARY_PATH=${LD_LIBRARY_PATH}:/usr/local/lib/ + +MODEL=ecapa-tdnn +export BIN_DIR=${MAIN_ROOT}/paddlespeech/vector/exps/${MODEL} \ No newline at end of file diff --git a/examples/voxceleb/sv0/run.sh b/examples/voxceleb/sv0/run.sh index a2336fb6..a6346cd5 100755 --- a/examples/voxceleb/sv0/run.sh +++ b/examples/voxceleb/sv0/run.sh @@ -30,23 +30,21 @@ if [ $stage -le 1 ]; then # stage 1: train the speaker identification model python3 \ -m paddle.distributed.launch --gpus=0,1,2,3 \ - local/train.py --device "gpu" --checkpoint-dir ${exp_dir} --augment \ + ${BIN_DIR}/train.py --device "gpu" --checkpoint-dir ${exp_dir} --augment \ --save-freq 10 --data-dir ${dir} --batch-size 64 --epochs 100 fi if [ $stage -le 2 ]; then - # stage 1: train the speaker identification model - # you can set the variable PPAUDIO_HOME to specifiy the downloaded the vox1 and vox2 dataset + # stage 1: get the speaker verification scores with cosine function python3 \ - local/speaker_verification_cosine.py\ + ${BIN_DIR}/speaker_verification_cosine.py\ --batch-size 4 --data-dir ${dir} --load-checkpoint ${exp_dir}/epoch_10/ fi if [ $stage -le 3 ]; then - # stage 1: train the speaker identification model - # you can set the variable PPAUDIO_HOME to specifiy the downloaded the vox1 and vox2 dataset + # stage 3: extract the audio embedding python3 \ - local/extract_speaker_embedding.py\ + ${BIN_DIR}/extract_speaker_embedding.py\ --audio-path "demo/csv/00001.wav" --load-checkpoint ${exp_dir}/epoch_60/ fi diff --git a/paddleaudio/paddleaudio/datasets/__init__.py b/paddleaudio/paddleaudio/datasets/__init__.py index 5c5f0369..cbf9b3ae 100644 --- a/paddleaudio/paddleaudio/datasets/__init__.py +++ b/paddleaudio/paddleaudio/datasets/__init__.py @@ -15,3 +15,5 @@ from .esc50 import ESC50 from .gtzan import GTZAN from .tess import TESS from .urban_sound import UrbanSound8K +from .voxceleb import VoxCeleb1 +from .rirs_noises import OpenRIRNoise diff --git a/paddleaudio/datasets/rirs_noises.py b/paddleaudio/paddleaudio/datasets/rirs_noises.py similarity index 97% rename from paddleaudio/datasets/rirs_noises.py rename to paddleaudio/paddleaudio/datasets/rirs_noises.py index 6af9fd9d..df5dec61 100644 --- a/paddleaudio/datasets/rirs_noises.py +++ b/paddleaudio/paddleaudio/datasets/rirs_noises.py @@ -23,11 +23,11 @@ from typing import Tuple from paddle.io import Dataset from tqdm import tqdm -from paddleaudio.backends import load as load_audio -from paddleaudio.backends import save_wav -from paddleaudio.datasets.dataset import feat_funcs -from paddleaudio.utils import DATA_HOME -from paddleaudio.utils import decompress +from ..backends import load as load_audio +from ..backends import save as save_wav +from .dataset import feat_funcs +from ..utils import DATA_HOME +from ..utils import decompress from paddlespeech.s2t.utils.log import Log from paddlespeech.vector.utils.download import download_and_decompress diff --git a/paddleaudio/datasets/voxceleb.py b/paddleaudio/paddleaudio/datasets/voxceleb.py similarity index 97% rename from paddleaudio/datasets/voxceleb.py rename to paddleaudio/paddleaudio/datasets/voxceleb.py index 0011340e..4989accb 100644 --- a/paddleaudio/datasets/voxceleb.py +++ b/paddleaudio/paddleaudio/datasets/voxceleb.py @@ -25,10 +25,10 @@ from paddle.io import Dataset from pathos.multiprocessing import Pool from tqdm import tqdm -from paddleaudio.backends import load as load_audio -from paddleaudio.datasets.dataset import feat_funcs -from paddleaudio.utils import DATA_HOME -from paddleaudio.utils import decompress +from .dataset import feat_funcs +from ..backends import load as load_audio +from ..utils import DATA_HOME +from ..utils import decompress from paddlespeech.s2t.utils.log import Log from paddlespeech.vector.utils.download import download_and_decompress from utils.utility import download @@ -83,7 +83,7 @@ class VoxCeleb1(Dataset): meta_path = os.path.join(base_path, 'meta') veri_test_file = os.path.join(meta_path, 'veri_test2.txt') csv_path = os.path.join(base_path, 'csv') - subsets = ['train', 'dev', 'enrol', 'test'] + subsets = ['train', 'dev', 'enroll', 'test'] def __init__( self, @@ -330,7 +330,7 @@ class VoxCeleb1(Dataset): self.generate_csv( enroll_files, - os.path.join(self.csv_path, 'enrol.csv'), + os.path.join(self.csv_path, 'enroll.csv'), split_chunks=False) self.generate_csv( test_files, diff --git a/paddleaudio/paddleaudio/metric/__init__.py b/paddleaudio/paddleaudio/metric/__init__.py index a96530ff..b435571d 100644 --- a/paddleaudio/paddleaudio/metric/__init__.py +++ b/paddleaudio/paddleaudio/metric/__init__.py @@ -13,3 +13,4 @@ # limitations under the License. from .dtw import dtw_distance from .mcd import mcd_distance +from .eer import compute_eer diff --git a/paddlespeech/vector/training/metrics.py b/paddleaudio/paddleaudio/metric/eer.py similarity index 100% rename from paddlespeech/vector/training/metrics.py rename to paddleaudio/paddleaudio/metric/eer.py diff --git a/paddlespeech/vector/io/augment.py b/paddlespeech/vector/io/augment.py index 366c0cff..76312978 100644 --- a/paddlespeech/vector/io/augment.py +++ b/paddlespeech/vector/io/augment.py @@ -20,8 +20,8 @@ import paddle import paddle.nn as nn import paddle.nn.functional as F -from paddleaudio.backends import load as load_audio -from paddleaudio.datasets.rirs_noises import OpenRIRNoise +from paddleaudio.paddleaudio import load as load_audio +from paddleaudio.paddleaudio.datasets.rirs_noises import OpenRIRNoise from paddlespeech.s2t.utils.log import Log from paddlespeech.vector.io.signal_processing import compute_amplitude from paddlespeech.vector.io.signal_processing import convolve1d diff --git a/paddlespeech/vector/io/batch.py b/paddlespeech/vector/io/batch.py index 879cde3a..811775e2 100644 --- a/paddlespeech/vector/io/batch.py +++ b/paddlespeech/vector/io/batch.py @@ -40,3 +40,41 @@ def feature_normalize(feats: paddle.Tensor, feats = (feats - mean) / std return feats + + +def pad_right_2d(x, target_length, axis=-1, mode='constant', **kwargs): + x = np.asarray(x) + assert len( + x.shape) == 2, f'Only 2D arrays supported, but got shape: {x.shape}' + + w = target_length - x.shape[axis] + assert w >= 0, f'Target length {target_length} is less than origin length {x.shape[axis]}' + + if axis == 0: + pad_width = [[0, w], [0, 0]] + else: + pad_width = [[0, 0], [0, w]] + + return np.pad(x, pad_width, mode=mode, **kwargs) + +def batch_feature_normalize(batch, mean_norm: bool=True, std_norm: bool=True): + ids = [item['id'] for item in batch] + lengths = np.asarray([item['feat'].shape[1] for item in batch]) + feats = list( + map(lambda x: pad_right_2d(x, lengths.max()), + [item['feat'] for item in batch])) + feats = np.stack(feats) + + # Features normalization if needed + for i in range(len(feats)): + feat = feats[i][:, :lengths[i]] # Excluding pad values. + mean = feat.mean(axis=-1, keepdims=True) if mean_norm else 0 + std = feat.std(axis=-1, keepdims=True) if std_norm else 1 + feats[i][:, :lengths[i]] = (feat - mean) / std + assert feats[i][:, lengths[ + i]:].sum() == 0 # Padding valus should all be 0. + + # Converts into ratios. + lengths = (lengths / lengths.max()).astype(np.float32) + + return {'ids': ids, 'feats': feats, 'lengths': lengths} \ No newline at end of file diff --git a/paddlespeech/vector/modules/lr.py b/paddlespeech/vector/training/scheduler.py similarity index 100% rename from paddlespeech/vector/modules/lr.py rename to paddlespeech/vector/training/scheduler.py