commit
8d1ee8262e
@ -0,0 +1,115 @@
|
||||
# Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
from fastapi.testclient import TestClient
|
||||
from vpr_search import app
|
||||
|
||||
from utils.utility import download
|
||||
from utils.utility import unpack
|
||||
|
||||
client = TestClient(app)
|
||||
|
||||
|
||||
def download_audio_data():
|
||||
"""
|
||||
Download audio data
|
||||
"""
|
||||
url = "https://paddlespeech.bj.bcebos.com/vector/audio/example_audio.tar.gz"
|
||||
md5sum = "52ac69316c1aa1fdef84da7dd2c67b39"
|
||||
target_dir = "./"
|
||||
filepath = download(url, md5sum, target_dir)
|
||||
unpack(filepath, target_dir, True)
|
||||
|
||||
|
||||
def test_drop():
|
||||
"""
|
||||
Delete the table of MySQL
|
||||
"""
|
||||
response = client.post("/vpr/drop")
|
||||
assert response.status_code == 200
|
||||
|
||||
|
||||
def test_enroll_local(spk: str, audio: str):
|
||||
"""
|
||||
Enroll the audio to MySQL
|
||||
"""
|
||||
response = client.post("/vpr/enroll/local?spk_id=" + spk +
|
||||
"&audio_path=.%2Fexample_audio%2F" + audio + ".wav")
|
||||
assert response.status_code == 200
|
||||
assert response.json() == {
|
||||
'status': True,
|
||||
'msg': "Successfully enroll data!"
|
||||
}
|
||||
|
||||
|
||||
def test_search_local():
|
||||
"""
|
||||
Search the spk in MySQL by audio
|
||||
"""
|
||||
response = client.post(
|
||||
"/vpr/recog/local?audio_path=.%2Fexample_audio%2Ftest.wav")
|
||||
assert response.status_code == 200
|
||||
|
||||
|
||||
def test_list():
|
||||
"""
|
||||
Get all records in MySQL
|
||||
"""
|
||||
response = client.get("/vpr/list")
|
||||
assert response.status_code == 200
|
||||
|
||||
|
||||
def test_data(spk: str):
|
||||
"""
|
||||
Get the audio file by spk_id in MySQL
|
||||
"""
|
||||
response = client.get("/vpr/data?spk_id=" + spk)
|
||||
assert response.status_code == 200
|
||||
|
||||
|
||||
def test_del(spk: str):
|
||||
"""
|
||||
Delete the record in MySQL by spk_id
|
||||
"""
|
||||
response = client.post("/vpr/del?spk_id=" + spk)
|
||||
assert response.status_code == 200
|
||||
|
||||
|
||||
def test_count():
|
||||
"""
|
||||
Get the number of spk in MySQL
|
||||
"""
|
||||
response = client.get("/vpr/count")
|
||||
assert response.status_code == 200
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
download_audio_data()
|
||||
|
||||
test_enroll_local("spk1", "arms_strikes")
|
||||
test_enroll_local("spk2", "sword_wielding")
|
||||
test_enroll_local("spk3", "test")
|
||||
test_list()
|
||||
test_data("spk1")
|
||||
test_count()
|
||||
test_search_local()
|
||||
|
||||
test_del("spk1")
|
||||
test_count()
|
||||
test_search_local()
|
||||
|
||||
test_enroll_local("spk1", "arms_strikes")
|
||||
test_count()
|
||||
test_search_local()
|
||||
|
||||
test_drop()
|
@ -0,0 +1,206 @@
|
||||
# Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
import os
|
||||
|
||||
import uvicorn
|
||||
from config import UPLOAD_PATH
|
||||
from fastapi import FastAPI
|
||||
from fastapi import File
|
||||
from fastapi import UploadFile
|
||||
from logs import LOGGER
|
||||
from mysql_helpers import MySQLHelper
|
||||
from operations.count import do_count_vpr
|
||||
from operations.count import do_get
|
||||
from operations.count import do_list
|
||||
from operations.drop import do_delete
|
||||
from operations.drop import do_drop_vpr
|
||||
from operations.load import do_enroll
|
||||
from operations.search import do_search_vpr
|
||||
from starlette.middleware.cors import CORSMiddleware
|
||||
from starlette.requests import Request
|
||||
from starlette.responses import FileResponse
|
||||
|
||||
app = FastAPI()
|
||||
app.add_middleware(
|
||||
CORSMiddleware,
|
||||
allow_origins=["*"],
|
||||
allow_credentials=True,
|
||||
allow_methods=["*"],
|
||||
allow_headers=["*"])
|
||||
|
||||
MYSQL_CLI = MySQLHelper()
|
||||
|
||||
# Mkdir 'tmp/audio-data'
|
||||
if not os.path.exists(UPLOAD_PATH):
|
||||
os.makedirs(UPLOAD_PATH)
|
||||
LOGGER.info(f"Mkdir the path: {UPLOAD_PATH}")
|
||||
|
||||
|
||||
@app.post('/vpr/enroll')
|
||||
async def vpr_enroll(table_name: str=None,
|
||||
spk_id: str=None,
|
||||
audio: UploadFile=File(...)):
|
||||
# Enroll the uploaded audio with spk-id into MySQL
|
||||
try:
|
||||
# Save the upload data to server.
|
||||
content = await audio.read()
|
||||
audio_path = os.path.join(UPLOAD_PATH, audio.filename)
|
||||
with open(audio_path, "wb+") as f:
|
||||
f.write(content)
|
||||
do_enroll(table_name, spk_id, audio_path, MYSQL_CLI)
|
||||
LOGGER.info(f"Successfully enrolled {spk_id} online!")
|
||||
return {'status': True, 'msg': "Successfully enroll data!"}
|
||||
except Exception as e:
|
||||
LOGGER.error(e)
|
||||
return {'status': False, 'msg': e}, 400
|
||||
|
||||
|
||||
@app.post('/vpr/enroll/local')
|
||||
async def vpr_enroll_local(table_name: str=None,
|
||||
spk_id: str=None,
|
||||
audio_path: str=None):
|
||||
# Enroll the local audio with spk-id into MySQL
|
||||
try:
|
||||
do_enroll(table_name, spk_id, audio_path, MYSQL_CLI)
|
||||
LOGGER.info(f"Successfully enrolled {spk_id} locally!")
|
||||
return {'status': True, 'msg': "Successfully enroll data!"}
|
||||
except Exception as e:
|
||||
LOGGER.error(e)
|
||||
return {'status': False, 'msg': e}, 400
|
||||
|
||||
|
||||
@app.post('/vpr/recog')
|
||||
async def vpr_recog(request: Request,
|
||||
table_name: str=None,
|
||||
audio: UploadFile=File(...)):
|
||||
# Voice print recognition online
|
||||
try:
|
||||
# Save the upload data to server.
|
||||
content = await audio.read()
|
||||
query_audio_path = os.path.join(UPLOAD_PATH, audio.filename)
|
||||
with open(query_audio_path, "wb+") as f:
|
||||
f.write(content)
|
||||
host = request.headers['host']
|
||||
spk_ids, paths, scores = do_search_vpr(host, table_name,
|
||||
query_audio_path, MYSQL_CLI)
|
||||
for spk_id, path, score in zip(spk_ids, paths, scores):
|
||||
LOGGER.info(f"spk {spk_id}, score {score}, audio path {path}, ")
|
||||
res = dict(zip(spk_ids, zip(paths, scores)))
|
||||
# Sort results by distance metric, closest distances first
|
||||
res = sorted(res.items(), key=lambda item: item[1][1], reverse=True)
|
||||
LOGGER.info("Successfully speaker recognition online!")
|
||||
return res
|
||||
except Exception as e:
|
||||
LOGGER.error(e)
|
||||
return {'status': False, 'msg': e}, 400
|
||||
|
||||
|
||||
@app.post('/vpr/recog/local')
|
||||
async def vpr_recog_local(request: Request,
|
||||
table_name: str=None,
|
||||
audio_path: str=None):
|
||||
# Voice print recognition locally
|
||||
try:
|
||||
host = request.headers['host']
|
||||
spk_ids, paths, scores = do_search_vpr(host, table_name, audio_path,
|
||||
MYSQL_CLI)
|
||||
for spk_id, path, score in zip(spk_ids, paths, scores):
|
||||
LOGGER.info(f"spk {spk_id}, score {score}, audio path {path}, ")
|
||||
res = dict(zip(spk_ids, zip(paths, scores)))
|
||||
# Sort results by distance metric, closest distances first
|
||||
res = sorted(res.items(), key=lambda item: item[1][1], reverse=True)
|
||||
LOGGER.info("Successfully speaker recognition locally!")
|
||||
return res
|
||||
except Exception as e:
|
||||
LOGGER.error(e)
|
||||
return {'status': False, 'msg': e}, 400
|
||||
|
||||
|
||||
@app.post('/vpr/del')
|
||||
async def vpr_del(table_name: str=None, spk_id: str=None):
|
||||
# Delete a record by spk_id in MySQL
|
||||
try:
|
||||
do_delete(table_name, spk_id, MYSQL_CLI)
|
||||
LOGGER.info("Successfully delete a record by spk_id in MySQL")
|
||||
return {'status': True, 'msg': "Successfully delete data!"}
|
||||
except Exception as e:
|
||||
LOGGER.error(e)
|
||||
return {'status': False, 'msg': e}, 400
|
||||
|
||||
|
||||
@app.get('/vpr/list')
|
||||
async def vpr_list(table_name: str=None):
|
||||
# Get all records in MySQL
|
||||
try:
|
||||
spk_ids, audio_paths = do_list(table_name, MYSQL_CLI)
|
||||
for i in range(len(spk_ids)):
|
||||
LOGGER.debug(f"spk {spk_ids[i]}, audio path {audio_paths[i]}")
|
||||
LOGGER.info("Successfully list all records from mysql!")
|
||||
return spk_ids, audio_paths
|
||||
except Exception as e:
|
||||
LOGGER.error(e)
|
||||
return {'status': False, 'msg': e}, 400
|
||||
|
||||
|
||||
@app.get('/vpr/data')
|
||||
async def vpr_data(
|
||||
table_name: str=None,
|
||||
spk_id: str=None, ):
|
||||
# Get the audio file from path by spk_id in MySQL
|
||||
try:
|
||||
audio_path = do_get(table_name, spk_id, MYSQL_CLI)
|
||||
LOGGER.info(f"Successfully get audio path {audio_path}!")
|
||||
return FileResponse(audio_path)
|
||||
except Exception as e:
|
||||
LOGGER.error(e)
|
||||
return {'status': False, 'msg': e}, 400
|
||||
|
||||
|
||||
@app.get('/vpr/count')
|
||||
async def vpr_count(table_name: str=None):
|
||||
# Get the total number of spk in MySQL
|
||||
try:
|
||||
num = do_count_vpr(table_name, MYSQL_CLI)
|
||||
LOGGER.info("Successfully count the number of spk!")
|
||||
return num
|
||||
except Exception as e:
|
||||
LOGGER.error(e)
|
||||
return {'status': False, 'msg': e}, 400
|
||||
|
||||
|
||||
@app.post('/vpr/drop')
|
||||
async def drop_tables(table_name: str=None):
|
||||
# Delete the table of MySQL
|
||||
try:
|
||||
do_drop_vpr(table_name, MYSQL_CLI)
|
||||
LOGGER.info("Successfully drop tables in MySQL!")
|
||||
return {'status': True, 'msg': "Successfully drop tables!"}
|
||||
except Exception as e:
|
||||
LOGGER.error(e)
|
||||
return {'status': False, 'msg': e}, 400
|
||||
|
||||
|
||||
@app.get('/data')
|
||||
def audio_path(audio_path):
|
||||
# Get the audio file from path
|
||||
try:
|
||||
LOGGER.info(f"Successfully get audio: {audio_path}")
|
||||
return FileResponse(audio_path)
|
||||
except Exception as e:
|
||||
LOGGER.error(f"get audio error: {e}")
|
||||
return {'status': False, 'msg': e}, 400
|
||||
|
||||
|
||||
if __name__ == '__main__':
|
||||
uvicorn.run(app=app, host='0.0.0.0', port=8002)
|
@ -1,6 +1,9 @@
|
||||
#!/bin/bash
|
||||
|
||||
wget -c https://paddlespeech.bj.bcebos.com/vector/audio/85236145389.wav
|
||||
wget -c https://paddlespeech.bj.bcebos.com/vector/audio/123456789.wav
|
||||
|
||||
# asr
|
||||
paddlespeech vector --task spk --input ./85236145389.wav
|
||||
# vector
|
||||
paddlespeech vector --task spk --input ./85236145389.wav
|
||||
|
||||
paddlespeech vector --task score --input "./85236145389.wav ./123456789.wav"
|
||||
|
@ -0,0 +1,62 @@
|
||||
###########################################################
|
||||
# AMI DATA PREPARE SETTING #
|
||||
###########################################################
|
||||
split_type: 'full_corpus_asr'
|
||||
skip_TNO: True
|
||||
# Options for mic_type: 'Mix-Lapel', 'Mix-Headset', 'Array1', 'Array1-01', 'BeamformIt'
|
||||
mic_type: 'Mix-Headset'
|
||||
vad_type: 'oracle'
|
||||
max_subseg_dur: 3.0
|
||||
overlap: 1.5
|
||||
# Some more exp folders (for cleaner structure).
|
||||
embedding_dir: emb #!ref <save_folder>/emb
|
||||
meta_data_dir: metadata #!ref <save_folder>/metadata
|
||||
ref_rttm_dir: ref_rttms #!ref <save_folder>/ref_rttms
|
||||
sys_rttm_dir: sys_rttms #!ref <save_folder>/sys_rttms
|
||||
der_dir: DER #!ref <save_folder>/DER
|
||||
|
||||
|
||||
###########################################################
|
||||
# FEATURE EXTRACTION SETTING #
|
||||
###########################################################
|
||||
# currently, we only support fbank
|
||||
sr: 16000 # sample rate
|
||||
n_mels: 80
|
||||
window_size: 400 #25ms, sample rate 16000, 25 * 16000 / 1000 = 400
|
||||
hop_size: 160 #10ms, sample rate 16000, 10 * 16000 / 1000 = 160
|
||||
#left_frames: 0
|
||||
#right_frames: 0
|
||||
#deltas: False
|
||||
|
||||
|
||||
###########################################################
|
||||
# MODEL SETTING #
|
||||
###########################################################
|
||||
# currently, we only support ecapa-tdnn in the ecapa_tdnn.yaml
|
||||
# if we want use another model, please choose another configuration yaml file
|
||||
seed: 1234
|
||||
emb_dim: 192
|
||||
batch_size: 16
|
||||
model:
|
||||
input_size: 80
|
||||
channels: [1024, 1024, 1024, 1024, 3072]
|
||||
kernel_sizes: [5, 3, 3, 3, 1]
|
||||
dilations: [1, 2, 3, 4, 1]
|
||||
attention_channels: 128
|
||||
lin_neurons: 192
|
||||
# Will automatically download ECAPA-TDNN model (best).
|
||||
|
||||
###########################################################
|
||||
# SPECTRAL CLUSTERING SETTING #
|
||||
###########################################################
|
||||
backend: 'SC' # options: 'kmeans' # Note: kmeans goes only with cos affinity
|
||||
affinity: 'cos' # options: cos, nn
|
||||
max_num_spkrs: 10
|
||||
oracle_n_spkrs: True
|
||||
|
||||
|
||||
###########################################################
|
||||
# DER EVALUATION SETTING #
|
||||
###########################################################
|
||||
ignore_overlap: True
|
||||
forgiveness_collar: 0.25
|
@ -0,0 +1,231 @@
|
||||
# Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
import argparse
|
||||
import json
|
||||
import os
|
||||
import pickle
|
||||
import sys
|
||||
|
||||
import numpy as np
|
||||
import paddle
|
||||
from paddle.io import BatchSampler
|
||||
from paddle.io import DataLoader
|
||||
from tqdm.contrib import tqdm
|
||||
from yacs.config import CfgNode
|
||||
|
||||
from paddlespeech.s2t.utils.log import Log
|
||||
from paddlespeech.vector.cluster.diarization import EmbeddingMeta
|
||||
from paddlespeech.vector.io.batch import batch_feature_normalize
|
||||
from paddlespeech.vector.io.dataset_from_json import JSONDataset
|
||||
from paddlespeech.vector.models.ecapa_tdnn import EcapaTdnn
|
||||
from paddlespeech.vector.modules.sid_model import SpeakerIdetification
|
||||
from paddlespeech.vector.training.seeding import seed_everything
|
||||
|
||||
# Logger setup
|
||||
logger = Log(__name__).getlog()
|
||||
|
||||
|
||||
def prepare_subset_json(full_meta_data, rec_id, out_meta_file):
|
||||
"""Prepares metadata for a given recording ID.
|
||||
|
||||
Arguments
|
||||
---------
|
||||
full_meta_data : json
|
||||
Full meta (json) containing all the recordings
|
||||
rec_id : str
|
||||
The recording ID for which meta (json) has to be prepared
|
||||
out_meta_file : str
|
||||
Path of the output meta (json) file.
|
||||
"""
|
||||
|
||||
subset = {}
|
||||
for key in full_meta_data:
|
||||
k = str(key)
|
||||
if k.startswith(rec_id):
|
||||
subset[key] = full_meta_data[key]
|
||||
|
||||
with open(out_meta_file, mode="w") as json_f:
|
||||
json.dump(subset, json_f, indent=2)
|
||||
|
||||
|
||||
def create_dataloader(json_file, batch_size):
|
||||
"""Creates the datasets and their data processing pipelines.
|
||||
This is used for multi-mic processing.
|
||||
"""
|
||||
|
||||
# create datasets
|
||||
dataset = JSONDataset(
|
||||
json_file=json_file,
|
||||
feat_type='melspectrogram',
|
||||
n_mels=config.n_mels,
|
||||
window_size=config.window_size,
|
||||
hop_length=config.hop_size)
|
||||
|
||||
# create dataloader
|
||||
batch_sampler = BatchSampler(dataset, batch_size=batch_size, shuffle=True)
|
||||
dataloader = DataLoader(dataset,
|
||||
batch_sampler=batch_sampler,
|
||||
collate_fn=lambda x: batch_feature_normalize(
|
||||
x, mean_norm=True, std_norm=False),
|
||||
return_list=True)
|
||||
|
||||
return dataloader
|
||||
|
||||
|
||||
def main(args, config):
|
||||
# set the training device, cpu or gpu
|
||||
paddle.set_device(args.device)
|
||||
# set the random seed
|
||||
seed_everything(config.seed)
|
||||
|
||||
# stage1: build the dnn backbone model network
|
||||
ecapa_tdnn = EcapaTdnn(**config.model)
|
||||
|
||||
# stage2: build the speaker verification eval instance with backbone model
|
||||
model = SpeakerIdetification(backbone=ecapa_tdnn, num_class=1)
|
||||
|
||||
# stage3: load the pre-trained model
|
||||
# we get the last model from the epoch and save_interval
|
||||
args.load_checkpoint = os.path.abspath(
|
||||
os.path.expanduser(args.load_checkpoint))
|
||||
|
||||
# load model checkpoint to sid model
|
||||
state_dict = paddle.load(
|
||||
os.path.join(args.load_checkpoint, 'model.pdparams'))
|
||||
model.set_state_dict(state_dict)
|
||||
logger.info(f'Checkpoint loaded from {args.load_checkpoint}')
|
||||
|
||||
# set the model to eval mode
|
||||
model.eval()
|
||||
|
||||
# load meta data
|
||||
meta_file = os.path.join(
|
||||
args.data_dir,
|
||||
config.meta_data_dir,
|
||||
"ami_" + args.dataset + "." + config.mic_type + ".subsegs.json", )
|
||||
with open(meta_file, "r") as f:
|
||||
full_meta = json.load(f)
|
||||
|
||||
# get all the recording IDs in this dataset.
|
||||
all_keys = full_meta.keys()
|
||||
A = [word.rstrip().split("_")[0] for word in all_keys]
|
||||
all_rec_ids = list(set(A[1:]))
|
||||
all_rec_ids.sort()
|
||||
split = "AMI_" + args.dataset
|
||||
i = 1
|
||||
|
||||
msg = "Extra embdding for " + args.dataset + " set"
|
||||
logger.info(msg)
|
||||
|
||||
if len(all_rec_ids) <= 0:
|
||||
msg = "No recording IDs found! Please check if meta_data json file is properly generated."
|
||||
logger.error(msg)
|
||||
sys.exit()
|
||||
|
||||
# extra different recordings embdding in a dataset.
|
||||
for rec_id in tqdm(all_rec_ids):
|
||||
# This tag will be displayed in the log.
|
||||
tag = ("[" + str(args.dataset) + ": " + str(i) + "/" +
|
||||
str(len(all_rec_ids)) + "]")
|
||||
i = i + 1
|
||||
|
||||
# log message.
|
||||
msg = "Embdding %s : %s " % (tag, rec_id)
|
||||
logger.debug(msg)
|
||||
|
||||
# embedding directory.
|
||||
if not os.path.exists(
|
||||
os.path.join(args.data_dir, config.embedding_dir, split)):
|
||||
os.makedirs(
|
||||
os.path.join(args.data_dir, config.embedding_dir, split))
|
||||
|
||||
# file to store embeddings.
|
||||
emb_file_name = rec_id + "." + config.mic_type + ".emb_stat.pkl"
|
||||
diary_stat_emb_file = os.path.join(args.data_dir, config.embedding_dir,
|
||||
split, emb_file_name)
|
||||
|
||||
# prepare a metadata (json) for one recording. This is basically a subset of full_meta.
|
||||
# lets keep this meta-info in embedding directory itself.
|
||||
json_file_name = rec_id + "." + config.mic_type + ".json"
|
||||
meta_per_rec_file = os.path.join(args.data_dir, config.embedding_dir,
|
||||
split, json_file_name)
|
||||
|
||||
# write subset (meta for one recording) json metadata.
|
||||
prepare_subset_json(full_meta, rec_id, meta_per_rec_file)
|
||||
|
||||
# prepare data loader.
|
||||
diary_set_loader = create_dataloader(meta_per_rec_file,
|
||||
config.batch_size)
|
||||
|
||||
# extract embeddings (skip if already done).
|
||||
if not os.path.isfile(diary_stat_emb_file):
|
||||
logger.debug("Extracting deep embeddings")
|
||||
embeddings = np.empty(shape=[0, config.emb_dim], dtype=np.float64)
|
||||
segset = []
|
||||
|
||||
for batch_idx, batch in enumerate(tqdm(diary_set_loader)):
|
||||
# extrac the audio embedding
|
||||
ids, feats, lengths = batch['ids'], batch['feats'], batch[
|
||||
'lengths']
|
||||
seg = [x for x in ids]
|
||||
segset = segset + seg
|
||||
emb = model.backbone(feats, lengths).squeeze(
|
||||
-1).numpy() # (N, emb_size, 1) -> (N, emb_size)
|
||||
embeddings = np.concatenate((embeddings, emb), axis=0)
|
||||
|
||||
segset = np.array(segset, dtype="|O")
|
||||
stat_obj = EmbeddingMeta(
|
||||
segset=segset,
|
||||
stats=embeddings, )
|
||||
logger.debug("Saving Embeddings...")
|
||||
with open(diary_stat_emb_file, "wb") as output:
|
||||
pickle.dump(stat_obj, output)
|
||||
|
||||
else:
|
||||
logger.debug("Skipping embedding extraction (as already present).")
|
||||
|
||||
|
||||
# Begin experiment!
|
||||
if __name__ == "__main__":
|
||||
parser = argparse.ArgumentParser(__doc__)
|
||||
parser.add_argument(
|
||||
'--device',
|
||||
default="gpu",
|
||||
help="Select which device to perform diarization, defaults to gpu.")
|
||||
parser.add_argument(
|
||||
"--config", default=None, type=str, help="configuration file")
|
||||
parser.add_argument(
|
||||
"--data-dir",
|
||||
default="../save/",
|
||||
type=str,
|
||||
help="processsed data directory")
|
||||
parser.add_argument(
|
||||
"--dataset",
|
||||
choices=['dev', 'eval'],
|
||||
default="dev",
|
||||
type=str,
|
||||
help="Select which dataset to extra embdding, defaults to dev")
|
||||
parser.add_argument(
|
||||
"--load-checkpoint",
|
||||
type=str,
|
||||
default='',
|
||||
help="Directory to load model checkpoint to compute embeddings.")
|
||||
args = parser.parse_args()
|
||||
config = CfgNode(new_allowed=True)
|
||||
if args.config:
|
||||
config.merge_from_file(args.config)
|
||||
|
||||
config.freeze()
|
||||
|
||||
main(args, config)
|
@ -1,49 +0,0 @@
|
||||
#!/bin/bash
|
||||
|
||||
stage=1
|
||||
|
||||
TARGET_DIR=${MAIN_ROOT}/dataset/ami
|
||||
data_folder=${TARGET_DIR}/amicorpus #e.g., /path/to/amicorpus/
|
||||
manual_annot_folder=${TARGET_DIR}/ami_public_manual_1.6.2 #e.g., /path/to/ami_public_manual_1.6.2/
|
||||
|
||||
save_folder=${MAIN_ROOT}/examples/ami/sd0/data
|
||||
ref_rttm_dir=${save_folder}/ref_rttms
|
||||
meta_data_dir=${save_folder}/metadata
|
||||
|
||||
set=L
|
||||
|
||||
. ${MAIN_ROOT}/utils/parse_options.sh || exit 1;
|
||||
set -u
|
||||
set -o pipefail
|
||||
|
||||
mkdir -p ${save_folder}
|
||||
|
||||
if [ ${stage} -le 0 ]; then
|
||||
# Download AMI corpus, You need around 10GB of free space to get whole data
|
||||
# The signals are too large to package in this way,
|
||||
# so you need to use the chooser to indicate which ones you wish to download
|
||||
echo "Please follow https://groups.inf.ed.ac.uk/ami/download/ to download the data."
|
||||
echo "Annotations: AMI manual annotations v1.6.2 "
|
||||
echo "Signals: "
|
||||
echo "1) Select one or more AMI meetings: the IDs please follow ./ami_split.py"
|
||||
echo "2) Select media streams: Just select Headset mix"
|
||||
exit 0;
|
||||
fi
|
||||
|
||||
if [ ${stage} -le 1 ]; then
|
||||
echo "AMI Data preparation"
|
||||
|
||||
python local/ami_prepare.py --data_folder ${data_folder} \
|
||||
--manual_annot_folder ${manual_annot_folder} \
|
||||
--save_folder ${save_folder} --ref_rttm_dir ${ref_rttm_dir} \
|
||||
--meta_data_dir ${meta_data_dir}
|
||||
|
||||
if [ $? -ne 0 ]; then
|
||||
echo "Prepare AMI failed. Please check log message."
|
||||
exit 1
|
||||
fi
|
||||
|
||||
fi
|
||||
|
||||
echo "AMI data preparation done."
|
||||
exit 0
|
@ -0,0 +1,428 @@
|
||||
# Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
import argparse
|
||||
import glob
|
||||
import json
|
||||
import os
|
||||
import pickle
|
||||
import shutil
|
||||
import sys
|
||||
|
||||
import numpy as np
|
||||
from tqdm.contrib import tqdm
|
||||
from yacs.config import CfgNode
|
||||
|
||||
from paddlespeech.s2t.utils.log import Log
|
||||
from paddlespeech.vector.cluster import diarization as diar
|
||||
from utils.DER import DER
|
||||
|
||||
# Logger setup
|
||||
logger = Log(__name__).getlog()
|
||||
|
||||
|
||||
def diarize_dataset(
|
||||
full_meta,
|
||||
split_type,
|
||||
n_lambdas,
|
||||
pval,
|
||||
save_dir,
|
||||
config,
|
||||
n_neighbors=10, ):
|
||||
"""This function diarizes all the recordings in a given dataset. It performs
|
||||
computation of embedding and clusters them using spectral clustering (or other backends).
|
||||
The output speaker boundary file is stored in the RTTM format.
|
||||
"""
|
||||
|
||||
# prepare `spkr_info` only once when Oracle num of speakers is selected.
|
||||
# spkr_info is essential to obtain number of speakers from groundtruth.
|
||||
if config.oracle_n_spkrs is True:
|
||||
full_ref_rttm_file = os.path.join(save_dir, config.ref_rttm_dir,
|
||||
"fullref_ami_" + split_type + ".rttm")
|
||||
rttm = diar.read_rttm(full_ref_rttm_file)
|
||||
|
||||
spkr_info = list( # noqa F841
|
||||
filter(lambda x: x.startswith("SPKR-INFO"), rttm))
|
||||
|
||||
# get all the recording IDs in this dataset.
|
||||
all_keys = full_meta.keys()
|
||||
A = [word.rstrip().split("_")[0] for word in all_keys]
|
||||
all_rec_ids = list(set(A[1:]))
|
||||
all_rec_ids.sort()
|
||||
split = "AMI_" + split_type
|
||||
i = 1
|
||||
|
||||
# adding tag for directory path.
|
||||
type_of_num_spkr = "oracle" if config.oracle_n_spkrs else "est"
|
||||
tag = (type_of_num_spkr + "_" + str(config.affinity) + "_" + config.backend)
|
||||
|
||||
# make out rttm dir
|
||||
out_rttm_dir = os.path.join(save_dir, config.sys_rttm_dir, config.mic_type,
|
||||
split, tag)
|
||||
if not os.path.exists(out_rttm_dir):
|
||||
os.makedirs(out_rttm_dir)
|
||||
|
||||
# diarizing different recordings in a dataset.
|
||||
for rec_id in tqdm(all_rec_ids):
|
||||
# this tag will be displayed in the log.
|
||||
tag = ("[" + str(split_type) + ": " + str(i) + "/" +
|
||||
str(len(all_rec_ids)) + "]")
|
||||
i = i + 1
|
||||
|
||||
# log message.
|
||||
msg = "Diarizing %s : %s " % (tag, rec_id)
|
||||
logger.debug(msg)
|
||||
|
||||
# load embeddings.
|
||||
emb_file_name = rec_id + "." + config.mic_type + ".emb_stat.pkl"
|
||||
diary_stat_emb_file = os.path.join(save_dir, config.embedding_dir,
|
||||
split, emb_file_name)
|
||||
if not os.path.isfile(diary_stat_emb_file):
|
||||
msg = "Embdding file %s not found! Please check if embdding file is properly generated." % (
|
||||
diary_stat_emb_file)
|
||||
logger.error(msg)
|
||||
sys.exit()
|
||||
with open(diary_stat_emb_file, "rb") as in_file:
|
||||
diary_obj = pickle.load(in_file)
|
||||
|
||||
out_rttm_file = out_rttm_dir + "/" + rec_id + ".rttm"
|
||||
|
||||
# processing starts from here.
|
||||
if config.oracle_n_spkrs is True:
|
||||
# oracle num of speakers.
|
||||
num_spkrs = diar.get_oracle_num_spkrs(rec_id, spkr_info)
|
||||
else:
|
||||
if config.affinity == "nn":
|
||||
# num of speakers tunned on dev set (only for nn affinity).
|
||||
num_spkrs = n_lambdas
|
||||
else:
|
||||
# num of speakers will be estimated using max eigen gap for cos based affinity.
|
||||
# so adding None here. Will use this None later-on.
|
||||
num_spkrs = None
|
||||
|
||||
if config.backend == "kmeans":
|
||||
diar.do_kmeans_clustering(
|
||||
diary_obj,
|
||||
out_rttm_file,
|
||||
rec_id,
|
||||
num_spkrs,
|
||||
pval, )
|
||||
|
||||
if config.backend == "SC":
|
||||
# go for Spectral Clustering (SC).
|
||||
diar.do_spec_clustering(
|
||||
diary_obj,
|
||||
out_rttm_file,
|
||||
rec_id,
|
||||
num_spkrs,
|
||||
pval,
|
||||
config.affinity,
|
||||
n_neighbors, )
|
||||
|
||||
# can used for AHC later. Likewise one can add different backends here.
|
||||
if config.backend == "AHC":
|
||||
# call AHC
|
||||
threshold = pval # pval for AHC is nothing but threshold.
|
||||
diar.do_AHC(diary_obj, out_rttm_file, rec_id, num_spkrs, threshold)
|
||||
|
||||
# once all RTTM outputs are generated, concatenate individual RTTM files to obtain single RTTM file.
|
||||
# this is not needed but just staying with the standards.
|
||||
concate_rttm_file = out_rttm_dir + "/sys_output.rttm"
|
||||
logger.debug("Concatenating individual RTTM files...")
|
||||
with open(concate_rttm_file, "w") as cat_file:
|
||||
for f in glob.glob(out_rttm_dir + "/*.rttm"):
|
||||
if f == concate_rttm_file:
|
||||
continue
|
||||
with open(f, "r") as indi_rttm_file:
|
||||
shutil.copyfileobj(indi_rttm_file, cat_file)
|
||||
|
||||
msg = "The system generated RTTM file for %s set : %s" % (
|
||||
split_type, concate_rttm_file, )
|
||||
logger.debug(msg)
|
||||
|
||||
return concate_rttm_file
|
||||
|
||||
|
||||
def dev_pval_tuner(full_meta, save_dir, config):
|
||||
"""Tuning p_value for affinity matrix.
|
||||
The p_value used so that only p% of the values in each row is retained.
|
||||
"""
|
||||
|
||||
DER_list = []
|
||||
prange = np.arange(0.002, 0.015, 0.001)
|
||||
|
||||
n_lambdas = None # using it as flag later.
|
||||
for p_v in prange:
|
||||
# Process whole dataset for value of p_v.
|
||||
concate_rttm_file = diarize_dataset(full_meta, "dev", n_lambdas, p_v,
|
||||
save_dir, config)
|
||||
|
||||
ref_rttm_file = os.path.join(save_dir, config.ref_rttm_dir,
|
||||
"fullref_ami_dev.rttm")
|
||||
sys_rttm_file = concate_rttm_file
|
||||
[MS, FA, SER, DER_] = DER(
|
||||
ref_rttm_file,
|
||||
sys_rttm_file,
|
||||
config.ignore_overlap,
|
||||
config.forgiveness_collar, )
|
||||
|
||||
DER_list.append(DER_)
|
||||
|
||||
if config.oracle_n_spkrs is True and config.backend == "kmeans":
|
||||
# no need of p_val search. Note p_val is needed for SC for both oracle and est num of speakers.
|
||||
# p_val is needed in oracle_n_spkr=False when using kmeans backend.
|
||||
break
|
||||
|
||||
# Take p_val that gave minmum DER on Dev dataset.
|
||||
tuned_p_val = prange[DER_list.index(min(DER_list))]
|
||||
|
||||
return tuned_p_val
|
||||
|
||||
|
||||
def dev_ahc_threshold_tuner(full_meta, save_dir, config):
|
||||
"""Tuning threshold for affinity matrix. This function is called when AHC is used as backend.
|
||||
"""
|
||||
|
||||
DER_list = []
|
||||
prange = np.arange(0.0, 1.0, 0.1)
|
||||
|
||||
n_lambdas = None # using it as flag later.
|
||||
|
||||
# Note: p_val is threshold in case of AHC.
|
||||
for p_v in prange:
|
||||
# Process whole dataset for value of p_v.
|
||||
concate_rttm_file = diarize_dataset(full_meta, "dev", n_lambdas, p_v,
|
||||
save_dir, config)
|
||||
|
||||
ref_rttm = os.path.join(save_dir, config.ref_rttm_dir,
|
||||
"fullref_ami_dev.rttm")
|
||||
sys_rttm = concate_rttm_file
|
||||
[MS, FA, SER, DER_] = DER(
|
||||
ref_rttm,
|
||||
sys_rttm,
|
||||
config.ignore_overlap,
|
||||
config.forgiveness_collar, )
|
||||
|
||||
DER_list.append(DER_)
|
||||
|
||||
if config.oracle_n_spkrs is True:
|
||||
break # no need of threshold search.
|
||||
|
||||
# Take p_val that gave minmum DER on Dev dataset.
|
||||
tuned_p_val = prange[DER_list.index(min(DER_list))]
|
||||
|
||||
return tuned_p_val
|
||||
|
||||
|
||||
def dev_nn_tuner(full_meta, split_type, save_dir, config):
|
||||
"""Tuning n_neighbors on dev set. Assuming oracle num of speakers.
|
||||
This is used when nn based affinity is selected.
|
||||
"""
|
||||
|
||||
DER_list = []
|
||||
pval = None
|
||||
|
||||
# Now assumming oracle num of speakers.
|
||||
n_lambdas = 4
|
||||
|
||||
for nn in range(5, 15):
|
||||
|
||||
# Process whole dataset for value of n_lambdas.
|
||||
concate_rttm_file = diarize_dataset(full_meta, "dev", n_lambdas, p_v,
|
||||
save_dir, config, nn)
|
||||
|
||||
ref_rttm = os.path.join(save_dir, config.ref_rttm_dir,
|
||||
"fullref_ami_dev.rttm")
|
||||
sys_rttm = concate_rttm_file
|
||||
[MS, FA, SER, DER_] = DER(
|
||||
ref_rttm,
|
||||
sys_rttm,
|
||||
config.ignore_overlap,
|
||||
config.forgiveness_collar, )
|
||||
|
||||
DER_list.append([nn, DER_])
|
||||
|
||||
if config.oracle_n_spkrs is True and config.backend == "kmeans":
|
||||
break
|
||||
|
||||
DER_list.sort(key=lambda x: x[1])
|
||||
tunned_nn = DER_list[0]
|
||||
|
||||
return tunned_nn[0]
|
||||
|
||||
|
||||
def dev_tuner(full_meta, split_type, save_dir, config):
|
||||
"""Tuning n_components on dev set. Used for nn based affinity matrix.
|
||||
Note: This is a very basic tunning for nn based affinity.
|
||||
This is work in progress till we find a better way.
|
||||
"""
|
||||
|
||||
DER_list = []
|
||||
pval = None
|
||||
for n_lambdas in range(1, config.max_num_spkrs + 1):
|
||||
|
||||
# Process whole dataset for value of n_lambdas.
|
||||
concate_rttm_file = diarize_dataset(full_meta, "dev", n_lambdas, p_v,
|
||||
save_dir, config)
|
||||
|
||||
ref_rttm = os.path.join(save_dir, config.ref_rttm_dir,
|
||||
"fullref_ami_dev.rttm")
|
||||
sys_rttm = concate_rttm_file
|
||||
[MS, FA, SER, DER_] = DER(
|
||||
ref_rttm,
|
||||
sys_rttm,
|
||||
config.ignore_overlap,
|
||||
config.forgiveness_collar, )
|
||||
|
||||
DER_list.append(DER_)
|
||||
|
||||
# Take n_lambdas with minmum DER.
|
||||
tuned_n_lambdas = DER_list.index(min(DER_list)) + 1
|
||||
|
||||
return tuned_n_lambdas
|
||||
|
||||
|
||||
def main(args, config):
|
||||
# AMI Dev Set: Tune hyperparams on dev set.
|
||||
# Read the embdding file for dev set generated during embdding compute
|
||||
dev_meta_file = os.path.join(
|
||||
args.data_dir,
|
||||
config.meta_data_dir,
|
||||
"ami_dev." + config.mic_type + ".subsegs.json", )
|
||||
with open(dev_meta_file, "r") as f:
|
||||
meta_dev = json.load(f)
|
||||
|
||||
full_meta = meta_dev
|
||||
|
||||
# Processing starts from here
|
||||
# Following few lines selects option for different backend and affinity matrices. Finds best values for hyperameters using dev set.
|
||||
ref_rttm_file = os.path.join(args.data_dir, config.ref_rttm_dir,
|
||||
"fullref_ami_dev.rttm")
|
||||
best_nn = None
|
||||
if config.affinity == "nn":
|
||||
logger.info("Tuning for nn (Multiple iterations over AMI Dev set)")
|
||||
best_nn = dev_nn_tuner(full_meta, args.data_dir, config)
|
||||
|
||||
n_lambdas = None
|
||||
best_pval = None
|
||||
|
||||
if config.affinity == "cos" and (config.backend == "SC" or
|
||||
config.backend == "kmeans"):
|
||||
# oracle num_spkrs or not, doesn't matter for kmeans and SC backends
|
||||
# cos: Tune for the best pval for SC /kmeans (for unknown num of spkrs)
|
||||
logger.info(
|
||||
"Tuning for p-value for SC (Multiple iterations over AMI Dev set)")
|
||||
best_pval = dev_pval_tuner(full_meta, args.data_dir, config)
|
||||
|
||||
elif config.backend == "AHC":
|
||||
logger.info("Tuning for threshold-value for AHC")
|
||||
best_threshold = dev_ahc_threshold_tuner(full_meta, args.data_dir,
|
||||
config)
|
||||
best_pval = best_threshold
|
||||
else:
|
||||
# NN for unknown num of speakers (can be used in future)
|
||||
if config.oracle_n_spkrs is False:
|
||||
# nn: Tune num of number of components (to be updated later)
|
||||
logger.info(
|
||||
"Tuning for number of eigen components for NN (Multiple iterations over AMI Dev set)"
|
||||
)
|
||||
# dev_tuner used for tuning num of components in NN. Can be used in future.
|
||||
n_lambdas = dev_tuner(full_meta, args.data_dir, config)
|
||||
|
||||
# load 'dev' and 'eval' metadata files.
|
||||
full_meta_dev = full_meta # current full_meta is for 'dev'
|
||||
eval_meta_file = os.path.join(
|
||||
args.data_dir,
|
||||
config.meta_data_dir,
|
||||
"ami_eval." + config.mic_type + ".subsegs.json", )
|
||||
with open(eval_meta_file, "r") as f:
|
||||
full_meta_eval = json.load(f)
|
||||
|
||||
# tag to be appended to final output DER files. Writing DER for individual files.
|
||||
type_of_num_spkr = "oracle" if config.oracle_n_spkrs else "est"
|
||||
tag = (
|
||||
type_of_num_spkr + "_" + str(config.affinity) + "." + config.mic_type)
|
||||
|
||||
# perform final diarization on 'dev' and 'eval' with best hyperparams.
|
||||
final_DERs = {}
|
||||
out_der_dir = os.path.join(args.data_dir, config.der_dir)
|
||||
if not os.path.exists(out_der_dir):
|
||||
os.makedirs(out_der_dir)
|
||||
|
||||
for split_type in ["dev", "eval"]:
|
||||
if split_type == "dev":
|
||||
full_meta = full_meta_dev
|
||||
else:
|
||||
full_meta = full_meta_eval
|
||||
|
||||
# performing diarization.
|
||||
msg = "Diarizing using best hyperparams: " + split_type + " set"
|
||||
logger.info(msg)
|
||||
out_boundaries = diarize_dataset(
|
||||
full_meta,
|
||||
split_type,
|
||||
n_lambdas=n_lambdas,
|
||||
pval=best_pval,
|
||||
n_neighbors=best_nn,
|
||||
save_dir=args.data_dir,
|
||||
config=config)
|
||||
|
||||
# computing DER.
|
||||
msg = "Computing DERs for " + split_type + " set"
|
||||
logger.info(msg)
|
||||
ref_rttm = os.path.join(args.data_dir, config.ref_rttm_dir,
|
||||
"fullref_ami_" + split_type + ".rttm")
|
||||
sys_rttm = out_boundaries
|
||||
[MS, FA, SER, DER_vals] = DER(
|
||||
ref_rttm,
|
||||
sys_rttm,
|
||||
config.ignore_overlap,
|
||||
config.forgiveness_collar,
|
||||
individual_file_scores=True, )
|
||||
|
||||
# writing DER values to a file. Append tag.
|
||||
der_file_name = split_type + "_DER_" + tag
|
||||
out_der_file = os.path.join(out_der_dir, der_file_name)
|
||||
msg = "Writing DER file to: " + out_der_file
|
||||
logger.info(msg)
|
||||
diar.write_ders_file(ref_rttm, DER_vals, out_der_file)
|
||||
|
||||
msg = ("AMI " + split_type + " set DER = %s %%\n" %
|
||||
(str(round(DER_vals[-1], 2))))
|
||||
logger.info(msg)
|
||||
final_DERs[split_type] = round(DER_vals[-1], 2)
|
||||
|
||||
# final print DERs
|
||||
msg = (
|
||||
"Final Diarization Error Rate (%%) on AMI corpus: Dev = %s %% | Eval = %s %%\n"
|
||||
% (str(final_DERs["dev"]), str(final_DERs["eval"])))
|
||||
logger.info(msg)
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
parser = argparse.ArgumentParser(__doc__)
|
||||
parser.add_argument(
|
||||
"--config", default=None, type=str, help="configuration file")
|
||||
parser.add_argument(
|
||||
"--data-dir",
|
||||
default="../data/",
|
||||
type=str,
|
||||
help="processsed data directory")
|
||||
args = parser.parse_args()
|
||||
config = CfgNode(new_allowed=True)
|
||||
if args.config:
|
||||
config.merge_from_file(args.config)
|
||||
|
||||
config.freeze()
|
||||
|
||||
main(args, config)
|
@ -0,0 +1,49 @@
|
||||
#!/bin/bash
|
||||
|
||||
stage=0
|
||||
set=L
|
||||
|
||||
. ${MAIN_ROOT}/utils/parse_options.sh || exit 1;
|
||||
set -o pipefail
|
||||
|
||||
data_folder=$1
|
||||
manual_annot_folder=$2
|
||||
save_folder=$3
|
||||
pretrained_model_dir=$4
|
||||
conf_path=$5
|
||||
device=$6
|
||||
|
||||
ref_rttm_dir=${save_folder}/ref_rttms
|
||||
meta_data_dir=${save_folder}/metadata
|
||||
|
||||
if [ ${stage} -le 0 ]; then
|
||||
echo "AMI Data preparation"
|
||||
python local/ami_prepare.py --data_folder ${data_folder} \
|
||||
--manual_annot_folder ${manual_annot_folder} \
|
||||
--save_folder ${save_folder} --ref_rttm_dir ${ref_rttm_dir} \
|
||||
--meta_data_dir ${meta_data_dir}
|
||||
|
||||
if [ $? -ne 0 ]; then
|
||||
echo "Prepare AMI failed. Please check log message."
|
||||
exit 1
|
||||
fi
|
||||
echo "AMI data preparation done."
|
||||
fi
|
||||
|
||||
if [ ${stage} -le 1 ]; then
|
||||
# extra embddings for dev and eval dataset
|
||||
for name in dev eval; do
|
||||
python local/compute_embdding.py --config ${conf_path} \
|
||||
--data-dir ${save_folder} \
|
||||
--device ${device} \
|
||||
--dataset ${name} \
|
||||
--load-checkpoint ${pretrained_model_dir}
|
||||
done
|
||||
fi
|
||||
|
||||
if [ ${stage} -le 2 ]; then
|
||||
# tune hyperparams on dev set
|
||||
# perform final diarization on 'dev' and 'eval' with best hyperparams
|
||||
python local/experiment.py --config ${conf_path} \
|
||||
--data-dir ${save_folder}
|
||||
fi
|
@ -1,14 +1,46 @@
|
||||
#!/bin/bash
|
||||
|
||||
. path.sh || exit 1;
|
||||
. ./path.sh || exit 1;
|
||||
set -e
|
||||
|
||||
stage=1
|
||||
stage=0
|
||||
|
||||
#TARGET_DIR=${MAIN_ROOT}/dataset/ami
|
||||
TARGET_DIR=/home/dataset/AMI
|
||||
data_folder=${TARGET_DIR}/amicorpus #e.g., /path/to/amicorpus/
|
||||
manual_annot_folder=${TARGET_DIR}/ami_public_manual_1.6.2 #e.g., /path/to/ami_public_manual_1.6.2/
|
||||
|
||||
save_folder=./save
|
||||
pretraind_model_dir=${save_folder}/sv0_ecapa_tdnn_voxceleb12_ckpt_0_1_1/model
|
||||
conf_path=conf/ecapa_tdnn.yaml
|
||||
device=gpu
|
||||
|
||||
. ${MAIN_ROOT}/utils/parse_options.sh || exit 1;
|
||||
|
||||
if [ ${stage} -le 1 ]; then
|
||||
# prepare data
|
||||
bash ./local/data.sh || exit -1
|
||||
fi
|
||||
if [ $stage -le 0 ]; then
|
||||
# Prepare data
|
||||
# Download AMI corpus, You need around 10GB of free space to get whole data
|
||||
# The signals are too large to package in this way,
|
||||
# so you need to use the chooser to indicate which ones you wish to download
|
||||
echo "Please follow https://groups.inf.ed.ac.uk/ami/download/ to download the data."
|
||||
echo "Annotations: AMI manual annotations v1.6.2 "
|
||||
echo "Signals: "
|
||||
echo "1) Select one or more AMI meetings: the IDs please follow ./ami_split.py"
|
||||
echo "2) Select media streams: Just select Headset mix"
|
||||
fi
|
||||
|
||||
if [ $stage -le 1 ]; then
|
||||
# Download the pretrained model
|
||||
wget https://paddlespeech.bj.bcebos.com/vector/voxceleb/sv0_ecapa_tdnn_voxceleb12_ckpt_0_1_1.tar.gz
|
||||
mkdir -p ${save_folder} && tar -xvf sv0_ecapa_tdnn_voxceleb12_ckpt_0_1_1.tar.gz -C ${save_folder}
|
||||
rm -rf sv0_ecapa_tdnn_voxceleb12_ckpt_0_1_1.tar.gz
|
||||
echo "download the pretrained ECAPA-TDNN Model to path: "${pretraind_model_dir}
|
||||
fi
|
||||
|
||||
if [ $stage -le 2 ]; then
|
||||
# Tune hyperparams on dev set and perform final diarization on dev and eval with best hyperparams.
|
||||
echo ${data_folder} ${manual_annot_folder} ${save_folder} ${pretraind_model_dir} ${conf_path}
|
||||
bash ./local/process.sh ${data_folder} ${manual_annot_folder} \
|
||||
${save_folder} ${pretraind_model_dir} ${conf_path} ${device} || exit 1
|
||||
fi
|
||||
|
||||
|
@ -0,0 +1,32 @@
|
||||
train_output_path=$1
|
||||
|
||||
stage=0
|
||||
stop_stage=0
|
||||
|
||||
# only support default_fastspeech2/speedyspeech + hifigan/mb_melgan now!
|
||||
|
||||
# synthesize from metadata
|
||||
if [ ${stage} -le 0 ] && [ ${stop_stage} -ge 0 ]; then
|
||||
python3 ${BIN_DIR}/../ort_predict.py \
|
||||
--inference_dir=${train_output_path}/inference_onnx \
|
||||
--am=speedyspeech_csmsc \
|
||||
--voc=hifigan_csmsc \
|
||||
--test_metadata=dump/test/norm/metadata.jsonl \
|
||||
--output_dir=${train_output_path}/onnx_infer_out \
|
||||
--device=cpu \
|
||||
--cpu_threads=2
|
||||
fi
|
||||
|
||||
# e2e, synthesize from text
|
||||
if [ ${stage} -le 1 ] && [ ${stop_stage} -ge 1 ]; then
|
||||
python3 ${BIN_DIR}/../ort_predict_e2e.py \
|
||||
--inference_dir=${train_output_path}/inference_onnx \
|
||||
--am=speedyspeech_csmsc \
|
||||
--voc=hifigan_csmsc \
|
||||
--output_dir=${train_output_path}/onnx_infer_out_e2e \
|
||||
--text=${BIN_DIR}/../csmsc_test.txt \
|
||||
--phones_dict=dump/phone_id_map.txt \
|
||||
--tones_dict=dump/tone_id_map.txt \
|
||||
--device=cpu \
|
||||
--cpu_threads=2
|
||||
fi
|
@ -0,0 +1 @@
|
||||
../../tts3/local/paddle2onnx.sh
|
@ -0,0 +1,47 @@
|
||||
#!/bin/bash
|
||||
|
||||
train_output_path=$1
|
||||
|
||||
stage=0
|
||||
stop_stage=0
|
||||
|
||||
# pwgan
|
||||
if [ ${stage} -le 0 ] && [ ${stop_stage} -ge 0 ]; then
|
||||
python3 ${BIN_DIR}/../inference_streaming.py \
|
||||
--inference_dir=${train_output_path}/inference_streaming \
|
||||
--am=fastspeech2_csmsc \
|
||||
--am_stat=dump/train/speech_stats.npy \
|
||||
--voc=pwgan_csmsc \
|
||||
--text=${BIN_DIR}/../sentences.txt \
|
||||
--output_dir=${train_output_path}/pd_infer_out_streaming \
|
||||
--phones_dict=dump/phone_id_map.txt \
|
||||
--am_streaming=True
|
||||
fi
|
||||
|
||||
# for more GAN Vocoders
|
||||
# multi band melgan
|
||||
if [ ${stage} -le 1 ] && [ ${stop_stage} -ge 1 ]; then
|
||||
python3 ${BIN_DIR}/../inference_streaming.py \
|
||||
--inference_dir=${train_output_path}/inference_streaming \
|
||||
--am=fastspeech2_csmsc \
|
||||
--am_stat=dump/train/speech_stats.npy \
|
||||
--voc=mb_melgan_csmsc \
|
||||
--text=${BIN_DIR}/../sentences.txt \
|
||||
--output_dir=${train_output_path}/pd_infer_out_streaming \
|
||||
--phones_dict=dump/phone_id_map.txt \
|
||||
--am_streaming=True
|
||||
fi
|
||||
|
||||
# hifigan
|
||||
if [ ${stage} -le 2 ] && [ ${stop_stage} -ge 2 ]; then
|
||||
python3 ${BIN_DIR}/../inference_streaming.py \
|
||||
--inference_dir=${train_output_path}/inference_streaming \
|
||||
--am=fastspeech2_csmsc \
|
||||
--am_stat=dump/train/speech_stats.npy \
|
||||
--voc=hifigan_csmsc \
|
||||
--text=${BIN_DIR}/../sentences.txt \
|
||||
--output_dir=${train_output_path}/pd_infer_out_streaming \
|
||||
--phones_dict=dump/phone_id_map.txt \
|
||||
--am_streaming=True
|
||||
fi
|
||||
|
@ -0,0 +1,31 @@
|
||||
train_output_path=$1
|
||||
|
||||
stage=0
|
||||
stop_stage=0
|
||||
|
||||
# only support default_fastspeech2/speedyspeech + hifigan/mb_melgan now!
|
||||
|
||||
# synthesize from metadata
|
||||
if [ ${stage} -le 0 ] && [ ${stop_stage} -ge 0 ]; then
|
||||
python3 ${BIN_DIR}/../ort_predict.py \
|
||||
--inference_dir=${train_output_path}/inference_onnx \
|
||||
--am=fastspeech2_csmsc \
|
||||
--voc=hifigan_csmsc \
|
||||
--test_metadata=dump/test/norm/metadata.jsonl \
|
||||
--output_dir=${train_output_path}/onnx_infer_out \
|
||||
--device=cpu \
|
||||
--cpu_threads=2
|
||||
fi
|
||||
|
||||
# e2e, synthesize from text
|
||||
if [ ${stage} -le 1 ] && [ ${stop_stage} -ge 1 ]; then
|
||||
python3 ${BIN_DIR}/../ort_predict_e2e.py \
|
||||
--inference_dir=${train_output_path}/inference_onnx \
|
||||
--am=fastspeech2_csmsc \
|
||||
--voc=hifigan_csmsc \
|
||||
--output_dir=${train_output_path}/onnx_infer_out_e2e \
|
||||
--text=${BIN_DIR}/../csmsc_test.txt \
|
||||
--phones_dict=dump/phone_id_map.txt \
|
||||
--device=cpu \
|
||||
--cpu_threads=2
|
||||
fi
|
@ -0,0 +1,19 @@
|
||||
train_output_path=$1
|
||||
|
||||
stage=0
|
||||
stop_stage=0
|
||||
|
||||
# e2e, synthesize from text
|
||||
if [ ${stage} -le 0 ] && [ ${stop_stage} -ge 0 ]; then
|
||||
python3 ${BIN_DIR}/../ort_predict_streaming.py \
|
||||
--inference_dir=${train_output_path}/inference_onnx_streaming \
|
||||
--am=fastspeech2_csmsc \
|
||||
--am_stat=dump/train/speech_stats.npy \
|
||||
--voc=hifigan_csmsc \
|
||||
--output_dir=${train_output_path}/onnx_infer_out_streaming \
|
||||
--text=${BIN_DIR}/../csmsc_test.txt \
|
||||
--phones_dict=dump/phone_id_map.txt \
|
||||
--device=cpu \
|
||||
--cpu_threads=2 \
|
||||
--am_streaming=True
|
||||
fi
|
@ -0,0 +1,23 @@
|
||||
train_output_path=$1
|
||||
model_dir=$2
|
||||
output_dir=$3
|
||||
model=$4
|
||||
|
||||
enable_dev_version=True
|
||||
|
||||
model_name=${model%_*}
|
||||
echo model_name: ${model_name}
|
||||
|
||||
if [ ${model_name} = 'mb_melgan' ] ;then
|
||||
enable_dev_version=False
|
||||
fi
|
||||
|
||||
mkdir -p ${train_output_path}/${output_dir}
|
||||
|
||||
paddle2onnx \
|
||||
--model_dir ${train_output_path}/${model_dir} \
|
||||
--model_filename ${model}.pdmodel \
|
||||
--params_filename ${model}.pdiparams \
|
||||
--save_file ${train_output_path}/${output_dir}/${model}.onnx \
|
||||
--opset_version 11 \
|
||||
--enable_dev_version ${enable_dev_version}
|
@ -0,0 +1,9 @@
|
||||
# iwslt2012
|
||||
|
||||
## Ernie
|
||||
|
||||
| |COMMA | PERIOD | QUESTION | OVERALL|
|
||||
|:-----:|:-----:|:-----:|:-----:|:-----:|
|
||||
|Precision |0.510955 |0.526462 |0.820755 |0.619391|
|
||||
|Recall |0.517433 |0.564179 |0.861386 |0.647666|
|
||||
|F1 |0.514173 |0.544669 |0.840580 |0.633141|
|
@ -0,0 +1,60 @@
|
||||
###########################################
|
||||
# Data #
|
||||
###########################################
|
||||
augment: True
|
||||
batch_size: 32
|
||||
num_workers: 2
|
||||
num_speakers: 1211 # 1211 vox1, 5994 vox2, 7205 vox1+2, test speakers: 41
|
||||
shuffle: True
|
||||
skip_prep: False
|
||||
split_ratio: 0.9
|
||||
chunk_duration: 3.0 # seconds
|
||||
random_chunk: True
|
||||
verification_file: data/vox1/veri_test2.txt
|
||||
|
||||
###########################################################
|
||||
# FEATURE EXTRACTION SETTING #
|
||||
###########################################################
|
||||
# currently, we only support fbank
|
||||
sr: 16000 # sample rate
|
||||
n_mels: 80
|
||||
window_size: 400 #25ms, sample rate 16000, 25 * 16000 / 1000 = 400
|
||||
hop_size: 160 #10ms, sample rate 16000, 10 * 16000 / 1000 = 160
|
||||
|
||||
###########################################################
|
||||
# MODEL SETTING #
|
||||
###########################################################
|
||||
# currently, we only support ecapa-tdnn in the ecapa_tdnn.yaml
|
||||
# if we want use another model, please choose another configuration yaml file
|
||||
model:
|
||||
input_size: 80
|
||||
channels: [512, 512, 512, 512, 1536]
|
||||
kernel_sizes: [5, 3, 3, 3, 1]
|
||||
dilations: [1, 2, 3, 4, 1]
|
||||
attention_channels: 128
|
||||
lin_neurons: 192
|
||||
|
||||
###########################################
|
||||
# Training #
|
||||
###########################################
|
||||
seed: 1986 # according from speechbrain configuration
|
||||
epochs: 100
|
||||
save_interval: 10
|
||||
log_interval: 10
|
||||
learning_rate: 1e-8
|
||||
max_lr: 1e-3
|
||||
step_size: 140000
|
||||
|
||||
###########################################
|
||||
# loss #
|
||||
###########################################
|
||||
margin: 0.2
|
||||
scale: 30
|
||||
|
||||
###########################################
|
||||
# Testing #
|
||||
###########################################
|
||||
global_embedding_norm: True
|
||||
embedding_mean_norm: True
|
||||
embedding_std_norm: False
|
||||
|
@ -0,0 +1,167 @@
|
||||
# Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
"""
|
||||
Convert the PaddleSpeech jsonline format data to csv format data in voxceleb experiment.
|
||||
Currently, Speaker Identificaton Training process use csv format.
|
||||
"""
|
||||
import argparse
|
||||
import csv
|
||||
import os
|
||||
from typing import List
|
||||
|
||||
import tqdm
|
||||
from yacs.config import CfgNode
|
||||
|
||||
from paddleaudio import load as load_audio
|
||||
from paddlespeech.s2t.utils.log import Log
|
||||
from paddlespeech.vector.utils.vector_utils import get_chunks
|
||||
|
||||
logger = Log(__name__).getlog()
|
||||
|
||||
|
||||
def get_chunks_list(wav_file: str,
|
||||
split_chunks: bool,
|
||||
base_path: str,
|
||||
chunk_duration: float=3.0) -> List[List[str]]:
|
||||
"""Get the single audio file info
|
||||
|
||||
Args:
|
||||
wav_file (list): the wav audio file and get this audio segment info list
|
||||
split_chunks (bool): audio split flag
|
||||
base_path (str): the audio base path
|
||||
chunk_duration (float): the chunk duration.
|
||||
if set the split_chunks, we split the audio into multi-chunks segment.
|
||||
"""
|
||||
waveform, sr = load_audio(wav_file)
|
||||
audio_id = wav_file.split("/rir_noise/")[-1].split(".")[0]
|
||||
audio_duration = waveform.shape[0] / sr
|
||||
|
||||
ret = []
|
||||
if split_chunks and audio_duration > chunk_duration: # Split into pieces of self.chunk_duration seconds.
|
||||
uniq_chunks_list = get_chunks(chunk_duration, audio_id, audio_duration)
|
||||
|
||||
for idx, chunk in enumerate(uniq_chunks_list):
|
||||
s, e = chunk.split("_")[-2:] # Timestamps of start and end
|
||||
start_sample = int(float(s) * sr)
|
||||
end_sample = int(float(e) * sr)
|
||||
|
||||
# currently, all vector csv data format use one representation
|
||||
# id, duration, wav, start, stop, label
|
||||
# in rirs noise, all the label name is 'noise'
|
||||
# the label is string type and we will convert it to integer type in training
|
||||
ret.append([
|
||||
chunk, audio_duration, wav_file, start_sample, end_sample,
|
||||
"noise"
|
||||
])
|
||||
else: # Keep whole audio.
|
||||
ret.append(
|
||||
[audio_id, audio_duration, wav_file, 0, waveform.shape[0], "noise"])
|
||||
return ret
|
||||
|
||||
|
||||
def generate_csv(wav_files,
|
||||
output_file: str,
|
||||
base_path: str,
|
||||
split_chunks: bool=True):
|
||||
"""Prepare the csv file according the wav files
|
||||
|
||||
Args:
|
||||
wav_files (list): all the audio list to prepare the csv file
|
||||
output_file (str): the output csv file
|
||||
config (CfgNode): yaml configuration content
|
||||
split_chunks (bool): audio split flag
|
||||
"""
|
||||
logger.info(f'Generating csv: {output_file}')
|
||||
header = ["utt_id", "duration", "wav", "start", "stop", "label"]
|
||||
csv_lines = []
|
||||
for item in tqdm.tqdm(wav_files):
|
||||
csv_lines.extend(
|
||||
get_chunks_list(
|
||||
item, base_path=base_path, split_chunks=split_chunks))
|
||||
|
||||
if not os.path.exists(os.path.dirname(output_file)):
|
||||
os.makedirs(os.path.dirname(output_file))
|
||||
|
||||
with open(output_file, mode="w") as csv_f:
|
||||
csv_writer = csv.writer(
|
||||
csv_f, delimiter=",", quotechar='"', quoting=csv.QUOTE_MINIMAL)
|
||||
csv_writer.writerow(header)
|
||||
for line in csv_lines:
|
||||
csv_writer.writerow(line)
|
||||
|
||||
|
||||
def prepare_data(args, config):
|
||||
"""Convert the jsonline format to csv format
|
||||
|
||||
Args:
|
||||
args (argparse.Namespace): scripts args
|
||||
config (CfgNode): yaml configuration content
|
||||
"""
|
||||
# if external config set the skip_prep flat, we will do nothing
|
||||
if config.skip_prep:
|
||||
return
|
||||
|
||||
base_path = args.noise_dir
|
||||
wav_path = os.path.join(base_path, "RIRS_NOISES")
|
||||
logger.info(f"base path: {base_path}")
|
||||
logger.info(f"wav path: {wav_path}")
|
||||
rir_list = os.path.join(wav_path, "real_rirs_isotropic_noises", "rir_list")
|
||||
rir_files = []
|
||||
with open(rir_list, 'r') as f:
|
||||
for line in f.readlines():
|
||||
rir_file = line.strip().split(' ')[-1]
|
||||
rir_files.append(os.path.join(base_path, rir_file))
|
||||
|
||||
noise_list = os.path.join(wav_path, "pointsource_noises", "noise_list")
|
||||
noise_files = []
|
||||
with open(noise_list, 'r') as f:
|
||||
for line in f.readlines():
|
||||
noise_file = line.strip().split(' ')[-1]
|
||||
noise_files.append(os.path.join(base_path, noise_file))
|
||||
|
||||
csv_path = os.path.join(args.data_dir, 'csv')
|
||||
logger.info(f"csv path: {csv_path}")
|
||||
generate_csv(
|
||||
rir_files, os.path.join(csv_path, 'rir.csv'), base_path=base_path)
|
||||
generate_csv(
|
||||
noise_files, os.path.join(csv_path, 'noise.csv'), base_path=base_path)
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
parser = argparse.ArgumentParser(description=__doc__)
|
||||
parser.add_argument(
|
||||
"--noise_dir",
|
||||
default=None,
|
||||
required=True,
|
||||
help="The noise dataset dataset directory.")
|
||||
parser.add_argument(
|
||||
"--data_dir",
|
||||
default=None,
|
||||
required=True,
|
||||
help="The target directory stores the csv files")
|
||||
parser.add_argument(
|
||||
"--config",
|
||||
default=None,
|
||||
required=True,
|
||||
type=str,
|
||||
help="configuration file")
|
||||
args = parser.parse_args()
|
||||
|
||||
# parse the yaml config file
|
||||
config = CfgNode(new_allowed=True)
|
||||
if args.config:
|
||||
config.merge_from_file(args.config)
|
||||
|
||||
# prepare the csv file from jsonlines files
|
||||
prepare_data(args, config)
|
@ -0,0 +1,251 @@
|
||||
# Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
"""
|
||||
Convert the PaddleSpeech jsonline format data to csv format data in voxceleb experiment.
|
||||
Currently, Speaker Identificaton Training process use csv format.
|
||||
"""
|
||||
import argparse
|
||||
import csv
|
||||
import json
|
||||
import os
|
||||
import random
|
||||
|
||||
import tqdm
|
||||
from yacs.config import CfgNode
|
||||
|
||||
from paddleaudio import load as load_audio
|
||||
from paddlespeech.s2t.utils.log import Log
|
||||
from paddlespeech.vector.utils.vector_utils import get_chunks
|
||||
|
||||
logger = Log(__name__).getlog()
|
||||
|
||||
|
||||
def prepare_csv(wav_files, output_file, config, split_chunks=True):
|
||||
"""Prepare the csv file according the wav files
|
||||
|
||||
Args:
|
||||
wav_files (list): all the audio list to prepare the csv file
|
||||
output_file (str): the output csv file
|
||||
config (CfgNode): yaml configuration content
|
||||
split_chunks (bool, optional): audio split flag. Defaults to True.
|
||||
"""
|
||||
if not os.path.exists(os.path.dirname(output_file)):
|
||||
os.makedirs(os.path.dirname(output_file))
|
||||
csv_lines = []
|
||||
header = ["utt_id", "duration", "wav", "start", "stop", "label"]
|
||||
# voxceleb meta info for each training utterance segment
|
||||
# we extract a segment from a utterance to train
|
||||
# and the segment' period is between start and stop time point in the original wav file
|
||||
# each field in the meta info means as follows:
|
||||
# utt_id: the utterance segment name, which is uniq in training dataset
|
||||
# duration: the total utterance time
|
||||
# wav: utterance file path, which should be absoulute path
|
||||
# start: start point in the original wav file sample point range
|
||||
# stop: stop point in the original wav file sample point range
|
||||
# label: the utterance segment's label name,
|
||||
# which is speaker name in speaker verification domain
|
||||
for item in tqdm.tqdm(wav_files, total=len(wav_files)):
|
||||
item = json.loads(item.strip())
|
||||
audio_id = item['utt'].replace(".wav",
|
||||
"") # we remove the wav suffix name
|
||||
audio_duration = item['feat_shape'][0]
|
||||
wav_file = item['feat']
|
||||
label = audio_id.split('-')[
|
||||
0] # speaker name in speaker verification domain
|
||||
waveform, sr = load_audio(wav_file)
|
||||
if split_chunks:
|
||||
uniq_chunks_list = get_chunks(config.chunk_duration, audio_id,
|
||||
audio_duration)
|
||||
for chunk in uniq_chunks_list:
|
||||
s, e = chunk.split("_")[-2:] # Timestamps of start and end
|
||||
start_sample = int(float(s) * sr)
|
||||
end_sample = int(float(e) * sr)
|
||||
# id, duration, wav, start, stop, label
|
||||
# in vector, the label in speaker id
|
||||
csv_lines.append([
|
||||
chunk, audio_duration, wav_file, start_sample, end_sample,
|
||||
label
|
||||
])
|
||||
else:
|
||||
csv_lines.append([
|
||||
audio_id, audio_duration, wav_file, 0, waveform.shape[0], label
|
||||
])
|
||||
|
||||
with open(output_file, mode="w") as csv_f:
|
||||
csv_writer = csv.writer(
|
||||
csv_f, delimiter=',', quotechar='"', quoting=csv.QUOTE_MINIMAL)
|
||||
csv_writer.writerow(header)
|
||||
for line in csv_lines:
|
||||
csv_writer.writerow(line)
|
||||
|
||||
|
||||
def get_enroll_test_list(dataset_list, verification_file):
|
||||
"""Get the enroll and test utterance list from all the voxceleb1 test utterance dataset.
|
||||
Generally, we get the enroll and test utterances from the verfification file.
|
||||
The verification file format as follows:
|
||||
target/nontarget enroll-utt test-utt,
|
||||
we set 0 as nontarget and 1 as target, eg:
|
||||
0 a.wav b.wav
|
||||
1 a.wav a.wav
|
||||
|
||||
Args:
|
||||
dataset_list (list): all the dataset to get the test utterances
|
||||
verification_file (str): voxceleb1 trial file
|
||||
"""
|
||||
logger.info(f"verification file: {verification_file}")
|
||||
enroll_audios = set()
|
||||
test_audios = set()
|
||||
with open(verification_file, 'r') as f:
|
||||
for line in f:
|
||||
_, enroll_file, test_file = line.strip().split(' ')
|
||||
enroll_audios.add('-'.join(enroll_file.split('/')))
|
||||
test_audios.add('-'.join(test_file.split('/')))
|
||||
|
||||
enroll_files = []
|
||||
test_files = []
|
||||
for dataset in dataset_list:
|
||||
with open(dataset, 'r') as f:
|
||||
for line in f:
|
||||
# audio_id may be in enroll and test at the same time
|
||||
# eg: 1 a.wav a.wav
|
||||
# the audio a.wav is enroll and test file at the same time
|
||||
audio_id = json.loads(line.strip())['utt']
|
||||
if audio_id in enroll_audios:
|
||||
enroll_files.append(line)
|
||||
if audio_id in test_audios:
|
||||
test_files.append(line)
|
||||
|
||||
enroll_files = sorted(enroll_files)
|
||||
test_files = sorted(test_files)
|
||||
|
||||
return enroll_files, test_files
|
||||
|
||||
|
||||
def get_train_dev_list(dataset_list, target_dir, split_ratio):
|
||||
"""Get the train and dev utterance list from all the training utterance dataset.
|
||||
Generally, we use the split_ratio as the train dataset ratio,
|
||||
and the remaining utterance (ratio is 1 - split_ratio) is the dev dataset
|
||||
|
||||
Args:
|
||||
dataset_list (list): all the dataset to get the all utterances
|
||||
target_dir (str): the target train and dev directory,
|
||||
we will create the csv directory to store the {train,dev}.csv file
|
||||
split_ratio (float): train dataset ratio in all utterance list
|
||||
"""
|
||||
logger.info("start to get train and dev utt list")
|
||||
if not os.path.exists(os.path.join(target_dir, "meta")):
|
||||
os.makedirs(os.path.join(target_dir, "meta"))
|
||||
|
||||
audio_files = []
|
||||
speakers = set()
|
||||
for dataset in dataset_list:
|
||||
with open(dataset, 'r') as f:
|
||||
for line in f:
|
||||
# the label is speaker name
|
||||
label_name = json.loads(line.strip())['utt2spk']
|
||||
speakers.add(label_name)
|
||||
audio_files.append(line.strip())
|
||||
speakers = sorted(speakers)
|
||||
logger.info(f"we get {len(speakers)} speakers from all the train dataset")
|
||||
|
||||
with open(os.path.join(target_dir, "meta", "label2id.txt"), 'w') as f:
|
||||
for label_id, label_name in enumerate(speakers):
|
||||
f.write(f'{label_name} {label_id}\n')
|
||||
logger.info(
|
||||
f'we store the speakers to {os.path.join(target_dir, "meta", "label2id.txt")}'
|
||||
)
|
||||
|
||||
# the split_ratio is for train dataset
|
||||
# the remaining is for dev dataset
|
||||
split_idx = int(split_ratio * len(audio_files))
|
||||
audio_files = sorted(audio_files)
|
||||
random.shuffle(audio_files)
|
||||
train_files, dev_files = audio_files[:split_idx], audio_files[split_idx:]
|
||||
logger.info(
|
||||
f"we get train utterances: {len(train_files)}, dev utterance: {len(dev_files)}"
|
||||
)
|
||||
return train_files, dev_files
|
||||
|
||||
|
||||
def prepare_data(args, config):
|
||||
"""Convert the jsonline format to csv format
|
||||
|
||||
Args:
|
||||
args (argparse.Namespace): scripts args
|
||||
config (CfgNode): yaml configuration content
|
||||
"""
|
||||
# stage0: set the random seed
|
||||
random.seed(config.seed)
|
||||
|
||||
# if external config set the skip_prep flat, we will do nothing
|
||||
if config.skip_prep:
|
||||
return
|
||||
|
||||
# stage 1: prepare the enroll and test csv file
|
||||
# And we generate the speaker to label file label2id.txt
|
||||
logger.info("start to prepare the data csv file")
|
||||
enroll_files, test_files = get_enroll_test_list(
|
||||
[args.test], verification_file=config.verification_file)
|
||||
prepare_csv(
|
||||
enroll_files,
|
||||
os.path.join(args.target_dir, "csv", "enroll.csv"),
|
||||
config,
|
||||
split_chunks=False)
|
||||
prepare_csv(
|
||||
test_files,
|
||||
os.path.join(args.target_dir, "csv", "test.csv"),
|
||||
config,
|
||||
split_chunks=False)
|
||||
|
||||
# stage 2: prepare the train and dev csv file
|
||||
# we get the train dataset ratio as config.split_ratio
|
||||
# and the remaining is dev dataset
|
||||
logger.info("start to prepare the data csv file")
|
||||
train_files, dev_files = get_train_dev_list(
|
||||
args.train, target_dir=args.target_dir, split_ratio=config.split_ratio)
|
||||
prepare_csv(train_files,
|
||||
os.path.join(args.target_dir, "csv", "train.csv"), config)
|
||||
prepare_csv(dev_files,
|
||||
os.path.join(args.target_dir, "csv", "dev.csv"), config)
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
parser = argparse.ArgumentParser(description=__doc__)
|
||||
parser.add_argument(
|
||||
"--train",
|
||||
required=True,
|
||||
nargs='+',
|
||||
help="The jsonline files list for train.")
|
||||
parser.add_argument(
|
||||
"--test", required=True, help="The jsonline file for test")
|
||||
parser.add_argument(
|
||||
"--target_dir",
|
||||
default=None,
|
||||
required=True,
|
||||
help="The target directory stores the csv files and meta file.")
|
||||
parser.add_argument(
|
||||
"--config",
|
||||
default=None,
|
||||
required=True,
|
||||
type=str,
|
||||
help="configuration file")
|
||||
args = parser.parse_args()
|
||||
|
||||
# parse the yaml config file
|
||||
config = CfgNode(new_allowed=True)
|
||||
if args.config:
|
||||
config.merge_from_file(args.config)
|
||||
|
||||
# prepare the csv file from jsonlines files
|
||||
prepare_data(args, config)
|
@ -1,63 +0,0 @@
|
||||
# Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
from typing import Callable
|
||||
|
||||
import mcd.metrics_fast as mt
|
||||
import numpy as np
|
||||
from mcd import dtw
|
||||
|
||||
__all__ = [
|
||||
'mcd_distance',
|
||||
]
|
||||
|
||||
|
||||
def mcd_distance(xs: np.ndarray,
|
||||
ys: np.ndarray,
|
||||
cost_fn: Callable=mt.logSpecDbDist) -> float:
|
||||
"""Mel cepstral distortion (MCD), dtw distance.
|
||||
|
||||
Dynamic Time Warping.
|
||||
Uses dynamic programming to compute:
|
||||
|
||||
Examples:
|
||||
.. code-block:: python
|
||||
|
||||
wps[i, j] = cost_fn(xs[i], ys[j]) + min(
|
||||
wps[i-1, j ], // vertical / insertion / expansion
|
||||
wps[i , j-1], // horizontal / deletion / compression
|
||||
wps[i-1, j-1]) // diagonal / match
|
||||
|
||||
dtw = sqrt(wps[-1, -1])
|
||||
|
||||
Cost Function:
|
||||
Examples:
|
||||
.. code-block:: python
|
||||
|
||||
logSpecDbConst = 10.0 / math.log(10.0) * math.sqrt(2.0)
|
||||
|
||||
def logSpecDbDist(x, y):
|
||||
diff = x - y
|
||||
return logSpecDbConst * math.sqrt(np.inner(diff, diff))
|
||||
|
||||
Args:
|
||||
xs (np.ndarray): ref sequence, [T,D]
|
||||
ys (np.ndarray): hyp sequence, [T,D]
|
||||
cost_fn (Callable, optional): Cost function. Defaults to mt.logSpecDbDist.
|
||||
|
||||
Returns:
|
||||
float: dtw distance
|
||||
"""
|
||||
|
||||
min_cost, path = dtw.dtw(xs, ys, cost_fn)
|
||||
return min_cost
|
@ -0,0 +1,30 @@
|
||||
# Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
import numpy as np
|
||||
|
||||
|
||||
def pcm16to32(audio: np.ndarray) -> np.ndarray:
|
||||
"""pcm int16 to float32
|
||||
|
||||
Args:
|
||||
audio (np.ndarray): Waveform with dtype of int16.
|
||||
|
||||
Returns:
|
||||
np.ndarray: Waveform with dtype of float32.
|
||||
"""
|
||||
if audio.dtype == np.int16:
|
||||
audio = audio.astype("float32")
|
||||
bits = np.iinfo(np.int16).bits
|
||||
audio = audio / (2**(bits - 1))
|
||||
return audio
|
@ -0,0 +1,97 @@
|
||||
# Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
|
||||
pretrained_models = {
|
||||
# The tags for pretrained_models should be "{model_name}[_{dataset}][-{lang}][-...]".
|
||||
# e.g. "conformer_wenetspeech-zh-16k" and "panns_cnn6-32k".
|
||||
# Command line and python api use "{model_name}[_{dataset}]" as --model, usage:
|
||||
# "paddlespeech asr --model conformer_wenetspeech --lang zh --sr 16000 --input ./input.wav"
|
||||
"conformer_wenetspeech-zh-16k": {
|
||||
'url':
|
||||
'https://paddlespeech.bj.bcebos.com/s2t/wenetspeech/asr1_conformer_wenetspeech_ckpt_0.1.1.model.tar.gz',
|
||||
'md5':
|
||||
'76cb19ed857e6623856b7cd7ebbfeda4',
|
||||
'cfg_path':
|
||||
'model.yaml',
|
||||
'ckpt_path':
|
||||
'exp/conformer/checkpoints/wenetspeech',
|
||||
},
|
||||
"transformer_librispeech-en-16k": {
|
||||
'url':
|
||||
'https://paddlespeech.bj.bcebos.com/s2t/librispeech/asr1/asr1_transformer_librispeech_ckpt_0.1.1.model.tar.gz',
|
||||
'md5':
|
||||
'2c667da24922aad391eacafe37bc1660',
|
||||
'cfg_path':
|
||||
'model.yaml',
|
||||
'ckpt_path':
|
||||
'exp/transformer/checkpoints/avg_10',
|
||||
},
|
||||
"deepspeech2offline_aishell-zh-16k": {
|
||||
'url':
|
||||
'https://paddlespeech.bj.bcebos.com/s2t/aishell/asr0/asr0_deepspeech2_aishell_ckpt_0.1.1.model.tar.gz',
|
||||
'md5':
|
||||
'932c3593d62fe5c741b59b31318aa314',
|
||||
'cfg_path':
|
||||
'model.yaml',
|
||||
'ckpt_path':
|
||||
'exp/deepspeech2/checkpoints/avg_1',
|
||||
'lm_url':
|
||||
'https://deepspeech.bj.bcebos.com/zh_lm/zh_giga.no_cna_cmn.prune01244.klm',
|
||||
'lm_md5':
|
||||
'29e02312deb2e59b3c8686c7966d4fe3'
|
||||
},
|
||||
"deepspeech2online_aishell-zh-16k": {
|
||||
'url':
|
||||
'https://paddlespeech.bj.bcebos.com/s2t/aishell/asr0/asr0_deepspeech2_online_aishell_ckpt_0.2.0.model.tar.gz',
|
||||
'md5':
|
||||
'23e16c69730a1cb5d735c98c83c21e16',
|
||||
'cfg_path':
|
||||
'model.yaml',
|
||||
'ckpt_path':
|
||||
'exp/deepspeech2_online/checkpoints/avg_1',
|
||||
'lm_url':
|
||||
'https://deepspeech.bj.bcebos.com/zh_lm/zh_giga.no_cna_cmn.prune01244.klm',
|
||||
'lm_md5':
|
||||
'29e02312deb2e59b3c8686c7966d4fe3'
|
||||
},
|
||||
"deepspeech2offline_librispeech-en-16k": {
|
||||
'url':
|
||||
'https://paddlespeech.bj.bcebos.com/s2t/librispeech/asr0/asr0_deepspeech2_librispeech_ckpt_0.1.1.model.tar.gz',
|
||||
'md5':
|
||||
'f5666c81ad015c8de03aac2bc92e5762',
|
||||
'cfg_path':
|
||||
'model.yaml',
|
||||
'ckpt_path':
|
||||
'exp/deepspeech2/checkpoints/avg_1',
|
||||
'lm_url':
|
||||
'https://deepspeech.bj.bcebos.com/en_lm/common_crawl_00.prune01111.trie.klm',
|
||||
'lm_md5':
|
||||
'099a601759d467cd0a8523ff939819c5'
|
||||
},
|
||||
}
|
||||
|
||||
model_alias = {
|
||||
"deepspeech2offline":
|
||||
"paddlespeech.s2t.models.ds2:DeepSpeech2Model",
|
||||
"deepspeech2online":
|
||||
"paddlespeech.s2t.models.ds2_online:DeepSpeech2ModelOnline",
|
||||
"conformer":
|
||||
"paddlespeech.s2t.models.u2:U2Model",
|
||||
"conformer_online":
|
||||
"paddlespeech.s2t.models.u2:U2Model",
|
||||
"transformer":
|
||||
"paddlespeech.s2t.models.u2:U2Model",
|
||||
"wenetspeech":
|
||||
"paddlespeech.s2t.models.u2:U2Model",
|
||||
}
|
@ -0,0 +1,47 @@
|
||||
# Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
|
||||
pretrained_models = {
|
||||
# The tags for pretrained_models should be "{model_name}[_{dataset}][-{lang}][-...]".
|
||||
# e.g. "conformer_wenetspeech-zh-16k", "transformer_aishell-zh-16k" and "panns_cnn6-32k".
|
||||
# Command line and python api use "{model_name}[_{dataset}]" as --model, usage:
|
||||
# "paddlespeech asr --model conformer_wenetspeech --lang zh --sr 16000 --input ./input.wav"
|
||||
"panns_cnn6-32k": {
|
||||
'url': 'https://paddlespeech.bj.bcebos.com/cls/panns_cnn6.tar.gz',
|
||||
'md5': '4cf09194a95df024fd12f84712cf0f9c',
|
||||
'cfg_path': 'panns.yaml',
|
||||
'ckpt_path': 'cnn6.pdparams',
|
||||
'label_file': 'audioset_labels.txt',
|
||||
},
|
||||
"panns_cnn10-32k": {
|
||||
'url': 'https://paddlespeech.bj.bcebos.com/cls/panns_cnn10.tar.gz',
|
||||
'md5': 'cb8427b22176cc2116367d14847f5413',
|
||||
'cfg_path': 'panns.yaml',
|
||||
'ckpt_path': 'cnn10.pdparams',
|
||||
'label_file': 'audioset_labels.txt',
|
||||
},
|
||||
"panns_cnn14-32k": {
|
||||
'url': 'https://paddlespeech.bj.bcebos.com/cls/panns_cnn14.tar.gz',
|
||||
'md5': 'e3b9b5614a1595001161d0ab95edee97',
|
||||
'cfg_path': 'panns.yaml',
|
||||
'ckpt_path': 'cnn14.pdparams',
|
||||
'label_file': 'audioset_labels.txt',
|
||||
},
|
||||
}
|
||||
|
||||
model_alias = {
|
||||
"panns_cnn6": "paddlespeech.cls.models.panns:CNN6",
|
||||
"panns_cnn10": "paddlespeech.cls.models.panns:CNN10",
|
||||
"panns_cnn14": "paddlespeech.cls.models.panns:CNN14",
|
||||
}
|
@ -0,0 +1,35 @@
|
||||
# Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
|
||||
pretrained_models = {
|
||||
"fat_st_ted-en-zh": {
|
||||
"url":
|
||||
"https://paddlespeech.bj.bcebos.com/s2t/ted_en_zh/st1/st1_transformer_mtl_noam_ted-en-zh_ckpt_0.1.1.model.tar.gz",
|
||||
"md5":
|
||||
"d62063f35a16d91210a71081bd2dd557",
|
||||
"cfg_path":
|
||||
"model.yaml",
|
||||
"ckpt_path":
|
||||
"exp/transformer_mtl_noam/checkpoints/fat_st_ted-en-zh.pdparams",
|
||||
}
|
||||
}
|
||||
|
||||
model_alias = {"fat_st": "paddlespeech.s2t.models.u2_st:U2STModel"}
|
||||
|
||||
kaldi_bins = {
|
||||
"url":
|
||||
"https://paddlespeech.bj.bcebos.com/s2t/ted_en_zh/st1/kaldi_bins.tar.gz",
|
||||
"md5":
|
||||
"c0682303b3f3393dbf6ed4c4e35a53eb",
|
||||
}
|
Some files were not shown because too many files have changed in this diff Show More
Loading…
Reference in new issue