From 378fe5909f839e61c898d0dca3a5dd1cbc9cf9a7 Mon Sep 17 00:00:00 2001 From: ccrrong <1039058843@qq.com> Date: Tue, 5 Apr 2022 11:34:57 +0800 Subject: [PATCH 1/5] add ami diarization pipeline, test=doc --- examples/ami/sd0/conf/ecapa_tdnn.yaml | 71 ++++ examples/ami/sd0/local/ami_dataset.py | 90 +++++ examples/ami/sd0/local/compute_embdding.py | 233 +++++++++++ examples/ami/sd0/local/experiment.py | 439 +++++++++++++++++++++ examples/ami/sd0/local/process.sh | 49 +++ examples/ami/sd0/run.sh | 36 +- paddlespeech/vector/cluster/diarization.py | 94 +++++ utils/compute_der.py | 175 ++++++++ 8 files changed, 1182 insertions(+), 5 deletions(-) create mode 100755 examples/ami/sd0/conf/ecapa_tdnn.yaml create mode 100644 examples/ami/sd0/local/ami_dataset.py create mode 100644 examples/ami/sd0/local/compute_embdding.py create mode 100755 examples/ami/sd0/local/experiment.py create mode 100755 examples/ami/sd0/local/process.sh create mode 100755 utils/compute_der.py diff --git a/examples/ami/sd0/conf/ecapa_tdnn.yaml b/examples/ami/sd0/conf/ecapa_tdnn.yaml new file mode 100755 index 000000000..0f298c35b --- /dev/null +++ b/examples/ami/sd0/conf/ecapa_tdnn.yaml @@ -0,0 +1,71 @@ +# ################################################## +# Model: Speaker Diarization Baseline +# Embeddings: Deep embedding +# Clustering Technique: Spectral clustering +# Authors: Nauman Dawalatabad 2020 +# ################################################# + +seed: 1234 +num_speakers: 7205 + +########################################################### +# AMI DATA PREPARE SETTING # +########################################################### +split_type: 'full_corpus_asr' +skip_TNO: True +# Options for mic_type: 'Mix-Lapel', 'Mix-Headset', 'Array1', 'Array1-01', 'BeamformIt' +mic_type: 'Mix-Headset' +vad_type: 'oracle' +max_subseg_dur: 3.0 +overlap: 1.5 +# Some more exp folders (for cleaner structure). +embedding_dir: emb #!ref /emb +meta_data_dir: metadata #!ref /metadata +ref_rttm_dir: ref_rttms #!ref /ref_rttms +sys_rttm_dir: sys_rttms #!ref /sys_rttms +der_dir: DER #!ref /DER + + +########################################################### +# FEATURE EXTRACTION SETTING # +########################################################### +# currently, we only support fbank +sr: 16000 # sample rate +n_mels: 80 +window_size: 400 #25ms, sample rate 16000, 25 * 16000 / 1000 = 400 +hop_size: 160 #10ms, sample rate 16000, 10 * 16000 / 1000 = 160 +#left_frames: 0 +#right_frames: 0 +#deltas: False + + +########################################################### +# MODEL SETTING # +########################################################### +# currently, we only support ecapa-tdnn in the ecapa_tdnn.yaml +# if we want use another model, please choose another configuration yaml file +emb_dim: 192 +batch_size: 16 +model: + input_size: 80 + channels: [1024, 1024, 1024, 1024, 3072] + kernel_sizes: [5, 3, 3, 3, 1] + dilations: [1, 2, 3, 4, 1] + attention_channels: 128 + lin_neurons: 192 +# Will automatically download ECAPA-TDNN model (best). + +########################################################### +# SPECTRAL CLUSTERING SETTING # +########################################################### +backend: 'SC' # options: 'kmeans' # Note: kmeans goes only with cos affinity +affinity: 'cos' # options: cos, nn +max_num_spkrs: 10 +oracle_n_spkrs: True + + +########################################################### +# DER EVALUATION SETTING # +########################################################### +ignore_overlap: True +forgiveness_collar: 0.25 diff --git a/examples/ami/sd0/local/ami_dataset.py b/examples/ami/sd0/local/ami_dataset.py new file mode 100644 index 000000000..c44329c83 --- /dev/null +++ b/examples/ami/sd0/local/ami_dataset.py @@ -0,0 +1,90 @@ +# Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +import collections +import json + +from paddle.io import Dataset + +from paddleaudio.backends import load as load_audio +from paddleaudio.datasets.dataset import feat_funcs + + +class AMIDataset(Dataset): + """ + AMI dataset. + """ + + meta_info = collections.namedtuple( + 'META_INFO', ('id', 'duration', 'wav', 'start', 'stop', 'record_id')) + + def __init__(self, json_file: str, feat_type: str='raw', **kwargs): + """ + Ags: + json_file (:obj:`str`): Data prep JSON file. + labels (:obj:`List[int]`): Labels of audio files. + feat_type (:obj:`str`, `optional`, defaults to `raw`): + It identifies the feature type that user wants to extrace of an audio file. + """ + if feat_type not in feat_funcs.keys(): + raise RuntimeError( + f"Unknown feat_type: {feat_type}, it must be one in {list(feat_funcs.keys())}" + ) + + self.json_file = json_file + self.feat_type = feat_type + self.feat_config = kwargs + self._data = self._get_data() + super(AMIDataset, self).__init__() + + def _get_data(self): + with open(self.json_file, "r") as f: + meta_data = json.load(f) + data = [] + for key in meta_data: + sub_seg = meta_data[key]["wav"] + wav = sub_seg["file"] + duration = sub_seg["duration"] + start = sub_seg["start"] + stop = sub_seg["stop"] + rec_id = str(key).rsplit("_", 2)[0] + data.append( + self.meta_info( + str(key), + float(duration), wav, int(start), int(stop), str(rec_id))) + return data + + def _convert_to_record(self, idx: int): + sample = self._data[idx] + + record = {} + # To show all fields in a namedtuple: `type(sample)._fields` + for field in type(sample)._fields: + record[field] = getattr(sample, field) + + waveform, sr = load_audio(record['wav']) + waveform = waveform[record['start']:record['stop']] + + feat_func = feat_funcs[self.feat_type] + feat = feat_func( + waveform, sr=sr, **self.feat_config) if feat_func else waveform + + record.update({'feat': feat}) + + return record + + def __getitem__(self, idx): + return self._convert_to_record(idx) + + def __len__(self): + return len(self._data) diff --git a/examples/ami/sd0/local/compute_embdding.py b/examples/ami/sd0/local/compute_embdding.py new file mode 100644 index 000000000..e4fd5da2c --- /dev/null +++ b/examples/ami/sd0/local/compute_embdding.py @@ -0,0 +1,233 @@ +# Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +import argparse +import json +import os +import pickle +import sys + +import numpy as np +import paddle +from ami_dataset import AMIDataset +from paddle.io import BatchSampler +from paddle.io import DataLoader +from tqdm.contrib import tqdm +from yacs.config import CfgNode + +from paddlespeech.s2t.utils.log import Log +from paddlespeech.vector.cluster.diarization import EmbeddingMeta +from paddlespeech.vector.io.batch import batch_feature_normalize +from paddlespeech.vector.models.ecapa_tdnn import EcapaTdnn +from paddlespeech.vector.modules.sid_model import SpeakerIdetification +from paddlespeech.vector.training.seeding import seed_everything + +# Logger setup +logger = Log(__name__).getlog() + + +def prepare_subset_json(full_meta_data, rec_id, out_meta_file): + """Prepares metadata for a given recording ID. + + Arguments + --------- + full_meta_data : json + Full meta (json) containing all the recordings + rec_id : str + The recording ID for which meta (json) has to be prepared + out_meta_file : str + Path of the output meta (json) file. + """ + + subset = {} + for key in full_meta_data: + k = str(key) + if k.startswith(rec_id): + subset[key] = full_meta_data[key] + + with open(out_meta_file, mode="w") as json_f: + json.dump(subset, json_f, indent=2) + + +def create_dataloader(json_file, batch_size): + """Creates the datasets and their data processing pipelines. + This is used for multi-mic processing. + """ + + # create datasets + dataset = AMIDataset( + json_file=json_file, + feat_type='melspectrogram', + n_mels=config.n_mels, + window_size=config.window_size, + hop_length=config.hop_size) + + # create dataloader + batch_sampler = BatchSampler(dataset, batch_size=batch_size, shuffle=True) + dataloader = DataLoader(dataset, + batch_sampler=batch_sampler, + collate_fn=lambda x: batch_feature_normalize( + x, mean_norm=True, std_norm=False), + return_list=True) + + return dataloader + + +def main(args, config): + # set the training device, cpu or gpu + paddle.set_device(args.device) + # set the random seed + seed_everything(config.seed) + + # stage1: build the dnn backbone model network + ecapa_tdnn = EcapaTdnn(**config.model) + + # stage2: build the speaker verification eval instance with backbone model + model = SpeakerIdetification( + backbone=ecapa_tdnn, num_class=config.num_speakers) + + # stage3: load the pre-trained model + # we get the last model from the epoch and save_interval + args.load_checkpoint = os.path.abspath( + os.path.expanduser(args.load_checkpoint)) + + # load model checkpoint to sid model + state_dict = paddle.load( + os.path.join(args.load_checkpoint, 'model.pdparams')) + model.set_state_dict(state_dict) + logger.info(f'Checkpoint loaded from {args.load_checkpoint}') + + # set the model to eval mode + model.eval() + + # load meta data + meta_file = os.path.join( + args.data_dir, + config.meta_data_dir, + "ami_" + args.dataset + "." + config.mic_type + ".subsegs.json", ) + with open(meta_file, "r") as f: + full_meta = json.load(f) + + # get all the recording IDs in this dataset. + all_keys = full_meta.keys() + A = [word.rstrip().split("_")[0] for word in all_keys] + all_rec_ids = list(set(A[1:])) + all_rec_ids.sort() + split = "AMI_" + args.dataset + i = 1 + + msg = "Extra embdding for " + args.dataset + " set" + logger.info(msg) + + if len(all_rec_ids) <= 0: + msg = "No recording IDs found! Please check if meta_data json file is properly generated." + logger.error(msg) + sys.exit() + + # extra different recordings embdding in a dataset. + for rec_id in tqdm(all_rec_ids): + # This tag will be displayed in the log. + tag = ("[" + str(args.dataset) + ": " + str(i) + "/" + + str(len(all_rec_ids)) + "]") + i = i + 1 + + # log message. + msg = "Embdding %s : %s " % (tag, rec_id) + logger.debug(msg) + + # embedding directory. + if not os.path.exists( + os.path.join(args.data_dir, config.embedding_dir, split)): + os.makedirs( + os.path.join(args.data_dir, config.embedding_dir, split)) + + # file to store embeddings. + emb_file_name = rec_id + "." + config.mic_type + ".emb_stat.pkl" + diary_stat_emb_file = os.path.join(args.data_dir, config.embedding_dir, + split, emb_file_name) + + # prepare a metadata (json) for one recording. This is basically a subset of full_meta. + # lets keep this meta-info in embedding directory itself. + json_file_name = rec_id + "." + config.mic_type + ".json" + meta_per_rec_file = os.path.join(args.data_dir, config.embedding_dir, + split, json_file_name) + + # write subset (meta for one recording) json metadata. + prepare_subset_json(full_meta, rec_id, meta_per_rec_file) + + # prepare data loader. + diary_set_loader = create_dataloader(meta_per_rec_file, + config.batch_size) + + # extract embeddings (skip if already done). + if not os.path.isfile(diary_stat_emb_file): + logger.debug("Extracting deep embeddings") + embeddings = np.empty(shape=[0, config.emb_dim], dtype=np.float64) + segset = [] + + for batch_idx, batch in enumerate(tqdm(diary_set_loader)): + # extrac the audio embedding + ids, feats, lengths = batch['ids'], batch['feats'], batch[ + 'lengths'] + seg = [x for x in ids] + segset = segset + seg + emb = model.backbone(feats, lengths).squeeze( + -1).numpy() # (N, emb_size, 1) -> (N, emb_size) + embeddings = np.concatenate((embeddings, emb), axis=0) + + segset = np.array(segset, dtype="|O") + stat_obj = EmbeddingMeta( + segset=segset, + stats=embeddings, ) + logger.debug("Saving Embeddings...") + with open(diary_stat_emb_file, "wb") as output: + pickle.dump(stat_obj, output) + + else: + logger.debug("Skipping embedding extraction (as already present).") + + +# Begin experiment! +if __name__ == "__main__": + parser = argparse.ArgumentParser(__doc__) + parser.add_argument( + '--device', + default="gpu", + help="Select which device to perform diarization, defaults to gpu.") + parser.add_argument( + "--config", default=None, type=str, help="configuration file") + parser.add_argument( + "--data-dir", + default="../save/", + type=str, + help="processsed data directory") + parser.add_argument( + "--dataset", + choices=['dev', 'eval'], + default="dev", + type=str, + help="Select which dataset to extra embdding, defaults to dev") + parser.add_argument( + "--load-checkpoint", + type=str, + default='', + help="Directory to load model checkpoint to compute embeddings.") + args = parser.parse_args() + config = CfgNode(new_allowed=True) + if args.config: + config.merge_from_file(args.config) + + config.freeze() + print(config) + + main(args, config) diff --git a/examples/ami/sd0/local/experiment.py b/examples/ami/sd0/local/experiment.py new file mode 100755 index 000000000..e912a4895 --- /dev/null +++ b/examples/ami/sd0/local/experiment.py @@ -0,0 +1,439 @@ +#!/usr/bin/python3 +"""This recipe implements diarization system using deep embedding extraction followed by spectral clustering. + +To run this recipe: +> python experiment.py hparams/ + e.g., python experiment.py hparams/ecapa_tdnn.yaml + +Condition: Oracle VAD (speech regions taken from the groundtruth). + +Note: There are multiple ways to write this recipe. We iterate over individual recordings. + This approach is less GPU memory demanding and also makes code easy to understand. + +Citation: This recipe is based on the following paper, + N. Dawalatabad, M. Ravanelli, F. Grondin, J. Thienpondt, B. Desplanques, H. Na, + "ECAPA-TDNN Embeddings for Speaker Diarization," arXiv:2104.01466, 2021. + +Authors + * Nauman Dawalatabad 2020 +""" +import argparse +import glob +import json +import os +import pickle +import shutil +import sys + +import numpy as np +from tqdm.contrib import tqdm +from yacs.config import CfgNode + +from paddlespeech.s2t.utils.log import Log +from paddlespeech.vector.cluster import diarization as diar +from utils.compute_der import DER + +# Logger setup +logger = Log(__name__).getlog() + + +def diarize_dataset( + full_meta, + split_type, + n_lambdas, + pval, + save_dir, + config, + n_neighbors=10, ): + """This function diarizes all the recordings in a given dataset. It performs + computation of embedding and clusters them using spectral clustering (or other backends). + The output speaker boundary file is stored in the RTTM format. + """ + + # prepare `spkr_info` only once when Oracle num of speakers is selected. + # spkr_info is essential to obtain number of speakers from groundtruth. + if config.oracle_n_spkrs is True: + full_ref_rttm_file = os.path.join(save_dir, config.ref_rttm_dir, + "fullref_ami_" + split_type + ".rttm") + rttm = diar.read_rttm(full_ref_rttm_file) + + spkr_info = list( # noqa F841 + filter(lambda x: x.startswith("SPKR-INFO"), rttm)) + + # get all the recording IDs in this dataset. + all_keys = full_meta.keys() + A = [word.rstrip().split("_")[0] for word in all_keys] + all_rec_ids = list(set(A[1:])) + all_rec_ids.sort() + split = "AMI_" + split_type + i = 1 + + # adding tag for directory path. + type_of_num_spkr = "oracle" if config.oracle_n_spkrs else "est" + tag = (type_of_num_spkr + "_" + str(config.affinity) + "_" + config.backend) + + # make out rttm dir + out_rttm_dir = os.path.join(save_dir, config.sys_rttm_dir, config.mic_type, + split, tag) + if not os.path.exists(out_rttm_dir): + os.makedirs(out_rttm_dir) + + # diarizing different recordings in a dataset. + for rec_id in tqdm(all_rec_ids): + # this tag will be displayed in the log. + if rec_id == "IS1008a": + continue + if rec_id == "ES2011a": + continue + tag = ("[" + str(split_type) + ": " + str(i) + "/" + + str(len(all_rec_ids)) + "]") + i = i + 1 + + # log message. + msg = "Diarizing %s : %s " % (tag, rec_id) + logger.debug(msg) + + # load embeddings. + emb_file_name = rec_id + "." + config.mic_type + ".emb_stat.pkl" + diary_stat_emb_file = os.path.join(save_dir, config.embedding_dir, + split, emb_file_name) + if not os.path.isfile(diary_stat_emb_file): + msg = "Embdding file %s not found! Please check if embdding file is properly generated." % ( + diary_stat_emb_file) + logger.error(msg) + sys.exit() + with open(diary_stat_emb_file, "rb") as in_file: + diary_obj = pickle.load(in_file) + + out_rttm_file = out_rttm_dir + "/" + rec_id + ".rttm" + + # processing starts from here. + if config.oracle_n_spkrs is True: + # oracle num of speakers. + num_spkrs = diar.get_oracle_num_spkrs(rec_id, spkr_info) + else: + if config.affinity == "nn": + # num of speakers tunned on dev set (only for nn affinity). + num_spkrs = n_lambdas + else: + # num of speakers will be estimated using max eigen gap for cos based affinity. + # so adding None here. Will use this None later-on. + num_spkrs = None + + if config.backend == "kmeans": + diar.do_kmeans_clustering( + diary_obj, + out_rttm_file, + rec_id, + num_spkrs, + pval, ) + + if config.backend == "SC": + # go for Spectral Clustering (SC). + diar.do_spec_clustering( + diary_obj, + out_rttm_file, + rec_id, + num_spkrs, + pval, + config.affinity, + n_neighbors, ) + + # can used for AHC later. Likewise one can add different backends here. + if config.backend == "AHC": + # call AHC + threshold = pval # pval for AHC is nothing but threshold. + diar.do_AHC(diary_obj, out_rttm_file, rec_id, num_spkrs, threshold) + + # once all RTTM outputs are generated, concatenate individual RTTM files to obtain single RTTM file. + # this is not needed but just staying with the standards. + concate_rttm_file = out_rttm_dir + "/sys_output.rttm" + logger.debug("Concatenating individual RTTM files...") + with open(concate_rttm_file, "w") as cat_file: + for f in glob.glob(out_rttm_dir + "/*.rttm"): + if f == concate_rttm_file: + continue + with open(f, "r") as indi_rttm_file: + shutil.copyfileobj(indi_rttm_file, cat_file) + + msg = "The system generated RTTM file for %s set : %s" % ( + split_type, concate_rttm_file, ) + logger.debug(msg) + + return concate_rttm_file + + +def dev_pval_tuner(full_meta, save_dir, config): + """Tuning p_value for affinity matrix. + The p_value used so that only p% of the values in each row is retained. + """ + + DER_list = [] + prange = np.arange(0.002, 0.015, 0.001) + + n_lambdas = None # using it as flag later. + for p_v in prange: + # Process whole dataset for value of p_v. + concate_rttm_file = diarize_dataset(full_meta, "dev", n_lambdas, p_v, + save_dir, config) + + ref_rttm_file = os.path.join(save_dir, config.ref_rttm_dir, + "fullref_ami_dev.rttm") + sys_rttm_file = concate_rttm_file + [MS, FA, SER, DER_] = DER( + ref_rttm_file, + sys_rttm_file, + config.ignore_overlap, + config.forgiveness_collar, ) + + DER_list.append(DER_) + + if config.oracle_n_spkrs is True and config.backend == "kmeans": + # no need of p_val search. Note p_val is needed for SC for both oracle and est num of speakers. + # p_val is needed in oracle_n_spkr=False when using kmeans backend. + break + + # Take p_val that gave minmum DER on Dev dataset. + tuned_p_val = prange[DER_list.index(min(DER_list))] + + return tuned_p_val + + +def dev_ahc_threshold_tuner(full_meta, save_dir, config): + """Tuning threshold for affinity matrix. This function is called when AHC is used as backend. + """ + + DER_list = [] + prange = np.arange(0.0, 1.0, 0.1) + + n_lambdas = None # using it as flag later. + + # Note: p_val is threshold in case of AHC. + for p_v in prange: + # Process whole dataset for value of p_v. + concate_rttm_file = diarize_dataset(full_meta, "dev", n_lambdas, p_v, + save_dir, config) + + ref_rttm = os.path.join(save_dir, config.ref_rttm_dir, + "fullref_ami_dev.rttm") + sys_rttm = concate_rttm_file + [MS, FA, SER, DER_] = DER( + ref_rttm, + sys_rttm, + config.ignore_overlap, + config.forgiveness_collar, ) + + DER_list.append(DER_) + + if config.oracle_n_spkrs is True: + break # no need of threshold search. + + # Take p_val that gave minmum DER on Dev dataset. + tuned_p_val = prange[DER_list.index(min(DER_list))] + + return tuned_p_val + + +def dev_nn_tuner(full_meta, split_type, save_dir, config): + """Tuning n_neighbors on dev set. Assuming oracle num of speakers. + This is used when nn based affinity is selected. + """ + + DER_list = [] + pval = None + + # Now assumming oracle num of speakers. + n_lambdas = 4 + + for nn in range(5, 15): + + # Process whole dataset for value of n_lambdas. + concate_rttm_file = diarize_dataset(full_meta, "dev", n_lambdas, p_v, + save_dir, config, nn) + + ref_rttm = os.path.join(save_dir, config.ref_rttm_dir, + "fullref_ami_dev.rttm") + sys_rttm = concate_rttm_file + [MS, FA, SER, DER_] = DER( + ref_rttm, + sys_rttm, + config.ignore_overlap, + config.forgiveness_collar, ) + + DER_list.append([nn, DER_]) + + if config.oracle_n_spkrs is True and config.backend == "kmeans": + break + + DER_list.sort(key=lambda x: x[1]) + tunned_nn = DER_list[0] + + return tunned_nn[0] + + +def dev_tuner(full_meta, split_type, save_dir, config): + """Tuning n_components on dev set. Used for nn based affinity matrix. + Note: This is a very basic tunning for nn based affinity. + This is work in progress till we find a better way. + """ + + DER_list = [] + pval = None + for n_lambdas in range(1, config.max_num_spkrs + 1): + + # Process whole dataset for value of n_lambdas. + concate_rttm_file = diarize_dataset(full_meta, "dev", n_lambdas, p_v, + save_dir, config) + + ref_rttm = os.path.join(save_dir, config.ref_rttm_dir, + "fullref_ami_dev.rttm") + sys_rttm = concate_rttm_file + [MS, FA, SER, DER_] = DER( + ref_rttm, + sys_rttm, + config.ignore_overlap, + config.forgiveness_collar, ) + + DER_list.append(DER_) + + # Take n_lambdas with minmum DER. + tuned_n_lambdas = DER_list.index(min(DER_list)) + 1 + + return tuned_n_lambdas + + +def main(args, config): + # AMI Dev Set: Tune hyperparams on dev set. + # Read the embdding file for dev set generated during embdding compute + dev_meta_file = os.path.join( + args.data_dir, + config.meta_data_dir, + "ami_dev." + config.mic_type + ".subsegs.json", ) + with open(dev_meta_file, "r") as f: + meta_dev = json.load(f) + + full_meta = meta_dev + + # Processing starts from here + # Following few lines selects option for different backend and affinity matrices. Finds best values for hyperameters using dev set. + ref_rttm_file = os.path.join(args.data_dir, config.ref_rttm_dir, + "fullref_ami_dev.rttm") + best_nn = None + if config.affinity == "nn": + logger.info("Tuning for nn (Multiple iterations over AMI Dev set)") + best_nn = dev_nn_tuner(full_meta, args.data_dir, config) + + n_lambdas = None + best_pval = None + + if config.affinity == "cos" and (config.backend == "SC" or + config.backend == "kmeans"): + # oracle num_spkrs or not, doesn't matter for kmeans and SC backends + # cos: Tune for the best pval for SC /kmeans (for unknown num of spkrs) + logger.info( + "Tuning for p-value for SC (Multiple iterations over AMI Dev set)") + best_pval = dev_pval_tuner(full_meta, args.data_dir, config) + + elif config.backend == "AHC": + logger.info("Tuning for threshold-value for AHC") + best_threshold = dev_ahc_threshold_tuner(full_meta, args.data_dir, + config) + best_pval = best_threshold + else: + # NN for unknown num of speakers (can be used in future) + if config.oracle_n_spkrs is False: + # nn: Tune num of number of components (to be updated later) + logger.info( + "Tuning for number of eigen components for NN (Multiple iterations over AMI Dev set)" + ) + # dev_tuner used for tuning num of components in NN. Can be used in future. + n_lambdas = dev_tuner(full_meta, args.data_dir, config) + + # load 'dev' and 'eval' metadata files. + full_meta_dev = full_meta # current full_meta is for 'dev' + eval_meta_file = os.path.join( + args.data_dir, + config.meta_data_dir, + "ami_eval." + config.mic_type + ".subsegs.json", ) + with open(eval_meta_file, "r") as f: + full_meta_eval = json.load(f) + + # tag to be appended to final output DER files. Writing DER for individual files. + type_of_num_spkr = "oracle" if config.oracle_n_spkrs else "est" + tag = ( + type_of_num_spkr + "_" + str(config.affinity) + "." + config.mic_type) + + # perform final diarization on 'dev' and 'eval' with best hyperparams. + final_DERs = {} + out_der_dir = os.path.join(args.data_dir, config.der_dir) + if not os.path.exists(out_der_dir): + os.makedirs(out_der_dir) + + for split_type in ["dev", "eval"]: + if split_type == "dev": + full_meta = full_meta_dev + else: + full_meta = full_meta_eval + + # performing diarization. + msg = "Diarizing using best hyperparams: " + split_type + " set" + logger.info(msg) + out_boundaries = diarize_dataset( + full_meta, + split_type, + n_lambdas=n_lambdas, + pval=best_pval, + n_neighbors=best_nn, + save_dir=args.data_dir, + config=config) + + # computing DER. + msg = "Computing DERs for " + split_type + " set" + logger.info(msg) + ref_rttm = os.path.join(args.data_dir, config.ref_rttm_dir, + "fullref_ami_" + split_type + ".rttm") + sys_rttm = out_boundaries + [MS, FA, SER, DER_vals] = DER( + ref_rttm, + sys_rttm, + config.ignore_overlap, + config.forgiveness_collar, + individual_file_scores=True, ) + + # writing DER values to a file. Append tag. + der_file_name = split_type + "_DER_" + tag + out_der_file = os.path.join(out_der_dir, der_file_name) + msg = "Writing DER file to: " + out_der_file + logger.info(msg) + diar.write_ders_file(ref_rttm, DER_vals, out_der_file) + + msg = ("AMI " + split_type + " set DER = %s %%\n" % + (str(round(DER_vals[-1], 2)))) + logger.info(msg) + final_DERs[split_type] = round(DER_vals[-1], 2) + + # final print DERs + msg = ( + "Final Diarization Error Rate (%%) on AMI corpus: Dev = %s %% | Eval = %s %%\n" + % (str(final_DERs["dev"]), str(final_DERs["eval"]))) + logger.info(msg) + + +if __name__ == "__main__": + parser = argparse.ArgumentParser(__doc__) + parser.add_argument( + "--config", default=None, type=str, help="configuration file") + parser.add_argument( + "--data-dir", + default="../data/", + type=str, + help="processsed data directory") + args = parser.parse_args() + config = CfgNode(new_allowed=True) + if args.config: + config.merge_from_file(args.config) + + config.freeze() + print(config) + + main(args, config) diff --git a/examples/ami/sd0/local/process.sh b/examples/ami/sd0/local/process.sh new file mode 100755 index 000000000..1b5ed5bd1 --- /dev/null +++ b/examples/ami/sd0/local/process.sh @@ -0,0 +1,49 @@ +#!/bin/bash + +stage=2 +set=L + +. ${MAIN_ROOT}/utils/parse_options.sh || exit 1; +set -u +set -o pipefail + +data_folder=$1 +manual_annot_folder=$2 +save_folder=$3 +pretrained_model_dir=$4 +conf_path=$5 + +ref_rttm_dir=${save_folder}/ref_rttms +meta_data_dir=${save_folder}/metadata + +if [ ${stage} -le 0 ]; then + echo "AMI Data preparation" + python local/ami_prepare.py --data_folder ${data_folder} \ + --manual_annot_folder ${manual_annot_folder} \ + --save_folder ${save_folder} --ref_rttm_dir ${ref_rttm_dir} \ + --meta_data_dir ${meta_data_dir} + + if [ $? -ne 0 ]; then + echo "Prepare AMI failed. Please check log message." + exit 1 + fi + echo "AMI data preparation done." +fi + +if [ ${stage} -le 1 ]; then + # extra embddings for dev and eval dataset + for name in dev eval; do + python local/compute_embdding.py --config ${conf_path} \ + --data-dir ${save_folder} \ + --device gpu:0 \ + --dataset ${name} \ + --load-checkpoint ${pretrained_model_dir} + done +fi + +if [ ${stage} -le 2 ]; then + # tune hyperparams on dev set + # perform final diarization on 'dev' and 'eval' with best hyperparams + python local/experiment.py --config ${conf_path} \ + --data-dir ${save_folder} +fi diff --git a/examples/ami/sd0/run.sh b/examples/ami/sd0/run.sh index 91d4b706a..fc6a91cc3 100644 --- a/examples/ami/sd0/run.sh +++ b/examples/ami/sd0/run.sh @@ -1,14 +1,40 @@ #!/bin/bash -. path.sh || exit 1; +. ./path.sh || exit 1; set -e stage=1 +stop_stage=50 + +#TARGET_DIR=${MAIN_ROOT}/dataset/ami +TARGET_DIR=/home/dataset/AMI +data_folder=${TARGET_DIR}/amicorpus #e.g., /path/to/amicorpus/ +manual_annot_folder=${TARGET_DIR}/ami_public_manual_1.6.2 #e.g., /path/to/ami_public_manual_1.6.2/ + +save_folder=./save +pretraind_model_dir=${save_folder}/model + +conf_path=conf/ecapa_tdnn.yaml . ${MAIN_ROOT}/utils/parse_options.sh || exit 1; -if [ ${stage} -le 1 ]; then - # prepare data - bash ./local/data.sh || exit -1 -fi \ No newline at end of file +if [ $stage -le 0 ] && [ ${stop_stage} -ge 0 ]; then + # Prepare data and model + # Download AMI corpus, You need around 10GB of free space to get whole data + # The signals are too large to package in this way, + # so you need to use the chooser to indicate which ones you wish to download + echo "Please follow https://groups.inf.ed.ac.uk/ami/download/ to download the data." + echo "Annotations: AMI manual annotations v1.6.2 " + echo "Signals: " + echo "1) Select one or more AMI meetings: the IDs please follow ./ami_split.py" + echo "2) Select media streams: Just select Headset mix" + # Download the pretrained Model from HuggingFace or other pretrained model + echo "Please download the pretrained ECAPA-TDNN Model and put the pretrainde model in given path: "${pretraind_model_dir} +fi + +if [ $stage -le 1 ] && [ ${stop_stage} -ge 1 ]; then + # Tune hyperparams on dev set and perform final diarization on dev and eval with best hyperparams. + bash ./local/process.sh ${data_folder} ${manual_annot_folder} ${save_folder} ${pretraind_model_dir} ${conf_path} || exit 1 +fi + diff --git a/paddlespeech/vector/cluster/diarization.py b/paddlespeech/vector/cluster/diarization.py index 597aa4807..5b2157257 100644 --- a/paddlespeech/vector/cluster/diarization.py +++ b/paddlespeech/vector/cluster/diarization.py @@ -746,6 +746,77 @@ def merge_ssegs_same_speaker(lol): return new_lol +def write_ders_file(ref_rttm, DER, out_der_file): + """Write the final DERs for individual recording. + + Arguments + --------- + ref_rttm : str + Reference RTTM file. + DER : array + Array containing DER values of each recording. + out_der_file : str + File to write the DERs. + """ + + rttm = read_rttm(ref_rttm) + spkr_info = list(filter(lambda x: x.startswith("SPKR-INFO"), rttm)) + + rec_id_list = [] + count = 0 + + with open(out_der_file, "w") as f: + for row in spkr_info: + a = row.split(" ") + rec_id = a[1] + if rec_id not in rec_id_list: + r = [rec_id, str(round(DER[count], 2))] + rec_id_list.append(rec_id) + line_str = " ".join(r) + f.write("%s\n" % line_str) + count += 1 + r = ["OVERALL ", str(round(DER[count], 2))] + line_str = " ".join(r) + f.write("%s\n" % line_str) + + +def get_oracle_num_spkrs(rec_id, spkr_info): + """ + Returns actual number of speakers in a recording from the ground-truth. + This can be used when the condition is oracle number of speakers. + + Arguments + --------- + rec_id : str + Recording ID for which the number of speakers have to be obtained. + spkr_info : list + Header of the RTTM file. Starting with `SPKR-INFO`. + + Example + ------- + >>> from speechbrain.processing import diarization as diar + >>> spkr_info = ['SPKR-INFO ES2011a 0 unknown ES2011a.A ', + ... 'SPKR-INFO ES2011a 0 unknown ES2011a.B ', + ... 'SPKR-INFO ES2011a 0 unknown ES2011a.C ', + ... 'SPKR-INFO ES2011a 0 unknown ES2011a.D ', + ... 'SPKR-INFO ES2011b 0 unknown ES2011b.A ', + ... 'SPKR-INFO ES2011b 0 unknown ES2011b.B ', + ... 'SPKR-INFO ES2011b 0 unknown ES2011b.C '] + >>> diar.get_oracle_num_spkrs('ES2011a', spkr_info) + 4 + >>> diar.get_oracle_num_spkrs('ES2011b', spkr_info) + 3 + """ + + num_spkrs = 0 + for line in spkr_info: + if rec_id in line: + # Since rec_id is prefix for each speaker + num_spkrs += 1 + + return num_spkrs + + def distribute_overlap(lol): """ Distributes the overlapped speech equally among the adjacent segments @@ -826,6 +897,29 @@ def distribute_overlap(lol): return new_lol +def read_rttm(rttm_file_path): + """ + Reads and returns RTTM in list format. + + Arguments + --------- + rttm_file_path : str + Path to the RTTM file to be read. + + Returns + ------- + rttm : list + List containing rows of RTTM file. + """ + + rttm = [] + with open(rttm_file_path, "r") as f: + for line in f: + entry = line[:-1] + rttm.append(entry) + return rttm + + def write_rttm(segs_list, out_rttm_file): """ Writes the segment list in RTTM format (A standard NIST format). diff --git a/utils/compute_der.py b/utils/compute_der.py new file mode 100755 index 000000000..d22f6a7d9 --- /dev/null +++ b/utils/compute_der.py @@ -0,0 +1,175 @@ +# Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +"""Calculates Diarization Error Rate (DER) which is the sum of Missed Speaker (MS), +False Alarm (FA), and Speaker Error Rate (SER) using md-eval-22.pl from NIST RT Evaluation. + +Credits +This code is adapted from https://github.com/speechbrain/speechbrain +""" +import argparse +import os +import re +import subprocess + +import numpy as np + +FILE_IDS = re.compile(r"(?<=Speaker Diarization for).+(?=\*\*\*)") +SCORED_SPEAKER_TIME = re.compile(r"(?<=SCORED SPEAKER TIME =)[\d.]+") +MISS_SPEAKER_TIME = re.compile(r"(?<=MISSED SPEAKER TIME =)[\d.]+") +FA_SPEAKER_TIME = re.compile(r"(?<=FALARM SPEAKER TIME =)[\d.]+") +ERROR_SPEAKER_TIME = re.compile(r"(?<=SPEAKER ERROR TIME =)[\d.]+") + + +def rectify(arr): + """Corrects corner cases and converts scores into percentage. + """ + + # Numerator and denominator both 0. + arr[np.isnan(arr)] = 0 + + # Numerator > 0, but denominator = 0. + arr[np.isinf(arr)] = 1 + arr *= 100.0 + + return arr + + +def DER( + ref_rttm, + sys_rttm, + ignore_overlap=False, + collar=0.25, + individual_file_scores=False, ): + """Computes Missed Speaker percentage (MS), False Alarm (FA), + Speaker Error Rate (SER), and Diarization Error Rate (DER). + + Arguments + --------- + ref_rttm : str + The path of reference/groundtruth RTTM file. + sys_rttm : str + The path of the system generated RTTM file. + individual_file_scores : bool + If True, returns scores for each file in order. + collar : float + Forgiveness collar. + ignore_overlap : bool + If True, ignores overlapping speech during evaluation. + + Returns + ------- + MS : float array + Missed Speech. + FA : float array + False Alarms. + SER : float array + Speaker Error Rates. + DER : float array + Diarization Error Rates. + """ + + curr = os.path.abspath(os.path.dirname(__file__)) + mdEval = os.path.join(curr, "./md-eval.pl") + + cmd = [ + mdEval, + "-af", + "-r", + ref_rttm, + "-s", + sys_rttm, + "-c", + str(collar), + ] + print(cmd) + if ignore_overlap: + cmd.append("-1") + + try: + stdout = subprocess.check_output(cmd, stderr=subprocess.STDOUT) + + except subprocess.CalledProcessError as ex: + stdout = ex.output + + else: + stdout = stdout.decode("utf-8") + + # Get all recording IDs + file_ids = [m.strip() for m in FILE_IDS.findall(stdout)] + file_ids = [ + file_id[2:] if file_id.startswith("f=") else file_id + for file_id in file_ids + ] + + scored_speaker_times = np.array( + [float(m) for m in SCORED_SPEAKER_TIME.findall(stdout)]) + + miss_speaker_times = np.array( + [float(m) for m in MISS_SPEAKER_TIME.findall(stdout)]) + + fa_speaker_times = np.array( + [float(m) for m in FA_SPEAKER_TIME.findall(stdout)]) + + error_speaker_times = np.array( + [float(m) for m in ERROR_SPEAKER_TIME.findall(stdout)]) + + with np.errstate(invalid="ignore", divide="ignore"): + tot_error_times = ( + miss_speaker_times + fa_speaker_times + error_speaker_times) + miss_speaker_frac = miss_speaker_times / scored_speaker_times + fa_speaker_frac = fa_speaker_times / scored_speaker_times + sers_frac = error_speaker_times / scored_speaker_times + ders_frac = tot_error_times / scored_speaker_times + + # Values in percentage of scored_speaker_time + miss_speaker = rectify(miss_speaker_frac) + fa_speaker = rectify(fa_speaker_frac) + sers = rectify(sers_frac) + ders = rectify(ders_frac) + + if individual_file_scores: + return miss_speaker, fa_speaker, sers, ders + else: + return miss_speaker[-1], fa_speaker[-1], sers[-1], ders[-1] + + +def main(): + parser = argparse.ArgumentParser(description="Compute DER") + parser.add_argument( + "--ref_rttm", + type=str, + help="the path of reference/groundtruth RTTM file") + parser.add_argument( + "--sys_rttm", + type=str, + help="the path of the system generated RTTM file.") + parser.add_argument( + "--individual_file_scores", + type=bool, + help="whether returns scores for each file in order.") + parser.add_argument("--collar", type=float, help="forgiveness collar.") + parser.add_argument( + "--ignore_overlap", + type=bool, + help="whether ignores overlapping speech during evaluation.") + + args = parser.parse_args() + + Scores = DER(args.ref_rttm, args.sys_rttm, args.ignore_overlap, args.collar, + args.individual_file_scores) + print(Scores) + + +if __name__ == "__main__": + main() From 7a03f36548aae74964d273af91dc943cc9175a4a Mon Sep 17 00:00:00 2001 From: ccrrong <1039058843@qq.com> Date: Tue, 5 Apr 2022 19:49:44 +0800 Subject: [PATCH 2/5] code format, test=doc --- examples/ami/sd0/conf/ecapa_tdnn.yaml | 11 +---- examples/ami/sd0/local/compute_embdding.py | 3 +- examples/ami/sd0/local/data.sh | 49 ---------------------- examples/ami/sd0/local/experiment.py | 37 ++++++---------- examples/ami/sd0/local/process.sh | 2 +- 5 files changed, 16 insertions(+), 86 deletions(-) delete mode 100755 examples/ami/sd0/local/data.sh diff --git a/examples/ami/sd0/conf/ecapa_tdnn.yaml b/examples/ami/sd0/conf/ecapa_tdnn.yaml index 0f298c35b..319e44976 100755 --- a/examples/ami/sd0/conf/ecapa_tdnn.yaml +++ b/examples/ami/sd0/conf/ecapa_tdnn.yaml @@ -1,13 +1,3 @@ -# ################################################## -# Model: Speaker Diarization Baseline -# Embeddings: Deep embedding -# Clustering Technique: Spectral clustering -# Authors: Nauman Dawalatabad 2020 -# ################################################# - -seed: 1234 -num_speakers: 7205 - ########################################################### # AMI DATA PREPARE SETTING # ########################################################### @@ -44,6 +34,7 @@ hop_size: 160 #10ms, sample rate 16000, 10 * 16000 / 1000 = 160 ########################################################### # currently, we only support ecapa-tdnn in the ecapa_tdnn.yaml # if we want use another model, please choose another configuration yaml file +seed: 1234 emb_dim: 192 batch_size: 16 model: diff --git a/examples/ami/sd0/local/compute_embdding.py b/examples/ami/sd0/local/compute_embdding.py index e4fd5da2c..30d49d511 100644 --- a/examples/ami/sd0/local/compute_embdding.py +++ b/examples/ami/sd0/local/compute_embdding.py @@ -94,7 +94,7 @@ def main(args, config): # stage2: build the speaker verification eval instance with backbone model model = SpeakerIdetification( - backbone=ecapa_tdnn, num_class=config.num_speakers) + backbone=ecapa_tdnn, num_class=1) # stage3: load the pre-trained model # we get the last model from the epoch and save_interval @@ -228,6 +228,5 @@ if __name__ == "__main__": config.merge_from_file(args.config) config.freeze() - print(config) main(args, config) diff --git a/examples/ami/sd0/local/data.sh b/examples/ami/sd0/local/data.sh deleted file mode 100755 index 478ec432d..000000000 --- a/examples/ami/sd0/local/data.sh +++ /dev/null @@ -1,49 +0,0 @@ -#!/bin/bash - -stage=1 - -TARGET_DIR=${MAIN_ROOT}/dataset/ami -data_folder=${TARGET_DIR}/amicorpus #e.g., /path/to/amicorpus/ -manual_annot_folder=${TARGET_DIR}/ami_public_manual_1.6.2 #e.g., /path/to/ami_public_manual_1.6.2/ - -save_folder=${MAIN_ROOT}/examples/ami/sd0/data -ref_rttm_dir=${save_folder}/ref_rttms -meta_data_dir=${save_folder}/metadata - -set=L - -. ${MAIN_ROOT}/utils/parse_options.sh || exit 1; -set -u -set -o pipefail - -mkdir -p ${save_folder} - -if [ ${stage} -le 0 ]; then - # Download AMI corpus, You need around 10GB of free space to get whole data - # The signals are too large to package in this way, - # so you need to use the chooser to indicate which ones you wish to download - echo "Please follow https://groups.inf.ed.ac.uk/ami/download/ to download the data." - echo "Annotations: AMI manual annotations v1.6.2 " - echo "Signals: " - echo "1) Select one or more AMI meetings: the IDs please follow ./ami_split.py" - echo "2) Select media streams: Just select Headset mix" - exit 0; -fi - -if [ ${stage} -le 1 ]; then - echo "AMI Data preparation" - - python local/ami_prepare.py --data_folder ${data_folder} \ - --manual_annot_folder ${manual_annot_folder} \ - --save_folder ${save_folder} --ref_rttm_dir ${ref_rttm_dir} \ - --meta_data_dir ${meta_data_dir} - - if [ $? -ne 0 ]; then - echo "Prepare AMI failed. Please check log message." - exit 1 - fi - -fi - -echo "AMI data preparation done." -exit 0 diff --git a/examples/ami/sd0/local/experiment.py b/examples/ami/sd0/local/experiment.py index e912a4895..5bb406d1e 100755 --- a/examples/ami/sd0/local/experiment.py +++ b/examples/ami/sd0/local/experiment.py @@ -1,22 +1,16 @@ -#!/usr/bin/python3 -"""This recipe implements diarization system using deep embedding extraction followed by spectral clustering. - -To run this recipe: -> python experiment.py hparams/ - e.g., python experiment.py hparams/ecapa_tdnn.yaml - -Condition: Oracle VAD (speech regions taken from the groundtruth). - -Note: There are multiple ways to write this recipe. We iterate over individual recordings. - This approach is less GPU memory demanding and also makes code easy to understand. - -Citation: This recipe is based on the following paper, - N. Dawalatabad, M. Ravanelli, F. Grondin, J. Thienpondt, B. Desplanques, H. Na, - "ECAPA-TDNN Embeddings for Speaker Diarization," arXiv:2104.01466, 2021. - -Authors - * Nauman Dawalatabad 2020 -""" +# Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. import argparse import glob import json @@ -81,10 +75,6 @@ def diarize_dataset( # diarizing different recordings in a dataset. for rec_id in tqdm(all_rec_ids): # this tag will be displayed in the log. - if rec_id == "IS1008a": - continue - if rec_id == "ES2011a": - continue tag = ("[" + str(split_type) + ": " + str(i) + "/" + str(len(all_rec_ids)) + "]") i = i + 1 @@ -434,6 +424,5 @@ if __name__ == "__main__": config.merge_from_file(args.config) config.freeze() - print(config) main(args, config) diff --git a/examples/ami/sd0/local/process.sh b/examples/ami/sd0/local/process.sh index 1b5ed5bd1..72c58b10a 100755 --- a/examples/ami/sd0/local/process.sh +++ b/examples/ami/sd0/local/process.sh @@ -1,6 +1,6 @@ #!/bin/bash -stage=2 +stage=0 set=L . ${MAIN_ROOT}/utils/parse_options.sh || exit 1; From bc53f726fece7f1536ad5c0d049c79686af1caee Mon Sep 17 00:00:00 2001 From: ccrrong <1039058843@qq.com> Date: Fri, 8 Apr 2022 21:34:16 +0800 Subject: [PATCH 3/5] convert dataset format to paddlespeech, test=doc --- examples/ami/sd0/local/compute_embdding.py | 7 +- examples/ami/sd0/local/process.sh | 4 +- examples/ami/sd0/run.sh | 28 +++-- paddlespeech/vector/io/dataset_from_json.py | 116 ++++++++++++++++++++ 4 files changed, 138 insertions(+), 17 deletions(-) create mode 100644 paddlespeech/vector/io/dataset_from_json.py diff --git a/examples/ami/sd0/local/compute_embdding.py b/examples/ami/sd0/local/compute_embdding.py index 30d49d511..dc824d7ca 100644 --- a/examples/ami/sd0/local/compute_embdding.py +++ b/examples/ami/sd0/local/compute_embdding.py @@ -19,7 +19,6 @@ import sys import numpy as np import paddle -from ami_dataset import AMIDataset from paddle.io import BatchSampler from paddle.io import DataLoader from tqdm.contrib import tqdm @@ -28,6 +27,7 @@ from yacs.config import CfgNode from paddlespeech.s2t.utils.log import Log from paddlespeech.vector.cluster.diarization import EmbeddingMeta from paddlespeech.vector.io.batch import batch_feature_normalize +from paddlespeech.vector.io.dataset_from_json import JSONDataset from paddlespeech.vector.models.ecapa_tdnn import EcapaTdnn from paddlespeech.vector.modules.sid_model import SpeakerIdetification from paddlespeech.vector.training.seeding import seed_everything @@ -65,7 +65,7 @@ def create_dataloader(json_file, batch_size): """ # create datasets - dataset = AMIDataset( + dataset = JSONDataset( json_file=json_file, feat_type='melspectrogram', n_mels=config.n_mels, @@ -93,8 +93,7 @@ def main(args, config): ecapa_tdnn = EcapaTdnn(**config.model) # stage2: build the speaker verification eval instance with backbone model - model = SpeakerIdetification( - backbone=ecapa_tdnn, num_class=1) + model = SpeakerIdetification(backbone=ecapa_tdnn, num_class=1) # stage3: load the pre-trained model # we get the last model from the epoch and save_interval diff --git a/examples/ami/sd0/local/process.sh b/examples/ami/sd0/local/process.sh index 72c58b10a..1dfd11b86 100755 --- a/examples/ami/sd0/local/process.sh +++ b/examples/ami/sd0/local/process.sh @@ -4,7 +4,6 @@ stage=0 set=L . ${MAIN_ROOT}/utils/parse_options.sh || exit 1; -set -u set -o pipefail data_folder=$1 @@ -12,6 +11,7 @@ manual_annot_folder=$2 save_folder=$3 pretrained_model_dir=$4 conf_path=$5 +device=$6 ref_rttm_dir=${save_folder}/ref_rttms meta_data_dir=${save_folder}/metadata @@ -35,7 +35,7 @@ if [ ${stage} -le 1 ]; then for name in dev eval; do python local/compute_embdding.py --config ${conf_path} \ --data-dir ${save_folder} \ - --device gpu:0 \ + --device ${device} \ --dataset ${name} \ --load-checkpoint ${pretrained_model_dir} done diff --git a/examples/ami/sd0/run.sh b/examples/ami/sd0/run.sh index fc6a91cc3..9035f5955 100644 --- a/examples/ami/sd0/run.sh +++ b/examples/ami/sd0/run.sh @@ -3,8 +3,7 @@ . ./path.sh || exit 1; set -e -stage=1 -stop_stage=50 +stage=0 #TARGET_DIR=${MAIN_ROOT}/dataset/ami TARGET_DIR=/home/dataset/AMI @@ -12,15 +11,14 @@ data_folder=${TARGET_DIR}/amicorpus #e.g., /path/to/amicorpus/ manual_annot_folder=${TARGET_DIR}/ami_public_manual_1.6.2 #e.g., /path/to/ami_public_manual_1.6.2/ save_folder=./save -pretraind_model_dir=${save_folder}/model - +pretraind_model_dir=${save_folder}/sv0_ecapa_tdnn_voxceleb12_ckpt_0_1_1/model conf_path=conf/ecapa_tdnn.yaml - +device=gpu . ${MAIN_ROOT}/utils/parse_options.sh || exit 1; -if [ $stage -le 0 ] && [ ${stop_stage} -ge 0 ]; then - # Prepare data and model +if [ $stage -le 0 ]; then + # Prepare data # Download AMI corpus, You need around 10GB of free space to get whole data # The signals are too large to package in this way, # so you need to use the chooser to indicate which ones you wish to download @@ -29,12 +27,20 @@ if [ $stage -le 0 ] && [ ${stop_stage} -ge 0 ]; then echo "Signals: " echo "1) Select one or more AMI meetings: the IDs please follow ./ami_split.py" echo "2) Select media streams: Just select Headset mix" - # Download the pretrained Model from HuggingFace or other pretrained model - echo "Please download the pretrained ECAPA-TDNN Model and put the pretrainde model in given path: "${pretraind_model_dir} fi -if [ $stage -le 1 ] && [ ${stop_stage} -ge 1 ]; then +if [ $stage -le 1 ]; then + # Download the pretrained model + wget https://paddlespeech.bj.bcebos.com/vector/voxceleb/sv0_ecapa_tdnn_voxceleb12_ckpt_0_1_1.tar.gz + mkdir -p ${save_folder} && tar -xvf sv0_ecapa_tdnn_voxceleb12_ckpt_0_1_1.tar.gz -C ${save_folder} + rm -rf sv0_ecapa_tdnn_voxceleb12_ckpt_0_1_1.tar.gz + echo "download the pretrained ECAPA-TDNN Model to path: "${pretraind_model_dir} +fi + +if [ $stage -le 2 ]; then # Tune hyperparams on dev set and perform final diarization on dev and eval with best hyperparams. - bash ./local/process.sh ${data_folder} ${manual_annot_folder} ${save_folder} ${pretraind_model_dir} ${conf_path} || exit 1 + echo ${data_folder} ${manual_annot_folder} ${save_folder} ${pretraind_model_dir} ${conf_path} + bash ./local/process.sh ${data_folder} ${manual_annot_folder} \ + ${save_folder} ${pretraind_model_dir} ${conf_path} ${device} || exit 1 fi diff --git a/paddlespeech/vector/io/dataset_from_json.py b/paddlespeech/vector/io/dataset_from_json.py new file mode 100644 index 000000000..5ffd2c186 --- /dev/null +++ b/paddlespeech/vector/io/dataset_from_json.py @@ -0,0 +1,116 @@ +# Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +import json + +from dataclasses import dataclass +from dataclasses import fields +from paddle.io import Dataset + +from paddleaudio import load as load_audio +from paddleaudio.compliance.librosa import melspectrogram +from paddleaudio.compliance.librosa import mfcc + + +@dataclass +class meta_info: + """the audio meta info in the vector JSONDataset + Args: + id (str): the segment name + duration (float): segment time + wav (str): wav file path + start (int): start point in the original wav file + stop (int): stop point in the original wav file + lab_id (str): the record id + """ + id: str + duration: float + wav: str + start: int + stop: int + record_id: str + + +# json dataset support feature type +feat_funcs = { + 'raw': None, + 'melspectrogram': melspectrogram, + 'mfcc': mfcc, +} + + +class JSONDataset(Dataset): + """ + dataset from json file. + """ + + def __init__(self, json_file: str, feat_type: str='raw', **kwargs): + """ + Ags: + json_file (:obj:`str`): Data prep JSON file. + labels (:obj:`List[int]`): Labels of audio files. + feat_type (:obj:`str`, `optional`, defaults to `raw`): + It identifies the feature type that user wants to extrace of an audio file. + """ + if feat_type not in feat_funcs.keys(): + raise RuntimeError( + f"Unknown feat_type: {feat_type}, it must be one in {list(feat_funcs.keys())}" + ) + + self.json_file = json_file + self.feat_type = feat_type + self.feat_config = kwargs + self._data = self._get_data() + super(JSONDataset, self).__init__() + + def _get_data(self): + with open(self.json_file, "r") as f: + meta_data = json.load(f) + data = [] + for key in meta_data: + sub_seg = meta_data[key]["wav"] + wav = sub_seg["file"] + duration = sub_seg["duration"] + start = sub_seg["start"] + stop = sub_seg["stop"] + rec_id = str(key).rsplit("_", 2)[0] + data.append( + meta_info( + str(key), + float(duration), wav, int(start), int(stop), str(rec_id))) + return data + + def _convert_to_record(self, idx: int): + sample = self._data[idx] + + record = {} + # To show all fields in a namedtuple + for field in fields(sample): + record[field.name] = getattr(sample, field.name) + + waveform, sr = load_audio(record['wav']) + waveform = waveform[record['start']:record['stop']] + + feat_func = feat_funcs[self.feat_type] + feat = feat_func( + waveform, sr=sr, **self.feat_config) if feat_func else waveform + + record.update({'feat': feat}) + + return record + + def __getitem__(self, idx): + return self._convert_to_record(idx) + + def __len__(self): + return len(self._data) From e4c5b5d16786f9d7495c6af76bd210d1380cde04 Mon Sep 17 00:00:00 2001 From: ccrrong <1039058843@qq.com> Date: Fri, 8 Apr 2022 21:43:06 +0800 Subject: [PATCH 4/5] delete unused file ami_dataset.py, test=doc --- examples/ami/sd0/local/ami_dataset.py | 90 --------------------------- 1 file changed, 90 deletions(-) delete mode 100644 examples/ami/sd0/local/ami_dataset.py diff --git a/examples/ami/sd0/local/ami_dataset.py b/examples/ami/sd0/local/ami_dataset.py deleted file mode 100644 index c44329c83..000000000 --- a/examples/ami/sd0/local/ami_dataset.py +++ /dev/null @@ -1,90 +0,0 @@ -# Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved. -# -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. -import collections -import json - -from paddle.io import Dataset - -from paddleaudio.backends import load as load_audio -from paddleaudio.datasets.dataset import feat_funcs - - -class AMIDataset(Dataset): - """ - AMI dataset. - """ - - meta_info = collections.namedtuple( - 'META_INFO', ('id', 'duration', 'wav', 'start', 'stop', 'record_id')) - - def __init__(self, json_file: str, feat_type: str='raw', **kwargs): - """ - Ags: - json_file (:obj:`str`): Data prep JSON file. - labels (:obj:`List[int]`): Labels of audio files. - feat_type (:obj:`str`, `optional`, defaults to `raw`): - It identifies the feature type that user wants to extrace of an audio file. - """ - if feat_type not in feat_funcs.keys(): - raise RuntimeError( - f"Unknown feat_type: {feat_type}, it must be one in {list(feat_funcs.keys())}" - ) - - self.json_file = json_file - self.feat_type = feat_type - self.feat_config = kwargs - self._data = self._get_data() - super(AMIDataset, self).__init__() - - def _get_data(self): - with open(self.json_file, "r") as f: - meta_data = json.load(f) - data = [] - for key in meta_data: - sub_seg = meta_data[key]["wav"] - wav = sub_seg["file"] - duration = sub_seg["duration"] - start = sub_seg["start"] - stop = sub_seg["stop"] - rec_id = str(key).rsplit("_", 2)[0] - data.append( - self.meta_info( - str(key), - float(duration), wav, int(start), int(stop), str(rec_id))) - return data - - def _convert_to_record(self, idx: int): - sample = self._data[idx] - - record = {} - # To show all fields in a namedtuple: `type(sample)._fields` - for field in type(sample)._fields: - record[field] = getattr(sample, field) - - waveform, sr = load_audio(record['wav']) - waveform = waveform[record['start']:record['stop']] - - feat_func = feat_funcs[self.feat_type] - feat = feat_func( - waveform, sr=sr, **self.feat_config) if feat_func else waveform - - record.update({'feat': feat}) - - return record - - def __getitem__(self, idx): - return self._convert_to_record(idx) - - def __len__(self): - return len(self._data) From 995436c6f17307c91030f7fc2cfb189f0532f4a8 Mon Sep 17 00:00:00 2001 From: ccrrong <1039058843@qq.com> Date: Fri, 8 Apr 2022 21:43:06 +0800 Subject: [PATCH 5/5] delete unused file ami_dataset.py, compute_der.py, test=doc --- examples/ami/sd0/local/ami_dataset.py | 90 ------------- examples/ami/sd0/local/experiment.py | 2 +- utils/compute_der.py | 175 -------------------------- 3 files changed, 1 insertion(+), 266 deletions(-) delete mode 100644 examples/ami/sd0/local/ami_dataset.py delete mode 100755 utils/compute_der.py diff --git a/examples/ami/sd0/local/ami_dataset.py b/examples/ami/sd0/local/ami_dataset.py deleted file mode 100644 index c44329c83..000000000 --- a/examples/ami/sd0/local/ami_dataset.py +++ /dev/null @@ -1,90 +0,0 @@ -# Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved. -# -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. -import collections -import json - -from paddle.io import Dataset - -from paddleaudio.backends import load as load_audio -from paddleaudio.datasets.dataset import feat_funcs - - -class AMIDataset(Dataset): - """ - AMI dataset. - """ - - meta_info = collections.namedtuple( - 'META_INFO', ('id', 'duration', 'wav', 'start', 'stop', 'record_id')) - - def __init__(self, json_file: str, feat_type: str='raw', **kwargs): - """ - Ags: - json_file (:obj:`str`): Data prep JSON file. - labels (:obj:`List[int]`): Labels of audio files. - feat_type (:obj:`str`, `optional`, defaults to `raw`): - It identifies the feature type that user wants to extrace of an audio file. - """ - if feat_type not in feat_funcs.keys(): - raise RuntimeError( - f"Unknown feat_type: {feat_type}, it must be one in {list(feat_funcs.keys())}" - ) - - self.json_file = json_file - self.feat_type = feat_type - self.feat_config = kwargs - self._data = self._get_data() - super(AMIDataset, self).__init__() - - def _get_data(self): - with open(self.json_file, "r") as f: - meta_data = json.load(f) - data = [] - for key in meta_data: - sub_seg = meta_data[key]["wav"] - wav = sub_seg["file"] - duration = sub_seg["duration"] - start = sub_seg["start"] - stop = sub_seg["stop"] - rec_id = str(key).rsplit("_", 2)[0] - data.append( - self.meta_info( - str(key), - float(duration), wav, int(start), int(stop), str(rec_id))) - return data - - def _convert_to_record(self, idx: int): - sample = self._data[idx] - - record = {} - # To show all fields in a namedtuple: `type(sample)._fields` - for field in type(sample)._fields: - record[field] = getattr(sample, field) - - waveform, sr = load_audio(record['wav']) - waveform = waveform[record['start']:record['stop']] - - feat_func = feat_funcs[self.feat_type] - feat = feat_func( - waveform, sr=sr, **self.feat_config) if feat_func else waveform - - record.update({'feat': feat}) - - return record - - def __getitem__(self, idx): - return self._convert_to_record(idx) - - def __len__(self): - return len(self._data) diff --git a/examples/ami/sd0/local/experiment.py b/examples/ami/sd0/local/experiment.py index 5bb406d1e..298228376 100755 --- a/examples/ami/sd0/local/experiment.py +++ b/examples/ami/sd0/local/experiment.py @@ -25,7 +25,7 @@ from yacs.config import CfgNode from paddlespeech.s2t.utils.log import Log from paddlespeech.vector.cluster import diarization as diar -from utils.compute_der import DER +from utils.DER import DER # Logger setup logger = Log(__name__).getlog() diff --git a/utils/compute_der.py b/utils/compute_der.py deleted file mode 100755 index d22f6a7d9..000000000 --- a/utils/compute_der.py +++ /dev/null @@ -1,175 +0,0 @@ -# Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved. -# -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. -"""Calculates Diarization Error Rate (DER) which is the sum of Missed Speaker (MS), -False Alarm (FA), and Speaker Error Rate (SER) using md-eval-22.pl from NIST RT Evaluation. - -Credits -This code is adapted from https://github.com/speechbrain/speechbrain -""" -import argparse -import os -import re -import subprocess - -import numpy as np - -FILE_IDS = re.compile(r"(?<=Speaker Diarization for).+(?=\*\*\*)") -SCORED_SPEAKER_TIME = re.compile(r"(?<=SCORED SPEAKER TIME =)[\d.]+") -MISS_SPEAKER_TIME = re.compile(r"(?<=MISSED SPEAKER TIME =)[\d.]+") -FA_SPEAKER_TIME = re.compile(r"(?<=FALARM SPEAKER TIME =)[\d.]+") -ERROR_SPEAKER_TIME = re.compile(r"(?<=SPEAKER ERROR TIME =)[\d.]+") - - -def rectify(arr): - """Corrects corner cases and converts scores into percentage. - """ - - # Numerator and denominator both 0. - arr[np.isnan(arr)] = 0 - - # Numerator > 0, but denominator = 0. - arr[np.isinf(arr)] = 1 - arr *= 100.0 - - return arr - - -def DER( - ref_rttm, - sys_rttm, - ignore_overlap=False, - collar=0.25, - individual_file_scores=False, ): - """Computes Missed Speaker percentage (MS), False Alarm (FA), - Speaker Error Rate (SER), and Diarization Error Rate (DER). - - Arguments - --------- - ref_rttm : str - The path of reference/groundtruth RTTM file. - sys_rttm : str - The path of the system generated RTTM file. - individual_file_scores : bool - If True, returns scores for each file in order. - collar : float - Forgiveness collar. - ignore_overlap : bool - If True, ignores overlapping speech during evaluation. - - Returns - ------- - MS : float array - Missed Speech. - FA : float array - False Alarms. - SER : float array - Speaker Error Rates. - DER : float array - Diarization Error Rates. - """ - - curr = os.path.abspath(os.path.dirname(__file__)) - mdEval = os.path.join(curr, "./md-eval.pl") - - cmd = [ - mdEval, - "-af", - "-r", - ref_rttm, - "-s", - sys_rttm, - "-c", - str(collar), - ] - print(cmd) - if ignore_overlap: - cmd.append("-1") - - try: - stdout = subprocess.check_output(cmd, stderr=subprocess.STDOUT) - - except subprocess.CalledProcessError as ex: - stdout = ex.output - - else: - stdout = stdout.decode("utf-8") - - # Get all recording IDs - file_ids = [m.strip() for m in FILE_IDS.findall(stdout)] - file_ids = [ - file_id[2:] if file_id.startswith("f=") else file_id - for file_id in file_ids - ] - - scored_speaker_times = np.array( - [float(m) for m in SCORED_SPEAKER_TIME.findall(stdout)]) - - miss_speaker_times = np.array( - [float(m) for m in MISS_SPEAKER_TIME.findall(stdout)]) - - fa_speaker_times = np.array( - [float(m) for m in FA_SPEAKER_TIME.findall(stdout)]) - - error_speaker_times = np.array( - [float(m) for m in ERROR_SPEAKER_TIME.findall(stdout)]) - - with np.errstate(invalid="ignore", divide="ignore"): - tot_error_times = ( - miss_speaker_times + fa_speaker_times + error_speaker_times) - miss_speaker_frac = miss_speaker_times / scored_speaker_times - fa_speaker_frac = fa_speaker_times / scored_speaker_times - sers_frac = error_speaker_times / scored_speaker_times - ders_frac = tot_error_times / scored_speaker_times - - # Values in percentage of scored_speaker_time - miss_speaker = rectify(miss_speaker_frac) - fa_speaker = rectify(fa_speaker_frac) - sers = rectify(sers_frac) - ders = rectify(ders_frac) - - if individual_file_scores: - return miss_speaker, fa_speaker, sers, ders - else: - return miss_speaker[-1], fa_speaker[-1], sers[-1], ders[-1] - - -def main(): - parser = argparse.ArgumentParser(description="Compute DER") - parser.add_argument( - "--ref_rttm", - type=str, - help="the path of reference/groundtruth RTTM file") - parser.add_argument( - "--sys_rttm", - type=str, - help="the path of the system generated RTTM file.") - parser.add_argument( - "--individual_file_scores", - type=bool, - help="whether returns scores for each file in order.") - parser.add_argument("--collar", type=float, help="forgiveness collar.") - parser.add_argument( - "--ignore_overlap", - type=bool, - help="whether ignores overlapping speech during evaluation.") - - args = parser.parse_args() - - Scores = DER(args.ref_rttm, args.sys_rttm, args.ignore_overlap, args.collar, - args.individual_file_scores) - print(Scores) - - -if __name__ == "__main__": - main()