add vector cli annotation, test=doc

pull/1605/head
xiongxinlei 3 years ago
parent 30dc4585ce
commit d5142e5e15

@ -97,7 +97,7 @@ wget -c https://paddlespeech.bj.bcebos.com/vector/audio/85236145389.wav
sample_rate=16000,
config=None, # Set `config` and `ckpt_path` to None to use pretrained model.
ckpt_path=None,
audio_file='./zh.wav',
audio_file='./85236145389.wav',
force_yes=False,
device=paddle.get_device())
print('Audio embedding Result: \n{}'.format(audio_emb))

@ -42,13 +42,15 @@ pretrained_models = {
# "paddlespeech vector --task spk --model ecapatdnn_voxceleb12-16k --sr 16000 --input ./input.wav"
"ecapatdnn_voxceleb12-16k": {
'url':
'https://paddlespeech.bj.bcebos.com/vector/voxceleb/sv0_ecapa_tdnn_voxceleb12_ckpt_0_1_0.tar.gz',
'https://paddlespeech.bj.bcebos.com/vector/voxceleb/sv0_ecapa_tdnn_voxceleb12_ckpt_0_1_1.tar.gz',
'md5':
'85ff08ce0ef406b8c6d7b5ffc5b2b48f',
'a1c0dba7d4de997187786ff517d5b4ec',
'cfg_path':
'conf/model.yaml',
'conf/model.yaml', # the yaml config path
'ckpt_path':
'model/model',
'model/model', # the format is ${dir}/{model_name},
# so the first 'model' is dir, the second 'model' is the name
# this means we have a model stored as model/model.pdparams
},
}
@ -173,22 +175,54 @@ class VectorExecutor(BaseExecutor):
sample_rate: int=16000,
config: os.PathLike=None,
ckpt_path: os.PathLike=None,
force_yes: bool=False,
device=paddle.get_device()):
"""Extract the audio embedding
Args:
audio_file (os.PathLike): audio path,
whose format must be wav and sample rate must be matched the model
model (str, optional): mode type, which is been loaded from the pretrained model list.
Defaults to 'ecapatdnn-voxceleb12'.
sample_rate (int, optional): model sample rate. Defaults to 16000.
config (os.PathLike, optional): yaml config. Defaults to None.
ckpt_path (os.PathLike, optional): pretrained model path. Defaults to None.
device (_type_, optional): paddle running host device. Defaults to paddle.get_device().
Returns:
dict: return the audio embedding and the embedding shape
"""
# stage 0: check the audio format
audio_file = os.path.abspath(audio_file)
if not self._check(audio_file, sample_rate):
sys.exit(-1)
# stage 1: set the paddle runtime host device
logger.info(f"device type: {device}")
paddle.device.set_device(device)
# stage 2: read the specific pretrained model
self._init_from_path(model, sample_rate, config, ckpt_path)
# stage 3: preprocess the audio and get the audio feat
self.preprocess(model, audio_file)
# stage 4: infer the model and get the audio embedding
self.infer(model)
# stage 5: process the result and set them to output dict
res = self.postprocess()
return res
def _get_pretrained_path(self, tag: str) -> os.PathLike:
"""get the neural network path from the pretrained model list
Args:
tag (str): model tag in the pretrained model list
Returns:
os.PathLike: the downloaded pretrained model path in the disk
"""
support_models = list(pretrained_models.keys())
assert tag in pretrained_models, \
'The model "{}" you want to use has not been supported,'\
@ -210,15 +244,33 @@ class VectorExecutor(BaseExecutor):
sample_rate: int=16000,
cfg_path: Optional[os.PathLike]=None,
ckpt_path: Optional[os.PathLike]=None):
"""Init the neural network from the model path
Args:
model_type (str, optional): model tag in the pretrained model list.
Defaults to 'ecapatdnn_voxceleb12'.
sample_rate (int, optional): model sample rate.
Defaults to 16000.
cfg_path (Optional[os.PathLike], optional): yaml config file path.
Defaults to None.
ckpt_path (Optional[os.PathLike], optional): the pretrained model path, which is stored in the disk.
Defaults to None.
"""
# stage 0: avoid to init the mode again
if hasattr(self, "model"):
logger.info("Model has been initialized")
return
# stage 1: get the model and config path
# if we want init the network from the model stored in the disk,
# we must pass the config path and the ckpt model path
if cfg_path is None or ckpt_path is None:
# get the mode from pretrained list
sample_rate_str = "16k" if sample_rate == 16000 else "8k"
tag = model_type + "-" + sample_rate_str
logger.info(f"load the pretrained model: {tag}")
# get the model from the pretrained list
# we download the pretrained model and store it in the res_path
res_path = self._get_pretrained_path(tag)
self.res_path = res_path
@ -227,6 +279,7 @@ class VectorExecutor(BaseExecutor):
self.ckpt_path = os.path.join(
res_path, pretrained_models[tag]['ckpt_path'] + '.pdparams')
else:
# get the model from disk
self.cfg_path = os.path.abspath(cfg_path)
self.ckpt_path = os.path.abspath(ckpt_path + ".pdparams")
self.res_path = os.path.dirname(
@ -241,7 +294,6 @@ class VectorExecutor(BaseExecutor):
self.config.merge_from_file(self.cfg_path)
# stage 3: get the model name to instance the model network with dynamic_import
# Noet: we use the '-' to get the model name instead of '_'
logger.info("start to dynamic import the model class")
model_name = model_type[:model_type.rindex('_')]
logger.info(f"model name {model_name}")
@ -262,31 +314,54 @@ class VectorExecutor(BaseExecutor):
@paddle.no_grad()
def infer(self, model_type: str):
"""Infer the model to get the embedding
Args:
model_type (str): speaker verification model type
"""
# stage 0: get the feat and length from _inputs
feats = self._inputs["feats"]
lengths = self._inputs["lengths"]
logger.info("start to do backbone network model forward")
logger.info(
f"feats shape:{feats.shape}, lengths shape: {lengths.shape}")
# stage 1: get the audio embedding
# embedding from (1, emb_size, 1) -> (emb_size)
embedding = self.model.backbone(feats, lengths).squeeze().numpy()
logger.info(f"embedding size: {embedding.shape}")
# stage 2: put the embedding and dim info to _outputs property
self._outputs["embedding"] = embedding
def postprocess(self) -> Union[str, os.PathLike]:
"""Return the audio embedding info
Returns:
Union[str, os.PathLike]: audio embedding info
"""
embedding = self._outputs["embedding"]
dim = embedding.shape[0]
# return {"dim": dim, "embedding": embedding}
return self._outputs["embedding"]
def preprocess(self, model_type: str, input_file: Union[str, os.PathLike]):
"""Extract the audio feat
Args:
model_type (str): speaker verification model type
input_file (Union[str, os.PathLike]): audio file path
"""
audio_file = input_file
if isinstance(audio_file, (str, os.PathLike)):
logger.info(f"Preprocess audio file: {audio_file}")
# stage 1: load the audio
# stage 1: load the audio sample points
waveform, sr = load_audio(audio_file)
logger.info(f"load the audio sample points, shape is: {waveform.shape}")
# stage 2: get the audio feat
# Note: Now we only support fbank feature
try:
feat = melspectrogram(
x=waveform,
@ -302,8 +377,13 @@ class VectorExecutor(BaseExecutor):
feat = paddle.to_tensor(feat).unsqueeze(0)
# in inference period, the lengths is all one without padding
lengths = paddle.ones([1])
# stage 3: we do feature normalize,
# Now we assume that the feat must do normalize
feat = feature_normalize(feat, mean_norm=True, std_norm=False)
# stage 4: store the feat and length in the _inputs,
# which will be used in other function
logger.info(f"feats shape: {feat.shape}")
self._inputs["feats"] = feat
self._inputs["lengths"] = lengths
@ -311,6 +391,15 @@ class VectorExecutor(BaseExecutor):
logger.info("audio extract the feat success")
def _check(self, audio_file: str, sample_rate: int):
"""Check if the model sample match the audio sample rate
Args:
audio_file (str): audio file path, which will be extracted the embedding
sample_rate (int): the desired model sample rate
Returns:
bool: return if the audio sample rate matches the model sample rate
"""
self.sample_rate = sample_rate
if self.sample_rate != 16000 and self.sample_rate != 8000:
logger.error(

Loading…
Cancel
Save