pull/1696/head
commit
d3f8715b0a
@ -1,6 +1,9 @@
|
|||||||
#!/bin/bash
|
#!/bin/bash
|
||||||
|
|
||||||
wget -c https://paddlespeech.bj.bcebos.com/vector/audio/85236145389.wav
|
wget -c https://paddlespeech.bj.bcebos.com/vector/audio/85236145389.wav
|
||||||
|
wget -c https://paddlespeech.bj.bcebos.com/vector/audio/123456789.wav
|
||||||
|
|
||||||
# asr
|
# vector
|
||||||
paddlespeech vector --task spk --input ./85236145389.wav
|
paddlespeech vector --task spk --input ./85236145389.wav
|
||||||
|
|
||||||
|
paddlespeech vector --task score --input "./85236145389.wav ./123456789.wav"
|
||||||
|
@ -0,0 +1,62 @@
|
|||||||
|
###########################################################
|
||||||
|
# AMI DATA PREPARE SETTING #
|
||||||
|
###########################################################
|
||||||
|
split_type: 'full_corpus_asr'
|
||||||
|
skip_TNO: True
|
||||||
|
# Options for mic_type: 'Mix-Lapel', 'Mix-Headset', 'Array1', 'Array1-01', 'BeamformIt'
|
||||||
|
mic_type: 'Mix-Headset'
|
||||||
|
vad_type: 'oracle'
|
||||||
|
max_subseg_dur: 3.0
|
||||||
|
overlap: 1.5
|
||||||
|
# Some more exp folders (for cleaner structure).
|
||||||
|
embedding_dir: emb #!ref <save_folder>/emb
|
||||||
|
meta_data_dir: metadata #!ref <save_folder>/metadata
|
||||||
|
ref_rttm_dir: ref_rttms #!ref <save_folder>/ref_rttms
|
||||||
|
sys_rttm_dir: sys_rttms #!ref <save_folder>/sys_rttms
|
||||||
|
der_dir: DER #!ref <save_folder>/DER
|
||||||
|
|
||||||
|
|
||||||
|
###########################################################
|
||||||
|
# FEATURE EXTRACTION SETTING #
|
||||||
|
###########################################################
|
||||||
|
# currently, we only support fbank
|
||||||
|
sr: 16000 # sample rate
|
||||||
|
n_mels: 80
|
||||||
|
window_size: 400 #25ms, sample rate 16000, 25 * 16000 / 1000 = 400
|
||||||
|
hop_size: 160 #10ms, sample rate 16000, 10 * 16000 / 1000 = 160
|
||||||
|
#left_frames: 0
|
||||||
|
#right_frames: 0
|
||||||
|
#deltas: False
|
||||||
|
|
||||||
|
|
||||||
|
###########################################################
|
||||||
|
# MODEL SETTING #
|
||||||
|
###########################################################
|
||||||
|
# currently, we only support ecapa-tdnn in the ecapa_tdnn.yaml
|
||||||
|
# if we want use another model, please choose another configuration yaml file
|
||||||
|
seed: 1234
|
||||||
|
emb_dim: 192
|
||||||
|
batch_size: 16
|
||||||
|
model:
|
||||||
|
input_size: 80
|
||||||
|
channels: [1024, 1024, 1024, 1024, 3072]
|
||||||
|
kernel_sizes: [5, 3, 3, 3, 1]
|
||||||
|
dilations: [1, 2, 3, 4, 1]
|
||||||
|
attention_channels: 128
|
||||||
|
lin_neurons: 192
|
||||||
|
# Will automatically download ECAPA-TDNN model (best).
|
||||||
|
|
||||||
|
###########################################################
|
||||||
|
# SPECTRAL CLUSTERING SETTING #
|
||||||
|
###########################################################
|
||||||
|
backend: 'SC' # options: 'kmeans' # Note: kmeans goes only with cos affinity
|
||||||
|
affinity: 'cos' # options: cos, nn
|
||||||
|
max_num_spkrs: 10
|
||||||
|
oracle_n_spkrs: True
|
||||||
|
|
||||||
|
|
||||||
|
###########################################################
|
||||||
|
# DER EVALUATION SETTING #
|
||||||
|
###########################################################
|
||||||
|
ignore_overlap: True
|
||||||
|
forgiveness_collar: 0.25
|
@ -0,0 +1,231 @@
|
|||||||
|
# Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved.
|
||||||
|
#
|
||||||
|
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||||
|
# you may not use this file except in compliance with the License.
|
||||||
|
# You may obtain a copy of the License at
|
||||||
|
#
|
||||||
|
# http://www.apache.org/licenses/LICENSE-2.0
|
||||||
|
#
|
||||||
|
# Unless required by applicable law or agreed to in writing, software
|
||||||
|
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||||
|
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||||
|
# See the License for the specific language governing permissions and
|
||||||
|
# limitations under the License.
|
||||||
|
import argparse
|
||||||
|
import json
|
||||||
|
import os
|
||||||
|
import pickle
|
||||||
|
import sys
|
||||||
|
|
||||||
|
import numpy as np
|
||||||
|
import paddle
|
||||||
|
from paddle.io import BatchSampler
|
||||||
|
from paddle.io import DataLoader
|
||||||
|
from tqdm.contrib import tqdm
|
||||||
|
from yacs.config import CfgNode
|
||||||
|
|
||||||
|
from paddlespeech.s2t.utils.log import Log
|
||||||
|
from paddlespeech.vector.cluster.diarization import EmbeddingMeta
|
||||||
|
from paddlespeech.vector.io.batch import batch_feature_normalize
|
||||||
|
from paddlespeech.vector.io.dataset_from_json import JSONDataset
|
||||||
|
from paddlespeech.vector.models.ecapa_tdnn import EcapaTdnn
|
||||||
|
from paddlespeech.vector.modules.sid_model import SpeakerIdetification
|
||||||
|
from paddlespeech.vector.training.seeding import seed_everything
|
||||||
|
|
||||||
|
# Logger setup
|
||||||
|
logger = Log(__name__).getlog()
|
||||||
|
|
||||||
|
|
||||||
|
def prepare_subset_json(full_meta_data, rec_id, out_meta_file):
|
||||||
|
"""Prepares metadata for a given recording ID.
|
||||||
|
|
||||||
|
Arguments
|
||||||
|
---------
|
||||||
|
full_meta_data : json
|
||||||
|
Full meta (json) containing all the recordings
|
||||||
|
rec_id : str
|
||||||
|
The recording ID for which meta (json) has to be prepared
|
||||||
|
out_meta_file : str
|
||||||
|
Path of the output meta (json) file.
|
||||||
|
"""
|
||||||
|
|
||||||
|
subset = {}
|
||||||
|
for key in full_meta_data:
|
||||||
|
k = str(key)
|
||||||
|
if k.startswith(rec_id):
|
||||||
|
subset[key] = full_meta_data[key]
|
||||||
|
|
||||||
|
with open(out_meta_file, mode="w") as json_f:
|
||||||
|
json.dump(subset, json_f, indent=2)
|
||||||
|
|
||||||
|
|
||||||
|
def create_dataloader(json_file, batch_size):
|
||||||
|
"""Creates the datasets and their data processing pipelines.
|
||||||
|
This is used for multi-mic processing.
|
||||||
|
"""
|
||||||
|
|
||||||
|
# create datasets
|
||||||
|
dataset = JSONDataset(
|
||||||
|
json_file=json_file,
|
||||||
|
feat_type='melspectrogram',
|
||||||
|
n_mels=config.n_mels,
|
||||||
|
window_size=config.window_size,
|
||||||
|
hop_length=config.hop_size)
|
||||||
|
|
||||||
|
# create dataloader
|
||||||
|
batch_sampler = BatchSampler(dataset, batch_size=batch_size, shuffle=True)
|
||||||
|
dataloader = DataLoader(dataset,
|
||||||
|
batch_sampler=batch_sampler,
|
||||||
|
collate_fn=lambda x: batch_feature_normalize(
|
||||||
|
x, mean_norm=True, std_norm=False),
|
||||||
|
return_list=True)
|
||||||
|
|
||||||
|
return dataloader
|
||||||
|
|
||||||
|
|
||||||
|
def main(args, config):
|
||||||
|
# set the training device, cpu or gpu
|
||||||
|
paddle.set_device(args.device)
|
||||||
|
# set the random seed
|
||||||
|
seed_everything(config.seed)
|
||||||
|
|
||||||
|
# stage1: build the dnn backbone model network
|
||||||
|
ecapa_tdnn = EcapaTdnn(**config.model)
|
||||||
|
|
||||||
|
# stage2: build the speaker verification eval instance with backbone model
|
||||||
|
model = SpeakerIdetification(backbone=ecapa_tdnn, num_class=1)
|
||||||
|
|
||||||
|
# stage3: load the pre-trained model
|
||||||
|
# we get the last model from the epoch and save_interval
|
||||||
|
args.load_checkpoint = os.path.abspath(
|
||||||
|
os.path.expanduser(args.load_checkpoint))
|
||||||
|
|
||||||
|
# load model checkpoint to sid model
|
||||||
|
state_dict = paddle.load(
|
||||||
|
os.path.join(args.load_checkpoint, 'model.pdparams'))
|
||||||
|
model.set_state_dict(state_dict)
|
||||||
|
logger.info(f'Checkpoint loaded from {args.load_checkpoint}')
|
||||||
|
|
||||||
|
# set the model to eval mode
|
||||||
|
model.eval()
|
||||||
|
|
||||||
|
# load meta data
|
||||||
|
meta_file = os.path.join(
|
||||||
|
args.data_dir,
|
||||||
|
config.meta_data_dir,
|
||||||
|
"ami_" + args.dataset + "." + config.mic_type + ".subsegs.json", )
|
||||||
|
with open(meta_file, "r") as f:
|
||||||
|
full_meta = json.load(f)
|
||||||
|
|
||||||
|
# get all the recording IDs in this dataset.
|
||||||
|
all_keys = full_meta.keys()
|
||||||
|
A = [word.rstrip().split("_")[0] for word in all_keys]
|
||||||
|
all_rec_ids = list(set(A[1:]))
|
||||||
|
all_rec_ids.sort()
|
||||||
|
split = "AMI_" + args.dataset
|
||||||
|
i = 1
|
||||||
|
|
||||||
|
msg = "Extra embdding for " + args.dataset + " set"
|
||||||
|
logger.info(msg)
|
||||||
|
|
||||||
|
if len(all_rec_ids) <= 0:
|
||||||
|
msg = "No recording IDs found! Please check if meta_data json file is properly generated."
|
||||||
|
logger.error(msg)
|
||||||
|
sys.exit()
|
||||||
|
|
||||||
|
# extra different recordings embdding in a dataset.
|
||||||
|
for rec_id in tqdm(all_rec_ids):
|
||||||
|
# This tag will be displayed in the log.
|
||||||
|
tag = ("[" + str(args.dataset) + ": " + str(i) + "/" +
|
||||||
|
str(len(all_rec_ids)) + "]")
|
||||||
|
i = i + 1
|
||||||
|
|
||||||
|
# log message.
|
||||||
|
msg = "Embdding %s : %s " % (tag, rec_id)
|
||||||
|
logger.debug(msg)
|
||||||
|
|
||||||
|
# embedding directory.
|
||||||
|
if not os.path.exists(
|
||||||
|
os.path.join(args.data_dir, config.embedding_dir, split)):
|
||||||
|
os.makedirs(
|
||||||
|
os.path.join(args.data_dir, config.embedding_dir, split))
|
||||||
|
|
||||||
|
# file to store embeddings.
|
||||||
|
emb_file_name = rec_id + "." + config.mic_type + ".emb_stat.pkl"
|
||||||
|
diary_stat_emb_file = os.path.join(args.data_dir, config.embedding_dir,
|
||||||
|
split, emb_file_name)
|
||||||
|
|
||||||
|
# prepare a metadata (json) for one recording. This is basically a subset of full_meta.
|
||||||
|
# lets keep this meta-info in embedding directory itself.
|
||||||
|
json_file_name = rec_id + "." + config.mic_type + ".json"
|
||||||
|
meta_per_rec_file = os.path.join(args.data_dir, config.embedding_dir,
|
||||||
|
split, json_file_name)
|
||||||
|
|
||||||
|
# write subset (meta for one recording) json metadata.
|
||||||
|
prepare_subset_json(full_meta, rec_id, meta_per_rec_file)
|
||||||
|
|
||||||
|
# prepare data loader.
|
||||||
|
diary_set_loader = create_dataloader(meta_per_rec_file,
|
||||||
|
config.batch_size)
|
||||||
|
|
||||||
|
# extract embeddings (skip if already done).
|
||||||
|
if not os.path.isfile(diary_stat_emb_file):
|
||||||
|
logger.debug("Extracting deep embeddings")
|
||||||
|
embeddings = np.empty(shape=[0, config.emb_dim], dtype=np.float64)
|
||||||
|
segset = []
|
||||||
|
|
||||||
|
for batch_idx, batch in enumerate(tqdm(diary_set_loader)):
|
||||||
|
# extrac the audio embedding
|
||||||
|
ids, feats, lengths = batch['ids'], batch['feats'], batch[
|
||||||
|
'lengths']
|
||||||
|
seg = [x for x in ids]
|
||||||
|
segset = segset + seg
|
||||||
|
emb = model.backbone(feats, lengths).squeeze(
|
||||||
|
-1).numpy() # (N, emb_size, 1) -> (N, emb_size)
|
||||||
|
embeddings = np.concatenate((embeddings, emb), axis=0)
|
||||||
|
|
||||||
|
segset = np.array(segset, dtype="|O")
|
||||||
|
stat_obj = EmbeddingMeta(
|
||||||
|
segset=segset,
|
||||||
|
stats=embeddings, )
|
||||||
|
logger.debug("Saving Embeddings...")
|
||||||
|
with open(diary_stat_emb_file, "wb") as output:
|
||||||
|
pickle.dump(stat_obj, output)
|
||||||
|
|
||||||
|
else:
|
||||||
|
logger.debug("Skipping embedding extraction (as already present).")
|
||||||
|
|
||||||
|
|
||||||
|
# Begin experiment!
|
||||||
|
if __name__ == "__main__":
|
||||||
|
parser = argparse.ArgumentParser(__doc__)
|
||||||
|
parser.add_argument(
|
||||||
|
'--device',
|
||||||
|
default="gpu",
|
||||||
|
help="Select which device to perform diarization, defaults to gpu.")
|
||||||
|
parser.add_argument(
|
||||||
|
"--config", default=None, type=str, help="configuration file")
|
||||||
|
parser.add_argument(
|
||||||
|
"--data-dir",
|
||||||
|
default="../save/",
|
||||||
|
type=str,
|
||||||
|
help="processsed data directory")
|
||||||
|
parser.add_argument(
|
||||||
|
"--dataset",
|
||||||
|
choices=['dev', 'eval'],
|
||||||
|
default="dev",
|
||||||
|
type=str,
|
||||||
|
help="Select which dataset to extra embdding, defaults to dev")
|
||||||
|
parser.add_argument(
|
||||||
|
"--load-checkpoint",
|
||||||
|
type=str,
|
||||||
|
default='',
|
||||||
|
help="Directory to load model checkpoint to compute embeddings.")
|
||||||
|
args = parser.parse_args()
|
||||||
|
config = CfgNode(new_allowed=True)
|
||||||
|
if args.config:
|
||||||
|
config.merge_from_file(args.config)
|
||||||
|
|
||||||
|
config.freeze()
|
||||||
|
|
||||||
|
main(args, config)
|
@ -1,49 +0,0 @@
|
|||||||
#!/bin/bash
|
|
||||||
|
|
||||||
stage=1
|
|
||||||
|
|
||||||
TARGET_DIR=${MAIN_ROOT}/dataset/ami
|
|
||||||
data_folder=${TARGET_DIR}/amicorpus #e.g., /path/to/amicorpus/
|
|
||||||
manual_annot_folder=${TARGET_DIR}/ami_public_manual_1.6.2 #e.g., /path/to/ami_public_manual_1.6.2/
|
|
||||||
|
|
||||||
save_folder=${MAIN_ROOT}/examples/ami/sd0/data
|
|
||||||
ref_rttm_dir=${save_folder}/ref_rttms
|
|
||||||
meta_data_dir=${save_folder}/metadata
|
|
||||||
|
|
||||||
set=L
|
|
||||||
|
|
||||||
. ${MAIN_ROOT}/utils/parse_options.sh || exit 1;
|
|
||||||
set -u
|
|
||||||
set -o pipefail
|
|
||||||
|
|
||||||
mkdir -p ${save_folder}
|
|
||||||
|
|
||||||
if [ ${stage} -le 0 ]; then
|
|
||||||
# Download AMI corpus, You need around 10GB of free space to get whole data
|
|
||||||
# The signals are too large to package in this way,
|
|
||||||
# so you need to use the chooser to indicate which ones you wish to download
|
|
||||||
echo "Please follow https://groups.inf.ed.ac.uk/ami/download/ to download the data."
|
|
||||||
echo "Annotations: AMI manual annotations v1.6.2 "
|
|
||||||
echo "Signals: "
|
|
||||||
echo "1) Select one or more AMI meetings: the IDs please follow ./ami_split.py"
|
|
||||||
echo "2) Select media streams: Just select Headset mix"
|
|
||||||
exit 0;
|
|
||||||
fi
|
|
||||||
|
|
||||||
if [ ${stage} -le 1 ]; then
|
|
||||||
echo "AMI Data preparation"
|
|
||||||
|
|
||||||
python local/ami_prepare.py --data_folder ${data_folder} \
|
|
||||||
--manual_annot_folder ${manual_annot_folder} \
|
|
||||||
--save_folder ${save_folder} --ref_rttm_dir ${ref_rttm_dir} \
|
|
||||||
--meta_data_dir ${meta_data_dir}
|
|
||||||
|
|
||||||
if [ $? -ne 0 ]; then
|
|
||||||
echo "Prepare AMI failed. Please check log message."
|
|
||||||
exit 1
|
|
||||||
fi
|
|
||||||
|
|
||||||
fi
|
|
||||||
|
|
||||||
echo "AMI data preparation done."
|
|
||||||
exit 0
|
|
@ -0,0 +1,428 @@
|
|||||||
|
# Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved.
|
||||||
|
#
|
||||||
|
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||||
|
# you may not use this file except in compliance with the License.
|
||||||
|
# You may obtain a copy of the License at
|
||||||
|
#
|
||||||
|
# http://www.apache.org/licenses/LICENSE-2.0
|
||||||
|
#
|
||||||
|
# Unless required by applicable law or agreed to in writing, software
|
||||||
|
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||||
|
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||||
|
# See the License for the specific language governing permissions and
|
||||||
|
# limitations under the License.
|
||||||
|
import argparse
|
||||||
|
import glob
|
||||||
|
import json
|
||||||
|
import os
|
||||||
|
import pickle
|
||||||
|
import shutil
|
||||||
|
import sys
|
||||||
|
|
||||||
|
import numpy as np
|
||||||
|
from tqdm.contrib import tqdm
|
||||||
|
from yacs.config import CfgNode
|
||||||
|
|
||||||
|
from paddlespeech.s2t.utils.log import Log
|
||||||
|
from paddlespeech.vector.cluster import diarization as diar
|
||||||
|
from utils.DER import DER
|
||||||
|
|
||||||
|
# Logger setup
|
||||||
|
logger = Log(__name__).getlog()
|
||||||
|
|
||||||
|
|
||||||
|
def diarize_dataset(
|
||||||
|
full_meta,
|
||||||
|
split_type,
|
||||||
|
n_lambdas,
|
||||||
|
pval,
|
||||||
|
save_dir,
|
||||||
|
config,
|
||||||
|
n_neighbors=10, ):
|
||||||
|
"""This function diarizes all the recordings in a given dataset. It performs
|
||||||
|
computation of embedding and clusters them using spectral clustering (or other backends).
|
||||||
|
The output speaker boundary file is stored in the RTTM format.
|
||||||
|
"""
|
||||||
|
|
||||||
|
# prepare `spkr_info` only once when Oracle num of speakers is selected.
|
||||||
|
# spkr_info is essential to obtain number of speakers from groundtruth.
|
||||||
|
if config.oracle_n_spkrs is True:
|
||||||
|
full_ref_rttm_file = os.path.join(save_dir, config.ref_rttm_dir,
|
||||||
|
"fullref_ami_" + split_type + ".rttm")
|
||||||
|
rttm = diar.read_rttm(full_ref_rttm_file)
|
||||||
|
|
||||||
|
spkr_info = list( # noqa F841
|
||||||
|
filter(lambda x: x.startswith("SPKR-INFO"), rttm))
|
||||||
|
|
||||||
|
# get all the recording IDs in this dataset.
|
||||||
|
all_keys = full_meta.keys()
|
||||||
|
A = [word.rstrip().split("_")[0] for word in all_keys]
|
||||||
|
all_rec_ids = list(set(A[1:]))
|
||||||
|
all_rec_ids.sort()
|
||||||
|
split = "AMI_" + split_type
|
||||||
|
i = 1
|
||||||
|
|
||||||
|
# adding tag for directory path.
|
||||||
|
type_of_num_spkr = "oracle" if config.oracle_n_spkrs else "est"
|
||||||
|
tag = (type_of_num_spkr + "_" + str(config.affinity) + "_" + config.backend)
|
||||||
|
|
||||||
|
# make out rttm dir
|
||||||
|
out_rttm_dir = os.path.join(save_dir, config.sys_rttm_dir, config.mic_type,
|
||||||
|
split, tag)
|
||||||
|
if not os.path.exists(out_rttm_dir):
|
||||||
|
os.makedirs(out_rttm_dir)
|
||||||
|
|
||||||
|
# diarizing different recordings in a dataset.
|
||||||
|
for rec_id in tqdm(all_rec_ids):
|
||||||
|
# this tag will be displayed in the log.
|
||||||
|
tag = ("[" + str(split_type) + ": " + str(i) + "/" +
|
||||||
|
str(len(all_rec_ids)) + "]")
|
||||||
|
i = i + 1
|
||||||
|
|
||||||
|
# log message.
|
||||||
|
msg = "Diarizing %s : %s " % (tag, rec_id)
|
||||||
|
logger.debug(msg)
|
||||||
|
|
||||||
|
# load embeddings.
|
||||||
|
emb_file_name = rec_id + "." + config.mic_type + ".emb_stat.pkl"
|
||||||
|
diary_stat_emb_file = os.path.join(save_dir, config.embedding_dir,
|
||||||
|
split, emb_file_name)
|
||||||
|
if not os.path.isfile(diary_stat_emb_file):
|
||||||
|
msg = "Embdding file %s not found! Please check if embdding file is properly generated." % (
|
||||||
|
diary_stat_emb_file)
|
||||||
|
logger.error(msg)
|
||||||
|
sys.exit()
|
||||||
|
with open(diary_stat_emb_file, "rb") as in_file:
|
||||||
|
diary_obj = pickle.load(in_file)
|
||||||
|
|
||||||
|
out_rttm_file = out_rttm_dir + "/" + rec_id + ".rttm"
|
||||||
|
|
||||||
|
# processing starts from here.
|
||||||
|
if config.oracle_n_spkrs is True:
|
||||||
|
# oracle num of speakers.
|
||||||
|
num_spkrs = diar.get_oracle_num_spkrs(rec_id, spkr_info)
|
||||||
|
else:
|
||||||
|
if config.affinity == "nn":
|
||||||
|
# num of speakers tunned on dev set (only for nn affinity).
|
||||||
|
num_spkrs = n_lambdas
|
||||||
|
else:
|
||||||
|
# num of speakers will be estimated using max eigen gap for cos based affinity.
|
||||||
|
# so adding None here. Will use this None later-on.
|
||||||
|
num_spkrs = None
|
||||||
|
|
||||||
|
if config.backend == "kmeans":
|
||||||
|
diar.do_kmeans_clustering(
|
||||||
|
diary_obj,
|
||||||
|
out_rttm_file,
|
||||||
|
rec_id,
|
||||||
|
num_spkrs,
|
||||||
|
pval, )
|
||||||
|
|
||||||
|
if config.backend == "SC":
|
||||||
|
# go for Spectral Clustering (SC).
|
||||||
|
diar.do_spec_clustering(
|
||||||
|
diary_obj,
|
||||||
|
out_rttm_file,
|
||||||
|
rec_id,
|
||||||
|
num_spkrs,
|
||||||
|
pval,
|
||||||
|
config.affinity,
|
||||||
|
n_neighbors, )
|
||||||
|
|
||||||
|
# can used for AHC later. Likewise one can add different backends here.
|
||||||
|
if config.backend == "AHC":
|
||||||
|
# call AHC
|
||||||
|
threshold = pval # pval for AHC is nothing but threshold.
|
||||||
|
diar.do_AHC(diary_obj, out_rttm_file, rec_id, num_spkrs, threshold)
|
||||||
|
|
||||||
|
# once all RTTM outputs are generated, concatenate individual RTTM files to obtain single RTTM file.
|
||||||
|
# this is not needed but just staying with the standards.
|
||||||
|
concate_rttm_file = out_rttm_dir + "/sys_output.rttm"
|
||||||
|
logger.debug("Concatenating individual RTTM files...")
|
||||||
|
with open(concate_rttm_file, "w") as cat_file:
|
||||||
|
for f in glob.glob(out_rttm_dir + "/*.rttm"):
|
||||||
|
if f == concate_rttm_file:
|
||||||
|
continue
|
||||||
|
with open(f, "r") as indi_rttm_file:
|
||||||
|
shutil.copyfileobj(indi_rttm_file, cat_file)
|
||||||
|
|
||||||
|
msg = "The system generated RTTM file for %s set : %s" % (
|
||||||
|
split_type, concate_rttm_file, )
|
||||||
|
logger.debug(msg)
|
||||||
|
|
||||||
|
return concate_rttm_file
|
||||||
|
|
||||||
|
|
||||||
|
def dev_pval_tuner(full_meta, save_dir, config):
|
||||||
|
"""Tuning p_value for affinity matrix.
|
||||||
|
The p_value used so that only p% of the values in each row is retained.
|
||||||
|
"""
|
||||||
|
|
||||||
|
DER_list = []
|
||||||
|
prange = np.arange(0.002, 0.015, 0.001)
|
||||||
|
|
||||||
|
n_lambdas = None # using it as flag later.
|
||||||
|
for p_v in prange:
|
||||||
|
# Process whole dataset for value of p_v.
|
||||||
|
concate_rttm_file = diarize_dataset(full_meta, "dev", n_lambdas, p_v,
|
||||||
|
save_dir, config)
|
||||||
|
|
||||||
|
ref_rttm_file = os.path.join(save_dir, config.ref_rttm_dir,
|
||||||
|
"fullref_ami_dev.rttm")
|
||||||
|
sys_rttm_file = concate_rttm_file
|
||||||
|
[MS, FA, SER, DER_] = DER(
|
||||||
|
ref_rttm_file,
|
||||||
|
sys_rttm_file,
|
||||||
|
config.ignore_overlap,
|
||||||
|
config.forgiveness_collar, )
|
||||||
|
|
||||||
|
DER_list.append(DER_)
|
||||||
|
|
||||||
|
if config.oracle_n_spkrs is True and config.backend == "kmeans":
|
||||||
|
# no need of p_val search. Note p_val is needed for SC for both oracle and est num of speakers.
|
||||||
|
# p_val is needed in oracle_n_spkr=False when using kmeans backend.
|
||||||
|
break
|
||||||
|
|
||||||
|
# Take p_val that gave minmum DER on Dev dataset.
|
||||||
|
tuned_p_val = prange[DER_list.index(min(DER_list))]
|
||||||
|
|
||||||
|
return tuned_p_val
|
||||||
|
|
||||||
|
|
||||||
|
def dev_ahc_threshold_tuner(full_meta, save_dir, config):
|
||||||
|
"""Tuning threshold for affinity matrix. This function is called when AHC is used as backend.
|
||||||
|
"""
|
||||||
|
|
||||||
|
DER_list = []
|
||||||
|
prange = np.arange(0.0, 1.0, 0.1)
|
||||||
|
|
||||||
|
n_lambdas = None # using it as flag later.
|
||||||
|
|
||||||
|
# Note: p_val is threshold in case of AHC.
|
||||||
|
for p_v in prange:
|
||||||
|
# Process whole dataset for value of p_v.
|
||||||
|
concate_rttm_file = diarize_dataset(full_meta, "dev", n_lambdas, p_v,
|
||||||
|
save_dir, config)
|
||||||
|
|
||||||
|
ref_rttm = os.path.join(save_dir, config.ref_rttm_dir,
|
||||||
|
"fullref_ami_dev.rttm")
|
||||||
|
sys_rttm = concate_rttm_file
|
||||||
|
[MS, FA, SER, DER_] = DER(
|
||||||
|
ref_rttm,
|
||||||
|
sys_rttm,
|
||||||
|
config.ignore_overlap,
|
||||||
|
config.forgiveness_collar, )
|
||||||
|
|
||||||
|
DER_list.append(DER_)
|
||||||
|
|
||||||
|
if config.oracle_n_spkrs is True:
|
||||||
|
break # no need of threshold search.
|
||||||
|
|
||||||
|
# Take p_val that gave minmum DER on Dev dataset.
|
||||||
|
tuned_p_val = prange[DER_list.index(min(DER_list))]
|
||||||
|
|
||||||
|
return tuned_p_val
|
||||||
|
|
||||||
|
|
||||||
|
def dev_nn_tuner(full_meta, split_type, save_dir, config):
|
||||||
|
"""Tuning n_neighbors on dev set. Assuming oracle num of speakers.
|
||||||
|
This is used when nn based affinity is selected.
|
||||||
|
"""
|
||||||
|
|
||||||
|
DER_list = []
|
||||||
|
pval = None
|
||||||
|
|
||||||
|
# Now assumming oracle num of speakers.
|
||||||
|
n_lambdas = 4
|
||||||
|
|
||||||
|
for nn in range(5, 15):
|
||||||
|
|
||||||
|
# Process whole dataset for value of n_lambdas.
|
||||||
|
concate_rttm_file = diarize_dataset(full_meta, "dev", n_lambdas, p_v,
|
||||||
|
save_dir, config, nn)
|
||||||
|
|
||||||
|
ref_rttm = os.path.join(save_dir, config.ref_rttm_dir,
|
||||||
|
"fullref_ami_dev.rttm")
|
||||||
|
sys_rttm = concate_rttm_file
|
||||||
|
[MS, FA, SER, DER_] = DER(
|
||||||
|
ref_rttm,
|
||||||
|
sys_rttm,
|
||||||
|
config.ignore_overlap,
|
||||||
|
config.forgiveness_collar, )
|
||||||
|
|
||||||
|
DER_list.append([nn, DER_])
|
||||||
|
|
||||||
|
if config.oracle_n_spkrs is True and config.backend == "kmeans":
|
||||||
|
break
|
||||||
|
|
||||||
|
DER_list.sort(key=lambda x: x[1])
|
||||||
|
tunned_nn = DER_list[0]
|
||||||
|
|
||||||
|
return tunned_nn[0]
|
||||||
|
|
||||||
|
|
||||||
|
def dev_tuner(full_meta, split_type, save_dir, config):
|
||||||
|
"""Tuning n_components on dev set. Used for nn based affinity matrix.
|
||||||
|
Note: This is a very basic tunning for nn based affinity.
|
||||||
|
This is work in progress till we find a better way.
|
||||||
|
"""
|
||||||
|
|
||||||
|
DER_list = []
|
||||||
|
pval = None
|
||||||
|
for n_lambdas in range(1, config.max_num_spkrs + 1):
|
||||||
|
|
||||||
|
# Process whole dataset for value of n_lambdas.
|
||||||
|
concate_rttm_file = diarize_dataset(full_meta, "dev", n_lambdas, p_v,
|
||||||
|
save_dir, config)
|
||||||
|
|
||||||
|
ref_rttm = os.path.join(save_dir, config.ref_rttm_dir,
|
||||||
|
"fullref_ami_dev.rttm")
|
||||||
|
sys_rttm = concate_rttm_file
|
||||||
|
[MS, FA, SER, DER_] = DER(
|
||||||
|
ref_rttm,
|
||||||
|
sys_rttm,
|
||||||
|
config.ignore_overlap,
|
||||||
|
config.forgiveness_collar, )
|
||||||
|
|
||||||
|
DER_list.append(DER_)
|
||||||
|
|
||||||
|
# Take n_lambdas with minmum DER.
|
||||||
|
tuned_n_lambdas = DER_list.index(min(DER_list)) + 1
|
||||||
|
|
||||||
|
return tuned_n_lambdas
|
||||||
|
|
||||||
|
|
||||||
|
def main(args, config):
|
||||||
|
# AMI Dev Set: Tune hyperparams on dev set.
|
||||||
|
# Read the embdding file for dev set generated during embdding compute
|
||||||
|
dev_meta_file = os.path.join(
|
||||||
|
args.data_dir,
|
||||||
|
config.meta_data_dir,
|
||||||
|
"ami_dev." + config.mic_type + ".subsegs.json", )
|
||||||
|
with open(dev_meta_file, "r") as f:
|
||||||
|
meta_dev = json.load(f)
|
||||||
|
|
||||||
|
full_meta = meta_dev
|
||||||
|
|
||||||
|
# Processing starts from here
|
||||||
|
# Following few lines selects option for different backend and affinity matrices. Finds best values for hyperameters using dev set.
|
||||||
|
ref_rttm_file = os.path.join(args.data_dir, config.ref_rttm_dir,
|
||||||
|
"fullref_ami_dev.rttm")
|
||||||
|
best_nn = None
|
||||||
|
if config.affinity == "nn":
|
||||||
|
logger.info("Tuning for nn (Multiple iterations over AMI Dev set)")
|
||||||
|
best_nn = dev_nn_tuner(full_meta, args.data_dir, config)
|
||||||
|
|
||||||
|
n_lambdas = None
|
||||||
|
best_pval = None
|
||||||
|
|
||||||
|
if config.affinity == "cos" and (config.backend == "SC" or
|
||||||
|
config.backend == "kmeans"):
|
||||||
|
# oracle num_spkrs or not, doesn't matter for kmeans and SC backends
|
||||||
|
# cos: Tune for the best pval for SC /kmeans (for unknown num of spkrs)
|
||||||
|
logger.info(
|
||||||
|
"Tuning for p-value for SC (Multiple iterations over AMI Dev set)")
|
||||||
|
best_pval = dev_pval_tuner(full_meta, args.data_dir, config)
|
||||||
|
|
||||||
|
elif config.backend == "AHC":
|
||||||
|
logger.info("Tuning for threshold-value for AHC")
|
||||||
|
best_threshold = dev_ahc_threshold_tuner(full_meta, args.data_dir,
|
||||||
|
config)
|
||||||
|
best_pval = best_threshold
|
||||||
|
else:
|
||||||
|
# NN for unknown num of speakers (can be used in future)
|
||||||
|
if config.oracle_n_spkrs is False:
|
||||||
|
# nn: Tune num of number of components (to be updated later)
|
||||||
|
logger.info(
|
||||||
|
"Tuning for number of eigen components for NN (Multiple iterations over AMI Dev set)"
|
||||||
|
)
|
||||||
|
# dev_tuner used for tuning num of components in NN. Can be used in future.
|
||||||
|
n_lambdas = dev_tuner(full_meta, args.data_dir, config)
|
||||||
|
|
||||||
|
# load 'dev' and 'eval' metadata files.
|
||||||
|
full_meta_dev = full_meta # current full_meta is for 'dev'
|
||||||
|
eval_meta_file = os.path.join(
|
||||||
|
args.data_dir,
|
||||||
|
config.meta_data_dir,
|
||||||
|
"ami_eval." + config.mic_type + ".subsegs.json", )
|
||||||
|
with open(eval_meta_file, "r") as f:
|
||||||
|
full_meta_eval = json.load(f)
|
||||||
|
|
||||||
|
# tag to be appended to final output DER files. Writing DER for individual files.
|
||||||
|
type_of_num_spkr = "oracle" if config.oracle_n_spkrs else "est"
|
||||||
|
tag = (
|
||||||
|
type_of_num_spkr + "_" + str(config.affinity) + "." + config.mic_type)
|
||||||
|
|
||||||
|
# perform final diarization on 'dev' and 'eval' with best hyperparams.
|
||||||
|
final_DERs = {}
|
||||||
|
out_der_dir = os.path.join(args.data_dir, config.der_dir)
|
||||||
|
if not os.path.exists(out_der_dir):
|
||||||
|
os.makedirs(out_der_dir)
|
||||||
|
|
||||||
|
for split_type in ["dev", "eval"]:
|
||||||
|
if split_type == "dev":
|
||||||
|
full_meta = full_meta_dev
|
||||||
|
else:
|
||||||
|
full_meta = full_meta_eval
|
||||||
|
|
||||||
|
# performing diarization.
|
||||||
|
msg = "Diarizing using best hyperparams: " + split_type + " set"
|
||||||
|
logger.info(msg)
|
||||||
|
out_boundaries = diarize_dataset(
|
||||||
|
full_meta,
|
||||||
|
split_type,
|
||||||
|
n_lambdas=n_lambdas,
|
||||||
|
pval=best_pval,
|
||||||
|
n_neighbors=best_nn,
|
||||||
|
save_dir=args.data_dir,
|
||||||
|
config=config)
|
||||||
|
|
||||||
|
# computing DER.
|
||||||
|
msg = "Computing DERs for " + split_type + " set"
|
||||||
|
logger.info(msg)
|
||||||
|
ref_rttm = os.path.join(args.data_dir, config.ref_rttm_dir,
|
||||||
|
"fullref_ami_" + split_type + ".rttm")
|
||||||
|
sys_rttm = out_boundaries
|
||||||
|
[MS, FA, SER, DER_vals] = DER(
|
||||||
|
ref_rttm,
|
||||||
|
sys_rttm,
|
||||||
|
config.ignore_overlap,
|
||||||
|
config.forgiveness_collar,
|
||||||
|
individual_file_scores=True, )
|
||||||
|
|
||||||
|
# writing DER values to a file. Append tag.
|
||||||
|
der_file_name = split_type + "_DER_" + tag
|
||||||
|
out_der_file = os.path.join(out_der_dir, der_file_name)
|
||||||
|
msg = "Writing DER file to: " + out_der_file
|
||||||
|
logger.info(msg)
|
||||||
|
diar.write_ders_file(ref_rttm, DER_vals, out_der_file)
|
||||||
|
|
||||||
|
msg = ("AMI " + split_type + " set DER = %s %%\n" %
|
||||||
|
(str(round(DER_vals[-1], 2))))
|
||||||
|
logger.info(msg)
|
||||||
|
final_DERs[split_type] = round(DER_vals[-1], 2)
|
||||||
|
|
||||||
|
# final print DERs
|
||||||
|
msg = (
|
||||||
|
"Final Diarization Error Rate (%%) on AMI corpus: Dev = %s %% | Eval = %s %%\n"
|
||||||
|
% (str(final_DERs["dev"]), str(final_DERs["eval"])))
|
||||||
|
logger.info(msg)
|
||||||
|
|
||||||
|
|
||||||
|
if __name__ == "__main__":
|
||||||
|
parser = argparse.ArgumentParser(__doc__)
|
||||||
|
parser.add_argument(
|
||||||
|
"--config", default=None, type=str, help="configuration file")
|
||||||
|
parser.add_argument(
|
||||||
|
"--data-dir",
|
||||||
|
default="../data/",
|
||||||
|
type=str,
|
||||||
|
help="processsed data directory")
|
||||||
|
args = parser.parse_args()
|
||||||
|
config = CfgNode(new_allowed=True)
|
||||||
|
if args.config:
|
||||||
|
config.merge_from_file(args.config)
|
||||||
|
|
||||||
|
config.freeze()
|
||||||
|
|
||||||
|
main(args, config)
|
@ -0,0 +1,49 @@
|
|||||||
|
#!/bin/bash
|
||||||
|
|
||||||
|
stage=0
|
||||||
|
set=L
|
||||||
|
|
||||||
|
. ${MAIN_ROOT}/utils/parse_options.sh || exit 1;
|
||||||
|
set -o pipefail
|
||||||
|
|
||||||
|
data_folder=$1
|
||||||
|
manual_annot_folder=$2
|
||||||
|
save_folder=$3
|
||||||
|
pretrained_model_dir=$4
|
||||||
|
conf_path=$5
|
||||||
|
device=$6
|
||||||
|
|
||||||
|
ref_rttm_dir=${save_folder}/ref_rttms
|
||||||
|
meta_data_dir=${save_folder}/metadata
|
||||||
|
|
||||||
|
if [ ${stage} -le 0 ]; then
|
||||||
|
echo "AMI Data preparation"
|
||||||
|
python local/ami_prepare.py --data_folder ${data_folder} \
|
||||||
|
--manual_annot_folder ${manual_annot_folder} \
|
||||||
|
--save_folder ${save_folder} --ref_rttm_dir ${ref_rttm_dir} \
|
||||||
|
--meta_data_dir ${meta_data_dir}
|
||||||
|
|
||||||
|
if [ $? -ne 0 ]; then
|
||||||
|
echo "Prepare AMI failed. Please check log message."
|
||||||
|
exit 1
|
||||||
|
fi
|
||||||
|
echo "AMI data preparation done."
|
||||||
|
fi
|
||||||
|
|
||||||
|
if [ ${stage} -le 1 ]; then
|
||||||
|
# extra embddings for dev and eval dataset
|
||||||
|
for name in dev eval; do
|
||||||
|
python local/compute_embdding.py --config ${conf_path} \
|
||||||
|
--data-dir ${save_folder} \
|
||||||
|
--device ${device} \
|
||||||
|
--dataset ${name} \
|
||||||
|
--load-checkpoint ${pretrained_model_dir}
|
||||||
|
done
|
||||||
|
fi
|
||||||
|
|
||||||
|
if [ ${stage} -le 2 ]; then
|
||||||
|
# tune hyperparams on dev set
|
||||||
|
# perform final diarization on 'dev' and 'eval' with best hyperparams
|
||||||
|
python local/experiment.py --config ${conf_path} \
|
||||||
|
--data-dir ${save_folder}
|
||||||
|
fi
|
@ -1,14 +1,46 @@
|
|||||||
#!/bin/bash
|
#!/bin/bash
|
||||||
|
|
||||||
. path.sh || exit 1;
|
. ./path.sh || exit 1;
|
||||||
set -e
|
set -e
|
||||||
|
|
||||||
stage=1
|
stage=0
|
||||||
|
|
||||||
|
#TARGET_DIR=${MAIN_ROOT}/dataset/ami
|
||||||
|
TARGET_DIR=/home/dataset/AMI
|
||||||
|
data_folder=${TARGET_DIR}/amicorpus #e.g., /path/to/amicorpus/
|
||||||
|
manual_annot_folder=${TARGET_DIR}/ami_public_manual_1.6.2 #e.g., /path/to/ami_public_manual_1.6.2/
|
||||||
|
|
||||||
|
save_folder=./save
|
||||||
|
pretraind_model_dir=${save_folder}/sv0_ecapa_tdnn_voxceleb12_ckpt_0_1_1/model
|
||||||
|
conf_path=conf/ecapa_tdnn.yaml
|
||||||
|
device=gpu
|
||||||
|
|
||||||
. ${MAIN_ROOT}/utils/parse_options.sh || exit 1;
|
. ${MAIN_ROOT}/utils/parse_options.sh || exit 1;
|
||||||
|
|
||||||
if [ ${stage} -le 1 ]; then
|
if [ $stage -le 0 ]; then
|
||||||
# prepare data
|
# Prepare data
|
||||||
bash ./local/data.sh || exit -1
|
# Download AMI corpus, You need around 10GB of free space to get whole data
|
||||||
fi
|
# The signals are too large to package in this way,
|
||||||
|
# so you need to use the chooser to indicate which ones you wish to download
|
||||||
|
echo "Please follow https://groups.inf.ed.ac.uk/ami/download/ to download the data."
|
||||||
|
echo "Annotations: AMI manual annotations v1.6.2 "
|
||||||
|
echo "Signals: "
|
||||||
|
echo "1) Select one or more AMI meetings: the IDs please follow ./ami_split.py"
|
||||||
|
echo "2) Select media streams: Just select Headset mix"
|
||||||
|
fi
|
||||||
|
|
||||||
|
if [ $stage -le 1 ]; then
|
||||||
|
# Download the pretrained model
|
||||||
|
wget https://paddlespeech.bj.bcebos.com/vector/voxceleb/sv0_ecapa_tdnn_voxceleb12_ckpt_0_1_1.tar.gz
|
||||||
|
mkdir -p ${save_folder} && tar -xvf sv0_ecapa_tdnn_voxceleb12_ckpt_0_1_1.tar.gz -C ${save_folder}
|
||||||
|
rm -rf sv0_ecapa_tdnn_voxceleb12_ckpt_0_1_1.tar.gz
|
||||||
|
echo "download the pretrained ECAPA-TDNN Model to path: "${pretraind_model_dir}
|
||||||
|
fi
|
||||||
|
|
||||||
|
if [ $stage -le 2 ]; then
|
||||||
|
# Tune hyperparams on dev set and perform final diarization on dev and eval with best hyperparams.
|
||||||
|
echo ${data_folder} ${manual_annot_folder} ${save_folder} ${pretraind_model_dir} ${conf_path}
|
||||||
|
bash ./local/process.sh ${data_folder} ${manual_annot_folder} \
|
||||||
|
${save_folder} ${pretraind_model_dir} ${conf_path} ${device} || exit 1
|
||||||
|
fi
|
||||||
|
|
||||||
|
@ -0,0 +1,31 @@
|
|||||||
|
train_output_path=$1
|
||||||
|
|
||||||
|
stage=0
|
||||||
|
stop_stage=0
|
||||||
|
|
||||||
|
# only support default_fastspeech2 + hifigan/mb_melgan now!
|
||||||
|
|
||||||
|
# synthesize from metadata
|
||||||
|
if [ ${stage} -le 0 ] && [ ${stop_stage} -ge 0 ]; then
|
||||||
|
python3 ${BIN_DIR}/../ort_predict.py \
|
||||||
|
--inference_dir=${train_output_path}/inference_onnx \
|
||||||
|
--am=fastspeech2_csmsc \
|
||||||
|
--voc=hifigan_csmsc \
|
||||||
|
--test_metadata=dump/test/norm/metadata.jsonl \
|
||||||
|
--output_dir=${train_output_path}/onnx_infer_out \
|
||||||
|
--device=cpu \
|
||||||
|
--cpu_threads=2
|
||||||
|
fi
|
||||||
|
|
||||||
|
# e2e, synthesize from text
|
||||||
|
if [ ${stage} -le 1 ] && [ ${stop_stage} -ge 1 ]; then
|
||||||
|
python3 ${BIN_DIR}/../ort_predict_e2e.py \
|
||||||
|
--inference_dir=${train_output_path}/inference_onnx \
|
||||||
|
--am=fastspeech2_csmsc \
|
||||||
|
--voc=hifigan_csmsc \
|
||||||
|
--output_dir=${train_output_path}/onnx_infer_out_e2e \
|
||||||
|
--text=${BIN_DIR}/../csmsc_test.txt \
|
||||||
|
--phones_dict=dump/phone_id_map.txt \
|
||||||
|
--device=cpu \
|
||||||
|
--cpu_threads=2
|
||||||
|
fi
|
@ -0,0 +1,22 @@
|
|||||||
|
train_output_path=$1
|
||||||
|
model_dir=$2
|
||||||
|
output_dir=$3
|
||||||
|
model=$4
|
||||||
|
|
||||||
|
enable_dev_version=True
|
||||||
|
|
||||||
|
model_name=${model%_*}
|
||||||
|
echo model_name: ${model_name}
|
||||||
|
|
||||||
|
if [ ${model_name} = 'mb_melgan' ] ;then
|
||||||
|
enable_dev_version=False
|
||||||
|
fi
|
||||||
|
|
||||||
|
mkdir -p ${train_output_path}/${output_dir}
|
||||||
|
|
||||||
|
paddle2onnx \
|
||||||
|
--model_dir ${train_output_path}/${model_dir} \
|
||||||
|
--model_filename ${model}.pdmodel \
|
||||||
|
--params_filename ${model}.pdiparams \
|
||||||
|
--save_file ${train_output_path}/${output_dir}/${model}.onnx \
|
||||||
|
--enable_dev_version ${enable_dev_version}
|
@ -0,0 +1,9 @@
|
|||||||
|
# iwslt2012
|
||||||
|
|
||||||
|
## Ernie
|
||||||
|
|
||||||
|
| |COMMA | PERIOD | QUESTION | OVERALL|
|
||||||
|
|:-----:|:-----:|:-----:|:-----:|:-----:|
|
||||||
|
|Precision |0.510955 |0.526462 |0.820755 |0.619391|
|
||||||
|
|Recall |0.517433 |0.564179 |0.861386 |0.647666|
|
||||||
|
|F1 |0.514173 |0.544669 |0.840580 |0.633141|
|
@ -0,0 +1,53 @@
|
|||||||
|
###########################################
|
||||||
|
# Data #
|
||||||
|
###########################################
|
||||||
|
augment: True
|
||||||
|
batch_size: 16
|
||||||
|
num_workers: 2
|
||||||
|
num_speakers: 1211 # 1211 vox1, 5994 vox2, 7205 vox1+2, test speakers: 41
|
||||||
|
shuffle: True
|
||||||
|
skip_prep: False
|
||||||
|
split_ratio: 0.9
|
||||||
|
chunk_duration: 3.0 # seconds
|
||||||
|
random_chunk: True
|
||||||
|
verification_file: data/vox1/veri_test2.txt
|
||||||
|
|
||||||
|
###########################################################
|
||||||
|
# FEATURE EXTRACTION SETTING #
|
||||||
|
###########################################################
|
||||||
|
# currently, we only support fbank
|
||||||
|
sr: 16000 # sample rate
|
||||||
|
n_mels: 80
|
||||||
|
window_size: 400 #25ms, sample rate 16000, 25 * 16000 / 1000 = 400
|
||||||
|
hop_size: 160 #10ms, sample rate 16000, 10 * 16000 / 1000 = 160
|
||||||
|
|
||||||
|
###########################################################
|
||||||
|
# MODEL SETTING #
|
||||||
|
###########################################################
|
||||||
|
# currently, we only support ecapa-tdnn in the ecapa_tdnn.yaml
|
||||||
|
# if we want use another model, please choose another configuration yaml file
|
||||||
|
model:
|
||||||
|
input_size: 80
|
||||||
|
channels: [512, 512, 512, 512, 1536]
|
||||||
|
kernel_sizes: [5, 3, 3, 3, 1]
|
||||||
|
dilations: [1, 2, 3, 4, 1]
|
||||||
|
attention_channels: 128
|
||||||
|
lin_neurons: 192
|
||||||
|
|
||||||
|
###########################################
|
||||||
|
# Training #
|
||||||
|
###########################################
|
||||||
|
seed: 1986 # according from speechbrain configuration
|
||||||
|
epochs: 100
|
||||||
|
save_interval: 10
|
||||||
|
log_interval: 10
|
||||||
|
learning_rate: 1e-8
|
||||||
|
|
||||||
|
|
||||||
|
###########################################
|
||||||
|
# Testing #
|
||||||
|
###########################################
|
||||||
|
global_embedding_norm: True
|
||||||
|
embedding_mean_norm: True
|
||||||
|
embedding_std_norm: False
|
||||||
|
|
@ -0,0 +1,167 @@
|
|||||||
|
# Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved.
|
||||||
|
#
|
||||||
|
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||||
|
# you may not use this file except in compliance with the License.
|
||||||
|
# You may obtain a copy of the License at
|
||||||
|
#
|
||||||
|
# http://www.apache.org/licenses/LICENSE-2.0
|
||||||
|
#
|
||||||
|
# Unless required by applicable law or agreed to in writing, software
|
||||||
|
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||||
|
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||||
|
# See the License for the specific language governing permissions and
|
||||||
|
# limitations under the License.
|
||||||
|
"""
|
||||||
|
Convert the PaddleSpeech jsonline format data to csv format data in voxceleb experiment.
|
||||||
|
Currently, Speaker Identificaton Training process use csv format.
|
||||||
|
"""
|
||||||
|
import argparse
|
||||||
|
import csv
|
||||||
|
import os
|
||||||
|
from typing import List
|
||||||
|
|
||||||
|
import tqdm
|
||||||
|
from yacs.config import CfgNode
|
||||||
|
|
||||||
|
from paddleaudio import load as load_audio
|
||||||
|
from paddlespeech.s2t.utils.log import Log
|
||||||
|
from paddlespeech.vector.utils.vector_utils import get_chunks
|
||||||
|
|
||||||
|
logger = Log(__name__).getlog()
|
||||||
|
|
||||||
|
|
||||||
|
def get_chunks_list(wav_file: str,
|
||||||
|
split_chunks: bool,
|
||||||
|
base_path: str,
|
||||||
|
chunk_duration: float=3.0) -> List[List[str]]:
|
||||||
|
"""Get the single audio file info
|
||||||
|
|
||||||
|
Args:
|
||||||
|
wav_file (list): the wav audio file and get this audio segment info list
|
||||||
|
split_chunks (bool): audio split flag
|
||||||
|
base_path (str): the audio base path
|
||||||
|
chunk_duration (float): the chunk duration.
|
||||||
|
if set the split_chunks, we split the audio into multi-chunks segment.
|
||||||
|
"""
|
||||||
|
waveform, sr = load_audio(wav_file)
|
||||||
|
audio_id = wav_file.split("/rir_noise/")[-1].split(".")[0]
|
||||||
|
audio_duration = waveform.shape[0] / sr
|
||||||
|
|
||||||
|
ret = []
|
||||||
|
if split_chunks and audio_duration > chunk_duration: # Split into pieces of self.chunk_duration seconds.
|
||||||
|
uniq_chunks_list = get_chunks(chunk_duration, audio_id, audio_duration)
|
||||||
|
|
||||||
|
for idx, chunk in enumerate(uniq_chunks_list):
|
||||||
|
s, e = chunk.split("_")[-2:] # Timestamps of start and end
|
||||||
|
start_sample = int(float(s) * sr)
|
||||||
|
end_sample = int(float(e) * sr)
|
||||||
|
|
||||||
|
# currently, all vector csv data format use one representation
|
||||||
|
# id, duration, wav, start, stop, label
|
||||||
|
# in rirs noise, all the label name is 'noise'
|
||||||
|
# the label is string type and we will convert it to integer type in training
|
||||||
|
ret.append([
|
||||||
|
chunk, audio_duration, wav_file, start_sample, end_sample,
|
||||||
|
"noise"
|
||||||
|
])
|
||||||
|
else: # Keep whole audio.
|
||||||
|
ret.append(
|
||||||
|
[audio_id, audio_duration, wav_file, 0, waveform.shape[0], "noise"])
|
||||||
|
return ret
|
||||||
|
|
||||||
|
|
||||||
|
def generate_csv(wav_files,
|
||||||
|
output_file: str,
|
||||||
|
base_path: str,
|
||||||
|
split_chunks: bool=True):
|
||||||
|
"""Prepare the csv file according the wav files
|
||||||
|
|
||||||
|
Args:
|
||||||
|
wav_files (list): all the audio list to prepare the csv file
|
||||||
|
output_file (str): the output csv file
|
||||||
|
config (CfgNode): yaml configuration content
|
||||||
|
split_chunks (bool): audio split flag
|
||||||
|
"""
|
||||||
|
logger.info(f'Generating csv: {output_file}')
|
||||||
|
header = ["utt_id", "duration", "wav", "start", "stop", "label"]
|
||||||
|
csv_lines = []
|
||||||
|
for item in tqdm.tqdm(wav_files):
|
||||||
|
csv_lines.extend(
|
||||||
|
get_chunks_list(
|
||||||
|
item, base_path=base_path, split_chunks=split_chunks))
|
||||||
|
|
||||||
|
if not os.path.exists(os.path.dirname(output_file)):
|
||||||
|
os.makedirs(os.path.dirname(output_file))
|
||||||
|
|
||||||
|
with open(output_file, mode="w") as csv_f:
|
||||||
|
csv_writer = csv.writer(
|
||||||
|
csv_f, delimiter=",", quotechar='"', quoting=csv.QUOTE_MINIMAL)
|
||||||
|
csv_writer.writerow(header)
|
||||||
|
for line in csv_lines:
|
||||||
|
csv_writer.writerow(line)
|
||||||
|
|
||||||
|
|
||||||
|
def prepare_data(args, config):
|
||||||
|
"""Convert the jsonline format to csv format
|
||||||
|
|
||||||
|
Args:
|
||||||
|
args (argparse.Namespace): scripts args
|
||||||
|
config (CfgNode): yaml configuration content
|
||||||
|
"""
|
||||||
|
# if external config set the skip_prep flat, we will do nothing
|
||||||
|
if config.skip_prep:
|
||||||
|
return
|
||||||
|
|
||||||
|
base_path = args.noise_dir
|
||||||
|
wav_path = os.path.join(base_path, "RIRS_NOISES")
|
||||||
|
logger.info(f"base path: {base_path}")
|
||||||
|
logger.info(f"wav path: {wav_path}")
|
||||||
|
rir_list = os.path.join(wav_path, "real_rirs_isotropic_noises", "rir_list")
|
||||||
|
rir_files = []
|
||||||
|
with open(rir_list, 'r') as f:
|
||||||
|
for line in f.readlines():
|
||||||
|
rir_file = line.strip().split(' ')[-1]
|
||||||
|
rir_files.append(os.path.join(base_path, rir_file))
|
||||||
|
|
||||||
|
noise_list = os.path.join(wav_path, "pointsource_noises", "noise_list")
|
||||||
|
noise_files = []
|
||||||
|
with open(noise_list, 'r') as f:
|
||||||
|
for line in f.readlines():
|
||||||
|
noise_file = line.strip().split(' ')[-1]
|
||||||
|
noise_files.append(os.path.join(base_path, noise_file))
|
||||||
|
|
||||||
|
csv_path = os.path.join(args.data_dir, 'csv')
|
||||||
|
logger.info(f"csv path: {csv_path}")
|
||||||
|
generate_csv(
|
||||||
|
rir_files, os.path.join(csv_path, 'rir.csv'), base_path=base_path)
|
||||||
|
generate_csv(
|
||||||
|
noise_files, os.path.join(csv_path, 'noise.csv'), base_path=base_path)
|
||||||
|
|
||||||
|
|
||||||
|
if __name__ == "__main__":
|
||||||
|
parser = argparse.ArgumentParser(description=__doc__)
|
||||||
|
parser.add_argument(
|
||||||
|
"--noise_dir",
|
||||||
|
default=None,
|
||||||
|
required=True,
|
||||||
|
help="The noise dataset dataset directory.")
|
||||||
|
parser.add_argument(
|
||||||
|
"--data_dir",
|
||||||
|
default=None,
|
||||||
|
required=True,
|
||||||
|
help="The target directory stores the csv files")
|
||||||
|
parser.add_argument(
|
||||||
|
"--config",
|
||||||
|
default=None,
|
||||||
|
required=True,
|
||||||
|
type=str,
|
||||||
|
help="configuration file")
|
||||||
|
args = parser.parse_args()
|
||||||
|
|
||||||
|
# parse the yaml config file
|
||||||
|
config = CfgNode(new_allowed=True)
|
||||||
|
if args.config:
|
||||||
|
config.merge_from_file(args.config)
|
||||||
|
|
||||||
|
# prepare the csv file from jsonlines files
|
||||||
|
prepare_data(args, config)
|
@ -0,0 +1,251 @@
|
|||||||
|
# Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved.
|
||||||
|
#
|
||||||
|
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||||
|
# you may not use this file except in compliance with the License.
|
||||||
|
# You may obtain a copy of the License at
|
||||||
|
#
|
||||||
|
# http://www.apache.org/licenses/LICENSE-2.0
|
||||||
|
#
|
||||||
|
# Unless required by applicable law or agreed to in writing, software
|
||||||
|
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||||
|
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||||
|
# See the License for the specific language governing permissions and
|
||||||
|
# limitations under the License.
|
||||||
|
"""
|
||||||
|
Convert the PaddleSpeech jsonline format data to csv format data in voxceleb experiment.
|
||||||
|
Currently, Speaker Identificaton Training process use csv format.
|
||||||
|
"""
|
||||||
|
import argparse
|
||||||
|
import csv
|
||||||
|
import json
|
||||||
|
import os
|
||||||
|
import random
|
||||||
|
|
||||||
|
import tqdm
|
||||||
|
from yacs.config import CfgNode
|
||||||
|
|
||||||
|
from paddleaudio import load as load_audio
|
||||||
|
from paddlespeech.s2t.utils.log import Log
|
||||||
|
from paddlespeech.vector.utils.vector_utils import get_chunks
|
||||||
|
|
||||||
|
logger = Log(__name__).getlog()
|
||||||
|
|
||||||
|
|
||||||
|
def prepare_csv(wav_files, output_file, config, split_chunks=True):
|
||||||
|
"""Prepare the csv file according the wav files
|
||||||
|
|
||||||
|
Args:
|
||||||
|
wav_files (list): all the audio list to prepare the csv file
|
||||||
|
output_file (str): the output csv file
|
||||||
|
config (CfgNode): yaml configuration content
|
||||||
|
split_chunks (bool, optional): audio split flag. Defaults to True.
|
||||||
|
"""
|
||||||
|
if not os.path.exists(os.path.dirname(output_file)):
|
||||||
|
os.makedirs(os.path.dirname(output_file))
|
||||||
|
csv_lines = []
|
||||||
|
header = ["utt_id", "duration", "wav", "start", "stop", "label"]
|
||||||
|
# voxceleb meta info for each training utterance segment
|
||||||
|
# we extract a segment from a utterance to train
|
||||||
|
# and the segment' period is between start and stop time point in the original wav file
|
||||||
|
# each field in the meta info means as follows:
|
||||||
|
# utt_id: the utterance segment name, which is uniq in training dataset
|
||||||
|
# duration: the total utterance time
|
||||||
|
# wav: utterance file path, which should be absoulute path
|
||||||
|
# start: start point in the original wav file sample point range
|
||||||
|
# stop: stop point in the original wav file sample point range
|
||||||
|
# label: the utterance segment's label name,
|
||||||
|
# which is speaker name in speaker verification domain
|
||||||
|
for item in tqdm.tqdm(wav_files, total=len(wav_files)):
|
||||||
|
item = json.loads(item.strip())
|
||||||
|
audio_id = item['utt'].replace(".wav",
|
||||||
|
"") # we remove the wav suffix name
|
||||||
|
audio_duration = item['feat_shape'][0]
|
||||||
|
wav_file = item['feat']
|
||||||
|
label = audio_id.split('-')[
|
||||||
|
0] # speaker name in speaker verification domain
|
||||||
|
waveform, sr = load_audio(wav_file)
|
||||||
|
if split_chunks:
|
||||||
|
uniq_chunks_list = get_chunks(config.chunk_duration, audio_id,
|
||||||
|
audio_duration)
|
||||||
|
for chunk in uniq_chunks_list:
|
||||||
|
s, e = chunk.split("_")[-2:] # Timestamps of start and end
|
||||||
|
start_sample = int(float(s) * sr)
|
||||||
|
end_sample = int(float(e) * sr)
|
||||||
|
# id, duration, wav, start, stop, label
|
||||||
|
# in vector, the label in speaker id
|
||||||
|
csv_lines.append([
|
||||||
|
chunk, audio_duration, wav_file, start_sample, end_sample,
|
||||||
|
label
|
||||||
|
])
|
||||||
|
else:
|
||||||
|
csv_lines.append([
|
||||||
|
audio_id, audio_duration, wav_file, 0, waveform.shape[0], label
|
||||||
|
])
|
||||||
|
|
||||||
|
with open(output_file, mode="w") as csv_f:
|
||||||
|
csv_writer = csv.writer(
|
||||||
|
csv_f, delimiter=',', quotechar='"', quoting=csv.QUOTE_MINIMAL)
|
||||||
|
csv_writer.writerow(header)
|
||||||
|
for line in csv_lines:
|
||||||
|
csv_writer.writerow(line)
|
||||||
|
|
||||||
|
|
||||||
|
def get_enroll_test_list(dataset_list, verification_file):
|
||||||
|
"""Get the enroll and test utterance list from all the voxceleb1 test utterance dataset.
|
||||||
|
Generally, we get the enroll and test utterances from the verfification file.
|
||||||
|
The verification file format as follows:
|
||||||
|
target/nontarget enroll-utt test-utt,
|
||||||
|
we set 0 as nontarget and 1 as target, eg:
|
||||||
|
0 a.wav b.wav
|
||||||
|
1 a.wav a.wav
|
||||||
|
|
||||||
|
Args:
|
||||||
|
dataset_list (list): all the dataset to get the test utterances
|
||||||
|
verification_file (str): voxceleb1 trial file
|
||||||
|
"""
|
||||||
|
logger.info(f"verification file: {verification_file}")
|
||||||
|
enroll_audios = set()
|
||||||
|
test_audios = set()
|
||||||
|
with open(verification_file, 'r') as f:
|
||||||
|
for line in f:
|
||||||
|
_, enroll_file, test_file = line.strip().split(' ')
|
||||||
|
enroll_audios.add('-'.join(enroll_file.split('/')))
|
||||||
|
test_audios.add('-'.join(test_file.split('/')))
|
||||||
|
|
||||||
|
enroll_files = []
|
||||||
|
test_files = []
|
||||||
|
for dataset in dataset_list:
|
||||||
|
with open(dataset, 'r') as f:
|
||||||
|
for line in f:
|
||||||
|
# audio_id may be in enroll and test at the same time
|
||||||
|
# eg: 1 a.wav a.wav
|
||||||
|
# the audio a.wav is enroll and test file at the same time
|
||||||
|
audio_id = json.loads(line.strip())['utt']
|
||||||
|
if audio_id in enroll_audios:
|
||||||
|
enroll_files.append(line)
|
||||||
|
if audio_id in test_audios:
|
||||||
|
test_files.append(line)
|
||||||
|
|
||||||
|
enroll_files = sorted(enroll_files)
|
||||||
|
test_files = sorted(test_files)
|
||||||
|
|
||||||
|
return enroll_files, test_files
|
||||||
|
|
||||||
|
|
||||||
|
def get_train_dev_list(dataset_list, target_dir, split_ratio):
|
||||||
|
"""Get the train and dev utterance list from all the training utterance dataset.
|
||||||
|
Generally, we use the split_ratio as the train dataset ratio,
|
||||||
|
and the remaining utterance (ratio is 1 - split_ratio) is the dev dataset
|
||||||
|
|
||||||
|
Args:
|
||||||
|
dataset_list (list): all the dataset to get the all utterances
|
||||||
|
target_dir (str): the target train and dev directory,
|
||||||
|
we will create the csv directory to store the {train,dev}.csv file
|
||||||
|
split_ratio (float): train dataset ratio in all utterance list
|
||||||
|
"""
|
||||||
|
logger.info("start to get train and dev utt list")
|
||||||
|
if not os.path.exists(os.path.join(target_dir, "meta")):
|
||||||
|
os.makedirs(os.path.join(target_dir, "meta"))
|
||||||
|
|
||||||
|
audio_files = []
|
||||||
|
speakers = set()
|
||||||
|
for dataset in dataset_list:
|
||||||
|
with open(dataset, 'r') as f:
|
||||||
|
for line in f:
|
||||||
|
# the label is speaker name
|
||||||
|
label_name = json.loads(line.strip())['utt2spk']
|
||||||
|
speakers.add(label_name)
|
||||||
|
audio_files.append(line.strip())
|
||||||
|
speakers = sorted(speakers)
|
||||||
|
logger.info(f"we get {len(speakers)} speakers from all the train dataset")
|
||||||
|
|
||||||
|
with open(os.path.join(target_dir, "meta", "label2id.txt"), 'w') as f:
|
||||||
|
for label_id, label_name in enumerate(speakers):
|
||||||
|
f.write(f'{label_name} {label_id}\n')
|
||||||
|
logger.info(
|
||||||
|
f'we store the speakers to {os.path.join(target_dir, "meta", "label2id.txt")}'
|
||||||
|
)
|
||||||
|
|
||||||
|
# the split_ratio is for train dataset
|
||||||
|
# the remaining is for dev dataset
|
||||||
|
split_idx = int(split_ratio * len(audio_files))
|
||||||
|
audio_files = sorted(audio_files)
|
||||||
|
random.shuffle(audio_files)
|
||||||
|
train_files, dev_files = audio_files[:split_idx], audio_files[split_idx:]
|
||||||
|
logger.info(
|
||||||
|
f"we get train utterances: {len(train_files)}, dev utterance: {len(dev_files)}"
|
||||||
|
)
|
||||||
|
return train_files, dev_files
|
||||||
|
|
||||||
|
|
||||||
|
def prepare_data(args, config):
|
||||||
|
"""Convert the jsonline format to csv format
|
||||||
|
|
||||||
|
Args:
|
||||||
|
args (argparse.Namespace): scripts args
|
||||||
|
config (CfgNode): yaml configuration content
|
||||||
|
"""
|
||||||
|
# stage0: set the random seed
|
||||||
|
random.seed(config.seed)
|
||||||
|
|
||||||
|
# if external config set the skip_prep flat, we will do nothing
|
||||||
|
if config.skip_prep:
|
||||||
|
return
|
||||||
|
|
||||||
|
# stage 1: prepare the enroll and test csv file
|
||||||
|
# And we generate the speaker to label file label2id.txt
|
||||||
|
logger.info("start to prepare the data csv file")
|
||||||
|
enroll_files, test_files = get_enroll_test_list(
|
||||||
|
[args.test], verification_file=config.verification_file)
|
||||||
|
prepare_csv(
|
||||||
|
enroll_files,
|
||||||
|
os.path.join(args.target_dir, "csv", "enroll.csv"),
|
||||||
|
config,
|
||||||
|
split_chunks=False)
|
||||||
|
prepare_csv(
|
||||||
|
test_files,
|
||||||
|
os.path.join(args.target_dir, "csv", "test.csv"),
|
||||||
|
config,
|
||||||
|
split_chunks=False)
|
||||||
|
|
||||||
|
# stage 2: prepare the train and dev csv file
|
||||||
|
# we get the train dataset ratio as config.split_ratio
|
||||||
|
# and the remaining is dev dataset
|
||||||
|
logger.info("start to prepare the data csv file")
|
||||||
|
train_files, dev_files = get_train_dev_list(
|
||||||
|
args.train, target_dir=args.target_dir, split_ratio=config.split_ratio)
|
||||||
|
prepare_csv(train_files,
|
||||||
|
os.path.join(args.target_dir, "csv", "train.csv"), config)
|
||||||
|
prepare_csv(dev_files,
|
||||||
|
os.path.join(args.target_dir, "csv", "dev.csv"), config)
|
||||||
|
|
||||||
|
|
||||||
|
if __name__ == "__main__":
|
||||||
|
parser = argparse.ArgumentParser(description=__doc__)
|
||||||
|
parser.add_argument(
|
||||||
|
"--train",
|
||||||
|
required=True,
|
||||||
|
nargs='+',
|
||||||
|
help="The jsonline files list for train.")
|
||||||
|
parser.add_argument(
|
||||||
|
"--test", required=True, help="The jsonline file for test")
|
||||||
|
parser.add_argument(
|
||||||
|
"--target_dir",
|
||||||
|
default=None,
|
||||||
|
required=True,
|
||||||
|
help="The target directory stores the csv files and meta file.")
|
||||||
|
parser.add_argument(
|
||||||
|
"--config",
|
||||||
|
default=None,
|
||||||
|
required=True,
|
||||||
|
type=str,
|
||||||
|
help="configuration file")
|
||||||
|
args = parser.parse_args()
|
||||||
|
|
||||||
|
# parse the yaml config file
|
||||||
|
config = CfgNode(new_allowed=True)
|
||||||
|
if args.config:
|
||||||
|
config.merge_from_file(args.config)
|
||||||
|
|
||||||
|
# prepare the csv file from jsonlines files
|
||||||
|
prepare_data(args, config)
|
@ -0,0 +1,30 @@
|
|||||||
|
# Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved.
|
||||||
|
#
|
||||||
|
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||||
|
# you may not use this file except in compliance with the License.
|
||||||
|
# You may obtain a copy of the License at
|
||||||
|
#
|
||||||
|
# http://www.apache.org/licenses/LICENSE-2.0
|
||||||
|
#
|
||||||
|
# Unless required by applicable law or agreed to in writing, software
|
||||||
|
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||||
|
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||||
|
# See the License for the specific language governing permissions and
|
||||||
|
# limitations under the License.
|
||||||
|
import numpy as np
|
||||||
|
|
||||||
|
|
||||||
|
def pcm16to32(audio: np.ndarray) -> np.ndarray:
|
||||||
|
"""pcm int16 to float32
|
||||||
|
|
||||||
|
Args:
|
||||||
|
audio (np.ndarray): Waveform with dtype of int16.
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
np.ndarray: Waveform with dtype of float32.
|
||||||
|
"""
|
||||||
|
if audio.dtype == np.int16:
|
||||||
|
audio = audio.astype("float32")
|
||||||
|
bits = np.iinfo(np.int16).bits
|
||||||
|
audio = audio / (2**(bits - 1))
|
||||||
|
return audio
|
@ -0,0 +1,46 @@
|
|||||||
|
# This is the parameter configuration file for PaddleSpeech Serving.
|
||||||
|
|
||||||
|
#################################################################################
|
||||||
|
# SERVER SETTING #
|
||||||
|
#################################################################################
|
||||||
|
host: 127.0.0.1
|
||||||
|
port: 8092
|
||||||
|
|
||||||
|
# The task format in the engin_list is: <speech task>_<engine type>
|
||||||
|
# task choices = ['asr_online', 'tts_online']
|
||||||
|
# protocol = ['websocket', 'http'] (only one can be selected).
|
||||||
|
protocol: 'http'
|
||||||
|
engine_list: ['tts_online']
|
||||||
|
|
||||||
|
|
||||||
|
#################################################################################
|
||||||
|
# ENGINE CONFIG #
|
||||||
|
#################################################################################
|
||||||
|
|
||||||
|
################################### TTS #########################################
|
||||||
|
################### speech task: tts; engine_type: online #######################
|
||||||
|
tts_online:
|
||||||
|
# am (acoustic model) choices=['fastspeech2_csmsc']
|
||||||
|
am: 'fastspeech2_csmsc'
|
||||||
|
am_config:
|
||||||
|
am_ckpt:
|
||||||
|
am_stat:
|
||||||
|
phones_dict:
|
||||||
|
tones_dict:
|
||||||
|
speaker_dict:
|
||||||
|
spk_id: 0
|
||||||
|
|
||||||
|
# voc (vocoder) choices=['mb_melgan_csmsc']
|
||||||
|
voc: 'mb_melgan_csmsc'
|
||||||
|
voc_config:
|
||||||
|
voc_ckpt:
|
||||||
|
voc_stat:
|
||||||
|
|
||||||
|
# others
|
||||||
|
lang: 'zh'
|
||||||
|
device: # set 'gpu:id' or 'cpu'
|
||||||
|
am_block: 42
|
||||||
|
am_pad: 12
|
||||||
|
voc_block: 14
|
||||||
|
voc_pad: 14
|
||||||
|
|
@ -0,0 +1,13 @@
|
|||||||
|
# Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved.
|
||||||
|
#
|
||||||
|
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||||
|
# you may not use this file except in compliance with the License.
|
||||||
|
# You may obtain a copy of the License at
|
||||||
|
#
|
||||||
|
# http://www.apache.org/licenses/LICENSE-2.0
|
||||||
|
#
|
||||||
|
# Unless required by applicable law or agreed to in writing, software
|
||||||
|
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||||
|
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||||
|
# See the License for the specific language governing permissions and
|
||||||
|
# limitations under the License.
|
@ -0,0 +1,220 @@
|
|||||||
|
# Copyright (c) 2021 PaddlePaddle Authors. All Rights Reserved.
|
||||||
|
#
|
||||||
|
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||||
|
# you may not use this file except in compliance with the License.
|
||||||
|
# You may obtain a copy of the License at
|
||||||
|
#
|
||||||
|
# http://www.apache.org/licenses/LICENSE-2.0
|
||||||
|
#
|
||||||
|
# Unless required by applicable law or agreed to in writing, software
|
||||||
|
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||||
|
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||||
|
# See the License for the specific language governing permissions and
|
||||||
|
# limitations under the License.
|
||||||
|
import base64
|
||||||
|
import time
|
||||||
|
|
||||||
|
import numpy as np
|
||||||
|
import paddle
|
||||||
|
|
||||||
|
from paddlespeech.cli.log import logger
|
||||||
|
from paddlespeech.cli.tts.infer import TTSExecutor
|
||||||
|
from paddlespeech.server.engine.base_engine import BaseEngine
|
||||||
|
from paddlespeech.server.utils.audio_process import float2pcm
|
||||||
|
from paddlespeech.server.utils.util import get_chunks
|
||||||
|
|
||||||
|
__all__ = ['TTSEngine']
|
||||||
|
|
||||||
|
|
||||||
|
class TTSServerExecutor(TTSExecutor):
|
||||||
|
def __init__(self):
|
||||||
|
super().__init__()
|
||||||
|
pass
|
||||||
|
|
||||||
|
@paddle.no_grad()
|
||||||
|
def infer(
|
||||||
|
self,
|
||||||
|
text: str,
|
||||||
|
lang: str='zh',
|
||||||
|
am: str='fastspeech2_csmsc',
|
||||||
|
spk_id: int=0,
|
||||||
|
am_block: int=42,
|
||||||
|
am_pad: int=12,
|
||||||
|
voc_block: int=14,
|
||||||
|
voc_pad: int=14, ):
|
||||||
|
"""
|
||||||
|
Model inference and result stored in self.output.
|
||||||
|
"""
|
||||||
|
am_name = am[:am.rindex('_')]
|
||||||
|
am_dataset = am[am.rindex('_') + 1:]
|
||||||
|
get_tone_ids = False
|
||||||
|
merge_sentences = False
|
||||||
|
frontend_st = time.time()
|
||||||
|
if lang == 'zh':
|
||||||
|
input_ids = self.frontend.get_input_ids(
|
||||||
|
text,
|
||||||
|
merge_sentences=merge_sentences,
|
||||||
|
get_tone_ids=get_tone_ids)
|
||||||
|
phone_ids = input_ids["phone_ids"]
|
||||||
|
if get_tone_ids:
|
||||||
|
tone_ids = input_ids["tone_ids"]
|
||||||
|
elif lang == 'en':
|
||||||
|
input_ids = self.frontend.get_input_ids(
|
||||||
|
text, merge_sentences=merge_sentences)
|
||||||
|
phone_ids = input_ids["phone_ids"]
|
||||||
|
else:
|
||||||
|
print("lang should in {'zh', 'en'}!")
|
||||||
|
self.frontend_time = time.time() - frontend_st
|
||||||
|
|
||||||
|
for i in range(len(phone_ids)):
|
||||||
|
am_st = time.time()
|
||||||
|
part_phone_ids = phone_ids[i]
|
||||||
|
# am
|
||||||
|
if am_name == 'speedyspeech':
|
||||||
|
part_tone_ids = tone_ids[i]
|
||||||
|
mel = self.am_inference(part_phone_ids, part_tone_ids)
|
||||||
|
# fastspeech2
|
||||||
|
else:
|
||||||
|
# multi speaker
|
||||||
|
if am_dataset in {"aishell3", "vctk"}:
|
||||||
|
mel = self.am_inference(
|
||||||
|
part_phone_ids, spk_id=paddle.to_tensor(spk_id))
|
||||||
|
else:
|
||||||
|
mel = self.am_inference(part_phone_ids)
|
||||||
|
am_et = time.time()
|
||||||
|
|
||||||
|
# voc streaming
|
||||||
|
voc_upsample = self.voc_config.n_shift
|
||||||
|
mel_chunks = get_chunks(mel, voc_block, voc_pad, "voc")
|
||||||
|
chunk_num = len(mel_chunks)
|
||||||
|
voc_st = time.time()
|
||||||
|
for i, mel_chunk in enumerate(mel_chunks):
|
||||||
|
sub_wav = self.voc_inference(mel_chunk)
|
||||||
|
front_pad = min(i * voc_block, voc_pad)
|
||||||
|
|
||||||
|
if i == 0:
|
||||||
|
sub_wav = sub_wav[:voc_block * voc_upsample]
|
||||||
|
elif i == chunk_num - 1:
|
||||||
|
sub_wav = sub_wav[front_pad * voc_upsample:]
|
||||||
|
else:
|
||||||
|
sub_wav = sub_wav[front_pad * voc_upsample:(
|
||||||
|
front_pad + voc_block) * voc_upsample]
|
||||||
|
|
||||||
|
yield sub_wav
|
||||||
|
|
||||||
|
|
||||||
|
class TTSEngine(BaseEngine):
|
||||||
|
"""TTS server engine
|
||||||
|
|
||||||
|
Args:
|
||||||
|
metaclass: Defaults to Singleton.
|
||||||
|
"""
|
||||||
|
|
||||||
|
def __init__(self, name=None):
|
||||||
|
"""Initialize TTS server engine
|
||||||
|
"""
|
||||||
|
super(TTSEngine, self).__init__()
|
||||||
|
|
||||||
|
def init(self, config: dict) -> bool:
|
||||||
|
self.executor = TTSServerExecutor()
|
||||||
|
self.config = config
|
||||||
|
assert "fastspeech2_csmsc" in config.am and (
|
||||||
|
config.voc == "hifigan_csmsc-zh" or config.voc == "mb_melgan_csmsc"
|
||||||
|
), 'Please check config, am support: fastspeech2, voc support: hifigan_csmsc-zh or mb_melgan_csmsc.'
|
||||||
|
try:
|
||||||
|
if self.config.device:
|
||||||
|
self.device = self.config.device
|
||||||
|
else:
|
||||||
|
self.device = paddle.get_device()
|
||||||
|
paddle.set_device(self.device)
|
||||||
|
except Exception as e:
|
||||||
|
logger.error(
|
||||||
|
"Set device failed, please check if device is already used and the parameter 'device' in the yaml file"
|
||||||
|
)
|
||||||
|
logger.error("Initialize TTS server engine Failed on device: %s." %
|
||||||
|
(self.device))
|
||||||
|
return False
|
||||||
|
|
||||||
|
try:
|
||||||
|
self.executor._init_from_path(
|
||||||
|
am=self.config.am,
|
||||||
|
am_config=self.config.am_config,
|
||||||
|
am_ckpt=self.config.am_ckpt,
|
||||||
|
am_stat=self.config.am_stat,
|
||||||
|
phones_dict=self.config.phones_dict,
|
||||||
|
tones_dict=self.config.tones_dict,
|
||||||
|
speaker_dict=self.config.speaker_dict,
|
||||||
|
voc=self.config.voc,
|
||||||
|
voc_config=self.config.voc_config,
|
||||||
|
voc_ckpt=self.config.voc_ckpt,
|
||||||
|
voc_stat=self.config.voc_stat,
|
||||||
|
lang=self.config.lang)
|
||||||
|
except Exception as e:
|
||||||
|
logger.error("Failed to get model related files.")
|
||||||
|
logger.error("Initialize TTS server engine Failed on device: %s." %
|
||||||
|
(self.device))
|
||||||
|
return False
|
||||||
|
|
||||||
|
self.am_block = self.config.am_block
|
||||||
|
self.am_pad = self.config.am_pad
|
||||||
|
self.voc_block = self.config.voc_block
|
||||||
|
self.voc_pad = self.config.voc_pad
|
||||||
|
|
||||||
|
logger.info("Initialize TTS server engine successfully on device: %s." %
|
||||||
|
(self.device))
|
||||||
|
return True
|
||||||
|
|
||||||
|
def preprocess(self, text_bese64: str=None, text_bytes: bytes=None):
|
||||||
|
# Convert byte to text
|
||||||
|
if text_bese64:
|
||||||
|
text_bytes = base64.b64decode(text_bese64) # base64 to bytes
|
||||||
|
text = text_bytes.decode('utf-8') # bytes to text
|
||||||
|
|
||||||
|
return text
|
||||||
|
|
||||||
|
def run(self,
|
||||||
|
sentence: str,
|
||||||
|
spk_id: int=0,
|
||||||
|
speed: float=1.0,
|
||||||
|
volume: float=1.0,
|
||||||
|
sample_rate: int=0,
|
||||||
|
save_path: str=None):
|
||||||
|
""" run include inference and postprocess.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
sentence (str): text to be synthesized
|
||||||
|
spk_id (int, optional): speaker id for multi-speaker speech synthesis. Defaults to 0.
|
||||||
|
speed (float, optional): speed. Defaults to 1.0.
|
||||||
|
volume (float, optional): volume. Defaults to 1.0.
|
||||||
|
sample_rate (int, optional): target sample rate for synthesized audio,
|
||||||
|
0 means the same as the model sampling rate. Defaults to 0.
|
||||||
|
save_path (str, optional): The save path of the synthesized audio.
|
||||||
|
None means do not save audio. Defaults to None.
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
wav_base64: The base64 format of the synthesized audio.
|
||||||
|
"""
|
||||||
|
|
||||||
|
lang = self.config.lang
|
||||||
|
wav_list = []
|
||||||
|
|
||||||
|
for wav in self.executor.infer(
|
||||||
|
text=sentence,
|
||||||
|
lang=lang,
|
||||||
|
am=self.config.am,
|
||||||
|
spk_id=spk_id,
|
||||||
|
am_block=self.am_block,
|
||||||
|
am_pad=self.am_pad,
|
||||||
|
voc_block=self.voc_block,
|
||||||
|
voc_pad=self.voc_pad):
|
||||||
|
# wav type: <class 'numpy.ndarray'> float32, convert to pcm (base64)
|
||||||
|
wav = float2pcm(wav) # float32 to int16
|
||||||
|
wav_bytes = wav.tobytes() # to bytes
|
||||||
|
wav_base64 = base64.b64encode(wav_bytes).decode('utf8') # to base64
|
||||||
|
wav_list.append(wav)
|
||||||
|
|
||||||
|
yield wav_base64
|
||||||
|
|
||||||
|
wav_all = np.concatenate(wav_list, axis=0)
|
||||||
|
logger.info("The durations of audio is: {} s".format(
|
||||||
|
len(wav_all) / self.executor.am_config.fs))
|
@ -0,0 +1,100 @@
|
|||||||
|
# Copyright (c) 2021 PaddlePaddle Authors. All Rights Reserved.
|
||||||
|
#
|
||||||
|
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||||
|
# you may not use this file except in compliance with the License.
|
||||||
|
# You may obtain a copy of the License at
|
||||||
|
#
|
||||||
|
# http://www.apache.org/licenses/LICENSE-2.0
|
||||||
|
#
|
||||||
|
# Unless required by applicable law or agreed to in writing, software
|
||||||
|
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||||
|
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||||
|
# See the License for the specific language governing permissions and
|
||||||
|
# limitations under the License.
|
||||||
|
import argparse
|
||||||
|
import base64
|
||||||
|
import json
|
||||||
|
import os
|
||||||
|
import time
|
||||||
|
|
||||||
|
import requests
|
||||||
|
|
||||||
|
from paddlespeech.server.utils.audio_process import pcm2wav
|
||||||
|
|
||||||
|
|
||||||
|
def save_audio(buffer, audio_path) -> bool:
|
||||||
|
if args.save_path.endswith("pcm"):
|
||||||
|
with open(args.save_path, "wb") as f:
|
||||||
|
f.write(buffer)
|
||||||
|
elif args.save_path.endswith("wav"):
|
||||||
|
with open("./tmp.pcm", "wb") as f:
|
||||||
|
f.write(buffer)
|
||||||
|
pcm2wav("./tmp.pcm", audio_path, channels=1, bits=16, sample_rate=24000)
|
||||||
|
os.system("rm ./tmp.pcm")
|
||||||
|
else:
|
||||||
|
print("Only supports saved audio format is pcm or wav")
|
||||||
|
return False
|
||||||
|
|
||||||
|
return True
|
||||||
|
|
||||||
|
|
||||||
|
def test(args):
|
||||||
|
params = {
|
||||||
|
"text": args.text,
|
||||||
|
"spk_id": args.spk_id,
|
||||||
|
"speed": args.speed,
|
||||||
|
"volume": args.volume,
|
||||||
|
"sample_rate": args.sample_rate,
|
||||||
|
"save_path": ''
|
||||||
|
}
|
||||||
|
|
||||||
|
buffer = b''
|
||||||
|
flag = 1
|
||||||
|
url = "http://" + str(args.server) + ":" + str(
|
||||||
|
args.port) + "/paddlespeech/streaming/tts"
|
||||||
|
st = time.time()
|
||||||
|
html = requests.post(url, json.dumps(params), stream=True)
|
||||||
|
for chunk in html.iter_content(chunk_size=1024):
|
||||||
|
chunk = base64.b64decode(chunk) # bytes
|
||||||
|
if flag:
|
||||||
|
first_response = time.time() - st
|
||||||
|
print(f"首包响应:{first_response} s")
|
||||||
|
flag = 0
|
||||||
|
buffer += chunk
|
||||||
|
|
||||||
|
final_response = time.time() - st
|
||||||
|
duration = len(buffer) / 2.0 / 24000
|
||||||
|
|
||||||
|
print(f"尾包响应:{final_response} s")
|
||||||
|
print(f"音频时长:{duration} s")
|
||||||
|
print(f"RTF: {final_response / duration}")
|
||||||
|
|
||||||
|
if args.save_path is not None:
|
||||||
|
if save_audio(buffer, args.save_path):
|
||||||
|
print("音频保存至:", args.save_path)
|
||||||
|
|
||||||
|
|
||||||
|
if __name__ == "__main__":
|
||||||
|
parser = argparse.ArgumentParser()
|
||||||
|
parser.add_argument(
|
||||||
|
'--text',
|
||||||
|
type=str,
|
||||||
|
default="您好,欢迎使用语音合成服务。",
|
||||||
|
help='A sentence to be synthesized')
|
||||||
|
parser.add_argument('--spk_id', type=int, default=0, help='Speaker id')
|
||||||
|
parser.add_argument('--speed', type=float, default=1.0, help='Audio speed')
|
||||||
|
parser.add_argument(
|
||||||
|
'--volume', type=float, default=1.0, help='Audio volume')
|
||||||
|
parser.add_argument(
|
||||||
|
'--sample_rate',
|
||||||
|
type=int,
|
||||||
|
default=0,
|
||||||
|
help='Sampling rate, the default is the same as the model')
|
||||||
|
parser.add_argument(
|
||||||
|
"--server", type=str, help="server ip", default="127.0.0.1")
|
||||||
|
parser.add_argument("--port", type=int, help="server port", default=8092)
|
||||||
|
parser.add_argument(
|
||||||
|
"--save_path", type=str, help="save audio path", default=None)
|
||||||
|
|
||||||
|
args = parser.parse_args()
|
||||||
|
test(args)
|
@ -0,0 +1,112 @@
|
|||||||
|
# Copyright (c) 2021 PaddlePaddle Authors. All Rights Reserved.
|
||||||
|
#
|
||||||
|
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||||
|
# you may not use this file except in compliance with the License.
|
||||||
|
# You may obtain a copy of the License at
|
||||||
|
#
|
||||||
|
# http://www.apache.org/licenses/LICENSE-2.0
|
||||||
|
#
|
||||||
|
# Unless required by applicable law or agreed to in writing, software
|
||||||
|
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||||
|
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||||
|
# See the License for the specific language governing permissions and
|
||||||
|
# limitations under the License.
|
||||||
|
import argparse
|
||||||
|
import base64
|
||||||
|
import json
|
||||||
|
import threading
|
||||||
|
import time
|
||||||
|
|
||||||
|
import pyaudio
|
||||||
|
import requests
|
||||||
|
|
||||||
|
mutex = threading.Lock()
|
||||||
|
buffer = b''
|
||||||
|
p = pyaudio.PyAudio()
|
||||||
|
stream = p.open(
|
||||||
|
format=p.get_format_from_width(2), channels=1, rate=24000, output=True)
|
||||||
|
max_fail = 50
|
||||||
|
|
||||||
|
|
||||||
|
def play_audio():
|
||||||
|
global stream
|
||||||
|
global buffer
|
||||||
|
global max_fail
|
||||||
|
while True:
|
||||||
|
if not buffer:
|
||||||
|
max_fail -= 1
|
||||||
|
time.sleep(0.05)
|
||||||
|
if max_fail < 0:
|
||||||
|
break
|
||||||
|
mutex.acquire()
|
||||||
|
stream.write(buffer)
|
||||||
|
buffer = b''
|
||||||
|
mutex.release()
|
||||||
|
|
||||||
|
|
||||||
|
def test(args):
|
||||||
|
global mutex
|
||||||
|
global buffer
|
||||||
|
params = {
|
||||||
|
"text": args.text,
|
||||||
|
"spk_id": args.spk_id,
|
||||||
|
"speed": args.speed,
|
||||||
|
"volume": args.volume,
|
||||||
|
"sample_rate": args.sample_rate,
|
||||||
|
"save_path": ''
|
||||||
|
}
|
||||||
|
|
||||||
|
all_bytes = 0.0
|
||||||
|
t = threading.Thread(target=play_audio)
|
||||||
|
flag = 1
|
||||||
|
url = "http://" + str(args.server) + ":" + str(
|
||||||
|
args.port) + "/paddlespeech/streaming/tts"
|
||||||
|
st = time.time()
|
||||||
|
html = requests.post(url, json.dumps(params), stream=True)
|
||||||
|
for chunk in html.iter_content(chunk_size=1024):
|
||||||
|
mutex.acquire()
|
||||||
|
chunk = base64.b64decode(chunk) # bytes
|
||||||
|
buffer += chunk
|
||||||
|
mutex.release()
|
||||||
|
if flag:
|
||||||
|
first_response = time.time() - st
|
||||||
|
print(f"首包响应:{first_response} s")
|
||||||
|
flag = 0
|
||||||
|
t.start()
|
||||||
|
all_bytes += len(chunk)
|
||||||
|
|
||||||
|
final_response = time.time() - st
|
||||||
|
duration = all_bytes / 2 / 24000
|
||||||
|
|
||||||
|
print(f"尾包响应:{final_response} s")
|
||||||
|
print(f"音频时长:{duration} s")
|
||||||
|
print(f"RTF: {final_response / duration}")
|
||||||
|
|
||||||
|
t.join()
|
||||||
|
stream.stop_stream()
|
||||||
|
stream.close()
|
||||||
|
p.terminate()
|
||||||
|
|
||||||
|
|
||||||
|
if __name__ == "__main__":
|
||||||
|
parser = argparse.ArgumentParser()
|
||||||
|
parser.add_argument(
|
||||||
|
'--text',
|
||||||
|
type=str,
|
||||||
|
default="您好,欢迎使用语音合成服务。",
|
||||||
|
help='A sentence to be synthesized')
|
||||||
|
parser.add_argument('--spk_id', type=int, default=0, help='Speaker id')
|
||||||
|
parser.add_argument('--speed', type=float, default=1.0, help='Audio speed')
|
||||||
|
parser.add_argument(
|
||||||
|
'--volume', type=float, default=1.0, help='Audio volume')
|
||||||
|
parser.add_argument(
|
||||||
|
'--sample_rate',
|
||||||
|
type=int,
|
||||||
|
default=0,
|
||||||
|
help='Sampling rate, the default is the same as the model')
|
||||||
|
parser.add_argument(
|
||||||
|
"--server", type=str, help="server ip", default="127.0.0.1")
|
||||||
|
parser.add_argument("--port", type=int, help="server port", default=8092)
|
||||||
|
|
||||||
|
args = parser.parse_args()
|
||||||
|
test(args)
|
@ -0,0 +1,126 @@
|
|||||||
|
# Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved.
|
||||||
|
#
|
||||||
|
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||||
|
# you may not use this file except in compliance with the License.
|
||||||
|
# You may obtain a copy of the License at
|
||||||
|
#
|
||||||
|
# http://www.apache.org/licenses/LICENSE-2.0
|
||||||
|
#
|
||||||
|
# Unless required by applicable law or agreed to in writing, software
|
||||||
|
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||||
|
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||||
|
# See the License for the specific language governing permissions and
|
||||||
|
# limitations under the License.
|
||||||
|
import _thread as thread
|
||||||
|
import argparse
|
||||||
|
import base64
|
||||||
|
import json
|
||||||
|
import ssl
|
||||||
|
import time
|
||||||
|
|
||||||
|
import websocket
|
||||||
|
|
||||||
|
flag = 1
|
||||||
|
st = 0.0
|
||||||
|
all_bytes = b''
|
||||||
|
|
||||||
|
|
||||||
|
class WsParam(object):
|
||||||
|
# 初始化
|
||||||
|
def __init__(self, text, server="127.0.0.1", port=8090):
|
||||||
|
self.server = server
|
||||||
|
self.port = port
|
||||||
|
self.url = "ws://" + self.server + ":" + str(self.port) + "/ws/tts"
|
||||||
|
self.text = text
|
||||||
|
|
||||||
|
# 生成url
|
||||||
|
def create_url(self):
|
||||||
|
return self.url
|
||||||
|
|
||||||
|
|
||||||
|
def on_message(ws, message):
|
||||||
|
global flag
|
||||||
|
global st
|
||||||
|
global all_bytes
|
||||||
|
|
||||||
|
try:
|
||||||
|
message = json.loads(message)
|
||||||
|
audio = message["audio"]
|
||||||
|
audio = base64.b64decode(audio) # bytes
|
||||||
|
status = message["status"]
|
||||||
|
all_bytes += audio
|
||||||
|
|
||||||
|
if status == 0:
|
||||||
|
print("create successfully.")
|
||||||
|
elif status == 1:
|
||||||
|
if flag:
|
||||||
|
print(f"首包响应:{time.time() - st} s")
|
||||||
|
flag = 0
|
||||||
|
elif status == 2:
|
||||||
|
final_response = time.time() - st
|
||||||
|
duration = len(all_bytes) / 2.0 / 24000
|
||||||
|
print(f"尾包响应:{final_response} s")
|
||||||
|
print(f"音频时长:{duration} s")
|
||||||
|
print(f"RTF: {final_response / duration}")
|
||||||
|
with open("./out.pcm", "wb") as f:
|
||||||
|
f.write(all_bytes)
|
||||||
|
print("ws is closed")
|
||||||
|
ws.close()
|
||||||
|
else:
|
||||||
|
print("infer error")
|
||||||
|
|
||||||
|
except Exception as e:
|
||||||
|
print("receive msg,but parse exception:", e)
|
||||||
|
|
||||||
|
|
||||||
|
# 收到websocket错误的处理
|
||||||
|
def on_error(ws, error):
|
||||||
|
print("### error:", error)
|
||||||
|
|
||||||
|
|
||||||
|
# 收到websocket关闭的处理
|
||||||
|
def on_close(ws):
|
||||||
|
print("### closed ###")
|
||||||
|
|
||||||
|
|
||||||
|
# 收到websocket连接建立的处理
|
||||||
|
def on_open(ws):
|
||||||
|
def run(*args):
|
||||||
|
global st
|
||||||
|
text_base64 = str(
|
||||||
|
base64.b64encode((wsParam.text).encode('utf-8')), "UTF8")
|
||||||
|
d = {"text": text_base64}
|
||||||
|
d = json.dumps(d)
|
||||||
|
print("Start sending text data")
|
||||||
|
st = time.time()
|
||||||
|
ws.send(d)
|
||||||
|
|
||||||
|
thread.start_new_thread(run, ())
|
||||||
|
|
||||||
|
|
||||||
|
if __name__ == "__main__":
|
||||||
|
parser = argparse.ArgumentParser()
|
||||||
|
parser.add_argument(
|
||||||
|
"--text",
|
||||||
|
type=str,
|
||||||
|
help="A sentence to be synthesized",
|
||||||
|
default="您好,欢迎使用语音合成服务。")
|
||||||
|
parser.add_argument(
|
||||||
|
"--server", type=str, help="server ip", default="127.0.0.1")
|
||||||
|
parser.add_argument("--port", type=int, help="server port", default=8092)
|
||||||
|
args = parser.parse_args()
|
||||||
|
|
||||||
|
print("***************************************")
|
||||||
|
print("Server ip: ", args.server)
|
||||||
|
print("Server port: ", args.port)
|
||||||
|
print("Sentence to be synthesized: ", args.text)
|
||||||
|
print("***************************************")
|
||||||
|
|
||||||
|
wsParam = WsParam(text=args.text, server=args.server, port=args.port)
|
||||||
|
|
||||||
|
websocket.enableTrace(False)
|
||||||
|
wsUrl = wsParam.create_url()
|
||||||
|
ws = websocket.WebSocketApp(
|
||||||
|
wsUrl, on_message=on_message, on_error=on_error, on_close=on_close)
|
||||||
|
ws.on_open = on_open
|
||||||
|
ws.run_forever(sslopt={"cert_reqs": ssl.CERT_NONE})
|
@ -0,0 +1,160 @@
|
|||||||
|
# Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved.
|
||||||
|
#
|
||||||
|
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||||
|
# you may not use this file except in compliance with the License.
|
||||||
|
# You may obtain a copy of the License at
|
||||||
|
#
|
||||||
|
# http://www.apache.org/licenses/LICENSE-2.0
|
||||||
|
#
|
||||||
|
# Unless required by applicable law or agreed to in writing, software
|
||||||
|
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||||
|
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||||
|
# See the License for the specific language governing permissions and
|
||||||
|
# limitations under the License.
|
||||||
|
import _thread as thread
|
||||||
|
import argparse
|
||||||
|
import base64
|
||||||
|
import json
|
||||||
|
import ssl
|
||||||
|
import threading
|
||||||
|
import time
|
||||||
|
|
||||||
|
import pyaudio
|
||||||
|
import websocket
|
||||||
|
|
||||||
|
mutex = threading.Lock()
|
||||||
|
buffer = b''
|
||||||
|
p = pyaudio.PyAudio()
|
||||||
|
stream = p.open(
|
||||||
|
format=p.get_format_from_width(2), channels=1, rate=24000, output=True)
|
||||||
|
flag = 1
|
||||||
|
st = 0.0
|
||||||
|
all_bytes = 0.0
|
||||||
|
|
||||||
|
|
||||||
|
class WsParam(object):
|
||||||
|
# 初始化
|
||||||
|
def __init__(self, text, server="127.0.0.1", port=8090):
|
||||||
|
self.server = server
|
||||||
|
self.port = port
|
||||||
|
self.url = "ws://" + self.server + ":" + str(self.port) + "/ws/tts"
|
||||||
|
self.text = text
|
||||||
|
|
||||||
|
# 生成url
|
||||||
|
def create_url(self):
|
||||||
|
return self.url
|
||||||
|
|
||||||
|
|
||||||
|
def play_audio():
|
||||||
|
global stream
|
||||||
|
global buffer
|
||||||
|
while True:
|
||||||
|
time.sleep(0.05)
|
||||||
|
if not buffer: # buffer 为空
|
||||||
|
break
|
||||||
|
mutex.acquire()
|
||||||
|
stream.write(buffer)
|
||||||
|
buffer = b''
|
||||||
|
mutex.release()
|
||||||
|
|
||||||
|
|
||||||
|
t = threading.Thread(target=play_audio)
|
||||||
|
|
||||||
|
|
||||||
|
def on_message(ws, message):
|
||||||
|
global flag
|
||||||
|
global t
|
||||||
|
global buffer
|
||||||
|
global st
|
||||||
|
global all_bytes
|
||||||
|
|
||||||
|
try:
|
||||||
|
message = json.loads(message)
|
||||||
|
audio = message["audio"]
|
||||||
|
audio = base64.b64decode(audio) # bytes
|
||||||
|
status = message["status"]
|
||||||
|
all_bytes += len(audio)
|
||||||
|
|
||||||
|
if status == 0:
|
||||||
|
print("create successfully.")
|
||||||
|
elif status == 1:
|
||||||
|
mutex.acquire()
|
||||||
|
buffer += audio
|
||||||
|
mutex.release()
|
||||||
|
if flag:
|
||||||
|
print(f"首包响应:{time.time() - st} s")
|
||||||
|
flag = 0
|
||||||
|
print("Start playing audio")
|
||||||
|
t.start()
|
||||||
|
elif status == 2:
|
||||||
|
final_response = time.time() - st
|
||||||
|
duration = all_bytes / 2 / 24000
|
||||||
|
print(f"尾包响应:{final_response} s")
|
||||||
|
print(f"音频时长:{duration} s")
|
||||||
|
print(f"RTF: {final_response / duration}")
|
||||||
|
print("ws is closed")
|
||||||
|
ws.close()
|
||||||
|
else:
|
||||||
|
print("infer error")
|
||||||
|
|
||||||
|
except Exception as e:
|
||||||
|
print("receive msg,but parse exception:", e)
|
||||||
|
|
||||||
|
|
||||||
|
# 收到websocket错误的处理
|
||||||
|
def on_error(ws, error):
|
||||||
|
print("### error:", error)
|
||||||
|
|
||||||
|
|
||||||
|
# 收到websocket关闭的处理
|
||||||
|
def on_close(ws):
|
||||||
|
print("### closed ###")
|
||||||
|
|
||||||
|
|
||||||
|
# 收到websocket连接建立的处理
|
||||||
|
def on_open(ws):
|
||||||
|
def run(*args):
|
||||||
|
global st
|
||||||
|
text_base64 = str(
|
||||||
|
base64.b64encode((wsParam.text).encode('utf-8')), "UTF8")
|
||||||
|
d = {"text": text_base64}
|
||||||
|
d = json.dumps(d)
|
||||||
|
print("Start sending text data")
|
||||||
|
st = time.time()
|
||||||
|
ws.send(d)
|
||||||
|
|
||||||
|
thread.start_new_thread(run, ())
|
||||||
|
|
||||||
|
|
||||||
|
if __name__ == "__main__":
|
||||||
|
parser = argparse.ArgumentParser()
|
||||||
|
parser.add_argument(
|
||||||
|
"--text",
|
||||||
|
type=str,
|
||||||
|
help="A sentence to be synthesized",
|
||||||
|
default="您好,欢迎使用语音合成服务。")
|
||||||
|
parser.add_argument(
|
||||||
|
"--server", type=str, help="server ip", default="127.0.0.1")
|
||||||
|
parser.add_argument("--port", type=int, help="server port", default=8092)
|
||||||
|
args = parser.parse_args()
|
||||||
|
|
||||||
|
print("***************************************")
|
||||||
|
print("Server ip: ", args.server)
|
||||||
|
print("Server port: ", args.port)
|
||||||
|
print("Sentence to be synthesized: ", args.text)
|
||||||
|
print("***************************************")
|
||||||
|
|
||||||
|
wsParam = WsParam(text=args.text, server=args.server, port=args.port)
|
||||||
|
|
||||||
|
websocket.enableTrace(False)
|
||||||
|
wsUrl = wsParam.create_url()
|
||||||
|
ws = websocket.WebSocketApp(
|
||||||
|
wsUrl, on_message=on_message, on_error=on_error, on_close=on_close)
|
||||||
|
ws.on_open = on_open
|
||||||
|
ws.run_forever(sslopt={"cert_reqs": ssl.CERT_NONE})
|
||||||
|
|
||||||
|
t.join()
|
||||||
|
print("End of playing audio")
|
||||||
|
stream.stop_stream()
|
||||||
|
stream.close()
|
||||||
|
p.terminate()
|
@ -0,0 +1,62 @@
|
|||||||
|
# Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved.
|
||||||
|
#
|
||||||
|
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||||
|
# you may not use this file except in compliance with the License.
|
||||||
|
# You may obtain a copy of the License at
|
||||||
|
#
|
||||||
|
# http://www.apache.org/licenses/LICENSE-2.0
|
||||||
|
#
|
||||||
|
# Unless required by applicable law or agreed to in writing, software
|
||||||
|
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||||
|
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||||
|
# See the License for the specific language governing permissions and
|
||||||
|
# limitations under the License.
|
||||||
|
import json
|
||||||
|
|
||||||
|
from fastapi import APIRouter
|
||||||
|
from fastapi import WebSocket
|
||||||
|
from fastapi import WebSocketDisconnect
|
||||||
|
from starlette.websockets import WebSocketState as WebSocketState
|
||||||
|
|
||||||
|
from paddlespeech.cli.log import logger
|
||||||
|
from paddlespeech.server.engine.engine_pool import get_engine_pool
|
||||||
|
|
||||||
|
router = APIRouter()
|
||||||
|
|
||||||
|
|
||||||
|
@router.websocket('/ws/tts')
|
||||||
|
async def websocket_endpoint(websocket: WebSocket):
|
||||||
|
await websocket.accept()
|
||||||
|
|
||||||
|
try:
|
||||||
|
# careful here, changed the source code from starlette.websockets
|
||||||
|
assert websocket.application_state == WebSocketState.CONNECTED
|
||||||
|
message = await websocket.receive()
|
||||||
|
websocket._raise_on_disconnect(message)
|
||||||
|
|
||||||
|
# get engine
|
||||||
|
engine_pool = get_engine_pool()
|
||||||
|
tts_engine = engine_pool['tts']
|
||||||
|
|
||||||
|
# 获取 message 并转文本
|
||||||
|
message = json.loads(message["text"])
|
||||||
|
text_bese64 = message["text"]
|
||||||
|
sentence = tts_engine.preprocess(text_bese64=text_bese64)
|
||||||
|
|
||||||
|
# run
|
||||||
|
wav_generator = tts_engine.run(sentence)
|
||||||
|
|
||||||
|
while True:
|
||||||
|
try:
|
||||||
|
tts_results = next(wav_generator)
|
||||||
|
resp = {"status": 1, "audio": tts_results}
|
||||||
|
await websocket.send_json(resp)
|
||||||
|
logger.info("streaming audio...")
|
||||||
|
except StopIteration as e:
|
||||||
|
resp = {"status": 2, "audio": ''}
|
||||||
|
await websocket.send_json(resp)
|
||||||
|
logger.info("Complete the transmission of audio streams")
|
||||||
|
break
|
||||||
|
|
||||||
|
except WebSocketDisconnect:
|
||||||
|
pass
|
@ -0,0 +1,156 @@
|
|||||||
|
# Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved.
|
||||||
|
#
|
||||||
|
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||||
|
# you may not use this file except in compliance with the License.
|
||||||
|
# You may obtain a copy of the License at
|
||||||
|
#
|
||||||
|
# http://www.apache.org/licenses/LICENSE-2.0
|
||||||
|
#
|
||||||
|
# Unless required by applicable law or agreed to in writing, software
|
||||||
|
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||||
|
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||||
|
# See the License for the specific language governing permissions and
|
||||||
|
# limitations under the License.
|
||||||
|
import argparse
|
||||||
|
from pathlib import Path
|
||||||
|
|
||||||
|
import jsonlines
|
||||||
|
import numpy as np
|
||||||
|
import onnxruntime as ort
|
||||||
|
import soundfile as sf
|
||||||
|
from timer import timer
|
||||||
|
|
||||||
|
from paddlespeech.t2s.exps.syn_utils import get_test_dataset
|
||||||
|
from paddlespeech.t2s.utils import str2bool
|
||||||
|
|
||||||
|
|
||||||
|
def get_sess(args, filed='am'):
|
||||||
|
full_name = ''
|
||||||
|
if filed == 'am':
|
||||||
|
full_name = args.am
|
||||||
|
elif filed == 'voc':
|
||||||
|
full_name = args.voc
|
||||||
|
model_dir = str(Path(args.inference_dir) / (full_name + ".onnx"))
|
||||||
|
sess_options = ort.SessionOptions()
|
||||||
|
sess_options.graph_optimization_level = ort.GraphOptimizationLevel.ORT_ENABLE_ALL
|
||||||
|
sess_options.execution_mode = ort.ExecutionMode.ORT_SEQUENTIAL
|
||||||
|
|
||||||
|
if args.device == "gpu":
|
||||||
|
# fastspeech2/mb_melgan can't use trt now!
|
||||||
|
if args.use_trt:
|
||||||
|
providers = ['TensorrtExecutionProvider']
|
||||||
|
else:
|
||||||
|
providers = ['CUDAExecutionProvider']
|
||||||
|
elif args.device == "cpu":
|
||||||
|
providers = ['CPUExecutionProvider']
|
||||||
|
sess_options.intra_op_num_threads = args.cpu_threads
|
||||||
|
sess = ort.InferenceSession(
|
||||||
|
model_dir, providers=providers, sess_options=sess_options)
|
||||||
|
return sess
|
||||||
|
|
||||||
|
|
||||||
|
def ort_predict(args):
|
||||||
|
# construct dataset for evaluation
|
||||||
|
with jsonlines.open(args.test_metadata, 'r') as reader:
|
||||||
|
test_metadata = list(reader)
|
||||||
|
am_name = args.am[:args.am.rindex('_')]
|
||||||
|
am_dataset = args.am[args.am.rindex('_') + 1:]
|
||||||
|
test_dataset = get_test_dataset(args, test_metadata, am_name, am_dataset)
|
||||||
|
|
||||||
|
output_dir = Path(args.output_dir)
|
||||||
|
output_dir.mkdir(parents=True, exist_ok=True)
|
||||||
|
|
||||||
|
fs = 24000 if am_dataset != 'ljspeech' else 22050
|
||||||
|
|
||||||
|
# am
|
||||||
|
am_sess = get_sess(args, filed='am')
|
||||||
|
|
||||||
|
# vocoder
|
||||||
|
voc_sess = get_sess(args, filed='voc')
|
||||||
|
|
||||||
|
# am warmup
|
||||||
|
for T in [27, 38, 54]:
|
||||||
|
data = np.random.randint(1, 266, size=(T, ))
|
||||||
|
am_sess.run(None, {"text": data})
|
||||||
|
|
||||||
|
# voc warmup
|
||||||
|
for T in [227, 308, 544]:
|
||||||
|
data = np.random.rand(T, 80).astype("float32")
|
||||||
|
voc_sess.run(None, {"logmel": data})
|
||||||
|
print("warm up done!")
|
||||||
|
|
||||||
|
N = 0
|
||||||
|
T = 0
|
||||||
|
for example in test_dataset:
|
||||||
|
utt_id = example['utt_id']
|
||||||
|
phone_ids = example["text"]
|
||||||
|
with timer() as t:
|
||||||
|
mel = am_sess.run(output_names=None, input_feed={'text': phone_ids})
|
||||||
|
mel = mel[0]
|
||||||
|
wav = voc_sess.run(output_names=None, input_feed={'logmel': mel})
|
||||||
|
|
||||||
|
N += len(wav[0])
|
||||||
|
T += t.elapse
|
||||||
|
speed = len(wav[0]) / t.elapse
|
||||||
|
rtf = fs / speed
|
||||||
|
sf.write(
|
||||||
|
str(output_dir / (utt_id + ".wav")),
|
||||||
|
np.array(wav)[0],
|
||||||
|
samplerate=fs)
|
||||||
|
print(
|
||||||
|
f"{utt_id}, mel: {mel.shape}, wave: {len(wav[0])}, time: {t.elapse}s, Hz: {speed}, RTF: {rtf}."
|
||||||
|
)
|
||||||
|
print(f"generation speed: {N / T}Hz, RTF: {fs / (N / T) }")
|
||||||
|
|
||||||
|
|
||||||
|
def parse_args():
|
||||||
|
parser = argparse.ArgumentParser(description="Infernce with onnxruntime.")
|
||||||
|
# acoustic model
|
||||||
|
parser.add_argument(
|
||||||
|
'--am',
|
||||||
|
type=str,
|
||||||
|
default='fastspeech2_csmsc',
|
||||||
|
choices=[
|
||||||
|
'fastspeech2_csmsc',
|
||||||
|
],
|
||||||
|
help='Choose acoustic model type of tts task.')
|
||||||
|
|
||||||
|
# voc
|
||||||
|
parser.add_argument(
|
||||||
|
'--voc',
|
||||||
|
type=str,
|
||||||
|
default='hifigan_csmsc',
|
||||||
|
choices=['hifigan_csmsc', 'mb_melgan_csmsc'],
|
||||||
|
help='Choose vocoder type of tts task.')
|
||||||
|
# other
|
||||||
|
parser.add_argument(
|
||||||
|
"--inference_dir", type=str, help="dir to save inference models")
|
||||||
|
parser.add_argument("--test_metadata", type=str, help="test metadata.")
|
||||||
|
parser.add_argument("--output_dir", type=str, help="output dir")
|
||||||
|
|
||||||
|
# inference
|
||||||
|
parser.add_argument(
|
||||||
|
"--use_trt",
|
||||||
|
type=str2bool,
|
||||||
|
default=False,
|
||||||
|
help="Whether to use inference engin TensorRT.", )
|
||||||
|
|
||||||
|
parser.add_argument(
|
||||||
|
"--device",
|
||||||
|
default="gpu",
|
||||||
|
choices=["gpu", "cpu"],
|
||||||
|
help="Device selected for inference.", )
|
||||||
|
parser.add_argument('--cpu_threads', type=int, default=1)
|
||||||
|
|
||||||
|
args, _ = parser.parse_known_args()
|
||||||
|
return args
|
||||||
|
|
||||||
|
|
||||||
|
def main():
|
||||||
|
args = parse_args()
|
||||||
|
|
||||||
|
ort_predict(args)
|
||||||
|
|
||||||
|
|
||||||
|
if __name__ == "__main__":
|
||||||
|
main()
|
@ -0,0 +1,183 @@
|
|||||||
|
# Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved.
|
||||||
|
#
|
||||||
|
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||||
|
# you may not use this file except in compliance with the License.
|
||||||
|
# You may obtain a copy of the License at
|
||||||
|
#
|
||||||
|
# http://www.apache.org/licenses/LICENSE-2.0
|
||||||
|
#
|
||||||
|
# Unless required by applicable law or agreed to in writing, software
|
||||||
|
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||||
|
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||||
|
# See the License for the specific language governing permissions and
|
||||||
|
# limitations under the License.
|
||||||
|
import argparse
|
||||||
|
from pathlib import Path
|
||||||
|
|
||||||
|
import numpy as np
|
||||||
|
import onnxruntime as ort
|
||||||
|
import soundfile as sf
|
||||||
|
from timer import timer
|
||||||
|
|
||||||
|
from paddlespeech.t2s.exps.syn_utils import get_frontend
|
||||||
|
from paddlespeech.t2s.exps.syn_utils import get_sentences
|
||||||
|
from paddlespeech.t2s.utils import str2bool
|
||||||
|
|
||||||
|
|
||||||
|
def get_sess(args, filed='am'):
|
||||||
|
full_name = ''
|
||||||
|
if filed == 'am':
|
||||||
|
full_name = args.am
|
||||||
|
elif filed == 'voc':
|
||||||
|
full_name = args.voc
|
||||||
|
model_dir = str(Path(args.inference_dir) / (full_name + ".onnx"))
|
||||||
|
sess_options = ort.SessionOptions()
|
||||||
|
sess_options.graph_optimization_level = ort.GraphOptimizationLevel.ORT_ENABLE_ALL
|
||||||
|
sess_options.execution_mode = ort.ExecutionMode.ORT_SEQUENTIAL
|
||||||
|
|
||||||
|
if args.device == "gpu":
|
||||||
|
# fastspeech2/mb_melgan can't use trt now!
|
||||||
|
if args.use_trt:
|
||||||
|
providers = ['TensorrtExecutionProvider']
|
||||||
|
else:
|
||||||
|
providers = ['CUDAExecutionProvider']
|
||||||
|
elif args.device == "cpu":
|
||||||
|
providers = ['CPUExecutionProvider']
|
||||||
|
sess_options.intra_op_num_threads = args.cpu_threads
|
||||||
|
sess = ort.InferenceSession(
|
||||||
|
model_dir, providers=providers, sess_options=sess_options)
|
||||||
|
return sess
|
||||||
|
|
||||||
|
|
||||||
|
def ort_predict(args):
|
||||||
|
|
||||||
|
# frontend
|
||||||
|
frontend = get_frontend(args)
|
||||||
|
|
||||||
|
output_dir = Path(args.output_dir)
|
||||||
|
output_dir.mkdir(parents=True, exist_ok=True)
|
||||||
|
sentences = get_sentences(args)
|
||||||
|
|
||||||
|
am_name = args.am[:args.am.rindex('_')]
|
||||||
|
am_dataset = args.am[args.am.rindex('_') + 1:]
|
||||||
|
fs = 24000 if am_dataset != 'ljspeech' else 22050
|
||||||
|
|
||||||
|
# am
|
||||||
|
am_sess = get_sess(args, filed='am')
|
||||||
|
|
||||||
|
# vocoder
|
||||||
|
voc_sess = get_sess(args, filed='voc')
|
||||||
|
|
||||||
|
# am warmup
|
||||||
|
for T in [27, 38, 54]:
|
||||||
|
data = np.random.randint(1, 266, size=(T, ))
|
||||||
|
am_sess.run(None, {"text": data})
|
||||||
|
|
||||||
|
# voc warmup
|
||||||
|
for T in [227, 308, 544]:
|
||||||
|
data = np.random.rand(T, 80).astype("float32")
|
||||||
|
voc_sess.run(None, {"logmel": data})
|
||||||
|
print("warm up done!")
|
||||||
|
|
||||||
|
# frontend warmup
|
||||||
|
# Loading model cost 0.5+ seconds
|
||||||
|
if args.lang == 'zh':
|
||||||
|
frontend.get_input_ids("你好,欢迎使用飞桨框架进行深度学习研究!", merge_sentences=True)
|
||||||
|
else:
|
||||||
|
print("lang should in be 'zh' here!")
|
||||||
|
|
||||||
|
N = 0
|
||||||
|
T = 0
|
||||||
|
merge_sentences = True
|
||||||
|
for utt_id, sentence in sentences:
|
||||||
|
with timer() as t:
|
||||||
|
if args.lang == 'zh':
|
||||||
|
input_ids = frontend.get_input_ids(
|
||||||
|
sentence, merge_sentences=merge_sentences)
|
||||||
|
|
||||||
|
phone_ids = input_ids["phone_ids"]
|
||||||
|
else:
|
||||||
|
print("lang should in be 'zh' here!")
|
||||||
|
# merge_sentences=True here, so we only use the first item of phone_ids
|
||||||
|
phone_ids = phone_ids[0].numpy()
|
||||||
|
mel = am_sess.run(output_names=None, input_feed={'text': phone_ids})
|
||||||
|
mel = mel[0]
|
||||||
|
wav = voc_sess.run(output_names=None, input_feed={'logmel': mel})
|
||||||
|
|
||||||
|
N += len(wav[0])
|
||||||
|
T += t.elapse
|
||||||
|
speed = len(wav[0]) / t.elapse
|
||||||
|
rtf = fs / speed
|
||||||
|
sf.write(
|
||||||
|
str(output_dir / (utt_id + ".wav")),
|
||||||
|
np.array(wav)[0],
|
||||||
|
samplerate=fs)
|
||||||
|
print(
|
||||||
|
f"{utt_id}, mel: {mel.shape}, wave: {len(wav[0])}, time: {t.elapse}s, Hz: {speed}, RTF: {rtf}."
|
||||||
|
)
|
||||||
|
print(f"generation speed: {N / T}Hz, RTF: {fs / (N / T) }")
|
||||||
|
|
||||||
|
|
||||||
|
def parse_args():
|
||||||
|
parser = argparse.ArgumentParser(description="Infernce with onnxruntime.")
|
||||||
|
# acoustic model
|
||||||
|
parser.add_argument(
|
||||||
|
'--am',
|
||||||
|
type=str,
|
||||||
|
default='fastspeech2_csmsc',
|
||||||
|
choices=[
|
||||||
|
'fastspeech2_csmsc',
|
||||||
|
],
|
||||||
|
help='Choose acoustic model type of tts task.')
|
||||||
|
parser.add_argument(
|
||||||
|
"--phones_dict", type=str, default=None, help="phone vocabulary file.")
|
||||||
|
parser.add_argument(
|
||||||
|
"--tones_dict", type=str, default=None, help="tone vocabulary file.")
|
||||||
|
|
||||||
|
# voc
|
||||||
|
parser.add_argument(
|
||||||
|
'--voc',
|
||||||
|
type=str,
|
||||||
|
default='hifigan_csmsc',
|
||||||
|
choices=['hifigan_csmsc', 'mb_melgan_csmsc'],
|
||||||
|
help='Choose vocoder type of tts task.')
|
||||||
|
# other
|
||||||
|
parser.add_argument(
|
||||||
|
"--inference_dir", type=str, help="dir to save inference models")
|
||||||
|
parser.add_argument(
|
||||||
|
"--text",
|
||||||
|
type=str,
|
||||||
|
help="text to synthesize, a 'utt_id sentence' pair per line")
|
||||||
|
parser.add_argument("--output_dir", type=str, help="output dir")
|
||||||
|
parser.add_argument(
|
||||||
|
'--lang',
|
||||||
|
type=str,
|
||||||
|
default='zh',
|
||||||
|
help='Choose model language. zh or en')
|
||||||
|
|
||||||
|
# inference
|
||||||
|
parser.add_argument(
|
||||||
|
"--use_trt",
|
||||||
|
type=str2bool,
|
||||||
|
default=False,
|
||||||
|
help="Whether to use inference engin TensorRT.", )
|
||||||
|
|
||||||
|
parser.add_argument(
|
||||||
|
"--device",
|
||||||
|
default="gpu",
|
||||||
|
choices=["gpu", "cpu"],
|
||||||
|
help="Device selected for inference.", )
|
||||||
|
parser.add_argument('--cpu_threads', type=int, default=1)
|
||||||
|
|
||||||
|
args, _ = parser.parse_known_args()
|
||||||
|
return args
|
||||||
|
|
||||||
|
|
||||||
|
def main():
|
||||||
|
args = parse_args()
|
||||||
|
|
||||||
|
ort_predict(args)
|
||||||
|
|
||||||
|
|
||||||
|
if __name__ == "__main__":
|
||||||
|
main()
|
@ -0,0 +1,192 @@
|
|||||||
|
# Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved.
|
||||||
|
#
|
||||||
|
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||||
|
# you may not use this file except in compliance with the License.
|
||||||
|
# You may obtain a copy of the License at
|
||||||
|
#
|
||||||
|
# http://www.apache.org/licenses/LICENSE-2.0
|
||||||
|
#
|
||||||
|
# Unless required by applicable law or agreed to in writing, software
|
||||||
|
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||||
|
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||||
|
# See the License for the specific language governing permissions and
|
||||||
|
# limitations under the License.
|
||||||
|
from dataclasses import dataclass
|
||||||
|
from dataclasses import fields
|
||||||
|
from paddle.io import Dataset
|
||||||
|
|
||||||
|
from paddleaudio import load as load_audio
|
||||||
|
from paddleaudio.compliance.librosa import melspectrogram
|
||||||
|
from paddlespeech.s2t.utils.log import Log
|
||||||
|
logger = Log(__name__).getlog()
|
||||||
|
|
||||||
|
# the audio meta info in the vector CSVDataset
|
||||||
|
# utt_id: the utterance segment name
|
||||||
|
# duration: utterance segment time
|
||||||
|
# wav: utterance file path
|
||||||
|
# start: start point in the original wav file
|
||||||
|
# stop: stop point in the original wav file
|
||||||
|
# label: the utterance segment's label id
|
||||||
|
|
||||||
|
|
||||||
|
@dataclass
|
||||||
|
class meta_info:
|
||||||
|
"""the audio meta info in the vector CSVDataset
|
||||||
|
|
||||||
|
Args:
|
||||||
|
utt_id (str): the utterance segment name
|
||||||
|
duration (float): utterance segment time
|
||||||
|
wav (str): utterance file path
|
||||||
|
start (int): start point in the original wav file
|
||||||
|
stop (int): stop point in the original wav file
|
||||||
|
lab_id (str): the utterance segment's label id
|
||||||
|
"""
|
||||||
|
utt_id: str
|
||||||
|
duration: float
|
||||||
|
wav: str
|
||||||
|
start: int
|
||||||
|
stop: int
|
||||||
|
label: str
|
||||||
|
|
||||||
|
|
||||||
|
# csv dataset support feature type
|
||||||
|
# raw: return the pcm data sample point
|
||||||
|
# melspectrogram: fbank feature
|
||||||
|
feat_funcs = {
|
||||||
|
'raw': None,
|
||||||
|
'melspectrogram': melspectrogram,
|
||||||
|
}
|
||||||
|
|
||||||
|
|
||||||
|
class CSVDataset(Dataset):
|
||||||
|
def __init__(self,
|
||||||
|
csv_path,
|
||||||
|
label2id_path=None,
|
||||||
|
config=None,
|
||||||
|
random_chunk=True,
|
||||||
|
feat_type: str="raw",
|
||||||
|
n_train_snts: int=-1,
|
||||||
|
**kwargs):
|
||||||
|
"""Implement the CSV Dataset
|
||||||
|
|
||||||
|
Args:
|
||||||
|
csv_path (str): csv dataset file path
|
||||||
|
label2id_path (str): the utterance label to integer id map file path
|
||||||
|
config (CfgNode): yaml config
|
||||||
|
feat_type (str): dataset feature type. if it is raw, it return pcm data.
|
||||||
|
n_train_snts (int): select the n_train_snts sample from the dataset.
|
||||||
|
if n_train_snts = -1, dataset will load all the sample.
|
||||||
|
Default value is -1.
|
||||||
|
kwargs : feature type args
|
||||||
|
"""
|
||||||
|
super().__init__()
|
||||||
|
self.csv_path = csv_path
|
||||||
|
self.label2id_path = label2id_path
|
||||||
|
self.config = config
|
||||||
|
self.random_chunk = random_chunk
|
||||||
|
self.feat_type = feat_type
|
||||||
|
self.n_train_snts = n_train_snts
|
||||||
|
self.feat_config = kwargs
|
||||||
|
self.id2label = {}
|
||||||
|
self.label2id = {}
|
||||||
|
self.data = self.load_data_csv()
|
||||||
|
self.load_speaker_to_label()
|
||||||
|
|
||||||
|
def load_data_csv(self):
|
||||||
|
"""Load the csv dataset content and store them in the data property
|
||||||
|
the csv dataset's format has six fields,
|
||||||
|
that is audio_id or utt_id, audio duration, segment start point, segment stop point
|
||||||
|
and utterance label.
|
||||||
|
Note in training period, the utterance label must has a map to integer id in label2id_path
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
list: the csv data with meta_info type
|
||||||
|
"""
|
||||||
|
data = []
|
||||||
|
|
||||||
|
with open(self.csv_path, 'r') as rf:
|
||||||
|
for line in rf.readlines()[1:]:
|
||||||
|
audio_id, duration, wav, start, stop, spk_id = line.strip(
|
||||||
|
).split(',')
|
||||||
|
data.append(
|
||||||
|
meta_info(audio_id,
|
||||||
|
float(duration), wav,
|
||||||
|
int(start), int(stop), spk_id))
|
||||||
|
if self.n_train_snts > 0:
|
||||||
|
sample_num = min(self.n_train_snts, len(data))
|
||||||
|
data = data[0:sample_num]
|
||||||
|
|
||||||
|
return data
|
||||||
|
|
||||||
|
def load_speaker_to_label(self):
|
||||||
|
"""Load the utterance label map content.
|
||||||
|
In vector domain, we call the utterance label as speaker label.
|
||||||
|
The speaker label is real speaker label in speaker verification domain,
|
||||||
|
and in language identification is language label.
|
||||||
|
"""
|
||||||
|
if not self.label2id_path:
|
||||||
|
logger.warning("No speaker id to label file")
|
||||||
|
return
|
||||||
|
|
||||||
|
with open(self.label2id_path, 'r') as f:
|
||||||
|
for line in f.readlines():
|
||||||
|
label_name, label_id = line.strip().split(' ')
|
||||||
|
self.label2id[label_name] = int(label_id)
|
||||||
|
self.id2label[int(label_id)] = label_name
|
||||||
|
|
||||||
|
def convert_to_record(self, idx: int):
|
||||||
|
"""convert the dataset sample to training record the CSV Dataset
|
||||||
|
|
||||||
|
Args:
|
||||||
|
idx (int) : the request index in all the dataset
|
||||||
|
"""
|
||||||
|
sample = self.data[idx]
|
||||||
|
|
||||||
|
record = {}
|
||||||
|
# To show all fields in a namedtuple: `type(sample)._fields`
|
||||||
|
for field in fields(sample):
|
||||||
|
record[field.name] = getattr(sample, field.name)
|
||||||
|
|
||||||
|
waveform, sr = load_audio(record['wav'])
|
||||||
|
|
||||||
|
# random select a chunk audio samples from the audio
|
||||||
|
if self.config and self.config.random_chunk:
|
||||||
|
num_wav_samples = waveform.shape[0]
|
||||||
|
num_chunk_samples = int(self.config.chunk_duration * sr)
|
||||||
|
start = random.randint(0, num_wav_samples - num_chunk_samples - 1)
|
||||||
|
stop = start + num_chunk_samples
|
||||||
|
else:
|
||||||
|
start = record['start']
|
||||||
|
stop = record['stop']
|
||||||
|
|
||||||
|
# we only return the waveform as feat
|
||||||
|
waveform = waveform[start:stop]
|
||||||
|
|
||||||
|
# all availabel feature type is in feat_funcs
|
||||||
|
assert self.feat_type in feat_funcs.keys(), \
|
||||||
|
f"Unknown feat_type: {self.feat_type}, it must be one in {list(feat_funcs.keys())}"
|
||||||
|
feat_func = feat_funcs[self.feat_type]
|
||||||
|
feat = feat_func(
|
||||||
|
waveform, sr=sr, **self.feat_config) if feat_func else waveform
|
||||||
|
|
||||||
|
record.update({'feat': feat})
|
||||||
|
if self.label2id:
|
||||||
|
record.update({'label': self.label2id[record['label']]})
|
||||||
|
|
||||||
|
return record
|
||||||
|
|
||||||
|
def __getitem__(self, idx):
|
||||||
|
"""Return the specific index sample
|
||||||
|
|
||||||
|
Args:
|
||||||
|
idx (int) : the request index in all the dataset
|
||||||
|
"""
|
||||||
|
return self.convert_to_record(idx)
|
||||||
|
|
||||||
|
def __len__(self):
|
||||||
|
"""Return the dataset length
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
int: the length num of the dataset
|
||||||
|
"""
|
||||||
|
return len(self.data)
|
@ -0,0 +1,116 @@
|
|||||||
|
# Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved.
|
||||||
|
#
|
||||||
|
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||||
|
# you may not use this file except in compliance with the License.
|
||||||
|
# You may obtain a copy of the License at
|
||||||
|
#
|
||||||
|
# http://www.apache.org/licenses/LICENSE-2.0
|
||||||
|
#
|
||||||
|
# Unless required by applicable law or agreed to in writing, software
|
||||||
|
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||||
|
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||||
|
# See the License for the specific language governing permissions and
|
||||||
|
# limitations under the License.
|
||||||
|
import json
|
||||||
|
|
||||||
|
from dataclasses import dataclass
|
||||||
|
from dataclasses import fields
|
||||||
|
from paddle.io import Dataset
|
||||||
|
|
||||||
|
from paddleaudio import load as load_audio
|
||||||
|
from paddleaudio.compliance.librosa import melspectrogram
|
||||||
|
from paddleaudio.compliance.librosa import mfcc
|
||||||
|
|
||||||
|
|
||||||
|
@dataclass
|
||||||
|
class meta_info:
|
||||||
|
"""the audio meta info in the vector JSONDataset
|
||||||
|
Args:
|
||||||
|
id (str): the segment name
|
||||||
|
duration (float): segment time
|
||||||
|
wav (str): wav file path
|
||||||
|
start (int): start point in the original wav file
|
||||||
|
stop (int): stop point in the original wav file
|
||||||
|
lab_id (str): the record id
|
||||||
|
"""
|
||||||
|
id: str
|
||||||
|
duration: float
|
||||||
|
wav: str
|
||||||
|
start: int
|
||||||
|
stop: int
|
||||||
|
record_id: str
|
||||||
|
|
||||||
|
|
||||||
|
# json dataset support feature type
|
||||||
|
feat_funcs = {
|
||||||
|
'raw': None,
|
||||||
|
'melspectrogram': melspectrogram,
|
||||||
|
'mfcc': mfcc,
|
||||||
|
}
|
||||||
|
|
||||||
|
|
||||||
|
class JSONDataset(Dataset):
|
||||||
|
"""
|
||||||
|
dataset from json file.
|
||||||
|
"""
|
||||||
|
|
||||||
|
def __init__(self, json_file: str, feat_type: str='raw', **kwargs):
|
||||||
|
"""
|
||||||
|
Ags:
|
||||||
|
json_file (:obj:`str`): Data prep JSON file.
|
||||||
|
labels (:obj:`List[int]`): Labels of audio files.
|
||||||
|
feat_type (:obj:`str`, `optional`, defaults to `raw`):
|
||||||
|
It identifies the feature type that user wants to extrace of an audio file.
|
||||||
|
"""
|
||||||
|
if feat_type not in feat_funcs.keys():
|
||||||
|
raise RuntimeError(
|
||||||
|
f"Unknown feat_type: {feat_type}, it must be one in {list(feat_funcs.keys())}"
|
||||||
|
)
|
||||||
|
|
||||||
|
self.json_file = json_file
|
||||||
|
self.feat_type = feat_type
|
||||||
|
self.feat_config = kwargs
|
||||||
|
self._data = self._get_data()
|
||||||
|
super(JSONDataset, self).__init__()
|
||||||
|
|
||||||
|
def _get_data(self):
|
||||||
|
with open(self.json_file, "r") as f:
|
||||||
|
meta_data = json.load(f)
|
||||||
|
data = []
|
||||||
|
for key in meta_data:
|
||||||
|
sub_seg = meta_data[key]["wav"]
|
||||||
|
wav = sub_seg["file"]
|
||||||
|
duration = sub_seg["duration"]
|
||||||
|
start = sub_seg["start"]
|
||||||
|
stop = sub_seg["stop"]
|
||||||
|
rec_id = str(key).rsplit("_", 2)[0]
|
||||||
|
data.append(
|
||||||
|
meta_info(
|
||||||
|
str(key),
|
||||||
|
float(duration), wav, int(start), int(stop), str(rec_id)))
|
||||||
|
return data
|
||||||
|
|
||||||
|
def _convert_to_record(self, idx: int):
|
||||||
|
sample = self._data[idx]
|
||||||
|
|
||||||
|
record = {}
|
||||||
|
# To show all fields in a namedtuple
|
||||||
|
for field in fields(sample):
|
||||||
|
record[field.name] = getattr(sample, field.name)
|
||||||
|
|
||||||
|
waveform, sr = load_audio(record['wav'])
|
||||||
|
waveform = waveform[record['start']:record['stop']]
|
||||||
|
|
||||||
|
feat_func = feat_funcs[self.feat_type]
|
||||||
|
feat = feat_func(
|
||||||
|
waveform, sr=sr, **self.feat_config) if feat_func else waveform
|
||||||
|
|
||||||
|
record.update({'feat': feat})
|
||||||
|
|
||||||
|
return record
|
||||||
|
|
||||||
|
def __getitem__(self, idx):
|
||||||
|
return self._convert_to_record(idx)
|
||||||
|
|
||||||
|
def __len__(self):
|
||||||
|
return len(self._data)
|
@ -0,0 +1,214 @@
|
|||||||
|
# Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved.
|
||||||
|
#
|
||||||
|
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||||
|
# you may not use this file except in compliance with the License.
|
||||||
|
# You may obtain a copy of the License at
|
||||||
|
#
|
||||||
|
# http://www.apache.org/licenses/LICENSE-2.0
|
||||||
|
#
|
||||||
|
# Unless required by applicable law or agreed to in writing, software
|
||||||
|
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||||
|
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||||
|
# See the License for the specific language governing permissions and
|
||||||
|
# limitations under the License.
|
||||||
|
from typing import Dict
|
||||||
|
|
||||||
|
import paddle
|
||||||
|
|
||||||
|
|
||||||
|
class InputNormalization:
|
||||||
|
spk_dict_mean: Dict[int, paddle.Tensor]
|
||||||
|
spk_dict_std: Dict[int, paddle.Tensor]
|
||||||
|
spk_dict_count: Dict[int, int]
|
||||||
|
|
||||||
|
def __init__(
|
||||||
|
self,
|
||||||
|
mean_norm=True,
|
||||||
|
std_norm=True,
|
||||||
|
norm_type="global", ):
|
||||||
|
"""Do feature or embedding mean and std norm
|
||||||
|
|
||||||
|
Args:
|
||||||
|
mean_norm (bool, optional): mean norm flag. Defaults to True.
|
||||||
|
std_norm (bool, optional): std norm flag. Defaults to True.
|
||||||
|
norm_type (str, optional): norm type. Defaults to "global".
|
||||||
|
"""
|
||||||
|
super().__init__()
|
||||||
|
self.training = True
|
||||||
|
self.mean_norm = mean_norm
|
||||||
|
self.std_norm = std_norm
|
||||||
|
self.norm_type = norm_type
|
||||||
|
self.glob_mean = paddle.to_tensor([0], dtype="float32")
|
||||||
|
self.glob_std = paddle.to_tensor([0], dtype="float32")
|
||||||
|
self.spk_dict_mean = {}
|
||||||
|
self.spk_dict_std = {}
|
||||||
|
self.spk_dict_count = {}
|
||||||
|
self.weight = 1.0
|
||||||
|
self.count = 0
|
||||||
|
self.eps = 1e-10
|
||||||
|
|
||||||
|
def __call__(self,
|
||||||
|
x,
|
||||||
|
lengths,
|
||||||
|
spk_ids=paddle.to_tensor([], dtype="float32")):
|
||||||
|
"""Returns the tensor with the surrounding context.
|
||||||
|
Args:
|
||||||
|
x (paddle.Tensor): A batch of tensors.
|
||||||
|
lengths (paddle.Tensor): A batch of tensors containing the relative length of each
|
||||||
|
sentence (e.g, [0.7, 0.9, 1.0]). It is used to avoid
|
||||||
|
computing stats on zero-padded steps.
|
||||||
|
spk_ids (_type_, optional): tensor containing the ids of each speaker (e.g, [0 10 6]).
|
||||||
|
It is used to perform per-speaker normalization when
|
||||||
|
norm_type='speaker'. Defaults to paddle.to_tensor([], dtype="float32").
|
||||||
|
Returns:
|
||||||
|
paddle.Tensor: The normalized feature or embedding
|
||||||
|
"""
|
||||||
|
N_batches = x.shape[0]
|
||||||
|
# print(f"x shape: {x.shape[1]}")
|
||||||
|
current_means = []
|
||||||
|
current_stds = []
|
||||||
|
|
||||||
|
for snt_id in range(N_batches):
|
||||||
|
|
||||||
|
# Avoiding padded time steps
|
||||||
|
# actual size is the actual time data length
|
||||||
|
actual_size = paddle.round(lengths[snt_id] *
|
||||||
|
x.shape[1]).astype("int32")
|
||||||
|
# computing actual time data statistics
|
||||||
|
current_mean, current_std = self._compute_current_stats(
|
||||||
|
x[snt_id, 0:actual_size, ...].unsqueeze(0))
|
||||||
|
current_means.append(current_mean)
|
||||||
|
current_stds.append(current_std)
|
||||||
|
|
||||||
|
if self.norm_type == "global":
|
||||||
|
current_mean = paddle.mean(paddle.stack(current_means), axis=0)
|
||||||
|
current_std = paddle.mean(paddle.stack(current_stds), axis=0)
|
||||||
|
|
||||||
|
if self.norm_type == "global":
|
||||||
|
|
||||||
|
if self.training:
|
||||||
|
if self.count == 0:
|
||||||
|
self.glob_mean = current_mean
|
||||||
|
self.glob_std = current_std
|
||||||
|
|
||||||
|
else:
|
||||||
|
self.weight = 1 / (self.count + 1)
|
||||||
|
|
||||||
|
self.glob_mean = (
|
||||||
|
1 - self.weight
|
||||||
|
) * self.glob_mean + self.weight * current_mean
|
||||||
|
|
||||||
|
self.glob_std = (
|
||||||
|
1 - self.weight
|
||||||
|
) * self.glob_std + self.weight * current_std
|
||||||
|
|
||||||
|
self.glob_mean.detach()
|
||||||
|
self.glob_std.detach()
|
||||||
|
|
||||||
|
self.count = self.count + 1
|
||||||
|
x = (x - self.glob_mean) / (self.glob_std)
|
||||||
|
return x
|
||||||
|
|
||||||
|
def _compute_current_stats(self, x):
|
||||||
|
"""Returns the tensor with the surrounding context.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
x (paddle.Tensor): A batch of tensors.
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
the statistics of the data
|
||||||
|
"""
|
||||||
|
# Compute current mean
|
||||||
|
if self.mean_norm:
|
||||||
|
current_mean = paddle.mean(x, axis=0).detach()
|
||||||
|
else:
|
||||||
|
current_mean = paddle.to_tensor([0.0], dtype="float32")
|
||||||
|
|
||||||
|
# Compute current std
|
||||||
|
if self.std_norm:
|
||||||
|
current_std = paddle.std(x, axis=0).detach()
|
||||||
|
else:
|
||||||
|
current_std = paddle.to_tensor([1.0], dtype="float32")
|
||||||
|
|
||||||
|
# Improving numerical stability of std
|
||||||
|
current_std = paddle.maximum(current_std,
|
||||||
|
self.eps * paddle.ones_like(current_std))
|
||||||
|
|
||||||
|
return current_mean, current_std
|
||||||
|
|
||||||
|
def _statistics_dict(self):
|
||||||
|
"""Fills the dictionary containing the normalization statistics.
|
||||||
|
"""
|
||||||
|
state = {}
|
||||||
|
state["count"] = self.count
|
||||||
|
state["glob_mean"] = self.glob_mean
|
||||||
|
state["glob_std"] = self.glob_std
|
||||||
|
state["spk_dict_mean"] = self.spk_dict_mean
|
||||||
|
state["spk_dict_std"] = self.spk_dict_std
|
||||||
|
state["spk_dict_count"] = self.spk_dict_count
|
||||||
|
|
||||||
|
return state
|
||||||
|
|
||||||
|
def _load_statistics_dict(self, state):
|
||||||
|
"""Loads the dictionary containing the statistics.
|
||||||
|
|
||||||
|
Arguments
|
||||||
|
---------
|
||||||
|
state : dict
|
||||||
|
A dictionary containing the normalization statistics.
|
||||||
|
"""
|
||||||
|
self.count = state["count"]
|
||||||
|
if isinstance(state["glob_mean"], int):
|
||||||
|
self.glob_mean = state["glob_mean"]
|
||||||
|
self.glob_std = state["glob_std"]
|
||||||
|
else:
|
||||||
|
self.glob_mean = state["glob_mean"] # .to(self.device_inp)
|
||||||
|
self.glob_std = state["glob_std"] # .to(self.device_inp)
|
||||||
|
|
||||||
|
# Loading the spk_dict_mean in the right device
|
||||||
|
self.spk_dict_mean = {}
|
||||||
|
for spk in state["spk_dict_mean"]:
|
||||||
|
self.spk_dict_mean[spk] = state["spk_dict_mean"][spk]
|
||||||
|
|
||||||
|
# Loading the spk_dict_std in the right device
|
||||||
|
self.spk_dict_std = {}
|
||||||
|
for spk in state["spk_dict_std"]:
|
||||||
|
self.spk_dict_std[spk] = state["spk_dict_std"][spk]
|
||||||
|
|
||||||
|
self.spk_dict_count = state["spk_dict_count"]
|
||||||
|
|
||||||
|
return state
|
||||||
|
|
||||||
|
def to(self, device):
|
||||||
|
"""Puts the needed tensors in the right device.
|
||||||
|
"""
|
||||||
|
self = super(InputNormalization, self).to(device)
|
||||||
|
self.glob_mean = self.glob_mean.to(device)
|
||||||
|
self.glob_std = self.glob_std.to(device)
|
||||||
|
for spk in self.spk_dict_mean:
|
||||||
|
self.spk_dict_mean[spk] = self.spk_dict_mean[spk].to(device)
|
||||||
|
self.spk_dict_std[spk] = self.spk_dict_std[spk].to(device)
|
||||||
|
return self
|
||||||
|
|
||||||
|
def save(self, path):
|
||||||
|
"""Save statistic dictionary.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
path (str): A path where to save the dictionary.
|
||||||
|
"""
|
||||||
|
stats = self._statistics_dict()
|
||||||
|
paddle.save(stats, path)
|
||||||
|
|
||||||
|
def _load(self, path, end_of_epoch=False, device=None):
|
||||||
|
"""Load statistic dictionary.
|
||||||
|
|
||||||
|
Arguments
|
||||||
|
---------
|
||||||
|
path : str
|
||||||
|
The path of the statistic dictionary
|
||||||
|
device : str, None
|
||||||
|
Passed to paddle.load(..., map_location=device)
|
||||||
|
"""
|
||||||
|
del end_of_epoch # Unused here.
|
||||||
|
stats = paddle.load(path, map_location=device)
|
||||||
|
self._load_statistics_dict(stats)
|
@ -0,0 +1,32 @@
|
|||||||
|
# Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved.
|
||||||
|
#
|
||||||
|
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||||
|
# you may not use this file except in compliance with the License.
|
||||||
|
# You may obtain a copy of the License at
|
||||||
|
#
|
||||||
|
# http://www.apache.org/licenses/LICENSE-2.0
|
||||||
|
#
|
||||||
|
# Unless required by applicable law or agreed to in writing, software
|
||||||
|
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||||
|
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||||
|
# See the License for the specific language governing permissions and
|
||||||
|
# limitations under the License.
|
||||||
|
|
||||||
|
|
||||||
|
def get_chunks(seg_dur, audio_id, audio_duration):
|
||||||
|
"""Get all chunk segments from a utterance
|
||||||
|
|
||||||
|
Args:
|
||||||
|
seg_dur (float): segment chunk duration, seconds
|
||||||
|
audio_id (str): utterance name,
|
||||||
|
audio_duration (float): utterance duration, seconds
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
List: all the chunk segments
|
||||||
|
"""
|
||||||
|
num_chunks = int(audio_duration / seg_dur) # all in seconds
|
||||||
|
chunk_lst = [
|
||||||
|
audio_id + "_" + str(i * seg_dur) + "_" + str(i * seg_dur + seg_dur)
|
||||||
|
for i in range(num_chunks)
|
||||||
|
]
|
||||||
|
return chunk_lst
|
@ -0,0 +1,24 @@
|
|||||||
|
#!/usr/bin/env bash
|
||||||
|
|
||||||
|
data=$1
|
||||||
|
feat_scp=$2
|
||||||
|
split_feat_name=$3
|
||||||
|
numsplit=$4
|
||||||
|
|
||||||
|
|
||||||
|
if ! [ "$numsplit" -gt 0 ]; then
|
||||||
|
echo "Invalid num-split argument";
|
||||||
|
exit 1;
|
||||||
|
fi
|
||||||
|
|
||||||
|
directories=$(for n in `seq $numsplit`; do echo $data/split${numsplit}/$n; done)
|
||||||
|
feat_split_scp=$(for n in `seq $numsplit`; do echo $data/split${numsplit}/$n/${split_feat_name}; done)
|
||||||
|
echo $feat_split_scp
|
||||||
|
# if this mkdir fails due to argument-list being too long, iterate.
|
||||||
|
if ! mkdir -p $directories >&/dev/null; then
|
||||||
|
for n in `seq $numsplit`; do
|
||||||
|
mkdir -p $data/split${numsplit}/$n
|
||||||
|
done
|
||||||
|
fi
|
||||||
|
|
||||||
|
utils/split_scp.pl $feat_scp $feat_split_scp
|
@ -0,0 +1,14 @@
|
|||||||
|
# This contains the locations of binarys build required for running the examples.
|
||||||
|
|
||||||
|
SPEECHX_ROOT=$PWD/../..
|
||||||
|
SPEECHX_EXAMPLES=$SPEECHX_ROOT/build/examples
|
||||||
|
|
||||||
|
SPEECHX_TOOLS=$SPEECHX_ROOT/tools
|
||||||
|
TOOLS_BIN=$SPEECHX_TOOLS/valgrind/install/bin
|
||||||
|
|
||||||
|
[ -d $SPEECHX_EXAMPLES ] || { echo "Error: 'build/examples' directory not found. please ensure that the project build successfully"; }
|
||||||
|
|
||||||
|
export LC_AL=C
|
||||||
|
|
||||||
|
SPEECHX_BIN=$SPEECHX_EXAMPLES/decoder:$SPEECHX_EXAMPLES/feat
|
||||||
|
export PATH=$PATH:$SPEECHX_BIN:$TOOLS_BIN
|
@ -0,0 +1,113 @@
|
|||||||
|
#!/bin/bash
|
||||||
|
set +x
|
||||||
|
set -e
|
||||||
|
|
||||||
|
. path.sh
|
||||||
|
|
||||||
|
# 1. compile
|
||||||
|
if [ ! -d ${SPEECHX_EXAMPLES} ]; then
|
||||||
|
pushd ${SPEECHX_ROOT}
|
||||||
|
bash build.sh
|
||||||
|
popd
|
||||||
|
fi
|
||||||
|
|
||||||
|
|
||||||
|
# 2. download model
|
||||||
|
if [ ! -d ../paddle_asr_model ]; then
|
||||||
|
wget -c https://paddlespeech.bj.bcebos.com/s2t/paddle_asr_online/paddle_asr_model.tar.gz
|
||||||
|
tar xzfv paddle_asr_model.tar.gz
|
||||||
|
mv ./paddle_asr_model ../
|
||||||
|
# produce wav scp
|
||||||
|
echo "utt1 " $PWD/../paddle_asr_model/BAC009S0764W0290.wav > ../paddle_asr_model/wav.scp
|
||||||
|
fi
|
||||||
|
|
||||||
|
mkdir -p data
|
||||||
|
data=$PWD/data
|
||||||
|
aishell_wav_scp=aishell_test.scp
|
||||||
|
if [ ! -d $data/test ]; then
|
||||||
|
wget -c https://paddlespeech.bj.bcebos.com/s2t/paddle_asr_online/aishell_test.zip
|
||||||
|
unzip -d $data aishell_test.zip
|
||||||
|
realpath $data/test/*/*.wav > $data/wavlist
|
||||||
|
awk -F '/' '{ print $(NF) }' $data/wavlist | awk -F '.' '{ print $1 }' > $data/utt_id
|
||||||
|
paste $data/utt_id $data/wavlist > $data/$aishell_wav_scp
|
||||||
|
fi
|
||||||
|
|
||||||
|
model_dir=$PWD/aishell_ds2_online_model
|
||||||
|
if [ ! -d $model_dir ]; then
|
||||||
|
mkdir -p $model_dir
|
||||||
|
wget -P $model_dir -c https://paddlespeech.bj.bcebos.com/s2t/aishell/asr0/asr0_deepspeech2_online_aishell_ckpt_0.2.0.model.tar.gz
|
||||||
|
tar xzfv $model_dir/asr0_deepspeech2_online_aishell_ckpt_0.2.0.model.tar.gz -C $model_dir
|
||||||
|
fi
|
||||||
|
|
||||||
|
# 3. make feature
|
||||||
|
aishell_online_model=$model_dir/exp/deepspeech2_online/checkpoints
|
||||||
|
lm_model_dir=../paddle_asr_model
|
||||||
|
label_file=./aishell_result
|
||||||
|
wer=./aishell_wer
|
||||||
|
|
||||||
|
nj=40
|
||||||
|
export GLOG_logtostderr=1
|
||||||
|
|
||||||
|
#./local/split_data.sh $data $data/$aishell_wav_scp $aishell_wav_scp $nj
|
||||||
|
|
||||||
|
data=$PWD/data
|
||||||
|
# 3. gen linear feat
|
||||||
|
cmvn=$PWD/cmvn.ark
|
||||||
|
cmvn_json2binary_main --json_file=$model_dir/data/mean_std.json --cmvn_write_path=$cmvn
|
||||||
|
|
||||||
|
utils/run.pl JOB=1:$nj $data/split${nj}/JOB/feat_log \
|
||||||
|
linear_spectrogram_without_db_norm_main \
|
||||||
|
--wav_rspecifier=scp:$data/split${nj}/JOB/${aishell_wav_scp} \
|
||||||
|
--feature_wspecifier=ark,scp:$data/split${nj}/JOB/feat.ark,$data/split${nj}/JOB/feat.scp \
|
||||||
|
--cmvn_file=$cmvn \
|
||||||
|
--streaming_chunk=0.36
|
||||||
|
|
||||||
|
text=$data/test/text
|
||||||
|
|
||||||
|
# 4. recognizer
|
||||||
|
utils/run.pl JOB=1:$nj $data/split${nj}/JOB/log \
|
||||||
|
offline_decoder_sliding_chunk_main \
|
||||||
|
--feature_rspecifier=scp:$data/split${nj}/JOB/feat.scp \
|
||||||
|
--model_path=$aishell_online_model/avg_1.jit.pdmodel \
|
||||||
|
--param_path=$aishell_online_model/avg_1.jit.pdiparams \
|
||||||
|
--model_output_names=softmax_0.tmp_0,tmp_5,concat_0.tmp_0,concat_1.tmp_0 \
|
||||||
|
--dict_file=$lm_model_dir/vocab.txt \
|
||||||
|
--result_wspecifier=ark,t:$data/split${nj}/JOB/result
|
||||||
|
|
||||||
|
cat $data/split${nj}/*/result > ${label_file}
|
||||||
|
local/compute-wer.py --char=1 --v=1 ${label_file} $text > ${wer}
|
||||||
|
|
||||||
|
# 4. decode with lm
|
||||||
|
utils/run.pl JOB=1:$nj $data/split${nj}/JOB/log_lm \
|
||||||
|
offline_decoder_sliding_chunk_main \
|
||||||
|
--feature_rspecifier=scp:$data/split${nj}/JOB/feat.scp \
|
||||||
|
--model_path=$aishell_online_model/avg_1.jit.pdmodel \
|
||||||
|
--param_path=$aishell_online_model/avg_1.jit.pdiparams \
|
||||||
|
--model_output_names=softmax_0.tmp_0,tmp_5,concat_0.tmp_0,concat_1.tmp_0 \
|
||||||
|
--dict_file=$lm_model_dir/vocab.txt \
|
||||||
|
--lm_path=$lm_model_dir/avg_1.jit.klm \
|
||||||
|
--result_wspecifier=ark,t:$data/split${nj}/JOB/result_lm
|
||||||
|
|
||||||
|
cat $data/split${nj}/*/result_lm > ${label_file}_lm
|
||||||
|
local/compute-wer.py --char=1 --v=1 ${label_file}_lm $text > ${wer}_lm
|
||||||
|
|
||||||
|
graph_dir=./aishell_graph
|
||||||
|
if [ ! -d $ ]; then
|
||||||
|
wget -c https://paddlespeech.bj.bcebos.com/s2t/paddle_asr_online/aishell_graph.zip
|
||||||
|
unzip -d aishell_graph.zip
|
||||||
|
fi
|
||||||
|
|
||||||
|
# 5. test TLG decoder
|
||||||
|
utils/run.pl JOB=1:$nj $data/split${nj}/JOB/log_tlg \
|
||||||
|
offline_wfst_decoder_main \
|
||||||
|
--feature_rspecifier=scp:$data/split${nj}/JOB/feat.scp \
|
||||||
|
--model_path=$aishell_online_model/avg_1.jit.pdmodel \
|
||||||
|
--param_path=$aishell_online_model/avg_1.jit.pdiparams \
|
||||||
|
--word_symbol_table=$graph_dir/words.txt \
|
||||||
|
--model_output_names=softmax_0.tmp_0,tmp_5,concat_0.tmp_0,concat_1.tmp_0 \
|
||||||
|
--graph_path=$graph_dir/TLG.fst --max_active=7500 \
|
||||||
|
--acoustic_scale=1.2 \
|
||||||
|
--result_wspecifier=ark,t:$data/split${nj}/JOB/result_tlg
|
||||||
|
|
||||||
|
cat $data/split${nj}/*/result_tlg > ${label_file}_tlg
|
||||||
|
local/compute-wer.py --char=1 --v=1 ${label_file}_tlg $text > ${wer}_tlg
|
@ -0,0 +1 @@
|
|||||||
|
../../../utils
|
@ -0,0 +1,158 @@
|
|||||||
|
// Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved.
|
||||||
|
//
|
||||||
|
// Licensed under the Apache License, Version 2.0 (the "License");
|
||||||
|
// you may not use this file except in compliance with the License.
|
||||||
|
// You may obtain a copy of the License at
|
||||||
|
//
|
||||||
|
// http://www.apache.org/licenses/LICENSE-2.0
|
||||||
|
//
|
||||||
|
// Unless required by applicable law or agreed to in writing, software
|
||||||
|
// distributed under the License is distributed on an "AS IS" BASIS,
|
||||||
|
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||||
|
// See the License for the specific language governing permissions and
|
||||||
|
// limitations under the License.
|
||||||
|
|
||||||
|
// todo refactor, repalce with gtest
|
||||||
|
|
||||||
|
#include "base/flags.h"
|
||||||
|
#include "base/log.h"
|
||||||
|
#include "decoder/ctc_tlg_decoder.h"
|
||||||
|
#include "frontend/audio/data_cache.h"
|
||||||
|
#include "kaldi/util/table-types.h"
|
||||||
|
#include "nnet/decodable.h"
|
||||||
|
#include "nnet/paddle_nnet.h"
|
||||||
|
|
||||||
|
DEFINE_string(feature_rspecifier, "", "test feature rspecifier");
|
||||||
|
DEFINE_string(result_wspecifier, "", "test result wspecifier");
|
||||||
|
DEFINE_string(model_path, "avg_1.jit.pdmodel", "paddle nnet model");
|
||||||
|
DEFINE_string(param_path, "avg_1.jit.pdiparams", "paddle nnet model param");
|
||||||
|
DEFINE_string(word_symbol_table, "words.txt", "word symbol table");
|
||||||
|
DEFINE_string(graph_path, "TLG", "decoder graph");
|
||||||
|
DEFINE_double(acoustic_scale, 1.0, "acoustic scale");
|
||||||
|
DEFINE_int32(max_active, 7500, "decoder graph");
|
||||||
|
DEFINE_int32(receptive_field_length,
|
||||||
|
7,
|
||||||
|
"receptive field of two CNN(kernel=5) downsampling module.");
|
||||||
|
DEFINE_int32(downsampling_rate,
|
||||||
|
4,
|
||||||
|
"two CNN(kernel=5) module downsampling rate.");
|
||||||
|
DEFINE_string(model_output_names,
|
||||||
|
"save_infer_model/scale_0.tmp_1,save_infer_model/"
|
||||||
|
"scale_1.tmp_1,save_infer_model/scale_2.tmp_1,save_infer_model/"
|
||||||
|
"scale_3.tmp_1",
|
||||||
|
"model output names");
|
||||||
|
DEFINE_string(model_cache_names, "5-1-1024,5-1-1024", "model cache names");
|
||||||
|
|
||||||
|
using kaldi::BaseFloat;
|
||||||
|
using kaldi::Matrix;
|
||||||
|
using std::vector;
|
||||||
|
|
||||||
|
// test TLG decoder by feeding speech feature.
|
||||||
|
int main(int argc, char* argv[]) {
|
||||||
|
gflags::ParseCommandLineFlags(&argc, &argv, false);
|
||||||
|
google::InitGoogleLogging(argv[0]);
|
||||||
|
|
||||||
|
kaldi::SequentialBaseFloatMatrixReader feature_reader(
|
||||||
|
FLAGS_feature_rspecifier);
|
||||||
|
kaldi::TokenWriter result_writer(FLAGS_result_wspecifier);
|
||||||
|
std::string model_graph = FLAGS_model_path;
|
||||||
|
std::string model_params = FLAGS_param_path;
|
||||||
|
std::string word_symbol_table = FLAGS_word_symbol_table;
|
||||||
|
std::string graph_path = FLAGS_graph_path;
|
||||||
|
LOG(INFO) << "model path: " << model_graph;
|
||||||
|
LOG(INFO) << "model param: " << model_params;
|
||||||
|
LOG(INFO) << "word symbol path: " << word_symbol_table;
|
||||||
|
LOG(INFO) << "graph path: " << graph_path;
|
||||||
|
|
||||||
|
int32 num_done = 0, num_err = 0;
|
||||||
|
|
||||||
|
ppspeech::TLGDecoderOptions opts;
|
||||||
|
opts.word_symbol_table = word_symbol_table;
|
||||||
|
opts.fst_path = graph_path;
|
||||||
|
opts.opts.max_active = FLAGS_max_active;
|
||||||
|
opts.opts.beam = 15.0;
|
||||||
|
opts.opts.lattice_beam = 7.5;
|
||||||
|
ppspeech::TLGDecoder decoder(opts);
|
||||||
|
|
||||||
|
ppspeech::ModelOptions model_opts;
|
||||||
|
model_opts.model_path = model_graph;
|
||||||
|
model_opts.params_path = model_params;
|
||||||
|
model_opts.cache_shape = FLAGS_model_cache_names;
|
||||||
|
model_opts.output_names = FLAGS_model_output_names;
|
||||||
|
std::shared_ptr<ppspeech::PaddleNnet> nnet(
|
||||||
|
new ppspeech::PaddleNnet(model_opts));
|
||||||
|
std::shared_ptr<ppspeech::DataCache> raw_data(new ppspeech::DataCache());
|
||||||
|
std::shared_ptr<ppspeech::Decodable> decodable(
|
||||||
|
new ppspeech::Decodable(nnet, raw_data, FLAGS_acoustic_scale));
|
||||||
|
|
||||||
|
int32 chunk_size = FLAGS_receptive_field_length;
|
||||||
|
int32 chunk_stride = FLAGS_downsampling_rate;
|
||||||
|
int32 receptive_field_length = FLAGS_receptive_field_length;
|
||||||
|
LOG(INFO) << "chunk size (frame): " << chunk_size;
|
||||||
|
LOG(INFO) << "chunk stride (frame): " << chunk_stride;
|
||||||
|
LOG(INFO) << "receptive field (frame): " << receptive_field_length;
|
||||||
|
decoder.InitDecoder();
|
||||||
|
|
||||||
|
for (; !feature_reader.Done(); feature_reader.Next()) {
|
||||||
|
string utt = feature_reader.Key();
|
||||||
|
kaldi::Matrix<BaseFloat> feature = feature_reader.Value();
|
||||||
|
raw_data->SetDim(feature.NumCols());
|
||||||
|
LOG(INFO) << "process utt: " << utt;
|
||||||
|
LOG(INFO) << "rows: " << feature.NumRows();
|
||||||
|
LOG(INFO) << "cols: " << feature.NumCols();
|
||||||
|
|
||||||
|
int32 row_idx = 0;
|
||||||
|
int32 padding_len = 0;
|
||||||
|
int32 ori_feature_len = feature.NumRows();
|
||||||
|
if ((feature.NumRows() - chunk_size) % chunk_stride != 0) {
|
||||||
|
padding_len =
|
||||||
|
chunk_stride - (feature.NumRows() - chunk_size) % chunk_stride;
|
||||||
|
feature.Resize(feature.NumRows() + padding_len,
|
||||||
|
feature.NumCols(),
|
||||||
|
kaldi::kCopyData);
|
||||||
|
}
|
||||||
|
int32 num_chunks = (feature.NumRows() - chunk_size) / chunk_stride + 1;
|
||||||
|
for (int chunk_idx = 0; chunk_idx < num_chunks; ++chunk_idx) {
|
||||||
|
kaldi::Vector<kaldi::BaseFloat> feature_chunk(chunk_size *
|
||||||
|
feature.NumCols());
|
||||||
|
int32 feature_chunk_size = 0;
|
||||||
|
if (ori_feature_len > chunk_idx * chunk_stride) {
|
||||||
|
feature_chunk_size = std::min(
|
||||||
|
ori_feature_len - chunk_idx * chunk_stride, chunk_size);
|
||||||
|
}
|
||||||
|
if (feature_chunk_size < receptive_field_length) break;
|
||||||
|
|
||||||
|
int32 start = chunk_idx * chunk_stride;
|
||||||
|
for (int row_id = 0; row_id < chunk_size; ++row_id) {
|
||||||
|
kaldi::SubVector<kaldi::BaseFloat> tmp(feature, start);
|
||||||
|
kaldi::SubVector<kaldi::BaseFloat> f_chunk_tmp(
|
||||||
|
feature_chunk.Data() + row_id * feature.NumCols(),
|
||||||
|
feature.NumCols());
|
||||||
|
f_chunk_tmp.CopyFromVec(tmp);
|
||||||
|
++start;
|
||||||
|
}
|
||||||
|
raw_data->Accept(feature_chunk);
|
||||||
|
if (chunk_idx == num_chunks - 1) {
|
||||||
|
raw_data->SetFinished();
|
||||||
|
}
|
||||||
|
decoder.AdvanceDecode(decodable);
|
||||||
|
}
|
||||||
|
std::string result;
|
||||||
|
result = decoder.GetFinalBestPath();
|
||||||
|
decodable->Reset();
|
||||||
|
decoder.Reset();
|
||||||
|
if (result.empty()) {
|
||||||
|
// the TokenWriter can not write empty string.
|
||||||
|
++num_err;
|
||||||
|
KALDI_LOG << " the result of " << utt << " is empty";
|
||||||
|
continue;
|
||||||
|
}
|
||||||
|
KALDI_LOG << " the result of " << utt << " is " << result;
|
||||||
|
result_writer.Write(utt, result);
|
||||||
|
++num_done;
|
||||||
|
}
|
||||||
|
|
||||||
|
KALDI_LOG << "Done " << num_done << " utterances, " << num_err
|
||||||
|
<< " with errors.";
|
||||||
|
return (num_done != 0 ? 0 : 1);
|
||||||
|
}
|
Some files were not shown because too many files have changed in this diff Show More
Loading…
Reference in new issue