[vector] add AMI data preparation scripts

pull/1335/head
qingen 4 years ago
parent 98788ca27e
commit 14d9e80b0d

@ -1,3 +1,16 @@
# Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
""" """
Data preparation. Data preparation.
@ -21,26 +34,24 @@ from distutils.util import strtobool
from utils.dataio import ( from utils.dataio import (
load_pkl, load_pkl,
save_pkl, save_pkl, )
)
logger = logging.getLogger(__name__) logger = logging.getLogger(__name__)
SAMPLERATE = 16000 SAMPLERATE = 16000
def prepare_ami( def prepare_ami(
data_folder, data_folder,
manual_annot_folder, manual_annot_folder,
save_folder, save_folder,
ref_rttm_dir, ref_rttm_dir,
meta_data_dir, meta_data_dir,
split_type="full_corpus_asr", split_type="full_corpus_asr",
skip_TNO=True, skip_TNO=True,
mic_type="Mix-Headset", mic_type="Mix-Headset",
vad_type="oracle", vad_type="oracle",
max_subseg_dur=3.0, max_subseg_dur=3.0,
overlap=1.5, overlap=1.5, ):
):
""" """
Prepares reference RTTM and JSON files for the AMI dataset. Prepares reference RTTM and JSON files for the AMI dataset.
@ -72,12 +83,12 @@ def prepare_ami(
Example Example
------- -------
>>> from recipes.AMI.ami_prepare import prepare_ami >>> from dataset.ami.ami_prepare import prepare_ami
>>> data_folder = '/network/datasets/ami/amicorpus/' >>> data_folder = '/home/data/ami/amicorpus/'
>>> manual_annot_folder = '/home/mila/d/dawalatn/nauman/ami_public_manual/' >>> manual_annot_folder = '/home/data/ami/ami_public_manual/'
>>> save_folder = 'results/save/' >>> save_folder = './results/
>>> split_type = 'full_corpus_asr' >>> split_type = 'full_corpus_asr'
>>> mic_type = 'Lapel' >>> mic_type = 'Mix-Headset'
>>> prepare_ami(data_folder, manual_annot_folder, save_folder, split_type, mic_type) >>> prepare_ami(data_folder, manual_annot_folder, save_folder, split_type, mic_type)
""" """
@ -112,8 +123,7 @@ def prepare_ami(
# Check if this phase is already done (if so, skip it) # Check if this phase is already done (if so, skip it)
if skip(save_folder, conf, meta_files, opt_file): if skip(save_folder, conf, meta_files, opt_file):
logger.info( logger.info(
"Skipping data preparation, as it was completed in previous run." "Skipping data preparation, as it was completed in previous run.")
)
return return
msg = "\tCreating meta-data file for the AMI Dataset.." msg = "\tCreating meta-data file for the AMI Dataset.."
@ -138,8 +148,7 @@ def prepare_ami(
data_folder, data_folder,
manual_annot_folder, manual_annot_folder,
i, i,
skip_TNO, skip_TNO, )
)
if i == "dev": if i == "dev":
prepare_segs_for_RTTM( prepare_segs_for_RTTM(
dev_set, dev_set,
@ -147,8 +156,7 @@ def prepare_ami(
data_folder, data_folder,
manual_annot_folder, manual_annot_folder,
i, i,
skip_TNO, skip_TNO, )
)
if i == "eval": if i == "eval":
prepare_segs_for_RTTM( prepare_segs_for_RTTM(
eval_set, eval_set,
@ -156,8 +164,7 @@ def prepare_ami(
data_folder, data_folder,
manual_annot_folder, manual_annot_folder,
i, i,
skip_TNO, skip_TNO, )
)
# Create meta_files for splits # Create meta_files for splits
meta_data_dir = meta_data_dir meta_data_dir = meta_data_dir
@ -174,8 +181,7 @@ def prepare_ami(
meta_filename_prefix, meta_filename_prefix,
max_subseg_dur, max_subseg_dur,
overlap, overlap,
mic_type, mic_type, )
)
save_opt_file = os.path.join(save_folder, opt_file) save_opt_file = os.path.join(save_folder, opt_file)
save_pkl(conf, save_opt_file) save_pkl(conf, save_opt_file)
@ -190,13 +196,8 @@ def get_RTTM_per_rec(segs, spkrs_list, rec_id):
# Prepare header # Prepare header
for spkr_id in spkrs_list: for spkr_id in spkrs_list:
# e.g. SPKR-INFO ES2008c 0 <NA> <NA> <NA> unknown ES2008c.A_PM <NA> <NA> # e.g. SPKR-INFO ES2008c 0 <NA> <NA> <NA> unknown ES2008c.A_PM <NA> <NA>
line = ( line = ("SPKR-INFO " + rec_id + " 0 <NA> <NA> <NA> unknown " + spkr_id +
"SPKR-INFO " " <NA> <NA>")
+ rec_id
+ " 0 <NA> <NA> <NA> unknown "
+ spkr_id
+ " <NA> <NA>"
)
rttm.append(line) rttm.append(line)
# Append remaining lines # Append remaining lines
@ -206,57 +207,35 @@ def get_RTTM_per_rec(segs, spkrs_list, rec_id):
if float(row[1]) < float(row[0]): if float(row[1]) < float(row[0]):
msg1 = ( msg1 = (
"Possibly Incorrect Annotation Found!! transcriber_start (%s) > transcriber_end (%s)" "Possibly Incorrect Annotation Found!! transcriber_start (%s) > transcriber_end (%s)"
% (row[0], row[1]) % (row[0], row[1]))
)
msg2 = ( msg2 = (
"Excluding this incorrect row from the RTTM : %s, %s, %s, %s" "Excluding this incorrect row from the RTTM : %s, %s, %s, %s" %
% ( (rec_id, row[0], str(round(float(row[1]) - float(row[0]), 4)),
rec_id, str(row[2]), ))
row[0],
str(round(float(row[1]) - float(row[0]), 4)),
str(row[2]),
)
)
logger.info(msg1) logger.info(msg1)
logger.info(msg2) logger.info(msg2)
continue continue
line = ( line = ("SPEAKER " + rec_id + " 0 " + str(round(float(row[0]), 4)) + " "
"SPEAKER " + str(round(float(row[1]) - float(row[0]), 4)) + " <NA> <NA> " +
+ rec_id str(row[2]) + " <NA> <NA>")
+ " 0 "
+ str(round(float(row[0]), 4))
+ " "
+ str(round(float(row[1]) - float(row[0]), 4))
+ " <NA> <NA> "
+ str(row[2])
+ " <NA> <NA>"
)
rttm.append(line) rttm.append(line)
return rttm return rttm
def prepare_segs_for_RTTM( def prepare_segs_for_RTTM(list_ids, out_rttm_file, audio_dir, annot_dir,
list_ids, out_rttm_file, audio_dir, annot_dir, split_type, skip_TNO split_type, skip_TNO):
):
RTTM = [] # Stores all RTTMs clubbed together for a given dataset split RTTM = [] # Stores all RTTMs clubbed together for a given dataset split
for main_meet_id in list_ids: for main_meet_id in list_ids:
# Skip TNO meetings from dev and eval sets # Skip TNO meetings from dev and eval sets
if ( if (main_meet_id.startswith("TS") and split_type != "train" and
main_meet_id.startswith("TS") skip_TNO is True):
and split_type != "train" msg = ("Skipping TNO meeting in AMI " + str(split_type) + " set : "
and skip_TNO is True + str(main_meet_id))
):
msg = (
"Skipping TNO meeting in AMI "
+ str(split_type)
+ " set : "
+ str(main_meet_id)
)
logger.info(msg) logger.info(msg)
continue continue
@ -271,8 +250,7 @@ def prepare_segs_for_RTTM(
list_spkr_xmls.sort() # A, B, C, D, E etc (Speakers) list_spkr_xmls.sort() # A, B, C, D, E etc (Speakers)
segs = [] segs = []
spkrs_list = ( spkrs_list = (
[] []) # Since non-scenario recordings contains 3-5 speakers
) # Since non-scenario recordings contains 3-5 speakers
for spkr_xml_file in list_spkr_xmls: for spkr_xml_file in list_spkr_xmls:
@ -286,14 +264,11 @@ def prepare_segs_for_RTTM(
root = tree.getroot() root = tree.getroot()
# Start, end and speaker_ID from xml file # Start, end and speaker_ID from xml file
segs = segs + [ segs = segs + [[
[ elem.attrib["transcriber_start"],
elem.attrib["transcriber_start"], elem.attrib["transcriber_end"],
elem.attrib["transcriber_end"], spkr_ID,
spkr_ID, ] for elem in root.iter("segment")]
]
for elem in root.iter("segment")
]
# Sort rows as per the start time (per recording) # Sort rows as per the start time (per recording)
segs.sort(key=lambda x: float(x[0])) segs.sort(key=lambda x: float(x[0]))
@ -404,9 +379,8 @@ def get_subsegments(merged_segs, max_subseg_dur=3.0, overlap=1.5):
return subsegments return subsegments
def prepare_metadata( def prepare_metadata(rttm_file, save_dir, data_dir, filename, max_subseg_dur,
rttm_file, save_dir, data_dir, filename, max_subseg_dur, overlap, mic_type overlap, mic_type):
):
# Read RTTM, get unique meeting_IDs (from RTTM headers) # Read RTTM, get unique meeting_IDs (from RTTM headers)
# For each MeetingID. select that meetID -> merge -> subsegment -> json -> append # For each MeetingID. select that meetID -> merge -> subsegment -> json -> append
@ -425,15 +399,13 @@ def prepare_metadata(
MERGED_SEGMENTS = [] MERGED_SEGMENTS = []
SUBSEGMENTS = [] SUBSEGMENTS = []
for rec_id in rec_ids: for rec_id in rec_ids:
segs_iter = filter( segs_iter = filter(lambda x: x.startswith("SPEAKER " + str(rec_id)),
lambda x: x.startswith("SPEAKER " + str(rec_id)), RTTM RTTM)
)
gt_rttm_segs = [row.split(" ") for row in segs_iter] gt_rttm_segs = [row.split(" ") for row in segs_iter]
# Merge, subsegment and then convert to json format. # Merge, subsegment and then convert to json format.
merged_segs = merge_rttm_intervals( merged_segs = merge_rttm_intervals(
gt_rttm_segs gt_rttm_segs) # We lose speaker_ID after merging
) # We lose speaker_ID after merging
MERGED_SEGMENTS = MERGED_SEGMENTS + merged_segs MERGED_SEGMENTS = MERGED_SEGMENTS + merged_segs
# Divide segments into smaller sub-segments # Divide segments into smaller sub-segments
@ -467,16 +439,8 @@ def prepare_metadata(
# If multi-mic audio is selected # If multi-mic audio is selected
if mic_type == "Array1": if mic_type == "Array1":
wav_file_base_path = ( wav_file_base_path = (data_dir + "/" + rec_id + "/audio/" + rec_id +
data_dir "." + mic_type + "-")
+ "/"
+ rec_id
+ "/audio/"
+ rec_id
+ "."
+ mic_type
+ "-"
)
f = [] # adding all 8 mics f = [] # adding all 8 mics
for i in range(8): for i in range(8):
@ -494,16 +458,8 @@ def prepare_metadata(
} }
else: else:
# Single mic audio # Single mic audio
wav_file_path = ( wav_file_path = (data_dir + "/" + rec_id + "/audio/" + rec_id + "."
data_dir + mic_type + ".wav")
+ "/"
+ rec_id
+ "/audio/"
+ rec_id
+ "."
+ mic_type
+ ".wav"
)
# Note: key "file" without 's' is used for single-mic # Note: key "file" without 's' is used for single-mic
json_dict[subsegment_ID] = { json_dict[subsegment_ID] = {
@ -554,6 +510,7 @@ def skip(save_folder, conf, meta_files, opt_file):
return skip return skip
if __name__ == '__main__': if __name__ == '__main__':
parser = argparse.ArgumentParser( parser = argparse.ArgumentParser(
@ -561,42 +518,56 @@ if __name__ == '__main__':
--manual_annot_folder /home/data/ami/ami_public_manual_1.6.2 \ --manual_annot_folder /home/data/ami/ami_public_manual_1.6.2 \
--save_folder ./results/ --ref_rttm_dir ./results/ref_rttms \ --save_folder ./results/ --ref_rttm_dir ./results/ref_rttms \
--meta_data_dir ./results/metadata', --meta_data_dir ./results/metadata',
description='AMI Data preparation') description='AMI Data preparation')
parser.add_argument( parser.add_argument(
'--data_folder', required=True, help='Path to the folder where the original amicorpus is stored') '--data_folder',
required=True,
help='Path to the folder where the original amicorpus is stored')
parser.add_argument( parser.add_argument(
'--manual_annot_folder', required=True, help='Directory where the manual annotations are stored') '--manual_annot_folder',
required=True,
help='Directory where the manual annotations are stored')
parser.add_argument( parser.add_argument(
'--save_folder', required=True, help='The save directory in results') '--save_folder', required=True, help='The save directory in results')
parser.add_argument( parser.add_argument(
'--ref_rttm_dir', required=True, help='Directory to store reference RTTM files') '--ref_rttm_dir',
required=True,
help='Directory to store reference RTTM files')
parser.add_argument( parser.add_argument(
'--meta_data_dir', required=True, help='Directory to store the meta data (json) files') '--meta_data_dir',
required=True,
help='Directory to store the meta data (json) files')
parser.add_argument( parser.add_argument(
'--split_type', '--split_type',
default="full_corpus_asr", default="full_corpus_asr",
help='Standard dataset split. See ami_splits.py for more information') help='Standard dataset split. See ami_splits.py for more information')
parser.add_argument( parser.add_argument(
'--skip_TNO', default=True, type=strtobool, help='Skips TNO meeting recordings if True') '--skip_TNO',
default=True,
type=strtobool,
help='Skips TNO meeting recordings if True')
parser.add_argument( parser.add_argument(
'--mic_type', default="Mix-Headset", help='Type of microphone to be used') '--mic_type',
default="Mix-Headset",
help='Type of microphone to be used')
parser.add_argument( parser.add_argument(
'--vad_type', default="oracle", help='Type of VAD. Kept for future when VAD will be added') '--vad_type',
default="oracle",
help='Type of VAD. Kept for future when VAD will be added')
parser.add_argument( parser.add_argument(
'--max_subseg_dur', '--max_subseg_dur',
default=3.0, default=3.0,
type=float, type=float,
help='Duration in seconds of a subsegments to be prepared from larger segments') help='Duration in seconds of a subsegments to be prepared from larger segments'
)
parser.add_argument( parser.add_argument(
'--overlap', default=1.5, type=float, help='Overlap duration in seconds between adjacent subsegments') '--overlap',
default=1.5,
type=float,
help='Overlap duration in seconds between adjacent subsegments')
args = parser.parse_args() args = parser.parse_args()
print(args) print(args)
prepare_ami( prepare_ami(args.data_folder, args.manual_annot_folder, args.save_folder,
args.data_folder, args.ref_rttm_dir, args.meta_data_dir)
args.manual_annot_folder,
args.save_folder,
args.ref_rttm_dir,
args.meta_data_dir
)

@ -1,3 +1,16 @@
# Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
""" """
AMI corpus contained 100 hours of meeting recording. AMI corpus contained 100 hours of meeting recording.
This script returns the standard train, dev and eval split for AMI corpus. This script returns the standard train, dev and eval split for AMI corpus.
@ -29,8 +42,7 @@ def get_AMI_split(split_option):
if split_option not in ALLOWED_OPTIONS: if split_option not in ALLOWED_OPTIONS:
print( print(
f'Invalid split "{split_option}" requested!\nValid split_options are: ', f'Invalid split "{split_option}" requested!\nValid split_options are: ',
ALLOWED_OPTIONS, ALLOWED_OPTIONS, )
)
return return
if split_option == "scenario_only": if split_option == "scenario_only":

@ -1,3 +1,16 @@
# Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
""" """
Data reading and writing. Data reading and writing.
@ -5,10 +18,10 @@ Authors
* qingenz123@126.com (Qingen ZHAO) 2022 * qingenz123@126.com (Qingen ZHAO) 2022
""" """
import os import os
import pickle import pickle
def save_pkl(obj, file): def save_pkl(obj, file):
"""Save an object in pkl format. """Save an object in pkl format.
@ -31,6 +44,7 @@ def save_pkl(obj, file):
with open(file, "wb") as f: with open(file, "wb") as f:
pickle.dump(obj, f) pickle.dump(obj, f)
def load_pickle(pickle_path): def load_pickle(pickle_path):
"""Utility function for loading .pkl pickle files. """Utility function for loading .pkl pickle files.
@ -48,6 +62,7 @@ def load_pickle(pickle_path):
out = pickle.load(f) out = pickle.load(f)
return out return out
def load_pkl(file): def load_pkl(file):
"""Loads a pkl file. """Loads a pkl file.
@ -79,4 +94,4 @@ def load_pkl(file):
return pickle.load(f) return pickle.load(f)
finally: finally:
if os.path.isfile(file + ".lock"): if os.path.isfile(file + ".lock"):
os.remove(file + ".lock") os.remove(file + ".lock")

@ -1,3 +1,16 @@
# Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""Calculates Diarization Error Rate (DER) which is the sum of Missed Speaker (MS), """Calculates Diarization Error Rate (DER) which is the sum of Missed Speaker (MS),
False Alarm (FA), and Speaker Error Rate (SER) using md-eval-22.pl from NIST RT Evaluation. False Alarm (FA), and Speaker Error Rate (SER) using md-eval-22.pl from NIST RT Evaluation.
@ -26,7 +39,6 @@ ERROR_SPEAKER_TIME = re.compile(r"(?<=SPEAKER ERROR TIME =)[\d.]+")
def rectify(arr): def rectify(arr):
"""Corrects corner cases and converts scores into percentage. """Corrects corner cases and converts scores into percentage.
""" """
# Numerator and denominator both 0. # Numerator and denominator both 0.
arr[np.isnan(arr)] = 0 arr[np.isnan(arr)] = 0
@ -38,12 +50,11 @@ def rectify(arr):
def DER( def DER(
ref_rttm, ref_rttm,
sys_rttm, sys_rttm,
ignore_overlap=False, ignore_overlap=False,
collar=0.25, collar=0.25,
individual_file_scores=False, individual_file_scores=False, ):
):
"""Computes Missed Speaker percentage (MS), False Alarm (FA), """Computes Missed Speaker percentage (MS), False Alarm (FA),
Speaker Error Rate (SER), and Diarization Error Rate (DER). Speaker Error Rate (SER), and Diarization Error Rate (DER).
@ -118,25 +129,20 @@ def DER(
] ]
scored_speaker_times = np.array( scored_speaker_times = np.array(
[float(m) for m in SCORED_SPEAKER_TIME.findall(stdout)] [float(m) for m in SCORED_SPEAKER_TIME.findall(stdout)])
)
miss_speaker_times = np.array( miss_speaker_times = np.array(
[float(m) for m in MISS_SPEAKER_TIME.findall(stdout)] [float(m) for m in MISS_SPEAKER_TIME.findall(stdout)])
)
fa_speaker_times = np.array( fa_speaker_times = np.array(
[float(m) for m in FA_SPEAKER_TIME.findall(stdout)] [float(m) for m in FA_SPEAKER_TIME.findall(stdout)])
)
error_speaker_times = np.array( error_speaker_times = np.array(
[float(m) for m in ERROR_SPEAKER_TIME.findall(stdout)] [float(m) for m in ERROR_SPEAKER_TIME.findall(stdout)])
)
with np.errstate(invalid="ignore", divide="ignore"): with np.errstate(invalid="ignore", divide="ignore"):
tot_error_times = ( tot_error_times = (
miss_speaker_times + fa_speaker_times + error_speaker_times miss_speaker_times + fa_speaker_times + error_speaker_times)
)
miss_speaker_frac = miss_speaker_times / scored_speaker_times miss_speaker_frac = miss_speaker_times / scored_speaker_times
fa_speaker_frac = fa_speaker_times / scored_speaker_times fa_speaker_frac = fa_speaker_times / scored_speaker_times
sers_frac = error_speaker_times / scored_speaker_times sers_frac = error_speaker_times / scored_speaker_times
@ -153,13 +159,19 @@ def DER(
else: else:
return miss_speaker[-1], fa_speaker[-1], sers[-1], ders[-1] return miss_speaker[-1], fa_speaker[-1], sers[-1], ders[-1]
if __name__ == '__main__': if __name__ == '__main__':
parser = argparse.ArgumentParser(description='Compute Diarization Error Rate') parser = argparse.ArgumentParser(
description='Compute Diarization Error Rate')
parser.add_argument( parser.add_argument(
'--ref_rttm', required=True, help='the path of reference/groundtruth RTTM file') '--ref_rttm',
required=True,
help='the path of reference/groundtruth RTTM file')
parser.add_argument( parser.add_argument(
'--sys_rttm', required=True, help='the path of the system generated RTTM file') '--sys_rttm',
required=True,
help='the path of the system generated RTTM file')
parser.add_argument( parser.add_argument(
'--individual_file', '--individual_file',
default=False, default=False,
@ -176,4 +188,5 @@ if __name__ == '__main__':
print(args) print(args)
der = DER(args.ref_rttm, args.sys_rttm) der = DER(args.ref_rttm, args.sys_rttm)
print("miss_speaker: %.3f%% fa_speaker: %.3f%% sers: %.3f%% ders: %.3f%%" % (der[0], der[1], der[2], der[-1])) print("miss_speaker: %.3f%% fa_speaker: %.3f%% sers: %.3f%% ders: %.3f%%" %
(der[0], der[1], der[2], der[-1]))
Loading…
Cancel
Save