diff --git a/demos/speech_ssl/README.md b/demos/speech_ssl/README.md new file mode 100755 index 000000000..fdef37e7b --- /dev/null +++ b/demos/speech_ssl/README.md @@ -0,0 +1,102 @@ +([简体中文](./README_cn.md)|English) +# Speech SSL (Self-Supervised Learning) + +## Introduction +Speech SSL, or Self-Supervised Learning, refers to a training method on the large-scale unlabeled speech dataset. The model trained in this way can produce a good acoustic representation, and can be applied to other downstream speech tasks by fine-tuning on labeled datasets. + +This demo is an implementation to recognize text or produce the acoustic representation from a specific audio file by speech ssl models. It can be done by a single command or a few lines in python using `PaddleSpeech`. + +## Usage +### 1. Installation +see [installation](https://github.com/PaddlePaddle/PaddleSpeech/blob/develop/docs/source/install.md). + +You can choose one way from easy, meduim and hard to install paddlespeech. + +### 2. Prepare Input File +The input of this demo should be a WAV file(`.wav`), and the sample rate must be the same as the model. + +Here are sample files for this demo that can be downloaded: +```bash +wget -c https://paddlespeech.bj.bcebos.com/PaddleAudio/en.wav +``` + +### 3. Usage +- Command Line(Recommended) + ```bash + # to recognize text + paddlespeech ssl --task asr --lang en --input ./en.wav + + # to get acoustic representation + paddlespeech ssl --task vector --lang en --input ./en.wav + ``` + + Usage: + ```bash + paddlespeech ssl --help + ``` + Arguments: + - `input`(required): Audio file to recognize. + - `model`: Model type of asr task. Default: `wav2vec2ASR_librispeech`. + - `task`: Output type. Default: `asr`. + - `lang`: Model language. Default: `en`. + - `sample_rate`: Sample rate of the model. Default: `16000`. + - `config`: Config of asr task. Use pretrained model when it is None. Default: `None`. + - `ckpt_path`: Model checkpoint. Use pretrained model when it is None. Default: `None`. + - `yes`: No additional parameters required. Once set this parameter, it means accepting the request of the program by default, which includes transforming the audio sample rate. Default: `False`. + - `device`: Choose device to execute model inference. Default: default device of paddlepaddle in current environment. + - `verbose`: Show the log information. + + +- Python API + ```python + import paddle + from paddlespeech.cli.ssl import SSLExecutor + + ssl_executor = SSLExecutor() + + # to recognize text + text = ssl_executor( + model='wav2vec2ASR_librispeech', + task='asr', + lang='en', + sample_rate=16000, + config=None, # Set `config` and `ckpt_path` to None to use pretrained model. + ckpt_path=None, + audio_file='./en.wav', + device=paddle.get_device()) + print('ASR Result: \n{}'.format(text)) + + # to get acoustic representation + feature = ssl_executor( + model='wav2vec2', + task='vector', + lang='en', + sample_rate=16000, + config=None, # Set `config` and `ckpt_path` to None to use pretrained model. + ckpt_path=None, + audio_file='./en.wav', + device=paddle.get_device()) + print('Representation: \n{}'.format(feature)) + ``` + + Output: + ```bash + ASR Result: + 我认为跑步最重要的就是给我带来了身体健康 + + Representation: + Tensor(shape=[1, 164, 1024], dtype=float32, place=Place(gpu:0), stop_gradient=True, + [[[ 0.02351918, -0.12980647, 0.17868176, ..., 0.10118122, + -0.04614586, 0.17853957], + [ 0.02361383, -0.12978461, 0.17870593, ..., 0.10103855, + -0.04638699, 0.17855372], + [ 0.02345137, -0.12982975, 0.17883906, ..., 0.10104341, + -0.04643029, 0.17856732], + ..., + [ 0.02313030, -0.12918393, 0.17845058, ..., 0.10073373, + -0.04701405, 0.17862988], + [ 0.02176583, -0.12929161, 0.17797582, ..., 0.10097728, + -0.04687393, 0.17864393], + [ 0.05269200, 0.01297141, -0.23336855, ..., -0.11257174, + -0.17227529, 0.20338398]]]) + ``` diff --git a/demos/speech_ssl/README_cn.md b/demos/speech_ssl/README_cn.md new file mode 100755 index 000000000..8e95f6a4d --- /dev/null +++ b/demos/speech_ssl/README_cn.md @@ -0,0 +1,103 @@ +(简体中文|[English](./README.md)) + +# 语音自监督学习 +## 介绍 +语音自监督学习,指的是在大规模无标记的语音数据集上的训练方法。用这种方法训练出来的模型可以产生很好的声学表征。并且可以通过在有标签的数据集上进行微调,应用于其他下游的语音任务。 + +这个 demo 是通过语音自监督模型将一个特定的音频文件识别成文本或产生声学表征,它可以通过使用 `PaddleSpeech` 的单个命令或 python 中的几行代码来实现。 + +## 使用方法 +### 1. 安装 +请看[安装文档](https://github.com/PaddlePaddle/PaddleSpeech/blob/develop/docs/source/install_cn.md)。 + +你可以从 easy,medium,hard 三中方式中选择一种方式安装。 + +### 2. 准备输入 +这个 demo 的输入应该是一个 WAV 文件(`.wav`),并且采样率必须与模型的采样率相同。 + +可以下载此 demo 的示例音频: +```bash +wget -c https://paddlespeech.bj.bcebos.com/PaddleAudio/en.wav +``` +### 3. 使用方法 +- 命令行 (推荐使用) + ```bash + + # 识别文本 + paddlespeech ssl --task asr --lang en --input ./en.wav + + # 产生声学表征 + paddlespeech ssl --task vector --lang en --input ./en.wav + ``` + + 使用方法: + ```bash + paddlespeech asr --help + ``` + 参数: + - `input`(必须输入):用于识别的音频文件。 + - `model`:ASR 任务的模型,默认值:`conformer_wenetspeech`。 + - `task`:输出类别,默认值:`asr`。 + - `lang`:模型语言,默认值:`zh`。 + - `sample_rate`:音频采样率,默认值:`16000`。 + - `config`:ASR 任务的参数文件,若不设置则使用预训练模型中的默认配置,默认值:`None`。 + - `ckpt_path`:模型参数文件,若不设置则下载预训练模型使用,默认值:`None`。 + - `yes`;不需要设置额外的参数,一旦设置了该参数,说明你默认同意程序的所有请求,其中包括自动转换输入音频的采样率。默认值:`False`。 + - `device`:执行预测的设备,默认值:当前系统下 paddlepaddle 的默认 device。 + - `verbose`: 如果使用,显示 logger 信息。 + + +- Python API + ```python + import paddle + from paddlespeech.cli.ssl import SSLExecutor + + ssl_executor = SSLExecutor() + + # 识别文本 + text = ssl_executor( + model='wav2vec2ASR_librispeech', + task='asr', + lang='en', + sample_rate=16000, + config=None, # Set `config` and `ckpt_path` to None to use pretrained model. + ckpt_path=None, + audio_file='./en.wav', + device=paddle.get_device()) + print('ASR Result: \n{}'.format(text)) + + # 得到声学表征 + feature = ssl_executor( + model='wav2vec2', + task='vector', + lang='en', + sample_rate=16000, + config=None, # Set `config` and `ckpt_path` to None to use pretrained model. + ckpt_path=None, + audio_file='./en.wav', + device=paddle.get_device()) + print('Representation: \n{}'.format(feature)) + ``` + + + 输出: + ```bash + ASR Result: + 我认为跑步最重要的就是给我带来了身体健康 + + Representation: + Tensor(shape=[1, 164, 1024], dtype=float32, place=Place(gpu:0), stop_gradient=True, + [[[ 0.02351918, -0.12980647, 0.17868176, ..., 0.10118122, + -0.04614586, 0.17853957], + [ 0.02361383, -0.12978461, 0.17870593, ..., 0.10103855, + -0.04638699, 0.17855372], + [ 0.02345137, -0.12982975, 0.17883906, ..., 0.10104341, + -0.04643029, 0.17856732], + ..., + [ 0.02313030, -0.12918393, 0.17845058, ..., 0.10073373, + -0.04701405, 0.17862988], + [ 0.02176583, -0.12929161, 0.17797582, ..., 0.10097728, + -0.04687393, 0.17864393], + [ 0.05269200, 0.01297141, -0.23336855, ..., -0.11257174, + -0.17227529, 0.20338398]]]) + ``` \ No newline at end of file diff --git a/demos/speech_ssl/run.sh b/demos/speech_ssl/run.sh new file mode 100755 index 000000000..204ccc826 --- /dev/null +++ b/demos/speech_ssl/run.sh @@ -0,0 +1,11 @@ +#!/bin/bash + +# audio download +wget -c https://paddlespeech.bj.bcebos.com/PaddleAudio/zh.wav https://paddlespeech.bj.bcebos.com/PaddleAudio/en.wav + +# to recognize text +paddlespeech ssl --task asr --lang en --input ./en.wav + +# to get acoustic representation +paddlespeech ssl --task vector --lang en --input ./en.wav + diff --git a/paddlespeech/cli/ssl/__init__.py b/paddlespeech/cli/ssl/__init__.py new file mode 100755 index 000000000..2e53128ea --- /dev/null +++ b/paddlespeech/cli/ssl/__init__.py @@ -0,0 +1,14 @@ +# Copyright (c) 2021 PaddlePaddle Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +from .infer import SSLExecutor diff --git a/paddlespeech/cli/ssl/infer.py b/paddlespeech/cli/ssl/infer.py new file mode 100755 index 000000000..4bc8a5074 --- /dev/null +++ b/paddlespeech/cli/ssl/infer.py @@ -0,0 +1,451 @@ +# Copyright (c) 2021 PaddlePaddle Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +import argparse +import io +import os +import sys +import time +from collections import OrderedDict +from typing import List +from typing import Optional +from typing import Union + +import librosa +import numpy as np +import paddle +import soundfile +from yacs.config import CfgNode + +from ...utils.env import MODEL_HOME +from ..download import get_path_from_url +from ..executor import BaseExecutor +from ..log import logger +from ..utils import CLI_TIMER +from ..utils import stats_wrapper +from ..utils import timer_register +from paddlespeech.audio.transform.transformation import Transformation +from paddlespeech.s2t.frontend.featurizer.text_featurizer import TextFeaturizer +from paddlespeech.s2t.utils.utility import UpdateConfig + +__all__ = ['SSLExecutor'] + + +@timer_register +class SSLExecutor(BaseExecutor): + def __init__(self): + super().__init__('ssl') + self.parser = argparse.ArgumentParser( + prog='paddlespeech.ssl', add_help=True) + self.parser.add_argument( + '--input', type=str, default=None, help='Audio file to recognize.') + self.parser.add_argument( + '--model', + type=str, + default='wav2vec2ASR_librispeech', + choices=[ + tag[:tag.index('-')] + for tag in self.task_resource.pretrained_models.keys() + ], + help='Choose model type of asr task.') + self.parser.add_argument( + '--task', + type=str, + default='asr', + choices=['asr', 'vector'], + help='Choose output type for ssl task') + self.parser.add_argument( + '--lang', + type=str, + default='en', + help='Choose model language. zh or en, zh:[wav2vec2ASR_aishell1-zh-16k], en:[wav2vec2ASR_librispeech-en-16k]' + ) + self.parser.add_argument( + "--sample_rate", + type=int, + default=16000, + choices=[8000, 16000], + help='Choose the audio sample rate of the model. 8000 or 16000') + self.parser.add_argument( + '--config', + type=str, + default=None, + help='Config of asr task. Use deault config when it is None.') + self.parser.add_argument( + '--decode_method', + type=str, + default='ctc_greedy_search', + choices=[ + 'ctc_greedy_search', 'ctc_prefix_beam_search', + ], + help='only support asr task') + self.parser.add_argument( + '--ckpt_path', + type=str, + default=None, + help='Checkpoint file of model.') + self.parser.add_argument( + '--yes', + '-y', + action="store_true", + default=False, + help='No additional parameters required. \ + Once set this parameter, it means accepting the request of the program by default, \ + which includes transforming the audio sample rate') + self.parser.add_argument( + '--rtf', + action="store_true", + default=False, + help='Show Real-time Factor(RTF).') + self.parser.add_argument( + '--device', + type=str, + default=paddle.get_device(), + help='Choose device to execute model inference.') + self.parser.add_argument( + '-d', + '--job_dump_result', + action='store_true', + help='Save job result into file.') + self.parser.add_argument( + '-v', + '--verbose', + action='store_true', + help='Increase logger verbosity of current task.') + + def _init_from_path(self, + model_type: str='wav2vec2ASR_librispeech', + task: str='asr', + lang: str='en', + sample_rate: int=16000, + cfg_path: Optional[os.PathLike]=None, + decode_method: str='ctc_greedy_search', + ckpt_path: Optional[os.PathLike]=None): + """ + Init model and other resources from a specific path. + """ + logger.debug("start to init the model") + # default max_len: unit:second + self.max_len = 50 + if hasattr(self, 'model'): + logger.debug('Model had been initialized.') + return + if cfg_path is None or ckpt_path is None: + sample_rate_str = '16k' if sample_rate == 16000 else '8k' + if task == 'asr': + tag = model_type + '-' + lang + '-' + sample_rate_str + else: + tag = 'wav2vec2' + '-' + lang + '-' + sample_rate_str + self.task_resource.set_task_model(tag, version=None) + self.res_path = self.task_resource.res_dir + + self.cfg_path = os.path.join( + self.res_path, self.task_resource.res_dict['cfg_path']) + self.ckpt_path = os.path.join( + self.res_path, + self.task_resource.res_dict['ckpt_path'] + ".pdparams") + logger.debug(self.res_path) + else: + self.cfg_path = os.path.abspath(cfg_path) + self.ckpt_path = os.path.abspath(ckpt_path + ".pdparams") + self.res_path = os.path.dirname( + os.path.dirname(os.path.abspath(self.cfg_path))) + logger.debug(self.cfg_path) + logger.debug(self.ckpt_path) + + #Init body. + self.config = CfgNode(new_allowed=True) + self.config.merge_from_file(self.cfg_path) + if task == 'asr': + with UpdateConfig(self.config): + self.text_feature = TextFeaturizer( + unit_type=self.config.unit_type, + vocab=self.config.vocab_filepath) + self.config.decode.decoding_method = decode_method + model_name = model_type[:model_type.rindex( + '_')] # model_type: {model_name}_{dataset} + else: + model_name = 'wav2vec2' + model_class = self.task_resource.get_model_class(model_name) + + model_conf = self.config + model = model_class.from_config(model_conf) + self.model = model + self.model.eval() + + # load model + model_dict = paddle.load(self.ckpt_path) + if task == 'asr': + self.model.set_state_dict(model_dict) + else: + self.model.wav2vec2.set_state_dict(model_dict) + + def preprocess(self, model_type: str, input: Union[str, os.PathLike]): + """ + Input preprocess and return paddle.Tensor stored in self.input. + Input content can be a text(tts), a file(asr, cls) or a streaming(not supported yet). + """ + + audio_file = input + if isinstance(audio_file, (str, os.PathLike)): + logger.debug("Preprocess audio_file:" + audio_file) + elif isinstance(audio_file, io.BytesIO): + audio_file.seek(0) + + # Get the object for feature extraction + logger.debug("get the preprocess conf") + preprocess_conf = self.config.preprocess_config + preprocess_args = {"train": False} + preprocessing = Transformation(preprocess_conf) + logger.debug("read the audio file") + audio, audio_sample_rate = soundfile.read( + audio_file, dtype="int16", always_2d=True) + if self.change_format: + if audio.shape[1] >= 2: + audio = audio.mean(axis=1, dtype=np.int16) + else: + audio = audio[:, 0] + # pcm16 -> pcm 32 + audio = self._pcm16to32(audio) + audio = librosa.resample( + audio, + orig_sr=audio_sample_rate, + target_sr=self.sample_rate) + audio_sample_rate = self.sample_rate + # pcm32 -> pcm 16 + audio = self._pcm32to16(audio) + else: + audio = audio[:, 0] + + logger.debug(f"audio shape: {audio.shape}") + # fbank + audio = preprocessing(audio, **preprocess_args) + + audio_len = paddle.to_tensor(audio.shape[0]) + audio = paddle.to_tensor(audio, dtype='float32').unsqueeze(axis=0) + + self._inputs["audio"] = audio + self._inputs["audio_len"] = audio_len + logger.debug(f"audio feat shape: {audio.shape}") + + logger.debug("audio feat process success") + + @paddle.no_grad() + def infer(self, model_type: str, task: str): + """ + Model inference and result stored in self.output. + """ + logger.debug("start to infer the model to get the output") + audio = self._inputs["audio"] + if task == 'asr': + cfg = self.config.decode + logger.debug( + f"we will use the wav2vec2ASR like model : {model_type}") + try: + result_transcripts = self.model.decode( + audio, + text_feature=self.text_feature, + decoding_method=cfg.decoding_method, + beam_size=cfg.beam_size) + self._outputs["result"] = result_transcripts[0][0] + except Exception as e: + logger.exception(e) + else: + logger.debug( + f"we will use the wav2vec2 like model to extract audio feature") + try: + out_feature = self.model(audio[:, :, 0]) + self._outputs["result"] = out_feature[0] + except Exception as e: + logger.exception(e) + + def postprocess(self) -> Union[str, os.PathLike]: + """ + Output postprocess and return human-readable results such as texts and audio files. + """ + return self._outputs["result"] + + def _pcm16to32(self, audio): + assert (audio.dtype == np.int16) + audio = audio.astype("float32") + bits = np.iinfo(np.int16).bits + audio = audio / (2**(bits - 1)) + return audio + + def _pcm32to16(self, audio): + assert (audio.dtype == np.float32) + bits = np.iinfo(np.int16).bits + audio = audio * (2**(bits - 1)) + audio = np.round(audio).astype("int16") + return audio + + def _check(self, audio_file: str, sample_rate: int, force_yes: bool=False): + self.sample_rate = sample_rate + if self.sample_rate != 16000 and self.sample_rate != 8000: + logger.error( + "invalid sample rate, please input --sr 8000 or --sr 16000") + return False + + if isinstance(audio_file, (str, os.PathLike)): + if not os.path.isfile(audio_file): + logger.error("Please input the right audio file path") + return False + elif isinstance(audio_file, io.BytesIO): + audio_file.seek(0) + + logger.debug("checking the audio file format......") + try: + audio, audio_sample_rate = soundfile.read( + audio_file, dtype="int16", always_2d=True) + audio_duration = audio.shape[0] / audio_sample_rate + if audio_duration > self.max_len: + logger.error( + f"Please input audio file less then {self.max_len} seconds.\n" + ) + return False + except Exception as e: + logger.exception(e) + logger.error( + f"can not open the audio file, please check the audio file({audio_file}) format is 'wav'. \n \ + you can try to use sox to change the file format.\n \ + For example: \n \ + sample rate: 16k \n \ + sox input_audio.xx --rate 16k --bits 16 --channels 1 output_audio.wav \n \ + sample rate: 8k \n \ + sox input_audio.xx --rate 8k --bits 16 --channels 1 output_audio.wav \n \ + ") + return False + logger.debug("The sample rate is %d" % audio_sample_rate) + if audio_sample_rate != self.sample_rate: + logger.warning("The sample rate of the input file is not {}.\n \ + The program will resample the wav file to {}.\n \ + If the result does not meet your expectations,\n \ + Please input the 16k 16 bit 1 channel wav file. \ + ".format(self.sample_rate, self.sample_rate)) + if force_yes is False: + while (True): + logger.debug( + "Whether to change the sample rate and the channel. Y: change the sample. N: exit the prgream." + ) + content = input("Input(Y/N):") + if content.strip() == "Y" or content.strip( + ) == "y" or content.strip() == "yes" or content.strip( + ) == "Yes": + logger.debug( + "change the sampele rate, channel to 16k and 1 channel" + ) + break + elif content.strip() == "N" or content.strip( + ) == "n" or content.strip() == "no" or content.strip( + ) == "No": + logger.debug("Exit the program") + return False + else: + logger.warning("Not regular input, please input again") + + self.change_format = True + else: + logger.debug("The audio file format is right") + self.change_format = False + + return True + + def execute(self, argv: List[str]) -> bool: + """ + Command line entry. + """ + parser_args = self.parser.parse_args(argv) + + model = parser_args.model + task = parser_args.task + lang = parser_args.lang + sample_rate = parser_args.sample_rate + config = parser_args.config + ckpt_path = parser_args.ckpt_path + decode_method = parser_args.decode_method + force_yes = parser_args.yes + rtf = parser_args.rtf + device = parser_args.device + + if not parser_args.verbose: + self.disable_task_loggers() + + task_source = self.get_input_source(parser_args.input) + task_results = OrderedDict() + has_exceptions = False + + for id_, input_ in task_source.items(): + try: + res = self( + audio_file=input_, + model=model, + task=task, + lang=lang, + sample_rate=sample_rate, + config=config, + ckpt_path=ckpt_path, + decode_method=decode_method, + force_yes=force_yes, + rtf=rtf, + device=device) + task_results[id_] = res + + except Exception as e: + has_exceptions = True + task_results[id_] = f'{e.__class__.__name__}: {e}' + + if rtf: + self.show_rtf(CLI_TIMER[self.__class__.__name__]) + self.process_task_results(parser_args.input, task_results, + parser_args.job_dump_result) + if has_exceptions: + return False + else: + return True + + @stats_wrapper + def __call__(self, + audio_file: os.PathLike, + model: str='wav2vec2ASR_librispeech', + task: str='asr', + lang: str='en', + sample_rate: int=16000, + config: os.PathLike=None, + ckpt_path: os.PathLike=None, + decode_method: str='ctc_greedy_search', + force_yes: bool=False, + rtf: bool=False, + device=paddle.get_device()): + """ + Python API to call an executor. + """ + + audio_file = os.path.abspath(audio_file) + paddle.set_device(device) + self._init_from_path(model, task, lang, sample_rate, config, decode_method, ckpt_path) + if not self._check(audio_file, sample_rate, force_yes): + sys.exit(-1) + if rtf: + k = self.__class__.__name__ + CLI_TIMER[k]['start'].append(time.time()) + self.preprocess(model, audio_file) + self.infer(model, task) + res = self.postprocess() # Retrieve result of asr. + + if rtf: + CLI_TIMER[k]['end'].append(time.time()) + audio, audio_sample_rate = soundfile.read( + audio_file, dtype="int16", always_2d=True) + CLI_TIMER[k]['extra'].append(audio.shape[0] / audio_sample_rate) + + return res diff --git a/paddlespeech/s2t/models/wav2vec2/modules/__init__.py b/paddlespeech/s2t/models/wav2vec2/modules/__init__.py new file mode 100755 index 000000000..97043fd7b --- /dev/null +++ b/paddlespeech/s2t/models/wav2vec2/modules/__init__.py @@ -0,0 +1,13 @@ +# Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. diff --git a/paddlespeech/s2t/models/wav2vec2/modules/normalization.py b/paddlespeech/s2t/models/wav2vec2/modules/normalization.py new file mode 100755 index 000000000..7716c755a --- /dev/null +++ b/paddlespeech/s2t/models/wav2vec2/modules/normalization.py @@ -0,0 +1,104 @@ +# Authors +# * Mirco Ravanelli 2020 +# * Guillermo Cámbara 2021 +# * Sarthak Yadav 2022 +# Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# Modified from speechbrain(https://github.com/speechbrain/speechbrain/blob/develop/speechbrain/nnet/normalization.py) + + +import paddle +import paddle.nn as nn +from paddlespeech.s2t.modules.align import BatchNorm1D + +class BatchNorm1d(nn.Layer): + """Applies 1d batch normalization to the input tensor. + Arguments + --------- + input_shape : tuple + The expected shape of the input. Alternatively, use ``input_size``. + input_size : int + The expected size of the input. Alternatively, use ``input_shape``. + eps : float + This value is added to std deviation estimation to improve the numerical + stability. + momentum : float + It is a value used for the running_mean and running_var computation. + affine : bool + When set to True, the affine parameters are learned. + track_running_stats : bool + When set to True, this module tracks the running mean and variance, + and when set to False, this module does not track such statistics. + combine_batch_time : bool + When true, it combines batch an time axis. + Example + ------- + >>> input = paddle.randn([100, 10]) + >>> norm = BatchNorm1d(input_shape=input.shape) + >>> output = norm(input) + >>> output.shape + Paddle.Shape([100, 10]) + """ + + def __init__( + self, + input_shape=None, + input_size=None, + eps=1e-05, + momentum=0.9, + combine_batch_time=False, + skip_transpose=False, + ): + super().__init__() + self.combine_batch_time = combine_batch_time + self.skip_transpose = skip_transpose + + if input_size is None and skip_transpose: + input_size = input_shape[1] + elif input_size is None: + input_size = input_shape[-1] + + self.norm = nn.BatchNorm1D( + input_size, + momentum=momentum, + epsilon=eps + ) + + def forward(self, x): + """Returns the normalized input tensor. + Arguments + --------- + x : paddle.Tensor (batch, time, [channels]) + input to normalize. 2d or 3d tensors are expected in input + 4d tensors can be used when combine_dims=True. + """ + shape_or = x.shape + if self.combine_batch_time: + if x.ndim == 3: + x = x.reshape(shape_or[0] * shape_or[1], shape_or[2]) + else: + x = x.reshape( + shape_or[0] * shape_or[1], shape_or[3], shape_or[2] + ) + + elif not self.skip_transpose: + x = x.transpose([0, 2, 1]) + + x_n = self.norm(x) + if self.combine_batch_time: + x_n = x_n.reshape(shape_or) + elif not self.skip_transpose: + x_n = x_n.transpose([0, 2, 1]) + + return x_n \ No newline at end of file diff --git a/paddlespeech/s2t/models/wav2vec2/processing/__init__.py b/paddlespeech/s2t/models/wav2vec2/processing/__init__.py new file mode 100755 index 000000000..97043fd7b --- /dev/null +++ b/paddlespeech/s2t/models/wav2vec2/processing/__init__.py @@ -0,0 +1,13 @@ +# Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License.