From 4cdfa5ccfdb57b3e1d178a0b3b7fddc1dd3acdfa Mon Sep 17 00:00:00 2001 From: tianhao zhang <15600919271@163.com> Date: Wed, 16 Nov 2022 04:20:42 +0000 Subject: [PATCH] wav2vec2_cli --- demos/speech_ssl/README.md | 102 ++++ demos/speech_ssl/README_cn.md | 103 ++++ demos/speech_ssl/run.sh | 12 + paddlespeech/cli/base_commands.py | 9 +- paddlespeech/cli/ssl/__init__.py | 14 + paddlespeech/cli/ssl/infer.py | 449 ++++++++++++++++++ paddlespeech/resource/model_alias.py | 6 + paddlespeech/resource/pretrained_models.py | 39 ++ paddlespeech/resource/resource.py | 3 +- paddlespeech/s2t/exps/wav2vec2/__init__.py | 13 + .../s2t/exps/wav2vec2/bin/__init__.py | 2 +- paddlespeech/s2t/exps/wav2vec2/model.py | 4 + paddlespeech/s2t/models/wav2vec2/__init__.py | 17 + .../s2t/models/wav2vec2/modules/__init__.py | 13 + .../models/wav2vec2/processing/__init__.py | 13 + .../s2t/models/wav2vec2/wav2vec2_ASR.py | 32 +- tests/unit/cli/test_cli.sh | 4 + 17 files changed, 828 insertions(+), 7 deletions(-) create mode 100644 demos/speech_ssl/README.md create mode 100644 demos/speech_ssl/README_cn.md create mode 100644 demos/speech_ssl/run.sh create mode 100644 paddlespeech/cli/ssl/__init__.py create mode 100644 paddlespeech/cli/ssl/infer.py create mode 100644 paddlespeech/s2t/exps/wav2vec2/__init__.py create mode 100644 paddlespeech/s2t/models/wav2vec2/modules/__init__.py create mode 100644 paddlespeech/s2t/models/wav2vec2/processing/__init__.py diff --git a/demos/speech_ssl/README.md b/demos/speech_ssl/README.md new file mode 100644 index 000000000..fdef37e7b --- /dev/null +++ b/demos/speech_ssl/README.md @@ -0,0 +1,102 @@ +([简体中文](./README_cn.md)|English) +# Speech SSL (Self-Supervised Learning) + +## Introduction +Speech SSL, or Self-Supervised Learning, refers to a training method on the large-scale unlabeled speech dataset. The model trained in this way can produce a good acoustic representation, and can be applied to other downstream speech tasks by fine-tuning on labeled datasets. + +This demo is an implementation to recognize text or produce the acoustic representation from a specific audio file by speech ssl models. It can be done by a single command or a few lines in python using `PaddleSpeech`. + +## Usage +### 1. Installation +see [installation](https://github.com/PaddlePaddle/PaddleSpeech/blob/develop/docs/source/install.md). + +You can choose one way from easy, meduim and hard to install paddlespeech. + +### 2. Prepare Input File +The input of this demo should be a WAV file(`.wav`), and the sample rate must be the same as the model. + +Here are sample files for this demo that can be downloaded: +```bash +wget -c https://paddlespeech.bj.bcebos.com/PaddleAudio/en.wav +``` + +### 3. Usage +- Command Line(Recommended) + ```bash + # to recognize text + paddlespeech ssl --task asr --lang en --input ./en.wav + + # to get acoustic representation + paddlespeech ssl --task vector --lang en --input ./en.wav + ``` + + Usage: + ```bash + paddlespeech ssl --help + ``` + Arguments: + - `input`(required): Audio file to recognize. + - `model`: Model type of asr task. Default: `wav2vec2ASR_librispeech`. + - `task`: Output type. Default: `asr`. + - `lang`: Model language. Default: `en`. + - `sample_rate`: Sample rate of the model. Default: `16000`. + - `config`: Config of asr task. Use pretrained model when it is None. Default: `None`. + - `ckpt_path`: Model checkpoint. Use pretrained model when it is None. Default: `None`. + - `yes`: No additional parameters required. Once set this parameter, it means accepting the request of the program by default, which includes transforming the audio sample rate. Default: `False`. + - `device`: Choose device to execute model inference. Default: default device of paddlepaddle in current environment. + - `verbose`: Show the log information. + + +- Python API + ```python + import paddle + from paddlespeech.cli.ssl import SSLExecutor + + ssl_executor = SSLExecutor() + + # to recognize text + text = ssl_executor( + model='wav2vec2ASR_librispeech', + task='asr', + lang='en', + sample_rate=16000, + config=None, # Set `config` and `ckpt_path` to None to use pretrained model. + ckpt_path=None, + audio_file='./en.wav', + device=paddle.get_device()) + print('ASR Result: \n{}'.format(text)) + + # to get acoustic representation + feature = ssl_executor( + model='wav2vec2', + task='vector', + lang='en', + sample_rate=16000, + config=None, # Set `config` and `ckpt_path` to None to use pretrained model. + ckpt_path=None, + audio_file='./en.wav', + device=paddle.get_device()) + print('Representation: \n{}'.format(feature)) + ``` + + Output: + ```bash + ASR Result: + 我认为跑步最重要的就是给我带来了身体健康 + + Representation: + Tensor(shape=[1, 164, 1024], dtype=float32, place=Place(gpu:0), stop_gradient=True, + [[[ 0.02351918, -0.12980647, 0.17868176, ..., 0.10118122, + -0.04614586, 0.17853957], + [ 0.02361383, -0.12978461, 0.17870593, ..., 0.10103855, + -0.04638699, 0.17855372], + [ 0.02345137, -0.12982975, 0.17883906, ..., 0.10104341, + -0.04643029, 0.17856732], + ..., + [ 0.02313030, -0.12918393, 0.17845058, ..., 0.10073373, + -0.04701405, 0.17862988], + [ 0.02176583, -0.12929161, 0.17797582, ..., 0.10097728, + -0.04687393, 0.17864393], + [ 0.05269200, 0.01297141, -0.23336855, ..., -0.11257174, + -0.17227529, 0.20338398]]]) + ``` diff --git a/demos/speech_ssl/README_cn.md b/demos/speech_ssl/README_cn.md new file mode 100644 index 000000000..76ec2f1ff --- /dev/null +++ b/demos/speech_ssl/README_cn.md @@ -0,0 +1,103 @@ +(简体中文|[English](./README.md)) + +# 语音自监督学习 +## 介绍 +语音自监督学习,指的是在大规模无标记的语音数据集上的训练方法。用这种方法训练出来的模型可以产生很好的声学表征。并且可以通过在有标签的数据集上进行微调,应用于其他下游的语音任务。 + +这个 demo 是通过语音自监督模型将一个特定的音频文件识别成文本或产生声学表征,它可以通过使用 `PaddleSpeech` 的单个命令或 python 中的几行代码来实现。 + +## 使用方法 +### 1. 安装 +请看[安装文档](https://github.com/PaddlePaddle/PaddleSpeech/blob/develop/docs/source/install_cn.md)。 + +你可以从 easy,medium,hard 三中方式中选择一种方式安装。 + +### 2. 准备输入 +这个 demo 的输入应该是一个 WAV 文件(`.wav`),并且采样率必须与模型的采样率相同。 + +可以下载此 demo 的示例音频: +```bash +wget -c https://paddlespeech.bj.bcebos.com/PaddleAudio/en.wav +``` +### 3. 使用方法 +- 命令行 (推荐使用) + ```bash + + # 识别文本 + paddlespeech ssl --task asr --lang en --input ./en.wav + + # 产生声学表征 + paddlespeech ssl --task vector --lang en --input ./en.wav + ``` + + 使用方法: + ```bash + paddlespeech asr --help + ``` + 参数: + - `input`(必须输入):用于识别的音频文件。 + - `model`:ASR 任务的模型,默认值:`conformer_wenetspeech`。 + - `task`:输出类别,默认值:`asr`。 + - `lang`:模型语言,默认值:`zh`。 + - `sample_rate`:音频采样率,默认值:`16000`。 + - `config`:ASR 任务的参数文件,若不设置则使用预训练模型中的默认配置,默认值:`None`。 + - `ckpt_path`:模型参数文件,若不设置则下载预训练模型使用,默认值:`None`。 + - `yes`;不需要设置额外的参数,一旦设置了该参数,说明你默认同意程序的所有请求,其中包括自动转换输入音频的采样率。默认值:`False`。 + - `device`:执行预测的设备,默认值:当前系统下 paddlepaddle 的默认 device。 + - `verbose`: 如果使用,显示 logger 信息。 + + +- Python API + ```python + import paddle + from paddlespeech.cli.ssl import SSLExecutor + + ssl_executor = SSLExecutor() + + # 识别文本 + text = ssl_executor( + model='wav2vec2ASR_librispeech', + task='asr', + lang='en', + sample_rate=16000, + config=None, # Set `config` and `ckpt_path` to None to use pretrained model. + ckpt_path=None, + audio_file='./en.wav', + device=paddle.get_device()) + print('ASR Result: \n{}'.format(text)) + + # 得到声学表征 + feature = ssl_executor( + model='wav2vec2', + task='vector', + lang='en', + sample_rate=16000, + config=None, # Set `config` and `ckpt_path` to None to use pretrained model. + ckpt_path=None, + audio_file='./en.wav', + device=paddle.get_device()) + print('Representation: \n{}'.format(feature)) + ``` + + + 输出: + ```bash + ASR Result: + 我认为跑步最重要的就是给我带来了身体健康 + + Representation: + Tensor(shape=[1, 164, 1024], dtype=float32, place=Place(gpu:0), stop_gradient=True, + [[[ 0.02351918, -0.12980647, 0.17868176, ..., 0.10118122, + -0.04614586, 0.17853957], + [ 0.02361383, -0.12978461, 0.17870593, ..., 0.10103855, + -0.04638699, 0.17855372], + [ 0.02345137, -0.12982975, 0.17883906, ..., 0.10104341, + -0.04643029, 0.17856732], + ..., + [ 0.02313030, -0.12918393, 0.17845058, ..., 0.10073373, + -0.04701405, 0.17862988], + [ 0.02176583, -0.12929161, 0.17797582, ..., 0.10097728, + -0.04687393, 0.17864393], + [ 0.05269200, 0.01297141, -0.23336855, ..., -0.11257174, + -0.17227529, 0.20338398]]]) + ``` diff --git a/demos/speech_ssl/run.sh b/demos/speech_ssl/run.sh new file mode 100644 index 000000000..b71e0883a --- /dev/null +++ b/demos/speech_ssl/run.sh @@ -0,0 +1,12 @@ +#!/bin/bash + +# audio download +wget -c https://paddlespeech.bj.bcebos.com/PaddleAudio/zh.wav https://paddlespeech.bj.bcebos.com/PaddleAudio/en.wav + +# to recognize text +paddlespeech ssl --task asr --lang en --input ./en.wav + +# to get acoustic representation +paddlespeech ssl --task vector --lang en --input ./en.wav + +README_cn \ No newline at end of file diff --git a/paddlespeech/cli/base_commands.py b/paddlespeech/cli/base_commands.py index 7210091a9..e62f59e6e 100644 --- a/paddlespeech/cli/base_commands.py +++ b/paddlespeech/cli/base_commands.py @@ -83,7 +83,8 @@ model_name_format = { 'st': 'Model-Source language-Target language', 'text': 'Model-Task-Language', 'tts': 'Model-Language', - 'vector': 'Model-Sample Rate' + 'vector': 'Model-Sample Rate', + 'ssl': 'Model-Language-Sample Rate' } @@ -94,7 +95,9 @@ class StatsCommand: def __init__(self): self.parser = argparse.ArgumentParser( prog='paddlespeech.stats', add_help=True) - self.task_choices = ['asr', 'cls', 'st', 'text', 'tts', 'vector', 'kws'] + self.task_choices = [ + 'asr', 'cls', 'st', 'text', 'tts', 'vector', 'kws', 'ssl' + ] self.parser.add_argument( '--task', type=str, @@ -141,6 +144,8 @@ _commands = { 'tts': ['Text to Speech infer command.', 'TTSExecutor'], 'vector': ['Speech to vector embedding infer command.', 'VectorExecutor'], 'kws': ['Keyword Spotting infer command.', 'KWSExecutor'], + 'ssl': + ['Self-Supervised Learning Pretrained model infer command.', 'SSLExecutor'] } for com, info in _commands.items(): diff --git a/paddlespeech/cli/ssl/__init__.py b/paddlespeech/cli/ssl/__init__.py new file mode 100644 index 000000000..2e53128ea --- /dev/null +++ b/paddlespeech/cli/ssl/__init__.py @@ -0,0 +1,14 @@ +# Copyright (c) 2021 PaddlePaddle Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +from .infer import SSLExecutor diff --git a/paddlespeech/cli/ssl/infer.py b/paddlespeech/cli/ssl/infer.py new file mode 100644 index 000000000..154c25f53 --- /dev/null +++ b/paddlespeech/cli/ssl/infer.py @@ -0,0 +1,449 @@ +# Copyright (c) 2021 PaddlePaddle Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +import argparse +import io +import os +import sys +import time +from collections import OrderedDict +from typing import List +from typing import Optional +from typing import Union + +import librosa +import numpy as np +import paddle +import soundfile +from yacs.config import CfgNode + +from ..executor import BaseExecutor +from ..log import logger +from ..utils import CLI_TIMER +from ..utils import stats_wrapper +from ..utils import timer_register +from paddlespeech.audio.transform.transformation import Transformation +from paddlespeech.s2t.frontend.featurizer.text_featurizer import TextFeaturizer +from paddlespeech.s2t.utils.utility import UpdateConfig + +__all__ = ['SSLExecutor'] + + +@timer_register +class SSLExecutor(BaseExecutor): + def __init__(self): + super().__init__('ssl') + self.parser = argparse.ArgumentParser( + prog='paddlespeech.ssl', add_help=True) + self.parser.add_argument( + '--input', type=str, default=None, help='Audio file to recognize.') + self.parser.add_argument( + '--model', + type=str, + default='wav2vec2ASR_librispeech', + choices=[ + tag[:tag.index('-')] + for tag in self.task_resource.pretrained_models.keys() + ], + help='Choose model type of asr task.') + self.parser.add_argument( + '--task', + type=str, + default='asr', + choices=['asr', 'vector'], + help='Choose output type for ssl task') + self.parser.add_argument( + '--lang', + type=str, + default='en', + help='Choose model language. zh or en, zh:[wav2vec2ASR_aishell1-zh-16k], en:[wav2vec2ASR_librispeech-en-16k]' + ) + self.parser.add_argument( + "--sample_rate", + type=int, + default=16000, + choices=[8000, 16000], + help='Choose the audio sample rate of the model. 8000 or 16000') + self.parser.add_argument( + '--config', + type=str, + default=None, + help='Config of asr task. Use deault config when it is None.') + self.parser.add_argument( + '--decode_method', + type=str, + default='ctc_greedy_search', + choices=[ + 'ctc_greedy_search', + 'ctc_prefix_beam_search', + ], + help='only support asr task') + self.parser.add_argument( + '--ckpt_path', + type=str, + default=None, + help='Checkpoint file of model.') + self.parser.add_argument( + '--yes', + '-y', + action="store_true", + default=False, + help='No additional parameters required. \ + Once set this parameter, it means accepting the request of the program by default, \ + which includes transforming the audio sample rate') + self.parser.add_argument( + '--rtf', + action="store_true", + default=False, + help='Show Real-time Factor(RTF).') + self.parser.add_argument( + '--device', + type=str, + default=paddle.get_device(), + help='Choose device to execute model inference.') + self.parser.add_argument( + '-d', + '--job_dump_result', + action='store_true', + help='Save job result into file.') + self.parser.add_argument( + '-v', + '--verbose', + action='store_true', + help='Increase logger verbosity of current task.') + + def _init_from_path(self, + model_type: str='wav2vec2ASR_librispeech', + task: str='asr', + lang: str='en', + sample_rate: int=16000, + cfg_path: Optional[os.PathLike]=None, + decode_method: str='ctc_greedy_search', + ckpt_path: Optional[os.PathLike]=None): + """ + Init model and other resources from a specific path. + """ + logger.debug("start to init the model") + # default max_len: unit:second + self.max_len = 50 + if hasattr(self, 'model'): + logger.debug('Model had been initialized.') + return + if cfg_path is None or ckpt_path is None: + sample_rate_str = '16k' if sample_rate == 16000 else '8k' + if task == 'asr': + tag = model_type + '-' + lang + '-' + sample_rate_str + else: + tag = 'wav2vec2' + '-' + lang + '-' + sample_rate_str + self.task_resource.set_task_model(tag, version=None) + self.res_path = self.task_resource.res_dir + + self.cfg_path = os.path.join( + self.res_path, self.task_resource.res_dict['cfg_path']) + self.ckpt_path = os.path.join( + self.res_path, + self.task_resource.res_dict['ckpt_path'] + ".pdparams") + logger.debug(self.res_path) + else: + self.cfg_path = os.path.abspath(cfg_path) + self.ckpt_path = os.path.abspath(ckpt_path + ".pdparams") + self.res_path = os.path.dirname( + os.path.dirname(os.path.abspath(self.cfg_path))) + logger.debug(self.cfg_path) + logger.debug(self.ckpt_path) + + #Init body. + self.config = CfgNode(new_allowed=True) + self.config.merge_from_file(self.cfg_path) + if task == 'asr': + with UpdateConfig(self.config): + self.text_feature = TextFeaturizer( + unit_type=self.config.unit_type, + vocab=self.config.vocab_filepath) + self.config.decode.decoding_method = decode_method + model_name = model_type[:model_type.rindex( + '_')] # model_type: {model_name}_{dataset} + else: + model_name = 'wav2vec2' + model_class = self.task_resource.get_model_class(model_name) + + model_conf = self.config + model = model_class.from_config(model_conf) + self.model = model + self.model.eval() + + # load model + model_dict = paddle.load(self.ckpt_path) + if task == 'asr': + self.model.set_state_dict(model_dict) + else: + self.model.wav2vec2.set_state_dict(model_dict) + + def preprocess(self, model_type: str, input: Union[str, os.PathLike]): + """ + Input preprocess and return paddle.Tensor stored in self.input. + Input content can be a text(tts), a file(asr, cls) or a streaming(not supported yet). + """ + + audio_file = input + if isinstance(audio_file, (str, os.PathLike)): + logger.debug("Preprocess audio_file:" + audio_file) + elif isinstance(audio_file, io.BytesIO): + audio_file.seek(0) + + # Get the object for feature extraction + logger.debug("get the preprocess conf") + preprocess_conf = self.config.preprocess_config + preprocess_args = {"train": False} + preprocessing = Transformation(preprocess_conf) + logger.debug("read the audio file") + audio, audio_sample_rate = soundfile.read( + audio_file, dtype="int16", always_2d=True) + if self.change_format: + if audio.shape[1] >= 2: + audio = audio.mean(axis=1, dtype=np.int16) + else: + audio = audio[:, 0] + # pcm16 -> pcm 32 + audio = self._pcm16to32(audio) + audio = librosa.resample( + audio, orig_sr=audio_sample_rate, target_sr=self.sample_rate) + audio_sample_rate = self.sample_rate + # pcm32 -> pcm 16 + audio = self._pcm32to16(audio) + else: + audio = audio[:, 0] + + logger.debug(f"audio shape: {audio.shape}") + # fbank + audio = preprocessing(audio, **preprocess_args) + + audio_len = paddle.to_tensor(audio.shape[0]) + audio = paddle.to_tensor(audio, dtype='float32').unsqueeze(axis=0) + + self._inputs["audio"] = audio + self._inputs["audio_len"] = audio_len + logger.debug(f"audio feat shape: {audio.shape}") + + logger.debug("audio feat process success") + + @paddle.no_grad() + def infer(self, model_type: str, task: str): + """ + Model inference and result stored in self.output. + """ + logger.debug("start to infer the model to get the output") + audio = self._inputs["audio"] + if task == 'asr': + cfg = self.config.decode + logger.debug( + f"we will use the wav2vec2ASR like model : {model_type}") + try: + result_transcripts = self.model.decode( + audio, + text_feature=self.text_feature, + decoding_method=cfg.decoding_method, + beam_size=cfg.beam_size) + self._outputs["result"] = result_transcripts[0][0] + except Exception as e: + logger.exception(e) + else: + logger.debug( + "we will use the wav2vec2 like model to extract audio feature") + try: + out_feature = self.model(audio[:, :, 0]) + self._outputs["result"] = out_feature[0] + except Exception as e: + logger.exception(e) + + def postprocess(self) -> Union[str, os.PathLike]: + """ + Output postprocess and return human-readable results such as texts and audio files. + """ + return self._outputs["result"] + + def _pcm16to32(self, audio): + assert (audio.dtype == np.int16) + audio = audio.astype("float32") + bits = np.iinfo(np.int16).bits + audio = audio / (2**(bits - 1)) + return audio + + def _pcm32to16(self, audio): + assert (audio.dtype == np.float32) + bits = np.iinfo(np.int16).bits + audio = audio * (2**(bits - 1)) + audio = np.round(audio).astype("int16") + return audio + + def _check(self, audio_file: str, sample_rate: int, force_yes: bool=False): + self.sample_rate = sample_rate + if self.sample_rate != 16000 and self.sample_rate != 8000: + logger.error( + "invalid sample rate, please input --sr 8000 or --sr 16000") + return False + + if isinstance(audio_file, (str, os.PathLike)): + if not os.path.isfile(audio_file): + logger.error("Please input the right audio file path") + return False + elif isinstance(audio_file, io.BytesIO): + audio_file.seek(0) + + logger.debug("checking the audio file format......") + try: + audio, audio_sample_rate = soundfile.read( + audio_file, dtype="int16", always_2d=True) + audio_duration = audio.shape[0] / audio_sample_rate + if audio_duration > self.max_len: + logger.error( + f"Please input audio file less then {self.max_len} seconds.\n" + ) + return False + except Exception as e: + logger.exception(e) + logger.error( + f"can not open the audio file, please check the audio file({audio_file}) format is 'wav'. \n \ + you can try to use sox to change the file format.\n \ + For example: \n \ + sample rate: 16k \n \ + sox input_audio.xx --rate 16k --bits 16 --channels 1 output_audio.wav \n \ + sample rate: 8k \n \ + sox input_audio.xx --rate 8k --bits 16 --channels 1 output_audio.wav \n \ + ") + return False + logger.debug("The sample rate is %d" % audio_sample_rate) + if audio_sample_rate != self.sample_rate: + logger.warning("The sample rate of the input file is not {}.\n \ + The program will resample the wav file to {}.\n \ + If the result does not meet your expectations,\n \ + Please input the 16k 16 bit 1 channel wav file. \ + ".format(self.sample_rate, self.sample_rate)) + if force_yes is False: + while (True): + logger.debug( + "Whether to change the sample rate and the channel. Y: change the sample. N: exit the prgream." + ) + content = input("Input(Y/N):") + if content.strip() == "Y" or content.strip( + ) == "y" or content.strip() == "yes" or content.strip( + ) == "Yes": + logger.debug( + "change the sampele rate, channel to 16k and 1 channel" + ) + break + elif content.strip() == "N" or content.strip( + ) == "n" or content.strip() == "no" or content.strip( + ) == "No": + logger.debug("Exit the program") + return False + else: + logger.warning("Not regular input, please input again") + + self.change_format = True + else: + logger.debug("The audio file format is right") + self.change_format = False + + return True + + def execute(self, argv: List[str]) -> bool: + """ + Command line entry. + """ + parser_args = self.parser.parse_args(argv) + + model = parser_args.model + task = parser_args.task + lang = parser_args.lang + sample_rate = parser_args.sample_rate + config = parser_args.config + ckpt_path = parser_args.ckpt_path + decode_method = parser_args.decode_method + force_yes = parser_args.yes + rtf = parser_args.rtf + device = parser_args.device + + if not parser_args.verbose: + self.disable_task_loggers() + + task_source = self.get_input_source(parser_args.input) + task_results = OrderedDict() + has_exceptions = False + + for id_, input_ in task_source.items(): + try: + res = self( + audio_file=input_, + model=model, + task=task, + lang=lang, + sample_rate=sample_rate, + config=config, + ckpt_path=ckpt_path, + decode_method=decode_method, + force_yes=force_yes, + rtf=rtf, + device=device) + task_results[id_] = res + + except Exception as e: + has_exceptions = True + task_results[id_] = f'{e.__class__.__name__}: {e}' + + if rtf: + self.show_rtf(CLI_TIMER[self.__class__.__name__]) + self.process_task_results(parser_args.input, task_results, + parser_args.job_dump_result) + if has_exceptions: + return False + else: + return True + + @stats_wrapper + def __call__(self, + audio_file: os.PathLike, + model: str='wav2vec2ASR_librispeech', + task: str='asr', + lang: str='en', + sample_rate: int=16000, + config: os.PathLike=None, + ckpt_path: os.PathLike=None, + decode_method: str='ctc_greedy_search', + force_yes: bool=False, + rtf: bool=False, + device=paddle.get_device()): + """ + Python API to call an executor. + """ + + audio_file = os.path.abspath(audio_file) + paddle.set_device(device) + self._init_from_path(model, task, lang, sample_rate, config, + decode_method, ckpt_path) + if not self._check(audio_file, sample_rate, force_yes): + sys.exit(-1) + if rtf: + k = self.__class__.__name__ + CLI_TIMER[k]['start'].append(time.time()) + self.preprocess(model, audio_file) + self.infer(model, task) + res = self.postprocess() # Retrieve result of asr. + + if rtf: + CLI_TIMER[k]['end'].append(time.time()) + audio, audio_sample_rate = soundfile.read( + audio_file, dtype="int16", always_2d=True) + CLI_TIMER[k]['extra'].append(audio.shape[0] / audio_sample_rate) + + return res diff --git a/paddlespeech/resource/model_alias.py b/paddlespeech/resource/model_alias.py index 8e9ecc4ba..5de8f55a1 100644 --- a/paddlespeech/resource/model_alias.py +++ b/paddlespeech/resource/model_alias.py @@ -18,6 +18,12 @@ __all__ = [ # Records of model name to import class model_alias = { + # --------------------------------- + # -------------- SSL -------------- + # --------------------------------- + "wav2vec2ASR": ["paddlespeech.s2t.models.wav2vec2:Wav2vec2ASR"], + "wav2vec2": ["paddlespeech.s2t.models.wav2vec2:Wav2vec2Base"], + # --------------------------------- # -------------- ASR -------------- # --------------------------------- diff --git a/paddlespeech/resource/pretrained_models.py b/paddlespeech/resource/pretrained_models.py index df50a6a9d..0a5095a4d 100644 --- a/paddlespeech/resource/pretrained_models.py +++ b/paddlespeech/resource/pretrained_models.py @@ -25,6 +25,7 @@ __all__ = [ 'tts_static_pretrained_models', 'tts_onnx_pretrained_models', 'vector_dynamic_pretrained_models', + 'ssl_dynamic_pretrained_models', ] # The tags for pretrained_models should be "{model_name}[_{dataset}][-{lang}][-...]". @@ -32,6 +33,44 @@ __all__ = [ # Command line and python api use "{model_name}[_{dataset}]" as --model, usage: # "paddlespeech asr --model conformer_wenetspeech --lang zh --sr 16000 --input ./input.wav" +# --------------------------------- +# -------------- SSL -------------- +# --------------------------------- +ssl_dynamic_pretrained_models = { + "wav2vec2-en-16k": { + '1.3': { + 'url': + 'https://paddlespeech.bj.bcebos.com/s2t/librispeech/asr3/wav2vec2-large-960h-lv60-self_ckpt_1.3.0.model.tar.gz', + 'md5': + 'd00de8506ac8e67751419c932fb4d84e', + 'cfg_path': + 'model.yaml', + 'ckpt_path': + 'wav2vec2-large-960h-lv60-self', + 'model': + 'wav2vec2-large-960h-lv60-self.pdparams', + 'params': + 'wav2vec2-large-960h-lv60-self.pdparams', + }, + }, + "wav2vec2ASR_librispeech-en-16k": { + '1.3': { + 'url': + 'https://paddlespeech.bj.bcebos.com/s2t/librispeech/asr3/wav2vec2ASR-large-960h-librispeech_ckpt_1.3.1.model.tar.gz', + 'md5': + 'fee083b3463ceaaec960fce84fdba0cd', + 'cfg_path': + 'model.yaml', + 'ckpt_path': + 'exp/wav2vec2ASR/checkpoints/avg_1', + 'model': + 'exp/wav2vec2ASR/checkpoints/avg_1.pdparams', + 'params': + 'exp/wav2vec2ASR/checkpoints/avg_1.pdparams', + }, + }, +} + # --------------------------------- # -------------- ASR -------------- # --------------------------------- diff --git a/paddlespeech/resource/resource.py b/paddlespeech/resource/resource.py index 8e9914b2e..001e09ba3 100644 --- a/paddlespeech/resource/resource.py +++ b/paddlespeech/resource/resource.py @@ -22,7 +22,7 @@ from ..utils.dynamic_import import dynamic_import from ..utils.env import MODEL_HOME from .model_alias import model_alias -task_supported = ['asr', 'cls', 'st', 'text', 'tts', 'vector', 'kws'] +task_supported = ['asr', 'cls', 'st', 'text', 'tts', 'vector', 'kws', 'ssl'] model_format_supported = ['dynamic', 'static', 'onnx'] inference_mode_supported = ['online', 'offline'] @@ -108,7 +108,6 @@ class CommonTaskResource: """ assert model_name in model_alias, 'No model classes found for "{}"'.format( model_name) - ret = [] for import_path in model_alias[model_name]: ret.append(dynamic_import(import_path)) diff --git a/paddlespeech/s2t/exps/wav2vec2/__init__.py b/paddlespeech/s2t/exps/wav2vec2/__init__.py new file mode 100644 index 000000000..97043fd7b --- /dev/null +++ b/paddlespeech/s2t/exps/wav2vec2/__init__.py @@ -0,0 +1,13 @@ +# Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. diff --git a/paddlespeech/s2t/exps/wav2vec2/bin/__init__.py b/paddlespeech/s2t/exps/wav2vec2/bin/__init__.py index 185a92b8d..97043fd7b 100644 --- a/paddlespeech/s2t/exps/wav2vec2/bin/__init__.py +++ b/paddlespeech/s2t/exps/wav2vec2/bin/__init__.py @@ -1,4 +1,4 @@ -# Copyright (c) 2021 PaddlePaddle Authors. All Rights Reserved. +# Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved. # # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. diff --git a/paddlespeech/s2t/exps/wav2vec2/model.py b/paddlespeech/s2t/exps/wav2vec2/model.py index 4f6bc0c5b..368a694c8 100644 --- a/paddlespeech/s2t/exps/wav2vec2/model.py +++ b/paddlespeech/s2t/exps/wav2vec2/model.py @@ -269,6 +269,10 @@ class Wav2Vec2ASRTrainer(Trainer): model = Wav2vec2ASR.from_config(model_conf) + # load pretrained wav2vec2 model params + wav2vec2_dict = paddle.load(config.wav2vec2_params_path) + model.wav2vec2.set_state_dict(wav2vec2_dict) + if self.parallel: model = paddle.DataParallel(model, find_unused_parameters=True) diff --git a/paddlespeech/s2t/models/wav2vec2/__init__.py b/paddlespeech/s2t/models/wav2vec2/__init__.py index e69de29bb..3a12a9cf3 100644 --- a/paddlespeech/s2t/models/wav2vec2/__init__.py +++ b/paddlespeech/s2t/models/wav2vec2/__init__.py @@ -0,0 +1,17 @@ +# Copyright (c) 2021 PaddlePaddle Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +from .wav2vec2_ASR import Wav2vec2ASR +from .wav2vec2_ASR import Wav2vec2Base + +__all__ = ["Wav2vec2ASR", "Wav2vec2Base"] diff --git a/paddlespeech/s2t/models/wav2vec2/modules/__init__.py b/paddlespeech/s2t/models/wav2vec2/modules/__init__.py new file mode 100644 index 000000000..97043fd7b --- /dev/null +++ b/paddlespeech/s2t/models/wav2vec2/modules/__init__.py @@ -0,0 +1,13 @@ +# Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. diff --git a/paddlespeech/s2t/models/wav2vec2/processing/__init__.py b/paddlespeech/s2t/models/wav2vec2/processing/__init__.py new file mode 100644 index 000000000..97043fd7b --- /dev/null +++ b/paddlespeech/s2t/models/wav2vec2/processing/__init__.py @@ -0,0 +1,13 @@ +# Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. diff --git a/paddlespeech/s2t/models/wav2vec2/wav2vec2_ASR.py b/paddlespeech/s2t/models/wav2vec2/wav2vec2_ASR.py index e13347740..09b254c51 100644 --- a/paddlespeech/s2t/models/wav2vec2/wav2vec2_ASR.py +++ b/paddlespeech/s2t/models/wav2vec2/wav2vec2_ASR.py @@ -34,8 +34,6 @@ class Wav2vec2ASR(nn.Layer): wav2vec2_config = Wav2Vec2ConfigPure(config) wav2vec2 = Wav2Vec2Model(wav2vec2_config) - model_dict = paddle.load(config.wav2vec2_params_path) - wav2vec2.set_state_dict(model_dict) self.normalize_wav = config.normalize_wav self.output_norm = config.output_norm if config.freeze_wav2vec2: @@ -239,3 +237,33 @@ class Wav2vec2ASR(nn.Layer): """ hyps = self._ctc_prefix_beam_search(wav, beam_size) return hyps[0][0] + + +class Wav2vec2Base(nn.Layer): + """Wav2vec2 model""" + + def __init__(self, config: dict): + super().__init__() + wav2vec2_config = Wav2Vec2ConfigPure(config) + wav2vec2 = Wav2Vec2Model(wav2vec2_config) + self.wav2vec2 = wav2vec2 + + @classmethod + def from_config(cls, configs: dict): + """init model. + + Args: + configs (dict): config dict. + + Raises: + ValueError: raise when using not support encoder type. + + Returns: + nn.Layer: Wav2Vec2Base + """ + model = cls(configs) + return model + + def forward(self, wav): + out = self.wav2vec2(wav) + return out diff --git a/tests/unit/cli/test_cli.sh b/tests/unit/cli/test_cli.sh index d571aa78f..dcf376e07 100755 --- a/tests/unit/cli/test_cli.sh +++ b/tests/unit/cli/test_cli.sh @@ -9,6 +9,10 @@ paddlespeech cls --input ./cat.wav --topk 10 # Punctuation_restoration paddlespeech text --input 今天的天气真不错啊你下午有空吗我想约你一起去吃饭 --model ernie_linear_p3_wudao_fast +# Speech SSL +paddlespeech ssl --task asr --lang en --input ./en.wav +paddlespeech ssl --task vector --lang en --input ./en.wav + # Speech_recognition wget -c https://paddlespeech.bj.bcebos.com/PaddleAudio/zh.wav https://paddlespeech.bj.bcebos.com/PaddleAudio/en.wav paddlespeech asr --input ./zh.wav