|
|
|
@ -11,182 +11,299 @@
|
|
|
|
|
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
|
|
|
|
# See the License for the specific language governing permissions and
|
|
|
|
|
# limitations under the License.
|
|
|
|
|
"""Prepare VoxCeleb1 dataset
|
|
|
|
|
|
|
|
|
|
create manifest files.
|
|
|
|
|
Manifest file is a json-format file with each line containing the
|
|
|
|
|
meta data (i.e. audio filepath, transcript and audio duration)
|
|
|
|
|
of each audio file in the data set.
|
|
|
|
|
|
|
|
|
|
researchers should download the voxceleb1 dataset yourselves
|
|
|
|
|
through google form to get the username & password and unpack the data
|
|
|
|
|
"""
|
|
|
|
|
import argparse
|
|
|
|
|
import codecs
|
|
|
|
|
|
|
|
|
|
import collections
|
|
|
|
|
import csv
|
|
|
|
|
import glob
|
|
|
|
|
import json
|
|
|
|
|
import os
|
|
|
|
|
import subprocess
|
|
|
|
|
from pathlib import Path
|
|
|
|
|
import random
|
|
|
|
|
from typing import Dict, List, Tuple
|
|
|
|
|
|
|
|
|
|
import soundfile
|
|
|
|
|
from paddle.io import Dataset
|
|
|
|
|
from tqdm import tqdm
|
|
|
|
|
from pathos.multiprocessing import Pool
|
|
|
|
|
|
|
|
|
|
from utils.utility import check_md5sum
|
|
|
|
|
from paddleaudio.backends import load as load_audio
|
|
|
|
|
from paddleaudio.utils import DATA_HOME, decompress, download_and_decompress
|
|
|
|
|
from paddleaudio.datasets.dataset import feat_funcs
|
|
|
|
|
from utils.utility import unpack
|
|
|
|
|
from utils.utility import download
|
|
|
|
|
from utils.utility import unzip
|
|
|
|
|
|
|
|
|
|
# all the data will be download in the current data/voxceleb directory default
|
|
|
|
|
DATA_HOME = os.path.expanduser('.')
|
|
|
|
|
|
|
|
|
|
# if you use the http://www.robots.ox.ac.uk/~vgg/data/voxceleb/vox1a/ as the download base url
|
|
|
|
|
# you need to get the username & password via the google form
|
|
|
|
|
|
|
|
|
|
# if you use the https://thor.robots.ox.ac.uk/~vgg/data/voxceleb/vox1a as the download base url,
|
|
|
|
|
# you need use --no-check-certificate to connect the target download url
|
|
|
|
|
|
|
|
|
|
BASE_URL = "https://thor.robots.ox.ac.uk/~vgg/data/voxceleb/vox1a"
|
|
|
|
|
|
|
|
|
|
# dev data
|
|
|
|
|
DEV_LIST = {
|
|
|
|
|
"vox1_dev_wav_partaa": "e395d020928bc15670b570a21695ed96",
|
|
|
|
|
"vox1_dev_wav_partab": "bbfaaccefab65d82b21903e81a8a8020",
|
|
|
|
|
"vox1_dev_wav_partac": "017d579a2a96a077f40042ec33e51512",
|
|
|
|
|
"vox1_dev_wav_partad": "7bb1e9f70fddc7a678fa998ea8b3ba19",
|
|
|
|
|
}
|
|
|
|
|
DEV_TARGET_DATA = "vox1_dev_wav_parta* vox1_dev_wav.zip ae63e55b951748cc486645f532ba230b"
|
|
|
|
|
|
|
|
|
|
# test data
|
|
|
|
|
TEST_LIST = {"vox1_test_wav.zip": "185fdc63c3c739954633d50379a3d102"}
|
|
|
|
|
TEST_TARGET_DATA = "vox1_test_wav.zip vox1_test_wav.zip 185fdc63c3c739954633d50379a3d102"
|
|
|
|
|
|
|
|
|
|
# kaldi trial
|
|
|
|
|
# this trial file is organized by kaldi according the official file,
|
|
|
|
|
# which is a little different with the official trial veri_test2.txt
|
|
|
|
|
KALDI_BASE_URL = "http://www.openslr.org/resources/49/"
|
|
|
|
|
TRIAL_LIST = {"voxceleb1_test_v2.txt": "29fc7cc1c5d59f0816dc15d6e8be60f7"}
|
|
|
|
|
TRIAL_TARGET_DATA = "voxceleb1_test_v2.txt voxceleb1_test_v2.txt 29fc7cc1c5d59f0816dc15d6e8be60f7"
|
|
|
|
|
|
|
|
|
|
parser = argparse.ArgumentParser(description=__doc__)
|
|
|
|
|
parser.add_argument(
|
|
|
|
|
"--target_dir",
|
|
|
|
|
default=DATA_HOME + "/voxceleb1/",
|
|
|
|
|
type=str,
|
|
|
|
|
help="Directory to save the voxceleb1 dataset. (default: %(default)s)")
|
|
|
|
|
parser.add_argument(
|
|
|
|
|
"--manifest_prefix",
|
|
|
|
|
default="manifest",
|
|
|
|
|
type=str,
|
|
|
|
|
help="Filepath prefix for output manifests. (default: %(default)s)")
|
|
|
|
|
|
|
|
|
|
args = parser.parse_args()
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
def create_manifest(data_dir, manifest_path_prefix):
|
|
|
|
|
print("Creating manifest %s ..." % manifest_path_prefix)
|
|
|
|
|
json_lines = []
|
|
|
|
|
data_path = os.path.join(data_dir, "wav", "**", "*.wav")
|
|
|
|
|
total_sec = 0.0
|
|
|
|
|
total_text = 0.0
|
|
|
|
|
total_num = 0
|
|
|
|
|
speakers = set()
|
|
|
|
|
for audio_path in glob.glob(data_path, recursive=True):
|
|
|
|
|
audio_id = "-".join(audio_path.split("/")[-3:])
|
|
|
|
|
utt2spk = audio_path.split("/")[-3]
|
|
|
|
|
duration = soundfile.info(audio_path).duration
|
|
|
|
|
text = ""
|
|
|
|
|
json_lines.append(
|
|
|
|
|
json.dumps(
|
|
|
|
|
{
|
|
|
|
|
"utt": audio_id,
|
|
|
|
|
"utt2spk": str(utt2spk),
|
|
|
|
|
"feat": audio_path,
|
|
|
|
|
"feat_shape": (duration, ),
|
|
|
|
|
"text": text # compatible with asr data format
|
|
|
|
|
},
|
|
|
|
|
ensure_ascii=False))
|
|
|
|
|
|
|
|
|
|
total_sec += duration
|
|
|
|
|
total_text += len(text)
|
|
|
|
|
total_num += 1
|
|
|
|
|
speakers.add(utt2spk)
|
|
|
|
|
|
|
|
|
|
# data_dir_name refer to dev or test
|
|
|
|
|
# voxceleb1 is given explicit in the path
|
|
|
|
|
data_dir_name = Path(data_dir).name
|
|
|
|
|
manifest_path_prefix = manifest_path_prefix + "." + data_dir_name
|
|
|
|
|
with codecs.open(manifest_path_prefix, 'w', encoding='utf-8') as f:
|
|
|
|
|
for line in json_lines:
|
|
|
|
|
f.write(line + "\n")
|
|
|
|
|
|
|
|
|
|
manifest_dir = os.path.dirname(manifest_path_prefix)
|
|
|
|
|
meta_path = os.path.join(manifest_dir, "voxceleb1." +
|
|
|
|
|
data_dir_name) + ".meta"
|
|
|
|
|
with codecs.open(meta_path, 'w', encoding='utf-8') as f:
|
|
|
|
|
print(f"{total_num} utts", file=f)
|
|
|
|
|
print(f"{len(speakers)} speakers", file=f)
|
|
|
|
|
print(f"{total_sec / (60 * 60)} h", file=f)
|
|
|
|
|
print(f"{total_text} text", file=f)
|
|
|
|
|
print(f"{total_text / total_sec} text/sec", file=f)
|
|
|
|
|
print(f"{total_sec / total_num} sec/utt", file=f)
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
def prepare_dataset(base_url, data_list, target_dir, manifest_path,
|
|
|
|
|
target_data):
|
|
|
|
|
if not os.path.exists(target_dir):
|
|
|
|
|
os.mkdir(target_dir)
|
|
|
|
|
|
|
|
|
|
# wav directory already exists, it need do nothing
|
|
|
|
|
if not os.path.exists(os.path.join(target_dir, "wav")):
|
|
|
|
|
# download all dataset part
|
|
|
|
|
for zip_part in data_list.keys():
|
|
|
|
|
download_url = " --no-check-certificate " + base_url + "/" + zip_part
|
|
|
|
|
download(
|
|
|
|
|
url=download_url,
|
|
|
|
|
md5sum=data_list[zip_part],
|
|
|
|
|
target_dir=target_dir)
|
|
|
|
|
|
|
|
|
|
# pack the all part to target zip file
|
|
|
|
|
all_target_part, target_name, target_md5sum = target_data.split()
|
|
|
|
|
target_name = os.path.join(target_dir, target_name)
|
|
|
|
|
if not os.path.exists(target_name):
|
|
|
|
|
pack_part_cmd = "cat {}/{} > {}".format(target_dir, all_target_part,
|
|
|
|
|
target_name)
|
|
|
|
|
subprocess.call(pack_part_cmd, shell=True)
|
|
|
|
|
|
|
|
|
|
# check the target zip file md5sum
|
|
|
|
|
if not check_md5sum(target_name, target_md5sum):
|
|
|
|
|
raise RuntimeError("{} MD5 checkssum failed".format(target_name))
|
|
|
|
|
else:
|
|
|
|
|
print("Check {} md5sum successfully".format(target_name))
|
|
|
|
|
|
|
|
|
|
# unzip the all zip file
|
|
|
|
|
if target_name.endswith(".zip"):
|
|
|
|
|
unzip(target_name, target_dir)
|
|
|
|
|
|
|
|
|
|
# create the manifest file
|
|
|
|
|
create_manifest(data_dir=target_dir, manifest_path_prefix=manifest_path)
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
def main():
|
|
|
|
|
if args.target_dir.startswith('~'):
|
|
|
|
|
args.target_dir = os.path.expanduser(args.target_dir)
|
|
|
|
|
|
|
|
|
|
prepare_dataset(
|
|
|
|
|
base_url=BASE_URL,
|
|
|
|
|
data_list=DEV_LIST,
|
|
|
|
|
target_dir=os.path.join(args.target_dir, "dev"),
|
|
|
|
|
manifest_path=args.manifest_prefix,
|
|
|
|
|
target_data=DEV_TARGET_DATA)
|
|
|
|
|
|
|
|
|
|
prepare_dataset(
|
|
|
|
|
base_url=BASE_URL,
|
|
|
|
|
data_list=TEST_LIST,
|
|
|
|
|
target_dir=os.path.join(args.target_dir, "test"),
|
|
|
|
|
manifest_path=args.manifest_prefix,
|
|
|
|
|
target_data=TEST_TARGET_DATA)
|
|
|
|
|
|
|
|
|
|
print("Manifest prepare done!")
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
if __name__ == '__main__':
|
|
|
|
|
main()
|
|
|
|
|
__all__ = ['VoxCeleb1']
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
class VoxCeleb1(Dataset):
|
|
|
|
|
source_url = 'https://thor.robots.ox.ac.uk/~vgg/data/voxceleb/vox1a/'
|
|
|
|
|
archieves_audio_dev = [
|
|
|
|
|
{
|
|
|
|
|
'url': source_url + 'vox1_dev_wav_partaa',
|
|
|
|
|
'md5': 'e395d020928bc15670b570a21695ed96',
|
|
|
|
|
},
|
|
|
|
|
{
|
|
|
|
|
'url': source_url + 'vox1_dev_wav_partab',
|
|
|
|
|
'md5': 'bbfaaccefab65d82b21903e81a8a8020',
|
|
|
|
|
},
|
|
|
|
|
{
|
|
|
|
|
'url': source_url + 'vox1_dev_wav_partac',
|
|
|
|
|
'md5': '017d579a2a96a077f40042ec33e51512',
|
|
|
|
|
},
|
|
|
|
|
{
|
|
|
|
|
'url': source_url + 'vox1_dev_wav_partad',
|
|
|
|
|
'md5': '7bb1e9f70fddc7a678fa998ea8b3ba19',
|
|
|
|
|
},
|
|
|
|
|
]
|
|
|
|
|
archieves_audio_test = [
|
|
|
|
|
{
|
|
|
|
|
'url': source_url + 'vox1_test_wav.zip',
|
|
|
|
|
'md5': '185fdc63c3c739954633d50379a3d102',
|
|
|
|
|
},
|
|
|
|
|
]
|
|
|
|
|
archieves_meta = [
|
|
|
|
|
{
|
|
|
|
|
'url': 'https://www.robots.ox.ac.uk/~vgg/data/voxceleb/meta/veri_test2.txt',
|
|
|
|
|
'md5': 'b73110731c9223c1461fe49cb48dddfc',
|
|
|
|
|
},
|
|
|
|
|
]
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
num_speakers = 1211 # 1211 vox1, 5994 vox2, 7205 vox1+2, test speakers: 41
|
|
|
|
|
sample_rate = 16000
|
|
|
|
|
meta_info = collections.namedtuple(
|
|
|
|
|
'META_INFO', ('id', 'duration', 'wav', 'start', 'stop', 'spk_id'))
|
|
|
|
|
base_path = os.path.join(DATA_HOME, 'vox1')
|
|
|
|
|
wav_path = os.path.join(base_path, 'wav')
|
|
|
|
|
subsets = ['train', 'dev', 'enrol', 'test']
|
|
|
|
|
|
|
|
|
|
def __init__(self,
|
|
|
|
|
subset: str = 'train',
|
|
|
|
|
feat_type: str = 'raw',
|
|
|
|
|
random_chunk: bool = True,
|
|
|
|
|
chunk_duration: float = 3.0, # seconds
|
|
|
|
|
split_ratio: float = 0.9, # train split ratio
|
|
|
|
|
seed: int = 0,
|
|
|
|
|
target_dir: str = None,
|
|
|
|
|
**kwargs):
|
|
|
|
|
|
|
|
|
|
assert subset in self.subsets, \
|
|
|
|
|
'Dataset subset must be one in {}, but got {}'.format(self.subsets, subset)
|
|
|
|
|
|
|
|
|
|
self.subset = subset
|
|
|
|
|
self.spk_id2label = {}
|
|
|
|
|
self.feat_type = feat_type
|
|
|
|
|
self.feat_config = kwargs
|
|
|
|
|
self.random_chunk = random_chunk
|
|
|
|
|
self.chunk_duration = chunk_duration
|
|
|
|
|
self.split_ratio = split_ratio
|
|
|
|
|
self.target_dir = target_dir if target_dir else self.base_path
|
|
|
|
|
self.csv_path = os.path.join(target_dir, 'csv') if target_dir else os.path.join(self.base_path, 'csv')
|
|
|
|
|
self.meta_path = os.path.join(target_dir, 'meta') if target_dir else os.path.join(base_path, 'meta')
|
|
|
|
|
self.veri_test_file = os.path.join(self.meta_path, 'veri_test2.txt')
|
|
|
|
|
# self._data = self._get_data()[:1000] # KP: Small dataset test.
|
|
|
|
|
self._data = self._get_data()
|
|
|
|
|
super(VoxCeleb1, self).__init__()
|
|
|
|
|
|
|
|
|
|
# Set up a seed to reproduce training or predicting result.
|
|
|
|
|
# random.seed(seed)
|
|
|
|
|
|
|
|
|
|
def _get_data(self):
|
|
|
|
|
# Download audio files.
|
|
|
|
|
# We need the users to decompress all vox1/dev/wav and vox1/test/wav/ to vox1/wav/ dir
|
|
|
|
|
# so, we check the vox1/wav dir status
|
|
|
|
|
print("wav base path: {}".format(self.wav_path))
|
|
|
|
|
if not os.path.isdir(self.wav_path):
|
|
|
|
|
print("start to download the voxceleb1 dataset")
|
|
|
|
|
download_and_decompress( # multi-zip parts concatenate to vox1_dev_wav.zip
|
|
|
|
|
self.archieves_audio_dev, self.base_path, decompress=False)
|
|
|
|
|
download_and_decompress( # download the vox1_test_wav.zip and unzip
|
|
|
|
|
self.archieves_audio_test, self.base_path, decompress=True)
|
|
|
|
|
|
|
|
|
|
# Download all parts and concatenate the files into one zip file.
|
|
|
|
|
dev_zipfile = os.path.join(self.base_path, 'vox1_dev_wav.zip')
|
|
|
|
|
print(f'Concatenating all parts to: {dev_zipfile}')
|
|
|
|
|
os.system(
|
|
|
|
|
f'cat {os.path.join(self.base_path, "vox1_dev_wav_parta*")} > {dev_zipfile}'
|
|
|
|
|
)
|
|
|
|
|
|
|
|
|
|
# Extract all audio files of dev and test set.
|
|
|
|
|
decompress(dev_zipfile, self.base_path)
|
|
|
|
|
|
|
|
|
|
# Download meta files.
|
|
|
|
|
if not os.path.isdir(self.meta_path):
|
|
|
|
|
download_and_decompress(
|
|
|
|
|
self.archieves_meta, self.meta_path, decompress=False)
|
|
|
|
|
|
|
|
|
|
# Data preparation.
|
|
|
|
|
if not os.path.isdir(self.csv_path):
|
|
|
|
|
os.makedirs(self.csv_path)
|
|
|
|
|
self.prepare_data()
|
|
|
|
|
|
|
|
|
|
data = []
|
|
|
|
|
with open(os.path.join(self.csv_path, f'{self.subset}.csv'), 'r') as rf:
|
|
|
|
|
for line in rf.readlines()[1:]:
|
|
|
|
|
audio_id, duration, wav, start, stop, spk_id = line.strip(
|
|
|
|
|
).split(',')
|
|
|
|
|
data.append(
|
|
|
|
|
self.meta_info(audio_id, float(duration), wav, int(start),
|
|
|
|
|
int(stop), spk_id))
|
|
|
|
|
|
|
|
|
|
with open(os.path.join(self.meta_path, 'spk_id2label.txt'), 'r') as f:
|
|
|
|
|
for line in f.readlines():
|
|
|
|
|
spk_id, label = line.strip().split(' ')
|
|
|
|
|
self.spk_id2label[spk_id] = int(label)
|
|
|
|
|
|
|
|
|
|
return data
|
|
|
|
|
|
|
|
|
|
def _convert_to_record(self, idx: int):
|
|
|
|
|
sample = self._data[idx]
|
|
|
|
|
|
|
|
|
|
record = {}
|
|
|
|
|
# To show all fields in a namedtuple: `type(sample)._fields`
|
|
|
|
|
for field in type(sample)._fields:
|
|
|
|
|
record[field] = getattr(sample, field)
|
|
|
|
|
|
|
|
|
|
waveform, sr = load_audio(record['wav'])
|
|
|
|
|
|
|
|
|
|
# random select a chunk audio samples from the audio
|
|
|
|
|
if self.random_chunk:
|
|
|
|
|
num_wav_samples = waveform.shape[0]
|
|
|
|
|
num_chunk_samples = int(self.chunk_duration * sr)
|
|
|
|
|
start = random.randint(0, num_wav_samples - num_chunk_samples - 1)
|
|
|
|
|
stop = start + num_chunk_samples
|
|
|
|
|
else:
|
|
|
|
|
start = record['start']
|
|
|
|
|
stop = record['stop']
|
|
|
|
|
|
|
|
|
|
waveform = waveform[start:stop]
|
|
|
|
|
|
|
|
|
|
assert self.feat_type in feat_funcs.keys(), \
|
|
|
|
|
f"Unknown feat_type: {self.feat_type}, it must be one in {list(feat_funcs.keys())}"
|
|
|
|
|
feat_func = feat_funcs[self.feat_type]
|
|
|
|
|
feat = feat_func(
|
|
|
|
|
waveform, sr=sr, **self.feat_config) if feat_func else waveform
|
|
|
|
|
|
|
|
|
|
record.update({'feat': feat})
|
|
|
|
|
if self.subset in ['train',
|
|
|
|
|
'dev']: # Labels are available in train and dev.
|
|
|
|
|
record.update({'label': self.spk_id2label[record['spk_id']]})
|
|
|
|
|
|
|
|
|
|
return record
|
|
|
|
|
|
|
|
|
|
@staticmethod
|
|
|
|
|
def _get_chunks(seg_dur, audio_id, audio_duration):
|
|
|
|
|
num_chunks = int(audio_duration / seg_dur) # all in milliseconds
|
|
|
|
|
|
|
|
|
|
chunk_lst = [
|
|
|
|
|
audio_id + "_" + str(i * seg_dur) + "_" + str(i * seg_dur + seg_dur)
|
|
|
|
|
for i in range(num_chunks)
|
|
|
|
|
]
|
|
|
|
|
return chunk_lst
|
|
|
|
|
|
|
|
|
|
def _get_audio_info(self, wav_file: str,
|
|
|
|
|
split_chunks: bool) -> List[List[str]]:
|
|
|
|
|
waveform, sr = load_audio(wav_file)
|
|
|
|
|
spk_id, sess_id, utt_id = wav_file.split("/")[-3:]
|
|
|
|
|
audio_id = '-'.join([spk_id, sess_id, utt_id.split(".")[0]])
|
|
|
|
|
audio_duration = waveform.shape[0] / sr
|
|
|
|
|
|
|
|
|
|
ret = []
|
|
|
|
|
if split_chunks: # Split into pieces of self.chunk_duration seconds.
|
|
|
|
|
uniq_chunks_list = self._get_chunks(self.chunk_duration, audio_id,
|
|
|
|
|
audio_duration)
|
|
|
|
|
|
|
|
|
|
for chunk in uniq_chunks_list:
|
|
|
|
|
s, e = chunk.split("_")[-2:] # Timestamps of start and end
|
|
|
|
|
start_sample = int(float(s) * sr)
|
|
|
|
|
end_sample = int(float(e) * sr)
|
|
|
|
|
# id, duration, wav, start, stop, spk_id
|
|
|
|
|
ret.append([
|
|
|
|
|
chunk, audio_duration, wav_file, start_sample, end_sample,
|
|
|
|
|
spk_id
|
|
|
|
|
])
|
|
|
|
|
else: # Keep whole audio.
|
|
|
|
|
ret.append([
|
|
|
|
|
audio_id, audio_duration, wav_file, 0, waveform.shape[0], spk_id
|
|
|
|
|
])
|
|
|
|
|
return ret
|
|
|
|
|
|
|
|
|
|
def generate_csv(self,
|
|
|
|
|
wav_files: List[str],
|
|
|
|
|
output_file: str,
|
|
|
|
|
split_chunks: bool = True):
|
|
|
|
|
print(f'Generating csv: {output_file}')
|
|
|
|
|
header = ["id", "duration", "wav", "start", "stop", "spk_id"]
|
|
|
|
|
|
|
|
|
|
with Pool(64) as p:
|
|
|
|
|
infos = list(
|
|
|
|
|
tqdm(
|
|
|
|
|
p.imap(lambda x: self._get_audio_info(x, split_chunks), wav_files), total=len(wav_files)))
|
|
|
|
|
|
|
|
|
|
csv_lines = []
|
|
|
|
|
for info in infos:
|
|
|
|
|
csv_lines.extend(info)
|
|
|
|
|
|
|
|
|
|
with open(output_file, mode="w") as csv_f:
|
|
|
|
|
csv_writer = csv.writer(
|
|
|
|
|
csv_f, delimiter=",", quotechar='"', quoting=csv.QUOTE_MINIMAL)
|
|
|
|
|
csv_writer.writerow(header)
|
|
|
|
|
for line in csv_lines:
|
|
|
|
|
csv_writer.writerow(line)
|
|
|
|
|
|
|
|
|
|
def prepare_data(self):
|
|
|
|
|
# Audio of speakers in veri_test_file should not be included in training set.
|
|
|
|
|
print("start to prepare the data csv file")
|
|
|
|
|
enrol_files = set()
|
|
|
|
|
test_files = set()
|
|
|
|
|
# get the enroll and test audio file path
|
|
|
|
|
with open(self.veri_test_file, 'r') as f:
|
|
|
|
|
for line in f.readlines():
|
|
|
|
|
_, enrol_file, test_file = line.strip().split(' ')
|
|
|
|
|
enrol_files.add(os.path.join(self.wav_path, enrol_file))
|
|
|
|
|
test_files.add(os.path.join(self.wav_path, test_file))
|
|
|
|
|
enrol_files = sorted(enrol_files)
|
|
|
|
|
test_files = sorted(test_files)
|
|
|
|
|
|
|
|
|
|
# get the enroll and test speakers
|
|
|
|
|
test_spks = set()
|
|
|
|
|
for file in (enrol_files + test_files):
|
|
|
|
|
spk = file.split('/wav/')[1].split('/')[0]
|
|
|
|
|
test_spks.add(spk)
|
|
|
|
|
|
|
|
|
|
# get all the train and dev audios file path
|
|
|
|
|
audio_files = []
|
|
|
|
|
speakers = set()
|
|
|
|
|
for path in [self.wav_path]:
|
|
|
|
|
for file in glob.glob(os.path.join(path, "**", "*.wav"), recursive=True):
|
|
|
|
|
spk = file.split('/wav/')[1].split('/')[0]
|
|
|
|
|
if spk in test_spks:
|
|
|
|
|
continue
|
|
|
|
|
speakers.add(spk)
|
|
|
|
|
audio_files.append(file)
|
|
|
|
|
|
|
|
|
|
print("start to generate the {}".format(os.path.join(self.meta_path, 'spk_id2label.txt')))
|
|
|
|
|
# encode the train and dev speakers label to spk_id2label.txt
|
|
|
|
|
with open(os.path.join(self.meta_path, 'spk_id2label.txt'), 'w') as f:
|
|
|
|
|
for label, spk_id in enumerate(sorted(speakers)): # 1211 vox1, 5994 vox2, 7205 vox1+2
|
|
|
|
|
f.write(f'{spk_id} {label}\n')
|
|
|
|
|
|
|
|
|
|
audio_files = sorted(audio_files)
|
|
|
|
|
random.shuffle(audio_files)
|
|
|
|
|
split_idx = int(self.split_ratio * len(audio_files))
|
|
|
|
|
# split_ratio to train
|
|
|
|
|
train_files, dev_files = audio_files[:split_idx], audio_files[split_idx:]
|
|
|
|
|
|
|
|
|
|
self.generate_csv(train_files,
|
|
|
|
|
os.path.join(self.csv_path, 'train.csv'))
|
|
|
|
|
self.generate_csv(dev_files,
|
|
|
|
|
os.path.join(self.csv_path, 'dev.csv'))
|
|
|
|
|
self.generate_csv(enrol_files,
|
|
|
|
|
os.path.join(self.csv_path, 'enrol.csv'),
|
|
|
|
|
split_chunks=False)
|
|
|
|
|
self.generate_csv(test_files,
|
|
|
|
|
os.path.join(self.csv_path, 'test.csv'),
|
|
|
|
|
split_chunks=False)
|
|
|
|
|
|
|
|
|
|
def __getitem__(self, idx):
|
|
|
|
|
return self._convert_to_record(idx)
|
|
|
|
|
|
|
|
|
|
def __len__(self):
|
|
|
|
|
return len(self._data)
|
|
|
|
|