diff --git a/dataset/ami/ami_prepare.py b/dataset/ami/ami_prepare.py index 8c0fc62dc..44993cead 100644 --- a/dataset/ami/ami_prepare.py +++ b/dataset/ami/ami_prepare.py @@ -1,3 +1,16 @@ +# Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. """ Data preparation. @@ -21,26 +34,24 @@ from distutils.util import strtobool from utils.dataio import ( load_pkl, - save_pkl, -) + save_pkl, ) logger = logging.getLogger(__name__) SAMPLERATE = 16000 def prepare_ami( - data_folder, - manual_annot_folder, - save_folder, - ref_rttm_dir, - meta_data_dir, - split_type="full_corpus_asr", - skip_TNO=True, - mic_type="Mix-Headset", - vad_type="oracle", - max_subseg_dur=3.0, - overlap=1.5, -): + data_folder, + manual_annot_folder, + save_folder, + ref_rttm_dir, + meta_data_dir, + split_type="full_corpus_asr", + skip_TNO=True, + mic_type="Mix-Headset", + vad_type="oracle", + max_subseg_dur=3.0, + overlap=1.5, ): """ Prepares reference RTTM and JSON files for the AMI dataset. @@ -72,12 +83,12 @@ def prepare_ami( Example ------- - >>> from recipes.AMI.ami_prepare import prepare_ami - >>> data_folder = '/network/datasets/ami/amicorpus/' - >>> manual_annot_folder = '/home/mila/d/dawalatn/nauman/ami_public_manual/' - >>> save_folder = 'results/save/' + >>> from dataset.ami.ami_prepare import prepare_ami + >>> data_folder = '/home/data/ami/amicorpus/' + >>> manual_annot_folder = '/home/data/ami/ami_public_manual/' + >>> save_folder = './results/ >>> split_type = 'full_corpus_asr' - >>> mic_type = 'Lapel' + >>> mic_type = 'Mix-Headset' >>> prepare_ami(data_folder, manual_annot_folder, save_folder, split_type, mic_type) """ @@ -112,8 +123,7 @@ def prepare_ami( # Check if this phase is already done (if so, skip it) if skip(save_folder, conf, meta_files, opt_file): logger.info( - "Skipping data preparation, as it was completed in previous run." - ) + "Skipping data preparation, as it was completed in previous run.") return msg = "\tCreating meta-data file for the AMI Dataset.." @@ -138,8 +148,7 @@ def prepare_ami( data_folder, manual_annot_folder, i, - skip_TNO, - ) + skip_TNO, ) if i == "dev": prepare_segs_for_RTTM( dev_set, @@ -147,8 +156,7 @@ def prepare_ami( data_folder, manual_annot_folder, i, - skip_TNO, - ) + skip_TNO, ) if i == "eval": prepare_segs_for_RTTM( eval_set, @@ -156,8 +164,7 @@ def prepare_ami( data_folder, manual_annot_folder, i, - skip_TNO, - ) + skip_TNO, ) # Create meta_files for splits meta_data_dir = meta_data_dir @@ -174,8 +181,7 @@ def prepare_ami( meta_filename_prefix, max_subseg_dur, overlap, - mic_type, - ) + mic_type, ) save_opt_file = os.path.join(save_folder, opt_file) save_pkl(conf, save_opt_file) @@ -190,13 +196,8 @@ def get_RTTM_per_rec(segs, spkrs_list, rec_id): # Prepare header for spkr_id in spkrs_list: # e.g. SPKR-INFO ES2008c 0 unknown ES2008c.A_PM - line = ( - "SPKR-INFO " - + rec_id - + " 0 unknown " - + spkr_id - + " " - ) + line = ("SPKR-INFO " + rec_id + " 0 unknown " + spkr_id + + " ") rttm.append(line) # Append remaining lines @@ -206,57 +207,35 @@ def get_RTTM_per_rec(segs, spkrs_list, rec_id): if float(row[1]) < float(row[0]): msg1 = ( "Possibly Incorrect Annotation Found!! transcriber_start (%s) > transcriber_end (%s)" - % (row[0], row[1]) - ) + % (row[0], row[1])) msg2 = ( - "Excluding this incorrect row from the RTTM : %s, %s, %s, %s" - % ( - rec_id, - row[0], - str(round(float(row[1]) - float(row[0]), 4)), - str(row[2]), - ) - ) + "Excluding this incorrect row from the RTTM : %s, %s, %s, %s" % + (rec_id, row[0], str(round(float(row[1]) - float(row[0]), 4)), + str(row[2]), )) logger.info(msg1) logger.info(msg2) continue - line = ( - "SPEAKER " - + rec_id - + " 0 " - + str(round(float(row[0]), 4)) - + " " - + str(round(float(row[1]) - float(row[0]), 4)) - + " " - + str(row[2]) - + " " - ) + line = ("SPEAKER " + rec_id + " 0 " + str(round(float(row[0]), 4)) + " " + + str(round(float(row[1]) - float(row[0]), 4)) + " " + + str(row[2]) + " ") rttm.append(line) return rttm -def prepare_segs_for_RTTM( - list_ids, out_rttm_file, audio_dir, annot_dir, split_type, skip_TNO -): +def prepare_segs_for_RTTM(list_ids, out_rttm_file, audio_dir, annot_dir, + split_type, skip_TNO): RTTM = [] # Stores all RTTMs clubbed together for a given dataset split for main_meet_id in list_ids: # Skip TNO meetings from dev and eval sets - if ( - main_meet_id.startswith("TS") - and split_type != "train" - and skip_TNO is True - ): - msg = ( - "Skipping TNO meeting in AMI " - + str(split_type) - + " set : " - + str(main_meet_id) - ) + if (main_meet_id.startswith("TS") and split_type != "train" and + skip_TNO is True): + msg = ("Skipping TNO meeting in AMI " + str(split_type) + " set : " + + str(main_meet_id)) logger.info(msg) continue @@ -271,8 +250,7 @@ def prepare_segs_for_RTTM( list_spkr_xmls.sort() # A, B, C, D, E etc (Speakers) segs = [] spkrs_list = ( - [] - ) # Since non-scenario recordings contains 3-5 speakers + []) # Since non-scenario recordings contains 3-5 speakers for spkr_xml_file in list_spkr_xmls: @@ -286,14 +264,11 @@ def prepare_segs_for_RTTM( root = tree.getroot() # Start, end and speaker_ID from xml file - segs = segs + [ - [ - elem.attrib["transcriber_start"], - elem.attrib["transcriber_end"], - spkr_ID, - ] - for elem in root.iter("segment") - ] + segs = segs + [[ + elem.attrib["transcriber_start"], + elem.attrib["transcriber_end"], + spkr_ID, + ] for elem in root.iter("segment")] # Sort rows as per the start time (per recording) segs.sort(key=lambda x: float(x[0])) @@ -404,9 +379,8 @@ def get_subsegments(merged_segs, max_subseg_dur=3.0, overlap=1.5): return subsegments -def prepare_metadata( - rttm_file, save_dir, data_dir, filename, max_subseg_dur, overlap, mic_type -): +def prepare_metadata(rttm_file, save_dir, data_dir, filename, max_subseg_dur, + overlap, mic_type): # Read RTTM, get unique meeting_IDs (from RTTM headers) # For each MeetingID. select that meetID -> merge -> subsegment -> json -> append @@ -425,15 +399,13 @@ def prepare_metadata( MERGED_SEGMENTS = [] SUBSEGMENTS = [] for rec_id in rec_ids: - segs_iter = filter( - lambda x: x.startswith("SPEAKER " + str(rec_id)), RTTM - ) + segs_iter = filter(lambda x: x.startswith("SPEAKER " + str(rec_id)), + RTTM) gt_rttm_segs = [row.split(" ") for row in segs_iter] # Merge, subsegment and then convert to json format. merged_segs = merge_rttm_intervals( - gt_rttm_segs - ) # We lose speaker_ID after merging + gt_rttm_segs) # We lose speaker_ID after merging MERGED_SEGMENTS = MERGED_SEGMENTS + merged_segs # Divide segments into smaller sub-segments @@ -467,16 +439,8 @@ def prepare_metadata( # If multi-mic audio is selected if mic_type == "Array1": - wav_file_base_path = ( - data_dir - + "/" - + rec_id - + "/audio/" - + rec_id - + "." - + mic_type - + "-" - ) + wav_file_base_path = (data_dir + "/" + rec_id + "/audio/" + rec_id + + "." + mic_type + "-") f = [] # adding all 8 mics for i in range(8): @@ -494,16 +458,8 @@ def prepare_metadata( } else: # Single mic audio - wav_file_path = ( - data_dir - + "/" - + rec_id - + "/audio/" - + rec_id - + "." - + mic_type - + ".wav" - ) + wav_file_path = (data_dir + "/" + rec_id + "/audio/" + rec_id + "." + + mic_type + ".wav") # Note: key "file" without 's' is used for single-mic json_dict[subsegment_ID] = { @@ -554,6 +510,7 @@ def skip(save_folder, conf, meta_files, opt_file): return skip + if __name__ == '__main__': parser = argparse.ArgumentParser( @@ -561,42 +518,56 @@ if __name__ == '__main__': --manual_annot_folder /home/data/ami/ami_public_manual_1.6.2 \ --save_folder ./results/ --ref_rttm_dir ./results/ref_rttms \ --meta_data_dir ./results/metadata', - description='AMI Data preparation') + description='AMI Data preparation') parser.add_argument( - '--data_folder', required=True, help='Path to the folder where the original amicorpus is stored') + '--data_folder', + required=True, + help='Path to the folder where the original amicorpus is stored') parser.add_argument( - '--manual_annot_folder', required=True, help='Directory where the manual annotations are stored') + '--manual_annot_folder', + required=True, + help='Directory where the manual annotations are stored') parser.add_argument( '--save_folder', required=True, help='The save directory in results') parser.add_argument( - '--ref_rttm_dir', required=True, help='Directory to store reference RTTM files') + '--ref_rttm_dir', + required=True, + help='Directory to store reference RTTM files') parser.add_argument( - '--meta_data_dir', required=True, help='Directory to store the meta data (json) files') + '--meta_data_dir', + required=True, + help='Directory to store the meta data (json) files') parser.add_argument( - '--split_type', - default="full_corpus_asr", + '--split_type', + default="full_corpus_asr", help='Standard dataset split. See ami_splits.py for more information') parser.add_argument( - '--skip_TNO', default=True, type=strtobool, help='Skips TNO meeting recordings if True') + '--skip_TNO', + default=True, + type=strtobool, + help='Skips TNO meeting recordings if True') parser.add_argument( - '--mic_type', default="Mix-Headset", help='Type of microphone to be used') + '--mic_type', + default="Mix-Headset", + help='Type of microphone to be used') parser.add_argument( - '--vad_type', default="oracle", help='Type of VAD. Kept for future when VAD will be added') + '--vad_type', + default="oracle", + help='Type of VAD. Kept for future when VAD will be added') parser.add_argument( - '--max_subseg_dur', - default=3.0, - type=float, - help='Duration in seconds of a subsegments to be prepared from larger segments') + '--max_subseg_dur', + default=3.0, + type=float, + help='Duration in seconds of a subsegments to be prepared from larger segments' + ) parser.add_argument( - '--overlap', default=1.5, type=float, help='Overlap duration in seconds between adjacent subsegments') + '--overlap', + default=1.5, + type=float, + help='Overlap duration in seconds between adjacent subsegments') args = parser.parse_args() print(args) - prepare_ami( - args.data_folder, - args.manual_annot_folder, - args.save_folder, - args.ref_rttm_dir, - args.meta_data_dir - ) \ No newline at end of file + prepare_ami(args.data_folder, args.manual_annot_folder, args.save_folder, + args.ref_rttm_dir, args.meta_data_dir) diff --git a/dataset/ami/ami_splits.py b/dataset/ami/ami_splits.py index 52a8ddc04..010638a39 100644 --- a/dataset/ami/ami_splits.py +++ b/dataset/ami/ami_splits.py @@ -1,3 +1,16 @@ +# Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. """ AMI corpus contained 100 hours of meeting recording. This script returns the standard train, dev and eval split for AMI corpus. @@ -29,8 +42,7 @@ def get_AMI_split(split_option): if split_option not in ALLOWED_OPTIONS: print( f'Invalid split "{split_option}" requested!\nValid split_options are: ', - ALLOWED_OPTIONS, - ) + ALLOWED_OPTIONS, ) return if split_option == "scenario_only": diff --git a/utils/dataio.py b/dataset/ami/dataio.py similarity index 71% rename from utils/dataio.py rename to dataset/ami/dataio.py index 48f792052..f7fe88157 100644 --- a/utils/dataio.py +++ b/dataset/ami/dataio.py @@ -1,3 +1,16 @@ +# Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. """ Data reading and writing. @@ -5,10 +18,10 @@ Authors * qingenz123@126.com (Qingen ZHAO) 2022 """ - import os import pickle + def save_pkl(obj, file): """Save an object in pkl format. @@ -31,6 +44,7 @@ def save_pkl(obj, file): with open(file, "wb") as f: pickle.dump(obj, f) + def load_pickle(pickle_path): """Utility function for loading .pkl pickle files. @@ -48,6 +62,7 @@ def load_pickle(pickle_path): out = pickle.load(f) return out + def load_pkl(file): """Loads a pkl file. @@ -79,4 +94,4 @@ def load_pkl(file): return pickle.load(f) finally: if os.path.isfile(file + ".lock"): - os.remove(file + ".lock") \ No newline at end of file + os.remove(file + ".lock") diff --git a/utils/DER.py b/paddlespeech/vector/utils/DER.py similarity index 80% rename from utils/DER.py rename to paddlespeech/vector/utils/DER.py index 11dfa8cfa..5b62094df 100755 --- a/utils/DER.py +++ b/paddlespeech/vector/utils/DER.py @@ -1,3 +1,16 @@ +# Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. """Calculates Diarization Error Rate (DER) which is the sum of Missed Speaker (MS), False Alarm (FA), and Speaker Error Rate (SER) using md-eval-22.pl from NIST RT Evaluation. @@ -26,7 +39,6 @@ ERROR_SPEAKER_TIME = re.compile(r"(?<=SPEAKER ERROR TIME =)[\d.]+") def rectify(arr): """Corrects corner cases and converts scores into percentage. """ - # Numerator and denominator both 0. arr[np.isnan(arr)] = 0 @@ -38,12 +50,11 @@ def rectify(arr): def DER( - ref_rttm, - sys_rttm, - ignore_overlap=False, - collar=0.25, - individual_file_scores=False, -): + ref_rttm, + sys_rttm, + ignore_overlap=False, + collar=0.25, + individual_file_scores=False, ): """Computes Missed Speaker percentage (MS), False Alarm (FA), Speaker Error Rate (SER), and Diarization Error Rate (DER). @@ -118,25 +129,20 @@ def DER( ] scored_speaker_times = np.array( - [float(m) for m in SCORED_SPEAKER_TIME.findall(stdout)] - ) + [float(m) for m in SCORED_SPEAKER_TIME.findall(stdout)]) miss_speaker_times = np.array( - [float(m) for m in MISS_SPEAKER_TIME.findall(stdout)] - ) + [float(m) for m in MISS_SPEAKER_TIME.findall(stdout)]) fa_speaker_times = np.array( - [float(m) for m in FA_SPEAKER_TIME.findall(stdout)] - ) + [float(m) for m in FA_SPEAKER_TIME.findall(stdout)]) error_speaker_times = np.array( - [float(m) for m in ERROR_SPEAKER_TIME.findall(stdout)] - ) + [float(m) for m in ERROR_SPEAKER_TIME.findall(stdout)]) with np.errstate(invalid="ignore", divide="ignore"): tot_error_times = ( - miss_speaker_times + fa_speaker_times + error_speaker_times - ) + miss_speaker_times + fa_speaker_times + error_speaker_times) miss_speaker_frac = miss_speaker_times / scored_speaker_times fa_speaker_frac = fa_speaker_times / scored_speaker_times sers_frac = error_speaker_times / scored_speaker_times @@ -153,13 +159,19 @@ def DER( else: return miss_speaker[-1], fa_speaker[-1], sers[-1], ders[-1] + if __name__ == '__main__': - parser = argparse.ArgumentParser(description='Compute Diarization Error Rate') + parser = argparse.ArgumentParser( + description='Compute Diarization Error Rate') parser.add_argument( - '--ref_rttm', required=True, help='the path of reference/groundtruth RTTM file') + '--ref_rttm', + required=True, + help='the path of reference/groundtruth RTTM file') parser.add_argument( - '--sys_rttm', required=True, help='the path of the system generated RTTM file') + '--sys_rttm', + required=True, + help='the path of the system generated RTTM file') parser.add_argument( '--individual_file', default=False, @@ -176,4 +188,5 @@ if __name__ == '__main__': print(args) der = DER(args.ref_rttm, args.sys_rttm) - print("miss_speaker: %.3f%% fa_speaker: %.3f%% sers: %.3f%% ders: %.3f%%" % (der[0], der[1], der[2], der[-1])) \ No newline at end of file + print("miss_speaker: %.3f%% fa_speaker: %.3f%% sers: %.3f%% ders: %.3f%%" % + (der[0], der[1], der[2], der[-1])) diff --git a/utils/md-eval.pl b/paddlespeech/vector/utils/md-eval.pl similarity index 100% rename from utils/md-eval.pl rename to paddlespeech/vector/utils/md-eval.pl