diff --git a/dataset/voxceleb/voxceleb1.py b/dataset/voxceleb/voxceleb1.py index e50c91bc1..0c9c68dc9 100644 --- a/dataset/voxceleb/voxceleb1.py +++ b/dataset/voxceleb/voxceleb1.py @@ -11,182 +11,299 @@ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. -"""Prepare VoxCeleb1 dataset - -create manifest files. -Manifest file is a json-format file with each line containing the -meta data (i.e. audio filepath, transcript and audio duration) -of each audio file in the data set. - -researchers should download the voxceleb1 dataset yourselves -through google form to get the username & password and unpack the data -""" -import argparse -import codecs + +import collections +import csv import glob -import json import os -import subprocess -from pathlib import Path +import random +from typing import Dict, List, Tuple -import soundfile +from paddle.io import Dataset +from tqdm import tqdm +from pathos.multiprocessing import Pool -from utils.utility import check_md5sum +from paddleaudio.backends import load as load_audio +from paddleaudio.utils import DATA_HOME, decompress, download_and_decompress +from paddleaudio.datasets.dataset import feat_funcs +from utils.utility import unpack from utils.utility import download -from utils.utility import unzip - -# all the data will be download in the current data/voxceleb directory default -DATA_HOME = os.path.expanduser('.') - -# if you use the http://www.robots.ox.ac.uk/~vgg/data/voxceleb/vox1a/ as the download base url -# you need to get the username & password via the google form - -# if you use the https://thor.robots.ox.ac.uk/~vgg/data/voxceleb/vox1a as the download base url, -# you need use --no-check-certificate to connect the target download url - -BASE_URL = "https://thor.robots.ox.ac.uk/~vgg/data/voxceleb/vox1a" - -# dev data -DEV_LIST = { - "vox1_dev_wav_partaa": "e395d020928bc15670b570a21695ed96", - "vox1_dev_wav_partab": "bbfaaccefab65d82b21903e81a8a8020", - "vox1_dev_wav_partac": "017d579a2a96a077f40042ec33e51512", - "vox1_dev_wav_partad": "7bb1e9f70fddc7a678fa998ea8b3ba19", -} -DEV_TARGET_DATA = "vox1_dev_wav_parta* vox1_dev_wav.zip ae63e55b951748cc486645f532ba230b" - -# test data -TEST_LIST = {"vox1_test_wav.zip": "185fdc63c3c739954633d50379a3d102"} -TEST_TARGET_DATA = "vox1_test_wav.zip vox1_test_wav.zip 185fdc63c3c739954633d50379a3d102" - -# kaldi trial -# this trial file is organized by kaldi according the official file, -# which is a little different with the official trial veri_test2.txt -KALDI_BASE_URL = "http://www.openslr.org/resources/49/" -TRIAL_LIST = {"voxceleb1_test_v2.txt": "29fc7cc1c5d59f0816dc15d6e8be60f7"} -TRIAL_TARGET_DATA = "voxceleb1_test_v2.txt voxceleb1_test_v2.txt 29fc7cc1c5d59f0816dc15d6e8be60f7" - -parser = argparse.ArgumentParser(description=__doc__) -parser.add_argument( - "--target_dir", - default=DATA_HOME + "/voxceleb1/", - type=str, - help="Directory to save the voxceleb1 dataset. (default: %(default)s)") -parser.add_argument( - "--manifest_prefix", - default="manifest", - type=str, - help="Filepath prefix for output manifests. (default: %(default)s)") - -args = parser.parse_args() - - -def create_manifest(data_dir, manifest_path_prefix): - print("Creating manifest %s ..." % manifest_path_prefix) - json_lines = [] - data_path = os.path.join(data_dir, "wav", "**", "*.wav") - total_sec = 0.0 - total_text = 0.0 - total_num = 0 - speakers = set() - for audio_path in glob.glob(data_path, recursive=True): - audio_id = "-".join(audio_path.split("/")[-3:]) - utt2spk = audio_path.split("/")[-3] - duration = soundfile.info(audio_path).duration - text = "" - json_lines.append( - json.dumps( - { - "utt": audio_id, - "utt2spk": str(utt2spk), - "feat": audio_path, - "feat_shape": (duration, ), - "text": text # compatible with asr data format - }, - ensure_ascii=False)) - - total_sec += duration - total_text += len(text) - total_num += 1 - speakers.add(utt2spk) - - # data_dir_name refer to dev or test - # voxceleb1 is given explicit in the path - data_dir_name = Path(data_dir).name - manifest_path_prefix = manifest_path_prefix + "." + data_dir_name - with codecs.open(manifest_path_prefix, 'w', encoding='utf-8') as f: - for line in json_lines: - f.write(line + "\n") - - manifest_dir = os.path.dirname(manifest_path_prefix) - meta_path = os.path.join(manifest_dir, "voxceleb1." + - data_dir_name) + ".meta" - with codecs.open(meta_path, 'w', encoding='utf-8') as f: - print(f"{total_num} utts", file=f) - print(f"{len(speakers)} speakers", file=f) - print(f"{total_sec / (60 * 60)} h", file=f) - print(f"{total_text} text", file=f) - print(f"{total_text / total_sec} text/sec", file=f) - print(f"{total_sec / total_num} sec/utt", file=f) - - -def prepare_dataset(base_url, data_list, target_dir, manifest_path, - target_data): - if not os.path.exists(target_dir): - os.mkdir(target_dir) - - # wav directory already exists, it need do nothing - if not os.path.exists(os.path.join(target_dir, "wav")): - # download all dataset part - for zip_part in data_list.keys(): - download_url = " --no-check-certificate " + base_url + "/" + zip_part - download( - url=download_url, - md5sum=data_list[zip_part], - target_dir=target_dir) - - # pack the all part to target zip file - all_target_part, target_name, target_md5sum = target_data.split() - target_name = os.path.join(target_dir, target_name) - if not os.path.exists(target_name): - pack_part_cmd = "cat {}/{} > {}".format(target_dir, all_target_part, - target_name) - subprocess.call(pack_part_cmd, shell=True) - - # check the target zip file md5sum - if not check_md5sum(target_name, target_md5sum): - raise RuntimeError("{} MD5 checkssum failed".format(target_name)) - else: - print("Check {} md5sum successfully".format(target_name)) - - # unzip the all zip file - if target_name.endswith(".zip"): - unzip(target_name, target_dir) - - # create the manifest file - create_manifest(data_dir=target_dir, manifest_path_prefix=manifest_path) - - -def main(): - if args.target_dir.startswith('~'): - args.target_dir = os.path.expanduser(args.target_dir) - prepare_dataset( - base_url=BASE_URL, - data_list=DEV_LIST, - target_dir=os.path.join(args.target_dir, "dev"), - manifest_path=args.manifest_prefix, - target_data=DEV_TARGET_DATA) - - prepare_dataset( - base_url=BASE_URL, - data_list=TEST_LIST, - target_dir=os.path.join(args.target_dir, "test"), - manifest_path=args.manifest_prefix, - target_data=TEST_TARGET_DATA) - - print("Manifest prepare done!") - - -if __name__ == '__main__': - main() +__all__ = ['VoxCeleb1'] + + +class VoxCeleb1(Dataset): + source_url = 'https://thor.robots.ox.ac.uk/~vgg/data/voxceleb/vox1a/' + archieves_audio_dev = [ + { + 'url': source_url + 'vox1_dev_wav_partaa', + 'md5': 'e395d020928bc15670b570a21695ed96', + }, + { + 'url': source_url + 'vox1_dev_wav_partab', + 'md5': 'bbfaaccefab65d82b21903e81a8a8020', + }, + { + 'url': source_url + 'vox1_dev_wav_partac', + 'md5': '017d579a2a96a077f40042ec33e51512', + }, + { + 'url': source_url + 'vox1_dev_wav_partad', + 'md5': '7bb1e9f70fddc7a678fa998ea8b3ba19', + }, + ] + archieves_audio_test = [ + { + 'url': source_url + 'vox1_test_wav.zip', + 'md5': '185fdc63c3c739954633d50379a3d102', + }, + ] + archieves_meta = [ + { + 'url': 'https://www.robots.ox.ac.uk/~vgg/data/voxceleb/meta/veri_test2.txt', + 'md5': 'b73110731c9223c1461fe49cb48dddfc', + }, + ] + + + num_speakers = 1211 # 1211 vox1, 5994 vox2, 7205 vox1+2, test speakers: 41 + sample_rate = 16000 + meta_info = collections.namedtuple( + 'META_INFO', ('id', 'duration', 'wav', 'start', 'stop', 'spk_id')) + base_path = os.path.join(DATA_HOME, 'vox1') + wav_path = os.path.join(base_path, 'wav') + subsets = ['train', 'dev', 'enrol', 'test'] + + def __init__(self, + subset: str = 'train', + feat_type: str = 'raw', + random_chunk: bool = True, + chunk_duration: float = 3.0, # seconds + split_ratio: float = 0.9, # train split ratio + seed: int = 0, + target_dir: str = None, + **kwargs): + + assert subset in self.subsets, \ + 'Dataset subset must be one in {}, but got {}'.format(self.subsets, subset) + + self.subset = subset + self.spk_id2label = {} + self.feat_type = feat_type + self.feat_config = kwargs + self.random_chunk = random_chunk + self.chunk_duration = chunk_duration + self.split_ratio = split_ratio + self.target_dir = target_dir if target_dir else self.base_path + self.csv_path = os.path.join(target_dir, 'csv') if target_dir else os.path.join(self.base_path, 'csv') + self.meta_path = os.path.join(target_dir, 'meta') if target_dir else os.path.join(base_path, 'meta') + self.veri_test_file = os.path.join(self.meta_path, 'veri_test2.txt') + # self._data = self._get_data()[:1000] # KP: Small dataset test. + self._data = self._get_data() + super(VoxCeleb1, self).__init__() + + # Set up a seed to reproduce training or predicting result. + # random.seed(seed) + + def _get_data(self): + # Download audio files. + # We need the users to decompress all vox1/dev/wav and vox1/test/wav/ to vox1/wav/ dir + # so, we check the vox1/wav dir status + print("wav base path: {}".format(self.wav_path)) + if not os.path.isdir(self.wav_path): + print("start to download the voxceleb1 dataset") + download_and_decompress( # multi-zip parts concatenate to vox1_dev_wav.zip + self.archieves_audio_dev, self.base_path, decompress=False) + download_and_decompress( # download the vox1_test_wav.zip and unzip + self.archieves_audio_test, self.base_path, decompress=True) + + # Download all parts and concatenate the files into one zip file. + dev_zipfile = os.path.join(self.base_path, 'vox1_dev_wav.zip') + print(f'Concatenating all parts to: {dev_zipfile}') + os.system( + f'cat {os.path.join(self.base_path, "vox1_dev_wav_parta*")} > {dev_zipfile}' + ) + + # Extract all audio files of dev and test set. + decompress(dev_zipfile, self.base_path) + + # Download meta files. + if not os.path.isdir(self.meta_path): + download_and_decompress( + self.archieves_meta, self.meta_path, decompress=False) + + # Data preparation. + if not os.path.isdir(self.csv_path): + os.makedirs(self.csv_path) + self.prepare_data() + + data = [] + with open(os.path.join(self.csv_path, f'{self.subset}.csv'), 'r') as rf: + for line in rf.readlines()[1:]: + audio_id, duration, wav, start, stop, spk_id = line.strip( + ).split(',') + data.append( + self.meta_info(audio_id, float(duration), wav, int(start), + int(stop), spk_id)) + + with open(os.path.join(self.meta_path, 'spk_id2label.txt'), 'r') as f: + for line in f.readlines(): + spk_id, label = line.strip().split(' ') + self.spk_id2label[spk_id] = int(label) + + return data + + def _convert_to_record(self, idx: int): + sample = self._data[idx] + + record = {} + # To show all fields in a namedtuple: `type(sample)._fields` + for field in type(sample)._fields: + record[field] = getattr(sample, field) + + waveform, sr = load_audio(record['wav']) + + # random select a chunk audio samples from the audio + if self.random_chunk: + num_wav_samples = waveform.shape[0] + num_chunk_samples = int(self.chunk_duration * sr) + start = random.randint(0, num_wav_samples - num_chunk_samples - 1) + stop = start + num_chunk_samples + else: + start = record['start'] + stop = record['stop'] + + waveform = waveform[start:stop] + + assert self.feat_type in feat_funcs.keys(), \ + f"Unknown feat_type: {self.feat_type}, it must be one in {list(feat_funcs.keys())}" + feat_func = feat_funcs[self.feat_type] + feat = feat_func( + waveform, sr=sr, **self.feat_config) if feat_func else waveform + + record.update({'feat': feat}) + if self.subset in ['train', + 'dev']: # Labels are available in train and dev. + record.update({'label': self.spk_id2label[record['spk_id']]}) + + return record + + @staticmethod + def _get_chunks(seg_dur, audio_id, audio_duration): + num_chunks = int(audio_duration / seg_dur) # all in milliseconds + + chunk_lst = [ + audio_id + "_" + str(i * seg_dur) + "_" + str(i * seg_dur + seg_dur) + for i in range(num_chunks) + ] + return chunk_lst + + def _get_audio_info(self, wav_file: str, + split_chunks: bool) -> List[List[str]]: + waveform, sr = load_audio(wav_file) + spk_id, sess_id, utt_id = wav_file.split("/")[-3:] + audio_id = '-'.join([spk_id, sess_id, utt_id.split(".")[0]]) + audio_duration = waveform.shape[0] / sr + + ret = [] + if split_chunks: # Split into pieces of self.chunk_duration seconds. + uniq_chunks_list = self._get_chunks(self.chunk_duration, audio_id, + audio_duration) + + for chunk in uniq_chunks_list: + s, e = chunk.split("_")[-2:] # Timestamps of start and end + start_sample = int(float(s) * sr) + end_sample = int(float(e) * sr) + # id, duration, wav, start, stop, spk_id + ret.append([ + chunk, audio_duration, wav_file, start_sample, end_sample, + spk_id + ]) + else: # Keep whole audio. + ret.append([ + audio_id, audio_duration, wav_file, 0, waveform.shape[0], spk_id + ]) + return ret + + def generate_csv(self, + wav_files: List[str], + output_file: str, + split_chunks: bool = True): + print(f'Generating csv: {output_file}') + header = ["id", "duration", "wav", "start", "stop", "spk_id"] + + with Pool(64) as p: + infos = list( + tqdm( + p.imap(lambda x: self._get_audio_info(x, split_chunks), wav_files), total=len(wav_files))) + + csv_lines = [] + for info in infos: + csv_lines.extend(info) + + with open(output_file, mode="w") as csv_f: + csv_writer = csv.writer( + csv_f, delimiter=",", quotechar='"', quoting=csv.QUOTE_MINIMAL) + csv_writer.writerow(header) + for line in csv_lines: + csv_writer.writerow(line) + + def prepare_data(self): + # Audio of speakers in veri_test_file should not be included in training set. + print("start to prepare the data csv file") + enrol_files = set() + test_files = set() + # get the enroll and test audio file path + with open(self.veri_test_file, 'r') as f: + for line in f.readlines(): + _, enrol_file, test_file = line.strip().split(' ') + enrol_files.add(os.path.join(self.wav_path, enrol_file)) + test_files.add(os.path.join(self.wav_path, test_file)) + enrol_files = sorted(enrol_files) + test_files = sorted(test_files) + + # get the enroll and test speakers + test_spks = set() + for file in (enrol_files + test_files): + spk = file.split('/wav/')[1].split('/')[0] + test_spks.add(spk) + + # get all the train and dev audios file path + audio_files = [] + speakers = set() + for path in [self.wav_path]: + for file in glob.glob(os.path.join(path, "**", "*.wav"), recursive=True): + spk = file.split('/wav/')[1].split('/')[0] + if spk in test_spks: + continue + speakers.add(spk) + audio_files.append(file) + + print("start to generate the {}".format(os.path.join(self.meta_path, 'spk_id2label.txt'))) + # encode the train and dev speakers label to spk_id2label.txt + with open(os.path.join(self.meta_path, 'spk_id2label.txt'), 'w') as f: + for label, spk_id in enumerate(sorted(speakers)): # 1211 vox1, 5994 vox2, 7205 vox1+2 + f.write(f'{spk_id} {label}\n') + + audio_files = sorted(audio_files) + random.shuffle(audio_files) + split_idx = int(self.split_ratio * len(audio_files)) + # split_ratio to train + train_files, dev_files = audio_files[:split_idx], audio_files[split_idx:] + + self.generate_csv(train_files, + os.path.join(self.csv_path, 'train.csv')) + self.generate_csv(dev_files, + os.path.join(self.csv_path, 'dev.csv')) + self.generate_csv(enrol_files, + os.path.join(self.csv_path, 'enrol.csv'), + split_chunks=False) + self.generate_csv(test_files, + os.path.join(self.csv_path, 'test.csv'), + split_chunks=False) + + def __getitem__(self, idx): + return self._convert_to_record(idx) + + def __len__(self): + return len(self._data) diff --git a/examples/voxceleb/sv0/local/train.py b/examples/voxceleb/sv0/local/train.py new file mode 100644 index 000000000..e8619cca9 --- /dev/null +++ b/examples/voxceleb/sv0/local/train.py @@ -0,0 +1,31 @@ +import argparse +import paddle +from dataset.voxceleb.voxceleb1 import VoxCeleb1 + + +def main(args): + paddle.set_device(args.device) + + # stage1: we must call the paddle.distributed.init_parallel_env() api at the begining + paddle.distributed.init_parallel_env() + nranks = paddle.distributed.get_world_size() + local_rank = paddle.distributed.get_rank() + + # stage2: data prepare + train_ds = VoxCeleb1('train', target_dir=args.data_dir) + +if __name__ == "__main__": + # yapf: disable + parser = argparse.ArgumentParser(__doc__) + parser.add_argument('--device', + choices=['cpu', 'gpu'], + default="cpu", + help="Select which device to train model, defaults to gpu.") + parser.add_argument("--data-dir", + default="./data/", + type=str, + help="data directory") + args = parser.parse_args() + # yapf: enable + + main(args) \ No newline at end of file diff --git a/examples/voxceleb/sv0/path.sh b/examples/voxceleb/sv0/path.sh new file mode 100755 index 000000000..38a242a4a --- /dev/null +++ b/examples/voxceleb/sv0/path.sh @@ -0,0 +1,11 @@ +export MAIN_ROOT=`realpath ${PWD}/../../../` + +export PATH=${MAIN_ROOT}:${MAIN_ROOT}/utils:${PATH} +export LC_ALL=C + +export PYTHONDONTWRITEBYTECODE=1 +# Use UTF-8 in Python to avoid UnicodeDecodeError when LC_ALL=C +export PYTHONIOENCODING=UTF-8 +export PYTHONPATH=${MAIN_ROOT}:${PYTHONPATH} + +export LD_LIBRARY_PATH=${LD_LIBRARY_PATH}:/usr/local/lib/ diff --git a/examples/voxceleb/sv0/run.sh b/examples/voxceleb/sv0/run.sh new file mode 100755 index 000000000..c24cbff4f --- /dev/null +++ b/examples/voxceleb/sv0/run.sh @@ -0,0 +1,10 @@ +#!/bin/bash +. ./path.sh +set -e +export PPAUDIO_HOME=/home/users/xiongxinlei/exprts/v3 +dir=./data/ +mkdir -p ${dir} +# you can set the variable DATA_HOME to specifiy the downloaded the vox1 and vox2 dataset +/home/users/xiongxinlei/.conda/envs/xxl_base/bin/python3 \ + local/train.py \ + --data-dir ${dir} \ No newline at end of file