Merge pull request #1335 from qingen/test-pr
[vector] add DER scripts, AMI data preparation scriptspull/1385/head
commit
7413c9e48a
@ -0,0 +1,3 @@
|
||||
# Speaker Diarization on AMI corpus
|
||||
|
||||
* sd0 - speaker diarization by AHC,SC base on x-vectors
|
@ -0,0 +1 @@
|
||||
results
|
@ -0,0 +1,13 @@
|
||||
# Speaker Diarization on AMI corpus
|
||||
|
||||
## About the AMI corpus:
|
||||
"The AMI Meeting Corpus consists of 100 hours of meeting recordings. The recordings use a range of signals synchronized to a common timeline. These include close-talking and far-field microphones, individual and room-view video cameras, and output from a slide projector and an electronic whiteboard. During the meetings, the participants also have unsynchronized pens available to them that record what is written. The meetings were recorded in English using three different rooms with different acoustic properties, and include mostly non-native speakers." See [ami overview](http://groups.inf.ed.ac.uk/ami/corpus/overview.shtml) for more details.
|
||||
|
||||
## About the example
|
||||
The script performs diarization using x-vectors(TDNN,ECAPA-TDNN) on the AMI mix-headset data. We demonstrate the use of different clustering methods: AHC, spectral.
|
||||
|
||||
## How to Run
|
||||
Use the following command to run diarization on AMI corpus.
|
||||
`bash ./run.sh`
|
||||
|
||||
## Results (DER) coming soon! :)
|
@ -0,0 +1,572 @@
|
||||
# Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
"""
|
||||
Data preparation.
|
||||
|
||||
Download: http://groups.inf.ed.ac.uk/ami/download/
|
||||
|
||||
Prepares metadata files (JSON) from manual annotations "segments/" using RTTM format (Oracle VAD).
|
||||
|
||||
Authors
|
||||
* qingenz123@126.com (Qingen ZHAO) 2022
|
||||
|
||||
"""
|
||||
|
||||
import os
|
||||
import logging
|
||||
import argparse
|
||||
import xml.etree.ElementTree as et
|
||||
import glob
|
||||
import json
|
||||
from ami_splits import get_AMI_split
|
||||
from distutils.util import strtobool
|
||||
|
||||
from dataio import (
|
||||
load_pkl,
|
||||
save_pkl, )
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
SAMPLERATE = 16000
|
||||
|
||||
|
||||
def prepare_ami(
|
||||
data_folder,
|
||||
manual_annot_folder,
|
||||
save_folder,
|
||||
ref_rttm_dir,
|
||||
meta_data_dir,
|
||||
split_type="full_corpus_asr",
|
||||
skip_TNO=True,
|
||||
mic_type="Mix-Headset",
|
||||
vad_type="oracle",
|
||||
max_subseg_dur=3.0,
|
||||
overlap=1.5, ):
|
||||
"""
|
||||
Prepares reference RTTM and JSON files for the AMI dataset.
|
||||
|
||||
Arguments
|
||||
---------
|
||||
data_folder : str
|
||||
Path to the folder where the original amicorpus is stored.
|
||||
manual_annot_folder : str
|
||||
Directory where the manual annotations are stored.
|
||||
save_folder : str
|
||||
The save directory in results.
|
||||
ref_rttm_dir : str
|
||||
Directory to store reference RTTM files.
|
||||
meta_data_dir : str
|
||||
Directory to store the meta data (json) files.
|
||||
split_type : str
|
||||
Standard dataset split. See ami_splits.py for more information.
|
||||
Allowed split_type: "scenario_only", "full_corpus" or "full_corpus_asr"
|
||||
skip_TNO: bool
|
||||
Skips TNO meeting recordings if True.
|
||||
mic_type : str
|
||||
Type of microphone to be used.
|
||||
vad_type : str
|
||||
Type of VAD. Kept for future when VAD will be added.
|
||||
max_subseg_dur : float
|
||||
Duration in seconds of a subsegments to be prepared from larger segments.
|
||||
overlap : float
|
||||
Overlap duration in seconds between adjacent subsegments
|
||||
|
||||
Example
|
||||
-------
|
||||
>>> from dataset.ami.ami_prepare import prepare_ami
|
||||
>>> data_folder = '/home/data/ami/amicorpus/'
|
||||
>>> manual_annot_folder = '/home/data/ami/ami_public_manual/'
|
||||
>>> save_folder = './results/
|
||||
>>> split_type = 'full_corpus_asr'
|
||||
>>> mic_type = 'Mix-Headset'
|
||||
>>> prepare_ami(data_folder, manual_annot_folder, save_folder, split_type, mic_type)
|
||||
"""
|
||||
|
||||
# Meta files
|
||||
meta_files = [
|
||||
os.path.join(meta_data_dir, "ami_train." + mic_type + ".subsegs.json"),
|
||||
os.path.join(meta_data_dir, "ami_dev." + mic_type + ".subsegs.json"),
|
||||
os.path.join(meta_data_dir, "ami_eval." + mic_type + ".subsegs.json"),
|
||||
]
|
||||
|
||||
# Create configuration for easily skipping data_preparation stage
|
||||
conf = {
|
||||
"data_folder": data_folder,
|
||||
"save_folder": save_folder,
|
||||
"ref_rttm_dir": ref_rttm_dir,
|
||||
"meta_data_dir": meta_data_dir,
|
||||
"split_type": split_type,
|
||||
"skip_TNO": skip_TNO,
|
||||
"mic_type": mic_type,
|
||||
"vad": vad_type,
|
||||
"max_subseg_dur": max_subseg_dur,
|
||||
"overlap": overlap,
|
||||
"meta_files": meta_files,
|
||||
}
|
||||
|
||||
if not os.path.exists(save_folder):
|
||||
os.makedirs(save_folder)
|
||||
|
||||
# Setting output option files.
|
||||
opt_file = "opt_ami_prepare." + mic_type + ".pkl"
|
||||
|
||||
# Check if this phase is already done (if so, skip it)
|
||||
if skip(save_folder, conf, meta_files, opt_file):
|
||||
logger.info(
|
||||
"Skipping data preparation, as it was completed in previous run.")
|
||||
return
|
||||
|
||||
msg = "\tCreating meta-data file for the AMI Dataset.."
|
||||
logger.debug(msg)
|
||||
|
||||
# Get the split
|
||||
train_set, dev_set, eval_set = get_AMI_split(split_type)
|
||||
|
||||
# Prepare RTTM from XML(manual annot) and store are groundtruth
|
||||
# Create ref_RTTM directory
|
||||
if not os.path.exists(ref_rttm_dir):
|
||||
os.makedirs(ref_rttm_dir)
|
||||
|
||||
# Create reference RTTM files
|
||||
splits = ["train", "dev", "eval"]
|
||||
for i in splits:
|
||||
rttm_file = ref_rttm_dir + "/fullref_ami_" + i + ".rttm"
|
||||
if i == "train":
|
||||
prepare_segs_for_RTTM(
|
||||
train_set,
|
||||
rttm_file,
|
||||
data_folder,
|
||||
manual_annot_folder,
|
||||
i,
|
||||
skip_TNO, )
|
||||
if i == "dev":
|
||||
prepare_segs_for_RTTM(
|
||||
dev_set,
|
||||
rttm_file,
|
||||
data_folder,
|
||||
manual_annot_folder,
|
||||
i,
|
||||
skip_TNO, )
|
||||
if i == "eval":
|
||||
prepare_segs_for_RTTM(
|
||||
eval_set,
|
||||
rttm_file,
|
||||
data_folder,
|
||||
manual_annot_folder,
|
||||
i,
|
||||
skip_TNO, )
|
||||
|
||||
# Create meta_files for splits
|
||||
meta_data_dir = meta_data_dir
|
||||
if not os.path.exists(meta_data_dir):
|
||||
os.makedirs(meta_data_dir)
|
||||
|
||||
for i in splits:
|
||||
rttm_file = ref_rttm_dir + "/fullref_ami_" + i + ".rttm"
|
||||
meta_filename_prefix = "ami_" + i
|
||||
prepare_metadata(
|
||||
rttm_file,
|
||||
meta_data_dir,
|
||||
data_folder,
|
||||
meta_filename_prefix,
|
||||
max_subseg_dur,
|
||||
overlap,
|
||||
mic_type, )
|
||||
|
||||
save_opt_file = os.path.join(save_folder, opt_file)
|
||||
save_pkl(conf, save_opt_file)
|
||||
|
||||
|
||||
def get_RTTM_per_rec(segs, spkrs_list, rec_id):
|
||||
"""Prepares rttm for each recording
|
||||
"""
|
||||
|
||||
rttm = []
|
||||
|
||||
# Prepare header
|
||||
for spkr_id in spkrs_list:
|
||||
# e.g. SPKR-INFO ES2008c 0 <NA> <NA> <NA> unknown ES2008c.A_PM <NA> <NA>
|
||||
line = ("SPKR-INFO " + rec_id + " 0 <NA> <NA> <NA> unknown " + spkr_id +
|
||||
" <NA> <NA>")
|
||||
rttm.append(line)
|
||||
|
||||
# Append remaining lines
|
||||
for row in segs:
|
||||
# e.g. SPEAKER ES2008c 0 37.880 0.590 <NA> <NA> ES2008c.A_PM <NA> <NA>
|
||||
|
||||
if float(row[1]) < float(row[0]):
|
||||
msg1 = (
|
||||
"Possibly Incorrect Annotation Found!! transcriber_start (%s) > transcriber_end (%s)"
|
||||
% (row[0], row[1]))
|
||||
msg2 = (
|
||||
"Excluding this incorrect row from the RTTM : %s, %s, %s, %s" %
|
||||
(rec_id, row[0], str(round(float(row[1]) - float(row[0]), 4)),
|
||||
str(row[2]), ))
|
||||
logger.info(msg1)
|
||||
logger.info(msg2)
|
||||
continue
|
||||
|
||||
line = ("SPEAKER " + rec_id + " 0 " + str(round(float(row[0]), 4)) + " "
|
||||
+ str(round(float(row[1]) - float(row[0]), 4)) + " <NA> <NA> " +
|
||||
str(row[2]) + " <NA> <NA>")
|
||||
rttm.append(line)
|
||||
|
||||
return rttm
|
||||
|
||||
|
||||
def prepare_segs_for_RTTM(list_ids, out_rttm_file, audio_dir, annot_dir,
|
||||
split_type, skip_TNO):
|
||||
|
||||
RTTM = [] # Stores all RTTMs clubbed together for a given dataset split
|
||||
|
||||
for main_meet_id in list_ids:
|
||||
|
||||
# Skip TNO meetings from dev and eval sets
|
||||
if (main_meet_id.startswith("TS") and split_type != "train" and
|
||||
skip_TNO is True):
|
||||
msg = ("Skipping TNO meeting in AMI " + str(split_type) + " set : "
|
||||
+ str(main_meet_id))
|
||||
logger.info(msg)
|
||||
continue
|
||||
|
||||
list_sessions = glob.glob(audio_dir + "/" + main_meet_id + "*")
|
||||
list_sessions.sort()
|
||||
|
||||
for sess in list_sessions:
|
||||
rec_id = os.path.basename(sess)
|
||||
path = annot_dir + "/segments/" + rec_id
|
||||
f = path + ".*.segments.xml"
|
||||
list_spkr_xmls = glob.glob(f)
|
||||
list_spkr_xmls.sort() # A, B, C, D, E etc (Speakers)
|
||||
segs = []
|
||||
spkrs_list = (
|
||||
[]) # Since non-scenario recordings contains 3-5 speakers
|
||||
|
||||
for spkr_xml_file in list_spkr_xmls:
|
||||
|
||||
# Speaker ID
|
||||
spkr = os.path.basename(spkr_xml_file).split(".")[1]
|
||||
spkr_ID = rec_id + "." + spkr
|
||||
spkrs_list.append(spkr_ID)
|
||||
|
||||
# Parse xml tree
|
||||
tree = et.parse(spkr_xml_file)
|
||||
root = tree.getroot()
|
||||
|
||||
# Start, end and speaker_ID from xml file
|
||||
segs = segs + [[
|
||||
elem.attrib["transcriber_start"],
|
||||
elem.attrib["transcriber_end"],
|
||||
spkr_ID,
|
||||
] for elem in root.iter("segment")]
|
||||
|
||||
# Sort rows as per the start time (per recording)
|
||||
segs.sort(key=lambda x: float(x[0]))
|
||||
|
||||
rttm_per_rec = get_RTTM_per_rec(segs, spkrs_list, rec_id)
|
||||
RTTM = RTTM + rttm_per_rec
|
||||
|
||||
# Write one RTTM as groundtruth. For example, "fullref_eval.rttm"
|
||||
with open(out_rttm_file, "w") as f:
|
||||
for item in RTTM:
|
||||
f.write("%s\n" % item)
|
||||
|
||||
|
||||
def is_overlapped(end1, start2):
|
||||
"""Returns True if the two segments overlap
|
||||
|
||||
Arguments
|
||||
---------
|
||||
end1 : float
|
||||
End time of the first segment.
|
||||
start2 : float
|
||||
Start time of the second segment.
|
||||
"""
|
||||
|
||||
if start2 > end1:
|
||||
return False
|
||||
else:
|
||||
return True
|
||||
|
||||
|
||||
def merge_rttm_intervals(rttm_segs):
|
||||
"""Merges adjacent segments in rttm if they overlap.
|
||||
"""
|
||||
# For one recording
|
||||
# rec_id = rttm_segs[0][1]
|
||||
rttm_segs.sort(key=lambda x: float(x[3]))
|
||||
|
||||
# first_seg = rttm_segs[0] # first interval.. as it is
|
||||
merged_segs = [rttm_segs[0]]
|
||||
strt = float(rttm_segs[0][3])
|
||||
end = float(rttm_segs[0][3]) + float(rttm_segs[0][4])
|
||||
|
||||
for row in rttm_segs[1:]:
|
||||
s = float(row[3])
|
||||
e = float(row[3]) + float(row[4])
|
||||
|
||||
if is_overlapped(end, s):
|
||||
# Update only end. The strt will be same as in last segment
|
||||
# Just update last row in the merged_segs
|
||||
end = max(end, e)
|
||||
merged_segs[-1][3] = str(round(strt, 4))
|
||||
merged_segs[-1][4] = str(round((end - strt), 4))
|
||||
merged_segs[-1][7] = "overlap" # previous_row[7] + '-'+ row[7]
|
||||
else:
|
||||
# Add a new disjoint segment
|
||||
strt = s
|
||||
end = e
|
||||
merged_segs.append(row) # this will have 1 spkr ID
|
||||
|
||||
return merged_segs
|
||||
|
||||
|
||||
def get_subsegments(merged_segs, max_subseg_dur=3.0, overlap=1.5):
|
||||
"""Divides bigger segments into smaller sub-segments
|
||||
"""
|
||||
|
||||
shift = max_subseg_dur - overlap
|
||||
subsegments = []
|
||||
|
||||
# These rows are in RTTM format
|
||||
for row in merged_segs:
|
||||
seg_dur = float(row[4])
|
||||
rec_id = row[1]
|
||||
|
||||
if seg_dur > max_subseg_dur:
|
||||
num_subsegs = int(seg_dur / shift)
|
||||
# Taking 0.01 sec as small step
|
||||
seg_start = float(row[3])
|
||||
seg_end = seg_start + seg_dur
|
||||
|
||||
# Now divide this segment (new_row) in smaller subsegments
|
||||
for i in range(num_subsegs):
|
||||
subseg_start = seg_start + i * shift
|
||||
subseg_end = min(subseg_start + max_subseg_dur - 0.01, seg_end)
|
||||
subseg_dur = subseg_end - subseg_start
|
||||
|
||||
new_row = [
|
||||
"SPEAKER",
|
||||
rec_id,
|
||||
"0",
|
||||
str(round(float(subseg_start), 4)),
|
||||
str(round(float(subseg_dur), 4)),
|
||||
"<NA>",
|
||||
"<NA>",
|
||||
row[7],
|
||||
"<NA>",
|
||||
"<NA>",
|
||||
]
|
||||
|
||||
subsegments.append(new_row)
|
||||
|
||||
# Break if exceeding the boundary
|
||||
if subseg_end >= seg_end:
|
||||
break
|
||||
else:
|
||||
subsegments.append(row)
|
||||
|
||||
return subsegments
|
||||
|
||||
|
||||
def prepare_metadata(rttm_file, save_dir, data_dir, filename, max_subseg_dur,
|
||||
overlap, mic_type):
|
||||
# Read RTTM, get unique meeting_IDs (from RTTM headers)
|
||||
# For each MeetingID. select that meetID -> merge -> subsegment -> json -> append
|
||||
|
||||
# Read RTTM
|
||||
RTTM = []
|
||||
with open(rttm_file, "r") as f:
|
||||
for line in f:
|
||||
entry = line[:-1]
|
||||
RTTM.append(entry)
|
||||
|
||||
spkr_info = filter(lambda x: x.startswith("SPKR-INFO"), RTTM)
|
||||
rec_ids = list(set([row.split(" ")[1] for row in spkr_info]))
|
||||
rec_ids.sort() # sorting just to make JSON look in proper sequence
|
||||
|
||||
# For each recording merge segments and then perform subsegmentation
|
||||
MERGED_SEGMENTS = []
|
||||
SUBSEGMENTS = []
|
||||
for rec_id in rec_ids:
|
||||
segs_iter = filter(lambda x: x.startswith("SPEAKER " + str(rec_id)),
|
||||
RTTM)
|
||||
gt_rttm_segs = [row.split(" ") for row in segs_iter]
|
||||
|
||||
# Merge, subsegment and then convert to json format.
|
||||
merged_segs = merge_rttm_intervals(
|
||||
gt_rttm_segs) # We lose speaker_ID after merging
|
||||
MERGED_SEGMENTS = MERGED_SEGMENTS + merged_segs
|
||||
|
||||
# Divide segments into smaller sub-segments
|
||||
subsegs = get_subsegments(merged_segs, max_subseg_dur, overlap)
|
||||
SUBSEGMENTS = SUBSEGMENTS + subsegs
|
||||
|
||||
# Write segment AND sub-segments (in RTTM format)
|
||||
segs_file = save_dir + "/" + filename + ".segments.rttm"
|
||||
subsegment_file = save_dir + "/" + filename + ".subsegments.rttm"
|
||||
|
||||
with open(segs_file, "w") as f:
|
||||
for row in MERGED_SEGMENTS:
|
||||
line_str = " ".join(row)
|
||||
f.write("%s\n" % line_str)
|
||||
|
||||
with open(subsegment_file, "w") as f:
|
||||
for row in SUBSEGMENTS:
|
||||
line_str = " ".join(row)
|
||||
f.write("%s\n" % line_str)
|
||||
|
||||
# Create JSON from subsegments
|
||||
json_dict = {}
|
||||
for row in SUBSEGMENTS:
|
||||
rec_id = row[1]
|
||||
strt = str(round(float(row[3]), 4))
|
||||
end = str(round((float(row[3]) + float(row[4])), 4))
|
||||
subsegment_ID = rec_id + "_" + strt + "_" + end
|
||||
dur = row[4]
|
||||
start_sample = int(float(strt) * SAMPLERATE)
|
||||
end_sample = int(float(end) * SAMPLERATE)
|
||||
|
||||
# If multi-mic audio is selected
|
||||
if mic_type == "Array1":
|
||||
wav_file_base_path = (data_dir + "/" + rec_id + "/audio/" + rec_id +
|
||||
"." + mic_type + "-")
|
||||
|
||||
f = [] # adding all 8 mics
|
||||
for i in range(8):
|
||||
f.append(wav_file_base_path + str(i + 1).zfill(2) + ".wav")
|
||||
audio_files_path_list = f
|
||||
|
||||
# Note: key "files" with 's' is used for multi-mic
|
||||
json_dict[subsegment_ID] = {
|
||||
"wav": {
|
||||
"files": audio_files_path_list,
|
||||
"duration": float(dur),
|
||||
"start": int(start_sample),
|
||||
"stop": int(end_sample),
|
||||
},
|
||||
}
|
||||
else:
|
||||
# Single mic audio
|
||||
wav_file_path = (data_dir + "/" + rec_id + "/audio/" + rec_id + "."
|
||||
+ mic_type + ".wav")
|
||||
|
||||
# Note: key "file" without 's' is used for single-mic
|
||||
json_dict[subsegment_ID] = {
|
||||
"wav": {
|
||||
"file": wav_file_path,
|
||||
"duration": float(dur),
|
||||
"start": int(start_sample),
|
||||
"stop": int(end_sample),
|
||||
},
|
||||
}
|
||||
|
||||
out_json_file = save_dir + "/" + filename + "." + mic_type + ".subsegs.json"
|
||||
with open(out_json_file, mode="w") as json_f:
|
||||
json.dump(json_dict, json_f, indent=2)
|
||||
|
||||
msg = "%s JSON prepared" % (out_json_file)
|
||||
logger.debug(msg)
|
||||
|
||||
|
||||
def skip(save_folder, conf, meta_files, opt_file):
|
||||
"""
|
||||
Detects if the AMI data_preparation has been already done.
|
||||
If the preparation has been done, we can skip it.
|
||||
|
||||
Returns
|
||||
-------
|
||||
bool
|
||||
if True, the preparation phase can be skipped.
|
||||
if False, it must be done.
|
||||
"""
|
||||
# Checking if meta (json) files are available
|
||||
skip = True
|
||||
for file_path in meta_files:
|
||||
if not os.path.isfile(file_path):
|
||||
skip = False
|
||||
|
||||
# Checking saved options
|
||||
save_opt_file = os.path.join(save_folder, opt_file)
|
||||
if skip is True:
|
||||
if os.path.isfile(save_opt_file):
|
||||
opts_old = load_pkl(save_opt_file)
|
||||
if opts_old == conf:
|
||||
skip = True
|
||||
else:
|
||||
skip = False
|
||||
else:
|
||||
skip = False
|
||||
|
||||
return skip
|
||||
|
||||
|
||||
if __name__ == '__main__':
|
||||
|
||||
parser = argparse.ArgumentParser(
|
||||
prog='python ami_prepare.py --data_folder /home/data/ami/amicorpus \
|
||||
--manual_annot_folder /home/data/ami/ami_public_manual_1.6.2 \
|
||||
--save_folder ./results/ --ref_rttm_dir ./results/ref_rttms \
|
||||
--meta_data_dir ./results/metadata',
|
||||
description='AMI Data preparation')
|
||||
parser.add_argument(
|
||||
'--data_folder',
|
||||
required=True,
|
||||
help='Path to the folder where the original amicorpus is stored')
|
||||
parser.add_argument(
|
||||
'--manual_annot_folder',
|
||||
required=True,
|
||||
help='Directory where the manual annotations are stored')
|
||||
parser.add_argument(
|
||||
'--save_folder', required=True, help='The save directory in results')
|
||||
parser.add_argument(
|
||||
'--ref_rttm_dir',
|
||||
required=True,
|
||||
help='Directory to store reference RTTM files')
|
||||
parser.add_argument(
|
||||
'--meta_data_dir',
|
||||
required=True,
|
||||
help='Directory to store the meta data (json) files')
|
||||
parser.add_argument(
|
||||
'--split_type',
|
||||
default="full_corpus_asr",
|
||||
help='Standard dataset split. See ami_splits.py for more information')
|
||||
parser.add_argument(
|
||||
'--skip_TNO',
|
||||
default=True,
|
||||
type=strtobool,
|
||||
help='Skips TNO meeting recordings if True')
|
||||
parser.add_argument(
|
||||
'--mic_type',
|
||||
default="Mix-Headset",
|
||||
help='Type of microphone to be used')
|
||||
parser.add_argument(
|
||||
'--vad_type',
|
||||
default="oracle",
|
||||
help='Type of VAD. Kept for future when VAD will be added')
|
||||
parser.add_argument(
|
||||
'--max_subseg_dur',
|
||||
default=3.0,
|
||||
type=float,
|
||||
help='Duration in seconds of a subsegments to be prepared from larger segments'
|
||||
)
|
||||
parser.add_argument(
|
||||
'--overlap',
|
||||
default=1.5,
|
||||
type=float,
|
||||
help='Overlap duration in seconds between adjacent subsegments')
|
||||
|
||||
args = parser.parse_args()
|
||||
|
||||
prepare_ami(args.data_folder, args.manual_annot_folder, args.save_folder,
|
||||
args.ref_rttm_dir, args.meta_data_dir)
|
@ -0,0 +1,234 @@
|
||||
# Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
"""
|
||||
AMI corpus contained 100 hours of meeting recording.
|
||||
This script returns the standard train, dev and eval split for AMI corpus.
|
||||
For more information on dataset please refer to http://groups.inf.ed.ac.uk/ami/corpus/datasets.shtml
|
||||
|
||||
Authors
|
||||
* qingenz123@126.com (Qingen ZHAO) 2022
|
||||
|
||||
"""
|
||||
|
||||
ALLOWED_OPTIONS = ["scenario_only", "full_corpus", "full_corpus_asr"]
|
||||
|
||||
|
||||
def get_AMI_split(split_option):
|
||||
"""
|
||||
Prepares train, dev, and test sets for given split_option
|
||||
|
||||
Arguments
|
||||
---------
|
||||
split_option: str
|
||||
The standard split option.
|
||||
Allowed options: "scenario_only", "full_corpus", "full_corpus_asr"
|
||||
|
||||
Returns
|
||||
-------
|
||||
Meeting IDs for train, dev, and test sets for given split_option
|
||||
"""
|
||||
|
||||
if split_option not in ALLOWED_OPTIONS:
|
||||
print(
|
||||
f'Invalid split "{split_option}" requested!\nValid split_options are: ',
|
||||
ALLOWED_OPTIONS, )
|
||||
return
|
||||
|
||||
if split_option == "scenario_only":
|
||||
|
||||
train_set = [
|
||||
"ES2002",
|
||||
"ES2005",
|
||||
"ES2006",
|
||||
"ES2007",
|
||||
"ES2008",
|
||||
"ES2009",
|
||||
"ES2010",
|
||||
"ES2012",
|
||||
"ES2013",
|
||||
"ES2015",
|
||||
"ES2016",
|
||||
"IS1000",
|
||||
"IS1001",
|
||||
"IS1002",
|
||||
"IS1003",
|
||||
"IS1004",
|
||||
"IS1005",
|
||||
"IS1006",
|
||||
"IS1007",
|
||||
"TS3005",
|
||||
"TS3008",
|
||||
"TS3009",
|
||||
"TS3010",
|
||||
"TS3011",
|
||||
"TS3012",
|
||||
]
|
||||
|
||||
dev_set = [
|
||||
"ES2003",
|
||||
"ES2011",
|
||||
"IS1008",
|
||||
"TS3004",
|
||||
"TS3006",
|
||||
]
|
||||
|
||||
test_set = [
|
||||
"ES2004",
|
||||
"ES2014",
|
||||
"IS1009",
|
||||
"TS3003",
|
||||
"TS3007",
|
||||
]
|
||||
|
||||
if split_option == "full_corpus":
|
||||
# List of train: SA (TRAINING PART OF SEEN DATA)
|
||||
train_set = [
|
||||
"ES2002",
|
||||
"ES2005",
|
||||
"ES2006",
|
||||
"ES2007",
|
||||
"ES2008",
|
||||
"ES2009",
|
||||
"ES2010",
|
||||
"ES2012",
|
||||
"ES2013",
|
||||
"ES2015",
|
||||
"ES2016",
|
||||
"IS1000",
|
||||
"IS1001",
|
||||
"IS1002",
|
||||
"IS1003",
|
||||
"IS1004",
|
||||
"IS1005",
|
||||
"IS1006",
|
||||
"IS1007",
|
||||
"TS3005",
|
||||
"TS3008",
|
||||
"TS3009",
|
||||
"TS3010",
|
||||
"TS3011",
|
||||
"TS3012",
|
||||
"EN2001",
|
||||
"EN2003",
|
||||
"EN2004",
|
||||
"EN2005",
|
||||
"EN2006",
|
||||
"EN2009",
|
||||
"IN1001",
|
||||
"IN1002",
|
||||
"IN1005",
|
||||
"IN1007",
|
||||
"IN1008",
|
||||
"IN1009",
|
||||
"IN1012",
|
||||
"IN1013",
|
||||
"IN1014",
|
||||
"IN1016",
|
||||
]
|
||||
|
||||
# List of dev: SB (DEV PART OF SEEN DATA)
|
||||
dev_set = [
|
||||
"ES2003",
|
||||
"ES2011",
|
||||
"IS1008",
|
||||
"TS3004",
|
||||
"TS3006",
|
||||
"IB4001",
|
||||
"IB4002",
|
||||
"IB4003",
|
||||
"IB4004",
|
||||
"IB4010",
|
||||
"IB4011",
|
||||
]
|
||||
|
||||
# List of test: SC (UNSEEN DATA FOR EVALUATION)
|
||||
# Note that IB4005 does not appear because it has speakers in common with two sets of data.
|
||||
test_set = [
|
||||
"ES2004",
|
||||
"ES2014",
|
||||
"IS1009",
|
||||
"TS3003",
|
||||
"TS3007",
|
||||
"EN2002",
|
||||
]
|
||||
|
||||
if split_option == "full_corpus_asr":
|
||||
train_set = [
|
||||
"ES2002",
|
||||
"ES2003",
|
||||
"ES2005",
|
||||
"ES2006",
|
||||
"ES2007",
|
||||
"ES2008",
|
||||
"ES2009",
|
||||
"ES2010",
|
||||
"ES2012",
|
||||
"ES2013",
|
||||
"ES2014",
|
||||
"ES2015",
|
||||
"ES2016",
|
||||
"IS1000",
|
||||
"IS1001",
|
||||
"IS1002",
|
||||
"IS1003",
|
||||
"IS1004",
|
||||
"IS1005",
|
||||
"IS1006",
|
||||
"IS1007",
|
||||
"TS3005",
|
||||
"TS3006",
|
||||
"TS3007",
|
||||
"TS3008",
|
||||
"TS3009",
|
||||
"TS3010",
|
||||
"TS3011",
|
||||
"TS3012",
|
||||
"EN2001",
|
||||
"EN2003",
|
||||
"EN2004",
|
||||
"EN2005",
|
||||
"EN2006",
|
||||
"EN2009",
|
||||
"IN1001",
|
||||
"IN1002",
|
||||
"IN1005",
|
||||
"IN1007",
|
||||
"IN1008",
|
||||
"IN1009",
|
||||
"IN1012",
|
||||
"IN1013",
|
||||
"IN1014",
|
||||
"IN1016",
|
||||
]
|
||||
|
||||
dev_set = [
|
||||
"ES2011",
|
||||
"IS1008",
|
||||
"TS3004",
|
||||
"IB4001",
|
||||
"IB4002",
|
||||
"IB4003",
|
||||
"IB4004",
|
||||
"IB4010",
|
||||
"IB4011",
|
||||
]
|
||||
|
||||
test_set = [
|
||||
"ES2004",
|
||||
"IS1009",
|
||||
"TS3003",
|
||||
"EN2002",
|
||||
]
|
||||
|
||||
return train_set, dev_set, test_set
|
@ -0,0 +1,49 @@
|
||||
#!/bin/bash
|
||||
|
||||
stage=1
|
||||
|
||||
TARGET_DIR=${MAIN_ROOT}/dataset/ami
|
||||
data_folder=${TARGET_DIR}/amicorpus #e.g., /path/to/amicorpus/
|
||||
manual_annot_folder=${TARGET_DIR}/ami_public_manual_1.6.2 #e.g., /path/to/ami_public_manual_1.6.2/
|
||||
|
||||
save_folder=${MAIN_ROOT}/examples/ami/sd0/data
|
||||
ref_rttm_dir=${save_folder}/ref_rttms
|
||||
meta_data_dir=${save_folder}/metadata
|
||||
|
||||
set=L
|
||||
|
||||
. ${MAIN_ROOT}/utils/parse_options.sh || exit 1;
|
||||
set -u
|
||||
set -o pipefail
|
||||
|
||||
mkdir -p ${save_folder}
|
||||
|
||||
if [ ${stage} -le 0 ]; then
|
||||
# Download AMI corpus, You need around 10GB of free space to get whole data
|
||||
# The signals are too large to package in this way,
|
||||
# so you need to use the chooser to indicate which ones you wish to download
|
||||
echo "Please follow https://groups.inf.ed.ac.uk/ami/download/ to download the data."
|
||||
echo "Annotations: AMI manual annotations v1.6.2 "
|
||||
echo "Signals: "
|
||||
echo "1) Select one or more AMI meetings: the IDs please follow ./ami_split.py"
|
||||
echo "2) Select media streams: Just select Headset mix"
|
||||
exit 0;
|
||||
fi
|
||||
|
||||
if [ ${stage} -le 1 ]; then
|
||||
echo "AMI Data preparation"
|
||||
|
||||
python local/ami_prepare.py --data_folder ${data_folder} \
|
||||
--manual_annot_folder ${manual_annot_folder} \
|
||||
--save_folder ${save_folder} --ref_rttm_dir ${ref_rttm_dir} \
|
||||
--meta_data_dir ${meta_data_dir}
|
||||
|
||||
if [ $? -ne 0 ]; then
|
||||
echo "Prepare AMI failed. Please check log message."
|
||||
exit 1
|
||||
fi
|
||||
|
||||
fi
|
||||
|
||||
echo "AMI data preparation done."
|
||||
exit 0
|
@ -0,0 +1,97 @@
|
||||
# Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
"""
|
||||
Data reading and writing.
|
||||
|
||||
Authors
|
||||
* qingenz123@126.com (Qingen ZHAO) 2022
|
||||
|
||||
"""
|
||||
import os
|
||||
import pickle
|
||||
|
||||
|
||||
def save_pkl(obj, file):
|
||||
"""Save an object in pkl format.
|
||||
|
||||
Arguments
|
||||
---------
|
||||
obj : object
|
||||
Object to save in pkl format
|
||||
file : str
|
||||
Path to the output file
|
||||
sampling_rate : int
|
||||
Sampling rate of the audio file, TODO: this is not used?
|
||||
|
||||
Example
|
||||
-------
|
||||
>>> tmpfile = os.path.join(getfixture('tmpdir'), "example.pkl")
|
||||
>>> save_pkl([1, 2, 3, 4, 5], tmpfile)
|
||||
>>> load_pkl(tmpfile)
|
||||
[1, 2, 3, 4, 5]
|
||||
"""
|
||||
with open(file, "wb") as f:
|
||||
pickle.dump(obj, f)
|
||||
|
||||
|
||||
def load_pickle(pickle_path):
|
||||
"""Utility function for loading .pkl pickle files.
|
||||
|
||||
Arguments
|
||||
---------
|
||||
pickle_path : str
|
||||
Path to pickle file.
|
||||
|
||||
Returns
|
||||
-------
|
||||
out : object
|
||||
Python object loaded from pickle.
|
||||
"""
|
||||
with open(pickle_path, "rb") as f:
|
||||
out = pickle.load(f)
|
||||
return out
|
||||
|
||||
|
||||
def load_pkl(file):
|
||||
"""Loads a pkl file.
|
||||
|
||||
For an example, see `save_pkl`.
|
||||
|
||||
Arguments
|
||||
---------
|
||||
file : str
|
||||
Path to the input pkl file.
|
||||
|
||||
Returns
|
||||
-------
|
||||
The loaded object.
|
||||
"""
|
||||
|
||||
# Deals with the situation where two processes are trying
|
||||
# to access the same label dictionary by creating a lock
|
||||
count = 100
|
||||
while count > 0:
|
||||
if os.path.isfile(file + ".lock"):
|
||||
time.sleep(1)
|
||||
count -= 1
|
||||
else:
|
||||
break
|
||||
|
||||
try:
|
||||
open(file + ".lock", "w").close()
|
||||
with open(file, "rb") as f:
|
||||
return pickle.load(f)
|
||||
finally:
|
||||
if os.path.isfile(file + ".lock"):
|
||||
os.remove(file + ".lock")
|
@ -0,0 +1,15 @@
|
||||
export MAIN_ROOT=`realpath ${PWD}/../../../`
|
||||
|
||||
export PATH=${MAIN_ROOT}:${MAIN_ROOT}/utils:${PATH}
|
||||
export LC_ALL=C
|
||||
|
||||
export PYTHONDONTWRITEBYTECODE=1
|
||||
# Use UTF-8 in Python to avoid UnicodeDecodeError when LC_ALL=C
|
||||
export PYTHONIOENCODING=UTF-8
|
||||
export PYTHONPATH=${MAIN_ROOT}:${PYTHONPATH}
|
||||
|
||||
export LD_LIBRARY_PATH=${LD_LIBRARY_PATH}:/usr/local/lib/
|
||||
|
||||
# model exp
|
||||
#MODEL=ECAPA_TDNN
|
||||
#export BIN_DIR=${MAIN_ROOT}/paddlespeech/vector/exps/${MODEL}/bin
|
@ -0,0 +1,14 @@
|
||||
#!/bin/bash
|
||||
|
||||
. path.sh || exit 1;
|
||||
set -e
|
||||
|
||||
stage=1
|
||||
|
||||
|
||||
. ${MAIN_ROOT}/utils/parse_options.sh || exit 1;
|
||||
|
||||
if [ ${stage} -le 1 ]; then
|
||||
# prepare data
|
||||
bash ./local/data.sh || exit -1
|
||||
fi
|
@ -0,0 +1 @@
|
||||
../../../utils
|
@ -0,0 +1,192 @@
|
||||
# Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
"""Calculates Diarization Error Rate (DER) which is the sum of Missed Speaker (MS),
|
||||
False Alarm (FA), and Speaker Error Rate (SER) using md-eval-22.pl from NIST RT Evaluation.
|
||||
|
||||
Authors
|
||||
* Neville Ryant 2018
|
||||
* Nauman Dawalatabad 2020
|
||||
* qingenz123@126.com (Qingen ZHAO) 2022
|
||||
|
||||
Credits
|
||||
This code is adapted from https://github.com/nryant/dscore
|
||||
"""
|
||||
import argparse
|
||||
from distutils.util import strtobool
|
||||
import os
|
||||
import re
|
||||
import subprocess
|
||||
import numpy as np
|
||||
|
||||
FILE_IDS = re.compile(r"(?<=Speaker Diarization for).+(?=\*\*\*)")
|
||||
SCORED_SPEAKER_TIME = re.compile(r"(?<=SCORED SPEAKER TIME =)[\d.]+")
|
||||
MISS_SPEAKER_TIME = re.compile(r"(?<=MISSED SPEAKER TIME =)[\d.]+")
|
||||
FA_SPEAKER_TIME = re.compile(r"(?<=FALARM SPEAKER TIME =)[\d.]+")
|
||||
ERROR_SPEAKER_TIME = re.compile(r"(?<=SPEAKER ERROR TIME =)[\d.]+")
|
||||
|
||||
|
||||
def rectify(arr):
|
||||
"""Corrects corner cases and converts scores into percentage.
|
||||
"""
|
||||
# Numerator and denominator both 0.
|
||||
arr[np.isnan(arr)] = 0
|
||||
|
||||
# Numerator > 0, but denominator = 0.
|
||||
arr[np.isinf(arr)] = 1
|
||||
arr *= 100.0
|
||||
|
||||
return arr
|
||||
|
||||
|
||||
def DER(
|
||||
ref_rttm,
|
||||
sys_rttm,
|
||||
ignore_overlap=False,
|
||||
collar=0.25,
|
||||
individual_file_scores=False, ):
|
||||
"""Computes Missed Speaker percentage (MS), False Alarm (FA),
|
||||
Speaker Error Rate (SER), and Diarization Error Rate (DER).
|
||||
|
||||
Arguments
|
||||
---------
|
||||
ref_rttm : str
|
||||
The path of reference/groundtruth RTTM file.
|
||||
sys_rttm : str
|
||||
The path of the system generated RTTM file.
|
||||
individual_file : bool
|
||||
If True, returns scores for each file in order.
|
||||
collar : float
|
||||
Forgiveness collar.
|
||||
ignore_overlap : bool
|
||||
If True, ignores overlapping speech during evaluation.
|
||||
|
||||
Returns
|
||||
-------
|
||||
MS : float array
|
||||
Missed Speech.
|
||||
FA : float array
|
||||
False Alarms.
|
||||
SER : float array
|
||||
Speaker Error Rates.
|
||||
DER : float array
|
||||
Diarization Error Rates.
|
||||
|
||||
Example
|
||||
-------
|
||||
>>> import pytest
|
||||
>>> pytest.skip('Skipping because of Perl dependency')
|
||||
>>> ref_rttm = "../../samples/rttm_samples/ref_rttm/ES2014c.rttm"
|
||||
>>> sys_rttm = "../../samples/rttm_samples/sys_rttm/ES2014c.rttm"
|
||||
>>> ignore_overlap = True
|
||||
>>> collar = 0.25
|
||||
>>> individual_file_scores = True
|
||||
>>> Scores = DER(ref_rttm, sys_rttm, ignore_overlap, collar, individual_file_scores)
|
||||
>>> print (Scores)
|
||||
(array([0., 0.]), array([0., 0.]), array([7.16923618, 7.16923618]), array([7.16923618, 7.16923618]))
|
||||
"""
|
||||
|
||||
curr = os.path.abspath(os.path.dirname(__file__))
|
||||
mdEval = os.path.join(curr, "./md-eval.pl")
|
||||
|
||||
cmd = [
|
||||
mdEval,
|
||||
"-af",
|
||||
"-r",
|
||||
ref_rttm,
|
||||
"-s",
|
||||
sys_rttm,
|
||||
"-c",
|
||||
str(collar),
|
||||
]
|
||||
if ignore_overlap:
|
||||
cmd.append("-1")
|
||||
|
||||
try:
|
||||
stdout = subprocess.check_output(cmd, stderr=subprocess.STDOUT)
|
||||
|
||||
except subprocess.CalledProcessError as ex:
|
||||
stdout = ex.output
|
||||
|
||||
else:
|
||||
stdout = stdout.decode("utf-8")
|
||||
|
||||
# Get all recording IDs
|
||||
file_ids = [m.strip() for m in FILE_IDS.findall(stdout)]
|
||||
file_ids = [
|
||||
file_id[2:] if file_id.startswith("f=") else file_id
|
||||
for file_id in file_ids
|
||||
]
|
||||
|
||||
scored_speaker_times = np.array(
|
||||
[float(m) for m in SCORED_SPEAKER_TIME.findall(stdout)])
|
||||
|
||||
miss_speaker_times = np.array(
|
||||
[float(m) for m in MISS_SPEAKER_TIME.findall(stdout)])
|
||||
|
||||
fa_speaker_times = np.array(
|
||||
[float(m) for m in FA_SPEAKER_TIME.findall(stdout)])
|
||||
|
||||
error_speaker_times = np.array(
|
||||
[float(m) for m in ERROR_SPEAKER_TIME.findall(stdout)])
|
||||
|
||||
with np.errstate(invalid="ignore", divide="ignore"):
|
||||
tot_error_times = (
|
||||
miss_speaker_times + fa_speaker_times + error_speaker_times)
|
||||
miss_speaker_frac = miss_speaker_times / scored_speaker_times
|
||||
fa_speaker_frac = fa_speaker_times / scored_speaker_times
|
||||
sers_frac = error_speaker_times / scored_speaker_times
|
||||
ders_frac = tot_error_times / scored_speaker_times
|
||||
|
||||
# Values in percentage of scored_speaker_time
|
||||
miss_speaker = rectify(miss_speaker_frac)
|
||||
fa_speaker = rectify(fa_speaker_frac)
|
||||
sers = rectify(sers_frac)
|
||||
ders = rectify(ders_frac)
|
||||
|
||||
if individual_file_scores:
|
||||
return miss_speaker, fa_speaker, sers, ders
|
||||
else:
|
||||
return miss_speaker[-1], fa_speaker[-1], sers[-1], ders[-1]
|
||||
|
||||
|
||||
if __name__ == '__main__':
|
||||
|
||||
parser = argparse.ArgumentParser(
|
||||
description='Compute Diarization Error Rate')
|
||||
parser.add_argument(
|
||||
'--ref_rttm',
|
||||
required=True,
|
||||
help='the path of reference/groundtruth RTTM file')
|
||||
parser.add_argument(
|
||||
'--sys_rttm',
|
||||
required=True,
|
||||
help='the path of the system generated RTTM file')
|
||||
parser.add_argument(
|
||||
'--individual_file',
|
||||
default=False,
|
||||
type=strtobool,
|
||||
help='if True, returns scores for each file in order')
|
||||
parser.add_argument(
|
||||
'--collar', default=0.25, type=float, help='forgiveness collar')
|
||||
parser.add_argument(
|
||||
'--ignore_overlap',
|
||||
default=False,
|
||||
type=strtobool,
|
||||
help='if True, ignores overlapping speech during evaluation')
|
||||
args = parser.parse_args()
|
||||
print(args)
|
||||
|
||||
der = DER(args.ref_rttm, args.sys_rttm)
|
||||
print("miss_speaker: %.3f%% fa_speaker: %.3f%% sers: %.3f%% ders: %.3f%%" %
|
||||
(der[0], der[1], der[2], der[-1]))
|
File diff suppressed because it is too large
Load Diff
Loading…
Reference in new issue