Merge pull request #1651 from ccrrong/ami

[vec] add speaker diarization pipeline
pull/1690/head
qingen 3 years ago committed by GitHub
commit fc72295334
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23

@ -0,0 +1,62 @@
###########################################################
# AMI DATA PREPARE SETTING #
###########################################################
split_type: 'full_corpus_asr'
skip_TNO: True
# Options for mic_type: 'Mix-Lapel', 'Mix-Headset', 'Array1', 'Array1-01', 'BeamformIt'
mic_type: 'Mix-Headset'
vad_type: 'oracle'
max_subseg_dur: 3.0
overlap: 1.5
# Some more exp folders (for cleaner structure).
embedding_dir: emb #!ref <save_folder>/emb
meta_data_dir: metadata #!ref <save_folder>/metadata
ref_rttm_dir: ref_rttms #!ref <save_folder>/ref_rttms
sys_rttm_dir: sys_rttms #!ref <save_folder>/sys_rttms
der_dir: DER #!ref <save_folder>/DER
###########################################################
# FEATURE EXTRACTION SETTING #
###########################################################
# currently, we only support fbank
sr: 16000 # sample rate
n_mels: 80
window_size: 400 #25ms, sample rate 16000, 25 * 16000 / 1000 = 400
hop_size: 160 #10ms, sample rate 16000, 10 * 16000 / 1000 = 160
#left_frames: 0
#right_frames: 0
#deltas: False
###########################################################
# MODEL SETTING #
###########################################################
# currently, we only support ecapa-tdnn in the ecapa_tdnn.yaml
# if we want use another model, please choose another configuration yaml file
seed: 1234
emb_dim: 192
batch_size: 16
model:
input_size: 80
channels: [1024, 1024, 1024, 1024, 3072]
kernel_sizes: [5, 3, 3, 3, 1]
dilations: [1, 2, 3, 4, 1]
attention_channels: 128
lin_neurons: 192
# Will automatically download ECAPA-TDNN model (best).
###########################################################
# SPECTRAL CLUSTERING SETTING #
###########################################################
backend: 'SC' # options: 'kmeans' # Note: kmeans goes only with cos affinity
affinity: 'cos' # options: cos, nn
max_num_spkrs: 10
oracle_n_spkrs: True
###########################################################
# DER EVALUATION SETTING #
###########################################################
ignore_overlap: True
forgiveness_collar: 0.25

@ -0,0 +1,231 @@
# Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import argparse
import json
import os
import pickle
import sys
import numpy as np
import paddle
from paddle.io import BatchSampler
from paddle.io import DataLoader
from tqdm.contrib import tqdm
from yacs.config import CfgNode
from paddlespeech.s2t.utils.log import Log
from paddlespeech.vector.cluster.diarization import EmbeddingMeta
from paddlespeech.vector.io.batch import batch_feature_normalize
from paddlespeech.vector.io.dataset_from_json import JSONDataset
from paddlespeech.vector.models.ecapa_tdnn import EcapaTdnn
from paddlespeech.vector.modules.sid_model import SpeakerIdetification
from paddlespeech.vector.training.seeding import seed_everything
# Logger setup
logger = Log(__name__).getlog()
def prepare_subset_json(full_meta_data, rec_id, out_meta_file):
"""Prepares metadata for a given recording ID.
Arguments
---------
full_meta_data : json
Full meta (json) containing all the recordings
rec_id : str
The recording ID for which meta (json) has to be prepared
out_meta_file : str
Path of the output meta (json) file.
"""
subset = {}
for key in full_meta_data:
k = str(key)
if k.startswith(rec_id):
subset[key] = full_meta_data[key]
with open(out_meta_file, mode="w") as json_f:
json.dump(subset, json_f, indent=2)
def create_dataloader(json_file, batch_size):
"""Creates the datasets and their data processing pipelines.
This is used for multi-mic processing.
"""
# create datasets
dataset = JSONDataset(
json_file=json_file,
feat_type='melspectrogram',
n_mels=config.n_mels,
window_size=config.window_size,
hop_length=config.hop_size)
# create dataloader
batch_sampler = BatchSampler(dataset, batch_size=batch_size, shuffle=True)
dataloader = DataLoader(dataset,
batch_sampler=batch_sampler,
collate_fn=lambda x: batch_feature_normalize(
x, mean_norm=True, std_norm=False),
return_list=True)
return dataloader
def main(args, config):
# set the training device, cpu or gpu
paddle.set_device(args.device)
# set the random seed
seed_everything(config.seed)
# stage1: build the dnn backbone model network
ecapa_tdnn = EcapaTdnn(**config.model)
# stage2: build the speaker verification eval instance with backbone model
model = SpeakerIdetification(backbone=ecapa_tdnn, num_class=1)
# stage3: load the pre-trained model
# we get the last model from the epoch and save_interval
args.load_checkpoint = os.path.abspath(
os.path.expanduser(args.load_checkpoint))
# load model checkpoint to sid model
state_dict = paddle.load(
os.path.join(args.load_checkpoint, 'model.pdparams'))
model.set_state_dict(state_dict)
logger.info(f'Checkpoint loaded from {args.load_checkpoint}')
# set the model to eval mode
model.eval()
# load meta data
meta_file = os.path.join(
args.data_dir,
config.meta_data_dir,
"ami_" + args.dataset + "." + config.mic_type + ".subsegs.json", )
with open(meta_file, "r") as f:
full_meta = json.load(f)
# get all the recording IDs in this dataset.
all_keys = full_meta.keys()
A = [word.rstrip().split("_")[0] for word in all_keys]
all_rec_ids = list(set(A[1:]))
all_rec_ids.sort()
split = "AMI_" + args.dataset
i = 1
msg = "Extra embdding for " + args.dataset + " set"
logger.info(msg)
if len(all_rec_ids) <= 0:
msg = "No recording IDs found! Please check if meta_data json file is properly generated."
logger.error(msg)
sys.exit()
# extra different recordings embdding in a dataset.
for rec_id in tqdm(all_rec_ids):
# This tag will be displayed in the log.
tag = ("[" + str(args.dataset) + ": " + str(i) + "/" +
str(len(all_rec_ids)) + "]")
i = i + 1
# log message.
msg = "Embdding %s : %s " % (tag, rec_id)
logger.debug(msg)
# embedding directory.
if not os.path.exists(
os.path.join(args.data_dir, config.embedding_dir, split)):
os.makedirs(
os.path.join(args.data_dir, config.embedding_dir, split))
# file to store embeddings.
emb_file_name = rec_id + "." + config.mic_type + ".emb_stat.pkl"
diary_stat_emb_file = os.path.join(args.data_dir, config.embedding_dir,
split, emb_file_name)
# prepare a metadata (json) for one recording. This is basically a subset of full_meta.
# lets keep this meta-info in embedding directory itself.
json_file_name = rec_id + "." + config.mic_type + ".json"
meta_per_rec_file = os.path.join(args.data_dir, config.embedding_dir,
split, json_file_name)
# write subset (meta for one recording) json metadata.
prepare_subset_json(full_meta, rec_id, meta_per_rec_file)
# prepare data loader.
diary_set_loader = create_dataloader(meta_per_rec_file,
config.batch_size)
# extract embeddings (skip if already done).
if not os.path.isfile(diary_stat_emb_file):
logger.debug("Extracting deep embeddings")
embeddings = np.empty(shape=[0, config.emb_dim], dtype=np.float64)
segset = []
for batch_idx, batch in enumerate(tqdm(diary_set_loader)):
# extrac the audio embedding
ids, feats, lengths = batch['ids'], batch['feats'], batch[
'lengths']
seg = [x for x in ids]
segset = segset + seg
emb = model.backbone(feats, lengths).squeeze(
-1).numpy() # (N, emb_size, 1) -> (N, emb_size)
embeddings = np.concatenate((embeddings, emb), axis=0)
segset = np.array(segset, dtype="|O")
stat_obj = EmbeddingMeta(
segset=segset,
stats=embeddings, )
logger.debug("Saving Embeddings...")
with open(diary_stat_emb_file, "wb") as output:
pickle.dump(stat_obj, output)
else:
logger.debug("Skipping embedding extraction (as already present).")
# Begin experiment!
if __name__ == "__main__":
parser = argparse.ArgumentParser(__doc__)
parser.add_argument(
'--device',
default="gpu",
help="Select which device to perform diarization, defaults to gpu.")
parser.add_argument(
"--config", default=None, type=str, help="configuration file")
parser.add_argument(
"--data-dir",
default="../save/",
type=str,
help="processsed data directory")
parser.add_argument(
"--dataset",
choices=['dev', 'eval'],
default="dev",
type=str,
help="Select which dataset to extra embdding, defaults to dev")
parser.add_argument(
"--load-checkpoint",
type=str,
default='',
help="Directory to load model checkpoint to compute embeddings.")
args = parser.parse_args()
config = CfgNode(new_allowed=True)
if args.config:
config.merge_from_file(args.config)
config.freeze()
main(args, config)

@ -1,49 +0,0 @@
#!/bin/bash
stage=1
TARGET_DIR=${MAIN_ROOT}/dataset/ami
data_folder=${TARGET_DIR}/amicorpus #e.g., /path/to/amicorpus/
manual_annot_folder=${TARGET_DIR}/ami_public_manual_1.6.2 #e.g., /path/to/ami_public_manual_1.6.2/
save_folder=${MAIN_ROOT}/examples/ami/sd0/data
ref_rttm_dir=${save_folder}/ref_rttms
meta_data_dir=${save_folder}/metadata
set=L
. ${MAIN_ROOT}/utils/parse_options.sh || exit 1;
set -u
set -o pipefail
mkdir -p ${save_folder}
if [ ${stage} -le 0 ]; then
# Download AMI corpus, You need around 10GB of free space to get whole data
# The signals are too large to package in this way,
# so you need to use the chooser to indicate which ones you wish to download
echo "Please follow https://groups.inf.ed.ac.uk/ami/download/ to download the data."
echo "Annotations: AMI manual annotations v1.6.2 "
echo "Signals: "
echo "1) Select one or more AMI meetings: the IDs please follow ./ami_split.py"
echo "2) Select media streams: Just select Headset mix"
exit 0;
fi
if [ ${stage} -le 1 ]; then
echo "AMI Data preparation"
python local/ami_prepare.py --data_folder ${data_folder} \
--manual_annot_folder ${manual_annot_folder} \
--save_folder ${save_folder} --ref_rttm_dir ${ref_rttm_dir} \
--meta_data_dir ${meta_data_dir}
if [ $? -ne 0 ]; then
echo "Prepare AMI failed. Please check log message."
exit 1
fi
fi
echo "AMI data preparation done."
exit 0

@ -0,0 +1,428 @@
# Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import argparse
import glob
import json
import os
import pickle
import shutil
import sys
import numpy as np
from tqdm.contrib import tqdm
from yacs.config import CfgNode
from paddlespeech.s2t.utils.log import Log
from paddlespeech.vector.cluster import diarization as diar
from utils.DER import DER
# Logger setup
logger = Log(__name__).getlog()
def diarize_dataset(
full_meta,
split_type,
n_lambdas,
pval,
save_dir,
config,
n_neighbors=10, ):
"""This function diarizes all the recordings in a given dataset. It performs
computation of embedding and clusters them using spectral clustering (or other backends).
The output speaker boundary file is stored in the RTTM format.
"""
# prepare `spkr_info` only once when Oracle num of speakers is selected.
# spkr_info is essential to obtain number of speakers from groundtruth.
if config.oracle_n_spkrs is True:
full_ref_rttm_file = os.path.join(save_dir, config.ref_rttm_dir,
"fullref_ami_" + split_type + ".rttm")
rttm = diar.read_rttm(full_ref_rttm_file)
spkr_info = list( # noqa F841
filter(lambda x: x.startswith("SPKR-INFO"), rttm))
# get all the recording IDs in this dataset.
all_keys = full_meta.keys()
A = [word.rstrip().split("_")[0] for word in all_keys]
all_rec_ids = list(set(A[1:]))
all_rec_ids.sort()
split = "AMI_" + split_type
i = 1
# adding tag for directory path.
type_of_num_spkr = "oracle" if config.oracle_n_spkrs else "est"
tag = (type_of_num_spkr + "_" + str(config.affinity) + "_" + config.backend)
# make out rttm dir
out_rttm_dir = os.path.join(save_dir, config.sys_rttm_dir, config.mic_type,
split, tag)
if not os.path.exists(out_rttm_dir):
os.makedirs(out_rttm_dir)
# diarizing different recordings in a dataset.
for rec_id in tqdm(all_rec_ids):
# this tag will be displayed in the log.
tag = ("[" + str(split_type) + ": " + str(i) + "/" +
str(len(all_rec_ids)) + "]")
i = i + 1
# log message.
msg = "Diarizing %s : %s " % (tag, rec_id)
logger.debug(msg)
# load embeddings.
emb_file_name = rec_id + "." + config.mic_type + ".emb_stat.pkl"
diary_stat_emb_file = os.path.join(save_dir, config.embedding_dir,
split, emb_file_name)
if not os.path.isfile(diary_stat_emb_file):
msg = "Embdding file %s not found! Please check if embdding file is properly generated." % (
diary_stat_emb_file)
logger.error(msg)
sys.exit()
with open(diary_stat_emb_file, "rb") as in_file:
diary_obj = pickle.load(in_file)
out_rttm_file = out_rttm_dir + "/" + rec_id + ".rttm"
# processing starts from here.
if config.oracle_n_spkrs is True:
# oracle num of speakers.
num_spkrs = diar.get_oracle_num_spkrs(rec_id, spkr_info)
else:
if config.affinity == "nn":
# num of speakers tunned on dev set (only for nn affinity).
num_spkrs = n_lambdas
else:
# num of speakers will be estimated using max eigen gap for cos based affinity.
# so adding None here. Will use this None later-on.
num_spkrs = None
if config.backend == "kmeans":
diar.do_kmeans_clustering(
diary_obj,
out_rttm_file,
rec_id,
num_spkrs,
pval, )
if config.backend == "SC":
# go for Spectral Clustering (SC).
diar.do_spec_clustering(
diary_obj,
out_rttm_file,
rec_id,
num_spkrs,
pval,
config.affinity,
n_neighbors, )
# can used for AHC later. Likewise one can add different backends here.
if config.backend == "AHC":
# call AHC
threshold = pval # pval for AHC is nothing but threshold.
diar.do_AHC(diary_obj, out_rttm_file, rec_id, num_spkrs, threshold)
# once all RTTM outputs are generated, concatenate individual RTTM files to obtain single RTTM file.
# this is not needed but just staying with the standards.
concate_rttm_file = out_rttm_dir + "/sys_output.rttm"
logger.debug("Concatenating individual RTTM files...")
with open(concate_rttm_file, "w") as cat_file:
for f in glob.glob(out_rttm_dir + "/*.rttm"):
if f == concate_rttm_file:
continue
with open(f, "r") as indi_rttm_file:
shutil.copyfileobj(indi_rttm_file, cat_file)
msg = "The system generated RTTM file for %s set : %s" % (
split_type, concate_rttm_file, )
logger.debug(msg)
return concate_rttm_file
def dev_pval_tuner(full_meta, save_dir, config):
"""Tuning p_value for affinity matrix.
The p_value used so that only p% of the values in each row is retained.
"""
DER_list = []
prange = np.arange(0.002, 0.015, 0.001)
n_lambdas = None # using it as flag later.
for p_v in prange:
# Process whole dataset for value of p_v.
concate_rttm_file = diarize_dataset(full_meta, "dev", n_lambdas, p_v,
save_dir, config)
ref_rttm_file = os.path.join(save_dir, config.ref_rttm_dir,
"fullref_ami_dev.rttm")
sys_rttm_file = concate_rttm_file
[MS, FA, SER, DER_] = DER(
ref_rttm_file,
sys_rttm_file,
config.ignore_overlap,
config.forgiveness_collar, )
DER_list.append(DER_)
if config.oracle_n_spkrs is True and config.backend == "kmeans":
# no need of p_val search. Note p_val is needed for SC for both oracle and est num of speakers.
# p_val is needed in oracle_n_spkr=False when using kmeans backend.
break
# Take p_val that gave minmum DER on Dev dataset.
tuned_p_val = prange[DER_list.index(min(DER_list))]
return tuned_p_val
def dev_ahc_threshold_tuner(full_meta, save_dir, config):
"""Tuning threshold for affinity matrix. This function is called when AHC is used as backend.
"""
DER_list = []
prange = np.arange(0.0, 1.0, 0.1)
n_lambdas = None # using it as flag later.
# Note: p_val is threshold in case of AHC.
for p_v in prange:
# Process whole dataset for value of p_v.
concate_rttm_file = diarize_dataset(full_meta, "dev", n_lambdas, p_v,
save_dir, config)
ref_rttm = os.path.join(save_dir, config.ref_rttm_dir,
"fullref_ami_dev.rttm")
sys_rttm = concate_rttm_file
[MS, FA, SER, DER_] = DER(
ref_rttm,
sys_rttm,
config.ignore_overlap,
config.forgiveness_collar, )
DER_list.append(DER_)
if config.oracle_n_spkrs is True:
break # no need of threshold search.
# Take p_val that gave minmum DER on Dev dataset.
tuned_p_val = prange[DER_list.index(min(DER_list))]
return tuned_p_val
def dev_nn_tuner(full_meta, split_type, save_dir, config):
"""Tuning n_neighbors on dev set. Assuming oracle num of speakers.
This is used when nn based affinity is selected.
"""
DER_list = []
pval = None
# Now assumming oracle num of speakers.
n_lambdas = 4
for nn in range(5, 15):
# Process whole dataset for value of n_lambdas.
concate_rttm_file = diarize_dataset(full_meta, "dev", n_lambdas, p_v,
save_dir, config, nn)
ref_rttm = os.path.join(save_dir, config.ref_rttm_dir,
"fullref_ami_dev.rttm")
sys_rttm = concate_rttm_file
[MS, FA, SER, DER_] = DER(
ref_rttm,
sys_rttm,
config.ignore_overlap,
config.forgiveness_collar, )
DER_list.append([nn, DER_])
if config.oracle_n_spkrs is True and config.backend == "kmeans":
break
DER_list.sort(key=lambda x: x[1])
tunned_nn = DER_list[0]
return tunned_nn[0]
def dev_tuner(full_meta, split_type, save_dir, config):
"""Tuning n_components on dev set. Used for nn based affinity matrix.
Note: This is a very basic tunning for nn based affinity.
This is work in progress till we find a better way.
"""
DER_list = []
pval = None
for n_lambdas in range(1, config.max_num_spkrs + 1):
# Process whole dataset for value of n_lambdas.
concate_rttm_file = diarize_dataset(full_meta, "dev", n_lambdas, p_v,
save_dir, config)
ref_rttm = os.path.join(save_dir, config.ref_rttm_dir,
"fullref_ami_dev.rttm")
sys_rttm = concate_rttm_file
[MS, FA, SER, DER_] = DER(
ref_rttm,
sys_rttm,
config.ignore_overlap,
config.forgiveness_collar, )
DER_list.append(DER_)
# Take n_lambdas with minmum DER.
tuned_n_lambdas = DER_list.index(min(DER_list)) + 1
return tuned_n_lambdas
def main(args, config):
# AMI Dev Set: Tune hyperparams on dev set.
# Read the embdding file for dev set generated during embdding compute
dev_meta_file = os.path.join(
args.data_dir,
config.meta_data_dir,
"ami_dev." + config.mic_type + ".subsegs.json", )
with open(dev_meta_file, "r") as f:
meta_dev = json.load(f)
full_meta = meta_dev
# Processing starts from here
# Following few lines selects option for different backend and affinity matrices. Finds best values for hyperameters using dev set.
ref_rttm_file = os.path.join(args.data_dir, config.ref_rttm_dir,
"fullref_ami_dev.rttm")
best_nn = None
if config.affinity == "nn":
logger.info("Tuning for nn (Multiple iterations over AMI Dev set)")
best_nn = dev_nn_tuner(full_meta, args.data_dir, config)
n_lambdas = None
best_pval = None
if config.affinity == "cos" and (config.backend == "SC" or
config.backend == "kmeans"):
# oracle num_spkrs or not, doesn't matter for kmeans and SC backends
# cos: Tune for the best pval for SC /kmeans (for unknown num of spkrs)
logger.info(
"Tuning for p-value for SC (Multiple iterations over AMI Dev set)")
best_pval = dev_pval_tuner(full_meta, args.data_dir, config)
elif config.backend == "AHC":
logger.info("Tuning for threshold-value for AHC")
best_threshold = dev_ahc_threshold_tuner(full_meta, args.data_dir,
config)
best_pval = best_threshold
else:
# NN for unknown num of speakers (can be used in future)
if config.oracle_n_spkrs is False:
# nn: Tune num of number of components (to be updated later)
logger.info(
"Tuning for number of eigen components for NN (Multiple iterations over AMI Dev set)"
)
# dev_tuner used for tuning num of components in NN. Can be used in future.
n_lambdas = dev_tuner(full_meta, args.data_dir, config)
# load 'dev' and 'eval' metadata files.
full_meta_dev = full_meta # current full_meta is for 'dev'
eval_meta_file = os.path.join(
args.data_dir,
config.meta_data_dir,
"ami_eval." + config.mic_type + ".subsegs.json", )
with open(eval_meta_file, "r") as f:
full_meta_eval = json.load(f)
# tag to be appended to final output DER files. Writing DER for individual files.
type_of_num_spkr = "oracle" if config.oracle_n_spkrs else "est"
tag = (
type_of_num_spkr + "_" + str(config.affinity) + "." + config.mic_type)
# perform final diarization on 'dev' and 'eval' with best hyperparams.
final_DERs = {}
out_der_dir = os.path.join(args.data_dir, config.der_dir)
if not os.path.exists(out_der_dir):
os.makedirs(out_der_dir)
for split_type in ["dev", "eval"]:
if split_type == "dev":
full_meta = full_meta_dev
else:
full_meta = full_meta_eval
# performing diarization.
msg = "Diarizing using best hyperparams: " + split_type + " set"
logger.info(msg)
out_boundaries = diarize_dataset(
full_meta,
split_type,
n_lambdas=n_lambdas,
pval=best_pval,
n_neighbors=best_nn,
save_dir=args.data_dir,
config=config)
# computing DER.
msg = "Computing DERs for " + split_type + " set"
logger.info(msg)
ref_rttm = os.path.join(args.data_dir, config.ref_rttm_dir,
"fullref_ami_" + split_type + ".rttm")
sys_rttm = out_boundaries
[MS, FA, SER, DER_vals] = DER(
ref_rttm,
sys_rttm,
config.ignore_overlap,
config.forgiveness_collar,
individual_file_scores=True, )
# writing DER values to a file. Append tag.
der_file_name = split_type + "_DER_" + tag
out_der_file = os.path.join(out_der_dir, der_file_name)
msg = "Writing DER file to: " + out_der_file
logger.info(msg)
diar.write_ders_file(ref_rttm, DER_vals, out_der_file)
msg = ("AMI " + split_type + " set DER = %s %%\n" %
(str(round(DER_vals[-1], 2))))
logger.info(msg)
final_DERs[split_type] = round(DER_vals[-1], 2)
# final print DERs
msg = (
"Final Diarization Error Rate (%%) on AMI corpus: Dev = %s %% | Eval = %s %%\n"
% (str(final_DERs["dev"]), str(final_DERs["eval"])))
logger.info(msg)
if __name__ == "__main__":
parser = argparse.ArgumentParser(__doc__)
parser.add_argument(
"--config", default=None, type=str, help="configuration file")
parser.add_argument(
"--data-dir",
default="../data/",
type=str,
help="processsed data directory")
args = parser.parse_args()
config = CfgNode(new_allowed=True)
if args.config:
config.merge_from_file(args.config)
config.freeze()
main(args, config)

@ -0,0 +1,49 @@
#!/bin/bash
stage=0
set=L
. ${MAIN_ROOT}/utils/parse_options.sh || exit 1;
set -o pipefail
data_folder=$1
manual_annot_folder=$2
save_folder=$3
pretrained_model_dir=$4
conf_path=$5
device=$6
ref_rttm_dir=${save_folder}/ref_rttms
meta_data_dir=${save_folder}/metadata
if [ ${stage} -le 0 ]; then
echo "AMI Data preparation"
python local/ami_prepare.py --data_folder ${data_folder} \
--manual_annot_folder ${manual_annot_folder} \
--save_folder ${save_folder} --ref_rttm_dir ${ref_rttm_dir} \
--meta_data_dir ${meta_data_dir}
if [ $? -ne 0 ]; then
echo "Prepare AMI failed. Please check log message."
exit 1
fi
echo "AMI data preparation done."
fi
if [ ${stage} -le 1 ]; then
# extra embddings for dev and eval dataset
for name in dev eval; do
python local/compute_embdding.py --config ${conf_path} \
--data-dir ${save_folder} \
--device ${device} \
--dataset ${name} \
--load-checkpoint ${pretrained_model_dir}
done
fi
if [ ${stage} -le 2 ]; then
# tune hyperparams on dev set
# perform final diarization on 'dev' and 'eval' with best hyperparams
python local/experiment.py --config ${conf_path} \
--data-dir ${save_folder}
fi

@ -1,14 +1,46 @@
#!/bin/bash #!/bin/bash
. path.sh || exit 1; . ./path.sh || exit 1;
set -e set -e
stage=1 stage=0
#TARGET_DIR=${MAIN_ROOT}/dataset/ami
TARGET_DIR=/home/dataset/AMI
data_folder=${TARGET_DIR}/amicorpus #e.g., /path/to/amicorpus/
manual_annot_folder=${TARGET_DIR}/ami_public_manual_1.6.2 #e.g., /path/to/ami_public_manual_1.6.2/
save_folder=./save
pretraind_model_dir=${save_folder}/sv0_ecapa_tdnn_voxceleb12_ckpt_0_1_1/model
conf_path=conf/ecapa_tdnn.yaml
device=gpu
. ${MAIN_ROOT}/utils/parse_options.sh || exit 1; . ${MAIN_ROOT}/utils/parse_options.sh || exit 1;
if [ ${stage} -le 1 ]; then if [ $stage -le 0 ]; then
# prepare data # Prepare data
bash ./local/data.sh || exit -1 # Download AMI corpus, You need around 10GB of free space to get whole data
fi # The signals are too large to package in this way,
# so you need to use the chooser to indicate which ones you wish to download
echo "Please follow https://groups.inf.ed.ac.uk/ami/download/ to download the data."
echo "Annotations: AMI manual annotations v1.6.2 "
echo "Signals: "
echo "1) Select one or more AMI meetings: the IDs please follow ./ami_split.py"
echo "2) Select media streams: Just select Headset mix"
fi
if [ $stage -le 1 ]; then
# Download the pretrained model
wget https://paddlespeech.bj.bcebos.com/vector/voxceleb/sv0_ecapa_tdnn_voxceleb12_ckpt_0_1_1.tar.gz
mkdir -p ${save_folder} && tar -xvf sv0_ecapa_tdnn_voxceleb12_ckpt_0_1_1.tar.gz -C ${save_folder}
rm -rf sv0_ecapa_tdnn_voxceleb12_ckpt_0_1_1.tar.gz
echo "download the pretrained ECAPA-TDNN Model to path: "${pretraind_model_dir}
fi
if [ $stage -le 2 ]; then
# Tune hyperparams on dev set and perform final diarization on dev and eval with best hyperparams.
echo ${data_folder} ${manual_annot_folder} ${save_folder} ${pretraind_model_dir} ${conf_path}
bash ./local/process.sh ${data_folder} ${manual_annot_folder} \
${save_folder} ${pretraind_model_dir} ${conf_path} ${device} || exit 1
fi

@ -746,6 +746,77 @@ def merge_ssegs_same_speaker(lol):
return new_lol return new_lol
def write_ders_file(ref_rttm, DER, out_der_file):
"""Write the final DERs for individual recording.
Arguments
---------
ref_rttm : str
Reference RTTM file.
DER : array
Array containing DER values of each recording.
out_der_file : str
File to write the DERs.
"""
rttm = read_rttm(ref_rttm)
spkr_info = list(filter(lambda x: x.startswith("SPKR-INFO"), rttm))
rec_id_list = []
count = 0
with open(out_der_file, "w") as f:
for row in spkr_info:
a = row.split(" ")
rec_id = a[1]
if rec_id not in rec_id_list:
r = [rec_id, str(round(DER[count], 2))]
rec_id_list.append(rec_id)
line_str = " ".join(r)
f.write("%s\n" % line_str)
count += 1
r = ["OVERALL ", str(round(DER[count], 2))]
line_str = " ".join(r)
f.write("%s\n" % line_str)
def get_oracle_num_spkrs(rec_id, spkr_info):
"""
Returns actual number of speakers in a recording from the ground-truth.
This can be used when the condition is oracle number of speakers.
Arguments
---------
rec_id : str
Recording ID for which the number of speakers have to be obtained.
spkr_info : list
Header of the RTTM file. Starting with `SPKR-INFO`.
Example
-------
>>> from speechbrain.processing import diarization as diar
>>> spkr_info = ['SPKR-INFO ES2011a 0 <NA> <NA> <NA> unknown ES2011a.A <NA> <NA>',
... 'SPKR-INFO ES2011a 0 <NA> <NA> <NA> unknown ES2011a.B <NA> <NA>',
... 'SPKR-INFO ES2011a 0 <NA> <NA> <NA> unknown ES2011a.C <NA> <NA>',
... 'SPKR-INFO ES2011a 0 <NA> <NA> <NA> unknown ES2011a.D <NA> <NA>',
... 'SPKR-INFO ES2011b 0 <NA> <NA> <NA> unknown ES2011b.A <NA> <NA>',
... 'SPKR-INFO ES2011b 0 <NA> <NA> <NA> unknown ES2011b.B <NA> <NA>',
... 'SPKR-INFO ES2011b 0 <NA> <NA> <NA> unknown ES2011b.C <NA> <NA>']
>>> diar.get_oracle_num_spkrs('ES2011a', spkr_info)
4
>>> diar.get_oracle_num_spkrs('ES2011b', spkr_info)
3
"""
num_spkrs = 0
for line in spkr_info:
if rec_id in line:
# Since rec_id is prefix for each speaker
num_spkrs += 1
return num_spkrs
def distribute_overlap(lol): def distribute_overlap(lol):
""" """
Distributes the overlapped speech equally among the adjacent segments Distributes the overlapped speech equally among the adjacent segments
@ -826,6 +897,29 @@ def distribute_overlap(lol):
return new_lol return new_lol
def read_rttm(rttm_file_path):
"""
Reads and returns RTTM in list format.
Arguments
---------
rttm_file_path : str
Path to the RTTM file to be read.
Returns
-------
rttm : list
List containing rows of RTTM file.
"""
rttm = []
with open(rttm_file_path, "r") as f:
for line in f:
entry = line[:-1]
rttm.append(entry)
return rttm
def write_rttm(segs_list, out_rttm_file): def write_rttm(segs_list, out_rttm_file):
""" """
Writes the segment list in RTTM format (A standard NIST format). Writes the segment list in RTTM format (A standard NIST format).

@ -0,0 +1,116 @@
# Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import json
from dataclasses import dataclass
from dataclasses import fields
from paddle.io import Dataset
from paddleaudio import load as load_audio
from paddleaudio.compliance.librosa import melspectrogram
from paddleaudio.compliance.librosa import mfcc
@dataclass
class meta_info:
"""the audio meta info in the vector JSONDataset
Args:
id (str): the segment name
duration (float): segment time
wav (str): wav file path
start (int): start point in the original wav file
stop (int): stop point in the original wav file
lab_id (str): the record id
"""
id: str
duration: float
wav: str
start: int
stop: int
record_id: str
# json dataset support feature type
feat_funcs = {
'raw': None,
'melspectrogram': melspectrogram,
'mfcc': mfcc,
}
class JSONDataset(Dataset):
"""
dataset from json file.
"""
def __init__(self, json_file: str, feat_type: str='raw', **kwargs):
"""
Ags:
json_file (:obj:`str`): Data prep JSON file.
labels (:obj:`List[int]`): Labels of audio files.
feat_type (:obj:`str`, `optional`, defaults to `raw`):
It identifies the feature type that user wants to extrace of an audio file.
"""
if feat_type not in feat_funcs.keys():
raise RuntimeError(
f"Unknown feat_type: {feat_type}, it must be one in {list(feat_funcs.keys())}"
)
self.json_file = json_file
self.feat_type = feat_type
self.feat_config = kwargs
self._data = self._get_data()
super(JSONDataset, self).__init__()
def _get_data(self):
with open(self.json_file, "r") as f:
meta_data = json.load(f)
data = []
for key in meta_data:
sub_seg = meta_data[key]["wav"]
wav = sub_seg["file"]
duration = sub_seg["duration"]
start = sub_seg["start"]
stop = sub_seg["stop"]
rec_id = str(key).rsplit("_", 2)[0]
data.append(
meta_info(
str(key),
float(duration), wav, int(start), int(stop), str(rec_id)))
return data
def _convert_to_record(self, idx: int):
sample = self._data[idx]
record = {}
# To show all fields in a namedtuple
for field in fields(sample):
record[field.name] = getattr(sample, field.name)
waveform, sr = load_audio(record['wav'])
waveform = waveform[record['start']:record['stop']]
feat_func = feat_funcs[self.feat_type]
feat = feat_func(
waveform, sr=sr, **self.feat_config) if feat_func else waveform
record.update({'feat': feat})
return record
def __getitem__(self, idx):
return self._convert_to_record(idx)
def __len__(self):
return len(self._data)
Loading…
Cancel
Save