Merge pull request #1523 from Honei/vox12

[vector] ecapa-tdnn on voxceleb
pull/1613/head
Honei 2 years ago committed by GitHub
commit e6e72b445a
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23

@ -59,12 +59,17 @@ DEV_TARGET_DATA = "vox1_dev_wav_parta* vox1_dev_wav.zip ae63e55b951748cc486645f5
TEST_LIST = {"vox1_test_wav.zip": "185fdc63c3c739954633d50379a3d102"}
TEST_TARGET_DATA = "vox1_test_wav.zip vox1_test_wav.zip 185fdc63c3c739954633d50379a3d102"
# kaldi trial
# this trial file is organized by kaldi according the official file,
# which is a little different with the official trial veri_test2.txt
KALDI_BASE_URL = "http://www.openslr.org/resources/49/"
TRIAL_LIST = {"voxceleb1_test_v2.txt": "29fc7cc1c5d59f0816dc15d6e8be60f7"}
TRIAL_TARGET_DATA = "voxceleb1_test_v2.txt voxceleb1_test_v2.txt 29fc7cc1c5d59f0816dc15d6e8be60f7"
# voxceleb trial
TRIAL_BASE_URL = "https://www.robots.ox.ac.uk/~vgg/data/voxceleb/meta/"
TRIAL_LIST = {
"veri_test.txt": "29fc7cc1c5d59f0816dc15d6e8be60f7", # voxceleb1
"veri_test2.txt": "b73110731c9223c1461fe49cb48dddfc", # voxceleb1(cleaned)
"list_test_hard.txt": "21c341b6b2168eea2634df0fb4b8fff1", # voxceleb1-H
"list_test_hard2.txt": "857790e09d579a68eb2e339a090343c8", # voxceleb1-H(cleaned)
"list_test_all.txt": "b9ecf7aa49d4b656aa927a8092844e4a", # voxceleb1-E
"list_test_all2.txt": "a53e059deb562ffcfc092bf5d90d9f3a" # voxceleb1-E(cleaned)
}
parser = argparse.ArgumentParser(description=__doc__)
parser.add_argument(
@ -82,7 +87,7 @@ args = parser.parse_args()
def create_manifest(data_dir, manifest_path_prefix):
print("Creating manifest %s ..." % manifest_path_prefix)
print(f"Creating manifest {manifest_path_prefix} from {data_dir}")
json_lines = []
data_path = os.path.join(data_dir, "wav", "**", "*.wav")
total_sec = 0.0
@ -114,6 +119,9 @@ def create_manifest(data_dir, manifest_path_prefix):
# voxceleb1 is given explicit in the path
data_dir_name = Path(data_dir).name
manifest_path_prefix = manifest_path_prefix + "." + data_dir_name
if not os.path.exists(os.path.dirname(manifest_path_prefix)):
os.makedirs(os.path.dirname(manifest_path_prefix))
with codecs.open(manifest_path_prefix, 'w', encoding='utf-8') as f:
for line in json_lines:
f.write(line + "\n")
@ -133,11 +141,13 @@ def create_manifest(data_dir, manifest_path_prefix):
def prepare_dataset(base_url, data_list, target_dir, manifest_path,
target_data):
if not os.path.exists(target_dir):
os.mkdir(target_dir)
os.makedirs(target_dir)
# wav directory already exists, it need do nothing
# we will download the voxceleb1 data to ${target_dir}/vox1/dev/ or ${target_dir}/vox1/test directory
if not os.path.exists(os.path.join(target_dir, "wav")):
# download all dataset part
print("start to download the vox1 dev zip package")
for zip_part in data_list.keys():
download_url = " --no-check-certificate " + base_url + "/" + zip_part
download(
@ -166,11 +176,20 @@ def prepare_dataset(base_url, data_list, target_dir, manifest_path,
# create the manifest file
create_manifest(data_dir=target_dir, manifest_path_prefix=manifest_path)
def prepare_trial(base_url, data_list, target_dir):
if not os.path.exists(target_dir):
os.makedirs(target_dir)
for trial, md5sum in data_list.items():
target_trial = os.path.join(target_dir, trial)
if not os.path.exists(os.path.join(target_dir, trial)):
download_url = " --no-check-certificate " + base_url + "/" + trial
download(url=download_url, md5sum=md5sum, target_dir=target_dir)
def main():
if args.target_dir.startswith('~'):
args.target_dir = os.path.expanduser(args.target_dir)
# prepare the vox1 dev data
prepare_dataset(
base_url=BASE_URL,
data_list=DEV_LIST,
@ -178,6 +197,7 @@ def main():
manifest_path=args.manifest_prefix,
target_data=DEV_TARGET_DATA)
# prepare the vox1 test data
prepare_dataset(
base_url=BASE_URL,
data_list=TEST_LIST,
@ -185,6 +205,13 @@ def main():
manifest_path=args.manifest_prefix,
target_data=TEST_TARGET_DATA)
# prepare the vox1 trial
prepare_trial(
base_url=TRIAL_BASE_URL,
data_list=TRIAL_LIST,
target_dir=os.path.dirname(args.manifest_prefix)
)
print("Manifest prepare done!")

@ -0,0 +1,163 @@
# Copyright (c) 2021 PaddlePaddle Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""Prepare VoxCeleb2 dataset
Download and unpack the voxceleb2 data files.
Voxceleb2 data is stored as the m4a format,
so we need convert the m4a to wav with the convert.sh scripts
"""
import argparse
import codecs
import glob
import json
import os
import subprocess
from pathlib import Path
import soundfile
from utils.utility import check_md5sum
from utils.utility import download
from utils.utility import unzip
# all the data will be download in the current data/voxceleb directory default
DATA_HOME = os.path.expanduser('.')
BASE_URL = "--no-check-certificate https://www.robots.ox.ac.uk/~vgg/data/voxceleb/data/"
# dev data
DEV_DATA_URL = BASE_URL + '/vox2_aac.zip'
DEV_MD5SUM = "bbc063c46078a602ca71605645c2a402"
# test data
TEST_DATA_URL = BASE_URL + '/vox2_test_aac.zip'
TEST_MD5SUM = "0d2b3ea430a821c33263b5ea37ede312"
parser = argparse.ArgumentParser(description=__doc__)
parser.add_argument(
"--target_dir",
default=DATA_HOME + "/voxceleb2/",
type=str,
help="Directory to save the voxceleb1 dataset. (default: %(default)s)")
parser.add_argument(
"--manifest_prefix",
default="manifest",
type=str,
help="Filepath prefix for output manifests. (default: %(default)s)")
parser.add_argument("--download",
default=False,
action="store_true",
help="Download the voxceleb2 dataset. (default: %(default)s)")
parser.add_argument("--generate",
default=False,
action="store_true",
help="Generate the manifest files. (default: %(default)s)")
args = parser.parse_args()
def create_manifest(data_dir, manifest_path_prefix):
print("Creating manifest %s ..." % manifest_path_prefix)
json_lines = []
data_path = os.path.join(data_dir, "**", "*.wav")
total_sec = 0.0
total_text = 0.0
total_num = 0
speakers = set()
for audio_path in glob.glob(data_path, recursive=True):
audio_id = "-".join(audio_path.split("/")[-3:])
utt2spk = audio_path.split("/")[-3]
duration = soundfile.info(audio_path).duration
text = ""
json_lines.append(
json.dumps(
{
"utt": audio_id,
"utt2spk": str(utt2spk),
"feat": audio_path,
"feat_shape": (duration, ),
"text": text # compatible with asr data format
},
ensure_ascii=False))
total_sec += duration
total_text += len(text)
total_num += 1
speakers.add(utt2spk)
# data_dir_name refer to dev or test
# voxceleb2 is given explicit in the path
data_dir_name = Path(data_dir).name
manifest_path_prefix = manifest_path_prefix + "." + data_dir_name
if not os.path.exists(os.path.dirname(manifest_path_prefix)):
os.makedirs(os.path.dirname(manifest_path_prefix))
with codecs.open(manifest_path_prefix, 'w', encoding='utf-8') as f:
for line in json_lines:
f.write(line + "\n")
manifest_dir = os.path.dirname(manifest_path_prefix)
meta_path = os.path.join(manifest_dir, "voxceleb2." +
data_dir_name) + ".meta"
with codecs.open(meta_path, 'w', encoding='utf-8') as f:
print(f"{total_num} utts", file=f)
print(f"{len(speakers)} speakers", file=f)
print(f"{total_sec / (60 * 60)} h", file=f)
print(f"{total_text} text", file=f)
print(f"{total_text / total_sec} text/sec", file=f)
print(f"{total_sec / total_num} sec/utt", file=f)
def download_dataset(url, md5sum, target_dir, dataset):
if not os.path.exists(target_dir):
os.makedirs(target_dir)
# wav directory already exists, it need do nothing
print("target dir {}".format(os.path.join(target_dir, dataset)))
# unzip the dev dataset will create the dev and unzip the m4a to dev dir
# but the test dataset will unzip to aac
# so, wo create the ${target_dir}/test and unzip the m4a to test dir
if not os.path.exists(os.path.join(target_dir, dataset)):
filepath = download(url, md5sum, target_dir)
if dataset == "test":
unzip(filepath, os.path.join(target_dir, "test"))
def main():
if args.target_dir.startswith('~'):
args.target_dir = os.path.expanduser(args.target_dir)
# download and unpack the vox2-dev data
print("download: {}".format(args.download))
if args.download:
download_dataset(
url=DEV_DATA_URL,
md5sum=DEV_MD5SUM,
target_dir=args.target_dir,
dataset="dev")
download_dataset(
url=TEST_DATA_URL,
md5sum=TEST_MD5SUM,
target_dir=args.target_dir,
dataset="test")
print("VoxCeleb2 download is done!")
if args.generate:
create_manifest(args.target_dir, manifest_path_prefix=args.manifest_prefix)
if __name__ == '__main__':
main()

@ -6,3 +6,51 @@ sv0 - speaker verfication with softmax backend etc, all python code
sv1 - dependence on kaldi, speaker verfication with plda/sc backend,
more info refer to the sv1/readme.txt
## VoxCeleb2 preparation
VoxCeleb2 audio files are released in m4a format. All the VoxCeleb2 m4a audio files must be converted in wav files before feeding them in PaddleSpeech.
Please, follow these steps to prepare the dataset correctly:
1. Download Voxceleb2.
You can find download instructions here: http://www.robots.ox.ac.uk/~vgg/data/voxceleb/
2. Convert .m4a to wav
VoxCeleb2 stores files with the m4a audio format. To use them in PaddleSpeech, you have to convert all the m4a audio files into wav files.
``` shell
ffmpeg -y -i %s -ac 1 -vn -acodec pcm_s16le -ar 16000 %s
```
You can do the conversion using ffmpeg https://gist.github.com/seungwonpark/4f273739beef2691cd53b5c39629d830). This operation might take several hours and should be only once.
3. Put all the wav files in a folder called `wav`. You should have something like `voxceleb2/wav/id*/*.wav` (e.g, `voxceleb2/wav/id00012/21Uxsk56VDQ/00001.wav`)
## voxceleb dataset summary
|dataset | vox1 - dev | vox1 - test |vox2 - dev| vox2 - test|
|---------|-----------|------------|-----------|----------|
|spks | 1211 |40 | 5994 | 118|
|utts | 148642 | 4874 | 1092009 |36273|
| time(h) | 340.4 | 11.2 | 2360.2 |79.9 |
## trial summary
| trial | filename | nums | positive | negative |
|--------|-----------|--------|-------|------|
| VoxCeleb1 | veri_test.txt | 37720 | 18860 | 18860 |
| VoxCeleb1(cleaned) | veri_test2.txt | 37611 | 18802 | 18809 |
| VoxCeleb1-H | list_test_hard.txt | 552536 | 276270 | 276266 |
|VoxCeleb1-H(cleaned) |list_test_hard2.txt | 550894 | 275488 | 275406 |
|VoxCeleb1-E | list_test_all.txt | 581480 | 290743 | 290737 |
|VoxCeleb1-E(cleaned) | list_test_all2.txt |579818 |289921 |289897 |

@ -0,0 +1,52 @@
###########################################
# Data #
###########################################
# we should explicitly specify the wav path of vox2 audio data converted from m4a
vox2_base_path:
augment: True
batch_size: 16
num_workers: 2
num_speakers: 7205 # 1211 vox1, 5994 vox2, 7205 vox1+2, test speakers: 41
shuffle: True
random_chunk: True
###########################################################
# FEATURE EXTRACTION SETTING #
###########################################################
# currently, we only support fbank
sr: 16000 # sample rate
n_mels: 80
window_size: 400 #25ms, sample rate 16000, 25 * 16000 / 1000 = 400
hop_size: 160 #10ms, sample rate 16000, 10 * 16000 / 1000 = 160
###########################################################
# MODEL SETTING #
###########################################################
# currently, we only support ecapa-tdnn in the ecapa_tdnn.yaml
# if we want use another model, please choose another configuration yaml file
model:
input_size: 80
# "channels": [512, 512, 512, 512, 1536],
channels: [1024, 1024, 1024, 1024, 3072]
kernel_sizes: [5, 3, 3, 3, 1]
dilations: [1, 2, 3, 4, 1]
attention_channels: 128
lin_neurons: 192
###########################################
# Training #
###########################################
seed: 1986 # according from speechbrain configuration
epochs: 10
save_interval: 1
log_interval: 1
learning_rate: 1e-8
###########################################
# Testing #
###########################################
global_embedding_norm: True
embedding_mean_norm: True
embedding_std_norm: False

@ -0,0 +1,58 @@
#!/bin/bash
# Copyright (c) 2021 PaddlePaddle Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
stage=1
stop_stage=100
. ${MAIN_ROOT}/utils/parse_options.sh || exit -1;
if [ $# -ne 2 ] ; then
echo "Usage: $0 [options] <data-dir> <conf-path>";
echo "e.g.: $0 ./data/ conf/ecapa_tdnn.yaml"
echo "Options: "
echo " --stage <stage|-1> # Used to run a partially-completed data process from somewhere in the middle."
echo " --stop-stage <stop-stage|100> # Used to run a partially-completed data process stop stage in the middle"
exit 1;
fi
dir=$1
conf_path=$2
mkdir -p ${dir}
if [ ${stage} -le 0 ] && [ ${stop_stage} -ge 0 ]; then
# data prepare for vox1 and vox2, vox2 must be converted from m4a to wav
# we should use the local/convert.sh convert m4a to wav
python3 local/data_prepare.py \
--data-dir ${dir} \
--config ${conf_path}
fi
TARGET_DIR=${MAIN_ROOT}/dataset
mkdir -p ${TARGET_DIR}
if [ ${stage} -le 1 ] && [ ${stop_stage} -ge 1 ]; then
# download data, generate manifests
python3 ${TARGET_DIR}/voxceleb/voxceleb1.py \
--manifest_prefix="data/vox1/manifest" \
--target_dir="${TARGET_DIR}/voxceleb/vox1/"
if [ $? -ne 0 ]; then
echo "Prepare voxceleb failed. Terminated."
exit 1
fi
# for dataset in train dev test; do
# mv data/manifest.${dataset} data/manifest.${dataset}.raw
# done
fi

@ -0,0 +1,71 @@
# Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import argparse
import os
import paddle
from yacs.config import CfgNode
from paddleaudio.datasets.voxceleb import VoxCeleb
from paddlespeech.s2t.utils.log import Log
from paddlespeech.vector.io.augment import build_augment_pipeline
from paddlespeech.vector.training.seeding import seed_everything
logger = Log(__name__).getlog()
def main(args, config):
# stage0: set the cpu device, all data prepare process will be done in cpu mode
paddle.set_device("cpu")
# set the random seed, it is a must for multiprocess training
seed_everything(config.seed)
# stage 1: generate the voxceleb csv file
# Note: this may occurs c++ execption, but the program will execute fine
# so we ignore the execption
# we explicitly pass the vox2 base path to data prepare and generate the audio info
logger.info("start to generate the voxceleb dataset info")
train_dataset = VoxCeleb(
'train', target_dir=args.data_dir, vox2_base_path=config.vox2_base_path)
# stage 2: generate the augment noise csv file
if config.augment:
logger.info("start to generate the augment dataset info")
augment_pipeline = build_augment_pipeline(target_dir=args.data_dir)
if __name__ == "__main__":
# yapf: disable
parser = argparse.ArgumentParser(__doc__)
parser.add_argument("--data-dir",
default="./data/",
type=str,
help="data directory")
parser.add_argument("--config",
default=None,
type=str,
help="configuration file")
args = parser.parse_args()
# yapf: enable
# https://yaml.org/type/float.html
config = CfgNode(new_allowed=True)
if args.config:
config.merge_from_file(args.config)
config.freeze()
print(config)
main(args, config)

@ -0,0 +1,51 @@
#!/bin/bash
# Copyright (c) 2021 PaddlePaddle Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
. ./path.sh
stage=0
stop_stage=100
exp_dir=exp/ecapa-tdnn-vox12-big/ # experiment directory
conf_path=conf/ecapa_tdnn.yaml
audio_path="demo/voxceleb/00001.wav"
use_gpu=true
. ${MAIN_ROOT}/utils/parse_options.sh || exit -1;
if [ $# -ne 0 ] ; then
echo "Usage: $0 [options]";
echo "e.g.: $0 ./data/ exp/voxceleb12/ conf/ecapa_tdnn.yaml"
echo "Options: "
echo " --use-gpu <true,false|true> # specify is gpu is to be used for training"
echo " --stage <stage|-1> # Used to run a partially-completed data process from somewhere in the middle."
echo " --stop-stage <stop-stage|100> # Used to run a partially-completed data process stop stage in the middle"
echo " --exp-dir # experiment directorh, where is has the model.pdparams"
echo " --conf-path # configuration file for extracting the embedding"
echo " --audio-path # audio-path, which will be processed to extract the embedding"
exit 1;
fi
# set the test device
device="cpu"
if ${use_gpu}; then
device="gpu"
fi
if [ ${stage} -le 0 ] && [ ${stop_stage} -ge 0 ]; then
# extract the audio embedding
python3 ${BIN_DIR}/extract_emb.py --device ${device} \
--config ${conf_path} \
--audio-path ${audio_path} --load-checkpoint ${exp_dir}
fi

@ -0,0 +1,42 @@
#!/bin/bash
# Copyright (c) 2021 PaddlePaddle Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
stage=1
stop_stage=100
use_gpu=true # if true, we run on GPU.
. ${MAIN_ROOT}/utils/parse_options.sh || exit -1;
if [ $# -ne 3 ] ; then
echo "Usage: $0 [options] <data-dir> <exp-dir> <conf-path>";
echo "e.g.: $0 ./data/ exp/voxceleb12/ conf/ecapa_tdnn.yaml"
echo "Options: "
echo " --use-gpu <true,false|true> # specify is gpu is to be used for training"
echo " --stage <stage|-1> # Used to run a partially-completed data process from somewhere in the middle."
echo " --stop-stage <stop-stage|100> # Used to run a partially-completed data process stop stage in the middle"
exit 1;
fi
dir=$1
exp_dir=$2
conf_path=$3
if [ ${stage} -le 1 ] && [ ${stop_stage} -ge 1 ]; then
# test the model and compute the eer metrics
python3 ${BIN_DIR}/test.py \
--data-dir ${dir} \
--load-checkpoint ${exp_dir} \
--config ${conf_path}
fi

@ -0,0 +1,61 @@
#!/bin/bash
# Copyright (c) 2021 PaddlePaddle Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
stage=0
stop_stage=100
use_gpu=true # if true, we run on GPU.
. ${MAIN_ROOT}/utils/parse_options.sh || exit -1;
if [ $# -ne 3 ] ; then
echo "Usage: $0 [options] <data-dir> <exp-dir> <conf-path>";
echo "e.g.: $0 ./data/ exp/voxceleb12/ conf/ecapa_tdnn.yaml"
echo "Options: "
echo " --use-gpu <true,false|true> # specify is gpu is to be used for training"
echo " --stage <stage|-1> # Used to run a partially-completed data process from somewhere in the middle."
echo " --stop-stage <stop-stage|100> # Used to run a partially-completed data process stop stage in the middle"
exit 1;
fi
dir=$1
exp_dir=$2
conf_path=$3
# get the gpu nums for training
ngpu=$(echo $CUDA_VISIBLE_DEVICES | awk -F "," '{print NF}')
echo "using $ngpu gpus..."
# setting training device
device="cpu"
if ${use_gpu}; then
device="gpu"
fi
if [ ${stage} -le 1 ] && [ ${stop_stage} -ge 1 ]; then
# train the speaker identification task with voxceleb data
# and we will create the trained model parameters in ${exp_dir}/model.pdparams as the soft link
# Note: we will store the log file in exp/log directory
python3 -m paddle.distributed.launch --gpus=$CUDA_VISIBLE_DEVICES \
${BIN_DIR}/train.py --device ${device} --checkpoint-dir ${exp_dir} \
--data-dir ${dir} --config ${conf_path}
fi
if [ $? -ne 0 ]; then
echo "Failed in training!"
exit 1
fi
exit 0

@ -0,0 +1,28 @@
#!/bin/bash
# Copyright (c) 2021 PaddlePaddle Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
export MAIN_ROOT=`realpath ${PWD}/../../../`
export PATH=${MAIN_ROOT}:${MAIN_ROOT}/utils:${PATH}
export LC_ALL=C
export PYTHONDONTWRITEBYTECODE=1
# Use UTF-8 in Python to avoid UnicodeDecodeError when LC_ALL=C
export PYTHONIOENCODING=UTF-8
export PYTHONPATH=${MAIN_ROOT}:${PYTHONPATH}
export LD_LIBRARY_PATH=${LD_LIBRARY_PATH}:/usr/local/lib/
MODEL=ecapa_tdnn
export BIN_DIR=${MAIN_ROOT}/paddlespeech/vector/exps/${MODEL}

@ -0,0 +1,69 @@
#!/bin/bash
# Copyright (c) 2021 PaddlePaddle Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
. ./path.sh
set -e
#######################################################################
# stage 0: data prepare, including voxceleb1 download and generate {train,dev,enroll,test}.csv
# voxceleb2 data is m4a format, so we need user to convert the m4a to wav yourselves as described in Readme.md with the script local/convert.sh
# stage 1: train the speaker identification model
# stage 2: test speaker identification
# stage 3: extract the training embeding to train the LDA and PLDA
######################################################################
# we can set the variable PPAUDIO_HOME to specifiy the root directory of the downloaded vox1 and vox2 dataset
# default the dataset will be stored in the ~/.paddleaudio/
# the vox2 dataset is stored in m4a format, we need to convert the audio from m4a to wav yourself
# and put all of them to ${PPAUDIO_HOME}/datasets/vox2
# we will find the wav from ${PPAUDIO_HOME}/datasets/vox1/wav and ${PPAUDIO_HOME}/datasets/vox2/wav
# export PPAUDIO_HOME=
stage=0
stop_stage=50
# data directory
# if we set the variable ${dir}, we will store the wav info to this directory
# otherwise, we will store the wav info to vox1 and vox2 directory respectively
# vox2 wav path, we must convert the m4a format to wav format
dir=data/ # data info directory
exp_dir=exp/ecapa-tdnn-vox12-big/ # experiment directory
conf_path=conf/ecapa_tdnn.yaml
gpus=0,1,2,3
source ${MAIN_ROOT}/utils/parse_options.sh || exit 1;
mkdir -p ${exp_dir}
if [ $stage -le 0 ] && [ ${stop_stage} -ge 0 ]; then
# stage 0: data prepare for vox1 and vox2, vox2 must be converted from m4a to wav
bash ./local/data.sh ${dir} ${conf_path}|| exit -1;
fi
if [ $stage -le 1 ] && [ ${stop_stage} -ge 1 ]; then
# stage 1: train the speaker identification model
CUDA_VISIBLE_DEVICES=${gpus} bash ./local/train.sh ${dir} ${exp_dir} ${conf_path}
fi
if [ $stage -le 2 ] && [ ${stop_stage} -ge 2 ]; then
# stage 2: get the speaker verification scores with cosine function
# now we only support use cosine to get the scores
CUDA_VISIBLE_DEVICES=0 bash ./local/test.sh ${dir} ${exp_dir} ${conf_path}
fi
# if [ $stage -le 3 ]; then
# # stage 2: extract the training embeding to train the LDA and PLDA
# # todo: extract the training embedding
# fi

@ -0,0 +1 @@
../../../utils/

@ -15,3 +15,5 @@ from .esc50 import ESC50
from .gtzan import GTZAN
from .tess import TESS
from .urban_sound import UrbanSound8K
from .voxceleb import VoxCeleb
from .rirs_noises import OpenRIRNoise

@ -0,0 +1,205 @@
# Copyright (c) 2021 PaddlePaddle Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import collections
import csv
import glob
import os
import random
from typing import Dict
from typing import List
from typing import Tuple
from paddle.io import Dataset
from tqdm import tqdm
from ..backends import load as load_audio
from ..backends import save as save_wav
from ..utils import DATA_HOME
from ..utils import decompress
from ..utils.download import download_and_decompress
from .dataset import feat_funcs
__all__ = ['OpenRIRNoise']
class OpenRIRNoise(Dataset):
archieves = [
{
'url': 'http://www.openslr.org/resources/28/rirs_noises.zip',
'md5': 'e6f48e257286e05de56413b4779d8ffb',
},
]
sample_rate = 16000
meta_info = collections.namedtuple('META_INFO', ('id', 'duration', 'wav'))
base_path = os.path.join(DATA_HOME, 'open_rir_noise')
wav_path = os.path.join(base_path, 'RIRS_NOISES')
csv_path = os.path.join(base_path, 'csv')
subsets = ['rir', 'noise']
def __init__(self,
subset: str='rir',
feat_type: str='raw',
target_dir=None,
random_chunk: bool=True,
chunk_duration: float=3.0,
seed: int=0,
**kwargs):
assert subset in self.subsets, \
'Dataset subset must be one in {}, but got {}'.format(self.subsets, subset)
self.subset = subset
self.feat_type = feat_type
self.feat_config = kwargs
self.random_chunk = random_chunk
self.chunk_duration = chunk_duration
OpenRIRNoise.csv_path = os.path.join(
target_dir, "open_rir_noise",
"csv") if target_dir else self.csv_path
self._data = self._get_data()
super(OpenRIRNoise, self).__init__()
# Set up a seed to reproduce training or predicting result.
# random.seed(seed)
def _get_data(self):
# Download audio files.
print(f"rirs noises base path: {self.base_path}")
if not os.path.isdir(self.base_path):
download_and_decompress(
self.archieves, self.base_path, decompress=True)
else:
print(
f"{self.base_path} already exists, we will not download and decompress again"
)
# Data preparation.
print(f"prepare the csv to {self.csv_path}")
if not os.path.isdir(self.csv_path):
os.makedirs(self.csv_path)
self.prepare_data()
data = []
with open(os.path.join(self.csv_path, f'{self.subset}.csv'), 'r') as rf:
for line in rf.readlines()[1:]:
audio_id, duration, wav = line.strip().split(',')
data.append(self.meta_info(audio_id, float(duration), wav))
random.shuffle(data)
return data
def _convert_to_record(self, idx: int):
sample = self._data[idx]
record = {}
# To show all fields in a namedtuple: `type(sample)._fields`
for field in type(sample)._fields:
record[field] = getattr(sample, field)
waveform, sr = load_audio(record['wav'])
assert self.feat_type in feat_funcs.keys(), \
f"Unknown feat_type: {self.feat_type}, it must be one in {list(feat_funcs.keys())}"
feat_func = feat_funcs[self.feat_type]
feat = feat_func(
waveform, sr=sr, **self.feat_config) if feat_func else waveform
record.update({'feat': feat})
return record
@staticmethod
def _get_chunks(seg_dur, audio_id, audio_duration):
num_chunks = int(audio_duration / seg_dur) # all in milliseconds
chunk_lst = [
audio_id + "_" + str(i * seg_dur) + "_" + str(i * seg_dur + seg_dur)
for i in range(num_chunks)
]
return chunk_lst
def _get_audio_info(self, wav_file: str,
split_chunks: bool) -> List[List[str]]:
waveform, sr = load_audio(wav_file)
audio_id = wav_file.split("/open_rir_noise/")[-1].split(".")[0]
audio_duration = waveform.shape[0] / sr
ret = []
if split_chunks and audio_duration > self.chunk_duration: # Split into pieces of self.chunk_duration seconds.
uniq_chunks_list = self._get_chunks(self.chunk_duration, audio_id,
audio_duration)
for idx, chunk in enumerate(uniq_chunks_list):
s, e = chunk.split("_")[-2:] # Timestamps of start and end
start_sample = int(float(s) * sr)
end_sample = int(float(e) * sr)
new_wav_file = os.path.join(self.base_path,
audio_id + f'_chunk_{idx+1:02}.wav')
save_wav(waveform[start_sample:end_sample], sr, new_wav_file)
# id, duration, new_wav
ret.append([chunk, self.chunk_duration, new_wav_file])
else: # Keep whole audio.
ret.append([audio_id, audio_duration, wav_file])
return ret
def generate_csv(self,
wav_files: List[str],
output_file: str,
split_chunks: bool=True):
print(f'Generating csv: {output_file}')
header = ["id", "duration", "wav"]
infos = list(
tqdm(
map(self._get_audio_info, wav_files, [split_chunks] * len(
wav_files)),
total=len(wav_files)))
csv_lines = []
for info in infos:
csv_lines.extend(info)
with open(output_file, mode="w") as csv_f:
csv_writer = csv.writer(
csv_f, delimiter=",", quotechar='"', quoting=csv.QUOTE_MINIMAL)
csv_writer.writerow(header)
for line in csv_lines:
csv_writer.writerow(line)
def prepare_data(self):
rir_list = os.path.join(self.wav_path, "real_rirs_isotropic_noises",
"rir_list")
rir_files = []
with open(rir_list, 'r') as f:
for line in f.readlines():
rir_file = line.strip().split(' ')[-1]
rir_files.append(os.path.join(self.base_path, rir_file))
noise_list = os.path.join(self.wav_path, "pointsource_noises",
"noise_list")
noise_files = []
with open(noise_list, 'r') as f:
for line in f.readlines():
noise_file = line.strip().split(' ')[-1]
noise_files.append(os.path.join(self.base_path, noise_file))
self.generate_csv(rir_files, os.path.join(self.csv_path, 'rir.csv'))
self.generate_csv(noise_files, os.path.join(self.csv_path, 'noise.csv'))
def __getitem__(self, idx):
return self._convert_to_record(idx)
def __len__(self):
return len(self._data)

@ -0,0 +1,358 @@
# Copyright (c) 2021 PaddlePaddle Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import collections
import csv
import glob
import os
import random
from multiprocessing import cpu_count
from typing import Dict
from typing import List
from typing import Tuple
from paddle.io import Dataset
from pathos.multiprocessing import Pool
from tqdm import tqdm
from ..backends import load as load_audio
from ..utils import DATA_HOME
from ..utils import decompress
from ..utils.download import download_and_decompress
from .dataset import feat_funcs
__all__ = ['VoxCeleb']
class VoxCeleb(Dataset):
source_url = 'https://thor.robots.ox.ac.uk/~vgg/data/voxceleb/vox1a/'
archieves_audio_dev = [
{
'url': source_url + 'vox1_dev_wav_partaa',
'md5': 'e395d020928bc15670b570a21695ed96',
},
{
'url': source_url + 'vox1_dev_wav_partab',
'md5': 'bbfaaccefab65d82b21903e81a8a8020',
},
{
'url': source_url + 'vox1_dev_wav_partac',
'md5': '017d579a2a96a077f40042ec33e51512',
},
{
'url': source_url + 'vox1_dev_wav_partad',
'md5': '7bb1e9f70fddc7a678fa998ea8b3ba19',
},
]
archieves_audio_test = [
{
'url': source_url + 'vox1_test_wav.zip',
'md5': '185fdc63c3c739954633d50379a3d102',
},
]
archieves_meta = [
{
'url':
'https://www.robots.ox.ac.uk/~vgg/data/voxceleb/meta/veri_test2.txt',
'md5':
'b73110731c9223c1461fe49cb48dddfc',
},
]
num_speakers = 1211 # 1211 vox1, 5994 vox2, 7205 vox1+2, test speakers: 41
sample_rate = 16000
meta_info = collections.namedtuple(
'META_INFO', ('id', 'duration', 'wav', 'start', 'stop', 'spk_id'))
base_path = os.path.join(DATA_HOME, 'vox1')
wav_path = os.path.join(base_path, 'wav')
meta_path = os.path.join(base_path, 'meta')
veri_test_file = os.path.join(meta_path, 'veri_test2.txt')
csv_path = os.path.join(base_path, 'csv')
subsets = ['train', 'dev', 'enroll', 'test']
def __init__(
self,
subset: str='train',
feat_type: str='raw',
random_chunk: bool=True,
chunk_duration: float=3.0, # seconds
split_ratio: float=0.9, # train split ratio
seed: int=0,
target_dir: str=None,
vox2_base_path=None,
**kwargs):
"""VoxCeleb data prepare and get the specific dataset audio info
Args:
subset (str, optional): dataset name, such as train, dev, enroll or test. Defaults to 'train'.
feat_type (str, optional): feat type, such raw, melspectrogram(fbank) or mfcc . Defaults to 'raw'.
random_chunk (bool, optional): random select a duration from audio. Defaults to True.
chunk_duration (float, optional): chunk duration if random_chunk flag is set. Defaults to 3.0.
target_dir (str, optional): data dir, audio info will be stored in this directory. Defaults to None.
vox2_base_path (_type_, optional): vox2 directory. vox2 data must be converted from m4a to wav. Defaults to None.
"""
assert subset in self.subsets, \
'Dataset subset must be one in {}, but got {}'.format(self.subsets, subset)
self.subset = subset
self.spk_id2label = {}
self.feat_type = feat_type
self.feat_config = kwargs
self.random_chunk = random_chunk
self.chunk_duration = chunk_duration
self.split_ratio = split_ratio
self.target_dir = target_dir if target_dir else VoxCeleb.base_path
self.vox2_base_path = vox2_base_path
# if we set the target dir, we will change the vox data info data from base path to target dir
VoxCeleb.csv_path = os.path.join(
target_dir, "voxceleb", 'csv') if target_dir else VoxCeleb.csv_path
VoxCeleb.meta_path = os.path.join(
target_dir, "voxceleb",
'meta') if target_dir else VoxCeleb.meta_path
VoxCeleb.veri_test_file = os.path.join(VoxCeleb.meta_path,
'veri_test2.txt')
# self._data = self._get_data()[:1000] # KP: Small dataset test.
self._data = self._get_data()
super(VoxCeleb, self).__init__()
# Set up a seed to reproduce training or predicting result.
# random.seed(seed)
def _get_data(self):
# Download audio files.
# We need the users to decompress all vox1/dev/wav and vox1/test/wav/ to vox1/wav/ dir
# so, we check the vox1/wav dir status
print(f"wav base path: {self.wav_path}")
if not os.path.isdir(self.wav_path):
print(f"start to download the voxceleb1 dataset")
download_and_decompress( # multi-zip parts concatenate to vox1_dev_wav.zip
self.archieves_audio_dev,
self.base_path,
decompress=False)
download_and_decompress( # download the vox1_test_wav.zip and unzip
self.archieves_audio_test,
self.base_path,
decompress=True)
# Download all parts and concatenate the files into one zip file.
dev_zipfile = os.path.join(self.base_path, 'vox1_dev_wav.zip')
print(f'Concatenating all parts to: {dev_zipfile}')
os.system(
f'cat {os.path.join(self.base_path, "vox1_dev_wav_parta*")} > {dev_zipfile}'
)
# Extract all audio files of dev and test set.
decompress(dev_zipfile, self.base_path)
# Download meta files.
if not os.path.isdir(self.meta_path):
print("prepare the meta data")
download_and_decompress(
self.archieves_meta, self.meta_path, decompress=False)
# Data preparation.
if not os.path.isdir(self.csv_path):
os.makedirs(self.csv_path)
self.prepare_data()
data = []
print(
f"read the {self.subset} from {os.path.join(self.csv_path, f'{self.subset}.csv')}"
)
with open(os.path.join(self.csv_path, f'{self.subset}.csv'), 'r') as rf:
for line in rf.readlines()[1:]:
audio_id, duration, wav, start, stop, spk_id = line.strip(
).split(',')
data.append(
self.meta_info(audio_id,
float(duration), wav,
int(start), int(stop), spk_id))
with open(os.path.join(self.meta_path, 'spk_id2label.txt'), 'r') as f:
for line in f.readlines():
spk_id, label = line.strip().split(' ')
self.spk_id2label[spk_id] = int(label)
return data
def _convert_to_record(self, idx: int):
sample = self._data[idx]
record = {}
# To show all fields in a namedtuple: `type(sample)._fields`
for field in type(sample)._fields:
record[field] = getattr(sample, field)
waveform, sr = load_audio(record['wav'])
# random select a chunk audio samples from the audio
if self.random_chunk:
num_wav_samples = waveform.shape[0]
num_chunk_samples = int(self.chunk_duration * sr)
start = random.randint(0, num_wav_samples - num_chunk_samples - 1)
stop = start + num_chunk_samples
else:
start = record['start']
stop = record['stop']
waveform = waveform[start:stop]
assert self.feat_type in feat_funcs.keys(), \
f"Unknown feat_type: {self.feat_type}, it must be one in {list(feat_funcs.keys())}"
feat_func = feat_funcs[self.feat_type]
feat = feat_func(
waveform, sr=sr, **self.feat_config) if feat_func else waveform
record.update({'feat': feat})
if self.subset in ['train',
'dev']: # Labels are available in train and dev.
record.update({'label': self.spk_id2label[record['spk_id']]})
return record
@staticmethod
def _get_chunks(seg_dur, audio_id, audio_duration):
num_chunks = int(audio_duration / seg_dur) # all in milliseconds
chunk_lst = [
audio_id + "_" + str(i * seg_dur) + "_" + str(i * seg_dur + seg_dur)
for i in range(num_chunks)
]
return chunk_lst
def _get_audio_info(self, wav_file: str,
split_chunks: bool) -> List[List[str]]:
waveform, sr = load_audio(wav_file)
spk_id, sess_id, utt_id = wav_file.split("/")[-3:]
audio_id = '-'.join([spk_id, sess_id, utt_id.split(".")[0]])
audio_duration = waveform.shape[0] / sr
ret = []
if split_chunks: # Split into pieces of self.chunk_duration seconds.
uniq_chunks_list = self._get_chunks(self.chunk_duration, audio_id,
audio_duration)
for chunk in uniq_chunks_list:
s, e = chunk.split("_")[-2:] # Timestamps of start and end
start_sample = int(float(s) * sr)
end_sample = int(float(e) * sr)
# id, duration, wav, start, stop, spk_id
ret.append([
chunk, audio_duration, wav_file, start_sample, end_sample,
spk_id
])
else: # Keep whole audio.
ret.append([
audio_id, audio_duration, wav_file, 0, waveform.shape[0], spk_id
])
return ret
def generate_csv(self,
wav_files: List[str],
output_file: str,
split_chunks: bool=True):
print(f'Generating csv: {output_file}')
header = ["ID", "duration", "wav", "start", "stop", "spk_id"]
# Note: this may occurs c++ execption, but the program will execute fine
# so we can ignore the execption
with Pool(cpu_count()) as p:
infos = list(
tqdm(
p.imap(lambda x: self._get_audio_info(x, split_chunks),
wav_files),
total=len(wav_files)))
csv_lines = []
for info in infos:
csv_lines.extend(info)
with open(output_file, mode="w") as csv_f:
csv_writer = csv.writer(
csv_f, delimiter=",", quotechar='"', quoting=csv.QUOTE_MINIMAL)
csv_writer.writerow(header)
for line in csv_lines:
csv_writer.writerow(line)
def prepare_data(self):
# Audio of speakers in veri_test_file should not be included in training set.
print("start to prepare the data csv file")
enroll_files = set()
test_files = set()
# get the enroll and test audio file path
with open(self.veri_test_file, 'r') as f:
for line in f.readlines():
_, enrol_file, test_file = line.strip().split(' ')
enroll_files.add(os.path.join(self.wav_path, enrol_file))
test_files.add(os.path.join(self.wav_path, test_file))
enroll_files = sorted(enroll_files)
test_files = sorted(test_files)
# get the enroll and test speakers
test_spks = set()
for file in (enroll_files + test_files):
spk = file.split('/wav/')[1].split('/')[0]
test_spks.add(spk)
# get all the train and dev audios file path
audio_files = []
speakers = set()
print("Getting file list...")
for path in [self.wav_path, self.vox2_base_path]:
# if vox2 directory is not set and vox2 is not a directory
# we will not process this directory
if not path or not os.path.exists(path):
print(f"{path} is an invalid path, please check again, "
"and we will ignore the vox2 base path")
continue
for file in glob.glob(
os.path.join(path, "**", "*.wav"), recursive=True):
spk = file.split('/wav/')[1].split('/')[0]
if spk in test_spks:
continue
speakers.add(spk)
audio_files.append(file)
print(
f"start to generate the {os.path.join(self.meta_path, 'spk_id2label.txt')}"
)
# encode the train and dev speakers label to spk_id2label.txt
with open(os.path.join(self.meta_path, 'spk_id2label.txt'), 'w') as f:
for label, spk_id in enumerate(
sorted(speakers)): # 1211 vox1, 5994 vox2, 7205 vox1+2
f.write(f'{spk_id} {label}\n')
audio_files = sorted(audio_files)
random.shuffle(audio_files)
split_idx = int(self.split_ratio * len(audio_files))
# split_ratio to train
train_files, dev_files = audio_files[:split_idx], audio_files[
split_idx:]
self.generate_csv(train_files, os.path.join(self.csv_path, 'train.csv'))
self.generate_csv(dev_files, os.path.join(self.csv_path, 'dev.csv'))
self.generate_csv(
enroll_files,
os.path.join(self.csv_path, 'enroll.csv'),
split_chunks=False)
self.generate_csv(
test_files,
os.path.join(self.csv_path, 'test.csv'),
split_chunks=False)
def __getitem__(self, idx):
return self._convert_to_record(idx)
def __len__(self):
return len(self._data)

@ -12,4 +12,6 @@
# See the License for the specific language governing permissions and
# limitations under the License.
from .dtw import dtw_distance
from .eer import compute_eer
from .eer import compute_minDCF
from .mcd import mcd_distance

@ -0,0 +1,100 @@
# Copyright (c) 2021 PaddlePaddle Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
from typing import List
import numpy as np
import paddle
from sklearn.metrics import roc_curve
def compute_eer(labels: np.ndarray, scores: np.ndarray) -> List[float]:
"""Compute EER and return score threshold.
Args:
labels (np.ndarray): the trial label, shape: [N], one-dimention, N refer to the samples num
scores (np.ndarray): the trial scores, shape: [N], one-dimention, N refer to the samples num
Returns:
List[float]: eer and the specific threshold
"""
fpr, tpr, threshold = roc_curve(y_true=labels, y_score=scores)
fnr = 1 - tpr
eer_threshold = threshold[np.nanargmin(np.absolute((fnr - fpr)))]
eer = fpr[np.nanargmin(np.absolute((fnr - fpr)))]
return eer, eer_threshold
def compute_minDCF(positive_scores,
negative_scores,
c_miss=1.0,
c_fa=1.0,
p_target=0.01):
"""
This is modified from SpeechBrain
https://github.com/speechbrain/speechbrain/blob/085be635c07f16d42cd1295045bc46c407f1e15b/speechbrain/utils/metric_stats.py#L509
Computes the minDCF metric normally used to evaluate speaker verification
systems. The min_DCF is the minimum of the following C_det function computed
within the defined threshold range:
C_det = c_miss * p_miss * p_target + c_fa * p_fa * (1 -p_target)
where p_miss is the missing probability and p_fa is the probability of having
a false alarm.
Args:
positive_scores (Paddle.Tensor): The scores from entries of the same class.
negative_scores (Paddle.Tensor): The scores from entries of different classes.
c_miss (float, optional): Cost assigned to a missing error (default 1.0).
c_fa (float, optional): Cost assigned to a false alarm (default 1.0).
p_target (float, optional): Prior probability of having a target (default 0.01).
Returns:
List[float]: min dcf and the specific threshold
"""
# Computing candidate thresholds
if len(positive_scores.shape) > 1:
positive_scores = positive_scores.squeeze()
if len(negative_scores.shape) > 1:
negative_scores = negative_scores.squeeze()
thresholds = paddle.sort(paddle.concat([positive_scores, negative_scores]))
thresholds = paddle.unique(thresholds)
# Adding intermediate thresholds
interm_thresholds = (thresholds[0:-1] + thresholds[1:]) / 2
thresholds = paddle.sort(paddle.concat([thresholds, interm_thresholds]))
# Computing False Rejection Rate (miss detection)
positive_scores = paddle.concat(
len(thresholds) * [positive_scores.unsqueeze(0)])
pos_scores_threshold = positive_scores.transpose(perm=[1, 0]) <= thresholds
p_miss = (pos_scores_threshold.sum(0)
).astype("float32") / positive_scores.shape[1]
del positive_scores
del pos_scores_threshold
# Computing False Acceptance Rate (false alarm)
negative_scores = paddle.concat(
len(thresholds) * [negative_scores.unsqueeze(0)])
neg_scores_threshold = negative_scores.transpose(perm=[1, 0]) > thresholds
p_fa = (neg_scores_threshold.sum(0)
).astype("float32") / negative_scores.shape[1]
del negative_scores
del neg_scores_threshold
c_det = c_miss * p_miss * p_target + c_fa * p_fa * (1 - p_target)
c_min = paddle.min(c_det, axis=0)
min_index = paddle.argmin(c_det, axis=0)
return float(c_min), float(thresholds[min_index])

@ -37,7 +37,9 @@ def decompress(file: str):
download._decompress(file)
def download_and_decompress(archives: List[Dict[str, str]], path: str):
def download_and_decompress(archives: List[Dict[str, str]],
path: str,
decompress: bool=True):
"""
Download archieves and decompress to specific path.
"""
@ -47,8 +49,8 @@ def download_and_decompress(archives: List[Dict[str, str]], path: str):
for archive in archives:
assert 'url' in archive and 'md5' in archive, \
'Dictionary keys of "url" and "md5" are required in the archive, but got: {list(archieve.keys())}'
download.get_path_from_url(archive['url'], path, archive['md5'])
download.get_path_from_url(
archive['url'], path, archive['md5'], decompress=decompress)
def load_state_dict_from_url(url: str, path: str, md5: str=None):

@ -21,5 +21,6 @@ from .st import STExecutor
from .stats import StatsExecutor
from .text import TextExecutor
from .tts import TTSExecutor
from .vector import VectorExecutor
_locale._getdefaultlocale = (lambda *args: ['en_US', 'utf8'])

@ -0,0 +1,14 @@
# Copyright (c) 2021 PaddlePaddle Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
from .infer import VectorExecutor

@ -0,0 +1,354 @@
# Copyright (c) 2021 PaddlePaddle Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import argparse
import os
import sys
from collections import OrderedDict
from typing import List
from typing import Optional
from typing import Union
import paddle
import soundfile
from yacs.config import CfgNode
from ..executor import BaseExecutor
from ..log import logger
from ..utils import cli_register
from ..utils import download_and_decompress
from ..utils import MODEL_HOME
from ..utils import stats_wrapper
from paddleaudio.backends import load as load_audio
from paddleaudio.compliance.librosa import melspectrogram
from paddlespeech.s2t.utils.dynamic_import import dynamic_import
from paddlespeech.vector.io.batch import feature_normalize
from paddlespeech.vector.modules.sid_model import SpeakerIdetification
pretrained_models = {
# The tags for pretrained_models should be "{model_name}[-{dataset}][-{sr}][-...]".
# e.g. "ecapatdnn_voxceleb12-16k".
# Command line and python api use "{model_name}[-{dataset}]" as --model, usage:
# "paddlespeech vector --task spk --model ecapatdnn_voxceleb12-16k --sr 16000 --input ./input.wav"
"ecapatdnn_voxceleb12-16k": {
'url':
'https://paddlespeech.bj.bcebos.com/vector/voxceleb/sv0_ecapa_tdnn_voxceleb12_ckpt_0_1_0.tar.gz',
'md5':
'85ff08ce0ef406b8c6d7b5ffc5b2b48f',
'cfg_path':
'conf/model.yaml',
'ckpt_path':
'model/model',
},
}
model_alias = {
"ecapatdnn": "paddlespeech.vector.models.ecapa_tdnn:EcapaTdnn",
}
@cli_register(
name="paddlespeech.vector",
description="Speech to vector embedding infer command.")
class VectorExecutor(BaseExecutor):
def __init__(self):
super(VectorExecutor, self).__init__()
self.parser = argparse.ArgumentParser(
prog="paddlespeech.vector", add_help=True)
self.parser.add_argument(
"--model",
type=str,
default="ecapatdnn_voxceleb12",
choices=["ecapatdnn_voxceleb12"],
help="Choose model type of asr task.")
self.parser.add_argument(
"--task",
type=str,
default="spk",
choices=["spk"],
help="task type in vector domain")
self.parser.add_argument(
"--input", type=str, default=None, help="Audio file to recognize.")
self.parser.add_argument(
"--sample_rate",
type=int,
default=16000,
choices=[16000],
help="Choose the audio sample rate of the model. 8000 or 16000")
self.parser.add_argument(
"--ckpt_path",
type=str,
default=None,
help="Checkpoint file of model.")
self.parser.add_argument(
'--config',
type=str,
default=None,
help='Config of asr task. Use deault config when it is None.')
self.parser.add_argument(
"--device",
type=str,
default=paddle.get_device(),
help="Choose device to execute model inference.")
self.parser.add_argument(
'-d',
'--job_dump_result',
action='store_true',
help='Save job result into file.')
self.parser.add_argument(
'-v',
'--verbose',
action='store_true',
help='Increase logger verbosity of current task.')
def execute(self, argv: List[str]) -> bool:
"""Command line entry for vector model
Args:
argv (List[str]): command line args list
Returns:
bool:
False: some audio occurs error
True: all audio process success
"""
# stage 0: parse the args and get the required args
parser_args = self.parser.parse_args(argv)
model = parser_args.model
sample_rate = parser_args.sample_rate
config = parser_args.config
ckpt_path = parser_args.ckpt_path
device = parser_args.device
# stage 1: configurate the verbose flag
if not parser_args.verbose:
self.disable_task_loggers()
# stage 2: read the input data and store them as a list
task_source = self.get_task_source(parser_args.input)
logger.info(f"task source: {task_source}")
# stage 3: process the audio one by one
task_result = OrderedDict()
has_exceptions = False
for id_, input_ in task_source.items():
try:
res = self(input_, model, sample_rate, config, ckpt_path,
device)
task_result[id_] = res
except Exception as e:
has_exceptions = True
task_result[id_] = f'{e.__class__.__name__}: {e}'
logger.info("task result as follows: ")
logger.info(f"{task_result}")
# stage 4: process the all the task results
self.process_task_results(parser_args.input, task_result,
parser_args.job_dump_result)
# stage 5: return the exception flag
# if return False, somen audio process occurs error
if has_exceptions:
return False
else:
return True
@stats_wrapper
def __call__(self,
audio_file: os.PathLike,
model: str='ecapatdnn-voxceleb12',
sample_rate: int=16000,
config: os.PathLike=None,
ckpt_path: os.PathLike=None,
force_yes: bool=False,
device=paddle.get_device()):
audio_file = os.path.abspath(audio_file)
if not self._check(audio_file, sample_rate):
sys.exit(-1)
logger.info(f"device type: {device}")
paddle.device.set_device(device)
self._init_from_path(model, sample_rate, config, ckpt_path)
self.preprocess(model, audio_file)
self.infer(model)
res = self.postprocess()
return res
def _get_pretrained_path(self, tag: str) -> os.PathLike:
support_models = list(pretrained_models.keys())
assert tag in pretrained_models, \
'The model "{}" you want to use has not been supported,'\
'please choose other models.\n' \
'The support models includes\n\t\t{}'.format(tag, "\n\t\t".join(support_models))
res_path = os.path.join(MODEL_HOME, tag)
decompressed_path = download_and_decompress(pretrained_models[tag],
res_path)
decompressed_path = os.path.abspath(decompressed_path)
logger.info(
'Use pretrained model stored in: {}'.format(decompressed_path))
return decompressed_path
def _init_from_path(self,
model_type: str='ecapatdnn_voxceleb12',
sample_rate: int=16000,
cfg_path: Optional[os.PathLike]=None,
ckpt_path: Optional[os.PathLike]=None):
if hasattr(self, "model"):
logger.info("Model has been initialized")
return
# stage 1: get the model and config path
if cfg_path is None or ckpt_path is None:
sample_rate_str = "16k" if sample_rate == 16000 else "8k"
tag = model_type + "-" + sample_rate_str
logger.info(f"load the pretrained model: {tag}")
res_path = self._get_pretrained_path(tag)
self.res_path = res_path
self.cfg_path = os.path.join(res_path,
pretrained_models[tag]['cfg_path'])
self.ckpt_path = os.path.join(
res_path, pretrained_models[tag]['ckpt_path'] + '.pdparams')
else:
self.cfg_path = os.path.abspath(cfg_path)
self.ckpt_path = os.path.abspath(ckpt_path + ".pdparams")
self.res_path = os.path.dirname(
os.path.dirname(os.path.abspath(self.cfg_path)))
logger.info(f"start to read the ckpt from {self.ckpt_path}")
logger.info(f"read the config from {self.cfg_path}")
logger.info(f"get the res path {self.res_path}")
# stage 2: read and config and init the model body
self.config = CfgNode(new_allowed=True)
self.config.merge_from_file(self.cfg_path)
# stage 3: get the model name to instance the model network with dynamic_import
# Noet: we use the '-' to get the model name instead of '_'
logger.info("start to dynamic import the model class")
model_name = model_type[:model_type.rindex('_')]
logger.info(f"model name {model_name}")
model_class = dynamic_import(model_name, model_alias)
model_conf = self.config.model
backbone = model_class(**model_conf)
model = SpeakerIdetification(
backbone=backbone, num_class=self.config.num_speakers)
self.model = model
self.model.eval()
# stage 4: load the model parameters
logger.info("start to set the model parameters to model")
model_dict = paddle.load(self.ckpt_path)
self.model.set_state_dict(model_dict)
logger.info("create the model instance success")
@paddle.no_grad()
def infer(self, model_type: str):
feats = self._inputs["feats"]
lengths = self._inputs["lengths"]
logger.info("start to do backbone network model forward")
logger.info(
f"feats shape:{feats.shape}, lengths shape: {lengths.shape}")
# embedding from (1, emb_size, 1) -> (emb_size)
embedding = self.model.backbone(feats, lengths).squeeze().numpy()
logger.info(f"embedding size: {embedding.shape}")
self._outputs["embedding"] = embedding
def postprocess(self) -> Union[str, os.PathLike]:
return self._outputs["embedding"]
def preprocess(self, model_type: str, input_file: Union[str, os.PathLike]):
audio_file = input_file
if isinstance(audio_file, (str, os.PathLike)):
logger.info(f"Preprocess audio file: {audio_file}")
# stage 1: load the audio
waveform, sr = load_audio(audio_file)
logger.info(f"load the audio sample points, shape is: {waveform.shape}")
# stage 2: get the audio feat
try:
feat = melspectrogram(
x=waveform,
sr=self.config.sr,
n_mels=self.config.n_mels,
window_size=self.config.window_size,
hop_length=self.config.hop_size)
logger.info(f"extract the audio feat, shape is: {feat.shape}")
except Exception as e:
logger.info(f"feat occurs exception {e}")
sys.exit(-1)
feat = paddle.to_tensor(feat).unsqueeze(0)
# in inference period, the lengths is all one without padding
lengths = paddle.ones([1])
feat = feature_normalize(feat, mean_norm=True, std_norm=False)
logger.info(f"feats shape: {feat.shape}")
self._inputs["feats"] = feat
self._inputs["lengths"] = lengths
logger.info("audio extract the feat success")
def _check(self, audio_file: str, sample_rate: int):
self.sample_rate = sample_rate
if self.sample_rate != 16000 and self.sample_rate != 8000:
logger.error(
"invalid sample rate, please input --sr 8000 or --sr 16000")
return False
if isinstance(audio_file, (str, os.PathLike)):
if not os.path.isfile(audio_file):
logger.error("Please input the right audio file path")
return False
logger.info("checking the aduio file format......")
try:
audio, audio_sample_rate = soundfile.read(
audio_file, dtype="float32", always_2d=True)
except Exception as e:
logger.exception(e)
logger.error(
"can not open the audio file, please check the audio file format is 'wav'. \n \
you can try to use sox to change the file format.\n \
For example: \n \
sample rate: 16k \n \
sox input_audio.xx --rate 16k --bits 16 --channels 1 output_audio.wav \n \
sample rate: 8k \n \
sox input_audio.xx --rate 8k --bits 16 --channels 1 output_audio.wav \n \
")
return False
logger.info(f"The sample rate is {audio_sample_rate}")
if audio_sample_rate != self.sample_rate:
logger.error("The sample rate of the input file is not {}.\n \
The program will resample the wav file to {}.\n \
If the result does not meet your expectations\n \
Please input the 16k 16 bit 1 channel wav file. \
".format(self.sample_rate, self.sample_rate))
sys.exit(-1)
else:
logger.info("The audio file format is right")
return True

@ -0,0 +1,119 @@
# Copyright (c) 2021 PaddlePaddle Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import argparse
import os
import time
import paddle
from yacs.config import CfgNode
from paddleaudio.backends import load as load_audio
from paddleaudio.compliance.librosa import melspectrogram
from paddlespeech.s2t.utils.log import Log
from paddlespeech.vector.io.batch import feature_normalize
from paddlespeech.vector.models.ecapa_tdnn import EcapaTdnn
from paddlespeech.vector.modules.sid_model import SpeakerIdetification
from paddlespeech.vector.training.seeding import seed_everything
logger = Log(__name__).getlog()
def extract_audio_embedding(args, config):
# stage 0: set the training device, cpu or gpu
paddle.set_device(args.device)
# set the random seed, it is a must for multiprocess training
seed_everything(config.seed)
# stage 1: build the dnn backbone model network
ecapa_tdnn = EcapaTdnn(**config.model)
# stage4: build the speaker verification train instance with backbone model
model = SpeakerIdetification(
backbone=ecapa_tdnn, num_class=config.num_speakers)
# stage 2: load the pre-trained model
args.load_checkpoint = os.path.abspath(
os.path.expanduser(args.load_checkpoint))
# load model checkpoint to sid model
state_dict = paddle.load(
os.path.join(args.load_checkpoint, 'model.pdparams'))
model.set_state_dict(state_dict)
logger.info(f'Checkpoint loaded from {args.load_checkpoint}')
# stage 3: we must set the model to eval mode
model.eval()
# stage 4: read the audio data and extract the embedding
# wavform is one dimension numpy array
waveform, sr = load_audio(args.audio_path)
# feat type is numpy array, whose shape is [dim, time]
# we need convert the audio feat to one-batch shape [batch, dim, time], where the batch is one
# so the final shape is [1, dim, time]
start_time = time.time()
feat = melspectrogram(
x=waveform,
sr=config.sr,
n_mels=config.n_mels,
window_size=config.window_size,
hop_length=config.hop_size)
feat = paddle.to_tensor(feat).unsqueeze(0)
# in inference period, the lengths is all one without padding
lengths = paddle.ones([1])
feat = feature_normalize(feat, mean_norm=True, std_norm=False)
# model backbone network forward the feats and get the embedding
embedding = model.backbone(
feat, lengths).squeeze().numpy() # (1, emb_size, 1) -> (emb_size)
elapsed_time = time.time() - start_time
audio_length = waveform.shape[0] / sr
# stage 5: do global norm with external mean and std
rtf = elapsed_time / audio_length
logger.info(f"{args.device} rft={rtf}")
return embedding
if __name__ == "__main__":
# yapf: disable
parser = argparse.ArgumentParser(__doc__)
parser.add_argument('--device',
choices=['cpu', 'gpu'],
default="cpu",
help="Select which device to train model, defaults to gpu.")
parser.add_argument("--config",
default=None,
type=str,
help="configuration file")
parser.add_argument("--load-checkpoint",
type=str,
default='',
help="Directory to load model checkpoint to contiune trainning.")
parser.add_argument("--audio-path",
default="./data/demo.wav",
type=str,
help="Single audio file path")
args = parser.parse_args()
# yapf: enable
# https://yaml.org/type/float.html
config = CfgNode(new_allowed=True)
if args.config:
config.merge_from_file(args.config)
config.freeze()
print(config)
extract_audio_embedding(args, config)

@ -0,0 +1,205 @@
# Copyright (c) 2021 PaddlePaddle Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import argparse
import ast
import os
import numpy as np
import paddle
import paddle.nn.functional as F
from paddle.io import BatchSampler
from paddle.io import DataLoader
from tqdm import tqdm
from yacs.config import CfgNode
from paddleaudio.datasets import VoxCeleb
from paddleaudio.metric import compute_eer
from paddlespeech.s2t.utils.log import Log
from paddlespeech.vector.io.batch import batch_feature_normalize
from paddlespeech.vector.models.ecapa_tdnn import EcapaTdnn
from paddlespeech.vector.modules.sid_model import SpeakerIdetification
from paddlespeech.vector.training.seeding import seed_everything
logger = Log(__name__).getlog()
def main(args, config):
# stage0: set the training device, cpu or gpu
paddle.set_device(args.device)
# set the random seed, it is a must for multiprocess training
seed_everything(config.seed)
# stage1: build the dnn backbone model network
ecapa_tdnn = EcapaTdnn(**config.model)
# stage2: build the speaker verification eval instance with backbone model
model = SpeakerIdetification(
backbone=ecapa_tdnn, num_class=config.num_speakers)
# stage3: load the pre-trained model
# we get the last model from the epoch and save_interval
args.load_checkpoint = os.path.abspath(
os.path.expanduser(args.load_checkpoint))
# load model checkpoint to sid model
state_dict = paddle.load(
os.path.join(args.load_checkpoint, 'model.pdparams'))
model.set_state_dict(state_dict)
logger.info(f'Checkpoint loaded from {args.load_checkpoint}')
# stage4: construct the enroll and test dataloader
enroll_dataset = VoxCeleb(
subset='enroll',
target_dir=args.data_dir,
feat_type='melspectrogram',
random_chunk=False,
n_mels=config.n_mels,
window_size=config.window_size,
hop_length=config.hop_size)
enroll_sampler = BatchSampler(
enroll_dataset, batch_size=config.batch_size,
shuffle=True) # Shuffle to make embedding normalization more robust.
enrol_loader = DataLoader(enroll_dataset,
batch_sampler=enroll_sampler,
collate_fn=lambda x: batch_feature_normalize(
x, mean_norm=True, std_norm=False),
num_workers=config.num_workers,
return_list=True,)
test_dataset = VoxCeleb(
subset='test',
target_dir=args.data_dir,
feat_type='melspectrogram',
random_chunk=False,
n_mels=config.n_mels,
window_size=config.window_size,
hop_length=config.hop_size)
test_sampler = BatchSampler(
test_dataset, batch_size=config.batch_size, shuffle=True)
test_loader = DataLoader(test_dataset,
batch_sampler=test_sampler,
collate_fn=lambda x: batch_feature_normalize(
x, mean_norm=True, std_norm=False),
num_workers=config.num_workers,
return_list=True,)
# stage5: we must set the model to eval mode
model.eval()
# stage6: global embedding norm to imporve the performance
logger.info(f"global embedding norm: {config.global_embedding_norm}")
if config.global_embedding_norm:
global_embedding_mean = None
global_embedding_std = None
mean_norm_flag = config.embedding_mean_norm
std_norm_flag = config.embedding_std_norm
batch_count = 0
# stage7: Compute embeddings of audios in enrol and test dataset from model.
id2embedding = {}
# Run multi times to make embedding normalization more stable.
for i in range(2):
for dl in [enrol_loader, test_loader]:
logger.info(
f'Loop {[i+1]}: Computing embeddings on {dl.dataset.subset} dataset'
)
with paddle.no_grad():
for batch_idx, batch in enumerate(tqdm(dl)):
# stage 8-1: extrac the audio embedding
ids, feats, lengths = batch['ids'], batch['feats'], batch[
'lengths']
embeddings = model.backbone(feats, lengths).squeeze(
-1).numpy() # (N, emb_size, 1) -> (N, emb_size)
# Global embedding normalization.
# if we use the global embedding norm
# eer can reduece about relative 10%
if config.global_embedding_norm:
batch_count += 1
current_mean = embeddings.mean(
axis=0) if mean_norm_flag else 0
current_std = embeddings.std(
axis=0) if std_norm_flag else 1
# Update global mean and std.
if global_embedding_mean is None and global_embedding_std is None:
global_embedding_mean, global_embedding_std = current_mean, current_std
else:
weight = 1 / batch_count # Weight decay by batches.
global_embedding_mean = (
1 - weight
) * global_embedding_mean + weight * current_mean
global_embedding_std = (
1 - weight
) * global_embedding_std + weight * current_std
# Apply global embedding normalization.
embeddings = (embeddings - global_embedding_mean
) / global_embedding_std
# Update embedding dict.
id2embedding.update(dict(zip(ids, embeddings)))
# stage 8: Compute cosine scores.
labels = []
enroll_ids = []
test_ids = []
logger.info(f"read the trial from {VoxCeleb.veri_test_file}")
with open(VoxCeleb.veri_test_file, 'r') as f:
for line in f.readlines():
label, enroll_id, test_id = line.strip().split(' ')
labels.append(int(label))
enroll_ids.append(enroll_id.split('.')[0].replace('/', '-'))
test_ids.append(test_id.split('.')[0].replace('/', '-'))
cos_sim_func = paddle.nn.CosineSimilarity(axis=1)
enrol_embeddings, test_embeddings = map(lambda ids: paddle.to_tensor(
np.asarray([id2embedding[uttid] for uttid in ids], dtype='float32')),
[enroll_ids, test_ids
]) # (N, emb_size)
scores = cos_sim_func(enrol_embeddings, test_embeddings)
EER, threshold = compute_eer(np.asarray(labels), scores.numpy())
logger.info(
f'EER of verification test: {EER*100:.4f}%, score threshold: {threshold:.5f}'
)
if __name__ == "__main__":
# yapf: disable
parser = argparse.ArgumentParser(__doc__)
parser.add_argument('--device',
choices=['cpu', 'gpu'],
default="gpu",
help="Select which device to train model, defaults to gpu.")
parser.add_argument("--config",
default=None,
type=str,
help="configuration file")
parser.add_argument("--data-dir",
default="./data/",
type=str,
help="data directory")
parser.add_argument("--load-checkpoint",
type=str,
default='',
help="Directory to load model checkpoint to contiune trainning.")
args = parser.parse_args()
# yapf: enable
# https://yaml.org/type/float.html
config = CfgNode(new_allowed=True)
if args.config:
config.merge_from_file(args.config)
config.freeze()
print(config)
main(args, config)

@ -0,0 +1,351 @@
# Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import argparse
import os
import time
import numpy as np
import paddle
from paddle.io import BatchSampler
from paddle.io import DataLoader
from paddle.io import DistributedBatchSampler
from yacs.config import CfgNode
from paddleaudio.compliance.librosa import melspectrogram
from paddleaudio.datasets.voxceleb import VoxCeleb
from paddlespeech.s2t.utils.log import Log
from paddlespeech.vector.io.augment import build_augment_pipeline
from paddlespeech.vector.io.augment import waveform_augment
from paddlespeech.vector.io.batch import batch_pad_right
from paddlespeech.vector.io.batch import feature_normalize
from paddlespeech.vector.io.batch import waveform_collate_fn
from paddlespeech.vector.models.ecapa_tdnn import EcapaTdnn
from paddlespeech.vector.modules.loss import AdditiveAngularMargin
from paddlespeech.vector.modules.loss import LogSoftmaxWrapper
from paddlespeech.vector.modules.sid_model import SpeakerIdetification
from paddlespeech.vector.training.scheduler import CyclicLRScheduler
from paddlespeech.vector.training.seeding import seed_everything
from paddlespeech.vector.utils.time import Timer
logger = Log(__name__).getlog()
def main(args, config):
# stage0: set the training device, cpu or gpu
paddle.set_device(args.device)
# stage1: we must call the paddle.distributed.init_parallel_env() api at the begining
paddle.distributed.init_parallel_env()
nranks = paddle.distributed.get_world_size()
local_rank = paddle.distributed.get_rank()
# set the random seed, it is a must for multiprocess training
seed_everything(config.seed)
# stage2: data prepare, such vox1 and vox2 data, and augment noise data and pipline
# note: some cmd must do in rank==0, so wo will refactor the data prepare code
train_dataset = VoxCeleb('train', target_dir=args.data_dir)
dev_dataset = VoxCeleb('dev', target_dir=args.data_dir)
if config.augment:
augment_pipeline = build_augment_pipeline(target_dir=args.data_dir)
else:
augment_pipeline = []
# stage3: build the dnn backbone model network
ecapa_tdnn = EcapaTdnn(**config.model)
# stage4: build the speaker verification train instance with backbone model
model = SpeakerIdetification(
backbone=ecapa_tdnn, num_class=VoxCeleb.num_speakers)
# stage5: build the optimizer, we now only construct the AdamW optimizer
# 140000 is single gpu steps
# so, in multi-gpu mode, wo reduce the step_size to 140000//nranks to enable CyclicLRScheduler
lr_schedule = CyclicLRScheduler(
base_lr=config.learning_rate, max_lr=1e-3, step_size=140000 // nranks)
optimizer = paddle.optimizer.AdamW(
learning_rate=lr_schedule, parameters=model.parameters())
# stage6: build the loss function, we now only support LogSoftmaxWrapper
criterion = LogSoftmaxWrapper(
loss_fn=AdditiveAngularMargin(margin=0.2, scale=30))
# stage7: confirm training start epoch
# if pre-trained model exists, start epoch confirmed by the pre-trained model
start_epoch = 0
if args.load_checkpoint:
logger.info("load the check point")
args.load_checkpoint = os.path.abspath(
os.path.expanduser(args.load_checkpoint))
try:
# load model checkpoint
state_dict = paddle.load(
os.path.join(args.load_checkpoint, 'model.pdparams'))
model.set_state_dict(state_dict)
# load optimizer checkpoint
state_dict = paddle.load(
os.path.join(args.load_checkpoint, 'model.pdopt'))
optimizer.set_state_dict(state_dict)
if local_rank == 0:
logger.info(f'Checkpoint loaded from {args.load_checkpoint}')
except FileExistsError:
if local_rank == 0:
logger.info('Train from scratch.')
try:
start_epoch = int(args.load_checkpoint[-1])
logger.info(f'Restore training from epoch {start_epoch}.')
except ValueError:
pass
# stage8: we build the batch sampler for paddle.DataLoader
train_sampler = DistributedBatchSampler(
train_dataset,
batch_size=config.batch_size,
shuffle=True,
drop_last=False)
train_loader = DataLoader(
train_dataset,
batch_sampler=train_sampler,
num_workers=config.num_workers,
collate_fn=waveform_collate_fn,
return_list=True,
use_buffer_reader=True, )
# stage9: start to train
# we will comment the training process
steps_per_epoch = len(train_sampler)
timer = Timer(steps_per_epoch * config.epochs)
last_saved_epoch = ""
timer.start()
for epoch in range(start_epoch + 1, config.epochs + 1):
# at the begining, model must set to train mode
model.train()
avg_loss = 0
num_corrects = 0
num_samples = 0
train_reader_cost = 0.0
train_feat_cost = 0.0
train_run_cost = 0.0
reader_start = time.time()
for batch_idx, batch in enumerate(train_loader):
train_reader_cost += time.time() - reader_start
# stage 9-1: batch data is audio sample points and speaker id label
feat_start = time.time()
waveforms, labels = batch['waveforms'], batch['labels']
waveforms, lengths = batch_pad_right(waveforms.numpy())
waveforms = paddle.to_tensor(waveforms)
# stage 9-2: audio sample augment method, which is done on the audio sample point
# the original wavefrom and the augmented waveform is concatented in a batch
# eg. five augment method in the augment pipeline
# the final data nums is batch_size * [five + one]
# -> five augmented waveform batch plus one original batch waveform
if len(augment_pipeline) != 0:
waveforms = waveform_augment(waveforms, augment_pipeline)
labels = paddle.concat(
[labels for i in range(len(augment_pipeline) + 1)])
# stage 9-3: extract the audio feats,such fbank, mfcc, spectrogram
feats = []
for waveform in waveforms.numpy():
feat = melspectrogram(
x=waveform,
sr=config.sr,
n_mels=config.n_mels,
window_size=config.window_size,
hop_length=config.hop_size)
feats.append(feat)
feats = paddle.to_tensor(np.asarray(feats))
# stage 9-4: feature normalize, which help converge and imporve the performance
feats = feature_normalize(
feats, mean_norm=True, std_norm=False) # Features normalization
train_feat_cost += time.time() - feat_start
# stage 9-5: model forward, such ecapa-tdnn, x-vector
train_start = time.time()
logits = model(feats)
# stage 9-6: loss function criterion, such AngularMargin, AdditiveAngularMargin
loss = criterion(logits, labels)
# stage 9-7: update the gradient and clear the gradient cache
loss.backward()
optimizer.step()
if isinstance(optimizer._learning_rate,
paddle.optimizer.lr.LRScheduler):
optimizer._learning_rate.step()
optimizer.clear_grad()
train_run_cost += time.time() - train_start
# stage 9-8: Calculate average loss per batch
avg_loss += loss.numpy()[0]
# stage 9-9: Calculate metrics, which is one-best accuracy
preds = paddle.argmax(logits, axis=1)
num_corrects += (preds == labels).numpy().sum()
num_samples += feats.shape[0]
timer.count() # step plus one in timer
# stage 9-10: print the log information only on 0-rank per log-freq batchs
if (batch_idx + 1) % config.log_interval == 0 and local_rank == 0:
lr = optimizer.get_lr()
avg_loss /= config.log_interval
avg_acc = num_corrects / num_samples
print_msg = 'Train Epoch={}/{}, Step={}/{}'.format(
epoch, config.epochs, batch_idx + 1, steps_per_epoch)
print_msg += ' loss={:.4f}'.format(avg_loss)
print_msg += ' acc={:.4f}'.format(avg_acc)
print_msg += ' avg_reader_cost: {:.5f} sec,'.format(
train_reader_cost / config.log_interval)
print_msg += ' avg_feat_cost: {:.5f} sec,'.format(
train_feat_cost / config.log_interval)
print_msg += ' avg_train_cost: {:.5f} sec,'.format(
train_run_cost / config.log_interval)
print_msg += ' lr={:.4E} step/sec={:.2f} | ETA {}'.format(
lr, timer.timing, timer.eta)
logger.info(print_msg)
avg_loss = 0
num_corrects = 0
num_samples = 0
train_reader_cost = 0.0
train_feat_cost = 0.0
train_run_cost = 0.0
reader_start = time.time()
# stage 9-11: save the model parameters only on 0-rank per save-freq batchs
if epoch % config.save_interval == 0 and batch_idx + 1 == steps_per_epoch:
if local_rank != 0:
paddle.distributed.barrier(
) # Wait for valid step in main process
continue # Resume trainning on other process
# stage 9-12: construct the valid dataset dataloader
dev_sampler = BatchSampler(
dev_dataset,
batch_size=config.batch_size,
shuffle=False,
drop_last=False)
dev_loader = DataLoader(
dev_dataset,
batch_sampler=dev_sampler,
collate_fn=waveform_collate_fn,
num_workers=config.num_workers,
return_list=True, )
# set the model to eval mode
model.eval()
num_corrects = 0
num_samples = 0
# stage 9-13: evaluation the valid dataset batch data
logger.info('Evaluate on validation dataset')
with paddle.no_grad():
for batch_idx, batch in enumerate(dev_loader):
waveforms, labels = batch['waveforms'], batch['labels']
feats = []
for waveform in waveforms.numpy():
feat = melspectrogram(
x=waveform,
sr=config.sr,
n_mels=config.n_mels,
window_size=config.window_size,
hop_length=config.hop_size)
feats.append(feat)
feats = paddle.to_tensor(np.asarray(feats))
feats = feature_normalize(
feats, mean_norm=True, std_norm=False)
logits = model(feats)
preds = paddle.argmax(logits, axis=1)
num_corrects += (preds == labels).numpy().sum()
num_samples += feats.shape[0]
print_msg = '[Evaluation result]'
print_msg += ' dev_acc={:.4f}'.format(num_corrects / num_samples)
logger.info(print_msg)
# stage 9-14: Save model parameters
save_dir = os.path.join(args.checkpoint_dir,
'epoch_{}'.format(epoch))
last_saved_epoch = os.path.join('epoch_{}'.format(epoch),
"model.pdparams")
logger.info('Saving model checkpoint to {}'.format(save_dir))
paddle.save(model.state_dict(),
os.path.join(save_dir, 'model.pdparams'))
paddle.save(optimizer.state_dict(),
os.path.join(save_dir, 'model.pdopt'))
if nranks > 1:
paddle.distributed.barrier() # Main process
# stage 10: create the final trained model.pdparams with soft link
if local_rank == 0:
final_model = os.path.join(args.checkpoint_dir, "model.pdparams")
logger.info(f"we will create the final model: {final_model}")
if os.path.islink(final_model):
logger.info(
f"An {final_model} already exists, we will rm is and create it again"
)
os.unlink(final_model)
os.symlink(last_saved_epoch, final_model)
if __name__ == "__main__":
# yapf: disable
parser = argparse.ArgumentParser(__doc__)
parser.add_argument('--device',
choices=['cpu', 'gpu'],
default="cpu",
help="Select which device to train model, defaults to gpu.")
parser.add_argument("--config",
default=None,
type=str,
help="configuration file")
parser.add_argument("--data-dir",
default="./data/",
type=str,
help="data directory")
parser.add_argument("--load-checkpoint",
type=str,
default=None,
help="Directory to load model checkpoint to contiune trainning.")
parser.add_argument("--checkpoint-dir",
type=str,
default='./checkpoint',
help="Directory to save model checkpoints.")
args = parser.parse_args()
# yapf: enable
# https://yaml.org/type/float.html
config = CfgNode(new_allowed=True)
if args.config:
config.merge_from_file(args.config)
config.freeze()
print(config)
main(args, config)

@ -0,0 +1,908 @@
# Copyright (c) 2021 PaddlePaddle Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# this is modified from SpeechBrain
# https://github.com/speechbrain/speechbrain/blob/085be635c07f16d42cd1295045bc46c407f1e15b/speechbrain/lobes/augment.py
import math
import os
from typing import List
import numpy as np
import paddle
import paddle.nn as nn
import paddle.nn.functional as F
from paddleaudio import load as load_audio
from paddleaudio.datasets.rirs_noises import OpenRIRNoise
from paddlespeech.s2t.utils.log import Log
from paddlespeech.vector.io.signal_processing import compute_amplitude
from paddlespeech.vector.io.signal_processing import convolve1d
from paddlespeech.vector.io.signal_processing import dB_to_amplitude
from paddlespeech.vector.io.signal_processing import notch_filter
from paddlespeech.vector.io.signal_processing import reverberate
logger = Log(__name__).getlog()
# TODO: Complete type-hint and doc string.
class DropFreq(nn.Layer):
def __init__(
self,
drop_freq_low=1e-14,
drop_freq_high=1,
drop_count_low=1,
drop_count_high=2,
drop_width=0.05,
drop_prob=1, ):
super(DropFreq, self).__init__()
self.drop_freq_low = drop_freq_low
self.drop_freq_high = drop_freq_high
self.drop_count_low = drop_count_low
self.drop_count_high = drop_count_high
self.drop_width = drop_width
self.drop_prob = drop_prob
def forward(self, waveforms):
# Don't drop (return early) 1-`drop_prob` portion of the batches
dropped_waveform = waveforms.clone()
if paddle.rand([1]) > self.drop_prob:
return dropped_waveform
# Add channels dimension
if len(waveforms.shape) == 2:
dropped_waveform = dropped_waveform.unsqueeze(-1)
# Pick number of frequencies to drop
drop_count = paddle.randint(
low=self.drop_count_low, high=self.drop_count_high + 1, shape=[1])
# Pick a frequency to drop
drop_range = self.drop_freq_high - self.drop_freq_low
drop_frequency = (
paddle.rand([drop_count]) * drop_range + self.drop_freq_low)
# Filter parameters
filter_length = 101
pad = filter_length // 2
# Start with delta function
drop_filter = paddle.zeros([1, filter_length, 1])
drop_filter[0, pad, 0] = 1
# Subtract each frequency
for frequency in drop_frequency:
notch_kernel = notch_filter(frequency, filter_length,
self.drop_width)
drop_filter = convolve1d(drop_filter, notch_kernel, pad)
# Apply filter
dropped_waveform = convolve1d(dropped_waveform, drop_filter, pad)
# Remove channels dimension if added
return dropped_waveform.squeeze(-1)
class DropChunk(nn.Layer):
def __init__(
self,
drop_length_low=100,
drop_length_high=1000,
drop_count_low=1,
drop_count_high=10,
drop_start=0,
drop_end=None,
drop_prob=1,
noise_factor=0.0, ):
super(DropChunk, self).__init__()
self.drop_length_low = drop_length_low
self.drop_length_high = drop_length_high
self.drop_count_low = drop_count_low
self.drop_count_high = drop_count_high
self.drop_start = drop_start
self.drop_end = drop_end
self.drop_prob = drop_prob
self.noise_factor = noise_factor
# Validate low < high
if drop_length_low > drop_length_high:
raise ValueError("Low limit must not be more than high limit")
if drop_count_low > drop_count_high:
raise ValueError("Low limit must not be more than high limit")
# Make sure the length doesn't exceed end - start
if drop_end is not None and drop_end >= 0:
if drop_start > drop_end:
raise ValueError("Low limit must not be more than high limit")
drop_range = drop_end - drop_start
self.drop_length_low = min(drop_length_low, drop_range)
self.drop_length_high = min(drop_length_high, drop_range)
def forward(self, waveforms, lengths):
# Reading input list
lengths = (lengths * waveforms.shape[1]).astype('int64')
batch_size = waveforms.shape[0]
dropped_waveform = waveforms.clone()
# Don't drop (return early) 1-`drop_prob` portion of the batches
if paddle.rand([1]) > self.drop_prob:
return dropped_waveform
# Store original amplitude for computing white noise amplitude
clean_amplitude = compute_amplitude(waveforms, lengths.unsqueeze(1))
# Pick a number of times to drop
drop_times = paddle.randint(
low=self.drop_count_low,
high=self.drop_count_high + 1,
shape=[batch_size], )
# Iterate batch to set mask
for i in range(batch_size):
if drop_times[i] == 0:
continue
# Pick lengths
length = paddle.randint(
low=self.drop_length_low,
high=self.drop_length_high + 1,
shape=[drop_times[i]], )
# Compute range of starting locations
start_min = self.drop_start
if start_min < 0:
start_min += lengths[i]
start_max = self.drop_end
if start_max is None:
start_max = lengths[i]
if start_max < 0:
start_max += lengths[i]
start_max = max(0, start_max - length.max())
# Pick starting locations
start = paddle.randint(
low=start_min,
high=start_max + 1,
shape=[drop_times[i]], )
end = start + length
# Update waveform
if not self.noise_factor:
for j in range(drop_times[i]):
if start[j] < end[j]:
dropped_waveform[i, start[j]:end[j]] = 0.0
else:
# Uniform distribution of -2 to +2 * avg amplitude should
# preserve the average for normalization
noise_max = 2 * clean_amplitude[i] * self.noise_factor
for j in range(drop_times[i]):
# zero-center the noise distribution
noise_vec = paddle.rand([length[j]], dtype='float32')
noise_vec = 2 * noise_max * noise_vec - noise_max
dropped_waveform[i, int(start[j]):int(end[j])] = noise_vec
return dropped_waveform
class Resample(nn.Layer):
def __init__(
self,
orig_freq=16000,
new_freq=16000,
lowpass_filter_width=6, ):
super(Resample, self).__init__()
self.orig_freq = orig_freq
self.new_freq = new_freq
self.lowpass_filter_width = lowpass_filter_width
# Compute rate for striding
self._compute_strides()
assert self.orig_freq % self.conv_stride == 0
assert self.new_freq % self.conv_transpose_stride == 0
def _compute_strides(self):
# Compute new unit based on ratio of in/out frequencies
base_freq = math.gcd(self.orig_freq, self.new_freq)
input_samples_in_unit = self.orig_freq // base_freq
self.output_samples = self.new_freq // base_freq
# Store the appropriate stride based on the new units
self.conv_stride = input_samples_in_unit
self.conv_transpose_stride = self.output_samples
def forward(self, waveforms):
if not hasattr(self, "first_indices"):
self._indices_and_weights(waveforms)
# Don't do anything if the frequencies are the same
if self.orig_freq == self.new_freq:
return waveforms
unsqueezed = False
if len(waveforms.shape) == 2:
waveforms = waveforms.unsqueeze(1)
unsqueezed = True
elif len(waveforms.shape) == 3:
waveforms = waveforms.transpose([0, 2, 1])
else:
raise ValueError("Input must be 2 or 3 dimensions")
# Do resampling
resampled_waveform = self._perform_resample(waveforms)
if unsqueezed:
resampled_waveform = resampled_waveform.squeeze(1)
else:
resampled_waveform = resampled_waveform.transpose([0, 2, 1])
return resampled_waveform
def _perform_resample(self, waveforms):
# Compute output size and initialize
batch_size, num_channels, wave_len = waveforms.shape
window_size = self.weights.shape[1]
tot_output_samp = self._output_samples(wave_len)
resampled_waveform = paddle.zeros((batch_size, num_channels,
tot_output_samp))
# eye size: (num_channels, num_channels, 1)
eye = paddle.eye(num_channels).unsqueeze(2)
# Iterate over the phases in the polyphase filter
for i in range(self.first_indices.shape[0]):
wave_to_conv = waveforms
first_index = int(self.first_indices[i].item())
if first_index >= 0:
# trim the signal as the filter will not be applied
# before the first_index
wave_to_conv = wave_to_conv[:, :, first_index:]
# pad the right of the signal to allow partial convolutions
# meaning compute values for partial windows (e.g. end of the
# window is outside the signal length)
max_index = (tot_output_samp - 1) // self.output_samples
end_index = max_index * self.conv_stride + window_size
current_wave_len = wave_len - first_index
right_padding = max(0, end_index + 1 - current_wave_len)
left_padding = max(0, -first_index)
wave_to_conv = paddle.nn.functional.pad(
wave_to_conv, [left_padding, right_padding], data_format='NCL')
conv_wave = paddle.nn.functional.conv1d(
x=wave_to_conv,
# weight=self.weights[i].repeat(num_channels, 1, 1),
weight=self.weights[i].expand((num_channels, 1, -1)),
stride=self.conv_stride,
groups=num_channels, )
# we want conv_wave[:, i] to be at
# output[:, i + n*conv_transpose_stride]
dilated_conv_wave = paddle.nn.functional.conv1d_transpose(
conv_wave, eye, stride=self.conv_transpose_stride)
# pad dilated_conv_wave so it reaches the output length if needed.
left_padding = i
previous_padding = left_padding + dilated_conv_wave.shape[-1]
right_padding = max(0, tot_output_samp - previous_padding)
dilated_conv_wave = paddle.nn.functional.pad(
dilated_conv_wave, [left_padding, right_padding],
data_format='NCL')
dilated_conv_wave = dilated_conv_wave[:, :, :tot_output_samp]
resampled_waveform += dilated_conv_wave
return resampled_waveform
def _output_samples(self, input_num_samp):
samp_in = int(self.orig_freq)
samp_out = int(self.new_freq)
tick_freq = abs(samp_in * samp_out) // math.gcd(samp_in, samp_out)
ticks_per_input_period = tick_freq // samp_in
# work out the number of ticks in the time interval
# [ 0, input_num_samp/samp_in ).
interval_length = input_num_samp * ticks_per_input_period
if interval_length <= 0:
return 0
ticks_per_output_period = tick_freq // samp_out
# Get the last output-sample in the closed interval,
# i.e. replacing [ ) with [ ]. Note: integer division rounds down.
# See http://en.wikipedia.org/wiki/Interval_(mathematics) for an
# explanation of the notation.
last_output_samp = interval_length // ticks_per_output_period
# We need the last output-sample in the open interval, so if it
# takes us to the end of the interval exactly, subtract one.
if last_output_samp * ticks_per_output_period == interval_length:
last_output_samp -= 1
# First output-sample index is zero, so the number of output samples
# is the last output-sample plus one.
num_output_samp = last_output_samp + 1
return num_output_samp
def _indices_and_weights(self, waveforms):
# Lowpass filter frequency depends on smaller of two frequencies
min_freq = min(self.orig_freq, self.new_freq)
lowpass_cutoff = 0.99 * 0.5 * min_freq
assert lowpass_cutoff * 2 <= min_freq
window_width = self.lowpass_filter_width / (2.0 * lowpass_cutoff)
assert lowpass_cutoff < min(self.orig_freq, self.new_freq) / 2
output_t = paddle.arange(start=0.0, end=self.output_samples)
output_t /= self.new_freq
min_t = output_t - window_width
max_t = output_t + window_width
min_input_index = paddle.ceil(min_t * self.orig_freq)
max_input_index = paddle.floor(max_t * self.orig_freq)
num_indices = max_input_index - min_input_index + 1
max_weight_width = num_indices.max()
j = paddle.arange(max_weight_width, dtype='float32')
input_index = min_input_index.unsqueeze(1) + j.unsqueeze(0)
delta_t = (input_index / self.orig_freq) - output_t.unsqueeze(1)
weights = paddle.zeros_like(delta_t)
inside_window_indices = delta_t.abs().less_than(
paddle.to_tensor(window_width))
# raised-cosine (Hanning) window with width `window_width`
weights[inside_window_indices] = 0.5 * (1 + paddle.cos(
2 * math.pi * lowpass_cutoff / self.lowpass_filter_width *
delta_t.masked_select(inside_window_indices)))
t_eq_zero_indices = delta_t.equal(paddle.zeros_like(delta_t))
t_not_eq_zero_indices = delta_t.not_equal(paddle.zeros_like(delta_t))
# sinc filter function
weights = paddle.where(
t_not_eq_zero_indices,
weights * paddle.sin(2 * math.pi * lowpass_cutoff * delta_t) /
(math.pi * delta_t), weights)
# limit of the function at t = 0
weights = paddle.where(t_eq_zero_indices, weights * 2 * lowpass_cutoff,
weights)
# size (output_samples, max_weight_width)
weights /= self.orig_freq
self.first_indices = min_input_index
self.weights = weights
class SpeedPerturb(nn.Layer):
def __init__(
self,
orig_freq,
speeds=[90, 100, 110],
perturb_prob=1.0, ):
super(SpeedPerturb, self).__init__()
self.orig_freq = orig_freq
self.speeds = speeds
self.perturb_prob = perturb_prob
# Initialize index of perturbation
self.samp_index = 0
# Initialize resamplers
self.resamplers = []
for speed in self.speeds:
config = {
"orig_freq": self.orig_freq,
"new_freq": self.orig_freq * speed // 100,
}
self.resamplers.append(Resample(**config))
def forward(self, waveform):
# Don't perturb (return early) 1-`perturb_prob` portion of the batches
if paddle.rand([1]) > self.perturb_prob:
return waveform.clone()
# Perform a random perturbation
self.samp_index = paddle.randint(len(self.speeds), shape=[1]).item()
perturbed_waveform = self.resamplers[self.samp_index](waveform)
return perturbed_waveform
class AddNoise(nn.Layer):
def __init__(
self,
noise_dataset=None, # None for white noise
num_workers=0,
snr_low=0,
snr_high=0,
mix_prob=1.0,
start_index=None,
normalize=False, ):
super(AddNoise, self).__init__()
self.num_workers = num_workers
self.snr_low = snr_low
self.snr_high = snr_high
self.mix_prob = mix_prob
self.start_index = start_index
self.normalize = normalize
self.noise_dataset = noise_dataset
self.noise_dataloader = None
def forward(self, waveforms, lengths=None):
if lengths is None:
lengths = paddle.ones([len(waveforms)])
# Copy clean waveform to initialize noisy waveform
noisy_waveform = waveforms.clone()
lengths = (lengths * waveforms.shape[1]).astype('int64').unsqueeze(1)
# Don't add noise (return early) 1-`mix_prob` portion of the batches
if paddle.rand([1]) > self.mix_prob:
return noisy_waveform
# Compute the average amplitude of the clean waveforms
clean_amplitude = compute_amplitude(waveforms, lengths)
# Pick an SNR and use it to compute the mixture amplitude factors
SNR = paddle.rand((len(waveforms), 1))
SNR = SNR * (self.snr_high - self.snr_low) + self.snr_low
noise_amplitude_factor = 1 / (dB_to_amplitude(SNR) + 1)
new_noise_amplitude = noise_amplitude_factor * clean_amplitude
# Scale clean signal appropriately
noisy_waveform *= 1 - noise_amplitude_factor
# Loop through clean samples and create mixture
if self.noise_dataset is None:
white_noise = paddle.normal(shape=waveforms.shape)
noisy_waveform += new_noise_amplitude * white_noise
else:
tensor_length = waveforms.shape[1]
noise_waveform, noise_length = self._load_noise(
lengths,
tensor_length, )
# Rescale and add
noise_amplitude = compute_amplitude(noise_waveform, noise_length)
noise_waveform *= new_noise_amplitude / (noise_amplitude + 1e-14)
noisy_waveform += noise_waveform
# Normalizing to prevent clipping
if self.normalize:
abs_max, _ = paddle.max(
paddle.abs(noisy_waveform), axis=1, keepdim=True)
noisy_waveform = noisy_waveform / abs_max.clip(min=1.0)
return noisy_waveform
def _load_noise(self, lengths, max_length):
"""
Load a batch of noises
args
lengths(Paddle.Tensor): Num samples of waveforms with shape (N, 1).
max_length(int): Width of a batch.
"""
lengths = lengths.squeeze(1)
batch_size = len(lengths)
# Load a noise batch
if self.noise_dataloader is None:
def noise_collate_fn(batch):
def pad(x, target_length, mode='constant', **kwargs):
x = np.asarray(x)
w = target_length - x.shape[0]
assert w >= 0, f'Target length {target_length} is less than origin length {x.shape[0]}'
return np.pad(x, [0, w], mode=mode, **kwargs)
ids = [item['id'] for item in batch]
lengths = np.asarray([item['feat'].shape[0] for item in batch])
waveforms = list(
map(lambda x: pad(x, max(max_length, lengths.max().item())),
[item['feat'] for item in batch]))
waveforms = np.stack(waveforms)
return {'ids': ids, 'feats': waveforms, 'lengths': lengths}
# Create noise data loader.
self.noise_dataloader = paddle.io.DataLoader(
self.noise_dataset,
batch_size=batch_size,
shuffle=True,
num_workers=self.num_workers,
collate_fn=noise_collate_fn,
return_list=True, )
self.noise_data = iter(self.noise_dataloader)
noise_batch, noise_len = self._load_noise_batch_of_size(batch_size)
# Select a random starting location in the waveform
start_index = self.start_index
if self.start_index is None:
start_index = 0
max_chop = (noise_len - lengths).min().clip(min=1)
start_index = paddle.randint(high=max_chop, shape=[1])
# Truncate noise_batch to max_length
noise_batch = noise_batch[:, start_index:start_index + max_length]
noise_len = (noise_len - start_index).clip(max=max_length).unsqueeze(1)
return noise_batch, noise_len
def _load_noise_batch_of_size(self, batch_size):
"""Concatenate noise batches, then chop to correct size"""
noise_batch, noise_lens = self._load_noise_batch()
# Expand
while len(noise_batch) < batch_size:
noise_batch = paddle.concat((noise_batch, noise_batch))
noise_lens = paddle.concat((noise_lens, noise_lens))
# Contract
if len(noise_batch) > batch_size:
noise_batch = noise_batch[:batch_size]
noise_lens = noise_lens[:batch_size]
return noise_batch, noise_lens
def _load_noise_batch(self):
"""Load a batch of noises, restarting iteration if necessary."""
try:
batch = next(self.noise_data)
except StopIteration:
self.noise_data = iter(self.noise_dataloader)
batch = next(self.noise_data)
noises, lens = batch['feats'], batch['lengths']
return noises, lens
class AddReverb(nn.Layer):
def __init__(
self,
rir_dataset,
reverb_prob=1.0,
rir_scale_factor=1.0,
num_workers=0, ):
super(AddReverb, self).__init__()
self.rir_dataset = rir_dataset
self.reverb_prob = reverb_prob
self.rir_scale_factor = rir_scale_factor
# Create rir data loader.
def rir_collate_fn(batch):
def pad(x, target_length, mode='constant', **kwargs):
x = np.asarray(x)
w = target_length - x.shape[0]
assert w >= 0, f'Target length {target_length} is less than origin length {x.shape[0]}'
return np.pad(x, [0, w], mode=mode, **kwargs)
ids = [item['id'] for item in batch]
lengths = np.asarray([item['feat'].shape[0] for item in batch])
waveforms = list(
map(lambda x: pad(x, lengths.max().item()),
[item['feat'] for item in batch]))
waveforms = np.stack(waveforms)
return {'ids': ids, 'feats': waveforms, 'lengths': lengths}
self.rir_dataloader = paddle.io.DataLoader(
self.rir_dataset,
collate_fn=rir_collate_fn,
num_workers=num_workers,
shuffle=True,
return_list=True, )
self.rir_data = iter(self.rir_dataloader)
def forward(self, waveforms, lengths=None):
"""
Arguments
---------
waveforms : tensor
Shape should be `[batch, time]` or `[batch, time, channels]`.
lengths : tensor
Shape should be a single dimension, `[batch]`.
Returns
-------
Tensor of shape `[batch, time]` or `[batch, time, channels]`.
"""
if lengths is None:
lengths = paddle.ones([len(waveforms)])
# Don't add reverb (return early) 1-`reverb_prob` portion of the time
if paddle.rand([1]) > self.reverb_prob:
return waveforms.clone()
# Add channels dimension if necessary
channel_added = False
if len(waveforms.shape) == 2:
waveforms = waveforms.unsqueeze(-1)
channel_added = True
# Load and prepare RIR
rir_waveform = self._load_rir()
# Compress or dilate RIR
if self.rir_scale_factor != 1:
rir_waveform = F.interpolate(
rir_waveform.transpose([0, 2, 1]),
scale_factor=self.rir_scale_factor,
mode="linear",
align_corners=False,
data_format='NCW', )
# (N, C, L) -> (N, L, C)
rir_waveform = rir_waveform.transpose([0, 2, 1])
rev_waveform = reverberate(
waveforms,
rir_waveform,
self.rir_dataset.sample_rate,
rescale_amp="avg")
# Remove channels dimension if added
if channel_added:
return rev_waveform.squeeze(-1)
return rev_waveform
def _load_rir(self):
try:
batch = next(self.rir_data)
except StopIteration:
self.rir_data = iter(self.rir_dataloader)
batch = next(self.rir_data)
rir_waveform = batch['feats']
# Make sure RIR has correct channels
if len(rir_waveform.shape) == 2:
rir_waveform = rir_waveform.unsqueeze(-1)
return rir_waveform
class AddBabble(nn.Layer):
def __init__(
self,
speaker_count=3,
snr_low=0,
snr_high=0,
mix_prob=1, ):
super(AddBabble, self).__init__()
self.speaker_count = speaker_count
self.snr_low = snr_low
self.snr_high = snr_high
self.mix_prob = mix_prob
def forward(self, waveforms, lengths=None):
if lengths is None:
lengths = paddle.ones([len(waveforms)])
babbled_waveform = waveforms.clone()
lengths = (lengths * waveforms.shape[1]).unsqueeze(1)
batch_size = len(waveforms)
# Don't mix (return early) 1-`mix_prob` portion of the batches
if paddle.rand([1]) > self.mix_prob:
return babbled_waveform
# Pick an SNR and use it to compute the mixture amplitude factors
clean_amplitude = compute_amplitude(waveforms, lengths)
SNR = paddle.rand((batch_size, 1))
SNR = SNR * (self.snr_high - self.snr_low) + self.snr_low
noise_amplitude_factor = 1 / (dB_to_amplitude(SNR) + 1)
new_noise_amplitude = noise_amplitude_factor * clean_amplitude
# Scale clean signal appropriately
babbled_waveform *= 1 - noise_amplitude_factor
# For each speaker in the mixture, roll and add
babble_waveform = waveforms.roll((1, ), axis=0)
babble_len = lengths.roll((1, ), axis=0)
for i in range(1, self.speaker_count):
babble_waveform += waveforms.roll((1 + i, ), axis=0)
babble_len = paddle.concat(
[babble_len, babble_len.roll((1, ), axis=0)], axis=-1).max(
axis=-1, keepdim=True)
# Rescale and add to mixture
babble_amplitude = compute_amplitude(babble_waveform, babble_len)
babble_waveform *= new_noise_amplitude / (babble_amplitude + 1e-14)
babbled_waveform += babble_waveform
return babbled_waveform
class TimeDomainSpecAugment(nn.Layer):
def __init__(
self,
perturb_prob=1.0,
drop_freq_prob=1.0,
drop_chunk_prob=1.0,
speeds=[95, 100, 105],
sample_rate=16000,
drop_freq_count_low=0,
drop_freq_count_high=3,
drop_chunk_count_low=0,
drop_chunk_count_high=5,
drop_chunk_length_low=1000,
drop_chunk_length_high=2000,
drop_chunk_noise_factor=0, ):
super(TimeDomainSpecAugment, self).__init__()
self.speed_perturb = SpeedPerturb(
perturb_prob=perturb_prob,
orig_freq=sample_rate,
speeds=speeds, )
self.drop_freq = DropFreq(
drop_prob=drop_freq_prob,
drop_count_low=drop_freq_count_low,
drop_count_high=drop_freq_count_high, )
self.drop_chunk = DropChunk(
drop_prob=drop_chunk_prob,
drop_count_low=drop_chunk_count_low,
drop_count_high=drop_chunk_count_high,
drop_length_low=drop_chunk_length_low,
drop_length_high=drop_chunk_length_high,
noise_factor=drop_chunk_noise_factor, )
def forward(self, waveforms, lengths=None):
if lengths is None:
lengths = paddle.ones([len(waveforms)])
with paddle.no_grad():
# Augmentation
waveforms = self.speed_perturb(waveforms)
waveforms = self.drop_freq(waveforms)
waveforms = self.drop_chunk(waveforms, lengths)
return waveforms
class EnvCorrupt(nn.Layer):
def __init__(
self,
reverb_prob=1.0,
babble_prob=1.0,
noise_prob=1.0,
rir_dataset=None,
noise_dataset=None,
num_workers=0,
babble_speaker_count=0,
babble_snr_low=0,
babble_snr_high=0,
noise_snr_low=0,
noise_snr_high=0,
rir_scale_factor=1.0, ):
super(EnvCorrupt, self).__init__()
# Initialize corrupters
if rir_dataset is not None and reverb_prob > 0.0:
self.add_reverb = AddReverb(
rir_dataset=rir_dataset,
num_workers=num_workers,
reverb_prob=reverb_prob,
rir_scale_factor=rir_scale_factor, )
if babble_speaker_count > 0 and babble_prob > 0.0:
self.add_babble = AddBabble(
speaker_count=babble_speaker_count,
snr_low=babble_snr_low,
snr_high=babble_snr_high,
mix_prob=babble_prob, )
if noise_dataset is not None and noise_prob > 0.0:
self.add_noise = AddNoise(
noise_dataset=noise_dataset,
num_workers=num_workers,
snr_low=noise_snr_low,
snr_high=noise_snr_high,
mix_prob=noise_prob, )
def forward(self, waveforms, lengths=None):
if lengths is None:
lengths = paddle.ones([len(waveforms)])
# Augmentation
with paddle.no_grad():
if hasattr(self, "add_reverb"):
try:
waveforms = self.add_reverb(waveforms, lengths)
except Exception:
pass
if hasattr(self, "add_babble"):
waveforms = self.add_babble(waveforms, lengths)
if hasattr(self, "add_noise"):
waveforms = self.add_noise(waveforms, lengths)
return waveforms
def build_augment_pipeline(target_dir=None) -> List[paddle.nn.Layer]:
"""build augment pipeline
Note: this pipeline cannot be used in the paddle.DataLoader
Returns:
List[paddle.nn.Layer]: all augment process
"""
logger.info("start to build the augment pipeline")
noise_dataset = OpenRIRNoise('noise', target_dir=target_dir)
rir_dataset = OpenRIRNoise('rir', target_dir=target_dir)
wavedrop = TimeDomainSpecAugment(
sample_rate=16000,
speeds=[100], )
speed_perturb = TimeDomainSpecAugment(
sample_rate=16000,
speeds=[95, 100, 105], )
add_noise = EnvCorrupt(
noise_dataset=noise_dataset,
reverb_prob=0.0,
noise_prob=1.0,
noise_snr_low=0,
noise_snr_high=15,
rir_scale_factor=1.0, )
add_rev = EnvCorrupt(
rir_dataset=rir_dataset,
reverb_prob=1.0,
noise_prob=0.0,
rir_scale_factor=1.0, )
add_rev_noise = EnvCorrupt(
noise_dataset=noise_dataset,
rir_dataset=rir_dataset,
reverb_prob=1.0,
noise_prob=1.0,
noise_snr_low=0,
noise_snr_high=15,
rir_scale_factor=1.0, )
return [wavedrop, speed_perturb, add_noise, add_rev, add_rev_noise]
def waveform_augment(waveforms: paddle.Tensor,
augment_pipeline: List[paddle.nn.Layer]) -> paddle.Tensor:
"""process the augment pipeline and return all the waveforms
Args:
waveforms (paddle.Tensor): original batch waveform
augment_pipeline (List[paddle.nn.Layer]): agument pipeline process
Returns:
paddle.Tensor: all the audio waveform including the original waveform and augmented waveform
"""
# stage 0: store the original waveforms
waveforms_aug_list = [waveforms]
# augment the original batch waveform
for aug in augment_pipeline:
# stage 1: augment the data
waveforms_aug = aug(waveforms) # (N, L)
if waveforms_aug.shape[1] >= waveforms.shape[1]:
# Trunc
waveforms_aug = waveforms_aug[:, :waveforms.shape[1]]
else:
# Pad
lengths_to_pad = waveforms.shape[1] - waveforms_aug.shape[1]
waveforms_aug = F.pad(
waveforms_aug.unsqueeze(-1), [0, lengths_to_pad],
data_format='NLC').squeeze(-1)
# stage 2: append the augmented waveform into the list
waveforms_aug_list.append(waveforms_aug)
# get the all the waveforms
return paddle.concat(waveforms_aug_list, axis=0)

@ -0,0 +1,166 @@
# Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import numpy
import numpy as np
import paddle
def waveform_collate_fn(batch):
waveforms = np.stack([item['feat'] for item in batch])
labels = np.stack([item['label'] for item in batch])
return {'waveforms': waveforms, 'labels': labels}
def feature_normalize(feats: paddle.Tensor,
mean_norm: bool=True,
std_norm: bool=True,
convert_to_numpy: bool=False):
# Features normalization if needed
# numpy.mean is a little with paddle.mean about 1e-6
if convert_to_numpy:
feats_np = feats.numpy()
mean = feats_np.mean(axis=-1, keepdims=True) if mean_norm else 0
std = feats_np.std(axis=-1, keepdims=True) if std_norm else 1
feats_np = (feats_np - mean) / std
feats = paddle.to_tensor(feats_np, dtype=feats.dtype)
else:
mean = feats.mean(axis=-1, keepdim=True) if mean_norm else 0
std = feats.std(axis=-1, keepdim=True) if std_norm else 1
feats = (feats - mean) / std
return feats
def pad_right_2d(x, target_length, axis=-1, mode='constant', **kwargs):
x = np.asarray(x)
assert len(
x.shape) == 2, f'Only 2D arrays supported, but got shape: {x.shape}'
w = target_length - x.shape[axis]
assert w >= 0, f'Target length {target_length} is less than origin length {x.shape[axis]}'
if axis == 0:
pad_width = [[0, w], [0, 0]]
else:
pad_width = [[0, 0], [0, w]]
return np.pad(x, pad_width, mode=mode, **kwargs)
def batch_feature_normalize(batch, mean_norm: bool=True, std_norm: bool=True):
ids = [item['id'] for item in batch]
lengths = np.asarray([item['feat'].shape[1] for item in batch])
feats = list(
map(lambda x: pad_right_2d(x, lengths.max()),
[item['feat'] for item in batch]))
feats = np.stack(feats)
# Features normalization if needed
for i in range(len(feats)):
feat = feats[i][:, :lengths[i]] # Excluding pad values.
mean = feat.mean(axis=-1, keepdims=True) if mean_norm else 0
std = feat.std(axis=-1, keepdims=True) if std_norm else 1
feats[i][:, :lengths[i]] = (feat - mean) / std
assert feats[i][:, lengths[
i]:].sum() == 0 # Padding valus should all be 0.
# Converts into ratios.
# the utterance of the max length doesn't need to padding
# the remaining utterances need to padding and all of them will be padded to max length
# we convert the original length of each utterance to the ratio of the max length
lengths = (lengths / lengths.max()).astype(np.float32)
return {'ids': ids, 'feats': feats, 'lengths': lengths}
def pad_right_to(array, target_shape, mode="constant", value=0):
"""
This function takes a numpy array of arbitrary shape and pads it to target
shape by appending values on the right.
Args:
array: input numpy array. Input array whose dimension we need to pad.
target_shape : (list, tuple). Target shape we want for the target array its len must be equal to array.ndim
mode : str. Pad mode, please refer to numpy.pad documentation.
value : float. Pad value, please refer to numpy.pad documentation.
Returns:
array: numpy.array. Padded array.
valid_vals : list. List containing proportion for each dimension of original, non-padded values.
"""
assert len(target_shape) == array.ndim
pads = [] # this contains the abs length of the padding for each dimension.
valid_vals = [] # this contains the relative lengths for each dimension.
i = 0 # iterating over target_shape ndims
while i < len(target_shape):
assert (target_shape[i] >= array.shape[i]
), "Target shape must be >= original shape for every dim"
pads.append([0, target_shape[i] - array.shape[i]])
valid_vals.append(array.shape[i] / target_shape[i])
i += 1
array = numpy.pad(array, pads, mode=mode, constant_values=value)
return array, valid_vals
def batch_pad_right(arrays, mode="constant", value=0):
"""Given a list of numpy arrays it batches them together by padding to the right
on each dimension in order to get same length for all.
Args:
arrays : list. List of array we wish to pad together.
mode : str. Padding mode see numpy.pad documentation.
value : float. Padding value see numpy.pad documentation.
Returns:
array : numpy.array. Padded array.
valid_vals : list. List containing proportion for each dimension of original, non-padded values.
"""
if not len(arrays):
raise IndexError("arrays list must not be empty")
if len(arrays) == 1:
# if there is only one array in the batch we simply unsqueeze it.
return numpy.expand_dims(arrays[0], axis=0), numpy.array([1.0])
if not (any(
[arrays[i].ndim == arrays[0].ndim for i in range(1, len(arrays))])):
raise IndexError("All arrays must have same number of dimensions")
# FIXME we limit the support here: we allow padding of only the last dimension
# need to remove this when feat extraction is updated to handle multichannel.
max_shape = []
for dim in range(arrays[0].ndim):
if dim != (arrays[0].ndim - 1):
if not all(
[x.shape[dim] == arrays[0].shape[dim] for x in arrays[1:]]):
raise EnvironmentError(
"arrays should have same dimensions except for last one")
max_shape.append(max([x.shape[dim] for x in arrays]))
batched = []
valid = []
for t in arrays:
# for each array we apply pad_right_to
padded, valid_percent = pad_right_to(
t, max_shape, mode=mode, value=value)
batched.append(padded)
valid.append(valid_percent[-1])
batched = numpy.stack(batched)
return batched, numpy.array(valid)

@ -0,0 +1,219 @@
# Copyright (c) 2021 PaddlePaddle Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import math
import numpy as np
import paddle
# TODO: Complete type-hint and doc string.
def blackman_window(win_len, dtype=np.float32):
arcs = np.pi * np.arange(win_len) / float(win_len)
win = np.asarray(
[0.42 - 0.5 * np.cos(2 * arc) + 0.08 * np.cos(4 * arc) for arc in arcs],
dtype=dtype)
return paddle.to_tensor(win)
def compute_amplitude(waveforms, lengths=None, amp_type="avg", scale="linear"):
if len(waveforms.shape) == 1:
waveforms = waveforms.unsqueeze(0)
assert amp_type in ["avg", "peak"]
assert scale in ["linear", "dB"]
if amp_type == "avg":
if lengths is None:
out = paddle.mean(paddle.abs(waveforms), axis=1, keepdim=True)
else:
wav_sum = paddle.sum(paddle.abs(waveforms), axis=1, keepdim=True)
out = wav_sum / lengths
elif amp_type == "peak":
out = paddle.max(paddle.abs(waveforms), axis=1, keepdim=True)
else:
raise NotImplementedError
if scale == "linear":
return out
elif scale == "dB":
return paddle.clip(20 * paddle.log10(out), min=-80)
else:
raise NotImplementedError
def dB_to_amplitude(SNR):
return 10**(SNR / 20)
def convolve1d(
waveform,
kernel,
padding=0,
pad_type="constant",
stride=1,
groups=1, ):
if len(waveform.shape) != 3:
raise ValueError("Convolve1D expects a 3-dimensional tensor")
# Padding can be a tuple (left_pad, right_pad) or an int
if isinstance(padding, list):
waveform = paddle.nn.functional.pad(
x=waveform,
pad=padding,
mode=pad_type,
data_format='NLC', )
# Move time dimension last, which pad and fft and conv expect.
# (N, L, C) -> (N, C, L)
waveform = waveform.transpose([0, 2, 1])
kernel = kernel.transpose([0, 2, 1])
convolved = paddle.nn.functional.conv1d(
x=waveform,
weight=kernel,
stride=stride,
groups=groups,
padding=padding if not isinstance(padding, list) else 0, )
# Return time dimension to the second dimension.
return convolved.transpose([0, 2, 1])
def notch_filter(notch_freq, filter_width=101, notch_width=0.05):
# Check inputs
assert 0 < notch_freq <= 1
assert filter_width % 2 != 0
pad = filter_width // 2
inputs = paddle.arange(filter_width, dtype='float32') - pad
# Avoid frequencies that are too low
notch_freq += notch_width
# Define sinc function, avoiding division by zero
def sinc(x):
def _sinc(x):
return paddle.sin(x) / x
# The zero is at the middle index
res = paddle.concat(
[_sinc(x[:pad]), paddle.ones([1]), _sinc(x[pad + 1:])])
return res
# Compute a low-pass filter with cutoff frequency notch_freq.
hlpf = sinc(3 * (notch_freq - notch_width) * inputs)
# import torch
# hlpf *= paddle.to_tensor(torch.blackman_window(filter_width).detach().numpy())
hlpf *= blackman_window(filter_width)
hlpf /= paddle.sum(hlpf)
# Compute a high-pass filter with cutoff frequency notch_freq.
hhpf = sinc(3 * (notch_freq + notch_width) * inputs)
# hhpf *= paddle.to_tensor(torch.blackman_window(filter_width).detach().numpy())
hhpf *= blackman_window(filter_width)
hhpf /= -paddle.sum(hhpf)
hhpf[pad] += 1
# Adding filters creates notch filter
return (hlpf + hhpf).reshape([1, -1, 1])
def reverberate(waveforms,
rir_waveform,
sample_rate,
impulse_duration=0.3,
rescale_amp="avg"):
orig_shape = waveforms.shape
if len(waveforms.shape) > 3 or len(rir_waveform.shape) > 3:
raise NotImplementedError
# if inputs are mono tensors we reshape to 1, samples
if len(waveforms.shape) == 1:
waveforms = waveforms.unsqueeze(0).unsqueeze(-1)
elif len(waveforms.shape) == 2:
waveforms = waveforms.unsqueeze(-1)
if len(rir_waveform.shape) == 1: # convolve1d expects a 3d tensor !
rir_waveform = rir_waveform.unsqueeze(0).unsqueeze(-1)
elif len(rir_waveform.shape) == 2:
rir_waveform = rir_waveform.unsqueeze(-1)
# Compute the average amplitude of the clean
orig_amplitude = compute_amplitude(waveforms, waveforms.shape[1],
rescale_amp)
# Compute index of the direct signal, so we can preserve alignment
impulse_index_start = rir_waveform.abs().argmax(axis=1).item()
impulse_index_end = min(
impulse_index_start + int(sample_rate * impulse_duration),
rir_waveform.shape[1])
rir_waveform = rir_waveform[:, impulse_index_start:impulse_index_end, :]
rir_waveform = rir_waveform / paddle.norm(rir_waveform, p=2)
rir_waveform = paddle.flip(rir_waveform, [1])
waveforms = convolve1d(
waveform=waveforms,
kernel=rir_waveform,
padding=[rir_waveform.shape[1] - 1, 0], )
# Rescale to the peak amplitude of the clean waveform
waveforms = rescale(waveforms, waveforms.shape[1], orig_amplitude,
rescale_amp)
if len(orig_shape) == 1:
waveforms = waveforms.squeeze(0).squeeze(-1)
if len(orig_shape) == 2:
waveforms = waveforms.squeeze(-1)
return waveforms
def rescale(waveforms, lengths, target_lvl, amp_type="avg", scale="linear"):
assert amp_type in ["peak", "avg"]
assert scale in ["linear", "dB"]
batch_added = False
if len(waveforms.shape) == 1:
batch_added = True
waveforms = waveforms.unsqueeze(0)
waveforms = normalize(waveforms, lengths, amp_type)
if scale == "linear":
out = target_lvl * waveforms
elif scale == "dB":
out = dB_to_amplitude(target_lvl) * waveforms
else:
raise NotImplementedError("Invalid scale, choose between dB and linear")
if batch_added:
out = out.squeeze(0)
return out
def normalize(waveforms, lengths=None, amp_type="avg", eps=1e-14):
assert amp_type in ["avg", "peak"]
batch_added = False
if len(waveforms.shape) == 1:
batch_added = True
waveforms = waveforms.unsqueeze(0)
den = compute_amplitude(waveforms, lengths, amp_type) + eps
if batch_added:
waveforms = waveforms.squeeze(0)
return waveforms / den

@ -47,6 +47,19 @@ class Conv1d(nn.Layer):
groups=1,
bias=True,
padding_mode="reflect", ):
"""_summary_
Args:
in_channels (int): intput channel or input data dimensions
out_channels (int): output channel or output data dimensions
kernel_size (int): kernel size of 1-d convolution
stride (int, optional): strid in 1-d convolution . Defaults to 1.
padding (str, optional): padding value. Defaults to "same".
dilation (int, optional): dilation in 1-d convolution. Defaults to 1.
groups (int, optional): groups in 1-d convolution. Defaults to 1.
bias (bool, optional): bias in 1-d convolution . Defaults to True.
padding_mode (str, optional): padding mode. Defaults to "reflect".
"""
super().__init__()
self.kernel_size = kernel_size
@ -134,6 +147,15 @@ class TDNNBlock(nn.Layer):
kernel_size,
dilation,
activation=nn.ReLU, ):
"""Implementation of TDNN network
Args:
in_channels (int): input channels or input embedding dimensions
out_channels (int): output channels or output embedding dimensions
kernel_size (int): the kernel size of the TDNN network block
dilation (int): the dilation of the TDNN network block
activation (paddle class, optional): the activation layers. Defaults to nn.ReLU.
"""
super().__init__()
self.conv = Conv1d(
in_channels=in_channels,
@ -149,6 +171,15 @@ class TDNNBlock(nn.Layer):
class Res2NetBlock(nn.Layer):
def __init__(self, in_channels, out_channels, scale=8, dilation=1):
"""Implementation of Res2Net Block with dilation
The paper is refered as "Res2Net: A New Multi-scale Backbone Architecture",
whose url is https://arxiv.org/abs/1904.01169
Args:
in_channels (int): input channels or input dimensions
out_channels (int): output channels or output dimensions
scale (int, optional): scale in res2net bolck. Defaults to 8.
dilation (int, optional): dilation of 1-d convolution in TDNN block. Defaults to 1.
"""
super().__init__()
assert in_channels % scale == 0
assert out_channels % scale == 0
@ -179,6 +210,14 @@ class Res2NetBlock(nn.Layer):
class SEBlock(nn.Layer):
def __init__(self, in_channels, se_channels, out_channels):
"""Implementation of SEBlock
The paper is refered as "Squeeze-and-Excitation Networks"
whose url is https://arxiv.org/abs/1709.01507
Args:
in_channels (int): input channels or input data dimensions
se_channels (_type_): _description_
out_channels (int): output channels or output data dimensions
"""
super().__init__()
self.conv1 = Conv1d(
@ -275,6 +314,18 @@ class SERes2NetBlock(nn.Layer):
kernel_size=1,
dilation=1,
activation=nn.ReLU, ):
"""Implementation of Squeeze-Extraction Res2Blocks in ECAPA-TDNN network model
The paper is refered "Squeeze-and-Excitation Networks"
whose url is: https://arxiv.org/pdf/1709.01507.pdf
Args:
in_channels (int): input channels or input data dimensions
out_channels (int): output channels or output data dimensions
res2net_scale (int, optional): scale in the res2net block. Defaults to 8.
se_channels (int, optional): embedding dimensions of res2net block. Defaults to 128.
kernel_size (int, optional): kernel size of 1-d convolution in TDNN block. Defaults to 1.
dilation (int, optional): dilation of 1-d convolution in TDNN block. Defaults to 1.
activation (paddle.nn.class, optional): activation function. Defaults to nn.ReLU.
"""
super().__init__()
self.out_channels = out_channels
self.tdnn1 = TDNNBlock(
@ -326,7 +377,21 @@ class EcapaTdnn(nn.Layer):
res2net_scale=8,
se_channels=128,
global_context=True, ):
"""Implementation of ECAPA-TDNN backbone model network
The paper is refered as "ECAPA-TDNN: Emphasized Channel Attention, Propagation and Aggregation in TDNN Based Speaker Verification"
whose url is: https://arxiv.org/abs/2005.07143
Args:
input_size (_type_): input fature dimension
lin_neurons (int, optional): speaker embedding size. Defaults to 192.
activation (paddle.nn.class, optional): activation function. Defaults to nn.ReLU.
channels (list, optional): inter embedding dimension. Defaults to [512, 512, 512, 512, 1536].
kernel_sizes (list, optional): kernel size of 1-d convolution in TDNN block . Defaults to [5, 3, 3, 3, 1].
dilations (list, optional): dilations of 1-d convolution in TDNN block. Defaults to [1, 2, 3, 4, 1].
attention_channels (int, optional): attention dimensions. Defaults to 128.
res2net_scale (int, optional): scale value in res2net. Defaults to 8.
se_channels (int, optional): dimensions of squeeze-excitation block. Defaults to 128.
global_context (bool, optional): global context flag. Defaults to True.
"""
super().__init__()
assert len(channels) == len(kernel_sizes)
assert len(channels) == len(dilations)

@ -0,0 +1,93 @@
# Copyright (c) 2021 PaddlePaddle Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# This is modified from SpeechBrain
# https://github.com/speechbrain/speechbrain/blob/085be635c07f16d42cd1295045bc46c407f1e15b/speechbrain/nnet/losses.py
import math
import paddle
import paddle.nn as nn
import paddle.nn.functional as F
class AngularMargin(nn.Layer):
def __init__(self, margin=0.0, scale=1.0):
"""An implementation of Angular Margin (AM) proposed in the following
paper: '''Margin Matters: Towards More Discriminative Deep Neural Network
Embeddings for Speaker Recognition''' (https://arxiv.org/abs/1906.07317)
Args:
margin (float, optional): The margin for cosine similiarity. Defaults to 0.0.
scale (float, optional): The scale for cosine similiarity. Defaults to 1.0.
"""
super(AngularMargin, self).__init__()
self.margin = margin
self.scale = scale
def forward(self, outputs, targets):
outputs = outputs - self.margin * targets
return self.scale * outputs
class AdditiveAngularMargin(AngularMargin):
def __init__(self, margin=0.0, scale=1.0, easy_margin=False):
"""The Implementation of Additive Angular Margin (AAM) proposed
in the following paper: '''Margin Matters: Towards More Discriminative Deep Neural Network Embeddings for Speaker Recognition'''
(https://arxiv.org/abs/1906.07317)
Args:
margin (float, optional): margin factor. Defaults to 0.0.
scale (float, optional): scale factor. Defaults to 1.0.
easy_margin (bool, optional): easy_margin flag. Defaults to False.
"""
super(AdditiveAngularMargin, self).__init__(margin, scale)
self.easy_margin = easy_margin
self.cos_m = math.cos(self.margin)
self.sin_m = math.sin(self.margin)
self.th = math.cos(math.pi - self.margin)
self.mm = math.sin(math.pi - self.margin) * self.margin
def forward(self, outputs, targets):
cosine = outputs.astype('float32')
sine = paddle.sqrt(1.0 - paddle.pow(cosine, 2))
phi = cosine * self.cos_m - sine * self.sin_m # cos(theta + m)
if self.easy_margin:
phi = paddle.where(cosine > 0, phi, cosine)
else:
phi = paddle.where(cosine > self.th, phi, cosine - self.mm)
outputs = (targets * phi) + ((1.0 - targets) * cosine)
return self.scale * outputs
class LogSoftmaxWrapper(nn.Layer):
def __init__(self, loss_fn):
"""Speaker identificatin loss function wrapper
including all of compositions of the loss transformation
Args:
loss_fn (_type_): the loss value of a batch
"""
super(LogSoftmaxWrapper, self).__init__()
self.loss_fn = loss_fn
self.criterion = paddle.nn.KLDivLoss(reduction="sum")
def forward(self, outputs, targets, length=None):
targets = F.one_hot(targets, outputs.shape[1])
try:
predictions = self.loss_fn(outputs, targets)
except TypeError:
predictions = self.loss_fn(outputs)
predictions = F.log_softmax(predictions, axis=1)
loss = self.criterion(predictions, targets) / targets.sum()
return loss

@ -0,0 +1,87 @@
# Copyright (c) 2021 PaddlePaddle Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import paddle
import paddle.nn as nn
import paddle.nn.functional as F
class SpeakerIdetification(nn.Layer):
def __init__(
self,
backbone,
num_class,
lin_blocks=0,
lin_neurons=192,
dropout=0.1, ):
"""The speaker identification model, which includes the speaker backbone network
and the a linear transform to speaker class num in training
Args:
backbone (Paddle.nn.Layer class): the speaker identification backbone network model
num_class (_type_): the speaker class num in the training dataset
lin_blocks (int, optional): the linear layer transform between the embedding and the final linear layer. Defaults to 0.
lin_neurons (int, optional): the output dimension of final linear layer. Defaults to 192.
dropout (float, optional): the dropout factor on the embedding. Defaults to 0.1.
"""
super(SpeakerIdetification, self).__init__()
# speaker idenfication backbone network model
# the output of the backbond network is the target embedding
self.backbone = backbone
if dropout > 0:
self.dropout = nn.Dropout(dropout)
else:
self.dropout = None
# construct the speaker classifer
input_size = self.backbone.emb_size
self.blocks = nn.LayerList()
for i in range(lin_blocks):
self.blocks.extend([
nn.BatchNorm1D(input_size),
nn.Linear(in_features=input_size, out_features=lin_neurons),
])
input_size = lin_neurons
# the final layer
self.weight = paddle.create_parameter(
shape=(input_size, num_class),
dtype='float32',
attr=paddle.ParamAttr(initializer=nn.initializer.XavierUniform()), )
def forward(self, x, lengths=None):
"""Do the speaker identification model forwrd,
including the speaker embedding model and the classifier model network
Args:
x (paddle.Tensor): input audio feats,
shape=[batch, dimension, times]
lengths (paddle.Tensor, optional): input audio length.
shape=[batch, times]
Defaults to None.
Returns:
paddle.Tensor: return the logits of the feats
"""
# x.shape: (N, C, L)
x = self.backbone(x, lengths).squeeze(
-1) # (N, emb_size, 1) -> (N, emb_size)
if self.dropout is not None:
x = self.dropout(x)
for fc in self.blocks:
x = fc(x)
logits = F.linear(F.normalize(x), F.normalize(self.weight, axis=0))
return logits

@ -0,0 +1,45 @@
# Copyright (c) 2021 PaddlePaddle Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
from paddle.optimizer.lr import LRScheduler
class CyclicLRScheduler(LRScheduler):
def __init__(self,
base_lr: float=1e-8,
max_lr: float=1e-3,
step_size: int=10000):
super(CyclicLRScheduler, self).__init__()
self.current_step = -1
self.base_lr = base_lr
self.max_lr = max_lr
self.step_size = step_size
def step(self):
if not hasattr(self, 'current_step'):
return
self.current_step += 1
if self.current_step >= 2 * self.step_size:
self.current_step %= 2 * self.step_size
self.last_lr = self.get_lr()
def get_lr(self):
p = self.current_step / (2 * self.step_size) # Proportion in one cycle.
if p < 0.5: # Increase
return self.base_lr + p / 0.5 * (self.max_lr - self.base_lr)
else: # Decrease
return self.max_lr - (p / 0.5 - 1) * (self.max_lr - self.base_lr)

@ -0,0 +1,28 @@
# Copyright (c) 2021 PaddlePaddle Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
from paddlespeech.s2t.utils.log import Log
logger = Log(__name__).getlog()
import random
import numpy as np
import paddle
def seed_everything(seed: int):
"""Seed paddle, random and np.random to help reproductivity."""
paddle.seed(seed)
random.seed(seed)
np.random.seed(seed)
logger.info(f"Set the seed of paddle, random, np.random to {seed}.")

@ -0,0 +1,66 @@
# Copyright (c) 2021 PaddlePaddle Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License"
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import math
import time
class Timer(object):
'''Calculate runing speed and estimated time of arrival(ETA)'''
def __init__(self, total_step: int):
self.total_step = total_step
self.last_start_step = 0
self.current_step = 0
self._is_running = True
def start(self):
self.last_time = time.time()
self.start_time = time.time()
def stop(self):
self._is_running = False
self.end_time = time.time()
def count(self) -> int:
if not self.current_step >= self.total_step:
self.current_step += 1
return self.current_step
@property
def timing(self) -> float:
run_steps = self.current_step - self.last_start_step
self.last_start_step = self.current_step
time_used = time.time() - self.last_time
self.last_time = time.time()
return time_used / run_steps
@property
def is_running(self) -> bool:
return self._is_running
@property
def eta(self) -> str:
if not self.is_running:
return '00:00:00'
remaining_time = time.time() - self.start_time
return seconds_to_hms(remaining_time)
def seconds_to_hms(seconds: int) -> str:
'''Convert the number of seconds to hh:mm:ss'''
h = math.floor(seconds / 3600)
m = math.floor((seconds - h * 3600) / 60)
s = int(seconds - h * 3600 - m * 60)
hms_str = '{:0>2}:{:0>2}:{:0>2}'.format(h, m, s)
return hms_str

@ -43,3 +43,16 @@ paddlespeech asr --input ./zh.wav | paddlespeech text --task punc
paddlespeech stats --task asr
paddlespeech stats --task tts
paddlespeech stats --task cls
# Speaker Verification
wget -c https://paddlespeech.bj.bcebos.com/vector/audio/85236145389.wav
paddlespeech vector --task spk --input 85236145389.wav
echo -e "demo1 85236145389.wav \n demo2 85236145389.wav" > vec.job
paddlespeech vector --task spk --input vec.job
echo -e "demo3 85236145389.wav \n demo4 85236145389.wav" | paddlespeech vector --task spk
rm 85236145389.wav
rm vec.job

@ -0,0 +1,11 @@
def pytest_addoption(parser):
parser.addoption("--device", action="store", default="cpu")
def pytest_generate_tests(metafunc):
# This is called for every test. Only get/set command line arguments
# if the argument is specified in the list of test "fixturenames".
option_value = metafunc.config.option.device
if "device" in metafunc.fixturenames and option_value is not None:
metafunc.parametrize("device", [option_value])

@ -0,0 +1,138 @@
# Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import os
import numpy as np
import paddle
import paddle.nn as nn
import paddle.nn.functional as F
from paddle.io import BatchSampler
from paddle.io import DataLoader
from paddle.io import Dataset
def test_add_noise(tmpdir, device):
paddle.device.set_device(device)
from paddlespeech.vector.io.augment import AddNoise
test_waveform = paddle.sin(
paddle.arange(16000.0, dtype="float32")).unsqueeze(0)
test_noise = paddle.cos(
paddle.arange(16000.0, dtype="float32")).unsqueeze(0)
wav_lens = paddle.ones([1], dtype="float32")
# Edge cases
no_noise = AddNoise(mix_prob=0.0)
assert no_noise(test_waveform, wav_lens).allclose(test_waveform)
def test_speed_perturb(device):
paddle.device.set_device(device)
from paddlespeech.vector.io.augment import SpeedPerturb
test_waveform = paddle.sin(
paddle.arange(16000.0, dtype="float32")).unsqueeze(0)
# Edge cases
no_perturb = SpeedPerturb(16000, perturb_prob=0.0)
assert no_perturb(test_waveform).allclose(test_waveform)
no_perturb = SpeedPerturb(16000, speeds=[100])
assert no_perturb(test_waveform).allclose(test_waveform)
# # Half speed
half_speed = SpeedPerturb(16000, speeds=[50])
assert half_speed(test_waveform).allclose(test_waveform[:, ::2], atol=3e-1)
def test_babble(device):
paddle.device.set_device(device)
from paddlespeech.vector.io.augment import AddBabble
test_waveform = paddle.stack(
(paddle.sin(paddle.arange(16000.0, dtype="float32")),
paddle.cos(paddle.arange(16000.0, dtype="float32")), ))
lengths = paddle.ones([2])
# Edge cases
no_babble = AddBabble(mix_prob=0.0)
assert no_babble(test_waveform, lengths).allclose(test_waveform)
no_babble = AddBabble(speaker_count=1, snr_low=1000, snr_high=1000)
assert no_babble(test_waveform, lengths).allclose(test_waveform)
# One babbler just averages the two speakers
babble = AddBabble(speaker_count=1).to(device)
expected = (test_waveform + test_waveform.roll(1, 0)) / 2
assert babble(test_waveform, lengths).allclose(expected, atol=1e-4)
def test_drop_freq(device):
paddle.device.set_device(device)
from paddlespeech.vector.io.augment import DropFreq
test_waveform = paddle.sin(
paddle.arange(16000.0, dtype="float32")).unsqueeze(0)
# Edge cases
no_drop = DropFreq(drop_prob=0.0)
assert no_drop(test_waveform).allclose(test_waveform)
no_drop = DropFreq(drop_count_low=0, drop_count_high=0)
assert no_drop(test_waveform).allclose(test_waveform)
# Check case where frequency range *does not* include signal frequency
drop_diff_freq = DropFreq(drop_freq_low=0.5, drop_freq_high=0.9)
assert drop_diff_freq(test_waveform).allclose(test_waveform, atol=1e-1)
# Check case where frequency range *does* include signal frequency
drop_same_freq = DropFreq(drop_freq_low=0.28, drop_freq_high=0.28)
assert drop_same_freq(test_waveform).allclose(
paddle.zeros([1, 16000]), atol=4e-1)
def test_drop_chunk(device):
paddle.device.set_device(device)
from paddlespeech.vector.io.augment import DropChunk
test_waveform = paddle.sin(
paddle.arange(16000.0, dtype="float32")).unsqueeze(0)
lengths = paddle.ones([1])
# Edge cases
no_drop = DropChunk(drop_prob=0.0)
assert no_drop(test_waveform, lengths).allclose(test_waveform)
no_drop = DropChunk(drop_length_low=0, drop_length_high=0)
assert no_drop(test_waveform, lengths).allclose(test_waveform)
no_drop = DropChunk(drop_count_low=0, drop_count_high=0)
assert no_drop(test_waveform, lengths).allclose(test_waveform)
no_drop = DropChunk(drop_start=0, drop_end=0)
assert no_drop(test_waveform, lengths).allclose(test_waveform)
# Specify all parameters to ensure it is deterministic
dropper = DropChunk(
drop_length_low=100,
drop_length_high=100,
drop_count_low=1,
drop_count_high=1,
drop_start=100,
drop_end=200,
noise_factor=0.0, )
expected_waveform = test_waveform.clone()
expected_waveform[:, 100:200] = 0.0
assert dropper(test_waveform, lengths).allclose(expected_waveform)
# Make sure amplitude is similar before and after
dropper = DropChunk(noise_factor=1.0)
drop_amplitude = dropper(test_waveform, lengths).abs().mean()
orig_amplitude = test_waveform.abs().mean()
assert drop_amplitude.allclose(orig_amplitude, atol=1e-2)
Loading…
Cancel
Save