diff --git a/paddlespeech/cli/vector/infer.py b/paddlespeech/cli/vector/infer.py index 378a3d83..79d3b5db 100644 --- a/paddlespeech/cli/vector/infer.py +++ b/paddlespeech/cli/vector/infer.py @@ -68,12 +68,13 @@ class VectorExecutor(BaseExecutor): self.parser = argparse.ArgumentParser( prog="paddlespeech.vector", add_help=True) + self.parser.add_argument( "--model", type=str, default="ecapatdnn_voxceleb12", choices=["ecapatdnn_voxceleb12"], - help="Choose model type of asr task.") + help="Choose model type of vector task.") self.parser.add_argument( "--task", type=str, @@ -81,7 +82,7 @@ class VectorExecutor(BaseExecutor): choices=["spk"], help="task type in vector domain") self.parser.add_argument( - "--input", type=str, default=None, help="Audio file to recognize.") + "--input", type=str, default=None, help="Audio file to extract embedding.") self.parser.add_argument( "--sample_rate", type=int, @@ -186,7 +187,7 @@ class VectorExecutor(BaseExecutor): sample_rate (int, optional): model sample rate. Defaults to 16000. config (os.PathLike, optional): yaml config. Defaults to None. ckpt_path (os.PathLike, optional): pretrained model path. Defaults to None. - device (_type_, optional): paddle running host device. Defaults to paddle.get_device(). + device (optional): paddle running host device. Defaults to paddle.get_device(). Returns: dict: return the audio embedding and the embedding shape @@ -216,6 +217,7 @@ class VectorExecutor(BaseExecutor): def _get_pretrained_path(self, tag: str) -> os.PathLike: """get the neural network path from the pretrained model list + we stored all the pretained mode in the variable `pretrained_models` Args: tag (str): model tag in the pretrained model list @@ -332,6 +334,7 @@ class VectorExecutor(BaseExecutor): logger.info(f"embedding size: {embedding.shape}") # stage 2: put the embedding and dim info to _outputs property + # the embedding type is numpy.array self._outputs["embedding"] = embedding def postprocess(self) -> Union[str, os.PathLike]: @@ -356,6 +359,7 @@ class VectorExecutor(BaseExecutor): logger.info(f"Preprocess audio file: {audio_file}") # stage 1: load the audio sample points + # Note: this process must match the training process waveform, sr = load_audio(audio_file) logger.info(f"load the audio sample points, shape is: {waveform.shape}") @@ -397,7 +401,7 @@ class VectorExecutor(BaseExecutor): sample_rate (int): the desired model sample rate Returns: - bool: return if the audio sample rate matches the model sample rate + bool: return if the audio sample rate matches the model sample rate """ self.sample_rate = sample_rate if self.sample_rate != 16000 and self.sample_rate != 8000: