You can not select more than 25 topics
Topics must start with a letter or number, can include dashes ('-') and can be up to 35 characters long.
232 lines
7.9 KiB
232 lines
7.9 KiB
# Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved.
|
|
#
|
|
# Licensed under the Apache License, Version 2.0 (the "License");
|
|
# you may not use this file except in compliance with the License.
|
|
# You may obtain a copy of the License at
|
|
#
|
|
# http://www.apache.org/licenses/LICENSE-2.0
|
|
#
|
|
# Unless required by applicable law or agreed to in writing, software
|
|
# distributed under the License is distributed on an "AS IS" BASIS,
|
|
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
|
# See the License for the specific language governing permissions and
|
|
# limitations under the License.
|
|
import argparse
|
|
import json
|
|
import os
|
|
import pickle
|
|
import sys
|
|
|
|
import numpy as np
|
|
import paddle
|
|
from paddle.io import BatchSampler
|
|
from paddle.io import DataLoader
|
|
from tqdm.contrib import tqdm
|
|
from yacs.config import CfgNode
|
|
|
|
from paddlespeech.s2t.utils.log import Log
|
|
from paddlespeech.vector.cluster.diarization import EmbeddingMeta
|
|
from paddlespeech.vector.io.batch import batch_feature_normalize
|
|
from paddlespeech.vector.io.dataset_from_json import JSONDataset
|
|
from paddlespeech.vector.models.ecapa_tdnn import EcapaTdnn
|
|
from paddlespeech.vector.modules.sid_model import SpeakerIdetification
|
|
from paddlespeech.vector.training.seeding import seed_everything
|
|
|
|
# Logger setup
|
|
logger = Log(__name__).getlog()
|
|
|
|
|
|
def prepare_subset_json(full_meta_data, rec_id, out_meta_file):
|
|
"""Prepares metadata for a given recording ID.
|
|
|
|
Arguments
|
|
---------
|
|
full_meta_data : json
|
|
Full meta (json) containing all the recordings
|
|
rec_id : str
|
|
The recording ID for which meta (json) has to be prepared
|
|
out_meta_file : str
|
|
Path of the output meta (json) file.
|
|
"""
|
|
|
|
subset = {}
|
|
for key in full_meta_data:
|
|
k = str(key)
|
|
if k.startswith(rec_id):
|
|
subset[key] = full_meta_data[key]
|
|
|
|
with open(out_meta_file, mode="w") as json_f:
|
|
json.dump(subset, json_f, indent=2)
|
|
|
|
|
|
def create_dataloader(json_file, batch_size):
|
|
"""Creates the datasets and their data processing pipelines.
|
|
This is used for multi-mic processing.
|
|
"""
|
|
|
|
# create datasets
|
|
dataset = JSONDataset(
|
|
json_file=json_file,
|
|
feat_type='melspectrogram',
|
|
n_mels=config.n_mels,
|
|
window_size=config.window_size,
|
|
hop_length=config.hop_size)
|
|
|
|
# create dataloader
|
|
batch_sampler = BatchSampler(dataset, batch_size=batch_size, shuffle=True)
|
|
dataloader = DataLoader(dataset,
|
|
batch_sampler=batch_sampler,
|
|
collate_fn=lambda x: batch_feature_normalize(
|
|
x, mean_norm=True, std_norm=False),
|
|
return_list=True)
|
|
|
|
return dataloader
|
|
|
|
|
|
def main(args, config):
|
|
# set the training device, cpu or gpu
|
|
paddle.set_device(args.device)
|
|
# set the random seed
|
|
seed_everything(config.seed)
|
|
|
|
# stage1: build the dnn backbone model network
|
|
ecapa_tdnn = EcapaTdnn(**config.model)
|
|
|
|
# stage2: build the speaker verification eval instance with backbone model
|
|
model = SpeakerIdetification(backbone=ecapa_tdnn, num_class=1)
|
|
|
|
# stage3: load the pre-trained model
|
|
# we get the last model from the epoch and save_interval
|
|
args.load_checkpoint = os.path.abspath(
|
|
os.path.expanduser(args.load_checkpoint))
|
|
|
|
# load model checkpoint to sid model
|
|
state_dict = paddle.load(
|
|
os.path.join(args.load_checkpoint, 'model.pdparams'))
|
|
model.set_state_dict(state_dict)
|
|
logger.info(f'Checkpoint loaded from {args.load_checkpoint}')
|
|
|
|
# set the model to eval mode
|
|
model.eval()
|
|
|
|
# load meta data
|
|
meta_file = os.path.join(
|
|
args.data_dir,
|
|
config.meta_data_dir,
|
|
"ami_" + args.dataset + "." + config.mic_type + ".subsegs.json", )
|
|
with open(meta_file, "r") as f:
|
|
full_meta = json.load(f)
|
|
|
|
# get all the recording IDs in this dataset.
|
|
all_keys = full_meta.keys()
|
|
A = [word.rstrip().split("_")[0] for word in all_keys]
|
|
all_rec_ids = list(set(A[1:]))
|
|
all_rec_ids.sort()
|
|
split = "AMI_" + args.dataset
|
|
i = 1
|
|
|
|
msg = "Extra embdding for " + args.dataset + " set"
|
|
logger.info(msg)
|
|
|
|
if len(all_rec_ids) <= 0:
|
|
msg = "No recording IDs found! Please check if meta_data json file is properly generated."
|
|
logger.error(msg)
|
|
sys.exit()
|
|
|
|
# extra different recordings embdding in a dataset.
|
|
for rec_id in tqdm(all_rec_ids):
|
|
# This tag will be displayed in the log.
|
|
tag = ("[" + str(args.dataset) + ": " + str(i) + "/" +
|
|
str(len(all_rec_ids)) + "]")
|
|
i = i + 1
|
|
|
|
# log message.
|
|
msg = "Embdding %s : %s " % (tag, rec_id)
|
|
logger.debug(msg)
|
|
|
|
# embedding directory.
|
|
if not os.path.exists(
|
|
os.path.join(args.data_dir, config.embedding_dir, split)):
|
|
os.makedirs(
|
|
os.path.join(args.data_dir, config.embedding_dir, split))
|
|
|
|
# file to store embeddings.
|
|
emb_file_name = rec_id + "." + config.mic_type + ".emb_stat.pkl"
|
|
diary_stat_emb_file = os.path.join(args.data_dir, config.embedding_dir,
|
|
split, emb_file_name)
|
|
|
|
# prepare a metadata (json) for one recording. This is basically a subset of full_meta.
|
|
# lets keep this meta-info in embedding directory itself.
|
|
json_file_name = rec_id + "." + config.mic_type + ".json"
|
|
meta_per_rec_file = os.path.join(args.data_dir, config.embedding_dir,
|
|
split, json_file_name)
|
|
|
|
# write subset (meta for one recording) json metadata.
|
|
prepare_subset_json(full_meta, rec_id, meta_per_rec_file)
|
|
|
|
# prepare data loader.
|
|
diary_set_loader = create_dataloader(meta_per_rec_file,
|
|
config.batch_size)
|
|
|
|
# extract embeddings (skip if already done).
|
|
if not os.path.isfile(diary_stat_emb_file):
|
|
logger.debug("Extracting deep embeddings")
|
|
embeddings = np.empty(shape=[0, config.emb_dim], dtype=np.float64)
|
|
segset = []
|
|
|
|
for batch_idx, batch in enumerate(tqdm(diary_set_loader)):
|
|
# extrac the audio embedding
|
|
ids, feats, lengths = batch['ids'], batch['feats'], batch[
|
|
'lengths']
|
|
seg = [x for x in ids]
|
|
segset = segset + seg
|
|
emb = model.backbone(feats, lengths).squeeze(
|
|
-1).numpy() # (N, emb_size, 1) -> (N, emb_size)
|
|
embeddings = np.concatenate((embeddings, emb), axis=0)
|
|
|
|
segset = np.array(segset, dtype="|O")
|
|
stat_obj = EmbeddingMeta(
|
|
segset=segset,
|
|
stats=embeddings, )
|
|
logger.debug("Saving Embeddings...")
|
|
with open(diary_stat_emb_file, "wb") as output:
|
|
pickle.dump(stat_obj, output)
|
|
|
|
else:
|
|
logger.debug("Skipping embedding extraction (as already present).")
|
|
|
|
|
|
# Begin experiment!
|
|
if __name__ == "__main__":
|
|
parser = argparse.ArgumentParser(__doc__)
|
|
parser.add_argument(
|
|
'--device',
|
|
default="gpu",
|
|
help="Select which device to perform diarization, defaults to gpu.")
|
|
parser.add_argument(
|
|
"--config", default=None, type=str, help="configuration file")
|
|
parser.add_argument(
|
|
"--data-dir",
|
|
default="../save/",
|
|
type=str,
|
|
help="processsed data directory")
|
|
parser.add_argument(
|
|
"--dataset",
|
|
choices=['dev', 'eval'],
|
|
default="dev",
|
|
type=str,
|
|
help="Select which dataset to extra embdding, defaults to dev")
|
|
parser.add_argument(
|
|
"--load-checkpoint",
|
|
type=str,
|
|
default='',
|
|
help="Directory to load model checkpoint to compute embeddings.")
|
|
args = parser.parse_args()
|
|
config = CfgNode(new_allowed=True)
|
|
if args.config:
|
|
config.merge_from_file(args.config)
|
|
|
|
config.freeze()
|
|
|
|
main(args, config)
|