Merge pull request #1630 from Honei/vox12

[vec]voxceleb convert dataset format to paddlespeech
pull/1690/head
Honei 3 years ago committed by GitHub
commit 48e0177767
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23

@ -34,14 +34,14 @@ from utils.utility import unzip
DATA_HOME = os.path.expanduser('~/.cache/paddle/dataset/speech')
URL_ROOT = 'http://www.openslr.org/resources/28'
URL_ROOT = '--no-check-certificate http://www.openslr.org/resources/28'
DATA_URL = URL_ROOT + '/rirs_noises.zip'
MD5_DATA = 'e6f48e257286e05de56413b4779d8ffb'
parser = argparse.ArgumentParser(description=__doc__)
parser.add_argument(
"--target_dir",
default=DATA_HOME + "/Aishell",
default=DATA_HOME + "/rirs_noise",
type=str,
help="Directory to save the dataset. (default: %(default)s)")
parser.add_argument(
@ -81,6 +81,10 @@ def create_manifest(data_dir, manifest_path_prefix):
},
ensure_ascii=False))
manifest_path = manifest_path_prefix + '.' + dtype
if not os.path.exists(os.path.dirname(manifest_path)):
os.makedirs(os.path.dirname(manifest_path))
with codecs.open(manifest_path, 'w', 'utf-8') as fout:
for line in json_lines:
fout.write(line + '\n')

@ -149,7 +149,7 @@ def prepare_dataset(base_url, data_list, target_dir, manifest_path,
# we will download the voxceleb1 data to ${target_dir}/vox1/dev/ or ${target_dir}/vox1/test directory
if not os.path.exists(os.path.join(target_dir, "wav")):
# download all dataset part
print("start to download the vox1 dev zip package")
print(f"start to download the vox1 zip package to {target_dir}")
for zip_part in data_list.keys():
download_url = " --no-check-certificate " + base_url + "/" + zip_part
download(

@ -22,10 +22,12 @@ import codecs
import glob
import json
import os
import subprocess
from pathlib import Path
import soundfile
from utils.utility import check_md5sum
from utils.utility import download
from utils.utility import unzip
@ -35,12 +37,22 @@ DATA_HOME = os.path.expanduser('.')
BASE_URL = "--no-check-certificate https://www.robots.ox.ac.uk/~vgg/data/voxceleb/data/"
# dev data
DEV_DATA_URL = BASE_URL + '/vox2_aac.zip'
DEV_MD5SUM = "bbc063c46078a602ca71605645c2a402"
DEV_LIST = {
"vox2_dev_aac_partaa": "da070494c573e5c0564b1d11c3b20577",
"vox2_dev_aac_partab": "17fe6dab2b32b48abaf1676429cdd06f",
"vox2_dev_aac_partac": "1de58e086c5edf63625af1cb6d831528",
"vox2_dev_aac_partad": "5a043eb03e15c5a918ee6a52aad477f9",
"vox2_dev_aac_partae": "cea401b624983e2d0b2a87fb5d59aa60",
"vox2_dev_aac_partaf": "fc886d9ba90ab88e7880ee98effd6ae9",
"vox2_dev_aac_partag": "d160ecc3f6ee3eed54d55349531cb42e",
"vox2_dev_aac_partah": "6b84a81b9af72a9d9eecbb3b1f602e65",
}
DEV_TARGET_DATA = "vox2_dev_aac_parta* vox2_dev_aac.zip bbc063c46078a602ca71605645c2a402"
# test data
TEST_DATA_URL = BASE_URL + '/vox2_test_aac.zip'
TEST_MD5SUM = "0d2b3ea430a821c33263b5ea37ede312"
TEST_LIST = {"vox2_test_aac.zip": "0d2b3ea430a821c33263b5ea37ede312"}
TEST_TARGET_DATA = "vox2_test_aac.zip vox2_test_aac.zip 0d2b3ea430a821c33263b5ea37ede312"
parser = argparse.ArgumentParser(description=__doc__)
parser.add_argument(
@ -68,6 +80,14 @@ args = parser.parse_args()
def create_manifest(data_dir, manifest_path_prefix):
"""Generate the voxceleb2 dataset manifest file.
We will create the ${manifest_path_prefix}.vox2 as the final manifest file
The dev and test wav info will be put in one manifest file.
Args:
data_dir (str): voxceleb2 wav directory, which include dev and test subdataset
manifest_path_prefix (str): manifest file prefix
"""
print("Creating manifest %s ..." % manifest_path_prefix)
json_lines = []
data_path = os.path.join(data_dir, "**", "*.wav")
@ -119,7 +139,19 @@ def create_manifest(data_dir, manifest_path_prefix):
print(f"{total_sec / total_num} sec/utt", file=f)
def download_dataset(url, md5sum, target_dir, dataset):
def download_dataset(base_url, data_list, target_data, target_dir, dataset):
"""Download the voxceleb2 zip package
Args:
base_url (str): the voxceleb2 dataset download baseline url
data_list (dict): the dataset part zip package and the md5 value
target_data (str): the final dataset zip info
target_dir (str): the dataset stored directory
dataset (str): the dataset name, dev or test
Raises:
RuntimeError: the md5sum occurs error
"""
if not os.path.exists(target_dir):
os.makedirs(target_dir)
@ -129,9 +161,34 @@ def download_dataset(url, md5sum, target_dir, dataset):
# but the test dataset will unzip to aac
# so, wo create the ${target_dir}/test and unzip the m4a to test dir
if not os.path.exists(os.path.join(target_dir, dataset)):
filepath = download(url, md5sum, target_dir)
print(f"start to download the vox2 zip package to {target_dir}")
for zip_part in data_list.keys():
download_url = " --no-check-certificate " + base_url + "/" + zip_part
download(
url=download_url,
md5sum=data_list[zip_part],
target_dir=target_dir)
# pack the all part to target zip file
all_target_part, target_name, target_md5sum = target_data.split()
target_name = os.path.join(target_dir, target_name)
if not os.path.exists(target_name):
pack_part_cmd = "cat {}/{} > {}".format(target_dir, all_target_part,
target_name)
subprocess.call(pack_part_cmd, shell=True)
# check the target zip file md5sum
if not check_md5sum(target_name, target_md5sum):
raise RuntimeError("{} MD5 checkssum failed".format(target_name))
else:
print("Check {} md5sum successfully".format(target_name))
if dataset == "test":
unzip(filepath, os.path.join(target_dir, "test"))
# we need make the test directory
unzip(target_name, os.path.join(target_dir, "test"))
else:
# upzip dev zip pacakge and will create the dev directory
unzip(target_name, target_dir)
def main():
@ -142,14 +199,16 @@ def main():
print("download: {}".format(args.download))
if args.download:
download_dataset(
url=DEV_DATA_URL,
md5sum=DEV_MD5SUM,
base_url=BASE_URL,
data_list=DEV_LIST,
target_data=DEV_TARGET_DATA,
target_dir=args.target_dir,
dataset="dev")
download_dataset(
url=TEST_DATA_URL,
md5sum=TEST_MD5SUM,
base_url=BASE_URL,
data_list=TEST_LIST,
target_data=TEST_TARGET_DATA,
target_dir=args.target_dir,
dataset="test")

@ -1,14 +1,16 @@
###########################################
# Data #
###########################################
# we should explicitly specify the wav path of vox2 audio data converted from m4a
vox2_base_path:
augment: True
batch_size: 16
batch_size: 32
num_workers: 2
num_speakers: 7205 # 1211 vox1, 5994 vox2, 7205 vox1+2, test speakers: 41
num_speakers: 1211 # 1211 vox1, 5994 vox2, 7205 vox1+2, test speakers: 41
shuffle: True
skip_prep: False
split_ratio: 0.9
chunk_duration: 3.0 # seconds
random_chunk: True
verification_file: data/vox1/veri_test2.txt
###########################################################
# FEATURE EXTRACTION SETTING #
@ -26,7 +28,6 @@ hop_size: 160 #10ms, sample rate 16000, 10 * 16000 / 1000 = 160
# if we want use another model, please choose another configuration yaml file
model:
input_size: 80
# "channels": [512, 512, 512, 512, 1536],
channels: [1024, 1024, 1024, 1024, 3072]
kernel_sizes: [5, 3, 3, 3, 1]
dilations: [1, 2, 3, 4, 1]
@ -38,8 +39,8 @@ model:
###########################################
seed: 1986 # according from speechbrain configuration
epochs: 10
save_interval: 1
log_interval: 1
save_interval: 10
log_interval: 10
learning_rate: 1e-8

@ -0,0 +1,53 @@
###########################################
# Data #
###########################################
augment: True
batch_size: 16
num_workers: 2
num_speakers: 1211 # 1211 vox1, 5994 vox2, 7205 vox1+2, test speakers: 41
shuffle: True
skip_prep: False
split_ratio: 0.9
chunk_duration: 3.0 # seconds
random_chunk: True
verification_file: data/vox1/veri_test2.txt
###########################################################
# FEATURE EXTRACTION SETTING #
###########################################################
# currently, we only support fbank
sr: 16000 # sample rate
n_mels: 80
window_size: 400 #25ms, sample rate 16000, 25 * 16000 / 1000 = 400
hop_size: 160 #10ms, sample rate 16000, 10 * 16000 / 1000 = 160
###########################################################
# MODEL SETTING #
###########################################################
# currently, we only support ecapa-tdnn in the ecapa_tdnn.yaml
# if we want use another model, please choose another configuration yaml file
model:
input_size: 80
channels: [512, 512, 512, 512, 1536]
kernel_sizes: [5, 3, 3, 3, 1]
dilations: [1, 2, 3, 4, 1]
attention_channels: 128
lin_neurons: 192
###########################################
# Training #
###########################################
seed: 1986 # according from speechbrain configuration
epochs: 100
save_interval: 10
log_interval: 10
learning_rate: 1e-8
###########################################
# Testing #
###########################################
global_embedding_norm: True
embedding_mean_norm: True
embedding_std_norm: False

@ -12,7 +12,7 @@
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
stage=1
stage=0
stop_stage=100
. ${MAIN_ROOT}/utils/parse_options.sh || exit -1;
@ -30,29 +30,114 @@ dir=$1
conf_path=$2
mkdir -p ${dir}
if [ ${stage} -le 0 ] && [ ${stop_stage} -ge 0 ]; then
# data prepare for vox1 and vox2, vox2 must be converted from m4a to wav
# we should use the local/convert.sh convert m4a to wav
python3 local/data_prepare.py \
--data-dir ${dir} \
--config ${conf_path}
fi
# Generally the `MAIN_ROOT` refers to the root of PaddleSpeech,
# which is defined in the path.sh
# And we will download the voxceleb data and rirs noise to ${MAIN_ROOT}/dataset
TARGET_DIR=${MAIN_ROOT}/dataset
mkdir -p ${TARGET_DIR}
if [ ${stage} -le 1 ] && [ ${stop_stage} -ge 1 ]; then
# download data, generate manifests
python3 ${TARGET_DIR}/voxceleb/voxceleb1.py \
--manifest_prefix="data/vox1/manifest" \
# download data, generate manifests
# we will generate the manifest.{dev,test} file from ${TARGET_DIR}/voxceleb/vox1/{dev,test} directory
# and generate the meta info and download the trial file
# manifest.dev: 148642
# manifest.test: 4847
echo "Start to download vox1 dataset and generate the manifest files "
python3 ${TARGET_DIR}/voxceleb/voxceleb1.py \
--manifest_prefix="${dir}/vox1/manifest" \
--target_dir="${TARGET_DIR}/voxceleb/vox1/"
if [ $? -ne 0 ]; then
echo "Prepare voxceleb failed. Terminated."
exit 1
fi
if [ $? -ne 0 ]; then
echo "Prepare voxceleb1 failed. Terminated."
exit 1
fi
fi
if [ ${stage} -le 2 ] && [ ${stop_stage} -ge 2 ]; then
# download voxceleb2 data
# we will download the data and unzip the package
# and we will store the m4a file in ${TARGET_DIR}/voxceleb/vox2/{dev,test}
echo "start to download vox2 dataset"
python3 ${TARGET_DIR}/voxceleb/voxceleb2.py \
--download \
--target_dir="${TARGET_DIR}/voxceleb/vox2/"
if [ $? -ne 0 ]; then
echo "Download voxceleb2 dataset failed. Terminated."
exit 1
fi
fi
if [ ${stage} -le 3 ] && [ ${stop_stage} -ge 3 ]; then
# convert the m4a to wav
# and we will not delete the original m4a file
echo "start to convert the m4a to wav"
bash local/convert.sh ${TARGET_DIR}/voxceleb/vox2/test/ || exit 1;
if [ $? -ne 0 ]; then
echo "Convert voxceleb2 dataset from m4a to wav failed. Terminated."
exit 1
fi
echo "m4a convert to wav operation finished"
fi
if [ ${stage} -le 4 ] && [ ${stop_stage} -ge 4 ]; then
# generate the vox2 manifest file from wav file
# we will generate the ${dir}/vox2/manifest.vox2
# because we use all the vox2 dataset to train, so collect all the vox2 data in one file
echo "start generate the vox2 manifest files"
python3 ${TARGET_DIR}/voxceleb/voxceleb2.py \
--generate \
--manifest_prefix="${dir}/vox2/manifest" \
--target_dir="${TARGET_DIR}/voxceleb/vox2/"
# for dataset in train dev test; do
# mv data/manifest.${dataset} data/manifest.${dataset}.raw
# done
fi
if [ $? -ne 0 ]; then
echo "Prepare voxceleb2 dataset failed. Terminated."
exit 1
fi
fi
if [ ${stage} -le 5 ] && [ ${stop_stage} -ge 5 ]; then
# generate the vox csv file
# Currently, our training system use csv file for dataset
echo "convert the json format to csv format to be compatible with training process"
python3 local/make_vox_csv_dataset_from_json.py\
--train "${dir}/vox1/manifest.dev" "${dir}/vox2/manifest.vox2"\
--test "${dir}/vox1/manifest.test" \
--target_dir "${dir}/vox/" \
--config ${conf_path}
if [ $? -ne 0 ]; then
echo "Prepare voxceleb failed. Terminated."
exit 1
fi
fi
if [ ${stage} -le 6 ] && [ ${stop_stage} -ge 6 ]; then
# generate the open rir noise manifest file
echo "generate the open rir noise manifest file"
python3 ${TARGET_DIR}/rir_noise/rir_noise.py\
--manifest_prefix="${dir}/rir_noise/manifest" \
--target_dir="${TARGET_DIR}/rir_noise/"
if [ $? -ne 0 ]; then
echo "Prepare rir_noise failed. Terminated."
exit 1
fi
fi
if [ ${stage} -le 7 ] && [ ${stop_stage} -ge 7 ]; then
# generate the open rir noise manifest file
echo "generate the open rir noise csv file"
python3 local/make_rirs_noise_csv_dataset_from_json.py \
--noise_dir="${TARGET_DIR}/rir_noise/" \
--data_dir="${dir}/rir_noise/" \
--config ${conf_path}
if [ $? -ne 0 ]; then
echo "Prepare rir_noise failed. Terminated."
exit 1
fi
fi

@ -0,0 +1,167 @@
# Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""
Convert the PaddleSpeech jsonline format data to csv format data in voxceleb experiment.
Currently, Speaker Identificaton Training process use csv format.
"""
import argparse
import csv
import os
from typing import List
import tqdm
from yacs.config import CfgNode
from paddleaudio import load as load_audio
from paddlespeech.s2t.utils.log import Log
from paddlespeech.vector.utils.vector_utils import get_chunks
logger = Log(__name__).getlog()
def get_chunks_list(wav_file: str,
split_chunks: bool,
base_path: str,
chunk_duration: float=3.0) -> List[List[str]]:
"""Get the single audio file info
Args:
wav_file (list): the wav audio file and get this audio segment info list
split_chunks (bool): audio split flag
base_path (str): the audio base path
chunk_duration (float): the chunk duration.
if set the split_chunks, we split the audio into multi-chunks segment.
"""
waveform, sr = load_audio(wav_file)
audio_id = wav_file.split("/rir_noise/")[-1].split(".")[0]
audio_duration = waveform.shape[0] / sr
ret = []
if split_chunks and audio_duration > chunk_duration: # Split into pieces of self.chunk_duration seconds.
uniq_chunks_list = get_chunks(chunk_duration, audio_id, audio_duration)
for idx, chunk in enumerate(uniq_chunks_list):
s, e = chunk.split("_")[-2:] # Timestamps of start and end
start_sample = int(float(s) * sr)
end_sample = int(float(e) * sr)
# currently, all vector csv data format use one representation
# id, duration, wav, start, stop, label
# in rirs noise, all the label name is 'noise'
# the label is string type and we will convert it to integer type in training
ret.append([
chunk, audio_duration, wav_file, start_sample, end_sample,
"noise"
])
else: # Keep whole audio.
ret.append(
[audio_id, audio_duration, wav_file, 0, waveform.shape[0], "noise"])
return ret
def generate_csv(wav_files,
output_file: str,
base_path: str,
split_chunks: bool=True):
"""Prepare the csv file according the wav files
Args:
wav_files (list): all the audio list to prepare the csv file
output_file (str): the output csv file
config (CfgNode): yaml configuration content
split_chunks (bool): audio split flag
"""
logger.info(f'Generating csv: {output_file}')
header = ["utt_id", "duration", "wav", "start", "stop", "label"]
csv_lines = []
for item in tqdm.tqdm(wav_files):
csv_lines.extend(
get_chunks_list(
item, base_path=base_path, split_chunks=split_chunks))
if not os.path.exists(os.path.dirname(output_file)):
os.makedirs(os.path.dirname(output_file))
with open(output_file, mode="w") as csv_f:
csv_writer = csv.writer(
csv_f, delimiter=",", quotechar='"', quoting=csv.QUOTE_MINIMAL)
csv_writer.writerow(header)
for line in csv_lines:
csv_writer.writerow(line)
def prepare_data(args, config):
"""Convert the jsonline format to csv format
Args:
args (argparse.Namespace): scripts args
config (CfgNode): yaml configuration content
"""
# if external config set the skip_prep flat, we will do nothing
if config.skip_prep:
return
base_path = args.noise_dir
wav_path = os.path.join(base_path, "RIRS_NOISES")
logger.info(f"base path: {base_path}")
logger.info(f"wav path: {wav_path}")
rir_list = os.path.join(wav_path, "real_rirs_isotropic_noises", "rir_list")
rir_files = []
with open(rir_list, 'r') as f:
for line in f.readlines():
rir_file = line.strip().split(' ')[-1]
rir_files.append(os.path.join(base_path, rir_file))
noise_list = os.path.join(wav_path, "pointsource_noises", "noise_list")
noise_files = []
with open(noise_list, 'r') as f:
for line in f.readlines():
noise_file = line.strip().split(' ')[-1]
noise_files.append(os.path.join(base_path, noise_file))
csv_path = os.path.join(args.data_dir, 'csv')
logger.info(f"csv path: {csv_path}")
generate_csv(
rir_files, os.path.join(csv_path, 'rir.csv'), base_path=base_path)
generate_csv(
noise_files, os.path.join(csv_path, 'noise.csv'), base_path=base_path)
if __name__ == "__main__":
parser = argparse.ArgumentParser(description=__doc__)
parser.add_argument(
"--noise_dir",
default=None,
required=True,
help="The noise dataset dataset directory.")
parser.add_argument(
"--data_dir",
default=None,
required=True,
help="The target directory stores the csv files")
parser.add_argument(
"--config",
default=None,
required=True,
type=str,
help="configuration file")
args = parser.parse_args()
# parse the yaml config file
config = CfgNode(new_allowed=True)
if args.config:
config.merge_from_file(args.config)
# prepare the csv file from jsonlines files
prepare_data(args, config)

@ -0,0 +1,251 @@
# Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""
Convert the PaddleSpeech jsonline format data to csv format data in voxceleb experiment.
Currently, Speaker Identificaton Training process use csv format.
"""
import argparse
import csv
import json
import os
import random
import tqdm
from yacs.config import CfgNode
from paddleaudio import load as load_audio
from paddlespeech.s2t.utils.log import Log
from paddlespeech.vector.utils.vector_utils import get_chunks
logger = Log(__name__).getlog()
def prepare_csv(wav_files, output_file, config, split_chunks=True):
"""Prepare the csv file according the wav files
Args:
wav_files (list): all the audio list to prepare the csv file
output_file (str): the output csv file
config (CfgNode): yaml configuration content
split_chunks (bool, optional): audio split flag. Defaults to True.
"""
if not os.path.exists(os.path.dirname(output_file)):
os.makedirs(os.path.dirname(output_file))
csv_lines = []
header = ["utt_id", "duration", "wav", "start", "stop", "label"]
# voxceleb meta info for each training utterance segment
# we extract a segment from a utterance to train
# and the segment' period is between start and stop time point in the original wav file
# each field in the meta info means as follows:
# utt_id: the utterance segment name, which is uniq in training dataset
# duration: the total utterance time
# wav: utterance file path, which should be absoulute path
# start: start point in the original wav file sample point range
# stop: stop point in the original wav file sample point range
# label: the utterance segment's label name,
# which is speaker name in speaker verification domain
for item in tqdm.tqdm(wav_files, total=len(wav_files)):
item = json.loads(item.strip())
audio_id = item['utt'].replace(".wav",
"") # we remove the wav suffix name
audio_duration = item['feat_shape'][0]
wav_file = item['feat']
label = audio_id.split('-')[
0] # speaker name in speaker verification domain
waveform, sr = load_audio(wav_file)
if split_chunks:
uniq_chunks_list = get_chunks(config.chunk_duration, audio_id,
audio_duration)
for chunk in uniq_chunks_list:
s, e = chunk.split("_")[-2:] # Timestamps of start and end
start_sample = int(float(s) * sr)
end_sample = int(float(e) * sr)
# id, duration, wav, start, stop, label
# in vector, the label in speaker id
csv_lines.append([
chunk, audio_duration, wav_file, start_sample, end_sample,
label
])
else:
csv_lines.append([
audio_id, audio_duration, wav_file, 0, waveform.shape[0], label
])
with open(output_file, mode="w") as csv_f:
csv_writer = csv.writer(
csv_f, delimiter=',', quotechar='"', quoting=csv.QUOTE_MINIMAL)
csv_writer.writerow(header)
for line in csv_lines:
csv_writer.writerow(line)
def get_enroll_test_list(dataset_list, verification_file):
"""Get the enroll and test utterance list from all the voxceleb1 test utterance dataset.
Generally, we get the enroll and test utterances from the verfification file.
The verification file format as follows:
target/nontarget enroll-utt test-utt,
we set 0 as nontarget and 1 as target, eg:
0 a.wav b.wav
1 a.wav a.wav
Args:
dataset_list (list): all the dataset to get the test utterances
verification_file (str): voxceleb1 trial file
"""
logger.info(f"verification file: {verification_file}")
enroll_audios = set()
test_audios = set()
with open(verification_file, 'r') as f:
for line in f:
_, enroll_file, test_file = line.strip().split(' ')
enroll_audios.add('-'.join(enroll_file.split('/')))
test_audios.add('-'.join(test_file.split('/')))
enroll_files = []
test_files = []
for dataset in dataset_list:
with open(dataset, 'r') as f:
for line in f:
# audio_id may be in enroll and test at the same time
# eg: 1 a.wav a.wav
# the audio a.wav is enroll and test file at the same time
audio_id = json.loads(line.strip())['utt']
if audio_id in enroll_audios:
enroll_files.append(line)
if audio_id in test_audios:
test_files.append(line)
enroll_files = sorted(enroll_files)
test_files = sorted(test_files)
return enroll_files, test_files
def get_train_dev_list(dataset_list, target_dir, split_ratio):
"""Get the train and dev utterance list from all the training utterance dataset.
Generally, we use the split_ratio as the train dataset ratio,
and the remaining utterance (ratio is 1 - split_ratio) is the dev dataset
Args:
dataset_list (list): all the dataset to get the all utterances
target_dir (str): the target train and dev directory,
we will create the csv directory to store the {train,dev}.csv file
split_ratio (float): train dataset ratio in all utterance list
"""
logger.info("start to get train and dev utt list")
if not os.path.exists(os.path.join(target_dir, "meta")):
os.makedirs(os.path.join(target_dir, "meta"))
audio_files = []
speakers = set()
for dataset in dataset_list:
with open(dataset, 'r') as f:
for line in f:
# the label is speaker name
label_name = json.loads(line.strip())['utt2spk']
speakers.add(label_name)
audio_files.append(line.strip())
speakers = sorted(speakers)
logger.info(f"we get {len(speakers)} speakers from all the train dataset")
with open(os.path.join(target_dir, "meta", "label2id.txt"), 'w') as f:
for label_id, label_name in enumerate(speakers):
f.write(f'{label_name} {label_id}\n')
logger.info(
f'we store the speakers to {os.path.join(target_dir, "meta", "label2id.txt")}'
)
# the split_ratio is for train dataset
# the remaining is for dev dataset
split_idx = int(split_ratio * len(audio_files))
audio_files = sorted(audio_files)
random.shuffle(audio_files)
train_files, dev_files = audio_files[:split_idx], audio_files[split_idx:]
logger.info(
f"we get train utterances: {len(train_files)}, dev utterance: {len(dev_files)}"
)
return train_files, dev_files
def prepare_data(args, config):
"""Convert the jsonline format to csv format
Args:
args (argparse.Namespace): scripts args
config (CfgNode): yaml configuration content
"""
# stage0: set the random seed
random.seed(config.seed)
# if external config set the skip_prep flat, we will do nothing
if config.skip_prep:
return
# stage 1: prepare the enroll and test csv file
# And we generate the speaker to label file label2id.txt
logger.info("start to prepare the data csv file")
enroll_files, test_files = get_enroll_test_list(
[args.test], verification_file=config.verification_file)
prepare_csv(
enroll_files,
os.path.join(args.target_dir, "csv", "enroll.csv"),
config,
split_chunks=False)
prepare_csv(
test_files,
os.path.join(args.target_dir, "csv", "test.csv"),
config,
split_chunks=False)
# stage 2: prepare the train and dev csv file
# we get the train dataset ratio as config.split_ratio
# and the remaining is dev dataset
logger.info("start to prepare the data csv file")
train_files, dev_files = get_train_dev_list(
args.train, target_dir=args.target_dir, split_ratio=config.split_ratio)
prepare_csv(train_files,
os.path.join(args.target_dir, "csv", "train.csv"), config)
prepare_csv(dev_files,
os.path.join(args.target_dir, "csv", "dev.csv"), config)
if __name__ == "__main__":
parser = argparse.ArgumentParser(description=__doc__)
parser.add_argument(
"--train",
required=True,
nargs='+',
help="The jsonline files list for train.")
parser.add_argument(
"--test", required=True, help="The jsonline file for test")
parser.add_argument(
"--target_dir",
default=None,
required=True,
help="The target directory stores the csv files and meta file.")
parser.add_argument(
"--config",
default=None,
required=True,
type=str,
help="configuration file")
args = parser.parse_args()
# parse the yaml config file
config = CfgNode(new_allowed=True)
if args.config:
config.merge_from_file(args.config)
# prepare the csv file from jsonlines files
prepare_data(args, config)

@ -18,24 +18,22 @@ set -e
#######################################################################
# stage 0: data prepare, including voxceleb1 download and generate {train,dev,enroll,test}.csv
# voxceleb2 data is m4a format, so we need user to convert the m4a to wav yourselves as described in Readme.md with the script local/convert.sh
# voxceleb2 data is m4a format, so we need convert the m4a to wav yourselves with the script local/convert.sh
# stage 1: train the speaker identification model
# stage 2: test speaker identification
# stage 3: extract the training embeding to train the LDA and PLDA
# stage 3: (todo)extract the training embeding to train the LDA and PLDA
######################################################################
# we can set the variable PPAUDIO_HOME to specifiy the root directory of the downloaded vox1 and vox2 dataset
# default the dataset will be stored in the ~/.paddleaudio/
# the vox2 dataset is stored in m4a format, we need to convert the audio from m4a to wav yourself
# and put all of them to ${PPAUDIO_HOME}/datasets/vox2
# we will find the wav from ${PPAUDIO_HOME}/datasets/vox1/wav and ${PPAUDIO_HOME}/datasets/vox2/wav
# export PPAUDIO_HOME=
# and put all of them to ${MAIN_ROOT}/datasets/vox2
# we will find the wav from ${MAIN_ROOT}/datasets/vox1/{dev,test}/wav and ${MAIN_ROOT}/datasets/vox2/wav
stage=0
stop_stage=50
# data directory
# if we set the variable ${dir}, we will store the wav info to this directory
# otherwise, we will store the wav info to vox1 and vox2 directory respectively
# otherwise, we will store the wav info to data/vox1 and data/vox2 directory respectively
# vox2 wav path, we must convert the m4a format to wav format
dir=data/ # data info directory
@ -64,6 +62,6 @@ if [ $stage -le 2 ] && [ ${stop_stage} -ge 2 ]; then
fi
# if [ $stage -le 3 ]; then
# # stage 2: extract the training embeding to train the LDA and PLDA
# # stage 3: extract the training embeding to train the LDA and PLDA
# # todo: extract the training embedding
# fi

@ -261,7 +261,7 @@ class VoxCeleb(Dataset):
output_file: str,
split_chunks: bool=True):
print(f'Generating csv: {output_file}')
header = ["ID", "duration", "wav", "start", "stop", "spk_id"]
header = ["id", "duration", "wav", "start", "stop", "spk_id"]
# Note: this may occurs c++ execption, but the program will execute fine
# so we can ignore the execption
with Pool(cpu_count()) as p:

@ -21,10 +21,11 @@ from paddle.io import DataLoader
from tqdm import tqdm
from yacs.config import CfgNode
from paddleaudio.datasets import VoxCeleb
from paddleaudio.metric import compute_eer
from paddlespeech.s2t.utils.log import Log
from paddlespeech.vector.io.batch import batch_feature_normalize
from paddlespeech.vector.io.dataset import CSVDataset
from paddlespeech.vector.io.embedding_norm import InputNormalization
from paddlespeech.vector.models.ecapa_tdnn import EcapaTdnn
from paddlespeech.vector.modules.sid_model import SpeakerIdetification
from paddlespeech.vector.training.seeding import seed_everything
@ -32,6 +33,91 @@ from paddlespeech.vector.training.seeding import seed_everything
logger = Log(__name__).getlog()
def compute_dataset_embedding(data_loader, model, mean_var_norm_emb, config,
id2embedding):
"""compute the dataset embeddings
Args:
data_loader (_type_): _description_
model (_type_): _description_
mean_var_norm_emb (_type_): _description_
config (_type_): _description_
"""
logger.info(
f'Computing embeddings on {data_loader.dataset.csv_path} dataset')
with paddle.no_grad():
for batch_idx, batch in enumerate(tqdm(data_loader)):
# stage 8-1: extrac the audio embedding
ids, feats, lengths = batch['ids'], batch['feats'], batch['lengths']
embeddings = model.backbone(feats, lengths).squeeze(
-1) # (N, emb_size, 1) -> (N, emb_size)
# Global embedding normalization.
# if we use the global embedding norm
# eer can reduece about relative 10%
if config.global_embedding_norm and mean_var_norm_emb:
lengths = paddle.ones([embeddings.shape[0]])
embeddings = mean_var_norm_emb(embeddings, lengths)
# Update embedding dict.
id2embedding.update(dict(zip(ids, embeddings)))
def compute_verification_scores(id2embedding, train_cohort, config):
labels = []
enroll_ids = []
test_ids = []
logger.info(f"read the trial from {config.verification_file}")
cos_sim_func = paddle.nn.CosineSimilarity(axis=-1)
scores = []
with open(config.verification_file, 'r') as f:
for line in f.readlines():
label, enroll_id, test_id = line.strip().split(' ')
enroll_id = enroll_id.split('.')[0].replace('/', '-')
test_id = test_id.split('.')[0].replace('/', '-')
labels.append(int(label))
enroll_emb = id2embedding[enroll_id]
test_emb = id2embedding[test_id]
score = cos_sim_func(enroll_emb, test_emb).item()
if "score_norm" in config:
# Getting norm stats for enroll impostors
enroll_rep = paddle.tile(
enroll_emb, repeat_times=[train_cohort.shape[0], 1])
score_e_c = cos_sim_func(enroll_rep, train_cohort)
if "cohort_size" in config:
score_e_c, _ = paddle.topk(
score_e_c, k=config.cohort_size, axis=0)
mean_e_c = paddle.mean(score_e_c, axis=0)
std_e_c = paddle.std(score_e_c, axis=0)
# Getting norm stats for test impostors
test_rep = paddle.tile(
test_emb, repeat_times=[train_cohort.shape[0], 1])
score_t_c = cos_sim_func(test_rep, train_cohort)
if "cohort_size" in config:
score_t_c, _ = paddle.topk(
score_t_c, k=config.cohort_size, axis=0)
mean_t_c = paddle.mean(score_t_c, axis=0)
std_t_c = paddle.std(score_t_c, axis=0)
if config.score_norm == "s-norm":
score_e = (score - mean_e_c) / std_e_c
score_t = (score - mean_t_c) / std_t_c
score = 0.5 * (score_e + score_t)
elif config.score_norm == "z-norm":
score = (score - mean_e_c) / std_e_c
elif config.score_norm == "t-norm":
score = (score - mean_t_c) / std_t_c
scores.append(score)
return scores, labels
def main(args, config):
# stage0: set the training device, cpu or gpu
paddle.set_device(args.device)
@ -58,9 +144,8 @@ def main(args, config):
# stage4: construct the enroll and test dataloader
enroll_dataset = VoxCeleb(
subset='enroll',
target_dir=args.data_dir,
enroll_dataset = CSVDataset(
os.path.join(args.data_dir, "vox/csv/enroll.csv"),
feat_type='melspectrogram',
random_chunk=False,
n_mels=config.n_mels,
@ -68,16 +153,15 @@ def main(args, config):
hop_length=config.hop_size)
enroll_sampler = BatchSampler(
enroll_dataset, batch_size=config.batch_size,
shuffle=True) # Shuffle to make embedding normalization more robust.
enrol_loader = DataLoader(enroll_dataset,
shuffle=False) # Shuffle to make embedding normalization more robust.
enroll_loader = DataLoader(enroll_dataset,
batch_sampler=enroll_sampler,
collate_fn=lambda x: batch_feature_normalize(
x, mean_norm=True, std_norm=False),
x, mean_norm=True, std_norm=False),
num_workers=config.num_workers,
return_list=True,)
test_dataset = VoxCeleb(
subset='test',
target_dir=args.data_dir,
test_dataset = CSVDataset(
os.path.join(args.data_dir, "vox/csv/test.csv"),
feat_type='melspectrogram',
random_chunk=False,
n_mels=config.n_mels,
@ -85,7 +169,7 @@ def main(args, config):
hop_length=config.hop_size)
test_sampler = BatchSampler(
test_dataset, batch_size=config.batch_size, shuffle=True)
test_dataset, batch_size=config.batch_size, shuffle=False)
test_loader = DataLoader(test_dataset,
batch_sampler=test_sampler,
collate_fn=lambda x: batch_feature_normalize(
@ -97,75 +181,65 @@ def main(args, config):
# stage6: global embedding norm to imporve the performance
logger.info(f"global embedding norm: {config.global_embedding_norm}")
if config.global_embedding_norm:
global_embedding_mean = None
global_embedding_std = None
mean_norm_flag = config.embedding_mean_norm
std_norm_flag = config.embedding_std_norm
batch_count = 0
# stage7: Compute embeddings of audios in enrol and test dataset from model.
if config.global_embedding_norm:
mean_var_norm_emb = InputNormalization(
norm_type="global",
mean_norm=config.embedding_mean_norm,
std_norm=config.embedding_std_norm)
if "score_norm" in config:
logger.info(f"we will do score norm: {config.score_norm}")
train_dataset = CSVDataset(
os.path.join(args.data_dir, "vox/csv/train.csv"),
feat_type='melspectrogram',
n_train_snts=config.n_train_snts,
random_chunk=False,
n_mels=config.n_mels,
window_size=config.window_size,
hop_length=config.hop_size)
train_sampler = BatchSampler(
train_dataset, batch_size=config.batch_size, shuffle=False)
train_loader = DataLoader(train_dataset,
batch_sampler=train_sampler,
collate_fn=lambda x: batch_feature_normalize(
x, mean_norm=True, std_norm=False),
num_workers=config.num_workers,
return_list=True,)
id2embedding = {}
# Run multi times to make embedding normalization more stable.
for i in range(2):
for dl in [enrol_loader, test_loader]:
logger.info(
f'Loop {[i+1]}: Computing embeddings on {dl.dataset.subset} dataset'
)
with paddle.no_grad():
for batch_idx, batch in enumerate(tqdm(dl)):
# stage 8-1: extrac the audio embedding
ids, feats, lengths = batch['ids'], batch['feats'], batch[
'lengths']
embeddings = model.backbone(feats, lengths).squeeze(
-1).numpy() # (N, emb_size, 1) -> (N, emb_size)
# Global embedding normalization.
# if we use the global embedding norm
# eer can reduece about relative 10%
if config.global_embedding_norm:
batch_count += 1
current_mean = embeddings.mean(
axis=0) if mean_norm_flag else 0
current_std = embeddings.std(
axis=0) if std_norm_flag else 1
# Update global mean and std.
if global_embedding_mean is None and global_embedding_std is None:
global_embedding_mean, global_embedding_std = current_mean, current_std
else:
weight = 1 / batch_count # Weight decay by batches.
global_embedding_mean = (
1 - weight
) * global_embedding_mean + weight * current_mean
global_embedding_std = (
1 - weight
) * global_embedding_std + weight * current_std
# Apply global embedding normalization.
embeddings = (embeddings - global_embedding_mean
) / global_embedding_std
# Update embedding dict.
id2embedding.update(dict(zip(ids, embeddings)))
logger.info("First loop for enroll and test dataset")
compute_dataset_embedding(enroll_loader, model, mean_var_norm_emb, config,
id2embedding)
compute_dataset_embedding(test_loader, model, mean_var_norm_emb, config,
id2embedding)
logger.info("Second loop for enroll and test dataset")
compute_dataset_embedding(enroll_loader, model, mean_var_norm_emb, config,
id2embedding)
compute_dataset_embedding(test_loader, model, mean_var_norm_emb, config,
id2embedding)
mean_var_norm_emb.save(
os.path.join(args.load_checkpoint, "mean_var_norm_emb"))
# stage 8: Compute cosine scores.
labels = []
enroll_ids = []
test_ids = []
logger.info(f"read the trial from {VoxCeleb.veri_test_file}")
with open(VoxCeleb.veri_test_file, 'r') as f:
for line in f.readlines():
label, enroll_id, test_id = line.strip().split(' ')
labels.append(int(label))
enroll_ids.append(enroll_id.split('.')[0].replace('/', '-'))
test_ids.append(test_id.split('.')[0].replace('/', '-'))
cos_sim_func = paddle.nn.CosineSimilarity(axis=1)
enrol_embeddings, test_embeddings = map(lambda ids: paddle.to_tensor(
np.asarray([id2embedding[uttid] for uttid in ids], dtype='float32')),
[enroll_ids, test_ids
]) # (N, emb_size)
scores = cos_sim_func(enrol_embeddings, test_embeddings)
train_cohort = None
if "score_norm" in config:
train_embeddings = {}
# cohort embedding not do mean and std norm
compute_dataset_embedding(train_loader, model, None, config,
train_embeddings)
train_cohort = paddle.stack(list(train_embeddings.values()))
# compute the scores
scores, labels = compute_verification_scores(id2embedding, train_cohort,
config)
# compute the EER and threshold
scores = paddle.to_tensor(scores)
EER, threshold = compute_eer(np.asarray(labels), scores.numpy())
logger.info(
f'EER of verification test: {EER*100:.4f}%, score threshold: {threshold:.5f}'

@ -23,13 +23,13 @@ from paddle.io import DistributedBatchSampler
from yacs.config import CfgNode
from paddleaudio.compliance.librosa import melspectrogram
from paddleaudio.datasets.voxceleb import VoxCeleb
from paddlespeech.s2t.utils.log import Log
from paddlespeech.vector.io.augment import build_augment_pipeline
from paddlespeech.vector.io.augment import waveform_augment
from paddlespeech.vector.io.batch import batch_pad_right
from paddlespeech.vector.io.batch import feature_normalize
from paddlespeech.vector.io.batch import waveform_collate_fn
from paddlespeech.vector.io.dataset import CSVDataset
from paddlespeech.vector.models.ecapa_tdnn import EcapaTdnn
from paddlespeech.vector.modules.loss import AdditiveAngularMargin
from paddlespeech.vector.modules.loss import LogSoftmaxWrapper
@ -54,8 +54,12 @@ def main(args, config):
# stage2: data prepare, such vox1 and vox2 data, and augment noise data and pipline
# note: some cmd must do in rank==0, so wo will refactor the data prepare code
train_dataset = VoxCeleb('train', target_dir=args.data_dir)
dev_dataset = VoxCeleb('dev', target_dir=args.data_dir)
train_dataset = CSVDataset(
csv_path=os.path.join(args.data_dir, "vox/csv/train.csv"),
label2id_path=os.path.join(args.data_dir, "vox/meta/label2id.txt"))
dev_dataset = CSVDataset(
csv_path=os.path.join(args.data_dir, "vox/csv/dev.csv"),
label2id_path=os.path.join(args.data_dir, "vox/meta/label2id.txt"))
if config.augment:
augment_pipeline = build_augment_pipeline(target_dir=args.data_dir)
@ -67,7 +71,7 @@ def main(args, config):
# stage4: build the speaker verification train instance with backbone model
model = SpeakerIdetification(
backbone=ecapa_tdnn, num_class=VoxCeleb.num_speakers)
backbone=ecapa_tdnn, num_class=config.num_speakers)
# stage5: build the optimizer, we now only construct the AdamW optimizer
# 140000 is single gpu steps
@ -193,15 +197,15 @@ def main(args, config):
paddle.optimizer.lr.LRScheduler):
optimizer._learning_rate.step()
optimizer.clear_grad()
train_run_cost += time.time() - train_start
# stage 9-8: Calculate average loss per batch
avg_loss += loss.numpy()[0]
avg_loss = loss.item()
# stage 9-9: Calculate metrics, which is one-best accuracy
preds = paddle.argmax(logits, axis=1)
num_corrects += (preds == labels).numpy().sum()
num_samples += feats.shape[0]
train_run_cost += time.time() - train_start
timer.count() # step plus one in timer
# stage 9-10: print the log information only on 0-rank per log-freq batchs
@ -220,8 +224,9 @@ def main(args, config):
train_feat_cost / config.log_interval)
print_msg += ' avg_train_cost: {:.5f} sec,'.format(
train_run_cost / config.log_interval)
print_msg += ' lr={:.4E} step/sec={:.2f} | ETA {}'.format(
lr, timer.timing, timer.eta)
print_msg += ' lr={:.4E} step/sec={:.2f} ips:{:.5f}| ETA {}'.format(
lr, timer.timing, timer.ips, timer.eta)
logger.info(print_msg)
avg_loss = 0

@ -14,6 +14,7 @@
# this is modified from SpeechBrain
# https://github.com/speechbrain/speechbrain/blob/085be635c07f16d42cd1295045bc46c407f1e15b/speechbrain/lobes/augment.py
import math
import os
from typing import List
import numpy as np
@ -21,8 +22,8 @@ import paddle
import paddle.nn as nn
import paddle.nn.functional as F
from paddleaudio.datasets.rirs_noises import OpenRIRNoise
from paddlespeech.s2t.utils.log import Log
from paddlespeech.vector.io.dataset import CSVDataset
from paddlespeech.vector.io.signal_processing import compute_amplitude
from paddlespeech.vector.io.signal_processing import convolve1d
from paddlespeech.vector.io.signal_processing import dB_to_amplitude
@ -509,7 +510,7 @@ class AddNoise(nn.Layer):
assert w >= 0, f'Target length {target_length} is less than origin length {x.shape[0]}'
return np.pad(x, [0, w], mode=mode, **kwargs)
ids = [item['id'] for item in batch]
ids = [item['utt_id'] for item in batch]
lengths = np.asarray([item['feat'].shape[0] for item in batch])
waveforms = list(
map(lambda x: pad(x, max(max_length, lengths.max().item())),
@ -589,7 +590,7 @@ class AddReverb(nn.Layer):
assert w >= 0, f'Target length {target_length} is less than origin length {x.shape[0]}'
return np.pad(x, [0, w], mode=mode, **kwargs)
ids = [item['id'] for item in batch]
ids = [item['utt_id'] for item in batch]
lengths = np.asarray([item['feat'].shape[0] for item in batch])
waveforms = list(
map(lambda x: pad(x, lengths.max().item()),
@ -839,8 +840,10 @@ def build_augment_pipeline(target_dir=None) -> List[paddle.nn.Layer]:
List[paddle.nn.Layer]: all augment process
"""
logger.info("start to build the augment pipeline")
noise_dataset = OpenRIRNoise('noise', target_dir=target_dir)
rir_dataset = OpenRIRNoise('rir', target_dir=target_dir)
noise_dataset = CSVDataset(csv_path=os.path.join(target_dir,
"rir_noise/csv/noise.csv"))
rir_dataset = CSVDataset(csv_path=os.path.join(target_dir,
"rir_noise/csv/rir.csv"))
wavedrop = TimeDomainSpecAugment(
sample_rate=16000,

@ -17,6 +17,17 @@ import paddle
def waveform_collate_fn(batch):
"""Wrap the waveform into a batch form
Args:
batch (list): the waveform list from the dataloader
the item of data include several field
feat: the utterance waveform data
label: the utterance label encoding data
Returns:
dict: the batch data to dataloader
"""
waveforms = np.stack([item['feat'] for item in batch])
labels = np.stack([item['label'] for item in batch])
@ -27,6 +38,18 @@ def feature_normalize(feats: paddle.Tensor,
mean_norm: bool=True,
std_norm: bool=True,
convert_to_numpy: bool=False):
"""Do one utterance feature normalization
Args:
feats (paddle.Tensor): the original utterance feat, such as fbank, mfcc
mean_norm (bool, optional): mean norm flag. Defaults to True.
std_norm (bool, optional): std norm flag. Defaults to True.
convert_to_numpy (bool, optional): convert the paddle.tensor to numpy
and do feature norm with numpy. Defaults to False.
Returns:
paddle.Tensor : the normalized feats
"""
# Features normalization if needed
# numpy.mean is a little with paddle.mean about 1e-6
if convert_to_numpy:
@ -60,7 +83,17 @@ def pad_right_2d(x, target_length, axis=-1, mode='constant', **kwargs):
def batch_feature_normalize(batch, mean_norm: bool=True, std_norm: bool=True):
ids = [item['id'] for item in batch]
"""Do batch utterance features normalization
Args:
batch (list): the batch feature from dataloader
mean_norm (bool, optional): mean normalization flag. Defaults to True.
std_norm (bool, optional): std normalization flag. Defaults to True.
Returns:
dict: the normalized batch features
"""
ids = [item['utt_id'] for item in batch]
lengths = np.asarray([item['feat'].shape[1] for item in batch])
feats = list(
map(lambda x: pad_right_2d(x, lengths.max()),

@ -0,0 +1,192 @@
# Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
from dataclasses import dataclass
from dataclasses import fields
from paddle.io import Dataset
from paddleaudio import load as load_audio
from paddleaudio.compliance.librosa import melspectrogram
from paddlespeech.s2t.utils.log import Log
logger = Log(__name__).getlog()
# the audio meta info in the vector CSVDataset
# utt_id: the utterance segment name
# duration: utterance segment time
# wav: utterance file path
# start: start point in the original wav file
# stop: stop point in the original wav file
# label: the utterance segment's label id
@dataclass
class meta_info:
"""the audio meta info in the vector CSVDataset
Args:
utt_id (str): the utterance segment name
duration (float): utterance segment time
wav (str): utterance file path
start (int): start point in the original wav file
stop (int): stop point in the original wav file
lab_id (str): the utterance segment's label id
"""
utt_id: str
duration: float
wav: str
start: int
stop: int
label: str
# csv dataset support feature type
# raw: return the pcm data sample point
# melspectrogram: fbank feature
feat_funcs = {
'raw': None,
'melspectrogram': melspectrogram,
}
class CSVDataset(Dataset):
def __init__(self,
csv_path,
label2id_path=None,
config=None,
random_chunk=True,
feat_type: str="raw",
n_train_snts: int=-1,
**kwargs):
"""Implement the CSV Dataset
Args:
csv_path (str): csv dataset file path
label2id_path (str): the utterance label to integer id map file path
config (CfgNode): yaml config
feat_type (str): dataset feature type. if it is raw, it return pcm data.
n_train_snts (int): select the n_train_snts sample from the dataset.
if n_train_snts = -1, dataset will load all the sample.
Default value is -1.
kwargs : feature type args
"""
super().__init__()
self.csv_path = csv_path
self.label2id_path = label2id_path
self.config = config
self.random_chunk = random_chunk
self.feat_type = feat_type
self.n_train_snts = n_train_snts
self.feat_config = kwargs
self.id2label = {}
self.label2id = {}
self.data = self.load_data_csv()
self.load_speaker_to_label()
def load_data_csv(self):
"""Load the csv dataset content and store them in the data property
the csv dataset's format has six fields,
that is audio_id or utt_id, audio duration, segment start point, segment stop point
and utterance label.
Note in training period, the utterance label must has a map to integer id in label2id_path
Returns:
list: the csv data with meta_info type
"""
data = []
with open(self.csv_path, 'r') as rf:
for line in rf.readlines()[1:]:
audio_id, duration, wav, start, stop, spk_id = line.strip(
).split(',')
data.append(
meta_info(audio_id,
float(duration), wav,
int(start), int(stop), spk_id))
if self.n_train_snts > 0:
sample_num = min(self.n_train_snts, len(data))
data = data[0:sample_num]
return data
def load_speaker_to_label(self):
"""Load the utterance label map content.
In vector domain, we call the utterance label as speaker label.
The speaker label is real speaker label in speaker verification domain,
and in language identification is language label.
"""
if not self.label2id_path:
logger.warning("No speaker id to label file")
return
with open(self.label2id_path, 'r') as f:
for line in f.readlines():
label_name, label_id = line.strip().split(' ')
self.label2id[label_name] = int(label_id)
self.id2label[int(label_id)] = label_name
def convert_to_record(self, idx: int):
"""convert the dataset sample to training record the CSV Dataset
Args:
idx (int) : the request index in all the dataset
"""
sample = self.data[idx]
record = {}
# To show all fields in a namedtuple: `type(sample)._fields`
for field in fields(sample):
record[field.name] = getattr(sample, field.name)
waveform, sr = load_audio(record['wav'])
# random select a chunk audio samples from the audio
if self.config and self.config.random_chunk:
num_wav_samples = waveform.shape[0]
num_chunk_samples = int(self.config.chunk_duration * sr)
start = random.randint(0, num_wav_samples - num_chunk_samples - 1)
stop = start + num_chunk_samples
else:
start = record['start']
stop = record['stop']
# we only return the waveform as feat
waveform = waveform[start:stop]
# all availabel feature type is in feat_funcs
assert self.feat_type in feat_funcs.keys(), \
f"Unknown feat_type: {self.feat_type}, it must be one in {list(feat_funcs.keys())}"
feat_func = feat_funcs[self.feat_type]
feat = feat_func(
waveform, sr=sr, **self.feat_config) if feat_func else waveform
record.update({'feat': feat})
if self.label2id:
record.update({'label': self.label2id[record['label']]})
return record
def __getitem__(self, idx):
"""Return the specific index sample
Args:
idx (int) : the request index in all the dataset
"""
return self.convert_to_record(idx)
def __len__(self):
"""Return the dataset length
Returns:
int: the length num of the dataset
"""
return len(self.data)

@ -0,0 +1,214 @@
# Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
from typing import Dict
import paddle
class InputNormalization:
spk_dict_mean: Dict[int, paddle.Tensor]
spk_dict_std: Dict[int, paddle.Tensor]
spk_dict_count: Dict[int, int]
def __init__(
self,
mean_norm=True,
std_norm=True,
norm_type="global", ):
"""Do feature or embedding mean and std norm
Args:
mean_norm (bool, optional): mean norm flag. Defaults to True.
std_norm (bool, optional): std norm flag. Defaults to True.
norm_type (str, optional): norm type. Defaults to "global".
"""
super().__init__()
self.training = True
self.mean_norm = mean_norm
self.std_norm = std_norm
self.norm_type = norm_type
self.glob_mean = paddle.to_tensor([0], dtype="float32")
self.glob_std = paddle.to_tensor([0], dtype="float32")
self.spk_dict_mean = {}
self.spk_dict_std = {}
self.spk_dict_count = {}
self.weight = 1.0
self.count = 0
self.eps = 1e-10
def __call__(self,
x,
lengths,
spk_ids=paddle.to_tensor([], dtype="float32")):
"""Returns the tensor with the surrounding context.
Args:
x (paddle.Tensor): A batch of tensors.
lengths (paddle.Tensor): A batch of tensors containing the relative length of each
sentence (e.g, [0.7, 0.9, 1.0]). It is used to avoid
computing stats on zero-padded steps.
spk_ids (_type_, optional): tensor containing the ids of each speaker (e.g, [0 10 6]).
It is used to perform per-speaker normalization when
norm_type='speaker'. Defaults to paddle.to_tensor([], dtype="float32").
Returns:
paddle.Tensor: The normalized feature or embedding
"""
N_batches = x.shape[0]
# print(f"x shape: {x.shape[1]}")
current_means = []
current_stds = []
for snt_id in range(N_batches):
# Avoiding padded time steps
# actual size is the actual time data length
actual_size = paddle.round(lengths[snt_id] *
x.shape[1]).astype("int32")
# computing actual time data statistics
current_mean, current_std = self._compute_current_stats(
x[snt_id, 0:actual_size, ...].unsqueeze(0))
current_means.append(current_mean)
current_stds.append(current_std)
if self.norm_type == "global":
current_mean = paddle.mean(paddle.stack(current_means), axis=0)
current_std = paddle.mean(paddle.stack(current_stds), axis=0)
if self.norm_type == "global":
if self.training:
if self.count == 0:
self.glob_mean = current_mean
self.glob_std = current_std
else:
self.weight = 1 / (self.count + 1)
self.glob_mean = (
1 - self.weight
) * self.glob_mean + self.weight * current_mean
self.glob_std = (
1 - self.weight
) * self.glob_std + self.weight * current_std
self.glob_mean.detach()
self.glob_std.detach()
self.count = self.count + 1
x = (x - self.glob_mean) / (self.glob_std)
return x
def _compute_current_stats(self, x):
"""Returns the tensor with the surrounding context.
Args:
x (paddle.Tensor): A batch of tensors.
Returns:
the statistics of the data
"""
# Compute current mean
if self.mean_norm:
current_mean = paddle.mean(x, axis=0).detach()
else:
current_mean = paddle.to_tensor([0.0], dtype="float32")
# Compute current std
if self.std_norm:
current_std = paddle.std(x, axis=0).detach()
else:
current_std = paddle.to_tensor([1.0], dtype="float32")
# Improving numerical stability of std
current_std = paddle.maximum(current_std,
self.eps * paddle.ones_like(current_std))
return current_mean, current_std
def _statistics_dict(self):
"""Fills the dictionary containing the normalization statistics.
"""
state = {}
state["count"] = self.count
state["glob_mean"] = self.glob_mean
state["glob_std"] = self.glob_std
state["spk_dict_mean"] = self.spk_dict_mean
state["spk_dict_std"] = self.spk_dict_std
state["spk_dict_count"] = self.spk_dict_count
return state
def _load_statistics_dict(self, state):
"""Loads the dictionary containing the statistics.
Arguments
---------
state : dict
A dictionary containing the normalization statistics.
"""
self.count = state["count"]
if isinstance(state["glob_mean"], int):
self.glob_mean = state["glob_mean"]
self.glob_std = state["glob_std"]
else:
self.glob_mean = state["glob_mean"] # .to(self.device_inp)
self.glob_std = state["glob_std"] # .to(self.device_inp)
# Loading the spk_dict_mean in the right device
self.spk_dict_mean = {}
for spk in state["spk_dict_mean"]:
self.spk_dict_mean[spk] = state["spk_dict_mean"][spk]
# Loading the spk_dict_std in the right device
self.spk_dict_std = {}
for spk in state["spk_dict_std"]:
self.spk_dict_std[spk] = state["spk_dict_std"][spk]
self.spk_dict_count = state["spk_dict_count"]
return state
def to(self, device):
"""Puts the needed tensors in the right device.
"""
self = super(InputNormalization, self).to(device)
self.glob_mean = self.glob_mean.to(device)
self.glob_std = self.glob_std.to(device)
for spk in self.spk_dict_mean:
self.spk_dict_mean[spk] = self.spk_dict_mean[spk].to(device)
self.spk_dict_std[spk] = self.spk_dict_std[spk].to(device)
return self
def save(self, path):
"""Save statistic dictionary.
Args:
path (str): A path where to save the dictionary.
"""
stats = self._statistics_dict()
paddle.save(stats, path)
def _load(self, path, end_of_epoch=False, device=None):
"""Load statistic dictionary.
Arguments
---------
path : str
The path of the statistic dictionary
device : str, None
Passed to paddle.load(..., map_location=device)
"""
del end_of_epoch # Unused here.
stats = paddle.load(path, map_location=device)
self._load_statistics_dict(stats)

@ -23,6 +23,7 @@ class Timer(object):
self.last_start_step = 0
self.current_step = 0
self._is_running = True
self.cur_ips = 0
def start(self):
self.last_time = time.time()
@ -43,12 +44,17 @@ class Timer(object):
self.last_start_step = self.current_step
time_used = time.time() - self.last_time
self.last_time = time.time()
self.cur_ips = run_steps / time_used
return time_used / run_steps
@property
def is_running(self) -> bool:
return self._is_running
@property
def ips(self) -> float:
return self.cur_ips
@property
def eta(self) -> str:
if not self.is_running:

@ -0,0 +1,32 @@
# Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
def get_chunks(seg_dur, audio_id, audio_duration):
"""Get all chunk segments from a utterance
Args:
seg_dur (float): segment chunk duration, seconds
audio_id (str): utterance name,
audio_duration (float): utterance duration, seconds
Returns:
List: all the chunk segments
"""
num_chunks = int(audio_duration / seg_dur) # all in seconds
chunk_lst = [
audio_id + "_" + str(i * seg_dur) + "_" + str(i * seg_dur + seg_dur)
for i in range(num_chunks)
]
return chunk_lst
Loading…
Cancel
Save