diff --git a/utils/DER.py b/utils/DER.py new file mode 100755 index 00000000..4c67c4c7 --- /dev/null +++ b/utils/DER.py @@ -0,0 +1,152 @@ +"""Calculates Diarization Error Rate (DER) which is the sum of Missed Speaker (MS), +False Alarm (FA), and Speaker Error Rate (SER) using md-eval-22.pl from NIST RT Evaluation. + +Authors + * Neville Ryant 2018 + * Nauman Dawalatabad 2020 + +Credits + This code is adapted from https://github.com/nryant/dscore +""" + +import os +import re +import subprocess +import numpy as np + +FILE_IDS = re.compile(r"(?<=Speaker Diarization for).+(?=\*\*\*)") +SCORED_SPEAKER_TIME = re.compile(r"(?<=SCORED SPEAKER TIME =)[\d.]+") +MISS_SPEAKER_TIME = re.compile(r"(?<=MISSED SPEAKER TIME =)[\d.]+") +FA_SPEAKER_TIME = re.compile(r"(?<=FALARM SPEAKER TIME =)[\d.]+") +ERROR_SPEAKER_TIME = re.compile(r"(?<=SPEAKER ERROR TIME =)[\d.]+") + + +def rectify(arr): + """Corrects corner cases and converts scores into percentage. + """ + + # Numerator and denominator both 0. + arr[np.isnan(arr)] = 0 + + # Numerator > 0, but denominator = 0. + arr[np.isinf(arr)] = 1 + arr *= 100.0 + + return arr + + +def DER( + ref_rttm, + sys_rttm, + ignore_overlap=False, + collar=0.25, + individual_file_scores=False, +): + """Computes Missed Speaker percentage (MS), False Alarm (FA), + Speaker Error Rate (SER), and Diarization Error Rate (DER). + + Arguments + --------- + ref_rttm : str + The path of reference/groundtruth RTTM file. + sys_rttm : str + The path of the system generated RTTM file. + individual_file : bool + If True, returns scores for each file in order. + collar : float + Forgiveness collar. + ignore_overlap : bool + If True, ignores overlapping speech during evaluation. + + Returns + ------- + MS : float array + Missed Speech. + FA : float array + False Alarms. + SER : float array + Speaker Error Rates. + DER : float array + Diarization Error Rates. + + Example + ------- + >>> import pytest + >>> pytest.skip('Skipping because of Perl dependency') + >>> ref_rttm = "../../samples/rttm_samples/ref_rttm/ES2014c.rttm" + >>> sys_rttm = "../../samples/rttm_samples/sys_rttm/ES2014c.rttm" + >>> ignore_overlap = True + >>> collar = 0.25 + >>> individual_file_scores = True + >>> Scores = DER(ref_rttm, sys_rttm, ignore_overlap, collar, individual_file_scores) + >>> print (Scores) + (array([0., 0.]), array([0., 0.]), array([7.16923618, 7.16923618]), array([7.16923618, 7.16923618])) + """ + + curr = os.path.abspath(os.path.dirname(__file__)) + mdEval = os.path.join(curr, "../../tools/der_eval/md-eval.pl") + + cmd = [ + mdEval, + "-af", + "-r", + ref_rttm, + "-s", + sys_rttm, + "-c", + str(collar), + ] + if ignore_overlap: + cmd.append("-1") + + try: + stdout = subprocess.check_output(cmd, stderr=subprocess.STDOUT) + + except subprocess.CalledProcessError as ex: + stdout = ex.output + + else: + stdout = stdout.decode("utf-8") + + # Get all recording IDs + file_ids = [m.strip() for m in FILE_IDS.findall(stdout)] + file_ids = [ + file_id[2:] if file_id.startswith("f=") else file_id + for file_id in file_ids + ] + + scored_speaker_times = np.array( + [float(m) for m in SCORED_SPEAKER_TIME.findall(stdout)] + ) + + miss_speaker_times = np.array( + [float(m) for m in MISS_SPEAKER_TIME.findall(stdout)] + ) + + fa_speaker_times = np.array( + [float(m) for m in FA_SPEAKER_TIME.findall(stdout)] + ) + + error_speaker_times = np.array( + [float(m) for m in ERROR_SPEAKER_TIME.findall(stdout)] + ) + + with np.errstate(invalid="ignore", divide="ignore"): + tot_error_times = ( + miss_speaker_times + fa_speaker_times + error_speaker_times + ) + miss_speaker_frac = miss_speaker_times / scored_speaker_times + fa_speaker_frac = fa_speaker_times / scored_speaker_times + sers_frac = error_speaker_times / scored_speaker_times + ders_frac = tot_error_times / scored_speaker_times + + # Values in percentage of scored_speaker_time + miss_speaker = rectify(miss_speaker_frac) + fa_speaker = rectify(fa_speaker_frac) + sers = rectify(sers_frac) + ders = rectify(ders_frac) + + if individual_file_scores: + return miss_speaker, fa_speaker, sers, ders + else: + return miss_speaker[-1], fa_speaker[-1], sers[-1], ders[-1]