add some vector cli comments, test=doc

pull/1605/head
xiongxinlei 3 years ago
parent ef1bc5e815
commit 2c9dc0c89b

@ -68,12 +68,13 @@ class VectorExecutor(BaseExecutor):
self.parser = argparse.ArgumentParser(
prog="paddlespeech.vector", add_help=True)
self.parser.add_argument(
"--model",
type=str,
default="ecapatdnn_voxceleb12",
choices=["ecapatdnn_voxceleb12"],
help="Choose model type of asr task.")
help="Choose model type of vector task.")
self.parser.add_argument(
"--task",
type=str,
@ -81,7 +82,7 @@ class VectorExecutor(BaseExecutor):
choices=["spk"],
help="task type in vector domain")
self.parser.add_argument(
"--input", type=str, default=None, help="Audio file to recognize.")
"--input", type=str, default=None, help="Audio file to extract embedding.")
self.parser.add_argument(
"--sample_rate",
type=int,
@ -186,7 +187,7 @@ class VectorExecutor(BaseExecutor):
sample_rate (int, optional): model sample rate. Defaults to 16000.
config (os.PathLike, optional): yaml config. Defaults to None.
ckpt_path (os.PathLike, optional): pretrained model path. Defaults to None.
device (_type_, optional): paddle running host device. Defaults to paddle.get_device().
device (optional): paddle running host device. Defaults to paddle.get_device().
Returns:
dict: return the audio embedding and the embedding shape
@ -216,6 +217,7 @@ class VectorExecutor(BaseExecutor):
def _get_pretrained_path(self, tag: str) -> os.PathLike:
"""get the neural network path from the pretrained model list
we stored all the pretained mode in the variable `pretrained_models`
Args:
tag (str): model tag in the pretrained model list
@ -332,6 +334,7 @@ class VectorExecutor(BaseExecutor):
logger.info(f"embedding size: {embedding.shape}")
# stage 2: put the embedding and dim info to _outputs property
# the embedding type is numpy.array
self._outputs["embedding"] = embedding
def postprocess(self) -> Union[str, os.PathLike]:
@ -356,6 +359,7 @@ class VectorExecutor(BaseExecutor):
logger.info(f"Preprocess audio file: {audio_file}")
# stage 1: load the audio sample points
# Note: this process must match the training process
waveform, sr = load_audio(audio_file)
logger.info(f"load the audio sample points, shape is: {waveform.shape}")

Loading…
Cancel
Save