From d5142e5e1591f2ef0fab6580c58aba0b6f1ec6a3 Mon Sep 17 00:00:00 2001 From: xiongxinlei Date: Sat, 26 Mar 2022 01:00:19 +0800 Subject: [PATCH] add vector cli annotation, test=doc --- demos/speaker_verification/README_cn.md | 2 +- paddlespeech/cli/vector/infer.py | 103 ++++++++++++++++++++++-- 2 files changed, 97 insertions(+), 8 deletions(-) diff --git a/demos/speaker_verification/README_cn.md b/demos/speaker_verification/README_cn.md index 242c07e6..fde636db 100644 --- a/demos/speaker_verification/README_cn.md +++ b/demos/speaker_verification/README_cn.md @@ -97,7 +97,7 @@ wget -c https://paddlespeech.bj.bcebos.com/vector/audio/85236145389.wav sample_rate=16000, config=None, # Set `config` and `ckpt_path` to None to use pretrained model. ckpt_path=None, - audio_file='./zh.wav', + audio_file='./85236145389.wav', force_yes=False, device=paddle.get_device()) print('Audio embedding Result: \n{}'.format(audio_emb)) diff --git a/paddlespeech/cli/vector/infer.py b/paddlespeech/cli/vector/infer.py index 91974761..53324f93 100644 --- a/paddlespeech/cli/vector/infer.py +++ b/paddlespeech/cli/vector/infer.py @@ -42,13 +42,15 @@ pretrained_models = { # "paddlespeech vector --task spk --model ecapatdnn_voxceleb12-16k --sr 16000 --input ./input.wav" "ecapatdnn_voxceleb12-16k": { 'url': - 'https://paddlespeech.bj.bcebos.com/vector/voxceleb/sv0_ecapa_tdnn_voxceleb12_ckpt_0_1_0.tar.gz', + 'https://paddlespeech.bj.bcebos.com/vector/voxceleb/sv0_ecapa_tdnn_voxceleb12_ckpt_0_1_1.tar.gz', 'md5': - '85ff08ce0ef406b8c6d7b5ffc5b2b48f', + 'a1c0dba7d4de997187786ff517d5b4ec', 'cfg_path': - 'conf/model.yaml', + 'conf/model.yaml', # the yaml config path 'ckpt_path': - 'model/model', + 'model/model', # the format is ${dir}/{model_name}, + # so the first 'model' is dir, the second 'model' is the name + # this means we have a model stored as model/model.pdparams }, } @@ -173,22 +175,54 @@ class VectorExecutor(BaseExecutor): sample_rate: int=16000, config: os.PathLike=None, ckpt_path: os.PathLike=None, - force_yes: bool=False, device=paddle.get_device()): + """Extract the audio embedding + + Args: + audio_file (os.PathLike): audio path, + whose format must be wav and sample rate must be matched the model + model (str, optional): mode type, which is been loaded from the pretrained model list. + Defaults to 'ecapatdnn-voxceleb12'. + sample_rate (int, optional): model sample rate. Defaults to 16000. + config (os.PathLike, optional): yaml config. Defaults to None. + ckpt_path (os.PathLike, optional): pretrained model path. Defaults to None. + device (_type_, optional): paddle running host device. Defaults to paddle.get_device(). + + Returns: + dict: return the audio embedding and the embedding shape + """ + # stage 0: check the audio format audio_file = os.path.abspath(audio_file) if not self._check(audio_file, sample_rate): sys.exit(-1) + # stage 1: set the paddle runtime host device logger.info(f"device type: {device}") paddle.device.set_device(device) + + # stage 2: read the specific pretrained model self._init_from_path(model, sample_rate, config, ckpt_path) + + # stage 3: preprocess the audio and get the audio feat self.preprocess(model, audio_file) + + # stage 4: infer the model and get the audio embedding self.infer(model) + + # stage 5: process the result and set them to output dict res = self.postprocess() return res def _get_pretrained_path(self, tag: str) -> os.PathLike: + """get the neural network path from the pretrained model list + + Args: + tag (str): model tag in the pretrained model list + + Returns: + os.PathLike: the downloaded pretrained model path in the disk + """ support_models = list(pretrained_models.keys()) assert tag in pretrained_models, \ 'The model "{}" you want to use has not been supported,'\ @@ -210,15 +244,33 @@ class VectorExecutor(BaseExecutor): sample_rate: int=16000, cfg_path: Optional[os.PathLike]=None, ckpt_path: Optional[os.PathLike]=None): + """Init the neural network from the model path + + Args: + model_type (str, optional): model tag in the pretrained model list. + Defaults to 'ecapatdnn_voxceleb12'. + sample_rate (int, optional): model sample rate. + Defaults to 16000. + cfg_path (Optional[os.PathLike], optional): yaml config file path. + Defaults to None. + ckpt_path (Optional[os.PathLike], optional): the pretrained model path, which is stored in the disk. + Defaults to None. + """ + # stage 0: avoid to init the mode again if hasattr(self, "model"): logger.info("Model has been initialized") return # stage 1: get the model and config path + # if we want init the network from the model stored in the disk, + # we must pass the config path and the ckpt model path if cfg_path is None or ckpt_path is None: + # get the mode from pretrained list sample_rate_str = "16k" if sample_rate == 16000 else "8k" tag = model_type + "-" + sample_rate_str logger.info(f"load the pretrained model: {tag}") + # get the model from the pretrained list + # we download the pretrained model and store it in the res_path res_path = self._get_pretrained_path(tag) self.res_path = res_path @@ -227,6 +279,7 @@ class VectorExecutor(BaseExecutor): self.ckpt_path = os.path.join( res_path, pretrained_models[tag]['ckpt_path'] + '.pdparams') else: + # get the model from disk self.cfg_path = os.path.abspath(cfg_path) self.ckpt_path = os.path.abspath(ckpt_path + ".pdparams") self.res_path = os.path.dirname( @@ -241,7 +294,6 @@ class VectorExecutor(BaseExecutor): self.config.merge_from_file(self.cfg_path) # stage 3: get the model name to instance the model network with dynamic_import - # Noet: we use the '-' to get the model name instead of '_' logger.info("start to dynamic import the model class") model_name = model_type[:model_type.rindex('_')] logger.info(f"model name {model_name}") @@ -262,31 +314,54 @@ class VectorExecutor(BaseExecutor): @paddle.no_grad() def infer(self, model_type: str): + """Infer the model to get the embedding + Args: + model_type (str): speaker verification model type + """ + # stage 0: get the feat and length from _inputs feats = self._inputs["feats"] lengths = self._inputs["lengths"] logger.info("start to do backbone network model forward") logger.info( f"feats shape:{feats.shape}, lengths shape: {lengths.shape}") + + # stage 1: get the audio embedding # embedding from (1, emb_size, 1) -> (emb_size) embedding = self.model.backbone(feats, lengths).squeeze().numpy() logger.info(f"embedding size: {embedding.shape}") + # stage 2: put the embedding and dim info to _outputs property self._outputs["embedding"] = embedding def postprocess(self) -> Union[str, os.PathLike]: + """Return the audio embedding info + + Returns: + Union[str, os.PathLike]: audio embedding info + """ + embedding = self._outputs["embedding"] + dim = embedding.shape[0] + # return {"dim": dim, "embedding": embedding} return self._outputs["embedding"] def preprocess(self, model_type: str, input_file: Union[str, os.PathLike]): + """Extract the audio feat + + Args: + model_type (str): speaker verification model type + input_file (Union[str, os.PathLike]): audio file path + """ audio_file = input_file if isinstance(audio_file, (str, os.PathLike)): logger.info(f"Preprocess audio file: {audio_file}") - # stage 1: load the audio + # stage 1: load the audio sample points waveform, sr = load_audio(audio_file) logger.info(f"load the audio sample points, shape is: {waveform.shape}") # stage 2: get the audio feat + # Note: Now we only support fbank feature try: feat = melspectrogram( x=waveform, @@ -302,8 +377,13 @@ class VectorExecutor(BaseExecutor): feat = paddle.to_tensor(feat).unsqueeze(0) # in inference period, the lengths is all one without padding lengths = paddle.ones([1]) + + # stage 3: we do feature normalize, + # Now we assume that the feat must do normalize feat = feature_normalize(feat, mean_norm=True, std_norm=False) + # stage 4: store the feat and length in the _inputs, + # which will be used in other function logger.info(f"feats shape: {feat.shape}") self._inputs["feats"] = feat self._inputs["lengths"] = lengths @@ -311,6 +391,15 @@ class VectorExecutor(BaseExecutor): logger.info("audio extract the feat success") def _check(self, audio_file: str, sample_rate: int): + """Check if the model sample match the audio sample rate + + Args: + audio_file (str): audio file path, which will be extracted the embedding + sample_rate (int): the desired model sample rate + + Returns: + bool: return if the audio sample rate matches the model sample rate + """ self.sample_rate = sample_rate if self.sample_rate != 16000 and self.sample_rate != 8000: logger.error(