From 378fe5909f839e61c898d0dca3a5dd1cbc9cf9a7 Mon Sep 17 00:00:00 2001 From: ccrrong <1039058843@qq.com> Date: Tue, 5 Apr 2022 11:34:57 +0800 Subject: [PATCH] add ami diarization pipeline, test=doc --- examples/ami/sd0/conf/ecapa_tdnn.yaml | 71 ++++ examples/ami/sd0/local/ami_dataset.py | 90 +++++ examples/ami/sd0/local/compute_embdding.py | 233 +++++++++++ examples/ami/sd0/local/experiment.py | 439 +++++++++++++++++++++ examples/ami/sd0/local/process.sh | 49 +++ examples/ami/sd0/run.sh | 36 +- paddlespeech/vector/cluster/diarization.py | 94 +++++ utils/compute_der.py | 175 ++++++++ 8 files changed, 1182 insertions(+), 5 deletions(-) create mode 100755 examples/ami/sd0/conf/ecapa_tdnn.yaml create mode 100644 examples/ami/sd0/local/ami_dataset.py create mode 100644 examples/ami/sd0/local/compute_embdding.py create mode 100755 examples/ami/sd0/local/experiment.py create mode 100755 examples/ami/sd0/local/process.sh create mode 100755 utils/compute_der.py diff --git a/examples/ami/sd0/conf/ecapa_tdnn.yaml b/examples/ami/sd0/conf/ecapa_tdnn.yaml new file mode 100755 index 000000000..0f298c35b --- /dev/null +++ b/examples/ami/sd0/conf/ecapa_tdnn.yaml @@ -0,0 +1,71 @@ +# ################################################## +# Model: Speaker Diarization Baseline +# Embeddings: Deep embedding +# Clustering Technique: Spectral clustering +# Authors: Nauman Dawalatabad 2020 +# ################################################# + +seed: 1234 +num_speakers: 7205 + +########################################################### +# AMI DATA PREPARE SETTING # +########################################################### +split_type: 'full_corpus_asr' +skip_TNO: True +# Options for mic_type: 'Mix-Lapel', 'Mix-Headset', 'Array1', 'Array1-01', 'BeamformIt' +mic_type: 'Mix-Headset' +vad_type: 'oracle' +max_subseg_dur: 3.0 +overlap: 1.5 +# Some more exp folders (for cleaner structure). +embedding_dir: emb #!ref /emb +meta_data_dir: metadata #!ref /metadata +ref_rttm_dir: ref_rttms #!ref /ref_rttms +sys_rttm_dir: sys_rttms #!ref /sys_rttms +der_dir: DER #!ref /DER + + +########################################################### +# FEATURE EXTRACTION SETTING # +########################################################### +# currently, we only support fbank +sr: 16000 # sample rate +n_mels: 80 +window_size: 400 #25ms, sample rate 16000, 25 * 16000 / 1000 = 400 +hop_size: 160 #10ms, sample rate 16000, 10 * 16000 / 1000 = 160 +#left_frames: 0 +#right_frames: 0 +#deltas: False + + +########################################################### +# MODEL SETTING # +########################################################### +# currently, we only support ecapa-tdnn in the ecapa_tdnn.yaml +# if we want use another model, please choose another configuration yaml file +emb_dim: 192 +batch_size: 16 +model: + input_size: 80 + channels: [1024, 1024, 1024, 1024, 3072] + kernel_sizes: [5, 3, 3, 3, 1] + dilations: [1, 2, 3, 4, 1] + attention_channels: 128 + lin_neurons: 192 +# Will automatically download ECAPA-TDNN model (best). + +########################################################### +# SPECTRAL CLUSTERING SETTING # +########################################################### +backend: 'SC' # options: 'kmeans' # Note: kmeans goes only with cos affinity +affinity: 'cos' # options: cos, nn +max_num_spkrs: 10 +oracle_n_spkrs: True + + +########################################################### +# DER EVALUATION SETTING # +########################################################### +ignore_overlap: True +forgiveness_collar: 0.25 diff --git a/examples/ami/sd0/local/ami_dataset.py b/examples/ami/sd0/local/ami_dataset.py new file mode 100644 index 000000000..c44329c83 --- /dev/null +++ b/examples/ami/sd0/local/ami_dataset.py @@ -0,0 +1,90 @@ +# Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +import collections +import json + +from paddle.io import Dataset + +from paddleaudio.backends import load as load_audio +from paddleaudio.datasets.dataset import feat_funcs + + +class AMIDataset(Dataset): + """ + AMI dataset. + """ + + meta_info = collections.namedtuple( + 'META_INFO', ('id', 'duration', 'wav', 'start', 'stop', 'record_id')) + + def __init__(self, json_file: str, feat_type: str='raw', **kwargs): + """ + Ags: + json_file (:obj:`str`): Data prep JSON file. + labels (:obj:`List[int]`): Labels of audio files. + feat_type (:obj:`str`, `optional`, defaults to `raw`): + It identifies the feature type that user wants to extrace of an audio file. + """ + if feat_type not in feat_funcs.keys(): + raise RuntimeError( + f"Unknown feat_type: {feat_type}, it must be one in {list(feat_funcs.keys())}" + ) + + self.json_file = json_file + self.feat_type = feat_type + self.feat_config = kwargs + self._data = self._get_data() + super(AMIDataset, self).__init__() + + def _get_data(self): + with open(self.json_file, "r") as f: + meta_data = json.load(f) + data = [] + for key in meta_data: + sub_seg = meta_data[key]["wav"] + wav = sub_seg["file"] + duration = sub_seg["duration"] + start = sub_seg["start"] + stop = sub_seg["stop"] + rec_id = str(key).rsplit("_", 2)[0] + data.append( + self.meta_info( + str(key), + float(duration), wav, int(start), int(stop), str(rec_id))) + return data + + def _convert_to_record(self, idx: int): + sample = self._data[idx] + + record = {} + # To show all fields in a namedtuple: `type(sample)._fields` + for field in type(sample)._fields: + record[field] = getattr(sample, field) + + waveform, sr = load_audio(record['wav']) + waveform = waveform[record['start']:record['stop']] + + feat_func = feat_funcs[self.feat_type] + feat = feat_func( + waveform, sr=sr, **self.feat_config) if feat_func else waveform + + record.update({'feat': feat}) + + return record + + def __getitem__(self, idx): + return self._convert_to_record(idx) + + def __len__(self): + return len(self._data) diff --git a/examples/ami/sd0/local/compute_embdding.py b/examples/ami/sd0/local/compute_embdding.py new file mode 100644 index 000000000..e4fd5da2c --- /dev/null +++ b/examples/ami/sd0/local/compute_embdding.py @@ -0,0 +1,233 @@ +# Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +import argparse +import json +import os +import pickle +import sys + +import numpy as np +import paddle +from ami_dataset import AMIDataset +from paddle.io import BatchSampler +from paddle.io import DataLoader +from tqdm.contrib import tqdm +from yacs.config import CfgNode + +from paddlespeech.s2t.utils.log import Log +from paddlespeech.vector.cluster.diarization import EmbeddingMeta +from paddlespeech.vector.io.batch import batch_feature_normalize +from paddlespeech.vector.models.ecapa_tdnn import EcapaTdnn +from paddlespeech.vector.modules.sid_model import SpeakerIdetification +from paddlespeech.vector.training.seeding import seed_everything + +# Logger setup +logger = Log(__name__).getlog() + + +def prepare_subset_json(full_meta_data, rec_id, out_meta_file): + """Prepares metadata for a given recording ID. + + Arguments + --------- + full_meta_data : json + Full meta (json) containing all the recordings + rec_id : str + The recording ID for which meta (json) has to be prepared + out_meta_file : str + Path of the output meta (json) file. + """ + + subset = {} + for key in full_meta_data: + k = str(key) + if k.startswith(rec_id): + subset[key] = full_meta_data[key] + + with open(out_meta_file, mode="w") as json_f: + json.dump(subset, json_f, indent=2) + + +def create_dataloader(json_file, batch_size): + """Creates the datasets and their data processing pipelines. + This is used for multi-mic processing. + """ + + # create datasets + dataset = AMIDataset( + json_file=json_file, + feat_type='melspectrogram', + n_mels=config.n_mels, + window_size=config.window_size, + hop_length=config.hop_size) + + # create dataloader + batch_sampler = BatchSampler(dataset, batch_size=batch_size, shuffle=True) + dataloader = DataLoader(dataset, + batch_sampler=batch_sampler, + collate_fn=lambda x: batch_feature_normalize( + x, mean_norm=True, std_norm=False), + return_list=True) + + return dataloader + + +def main(args, config): + # set the training device, cpu or gpu + paddle.set_device(args.device) + # set the random seed + seed_everything(config.seed) + + # stage1: build the dnn backbone model network + ecapa_tdnn = EcapaTdnn(**config.model) + + # stage2: build the speaker verification eval instance with backbone model + model = SpeakerIdetification( + backbone=ecapa_tdnn, num_class=config.num_speakers) + + # stage3: load the pre-trained model + # we get the last model from the epoch and save_interval + args.load_checkpoint = os.path.abspath( + os.path.expanduser(args.load_checkpoint)) + + # load model checkpoint to sid model + state_dict = paddle.load( + os.path.join(args.load_checkpoint, 'model.pdparams')) + model.set_state_dict(state_dict) + logger.info(f'Checkpoint loaded from {args.load_checkpoint}') + + # set the model to eval mode + model.eval() + + # load meta data + meta_file = os.path.join( + args.data_dir, + config.meta_data_dir, + "ami_" + args.dataset + "." + config.mic_type + ".subsegs.json", ) + with open(meta_file, "r") as f: + full_meta = json.load(f) + + # get all the recording IDs in this dataset. + all_keys = full_meta.keys() + A = [word.rstrip().split("_")[0] for word in all_keys] + all_rec_ids = list(set(A[1:])) + all_rec_ids.sort() + split = "AMI_" + args.dataset + i = 1 + + msg = "Extra embdding for " + args.dataset + " set" + logger.info(msg) + + if len(all_rec_ids) <= 0: + msg = "No recording IDs found! Please check if meta_data json file is properly generated." + logger.error(msg) + sys.exit() + + # extra different recordings embdding in a dataset. + for rec_id in tqdm(all_rec_ids): + # This tag will be displayed in the log. + tag = ("[" + str(args.dataset) + ": " + str(i) + "/" + + str(len(all_rec_ids)) + "]") + i = i + 1 + + # log message. + msg = "Embdding %s : %s " % (tag, rec_id) + logger.debug(msg) + + # embedding directory. + if not os.path.exists( + os.path.join(args.data_dir, config.embedding_dir, split)): + os.makedirs( + os.path.join(args.data_dir, config.embedding_dir, split)) + + # file to store embeddings. + emb_file_name = rec_id + "." + config.mic_type + ".emb_stat.pkl" + diary_stat_emb_file = os.path.join(args.data_dir, config.embedding_dir, + split, emb_file_name) + + # prepare a metadata (json) for one recording. This is basically a subset of full_meta. + # lets keep this meta-info in embedding directory itself. + json_file_name = rec_id + "." + config.mic_type + ".json" + meta_per_rec_file = os.path.join(args.data_dir, config.embedding_dir, + split, json_file_name) + + # write subset (meta for one recording) json metadata. + prepare_subset_json(full_meta, rec_id, meta_per_rec_file) + + # prepare data loader. + diary_set_loader = create_dataloader(meta_per_rec_file, + config.batch_size) + + # extract embeddings (skip if already done). + if not os.path.isfile(diary_stat_emb_file): + logger.debug("Extracting deep embeddings") + embeddings = np.empty(shape=[0, config.emb_dim], dtype=np.float64) + segset = [] + + for batch_idx, batch in enumerate(tqdm(diary_set_loader)): + # extrac the audio embedding + ids, feats, lengths = batch['ids'], batch['feats'], batch[ + 'lengths'] + seg = [x for x in ids] + segset = segset + seg + emb = model.backbone(feats, lengths).squeeze( + -1).numpy() # (N, emb_size, 1) -> (N, emb_size) + embeddings = np.concatenate((embeddings, emb), axis=0) + + segset = np.array(segset, dtype="|O") + stat_obj = EmbeddingMeta( + segset=segset, + stats=embeddings, ) + logger.debug("Saving Embeddings...") + with open(diary_stat_emb_file, "wb") as output: + pickle.dump(stat_obj, output) + + else: + logger.debug("Skipping embedding extraction (as already present).") + + +# Begin experiment! +if __name__ == "__main__": + parser = argparse.ArgumentParser(__doc__) + parser.add_argument( + '--device', + default="gpu", + help="Select which device to perform diarization, defaults to gpu.") + parser.add_argument( + "--config", default=None, type=str, help="configuration file") + parser.add_argument( + "--data-dir", + default="../save/", + type=str, + help="processsed data directory") + parser.add_argument( + "--dataset", + choices=['dev', 'eval'], + default="dev", + type=str, + help="Select which dataset to extra embdding, defaults to dev") + parser.add_argument( + "--load-checkpoint", + type=str, + default='', + help="Directory to load model checkpoint to compute embeddings.") + args = parser.parse_args() + config = CfgNode(new_allowed=True) + if args.config: + config.merge_from_file(args.config) + + config.freeze() + print(config) + + main(args, config) diff --git a/examples/ami/sd0/local/experiment.py b/examples/ami/sd0/local/experiment.py new file mode 100755 index 000000000..e912a4895 --- /dev/null +++ b/examples/ami/sd0/local/experiment.py @@ -0,0 +1,439 @@ +#!/usr/bin/python3 +"""This recipe implements diarization system using deep embedding extraction followed by spectral clustering. + +To run this recipe: +> python experiment.py hparams/ + e.g., python experiment.py hparams/ecapa_tdnn.yaml + +Condition: Oracle VAD (speech regions taken from the groundtruth). + +Note: There are multiple ways to write this recipe. We iterate over individual recordings. + This approach is less GPU memory demanding and also makes code easy to understand. + +Citation: This recipe is based on the following paper, + N. Dawalatabad, M. Ravanelli, F. Grondin, J. Thienpondt, B. Desplanques, H. Na, + "ECAPA-TDNN Embeddings for Speaker Diarization," arXiv:2104.01466, 2021. + +Authors + * Nauman Dawalatabad 2020 +""" +import argparse +import glob +import json +import os +import pickle +import shutil +import sys + +import numpy as np +from tqdm.contrib import tqdm +from yacs.config import CfgNode + +from paddlespeech.s2t.utils.log import Log +from paddlespeech.vector.cluster import diarization as diar +from utils.compute_der import DER + +# Logger setup +logger = Log(__name__).getlog() + + +def diarize_dataset( + full_meta, + split_type, + n_lambdas, + pval, + save_dir, + config, + n_neighbors=10, ): + """This function diarizes all the recordings in a given dataset. It performs + computation of embedding and clusters them using spectral clustering (or other backends). + The output speaker boundary file is stored in the RTTM format. + """ + + # prepare `spkr_info` only once when Oracle num of speakers is selected. + # spkr_info is essential to obtain number of speakers from groundtruth. + if config.oracle_n_spkrs is True: + full_ref_rttm_file = os.path.join(save_dir, config.ref_rttm_dir, + "fullref_ami_" + split_type + ".rttm") + rttm = diar.read_rttm(full_ref_rttm_file) + + spkr_info = list( # noqa F841 + filter(lambda x: x.startswith("SPKR-INFO"), rttm)) + + # get all the recording IDs in this dataset. + all_keys = full_meta.keys() + A = [word.rstrip().split("_")[0] for word in all_keys] + all_rec_ids = list(set(A[1:])) + all_rec_ids.sort() + split = "AMI_" + split_type + i = 1 + + # adding tag for directory path. + type_of_num_spkr = "oracle" if config.oracle_n_spkrs else "est" + tag = (type_of_num_spkr + "_" + str(config.affinity) + "_" + config.backend) + + # make out rttm dir + out_rttm_dir = os.path.join(save_dir, config.sys_rttm_dir, config.mic_type, + split, tag) + if not os.path.exists(out_rttm_dir): + os.makedirs(out_rttm_dir) + + # diarizing different recordings in a dataset. + for rec_id in tqdm(all_rec_ids): + # this tag will be displayed in the log. + if rec_id == "IS1008a": + continue + if rec_id == "ES2011a": + continue + tag = ("[" + str(split_type) + ": " + str(i) + "/" + + str(len(all_rec_ids)) + "]") + i = i + 1 + + # log message. + msg = "Diarizing %s : %s " % (tag, rec_id) + logger.debug(msg) + + # load embeddings. + emb_file_name = rec_id + "." + config.mic_type + ".emb_stat.pkl" + diary_stat_emb_file = os.path.join(save_dir, config.embedding_dir, + split, emb_file_name) + if not os.path.isfile(diary_stat_emb_file): + msg = "Embdding file %s not found! Please check if embdding file is properly generated." % ( + diary_stat_emb_file) + logger.error(msg) + sys.exit() + with open(diary_stat_emb_file, "rb") as in_file: + diary_obj = pickle.load(in_file) + + out_rttm_file = out_rttm_dir + "/" + rec_id + ".rttm" + + # processing starts from here. + if config.oracle_n_spkrs is True: + # oracle num of speakers. + num_spkrs = diar.get_oracle_num_spkrs(rec_id, spkr_info) + else: + if config.affinity == "nn": + # num of speakers tunned on dev set (only for nn affinity). + num_spkrs = n_lambdas + else: + # num of speakers will be estimated using max eigen gap for cos based affinity. + # so adding None here. Will use this None later-on. + num_spkrs = None + + if config.backend == "kmeans": + diar.do_kmeans_clustering( + diary_obj, + out_rttm_file, + rec_id, + num_spkrs, + pval, ) + + if config.backend == "SC": + # go for Spectral Clustering (SC). + diar.do_spec_clustering( + diary_obj, + out_rttm_file, + rec_id, + num_spkrs, + pval, + config.affinity, + n_neighbors, ) + + # can used for AHC later. Likewise one can add different backends here. + if config.backend == "AHC": + # call AHC + threshold = pval # pval for AHC is nothing but threshold. + diar.do_AHC(diary_obj, out_rttm_file, rec_id, num_spkrs, threshold) + + # once all RTTM outputs are generated, concatenate individual RTTM files to obtain single RTTM file. + # this is not needed but just staying with the standards. + concate_rttm_file = out_rttm_dir + "/sys_output.rttm" + logger.debug("Concatenating individual RTTM files...") + with open(concate_rttm_file, "w") as cat_file: + for f in glob.glob(out_rttm_dir + "/*.rttm"): + if f == concate_rttm_file: + continue + with open(f, "r") as indi_rttm_file: + shutil.copyfileobj(indi_rttm_file, cat_file) + + msg = "The system generated RTTM file for %s set : %s" % ( + split_type, concate_rttm_file, ) + logger.debug(msg) + + return concate_rttm_file + + +def dev_pval_tuner(full_meta, save_dir, config): + """Tuning p_value for affinity matrix. + The p_value used so that only p% of the values in each row is retained. + """ + + DER_list = [] + prange = np.arange(0.002, 0.015, 0.001) + + n_lambdas = None # using it as flag later. + for p_v in prange: + # Process whole dataset for value of p_v. + concate_rttm_file = diarize_dataset(full_meta, "dev", n_lambdas, p_v, + save_dir, config) + + ref_rttm_file = os.path.join(save_dir, config.ref_rttm_dir, + "fullref_ami_dev.rttm") + sys_rttm_file = concate_rttm_file + [MS, FA, SER, DER_] = DER( + ref_rttm_file, + sys_rttm_file, + config.ignore_overlap, + config.forgiveness_collar, ) + + DER_list.append(DER_) + + if config.oracle_n_spkrs is True and config.backend == "kmeans": + # no need of p_val search. Note p_val is needed for SC for both oracle and est num of speakers. + # p_val is needed in oracle_n_spkr=False when using kmeans backend. + break + + # Take p_val that gave minmum DER on Dev dataset. + tuned_p_val = prange[DER_list.index(min(DER_list))] + + return tuned_p_val + + +def dev_ahc_threshold_tuner(full_meta, save_dir, config): + """Tuning threshold for affinity matrix. This function is called when AHC is used as backend. + """ + + DER_list = [] + prange = np.arange(0.0, 1.0, 0.1) + + n_lambdas = None # using it as flag later. + + # Note: p_val is threshold in case of AHC. + for p_v in prange: + # Process whole dataset for value of p_v. + concate_rttm_file = diarize_dataset(full_meta, "dev", n_lambdas, p_v, + save_dir, config) + + ref_rttm = os.path.join(save_dir, config.ref_rttm_dir, + "fullref_ami_dev.rttm") + sys_rttm = concate_rttm_file + [MS, FA, SER, DER_] = DER( + ref_rttm, + sys_rttm, + config.ignore_overlap, + config.forgiveness_collar, ) + + DER_list.append(DER_) + + if config.oracle_n_spkrs is True: + break # no need of threshold search. + + # Take p_val that gave minmum DER on Dev dataset. + tuned_p_val = prange[DER_list.index(min(DER_list))] + + return tuned_p_val + + +def dev_nn_tuner(full_meta, split_type, save_dir, config): + """Tuning n_neighbors on dev set. Assuming oracle num of speakers. + This is used when nn based affinity is selected. + """ + + DER_list = [] + pval = None + + # Now assumming oracle num of speakers. + n_lambdas = 4 + + for nn in range(5, 15): + + # Process whole dataset for value of n_lambdas. + concate_rttm_file = diarize_dataset(full_meta, "dev", n_lambdas, p_v, + save_dir, config, nn) + + ref_rttm = os.path.join(save_dir, config.ref_rttm_dir, + "fullref_ami_dev.rttm") + sys_rttm = concate_rttm_file + [MS, FA, SER, DER_] = DER( + ref_rttm, + sys_rttm, + config.ignore_overlap, + config.forgiveness_collar, ) + + DER_list.append([nn, DER_]) + + if config.oracle_n_spkrs is True and config.backend == "kmeans": + break + + DER_list.sort(key=lambda x: x[1]) + tunned_nn = DER_list[0] + + return tunned_nn[0] + + +def dev_tuner(full_meta, split_type, save_dir, config): + """Tuning n_components on dev set. Used for nn based affinity matrix. + Note: This is a very basic tunning for nn based affinity. + This is work in progress till we find a better way. + """ + + DER_list = [] + pval = None + for n_lambdas in range(1, config.max_num_spkrs + 1): + + # Process whole dataset for value of n_lambdas. + concate_rttm_file = diarize_dataset(full_meta, "dev", n_lambdas, p_v, + save_dir, config) + + ref_rttm = os.path.join(save_dir, config.ref_rttm_dir, + "fullref_ami_dev.rttm") + sys_rttm = concate_rttm_file + [MS, FA, SER, DER_] = DER( + ref_rttm, + sys_rttm, + config.ignore_overlap, + config.forgiveness_collar, ) + + DER_list.append(DER_) + + # Take n_lambdas with minmum DER. + tuned_n_lambdas = DER_list.index(min(DER_list)) + 1 + + return tuned_n_lambdas + + +def main(args, config): + # AMI Dev Set: Tune hyperparams on dev set. + # Read the embdding file for dev set generated during embdding compute + dev_meta_file = os.path.join( + args.data_dir, + config.meta_data_dir, + "ami_dev." + config.mic_type + ".subsegs.json", ) + with open(dev_meta_file, "r") as f: + meta_dev = json.load(f) + + full_meta = meta_dev + + # Processing starts from here + # Following few lines selects option for different backend and affinity matrices. Finds best values for hyperameters using dev set. + ref_rttm_file = os.path.join(args.data_dir, config.ref_rttm_dir, + "fullref_ami_dev.rttm") + best_nn = None + if config.affinity == "nn": + logger.info("Tuning for nn (Multiple iterations over AMI Dev set)") + best_nn = dev_nn_tuner(full_meta, args.data_dir, config) + + n_lambdas = None + best_pval = None + + if config.affinity == "cos" and (config.backend == "SC" or + config.backend == "kmeans"): + # oracle num_spkrs or not, doesn't matter for kmeans and SC backends + # cos: Tune for the best pval for SC /kmeans (for unknown num of spkrs) + logger.info( + "Tuning for p-value for SC (Multiple iterations over AMI Dev set)") + best_pval = dev_pval_tuner(full_meta, args.data_dir, config) + + elif config.backend == "AHC": + logger.info("Tuning for threshold-value for AHC") + best_threshold = dev_ahc_threshold_tuner(full_meta, args.data_dir, + config) + best_pval = best_threshold + else: + # NN for unknown num of speakers (can be used in future) + if config.oracle_n_spkrs is False: + # nn: Tune num of number of components (to be updated later) + logger.info( + "Tuning for number of eigen components for NN (Multiple iterations over AMI Dev set)" + ) + # dev_tuner used for tuning num of components in NN. Can be used in future. + n_lambdas = dev_tuner(full_meta, args.data_dir, config) + + # load 'dev' and 'eval' metadata files. + full_meta_dev = full_meta # current full_meta is for 'dev' + eval_meta_file = os.path.join( + args.data_dir, + config.meta_data_dir, + "ami_eval." + config.mic_type + ".subsegs.json", ) + with open(eval_meta_file, "r") as f: + full_meta_eval = json.load(f) + + # tag to be appended to final output DER files. Writing DER for individual files. + type_of_num_spkr = "oracle" if config.oracle_n_spkrs else "est" + tag = ( + type_of_num_spkr + "_" + str(config.affinity) + "." + config.mic_type) + + # perform final diarization on 'dev' and 'eval' with best hyperparams. + final_DERs = {} + out_der_dir = os.path.join(args.data_dir, config.der_dir) + if not os.path.exists(out_der_dir): + os.makedirs(out_der_dir) + + for split_type in ["dev", "eval"]: + if split_type == "dev": + full_meta = full_meta_dev + else: + full_meta = full_meta_eval + + # performing diarization. + msg = "Diarizing using best hyperparams: " + split_type + " set" + logger.info(msg) + out_boundaries = diarize_dataset( + full_meta, + split_type, + n_lambdas=n_lambdas, + pval=best_pval, + n_neighbors=best_nn, + save_dir=args.data_dir, + config=config) + + # computing DER. + msg = "Computing DERs for " + split_type + " set" + logger.info(msg) + ref_rttm = os.path.join(args.data_dir, config.ref_rttm_dir, + "fullref_ami_" + split_type + ".rttm") + sys_rttm = out_boundaries + [MS, FA, SER, DER_vals] = DER( + ref_rttm, + sys_rttm, + config.ignore_overlap, + config.forgiveness_collar, + individual_file_scores=True, ) + + # writing DER values to a file. Append tag. + der_file_name = split_type + "_DER_" + tag + out_der_file = os.path.join(out_der_dir, der_file_name) + msg = "Writing DER file to: " + out_der_file + logger.info(msg) + diar.write_ders_file(ref_rttm, DER_vals, out_der_file) + + msg = ("AMI " + split_type + " set DER = %s %%\n" % + (str(round(DER_vals[-1], 2)))) + logger.info(msg) + final_DERs[split_type] = round(DER_vals[-1], 2) + + # final print DERs + msg = ( + "Final Diarization Error Rate (%%) on AMI corpus: Dev = %s %% | Eval = %s %%\n" + % (str(final_DERs["dev"]), str(final_DERs["eval"]))) + logger.info(msg) + + +if __name__ == "__main__": + parser = argparse.ArgumentParser(__doc__) + parser.add_argument( + "--config", default=None, type=str, help="configuration file") + parser.add_argument( + "--data-dir", + default="../data/", + type=str, + help="processsed data directory") + args = parser.parse_args() + config = CfgNode(new_allowed=True) + if args.config: + config.merge_from_file(args.config) + + config.freeze() + print(config) + + main(args, config) diff --git a/examples/ami/sd0/local/process.sh b/examples/ami/sd0/local/process.sh new file mode 100755 index 000000000..1b5ed5bd1 --- /dev/null +++ b/examples/ami/sd0/local/process.sh @@ -0,0 +1,49 @@ +#!/bin/bash + +stage=2 +set=L + +. ${MAIN_ROOT}/utils/parse_options.sh || exit 1; +set -u +set -o pipefail + +data_folder=$1 +manual_annot_folder=$2 +save_folder=$3 +pretrained_model_dir=$4 +conf_path=$5 + +ref_rttm_dir=${save_folder}/ref_rttms +meta_data_dir=${save_folder}/metadata + +if [ ${stage} -le 0 ]; then + echo "AMI Data preparation" + python local/ami_prepare.py --data_folder ${data_folder} \ + --manual_annot_folder ${manual_annot_folder} \ + --save_folder ${save_folder} --ref_rttm_dir ${ref_rttm_dir} \ + --meta_data_dir ${meta_data_dir} + + if [ $? -ne 0 ]; then + echo "Prepare AMI failed. Please check log message." + exit 1 + fi + echo "AMI data preparation done." +fi + +if [ ${stage} -le 1 ]; then + # extra embddings for dev and eval dataset + for name in dev eval; do + python local/compute_embdding.py --config ${conf_path} \ + --data-dir ${save_folder} \ + --device gpu:0 \ + --dataset ${name} \ + --load-checkpoint ${pretrained_model_dir} + done +fi + +if [ ${stage} -le 2 ]; then + # tune hyperparams on dev set + # perform final diarization on 'dev' and 'eval' with best hyperparams + python local/experiment.py --config ${conf_path} \ + --data-dir ${save_folder} +fi diff --git a/examples/ami/sd0/run.sh b/examples/ami/sd0/run.sh index 91d4b706a..fc6a91cc3 100644 --- a/examples/ami/sd0/run.sh +++ b/examples/ami/sd0/run.sh @@ -1,14 +1,40 @@ #!/bin/bash -. path.sh || exit 1; +. ./path.sh || exit 1; set -e stage=1 +stop_stage=50 + +#TARGET_DIR=${MAIN_ROOT}/dataset/ami +TARGET_DIR=/home/dataset/AMI +data_folder=${TARGET_DIR}/amicorpus #e.g., /path/to/amicorpus/ +manual_annot_folder=${TARGET_DIR}/ami_public_manual_1.6.2 #e.g., /path/to/ami_public_manual_1.6.2/ + +save_folder=./save +pretraind_model_dir=${save_folder}/model + +conf_path=conf/ecapa_tdnn.yaml . ${MAIN_ROOT}/utils/parse_options.sh || exit 1; -if [ ${stage} -le 1 ]; then - # prepare data - bash ./local/data.sh || exit -1 -fi \ No newline at end of file +if [ $stage -le 0 ] && [ ${stop_stage} -ge 0 ]; then + # Prepare data and model + # Download AMI corpus, You need around 10GB of free space to get whole data + # The signals are too large to package in this way, + # so you need to use the chooser to indicate which ones you wish to download + echo "Please follow https://groups.inf.ed.ac.uk/ami/download/ to download the data." + echo "Annotations: AMI manual annotations v1.6.2 " + echo "Signals: " + echo "1) Select one or more AMI meetings: the IDs please follow ./ami_split.py" + echo "2) Select media streams: Just select Headset mix" + # Download the pretrained Model from HuggingFace or other pretrained model + echo "Please download the pretrained ECAPA-TDNN Model and put the pretrainde model in given path: "${pretraind_model_dir} +fi + +if [ $stage -le 1 ] && [ ${stop_stage} -ge 1 ]; then + # Tune hyperparams on dev set and perform final diarization on dev and eval with best hyperparams. + bash ./local/process.sh ${data_folder} ${manual_annot_folder} ${save_folder} ${pretraind_model_dir} ${conf_path} || exit 1 +fi + diff --git a/paddlespeech/vector/cluster/diarization.py b/paddlespeech/vector/cluster/diarization.py index 597aa4807..5b2157257 100644 --- a/paddlespeech/vector/cluster/diarization.py +++ b/paddlespeech/vector/cluster/diarization.py @@ -746,6 +746,77 @@ def merge_ssegs_same_speaker(lol): return new_lol +def write_ders_file(ref_rttm, DER, out_der_file): + """Write the final DERs for individual recording. + + Arguments + --------- + ref_rttm : str + Reference RTTM file. + DER : array + Array containing DER values of each recording. + out_der_file : str + File to write the DERs. + """ + + rttm = read_rttm(ref_rttm) + spkr_info = list(filter(lambda x: x.startswith("SPKR-INFO"), rttm)) + + rec_id_list = [] + count = 0 + + with open(out_der_file, "w") as f: + for row in spkr_info: + a = row.split(" ") + rec_id = a[1] + if rec_id not in rec_id_list: + r = [rec_id, str(round(DER[count], 2))] + rec_id_list.append(rec_id) + line_str = " ".join(r) + f.write("%s\n" % line_str) + count += 1 + r = ["OVERALL ", str(round(DER[count], 2))] + line_str = " ".join(r) + f.write("%s\n" % line_str) + + +def get_oracle_num_spkrs(rec_id, spkr_info): + """ + Returns actual number of speakers in a recording from the ground-truth. + This can be used when the condition is oracle number of speakers. + + Arguments + --------- + rec_id : str + Recording ID for which the number of speakers have to be obtained. + spkr_info : list + Header of the RTTM file. Starting with `SPKR-INFO`. + + Example + ------- + >>> from speechbrain.processing import diarization as diar + >>> spkr_info = ['SPKR-INFO ES2011a 0 unknown ES2011a.A ', + ... 'SPKR-INFO ES2011a 0 unknown ES2011a.B ', + ... 'SPKR-INFO ES2011a 0 unknown ES2011a.C ', + ... 'SPKR-INFO ES2011a 0 unknown ES2011a.D ', + ... 'SPKR-INFO ES2011b 0 unknown ES2011b.A ', + ... 'SPKR-INFO ES2011b 0 unknown ES2011b.B ', + ... 'SPKR-INFO ES2011b 0 unknown ES2011b.C '] + >>> diar.get_oracle_num_spkrs('ES2011a', spkr_info) + 4 + >>> diar.get_oracle_num_spkrs('ES2011b', spkr_info) + 3 + """ + + num_spkrs = 0 + for line in spkr_info: + if rec_id in line: + # Since rec_id is prefix for each speaker + num_spkrs += 1 + + return num_spkrs + + def distribute_overlap(lol): """ Distributes the overlapped speech equally among the adjacent segments @@ -826,6 +897,29 @@ def distribute_overlap(lol): return new_lol +def read_rttm(rttm_file_path): + """ + Reads and returns RTTM in list format. + + Arguments + --------- + rttm_file_path : str + Path to the RTTM file to be read. + + Returns + ------- + rttm : list + List containing rows of RTTM file. + """ + + rttm = [] + with open(rttm_file_path, "r") as f: + for line in f: + entry = line[:-1] + rttm.append(entry) + return rttm + + def write_rttm(segs_list, out_rttm_file): """ Writes the segment list in RTTM format (A standard NIST format). diff --git a/utils/compute_der.py b/utils/compute_der.py new file mode 100755 index 000000000..d22f6a7d9 --- /dev/null +++ b/utils/compute_der.py @@ -0,0 +1,175 @@ +# Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +"""Calculates Diarization Error Rate (DER) which is the sum of Missed Speaker (MS), +False Alarm (FA), and Speaker Error Rate (SER) using md-eval-22.pl from NIST RT Evaluation. + +Credits +This code is adapted from https://github.com/speechbrain/speechbrain +""" +import argparse +import os +import re +import subprocess + +import numpy as np + +FILE_IDS = re.compile(r"(?<=Speaker Diarization for).+(?=\*\*\*)") +SCORED_SPEAKER_TIME = re.compile(r"(?<=SCORED SPEAKER TIME =)[\d.]+") +MISS_SPEAKER_TIME = re.compile(r"(?<=MISSED SPEAKER TIME =)[\d.]+") +FA_SPEAKER_TIME = re.compile(r"(?<=FALARM SPEAKER TIME =)[\d.]+") +ERROR_SPEAKER_TIME = re.compile(r"(?<=SPEAKER ERROR TIME =)[\d.]+") + + +def rectify(arr): + """Corrects corner cases and converts scores into percentage. + """ + + # Numerator and denominator both 0. + arr[np.isnan(arr)] = 0 + + # Numerator > 0, but denominator = 0. + arr[np.isinf(arr)] = 1 + arr *= 100.0 + + return arr + + +def DER( + ref_rttm, + sys_rttm, + ignore_overlap=False, + collar=0.25, + individual_file_scores=False, ): + """Computes Missed Speaker percentage (MS), False Alarm (FA), + Speaker Error Rate (SER), and Diarization Error Rate (DER). + + Arguments + --------- + ref_rttm : str + The path of reference/groundtruth RTTM file. + sys_rttm : str + The path of the system generated RTTM file. + individual_file_scores : bool + If True, returns scores for each file in order. + collar : float + Forgiveness collar. + ignore_overlap : bool + If True, ignores overlapping speech during evaluation. + + Returns + ------- + MS : float array + Missed Speech. + FA : float array + False Alarms. + SER : float array + Speaker Error Rates. + DER : float array + Diarization Error Rates. + """ + + curr = os.path.abspath(os.path.dirname(__file__)) + mdEval = os.path.join(curr, "./md-eval.pl") + + cmd = [ + mdEval, + "-af", + "-r", + ref_rttm, + "-s", + sys_rttm, + "-c", + str(collar), + ] + print(cmd) + if ignore_overlap: + cmd.append("-1") + + try: + stdout = subprocess.check_output(cmd, stderr=subprocess.STDOUT) + + except subprocess.CalledProcessError as ex: + stdout = ex.output + + else: + stdout = stdout.decode("utf-8") + + # Get all recording IDs + file_ids = [m.strip() for m in FILE_IDS.findall(stdout)] + file_ids = [ + file_id[2:] if file_id.startswith("f=") else file_id + for file_id in file_ids + ] + + scored_speaker_times = np.array( + [float(m) for m in SCORED_SPEAKER_TIME.findall(stdout)]) + + miss_speaker_times = np.array( + [float(m) for m in MISS_SPEAKER_TIME.findall(stdout)]) + + fa_speaker_times = np.array( + [float(m) for m in FA_SPEAKER_TIME.findall(stdout)]) + + error_speaker_times = np.array( + [float(m) for m in ERROR_SPEAKER_TIME.findall(stdout)]) + + with np.errstate(invalid="ignore", divide="ignore"): + tot_error_times = ( + miss_speaker_times + fa_speaker_times + error_speaker_times) + miss_speaker_frac = miss_speaker_times / scored_speaker_times + fa_speaker_frac = fa_speaker_times / scored_speaker_times + sers_frac = error_speaker_times / scored_speaker_times + ders_frac = tot_error_times / scored_speaker_times + + # Values in percentage of scored_speaker_time + miss_speaker = rectify(miss_speaker_frac) + fa_speaker = rectify(fa_speaker_frac) + sers = rectify(sers_frac) + ders = rectify(ders_frac) + + if individual_file_scores: + return miss_speaker, fa_speaker, sers, ders + else: + return miss_speaker[-1], fa_speaker[-1], sers[-1], ders[-1] + + +def main(): + parser = argparse.ArgumentParser(description="Compute DER") + parser.add_argument( + "--ref_rttm", + type=str, + help="the path of reference/groundtruth RTTM file") + parser.add_argument( + "--sys_rttm", + type=str, + help="the path of the system generated RTTM file.") + parser.add_argument( + "--individual_file_scores", + type=bool, + help="whether returns scores for each file in order.") + parser.add_argument("--collar", type=float, help="forgiveness collar.") + parser.add_argument( + "--ignore_overlap", + type=bool, + help="whether ignores overlapping speech during evaluation.") + + args = parser.parse_args() + + Scores = DER(args.ref_rttm, args.sys_rttm, args.ignore_overlap, args.collar, + args.individual_file_scores) + print(Scores) + + +if __name__ == "__main__": + main()