diff --git a/dataset/voxceleb/voxceleb1.py b/dataset/voxceleb/voxceleb1.py index c6fc0695..d0978d9d 100644 --- a/dataset/voxceleb/voxceleb1.py +++ b/dataset/voxceleb/voxceleb1.py @@ -59,12 +59,17 @@ DEV_TARGET_DATA = "vox1_dev_wav_parta* vox1_dev_wav.zip ae63e55b951748cc486645f5 TEST_LIST = {"vox1_test_wav.zip": "185fdc63c3c739954633d50379a3d102"} TEST_TARGET_DATA = "vox1_test_wav.zip vox1_test_wav.zip 185fdc63c3c739954633d50379a3d102" -# kaldi trial -# this trial file is organized by kaldi according the official file, -# which is a little different with the official trial veri_test2.txt -KALDI_BASE_URL = "http://www.openslr.org/resources/49/" -TRIAL_LIST = {"voxceleb1_test_v2.txt": "29fc7cc1c5d59f0816dc15d6e8be60f7"} -TRIAL_TARGET_DATA = "voxceleb1_test_v2.txt voxceleb1_test_v2.txt 29fc7cc1c5d59f0816dc15d6e8be60f7" +# voxceleb trial + +TRIAL_BASE_URL = "https://www.robots.ox.ac.uk/~vgg/data/voxceleb/meta/" +TRIAL_LIST = { + "veri_test.txt": "29fc7cc1c5d59f0816dc15d6e8be60f7", # voxceleb1 + "veri_test2.txt": "b73110731c9223c1461fe49cb48dddfc", # voxceleb1(cleaned) + "list_test_hard.txt": "21c341b6b2168eea2634df0fb4b8fff1", # voxceleb1-H + "list_test_hard2.txt": "857790e09d579a68eb2e339a090343c8", # voxceleb1-H(cleaned) + "list_test_all.txt": "b9ecf7aa49d4b656aa927a8092844e4a", # voxceleb1-E + "list_test_all2.txt": "a53e059deb562ffcfc092bf5d90d9f3a" # voxceleb1-E(cleaned) + } parser = argparse.ArgumentParser(description=__doc__) parser.add_argument( @@ -82,7 +87,7 @@ args = parser.parse_args() def create_manifest(data_dir, manifest_path_prefix): - print("Creating manifest %s ..." % manifest_path_prefix) + print(f"Creating manifest {manifest_path_prefix} from {data_dir}") json_lines = [] data_path = os.path.join(data_dir, "wav", "**", "*.wav") total_sec = 0.0 @@ -114,6 +119,9 @@ def create_manifest(data_dir, manifest_path_prefix): # voxceleb1 is given explicit in the path data_dir_name = Path(data_dir).name manifest_path_prefix = manifest_path_prefix + "." + data_dir_name + if not os.path.exists(os.path.dirname(manifest_path_prefix)): + os.makedirs(os.path.dirname(manifest_path_prefix)) + with codecs.open(manifest_path_prefix, 'w', encoding='utf-8') as f: for line in json_lines: f.write(line + "\n") @@ -133,11 +141,13 @@ def create_manifest(data_dir, manifest_path_prefix): def prepare_dataset(base_url, data_list, target_dir, manifest_path, target_data): if not os.path.exists(target_dir): - os.mkdir(target_dir) + os.makedirs(target_dir) # wav directory already exists, it need do nothing + # we will download the voxceleb1 data to ${target_dir}/vox1/dev/ or ${target_dir}/vox1/test directory if not os.path.exists(os.path.join(target_dir, "wav")): # download all dataset part + print("start to download the vox1 dev zip package") for zip_part in data_list.keys(): download_url = " --no-check-certificate " + base_url + "/" + zip_part download( @@ -166,11 +176,20 @@ def prepare_dataset(base_url, data_list, target_dir, manifest_path, # create the manifest file create_manifest(data_dir=target_dir, manifest_path_prefix=manifest_path) +def prepare_trial(base_url, data_list, target_dir): + if not os.path.exists(target_dir): + os.makedirs(target_dir) + for trial, md5sum in data_list.items(): + target_trial = os.path.join(target_dir, trial) + if not os.path.exists(os.path.join(target_dir, trial)): + download_url = " --no-check-certificate " + base_url + "/" + trial + download(url=download_url, md5sum=md5sum, target_dir=target_dir) def main(): if args.target_dir.startswith('~'): args.target_dir = os.path.expanduser(args.target_dir) - + + # prepare the vox1 dev data prepare_dataset( base_url=BASE_URL, data_list=DEV_LIST, @@ -178,6 +197,7 @@ def main(): manifest_path=args.manifest_prefix, target_data=DEV_TARGET_DATA) + # prepare the vox1 test data prepare_dataset( base_url=BASE_URL, data_list=TEST_LIST, @@ -185,8 +205,15 @@ def main(): manifest_path=args.manifest_prefix, target_data=TEST_TARGET_DATA) + # prepare the vox1 trial + prepare_trial( + base_url=TRIAL_BASE_URL, + data_list=TRIAL_LIST, + target_dir=os.path.dirname(args.manifest_prefix) + ) + print("Manifest prepare done!") if __name__ == '__main__': - main() \ No newline at end of file + main() diff --git a/dataset/voxceleb/voxceleb2.py b/dataset/voxceleb/voxceleb2.py new file mode 100644 index 00000000..ef7bb230 --- /dev/null +++ b/dataset/voxceleb/voxceleb2.py @@ -0,0 +1,163 @@ +# Copyright (c) 2021 PaddlePaddle Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +"""Prepare VoxCeleb2 dataset + +Download and unpack the voxceleb2 data files. +Voxceleb2 data is stored as the m4a format, +so we need convert the m4a to wav with the convert.sh scripts +""" +import argparse +import codecs +import glob +import json +import os +import subprocess +from pathlib import Path + +import soundfile + +from utils.utility import check_md5sum +from utils.utility import download +from utils.utility import unzip + +# all the data will be download in the current data/voxceleb directory default +DATA_HOME = os.path.expanduser('.') + +BASE_URL = "--no-check-certificate https://www.robots.ox.ac.uk/~vgg/data/voxceleb/data/" + +# dev data +DEV_DATA_URL = BASE_URL + '/vox2_aac.zip' +DEV_MD5SUM = "bbc063c46078a602ca71605645c2a402" + + +# test data +TEST_DATA_URL = BASE_URL + '/vox2_test_aac.zip' +TEST_MD5SUM = "0d2b3ea430a821c33263b5ea37ede312" + +parser = argparse.ArgumentParser(description=__doc__) +parser.add_argument( + "--target_dir", + default=DATA_HOME + "/voxceleb2/", + type=str, + help="Directory to save the voxceleb1 dataset. (default: %(default)s)") +parser.add_argument( + "--manifest_prefix", + default="manifest", + type=str, + help="Filepath prefix for output manifests. (default: %(default)s)") +parser.add_argument("--download", + default=False, + action="store_true", + help="Download the voxceleb2 dataset. (default: %(default)s)") +parser.add_argument("--generate", + default=False, + action="store_true", + help="Generate the manifest files. (default: %(default)s)") + +args = parser.parse_args() + + +def create_manifest(data_dir, manifest_path_prefix): + print("Creating manifest %s ..." % manifest_path_prefix) + json_lines = [] + data_path = os.path.join(data_dir, "**", "*.wav") + total_sec = 0.0 + total_text = 0.0 + total_num = 0 + speakers = set() + for audio_path in glob.glob(data_path, recursive=True): + audio_id = "-".join(audio_path.split("/")[-3:]) + utt2spk = audio_path.split("/")[-3] + duration = soundfile.info(audio_path).duration + text = "" + json_lines.append( + json.dumps( + { + "utt": audio_id, + "utt2spk": str(utt2spk), + "feat": audio_path, + "feat_shape": (duration, ), + "text": text # compatible with asr data format + }, + ensure_ascii=False)) + + total_sec += duration + total_text += len(text) + total_num += 1 + speakers.add(utt2spk) + + # data_dir_name refer to dev or test + # voxceleb2 is given explicit in the path + data_dir_name = Path(data_dir).name + manifest_path_prefix = manifest_path_prefix + "." + data_dir_name + + if not os.path.exists(os.path.dirname(manifest_path_prefix)): + os.makedirs(os.path.dirname(manifest_path_prefix)) + with codecs.open(manifest_path_prefix, 'w', encoding='utf-8') as f: + for line in json_lines: + f.write(line + "\n") + + manifest_dir = os.path.dirname(manifest_path_prefix) + meta_path = os.path.join(manifest_dir, "voxceleb2." + + data_dir_name) + ".meta" + with codecs.open(meta_path, 'w', encoding='utf-8') as f: + print(f"{total_num} utts", file=f) + print(f"{len(speakers)} speakers", file=f) + print(f"{total_sec / (60 * 60)} h", file=f) + print(f"{total_text} text", file=f) + print(f"{total_text / total_sec} text/sec", file=f) + print(f"{total_sec / total_num} sec/utt", file=f) + + +def download_dataset(url, md5sum, target_dir, dataset): + if not os.path.exists(target_dir): + os.makedirs(target_dir) + + # wav directory already exists, it need do nothing + print("target dir {}".format(os.path.join(target_dir, dataset))) + # unzip the dev dataset will create the dev and unzip the m4a to dev dir + # but the test dataset will unzip to aac + # so, wo create the ${target_dir}/test and unzip the m4a to test dir + if not os.path.exists(os.path.join(target_dir, dataset)): + filepath = download(url, md5sum, target_dir) + if dataset == "test": + unzip(filepath, os.path.join(target_dir, "test")) + + +def main(): + if args.target_dir.startswith('~'): + args.target_dir = os.path.expanduser(args.target_dir) + + # download and unpack the vox2-dev data + print("download: {}".format(args.download)) + if args.download: + download_dataset( + url=DEV_DATA_URL, + md5sum=DEV_MD5SUM, + target_dir=args.target_dir, + dataset="dev") + + download_dataset( + url=TEST_DATA_URL, + md5sum=TEST_MD5SUM, + target_dir=args.target_dir, + dataset="test") + + print("VoxCeleb2 download is done!") + + if args.generate: + create_manifest(args.target_dir, manifest_path_prefix=args.manifest_prefix) + +if __name__ == '__main__': + main() diff --git a/paddlespeech/vector/exps/ecapa_tdnn/extract_emb.py b/paddlespeech/vector/exps/ecapa_tdnn/extract_emb.py index e30a50e4..ec24be51 100644 --- a/paddlespeech/vector/exps/ecapa_tdnn/extract_emb.py +++ b/paddlespeech/vector/exps/ecapa_tdnn/extract_emb.py @@ -28,6 +28,91 @@ from paddlespeech.vector.training.seeding import seed_everything logger = Log(__name__).getlog() +class VectorWrapper: + """ VectorWrapper extract the audio embedding, + and single audio will get only an embedding + """ + def __init__(self, + device, + config_path, + model_path,): + super(VectorWrapper, self).__init__() + # stage 0: config the + self.device = device + self.config_path = config_path + self.model_path = model_path + + # stage 1: set the run host device + paddle.device.set_device(device) + + # stage 2: read the yaml config and set the seed factor + self.read_yaml_config(self.config_path) + seed_everything(self.config.seed) + + # stage 3: init the speaker verification model + self.init_vector_model(self.config, self.model_path) + + def read_yaml_config(self, config_path): + """Read the yaml config from the config path + + Args: + config_path (str): yaml config path + """ + config = CfgNode(new_allowed=True) + + if config_path: + config.merge_from_file(config_path) + + config.freeze() + self.config = config + + def init_vector_model(self, config, model_path): + """Init the vector model from yaml config + + Args: + config (CfgNode): yaml config + model_path (str): pretrained model path and the stored model is named as model.pdparams + """ + # get the backbone network instance + ecapa_tdnn = EcapaTdnn(**config.model) + + # get the sid instance + model = SpeakerIdetification(backbone=ecapa_tdnn, num_class=config.num_speakers) + + # read the model parameters to sid model + model_path = os.path.abspath(os.path.expanduser(model_path)) + state_dict = paddle.load(os.path.join(model_path, "model.pdparams")) + model.set_state_dict(state_dict) + + model.eval() + self.model = model + + def extract_audio_embedding(self, audio_path): + """Extract the audio embedding + + Args: + audio_path (str): audio path, which will be extracted the embedding + + Returns: + embedding (numpy.array) : audio embedding + """ + waveform, sr = load_audio(audio_path) + feat = melspectrogram(x=waveform, + sr=self.config.sr, + n_mels=self.config.n_mels, + window_size=self.config.window_size, + hop_length=self.config.hop_size) + # conver the audio feat to batch shape, which means batch_size is equal to one + feat = paddle.to_tensor(feat).unsqueeze(0) + + # in inference period, the lengths is all one without padding + lengths = paddle.ones([1]) + feat = feature_normalize(feat, mean_norm=True, std_norm=False) + + # model backbone network forward the feats and get the embedding + embedding = self.model.backbone(feat, lengths).squeeze().numpy() # (1, emb_size, 1) -> (emb_size) + + return embedding def extract_audio_embedding(args, config): # stage 0: set the training device, cpu or gpu @@ -83,6 +168,7 @@ def extract_audio_embedding(args, config): # stage 5: do global norm with external mean and std rtf = elapsed_time / audio_length logger.info(f"{args.device} rft={rtf}") + paddle.save(embedding, "emb1") return embedding @@ -116,3 +202,10 @@ if __name__ == "__main__": print(config) extract_audio_embedding(args, config) + + # use the VectorWrapper to extract the audio embedding + vector_inst = VectorWrapper(device="gpu", + config_path=args.config, + model_path=args.load_checkpoint) + + embedding = vector_inst.extract_audio_embedding(args.audio_path)