Merge pull request #1630 from Honei/vox12
[vec]voxceleb convert dataset format to paddlespeechpull/1690/head
commit
48e0177767
@ -0,0 +1,53 @@
|
|||||||
|
###########################################
|
||||||
|
# Data #
|
||||||
|
###########################################
|
||||||
|
augment: True
|
||||||
|
batch_size: 16
|
||||||
|
num_workers: 2
|
||||||
|
num_speakers: 1211 # 1211 vox1, 5994 vox2, 7205 vox1+2, test speakers: 41
|
||||||
|
shuffle: True
|
||||||
|
skip_prep: False
|
||||||
|
split_ratio: 0.9
|
||||||
|
chunk_duration: 3.0 # seconds
|
||||||
|
random_chunk: True
|
||||||
|
verification_file: data/vox1/veri_test2.txt
|
||||||
|
|
||||||
|
###########################################################
|
||||||
|
# FEATURE EXTRACTION SETTING #
|
||||||
|
###########################################################
|
||||||
|
# currently, we only support fbank
|
||||||
|
sr: 16000 # sample rate
|
||||||
|
n_mels: 80
|
||||||
|
window_size: 400 #25ms, sample rate 16000, 25 * 16000 / 1000 = 400
|
||||||
|
hop_size: 160 #10ms, sample rate 16000, 10 * 16000 / 1000 = 160
|
||||||
|
|
||||||
|
###########################################################
|
||||||
|
# MODEL SETTING #
|
||||||
|
###########################################################
|
||||||
|
# currently, we only support ecapa-tdnn in the ecapa_tdnn.yaml
|
||||||
|
# if we want use another model, please choose another configuration yaml file
|
||||||
|
model:
|
||||||
|
input_size: 80
|
||||||
|
channels: [512, 512, 512, 512, 1536]
|
||||||
|
kernel_sizes: [5, 3, 3, 3, 1]
|
||||||
|
dilations: [1, 2, 3, 4, 1]
|
||||||
|
attention_channels: 128
|
||||||
|
lin_neurons: 192
|
||||||
|
|
||||||
|
###########################################
|
||||||
|
# Training #
|
||||||
|
###########################################
|
||||||
|
seed: 1986 # according from speechbrain configuration
|
||||||
|
epochs: 100
|
||||||
|
save_interval: 10
|
||||||
|
log_interval: 10
|
||||||
|
learning_rate: 1e-8
|
||||||
|
|
||||||
|
|
||||||
|
###########################################
|
||||||
|
# Testing #
|
||||||
|
###########################################
|
||||||
|
global_embedding_norm: True
|
||||||
|
embedding_mean_norm: True
|
||||||
|
embedding_std_norm: False
|
||||||
|
|
@ -0,0 +1,167 @@
|
|||||||
|
# Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved.
|
||||||
|
#
|
||||||
|
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||||
|
# you may not use this file except in compliance with the License.
|
||||||
|
# You may obtain a copy of the License at
|
||||||
|
#
|
||||||
|
# http://www.apache.org/licenses/LICENSE-2.0
|
||||||
|
#
|
||||||
|
# Unless required by applicable law or agreed to in writing, software
|
||||||
|
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||||
|
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||||
|
# See the License for the specific language governing permissions and
|
||||||
|
# limitations under the License.
|
||||||
|
"""
|
||||||
|
Convert the PaddleSpeech jsonline format data to csv format data in voxceleb experiment.
|
||||||
|
Currently, Speaker Identificaton Training process use csv format.
|
||||||
|
"""
|
||||||
|
import argparse
|
||||||
|
import csv
|
||||||
|
import os
|
||||||
|
from typing import List
|
||||||
|
|
||||||
|
import tqdm
|
||||||
|
from yacs.config import CfgNode
|
||||||
|
|
||||||
|
from paddleaudio import load as load_audio
|
||||||
|
from paddlespeech.s2t.utils.log import Log
|
||||||
|
from paddlespeech.vector.utils.vector_utils import get_chunks
|
||||||
|
|
||||||
|
logger = Log(__name__).getlog()
|
||||||
|
|
||||||
|
|
||||||
|
def get_chunks_list(wav_file: str,
|
||||||
|
split_chunks: bool,
|
||||||
|
base_path: str,
|
||||||
|
chunk_duration: float=3.0) -> List[List[str]]:
|
||||||
|
"""Get the single audio file info
|
||||||
|
|
||||||
|
Args:
|
||||||
|
wav_file (list): the wav audio file and get this audio segment info list
|
||||||
|
split_chunks (bool): audio split flag
|
||||||
|
base_path (str): the audio base path
|
||||||
|
chunk_duration (float): the chunk duration.
|
||||||
|
if set the split_chunks, we split the audio into multi-chunks segment.
|
||||||
|
"""
|
||||||
|
waveform, sr = load_audio(wav_file)
|
||||||
|
audio_id = wav_file.split("/rir_noise/")[-1].split(".")[0]
|
||||||
|
audio_duration = waveform.shape[0] / sr
|
||||||
|
|
||||||
|
ret = []
|
||||||
|
if split_chunks and audio_duration > chunk_duration: # Split into pieces of self.chunk_duration seconds.
|
||||||
|
uniq_chunks_list = get_chunks(chunk_duration, audio_id, audio_duration)
|
||||||
|
|
||||||
|
for idx, chunk in enumerate(uniq_chunks_list):
|
||||||
|
s, e = chunk.split("_")[-2:] # Timestamps of start and end
|
||||||
|
start_sample = int(float(s) * sr)
|
||||||
|
end_sample = int(float(e) * sr)
|
||||||
|
|
||||||
|
# currently, all vector csv data format use one representation
|
||||||
|
# id, duration, wav, start, stop, label
|
||||||
|
# in rirs noise, all the label name is 'noise'
|
||||||
|
# the label is string type and we will convert it to integer type in training
|
||||||
|
ret.append([
|
||||||
|
chunk, audio_duration, wav_file, start_sample, end_sample,
|
||||||
|
"noise"
|
||||||
|
])
|
||||||
|
else: # Keep whole audio.
|
||||||
|
ret.append(
|
||||||
|
[audio_id, audio_duration, wav_file, 0, waveform.shape[0], "noise"])
|
||||||
|
return ret
|
||||||
|
|
||||||
|
|
||||||
|
def generate_csv(wav_files,
|
||||||
|
output_file: str,
|
||||||
|
base_path: str,
|
||||||
|
split_chunks: bool=True):
|
||||||
|
"""Prepare the csv file according the wav files
|
||||||
|
|
||||||
|
Args:
|
||||||
|
wav_files (list): all the audio list to prepare the csv file
|
||||||
|
output_file (str): the output csv file
|
||||||
|
config (CfgNode): yaml configuration content
|
||||||
|
split_chunks (bool): audio split flag
|
||||||
|
"""
|
||||||
|
logger.info(f'Generating csv: {output_file}')
|
||||||
|
header = ["utt_id", "duration", "wav", "start", "stop", "label"]
|
||||||
|
csv_lines = []
|
||||||
|
for item in tqdm.tqdm(wav_files):
|
||||||
|
csv_lines.extend(
|
||||||
|
get_chunks_list(
|
||||||
|
item, base_path=base_path, split_chunks=split_chunks))
|
||||||
|
|
||||||
|
if not os.path.exists(os.path.dirname(output_file)):
|
||||||
|
os.makedirs(os.path.dirname(output_file))
|
||||||
|
|
||||||
|
with open(output_file, mode="w") as csv_f:
|
||||||
|
csv_writer = csv.writer(
|
||||||
|
csv_f, delimiter=",", quotechar='"', quoting=csv.QUOTE_MINIMAL)
|
||||||
|
csv_writer.writerow(header)
|
||||||
|
for line in csv_lines:
|
||||||
|
csv_writer.writerow(line)
|
||||||
|
|
||||||
|
|
||||||
|
def prepare_data(args, config):
|
||||||
|
"""Convert the jsonline format to csv format
|
||||||
|
|
||||||
|
Args:
|
||||||
|
args (argparse.Namespace): scripts args
|
||||||
|
config (CfgNode): yaml configuration content
|
||||||
|
"""
|
||||||
|
# if external config set the skip_prep flat, we will do nothing
|
||||||
|
if config.skip_prep:
|
||||||
|
return
|
||||||
|
|
||||||
|
base_path = args.noise_dir
|
||||||
|
wav_path = os.path.join(base_path, "RIRS_NOISES")
|
||||||
|
logger.info(f"base path: {base_path}")
|
||||||
|
logger.info(f"wav path: {wav_path}")
|
||||||
|
rir_list = os.path.join(wav_path, "real_rirs_isotropic_noises", "rir_list")
|
||||||
|
rir_files = []
|
||||||
|
with open(rir_list, 'r') as f:
|
||||||
|
for line in f.readlines():
|
||||||
|
rir_file = line.strip().split(' ')[-1]
|
||||||
|
rir_files.append(os.path.join(base_path, rir_file))
|
||||||
|
|
||||||
|
noise_list = os.path.join(wav_path, "pointsource_noises", "noise_list")
|
||||||
|
noise_files = []
|
||||||
|
with open(noise_list, 'r') as f:
|
||||||
|
for line in f.readlines():
|
||||||
|
noise_file = line.strip().split(' ')[-1]
|
||||||
|
noise_files.append(os.path.join(base_path, noise_file))
|
||||||
|
|
||||||
|
csv_path = os.path.join(args.data_dir, 'csv')
|
||||||
|
logger.info(f"csv path: {csv_path}")
|
||||||
|
generate_csv(
|
||||||
|
rir_files, os.path.join(csv_path, 'rir.csv'), base_path=base_path)
|
||||||
|
generate_csv(
|
||||||
|
noise_files, os.path.join(csv_path, 'noise.csv'), base_path=base_path)
|
||||||
|
|
||||||
|
|
||||||
|
if __name__ == "__main__":
|
||||||
|
parser = argparse.ArgumentParser(description=__doc__)
|
||||||
|
parser.add_argument(
|
||||||
|
"--noise_dir",
|
||||||
|
default=None,
|
||||||
|
required=True,
|
||||||
|
help="The noise dataset dataset directory.")
|
||||||
|
parser.add_argument(
|
||||||
|
"--data_dir",
|
||||||
|
default=None,
|
||||||
|
required=True,
|
||||||
|
help="The target directory stores the csv files")
|
||||||
|
parser.add_argument(
|
||||||
|
"--config",
|
||||||
|
default=None,
|
||||||
|
required=True,
|
||||||
|
type=str,
|
||||||
|
help="configuration file")
|
||||||
|
args = parser.parse_args()
|
||||||
|
|
||||||
|
# parse the yaml config file
|
||||||
|
config = CfgNode(new_allowed=True)
|
||||||
|
if args.config:
|
||||||
|
config.merge_from_file(args.config)
|
||||||
|
|
||||||
|
# prepare the csv file from jsonlines files
|
||||||
|
prepare_data(args, config)
|
@ -0,0 +1,251 @@
|
|||||||
|
# Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved.
|
||||||
|
#
|
||||||
|
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||||
|
# you may not use this file except in compliance with the License.
|
||||||
|
# You may obtain a copy of the License at
|
||||||
|
#
|
||||||
|
# http://www.apache.org/licenses/LICENSE-2.0
|
||||||
|
#
|
||||||
|
# Unless required by applicable law or agreed to in writing, software
|
||||||
|
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||||
|
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||||
|
# See the License for the specific language governing permissions and
|
||||||
|
# limitations under the License.
|
||||||
|
"""
|
||||||
|
Convert the PaddleSpeech jsonline format data to csv format data in voxceleb experiment.
|
||||||
|
Currently, Speaker Identificaton Training process use csv format.
|
||||||
|
"""
|
||||||
|
import argparse
|
||||||
|
import csv
|
||||||
|
import json
|
||||||
|
import os
|
||||||
|
import random
|
||||||
|
|
||||||
|
import tqdm
|
||||||
|
from yacs.config import CfgNode
|
||||||
|
|
||||||
|
from paddleaudio import load as load_audio
|
||||||
|
from paddlespeech.s2t.utils.log import Log
|
||||||
|
from paddlespeech.vector.utils.vector_utils import get_chunks
|
||||||
|
|
||||||
|
logger = Log(__name__).getlog()
|
||||||
|
|
||||||
|
|
||||||
|
def prepare_csv(wav_files, output_file, config, split_chunks=True):
|
||||||
|
"""Prepare the csv file according the wav files
|
||||||
|
|
||||||
|
Args:
|
||||||
|
wav_files (list): all the audio list to prepare the csv file
|
||||||
|
output_file (str): the output csv file
|
||||||
|
config (CfgNode): yaml configuration content
|
||||||
|
split_chunks (bool, optional): audio split flag. Defaults to True.
|
||||||
|
"""
|
||||||
|
if not os.path.exists(os.path.dirname(output_file)):
|
||||||
|
os.makedirs(os.path.dirname(output_file))
|
||||||
|
csv_lines = []
|
||||||
|
header = ["utt_id", "duration", "wav", "start", "stop", "label"]
|
||||||
|
# voxceleb meta info for each training utterance segment
|
||||||
|
# we extract a segment from a utterance to train
|
||||||
|
# and the segment' period is between start and stop time point in the original wav file
|
||||||
|
# each field in the meta info means as follows:
|
||||||
|
# utt_id: the utterance segment name, which is uniq in training dataset
|
||||||
|
# duration: the total utterance time
|
||||||
|
# wav: utterance file path, which should be absoulute path
|
||||||
|
# start: start point in the original wav file sample point range
|
||||||
|
# stop: stop point in the original wav file sample point range
|
||||||
|
# label: the utterance segment's label name,
|
||||||
|
# which is speaker name in speaker verification domain
|
||||||
|
for item in tqdm.tqdm(wav_files, total=len(wav_files)):
|
||||||
|
item = json.loads(item.strip())
|
||||||
|
audio_id = item['utt'].replace(".wav",
|
||||||
|
"") # we remove the wav suffix name
|
||||||
|
audio_duration = item['feat_shape'][0]
|
||||||
|
wav_file = item['feat']
|
||||||
|
label = audio_id.split('-')[
|
||||||
|
0] # speaker name in speaker verification domain
|
||||||
|
waveform, sr = load_audio(wav_file)
|
||||||
|
if split_chunks:
|
||||||
|
uniq_chunks_list = get_chunks(config.chunk_duration, audio_id,
|
||||||
|
audio_duration)
|
||||||
|
for chunk in uniq_chunks_list:
|
||||||
|
s, e = chunk.split("_")[-2:] # Timestamps of start and end
|
||||||
|
start_sample = int(float(s) * sr)
|
||||||
|
end_sample = int(float(e) * sr)
|
||||||
|
# id, duration, wav, start, stop, label
|
||||||
|
# in vector, the label in speaker id
|
||||||
|
csv_lines.append([
|
||||||
|
chunk, audio_duration, wav_file, start_sample, end_sample,
|
||||||
|
label
|
||||||
|
])
|
||||||
|
else:
|
||||||
|
csv_lines.append([
|
||||||
|
audio_id, audio_duration, wav_file, 0, waveform.shape[0], label
|
||||||
|
])
|
||||||
|
|
||||||
|
with open(output_file, mode="w") as csv_f:
|
||||||
|
csv_writer = csv.writer(
|
||||||
|
csv_f, delimiter=',', quotechar='"', quoting=csv.QUOTE_MINIMAL)
|
||||||
|
csv_writer.writerow(header)
|
||||||
|
for line in csv_lines:
|
||||||
|
csv_writer.writerow(line)
|
||||||
|
|
||||||
|
|
||||||
|
def get_enroll_test_list(dataset_list, verification_file):
|
||||||
|
"""Get the enroll and test utterance list from all the voxceleb1 test utterance dataset.
|
||||||
|
Generally, we get the enroll and test utterances from the verfification file.
|
||||||
|
The verification file format as follows:
|
||||||
|
target/nontarget enroll-utt test-utt,
|
||||||
|
we set 0 as nontarget and 1 as target, eg:
|
||||||
|
0 a.wav b.wav
|
||||||
|
1 a.wav a.wav
|
||||||
|
|
||||||
|
Args:
|
||||||
|
dataset_list (list): all the dataset to get the test utterances
|
||||||
|
verification_file (str): voxceleb1 trial file
|
||||||
|
"""
|
||||||
|
logger.info(f"verification file: {verification_file}")
|
||||||
|
enroll_audios = set()
|
||||||
|
test_audios = set()
|
||||||
|
with open(verification_file, 'r') as f:
|
||||||
|
for line in f:
|
||||||
|
_, enroll_file, test_file = line.strip().split(' ')
|
||||||
|
enroll_audios.add('-'.join(enroll_file.split('/')))
|
||||||
|
test_audios.add('-'.join(test_file.split('/')))
|
||||||
|
|
||||||
|
enroll_files = []
|
||||||
|
test_files = []
|
||||||
|
for dataset in dataset_list:
|
||||||
|
with open(dataset, 'r') as f:
|
||||||
|
for line in f:
|
||||||
|
# audio_id may be in enroll and test at the same time
|
||||||
|
# eg: 1 a.wav a.wav
|
||||||
|
# the audio a.wav is enroll and test file at the same time
|
||||||
|
audio_id = json.loads(line.strip())['utt']
|
||||||
|
if audio_id in enroll_audios:
|
||||||
|
enroll_files.append(line)
|
||||||
|
if audio_id in test_audios:
|
||||||
|
test_files.append(line)
|
||||||
|
|
||||||
|
enroll_files = sorted(enroll_files)
|
||||||
|
test_files = sorted(test_files)
|
||||||
|
|
||||||
|
return enroll_files, test_files
|
||||||
|
|
||||||
|
|
||||||
|
def get_train_dev_list(dataset_list, target_dir, split_ratio):
|
||||||
|
"""Get the train and dev utterance list from all the training utterance dataset.
|
||||||
|
Generally, we use the split_ratio as the train dataset ratio,
|
||||||
|
and the remaining utterance (ratio is 1 - split_ratio) is the dev dataset
|
||||||
|
|
||||||
|
Args:
|
||||||
|
dataset_list (list): all the dataset to get the all utterances
|
||||||
|
target_dir (str): the target train and dev directory,
|
||||||
|
we will create the csv directory to store the {train,dev}.csv file
|
||||||
|
split_ratio (float): train dataset ratio in all utterance list
|
||||||
|
"""
|
||||||
|
logger.info("start to get train and dev utt list")
|
||||||
|
if not os.path.exists(os.path.join(target_dir, "meta")):
|
||||||
|
os.makedirs(os.path.join(target_dir, "meta"))
|
||||||
|
|
||||||
|
audio_files = []
|
||||||
|
speakers = set()
|
||||||
|
for dataset in dataset_list:
|
||||||
|
with open(dataset, 'r') as f:
|
||||||
|
for line in f:
|
||||||
|
# the label is speaker name
|
||||||
|
label_name = json.loads(line.strip())['utt2spk']
|
||||||
|
speakers.add(label_name)
|
||||||
|
audio_files.append(line.strip())
|
||||||
|
speakers = sorted(speakers)
|
||||||
|
logger.info(f"we get {len(speakers)} speakers from all the train dataset")
|
||||||
|
|
||||||
|
with open(os.path.join(target_dir, "meta", "label2id.txt"), 'w') as f:
|
||||||
|
for label_id, label_name in enumerate(speakers):
|
||||||
|
f.write(f'{label_name} {label_id}\n')
|
||||||
|
logger.info(
|
||||||
|
f'we store the speakers to {os.path.join(target_dir, "meta", "label2id.txt")}'
|
||||||
|
)
|
||||||
|
|
||||||
|
# the split_ratio is for train dataset
|
||||||
|
# the remaining is for dev dataset
|
||||||
|
split_idx = int(split_ratio * len(audio_files))
|
||||||
|
audio_files = sorted(audio_files)
|
||||||
|
random.shuffle(audio_files)
|
||||||
|
train_files, dev_files = audio_files[:split_idx], audio_files[split_idx:]
|
||||||
|
logger.info(
|
||||||
|
f"we get train utterances: {len(train_files)}, dev utterance: {len(dev_files)}"
|
||||||
|
)
|
||||||
|
return train_files, dev_files
|
||||||
|
|
||||||
|
|
||||||
|
def prepare_data(args, config):
|
||||||
|
"""Convert the jsonline format to csv format
|
||||||
|
|
||||||
|
Args:
|
||||||
|
args (argparse.Namespace): scripts args
|
||||||
|
config (CfgNode): yaml configuration content
|
||||||
|
"""
|
||||||
|
# stage0: set the random seed
|
||||||
|
random.seed(config.seed)
|
||||||
|
|
||||||
|
# if external config set the skip_prep flat, we will do nothing
|
||||||
|
if config.skip_prep:
|
||||||
|
return
|
||||||
|
|
||||||
|
# stage 1: prepare the enroll and test csv file
|
||||||
|
# And we generate the speaker to label file label2id.txt
|
||||||
|
logger.info("start to prepare the data csv file")
|
||||||
|
enroll_files, test_files = get_enroll_test_list(
|
||||||
|
[args.test], verification_file=config.verification_file)
|
||||||
|
prepare_csv(
|
||||||
|
enroll_files,
|
||||||
|
os.path.join(args.target_dir, "csv", "enroll.csv"),
|
||||||
|
config,
|
||||||
|
split_chunks=False)
|
||||||
|
prepare_csv(
|
||||||
|
test_files,
|
||||||
|
os.path.join(args.target_dir, "csv", "test.csv"),
|
||||||
|
config,
|
||||||
|
split_chunks=False)
|
||||||
|
|
||||||
|
# stage 2: prepare the train and dev csv file
|
||||||
|
# we get the train dataset ratio as config.split_ratio
|
||||||
|
# and the remaining is dev dataset
|
||||||
|
logger.info("start to prepare the data csv file")
|
||||||
|
train_files, dev_files = get_train_dev_list(
|
||||||
|
args.train, target_dir=args.target_dir, split_ratio=config.split_ratio)
|
||||||
|
prepare_csv(train_files,
|
||||||
|
os.path.join(args.target_dir, "csv", "train.csv"), config)
|
||||||
|
prepare_csv(dev_files,
|
||||||
|
os.path.join(args.target_dir, "csv", "dev.csv"), config)
|
||||||
|
|
||||||
|
|
||||||
|
if __name__ == "__main__":
|
||||||
|
parser = argparse.ArgumentParser(description=__doc__)
|
||||||
|
parser.add_argument(
|
||||||
|
"--train",
|
||||||
|
required=True,
|
||||||
|
nargs='+',
|
||||||
|
help="The jsonline files list for train.")
|
||||||
|
parser.add_argument(
|
||||||
|
"--test", required=True, help="The jsonline file for test")
|
||||||
|
parser.add_argument(
|
||||||
|
"--target_dir",
|
||||||
|
default=None,
|
||||||
|
required=True,
|
||||||
|
help="The target directory stores the csv files and meta file.")
|
||||||
|
parser.add_argument(
|
||||||
|
"--config",
|
||||||
|
default=None,
|
||||||
|
required=True,
|
||||||
|
type=str,
|
||||||
|
help="configuration file")
|
||||||
|
args = parser.parse_args()
|
||||||
|
|
||||||
|
# parse the yaml config file
|
||||||
|
config = CfgNode(new_allowed=True)
|
||||||
|
if args.config:
|
||||||
|
config.merge_from_file(args.config)
|
||||||
|
|
||||||
|
# prepare the csv file from jsonlines files
|
||||||
|
prepare_data(args, config)
|
@ -0,0 +1,192 @@
|
|||||||
|
# Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved.
|
||||||
|
#
|
||||||
|
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||||
|
# you may not use this file except in compliance with the License.
|
||||||
|
# You may obtain a copy of the License at
|
||||||
|
#
|
||||||
|
# http://www.apache.org/licenses/LICENSE-2.0
|
||||||
|
#
|
||||||
|
# Unless required by applicable law or agreed to in writing, software
|
||||||
|
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||||
|
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||||
|
# See the License for the specific language governing permissions and
|
||||||
|
# limitations under the License.
|
||||||
|
from dataclasses import dataclass
|
||||||
|
from dataclasses import fields
|
||||||
|
from paddle.io import Dataset
|
||||||
|
|
||||||
|
from paddleaudio import load as load_audio
|
||||||
|
from paddleaudio.compliance.librosa import melspectrogram
|
||||||
|
from paddlespeech.s2t.utils.log import Log
|
||||||
|
logger = Log(__name__).getlog()
|
||||||
|
|
||||||
|
# the audio meta info in the vector CSVDataset
|
||||||
|
# utt_id: the utterance segment name
|
||||||
|
# duration: utterance segment time
|
||||||
|
# wav: utterance file path
|
||||||
|
# start: start point in the original wav file
|
||||||
|
# stop: stop point in the original wav file
|
||||||
|
# label: the utterance segment's label id
|
||||||
|
|
||||||
|
|
||||||
|
@dataclass
|
||||||
|
class meta_info:
|
||||||
|
"""the audio meta info in the vector CSVDataset
|
||||||
|
|
||||||
|
Args:
|
||||||
|
utt_id (str): the utterance segment name
|
||||||
|
duration (float): utterance segment time
|
||||||
|
wav (str): utterance file path
|
||||||
|
start (int): start point in the original wav file
|
||||||
|
stop (int): stop point in the original wav file
|
||||||
|
lab_id (str): the utterance segment's label id
|
||||||
|
"""
|
||||||
|
utt_id: str
|
||||||
|
duration: float
|
||||||
|
wav: str
|
||||||
|
start: int
|
||||||
|
stop: int
|
||||||
|
label: str
|
||||||
|
|
||||||
|
|
||||||
|
# csv dataset support feature type
|
||||||
|
# raw: return the pcm data sample point
|
||||||
|
# melspectrogram: fbank feature
|
||||||
|
feat_funcs = {
|
||||||
|
'raw': None,
|
||||||
|
'melspectrogram': melspectrogram,
|
||||||
|
}
|
||||||
|
|
||||||
|
|
||||||
|
class CSVDataset(Dataset):
|
||||||
|
def __init__(self,
|
||||||
|
csv_path,
|
||||||
|
label2id_path=None,
|
||||||
|
config=None,
|
||||||
|
random_chunk=True,
|
||||||
|
feat_type: str="raw",
|
||||||
|
n_train_snts: int=-1,
|
||||||
|
**kwargs):
|
||||||
|
"""Implement the CSV Dataset
|
||||||
|
|
||||||
|
Args:
|
||||||
|
csv_path (str): csv dataset file path
|
||||||
|
label2id_path (str): the utterance label to integer id map file path
|
||||||
|
config (CfgNode): yaml config
|
||||||
|
feat_type (str): dataset feature type. if it is raw, it return pcm data.
|
||||||
|
n_train_snts (int): select the n_train_snts sample from the dataset.
|
||||||
|
if n_train_snts = -1, dataset will load all the sample.
|
||||||
|
Default value is -1.
|
||||||
|
kwargs : feature type args
|
||||||
|
"""
|
||||||
|
super().__init__()
|
||||||
|
self.csv_path = csv_path
|
||||||
|
self.label2id_path = label2id_path
|
||||||
|
self.config = config
|
||||||
|
self.random_chunk = random_chunk
|
||||||
|
self.feat_type = feat_type
|
||||||
|
self.n_train_snts = n_train_snts
|
||||||
|
self.feat_config = kwargs
|
||||||
|
self.id2label = {}
|
||||||
|
self.label2id = {}
|
||||||
|
self.data = self.load_data_csv()
|
||||||
|
self.load_speaker_to_label()
|
||||||
|
|
||||||
|
def load_data_csv(self):
|
||||||
|
"""Load the csv dataset content and store them in the data property
|
||||||
|
the csv dataset's format has six fields,
|
||||||
|
that is audio_id or utt_id, audio duration, segment start point, segment stop point
|
||||||
|
and utterance label.
|
||||||
|
Note in training period, the utterance label must has a map to integer id in label2id_path
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
list: the csv data with meta_info type
|
||||||
|
"""
|
||||||
|
data = []
|
||||||
|
|
||||||
|
with open(self.csv_path, 'r') as rf:
|
||||||
|
for line in rf.readlines()[1:]:
|
||||||
|
audio_id, duration, wav, start, stop, spk_id = line.strip(
|
||||||
|
).split(',')
|
||||||
|
data.append(
|
||||||
|
meta_info(audio_id,
|
||||||
|
float(duration), wav,
|
||||||
|
int(start), int(stop), spk_id))
|
||||||
|
if self.n_train_snts > 0:
|
||||||
|
sample_num = min(self.n_train_snts, len(data))
|
||||||
|
data = data[0:sample_num]
|
||||||
|
|
||||||
|
return data
|
||||||
|
|
||||||
|
def load_speaker_to_label(self):
|
||||||
|
"""Load the utterance label map content.
|
||||||
|
In vector domain, we call the utterance label as speaker label.
|
||||||
|
The speaker label is real speaker label in speaker verification domain,
|
||||||
|
and in language identification is language label.
|
||||||
|
"""
|
||||||
|
if not self.label2id_path:
|
||||||
|
logger.warning("No speaker id to label file")
|
||||||
|
return
|
||||||
|
|
||||||
|
with open(self.label2id_path, 'r') as f:
|
||||||
|
for line in f.readlines():
|
||||||
|
label_name, label_id = line.strip().split(' ')
|
||||||
|
self.label2id[label_name] = int(label_id)
|
||||||
|
self.id2label[int(label_id)] = label_name
|
||||||
|
|
||||||
|
def convert_to_record(self, idx: int):
|
||||||
|
"""convert the dataset sample to training record the CSV Dataset
|
||||||
|
|
||||||
|
Args:
|
||||||
|
idx (int) : the request index in all the dataset
|
||||||
|
"""
|
||||||
|
sample = self.data[idx]
|
||||||
|
|
||||||
|
record = {}
|
||||||
|
# To show all fields in a namedtuple: `type(sample)._fields`
|
||||||
|
for field in fields(sample):
|
||||||
|
record[field.name] = getattr(sample, field.name)
|
||||||
|
|
||||||
|
waveform, sr = load_audio(record['wav'])
|
||||||
|
|
||||||
|
# random select a chunk audio samples from the audio
|
||||||
|
if self.config and self.config.random_chunk:
|
||||||
|
num_wav_samples = waveform.shape[0]
|
||||||
|
num_chunk_samples = int(self.config.chunk_duration * sr)
|
||||||
|
start = random.randint(0, num_wav_samples - num_chunk_samples - 1)
|
||||||
|
stop = start + num_chunk_samples
|
||||||
|
else:
|
||||||
|
start = record['start']
|
||||||
|
stop = record['stop']
|
||||||
|
|
||||||
|
# we only return the waveform as feat
|
||||||
|
waveform = waveform[start:stop]
|
||||||
|
|
||||||
|
# all availabel feature type is in feat_funcs
|
||||||
|
assert self.feat_type in feat_funcs.keys(), \
|
||||||
|
f"Unknown feat_type: {self.feat_type}, it must be one in {list(feat_funcs.keys())}"
|
||||||
|
feat_func = feat_funcs[self.feat_type]
|
||||||
|
feat = feat_func(
|
||||||
|
waveform, sr=sr, **self.feat_config) if feat_func else waveform
|
||||||
|
|
||||||
|
record.update({'feat': feat})
|
||||||
|
if self.label2id:
|
||||||
|
record.update({'label': self.label2id[record['label']]})
|
||||||
|
|
||||||
|
return record
|
||||||
|
|
||||||
|
def __getitem__(self, idx):
|
||||||
|
"""Return the specific index sample
|
||||||
|
|
||||||
|
Args:
|
||||||
|
idx (int) : the request index in all the dataset
|
||||||
|
"""
|
||||||
|
return self.convert_to_record(idx)
|
||||||
|
|
||||||
|
def __len__(self):
|
||||||
|
"""Return the dataset length
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
int: the length num of the dataset
|
||||||
|
"""
|
||||||
|
return len(self.data)
|
@ -0,0 +1,214 @@
|
|||||||
|
# Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved.
|
||||||
|
#
|
||||||
|
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||||
|
# you may not use this file except in compliance with the License.
|
||||||
|
# You may obtain a copy of the License at
|
||||||
|
#
|
||||||
|
# http://www.apache.org/licenses/LICENSE-2.0
|
||||||
|
#
|
||||||
|
# Unless required by applicable law or agreed to in writing, software
|
||||||
|
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||||
|
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||||
|
# See the License for the specific language governing permissions and
|
||||||
|
# limitations under the License.
|
||||||
|
from typing import Dict
|
||||||
|
|
||||||
|
import paddle
|
||||||
|
|
||||||
|
|
||||||
|
class InputNormalization:
|
||||||
|
spk_dict_mean: Dict[int, paddle.Tensor]
|
||||||
|
spk_dict_std: Dict[int, paddle.Tensor]
|
||||||
|
spk_dict_count: Dict[int, int]
|
||||||
|
|
||||||
|
def __init__(
|
||||||
|
self,
|
||||||
|
mean_norm=True,
|
||||||
|
std_norm=True,
|
||||||
|
norm_type="global", ):
|
||||||
|
"""Do feature or embedding mean and std norm
|
||||||
|
|
||||||
|
Args:
|
||||||
|
mean_norm (bool, optional): mean norm flag. Defaults to True.
|
||||||
|
std_norm (bool, optional): std norm flag. Defaults to True.
|
||||||
|
norm_type (str, optional): norm type. Defaults to "global".
|
||||||
|
"""
|
||||||
|
super().__init__()
|
||||||
|
self.training = True
|
||||||
|
self.mean_norm = mean_norm
|
||||||
|
self.std_norm = std_norm
|
||||||
|
self.norm_type = norm_type
|
||||||
|
self.glob_mean = paddle.to_tensor([0], dtype="float32")
|
||||||
|
self.glob_std = paddle.to_tensor([0], dtype="float32")
|
||||||
|
self.spk_dict_mean = {}
|
||||||
|
self.spk_dict_std = {}
|
||||||
|
self.spk_dict_count = {}
|
||||||
|
self.weight = 1.0
|
||||||
|
self.count = 0
|
||||||
|
self.eps = 1e-10
|
||||||
|
|
||||||
|
def __call__(self,
|
||||||
|
x,
|
||||||
|
lengths,
|
||||||
|
spk_ids=paddle.to_tensor([], dtype="float32")):
|
||||||
|
"""Returns the tensor with the surrounding context.
|
||||||
|
Args:
|
||||||
|
x (paddle.Tensor): A batch of tensors.
|
||||||
|
lengths (paddle.Tensor): A batch of tensors containing the relative length of each
|
||||||
|
sentence (e.g, [0.7, 0.9, 1.0]). It is used to avoid
|
||||||
|
computing stats on zero-padded steps.
|
||||||
|
spk_ids (_type_, optional): tensor containing the ids of each speaker (e.g, [0 10 6]).
|
||||||
|
It is used to perform per-speaker normalization when
|
||||||
|
norm_type='speaker'. Defaults to paddle.to_tensor([], dtype="float32").
|
||||||
|
Returns:
|
||||||
|
paddle.Tensor: The normalized feature or embedding
|
||||||
|
"""
|
||||||
|
N_batches = x.shape[0]
|
||||||
|
# print(f"x shape: {x.shape[1]}")
|
||||||
|
current_means = []
|
||||||
|
current_stds = []
|
||||||
|
|
||||||
|
for snt_id in range(N_batches):
|
||||||
|
|
||||||
|
# Avoiding padded time steps
|
||||||
|
# actual size is the actual time data length
|
||||||
|
actual_size = paddle.round(lengths[snt_id] *
|
||||||
|
x.shape[1]).astype("int32")
|
||||||
|
# computing actual time data statistics
|
||||||
|
current_mean, current_std = self._compute_current_stats(
|
||||||
|
x[snt_id, 0:actual_size, ...].unsqueeze(0))
|
||||||
|
current_means.append(current_mean)
|
||||||
|
current_stds.append(current_std)
|
||||||
|
|
||||||
|
if self.norm_type == "global":
|
||||||
|
current_mean = paddle.mean(paddle.stack(current_means), axis=0)
|
||||||
|
current_std = paddle.mean(paddle.stack(current_stds), axis=0)
|
||||||
|
|
||||||
|
if self.norm_type == "global":
|
||||||
|
|
||||||
|
if self.training:
|
||||||
|
if self.count == 0:
|
||||||
|
self.glob_mean = current_mean
|
||||||
|
self.glob_std = current_std
|
||||||
|
|
||||||
|
else:
|
||||||
|
self.weight = 1 / (self.count + 1)
|
||||||
|
|
||||||
|
self.glob_mean = (
|
||||||
|
1 - self.weight
|
||||||
|
) * self.glob_mean + self.weight * current_mean
|
||||||
|
|
||||||
|
self.glob_std = (
|
||||||
|
1 - self.weight
|
||||||
|
) * self.glob_std + self.weight * current_std
|
||||||
|
|
||||||
|
self.glob_mean.detach()
|
||||||
|
self.glob_std.detach()
|
||||||
|
|
||||||
|
self.count = self.count + 1
|
||||||
|
x = (x - self.glob_mean) / (self.glob_std)
|
||||||
|
return x
|
||||||
|
|
||||||
|
def _compute_current_stats(self, x):
|
||||||
|
"""Returns the tensor with the surrounding context.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
x (paddle.Tensor): A batch of tensors.
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
the statistics of the data
|
||||||
|
"""
|
||||||
|
# Compute current mean
|
||||||
|
if self.mean_norm:
|
||||||
|
current_mean = paddle.mean(x, axis=0).detach()
|
||||||
|
else:
|
||||||
|
current_mean = paddle.to_tensor([0.0], dtype="float32")
|
||||||
|
|
||||||
|
# Compute current std
|
||||||
|
if self.std_norm:
|
||||||
|
current_std = paddle.std(x, axis=0).detach()
|
||||||
|
else:
|
||||||
|
current_std = paddle.to_tensor([1.0], dtype="float32")
|
||||||
|
|
||||||
|
# Improving numerical stability of std
|
||||||
|
current_std = paddle.maximum(current_std,
|
||||||
|
self.eps * paddle.ones_like(current_std))
|
||||||
|
|
||||||
|
return current_mean, current_std
|
||||||
|
|
||||||
|
def _statistics_dict(self):
|
||||||
|
"""Fills the dictionary containing the normalization statistics.
|
||||||
|
"""
|
||||||
|
state = {}
|
||||||
|
state["count"] = self.count
|
||||||
|
state["glob_mean"] = self.glob_mean
|
||||||
|
state["glob_std"] = self.glob_std
|
||||||
|
state["spk_dict_mean"] = self.spk_dict_mean
|
||||||
|
state["spk_dict_std"] = self.spk_dict_std
|
||||||
|
state["spk_dict_count"] = self.spk_dict_count
|
||||||
|
|
||||||
|
return state
|
||||||
|
|
||||||
|
def _load_statistics_dict(self, state):
|
||||||
|
"""Loads the dictionary containing the statistics.
|
||||||
|
|
||||||
|
Arguments
|
||||||
|
---------
|
||||||
|
state : dict
|
||||||
|
A dictionary containing the normalization statistics.
|
||||||
|
"""
|
||||||
|
self.count = state["count"]
|
||||||
|
if isinstance(state["glob_mean"], int):
|
||||||
|
self.glob_mean = state["glob_mean"]
|
||||||
|
self.glob_std = state["glob_std"]
|
||||||
|
else:
|
||||||
|
self.glob_mean = state["glob_mean"] # .to(self.device_inp)
|
||||||
|
self.glob_std = state["glob_std"] # .to(self.device_inp)
|
||||||
|
|
||||||
|
# Loading the spk_dict_mean in the right device
|
||||||
|
self.spk_dict_mean = {}
|
||||||
|
for spk in state["spk_dict_mean"]:
|
||||||
|
self.spk_dict_mean[spk] = state["spk_dict_mean"][spk]
|
||||||
|
|
||||||
|
# Loading the spk_dict_std in the right device
|
||||||
|
self.spk_dict_std = {}
|
||||||
|
for spk in state["spk_dict_std"]:
|
||||||
|
self.spk_dict_std[spk] = state["spk_dict_std"][spk]
|
||||||
|
|
||||||
|
self.spk_dict_count = state["spk_dict_count"]
|
||||||
|
|
||||||
|
return state
|
||||||
|
|
||||||
|
def to(self, device):
|
||||||
|
"""Puts the needed tensors in the right device.
|
||||||
|
"""
|
||||||
|
self = super(InputNormalization, self).to(device)
|
||||||
|
self.glob_mean = self.glob_mean.to(device)
|
||||||
|
self.glob_std = self.glob_std.to(device)
|
||||||
|
for spk in self.spk_dict_mean:
|
||||||
|
self.spk_dict_mean[spk] = self.spk_dict_mean[spk].to(device)
|
||||||
|
self.spk_dict_std[spk] = self.spk_dict_std[spk].to(device)
|
||||||
|
return self
|
||||||
|
|
||||||
|
def save(self, path):
|
||||||
|
"""Save statistic dictionary.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
path (str): A path where to save the dictionary.
|
||||||
|
"""
|
||||||
|
stats = self._statistics_dict()
|
||||||
|
paddle.save(stats, path)
|
||||||
|
|
||||||
|
def _load(self, path, end_of_epoch=False, device=None):
|
||||||
|
"""Load statistic dictionary.
|
||||||
|
|
||||||
|
Arguments
|
||||||
|
---------
|
||||||
|
path : str
|
||||||
|
The path of the statistic dictionary
|
||||||
|
device : str, None
|
||||||
|
Passed to paddle.load(..., map_location=device)
|
||||||
|
"""
|
||||||
|
del end_of_epoch # Unused here.
|
||||||
|
stats = paddle.load(path, map_location=device)
|
||||||
|
self._load_statistics_dict(stats)
|
@ -0,0 +1,32 @@
|
|||||||
|
# Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved.
|
||||||
|
#
|
||||||
|
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||||
|
# you may not use this file except in compliance with the License.
|
||||||
|
# You may obtain a copy of the License at
|
||||||
|
#
|
||||||
|
# http://www.apache.org/licenses/LICENSE-2.0
|
||||||
|
#
|
||||||
|
# Unless required by applicable law or agreed to in writing, software
|
||||||
|
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||||
|
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||||
|
# See the License for the specific language governing permissions and
|
||||||
|
# limitations under the License.
|
||||||
|
|
||||||
|
|
||||||
|
def get_chunks(seg_dur, audio_id, audio_duration):
|
||||||
|
"""Get all chunk segments from a utterance
|
||||||
|
|
||||||
|
Args:
|
||||||
|
seg_dur (float): segment chunk duration, seconds
|
||||||
|
audio_id (str): utterance name,
|
||||||
|
audio_duration (float): utterance duration, seconds
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
List: all the chunk segments
|
||||||
|
"""
|
||||||
|
num_chunks = int(audio_duration / seg_dur) # all in seconds
|
||||||
|
chunk_lst = [
|
||||||
|
audio_id + "_" + str(i * seg_dur) + "_" + str(i * seg_dur + seg_dur)
|
||||||
|
for i in range(num_chunks)
|
||||||
|
]
|
||||||
|
return chunk_lst
|
Loading…
Reference in new issue