|
|
# Copyright (c) 2021 PaddlePaddle Authors. All Rights Reserved.
|
|
|
#
|
|
|
# Licensed under the Apache License, Version 2.0 (the "License");
|
|
|
# you may not use this file except in compliance with the License.
|
|
|
# You may obtain a copy of the License at
|
|
|
#
|
|
|
# http://www.apache.org/licenses/LICENSE-2.0
|
|
|
#
|
|
|
# Unless required by applicable law or agreed to in writing, software
|
|
|
# distributed under the License is distributed on an "AS IS" BASIS,
|
|
|
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
|
|
# See the License for the specific language governing permissions and
|
|
|
# limitations under the License.
|
|
|
import argparse
|
|
|
import os
|
|
|
import sys
|
|
|
from collections import OrderedDict
|
|
|
from typing import List
|
|
|
from typing import Optional
|
|
|
from typing import Union
|
|
|
|
|
|
import paddle
|
|
|
import soundfile
|
|
|
from yacs.config import CfgNode
|
|
|
|
|
|
from ..executor import BaseExecutor
|
|
|
from ..log import logger
|
|
|
from ..utils import cli_register
|
|
|
from ..utils import download_and_decompress
|
|
|
from ..utils import MODEL_HOME
|
|
|
from ..utils import stats_wrapper
|
|
|
from paddleaudio.backends import load as load_audio
|
|
|
from paddleaudio.compliance.librosa import melspectrogram
|
|
|
from paddlespeech.s2t.utils.dynamic_import import dynamic_import
|
|
|
from paddlespeech.vector.io.batch import feature_normalize
|
|
|
from paddlespeech.vector.modules.sid_model import SpeakerIdetification
|
|
|
|
|
|
pretrained_models = {
|
|
|
# The tags for pretrained_models should be "{model_name}[-{dataset}][-{sr}][-...]".
|
|
|
# e.g. "ecapatdnn_voxceleb12-16k".
|
|
|
# Command line and python api use "{model_name}[-{dataset}]" as --model, usage:
|
|
|
# "paddlespeech vector --task spk --model ecapatdnn_voxceleb12-16k --sr 16000 --input ./input.wav"
|
|
|
"ecapatdnn_voxceleb12-16k": {
|
|
|
'url':
|
|
|
'https://paddlespeech.bj.bcebos.com/vector/voxceleb/sv0_ecapa_tdnn_voxceleb12_ckpt_0_2_0.tar.gz',
|
|
|
'md5':
|
|
|
'cc33023c54ab346cd318408f43fcaf95',
|
|
|
'cfg_path':
|
|
|
'conf/model.yaml', # the yaml config path
|
|
|
'ckpt_path':
|
|
|
'model/model', # the format is ${dir}/{model_name},
|
|
|
# so the first 'model' is dir, the second 'model' is the name
|
|
|
# this means we have a model stored as model/model.pdparams
|
|
|
},
|
|
|
}
|
|
|
|
|
|
model_alias = {
|
|
|
"ecapatdnn": "paddlespeech.vector.models.ecapa_tdnn:EcapaTdnn",
|
|
|
}
|
|
|
|
|
|
|
|
|
@cli_register(
|
|
|
name="paddlespeech.vector",
|
|
|
description="Speech to vector embedding infer command.")
|
|
|
class VectorExecutor(BaseExecutor):
|
|
|
def __init__(self):
|
|
|
super(VectorExecutor, self).__init__()
|
|
|
|
|
|
self.parser = argparse.ArgumentParser(
|
|
|
prog="paddlespeech.vector", add_help=True)
|
|
|
|
|
|
self.parser.add_argument(
|
|
|
"--model",
|
|
|
type=str,
|
|
|
default="ecapatdnn_voxceleb12",
|
|
|
choices=["ecapatdnn_voxceleb12"],
|
|
|
help="Choose model type of vector task.")
|
|
|
self.parser.add_argument(
|
|
|
"--task",
|
|
|
type=str,
|
|
|
default="spk",
|
|
|
choices=["spk"],
|
|
|
help="task type in vector domain")
|
|
|
self.parser.add_argument(
|
|
|
"--input",
|
|
|
type=str,
|
|
|
default=None,
|
|
|
help="Audio file to extract embedding.")
|
|
|
self.parser.add_argument(
|
|
|
"--sample_rate",
|
|
|
type=int,
|
|
|
default=16000,
|
|
|
choices=[16000],
|
|
|
help="Choose the audio sample rate of the model. 8000 or 16000")
|
|
|
self.parser.add_argument(
|
|
|
"--ckpt_path",
|
|
|
type=str,
|
|
|
default=None,
|
|
|
help="Checkpoint file of model.")
|
|
|
self.parser.add_argument(
|
|
|
'--config',
|
|
|
type=str,
|
|
|
default=None,
|
|
|
help='Config of asr task. Use deault config when it is None.')
|
|
|
self.parser.add_argument(
|
|
|
"--device",
|
|
|
type=str,
|
|
|
default=paddle.get_device(),
|
|
|
help="Choose device to execute model inference.")
|
|
|
self.parser.add_argument(
|
|
|
'-d',
|
|
|
'--job_dump_result',
|
|
|
action='store_true',
|
|
|
help='Save job result into file.')
|
|
|
|
|
|
self.parser.add_argument(
|
|
|
'-v',
|
|
|
'--verbose',
|
|
|
action='store_true',
|
|
|
help='Increase logger verbosity of current task.')
|
|
|
|
|
|
def execute(self, argv: List[str]) -> bool:
|
|
|
"""Command line entry for vector model
|
|
|
|
|
|
Args:
|
|
|
argv (List[str]): command line args list
|
|
|
|
|
|
Returns:
|
|
|
bool:
|
|
|
False: some audio occurs error
|
|
|
True: all audio process success
|
|
|
"""
|
|
|
# stage 0: parse the args and get the required args
|
|
|
parser_args = self.parser.parse_args(argv)
|
|
|
model = parser_args.model
|
|
|
sample_rate = parser_args.sample_rate
|
|
|
config = parser_args.config
|
|
|
ckpt_path = parser_args.ckpt_path
|
|
|
device = parser_args.device
|
|
|
|
|
|
# stage 1: configurate the verbose flag
|
|
|
if not parser_args.verbose:
|
|
|
self.disable_task_loggers()
|
|
|
|
|
|
# stage 2: read the input data and store them as a list
|
|
|
task_source = self.get_task_source(parser_args.input)
|
|
|
logger.info(f"task source: {task_source}")
|
|
|
|
|
|
# stage 3: process the audio one by one
|
|
|
task_result = OrderedDict()
|
|
|
has_exceptions = False
|
|
|
for id_, input_ in task_source.items():
|
|
|
try:
|
|
|
res = self(input_, model, sample_rate, config, ckpt_path,
|
|
|
device)
|
|
|
task_result[id_] = res
|
|
|
except Exception as e:
|
|
|
has_exceptions = True
|
|
|
task_result[id_] = f'{e.__class__.__name__}: {e}'
|
|
|
|
|
|
logger.info("task result as follows: ")
|
|
|
logger.info(f"{task_result}")
|
|
|
|
|
|
# stage 4: process the all the task results
|
|
|
self.process_task_results(parser_args.input, task_result,
|
|
|
parser_args.job_dump_result)
|
|
|
|
|
|
# stage 5: return the exception flag
|
|
|
# if return False, somen audio process occurs error
|
|
|
if has_exceptions:
|
|
|
return False
|
|
|
else:
|
|
|
return True
|
|
|
|
|
|
@stats_wrapper
|
|
|
def __call__(self,
|
|
|
audio_file: os.PathLike,
|
|
|
model: str='ecapatdnn_voxceleb12',
|
|
|
sample_rate: int=16000,
|
|
|
config: os.PathLike=None,
|
|
|
ckpt_path: os.PathLike=None,
|
|
|
device=paddle.get_device()):
|
|
|
"""Extract the audio embedding
|
|
|
|
|
|
Args:
|
|
|
audio_file (os.PathLike): audio path,
|
|
|
whose format must be wav and sample rate must be matched the model
|
|
|
model (str, optional): mode type, which is been loaded from the pretrained model list.
|
|
|
Defaults to 'ecapatdnn-voxceleb12'.
|
|
|
sample_rate (int, optional): model sample rate. Defaults to 16000.
|
|
|
config (os.PathLike, optional): yaml config. Defaults to None.
|
|
|
ckpt_path (os.PathLike, optional): pretrained model path. Defaults to None.
|
|
|
device (optional): paddle running host device. Defaults to paddle.get_device().
|
|
|
|
|
|
Returns:
|
|
|
dict: return the audio embedding and the embedding shape
|
|
|
"""
|
|
|
# stage 0: check the audio format
|
|
|
audio_file = os.path.abspath(audio_file)
|
|
|
if not self._check(audio_file, sample_rate):
|
|
|
sys.exit(-1)
|
|
|
|
|
|
# stage 1: set the paddle runtime host device
|
|
|
logger.info(f"device type: {device}")
|
|
|
paddle.device.set_device(device)
|
|
|
|
|
|
# stage 2: read the specific pretrained model
|
|
|
self._init_from_path(model, sample_rate, config, ckpt_path)
|
|
|
|
|
|
# stage 3: preprocess the audio and get the audio feat
|
|
|
self.preprocess(model, audio_file)
|
|
|
|
|
|
# stage 4: infer the model and get the audio embedding
|
|
|
self.infer(model)
|
|
|
|
|
|
# stage 5: process the result and set them to output dict
|
|
|
res = self.postprocess()
|
|
|
|
|
|
return res
|
|
|
|
|
|
def _get_pretrained_path(self, tag: str) -> os.PathLike:
|
|
|
"""get the neural network path from the pretrained model list
|
|
|
we stored all the pretained mode in the variable `pretrained_models`
|
|
|
|
|
|
Args:
|
|
|
tag (str): model tag in the pretrained model list
|
|
|
|
|
|
Returns:
|
|
|
os.PathLike: the downloaded pretrained model path in the disk
|
|
|
"""
|
|
|
support_models = list(pretrained_models.keys())
|
|
|
assert tag in pretrained_models, \
|
|
|
'The model "{}" you want to use has not been supported,'\
|
|
|
'please choose other models.\n' \
|
|
|
'The support models includes\n\t\t{}'.format(tag, "\n\t\t".join(support_models))
|
|
|
|
|
|
res_path = os.path.join(MODEL_HOME, tag)
|
|
|
decompressed_path = download_and_decompress(pretrained_models[tag],
|
|
|
res_path)
|
|
|
|
|
|
decompressed_path = os.path.abspath(decompressed_path)
|
|
|
logger.info(
|
|
|
'Use pretrained model stored in: {}'.format(decompressed_path))
|
|
|
|
|
|
return decompressed_path
|
|
|
|
|
|
def _init_from_path(self,
|
|
|
model_type: str='ecapatdnn_voxceleb12',
|
|
|
sample_rate: int=16000,
|
|
|
cfg_path: Optional[os.PathLike]=None,
|
|
|
ckpt_path: Optional[os.PathLike]=None):
|
|
|
"""Init the neural network from the model path
|
|
|
|
|
|
Args:
|
|
|
model_type (str, optional): model tag in the pretrained model list.
|
|
|
Defaults to 'ecapatdnn_voxceleb12'.
|
|
|
sample_rate (int, optional): model sample rate.
|
|
|
Defaults to 16000.
|
|
|
cfg_path (Optional[os.PathLike], optional): yaml config file path.
|
|
|
Defaults to None.
|
|
|
ckpt_path (Optional[os.PathLike], optional): the pretrained model path, which is stored in the disk.
|
|
|
Defaults to None.
|
|
|
"""
|
|
|
# stage 0: avoid to init the mode again
|
|
|
if hasattr(self, "model"):
|
|
|
logger.info("Model has been initialized")
|
|
|
return
|
|
|
|
|
|
# stage 1: get the model and config path
|
|
|
# if we want init the network from the model stored in the disk,
|
|
|
# we must pass the config path and the ckpt model path
|
|
|
if cfg_path is None or ckpt_path is None:
|
|
|
# get the mode from pretrained list
|
|
|
sample_rate_str = "16k" if sample_rate == 16000 else "8k"
|
|
|
tag = model_type + "-" + sample_rate_str
|
|
|
logger.info(f"load the pretrained model: {tag}")
|
|
|
# get the model from the pretrained list
|
|
|
# we download the pretrained model and store it in the res_path
|
|
|
res_path = self._get_pretrained_path(tag)
|
|
|
self.res_path = res_path
|
|
|
|
|
|
self.cfg_path = os.path.join(res_path,
|
|
|
pretrained_models[tag]['cfg_path'])
|
|
|
self.ckpt_path = os.path.join(
|
|
|
res_path, pretrained_models[tag]['ckpt_path'] + '.pdparams')
|
|
|
else:
|
|
|
# get the model from disk
|
|
|
self.cfg_path = os.path.abspath(cfg_path)
|
|
|
self.ckpt_path = os.path.abspath(ckpt_path + ".pdparams")
|
|
|
self.res_path = os.path.dirname(
|
|
|
os.path.dirname(os.path.abspath(self.cfg_path)))
|
|
|
|
|
|
logger.info(f"start to read the ckpt from {self.ckpt_path}")
|
|
|
logger.info(f"read the config from {self.cfg_path}")
|
|
|
logger.info(f"get the res path {self.res_path}")
|
|
|
|
|
|
# stage 2: read and config and init the model body
|
|
|
self.config = CfgNode(new_allowed=True)
|
|
|
self.config.merge_from_file(self.cfg_path)
|
|
|
|
|
|
# stage 3: get the model name to instance the model network with dynamic_import
|
|
|
logger.info("start to dynamic import the model class")
|
|
|
model_name = model_type[:model_type.rindex('_')]
|
|
|
logger.info(f"model name {model_name}")
|
|
|
model_class = dynamic_import(model_name, model_alias)
|
|
|
model_conf = self.config.model
|
|
|
backbone = model_class(**model_conf)
|
|
|
model = SpeakerIdetification(
|
|
|
backbone=backbone, num_class=self.config.num_speakers)
|
|
|
self.model = model
|
|
|
self.model.eval()
|
|
|
|
|
|
# stage 4: load the model parameters
|
|
|
logger.info("start to set the model parameters to model")
|
|
|
model_dict = paddle.load(self.ckpt_path)
|
|
|
self.model.set_state_dict(model_dict)
|
|
|
|
|
|
logger.info("create the model instance success")
|
|
|
|
|
|
@paddle.no_grad()
|
|
|
def infer(self, model_type: str):
|
|
|
"""Infer the model to get the embedding
|
|
|
|
|
|
Args:
|
|
|
model_type (str): speaker verification model type
|
|
|
"""
|
|
|
# stage 0: get the feat and length from _inputs
|
|
|
feats = self._inputs["feats"]
|
|
|
lengths = self._inputs["lengths"]
|
|
|
logger.info("start to do backbone network model forward")
|
|
|
logger.info(
|
|
|
f"feats shape:{feats.shape}, lengths shape: {lengths.shape}")
|
|
|
|
|
|
# stage 1: get the audio embedding
|
|
|
# embedding from (1, emb_size, 1) -> (emb_size)
|
|
|
embedding = self.model.backbone(feats, lengths).squeeze().numpy()
|
|
|
logger.info(f"embedding size: {embedding.shape}")
|
|
|
|
|
|
# stage 2: put the embedding and dim info to _outputs property
|
|
|
# the embedding type is numpy.array
|
|
|
self._outputs["embedding"] = embedding
|
|
|
|
|
|
def postprocess(self) -> Union[str, os.PathLike]:
|
|
|
"""Return the audio embedding info
|
|
|
|
|
|
Returns:
|
|
|
Union[str, os.PathLike]: audio embedding info
|
|
|
"""
|
|
|
embedding = self._outputs["embedding"]
|
|
|
return embedding
|
|
|
|
|
|
def preprocess(self, model_type: str, input_file: Union[str, os.PathLike]):
|
|
|
"""Extract the audio feat
|
|
|
|
|
|
Args:
|
|
|
model_type (str): speaker verification model type
|
|
|
input_file (Union[str, os.PathLike]): audio file path
|
|
|
"""
|
|
|
audio_file = input_file
|
|
|
if isinstance(audio_file, (str, os.PathLike)):
|
|
|
logger.info(f"Preprocess audio file: {audio_file}")
|
|
|
|
|
|
# stage 1: load the audio sample points
|
|
|
# Note: this process must match the training process
|
|
|
waveform, sr = load_audio(audio_file)
|
|
|
logger.info(f"load the audio sample points, shape is: {waveform.shape}")
|
|
|
|
|
|
# stage 2: get the audio feat
|
|
|
# Note: Now we only support fbank feature
|
|
|
try:
|
|
|
feat = melspectrogram(
|
|
|
x=waveform,
|
|
|
sr=self.config.sr,
|
|
|
n_mels=self.config.n_mels,
|
|
|
window_size=self.config.window_size,
|
|
|
hop_length=self.config.hop_size)
|
|
|
logger.info(f"extract the audio feat, shape is: {feat.shape}")
|
|
|
except Exception as e:
|
|
|
logger.info(f"feat occurs exception {e}")
|
|
|
sys.exit(-1)
|
|
|
|
|
|
feat = paddle.to_tensor(feat).unsqueeze(0)
|
|
|
# in inference period, the lengths is all one without padding
|
|
|
lengths = paddle.ones([1])
|
|
|
|
|
|
# stage 3: we do feature normalize,
|
|
|
# Now we assume that the feat must do normalize
|
|
|
feat = feature_normalize(feat, mean_norm=True, std_norm=False)
|
|
|
|
|
|
# stage 4: store the feat and length in the _inputs,
|
|
|
# which will be used in other function
|
|
|
logger.info(f"feats shape: {feat.shape}")
|
|
|
self._inputs["feats"] = feat
|
|
|
self._inputs["lengths"] = lengths
|
|
|
|
|
|
logger.info("audio extract the feat success")
|
|
|
|
|
|
def _check(self, audio_file: str, sample_rate: int):
|
|
|
"""Check if the model sample match the audio sample rate
|
|
|
|
|
|
Args:
|
|
|
audio_file (str): audio file path, which will be extracted the embedding
|
|
|
sample_rate (int): the desired model sample rate
|
|
|
|
|
|
Returns:
|
|
|
bool: return if the audio sample rate matches the model sample rate
|
|
|
"""
|
|
|
self.sample_rate = sample_rate
|
|
|
if self.sample_rate != 16000 and self.sample_rate != 8000:
|
|
|
logger.error(
|
|
|
"invalid sample rate, please input --sr 8000 or --sr 16000")
|
|
|
return False
|
|
|
|
|
|
if isinstance(audio_file, (str, os.PathLike)):
|
|
|
if not os.path.isfile(audio_file):
|
|
|
logger.error("Please input the right audio file path")
|
|
|
return False
|
|
|
|
|
|
logger.info("checking the aduio file format......")
|
|
|
try:
|
|
|
audio, audio_sample_rate = soundfile.read(
|
|
|
audio_file, dtype="float32", always_2d=True)
|
|
|
except Exception as e:
|
|
|
logger.exception(e)
|
|
|
logger.error(
|
|
|
"can not open the audio file, please check the audio file format is 'wav'. \n \
|
|
|
you can try to use sox to change the file format.\n \
|
|
|
For example: \n \
|
|
|
sample rate: 16k \n \
|
|
|
sox input_audio.xx --rate 16k --bits 16 --channels 1 output_audio.wav \n \
|
|
|
sample rate: 8k \n \
|
|
|
sox input_audio.xx --rate 8k --bits 16 --channels 1 output_audio.wav \n \
|
|
|
")
|
|
|
return False
|
|
|
|
|
|
logger.info(f"The sample rate is {audio_sample_rate}")
|
|
|
|
|
|
if audio_sample_rate != self.sample_rate:
|
|
|
logger.error("The sample rate of the input file is not {}.\n \
|
|
|
The program will resample the wav file to {}.\n \
|
|
|
If the result does not meet your expectations,\n \
|
|
|
Please input the 16k 16 bit 1 channel wav file. \
|
|
|
".format(self.sample_rate, self.sample_rate))
|
|
|
sys.exit(-1)
|
|
|
else:
|
|
|
logger.info("The audio file format is right")
|
|
|
|
|
|
return True
|