diff --git a/examples/ami/sd0/conf/ecapa_tdnn.yaml b/examples/ami/sd0/conf/ecapa_tdnn.yaml new file mode 100755 index 000000000..319e44976 --- /dev/null +++ b/examples/ami/sd0/conf/ecapa_tdnn.yaml @@ -0,0 +1,62 @@ +########################################################### +# AMI DATA PREPARE SETTING # +########################################################### +split_type: 'full_corpus_asr' +skip_TNO: True +# Options for mic_type: 'Mix-Lapel', 'Mix-Headset', 'Array1', 'Array1-01', 'BeamformIt' +mic_type: 'Mix-Headset' +vad_type: 'oracle' +max_subseg_dur: 3.0 +overlap: 1.5 +# Some more exp folders (for cleaner structure). +embedding_dir: emb #!ref /emb +meta_data_dir: metadata #!ref /metadata +ref_rttm_dir: ref_rttms #!ref /ref_rttms +sys_rttm_dir: sys_rttms #!ref /sys_rttms +der_dir: DER #!ref /DER + + +########################################################### +# FEATURE EXTRACTION SETTING # +########################################################### +# currently, we only support fbank +sr: 16000 # sample rate +n_mels: 80 +window_size: 400 #25ms, sample rate 16000, 25 * 16000 / 1000 = 400 +hop_size: 160 #10ms, sample rate 16000, 10 * 16000 / 1000 = 160 +#left_frames: 0 +#right_frames: 0 +#deltas: False + + +########################################################### +# MODEL SETTING # +########################################################### +# currently, we only support ecapa-tdnn in the ecapa_tdnn.yaml +# if we want use another model, please choose another configuration yaml file +seed: 1234 +emb_dim: 192 +batch_size: 16 +model: + input_size: 80 + channels: [1024, 1024, 1024, 1024, 3072] + kernel_sizes: [5, 3, 3, 3, 1] + dilations: [1, 2, 3, 4, 1] + attention_channels: 128 + lin_neurons: 192 +# Will automatically download ECAPA-TDNN model (best). + +########################################################### +# SPECTRAL CLUSTERING SETTING # +########################################################### +backend: 'SC' # options: 'kmeans' # Note: kmeans goes only with cos affinity +affinity: 'cos' # options: cos, nn +max_num_spkrs: 10 +oracle_n_spkrs: True + + +########################################################### +# DER EVALUATION SETTING # +########################################################### +ignore_overlap: True +forgiveness_collar: 0.25 diff --git a/examples/ami/sd0/local/compute_embdding.py b/examples/ami/sd0/local/compute_embdding.py new file mode 100644 index 000000000..dc824d7ca --- /dev/null +++ b/examples/ami/sd0/local/compute_embdding.py @@ -0,0 +1,231 @@ +# Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +import argparse +import json +import os +import pickle +import sys + +import numpy as np +import paddle +from paddle.io import BatchSampler +from paddle.io import DataLoader +from tqdm.contrib import tqdm +from yacs.config import CfgNode + +from paddlespeech.s2t.utils.log import Log +from paddlespeech.vector.cluster.diarization import EmbeddingMeta +from paddlespeech.vector.io.batch import batch_feature_normalize +from paddlespeech.vector.io.dataset_from_json import JSONDataset +from paddlespeech.vector.models.ecapa_tdnn import EcapaTdnn +from paddlespeech.vector.modules.sid_model import SpeakerIdetification +from paddlespeech.vector.training.seeding import seed_everything + +# Logger setup +logger = Log(__name__).getlog() + + +def prepare_subset_json(full_meta_data, rec_id, out_meta_file): + """Prepares metadata for a given recording ID. + + Arguments + --------- + full_meta_data : json + Full meta (json) containing all the recordings + rec_id : str + The recording ID for which meta (json) has to be prepared + out_meta_file : str + Path of the output meta (json) file. + """ + + subset = {} + for key in full_meta_data: + k = str(key) + if k.startswith(rec_id): + subset[key] = full_meta_data[key] + + with open(out_meta_file, mode="w") as json_f: + json.dump(subset, json_f, indent=2) + + +def create_dataloader(json_file, batch_size): + """Creates the datasets and their data processing pipelines. + This is used for multi-mic processing. + """ + + # create datasets + dataset = JSONDataset( + json_file=json_file, + feat_type='melspectrogram', + n_mels=config.n_mels, + window_size=config.window_size, + hop_length=config.hop_size) + + # create dataloader + batch_sampler = BatchSampler(dataset, batch_size=batch_size, shuffle=True) + dataloader = DataLoader(dataset, + batch_sampler=batch_sampler, + collate_fn=lambda x: batch_feature_normalize( + x, mean_norm=True, std_norm=False), + return_list=True) + + return dataloader + + +def main(args, config): + # set the training device, cpu or gpu + paddle.set_device(args.device) + # set the random seed + seed_everything(config.seed) + + # stage1: build the dnn backbone model network + ecapa_tdnn = EcapaTdnn(**config.model) + + # stage2: build the speaker verification eval instance with backbone model + model = SpeakerIdetification(backbone=ecapa_tdnn, num_class=1) + + # stage3: load the pre-trained model + # we get the last model from the epoch and save_interval + args.load_checkpoint = os.path.abspath( + os.path.expanduser(args.load_checkpoint)) + + # load model checkpoint to sid model + state_dict = paddle.load( + os.path.join(args.load_checkpoint, 'model.pdparams')) + model.set_state_dict(state_dict) + logger.info(f'Checkpoint loaded from {args.load_checkpoint}') + + # set the model to eval mode + model.eval() + + # load meta data + meta_file = os.path.join( + args.data_dir, + config.meta_data_dir, + "ami_" + args.dataset + "." + config.mic_type + ".subsegs.json", ) + with open(meta_file, "r") as f: + full_meta = json.load(f) + + # get all the recording IDs in this dataset. + all_keys = full_meta.keys() + A = [word.rstrip().split("_")[0] for word in all_keys] + all_rec_ids = list(set(A[1:])) + all_rec_ids.sort() + split = "AMI_" + args.dataset + i = 1 + + msg = "Extra embdding for " + args.dataset + " set" + logger.info(msg) + + if len(all_rec_ids) <= 0: + msg = "No recording IDs found! Please check if meta_data json file is properly generated." + logger.error(msg) + sys.exit() + + # extra different recordings embdding in a dataset. + for rec_id in tqdm(all_rec_ids): + # This tag will be displayed in the log. + tag = ("[" + str(args.dataset) + ": " + str(i) + "/" + + str(len(all_rec_ids)) + "]") + i = i + 1 + + # log message. + msg = "Embdding %s : %s " % (tag, rec_id) + logger.debug(msg) + + # embedding directory. + if not os.path.exists( + os.path.join(args.data_dir, config.embedding_dir, split)): + os.makedirs( + os.path.join(args.data_dir, config.embedding_dir, split)) + + # file to store embeddings. + emb_file_name = rec_id + "." + config.mic_type + ".emb_stat.pkl" + diary_stat_emb_file = os.path.join(args.data_dir, config.embedding_dir, + split, emb_file_name) + + # prepare a metadata (json) for one recording. This is basically a subset of full_meta. + # lets keep this meta-info in embedding directory itself. + json_file_name = rec_id + "." + config.mic_type + ".json" + meta_per_rec_file = os.path.join(args.data_dir, config.embedding_dir, + split, json_file_name) + + # write subset (meta for one recording) json metadata. + prepare_subset_json(full_meta, rec_id, meta_per_rec_file) + + # prepare data loader. + diary_set_loader = create_dataloader(meta_per_rec_file, + config.batch_size) + + # extract embeddings (skip if already done). + if not os.path.isfile(diary_stat_emb_file): + logger.debug("Extracting deep embeddings") + embeddings = np.empty(shape=[0, config.emb_dim], dtype=np.float64) + segset = [] + + for batch_idx, batch in enumerate(tqdm(diary_set_loader)): + # extrac the audio embedding + ids, feats, lengths = batch['ids'], batch['feats'], batch[ + 'lengths'] + seg = [x for x in ids] + segset = segset + seg + emb = model.backbone(feats, lengths).squeeze( + -1).numpy() # (N, emb_size, 1) -> (N, emb_size) + embeddings = np.concatenate((embeddings, emb), axis=0) + + segset = np.array(segset, dtype="|O") + stat_obj = EmbeddingMeta( + segset=segset, + stats=embeddings, ) + logger.debug("Saving Embeddings...") + with open(diary_stat_emb_file, "wb") as output: + pickle.dump(stat_obj, output) + + else: + logger.debug("Skipping embedding extraction (as already present).") + + +# Begin experiment! +if __name__ == "__main__": + parser = argparse.ArgumentParser(__doc__) + parser.add_argument( + '--device', + default="gpu", + help="Select which device to perform diarization, defaults to gpu.") + parser.add_argument( + "--config", default=None, type=str, help="configuration file") + parser.add_argument( + "--data-dir", + default="../save/", + type=str, + help="processsed data directory") + parser.add_argument( + "--dataset", + choices=['dev', 'eval'], + default="dev", + type=str, + help="Select which dataset to extra embdding, defaults to dev") + parser.add_argument( + "--load-checkpoint", + type=str, + default='', + help="Directory to load model checkpoint to compute embeddings.") + args = parser.parse_args() + config = CfgNode(new_allowed=True) + if args.config: + config.merge_from_file(args.config) + + config.freeze() + + main(args, config) diff --git a/examples/ami/sd0/local/data.sh b/examples/ami/sd0/local/data.sh deleted file mode 100755 index 478ec432d..000000000 --- a/examples/ami/sd0/local/data.sh +++ /dev/null @@ -1,49 +0,0 @@ -#!/bin/bash - -stage=1 - -TARGET_DIR=${MAIN_ROOT}/dataset/ami -data_folder=${TARGET_DIR}/amicorpus #e.g., /path/to/amicorpus/ -manual_annot_folder=${TARGET_DIR}/ami_public_manual_1.6.2 #e.g., /path/to/ami_public_manual_1.6.2/ - -save_folder=${MAIN_ROOT}/examples/ami/sd0/data -ref_rttm_dir=${save_folder}/ref_rttms -meta_data_dir=${save_folder}/metadata - -set=L - -. ${MAIN_ROOT}/utils/parse_options.sh || exit 1; -set -u -set -o pipefail - -mkdir -p ${save_folder} - -if [ ${stage} -le 0 ]; then - # Download AMI corpus, You need around 10GB of free space to get whole data - # The signals are too large to package in this way, - # so you need to use the chooser to indicate which ones you wish to download - echo "Please follow https://groups.inf.ed.ac.uk/ami/download/ to download the data." - echo "Annotations: AMI manual annotations v1.6.2 " - echo "Signals: " - echo "1) Select one or more AMI meetings: the IDs please follow ./ami_split.py" - echo "2) Select media streams: Just select Headset mix" - exit 0; -fi - -if [ ${stage} -le 1 ]; then - echo "AMI Data preparation" - - python local/ami_prepare.py --data_folder ${data_folder} \ - --manual_annot_folder ${manual_annot_folder} \ - --save_folder ${save_folder} --ref_rttm_dir ${ref_rttm_dir} \ - --meta_data_dir ${meta_data_dir} - - if [ $? -ne 0 ]; then - echo "Prepare AMI failed. Please check log message." - exit 1 - fi - -fi - -echo "AMI data preparation done." -exit 0 diff --git a/examples/ami/sd0/local/experiment.py b/examples/ami/sd0/local/experiment.py new file mode 100755 index 000000000..298228376 --- /dev/null +++ b/examples/ami/sd0/local/experiment.py @@ -0,0 +1,428 @@ +# Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +import argparse +import glob +import json +import os +import pickle +import shutil +import sys + +import numpy as np +from tqdm.contrib import tqdm +from yacs.config import CfgNode + +from paddlespeech.s2t.utils.log import Log +from paddlespeech.vector.cluster import diarization as diar +from utils.DER import DER + +# Logger setup +logger = Log(__name__).getlog() + + +def diarize_dataset( + full_meta, + split_type, + n_lambdas, + pval, + save_dir, + config, + n_neighbors=10, ): + """This function diarizes all the recordings in a given dataset. It performs + computation of embedding and clusters them using spectral clustering (or other backends). + The output speaker boundary file is stored in the RTTM format. + """ + + # prepare `spkr_info` only once when Oracle num of speakers is selected. + # spkr_info is essential to obtain number of speakers from groundtruth. + if config.oracle_n_spkrs is True: + full_ref_rttm_file = os.path.join(save_dir, config.ref_rttm_dir, + "fullref_ami_" + split_type + ".rttm") + rttm = diar.read_rttm(full_ref_rttm_file) + + spkr_info = list( # noqa F841 + filter(lambda x: x.startswith("SPKR-INFO"), rttm)) + + # get all the recording IDs in this dataset. + all_keys = full_meta.keys() + A = [word.rstrip().split("_")[0] for word in all_keys] + all_rec_ids = list(set(A[1:])) + all_rec_ids.sort() + split = "AMI_" + split_type + i = 1 + + # adding tag for directory path. + type_of_num_spkr = "oracle" if config.oracle_n_spkrs else "est" + tag = (type_of_num_spkr + "_" + str(config.affinity) + "_" + config.backend) + + # make out rttm dir + out_rttm_dir = os.path.join(save_dir, config.sys_rttm_dir, config.mic_type, + split, tag) + if not os.path.exists(out_rttm_dir): + os.makedirs(out_rttm_dir) + + # diarizing different recordings in a dataset. + for rec_id in tqdm(all_rec_ids): + # this tag will be displayed in the log. + tag = ("[" + str(split_type) + ": " + str(i) + "/" + + str(len(all_rec_ids)) + "]") + i = i + 1 + + # log message. + msg = "Diarizing %s : %s " % (tag, rec_id) + logger.debug(msg) + + # load embeddings. + emb_file_name = rec_id + "." + config.mic_type + ".emb_stat.pkl" + diary_stat_emb_file = os.path.join(save_dir, config.embedding_dir, + split, emb_file_name) + if not os.path.isfile(diary_stat_emb_file): + msg = "Embdding file %s not found! Please check if embdding file is properly generated." % ( + diary_stat_emb_file) + logger.error(msg) + sys.exit() + with open(diary_stat_emb_file, "rb") as in_file: + diary_obj = pickle.load(in_file) + + out_rttm_file = out_rttm_dir + "/" + rec_id + ".rttm" + + # processing starts from here. + if config.oracle_n_spkrs is True: + # oracle num of speakers. + num_spkrs = diar.get_oracle_num_spkrs(rec_id, spkr_info) + else: + if config.affinity == "nn": + # num of speakers tunned on dev set (only for nn affinity). + num_spkrs = n_lambdas + else: + # num of speakers will be estimated using max eigen gap for cos based affinity. + # so adding None here. Will use this None later-on. + num_spkrs = None + + if config.backend == "kmeans": + diar.do_kmeans_clustering( + diary_obj, + out_rttm_file, + rec_id, + num_spkrs, + pval, ) + + if config.backend == "SC": + # go for Spectral Clustering (SC). + diar.do_spec_clustering( + diary_obj, + out_rttm_file, + rec_id, + num_spkrs, + pval, + config.affinity, + n_neighbors, ) + + # can used for AHC later. Likewise one can add different backends here. + if config.backend == "AHC": + # call AHC + threshold = pval # pval for AHC is nothing but threshold. + diar.do_AHC(diary_obj, out_rttm_file, rec_id, num_spkrs, threshold) + + # once all RTTM outputs are generated, concatenate individual RTTM files to obtain single RTTM file. + # this is not needed but just staying with the standards. + concate_rttm_file = out_rttm_dir + "/sys_output.rttm" + logger.debug("Concatenating individual RTTM files...") + with open(concate_rttm_file, "w") as cat_file: + for f in glob.glob(out_rttm_dir + "/*.rttm"): + if f == concate_rttm_file: + continue + with open(f, "r") as indi_rttm_file: + shutil.copyfileobj(indi_rttm_file, cat_file) + + msg = "The system generated RTTM file for %s set : %s" % ( + split_type, concate_rttm_file, ) + logger.debug(msg) + + return concate_rttm_file + + +def dev_pval_tuner(full_meta, save_dir, config): + """Tuning p_value for affinity matrix. + The p_value used so that only p% of the values in each row is retained. + """ + + DER_list = [] + prange = np.arange(0.002, 0.015, 0.001) + + n_lambdas = None # using it as flag later. + for p_v in prange: + # Process whole dataset for value of p_v. + concate_rttm_file = diarize_dataset(full_meta, "dev", n_lambdas, p_v, + save_dir, config) + + ref_rttm_file = os.path.join(save_dir, config.ref_rttm_dir, + "fullref_ami_dev.rttm") + sys_rttm_file = concate_rttm_file + [MS, FA, SER, DER_] = DER( + ref_rttm_file, + sys_rttm_file, + config.ignore_overlap, + config.forgiveness_collar, ) + + DER_list.append(DER_) + + if config.oracle_n_spkrs is True and config.backend == "kmeans": + # no need of p_val search. Note p_val is needed for SC for both oracle and est num of speakers. + # p_val is needed in oracle_n_spkr=False when using kmeans backend. + break + + # Take p_val that gave minmum DER on Dev dataset. + tuned_p_val = prange[DER_list.index(min(DER_list))] + + return tuned_p_val + + +def dev_ahc_threshold_tuner(full_meta, save_dir, config): + """Tuning threshold for affinity matrix. This function is called when AHC is used as backend. + """ + + DER_list = [] + prange = np.arange(0.0, 1.0, 0.1) + + n_lambdas = None # using it as flag later. + + # Note: p_val is threshold in case of AHC. + for p_v in prange: + # Process whole dataset for value of p_v. + concate_rttm_file = diarize_dataset(full_meta, "dev", n_lambdas, p_v, + save_dir, config) + + ref_rttm = os.path.join(save_dir, config.ref_rttm_dir, + "fullref_ami_dev.rttm") + sys_rttm = concate_rttm_file + [MS, FA, SER, DER_] = DER( + ref_rttm, + sys_rttm, + config.ignore_overlap, + config.forgiveness_collar, ) + + DER_list.append(DER_) + + if config.oracle_n_spkrs is True: + break # no need of threshold search. + + # Take p_val that gave minmum DER on Dev dataset. + tuned_p_val = prange[DER_list.index(min(DER_list))] + + return tuned_p_val + + +def dev_nn_tuner(full_meta, split_type, save_dir, config): + """Tuning n_neighbors on dev set. Assuming oracle num of speakers. + This is used when nn based affinity is selected. + """ + + DER_list = [] + pval = None + + # Now assumming oracle num of speakers. + n_lambdas = 4 + + for nn in range(5, 15): + + # Process whole dataset for value of n_lambdas. + concate_rttm_file = diarize_dataset(full_meta, "dev", n_lambdas, p_v, + save_dir, config, nn) + + ref_rttm = os.path.join(save_dir, config.ref_rttm_dir, + "fullref_ami_dev.rttm") + sys_rttm = concate_rttm_file + [MS, FA, SER, DER_] = DER( + ref_rttm, + sys_rttm, + config.ignore_overlap, + config.forgiveness_collar, ) + + DER_list.append([nn, DER_]) + + if config.oracle_n_spkrs is True and config.backend == "kmeans": + break + + DER_list.sort(key=lambda x: x[1]) + tunned_nn = DER_list[0] + + return tunned_nn[0] + + +def dev_tuner(full_meta, split_type, save_dir, config): + """Tuning n_components on dev set. Used for nn based affinity matrix. + Note: This is a very basic tunning for nn based affinity. + This is work in progress till we find a better way. + """ + + DER_list = [] + pval = None + for n_lambdas in range(1, config.max_num_spkrs + 1): + + # Process whole dataset for value of n_lambdas. + concate_rttm_file = diarize_dataset(full_meta, "dev", n_lambdas, p_v, + save_dir, config) + + ref_rttm = os.path.join(save_dir, config.ref_rttm_dir, + "fullref_ami_dev.rttm") + sys_rttm = concate_rttm_file + [MS, FA, SER, DER_] = DER( + ref_rttm, + sys_rttm, + config.ignore_overlap, + config.forgiveness_collar, ) + + DER_list.append(DER_) + + # Take n_lambdas with minmum DER. + tuned_n_lambdas = DER_list.index(min(DER_list)) + 1 + + return tuned_n_lambdas + + +def main(args, config): + # AMI Dev Set: Tune hyperparams on dev set. + # Read the embdding file for dev set generated during embdding compute + dev_meta_file = os.path.join( + args.data_dir, + config.meta_data_dir, + "ami_dev." + config.mic_type + ".subsegs.json", ) + with open(dev_meta_file, "r") as f: + meta_dev = json.load(f) + + full_meta = meta_dev + + # Processing starts from here + # Following few lines selects option for different backend and affinity matrices. Finds best values for hyperameters using dev set. + ref_rttm_file = os.path.join(args.data_dir, config.ref_rttm_dir, + "fullref_ami_dev.rttm") + best_nn = None + if config.affinity == "nn": + logger.info("Tuning for nn (Multiple iterations over AMI Dev set)") + best_nn = dev_nn_tuner(full_meta, args.data_dir, config) + + n_lambdas = None + best_pval = None + + if config.affinity == "cos" and (config.backend == "SC" or + config.backend == "kmeans"): + # oracle num_spkrs or not, doesn't matter for kmeans and SC backends + # cos: Tune for the best pval for SC /kmeans (for unknown num of spkrs) + logger.info( + "Tuning for p-value for SC (Multiple iterations over AMI Dev set)") + best_pval = dev_pval_tuner(full_meta, args.data_dir, config) + + elif config.backend == "AHC": + logger.info("Tuning for threshold-value for AHC") + best_threshold = dev_ahc_threshold_tuner(full_meta, args.data_dir, + config) + best_pval = best_threshold + else: + # NN for unknown num of speakers (can be used in future) + if config.oracle_n_spkrs is False: + # nn: Tune num of number of components (to be updated later) + logger.info( + "Tuning for number of eigen components for NN (Multiple iterations over AMI Dev set)" + ) + # dev_tuner used for tuning num of components in NN. Can be used in future. + n_lambdas = dev_tuner(full_meta, args.data_dir, config) + + # load 'dev' and 'eval' metadata files. + full_meta_dev = full_meta # current full_meta is for 'dev' + eval_meta_file = os.path.join( + args.data_dir, + config.meta_data_dir, + "ami_eval." + config.mic_type + ".subsegs.json", ) + with open(eval_meta_file, "r") as f: + full_meta_eval = json.load(f) + + # tag to be appended to final output DER files. Writing DER for individual files. + type_of_num_spkr = "oracle" if config.oracle_n_spkrs else "est" + tag = ( + type_of_num_spkr + "_" + str(config.affinity) + "." + config.mic_type) + + # perform final diarization on 'dev' and 'eval' with best hyperparams. + final_DERs = {} + out_der_dir = os.path.join(args.data_dir, config.der_dir) + if not os.path.exists(out_der_dir): + os.makedirs(out_der_dir) + + for split_type in ["dev", "eval"]: + if split_type == "dev": + full_meta = full_meta_dev + else: + full_meta = full_meta_eval + + # performing diarization. + msg = "Diarizing using best hyperparams: " + split_type + " set" + logger.info(msg) + out_boundaries = diarize_dataset( + full_meta, + split_type, + n_lambdas=n_lambdas, + pval=best_pval, + n_neighbors=best_nn, + save_dir=args.data_dir, + config=config) + + # computing DER. + msg = "Computing DERs for " + split_type + " set" + logger.info(msg) + ref_rttm = os.path.join(args.data_dir, config.ref_rttm_dir, + "fullref_ami_" + split_type + ".rttm") + sys_rttm = out_boundaries + [MS, FA, SER, DER_vals] = DER( + ref_rttm, + sys_rttm, + config.ignore_overlap, + config.forgiveness_collar, + individual_file_scores=True, ) + + # writing DER values to a file. Append tag. + der_file_name = split_type + "_DER_" + tag + out_der_file = os.path.join(out_der_dir, der_file_name) + msg = "Writing DER file to: " + out_der_file + logger.info(msg) + diar.write_ders_file(ref_rttm, DER_vals, out_der_file) + + msg = ("AMI " + split_type + " set DER = %s %%\n" % + (str(round(DER_vals[-1], 2)))) + logger.info(msg) + final_DERs[split_type] = round(DER_vals[-1], 2) + + # final print DERs + msg = ( + "Final Diarization Error Rate (%%) on AMI corpus: Dev = %s %% | Eval = %s %%\n" + % (str(final_DERs["dev"]), str(final_DERs["eval"]))) + logger.info(msg) + + +if __name__ == "__main__": + parser = argparse.ArgumentParser(__doc__) + parser.add_argument( + "--config", default=None, type=str, help="configuration file") + parser.add_argument( + "--data-dir", + default="../data/", + type=str, + help="processsed data directory") + args = parser.parse_args() + config = CfgNode(new_allowed=True) + if args.config: + config.merge_from_file(args.config) + + config.freeze() + + main(args, config) diff --git a/examples/ami/sd0/local/process.sh b/examples/ami/sd0/local/process.sh new file mode 100755 index 000000000..1dfd11b86 --- /dev/null +++ b/examples/ami/sd0/local/process.sh @@ -0,0 +1,49 @@ +#!/bin/bash + +stage=0 +set=L + +. ${MAIN_ROOT}/utils/parse_options.sh || exit 1; +set -o pipefail + +data_folder=$1 +manual_annot_folder=$2 +save_folder=$3 +pretrained_model_dir=$4 +conf_path=$5 +device=$6 + +ref_rttm_dir=${save_folder}/ref_rttms +meta_data_dir=${save_folder}/metadata + +if [ ${stage} -le 0 ]; then + echo "AMI Data preparation" + python local/ami_prepare.py --data_folder ${data_folder} \ + --manual_annot_folder ${manual_annot_folder} \ + --save_folder ${save_folder} --ref_rttm_dir ${ref_rttm_dir} \ + --meta_data_dir ${meta_data_dir} + + if [ $? -ne 0 ]; then + echo "Prepare AMI failed. Please check log message." + exit 1 + fi + echo "AMI data preparation done." +fi + +if [ ${stage} -le 1 ]; then + # extra embddings for dev and eval dataset + for name in dev eval; do + python local/compute_embdding.py --config ${conf_path} \ + --data-dir ${save_folder} \ + --device ${device} \ + --dataset ${name} \ + --load-checkpoint ${pretrained_model_dir} + done +fi + +if [ ${stage} -le 2 ]; then + # tune hyperparams on dev set + # perform final diarization on 'dev' and 'eval' with best hyperparams + python local/experiment.py --config ${conf_path} \ + --data-dir ${save_folder} +fi diff --git a/examples/ami/sd0/run.sh b/examples/ami/sd0/run.sh index 91d4b706a..9035f5955 100644 --- a/examples/ami/sd0/run.sh +++ b/examples/ami/sd0/run.sh @@ -1,14 +1,46 @@ #!/bin/bash -. path.sh || exit 1; +. ./path.sh || exit 1; set -e -stage=1 +stage=0 +#TARGET_DIR=${MAIN_ROOT}/dataset/ami +TARGET_DIR=/home/dataset/AMI +data_folder=${TARGET_DIR}/amicorpus #e.g., /path/to/amicorpus/ +manual_annot_folder=${TARGET_DIR}/ami_public_manual_1.6.2 #e.g., /path/to/ami_public_manual_1.6.2/ + +save_folder=./save +pretraind_model_dir=${save_folder}/sv0_ecapa_tdnn_voxceleb12_ckpt_0_1_1/model +conf_path=conf/ecapa_tdnn.yaml +device=gpu . ${MAIN_ROOT}/utils/parse_options.sh || exit 1; -if [ ${stage} -le 1 ]; then - # prepare data - bash ./local/data.sh || exit -1 -fi \ No newline at end of file +if [ $stage -le 0 ]; then + # Prepare data + # Download AMI corpus, You need around 10GB of free space to get whole data + # The signals are too large to package in this way, + # so you need to use the chooser to indicate which ones you wish to download + echo "Please follow https://groups.inf.ed.ac.uk/ami/download/ to download the data." + echo "Annotations: AMI manual annotations v1.6.2 " + echo "Signals: " + echo "1) Select one or more AMI meetings: the IDs please follow ./ami_split.py" + echo "2) Select media streams: Just select Headset mix" +fi + +if [ $stage -le 1 ]; then + # Download the pretrained model + wget https://paddlespeech.bj.bcebos.com/vector/voxceleb/sv0_ecapa_tdnn_voxceleb12_ckpt_0_1_1.tar.gz + mkdir -p ${save_folder} && tar -xvf sv0_ecapa_tdnn_voxceleb12_ckpt_0_1_1.tar.gz -C ${save_folder} + rm -rf sv0_ecapa_tdnn_voxceleb12_ckpt_0_1_1.tar.gz + echo "download the pretrained ECAPA-TDNN Model to path: "${pretraind_model_dir} +fi + +if [ $stage -le 2 ]; then + # Tune hyperparams on dev set and perform final diarization on dev and eval with best hyperparams. + echo ${data_folder} ${manual_annot_folder} ${save_folder} ${pretraind_model_dir} ${conf_path} + bash ./local/process.sh ${data_folder} ${manual_annot_folder} \ + ${save_folder} ${pretraind_model_dir} ${conf_path} ${device} || exit 1 +fi + diff --git a/paddlespeech/vector/cluster/diarization.py b/paddlespeech/vector/cluster/diarization.py index 597aa4807..5b2157257 100644 --- a/paddlespeech/vector/cluster/diarization.py +++ b/paddlespeech/vector/cluster/diarization.py @@ -746,6 +746,77 @@ def merge_ssegs_same_speaker(lol): return new_lol +def write_ders_file(ref_rttm, DER, out_der_file): + """Write the final DERs for individual recording. + + Arguments + --------- + ref_rttm : str + Reference RTTM file. + DER : array + Array containing DER values of each recording. + out_der_file : str + File to write the DERs. + """ + + rttm = read_rttm(ref_rttm) + spkr_info = list(filter(lambda x: x.startswith("SPKR-INFO"), rttm)) + + rec_id_list = [] + count = 0 + + with open(out_der_file, "w") as f: + for row in spkr_info: + a = row.split(" ") + rec_id = a[1] + if rec_id not in rec_id_list: + r = [rec_id, str(round(DER[count], 2))] + rec_id_list.append(rec_id) + line_str = " ".join(r) + f.write("%s\n" % line_str) + count += 1 + r = ["OVERALL ", str(round(DER[count], 2))] + line_str = " ".join(r) + f.write("%s\n" % line_str) + + +def get_oracle_num_spkrs(rec_id, spkr_info): + """ + Returns actual number of speakers in a recording from the ground-truth. + This can be used when the condition is oracle number of speakers. + + Arguments + --------- + rec_id : str + Recording ID for which the number of speakers have to be obtained. + spkr_info : list + Header of the RTTM file. Starting with `SPKR-INFO`. + + Example + ------- + >>> from speechbrain.processing import diarization as diar + >>> spkr_info = ['SPKR-INFO ES2011a 0 unknown ES2011a.A ', + ... 'SPKR-INFO ES2011a 0 unknown ES2011a.B ', + ... 'SPKR-INFO ES2011a 0 unknown ES2011a.C ', + ... 'SPKR-INFO ES2011a 0 unknown ES2011a.D ', + ... 'SPKR-INFO ES2011b 0 unknown ES2011b.A ', + ... 'SPKR-INFO ES2011b 0 unknown ES2011b.B ', + ... 'SPKR-INFO ES2011b 0 unknown ES2011b.C '] + >>> diar.get_oracle_num_spkrs('ES2011a', spkr_info) + 4 + >>> diar.get_oracle_num_spkrs('ES2011b', spkr_info) + 3 + """ + + num_spkrs = 0 + for line in spkr_info: + if rec_id in line: + # Since rec_id is prefix for each speaker + num_spkrs += 1 + + return num_spkrs + + def distribute_overlap(lol): """ Distributes the overlapped speech equally among the adjacent segments @@ -826,6 +897,29 @@ def distribute_overlap(lol): return new_lol +def read_rttm(rttm_file_path): + """ + Reads and returns RTTM in list format. + + Arguments + --------- + rttm_file_path : str + Path to the RTTM file to be read. + + Returns + ------- + rttm : list + List containing rows of RTTM file. + """ + + rttm = [] + with open(rttm_file_path, "r") as f: + for line in f: + entry = line[:-1] + rttm.append(entry) + return rttm + + def write_rttm(segs_list, out_rttm_file): """ Writes the segment list in RTTM format (A standard NIST format). diff --git a/paddlespeech/vector/io/dataset_from_json.py b/paddlespeech/vector/io/dataset_from_json.py new file mode 100644 index 000000000..5ffd2c186 --- /dev/null +++ b/paddlespeech/vector/io/dataset_from_json.py @@ -0,0 +1,116 @@ +# Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +import json + +from dataclasses import dataclass +from dataclasses import fields +from paddle.io import Dataset + +from paddleaudio import load as load_audio +from paddleaudio.compliance.librosa import melspectrogram +from paddleaudio.compliance.librosa import mfcc + + +@dataclass +class meta_info: + """the audio meta info in the vector JSONDataset + Args: + id (str): the segment name + duration (float): segment time + wav (str): wav file path + start (int): start point in the original wav file + stop (int): stop point in the original wav file + lab_id (str): the record id + """ + id: str + duration: float + wav: str + start: int + stop: int + record_id: str + + +# json dataset support feature type +feat_funcs = { + 'raw': None, + 'melspectrogram': melspectrogram, + 'mfcc': mfcc, +} + + +class JSONDataset(Dataset): + """ + dataset from json file. + """ + + def __init__(self, json_file: str, feat_type: str='raw', **kwargs): + """ + Ags: + json_file (:obj:`str`): Data prep JSON file. + labels (:obj:`List[int]`): Labels of audio files. + feat_type (:obj:`str`, `optional`, defaults to `raw`): + It identifies the feature type that user wants to extrace of an audio file. + """ + if feat_type not in feat_funcs.keys(): + raise RuntimeError( + f"Unknown feat_type: {feat_type}, it must be one in {list(feat_funcs.keys())}" + ) + + self.json_file = json_file + self.feat_type = feat_type + self.feat_config = kwargs + self._data = self._get_data() + super(JSONDataset, self).__init__() + + def _get_data(self): + with open(self.json_file, "r") as f: + meta_data = json.load(f) + data = [] + for key in meta_data: + sub_seg = meta_data[key]["wav"] + wav = sub_seg["file"] + duration = sub_seg["duration"] + start = sub_seg["start"] + stop = sub_seg["stop"] + rec_id = str(key).rsplit("_", 2)[0] + data.append( + meta_info( + str(key), + float(duration), wav, int(start), int(stop), str(rec_id))) + return data + + def _convert_to_record(self, idx: int): + sample = self._data[idx] + + record = {} + # To show all fields in a namedtuple + for field in fields(sample): + record[field.name] = getattr(sample, field.name) + + waveform, sr = load_audio(record['wav']) + waveform = waveform[record['start']:record['stop']] + + feat_func = feat_funcs[self.feat_type] + feat = feat_func( + waveform, sr=sr, **self.feat_config) if feat_func else waveform + + record.update({'feat': feat}) + + return record + + def __getitem__(self, idx): + return self._convert_to_record(idx) + + def __len__(self): + return len(self._data)