diff --git a/examples/voxceleb/README.md b/examples/voxceleb/README.md index 2c8ad138..59fb491c 100644 --- a/examples/voxceleb/README.md +++ b/examples/voxceleb/README.md @@ -6,3 +6,56 @@ sv0 - speaker verfication with softmax backend etc, all python code sv1 - dependence on kaldi, speaker verfication with plda/sc backend, more info refer to the sv1/readme.txt + + +## VoxCeleb2 preparation + +VoxCeleb2 audio files are released in m4a format. All the VoxCeleb2 m4a audio files must be converted in wav files before feeding them in PaddleSpeech. +Please, follow these steps to prepare the dataset correctly: + +1. Download Voxceleb2. +You can find download instructions here: http://www.robots.ox.ac.uk/~vgg/data/voxceleb/ + +2. Convert .m4a to wav +VoxCeleb2 stores files with the m4a audio format. To use them in PaddleSpeech, you have to convert all the m4a audio files into wav files. + +``` shell +ffmpeg -y -i %s -ac 1 -vn -acodec pcm_s16le -ar 16000 %s +``` + +``` shell +# copy this to root directory of data and +# chmod a+x convert.sh +# ./convert.sh +# https://unix.stackexchange.com/questions/103920/parallelize-a-bash-for-loop + +open_sem(){ + mkfifo pipe-$$ + exec 3<>pipe-$$ + rm pipe-$$ + local i=$1 + for((;i>0;i--)); do + printf %s 000 >&3 + done +} +run_with_lock(){ + local x + read -u 3 -n 3 x && ((0==x)) || exit $x + ( + ( "$@"; ) + printf '%.3d' $? >&3 + )& +} + +N=32 # number of vCPU +open_sem $N +for f in $(find . -name "*.m4a"); do + run_with_lock ffmpeg -loglevel panic -i "$f" -ar 16000 "${f%.*}.wav" +done +``` + +You can do the conversion using ffmpeg https://gist.github.com/seungwonpark/4f273739beef2691cd53b5c39629d830). This operation might take several hours and should be only once. + +3. Put all the wav files in a folder called `wav`. You should have something like `voxceleb2/wav/id*/*.wav` (e.g, `voxceleb2/wav/id00012/21Uxsk56VDQ/00001.wav`) + +4. \ No newline at end of file diff --git a/examples/voxceleb/sv0/local/data_prepare.py b/examples/voxceleb/sv0/local/data_prepare.py new file mode 100644 index 00000000..ca707fc2 --- /dev/null +++ b/examples/voxceleb/sv0/local/data_prepare.py @@ -0,0 +1,60 @@ +import argparse +import os + +import numpy as np +import paddle +from paddle.io import BatchSampler +from paddle.io import DataLoader +from paddle.io import DistributedBatchSampler + +from paddleaudio.datasets.voxceleb import VoxCeleb1 +from paddleaudio.features.core import melspectrogram +from paddlespeech.s2t.utils.log import Log +from paddlespeech.vector.io.augment import build_augment_pipeline +from paddlespeech.vector.io.augment import waveform_augment +from paddlespeech.vector.io.batch import feature_normalize +from paddlespeech.vector.io.batch import waveform_collate_fn +from paddlespeech.vector.models.ecapa_tdnn import EcapaTdnn +from paddlespeech.vector.modules.loss import AdditiveAngularMargin +from paddlespeech.vector.modules.loss import LogSoftmaxWrapper +from paddlespeech.vector.modules.lr import CyclicLRScheduler +from paddlespeech.vector.modules.sid_model import SpeakerIdetification +from paddlespeech.vector.training.seeding import seed_everything +from paddlespeech.vector.utils.time import Timer + +logger = Log(__name__).getlog() + +def main(args): + # stage0: set the cpu device, all data prepare process will be done in cpu mode + paddle.set_device("cpu") + # set the random seed, it is a must for multiprocess training + seed_everything(args.seed) + + # stage 1: generate the voxceleb csv file + # Note: this may occurs c++ execption, but the program will execute fine + # so we can ignore the execption + train_dataset = VoxCeleb1('train', target_dir=args.data_dir) + dev_dataset = VoxCeleb1('dev', target_dir=args.data_dir) + + # stage 2: generate the augment noise csv file + if args.augment: + augment_pipeline = build_augment_pipeline(target_dir=args.data_dir) + +if __name__ == "__main__": + # yapf: disable + parser = argparse.ArgumentParser(__doc__) + parser.add_argument("--seed", + default=0, + type=int, + help="random seed for paddle, numpy and python random package") + parser.add_argument("--data-dir", + default="./data/", + type=str, + help="data directory") + parser.add_argument("--augment", + action="store_true", + default=False, + help="Apply audio augments.") + args = parser.parse_args() + # yapf: enable + main(args) \ No newline at end of file diff --git a/examples/voxceleb/sv0/run.sh b/examples/voxceleb/sv0/run.sh index 34a1cbd4..7ad3a36f 100755 --- a/examples/voxceleb/sv0/run.sh +++ b/examples/voxceleb/sv0/run.sh @@ -20,10 +20,10 @@ exp_dir=exp/ecapa-tdnn/ # experiment directory mkdir -p ${dir} mkdir -p ${exp_dir} -# if [ $stage -le 0 ]; then -# # stage 0: data prepare for vox1 and vox2, vox2 must be converted from m4a to wav -# # todo -# fi +if [ $stage -le 0 ]; then + # stage 0: data prepare for vox1 and vox2, vox2 must be converted from m4a to wav + python3 local/data_prepare.py --data-dir ${dir} --augment +fi if [ $stage -le 1 ]; then # stage 1: train the speaker identification model diff --git a/paddleaudio/datasets/rirs_noises.py b/paddleaudio/datasets/rirs_noises.py index fa9e7f09..6af9fd9d 100644 --- a/paddleaudio/datasets/rirs_noises.py +++ b/paddleaudio/datasets/rirs_noises.py @@ -69,8 +69,9 @@ class OpenRIRNoise(Dataset): self.random_chunk = random_chunk self.chunk_duration = chunk_duration - self.csv_path = os.path.join(target_dir, "open_rir_noise", - "csv") if target_dir else self.csv_path + OpenRIRNoise.csv_path = os.path.join( + target_dir, "open_rir_noise", + "csv") if target_dir else self.csv_path self._data = self._get_data() super(OpenRIRNoise, self).__init__() diff --git a/paddleaudio/datasets/voxceleb.py b/paddleaudio/datasets/voxceleb.py index c97e825e..0011340e 100644 --- a/paddleaudio/datasets/voxceleb.py +++ b/paddleaudio/datasets/voxceleb.py @@ -16,6 +16,7 @@ import csv import glob import os import random +from multiprocessing import cpu_count from typing import Dict from typing import List from typing import Tuple @@ -28,8 +29,8 @@ from paddleaudio.backends import load as load_audio from paddleaudio.datasets.dataset import feat_funcs from paddleaudio.utils import DATA_HOME from paddleaudio.utils import decompress -from paddlespeech.vector.utils.download import download_and_decompress from paddlespeech.s2t.utils.log import Log +from paddlespeech.vector.utils.download import download_and_decompress from utils.utility import download from utils.utility import unpack @@ -105,14 +106,15 @@ class VoxCeleb1(Dataset): self.random_chunk = random_chunk self.chunk_duration = chunk_duration self.split_ratio = split_ratio - self.target_dir = target_dir if target_dir else self.base_path + self.target_dir = target_dir if target_dir else VoxCeleb1.base_path + + # if we set the target dir, we will change the vox data info data from base path to target dir VoxCeleb1.csv_path = os.path.join( - target_dir, 'csv') if target_dir else os.path.join(self.base_path, - 'csv') + target_dir, "voxceleb", 'csv') if target_dir else VoxCeleb1.csv_path VoxCeleb1.meta_path = os.path.join( - target_dir, 'meta') if target_dir else os.path.join(self.base_path, - 'meta') - VoxCeleb1.veri_test_file = os.path.join(self.meta_path, + target_dir, "voxceleb", + 'meta') if target_dir else VoxCeleb1.meta_path + VoxCeleb1.veri_test_file = os.path.join(VoxCeleb1.meta_path, 'veri_test2.txt') # self._data = self._get_data()[:1000] # KP: Small dataset test. self._data = self._get_data() @@ -255,8 +257,9 @@ class VoxCeleb1(Dataset): split_chunks: bool=True): logger.info(f'Generating csv: {output_file}') header = ["id", "duration", "wav", "start", "stop", "spk_id"] - - with Pool(64) as p: + # Note: this may occurs c++ execption, but the program will execute fine + # so we can ignore the execption + with Pool(cpu_count()) as p: infos = list( tqdm( p.imap(lambda x: self._get_audio_info(x, split_chunks), @@ -277,20 +280,20 @@ class VoxCeleb1(Dataset): def prepare_data(self): # Audio of speakers in veri_test_file should not be included in training set. logger.info("start to prepare the data csv file") - enrol_files = set() + enroll_files = set() test_files = set() # get the enroll and test audio file path with open(self.veri_test_file, 'r') as f: for line in f.readlines(): _, enrol_file, test_file = line.strip().split(' ') - enrol_files.add(os.path.join(self.wav_path, enrol_file)) + enroll_files.add(os.path.join(self.wav_path, enrol_file)) test_files.add(os.path.join(self.wav_path, test_file)) - enrol_files = sorted(enrol_files) + enroll_files = sorted(enroll_files) test_files = sorted(test_files) # get the enroll and test speakers test_spks = set() - for file in (enrol_files + test_files): + for file in (enroll_files + test_files): spk = file.split('/wav/')[1].split('/')[0] test_spks.add(spk) @@ -306,8 +309,9 @@ class VoxCeleb1(Dataset): speakers.add(spk) audio_files.append(file) - logger.info("start to generate the {}".format( - os.path.join(self.meta_path, 'spk_id2label.txt'))) + logger.info( + f"start to generate the {os.path.join(self.meta_path, 'spk_id2label.txt')}" + ) # encode the train and dev speakers label to spk_id2label.txt with open(os.path.join(self.meta_path, 'spk_id2label.txt'), 'w') as f: for label, spk_id in enumerate( @@ -323,8 +327,9 @@ class VoxCeleb1(Dataset): self.generate_csv(train_files, os.path.join(self.csv_path, 'train.csv')) self.generate_csv(dev_files, os.path.join(self.csv_path, 'dev.csv')) + self.generate_csv( - enrol_files, + enroll_files, os.path.join(self.csv_path, 'enrol.csv'), split_chunks=False) self.generate_csv( diff --git a/paddlespeech/vector/io/augment.py b/paddlespeech/vector/io/augment.py index af7aeb22..366c0cff 100644 --- a/paddlespeech/vector/io/augment.py +++ b/paddlespeech/vector/io/augment.py @@ -840,7 +840,7 @@ def build_augment_pipeline(target_dir=None) -> List[paddle.nn.Layer]: """ logger.info("start to build the augment pipeline") noise_dataset = OpenRIRNoise('noise', target_dir=target_dir) - rir_dataset = OpenRIRNoise('rir') + rir_dataset = OpenRIRNoise('rir', target_dir=target_dir) wavedrop = TimeDomainSpecAugment( sample_rate=16000,