commit
fc72295334
@ -0,0 +1,62 @@
|
|||||||
|
###########################################################
|
||||||
|
# AMI DATA PREPARE SETTING #
|
||||||
|
###########################################################
|
||||||
|
split_type: 'full_corpus_asr'
|
||||||
|
skip_TNO: True
|
||||||
|
# Options for mic_type: 'Mix-Lapel', 'Mix-Headset', 'Array1', 'Array1-01', 'BeamformIt'
|
||||||
|
mic_type: 'Mix-Headset'
|
||||||
|
vad_type: 'oracle'
|
||||||
|
max_subseg_dur: 3.0
|
||||||
|
overlap: 1.5
|
||||||
|
# Some more exp folders (for cleaner structure).
|
||||||
|
embedding_dir: emb #!ref <save_folder>/emb
|
||||||
|
meta_data_dir: metadata #!ref <save_folder>/metadata
|
||||||
|
ref_rttm_dir: ref_rttms #!ref <save_folder>/ref_rttms
|
||||||
|
sys_rttm_dir: sys_rttms #!ref <save_folder>/sys_rttms
|
||||||
|
der_dir: DER #!ref <save_folder>/DER
|
||||||
|
|
||||||
|
|
||||||
|
###########################################################
|
||||||
|
# FEATURE EXTRACTION SETTING #
|
||||||
|
###########################################################
|
||||||
|
# currently, we only support fbank
|
||||||
|
sr: 16000 # sample rate
|
||||||
|
n_mels: 80
|
||||||
|
window_size: 400 #25ms, sample rate 16000, 25 * 16000 / 1000 = 400
|
||||||
|
hop_size: 160 #10ms, sample rate 16000, 10 * 16000 / 1000 = 160
|
||||||
|
#left_frames: 0
|
||||||
|
#right_frames: 0
|
||||||
|
#deltas: False
|
||||||
|
|
||||||
|
|
||||||
|
###########################################################
|
||||||
|
# MODEL SETTING #
|
||||||
|
###########################################################
|
||||||
|
# currently, we only support ecapa-tdnn in the ecapa_tdnn.yaml
|
||||||
|
# if we want use another model, please choose another configuration yaml file
|
||||||
|
seed: 1234
|
||||||
|
emb_dim: 192
|
||||||
|
batch_size: 16
|
||||||
|
model:
|
||||||
|
input_size: 80
|
||||||
|
channels: [1024, 1024, 1024, 1024, 3072]
|
||||||
|
kernel_sizes: [5, 3, 3, 3, 1]
|
||||||
|
dilations: [1, 2, 3, 4, 1]
|
||||||
|
attention_channels: 128
|
||||||
|
lin_neurons: 192
|
||||||
|
# Will automatically download ECAPA-TDNN model (best).
|
||||||
|
|
||||||
|
###########################################################
|
||||||
|
# SPECTRAL CLUSTERING SETTING #
|
||||||
|
###########################################################
|
||||||
|
backend: 'SC' # options: 'kmeans' # Note: kmeans goes only with cos affinity
|
||||||
|
affinity: 'cos' # options: cos, nn
|
||||||
|
max_num_spkrs: 10
|
||||||
|
oracle_n_spkrs: True
|
||||||
|
|
||||||
|
|
||||||
|
###########################################################
|
||||||
|
# DER EVALUATION SETTING #
|
||||||
|
###########################################################
|
||||||
|
ignore_overlap: True
|
||||||
|
forgiveness_collar: 0.25
|
@ -0,0 +1,231 @@
|
|||||||
|
# Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved.
|
||||||
|
#
|
||||||
|
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||||
|
# you may not use this file except in compliance with the License.
|
||||||
|
# You may obtain a copy of the License at
|
||||||
|
#
|
||||||
|
# http://www.apache.org/licenses/LICENSE-2.0
|
||||||
|
#
|
||||||
|
# Unless required by applicable law or agreed to in writing, software
|
||||||
|
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||||
|
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||||
|
# See the License for the specific language governing permissions and
|
||||||
|
# limitations under the License.
|
||||||
|
import argparse
|
||||||
|
import json
|
||||||
|
import os
|
||||||
|
import pickle
|
||||||
|
import sys
|
||||||
|
|
||||||
|
import numpy as np
|
||||||
|
import paddle
|
||||||
|
from paddle.io import BatchSampler
|
||||||
|
from paddle.io import DataLoader
|
||||||
|
from tqdm.contrib import tqdm
|
||||||
|
from yacs.config import CfgNode
|
||||||
|
|
||||||
|
from paddlespeech.s2t.utils.log import Log
|
||||||
|
from paddlespeech.vector.cluster.diarization import EmbeddingMeta
|
||||||
|
from paddlespeech.vector.io.batch import batch_feature_normalize
|
||||||
|
from paddlespeech.vector.io.dataset_from_json import JSONDataset
|
||||||
|
from paddlespeech.vector.models.ecapa_tdnn import EcapaTdnn
|
||||||
|
from paddlespeech.vector.modules.sid_model import SpeakerIdetification
|
||||||
|
from paddlespeech.vector.training.seeding import seed_everything
|
||||||
|
|
||||||
|
# Logger setup
|
||||||
|
logger = Log(__name__).getlog()
|
||||||
|
|
||||||
|
|
||||||
|
def prepare_subset_json(full_meta_data, rec_id, out_meta_file):
|
||||||
|
"""Prepares metadata for a given recording ID.
|
||||||
|
|
||||||
|
Arguments
|
||||||
|
---------
|
||||||
|
full_meta_data : json
|
||||||
|
Full meta (json) containing all the recordings
|
||||||
|
rec_id : str
|
||||||
|
The recording ID for which meta (json) has to be prepared
|
||||||
|
out_meta_file : str
|
||||||
|
Path of the output meta (json) file.
|
||||||
|
"""
|
||||||
|
|
||||||
|
subset = {}
|
||||||
|
for key in full_meta_data:
|
||||||
|
k = str(key)
|
||||||
|
if k.startswith(rec_id):
|
||||||
|
subset[key] = full_meta_data[key]
|
||||||
|
|
||||||
|
with open(out_meta_file, mode="w") as json_f:
|
||||||
|
json.dump(subset, json_f, indent=2)
|
||||||
|
|
||||||
|
|
||||||
|
def create_dataloader(json_file, batch_size):
|
||||||
|
"""Creates the datasets and their data processing pipelines.
|
||||||
|
This is used for multi-mic processing.
|
||||||
|
"""
|
||||||
|
|
||||||
|
# create datasets
|
||||||
|
dataset = JSONDataset(
|
||||||
|
json_file=json_file,
|
||||||
|
feat_type='melspectrogram',
|
||||||
|
n_mels=config.n_mels,
|
||||||
|
window_size=config.window_size,
|
||||||
|
hop_length=config.hop_size)
|
||||||
|
|
||||||
|
# create dataloader
|
||||||
|
batch_sampler = BatchSampler(dataset, batch_size=batch_size, shuffle=True)
|
||||||
|
dataloader = DataLoader(dataset,
|
||||||
|
batch_sampler=batch_sampler,
|
||||||
|
collate_fn=lambda x: batch_feature_normalize(
|
||||||
|
x, mean_norm=True, std_norm=False),
|
||||||
|
return_list=True)
|
||||||
|
|
||||||
|
return dataloader
|
||||||
|
|
||||||
|
|
||||||
|
def main(args, config):
|
||||||
|
# set the training device, cpu or gpu
|
||||||
|
paddle.set_device(args.device)
|
||||||
|
# set the random seed
|
||||||
|
seed_everything(config.seed)
|
||||||
|
|
||||||
|
# stage1: build the dnn backbone model network
|
||||||
|
ecapa_tdnn = EcapaTdnn(**config.model)
|
||||||
|
|
||||||
|
# stage2: build the speaker verification eval instance with backbone model
|
||||||
|
model = SpeakerIdetification(backbone=ecapa_tdnn, num_class=1)
|
||||||
|
|
||||||
|
# stage3: load the pre-trained model
|
||||||
|
# we get the last model from the epoch and save_interval
|
||||||
|
args.load_checkpoint = os.path.abspath(
|
||||||
|
os.path.expanduser(args.load_checkpoint))
|
||||||
|
|
||||||
|
# load model checkpoint to sid model
|
||||||
|
state_dict = paddle.load(
|
||||||
|
os.path.join(args.load_checkpoint, 'model.pdparams'))
|
||||||
|
model.set_state_dict(state_dict)
|
||||||
|
logger.info(f'Checkpoint loaded from {args.load_checkpoint}')
|
||||||
|
|
||||||
|
# set the model to eval mode
|
||||||
|
model.eval()
|
||||||
|
|
||||||
|
# load meta data
|
||||||
|
meta_file = os.path.join(
|
||||||
|
args.data_dir,
|
||||||
|
config.meta_data_dir,
|
||||||
|
"ami_" + args.dataset + "." + config.mic_type + ".subsegs.json", )
|
||||||
|
with open(meta_file, "r") as f:
|
||||||
|
full_meta = json.load(f)
|
||||||
|
|
||||||
|
# get all the recording IDs in this dataset.
|
||||||
|
all_keys = full_meta.keys()
|
||||||
|
A = [word.rstrip().split("_")[0] for word in all_keys]
|
||||||
|
all_rec_ids = list(set(A[1:]))
|
||||||
|
all_rec_ids.sort()
|
||||||
|
split = "AMI_" + args.dataset
|
||||||
|
i = 1
|
||||||
|
|
||||||
|
msg = "Extra embdding for " + args.dataset + " set"
|
||||||
|
logger.info(msg)
|
||||||
|
|
||||||
|
if len(all_rec_ids) <= 0:
|
||||||
|
msg = "No recording IDs found! Please check if meta_data json file is properly generated."
|
||||||
|
logger.error(msg)
|
||||||
|
sys.exit()
|
||||||
|
|
||||||
|
# extra different recordings embdding in a dataset.
|
||||||
|
for rec_id in tqdm(all_rec_ids):
|
||||||
|
# This tag will be displayed in the log.
|
||||||
|
tag = ("[" + str(args.dataset) + ": " + str(i) + "/" +
|
||||||
|
str(len(all_rec_ids)) + "]")
|
||||||
|
i = i + 1
|
||||||
|
|
||||||
|
# log message.
|
||||||
|
msg = "Embdding %s : %s " % (tag, rec_id)
|
||||||
|
logger.debug(msg)
|
||||||
|
|
||||||
|
# embedding directory.
|
||||||
|
if not os.path.exists(
|
||||||
|
os.path.join(args.data_dir, config.embedding_dir, split)):
|
||||||
|
os.makedirs(
|
||||||
|
os.path.join(args.data_dir, config.embedding_dir, split))
|
||||||
|
|
||||||
|
# file to store embeddings.
|
||||||
|
emb_file_name = rec_id + "." + config.mic_type + ".emb_stat.pkl"
|
||||||
|
diary_stat_emb_file = os.path.join(args.data_dir, config.embedding_dir,
|
||||||
|
split, emb_file_name)
|
||||||
|
|
||||||
|
# prepare a metadata (json) for one recording. This is basically a subset of full_meta.
|
||||||
|
# lets keep this meta-info in embedding directory itself.
|
||||||
|
json_file_name = rec_id + "." + config.mic_type + ".json"
|
||||||
|
meta_per_rec_file = os.path.join(args.data_dir, config.embedding_dir,
|
||||||
|
split, json_file_name)
|
||||||
|
|
||||||
|
# write subset (meta for one recording) json metadata.
|
||||||
|
prepare_subset_json(full_meta, rec_id, meta_per_rec_file)
|
||||||
|
|
||||||
|
# prepare data loader.
|
||||||
|
diary_set_loader = create_dataloader(meta_per_rec_file,
|
||||||
|
config.batch_size)
|
||||||
|
|
||||||
|
# extract embeddings (skip if already done).
|
||||||
|
if not os.path.isfile(diary_stat_emb_file):
|
||||||
|
logger.debug("Extracting deep embeddings")
|
||||||
|
embeddings = np.empty(shape=[0, config.emb_dim], dtype=np.float64)
|
||||||
|
segset = []
|
||||||
|
|
||||||
|
for batch_idx, batch in enumerate(tqdm(diary_set_loader)):
|
||||||
|
# extrac the audio embedding
|
||||||
|
ids, feats, lengths = batch['ids'], batch['feats'], batch[
|
||||||
|
'lengths']
|
||||||
|
seg = [x for x in ids]
|
||||||
|
segset = segset + seg
|
||||||
|
emb = model.backbone(feats, lengths).squeeze(
|
||||||
|
-1).numpy() # (N, emb_size, 1) -> (N, emb_size)
|
||||||
|
embeddings = np.concatenate((embeddings, emb), axis=0)
|
||||||
|
|
||||||
|
segset = np.array(segset, dtype="|O")
|
||||||
|
stat_obj = EmbeddingMeta(
|
||||||
|
segset=segset,
|
||||||
|
stats=embeddings, )
|
||||||
|
logger.debug("Saving Embeddings...")
|
||||||
|
with open(diary_stat_emb_file, "wb") as output:
|
||||||
|
pickle.dump(stat_obj, output)
|
||||||
|
|
||||||
|
else:
|
||||||
|
logger.debug("Skipping embedding extraction (as already present).")
|
||||||
|
|
||||||
|
|
||||||
|
# Begin experiment!
|
||||||
|
if __name__ == "__main__":
|
||||||
|
parser = argparse.ArgumentParser(__doc__)
|
||||||
|
parser.add_argument(
|
||||||
|
'--device',
|
||||||
|
default="gpu",
|
||||||
|
help="Select which device to perform diarization, defaults to gpu.")
|
||||||
|
parser.add_argument(
|
||||||
|
"--config", default=None, type=str, help="configuration file")
|
||||||
|
parser.add_argument(
|
||||||
|
"--data-dir",
|
||||||
|
default="../save/",
|
||||||
|
type=str,
|
||||||
|
help="processsed data directory")
|
||||||
|
parser.add_argument(
|
||||||
|
"--dataset",
|
||||||
|
choices=['dev', 'eval'],
|
||||||
|
default="dev",
|
||||||
|
type=str,
|
||||||
|
help="Select which dataset to extra embdding, defaults to dev")
|
||||||
|
parser.add_argument(
|
||||||
|
"--load-checkpoint",
|
||||||
|
type=str,
|
||||||
|
default='',
|
||||||
|
help="Directory to load model checkpoint to compute embeddings.")
|
||||||
|
args = parser.parse_args()
|
||||||
|
config = CfgNode(new_allowed=True)
|
||||||
|
if args.config:
|
||||||
|
config.merge_from_file(args.config)
|
||||||
|
|
||||||
|
config.freeze()
|
||||||
|
|
||||||
|
main(args, config)
|
@ -1,49 +0,0 @@
|
|||||||
#!/bin/bash
|
|
||||||
|
|
||||||
stage=1
|
|
||||||
|
|
||||||
TARGET_DIR=${MAIN_ROOT}/dataset/ami
|
|
||||||
data_folder=${TARGET_DIR}/amicorpus #e.g., /path/to/amicorpus/
|
|
||||||
manual_annot_folder=${TARGET_DIR}/ami_public_manual_1.6.2 #e.g., /path/to/ami_public_manual_1.6.2/
|
|
||||||
|
|
||||||
save_folder=${MAIN_ROOT}/examples/ami/sd0/data
|
|
||||||
ref_rttm_dir=${save_folder}/ref_rttms
|
|
||||||
meta_data_dir=${save_folder}/metadata
|
|
||||||
|
|
||||||
set=L
|
|
||||||
|
|
||||||
. ${MAIN_ROOT}/utils/parse_options.sh || exit 1;
|
|
||||||
set -u
|
|
||||||
set -o pipefail
|
|
||||||
|
|
||||||
mkdir -p ${save_folder}
|
|
||||||
|
|
||||||
if [ ${stage} -le 0 ]; then
|
|
||||||
# Download AMI corpus, You need around 10GB of free space to get whole data
|
|
||||||
# The signals are too large to package in this way,
|
|
||||||
# so you need to use the chooser to indicate which ones you wish to download
|
|
||||||
echo "Please follow https://groups.inf.ed.ac.uk/ami/download/ to download the data."
|
|
||||||
echo "Annotations: AMI manual annotations v1.6.2 "
|
|
||||||
echo "Signals: "
|
|
||||||
echo "1) Select one or more AMI meetings: the IDs please follow ./ami_split.py"
|
|
||||||
echo "2) Select media streams: Just select Headset mix"
|
|
||||||
exit 0;
|
|
||||||
fi
|
|
||||||
|
|
||||||
if [ ${stage} -le 1 ]; then
|
|
||||||
echo "AMI Data preparation"
|
|
||||||
|
|
||||||
python local/ami_prepare.py --data_folder ${data_folder} \
|
|
||||||
--manual_annot_folder ${manual_annot_folder} \
|
|
||||||
--save_folder ${save_folder} --ref_rttm_dir ${ref_rttm_dir} \
|
|
||||||
--meta_data_dir ${meta_data_dir}
|
|
||||||
|
|
||||||
if [ $? -ne 0 ]; then
|
|
||||||
echo "Prepare AMI failed. Please check log message."
|
|
||||||
exit 1
|
|
||||||
fi
|
|
||||||
|
|
||||||
fi
|
|
||||||
|
|
||||||
echo "AMI data preparation done."
|
|
||||||
exit 0
|
|
@ -0,0 +1,428 @@
|
|||||||
|
# Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved.
|
||||||
|
#
|
||||||
|
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||||
|
# you may not use this file except in compliance with the License.
|
||||||
|
# You may obtain a copy of the License at
|
||||||
|
#
|
||||||
|
# http://www.apache.org/licenses/LICENSE-2.0
|
||||||
|
#
|
||||||
|
# Unless required by applicable law or agreed to in writing, software
|
||||||
|
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||||
|
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||||
|
# See the License for the specific language governing permissions and
|
||||||
|
# limitations under the License.
|
||||||
|
import argparse
|
||||||
|
import glob
|
||||||
|
import json
|
||||||
|
import os
|
||||||
|
import pickle
|
||||||
|
import shutil
|
||||||
|
import sys
|
||||||
|
|
||||||
|
import numpy as np
|
||||||
|
from tqdm.contrib import tqdm
|
||||||
|
from yacs.config import CfgNode
|
||||||
|
|
||||||
|
from paddlespeech.s2t.utils.log import Log
|
||||||
|
from paddlespeech.vector.cluster import diarization as diar
|
||||||
|
from utils.DER import DER
|
||||||
|
|
||||||
|
# Logger setup
|
||||||
|
logger = Log(__name__).getlog()
|
||||||
|
|
||||||
|
|
||||||
|
def diarize_dataset(
|
||||||
|
full_meta,
|
||||||
|
split_type,
|
||||||
|
n_lambdas,
|
||||||
|
pval,
|
||||||
|
save_dir,
|
||||||
|
config,
|
||||||
|
n_neighbors=10, ):
|
||||||
|
"""This function diarizes all the recordings in a given dataset. It performs
|
||||||
|
computation of embedding and clusters them using spectral clustering (or other backends).
|
||||||
|
The output speaker boundary file is stored in the RTTM format.
|
||||||
|
"""
|
||||||
|
|
||||||
|
# prepare `spkr_info` only once when Oracle num of speakers is selected.
|
||||||
|
# spkr_info is essential to obtain number of speakers from groundtruth.
|
||||||
|
if config.oracle_n_spkrs is True:
|
||||||
|
full_ref_rttm_file = os.path.join(save_dir, config.ref_rttm_dir,
|
||||||
|
"fullref_ami_" + split_type + ".rttm")
|
||||||
|
rttm = diar.read_rttm(full_ref_rttm_file)
|
||||||
|
|
||||||
|
spkr_info = list( # noqa F841
|
||||||
|
filter(lambda x: x.startswith("SPKR-INFO"), rttm))
|
||||||
|
|
||||||
|
# get all the recording IDs in this dataset.
|
||||||
|
all_keys = full_meta.keys()
|
||||||
|
A = [word.rstrip().split("_")[0] for word in all_keys]
|
||||||
|
all_rec_ids = list(set(A[1:]))
|
||||||
|
all_rec_ids.sort()
|
||||||
|
split = "AMI_" + split_type
|
||||||
|
i = 1
|
||||||
|
|
||||||
|
# adding tag for directory path.
|
||||||
|
type_of_num_spkr = "oracle" if config.oracle_n_spkrs else "est"
|
||||||
|
tag = (type_of_num_spkr + "_" + str(config.affinity) + "_" + config.backend)
|
||||||
|
|
||||||
|
# make out rttm dir
|
||||||
|
out_rttm_dir = os.path.join(save_dir, config.sys_rttm_dir, config.mic_type,
|
||||||
|
split, tag)
|
||||||
|
if not os.path.exists(out_rttm_dir):
|
||||||
|
os.makedirs(out_rttm_dir)
|
||||||
|
|
||||||
|
# diarizing different recordings in a dataset.
|
||||||
|
for rec_id in tqdm(all_rec_ids):
|
||||||
|
# this tag will be displayed in the log.
|
||||||
|
tag = ("[" + str(split_type) + ": " + str(i) + "/" +
|
||||||
|
str(len(all_rec_ids)) + "]")
|
||||||
|
i = i + 1
|
||||||
|
|
||||||
|
# log message.
|
||||||
|
msg = "Diarizing %s : %s " % (tag, rec_id)
|
||||||
|
logger.debug(msg)
|
||||||
|
|
||||||
|
# load embeddings.
|
||||||
|
emb_file_name = rec_id + "." + config.mic_type + ".emb_stat.pkl"
|
||||||
|
diary_stat_emb_file = os.path.join(save_dir, config.embedding_dir,
|
||||||
|
split, emb_file_name)
|
||||||
|
if not os.path.isfile(diary_stat_emb_file):
|
||||||
|
msg = "Embdding file %s not found! Please check if embdding file is properly generated." % (
|
||||||
|
diary_stat_emb_file)
|
||||||
|
logger.error(msg)
|
||||||
|
sys.exit()
|
||||||
|
with open(diary_stat_emb_file, "rb") as in_file:
|
||||||
|
diary_obj = pickle.load(in_file)
|
||||||
|
|
||||||
|
out_rttm_file = out_rttm_dir + "/" + rec_id + ".rttm"
|
||||||
|
|
||||||
|
# processing starts from here.
|
||||||
|
if config.oracle_n_spkrs is True:
|
||||||
|
# oracle num of speakers.
|
||||||
|
num_spkrs = diar.get_oracle_num_spkrs(rec_id, spkr_info)
|
||||||
|
else:
|
||||||
|
if config.affinity == "nn":
|
||||||
|
# num of speakers tunned on dev set (only for nn affinity).
|
||||||
|
num_spkrs = n_lambdas
|
||||||
|
else:
|
||||||
|
# num of speakers will be estimated using max eigen gap for cos based affinity.
|
||||||
|
# so adding None here. Will use this None later-on.
|
||||||
|
num_spkrs = None
|
||||||
|
|
||||||
|
if config.backend == "kmeans":
|
||||||
|
diar.do_kmeans_clustering(
|
||||||
|
diary_obj,
|
||||||
|
out_rttm_file,
|
||||||
|
rec_id,
|
||||||
|
num_spkrs,
|
||||||
|
pval, )
|
||||||
|
|
||||||
|
if config.backend == "SC":
|
||||||
|
# go for Spectral Clustering (SC).
|
||||||
|
diar.do_spec_clustering(
|
||||||
|
diary_obj,
|
||||||
|
out_rttm_file,
|
||||||
|
rec_id,
|
||||||
|
num_spkrs,
|
||||||
|
pval,
|
||||||
|
config.affinity,
|
||||||
|
n_neighbors, )
|
||||||
|
|
||||||
|
# can used for AHC later. Likewise one can add different backends here.
|
||||||
|
if config.backend == "AHC":
|
||||||
|
# call AHC
|
||||||
|
threshold = pval # pval for AHC is nothing but threshold.
|
||||||
|
diar.do_AHC(diary_obj, out_rttm_file, rec_id, num_spkrs, threshold)
|
||||||
|
|
||||||
|
# once all RTTM outputs are generated, concatenate individual RTTM files to obtain single RTTM file.
|
||||||
|
# this is not needed but just staying with the standards.
|
||||||
|
concate_rttm_file = out_rttm_dir + "/sys_output.rttm"
|
||||||
|
logger.debug("Concatenating individual RTTM files...")
|
||||||
|
with open(concate_rttm_file, "w") as cat_file:
|
||||||
|
for f in glob.glob(out_rttm_dir + "/*.rttm"):
|
||||||
|
if f == concate_rttm_file:
|
||||||
|
continue
|
||||||
|
with open(f, "r") as indi_rttm_file:
|
||||||
|
shutil.copyfileobj(indi_rttm_file, cat_file)
|
||||||
|
|
||||||
|
msg = "The system generated RTTM file for %s set : %s" % (
|
||||||
|
split_type, concate_rttm_file, )
|
||||||
|
logger.debug(msg)
|
||||||
|
|
||||||
|
return concate_rttm_file
|
||||||
|
|
||||||
|
|
||||||
|
def dev_pval_tuner(full_meta, save_dir, config):
|
||||||
|
"""Tuning p_value for affinity matrix.
|
||||||
|
The p_value used so that only p% of the values in each row is retained.
|
||||||
|
"""
|
||||||
|
|
||||||
|
DER_list = []
|
||||||
|
prange = np.arange(0.002, 0.015, 0.001)
|
||||||
|
|
||||||
|
n_lambdas = None # using it as flag later.
|
||||||
|
for p_v in prange:
|
||||||
|
# Process whole dataset for value of p_v.
|
||||||
|
concate_rttm_file = diarize_dataset(full_meta, "dev", n_lambdas, p_v,
|
||||||
|
save_dir, config)
|
||||||
|
|
||||||
|
ref_rttm_file = os.path.join(save_dir, config.ref_rttm_dir,
|
||||||
|
"fullref_ami_dev.rttm")
|
||||||
|
sys_rttm_file = concate_rttm_file
|
||||||
|
[MS, FA, SER, DER_] = DER(
|
||||||
|
ref_rttm_file,
|
||||||
|
sys_rttm_file,
|
||||||
|
config.ignore_overlap,
|
||||||
|
config.forgiveness_collar, )
|
||||||
|
|
||||||
|
DER_list.append(DER_)
|
||||||
|
|
||||||
|
if config.oracle_n_spkrs is True and config.backend == "kmeans":
|
||||||
|
# no need of p_val search. Note p_val is needed for SC for both oracle and est num of speakers.
|
||||||
|
# p_val is needed in oracle_n_spkr=False when using kmeans backend.
|
||||||
|
break
|
||||||
|
|
||||||
|
# Take p_val that gave minmum DER on Dev dataset.
|
||||||
|
tuned_p_val = prange[DER_list.index(min(DER_list))]
|
||||||
|
|
||||||
|
return tuned_p_val
|
||||||
|
|
||||||
|
|
||||||
|
def dev_ahc_threshold_tuner(full_meta, save_dir, config):
|
||||||
|
"""Tuning threshold for affinity matrix. This function is called when AHC is used as backend.
|
||||||
|
"""
|
||||||
|
|
||||||
|
DER_list = []
|
||||||
|
prange = np.arange(0.0, 1.0, 0.1)
|
||||||
|
|
||||||
|
n_lambdas = None # using it as flag later.
|
||||||
|
|
||||||
|
# Note: p_val is threshold in case of AHC.
|
||||||
|
for p_v in prange:
|
||||||
|
# Process whole dataset for value of p_v.
|
||||||
|
concate_rttm_file = diarize_dataset(full_meta, "dev", n_lambdas, p_v,
|
||||||
|
save_dir, config)
|
||||||
|
|
||||||
|
ref_rttm = os.path.join(save_dir, config.ref_rttm_dir,
|
||||||
|
"fullref_ami_dev.rttm")
|
||||||
|
sys_rttm = concate_rttm_file
|
||||||
|
[MS, FA, SER, DER_] = DER(
|
||||||
|
ref_rttm,
|
||||||
|
sys_rttm,
|
||||||
|
config.ignore_overlap,
|
||||||
|
config.forgiveness_collar, )
|
||||||
|
|
||||||
|
DER_list.append(DER_)
|
||||||
|
|
||||||
|
if config.oracle_n_spkrs is True:
|
||||||
|
break # no need of threshold search.
|
||||||
|
|
||||||
|
# Take p_val that gave minmum DER on Dev dataset.
|
||||||
|
tuned_p_val = prange[DER_list.index(min(DER_list))]
|
||||||
|
|
||||||
|
return tuned_p_val
|
||||||
|
|
||||||
|
|
||||||
|
def dev_nn_tuner(full_meta, split_type, save_dir, config):
|
||||||
|
"""Tuning n_neighbors on dev set. Assuming oracle num of speakers.
|
||||||
|
This is used when nn based affinity is selected.
|
||||||
|
"""
|
||||||
|
|
||||||
|
DER_list = []
|
||||||
|
pval = None
|
||||||
|
|
||||||
|
# Now assumming oracle num of speakers.
|
||||||
|
n_lambdas = 4
|
||||||
|
|
||||||
|
for nn in range(5, 15):
|
||||||
|
|
||||||
|
# Process whole dataset for value of n_lambdas.
|
||||||
|
concate_rttm_file = diarize_dataset(full_meta, "dev", n_lambdas, p_v,
|
||||||
|
save_dir, config, nn)
|
||||||
|
|
||||||
|
ref_rttm = os.path.join(save_dir, config.ref_rttm_dir,
|
||||||
|
"fullref_ami_dev.rttm")
|
||||||
|
sys_rttm = concate_rttm_file
|
||||||
|
[MS, FA, SER, DER_] = DER(
|
||||||
|
ref_rttm,
|
||||||
|
sys_rttm,
|
||||||
|
config.ignore_overlap,
|
||||||
|
config.forgiveness_collar, )
|
||||||
|
|
||||||
|
DER_list.append([nn, DER_])
|
||||||
|
|
||||||
|
if config.oracle_n_spkrs is True and config.backend == "kmeans":
|
||||||
|
break
|
||||||
|
|
||||||
|
DER_list.sort(key=lambda x: x[1])
|
||||||
|
tunned_nn = DER_list[0]
|
||||||
|
|
||||||
|
return tunned_nn[0]
|
||||||
|
|
||||||
|
|
||||||
|
def dev_tuner(full_meta, split_type, save_dir, config):
|
||||||
|
"""Tuning n_components on dev set. Used for nn based affinity matrix.
|
||||||
|
Note: This is a very basic tunning for nn based affinity.
|
||||||
|
This is work in progress till we find a better way.
|
||||||
|
"""
|
||||||
|
|
||||||
|
DER_list = []
|
||||||
|
pval = None
|
||||||
|
for n_lambdas in range(1, config.max_num_spkrs + 1):
|
||||||
|
|
||||||
|
# Process whole dataset for value of n_lambdas.
|
||||||
|
concate_rttm_file = diarize_dataset(full_meta, "dev", n_lambdas, p_v,
|
||||||
|
save_dir, config)
|
||||||
|
|
||||||
|
ref_rttm = os.path.join(save_dir, config.ref_rttm_dir,
|
||||||
|
"fullref_ami_dev.rttm")
|
||||||
|
sys_rttm = concate_rttm_file
|
||||||
|
[MS, FA, SER, DER_] = DER(
|
||||||
|
ref_rttm,
|
||||||
|
sys_rttm,
|
||||||
|
config.ignore_overlap,
|
||||||
|
config.forgiveness_collar, )
|
||||||
|
|
||||||
|
DER_list.append(DER_)
|
||||||
|
|
||||||
|
# Take n_lambdas with minmum DER.
|
||||||
|
tuned_n_lambdas = DER_list.index(min(DER_list)) + 1
|
||||||
|
|
||||||
|
return tuned_n_lambdas
|
||||||
|
|
||||||
|
|
||||||
|
def main(args, config):
|
||||||
|
# AMI Dev Set: Tune hyperparams on dev set.
|
||||||
|
# Read the embdding file for dev set generated during embdding compute
|
||||||
|
dev_meta_file = os.path.join(
|
||||||
|
args.data_dir,
|
||||||
|
config.meta_data_dir,
|
||||||
|
"ami_dev." + config.mic_type + ".subsegs.json", )
|
||||||
|
with open(dev_meta_file, "r") as f:
|
||||||
|
meta_dev = json.load(f)
|
||||||
|
|
||||||
|
full_meta = meta_dev
|
||||||
|
|
||||||
|
# Processing starts from here
|
||||||
|
# Following few lines selects option for different backend and affinity matrices. Finds best values for hyperameters using dev set.
|
||||||
|
ref_rttm_file = os.path.join(args.data_dir, config.ref_rttm_dir,
|
||||||
|
"fullref_ami_dev.rttm")
|
||||||
|
best_nn = None
|
||||||
|
if config.affinity == "nn":
|
||||||
|
logger.info("Tuning for nn (Multiple iterations over AMI Dev set)")
|
||||||
|
best_nn = dev_nn_tuner(full_meta, args.data_dir, config)
|
||||||
|
|
||||||
|
n_lambdas = None
|
||||||
|
best_pval = None
|
||||||
|
|
||||||
|
if config.affinity == "cos" and (config.backend == "SC" or
|
||||||
|
config.backend == "kmeans"):
|
||||||
|
# oracle num_spkrs or not, doesn't matter for kmeans and SC backends
|
||||||
|
# cos: Tune for the best pval for SC /kmeans (for unknown num of spkrs)
|
||||||
|
logger.info(
|
||||||
|
"Tuning for p-value for SC (Multiple iterations over AMI Dev set)")
|
||||||
|
best_pval = dev_pval_tuner(full_meta, args.data_dir, config)
|
||||||
|
|
||||||
|
elif config.backend == "AHC":
|
||||||
|
logger.info("Tuning for threshold-value for AHC")
|
||||||
|
best_threshold = dev_ahc_threshold_tuner(full_meta, args.data_dir,
|
||||||
|
config)
|
||||||
|
best_pval = best_threshold
|
||||||
|
else:
|
||||||
|
# NN for unknown num of speakers (can be used in future)
|
||||||
|
if config.oracle_n_spkrs is False:
|
||||||
|
# nn: Tune num of number of components (to be updated later)
|
||||||
|
logger.info(
|
||||||
|
"Tuning for number of eigen components for NN (Multiple iterations over AMI Dev set)"
|
||||||
|
)
|
||||||
|
# dev_tuner used for tuning num of components in NN. Can be used in future.
|
||||||
|
n_lambdas = dev_tuner(full_meta, args.data_dir, config)
|
||||||
|
|
||||||
|
# load 'dev' and 'eval' metadata files.
|
||||||
|
full_meta_dev = full_meta # current full_meta is for 'dev'
|
||||||
|
eval_meta_file = os.path.join(
|
||||||
|
args.data_dir,
|
||||||
|
config.meta_data_dir,
|
||||||
|
"ami_eval." + config.mic_type + ".subsegs.json", )
|
||||||
|
with open(eval_meta_file, "r") as f:
|
||||||
|
full_meta_eval = json.load(f)
|
||||||
|
|
||||||
|
# tag to be appended to final output DER files. Writing DER for individual files.
|
||||||
|
type_of_num_spkr = "oracle" if config.oracle_n_spkrs else "est"
|
||||||
|
tag = (
|
||||||
|
type_of_num_spkr + "_" + str(config.affinity) + "." + config.mic_type)
|
||||||
|
|
||||||
|
# perform final diarization on 'dev' and 'eval' with best hyperparams.
|
||||||
|
final_DERs = {}
|
||||||
|
out_der_dir = os.path.join(args.data_dir, config.der_dir)
|
||||||
|
if not os.path.exists(out_der_dir):
|
||||||
|
os.makedirs(out_der_dir)
|
||||||
|
|
||||||
|
for split_type in ["dev", "eval"]:
|
||||||
|
if split_type == "dev":
|
||||||
|
full_meta = full_meta_dev
|
||||||
|
else:
|
||||||
|
full_meta = full_meta_eval
|
||||||
|
|
||||||
|
# performing diarization.
|
||||||
|
msg = "Diarizing using best hyperparams: " + split_type + " set"
|
||||||
|
logger.info(msg)
|
||||||
|
out_boundaries = diarize_dataset(
|
||||||
|
full_meta,
|
||||||
|
split_type,
|
||||||
|
n_lambdas=n_lambdas,
|
||||||
|
pval=best_pval,
|
||||||
|
n_neighbors=best_nn,
|
||||||
|
save_dir=args.data_dir,
|
||||||
|
config=config)
|
||||||
|
|
||||||
|
# computing DER.
|
||||||
|
msg = "Computing DERs for " + split_type + " set"
|
||||||
|
logger.info(msg)
|
||||||
|
ref_rttm = os.path.join(args.data_dir, config.ref_rttm_dir,
|
||||||
|
"fullref_ami_" + split_type + ".rttm")
|
||||||
|
sys_rttm = out_boundaries
|
||||||
|
[MS, FA, SER, DER_vals] = DER(
|
||||||
|
ref_rttm,
|
||||||
|
sys_rttm,
|
||||||
|
config.ignore_overlap,
|
||||||
|
config.forgiveness_collar,
|
||||||
|
individual_file_scores=True, )
|
||||||
|
|
||||||
|
# writing DER values to a file. Append tag.
|
||||||
|
der_file_name = split_type + "_DER_" + tag
|
||||||
|
out_der_file = os.path.join(out_der_dir, der_file_name)
|
||||||
|
msg = "Writing DER file to: " + out_der_file
|
||||||
|
logger.info(msg)
|
||||||
|
diar.write_ders_file(ref_rttm, DER_vals, out_der_file)
|
||||||
|
|
||||||
|
msg = ("AMI " + split_type + " set DER = %s %%\n" %
|
||||||
|
(str(round(DER_vals[-1], 2))))
|
||||||
|
logger.info(msg)
|
||||||
|
final_DERs[split_type] = round(DER_vals[-1], 2)
|
||||||
|
|
||||||
|
# final print DERs
|
||||||
|
msg = (
|
||||||
|
"Final Diarization Error Rate (%%) on AMI corpus: Dev = %s %% | Eval = %s %%\n"
|
||||||
|
% (str(final_DERs["dev"]), str(final_DERs["eval"])))
|
||||||
|
logger.info(msg)
|
||||||
|
|
||||||
|
|
||||||
|
if __name__ == "__main__":
|
||||||
|
parser = argparse.ArgumentParser(__doc__)
|
||||||
|
parser.add_argument(
|
||||||
|
"--config", default=None, type=str, help="configuration file")
|
||||||
|
parser.add_argument(
|
||||||
|
"--data-dir",
|
||||||
|
default="../data/",
|
||||||
|
type=str,
|
||||||
|
help="processsed data directory")
|
||||||
|
args = parser.parse_args()
|
||||||
|
config = CfgNode(new_allowed=True)
|
||||||
|
if args.config:
|
||||||
|
config.merge_from_file(args.config)
|
||||||
|
|
||||||
|
config.freeze()
|
||||||
|
|
||||||
|
main(args, config)
|
@ -0,0 +1,49 @@
|
|||||||
|
#!/bin/bash
|
||||||
|
|
||||||
|
stage=0
|
||||||
|
set=L
|
||||||
|
|
||||||
|
. ${MAIN_ROOT}/utils/parse_options.sh || exit 1;
|
||||||
|
set -o pipefail
|
||||||
|
|
||||||
|
data_folder=$1
|
||||||
|
manual_annot_folder=$2
|
||||||
|
save_folder=$3
|
||||||
|
pretrained_model_dir=$4
|
||||||
|
conf_path=$5
|
||||||
|
device=$6
|
||||||
|
|
||||||
|
ref_rttm_dir=${save_folder}/ref_rttms
|
||||||
|
meta_data_dir=${save_folder}/metadata
|
||||||
|
|
||||||
|
if [ ${stage} -le 0 ]; then
|
||||||
|
echo "AMI Data preparation"
|
||||||
|
python local/ami_prepare.py --data_folder ${data_folder} \
|
||||||
|
--manual_annot_folder ${manual_annot_folder} \
|
||||||
|
--save_folder ${save_folder} --ref_rttm_dir ${ref_rttm_dir} \
|
||||||
|
--meta_data_dir ${meta_data_dir}
|
||||||
|
|
||||||
|
if [ $? -ne 0 ]; then
|
||||||
|
echo "Prepare AMI failed. Please check log message."
|
||||||
|
exit 1
|
||||||
|
fi
|
||||||
|
echo "AMI data preparation done."
|
||||||
|
fi
|
||||||
|
|
||||||
|
if [ ${stage} -le 1 ]; then
|
||||||
|
# extra embddings for dev and eval dataset
|
||||||
|
for name in dev eval; do
|
||||||
|
python local/compute_embdding.py --config ${conf_path} \
|
||||||
|
--data-dir ${save_folder} \
|
||||||
|
--device ${device} \
|
||||||
|
--dataset ${name} \
|
||||||
|
--load-checkpoint ${pretrained_model_dir}
|
||||||
|
done
|
||||||
|
fi
|
||||||
|
|
||||||
|
if [ ${stage} -le 2 ]; then
|
||||||
|
# tune hyperparams on dev set
|
||||||
|
# perform final diarization on 'dev' and 'eval' with best hyperparams
|
||||||
|
python local/experiment.py --config ${conf_path} \
|
||||||
|
--data-dir ${save_folder}
|
||||||
|
fi
|
@ -1,14 +1,46 @@
|
|||||||
#!/bin/bash
|
#!/bin/bash
|
||||||
|
|
||||||
. path.sh || exit 1;
|
. ./path.sh || exit 1;
|
||||||
set -e
|
set -e
|
||||||
|
|
||||||
stage=1
|
stage=0
|
||||||
|
|
||||||
|
#TARGET_DIR=${MAIN_ROOT}/dataset/ami
|
||||||
|
TARGET_DIR=/home/dataset/AMI
|
||||||
|
data_folder=${TARGET_DIR}/amicorpus #e.g., /path/to/amicorpus/
|
||||||
|
manual_annot_folder=${TARGET_DIR}/ami_public_manual_1.6.2 #e.g., /path/to/ami_public_manual_1.6.2/
|
||||||
|
|
||||||
|
save_folder=./save
|
||||||
|
pretraind_model_dir=${save_folder}/sv0_ecapa_tdnn_voxceleb12_ckpt_0_1_1/model
|
||||||
|
conf_path=conf/ecapa_tdnn.yaml
|
||||||
|
device=gpu
|
||||||
|
|
||||||
. ${MAIN_ROOT}/utils/parse_options.sh || exit 1;
|
. ${MAIN_ROOT}/utils/parse_options.sh || exit 1;
|
||||||
|
|
||||||
if [ ${stage} -le 1 ]; then
|
if [ $stage -le 0 ]; then
|
||||||
# prepare data
|
# Prepare data
|
||||||
bash ./local/data.sh || exit -1
|
# Download AMI corpus, You need around 10GB of free space to get whole data
|
||||||
fi
|
# The signals are too large to package in this way,
|
||||||
|
# so you need to use the chooser to indicate which ones you wish to download
|
||||||
|
echo "Please follow https://groups.inf.ed.ac.uk/ami/download/ to download the data."
|
||||||
|
echo "Annotations: AMI manual annotations v1.6.2 "
|
||||||
|
echo "Signals: "
|
||||||
|
echo "1) Select one or more AMI meetings: the IDs please follow ./ami_split.py"
|
||||||
|
echo "2) Select media streams: Just select Headset mix"
|
||||||
|
fi
|
||||||
|
|
||||||
|
if [ $stage -le 1 ]; then
|
||||||
|
# Download the pretrained model
|
||||||
|
wget https://paddlespeech.bj.bcebos.com/vector/voxceleb/sv0_ecapa_tdnn_voxceleb12_ckpt_0_1_1.tar.gz
|
||||||
|
mkdir -p ${save_folder} && tar -xvf sv0_ecapa_tdnn_voxceleb12_ckpt_0_1_1.tar.gz -C ${save_folder}
|
||||||
|
rm -rf sv0_ecapa_tdnn_voxceleb12_ckpt_0_1_1.tar.gz
|
||||||
|
echo "download the pretrained ECAPA-TDNN Model to path: "${pretraind_model_dir}
|
||||||
|
fi
|
||||||
|
|
||||||
|
if [ $stage -le 2 ]; then
|
||||||
|
# Tune hyperparams on dev set and perform final diarization on dev and eval with best hyperparams.
|
||||||
|
echo ${data_folder} ${manual_annot_folder} ${save_folder} ${pretraind_model_dir} ${conf_path}
|
||||||
|
bash ./local/process.sh ${data_folder} ${manual_annot_folder} \
|
||||||
|
${save_folder} ${pretraind_model_dir} ${conf_path} ${device} || exit 1
|
||||||
|
fi
|
||||||
|
|
||||||
|
@ -0,0 +1,116 @@
|
|||||||
|
# Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved.
|
||||||
|
#
|
||||||
|
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||||
|
# you may not use this file except in compliance with the License.
|
||||||
|
# You may obtain a copy of the License at
|
||||||
|
#
|
||||||
|
# http://www.apache.org/licenses/LICENSE-2.0
|
||||||
|
#
|
||||||
|
# Unless required by applicable law or agreed to in writing, software
|
||||||
|
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||||
|
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||||
|
# See the License for the specific language governing permissions and
|
||||||
|
# limitations under the License.
|
||||||
|
import json
|
||||||
|
|
||||||
|
from dataclasses import dataclass
|
||||||
|
from dataclasses import fields
|
||||||
|
from paddle.io import Dataset
|
||||||
|
|
||||||
|
from paddleaudio import load as load_audio
|
||||||
|
from paddleaudio.compliance.librosa import melspectrogram
|
||||||
|
from paddleaudio.compliance.librosa import mfcc
|
||||||
|
|
||||||
|
|
||||||
|
@dataclass
|
||||||
|
class meta_info:
|
||||||
|
"""the audio meta info in the vector JSONDataset
|
||||||
|
Args:
|
||||||
|
id (str): the segment name
|
||||||
|
duration (float): segment time
|
||||||
|
wav (str): wav file path
|
||||||
|
start (int): start point in the original wav file
|
||||||
|
stop (int): stop point in the original wav file
|
||||||
|
lab_id (str): the record id
|
||||||
|
"""
|
||||||
|
id: str
|
||||||
|
duration: float
|
||||||
|
wav: str
|
||||||
|
start: int
|
||||||
|
stop: int
|
||||||
|
record_id: str
|
||||||
|
|
||||||
|
|
||||||
|
# json dataset support feature type
|
||||||
|
feat_funcs = {
|
||||||
|
'raw': None,
|
||||||
|
'melspectrogram': melspectrogram,
|
||||||
|
'mfcc': mfcc,
|
||||||
|
}
|
||||||
|
|
||||||
|
|
||||||
|
class JSONDataset(Dataset):
|
||||||
|
"""
|
||||||
|
dataset from json file.
|
||||||
|
"""
|
||||||
|
|
||||||
|
def __init__(self, json_file: str, feat_type: str='raw', **kwargs):
|
||||||
|
"""
|
||||||
|
Ags:
|
||||||
|
json_file (:obj:`str`): Data prep JSON file.
|
||||||
|
labels (:obj:`List[int]`): Labels of audio files.
|
||||||
|
feat_type (:obj:`str`, `optional`, defaults to `raw`):
|
||||||
|
It identifies the feature type that user wants to extrace of an audio file.
|
||||||
|
"""
|
||||||
|
if feat_type not in feat_funcs.keys():
|
||||||
|
raise RuntimeError(
|
||||||
|
f"Unknown feat_type: {feat_type}, it must be one in {list(feat_funcs.keys())}"
|
||||||
|
)
|
||||||
|
|
||||||
|
self.json_file = json_file
|
||||||
|
self.feat_type = feat_type
|
||||||
|
self.feat_config = kwargs
|
||||||
|
self._data = self._get_data()
|
||||||
|
super(JSONDataset, self).__init__()
|
||||||
|
|
||||||
|
def _get_data(self):
|
||||||
|
with open(self.json_file, "r") as f:
|
||||||
|
meta_data = json.load(f)
|
||||||
|
data = []
|
||||||
|
for key in meta_data:
|
||||||
|
sub_seg = meta_data[key]["wav"]
|
||||||
|
wav = sub_seg["file"]
|
||||||
|
duration = sub_seg["duration"]
|
||||||
|
start = sub_seg["start"]
|
||||||
|
stop = sub_seg["stop"]
|
||||||
|
rec_id = str(key).rsplit("_", 2)[0]
|
||||||
|
data.append(
|
||||||
|
meta_info(
|
||||||
|
str(key),
|
||||||
|
float(duration), wav, int(start), int(stop), str(rec_id)))
|
||||||
|
return data
|
||||||
|
|
||||||
|
def _convert_to_record(self, idx: int):
|
||||||
|
sample = self._data[idx]
|
||||||
|
|
||||||
|
record = {}
|
||||||
|
# To show all fields in a namedtuple
|
||||||
|
for field in fields(sample):
|
||||||
|
record[field.name] = getattr(sample, field.name)
|
||||||
|
|
||||||
|
waveform, sr = load_audio(record['wav'])
|
||||||
|
waveform = waveform[record['start']:record['stop']]
|
||||||
|
|
||||||
|
feat_func = feat_funcs[self.feat_type]
|
||||||
|
feat = feat_func(
|
||||||
|
waveform, sr=sr, **self.feat_config) if feat_func else waveform
|
||||||
|
|
||||||
|
record.update({'feat': feat})
|
||||||
|
|
||||||
|
return record
|
||||||
|
|
||||||
|
def __getitem__(self, idx):
|
||||||
|
return self._convert_to_record(idx)
|
||||||
|
|
||||||
|
def __len__(self):
|
||||||
|
return len(self._data)
|
Loading…
Reference in new issue