add some vector cli comments, test=doc

pull/1605/head
xiongxinlei 3 years ago
parent ef1bc5e815
commit 2c9dc0c89b

@ -68,12 +68,13 @@ class VectorExecutor(BaseExecutor):
self.parser = argparse.ArgumentParser( self.parser = argparse.ArgumentParser(
prog="paddlespeech.vector", add_help=True) prog="paddlespeech.vector", add_help=True)
self.parser.add_argument( self.parser.add_argument(
"--model", "--model",
type=str, type=str,
default="ecapatdnn_voxceleb12", default="ecapatdnn_voxceleb12",
choices=["ecapatdnn_voxceleb12"], choices=["ecapatdnn_voxceleb12"],
help="Choose model type of asr task.") help="Choose model type of vector task.")
self.parser.add_argument( self.parser.add_argument(
"--task", "--task",
type=str, type=str,
@ -81,7 +82,7 @@ class VectorExecutor(BaseExecutor):
choices=["spk"], choices=["spk"],
help="task type in vector domain") help="task type in vector domain")
self.parser.add_argument( self.parser.add_argument(
"--input", type=str, default=None, help="Audio file to recognize.") "--input", type=str, default=None, help="Audio file to extract embedding.")
self.parser.add_argument( self.parser.add_argument(
"--sample_rate", "--sample_rate",
type=int, type=int,
@ -186,7 +187,7 @@ class VectorExecutor(BaseExecutor):
sample_rate (int, optional): model sample rate. Defaults to 16000. sample_rate (int, optional): model sample rate. Defaults to 16000.
config (os.PathLike, optional): yaml config. Defaults to None. config (os.PathLike, optional): yaml config. Defaults to None.
ckpt_path (os.PathLike, optional): pretrained model path. Defaults to None. ckpt_path (os.PathLike, optional): pretrained model path. Defaults to None.
device (_type_, optional): paddle running host device. Defaults to paddle.get_device(). device (optional): paddle running host device. Defaults to paddle.get_device().
Returns: Returns:
dict: return the audio embedding and the embedding shape dict: return the audio embedding and the embedding shape
@ -216,6 +217,7 @@ class VectorExecutor(BaseExecutor):
def _get_pretrained_path(self, tag: str) -> os.PathLike: def _get_pretrained_path(self, tag: str) -> os.PathLike:
"""get the neural network path from the pretrained model list """get the neural network path from the pretrained model list
we stored all the pretained mode in the variable `pretrained_models`
Args: Args:
tag (str): model tag in the pretrained model list tag (str): model tag in the pretrained model list
@ -332,6 +334,7 @@ class VectorExecutor(BaseExecutor):
logger.info(f"embedding size: {embedding.shape}") logger.info(f"embedding size: {embedding.shape}")
# stage 2: put the embedding and dim info to _outputs property # stage 2: put the embedding and dim info to _outputs property
# the embedding type is numpy.array
self._outputs["embedding"] = embedding self._outputs["embedding"] = embedding
def postprocess(self) -> Union[str, os.PathLike]: def postprocess(self) -> Union[str, os.PathLike]:
@ -356,6 +359,7 @@ class VectorExecutor(BaseExecutor):
logger.info(f"Preprocess audio file: {audio_file}") logger.info(f"Preprocess audio file: {audio_file}")
# stage 1: load the audio sample points # stage 1: load the audio sample points
# Note: this process must match the training process
waveform, sr = load_audio(audio_file) waveform, sr = load_audio(audio_file)
logger.info(f"load the audio sample points, shape is: {waveform.shape}") logger.info(f"load the audio sample points, shape is: {waveform.shape}")
@ -397,7 +401,7 @@ class VectorExecutor(BaseExecutor):
sample_rate (int): the desired model sample rate sample_rate (int): the desired model sample rate
Returns: Returns:
bool: return if the audio sample rate matches the model sample rate bool: return if the audio sample rate matches the model sample rate
""" """
self.sample_rate = sample_rate self.sample_rate = sample_rate
if self.sample_rate != 16000 and self.sample_rate != 8000: if self.sample_rate != 16000 and self.sample_rate != 8000:

Loading…
Cancel
Save