diff --git a/dataset/voxceleb/voxceleb1.py b/dataset/voxceleb/voxceleb1.py index e50c91bc..d0978d9d 100644 --- a/dataset/voxceleb/voxceleb1.py +++ b/dataset/voxceleb/voxceleb1.py @@ -59,12 +59,17 @@ DEV_TARGET_DATA = "vox1_dev_wav_parta* vox1_dev_wav.zip ae63e55b951748cc486645f5 TEST_LIST = {"vox1_test_wav.zip": "185fdc63c3c739954633d50379a3d102"} TEST_TARGET_DATA = "vox1_test_wav.zip vox1_test_wav.zip 185fdc63c3c739954633d50379a3d102" -# kaldi trial -# this trial file is organized by kaldi according the official file, -# which is a little different with the official trial veri_test2.txt -KALDI_BASE_URL = "http://www.openslr.org/resources/49/" -TRIAL_LIST = {"voxceleb1_test_v2.txt": "29fc7cc1c5d59f0816dc15d6e8be60f7"} -TRIAL_TARGET_DATA = "voxceleb1_test_v2.txt voxceleb1_test_v2.txt 29fc7cc1c5d59f0816dc15d6e8be60f7" +# voxceleb trial + +TRIAL_BASE_URL = "https://www.robots.ox.ac.uk/~vgg/data/voxceleb/meta/" +TRIAL_LIST = { + "veri_test.txt": "29fc7cc1c5d59f0816dc15d6e8be60f7", # voxceleb1 + "veri_test2.txt": "b73110731c9223c1461fe49cb48dddfc", # voxceleb1(cleaned) + "list_test_hard.txt": "21c341b6b2168eea2634df0fb4b8fff1", # voxceleb1-H + "list_test_hard2.txt": "857790e09d579a68eb2e339a090343c8", # voxceleb1-H(cleaned) + "list_test_all.txt": "b9ecf7aa49d4b656aa927a8092844e4a", # voxceleb1-E + "list_test_all2.txt": "a53e059deb562ffcfc092bf5d90d9f3a" # voxceleb1-E(cleaned) + } parser = argparse.ArgumentParser(description=__doc__) parser.add_argument( @@ -82,7 +87,7 @@ args = parser.parse_args() def create_manifest(data_dir, manifest_path_prefix): - print("Creating manifest %s ..." % manifest_path_prefix) + print(f"Creating manifest {manifest_path_prefix} from {data_dir}") json_lines = [] data_path = os.path.join(data_dir, "wav", "**", "*.wav") total_sec = 0.0 @@ -114,6 +119,9 @@ def create_manifest(data_dir, manifest_path_prefix): # voxceleb1 is given explicit in the path data_dir_name = Path(data_dir).name manifest_path_prefix = manifest_path_prefix + "." + data_dir_name + if not os.path.exists(os.path.dirname(manifest_path_prefix)): + os.makedirs(os.path.dirname(manifest_path_prefix)) + with codecs.open(manifest_path_prefix, 'w', encoding='utf-8') as f: for line in json_lines: f.write(line + "\n") @@ -133,11 +141,13 @@ def create_manifest(data_dir, manifest_path_prefix): def prepare_dataset(base_url, data_list, target_dir, manifest_path, target_data): if not os.path.exists(target_dir): - os.mkdir(target_dir) + os.makedirs(target_dir) # wav directory already exists, it need do nothing + # we will download the voxceleb1 data to ${target_dir}/vox1/dev/ or ${target_dir}/vox1/test directory if not os.path.exists(os.path.join(target_dir, "wav")): # download all dataset part + print("start to download the vox1 dev zip package") for zip_part in data_list.keys(): download_url = " --no-check-certificate " + base_url + "/" + zip_part download( @@ -166,11 +176,20 @@ def prepare_dataset(base_url, data_list, target_dir, manifest_path, # create the manifest file create_manifest(data_dir=target_dir, manifest_path_prefix=manifest_path) +def prepare_trial(base_url, data_list, target_dir): + if not os.path.exists(target_dir): + os.makedirs(target_dir) + for trial, md5sum in data_list.items(): + target_trial = os.path.join(target_dir, trial) + if not os.path.exists(os.path.join(target_dir, trial)): + download_url = " --no-check-certificate " + base_url + "/" + trial + download(url=download_url, md5sum=md5sum, target_dir=target_dir) def main(): if args.target_dir.startswith('~'): args.target_dir = os.path.expanduser(args.target_dir) - + + # prepare the vox1 dev data prepare_dataset( base_url=BASE_URL, data_list=DEV_LIST, @@ -178,6 +197,7 @@ def main(): manifest_path=args.manifest_prefix, target_data=DEV_TARGET_DATA) + # prepare the vox1 test data prepare_dataset( base_url=BASE_URL, data_list=TEST_LIST, @@ -185,6 +205,13 @@ def main(): manifest_path=args.manifest_prefix, target_data=TEST_TARGET_DATA) + # prepare the vox1 trial + prepare_trial( + base_url=TRIAL_BASE_URL, + data_list=TRIAL_LIST, + target_dir=os.path.dirname(args.manifest_prefix) + ) + print("Manifest prepare done!") diff --git a/dataset/voxceleb/voxceleb2.py b/dataset/voxceleb/voxceleb2.py new file mode 100644 index 00000000..ef7bb230 --- /dev/null +++ b/dataset/voxceleb/voxceleb2.py @@ -0,0 +1,163 @@ +# Copyright (c) 2021 PaddlePaddle Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +"""Prepare VoxCeleb2 dataset + +Download and unpack the voxceleb2 data files. +Voxceleb2 data is stored as the m4a format, +so we need convert the m4a to wav with the convert.sh scripts +""" +import argparse +import codecs +import glob +import json +import os +import subprocess +from pathlib import Path + +import soundfile + +from utils.utility import check_md5sum +from utils.utility import download +from utils.utility import unzip + +# all the data will be download in the current data/voxceleb directory default +DATA_HOME = os.path.expanduser('.') + +BASE_URL = "--no-check-certificate https://www.robots.ox.ac.uk/~vgg/data/voxceleb/data/" + +# dev data +DEV_DATA_URL = BASE_URL + '/vox2_aac.zip' +DEV_MD5SUM = "bbc063c46078a602ca71605645c2a402" + + +# test data +TEST_DATA_URL = BASE_URL + '/vox2_test_aac.zip' +TEST_MD5SUM = "0d2b3ea430a821c33263b5ea37ede312" + +parser = argparse.ArgumentParser(description=__doc__) +parser.add_argument( + "--target_dir", + default=DATA_HOME + "/voxceleb2/", + type=str, + help="Directory to save the voxceleb1 dataset. (default: %(default)s)") +parser.add_argument( + "--manifest_prefix", + default="manifest", + type=str, + help="Filepath prefix for output manifests. (default: %(default)s)") +parser.add_argument("--download", + default=False, + action="store_true", + help="Download the voxceleb2 dataset. (default: %(default)s)") +parser.add_argument("--generate", + default=False, + action="store_true", + help="Generate the manifest files. (default: %(default)s)") + +args = parser.parse_args() + + +def create_manifest(data_dir, manifest_path_prefix): + print("Creating manifest %s ..." % manifest_path_prefix) + json_lines = [] + data_path = os.path.join(data_dir, "**", "*.wav") + total_sec = 0.0 + total_text = 0.0 + total_num = 0 + speakers = set() + for audio_path in glob.glob(data_path, recursive=True): + audio_id = "-".join(audio_path.split("/")[-3:]) + utt2spk = audio_path.split("/")[-3] + duration = soundfile.info(audio_path).duration + text = "" + json_lines.append( + json.dumps( + { + "utt": audio_id, + "utt2spk": str(utt2spk), + "feat": audio_path, + "feat_shape": (duration, ), + "text": text # compatible with asr data format + }, + ensure_ascii=False)) + + total_sec += duration + total_text += len(text) + total_num += 1 + speakers.add(utt2spk) + + # data_dir_name refer to dev or test + # voxceleb2 is given explicit in the path + data_dir_name = Path(data_dir).name + manifest_path_prefix = manifest_path_prefix + "." + data_dir_name + + if not os.path.exists(os.path.dirname(manifest_path_prefix)): + os.makedirs(os.path.dirname(manifest_path_prefix)) + with codecs.open(manifest_path_prefix, 'w', encoding='utf-8') as f: + for line in json_lines: + f.write(line + "\n") + + manifest_dir = os.path.dirname(manifest_path_prefix) + meta_path = os.path.join(manifest_dir, "voxceleb2." + + data_dir_name) + ".meta" + with codecs.open(meta_path, 'w', encoding='utf-8') as f: + print(f"{total_num} utts", file=f) + print(f"{len(speakers)} speakers", file=f) + print(f"{total_sec / (60 * 60)} h", file=f) + print(f"{total_text} text", file=f) + print(f"{total_text / total_sec} text/sec", file=f) + print(f"{total_sec / total_num} sec/utt", file=f) + + +def download_dataset(url, md5sum, target_dir, dataset): + if not os.path.exists(target_dir): + os.makedirs(target_dir) + + # wav directory already exists, it need do nothing + print("target dir {}".format(os.path.join(target_dir, dataset))) + # unzip the dev dataset will create the dev and unzip the m4a to dev dir + # but the test dataset will unzip to aac + # so, wo create the ${target_dir}/test and unzip the m4a to test dir + if not os.path.exists(os.path.join(target_dir, dataset)): + filepath = download(url, md5sum, target_dir) + if dataset == "test": + unzip(filepath, os.path.join(target_dir, "test")) + + +def main(): + if args.target_dir.startswith('~'): + args.target_dir = os.path.expanduser(args.target_dir) + + # download and unpack the vox2-dev data + print("download: {}".format(args.download)) + if args.download: + download_dataset( + url=DEV_DATA_URL, + md5sum=DEV_MD5SUM, + target_dir=args.target_dir, + dataset="dev") + + download_dataset( + url=TEST_DATA_URL, + md5sum=TEST_MD5SUM, + target_dir=args.target_dir, + dataset="test") + + print("VoxCeleb2 download is done!") + + if args.generate: + create_manifest(args.target_dir, manifest_path_prefix=args.manifest_prefix) + +if __name__ == '__main__': + main() diff --git a/examples/voxceleb/README.md b/examples/voxceleb/README.md index 2c8ad138..a2e58e00 100644 --- a/examples/voxceleb/README.md +++ b/examples/voxceleb/README.md @@ -6,3 +6,51 @@ sv0 - speaker verfication with softmax backend etc, all python code sv1 - dependence on kaldi, speaker verfication with plda/sc backend, more info refer to the sv1/readme.txt + + +## VoxCeleb2 preparation + +VoxCeleb2 audio files are released in m4a format. All the VoxCeleb2 m4a audio files must be converted in wav files before feeding them in PaddleSpeech. +Please, follow these steps to prepare the dataset correctly: + +1. Download Voxceleb2. +You can find download instructions here: http://www.robots.ox.ac.uk/~vgg/data/voxceleb/ + +2. Convert .m4a to wav +VoxCeleb2 stores files with the m4a audio format. To use them in PaddleSpeech, you have to convert all the m4a audio files into wav files. + +``` shell +ffmpeg -y -i %s -ac 1 -vn -acodec pcm_s16le -ar 16000 %s +``` + +You can do the conversion using ffmpeg https://gist.github.com/seungwonpark/4f273739beef2691cd53b5c39629d830). This operation might take several hours and should be only once. + +3. Put all the wav files in a folder called `wav`. You should have something like `voxceleb2/wav/id*/*.wav` (e.g, `voxceleb2/wav/id00012/21Uxsk56VDQ/00001.wav`) + + +## voxceleb dataset summary + + +|dataset | vox1 - dev | vox1 - test |vox2 - dev| vox2 - test| +|---------|-----------|------------|-----------|----------| +|spks | 1211 |40 | 5994 | 118| +|utts | 148642 | 4874 | 1092009 |36273| +| time(h) | 340.4 | 11.2 | 2360.2 |79.9 | + + +## trial summary + +| trial | filename | nums | positive | negative | +|--------|-----------|--------|-------|------| +| VoxCeleb1 | veri_test.txt | 37720 | 18860 | 18860 | +| VoxCeleb1(cleaned) | veri_test2.txt | 37611 | 18802 | 18809 | +| VoxCeleb1-H | list_test_hard.txt | 552536 | 276270 | 276266 | +|VoxCeleb1-H(cleaned) |list_test_hard2.txt | 550894 | 275488 | 275406 | +|VoxCeleb1-E | list_test_all.txt | 581480 | 290743 | 290737 | +|VoxCeleb1-E(cleaned) | list_test_all2.txt |579818 |289921 |289897 | + + + + + + diff --git a/examples/voxceleb/sv0/conf/ecapa_tdnn.yaml b/examples/voxceleb/sv0/conf/ecapa_tdnn.yaml new file mode 100644 index 00000000..e58dca82 --- /dev/null +++ b/examples/voxceleb/sv0/conf/ecapa_tdnn.yaml @@ -0,0 +1,52 @@ +########################################### +# Data # +########################################### +# we should explicitly specify the wav path of vox2 audio data converted from m4a +vox2_base_path: +augment: True +batch_size: 16 +num_workers: 2 +num_speakers: 7205 # 1211 vox1, 5994 vox2, 7205 vox1+2, test speakers: 41 +shuffle: True +random_chunk: True + +########################################################### +# FEATURE EXTRACTION SETTING # +########################################################### +# currently, we only support fbank +sr: 16000 # sample rate +n_mels: 80 +window_size: 400 #25ms, sample rate 16000, 25 * 16000 / 1000 = 400 +hop_size: 160 #10ms, sample rate 16000, 10 * 16000 / 1000 = 160 + +########################################################### +# MODEL SETTING # +########################################################### +# currently, we only support ecapa-tdnn in the ecapa_tdnn.yaml +# if we want use another model, please choose another configuration yaml file +model: + input_size: 80 + # "channels": [512, 512, 512, 512, 1536], + channels: [1024, 1024, 1024, 1024, 3072] + kernel_sizes: [5, 3, 3, 3, 1] + dilations: [1, 2, 3, 4, 1] + attention_channels: 128 + lin_neurons: 192 + +########################################### +# Training # +########################################### +seed: 1986 # according from speechbrain configuration +epochs: 10 +save_interval: 1 +log_interval: 1 +learning_rate: 1e-8 + + +########################################### +# Testing # +########################################### +global_embedding_norm: True +embedding_mean_norm: True +embedding_std_norm: False + diff --git a/examples/voxceleb/sv0/local/data.sh b/examples/voxceleb/sv0/local/data.sh new file mode 100755 index 00000000..a3ff1c48 --- /dev/null +++ b/examples/voxceleb/sv0/local/data.sh @@ -0,0 +1,58 @@ +#!/bin/bash +# Copyright (c) 2021 PaddlePaddle Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +stage=1 +stop_stage=100 + +. ${MAIN_ROOT}/utils/parse_options.sh || exit -1; + +if [ $# -ne 2 ] ; then + echo "Usage: $0 [options] "; + echo "e.g.: $0 ./data/ conf/ecapa_tdnn.yaml" + echo "Options: " + echo " --stage # Used to run a partially-completed data process from somewhere in the middle." + echo " --stop-stage # Used to run a partially-completed data process stop stage in the middle" + exit 1; +fi + +dir=$1 +conf_path=$2 +mkdir -p ${dir} + +if [ ${stage} -le 0 ] && [ ${stop_stage} -ge 0 ]; then + # data prepare for vox1 and vox2, vox2 must be converted from m4a to wav + # we should use the local/convert.sh convert m4a to wav + python3 local/data_prepare.py \ + --data-dir ${dir} \ + --config ${conf_path} +fi + +TARGET_DIR=${MAIN_ROOT}/dataset +mkdir -p ${TARGET_DIR} + +if [ ${stage} -le 1 ] && [ ${stop_stage} -ge 1 ]; then + # download data, generate manifests + python3 ${TARGET_DIR}/voxceleb/voxceleb1.py \ + --manifest_prefix="data/vox1/manifest" \ + --target_dir="${TARGET_DIR}/voxceleb/vox1/" + + if [ $? -ne 0 ]; then + echo "Prepare voxceleb failed. Terminated." + exit 1 + fi + + # for dataset in train dev test; do + # mv data/manifest.${dataset} data/manifest.${dataset}.raw + # done +fi \ No newline at end of file diff --git a/examples/voxceleb/sv0/local/data_prepare.py b/examples/voxceleb/sv0/local/data_prepare.py new file mode 100644 index 00000000..19ba41b8 --- /dev/null +++ b/examples/voxceleb/sv0/local/data_prepare.py @@ -0,0 +1,71 @@ +# Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +import argparse +import os + +import paddle +from yacs.config import CfgNode + +from paddleaudio.datasets.voxceleb import VoxCeleb +from paddlespeech.s2t.utils.log import Log +from paddlespeech.vector.io.augment import build_augment_pipeline +from paddlespeech.vector.training.seeding import seed_everything + +logger = Log(__name__).getlog() + + +def main(args, config): + + # stage0: set the cpu device, all data prepare process will be done in cpu mode + paddle.set_device("cpu") + # set the random seed, it is a must for multiprocess training + seed_everything(config.seed) + + # stage 1: generate the voxceleb csv file + # Note: this may occurs c++ execption, but the program will execute fine + # so we ignore the execption + # we explicitly pass the vox2 base path to data prepare and generate the audio info + logger.info("start to generate the voxceleb dataset info") + train_dataset = VoxCeleb( + 'train', target_dir=args.data_dir, vox2_base_path=config.vox2_base_path) + + # stage 2: generate the augment noise csv file + if config.augment: + logger.info("start to generate the augment dataset info") + augment_pipeline = build_augment_pipeline(target_dir=args.data_dir) + + +if __name__ == "__main__": + # yapf: disable + parser = argparse.ArgumentParser(__doc__) + parser.add_argument("--data-dir", + default="./data/", + type=str, + help="data directory") + parser.add_argument("--config", + default=None, + type=str, + help="configuration file") + args = parser.parse_args() + # yapf: enable + + # https://yaml.org/type/float.html + config = CfgNode(new_allowed=True) + if args.config: + config.merge_from_file(args.config) + + config.freeze() + print(config) + + main(args, config) diff --git a/examples/voxceleb/sv0/local/emb.sh b/examples/voxceleb/sv0/local/emb.sh new file mode 100755 index 00000000..31d79e52 --- /dev/null +++ b/examples/voxceleb/sv0/local/emb.sh @@ -0,0 +1,51 @@ +#!/bin/bash +# Copyright (c) 2021 PaddlePaddle Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +. ./path.sh + +stage=0 +stop_stage=100 +exp_dir=exp/ecapa-tdnn-vox12-big/ # experiment directory +conf_path=conf/ecapa_tdnn.yaml +audio_path="demo/voxceleb/00001.wav" +use_gpu=true + +. ${MAIN_ROOT}/utils/parse_options.sh || exit -1; + +if [ $# -ne 0 ] ; then + echo "Usage: $0 [options]"; + echo "e.g.: $0 ./data/ exp/voxceleb12/ conf/ecapa_tdnn.yaml" + echo "Options: " + echo " --use-gpu # specify is gpu is to be used for training" + echo " --stage # Used to run a partially-completed data process from somewhere in the middle." + echo " --stop-stage # Used to run a partially-completed data process stop stage in the middle" + echo " --exp-dir # experiment directorh, where is has the model.pdparams" + echo " --conf-path # configuration file for extracting the embedding" + echo " --audio-path # audio-path, which will be processed to extract the embedding" + exit 1; +fi + +# set the test device +device="cpu" +if ${use_gpu}; then + device="gpu" +fi + +if [ ${stage} -le 0 ] && [ ${stop_stage} -ge 0 ]; then + # extract the audio embedding + python3 ${BIN_DIR}/extract_emb.py --device ${device} \ + --config ${conf_path} \ + --audio-path ${audio_path} --load-checkpoint ${exp_dir} +fi \ No newline at end of file diff --git a/examples/voxceleb/sv0/local/test.sh b/examples/voxceleb/sv0/local/test.sh new file mode 100644 index 00000000..4460a165 --- /dev/null +++ b/examples/voxceleb/sv0/local/test.sh @@ -0,0 +1,42 @@ +#!/bin/bash +# Copyright (c) 2021 PaddlePaddle Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +stage=1 +stop_stage=100 +use_gpu=true # if true, we run on GPU. + +. ${MAIN_ROOT}/utils/parse_options.sh || exit -1; + +if [ $# -ne 3 ] ; then + echo "Usage: $0 [options] "; + echo "e.g.: $0 ./data/ exp/voxceleb12/ conf/ecapa_tdnn.yaml" + echo "Options: " + echo " --use-gpu # specify is gpu is to be used for training" + echo " --stage # Used to run a partially-completed data process from somewhere in the middle." + echo " --stop-stage # Used to run a partially-completed data process stop stage in the middle" + exit 1; +fi + +dir=$1 +exp_dir=$2 +conf_path=$3 + +if [ ${stage} -le 1 ] && [ ${stop_stage} -ge 1 ]; then + # test the model and compute the eer metrics + python3 ${BIN_DIR}/test.py \ + --data-dir ${dir} \ + --load-checkpoint ${exp_dir} \ + --config ${conf_path} +fi diff --git a/examples/voxceleb/sv0/local/train.sh b/examples/voxceleb/sv0/local/train.sh new file mode 100755 index 00000000..5477d0a3 --- /dev/null +++ b/examples/voxceleb/sv0/local/train.sh @@ -0,0 +1,61 @@ +#!/bin/bash +# Copyright (c) 2021 PaddlePaddle Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +stage=0 +stop_stage=100 +use_gpu=true # if true, we run on GPU. + +. ${MAIN_ROOT}/utils/parse_options.sh || exit -1; + +if [ $# -ne 3 ] ; then + echo "Usage: $0 [options] "; + echo "e.g.: $0 ./data/ exp/voxceleb12/ conf/ecapa_tdnn.yaml" + echo "Options: " + echo " --use-gpu # specify is gpu is to be used for training" + echo " --stage # Used to run a partially-completed data process from somewhere in the middle." + echo " --stop-stage # Used to run a partially-completed data process stop stage in the middle" + exit 1; +fi + +dir=$1 +exp_dir=$2 +conf_path=$3 + +# get the gpu nums for training +ngpu=$(echo $CUDA_VISIBLE_DEVICES | awk -F "," '{print NF}') +echo "using $ngpu gpus..." + +# setting training device +device="cpu" +if ${use_gpu}; then + device="gpu" +fi + +if [ ${stage} -le 1 ] && [ ${stop_stage} -ge 1 ]; then + # train the speaker identification task with voxceleb data + # and we will create the trained model parameters in ${exp_dir}/model.pdparams as the soft link + # Note: we will store the log file in exp/log directory + python3 -m paddle.distributed.launch --gpus=$CUDA_VISIBLE_DEVICES \ + ${BIN_DIR}/train.py --device ${device} --checkpoint-dir ${exp_dir} \ + --data-dir ${dir} --config ${conf_path} + +fi + +if [ $? -ne 0 ]; then + echo "Failed in training!" + exit 1 +fi + +exit 0 \ No newline at end of file diff --git a/examples/voxceleb/sv0/path.sh b/examples/voxceleb/sv0/path.sh new file mode 100755 index 00000000..2be098e0 --- /dev/null +++ b/examples/voxceleb/sv0/path.sh @@ -0,0 +1,28 @@ +#!/bin/bash +# Copyright (c) 2021 PaddlePaddle Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +export MAIN_ROOT=`realpath ${PWD}/../../../` + +export PATH=${MAIN_ROOT}:${MAIN_ROOT}/utils:${PATH} +export LC_ALL=C + +export PYTHONDONTWRITEBYTECODE=1 +# Use UTF-8 in Python to avoid UnicodeDecodeError when LC_ALL=C +export PYTHONIOENCODING=UTF-8 +export PYTHONPATH=${MAIN_ROOT}:${PYTHONPATH} + +export LD_LIBRARY_PATH=${LD_LIBRARY_PATH}:/usr/local/lib/ + +MODEL=ecapa_tdnn +export BIN_DIR=${MAIN_ROOT}/paddlespeech/vector/exps/${MODEL} \ No newline at end of file diff --git a/examples/voxceleb/sv0/run.sh b/examples/voxceleb/sv0/run.sh new file mode 100755 index 00000000..bbc9e3db --- /dev/null +++ b/examples/voxceleb/sv0/run.sh @@ -0,0 +1,69 @@ +#!/bin/bash +# Copyright (c) 2021 PaddlePaddle Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +. ./path.sh +set -e + +####################################################################### +# stage 0: data prepare, including voxceleb1 download and generate {train,dev,enroll,test}.csv +# voxceleb2 data is m4a format, so we need user to convert the m4a to wav yourselves as described in Readme.md with the script local/convert.sh +# stage 1: train the speaker identification model +# stage 2: test speaker identification +# stage 3: extract the training embeding to train the LDA and PLDA +###################################################################### + +# we can set the variable PPAUDIO_HOME to specifiy the root directory of the downloaded vox1 and vox2 dataset +# default the dataset will be stored in the ~/.paddleaudio/ +# the vox2 dataset is stored in m4a format, we need to convert the audio from m4a to wav yourself +# and put all of them to ${PPAUDIO_HOME}/datasets/vox2 +# we will find the wav from ${PPAUDIO_HOME}/datasets/vox1/wav and ${PPAUDIO_HOME}/datasets/vox2/wav +# export PPAUDIO_HOME= +stage=0 +stop_stage=50 + +# data directory +# if we set the variable ${dir}, we will store the wav info to this directory +# otherwise, we will store the wav info to vox1 and vox2 directory respectively +# vox2 wav path, we must convert the m4a format to wav format +dir=data/ # data info directory + +exp_dir=exp/ecapa-tdnn-vox12-big/ # experiment directory +conf_path=conf/ecapa_tdnn.yaml +gpus=0,1,2,3 + +source ${MAIN_ROOT}/utils/parse_options.sh || exit 1; + +mkdir -p ${exp_dir} + +if [ $stage -le 0 ] && [ ${stop_stage} -ge 0 ]; then + # stage 0: data prepare for vox1 and vox2, vox2 must be converted from m4a to wav + bash ./local/data.sh ${dir} ${conf_path}|| exit -1; +fi + +if [ $stage -le 1 ] && [ ${stop_stage} -ge 1 ]; then + # stage 1: train the speaker identification model + CUDA_VISIBLE_DEVICES=${gpus} bash ./local/train.sh ${dir} ${exp_dir} ${conf_path} +fi + +if [ $stage -le 2 ] && [ ${stop_stage} -ge 2 ]; then + # stage 2: get the speaker verification scores with cosine function + # now we only support use cosine to get the scores + CUDA_VISIBLE_DEVICES=0 bash ./local/test.sh ${dir} ${exp_dir} ${conf_path} +fi + +# if [ $stage -le 3 ]; then +# # stage 2: extract the training embeding to train the LDA and PLDA +# # todo: extract the training embedding +# fi diff --git a/examples/voxceleb/sv0/utils b/examples/voxceleb/sv0/utils new file mode 120000 index 00000000..256f914a --- /dev/null +++ b/examples/voxceleb/sv0/utils @@ -0,0 +1 @@ +../../../utils/ \ No newline at end of file diff --git a/paddleaudio/paddleaudio/datasets/__init__.py b/paddleaudio/paddleaudio/datasets/__init__.py index 5c5f0369..6f44e977 100644 --- a/paddleaudio/paddleaudio/datasets/__init__.py +++ b/paddleaudio/paddleaudio/datasets/__init__.py @@ -15,3 +15,5 @@ from .esc50 import ESC50 from .gtzan import GTZAN from .tess import TESS from .urban_sound import UrbanSound8K +from .voxceleb import VoxCeleb +from .rirs_noises import OpenRIRNoise diff --git a/paddleaudio/paddleaudio/datasets/rirs_noises.py b/paddleaudio/paddleaudio/datasets/rirs_noises.py new file mode 100644 index 00000000..80bb2d74 --- /dev/null +++ b/paddleaudio/paddleaudio/datasets/rirs_noises.py @@ -0,0 +1,205 @@ +# Copyright (c) 2021 PaddlePaddle Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +import collections +import csv +import glob +import os +import random +from typing import Dict +from typing import List +from typing import Tuple + +from paddle.io import Dataset +from tqdm import tqdm + +from ..backends import load as load_audio +from ..backends import save as save_wav +from ..utils import DATA_HOME +from ..utils import decompress +from ..utils.download import download_and_decompress +from .dataset import feat_funcs + +__all__ = ['OpenRIRNoise'] + + +class OpenRIRNoise(Dataset): + archieves = [ + { + 'url': 'http://www.openslr.org/resources/28/rirs_noises.zip', + 'md5': 'e6f48e257286e05de56413b4779d8ffb', + }, + ] + + sample_rate = 16000 + meta_info = collections.namedtuple('META_INFO', ('id', 'duration', 'wav')) + base_path = os.path.join(DATA_HOME, 'open_rir_noise') + wav_path = os.path.join(base_path, 'RIRS_NOISES') + csv_path = os.path.join(base_path, 'csv') + subsets = ['rir', 'noise'] + + def __init__(self, + subset: str='rir', + feat_type: str='raw', + target_dir=None, + random_chunk: bool=True, + chunk_duration: float=3.0, + seed: int=0, + **kwargs): + + assert subset in self.subsets, \ + 'Dataset subset must be one in {}, but got {}'.format(self.subsets, subset) + + self.subset = subset + self.feat_type = feat_type + self.feat_config = kwargs + self.random_chunk = random_chunk + self.chunk_duration = chunk_duration + + OpenRIRNoise.csv_path = os.path.join( + target_dir, "open_rir_noise", + "csv") if target_dir else self.csv_path + self._data = self._get_data() + super(OpenRIRNoise, self).__init__() + + # Set up a seed to reproduce training or predicting result. + # random.seed(seed) + + def _get_data(self): + # Download audio files. + print(f"rirs noises base path: {self.base_path}") + if not os.path.isdir(self.base_path): + download_and_decompress( + self.archieves, self.base_path, decompress=True) + else: + print( + f"{self.base_path} already exists, we will not download and decompress again" + ) + + # Data preparation. + print(f"prepare the csv to {self.csv_path}") + if not os.path.isdir(self.csv_path): + os.makedirs(self.csv_path) + self.prepare_data() + + data = [] + with open(os.path.join(self.csv_path, f'{self.subset}.csv'), 'r') as rf: + for line in rf.readlines()[1:]: + audio_id, duration, wav = line.strip().split(',') + data.append(self.meta_info(audio_id, float(duration), wav)) + + random.shuffle(data) + return data + + def _convert_to_record(self, idx: int): + sample = self._data[idx] + + record = {} + # To show all fields in a namedtuple: `type(sample)._fields` + for field in type(sample)._fields: + record[field] = getattr(sample, field) + + waveform, sr = load_audio(record['wav']) + + assert self.feat_type in feat_funcs.keys(), \ + f"Unknown feat_type: {self.feat_type}, it must be one in {list(feat_funcs.keys())}" + feat_func = feat_funcs[self.feat_type] + feat = feat_func( + waveform, sr=sr, **self.feat_config) if feat_func else waveform + + record.update({'feat': feat}) + return record + + @staticmethod + def _get_chunks(seg_dur, audio_id, audio_duration): + num_chunks = int(audio_duration / seg_dur) # all in milliseconds + + chunk_lst = [ + audio_id + "_" + str(i * seg_dur) + "_" + str(i * seg_dur + seg_dur) + for i in range(num_chunks) + ] + return chunk_lst + + def _get_audio_info(self, wav_file: str, + split_chunks: bool) -> List[List[str]]: + waveform, sr = load_audio(wav_file) + audio_id = wav_file.split("/open_rir_noise/")[-1].split(".")[0] + audio_duration = waveform.shape[0] / sr + + ret = [] + if split_chunks and audio_duration > self.chunk_duration: # Split into pieces of self.chunk_duration seconds. + uniq_chunks_list = self._get_chunks(self.chunk_duration, audio_id, + audio_duration) + + for idx, chunk in enumerate(uniq_chunks_list): + s, e = chunk.split("_")[-2:] # Timestamps of start and end + start_sample = int(float(s) * sr) + end_sample = int(float(e) * sr) + new_wav_file = os.path.join(self.base_path, + audio_id + f'_chunk_{idx+1:02}.wav') + save_wav(waveform[start_sample:end_sample], sr, new_wav_file) + # id, duration, new_wav + ret.append([chunk, self.chunk_duration, new_wav_file]) + else: # Keep whole audio. + ret.append([audio_id, audio_duration, wav_file]) + return ret + + def generate_csv(self, + wav_files: List[str], + output_file: str, + split_chunks: bool=True): + print(f'Generating csv: {output_file}') + header = ["id", "duration", "wav"] + + infos = list( + tqdm( + map(self._get_audio_info, wav_files, [split_chunks] * len( + wav_files)), + total=len(wav_files))) + + csv_lines = [] + for info in infos: + csv_lines.extend(info) + + with open(output_file, mode="w") as csv_f: + csv_writer = csv.writer( + csv_f, delimiter=",", quotechar='"', quoting=csv.QUOTE_MINIMAL) + csv_writer.writerow(header) + for line in csv_lines: + csv_writer.writerow(line) + + def prepare_data(self): + rir_list = os.path.join(self.wav_path, "real_rirs_isotropic_noises", + "rir_list") + rir_files = [] + with open(rir_list, 'r') as f: + for line in f.readlines(): + rir_file = line.strip().split(' ')[-1] + rir_files.append(os.path.join(self.base_path, rir_file)) + + noise_list = os.path.join(self.wav_path, "pointsource_noises", + "noise_list") + noise_files = [] + with open(noise_list, 'r') as f: + for line in f.readlines(): + noise_file = line.strip().split(' ')[-1] + noise_files.append(os.path.join(self.base_path, noise_file)) + + self.generate_csv(rir_files, os.path.join(self.csv_path, 'rir.csv')) + self.generate_csv(noise_files, os.path.join(self.csv_path, 'noise.csv')) + + def __getitem__(self, idx): + return self._convert_to_record(idx) + + def __len__(self): + return len(self._data) diff --git a/paddleaudio/paddleaudio/datasets/voxceleb.py b/paddleaudio/paddleaudio/datasets/voxceleb.py new file mode 100644 index 00000000..b9b8c271 --- /dev/null +++ b/paddleaudio/paddleaudio/datasets/voxceleb.py @@ -0,0 +1,358 @@ +# Copyright (c) 2021 PaddlePaddle Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +import collections +import csv +import glob +import os +import random +from multiprocessing import cpu_count +from typing import Dict +from typing import List +from typing import Tuple + +from paddle.io import Dataset +from pathos.multiprocessing import Pool +from tqdm import tqdm + +from ..backends import load as load_audio +from ..utils import DATA_HOME +from ..utils import decompress +from ..utils.download import download_and_decompress +from .dataset import feat_funcs + +__all__ = ['VoxCeleb'] + + +class VoxCeleb(Dataset): + source_url = 'https://thor.robots.ox.ac.uk/~vgg/data/voxceleb/vox1a/' + archieves_audio_dev = [ + { + 'url': source_url + 'vox1_dev_wav_partaa', + 'md5': 'e395d020928bc15670b570a21695ed96', + }, + { + 'url': source_url + 'vox1_dev_wav_partab', + 'md5': 'bbfaaccefab65d82b21903e81a8a8020', + }, + { + 'url': source_url + 'vox1_dev_wav_partac', + 'md5': '017d579a2a96a077f40042ec33e51512', + }, + { + 'url': source_url + 'vox1_dev_wav_partad', + 'md5': '7bb1e9f70fddc7a678fa998ea8b3ba19', + }, + ] + archieves_audio_test = [ + { + 'url': source_url + 'vox1_test_wav.zip', + 'md5': '185fdc63c3c739954633d50379a3d102', + }, + ] + archieves_meta = [ + { + 'url': + 'https://www.robots.ox.ac.uk/~vgg/data/voxceleb/meta/veri_test2.txt', + 'md5': + 'b73110731c9223c1461fe49cb48dddfc', + }, + ] + + num_speakers = 1211 # 1211 vox1, 5994 vox2, 7205 vox1+2, test speakers: 41 + sample_rate = 16000 + meta_info = collections.namedtuple( + 'META_INFO', ('id', 'duration', 'wav', 'start', 'stop', 'spk_id')) + base_path = os.path.join(DATA_HOME, 'vox1') + wav_path = os.path.join(base_path, 'wav') + meta_path = os.path.join(base_path, 'meta') + veri_test_file = os.path.join(meta_path, 'veri_test2.txt') + csv_path = os.path.join(base_path, 'csv') + subsets = ['train', 'dev', 'enroll', 'test'] + + def __init__( + self, + subset: str='train', + feat_type: str='raw', + random_chunk: bool=True, + chunk_duration: float=3.0, # seconds + split_ratio: float=0.9, # train split ratio + seed: int=0, + target_dir: str=None, + vox2_base_path=None, + **kwargs): + """VoxCeleb data prepare and get the specific dataset audio info + + Args: + subset (str, optional): dataset name, such as train, dev, enroll or test. Defaults to 'train'. + feat_type (str, optional): feat type, such raw, melspectrogram(fbank) or mfcc . Defaults to 'raw'. + random_chunk (bool, optional): random select a duration from audio. Defaults to True. + chunk_duration (float, optional): chunk duration if random_chunk flag is set. Defaults to 3.0. + target_dir (str, optional): data dir, audio info will be stored in this directory. Defaults to None. + vox2_base_path (_type_, optional): vox2 directory. vox2 data must be converted from m4a to wav. Defaults to None. + """ + assert subset in self.subsets, \ + 'Dataset subset must be one in {}, but got {}'.format(self.subsets, subset) + + self.subset = subset + self.spk_id2label = {} + self.feat_type = feat_type + self.feat_config = kwargs + self.random_chunk = random_chunk + self.chunk_duration = chunk_duration + self.split_ratio = split_ratio + self.target_dir = target_dir if target_dir else VoxCeleb.base_path + self.vox2_base_path = vox2_base_path + + # if we set the target dir, we will change the vox data info data from base path to target dir + VoxCeleb.csv_path = os.path.join( + target_dir, "voxceleb", 'csv') if target_dir else VoxCeleb.csv_path + VoxCeleb.meta_path = os.path.join( + target_dir, "voxceleb", + 'meta') if target_dir else VoxCeleb.meta_path + VoxCeleb.veri_test_file = os.path.join(VoxCeleb.meta_path, + 'veri_test2.txt') + # self._data = self._get_data()[:1000] # KP: Small dataset test. + self._data = self._get_data() + super(VoxCeleb, self).__init__() + + # Set up a seed to reproduce training or predicting result. + # random.seed(seed) + + def _get_data(self): + # Download audio files. + # We need the users to decompress all vox1/dev/wav and vox1/test/wav/ to vox1/wav/ dir + # so, we check the vox1/wav dir status + print(f"wav base path: {self.wav_path}") + if not os.path.isdir(self.wav_path): + print(f"start to download the voxceleb1 dataset") + download_and_decompress( # multi-zip parts concatenate to vox1_dev_wav.zip + self.archieves_audio_dev, + self.base_path, + decompress=False) + download_and_decompress( # download the vox1_test_wav.zip and unzip + self.archieves_audio_test, + self.base_path, + decompress=True) + + # Download all parts and concatenate the files into one zip file. + dev_zipfile = os.path.join(self.base_path, 'vox1_dev_wav.zip') + print(f'Concatenating all parts to: {dev_zipfile}') + os.system( + f'cat {os.path.join(self.base_path, "vox1_dev_wav_parta*")} > {dev_zipfile}' + ) + + # Extract all audio files of dev and test set. + decompress(dev_zipfile, self.base_path) + + # Download meta files. + if not os.path.isdir(self.meta_path): + print("prepare the meta data") + download_and_decompress( + self.archieves_meta, self.meta_path, decompress=False) + + # Data preparation. + if not os.path.isdir(self.csv_path): + os.makedirs(self.csv_path) + self.prepare_data() + + data = [] + print( + f"read the {self.subset} from {os.path.join(self.csv_path, f'{self.subset}.csv')}" + ) + with open(os.path.join(self.csv_path, f'{self.subset}.csv'), 'r') as rf: + for line in rf.readlines()[1:]: + audio_id, duration, wav, start, stop, spk_id = line.strip( + ).split(',') + data.append( + self.meta_info(audio_id, + float(duration), wav, + int(start), int(stop), spk_id)) + + with open(os.path.join(self.meta_path, 'spk_id2label.txt'), 'r') as f: + for line in f.readlines(): + spk_id, label = line.strip().split(' ') + self.spk_id2label[spk_id] = int(label) + + return data + + def _convert_to_record(self, idx: int): + sample = self._data[idx] + + record = {} + # To show all fields in a namedtuple: `type(sample)._fields` + for field in type(sample)._fields: + record[field] = getattr(sample, field) + + waveform, sr = load_audio(record['wav']) + + # random select a chunk audio samples from the audio + if self.random_chunk: + num_wav_samples = waveform.shape[0] + num_chunk_samples = int(self.chunk_duration * sr) + start = random.randint(0, num_wav_samples - num_chunk_samples - 1) + stop = start + num_chunk_samples + else: + start = record['start'] + stop = record['stop'] + + waveform = waveform[start:stop] + + assert self.feat_type in feat_funcs.keys(), \ + f"Unknown feat_type: {self.feat_type}, it must be one in {list(feat_funcs.keys())}" + feat_func = feat_funcs[self.feat_type] + feat = feat_func( + waveform, sr=sr, **self.feat_config) if feat_func else waveform + + record.update({'feat': feat}) + if self.subset in ['train', + 'dev']: # Labels are available in train and dev. + record.update({'label': self.spk_id2label[record['spk_id']]}) + + return record + + @staticmethod + def _get_chunks(seg_dur, audio_id, audio_duration): + num_chunks = int(audio_duration / seg_dur) # all in milliseconds + + chunk_lst = [ + audio_id + "_" + str(i * seg_dur) + "_" + str(i * seg_dur + seg_dur) + for i in range(num_chunks) + ] + return chunk_lst + + def _get_audio_info(self, wav_file: str, + split_chunks: bool) -> List[List[str]]: + waveform, sr = load_audio(wav_file) + spk_id, sess_id, utt_id = wav_file.split("/")[-3:] + audio_id = '-'.join([spk_id, sess_id, utt_id.split(".")[0]]) + audio_duration = waveform.shape[0] / sr + + ret = [] + if split_chunks: # Split into pieces of self.chunk_duration seconds. + uniq_chunks_list = self._get_chunks(self.chunk_duration, audio_id, + audio_duration) + + for chunk in uniq_chunks_list: + s, e = chunk.split("_")[-2:] # Timestamps of start and end + start_sample = int(float(s) * sr) + end_sample = int(float(e) * sr) + # id, duration, wav, start, stop, spk_id + ret.append([ + chunk, audio_duration, wav_file, start_sample, end_sample, + spk_id + ]) + else: # Keep whole audio. + ret.append([ + audio_id, audio_duration, wav_file, 0, waveform.shape[0], spk_id + ]) + return ret + + def generate_csv(self, + wav_files: List[str], + output_file: str, + split_chunks: bool=True): + print(f'Generating csv: {output_file}') + header = ["ID", "duration", "wav", "start", "stop", "spk_id"] + # Note: this may occurs c++ execption, but the program will execute fine + # so we can ignore the execption + with Pool(cpu_count()) as p: + infos = list( + tqdm( + p.imap(lambda x: self._get_audio_info(x, split_chunks), + wav_files), + total=len(wav_files))) + + csv_lines = [] + for info in infos: + csv_lines.extend(info) + + with open(output_file, mode="w") as csv_f: + csv_writer = csv.writer( + csv_f, delimiter=",", quotechar='"', quoting=csv.QUOTE_MINIMAL) + csv_writer.writerow(header) + for line in csv_lines: + csv_writer.writerow(line) + + def prepare_data(self): + # Audio of speakers in veri_test_file should not be included in training set. + print("start to prepare the data csv file") + enroll_files = set() + test_files = set() + # get the enroll and test audio file path + with open(self.veri_test_file, 'r') as f: + for line in f.readlines(): + _, enrol_file, test_file = line.strip().split(' ') + enroll_files.add(os.path.join(self.wav_path, enrol_file)) + test_files.add(os.path.join(self.wav_path, test_file)) + enroll_files = sorted(enroll_files) + test_files = sorted(test_files) + + # get the enroll and test speakers + test_spks = set() + for file in (enroll_files + test_files): + spk = file.split('/wav/')[1].split('/')[0] + test_spks.add(spk) + + # get all the train and dev audios file path + audio_files = [] + speakers = set() + print("Getting file list...") + for path in [self.wav_path, self.vox2_base_path]: + # if vox2 directory is not set and vox2 is not a directory + # we will not process this directory + if not path or not os.path.exists(path): + print(f"{path} is an invalid path, please check again, " + "and we will ignore the vox2 base path") + continue + for file in glob.glob( + os.path.join(path, "**", "*.wav"), recursive=True): + spk = file.split('/wav/')[1].split('/')[0] + if spk in test_spks: + continue + speakers.add(spk) + audio_files.append(file) + + print( + f"start to generate the {os.path.join(self.meta_path, 'spk_id2label.txt')}" + ) + # encode the train and dev speakers label to spk_id2label.txt + with open(os.path.join(self.meta_path, 'spk_id2label.txt'), 'w') as f: + for label, spk_id in enumerate( + sorted(speakers)): # 1211 vox1, 5994 vox2, 7205 vox1+2 + f.write(f'{spk_id} {label}\n') + + audio_files = sorted(audio_files) + random.shuffle(audio_files) + split_idx = int(self.split_ratio * len(audio_files)) + # split_ratio to train + train_files, dev_files = audio_files[:split_idx], audio_files[ + split_idx:] + + self.generate_csv(train_files, os.path.join(self.csv_path, 'train.csv')) + self.generate_csv(dev_files, os.path.join(self.csv_path, 'dev.csv')) + + self.generate_csv( + enroll_files, + os.path.join(self.csv_path, 'enroll.csv'), + split_chunks=False) + self.generate_csv( + test_files, + os.path.join(self.csv_path, 'test.csv'), + split_chunks=False) + + def __getitem__(self, idx): + return self._convert_to_record(idx) + + def __len__(self): + return len(self._data) diff --git a/paddleaudio/paddleaudio/metric/__init__.py b/paddleaudio/paddleaudio/metric/__init__.py index a96530ff..8e5ca9f7 100644 --- a/paddleaudio/paddleaudio/metric/__init__.py +++ b/paddleaudio/paddleaudio/metric/__init__.py @@ -12,4 +12,6 @@ # See the License for the specific language governing permissions and # limitations under the License. from .dtw import dtw_distance +from .eer import compute_eer +from .eer import compute_minDCF from .mcd import mcd_distance diff --git a/paddleaudio/paddleaudio/metric/eer.py b/paddleaudio/paddleaudio/metric/eer.py new file mode 100644 index 00000000..a1166d3f --- /dev/null +++ b/paddleaudio/paddleaudio/metric/eer.py @@ -0,0 +1,100 @@ +# Copyright (c) 2021 PaddlePaddle Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +from typing import List + +import numpy as np +import paddle +from sklearn.metrics import roc_curve + + +def compute_eer(labels: np.ndarray, scores: np.ndarray) -> List[float]: + """Compute EER and return score threshold. + + Args: + labels (np.ndarray): the trial label, shape: [N], one-dimention, N refer to the samples num + scores (np.ndarray): the trial scores, shape: [N], one-dimention, N refer to the samples num + + Returns: + List[float]: eer and the specific threshold + """ + fpr, tpr, threshold = roc_curve(y_true=labels, y_score=scores) + fnr = 1 - tpr + eer_threshold = threshold[np.nanargmin(np.absolute((fnr - fpr)))] + eer = fpr[np.nanargmin(np.absolute((fnr - fpr)))] + return eer, eer_threshold + + +def compute_minDCF(positive_scores, + negative_scores, + c_miss=1.0, + c_fa=1.0, + p_target=0.01): + """ + This is modified from SpeechBrain + https://github.com/speechbrain/speechbrain/blob/085be635c07f16d42cd1295045bc46c407f1e15b/speechbrain/utils/metric_stats.py#L509 + Computes the minDCF metric normally used to evaluate speaker verification + systems. The min_DCF is the minimum of the following C_det function computed + within the defined threshold range: + + C_det = c_miss * p_miss * p_target + c_fa * p_fa * (1 -p_target) + + where p_miss is the missing probability and p_fa is the probability of having + a false alarm. + + Args: + positive_scores (Paddle.Tensor): The scores from entries of the same class. + negative_scores (Paddle.Tensor): The scores from entries of different classes. + c_miss (float, optional): Cost assigned to a missing error (default 1.0). + c_fa (float, optional): Cost assigned to a false alarm (default 1.0). + p_target (float, optional): Prior probability of having a target (default 0.01). + + Returns: + List[float]: min dcf and the specific threshold + """ + # Computing candidate thresholds + if len(positive_scores.shape) > 1: + positive_scores = positive_scores.squeeze() + + if len(negative_scores.shape) > 1: + negative_scores = negative_scores.squeeze() + + thresholds = paddle.sort(paddle.concat([positive_scores, negative_scores])) + thresholds = paddle.unique(thresholds) + + # Adding intermediate thresholds + interm_thresholds = (thresholds[0:-1] + thresholds[1:]) / 2 + thresholds = paddle.sort(paddle.concat([thresholds, interm_thresholds])) + + # Computing False Rejection Rate (miss detection) + positive_scores = paddle.concat( + len(thresholds) * [positive_scores.unsqueeze(0)]) + pos_scores_threshold = positive_scores.transpose(perm=[1, 0]) <= thresholds + p_miss = (pos_scores_threshold.sum(0) + ).astype("float32") / positive_scores.shape[1] + del positive_scores + del pos_scores_threshold + + # Computing False Acceptance Rate (false alarm) + negative_scores = paddle.concat( + len(thresholds) * [negative_scores.unsqueeze(0)]) + neg_scores_threshold = negative_scores.transpose(perm=[1, 0]) > thresholds + p_fa = (neg_scores_threshold.sum(0) + ).astype("float32") / negative_scores.shape[1] + del negative_scores + del neg_scores_threshold + + c_det = c_miss * p_miss * p_target + c_fa * p_fa * (1 - p_target) + c_min = paddle.min(c_det, axis=0) + min_index = paddle.argmin(c_det, axis=0) + return float(c_min), float(thresholds[min_index]) diff --git a/paddleaudio/paddleaudio/utils/download.py b/paddleaudio/paddleaudio/utils/download.py index 4658352f..07d5eea8 100644 --- a/paddleaudio/paddleaudio/utils/download.py +++ b/paddleaudio/paddleaudio/utils/download.py @@ -37,7 +37,9 @@ def decompress(file: str): download._decompress(file) -def download_and_decompress(archives: List[Dict[str, str]], path: str): +def download_and_decompress(archives: List[Dict[str, str]], + path: str, + decompress: bool=True): """ Download archieves and decompress to specific path. """ @@ -47,8 +49,8 @@ def download_and_decompress(archives: List[Dict[str, str]], path: str): for archive in archives: assert 'url' in archive and 'md5' in archive, \ 'Dictionary keys of "url" and "md5" are required in the archive, but got: {list(archieve.keys())}' - - download.get_path_from_url(archive['url'], path, archive['md5']) + download.get_path_from_url( + archive['url'], path, archive['md5'], decompress=decompress) def load_state_dict_from_url(url: str, path: str, md5: str=None): diff --git a/paddlespeech/cli/__init__.py b/paddlespeech/cli/__init__.py index b526a384..ddf0359b 100644 --- a/paddlespeech/cli/__init__.py +++ b/paddlespeech/cli/__init__.py @@ -21,5 +21,6 @@ from .st import STExecutor from .stats import StatsExecutor from .text import TextExecutor from .tts import TTSExecutor +from .vector import VectorExecutor _locale._getdefaultlocale = (lambda *args: ['en_US', 'utf8']) diff --git a/paddlespeech/cli/vector/__init__.py b/paddlespeech/cli/vector/__init__.py new file mode 100644 index 00000000..038596af --- /dev/null +++ b/paddlespeech/cli/vector/__init__.py @@ -0,0 +1,14 @@ +# Copyright (c) 2021 PaddlePaddle Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +from .infer import VectorExecutor diff --git a/paddlespeech/cli/vector/infer.py b/paddlespeech/cli/vector/infer.py new file mode 100644 index 00000000..91974761 --- /dev/null +++ b/paddlespeech/cli/vector/infer.py @@ -0,0 +1,354 @@ +# Copyright (c) 2021 PaddlePaddle Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +import argparse +import os +import sys +from collections import OrderedDict +from typing import List +from typing import Optional +from typing import Union + +import paddle +import soundfile +from yacs.config import CfgNode + +from ..executor import BaseExecutor +from ..log import logger +from ..utils import cli_register +from ..utils import download_and_decompress +from ..utils import MODEL_HOME +from ..utils import stats_wrapper +from paddleaudio.backends import load as load_audio +from paddleaudio.compliance.librosa import melspectrogram +from paddlespeech.s2t.utils.dynamic_import import dynamic_import +from paddlespeech.vector.io.batch import feature_normalize +from paddlespeech.vector.modules.sid_model import SpeakerIdetification + +pretrained_models = { + # The tags for pretrained_models should be "{model_name}[-{dataset}][-{sr}][-...]". + # e.g. "ecapatdnn_voxceleb12-16k". + # Command line and python api use "{model_name}[-{dataset}]" as --model, usage: + # "paddlespeech vector --task spk --model ecapatdnn_voxceleb12-16k --sr 16000 --input ./input.wav" + "ecapatdnn_voxceleb12-16k": { + 'url': + 'https://paddlespeech.bj.bcebos.com/vector/voxceleb/sv0_ecapa_tdnn_voxceleb12_ckpt_0_1_0.tar.gz', + 'md5': + '85ff08ce0ef406b8c6d7b5ffc5b2b48f', + 'cfg_path': + 'conf/model.yaml', + 'ckpt_path': + 'model/model', + }, +} + +model_alias = { + "ecapatdnn": "paddlespeech.vector.models.ecapa_tdnn:EcapaTdnn", +} + + +@cli_register( + name="paddlespeech.vector", + description="Speech to vector embedding infer command.") +class VectorExecutor(BaseExecutor): + def __init__(self): + super(VectorExecutor, self).__init__() + + self.parser = argparse.ArgumentParser( + prog="paddlespeech.vector", add_help=True) + self.parser.add_argument( + "--model", + type=str, + default="ecapatdnn_voxceleb12", + choices=["ecapatdnn_voxceleb12"], + help="Choose model type of asr task.") + self.parser.add_argument( + "--task", + type=str, + default="spk", + choices=["spk"], + help="task type in vector domain") + self.parser.add_argument( + "--input", type=str, default=None, help="Audio file to recognize.") + self.parser.add_argument( + "--sample_rate", + type=int, + default=16000, + choices=[16000], + help="Choose the audio sample rate of the model. 8000 or 16000") + self.parser.add_argument( + "--ckpt_path", + type=str, + default=None, + help="Checkpoint file of model.") + self.parser.add_argument( + '--config', + type=str, + default=None, + help='Config of asr task. Use deault config when it is None.') + self.parser.add_argument( + "--device", + type=str, + default=paddle.get_device(), + help="Choose device to execute model inference.") + self.parser.add_argument( + '-d', + '--job_dump_result', + action='store_true', + help='Save job result into file.') + + self.parser.add_argument( + '-v', + '--verbose', + action='store_true', + help='Increase logger verbosity of current task.') + + def execute(self, argv: List[str]) -> bool: + """Command line entry for vector model + + Args: + argv (List[str]): command line args list + + Returns: + bool: + False: some audio occurs error + True: all audio process success + """ + # stage 0: parse the args and get the required args + parser_args = self.parser.parse_args(argv) + model = parser_args.model + sample_rate = parser_args.sample_rate + config = parser_args.config + ckpt_path = parser_args.ckpt_path + device = parser_args.device + + # stage 1: configurate the verbose flag + if not parser_args.verbose: + self.disable_task_loggers() + + # stage 2: read the input data and store them as a list + task_source = self.get_task_source(parser_args.input) + logger.info(f"task source: {task_source}") + + # stage 3: process the audio one by one + task_result = OrderedDict() + has_exceptions = False + for id_, input_ in task_source.items(): + try: + res = self(input_, model, sample_rate, config, ckpt_path, + device) + task_result[id_] = res + except Exception as e: + has_exceptions = True + task_result[id_] = f'{e.__class__.__name__}: {e}' + + logger.info("task result as follows: ") + logger.info(f"{task_result}") + + # stage 4: process the all the task results + self.process_task_results(parser_args.input, task_result, + parser_args.job_dump_result) + + # stage 5: return the exception flag + # if return False, somen audio process occurs error + if has_exceptions: + return False + else: + return True + + @stats_wrapper + def __call__(self, + audio_file: os.PathLike, + model: str='ecapatdnn-voxceleb12', + sample_rate: int=16000, + config: os.PathLike=None, + ckpt_path: os.PathLike=None, + force_yes: bool=False, + device=paddle.get_device()): + audio_file = os.path.abspath(audio_file) + if not self._check(audio_file, sample_rate): + sys.exit(-1) + + logger.info(f"device type: {device}") + paddle.device.set_device(device) + self._init_from_path(model, sample_rate, config, ckpt_path) + self.preprocess(model, audio_file) + self.infer(model) + res = self.postprocess() + + return res + + def _get_pretrained_path(self, tag: str) -> os.PathLike: + support_models = list(pretrained_models.keys()) + assert tag in pretrained_models, \ + 'The model "{}" you want to use has not been supported,'\ + 'please choose other models.\n' \ + 'The support models includes\n\t\t{}'.format(tag, "\n\t\t".join(support_models)) + + res_path = os.path.join(MODEL_HOME, tag) + decompressed_path = download_and_decompress(pretrained_models[tag], + res_path) + + decompressed_path = os.path.abspath(decompressed_path) + logger.info( + 'Use pretrained model stored in: {}'.format(decompressed_path)) + + return decompressed_path + + def _init_from_path(self, + model_type: str='ecapatdnn_voxceleb12', + sample_rate: int=16000, + cfg_path: Optional[os.PathLike]=None, + ckpt_path: Optional[os.PathLike]=None): + if hasattr(self, "model"): + logger.info("Model has been initialized") + return + + # stage 1: get the model and config path + if cfg_path is None or ckpt_path is None: + sample_rate_str = "16k" if sample_rate == 16000 else "8k" + tag = model_type + "-" + sample_rate_str + logger.info(f"load the pretrained model: {tag}") + res_path = self._get_pretrained_path(tag) + self.res_path = res_path + + self.cfg_path = os.path.join(res_path, + pretrained_models[tag]['cfg_path']) + self.ckpt_path = os.path.join( + res_path, pretrained_models[tag]['ckpt_path'] + '.pdparams') + else: + self.cfg_path = os.path.abspath(cfg_path) + self.ckpt_path = os.path.abspath(ckpt_path + ".pdparams") + self.res_path = os.path.dirname( + os.path.dirname(os.path.abspath(self.cfg_path))) + + logger.info(f"start to read the ckpt from {self.ckpt_path}") + logger.info(f"read the config from {self.cfg_path}") + logger.info(f"get the res path {self.res_path}") + + # stage 2: read and config and init the model body + self.config = CfgNode(new_allowed=True) + self.config.merge_from_file(self.cfg_path) + + # stage 3: get the model name to instance the model network with dynamic_import + # Noet: we use the '-' to get the model name instead of '_' + logger.info("start to dynamic import the model class") + model_name = model_type[:model_type.rindex('_')] + logger.info(f"model name {model_name}") + model_class = dynamic_import(model_name, model_alias) + model_conf = self.config.model + backbone = model_class(**model_conf) + model = SpeakerIdetification( + backbone=backbone, num_class=self.config.num_speakers) + self.model = model + self.model.eval() + + # stage 4: load the model parameters + logger.info("start to set the model parameters to model") + model_dict = paddle.load(self.ckpt_path) + self.model.set_state_dict(model_dict) + + logger.info("create the model instance success") + + @paddle.no_grad() + def infer(self, model_type: str): + + feats = self._inputs["feats"] + lengths = self._inputs["lengths"] + logger.info("start to do backbone network model forward") + logger.info( + f"feats shape:{feats.shape}, lengths shape: {lengths.shape}") + # embedding from (1, emb_size, 1) -> (emb_size) + embedding = self.model.backbone(feats, lengths).squeeze().numpy() + logger.info(f"embedding size: {embedding.shape}") + + self._outputs["embedding"] = embedding + + def postprocess(self) -> Union[str, os.PathLike]: + return self._outputs["embedding"] + + def preprocess(self, model_type: str, input_file: Union[str, os.PathLike]): + audio_file = input_file + if isinstance(audio_file, (str, os.PathLike)): + logger.info(f"Preprocess audio file: {audio_file}") + + # stage 1: load the audio + waveform, sr = load_audio(audio_file) + logger.info(f"load the audio sample points, shape is: {waveform.shape}") + + # stage 2: get the audio feat + try: + feat = melspectrogram( + x=waveform, + sr=self.config.sr, + n_mels=self.config.n_mels, + window_size=self.config.window_size, + hop_length=self.config.hop_size) + logger.info(f"extract the audio feat, shape is: {feat.shape}") + except Exception as e: + logger.info(f"feat occurs exception {e}") + sys.exit(-1) + + feat = paddle.to_tensor(feat).unsqueeze(0) + # in inference period, the lengths is all one without padding + lengths = paddle.ones([1]) + feat = feature_normalize(feat, mean_norm=True, std_norm=False) + + logger.info(f"feats shape: {feat.shape}") + self._inputs["feats"] = feat + self._inputs["lengths"] = lengths + + logger.info("audio extract the feat success") + + def _check(self, audio_file: str, sample_rate: int): + self.sample_rate = sample_rate + if self.sample_rate != 16000 and self.sample_rate != 8000: + logger.error( + "invalid sample rate, please input --sr 8000 or --sr 16000") + return False + + if isinstance(audio_file, (str, os.PathLike)): + if not os.path.isfile(audio_file): + logger.error("Please input the right audio file path") + return False + + logger.info("checking the aduio file format......") + try: + audio, audio_sample_rate = soundfile.read( + audio_file, dtype="float32", always_2d=True) + except Exception as e: + logger.exception(e) + logger.error( + "can not open the audio file, please check the audio file format is 'wav'. \n \ + you can try to use sox to change the file format.\n \ + For example: \n \ + sample rate: 16k \n \ + sox input_audio.xx --rate 16k --bits 16 --channels 1 output_audio.wav \n \ + sample rate: 8k \n \ + sox input_audio.xx --rate 8k --bits 16 --channels 1 output_audio.wav \n \ + ") + return False + + logger.info(f"The sample rate is {audio_sample_rate}") + + if audio_sample_rate != self.sample_rate: + logger.error("The sample rate of the input file is not {}.\n \ + The program will resample the wav file to {}.\n \ + If the result does not meet your expectations,\n \ + Please input the 16k 16 bit 1 channel wav file. \ + ".format(self.sample_rate, self.sample_rate)) + sys.exit(-1) + else: + logger.info("The audio file format is right") + + return True diff --git a/paddlespeech/vector/exps/ecapa_tdnn/extract_emb.py b/paddlespeech/vector/exps/ecapa_tdnn/extract_emb.py new file mode 100644 index 00000000..686de936 --- /dev/null +++ b/paddlespeech/vector/exps/ecapa_tdnn/extract_emb.py @@ -0,0 +1,119 @@ +# Copyright (c) 2021 PaddlePaddle Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +import argparse +import os +import time + +import paddle +from yacs.config import CfgNode + +from paddleaudio.backends import load as load_audio +from paddleaudio.compliance.librosa import melspectrogram +from paddlespeech.s2t.utils.log import Log +from paddlespeech.vector.io.batch import feature_normalize +from paddlespeech.vector.models.ecapa_tdnn import EcapaTdnn +from paddlespeech.vector.modules.sid_model import SpeakerIdetification +from paddlespeech.vector.training.seeding import seed_everything + +logger = Log(__name__).getlog() + + +def extract_audio_embedding(args, config): + # stage 0: set the training device, cpu or gpu + paddle.set_device(args.device) + # set the random seed, it is a must for multiprocess training + seed_everything(config.seed) + + # stage 1: build the dnn backbone model network + ecapa_tdnn = EcapaTdnn(**config.model) + + # stage4: build the speaker verification train instance with backbone model + model = SpeakerIdetification( + backbone=ecapa_tdnn, num_class=config.num_speakers) + # stage 2: load the pre-trained model + args.load_checkpoint = os.path.abspath( + os.path.expanduser(args.load_checkpoint)) + + # load model checkpoint to sid model + state_dict = paddle.load( + os.path.join(args.load_checkpoint, 'model.pdparams')) + model.set_state_dict(state_dict) + logger.info(f'Checkpoint loaded from {args.load_checkpoint}') + + # stage 3: we must set the model to eval mode + model.eval() + + # stage 4: read the audio data and extract the embedding + # wavform is one dimension numpy array + waveform, sr = load_audio(args.audio_path) + + # feat type is numpy array, whose shape is [dim, time] + # we need convert the audio feat to one-batch shape [batch, dim, time], where the batch is one + # so the final shape is [1, dim, time] + start_time = time.time() + feat = melspectrogram( + x=waveform, + sr=config.sr, + n_mels=config.n_mels, + window_size=config.window_size, + hop_length=config.hop_size) + feat = paddle.to_tensor(feat).unsqueeze(0) + + # in inference period, the lengths is all one without padding + lengths = paddle.ones([1]) + feat = feature_normalize(feat, mean_norm=True, std_norm=False) + + # model backbone network forward the feats and get the embedding + embedding = model.backbone( + feat, lengths).squeeze().numpy() # (1, emb_size, 1) -> (emb_size) + elapsed_time = time.time() - start_time + audio_length = waveform.shape[0] / sr + + # stage 5: do global norm with external mean and std + rtf = elapsed_time / audio_length + logger.info(f"{args.device} rft={rtf}") + + return embedding + + +if __name__ == "__main__": + # yapf: disable + parser = argparse.ArgumentParser(__doc__) + parser.add_argument('--device', + choices=['cpu', 'gpu'], + default="cpu", + help="Select which device to train model, defaults to gpu.") + parser.add_argument("--config", + default=None, + type=str, + help="configuration file") + parser.add_argument("--load-checkpoint", + type=str, + default='', + help="Directory to load model checkpoint to contiune trainning.") + parser.add_argument("--audio-path", + default="./data/demo.wav", + type=str, + help="Single audio file path") + args = parser.parse_args() + # yapf: enable + # https://yaml.org/type/float.html + config = CfgNode(new_allowed=True) + if args.config: + config.merge_from_file(args.config) + + config.freeze() + print(config) + + extract_audio_embedding(args, config) diff --git a/paddlespeech/vector/exps/ecapa_tdnn/test.py b/paddlespeech/vector/exps/ecapa_tdnn/test.py new file mode 100644 index 00000000..76832fd8 --- /dev/null +++ b/paddlespeech/vector/exps/ecapa_tdnn/test.py @@ -0,0 +1,205 @@ +# Copyright (c) 2021 PaddlePaddle Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +import argparse +import ast +import os + +import numpy as np +import paddle +import paddle.nn.functional as F +from paddle.io import BatchSampler +from paddle.io import DataLoader +from tqdm import tqdm +from yacs.config import CfgNode + +from paddleaudio.datasets import VoxCeleb +from paddleaudio.metric import compute_eer +from paddlespeech.s2t.utils.log import Log +from paddlespeech.vector.io.batch import batch_feature_normalize +from paddlespeech.vector.models.ecapa_tdnn import EcapaTdnn +from paddlespeech.vector.modules.sid_model import SpeakerIdetification +from paddlespeech.vector.training.seeding import seed_everything + +logger = Log(__name__).getlog() + + +def main(args, config): + # stage0: set the training device, cpu or gpu + paddle.set_device(args.device) + # set the random seed, it is a must for multiprocess training + seed_everything(config.seed) + + # stage1: build the dnn backbone model network + ecapa_tdnn = EcapaTdnn(**config.model) + + # stage2: build the speaker verification eval instance with backbone model + model = SpeakerIdetification( + backbone=ecapa_tdnn, num_class=config.num_speakers) + + # stage3: load the pre-trained model + # we get the last model from the epoch and save_interval + args.load_checkpoint = os.path.abspath( + os.path.expanduser(args.load_checkpoint)) + + # load model checkpoint to sid model + state_dict = paddle.load( + os.path.join(args.load_checkpoint, 'model.pdparams')) + model.set_state_dict(state_dict) + logger.info(f'Checkpoint loaded from {args.load_checkpoint}') + + # stage4: construct the enroll and test dataloader + + enroll_dataset = VoxCeleb( + subset='enroll', + target_dir=args.data_dir, + feat_type='melspectrogram', + random_chunk=False, + n_mels=config.n_mels, + window_size=config.window_size, + hop_length=config.hop_size) + enroll_sampler = BatchSampler( + enroll_dataset, batch_size=config.batch_size, + shuffle=True) # Shuffle to make embedding normalization more robust. + enrol_loader = DataLoader(enroll_dataset, + batch_sampler=enroll_sampler, + collate_fn=lambda x: batch_feature_normalize( + x, mean_norm=True, std_norm=False), + num_workers=config.num_workers, + return_list=True,) + test_dataset = VoxCeleb( + subset='test', + target_dir=args.data_dir, + feat_type='melspectrogram', + random_chunk=False, + n_mels=config.n_mels, + window_size=config.window_size, + hop_length=config.hop_size) + + test_sampler = BatchSampler( + test_dataset, batch_size=config.batch_size, shuffle=True) + test_loader = DataLoader(test_dataset, + batch_sampler=test_sampler, + collate_fn=lambda x: batch_feature_normalize( + x, mean_norm=True, std_norm=False), + num_workers=config.num_workers, + return_list=True,) + # stage5: we must set the model to eval mode + model.eval() + + # stage6: global embedding norm to imporve the performance + logger.info(f"global embedding norm: {config.global_embedding_norm}") + if config.global_embedding_norm: + global_embedding_mean = None + global_embedding_std = None + mean_norm_flag = config.embedding_mean_norm + std_norm_flag = config.embedding_std_norm + batch_count = 0 + + # stage7: Compute embeddings of audios in enrol and test dataset from model. + id2embedding = {} + # Run multi times to make embedding normalization more stable. + for i in range(2): + for dl in [enrol_loader, test_loader]: + logger.info( + f'Loop {[i+1]}: Computing embeddings on {dl.dataset.subset} dataset' + ) + with paddle.no_grad(): + for batch_idx, batch in enumerate(tqdm(dl)): + + # stage 8-1: extrac the audio embedding + ids, feats, lengths = batch['ids'], batch['feats'], batch[ + 'lengths'] + embeddings = model.backbone(feats, lengths).squeeze( + -1).numpy() # (N, emb_size, 1) -> (N, emb_size) + + # Global embedding normalization. + # if we use the global embedding norm + # eer can reduece about relative 10% + if config.global_embedding_norm: + batch_count += 1 + current_mean = embeddings.mean( + axis=0) if mean_norm_flag else 0 + current_std = embeddings.std( + axis=0) if std_norm_flag else 1 + # Update global mean and std. + if global_embedding_mean is None and global_embedding_std is None: + global_embedding_mean, global_embedding_std = current_mean, current_std + else: + weight = 1 / batch_count # Weight decay by batches. + global_embedding_mean = ( + 1 - weight + ) * global_embedding_mean + weight * current_mean + global_embedding_std = ( + 1 - weight + ) * global_embedding_std + weight * current_std + # Apply global embedding normalization. + embeddings = (embeddings - global_embedding_mean + ) / global_embedding_std + + # Update embedding dict. + id2embedding.update(dict(zip(ids, embeddings))) + + # stage 8: Compute cosine scores. + labels = [] + enroll_ids = [] + test_ids = [] + logger.info(f"read the trial from {VoxCeleb.veri_test_file}") + with open(VoxCeleb.veri_test_file, 'r') as f: + for line in f.readlines(): + label, enroll_id, test_id = line.strip().split(' ') + labels.append(int(label)) + enroll_ids.append(enroll_id.split('.')[0].replace('/', '-')) + test_ids.append(test_id.split('.')[0].replace('/', '-')) + + cos_sim_func = paddle.nn.CosineSimilarity(axis=1) + enrol_embeddings, test_embeddings = map(lambda ids: paddle.to_tensor( + np.asarray([id2embedding[uttid] for uttid in ids], dtype='float32')), + [enroll_ids, test_ids + ]) # (N, emb_size) + scores = cos_sim_func(enrol_embeddings, test_embeddings) + EER, threshold = compute_eer(np.asarray(labels), scores.numpy()) + logger.info( + f'EER of verification test: {EER*100:.4f}%, score threshold: {threshold:.5f}' + ) + + +if __name__ == "__main__": + # yapf: disable + parser = argparse.ArgumentParser(__doc__) + parser.add_argument('--device', + choices=['cpu', 'gpu'], + default="gpu", + help="Select which device to train model, defaults to gpu.") + parser.add_argument("--config", + default=None, + type=str, + help="configuration file") + parser.add_argument("--data-dir", + default="./data/", + type=str, + help="data directory") + parser.add_argument("--load-checkpoint", + type=str, + default='', + help="Directory to load model checkpoint to contiune trainning.") + args = parser.parse_args() + # yapf: enable + # https://yaml.org/type/float.html + config = CfgNode(new_allowed=True) + if args.config: + config.merge_from_file(args.config) + + config.freeze() + print(config) + main(args, config) diff --git a/paddlespeech/vector/exps/ecapa_tdnn/train.py b/paddlespeech/vector/exps/ecapa_tdnn/train.py new file mode 100644 index 00000000..257b97ab --- /dev/null +++ b/paddlespeech/vector/exps/ecapa_tdnn/train.py @@ -0,0 +1,351 @@ +# Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +import argparse +import os +import time + +import numpy as np +import paddle +from paddle.io import BatchSampler +from paddle.io import DataLoader +from paddle.io import DistributedBatchSampler +from yacs.config import CfgNode + +from paddleaudio.compliance.librosa import melspectrogram +from paddleaudio.datasets.voxceleb import VoxCeleb +from paddlespeech.s2t.utils.log import Log +from paddlespeech.vector.io.augment import build_augment_pipeline +from paddlespeech.vector.io.augment import waveform_augment +from paddlespeech.vector.io.batch import batch_pad_right +from paddlespeech.vector.io.batch import feature_normalize +from paddlespeech.vector.io.batch import waveform_collate_fn +from paddlespeech.vector.models.ecapa_tdnn import EcapaTdnn +from paddlespeech.vector.modules.loss import AdditiveAngularMargin +from paddlespeech.vector.modules.loss import LogSoftmaxWrapper +from paddlespeech.vector.modules.sid_model import SpeakerIdetification +from paddlespeech.vector.training.scheduler import CyclicLRScheduler +from paddlespeech.vector.training.seeding import seed_everything +from paddlespeech.vector.utils.time import Timer + +logger = Log(__name__).getlog() + + +def main(args, config): + # stage0: set the training device, cpu or gpu + paddle.set_device(args.device) + + # stage1: we must call the paddle.distributed.init_parallel_env() api at the begining + paddle.distributed.init_parallel_env() + nranks = paddle.distributed.get_world_size() + local_rank = paddle.distributed.get_rank() + # set the random seed, it is a must for multiprocess training + seed_everything(config.seed) + + # stage2: data prepare, such vox1 and vox2 data, and augment noise data and pipline + # note: some cmd must do in rank==0, so wo will refactor the data prepare code + train_dataset = VoxCeleb('train', target_dir=args.data_dir) + dev_dataset = VoxCeleb('dev', target_dir=args.data_dir) + + if config.augment: + augment_pipeline = build_augment_pipeline(target_dir=args.data_dir) + else: + augment_pipeline = [] + + # stage3: build the dnn backbone model network + ecapa_tdnn = EcapaTdnn(**config.model) + + # stage4: build the speaker verification train instance with backbone model + model = SpeakerIdetification( + backbone=ecapa_tdnn, num_class=VoxCeleb.num_speakers) + + # stage5: build the optimizer, we now only construct the AdamW optimizer + # 140000 is single gpu steps + # so, in multi-gpu mode, wo reduce the step_size to 140000//nranks to enable CyclicLRScheduler + lr_schedule = CyclicLRScheduler( + base_lr=config.learning_rate, max_lr=1e-3, step_size=140000 // nranks) + optimizer = paddle.optimizer.AdamW( + learning_rate=lr_schedule, parameters=model.parameters()) + + # stage6: build the loss function, we now only support LogSoftmaxWrapper + criterion = LogSoftmaxWrapper( + loss_fn=AdditiveAngularMargin(margin=0.2, scale=30)) + + # stage7: confirm training start epoch + # if pre-trained model exists, start epoch confirmed by the pre-trained model + start_epoch = 0 + if args.load_checkpoint: + logger.info("load the check point") + args.load_checkpoint = os.path.abspath( + os.path.expanduser(args.load_checkpoint)) + try: + # load model checkpoint + state_dict = paddle.load( + os.path.join(args.load_checkpoint, 'model.pdparams')) + model.set_state_dict(state_dict) + + # load optimizer checkpoint + state_dict = paddle.load( + os.path.join(args.load_checkpoint, 'model.pdopt')) + optimizer.set_state_dict(state_dict) + if local_rank == 0: + logger.info(f'Checkpoint loaded from {args.load_checkpoint}') + except FileExistsError: + if local_rank == 0: + logger.info('Train from scratch.') + + try: + start_epoch = int(args.load_checkpoint[-1]) + logger.info(f'Restore training from epoch {start_epoch}.') + except ValueError: + pass + + # stage8: we build the batch sampler for paddle.DataLoader + train_sampler = DistributedBatchSampler( + train_dataset, + batch_size=config.batch_size, + shuffle=True, + drop_last=False) + train_loader = DataLoader( + train_dataset, + batch_sampler=train_sampler, + num_workers=config.num_workers, + collate_fn=waveform_collate_fn, + return_list=True, + use_buffer_reader=True, ) + + # stage9: start to train + # we will comment the training process + steps_per_epoch = len(train_sampler) + timer = Timer(steps_per_epoch * config.epochs) + last_saved_epoch = "" + timer.start() + + for epoch in range(start_epoch + 1, config.epochs + 1): + # at the begining, model must set to train mode + model.train() + + avg_loss = 0 + num_corrects = 0 + num_samples = 0 + train_reader_cost = 0.0 + train_feat_cost = 0.0 + train_run_cost = 0.0 + + reader_start = time.time() + for batch_idx, batch in enumerate(train_loader): + train_reader_cost += time.time() - reader_start + + # stage 9-1: batch data is audio sample points and speaker id label + feat_start = time.time() + waveforms, labels = batch['waveforms'], batch['labels'] + waveforms, lengths = batch_pad_right(waveforms.numpy()) + waveforms = paddle.to_tensor(waveforms) + + # stage 9-2: audio sample augment method, which is done on the audio sample point + # the original wavefrom and the augmented waveform is concatented in a batch + # eg. five augment method in the augment pipeline + # the final data nums is batch_size * [five + one] + # -> five augmented waveform batch plus one original batch waveform + if len(augment_pipeline) != 0: + waveforms = waveform_augment(waveforms, augment_pipeline) + labels = paddle.concat( + [labels for i in range(len(augment_pipeline) + 1)]) + + # stage 9-3: extract the audio feats,such fbank, mfcc, spectrogram + feats = [] + for waveform in waveforms.numpy(): + feat = melspectrogram( + x=waveform, + sr=config.sr, + n_mels=config.n_mels, + window_size=config.window_size, + hop_length=config.hop_size) + feats.append(feat) + feats = paddle.to_tensor(np.asarray(feats)) + + # stage 9-4: feature normalize, which help converge and imporve the performance + feats = feature_normalize( + feats, mean_norm=True, std_norm=False) # Features normalization + train_feat_cost += time.time() - feat_start + + # stage 9-5: model forward, such ecapa-tdnn, x-vector + train_start = time.time() + logits = model(feats) + + # stage 9-6: loss function criterion, such AngularMargin, AdditiveAngularMargin + loss = criterion(logits, labels) + + # stage 9-7: update the gradient and clear the gradient cache + loss.backward() + optimizer.step() + if isinstance(optimizer._learning_rate, + paddle.optimizer.lr.LRScheduler): + optimizer._learning_rate.step() + optimizer.clear_grad() + train_run_cost += time.time() - train_start + + # stage 9-8: Calculate average loss per batch + avg_loss += loss.numpy()[0] + + # stage 9-9: Calculate metrics, which is one-best accuracy + preds = paddle.argmax(logits, axis=1) + num_corrects += (preds == labels).numpy().sum() + num_samples += feats.shape[0] + timer.count() # step plus one in timer + + # stage 9-10: print the log information only on 0-rank per log-freq batchs + if (batch_idx + 1) % config.log_interval == 0 and local_rank == 0: + lr = optimizer.get_lr() + avg_loss /= config.log_interval + avg_acc = num_corrects / num_samples + + print_msg = 'Train Epoch={}/{}, Step={}/{}'.format( + epoch, config.epochs, batch_idx + 1, steps_per_epoch) + print_msg += ' loss={:.4f}'.format(avg_loss) + print_msg += ' acc={:.4f}'.format(avg_acc) + print_msg += ' avg_reader_cost: {:.5f} sec,'.format( + train_reader_cost / config.log_interval) + print_msg += ' avg_feat_cost: {:.5f} sec,'.format( + train_feat_cost / config.log_interval) + print_msg += ' avg_train_cost: {:.5f} sec,'.format( + train_run_cost / config.log_interval) + print_msg += ' lr={:.4E} step/sec={:.2f} | ETA {}'.format( + lr, timer.timing, timer.eta) + logger.info(print_msg) + + avg_loss = 0 + num_corrects = 0 + num_samples = 0 + train_reader_cost = 0.0 + train_feat_cost = 0.0 + train_run_cost = 0.0 + + reader_start = time.time() + + # stage 9-11: save the model parameters only on 0-rank per save-freq batchs + if epoch % config.save_interval == 0 and batch_idx + 1 == steps_per_epoch: + if local_rank != 0: + paddle.distributed.barrier( + ) # Wait for valid step in main process + continue # Resume trainning on other process + + # stage 9-12: construct the valid dataset dataloader + dev_sampler = BatchSampler( + dev_dataset, + batch_size=config.batch_size, + shuffle=False, + drop_last=False) + dev_loader = DataLoader( + dev_dataset, + batch_sampler=dev_sampler, + collate_fn=waveform_collate_fn, + num_workers=config.num_workers, + return_list=True, ) + + # set the model to eval mode + model.eval() + num_corrects = 0 + num_samples = 0 + + # stage 9-13: evaluation the valid dataset batch data + logger.info('Evaluate on validation dataset') + with paddle.no_grad(): + for batch_idx, batch in enumerate(dev_loader): + waveforms, labels = batch['waveforms'], batch['labels'] + + feats = [] + for waveform in waveforms.numpy(): + feat = melspectrogram( + x=waveform, + sr=config.sr, + n_mels=config.n_mels, + window_size=config.window_size, + hop_length=config.hop_size) + feats.append(feat) + + feats = paddle.to_tensor(np.asarray(feats)) + feats = feature_normalize( + feats, mean_norm=True, std_norm=False) + logits = model(feats) + + preds = paddle.argmax(logits, axis=1) + num_corrects += (preds == labels).numpy().sum() + num_samples += feats.shape[0] + + print_msg = '[Evaluation result]' + print_msg += ' dev_acc={:.4f}'.format(num_corrects / num_samples) + logger.info(print_msg) + + # stage 9-14: Save model parameters + save_dir = os.path.join(args.checkpoint_dir, + 'epoch_{}'.format(epoch)) + last_saved_epoch = os.path.join('epoch_{}'.format(epoch), + "model.pdparams") + logger.info('Saving model checkpoint to {}'.format(save_dir)) + paddle.save(model.state_dict(), + os.path.join(save_dir, 'model.pdparams')) + paddle.save(optimizer.state_dict(), + os.path.join(save_dir, 'model.pdopt')) + + if nranks > 1: + paddle.distributed.barrier() # Main process + + # stage 10: create the final trained model.pdparams with soft link + if local_rank == 0: + final_model = os.path.join(args.checkpoint_dir, "model.pdparams") + logger.info(f"we will create the final model: {final_model}") + if os.path.islink(final_model): + logger.info( + f"An {final_model} already exists, we will rm is and create it again" + ) + os.unlink(final_model) + os.symlink(last_saved_epoch, final_model) + + +if __name__ == "__main__": + # yapf: disable + parser = argparse.ArgumentParser(__doc__) + parser.add_argument('--device', + choices=['cpu', 'gpu'], + default="cpu", + help="Select which device to train model, defaults to gpu.") + parser.add_argument("--config", + default=None, + type=str, + help="configuration file") + parser.add_argument("--data-dir", + default="./data/", + type=str, + help="data directory") + parser.add_argument("--load-checkpoint", + type=str, + default=None, + help="Directory to load model checkpoint to contiune trainning.") + parser.add_argument("--checkpoint-dir", + type=str, + default='./checkpoint', + help="Directory to save model checkpoints.") + + args = parser.parse_args() + # yapf: enable + + # https://yaml.org/type/float.html + config = CfgNode(new_allowed=True) + if args.config: + config.merge_from_file(args.config) + + config.freeze() + print(config) + + main(args, config) diff --git a/paddlespeech/vector/io/augment.py b/paddlespeech/vector/io/augment.py new file mode 100644 index 00000000..6e508c37 --- /dev/null +++ b/paddlespeech/vector/io/augment.py @@ -0,0 +1,908 @@ +# Copyright (c) 2021 PaddlePaddle Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# this is modified from SpeechBrain +# https://github.com/speechbrain/speechbrain/blob/085be635c07f16d42cd1295045bc46c407f1e15b/speechbrain/lobes/augment.py +import math +import os +from typing import List + +import numpy as np +import paddle +import paddle.nn as nn +import paddle.nn.functional as F + +from paddleaudio import load as load_audio +from paddleaudio.datasets.rirs_noises import OpenRIRNoise +from paddlespeech.s2t.utils.log import Log +from paddlespeech.vector.io.signal_processing import compute_amplitude +from paddlespeech.vector.io.signal_processing import convolve1d +from paddlespeech.vector.io.signal_processing import dB_to_amplitude +from paddlespeech.vector.io.signal_processing import notch_filter +from paddlespeech.vector.io.signal_processing import reverberate + +logger = Log(__name__).getlog() + + +# TODO: Complete type-hint and doc string. +class DropFreq(nn.Layer): + def __init__( + self, + drop_freq_low=1e-14, + drop_freq_high=1, + drop_count_low=1, + drop_count_high=2, + drop_width=0.05, + drop_prob=1, ): + super(DropFreq, self).__init__() + self.drop_freq_low = drop_freq_low + self.drop_freq_high = drop_freq_high + self.drop_count_low = drop_count_low + self.drop_count_high = drop_count_high + self.drop_width = drop_width + self.drop_prob = drop_prob + + def forward(self, waveforms): + # Don't drop (return early) 1-`drop_prob` portion of the batches + dropped_waveform = waveforms.clone() + if paddle.rand([1]) > self.drop_prob: + return dropped_waveform + + # Add channels dimension + if len(waveforms.shape) == 2: + dropped_waveform = dropped_waveform.unsqueeze(-1) + + # Pick number of frequencies to drop + drop_count = paddle.randint( + low=self.drop_count_low, high=self.drop_count_high + 1, shape=[1]) + + # Pick a frequency to drop + drop_range = self.drop_freq_high - self.drop_freq_low + drop_frequency = ( + paddle.rand([drop_count]) * drop_range + self.drop_freq_low) + + # Filter parameters + filter_length = 101 + pad = filter_length // 2 + + # Start with delta function + drop_filter = paddle.zeros([1, filter_length, 1]) + drop_filter[0, pad, 0] = 1 + + # Subtract each frequency + for frequency in drop_frequency: + notch_kernel = notch_filter(frequency, filter_length, + self.drop_width) + drop_filter = convolve1d(drop_filter, notch_kernel, pad) + + # Apply filter + dropped_waveform = convolve1d(dropped_waveform, drop_filter, pad) + + # Remove channels dimension if added + return dropped_waveform.squeeze(-1) + + +class DropChunk(nn.Layer): + def __init__( + self, + drop_length_low=100, + drop_length_high=1000, + drop_count_low=1, + drop_count_high=10, + drop_start=0, + drop_end=None, + drop_prob=1, + noise_factor=0.0, ): + super(DropChunk, self).__init__() + self.drop_length_low = drop_length_low + self.drop_length_high = drop_length_high + self.drop_count_low = drop_count_low + self.drop_count_high = drop_count_high + self.drop_start = drop_start + self.drop_end = drop_end + self.drop_prob = drop_prob + self.noise_factor = noise_factor + + # Validate low < high + if drop_length_low > drop_length_high: + raise ValueError("Low limit must not be more than high limit") + if drop_count_low > drop_count_high: + raise ValueError("Low limit must not be more than high limit") + + # Make sure the length doesn't exceed end - start + if drop_end is not None and drop_end >= 0: + if drop_start > drop_end: + raise ValueError("Low limit must not be more than high limit") + + drop_range = drop_end - drop_start + self.drop_length_low = min(drop_length_low, drop_range) + self.drop_length_high = min(drop_length_high, drop_range) + + def forward(self, waveforms, lengths): + # Reading input list + lengths = (lengths * waveforms.shape[1]).astype('int64') + batch_size = waveforms.shape[0] + dropped_waveform = waveforms.clone() + + # Don't drop (return early) 1-`drop_prob` portion of the batches + if paddle.rand([1]) > self.drop_prob: + return dropped_waveform + + # Store original amplitude for computing white noise amplitude + clean_amplitude = compute_amplitude(waveforms, lengths.unsqueeze(1)) + + # Pick a number of times to drop + drop_times = paddle.randint( + low=self.drop_count_low, + high=self.drop_count_high + 1, + shape=[batch_size], ) + + # Iterate batch to set mask + for i in range(batch_size): + if drop_times[i] == 0: + continue + + # Pick lengths + length = paddle.randint( + low=self.drop_length_low, + high=self.drop_length_high + 1, + shape=[drop_times[i]], ) + + # Compute range of starting locations + start_min = self.drop_start + if start_min < 0: + start_min += lengths[i] + start_max = self.drop_end + if start_max is None: + start_max = lengths[i] + if start_max < 0: + start_max += lengths[i] + start_max = max(0, start_max - length.max()) + + # Pick starting locations + start = paddle.randint( + low=start_min, + high=start_max + 1, + shape=[drop_times[i]], ) + + end = start + length + + # Update waveform + if not self.noise_factor: + for j in range(drop_times[i]): + if start[j] < end[j]: + dropped_waveform[i, start[j]:end[j]] = 0.0 + else: + # Uniform distribution of -2 to +2 * avg amplitude should + # preserve the average for normalization + noise_max = 2 * clean_amplitude[i] * self.noise_factor + for j in range(drop_times[i]): + # zero-center the noise distribution + noise_vec = paddle.rand([length[j]], dtype='float32') + + noise_vec = 2 * noise_max * noise_vec - noise_max + dropped_waveform[i, int(start[j]):int(end[j])] = noise_vec + + return dropped_waveform + + +class Resample(nn.Layer): + def __init__( + self, + orig_freq=16000, + new_freq=16000, + lowpass_filter_width=6, ): + super(Resample, self).__init__() + self.orig_freq = orig_freq + self.new_freq = new_freq + self.lowpass_filter_width = lowpass_filter_width + + # Compute rate for striding + self._compute_strides() + assert self.orig_freq % self.conv_stride == 0 + assert self.new_freq % self.conv_transpose_stride == 0 + + def _compute_strides(self): + # Compute new unit based on ratio of in/out frequencies + base_freq = math.gcd(self.orig_freq, self.new_freq) + input_samples_in_unit = self.orig_freq // base_freq + self.output_samples = self.new_freq // base_freq + + # Store the appropriate stride based on the new units + self.conv_stride = input_samples_in_unit + self.conv_transpose_stride = self.output_samples + + def forward(self, waveforms): + if not hasattr(self, "first_indices"): + self._indices_and_weights(waveforms) + + # Don't do anything if the frequencies are the same + if self.orig_freq == self.new_freq: + return waveforms + + unsqueezed = False + if len(waveforms.shape) == 2: + waveforms = waveforms.unsqueeze(1) + unsqueezed = True + elif len(waveforms.shape) == 3: + waveforms = waveforms.transpose([0, 2, 1]) + else: + raise ValueError("Input must be 2 or 3 dimensions") + + # Do resampling + resampled_waveform = self._perform_resample(waveforms) + + if unsqueezed: + resampled_waveform = resampled_waveform.squeeze(1) + else: + resampled_waveform = resampled_waveform.transpose([0, 2, 1]) + + return resampled_waveform + + def _perform_resample(self, waveforms): + # Compute output size and initialize + batch_size, num_channels, wave_len = waveforms.shape + window_size = self.weights.shape[1] + tot_output_samp = self._output_samples(wave_len) + resampled_waveform = paddle.zeros((batch_size, num_channels, + tot_output_samp)) + + # eye size: (num_channels, num_channels, 1) + eye = paddle.eye(num_channels).unsqueeze(2) + + # Iterate over the phases in the polyphase filter + for i in range(self.first_indices.shape[0]): + wave_to_conv = waveforms + first_index = int(self.first_indices[i].item()) + if first_index >= 0: + # trim the signal as the filter will not be applied + # before the first_index + wave_to_conv = wave_to_conv[:, :, first_index:] + + # pad the right of the signal to allow partial convolutions + # meaning compute values for partial windows (e.g. end of the + # window is outside the signal length) + max_index = (tot_output_samp - 1) // self.output_samples + end_index = max_index * self.conv_stride + window_size + current_wave_len = wave_len - first_index + right_padding = max(0, end_index + 1 - current_wave_len) + left_padding = max(0, -first_index) + wave_to_conv = paddle.nn.functional.pad( + wave_to_conv, [left_padding, right_padding], data_format='NCL') + conv_wave = paddle.nn.functional.conv1d( + x=wave_to_conv, + # weight=self.weights[i].repeat(num_channels, 1, 1), + weight=self.weights[i].expand((num_channels, 1, -1)), + stride=self.conv_stride, + groups=num_channels, ) + + # we want conv_wave[:, i] to be at + # output[:, i + n*conv_transpose_stride] + dilated_conv_wave = paddle.nn.functional.conv1d_transpose( + conv_wave, eye, stride=self.conv_transpose_stride) + + # pad dilated_conv_wave so it reaches the output length if needed. + left_padding = i + previous_padding = left_padding + dilated_conv_wave.shape[-1] + right_padding = max(0, tot_output_samp - previous_padding) + dilated_conv_wave = paddle.nn.functional.pad( + dilated_conv_wave, [left_padding, right_padding], + data_format='NCL') + dilated_conv_wave = dilated_conv_wave[:, :, :tot_output_samp] + + resampled_waveform += dilated_conv_wave + + return resampled_waveform + + def _output_samples(self, input_num_samp): + samp_in = int(self.orig_freq) + samp_out = int(self.new_freq) + + tick_freq = abs(samp_in * samp_out) // math.gcd(samp_in, samp_out) + ticks_per_input_period = tick_freq // samp_in + + # work out the number of ticks in the time interval + # [ 0, input_num_samp/samp_in ). + interval_length = input_num_samp * ticks_per_input_period + if interval_length <= 0: + return 0 + ticks_per_output_period = tick_freq // samp_out + + # Get the last output-sample in the closed interval, + # i.e. replacing [ ) with [ ]. Note: integer division rounds down. + # See http://en.wikipedia.org/wiki/Interval_(mathematics) for an + # explanation of the notation. + last_output_samp = interval_length // ticks_per_output_period + + # We need the last output-sample in the open interval, so if it + # takes us to the end of the interval exactly, subtract one. + if last_output_samp * ticks_per_output_period == interval_length: + last_output_samp -= 1 + + # First output-sample index is zero, so the number of output samples + # is the last output-sample plus one. + num_output_samp = last_output_samp + 1 + + return num_output_samp + + def _indices_and_weights(self, waveforms): + # Lowpass filter frequency depends on smaller of two frequencies + min_freq = min(self.orig_freq, self.new_freq) + lowpass_cutoff = 0.99 * 0.5 * min_freq + + assert lowpass_cutoff * 2 <= min_freq + window_width = self.lowpass_filter_width / (2.0 * lowpass_cutoff) + + assert lowpass_cutoff < min(self.orig_freq, self.new_freq) / 2 + output_t = paddle.arange(start=0.0, end=self.output_samples) + output_t /= self.new_freq + min_t = output_t - window_width + max_t = output_t + window_width + + min_input_index = paddle.ceil(min_t * self.orig_freq) + max_input_index = paddle.floor(max_t * self.orig_freq) + num_indices = max_input_index - min_input_index + 1 + + max_weight_width = num_indices.max() + j = paddle.arange(max_weight_width, dtype='float32') + input_index = min_input_index.unsqueeze(1) + j.unsqueeze(0) + delta_t = (input_index / self.orig_freq) - output_t.unsqueeze(1) + + weights = paddle.zeros_like(delta_t) + inside_window_indices = delta_t.abs().less_than( + paddle.to_tensor(window_width)) + + # raised-cosine (Hanning) window with width `window_width` + weights[inside_window_indices] = 0.5 * (1 + paddle.cos( + 2 * math.pi * lowpass_cutoff / self.lowpass_filter_width * + delta_t.masked_select(inside_window_indices))) + + t_eq_zero_indices = delta_t.equal(paddle.zeros_like(delta_t)) + t_not_eq_zero_indices = delta_t.not_equal(paddle.zeros_like(delta_t)) + + # sinc filter function + weights = paddle.where( + t_not_eq_zero_indices, + weights * paddle.sin(2 * math.pi * lowpass_cutoff * delta_t) / + (math.pi * delta_t), weights) + + # limit of the function at t = 0 + weights = paddle.where(t_eq_zero_indices, weights * 2 * lowpass_cutoff, + weights) + + # size (output_samples, max_weight_width) + weights /= self.orig_freq + + self.first_indices = min_input_index + self.weights = weights + + +class SpeedPerturb(nn.Layer): + def __init__( + self, + orig_freq, + speeds=[90, 100, 110], + perturb_prob=1.0, ): + super(SpeedPerturb, self).__init__() + self.orig_freq = orig_freq + self.speeds = speeds + self.perturb_prob = perturb_prob + + # Initialize index of perturbation + self.samp_index = 0 + + # Initialize resamplers + self.resamplers = [] + for speed in self.speeds: + config = { + "orig_freq": self.orig_freq, + "new_freq": self.orig_freq * speed // 100, + } + self.resamplers.append(Resample(**config)) + + def forward(self, waveform): + # Don't perturb (return early) 1-`perturb_prob` portion of the batches + if paddle.rand([1]) > self.perturb_prob: + return waveform.clone() + + # Perform a random perturbation + self.samp_index = paddle.randint(len(self.speeds), shape=[1]).item() + perturbed_waveform = self.resamplers[self.samp_index](waveform) + + return perturbed_waveform + + +class AddNoise(nn.Layer): + def __init__( + self, + noise_dataset=None, # None for white noise + num_workers=0, + snr_low=0, + snr_high=0, + mix_prob=1.0, + start_index=None, + normalize=False, ): + super(AddNoise, self).__init__() + + self.num_workers = num_workers + self.snr_low = snr_low + self.snr_high = snr_high + self.mix_prob = mix_prob + self.start_index = start_index + self.normalize = normalize + self.noise_dataset = noise_dataset + self.noise_dataloader = None + + def forward(self, waveforms, lengths=None): + if lengths is None: + lengths = paddle.ones([len(waveforms)]) + + # Copy clean waveform to initialize noisy waveform + noisy_waveform = waveforms.clone() + lengths = (lengths * waveforms.shape[1]).astype('int64').unsqueeze(1) + + # Don't add noise (return early) 1-`mix_prob` portion of the batches + if paddle.rand([1]) > self.mix_prob: + return noisy_waveform + + # Compute the average amplitude of the clean waveforms + clean_amplitude = compute_amplitude(waveforms, lengths) + + # Pick an SNR and use it to compute the mixture amplitude factors + SNR = paddle.rand((len(waveforms), 1)) + SNR = SNR * (self.snr_high - self.snr_low) + self.snr_low + noise_amplitude_factor = 1 / (dB_to_amplitude(SNR) + 1) + new_noise_amplitude = noise_amplitude_factor * clean_amplitude + + # Scale clean signal appropriately + noisy_waveform *= 1 - noise_amplitude_factor + + # Loop through clean samples and create mixture + if self.noise_dataset is None: + white_noise = paddle.normal(shape=waveforms.shape) + noisy_waveform += new_noise_amplitude * white_noise + else: + tensor_length = waveforms.shape[1] + noise_waveform, noise_length = self._load_noise( + lengths, + tensor_length, ) + + # Rescale and add + noise_amplitude = compute_amplitude(noise_waveform, noise_length) + noise_waveform *= new_noise_amplitude / (noise_amplitude + 1e-14) + noisy_waveform += noise_waveform + + # Normalizing to prevent clipping + if self.normalize: + abs_max, _ = paddle.max( + paddle.abs(noisy_waveform), axis=1, keepdim=True) + noisy_waveform = noisy_waveform / abs_max.clip(min=1.0) + + return noisy_waveform + + def _load_noise(self, lengths, max_length): + """ + Load a batch of noises + + args + lengths(Paddle.Tensor): Num samples of waveforms with shape (N, 1). + max_length(int): Width of a batch. + """ + lengths = lengths.squeeze(1) + batch_size = len(lengths) + + # Load a noise batch + if self.noise_dataloader is None: + + def noise_collate_fn(batch): + def pad(x, target_length, mode='constant', **kwargs): + x = np.asarray(x) + w = target_length - x.shape[0] + assert w >= 0, f'Target length {target_length} is less than origin length {x.shape[0]}' + return np.pad(x, [0, w], mode=mode, **kwargs) + + ids = [item['id'] for item in batch] + lengths = np.asarray([item['feat'].shape[0] for item in batch]) + waveforms = list( + map(lambda x: pad(x, max(max_length, lengths.max().item())), + [item['feat'] for item in batch])) + waveforms = np.stack(waveforms) + return {'ids': ids, 'feats': waveforms, 'lengths': lengths} + + # Create noise data loader. + self.noise_dataloader = paddle.io.DataLoader( + self.noise_dataset, + batch_size=batch_size, + shuffle=True, + num_workers=self.num_workers, + collate_fn=noise_collate_fn, + return_list=True, ) + self.noise_data = iter(self.noise_dataloader) + + noise_batch, noise_len = self._load_noise_batch_of_size(batch_size) + + # Select a random starting location in the waveform + start_index = self.start_index + if self.start_index is None: + start_index = 0 + max_chop = (noise_len - lengths).min().clip(min=1) + start_index = paddle.randint(high=max_chop, shape=[1]) + + # Truncate noise_batch to max_length + noise_batch = noise_batch[:, start_index:start_index + max_length] + noise_len = (noise_len - start_index).clip(max=max_length).unsqueeze(1) + return noise_batch, noise_len + + def _load_noise_batch_of_size(self, batch_size): + """Concatenate noise batches, then chop to correct size""" + noise_batch, noise_lens = self._load_noise_batch() + + # Expand + while len(noise_batch) < batch_size: + noise_batch = paddle.concat((noise_batch, noise_batch)) + noise_lens = paddle.concat((noise_lens, noise_lens)) + + # Contract + if len(noise_batch) > batch_size: + noise_batch = noise_batch[:batch_size] + noise_lens = noise_lens[:batch_size] + + return noise_batch, noise_lens + + def _load_noise_batch(self): + """Load a batch of noises, restarting iteration if necessary.""" + try: + batch = next(self.noise_data) + except StopIteration: + self.noise_data = iter(self.noise_dataloader) + batch = next(self.noise_data) + + noises, lens = batch['feats'], batch['lengths'] + return noises, lens + + +class AddReverb(nn.Layer): + def __init__( + self, + rir_dataset, + reverb_prob=1.0, + rir_scale_factor=1.0, + num_workers=0, ): + super(AddReverb, self).__init__() + self.rir_dataset = rir_dataset + self.reverb_prob = reverb_prob + self.rir_scale_factor = rir_scale_factor + + # Create rir data loader. + def rir_collate_fn(batch): + def pad(x, target_length, mode='constant', **kwargs): + x = np.asarray(x) + w = target_length - x.shape[0] + assert w >= 0, f'Target length {target_length} is less than origin length {x.shape[0]}' + return np.pad(x, [0, w], mode=mode, **kwargs) + + ids = [item['id'] for item in batch] + lengths = np.asarray([item['feat'].shape[0] for item in batch]) + waveforms = list( + map(lambda x: pad(x, lengths.max().item()), + [item['feat'] for item in batch])) + waveforms = np.stack(waveforms) + return {'ids': ids, 'feats': waveforms, 'lengths': lengths} + + self.rir_dataloader = paddle.io.DataLoader( + self.rir_dataset, + collate_fn=rir_collate_fn, + num_workers=num_workers, + shuffle=True, + return_list=True, ) + + self.rir_data = iter(self.rir_dataloader) + + def forward(self, waveforms, lengths=None): + """ + Arguments + --------- + waveforms : tensor + Shape should be `[batch, time]` or `[batch, time, channels]`. + lengths : tensor + Shape should be a single dimension, `[batch]`. + + Returns + ------- + Tensor of shape `[batch, time]` or `[batch, time, channels]`. + """ + + if lengths is None: + lengths = paddle.ones([len(waveforms)]) + + # Don't add reverb (return early) 1-`reverb_prob` portion of the time + if paddle.rand([1]) > self.reverb_prob: + return waveforms.clone() + + # Add channels dimension if necessary + channel_added = False + if len(waveforms.shape) == 2: + waveforms = waveforms.unsqueeze(-1) + channel_added = True + + # Load and prepare RIR + rir_waveform = self._load_rir() + + # Compress or dilate RIR + if self.rir_scale_factor != 1: + rir_waveform = F.interpolate( + rir_waveform.transpose([0, 2, 1]), + scale_factor=self.rir_scale_factor, + mode="linear", + align_corners=False, + data_format='NCW', ) + # (N, C, L) -> (N, L, C) + rir_waveform = rir_waveform.transpose([0, 2, 1]) + + rev_waveform = reverberate( + waveforms, + rir_waveform, + self.rir_dataset.sample_rate, + rescale_amp="avg") + + # Remove channels dimension if added + if channel_added: + return rev_waveform.squeeze(-1) + + return rev_waveform + + def _load_rir(self): + try: + batch = next(self.rir_data) + except StopIteration: + self.rir_data = iter(self.rir_dataloader) + batch = next(self.rir_data) + + rir_waveform = batch['feats'] + + # Make sure RIR has correct channels + if len(rir_waveform.shape) == 2: + rir_waveform = rir_waveform.unsqueeze(-1) + + return rir_waveform + + +class AddBabble(nn.Layer): + def __init__( + self, + speaker_count=3, + snr_low=0, + snr_high=0, + mix_prob=1, ): + super(AddBabble, self).__init__() + self.speaker_count = speaker_count + self.snr_low = snr_low + self.snr_high = snr_high + self.mix_prob = mix_prob + + def forward(self, waveforms, lengths=None): + if lengths is None: + lengths = paddle.ones([len(waveforms)]) + + babbled_waveform = waveforms.clone() + lengths = (lengths * waveforms.shape[1]).unsqueeze(1) + batch_size = len(waveforms) + + # Don't mix (return early) 1-`mix_prob` portion of the batches + if paddle.rand([1]) > self.mix_prob: + return babbled_waveform + + # Pick an SNR and use it to compute the mixture amplitude factors + clean_amplitude = compute_amplitude(waveforms, lengths) + SNR = paddle.rand((batch_size, 1)) + SNR = SNR * (self.snr_high - self.snr_low) + self.snr_low + noise_amplitude_factor = 1 / (dB_to_amplitude(SNR) + 1) + new_noise_amplitude = noise_amplitude_factor * clean_amplitude + + # Scale clean signal appropriately + babbled_waveform *= 1 - noise_amplitude_factor + + # For each speaker in the mixture, roll and add + babble_waveform = waveforms.roll((1, ), axis=0) + babble_len = lengths.roll((1, ), axis=0) + for i in range(1, self.speaker_count): + babble_waveform += waveforms.roll((1 + i, ), axis=0) + babble_len = paddle.concat( + [babble_len, babble_len.roll((1, ), axis=0)], axis=-1).max( + axis=-1, keepdim=True) + + # Rescale and add to mixture + babble_amplitude = compute_amplitude(babble_waveform, babble_len) + babble_waveform *= new_noise_amplitude / (babble_amplitude + 1e-14) + babbled_waveform += babble_waveform + + return babbled_waveform + + +class TimeDomainSpecAugment(nn.Layer): + def __init__( + self, + perturb_prob=1.0, + drop_freq_prob=1.0, + drop_chunk_prob=1.0, + speeds=[95, 100, 105], + sample_rate=16000, + drop_freq_count_low=0, + drop_freq_count_high=3, + drop_chunk_count_low=0, + drop_chunk_count_high=5, + drop_chunk_length_low=1000, + drop_chunk_length_high=2000, + drop_chunk_noise_factor=0, ): + super(TimeDomainSpecAugment, self).__init__() + self.speed_perturb = SpeedPerturb( + perturb_prob=perturb_prob, + orig_freq=sample_rate, + speeds=speeds, ) + self.drop_freq = DropFreq( + drop_prob=drop_freq_prob, + drop_count_low=drop_freq_count_low, + drop_count_high=drop_freq_count_high, ) + self.drop_chunk = DropChunk( + drop_prob=drop_chunk_prob, + drop_count_low=drop_chunk_count_low, + drop_count_high=drop_chunk_count_high, + drop_length_low=drop_chunk_length_low, + drop_length_high=drop_chunk_length_high, + noise_factor=drop_chunk_noise_factor, ) + + def forward(self, waveforms, lengths=None): + if lengths is None: + lengths = paddle.ones([len(waveforms)]) + + with paddle.no_grad(): + # Augmentation + waveforms = self.speed_perturb(waveforms) + waveforms = self.drop_freq(waveforms) + waveforms = self.drop_chunk(waveforms, lengths) + + return waveforms + + +class EnvCorrupt(nn.Layer): + def __init__( + self, + reverb_prob=1.0, + babble_prob=1.0, + noise_prob=1.0, + rir_dataset=None, + noise_dataset=None, + num_workers=0, + babble_speaker_count=0, + babble_snr_low=0, + babble_snr_high=0, + noise_snr_low=0, + noise_snr_high=0, + rir_scale_factor=1.0, ): + super(EnvCorrupt, self).__init__() + + # Initialize corrupters + if rir_dataset is not None and reverb_prob > 0.0: + self.add_reverb = AddReverb( + rir_dataset=rir_dataset, + num_workers=num_workers, + reverb_prob=reverb_prob, + rir_scale_factor=rir_scale_factor, ) + + if babble_speaker_count > 0 and babble_prob > 0.0: + self.add_babble = AddBabble( + speaker_count=babble_speaker_count, + snr_low=babble_snr_low, + snr_high=babble_snr_high, + mix_prob=babble_prob, ) + + if noise_dataset is not None and noise_prob > 0.0: + self.add_noise = AddNoise( + noise_dataset=noise_dataset, + num_workers=num_workers, + snr_low=noise_snr_low, + snr_high=noise_snr_high, + mix_prob=noise_prob, ) + + def forward(self, waveforms, lengths=None): + if lengths is None: + lengths = paddle.ones([len(waveforms)]) + + # Augmentation + with paddle.no_grad(): + if hasattr(self, "add_reverb"): + try: + waveforms = self.add_reverb(waveforms, lengths) + except Exception: + pass + if hasattr(self, "add_babble"): + waveforms = self.add_babble(waveforms, lengths) + if hasattr(self, "add_noise"): + waveforms = self.add_noise(waveforms, lengths) + + return waveforms + + +def build_augment_pipeline(target_dir=None) -> List[paddle.nn.Layer]: + """build augment pipeline + Note: this pipeline cannot be used in the paddle.DataLoader + + Returns: + List[paddle.nn.Layer]: all augment process + """ + logger.info("start to build the augment pipeline") + noise_dataset = OpenRIRNoise('noise', target_dir=target_dir) + rir_dataset = OpenRIRNoise('rir', target_dir=target_dir) + + wavedrop = TimeDomainSpecAugment( + sample_rate=16000, + speeds=[100], ) + speed_perturb = TimeDomainSpecAugment( + sample_rate=16000, + speeds=[95, 100, 105], ) + add_noise = EnvCorrupt( + noise_dataset=noise_dataset, + reverb_prob=0.0, + noise_prob=1.0, + noise_snr_low=0, + noise_snr_high=15, + rir_scale_factor=1.0, ) + add_rev = EnvCorrupt( + rir_dataset=rir_dataset, + reverb_prob=1.0, + noise_prob=0.0, + rir_scale_factor=1.0, ) + add_rev_noise = EnvCorrupt( + noise_dataset=noise_dataset, + rir_dataset=rir_dataset, + reverb_prob=1.0, + noise_prob=1.0, + noise_snr_low=0, + noise_snr_high=15, + rir_scale_factor=1.0, ) + + return [wavedrop, speed_perturb, add_noise, add_rev, add_rev_noise] + + +def waveform_augment(waveforms: paddle.Tensor, + augment_pipeline: List[paddle.nn.Layer]) -> paddle.Tensor: + """process the augment pipeline and return all the waveforms + + Args: + waveforms (paddle.Tensor): original batch waveform + augment_pipeline (List[paddle.nn.Layer]): agument pipeline process + + Returns: + paddle.Tensor: all the audio waveform including the original waveform and augmented waveform + """ + # stage 0: store the original waveforms + waveforms_aug_list = [waveforms] + + # augment the original batch waveform + for aug in augment_pipeline: + # stage 1: augment the data + waveforms_aug = aug(waveforms) # (N, L) + if waveforms_aug.shape[1] >= waveforms.shape[1]: + # Trunc + waveforms_aug = waveforms_aug[:, :waveforms.shape[1]] + else: + # Pad + lengths_to_pad = waveforms.shape[1] - waveforms_aug.shape[1] + waveforms_aug = F.pad( + waveforms_aug.unsqueeze(-1), [0, lengths_to_pad], + data_format='NLC').squeeze(-1) + # stage 2: append the augmented waveform into the list + waveforms_aug_list.append(waveforms_aug) + + # get the all the waveforms + return paddle.concat(waveforms_aug_list, axis=0) diff --git a/paddlespeech/vector/io/batch.py b/paddlespeech/vector/io/batch.py new file mode 100644 index 00000000..92ca990c --- /dev/null +++ b/paddlespeech/vector/io/batch.py @@ -0,0 +1,166 @@ +# Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +import numpy +import numpy as np +import paddle + + +def waveform_collate_fn(batch): + waveforms = np.stack([item['feat'] for item in batch]) + labels = np.stack([item['label'] for item in batch]) + + return {'waveforms': waveforms, 'labels': labels} + + +def feature_normalize(feats: paddle.Tensor, + mean_norm: bool=True, + std_norm: bool=True, + convert_to_numpy: bool=False): + # Features normalization if needed + # numpy.mean is a little with paddle.mean about 1e-6 + if convert_to_numpy: + feats_np = feats.numpy() + mean = feats_np.mean(axis=-1, keepdims=True) if mean_norm else 0 + std = feats_np.std(axis=-1, keepdims=True) if std_norm else 1 + feats_np = (feats_np - mean) / std + feats = paddle.to_tensor(feats_np, dtype=feats.dtype) + else: + mean = feats.mean(axis=-1, keepdim=True) if mean_norm else 0 + std = feats.std(axis=-1, keepdim=True) if std_norm else 1 + feats = (feats - mean) / std + + return feats + + +def pad_right_2d(x, target_length, axis=-1, mode='constant', **kwargs): + x = np.asarray(x) + assert len( + x.shape) == 2, f'Only 2D arrays supported, but got shape: {x.shape}' + + w = target_length - x.shape[axis] + assert w >= 0, f'Target length {target_length} is less than origin length {x.shape[axis]}' + + if axis == 0: + pad_width = [[0, w], [0, 0]] + else: + pad_width = [[0, 0], [0, w]] + + return np.pad(x, pad_width, mode=mode, **kwargs) + + +def batch_feature_normalize(batch, mean_norm: bool=True, std_norm: bool=True): + ids = [item['id'] for item in batch] + lengths = np.asarray([item['feat'].shape[1] for item in batch]) + feats = list( + map(lambda x: pad_right_2d(x, lengths.max()), + [item['feat'] for item in batch])) + feats = np.stack(feats) + + # Features normalization if needed + for i in range(len(feats)): + feat = feats[i][:, :lengths[i]] # Excluding pad values. + mean = feat.mean(axis=-1, keepdims=True) if mean_norm else 0 + std = feat.std(axis=-1, keepdims=True) if std_norm else 1 + feats[i][:, :lengths[i]] = (feat - mean) / std + assert feats[i][:, lengths[ + i]:].sum() == 0 # Padding valus should all be 0. + + # Converts into ratios. + # the utterance of the max length doesn't need to padding + # the remaining utterances need to padding and all of them will be padded to max length + # we convert the original length of each utterance to the ratio of the max length + lengths = (lengths / lengths.max()).astype(np.float32) + + return {'ids': ids, 'feats': feats, 'lengths': lengths} + + +def pad_right_to(array, target_shape, mode="constant", value=0): + """ + This function takes a numpy array of arbitrary shape and pads it to target + shape by appending values on the right. + + Args: + array: input numpy array. Input array whose dimension we need to pad. + target_shape : (list, tuple). Target shape we want for the target array its len must be equal to array.ndim + mode : str. Pad mode, please refer to numpy.pad documentation. + value : float. Pad value, please refer to numpy.pad documentation. + + Returns: + array: numpy.array. Padded array. + valid_vals : list. List containing proportion for each dimension of original, non-padded values. + """ + assert len(target_shape) == array.ndim + pads = [] # this contains the abs length of the padding for each dimension. + valid_vals = [] # this contains the relative lengths for each dimension. + i = 0 # iterating over target_shape ndims + while i < len(target_shape): + assert (target_shape[i] >= array.shape[i] + ), "Target shape must be >= original shape for every dim" + pads.append([0, target_shape[i] - array.shape[i]]) + valid_vals.append(array.shape[i] / target_shape[i]) + i += 1 + + array = numpy.pad(array, pads, mode=mode, constant_values=value) + + return array, valid_vals + + +def batch_pad_right(arrays, mode="constant", value=0): + """Given a list of numpy arrays it batches them together by padding to the right + on each dimension in order to get same length for all. + + Args: + arrays : list. List of array we wish to pad together. + mode : str. Padding mode see numpy.pad documentation. + value : float. Padding value see numpy.pad documentation. + + Returns: + array : numpy.array. Padded array. + valid_vals : list. List containing proportion for each dimension of original, non-padded values. + """ + + if not len(arrays): + raise IndexError("arrays list must not be empty") + + if len(arrays) == 1: + # if there is only one array in the batch we simply unsqueeze it. + return numpy.expand_dims(arrays[0], axis=0), numpy.array([1.0]) + + if not (any( + [arrays[i].ndim == arrays[0].ndim for i in range(1, len(arrays))])): + raise IndexError("All arrays must have same number of dimensions") + + # FIXME we limit the support here: we allow padding of only the last dimension + # need to remove this when feat extraction is updated to handle multichannel. + max_shape = [] + for dim in range(arrays[0].ndim): + if dim != (arrays[0].ndim - 1): + if not all( + [x.shape[dim] == arrays[0].shape[dim] for x in arrays[1:]]): + raise EnvironmentError( + "arrays should have same dimensions except for last one") + max_shape.append(max([x.shape[dim] for x in arrays])) + + batched = [] + valid = [] + for t in arrays: + # for each array we apply pad_right_to + padded, valid_percent = pad_right_to( + t, max_shape, mode=mode, value=value) + batched.append(padded) + valid.append(valid_percent[-1]) + + batched = numpy.stack(batched) + + return batched, numpy.array(valid) diff --git a/paddlespeech/vector/io/signal_processing.py b/paddlespeech/vector/io/signal_processing.py new file mode 100644 index 00000000..a61bf554 --- /dev/null +++ b/paddlespeech/vector/io/signal_processing.py @@ -0,0 +1,219 @@ +# Copyright (c) 2021 PaddlePaddle Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +import math + +import numpy as np +import paddle + +# TODO: Complete type-hint and doc string. + + +def blackman_window(win_len, dtype=np.float32): + arcs = np.pi * np.arange(win_len) / float(win_len) + win = np.asarray( + [0.42 - 0.5 * np.cos(2 * arc) + 0.08 * np.cos(4 * arc) for arc in arcs], + dtype=dtype) + return paddle.to_tensor(win) + + +def compute_amplitude(waveforms, lengths=None, amp_type="avg", scale="linear"): + if len(waveforms.shape) == 1: + waveforms = waveforms.unsqueeze(0) + + assert amp_type in ["avg", "peak"] + assert scale in ["linear", "dB"] + + if amp_type == "avg": + if lengths is None: + out = paddle.mean(paddle.abs(waveforms), axis=1, keepdim=True) + else: + wav_sum = paddle.sum(paddle.abs(waveforms), axis=1, keepdim=True) + out = wav_sum / lengths + elif amp_type == "peak": + out = paddle.max(paddle.abs(waveforms), axis=1, keepdim=True) + else: + raise NotImplementedError + + if scale == "linear": + return out + elif scale == "dB": + return paddle.clip(20 * paddle.log10(out), min=-80) + else: + raise NotImplementedError + + +def dB_to_amplitude(SNR): + return 10**(SNR / 20) + + +def convolve1d( + waveform, + kernel, + padding=0, + pad_type="constant", + stride=1, + groups=1, ): + if len(waveform.shape) != 3: + raise ValueError("Convolve1D expects a 3-dimensional tensor") + + # Padding can be a tuple (left_pad, right_pad) or an int + if isinstance(padding, list): + waveform = paddle.nn.functional.pad( + x=waveform, + pad=padding, + mode=pad_type, + data_format='NLC', ) + + # Move time dimension last, which pad and fft and conv expect. + # (N, L, C) -> (N, C, L) + waveform = waveform.transpose([0, 2, 1]) + kernel = kernel.transpose([0, 2, 1]) + + convolved = paddle.nn.functional.conv1d( + x=waveform, + weight=kernel, + stride=stride, + groups=groups, + padding=padding if not isinstance(padding, list) else 0, ) + + # Return time dimension to the second dimension. + return convolved.transpose([0, 2, 1]) + + +def notch_filter(notch_freq, filter_width=101, notch_width=0.05): + # Check inputs + assert 0 < notch_freq <= 1 + assert filter_width % 2 != 0 + pad = filter_width // 2 + inputs = paddle.arange(filter_width, dtype='float32') - pad + + # Avoid frequencies that are too low + notch_freq += notch_width + + # Define sinc function, avoiding division by zero + def sinc(x): + def _sinc(x): + return paddle.sin(x) / x + + # The zero is at the middle index + res = paddle.concat( + [_sinc(x[:pad]), paddle.ones([1]), _sinc(x[pad + 1:])]) + return res + + # Compute a low-pass filter with cutoff frequency notch_freq. + hlpf = sinc(3 * (notch_freq - notch_width) * inputs) + # import torch + # hlpf *= paddle.to_tensor(torch.blackman_window(filter_width).detach().numpy()) + hlpf *= blackman_window(filter_width) + hlpf /= paddle.sum(hlpf) + + # Compute a high-pass filter with cutoff frequency notch_freq. + hhpf = sinc(3 * (notch_freq + notch_width) * inputs) + # hhpf *= paddle.to_tensor(torch.blackman_window(filter_width).detach().numpy()) + hhpf *= blackman_window(filter_width) + hhpf /= -paddle.sum(hhpf) + hhpf[pad] += 1 + + # Adding filters creates notch filter + return (hlpf + hhpf).reshape([1, -1, 1]) + + +def reverberate(waveforms, + rir_waveform, + sample_rate, + impulse_duration=0.3, + rescale_amp="avg"): + orig_shape = waveforms.shape + + if len(waveforms.shape) > 3 or len(rir_waveform.shape) > 3: + raise NotImplementedError + + # if inputs are mono tensors we reshape to 1, samples + if len(waveforms.shape) == 1: + waveforms = waveforms.unsqueeze(0).unsqueeze(-1) + elif len(waveforms.shape) == 2: + waveforms = waveforms.unsqueeze(-1) + + if len(rir_waveform.shape) == 1: # convolve1d expects a 3d tensor ! + rir_waveform = rir_waveform.unsqueeze(0).unsqueeze(-1) + elif len(rir_waveform.shape) == 2: + rir_waveform = rir_waveform.unsqueeze(-1) + + # Compute the average amplitude of the clean + orig_amplitude = compute_amplitude(waveforms, waveforms.shape[1], + rescale_amp) + + # Compute index of the direct signal, so we can preserve alignment + impulse_index_start = rir_waveform.abs().argmax(axis=1).item() + impulse_index_end = min( + impulse_index_start + int(sample_rate * impulse_duration), + rir_waveform.shape[1]) + rir_waveform = rir_waveform[:, impulse_index_start:impulse_index_end, :] + rir_waveform = rir_waveform / paddle.norm(rir_waveform, p=2) + rir_waveform = paddle.flip(rir_waveform, [1]) + + waveforms = convolve1d( + waveform=waveforms, + kernel=rir_waveform, + padding=[rir_waveform.shape[1] - 1, 0], ) + + # Rescale to the peak amplitude of the clean waveform + waveforms = rescale(waveforms, waveforms.shape[1], orig_amplitude, + rescale_amp) + + if len(orig_shape) == 1: + waveforms = waveforms.squeeze(0).squeeze(-1) + if len(orig_shape) == 2: + waveforms = waveforms.squeeze(-1) + + return waveforms + + +def rescale(waveforms, lengths, target_lvl, amp_type="avg", scale="linear"): + assert amp_type in ["peak", "avg"] + assert scale in ["linear", "dB"] + + batch_added = False + if len(waveforms.shape) == 1: + batch_added = True + waveforms = waveforms.unsqueeze(0) + + waveforms = normalize(waveforms, lengths, amp_type) + + if scale == "linear": + out = target_lvl * waveforms + elif scale == "dB": + out = dB_to_amplitude(target_lvl) * waveforms + + else: + raise NotImplementedError("Invalid scale, choose between dB and linear") + + if batch_added: + out = out.squeeze(0) + + return out + + +def normalize(waveforms, lengths=None, amp_type="avg", eps=1e-14): + assert amp_type in ["avg", "peak"] + + batch_added = False + if len(waveforms.shape) == 1: + batch_added = True + waveforms = waveforms.unsqueeze(0) + + den = compute_amplitude(waveforms, lengths, amp_type) + eps + if batch_added: + waveforms = waveforms.squeeze(0) + return waveforms / den diff --git a/paddlespeech/vector/models/ecapa_tdnn.py b/paddlespeech/vector/models/ecapa_tdnn.py index e493b800..0e7287cd 100644 --- a/paddlespeech/vector/models/ecapa_tdnn.py +++ b/paddlespeech/vector/models/ecapa_tdnn.py @@ -47,6 +47,19 @@ class Conv1d(nn.Layer): groups=1, bias=True, padding_mode="reflect", ): + """_summary_ + + Args: + in_channels (int): intput channel or input data dimensions + out_channels (int): output channel or output data dimensions + kernel_size (int): kernel size of 1-d convolution + stride (int, optional): strid in 1-d convolution . Defaults to 1. + padding (str, optional): padding value. Defaults to "same". + dilation (int, optional): dilation in 1-d convolution. Defaults to 1. + groups (int, optional): groups in 1-d convolution. Defaults to 1. + bias (bool, optional): bias in 1-d convolution . Defaults to True. + padding_mode (str, optional): padding mode. Defaults to "reflect". + """ super().__init__() self.kernel_size = kernel_size @@ -134,6 +147,15 @@ class TDNNBlock(nn.Layer): kernel_size, dilation, activation=nn.ReLU, ): + """Implementation of TDNN network + + Args: + in_channels (int): input channels or input embedding dimensions + out_channels (int): output channels or output embedding dimensions + kernel_size (int): the kernel size of the TDNN network block + dilation (int): the dilation of the TDNN network block + activation (paddle class, optional): the activation layers. Defaults to nn.ReLU. + """ super().__init__() self.conv = Conv1d( in_channels=in_channels, @@ -149,6 +171,15 @@ class TDNNBlock(nn.Layer): class Res2NetBlock(nn.Layer): def __init__(self, in_channels, out_channels, scale=8, dilation=1): + """Implementation of Res2Net Block with dilation + The paper is refered as "Res2Net: A New Multi-scale Backbone Architecture", + whose url is https://arxiv.org/abs/1904.01169 + Args: + in_channels (int): input channels or input dimensions + out_channels (int): output channels or output dimensions + scale (int, optional): scale in res2net bolck. Defaults to 8. + dilation (int, optional): dilation of 1-d convolution in TDNN block. Defaults to 1. + """ super().__init__() assert in_channels % scale == 0 assert out_channels % scale == 0 @@ -179,6 +210,14 @@ class Res2NetBlock(nn.Layer): class SEBlock(nn.Layer): def __init__(self, in_channels, se_channels, out_channels): + """Implementation of SEBlock + The paper is refered as "Squeeze-and-Excitation Networks" + whose url is https://arxiv.org/abs/1709.01507 + Args: + in_channels (int): input channels or input data dimensions + se_channels (_type_): _description_ + out_channels (int): output channels or output data dimensions + """ super().__init__() self.conv1 = Conv1d( @@ -275,6 +314,18 @@ class SERes2NetBlock(nn.Layer): kernel_size=1, dilation=1, activation=nn.ReLU, ): + """Implementation of Squeeze-Extraction Res2Blocks in ECAPA-TDNN network model + The paper is refered "Squeeze-and-Excitation Networks" + whose url is: https://arxiv.org/pdf/1709.01507.pdf + Args: + in_channels (int): input channels or input data dimensions + out_channels (int): output channels or output data dimensions + res2net_scale (int, optional): scale in the res2net block. Defaults to 8. + se_channels (int, optional): embedding dimensions of res2net block. Defaults to 128. + kernel_size (int, optional): kernel size of 1-d convolution in TDNN block. Defaults to 1. + dilation (int, optional): dilation of 1-d convolution in TDNN block. Defaults to 1. + activation (paddle.nn.class, optional): activation function. Defaults to nn.ReLU. + """ super().__init__() self.out_channels = out_channels self.tdnn1 = TDNNBlock( @@ -326,7 +377,21 @@ class EcapaTdnn(nn.Layer): res2net_scale=8, se_channels=128, global_context=True, ): - + """Implementation of ECAPA-TDNN backbone model network + The paper is refered as "ECAPA-TDNN: Emphasized Channel Attention, Propagation and Aggregation in TDNN Based Speaker Verification" + whose url is: https://arxiv.org/abs/2005.07143 + Args: + input_size (_type_): input fature dimension + lin_neurons (int, optional): speaker embedding size. Defaults to 192. + activation (paddle.nn.class, optional): activation function. Defaults to nn.ReLU. + channels (list, optional): inter embedding dimension. Defaults to [512, 512, 512, 512, 1536]. + kernel_sizes (list, optional): kernel size of 1-d convolution in TDNN block . Defaults to [5, 3, 3, 3, 1]. + dilations (list, optional): dilations of 1-d convolution in TDNN block. Defaults to [1, 2, 3, 4, 1]. + attention_channels (int, optional): attention dimensions. Defaults to 128. + res2net_scale (int, optional): scale value in res2net. Defaults to 8. + se_channels (int, optional): dimensions of squeeze-excitation block. Defaults to 128. + global_context (bool, optional): global context flag. Defaults to True. + """ super().__init__() assert len(channels) == len(kernel_sizes) assert len(channels) == len(dilations) diff --git a/paddlespeech/vector/modules/loss.py b/paddlespeech/vector/modules/loss.py new file mode 100644 index 00000000..1c80dda4 --- /dev/null +++ b/paddlespeech/vector/modules/loss.py @@ -0,0 +1,93 @@ +# Copyright (c) 2021 PaddlePaddle Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# This is modified from SpeechBrain +# https://github.com/speechbrain/speechbrain/blob/085be635c07f16d42cd1295045bc46c407f1e15b/speechbrain/nnet/losses.py +import math + +import paddle +import paddle.nn as nn +import paddle.nn.functional as F + + +class AngularMargin(nn.Layer): + def __init__(self, margin=0.0, scale=1.0): + """An implementation of Angular Margin (AM) proposed in the following + paper: '''Margin Matters: Towards More Discriminative Deep Neural Network + Embeddings for Speaker Recognition''' (https://arxiv.org/abs/1906.07317) + + Args: + margin (float, optional): The margin for cosine similiarity. Defaults to 0.0. + scale (float, optional): The scale for cosine similiarity. Defaults to 1.0. + """ + super(AngularMargin, self).__init__() + self.margin = margin + self.scale = scale + + def forward(self, outputs, targets): + outputs = outputs - self.margin * targets + return self.scale * outputs + + +class AdditiveAngularMargin(AngularMargin): + def __init__(self, margin=0.0, scale=1.0, easy_margin=False): + """The Implementation of Additive Angular Margin (AAM) proposed + in the following paper: '''Margin Matters: Towards More Discriminative Deep Neural Network Embeddings for Speaker Recognition''' + (https://arxiv.org/abs/1906.07317) + + Args: + margin (float, optional): margin factor. Defaults to 0.0. + scale (float, optional): scale factor. Defaults to 1.0. + easy_margin (bool, optional): easy_margin flag. Defaults to False. + """ + super(AdditiveAngularMargin, self).__init__(margin, scale) + self.easy_margin = easy_margin + + self.cos_m = math.cos(self.margin) + self.sin_m = math.sin(self.margin) + self.th = math.cos(math.pi - self.margin) + self.mm = math.sin(math.pi - self.margin) * self.margin + + def forward(self, outputs, targets): + cosine = outputs.astype('float32') + sine = paddle.sqrt(1.0 - paddle.pow(cosine, 2)) + phi = cosine * self.cos_m - sine * self.sin_m # cos(theta + m) + if self.easy_margin: + phi = paddle.where(cosine > 0, phi, cosine) + else: + phi = paddle.where(cosine > self.th, phi, cosine - self.mm) + outputs = (targets * phi) + ((1.0 - targets) * cosine) + return self.scale * outputs + + +class LogSoftmaxWrapper(nn.Layer): + def __init__(self, loss_fn): + """Speaker identificatin loss function wrapper + including all of compositions of the loss transformation + Args: + loss_fn (_type_): the loss value of a batch + """ + super(LogSoftmaxWrapper, self).__init__() + self.loss_fn = loss_fn + self.criterion = paddle.nn.KLDivLoss(reduction="sum") + + def forward(self, outputs, targets, length=None): + targets = F.one_hot(targets, outputs.shape[1]) + try: + predictions = self.loss_fn(outputs, targets) + except TypeError: + predictions = self.loss_fn(outputs) + + predictions = F.log_softmax(predictions, axis=1) + loss = self.criterion(predictions, targets) / targets.sum() + return loss diff --git a/paddlespeech/vector/modules/sid_model.py b/paddlespeech/vector/modules/sid_model.py new file mode 100644 index 00000000..4045f75d --- /dev/null +++ b/paddlespeech/vector/modules/sid_model.py @@ -0,0 +1,87 @@ +# Copyright (c) 2021 PaddlePaddle Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +import paddle +import paddle.nn as nn +import paddle.nn.functional as F + + +class SpeakerIdetification(nn.Layer): + def __init__( + self, + backbone, + num_class, + lin_blocks=0, + lin_neurons=192, + dropout=0.1, ): + """The speaker identification model, which includes the speaker backbone network + and the a linear transform to speaker class num in training + + Args: + backbone (Paddle.nn.Layer class): the speaker identification backbone network model + num_class (_type_): the speaker class num in the training dataset + lin_blocks (int, optional): the linear layer transform between the embedding and the final linear layer. Defaults to 0. + lin_neurons (int, optional): the output dimension of final linear layer. Defaults to 192. + dropout (float, optional): the dropout factor on the embedding. Defaults to 0.1. + """ + super(SpeakerIdetification, self).__init__() + # speaker idenfication backbone network model + # the output of the backbond network is the target embedding + self.backbone = backbone + if dropout > 0: + self.dropout = nn.Dropout(dropout) + else: + self.dropout = None + + # construct the speaker classifer + input_size = self.backbone.emb_size + self.blocks = nn.LayerList() + for i in range(lin_blocks): + self.blocks.extend([ + nn.BatchNorm1D(input_size), + nn.Linear(in_features=input_size, out_features=lin_neurons), + ]) + input_size = lin_neurons + + # the final layer + self.weight = paddle.create_parameter( + shape=(input_size, num_class), + dtype='float32', + attr=paddle.ParamAttr(initializer=nn.initializer.XavierUniform()), ) + + def forward(self, x, lengths=None): + """Do the speaker identification model forwrd, + including the speaker embedding model and the classifier model network + + Args: + x (paddle.Tensor): input audio feats, + shape=[batch, dimension, times] + lengths (paddle.Tensor, optional): input audio length. + shape=[batch, times] + Defaults to None. + + Returns: + paddle.Tensor: return the logits of the feats + """ + # x.shape: (N, C, L) + x = self.backbone(x, lengths).squeeze( + -1) # (N, emb_size, 1) -> (N, emb_size) + if self.dropout is not None: + x = self.dropout(x) + + for fc in self.blocks: + x = fc(x) + + logits = F.linear(F.normalize(x), F.normalize(self.weight, axis=0)) + + return logits diff --git a/paddlespeech/vector/training/scheduler.py b/paddlespeech/vector/training/scheduler.py new file mode 100644 index 00000000..3dcac057 --- /dev/null +++ b/paddlespeech/vector/training/scheduler.py @@ -0,0 +1,45 @@ +# Copyright (c) 2021 PaddlePaddle Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +from paddle.optimizer.lr import LRScheduler + + +class CyclicLRScheduler(LRScheduler): + def __init__(self, + base_lr: float=1e-8, + max_lr: float=1e-3, + step_size: int=10000): + + super(CyclicLRScheduler, self).__init__() + + self.current_step = -1 + self.base_lr = base_lr + self.max_lr = max_lr + self.step_size = step_size + + def step(self): + if not hasattr(self, 'current_step'): + return + + self.current_step += 1 + if self.current_step >= 2 * self.step_size: + self.current_step %= 2 * self.step_size + + self.last_lr = self.get_lr() + + def get_lr(self): + p = self.current_step / (2 * self.step_size) # Proportion in one cycle. + if p < 0.5: # Increase + return self.base_lr + p / 0.5 * (self.max_lr - self.base_lr) + else: # Decrease + return self.max_lr - (p / 0.5 - 1) * (self.max_lr - self.base_lr) diff --git a/paddlespeech/vector/training/seeding.py b/paddlespeech/vector/training/seeding.py new file mode 100644 index 00000000..0778a27d --- /dev/null +++ b/paddlespeech/vector/training/seeding.py @@ -0,0 +1,28 @@ +# Copyright (c) 2021 PaddlePaddle Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +from paddlespeech.s2t.utils.log import Log + +logger = Log(__name__).getlog() +import random + +import numpy as np +import paddle + + +def seed_everything(seed: int): + """Seed paddle, random and np.random to help reproductivity.""" + paddle.seed(seed) + random.seed(seed) + np.random.seed(seed) + logger.info(f"Set the seed of paddle, random, np.random to {seed}.") diff --git a/paddlespeech/vector/utils/time.py b/paddlespeech/vector/utils/time.py new file mode 100644 index 00000000..8e85b0e1 --- /dev/null +++ b/paddlespeech/vector/utils/time.py @@ -0,0 +1,66 @@ +# Copyright (c) 2021 PaddlePaddle Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License" +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +import math +import time + + +class Timer(object): + '''Calculate runing speed and estimated time of arrival(ETA)''' + + def __init__(self, total_step: int): + self.total_step = total_step + self.last_start_step = 0 + self.current_step = 0 + self._is_running = True + + def start(self): + self.last_time = time.time() + self.start_time = time.time() + + def stop(self): + self._is_running = False + self.end_time = time.time() + + def count(self) -> int: + if not self.current_step >= self.total_step: + self.current_step += 1 + return self.current_step + + @property + def timing(self) -> float: + run_steps = self.current_step - self.last_start_step + self.last_start_step = self.current_step + time_used = time.time() - self.last_time + self.last_time = time.time() + return time_used / run_steps + + @property + def is_running(self) -> bool: + return self._is_running + + @property + def eta(self) -> str: + if not self.is_running: + return '00:00:00' + remaining_time = time.time() - self.start_time + return seconds_to_hms(remaining_time) + + +def seconds_to_hms(seconds: int) -> str: + '''Convert the number of seconds to hh:mm:ss''' + h = math.floor(seconds / 3600) + m = math.floor((seconds - h * 3600) / 60) + s = int(seconds - h * 3600 - m * 60) + hms_str = '{:0>2}:{:0>2}:{:0>2}'.format(h, m, s) + return hms_str diff --git a/tests/unit/cli/test_cli.sh b/tests/unit/cli/test_cli.sh index b0d18b3b..96ab84d6 100755 --- a/tests/unit/cli/test_cli.sh +++ b/tests/unit/cli/test_cli.sh @@ -43,3 +43,16 @@ paddlespeech asr --input ./zh.wav | paddlespeech text --task punc paddlespeech stats --task asr paddlespeech stats --task tts paddlespeech stats --task cls + +# Speaker Verification +wget -c https://paddlespeech.bj.bcebos.com/vector/audio/85236145389.wav +paddlespeech vector --task spk --input 85236145389.wav + +echo -e "demo1 85236145389.wav \n demo2 85236145389.wav" > vec.job +paddlespeech vector --task spk --input vec.job + +echo -e "demo3 85236145389.wav \n demo4 85236145389.wav" | paddlespeech vector --task spk +rm 85236145389.wav +rm vec.job + + diff --git a/tests/unit/vector/conftest.py b/tests/unit/vector/conftest.py new file mode 100644 index 00000000..7cac519b --- /dev/null +++ b/tests/unit/vector/conftest.py @@ -0,0 +1,11 @@ +def pytest_addoption(parser): + parser.addoption("--device", action="store", default="cpu") + + +def pytest_generate_tests(metafunc): + # This is called for every test. Only get/set command line arguments + # if the argument is specified in the list of test "fixturenames". + option_value = metafunc.config.option.device + if "device" in metafunc.fixturenames and option_value is not None: + metafunc.parametrize("device", [option_value]) + diff --git a/tests/unit/vector/test_augment.py b/tests/unit/vector/test_augment.py new file mode 100644 index 00000000..21d75bb3 --- /dev/null +++ b/tests/unit/vector/test_augment.py @@ -0,0 +1,138 @@ +# Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +import os + +import numpy as np +import paddle +import paddle.nn as nn +import paddle.nn.functional as F +from paddle.io import BatchSampler +from paddle.io import DataLoader +from paddle.io import Dataset + + +def test_add_noise(tmpdir, device): + paddle.device.set_device(device) + from paddlespeech.vector.io.augment import AddNoise + + test_waveform = paddle.sin( + paddle.arange(16000.0, dtype="float32")).unsqueeze(0) + test_noise = paddle.cos( + paddle.arange(16000.0, dtype="float32")).unsqueeze(0) + wav_lens = paddle.ones([1], dtype="float32") + + # Edge cases + no_noise = AddNoise(mix_prob=0.0) + assert no_noise(test_waveform, wav_lens).allclose(test_waveform) + + +def test_speed_perturb(device): + paddle.device.set_device(device) + from paddlespeech.vector.io.augment import SpeedPerturb + + test_waveform = paddle.sin( + paddle.arange(16000.0, dtype="float32")).unsqueeze(0) + + # Edge cases + no_perturb = SpeedPerturb(16000, perturb_prob=0.0) + assert no_perturb(test_waveform).allclose(test_waveform) + no_perturb = SpeedPerturb(16000, speeds=[100]) + assert no_perturb(test_waveform).allclose(test_waveform) + + # # Half speed + half_speed = SpeedPerturb(16000, speeds=[50]) + assert half_speed(test_waveform).allclose(test_waveform[:, ::2], atol=3e-1) + + +def test_babble(device): + paddle.device.set_device(device) + from paddlespeech.vector.io.augment import AddBabble + + test_waveform = paddle.stack( + (paddle.sin(paddle.arange(16000.0, dtype="float32")), + paddle.cos(paddle.arange(16000.0, dtype="float32")), )) + lengths = paddle.ones([2]) + + # Edge cases + no_babble = AddBabble(mix_prob=0.0) + assert no_babble(test_waveform, lengths).allclose(test_waveform) + no_babble = AddBabble(speaker_count=1, snr_low=1000, snr_high=1000) + assert no_babble(test_waveform, lengths).allclose(test_waveform) + + # One babbler just averages the two speakers + babble = AddBabble(speaker_count=1).to(device) + expected = (test_waveform + test_waveform.roll(1, 0)) / 2 + assert babble(test_waveform, lengths).allclose(expected, atol=1e-4) + + +def test_drop_freq(device): + paddle.device.set_device(device) + from paddlespeech.vector.io.augment import DropFreq + + test_waveform = paddle.sin( + paddle.arange(16000.0, dtype="float32")).unsqueeze(0) + + # Edge cases + no_drop = DropFreq(drop_prob=0.0) + assert no_drop(test_waveform).allclose(test_waveform) + no_drop = DropFreq(drop_count_low=0, drop_count_high=0) + assert no_drop(test_waveform).allclose(test_waveform) + + # Check case where frequency range *does not* include signal frequency + drop_diff_freq = DropFreq(drop_freq_low=0.5, drop_freq_high=0.9) + assert drop_diff_freq(test_waveform).allclose(test_waveform, atol=1e-1) + + # Check case where frequency range *does* include signal frequency + drop_same_freq = DropFreq(drop_freq_low=0.28, drop_freq_high=0.28) + assert drop_same_freq(test_waveform).allclose( + paddle.zeros([1, 16000]), atol=4e-1) + + +def test_drop_chunk(device): + paddle.device.set_device(device) + from paddlespeech.vector.io.augment import DropChunk + + test_waveform = paddle.sin( + paddle.arange(16000.0, dtype="float32")).unsqueeze(0) + lengths = paddle.ones([1]) + + # Edge cases + no_drop = DropChunk(drop_prob=0.0) + assert no_drop(test_waveform, lengths).allclose(test_waveform) + no_drop = DropChunk(drop_length_low=0, drop_length_high=0) + assert no_drop(test_waveform, lengths).allclose(test_waveform) + no_drop = DropChunk(drop_count_low=0, drop_count_high=0) + assert no_drop(test_waveform, lengths).allclose(test_waveform) + no_drop = DropChunk(drop_start=0, drop_end=0) + assert no_drop(test_waveform, lengths).allclose(test_waveform) + + # Specify all parameters to ensure it is deterministic + dropper = DropChunk( + drop_length_low=100, + drop_length_high=100, + drop_count_low=1, + drop_count_high=1, + drop_start=100, + drop_end=200, + noise_factor=0.0, ) + expected_waveform = test_waveform.clone() + expected_waveform[:, 100:200] = 0.0 + + assert dropper(test_waveform, lengths).allclose(expected_waveform) + + # Make sure amplitude is similar before and after + dropper = DropChunk(noise_factor=1.0) + drop_amplitude = dropper(test_waveform, lengths).abs().mean() + orig_amplitude = test_waveform.abs().mean() + assert drop_amplitude.allclose(orig_amplitude, atol=1e-2)