add vector cli component, test=doc

pull/1523/head
xiongxinlei 3 years ago
parent 506d26a957
commit d28ccfa96b

@ -14,10 +14,10 @@ random_chunk: True
# FEATURE EXTRACTION SETTING #
###########################################################
# currently, we only support fbank
sample_rate: 16000
sr: 16000 # sample rate
n_mels: 80
window_size: 400 #25ms, sample rate 16000, 25 * 16000 / 1000 = 400
hop_length: 160 #10ms, sample rate 16000, 10 * 16000 / 1000 = 160
hop_size: 160 #10ms, sample rate 16000, 10 * 16000 / 1000 = 160
###########################################################
# MODEL SETTING #

@ -1,15 +1,36 @@
#!/bin/bash
stage=-1
# Copyright (c) 2021 PaddlePaddle Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
stage=0
stop_stage=100
. ${MAIN_ROOT}/utils/parse_options.sh || exit -1;
if [ $# -ne 2 ] ; then
echo "Usage: $0 [options] <data-dir> <conf-path>";
echo "e.g.: $0 ./data/ conf/ecapa_tdnn.yaml"
echo "Options: "
echo " --stage <stage|-1> # Used to run a partially-completed data process from somewhere in the middle."
echo " --stop-stage <stop-stage|100> # Used to run a partially-completed data process stop stage in the middle"
exit 1;
fi
dir=$1
conf_path=$2
mkdir -p ${dir}
if [ ${stage} -le -1 ] && [ ${stop_stage} -ge -1 ]; then
if [ ${stage} -le 0 ] && [ ${stop_stage} -ge 0 ]; then
# data prepare for vox1 and vox2, vox2 must be converted from m4a to wav
# we should use the local/convert.sh convert m4a to wav
python3 local/data_prepare.py \

@ -1,13 +1,51 @@
#!/bin/bash
# Copyright (c) 2021 PaddlePaddle Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
. ./path.sh
exp_dir=exp/ecapa-tdnn-vox12-big//epoch_10/ # experiment directory
stage=0
stop_stage=100
exp_dir=exp/ecapa-tdnn-vox12-big/ # experiment directory
conf_path=conf/ecapa_tdnn.yaml
audio_path="demo/voxceleb/00001.wav"
use_gpu=true
. ${MAIN_ROOT}/utils/parse_options.sh || exit -1;
if [ $# -ne 0 ] ; then
echo "Usage: $0 [options]";
echo "e.g.: $0 ./data/ exp/voxceleb12/ conf/ecapa_tdnn.yaml"
echo "Options: "
echo " --use-gpu <true,false|true> # specify is gpu is to be used for training"
echo " --stage <stage|-1> # Used to run a partially-completed data process from somewhere in the middle."
echo " --stop-stage <stop-stage|100> # Used to run a partially-completed data process stop stage in the middle"
echo " --exp-dir # experiment directorh, where is has the model.pdparams"
echo " --conf-path # configuration file for extracting the embedding"
echo " --audio-path # audio-path, which will be processed to extract the embedding"
exit 1;
fi
source ${MAIN_ROOT}/utils/parse_options.sh || exit 1;
# set the test device
device="cpu"
if ${use_gpu}; then
device="gpu"
fi
# extract the audio embedding
python3 ${BIN_DIR}/extract_emb.py --device "gpu" \
--config ${conf_path} \
--audio-path ${audio_path} --load-checkpoint ${exp_dir}
if [ ${stage} -le 0 ] && [ ${stop_stage} -ge 0 ]; then
# extract the audio embedding
python3 ${BIN_DIR}/extract_emb.py --device ${device} \
--config ${conf_path} \
--audio-path ${audio_path} --load-checkpoint ${exp_dir}
fi

@ -1,8 +1,42 @@
#!/bin/bash
# Copyright (c) 2021 PaddlePaddle Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
stage=1
stop_stage=100
use_gpu=true # if true, we run on GPU.
. ${MAIN_ROOT}/utils/parse_options.sh || exit -1;
if [ $# -ne 3 ] ; then
echo "Usage: $0 [options] <data-dir> <exp-dir> <conf-path>";
echo "e.g.: $0 ./data/ exp/voxceleb12/ conf/ecapa_tdnn.yaml"
echo "Options: "
echo " --use-gpu <true,false|true> # specify is gpu is to be used for training"
echo " --stage <stage|-1> # Used to run a partially-completed data process from somewhere in the middle."
echo " --stop-stage <stop-stage|100> # Used to run a partially-completed data process stop stage in the middle"
exit 1;
fi
dir=$1
exp_dir=$2
conf_path=$3
python3 ${BIN_DIR}/test.py \
--config ${conf_path} \
--data-dir ${dir} \
--load-checkpoint ${exp_dir}
if [ ${stage} -le 1 ] && [ ${stop_stage} -ge 1 ]; then
# test the model and compute the eer metrics
python3 ${BIN_DIR}/test.py \
--data-dir ${dir} \
--load-checkpoint ${exp_dir} \
--config ${conf_path}
fi

@ -1,18 +1,57 @@
#!/bin/bash
# Copyright (c) 2021 PaddlePaddle Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
stage=0
stop_stage=100
use_gpu=true # if true, we run on GPU.
. ${MAIN_ROOT}/utils/parse_options.sh || exit -1;
if [ $# -ne 3 ] ; then
echo "Usage: $0 [options] <data-dir> <exp-dir> <conf-path>";
echo "e.g.: $0 ./data/ exp/voxceleb12/ conf/ecapa_tdnn.yaml"
echo "Options: "
echo " --use-gpu <true,false|true> # specify is gpu is to be used for training"
echo " --stage <stage|-1> # Used to run a partially-completed data process from somewhere in the middle."
echo " --stop-stage <stop-stage|100> # Used to run a partially-completed data process stop stage in the middle"
exit 1;
fi
dir=$1
exp_dir=$2
conf_path=$3
# get the gpu nums for training
ngpu=$(echo $CUDA_VISIBLE_DEVICES | awk -F "," '{print NF}')
echo "using $ngpu gpus..."
# train the speaker identification task with voxceleb data
# Note: we will store the log file in exp/log directory
python3 -m paddle.distributed.launch --gpus=$CUDA_VISIBLE_DEVICES \
${BIN_DIR}/train.py --device "gpu" --checkpoint-dir ${exp_dir} --augment \
--data-dir ${dir} --config ${conf_path}
# setting training device
device="cpu"
if ${use_gpu}; then
device="gpu"
fi
if [ ${stage} -le 1 ] && [ ${stop_stage} -ge 1 ]; then
# train the speaker identification task with voxceleb data
# and we will create the trained model parameters in ${exp_dir}/model.pdparams as the soft link
# Note: we will store the log file in exp/log directory
python3 -m paddle.distributed.launch --gpus=$CUDA_VISIBLE_DEVICES \
${BIN_DIR}/train.py --device ${device} --checkpoint-dir ${exp_dir} \
--data-dir ${dir} --config ${conf_path}
fi
if [ $? -ne 0 ]; then
echo "Failed in training!"

@ -36,11 +36,10 @@ stop_stage=50
# data directory
# if we set the variable ${dir}, we will store the wav info to this directory
# otherwise, we will store the wav info to vox1 and vox2 directory respectively
# vox2 wav path, we must convert the m4a format to wav format
# dir=data-demo/ # data info directory
dir=demo/ # data info directory
# vox2 wav path, we must convert the m4a format to wav format
dir=data/ # data info directory
exp_dir=exp/ecapa-tdnn-vox12-big// # experiment directory
exp_dir=exp/ecapa-tdnn-vox12-big/ # experiment directory
conf_path=conf/ecapa_tdnn.yaml
gpus=0,1,2,3
@ -50,16 +49,15 @@ mkdir -p ${exp_dir}
if [ $stage -le 0 ] && [ ${stop_stage} -ge 0 ]; then
# stage 0: data prepare for vox1 and vox2, vox2 must be converted from m4a to wav
# and we should specifiy the vox2 data in the data.sh
bash ./local/data.sh ${dir} ${conf_path}|| exit -1;
fi
fi
if [ $stage -le 1 ] && [ ${stop_stage} -ge 1 ]; then
# stage 1: train the speaker identification model
CUDA_VISIBLE_DEVICES=${gpus} bash ./local/train.sh ${dir} ${exp_dir} ${conf_path}
fi
if [ $stage -le 2 ]; then
if [ $stage -le 2 ] && [ ${stop_stage} -ge 2 ]; then
# stage 2: get the speaker verification scores with cosine function
# now we only support use cosine to get the scores
CUDA_VISIBLE_DEVICES=0 bash ./local/test.sh ${dir} ${exp_dir} ${conf_path}

@ -19,9 +19,15 @@ from sklearn.metrics import roc_curve
def compute_eer(labels: np.ndarray, scores: np.ndarray) -> List[float]:
'''
Compute EER and return score threshold.
'''
"""Compute EER and return score threshold.
Args:
labels (np.ndarray): the trial label, shape: [N], one-dimention, N refer to the samples num
scores (np.ndarray): the trial scores, shape: [N], one-dimention, N refer to the samples num
Returns:
List[float]: eer and the specific threshold
"""
fpr, tpr, threshold = roc_curve(y_true=labels, y_score=scores)
fnr = 1 - tpr
eer_threshold = threshold[np.nanargmin(np.absolute((fnr - fpr)))]
@ -54,7 +60,7 @@ def compute_minDCF(positive_scores,
p_target (float, optional): Prior probability of having a target (default 0.01).
Returns:
_type_: min dcf
List[float]: min dcf and the specific threshold
"""
# Computing candidate thresholds
if len(positive_scores.shape) > 1:

@ -21,5 +21,6 @@ from .st import STExecutor
from .stats import StatsExecutor
from .text import TextExecutor
from .tts import TTSExecutor
from .vector import VectorExecutor
_locale._getdefaultlocale = (lambda *args: ['en_US', 'utf8'])

@ -0,0 +1,14 @@
# Copyright (c) 2021 PaddlePaddle Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
from .infer import VectorExecutor

@ -0,0 +1,345 @@
# Copyright (c) 2021 PaddlePaddle Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import argparse
import os
import sys
from collections import OrderedDict
from typing import List
from typing import Optional
from typing import Union
import librosa
import numpy as np
import paddle
import soundfile
from yacs.config import CfgNode
from paddleaudio.backends import load as load_audio
from paddleaudio.compliance.librosa import melspectrogram
from ..download import get_path_from_url
from ..executor import BaseExecutor
from ..log import logger
from ..utils import cli_register
from ..utils import download_and_decompress
from ..utils import MODEL_HOME
from ..utils import stats_wrapper
from paddlespeech.vector.io.batch import feature_normalize
from paddlespeech.s2t.frontend.featurizer.text_featurizer import TextFeaturizer
from paddlespeech.s2t.transform.transformation import Transformation
from paddlespeech.s2t.utils.dynamic_import import dynamic_import
from paddlespeech.s2t.utils.utility import UpdateConfig
from paddlespeech.vector.modules.sid_model import SpeakerIdetification
pretrained_models = {
# The tags for pretrained_models should be "{model_name}[_{dataset}][-{lang}][-...]".
# e.g. "conformer_wenetspeech-zh-16k" and "panns_cnn6-32k".
# Command line and python api use "{model_name}[_{dataset}]" as --model, usage:
# "paddlespeech asr --model conformer_wenetspeech --lang zh --sr 16000 --input ./input.wav"
"ecapa_tdnn-16k": {
'url':
'https://paddlespeech.bj.bcebos.com/s2t/wenetspeech/asr1_conformer_wenetspeech_ckpt_0.1.1.model.tar.gz',
'md5':
'76cb19ed857e6623856b7cd7ebbfeda4',
'cfg_path':
'model.yaml',
'ckpt_path':
'exp/conformer/checkpoints/wenetspeech',
},
}
model_alias = {
"ecapa_tdnn": "paddlespeech.vector.models.ecapa_tdnn:EcapaTdnn",
}
@cli_register(
name="paddlespeech.vector",
description="Speech to vector embedding infer command.")
class VectorExecutor(BaseExecutor):
def __init__(self):
super(VectorExecutor, self).__init__()
self.parser = argparse.ArgumentParser(
prog="paddlespeech.vector", add_help=True)
self.parser.add_argument(
"--model",
type=str,
default="ecapa_tdnn-voxceleb12",
choices=["ecapa_tdnn"],
help="Choose model type of asr task.")
self.parser.add_argument(
"--task",
type=str,
default="spk",
choices=["spk"],
help="task type in vector domain")
self.parser.add_argument(
"--input", type=str, default=None, help="Audio file to recognize.")
self.parser.add_argument(
"--sample_rate",
type=int,
default=16000,
choices=[16000, 8000],
help="Choose the audio sample rate of the model. 8000 or 16000")
self.parser.add_argument(
"--ckpt_path",
type=str,
default=None,
help="Checkpoint file of model.")
self.parser.add_argument(
'--config',
type=str,
default=None,
help='Config of asr task. Use deault config when it is None.')
self.parser.add_argument(
"--device",
type=str,
default=paddle.get_device(),
help="Choose device to execute model inference.")
self.parser.add_argument(
'-d',
'--job_dump_result',
action='store_true',
help='Save job result into file.')
self.parser.add_argument(
'-v',
'--verbose',
action='store_true',
help='Increase logger verbosity of current task.')
def execute(self, argv: List[str]) -> bool:
"""Command line entry for vector model
Args:
argv (List[str]): command line args list
Returns:
bool:
False: some audio occurs error
True: all audio process success
"""
# stage 0: parse the args and get the required args
parser_args = self.parser.parse_args(argv)
model = parser_args.model
sample_rate = parser_args.sample_rate
config = parser_args.config
ckpt_path = parser_args.ckpt_path
device = parser_args.device
# stage 1: configurate the verbose flag
if not parser_args.verbose:
self.disable_task_loggers()
# stage 2: read the input data and store them as a list
task_source = self.get_task_source(parser_args.input)
logger.info(f"task source: {task_source}")
# stage 3: process the audio one by one
task_result = OrderedDict()
has_exceptions = False
for id_, input_ in task_source.items():
try:
res = self(input_, model, sample_rate, config, ckpt_path,
device)
task_result[id_] = res
except Exception as e:
has_exceptions = True
task_result[id_] = f'{e.__class__.__name__}: {e}'
logger.info("task result as follows: ")
logger.info(f"{task_result}")
# stage 4: process the all the task results
self.process_task_results(parser_args.input, task_result,
parser_args.job_dump_result)
# stage 5: return the exception flag
# if return False, somen audio process occurs error
if has_exceptions:
return False
else:
return True
@stats_wrapper
def __call__(self,
audio_file: os.PathLike,
model: str='ecapa_tdnn-voxceleb12',
sample_rate: int=16000,
config: os.PathLike=None,
ckpt_path: os.PathLike=None,
force_yes: bool=False,
device=paddle.get_device()):
audio_file = os.path.abspath(audio_file)
if not self._check(audio_file, sample_rate):
sys.exit(-1)
logger.info(f"device type: {device}")
paddle.device.set_device(device)
self._init_from_path(model, sample_rate, config, ckpt_path)
self.preprocess(model, audio_file)
self.infer(model)
res = self.postprocess()
return res
def _get_pretrained_path(self, tag: str) -> os.PathLike:
support_models = list(pretrained_models.keys())
assert tag in pretrained_models, \
'The model "{}" you want to use has not been supported, \
please choose other models.\n \
The support models includes \n\t\t{}'.format(tag, "\n\t\t".join(support_models))
res_path = os.path.join(MODEL_HOME, tag)
def _init_from_path(self,
model_type: str='ecapa_tdnn-voxceleb12',
sample_rate: int=16000,
cfg_path: Optional[os.PathLike]=None,
ckpt_path: Optional[os.PathLike]=None):
if hasattr(self, "model"):
logger.info("Model has been initialized")
return
# stage 1: get the model and config path
if cfg_path is None or ckpt_path is None:
sample_rate_str = "16k" if sample_rate == 16000 else "8k"
tag = model_type + "-" + sample_rate_str
res_path = self._get_pretrained_path(tag)
else:
self.cfg_path = os.path.abspath(cfg_path)
self.ckpt_path = os.path.abspath(ckpt_path + ".pdparams")
self.res_path = os.path.dirname(
os.path.dirname(os.path.abspath(self.cfg_path)))
logger.info(f"start to read the ckpt from {self.ckpt_path}")
logger.info(f"read the config from {self.cfg_path}")
logger.info(f"get the res path {self.res_path}")
# stage 2: read and config and init the model body
self.config = CfgNode(new_allowed=True)
self.config.merge_from_file(self.cfg_path)
# stage 3: get the model name to instance the model network with dynamic_import
# Noet: we use the '-' to get the model name instead of '_'
logger.info("start to dynamic import the model class")
model_name = model_type[:model_type.rindex('-')]
logger.info(f"model name {model_name}")
model_class = dynamic_import(model_name, model_alias)
model_conf = self.config.model
backbone = model_class(**model_conf)
model = SpeakerIdetification(
backbone=backbone, num_class=self.config.num_speakers)
self.model = model
self.model.eval()
# stage 4: load the model parameters
logger.info("start to set the model parameters to model")
model_dict = paddle.load(self.ckpt_path)
self.model.set_state_dict(model_dict)
logger.info("create the model instance success")
@paddle.no_grad()
def infer(self, model_type: str):
feats = self._inputs["feats"]
lengths = self._inputs["lengths"]
logger.info(f"start to do backbone network model forward")
logger.info(
f"feats shape:{feats.shape}, lengths shape: {lengths.shape}")
# embedding from (1, emb_size, 1) -> (emb_size)
embedding = self.model.backbone(feats, lengths).squeeze().numpy()
logger.info(f"embedding size: {embedding.shape}")
self._outputs["embedding"] = embedding
def postprocess(self) -> Union[str, os.PathLike]:
return self._outputs["embedding"]
def preprocess(self, model_type: str, input_file: Union[str, os.PathLike]):
audio_file = input_file
if isinstance(audio_file, (str, os.PathLike)):
logger.info(f"Preprocess audio file: {audio_file}")
# stage 1: load the audio
waveform, sr = load_audio(audio_file)
logger.info(f"load the audio sample points, shape is: {waveform.shape}")
# stage 2: get the audio feat
try:
feat = melspectrogram(
x=waveform,
sr=self.config.sr,
n_mels=self.config.n_mels,
window_size=self.config.window_size,
hop_length=self.config.hop_size)
logger.info(f"extract the audio feat, shape is: {feat.shape}")
except Exception as e:
logger.info(f"feat occurs exception {e}")
sys.exit(-1)
feat = paddle.to_tensor(feat).unsqueeze(0)
# in inference period, the lengths is all one without padding
lengths = paddle.ones([1])
feat = feature_normalize(feat, mean_norm=True, std_norm=False)
logger.info(f"feats shape: {feat.shape}")
self._inputs["feats"] = feat
self._inputs["lengths"] = lengths
logger.info("audio extract the feat success")
def _check(self, audio_file: str, sample_rate: int):
self.sample_rate = sample_rate
if self.sample_rate != 16000 and self.sample_rate != 8000:
logger.error(
"invalid sample rate, please input --sr 8000 or --sr 16000")
return False
if isinstance(audio_file, (str, os.PathLike)):
if not os.path.isfile(audio_file):
logger.error("Please input the right audio file path")
return False
logger.info("checking the aduio file format......")
try:
audio, audio_sample_rate = soundfile.read(
audio_file, dtype="float32", always_2d=True)
except Exception as e:
logger.exception(e)
logger.error(
"can not open the audio file, please check the audio file format is 'wav'. \n \
you can try to use sox to change the file format.\n \
For example: \n \
sample rate: 16k \n \
sox input_audio.xx --rate 16k --bits 16 --channels 1 output_audio.wav \n \
sample rate: 8k \n \
sox input_audio.xx --rate 8k --bits 16 --channels 1 output_audio.wav \n \
")
return False
logger.info(f"The sample rate is {audio_sample_rate}")
if audio_sample_rate != self.sample_rate:
logger.error("The sample rate of the input file is not {}.\n \
The program will resample the wav file to {}.\n \
If the result does not meet your expectations\n \
Please input the 16k 16 bit 1 channel wav file. \
".format(self.sample_rate, self.sample_rate))
sys.exit(-1)
else:
logger.info("The audio file format is right")
return True

@ -63,16 +63,16 @@ def extract_audio_embedding(args, config):
# so the final shape is [1, dim, time]
start_time = time.time()
feat = melspectrogram(x=waveform,
sr=config.sample_rate,
sr=config.sr,
n_mels=config.n_mels,
window_size=config.window_size,
hop_length=config.hop_length)
hop_length=config.hop_size)
feat = paddle.to_tensor(feat).unsqueeze(0)
# in inference period, the lengths is all one without padding
lengths = paddle.ones([1])
feat = feature_normalize(
feat, mean_norm=True, std_norm=False, convert_to_numpy=True)
feat, mean_norm=True, std_norm=False)
# model backbone network forward the feats and get the embedding
embedding = model.backbone(

@ -49,8 +49,6 @@ def main(args, config):
# stage3: load the pre-trained model
# we get the last model from the epoch and save_interval
last_save_epoch = (config.epochs // config.save_interval) * config.save_interval
args.load_checkpoint = os.path.join(args.load_checkpoint, "epoch_" + str(last_save_epoch))
args.load_checkpoint = os.path.abspath(
os.path.expanduser(args.load_checkpoint))
@ -61,6 +59,7 @@ def main(args, config):
logger.info(f'Checkpoint loaded from {args.load_checkpoint}')
# stage4: construct the enroll and test dataloader
enroll_dataset = VoxCeleb(
subset='enroll',
target_dir=args.data_dir,
@ -68,7 +67,7 @@ def main(args, config):
random_chunk=False,
n_mels=config.n_mels,
window_size=config.window_size,
hop_length=config.hop_length)
hop_length=config.hop_size)
enroll_sampler = BatchSampler(
enroll_dataset, batch_size=config.batch_size,
shuffle=True) # Shuffle to make embedding normalization more robust.
@ -85,7 +84,7 @@ def main(args, config):
random_chunk=False,
n_mels=config.n_mels,
window_size=config.window_size,
hop_length=config.hop_length)
hop_length=config.hop_size)
test_sampler = BatchSampler(
test_dataset, batch_size=config.batch_size, shuffle=True)

@ -15,6 +15,7 @@ import argparse
import os
import numpy as np
import time
import paddle
from paddle.io import BatchSampler
from paddle.io import DataLoader
@ -35,6 +36,7 @@ from paddlespeech.vector.modules.sid_model import SpeakerIdetification
from paddlespeech.vector.training.scheduler import CyclicLRScheduler
from paddlespeech.vector.training.seeding import seed_everything
from paddlespeech.vector.utils.time import Timer
from paddlespeech.vector.io.batch import batch_pad_right
logger = Log(__name__).getlog()
@ -55,7 +57,7 @@ def main(args, config):
train_dataset = VoxCeleb('train', target_dir=args.data_dir)
dev_dataset = VoxCeleb('dev', target_dir=args.data_dir)
if args.augment:
if config.augment:
augment_pipeline = build_augment_pipeline(target_dir=args.data_dir)
else:
augment_pipeline = []
@ -126,6 +128,7 @@ def main(args, config):
# we will comment the training process
steps_per_epoch = len(train_sampler)
timer = Timer(steps_per_epoch * config.epochs)
last_saved_epoch = ""
timer.start()
for epoch in range(start_epoch + 1, config.epochs + 1):
@ -135,9 +138,19 @@ def main(args, config):
avg_loss = 0
num_corrects = 0
num_samples = 0
train_reader_cost = 0.0
train_feat_cost = 0.0
train_run_cost = 0.0
reader_start = time.time()
for batch_idx, batch in enumerate(train_loader):
train_reader_cost += time.time() - reader_start
# stage 9-1: batch data is audio sample points and speaker id label
feat_start = time.time()
waveforms, labels = batch['waveforms'], batch['labels']
waveforms, lengths = batch_pad_right(waveforms.numpy())
waveforms = paddle.to_tensor(waveforms)
# stage 9-2: audio sample augment method, which is done on the audio sample point
# the original wavefrom and the augmented waveform is concatented in a batch
@ -153,18 +166,20 @@ def main(args, config):
feats = []
for waveform in waveforms.numpy():
feat = melspectrogram(x=waveform,
sr=config.sample_rate,
sr=config.sr,
n_mels=config.n_mels,
window_size=config.window_size,
hop_length=config.hop_length)
hop_length=config.hop_size)
feats.append(feat)
feats = paddle.to_tensor(np.asarray(feats))
# stage 9-4: feature normalize, which help converge and imporve the performance
feats = feature_normalize(
feats, mean_norm=True, std_norm=False) # Features normalization
train_feat_cost += time.time() - feat_start
# stage 9-5: model forward, such ecapa-tdnn, x-vector
train_start = time.time()
logits = model(feats)
# stage 9-6: loss function criterion, such AngularMargin, AdditiveAngularMargin
@ -177,6 +192,7 @@ def main(args, config):
paddle.optimizer.lr.LRScheduler):
optimizer._learning_rate.step()
optimizer.clear_grad()
train_run_cost += time.time() - train_start
# stage 9-8: Calculate average loss per batch
avg_loss += loss.numpy()[0]
@ -186,7 +202,7 @@ def main(args, config):
num_corrects += (preds == labels).numpy().sum()
num_samples += feats.shape[0]
timer.count() # step plus one in timer
# stage 9-10: print the log information only on 0-rank per log-freq batchs
if (batch_idx + 1) % config.log_interval == 0 and local_rank == 0:
lr = optimizer.get_lr()
@ -197,6 +213,9 @@ def main(args, config):
epoch, config.epochs, batch_idx + 1, steps_per_epoch)
print_msg += ' loss={:.4f}'.format(avg_loss)
print_msg += ' acc={:.4f}'.format(avg_acc)
print_msg += ' avg_reader_cost: {:.5f} sec,'.format(train_reader_cost / config.log_interval)
print_msg += ' avg_feat_cost: {:.5f} sec,'.format(train_feat_cost / config.log_interval)
print_msg += ' avg_train_cost: {:.5f} sec,'.format(train_run_cost / config.log_interval)
print_msg += ' lr={:.4E} step/sec={:.2f} | ETA {}'.format(
lr, timer.timing, timer.eta)
logger.info(print_msg)
@ -204,6 +223,11 @@ def main(args, config):
avg_loss = 0
num_corrects = 0
num_samples = 0
train_reader_cost = 0.0
train_feat_cost = 0.0
train_run_cost = 0.0
reader_start = time.time()
# stage 9-11: save the model parameters only on 0-rank per save-freq batchs
if epoch % config.save_interval == 0 and batch_idx + 1 == steps_per_epoch:
@ -239,10 +263,10 @@ def main(args, config):
feats = []
for waveform in waveforms.numpy():
feat = melspectrogram(x=waveform,
sr=config.sample_rate,
sr=config.sr,
n_mels=config.n_mels,
window_size=config.window_size,
hop_length=config.hop_length)
hop_length=config.hop_size)
feats.append(feat)
feats = paddle.to_tensor(np.asarray(feats))
@ -261,6 +285,7 @@ def main(args, config):
# stage 9-14: Save model parameters
save_dir = os.path.join(args.checkpoint_dir,
'epoch_{}'.format(epoch))
last_saved_epoch = os.path.join('epoch_{}'.format(epoch), "model.pdparams")
logger.info('Saving model checkpoint to {}'.format(save_dir))
paddle.save(model.state_dict(),
os.path.join(save_dir, 'model.pdparams'))
@ -270,6 +295,14 @@ def main(args, config):
if nranks > 1:
paddle.distributed.barrier() # Main process
# stage 10: create the final trained model.pdparams with soft link
if local_rank == 0:
final_model = os.path.join(args.checkpoint_dir, "model.pdparams")
logger.info(f"we will create the final model: {final_model}")
if os.path.islink(final_model):
logger.info(f"An {final_model} already exists, we will rm is and create it again")
os.unlink(final_model)
os.symlink(last_saved_epoch, final_model)
if __name__ == "__main__":
# yapf: disable
@ -294,10 +327,6 @@ if __name__ == "__main__":
type=str,
default='./checkpoint',
help="Directory to save model checkpoints.")
parser.add_argument("--augment",
action="store_true",
default=False,
help="Apply audio augments.")
args = parser.parse_args()
# yapf: enable

@ -13,7 +13,7 @@
# limitations under the License.
import numpy as np
import paddle
import numpy
def waveform_collate_fn(batch):
waveforms = np.stack([item['feat'] for item in batch])
@ -80,4 +80,92 @@ def batch_feature_normalize(batch, mean_norm: bool=True, std_norm: bool=True):
# we convert the original length of each utterance to the ratio of the max length
lengths = (lengths / lengths.max()).astype(np.float32)
return {'ids': ids, 'feats': feats, 'lengths': lengths}
return {'ids': ids, 'feats': feats, 'lengths': lengths}
def pad_right_to(array, target_shape, mode="constant", value=0):
"""
This function takes a numpy array of arbitrary shape and pads it to target
shape by appending values on the right.
Args:
array: input numpy array. Input array whose dimension we need to pad.
target_shape : (list, tuple). Target shape we want for the target array its len must be equal to array.ndim
mode : str. Pad mode, please refer to numpy.pad documentation.
value : float. Pad value, please refer to numpy.pad documentation.
Returns:
array: numpy.array. Padded array.
valid_vals : list. List containing proportion for each dimension of original, non-padded values.
"""
assert len(target_shape) == array.ndim
pads = [] # this contains the abs length of the padding for each dimension.
valid_vals = [] # thic contains the relative lengths for each dimension.
i = 0 # iterating over target_shape ndims
while i < len(target_shape):
assert (
target_shape[i] >= array.shape[i]
), "Target shape must be >= original shape for every dim"
pads.append([0, target_shape[i] - array.shape[i]])
valid_vals.append(array.shape[i] / target_shape[i])
i += 1
array = numpy.pad(array, pads, mode=mode, constant_values=value)
return array, valid_vals
def batch_pad_right(arrays, mode="constant", value=0):
"""Given a list of numpy arrays it batches them together by padding to the right
on each dimension in order to get same length for all.
Args:
arrays : list. List of array we wish to pad together.
mode : str. Padding mode see numpy.pad documentation.
value : float. Padding value see numpy.pad documentation.
Returns:
array : numpy.array. Padded array.
valid_vals : list. List containing proportion for each dimension of original, non-padded values.
"""
if not len(arrays):
raise IndexError("arrays list must not be empty")
if len(arrays) == 1:
# if there is only one array in the batch we simply unsqueeze it.
return numpy.expand_dims(arrays[0], axis=0), numpy.array([1.0])
if not (
any(
[arrays[i].ndim == arrays[0].ndim for i in range(1, len(arrays))]
)
):
raise IndexError("All arrays must have same number of dimensions")
# FIXME we limit the support here: we allow padding of only the last dimension
# need to remove this when feat extraction is updated to handle multichannel.
max_shape = []
for dim in range(arrays[0].ndim):
if dim != (arrays[0].ndim - 1):
if not all(
[x.shape[dim] == arrays[0].shape[dim] for x in arrays[1:]]
):
raise EnvironmentError(
"arrays should have same dimensions except for last one"
)
max_shape.append(max([x.shape[dim] for x in arrays]))
batched = []
valid = []
for t in arrays:
# for each array we apply pad_right_to
padded, valid_percent = pad_right_to(
t, max_shape, mode=mode, value=value
)
batched.append(padded)
valid.append(valid_percent[-1])
batched = numpy.stack(batched)
return batched, numpy.array(valid)

@ -11,6 +11,8 @@
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# This is modified from SpeechBrain
# https://github.com/speechbrain/speechbrain/blob/085be635c07f16d42cd1295045bc46c407f1e15b/speechbrain/nnet/losses.py
import math
import paddle
@ -20,6 +22,14 @@ import paddle.nn.functional as F
class AngularMargin(nn.Layer):
def __init__(self, margin=0.0, scale=1.0):
"""An implementation of Angular Margin (AM) proposed in the following
paper: '''Margin Matters: Towards More Discriminative Deep Neural Network
Embeddings for Speaker Recognition''' (https://arxiv.org/abs/1906.07317)
Args:
margin (float, optional): The margin for cosine similiarity. Defaults to 0.0.
scale (float, optional): The scale for cosine similiarity. Defaults to 1.0.
"""
super(AngularMargin, self).__init__()
self.margin = margin
self.scale = scale
@ -31,6 +41,15 @@ class AngularMargin(nn.Layer):
class AdditiveAngularMargin(AngularMargin):
def __init__(self, margin=0.0, scale=1.0, easy_margin=False):
"""The Implementation of Additive Angular Margin (AAM) proposed
in the following paper: '''Margin Matters: Towards More Discriminative Deep Neural Network Embeddings for Speaker Recognition'''
(https://arxiv.org/abs/1906.07317)
Args:
margin (float, optional): margin factor. Defaults to 0.0.
scale (float, optional): scale factor. Defaults to 1.0.
easy_margin (bool, optional): easy_margin flag. Defaults to False.
"""
super(AdditiveAngularMargin, self).__init__(margin, scale)
self.easy_margin = easy_margin
@ -53,6 +72,11 @@ class AdditiveAngularMargin(AngularMargin):
class LogSoftmaxWrapper(nn.Layer):
def __init__(self, loss_fn):
"""Speaker identificatin loss function wrapper
including all of compositions of the loss transformation
Args:
loss_fn (_type_): the loss value of a batch
"""
super(LogSoftmaxWrapper, self).__init__()
self.loss_fn = loss_fn
self.criterion = paddle.nn.KLDivLoss(reduction="sum")

@ -24,13 +24,25 @@ class SpeakerIdetification(nn.Layer):
lin_blocks=0,
lin_neurons=192,
dropout=0.1, ):
"""_summary_
Args:
backbone (Paddle.nn.Layer class): the speaker identification backbone network model
num_class (_type_): the speaker class num in the training dataset
lin_blocks (int, optional): the linear layer transform between the embedding and the final linear layer. Defaults to 0.
lin_neurons (int, optional): the output dimension of final linear layer. Defaults to 192.
dropout (float, optional): the dropout factor on the embedding. Defaults to 0.1.
"""
super(SpeakerIdetification, self).__init__()
# speaker idenfication backbone network model
# the output of the backbond network is the target embedding
self.backbone = backbone
if dropout > 0:
self.dropout = nn.Dropout(dropout)
else:
self.dropout = None
# construct the speaker classifer
input_size = self.backbone.emb_size
self.blocks = nn.LayerList()
for i in range(lin_blocks):
@ -40,12 +52,26 @@ class SpeakerIdetification(nn.Layer):
])
input_size = lin_neurons
# the final layer
self.weight = paddle.create_parameter(
shape=(input_size, num_class),
dtype='float32',
attr=paddle.ParamAttr(initializer=nn.initializer.XavierUniform()), )
def forward(self, x, lengths=None):
"""Do the speaker identification model forwrd,
including the speaker embedding model and the classifier model network
Args:
x (Paddle.Tensor): input audio feats,
shape=[batch, dimension, times]
lengths (_type_, optional): input audio length.
shape=[batch, times]
Defaults to None.
Returns:
_type_: _description_
"""
# x.shape: (N, C, L)
x = self.backbone(x, lengths).squeeze(
-1) # (N, emb_size, 1) -> (N, emb_size)

Loading…
Cancel
Save