diff --git a/README.md b/README.md index 8e338fde..f77b7da1 100644 --- a/README.md +++ b/README.md @@ -157,12 +157,12 @@ Via the easy-to-use, efficient, flexible and scalable implementation, our vision - 🧩 *Cascaded models application*: as an extension of the typical traditional audio tasks, we combine the workflows of the aforementioned tasks with other fields like Natural language processing (NLP) and Computer Vision (CV). ### Recent Update +- 🔥 2022.11.18: Add [Wav2vec2 CLI and Demos](https://github.com/PaddlePaddle/PaddleSpeech/blob/develop/demos/speech_ssl), Support ASR and Feature Extraction. - 🎉 2022.11.17: Add [male voice for TTS](https://github.com/PaddlePaddle/PaddleSpeech/pull/2660). -- 🔥 2022.11.07: Add [U2/U2++ C++ High Performance Streaming ASR Deployment](https://github.com/PaddlePaddle/PaddleSpeech/blob/develop/speechx/examples/u2pp_ol/wenetspeech). -- 👑 2022.11.01: Add [Adversarial Loss](https://arxiv.org/pdf/1907.04448.pdf) for [Chinese English mixed TTS](./examples/zh_en_tts/tts3). +- 🔥 2022.11.07: Add [U2/U2++ C++ High Performance Streaming ASR Deployment](https://github.com/PaddlePaddle/PaddleSpeech/blob/develop/speechx/examples/u2pp_ol/wenetspeech).- 👑 2022.11.01: Add [Adversarial Loss](https://arxiv.org/pdf/1907.04448.pdf) for [Chinese English mixed TTS](./examples/zh_en_tts/tts3). - 🔥 2022.10.26: Add [Prosody Prediction](https://github.com/PaddlePaddle/PaddleSpeech/tree/develop/examples/other/rhy) for TTS. - 🎉 2022.10.21: Add [SSML](https://github.com/PaddlePaddle/PaddleSpeech/discussions/2538) for TTS Chinese Text Frontend. -- 👑 2022.10.11: Add [Wav2vec2ASR](./examples/librispeech/asr3), wav2vec2.0 fine-tuning for ASR on LibriSpeech. +- 👑 2022.10.11: Add [Wav2vec2ASR-en](./examples/librispeech/asr3), wav2vec2.0 fine-tuning for ASR on LibriSpeech. - 🔥 2022.09.26: Add Voice Cloning, TTS finetune, and [ERNIE-SAT](https://arxiv.org/abs/2211.03545) in [PaddleSpeech Web Demo](./demos/speech_web). - ⚡ 2022.09.09: Add AISHELL-3 Voice Cloning [example](./examples/aishell3/vc2) with ECAPA-TDNN speaker encoder. - ⚡ 2022.08.25: Release TTS [finetune](./examples/other/tts_finetune/tts3) example. diff --git a/README_cn.md b/README_cn.md index 27b23912..eaa55005 100644 --- a/README_cn.md +++ b/README_cn.md @@ -164,12 +164,13 @@ ### 近期更新 +- 🔥 2022.11.18: 新增 [Wav2vec2 CLI 和 Demos](https://github.com/PaddlePaddle/PaddleSpeech/blob/develop/demos/speech_ssl), 支持 ASR 和 特征提取. - 🎉 2022.11.17: TTS 新增[高质量男性音色](https://github.com/PaddlePaddle/PaddleSpeech/pull/2660)。 - 🔥 2022.11.07: 新增 [U2/U2++ 高性能流式 ASR C++ 部署](https://github.com/PaddlePaddle/PaddleSpeech/blob/develop/speechx/examples/u2pp_ol/wenetspeech)。 - 👑 2022.11.01: [中英文混合 TTS](./examples/zh_en_tts/tts3) 新增 [Adversarial Loss](https://arxiv.org/pdf/1907.04448.pdf) 模块。 - 🔥 2022.10.26: TTS 新增[韵律预测](https://github.com/PaddlePaddle/PaddleSpeech/tree/develop/examples/other/rhy)功能。 - 🎉 2022.10.21: TTS 中文文本前端新增 [SSML](https://github.com/PaddlePaddle/PaddleSpeech/discussions/2538) 功能。 -- 👑 2022.10.11: 新增 [Wav2vec2ASR](./examples/librispeech/asr3), 在 LibriSpeech 上针对 ASR 任务对 wav2vec2.0 的 finetuning。 +- 👑 2022.10.11: 新增 [Wav2vec2ASR-en](./examples/librispeech/asr3), 在 LibriSpeech 上针对 ASR 任务对 wav2vec2.0 的 finetuning。 - 🔥 2022.09.26: 新增 Voice Cloning, TTS finetune 和 [ERNIE-SAT](https://arxiv.org/abs/2211.03545) 到 [PaddleSpeech 网页应用](./demos/speech_web)。 - ⚡ 2022.09.09: 新增基于 ECAPA-TDNN 声纹模型的 AISHELL-3 Voice Cloning [示例](./examples/aishell3/vc2)。 - ⚡ 2022.08.25: 发布 TTS [finetune](./examples/other/tts_finetune/tts3) 示例。 diff --git a/demos/speech_ssl/README.md b/demos/speech_ssl/README.md new file mode 100644 index 00000000..fdef37e7 --- /dev/null +++ b/demos/speech_ssl/README.md @@ -0,0 +1,102 @@ +([简体中文](./README_cn.md)|English) +# Speech SSL (Self-Supervised Learning) + +## Introduction +Speech SSL, or Self-Supervised Learning, refers to a training method on the large-scale unlabeled speech dataset. The model trained in this way can produce a good acoustic representation, and can be applied to other downstream speech tasks by fine-tuning on labeled datasets. + +This demo is an implementation to recognize text or produce the acoustic representation from a specific audio file by speech ssl models. It can be done by a single command or a few lines in python using `PaddleSpeech`. + +## Usage +### 1. Installation +see [installation](https://github.com/PaddlePaddle/PaddleSpeech/blob/develop/docs/source/install.md). + +You can choose one way from easy, meduim and hard to install paddlespeech. + +### 2. Prepare Input File +The input of this demo should be a WAV file(`.wav`), and the sample rate must be the same as the model. + +Here are sample files for this demo that can be downloaded: +```bash +wget -c https://paddlespeech.bj.bcebos.com/PaddleAudio/en.wav +``` + +### 3. Usage +- Command Line(Recommended) + ```bash + # to recognize text + paddlespeech ssl --task asr --lang en --input ./en.wav + + # to get acoustic representation + paddlespeech ssl --task vector --lang en --input ./en.wav + ``` + + Usage: + ```bash + paddlespeech ssl --help + ``` + Arguments: + - `input`(required): Audio file to recognize. + - `model`: Model type of asr task. Default: `wav2vec2ASR_librispeech`. + - `task`: Output type. Default: `asr`. + - `lang`: Model language. Default: `en`. + - `sample_rate`: Sample rate of the model. Default: `16000`. + - `config`: Config of asr task. Use pretrained model when it is None. Default: `None`. + - `ckpt_path`: Model checkpoint. Use pretrained model when it is None. Default: `None`. + - `yes`: No additional parameters required. Once set this parameter, it means accepting the request of the program by default, which includes transforming the audio sample rate. Default: `False`. + - `device`: Choose device to execute model inference. Default: default device of paddlepaddle in current environment. + - `verbose`: Show the log information. + + +- Python API + ```python + import paddle + from paddlespeech.cli.ssl import SSLExecutor + + ssl_executor = SSLExecutor() + + # to recognize text + text = ssl_executor( + model='wav2vec2ASR_librispeech', + task='asr', + lang='en', + sample_rate=16000, + config=None, # Set `config` and `ckpt_path` to None to use pretrained model. + ckpt_path=None, + audio_file='./en.wav', + device=paddle.get_device()) + print('ASR Result: \n{}'.format(text)) + + # to get acoustic representation + feature = ssl_executor( + model='wav2vec2', + task='vector', + lang='en', + sample_rate=16000, + config=None, # Set `config` and `ckpt_path` to None to use pretrained model. + ckpt_path=None, + audio_file='./en.wav', + device=paddle.get_device()) + print('Representation: \n{}'.format(feature)) + ``` + + Output: + ```bash + ASR Result: + 我认为跑步最重要的就是给我带来了身体健康 + + Representation: + Tensor(shape=[1, 164, 1024], dtype=float32, place=Place(gpu:0), stop_gradient=True, + [[[ 0.02351918, -0.12980647, 0.17868176, ..., 0.10118122, + -0.04614586, 0.17853957], + [ 0.02361383, -0.12978461, 0.17870593, ..., 0.10103855, + -0.04638699, 0.17855372], + [ 0.02345137, -0.12982975, 0.17883906, ..., 0.10104341, + -0.04643029, 0.17856732], + ..., + [ 0.02313030, -0.12918393, 0.17845058, ..., 0.10073373, + -0.04701405, 0.17862988], + [ 0.02176583, -0.12929161, 0.17797582, ..., 0.10097728, + -0.04687393, 0.17864393], + [ 0.05269200, 0.01297141, -0.23336855, ..., -0.11257174, + -0.17227529, 0.20338398]]]) + ``` diff --git a/demos/speech_ssl/README_cn.md b/demos/speech_ssl/README_cn.md new file mode 100644 index 00000000..76ec2f1f --- /dev/null +++ b/demos/speech_ssl/README_cn.md @@ -0,0 +1,103 @@ +(简体中文|[English](./README.md)) + +# 语音自监督学习 +## 介绍 +语音自监督学习,指的是在大规模无标记的语音数据集上的训练方法。用这种方法训练出来的模型可以产生很好的声学表征。并且可以通过在有标签的数据集上进行微调,应用于其他下游的语音任务。 + +这个 demo 是通过语音自监督模型将一个特定的音频文件识别成文本或产生声学表征,它可以通过使用 `PaddleSpeech` 的单个命令或 python 中的几行代码来实现。 + +## 使用方法 +### 1. 安装 +请看[安装文档](https://github.com/PaddlePaddle/PaddleSpeech/blob/develop/docs/source/install_cn.md)。 + +你可以从 easy,medium,hard 三中方式中选择一种方式安装。 + +### 2. 准备输入 +这个 demo 的输入应该是一个 WAV 文件(`.wav`),并且采样率必须与模型的采样率相同。 + +可以下载此 demo 的示例音频: +```bash +wget -c https://paddlespeech.bj.bcebos.com/PaddleAudio/en.wav +``` +### 3. 使用方法 +- 命令行 (推荐使用) + ```bash + + # 识别文本 + paddlespeech ssl --task asr --lang en --input ./en.wav + + # 产生声学表征 + paddlespeech ssl --task vector --lang en --input ./en.wav + ``` + + 使用方法: + ```bash + paddlespeech asr --help + ``` + 参数: + - `input`(必须输入):用于识别的音频文件。 + - `model`:ASR 任务的模型,默认值:`conformer_wenetspeech`。 + - `task`:输出类别,默认值:`asr`。 + - `lang`:模型语言,默认值:`zh`。 + - `sample_rate`:音频采样率,默认值:`16000`。 + - `config`:ASR 任务的参数文件,若不设置则使用预训练模型中的默认配置,默认值:`None`。 + - `ckpt_path`:模型参数文件,若不设置则下载预训练模型使用,默认值:`None`。 + - `yes`;不需要设置额外的参数,一旦设置了该参数,说明你默认同意程序的所有请求,其中包括自动转换输入音频的采样率。默认值:`False`。 + - `device`:执行预测的设备,默认值:当前系统下 paddlepaddle 的默认 device。 + - `verbose`: 如果使用,显示 logger 信息。 + + +- Python API + ```python + import paddle + from paddlespeech.cli.ssl import SSLExecutor + + ssl_executor = SSLExecutor() + + # 识别文本 + text = ssl_executor( + model='wav2vec2ASR_librispeech', + task='asr', + lang='en', + sample_rate=16000, + config=None, # Set `config` and `ckpt_path` to None to use pretrained model. + ckpt_path=None, + audio_file='./en.wav', + device=paddle.get_device()) + print('ASR Result: \n{}'.format(text)) + + # 得到声学表征 + feature = ssl_executor( + model='wav2vec2', + task='vector', + lang='en', + sample_rate=16000, + config=None, # Set `config` and `ckpt_path` to None to use pretrained model. + ckpt_path=None, + audio_file='./en.wav', + device=paddle.get_device()) + print('Representation: \n{}'.format(feature)) + ``` + + + 输出: + ```bash + ASR Result: + 我认为跑步最重要的就是给我带来了身体健康 + + Representation: + Tensor(shape=[1, 164, 1024], dtype=float32, place=Place(gpu:0), stop_gradient=True, + [[[ 0.02351918, -0.12980647, 0.17868176, ..., 0.10118122, + -0.04614586, 0.17853957], + [ 0.02361383, -0.12978461, 0.17870593, ..., 0.10103855, + -0.04638699, 0.17855372], + [ 0.02345137, -0.12982975, 0.17883906, ..., 0.10104341, + -0.04643029, 0.17856732], + ..., + [ 0.02313030, -0.12918393, 0.17845058, ..., 0.10073373, + -0.04701405, 0.17862988], + [ 0.02176583, -0.12929161, 0.17797582, ..., 0.10097728, + -0.04687393, 0.17864393], + [ 0.05269200, 0.01297141, -0.23336855, ..., -0.11257174, + -0.17227529, 0.20338398]]]) + ``` diff --git a/demos/speech_ssl/run.sh b/demos/speech_ssl/run.sh new file mode 100644 index 00000000..ca94bc5c --- /dev/null +++ b/demos/speech_ssl/run.sh @@ -0,0 +1,10 @@ +#!/bin/bash + +# audio download +wget -c https://paddlespeech.bj.bcebos.com/PaddleAudio/en.wav + +# to recognize text +paddlespeech ssl --task asr --lang en --input ./en.wav + +# to get acoustic representation +paddlespeech ssl --task vector --lang en --input ./en.wav diff --git a/examples/librispeech/asr3/RESULTS.md b/examples/librispeech/asr3/RESULTS.md index 1c5626d9..27a87e13 100644 --- a/examples/librispeech/asr3/RESULTS.md +++ b/examples/librispeech/asr3/RESULTS.md @@ -1,8 +1,8 @@ # LibriSpeech ## Wav2VecASR -train: Epoch 1, 1*V100-32G, batchsize:10 +train: Epoch 1, 1*V100-32G, batchsize: 6 | Model | Params | Config | Augmentation| Test set | Decode method | WER | | --- | --- | --- | --- | --- | --- | --- | -| wav2vec2ASR | 302.86 M | conf/wav2vec2ASR.yaml | spec_aug | test-clean | greedy search | 0.018887 | +| wav2vec2ASR | 302.86 M | conf/wav2vec2ASR.yaml | spec_aug | test-clean | greedy search | 0.018906 | diff --git a/examples/librispeech/asr3/conf/preprocess.yaml b/examples/librispeech/asr3/conf/preprocess.yaml index 4a908a83..724782ed 100644 --- a/examples/librispeech/asr3/conf/preprocess.yaml +++ b/examples/librispeech/asr3/conf/preprocess.yaml @@ -1,4 +1,3 @@ process: # use raw audio - type: wav_process - dither: 0.0 diff --git a/examples/librispeech/asr3/conf/wav2vec2ASR.yaml b/examples/librispeech/asr3/conf/wav2vec2ASR.yaml index c45bd692..1ce2d94d 100644 --- a/examples/librispeech/asr3/conf/wav2vec2ASR.yaml +++ b/examples/librispeech/asr3/conf/wav2vec2ASR.yaml @@ -4,16 +4,21 @@ freeze_wav2vec2: True normalize_wav: True output_norm: True -dnn_blocks: 2 -dnn_neurons: 1024 -blank_id: 0 -ctc_dropout_rate: 0.0 +init_type: 'kaiming_uniform' # !Warning: need to convergence +enc: + input_shape: 1024 + dnn_blocks: 2 + dnn_neurons: 1024 + activation: True +ctc: + enc_n_units: 1024 + blank_id: 0 + dropout_rate: 0.0 wav2vec2_params_path: "exp/wav2vec2/wav2vec2-large-960h-lv60-self.pdparams" ############################################ # Wav2Vec2.0 # ############################################ -vocab_size: 32 hidden_size: 1024 num_hidden_layers: 24 num_attention_heads: 16 @@ -54,9 +59,6 @@ diversity_loss_weight: 0.1 ctc_loss_reduction: "sum" ctc_zero_infinity: False use_weighted_layer_sum: False -pad_token_id: 0 -bos_token_id: 1 -eos_token_id: 2 add_adapter: False adapter_kernel_size: 3 adapter_stride: 2 @@ -78,7 +80,7 @@ unit_type: 'char' mean_std_filepath: "" preprocess_config: conf/preprocess.yaml sortagrad: -1 # Feed samples from shortest to longest ; -1: enabled for all epochs 0: disabled other: enabled for 'other' epochs -batch_size: 10 # Different batch_size may cause large differences in results +batch_size: 6 # Different batch_size may cause large differences in results maxlen_in: 51200000000 # if input length > maxlen-in batchsize is automatically reduced maxlen_out: 1500000 # if output length > maxlen-out batchsize is automatically reduced minibatches: 0 # for debug @@ -106,17 +108,26 @@ audio_augment: # for raw audio ########################################### n_epoch: 1 accum_grad: 1 -global_grad_clip: 3.0 +global_grad_clip: 5.0 model_optim: adadelta model_optim_conf: lr: 0.9 epsilon: 1.0e-6 rho: 0.95 -scheduler: constantlr -scheduler_conf: +model_scheduler: constantlr +model_scheduler_conf: + warmup_steps: 25000 + lr_decay: 1.0 +wav2vec2_optim: adadelta +wav2vec2_optim_conf: + lr: 0.9 + epsilon: 1.0e-6 + rho: 0.95 +wav2vec2_scheduler: constantlr +wav2vec2_scheduler_conf: warmup_steps: 25000 lr_decay: 1.0 log_interval: 1 checkpoint: kbest_n: 50 - latest_n: 5 + latest_n: 5 \ No newline at end of file diff --git a/examples/librispeech/asr3/local/train.sh b/examples/librispeech/asr3/local/train.sh index 6913ed17..24776fd1 100644 --- a/examples/librispeech/asr3/local/train.sh +++ b/examples/librispeech/asr3/local/train.sh @@ -10,7 +10,8 @@ echo "using $ngpu gpus..." config_path=$1 ckpt_name=$2 -ips=$3 +resume=$3 +ips=$4 if [ ! $ips ];then ips_config= @@ -21,7 +22,7 @@ fi mkdir -p exp # seed may break model convergence -seed=1998 +seed=1988 if [ ${seed} != 0 ]; then export FLAGS_cudnn_deterministic=True fi @@ -34,13 +35,15 @@ python3 -u ${BIN_DIR}/train.py \ --ngpu ${ngpu} \ --config ${config_path} \ --output exp/${ckpt_name} \ ---seed ${seed} +--seed ${seed} \ +--resume ${resume} else python3 -m paddle.distributed.launch --gpus=${CUDA_VISIBLE_DEVICES} ${ips_config} ${BIN_DIR}/train.py \ --ngpu ${ngpu} \ --config ${config_path} \ --output exp/${ckpt_name} \ ---seed ${seed} +--seed ${seed} \ +--resume ${resume} fi if [ ${seed} != 0 ]; then diff --git a/examples/librispeech/asr3/run.sh b/examples/librispeech/asr3/run.sh index 3b1abb11..05ad505c 100644 --- a/examples/librispeech/asr3/run.sh +++ b/examples/librispeech/asr3/run.sh @@ -11,7 +11,7 @@ conf_path=conf/wav2vec2ASR.yaml ips= #xx.xx.xx.xx,xx.xx.xx.xx decode_conf_path=conf/tuning/decode.yaml avg_num=1 -dict_path=data/lang_char/vocab.txt +resume= # xx e.g. 30 . ${MAIN_ROOT}/utils/parse_options.sh || exit 1; @@ -28,7 +28,7 @@ fi if [ ${stage} -le 1 ] && [ ${stop_stage} -ge 1 ]; then # train model, all `ckpt` under `exp` dir - CUDA_VISIBLE_DEVICES=${gpus} ./local/train.sh ${conf_path} ${ckpt} ${ips} + CUDA_VISIBLE_DEVICES=${gpus} ./local/train.sh ${conf_path} ${ckpt} ${resume} ${ips} fi if [ ${stage} -le 2 ] && [ ${stop_stage} -ge 2 ]; then @@ -38,10 +38,10 @@ fi if [ ${stage} -le 3 ] && [ ${stop_stage} -ge 3 ]; then # greedy search decoder - CUDA_VISIBLE_DEVICES=${gpus} ./local/test.sh ${conf_path} ${decode_conf_path} exp/${ckpt}/checkpoints/${avg_ckpt} || exit -1 + CUDA_VISIBLE_DEVICES=0 ./local/test.sh ${conf_path} ${decode_conf_path} exp/${ckpt}/checkpoints/${avg_ckpt} || exit -1 fi if [ ${stage} -le 4 ] && [ ${stop_stage} -ge 4 ]; then # test a single .wav file - CUDA_VISIBLE_DEVICES=${gpus} ./local/test_wav.sh ${conf_path} ${decode_conf_path} exp/${ckpt}/checkpoints/${avg_ckpt} ${audio_file} || exit -1 + CUDA_VISIBLE_DEVICES=0 ./local/test_wav.sh ${conf_path} ${decode_conf_path} exp/${ckpt}/checkpoints/${avg_ckpt} ${audio_file} || exit -1 fi diff --git a/paddlespeech/audio/transform/spectrogram.py b/paddlespeech/audio/transform/spectrogram.py index cba60cfd..84812a2c 100644 --- a/paddlespeech/audio/transform/spectrogram.py +++ b/paddlespeech/audio/transform/spectrogram.py @@ -383,7 +383,7 @@ class LogMelSpectrogramKaldi(): class WavProcess(): - def __init__(self, dither=0.0): + def __init__(self): """ Args: dither (float): Dithering constant @@ -391,9 +391,7 @@ class WavProcess(): Returns: """ - self.dither = dither - - def __call__(self, x, train): + def __call__(self, x): """ Args: x (np.ndarray): shape (Ti,) @@ -405,10 +403,10 @@ class WavProcess(): Returns: np.ndarray: (T, D) """ - dither = self.dither if train else 0.0 if x.ndim != 1: raise ValueError("Not support x: [Time, Channel]") - waveform = np.expand_dims(x, -1) + waveform = x.astype("float32") / 32768.0 + waveform = np.expand_dims(waveform, -1) return waveform diff --git a/paddlespeech/cli/base_commands.py b/paddlespeech/cli/base_commands.py index 7551b6c0..767d0df7 100644 --- a/paddlespeech/cli/base_commands.py +++ b/paddlespeech/cli/base_commands.py @@ -84,6 +84,7 @@ model_name_format = { 'text': 'Model-Task-Language', 'tts': 'Model-Language', 'vector': 'Model-Sample Rate', + 'ssl': 'Model-Language-Sample Rate', 'whisper': 'Model-Language-Sample Rate' } @@ -96,7 +97,7 @@ class StatsCommand: self.parser = argparse.ArgumentParser( prog='paddlespeech.stats', add_help=True) self.task_choices = [ - 'asr', 'cls', 'st', 'text', 'tts', 'vector', 'kws', 'whisper' + 'asr', 'cls', 'st', 'text', 'tts', 'vector', 'kws', 'ssl', 'whisper' ] self.parser.add_argument( '--task', @@ -144,6 +145,8 @@ _commands = { 'tts': ['Text to Speech infer command.', 'TTSExecutor'], 'vector': ['Speech to vector embedding infer command.', 'VectorExecutor'], 'kws': ['Keyword Spotting infer command.', 'KWSExecutor'], + 'ssl': + ['Self-Supervised Learning Pretrained model infer command.', 'SSLExecutor'], 'whisper': [ 'Whisper model for speech to text or translate speech to English.', 'WhisperExecutor' diff --git a/paddlespeech/cli/ssl/__init__.py b/paddlespeech/cli/ssl/__init__.py new file mode 100644 index 00000000..2e53128e --- /dev/null +++ b/paddlespeech/cli/ssl/__init__.py @@ -0,0 +1,14 @@ +# Copyright (c) 2021 PaddlePaddle Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +from .infer import SSLExecutor diff --git a/paddlespeech/cli/ssl/infer.py b/paddlespeech/cli/ssl/infer.py new file mode 100644 index 00000000..154c25f5 --- /dev/null +++ b/paddlespeech/cli/ssl/infer.py @@ -0,0 +1,449 @@ +# Copyright (c) 2021 PaddlePaddle Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +import argparse +import io +import os +import sys +import time +from collections import OrderedDict +from typing import List +from typing import Optional +from typing import Union + +import librosa +import numpy as np +import paddle +import soundfile +from yacs.config import CfgNode + +from ..executor import BaseExecutor +from ..log import logger +from ..utils import CLI_TIMER +from ..utils import stats_wrapper +from ..utils import timer_register +from paddlespeech.audio.transform.transformation import Transformation +from paddlespeech.s2t.frontend.featurizer.text_featurizer import TextFeaturizer +from paddlespeech.s2t.utils.utility import UpdateConfig + +__all__ = ['SSLExecutor'] + + +@timer_register +class SSLExecutor(BaseExecutor): + def __init__(self): + super().__init__('ssl') + self.parser = argparse.ArgumentParser( + prog='paddlespeech.ssl', add_help=True) + self.parser.add_argument( + '--input', type=str, default=None, help='Audio file to recognize.') + self.parser.add_argument( + '--model', + type=str, + default='wav2vec2ASR_librispeech', + choices=[ + tag[:tag.index('-')] + for tag in self.task_resource.pretrained_models.keys() + ], + help='Choose model type of asr task.') + self.parser.add_argument( + '--task', + type=str, + default='asr', + choices=['asr', 'vector'], + help='Choose output type for ssl task') + self.parser.add_argument( + '--lang', + type=str, + default='en', + help='Choose model language. zh or en, zh:[wav2vec2ASR_aishell1-zh-16k], en:[wav2vec2ASR_librispeech-en-16k]' + ) + self.parser.add_argument( + "--sample_rate", + type=int, + default=16000, + choices=[8000, 16000], + help='Choose the audio sample rate of the model. 8000 or 16000') + self.parser.add_argument( + '--config', + type=str, + default=None, + help='Config of asr task. Use deault config when it is None.') + self.parser.add_argument( + '--decode_method', + type=str, + default='ctc_greedy_search', + choices=[ + 'ctc_greedy_search', + 'ctc_prefix_beam_search', + ], + help='only support asr task') + self.parser.add_argument( + '--ckpt_path', + type=str, + default=None, + help='Checkpoint file of model.') + self.parser.add_argument( + '--yes', + '-y', + action="store_true", + default=False, + help='No additional parameters required. \ + Once set this parameter, it means accepting the request of the program by default, \ + which includes transforming the audio sample rate') + self.parser.add_argument( + '--rtf', + action="store_true", + default=False, + help='Show Real-time Factor(RTF).') + self.parser.add_argument( + '--device', + type=str, + default=paddle.get_device(), + help='Choose device to execute model inference.') + self.parser.add_argument( + '-d', + '--job_dump_result', + action='store_true', + help='Save job result into file.') + self.parser.add_argument( + '-v', + '--verbose', + action='store_true', + help='Increase logger verbosity of current task.') + + def _init_from_path(self, + model_type: str='wav2vec2ASR_librispeech', + task: str='asr', + lang: str='en', + sample_rate: int=16000, + cfg_path: Optional[os.PathLike]=None, + decode_method: str='ctc_greedy_search', + ckpt_path: Optional[os.PathLike]=None): + """ + Init model and other resources from a specific path. + """ + logger.debug("start to init the model") + # default max_len: unit:second + self.max_len = 50 + if hasattr(self, 'model'): + logger.debug('Model had been initialized.') + return + if cfg_path is None or ckpt_path is None: + sample_rate_str = '16k' if sample_rate == 16000 else '8k' + if task == 'asr': + tag = model_type + '-' + lang + '-' + sample_rate_str + else: + tag = 'wav2vec2' + '-' + lang + '-' + sample_rate_str + self.task_resource.set_task_model(tag, version=None) + self.res_path = self.task_resource.res_dir + + self.cfg_path = os.path.join( + self.res_path, self.task_resource.res_dict['cfg_path']) + self.ckpt_path = os.path.join( + self.res_path, + self.task_resource.res_dict['ckpt_path'] + ".pdparams") + logger.debug(self.res_path) + else: + self.cfg_path = os.path.abspath(cfg_path) + self.ckpt_path = os.path.abspath(ckpt_path + ".pdparams") + self.res_path = os.path.dirname( + os.path.dirname(os.path.abspath(self.cfg_path))) + logger.debug(self.cfg_path) + logger.debug(self.ckpt_path) + + #Init body. + self.config = CfgNode(new_allowed=True) + self.config.merge_from_file(self.cfg_path) + if task == 'asr': + with UpdateConfig(self.config): + self.text_feature = TextFeaturizer( + unit_type=self.config.unit_type, + vocab=self.config.vocab_filepath) + self.config.decode.decoding_method = decode_method + model_name = model_type[:model_type.rindex( + '_')] # model_type: {model_name}_{dataset} + else: + model_name = 'wav2vec2' + model_class = self.task_resource.get_model_class(model_name) + + model_conf = self.config + model = model_class.from_config(model_conf) + self.model = model + self.model.eval() + + # load model + model_dict = paddle.load(self.ckpt_path) + if task == 'asr': + self.model.set_state_dict(model_dict) + else: + self.model.wav2vec2.set_state_dict(model_dict) + + def preprocess(self, model_type: str, input: Union[str, os.PathLike]): + """ + Input preprocess and return paddle.Tensor stored in self.input. + Input content can be a text(tts), a file(asr, cls) or a streaming(not supported yet). + """ + + audio_file = input + if isinstance(audio_file, (str, os.PathLike)): + logger.debug("Preprocess audio_file:" + audio_file) + elif isinstance(audio_file, io.BytesIO): + audio_file.seek(0) + + # Get the object for feature extraction + logger.debug("get the preprocess conf") + preprocess_conf = self.config.preprocess_config + preprocess_args = {"train": False} + preprocessing = Transformation(preprocess_conf) + logger.debug("read the audio file") + audio, audio_sample_rate = soundfile.read( + audio_file, dtype="int16", always_2d=True) + if self.change_format: + if audio.shape[1] >= 2: + audio = audio.mean(axis=1, dtype=np.int16) + else: + audio = audio[:, 0] + # pcm16 -> pcm 32 + audio = self._pcm16to32(audio) + audio = librosa.resample( + audio, orig_sr=audio_sample_rate, target_sr=self.sample_rate) + audio_sample_rate = self.sample_rate + # pcm32 -> pcm 16 + audio = self._pcm32to16(audio) + else: + audio = audio[:, 0] + + logger.debug(f"audio shape: {audio.shape}") + # fbank + audio = preprocessing(audio, **preprocess_args) + + audio_len = paddle.to_tensor(audio.shape[0]) + audio = paddle.to_tensor(audio, dtype='float32').unsqueeze(axis=0) + + self._inputs["audio"] = audio + self._inputs["audio_len"] = audio_len + logger.debug(f"audio feat shape: {audio.shape}") + + logger.debug("audio feat process success") + + @paddle.no_grad() + def infer(self, model_type: str, task: str): + """ + Model inference and result stored in self.output. + """ + logger.debug("start to infer the model to get the output") + audio = self._inputs["audio"] + if task == 'asr': + cfg = self.config.decode + logger.debug( + f"we will use the wav2vec2ASR like model : {model_type}") + try: + result_transcripts = self.model.decode( + audio, + text_feature=self.text_feature, + decoding_method=cfg.decoding_method, + beam_size=cfg.beam_size) + self._outputs["result"] = result_transcripts[0][0] + except Exception as e: + logger.exception(e) + else: + logger.debug( + "we will use the wav2vec2 like model to extract audio feature") + try: + out_feature = self.model(audio[:, :, 0]) + self._outputs["result"] = out_feature[0] + except Exception as e: + logger.exception(e) + + def postprocess(self) -> Union[str, os.PathLike]: + """ + Output postprocess and return human-readable results such as texts and audio files. + """ + return self._outputs["result"] + + def _pcm16to32(self, audio): + assert (audio.dtype == np.int16) + audio = audio.astype("float32") + bits = np.iinfo(np.int16).bits + audio = audio / (2**(bits - 1)) + return audio + + def _pcm32to16(self, audio): + assert (audio.dtype == np.float32) + bits = np.iinfo(np.int16).bits + audio = audio * (2**(bits - 1)) + audio = np.round(audio).astype("int16") + return audio + + def _check(self, audio_file: str, sample_rate: int, force_yes: bool=False): + self.sample_rate = sample_rate + if self.sample_rate != 16000 and self.sample_rate != 8000: + logger.error( + "invalid sample rate, please input --sr 8000 or --sr 16000") + return False + + if isinstance(audio_file, (str, os.PathLike)): + if not os.path.isfile(audio_file): + logger.error("Please input the right audio file path") + return False + elif isinstance(audio_file, io.BytesIO): + audio_file.seek(0) + + logger.debug("checking the audio file format......") + try: + audio, audio_sample_rate = soundfile.read( + audio_file, dtype="int16", always_2d=True) + audio_duration = audio.shape[0] / audio_sample_rate + if audio_duration > self.max_len: + logger.error( + f"Please input audio file less then {self.max_len} seconds.\n" + ) + return False + except Exception as e: + logger.exception(e) + logger.error( + f"can not open the audio file, please check the audio file({audio_file}) format is 'wav'. \n \ + you can try to use sox to change the file format.\n \ + For example: \n \ + sample rate: 16k \n \ + sox input_audio.xx --rate 16k --bits 16 --channels 1 output_audio.wav \n \ + sample rate: 8k \n \ + sox input_audio.xx --rate 8k --bits 16 --channels 1 output_audio.wav \n \ + ") + return False + logger.debug("The sample rate is %d" % audio_sample_rate) + if audio_sample_rate != self.sample_rate: + logger.warning("The sample rate of the input file is not {}.\n \ + The program will resample the wav file to {}.\n \ + If the result does not meet your expectations,\n \ + Please input the 16k 16 bit 1 channel wav file. \ + ".format(self.sample_rate, self.sample_rate)) + if force_yes is False: + while (True): + logger.debug( + "Whether to change the sample rate and the channel. Y: change the sample. N: exit the prgream." + ) + content = input("Input(Y/N):") + if content.strip() == "Y" or content.strip( + ) == "y" or content.strip() == "yes" or content.strip( + ) == "Yes": + logger.debug( + "change the sampele rate, channel to 16k and 1 channel" + ) + break + elif content.strip() == "N" or content.strip( + ) == "n" or content.strip() == "no" or content.strip( + ) == "No": + logger.debug("Exit the program") + return False + else: + logger.warning("Not regular input, please input again") + + self.change_format = True + else: + logger.debug("The audio file format is right") + self.change_format = False + + return True + + def execute(self, argv: List[str]) -> bool: + """ + Command line entry. + """ + parser_args = self.parser.parse_args(argv) + + model = parser_args.model + task = parser_args.task + lang = parser_args.lang + sample_rate = parser_args.sample_rate + config = parser_args.config + ckpt_path = parser_args.ckpt_path + decode_method = parser_args.decode_method + force_yes = parser_args.yes + rtf = parser_args.rtf + device = parser_args.device + + if not parser_args.verbose: + self.disable_task_loggers() + + task_source = self.get_input_source(parser_args.input) + task_results = OrderedDict() + has_exceptions = False + + for id_, input_ in task_source.items(): + try: + res = self( + audio_file=input_, + model=model, + task=task, + lang=lang, + sample_rate=sample_rate, + config=config, + ckpt_path=ckpt_path, + decode_method=decode_method, + force_yes=force_yes, + rtf=rtf, + device=device) + task_results[id_] = res + + except Exception as e: + has_exceptions = True + task_results[id_] = f'{e.__class__.__name__}: {e}' + + if rtf: + self.show_rtf(CLI_TIMER[self.__class__.__name__]) + self.process_task_results(parser_args.input, task_results, + parser_args.job_dump_result) + if has_exceptions: + return False + else: + return True + + @stats_wrapper + def __call__(self, + audio_file: os.PathLike, + model: str='wav2vec2ASR_librispeech', + task: str='asr', + lang: str='en', + sample_rate: int=16000, + config: os.PathLike=None, + ckpt_path: os.PathLike=None, + decode_method: str='ctc_greedy_search', + force_yes: bool=False, + rtf: bool=False, + device=paddle.get_device()): + """ + Python API to call an executor. + """ + + audio_file = os.path.abspath(audio_file) + paddle.set_device(device) + self._init_from_path(model, task, lang, sample_rate, config, + decode_method, ckpt_path) + if not self._check(audio_file, sample_rate, force_yes): + sys.exit(-1) + if rtf: + k = self.__class__.__name__ + CLI_TIMER[k]['start'].append(time.time()) + self.preprocess(model, audio_file) + self.infer(model, task) + res = self.postprocess() # Retrieve result of asr. + + if rtf: + CLI_TIMER[k]['end'].append(time.time()) + audio, audio_sample_rate = soundfile.read( + audio_file, dtype="int16", always_2d=True) + CLI_TIMER[k]['extra'].append(audio.shape[0] / audio_sample_rate) + + return res diff --git a/paddlespeech/resource/model_alias.py b/paddlespeech/resource/model_alias.py index ce7fa662..ab0b1828 100644 --- a/paddlespeech/resource/model_alias.py +++ b/paddlespeech/resource/model_alias.py @@ -18,6 +18,12 @@ __all__ = [ # Records of model name to import class model_alias = { + # --------------------------------- + # -------------- SSL -------------- + # --------------------------------- + "wav2vec2ASR": ["paddlespeech.s2t.models.wav2vec2:Wav2vec2ASR"], + "wav2vec2": ["paddlespeech.s2t.models.wav2vec2:Wav2vec2Base"], + # --------------------------------- # -------------- ASR -------------- # --------------------------------- diff --git a/paddlespeech/resource/pretrained_models.py b/paddlespeech/resource/pretrained_models.py index 1a3cbec1..bdab9167 100644 --- a/paddlespeech/resource/pretrained_models.py +++ b/paddlespeech/resource/pretrained_models.py @@ -25,6 +25,7 @@ __all__ = [ 'tts_static_pretrained_models', 'tts_onnx_pretrained_models', 'vector_dynamic_pretrained_models', + 'ssl_dynamic_pretrained_models', 'whisper_dynamic_pretrained_models', ] @@ -33,6 +34,44 @@ __all__ = [ # Command line and python api use "{model_name}[_{dataset}]" as --model, usage: # "paddlespeech asr --model conformer_wenetspeech --lang zh --sr 16000 --input ./input.wav" +# --------------------------------- +# -------------- SSL -------------- +# --------------------------------- +ssl_dynamic_pretrained_models = { + "wav2vec2-en-16k": { + '1.3': { + 'url': + 'https://paddlespeech.bj.bcebos.com/s2t/librispeech/asr3/wav2vec2-large-960h-lv60-self_ckpt_1.3.0.model.tar.gz', + 'md5': + 'acc46900680e341e500437aa59193518', + 'cfg_path': + 'model.yaml', + 'ckpt_path': + 'wav2vec2-large-960h-lv60-self', + 'model': + 'wav2vec2-large-960h-lv60-self.pdparams', + 'params': + 'wav2vec2-large-960h-lv60-self.pdparams', + }, + }, + "wav2vec2ASR_librispeech-en-16k": { + '1.3': { + 'url': + 'https://paddlespeech.bj.bcebos.com/s2t/librispeech/asr3/wav2vec2ASR-large-960h-librispeech_ckpt_1.3.1.model.tar.gz', + 'md5': + 'cbe28d6c78f3dd2e189968402381f454', + 'cfg_path': + 'model.yaml', + 'ckpt_path': + 'exp/wav2vec2ASR/checkpoints/avg_1', + 'model': + 'exp/wav2vec2ASR/checkpoints/avg_1.pdparams', + 'params': + 'exp/wav2vec2ASR/checkpoints/avg_1.pdparams', + }, + }, +} + # --------------------------------- # -------------- ASR -------------- # --------------------------------- diff --git a/paddlespeech/resource/resource.py b/paddlespeech/resource/resource.py index d3d89f4d..4eb0e32d 100644 --- a/paddlespeech/resource/resource.py +++ b/paddlespeech/resource/resource.py @@ -22,7 +22,9 @@ from ..utils.dynamic_import import dynamic_import from ..utils.env import MODEL_HOME from .model_alias import model_alias -task_supported = ['asr', 'cls', 'st', 'text', 'tts', 'vector', 'kws', 'whisper'] +task_supported = [ + 'asr', 'cls', 'st', 'text', 'tts', 'vector', 'kws', 'ssl', 'whisper' +] model_format_supported = ['dynamic', 'static', 'onnx'] inference_mode_supported = ['online', 'offline'] @@ -108,7 +110,6 @@ class CommonTaskResource: """ assert model_name in model_alias, 'No model classes found for "{}"'.format( model_name) - ret = [] for import_path in model_alias[model_name]: ret.append(dynamic_import(import_path)) diff --git a/paddlespeech/s2t/exps/wav2vec2/__init__.py b/paddlespeech/s2t/exps/wav2vec2/__init__.py new file mode 100644 index 00000000..97043fd7 --- /dev/null +++ b/paddlespeech/s2t/exps/wav2vec2/__init__.py @@ -0,0 +1,13 @@ +# Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. diff --git a/paddlespeech/s2t/exps/wav2vec2/bin/__init__.py b/paddlespeech/s2t/exps/wav2vec2/bin/__init__.py index 185a92b8..97043fd7 100644 --- a/paddlespeech/s2t/exps/wav2vec2/bin/__init__.py +++ b/paddlespeech/s2t/exps/wav2vec2/bin/__init__.py @@ -1,4 +1,4 @@ -# Copyright (c) 2021 PaddlePaddle Authors. All Rights Reserved. +# Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved. # # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. diff --git a/paddlespeech/s2t/exps/wav2vec2/bin/train.py b/paddlespeech/s2t/exps/wav2vec2/bin/train.py index 3ae3a9e7..29e7ef55 100644 --- a/paddlespeech/s2t/exps/wav2vec2/bin/train.py +++ b/paddlespeech/s2t/exps/wav2vec2/bin/train.py @@ -34,9 +34,10 @@ def main(config, args): if __name__ == "__main__": parser = default_argument_parser() + parser.add_argument( + '--resume', type=str, default="", nargs="?", help='resume ckpt path.') args = parser.parse_args() print_arguments(args, globals()) - # https://yaml.org/type/float.html config = CfgNode(new_allowed=True) if args.config: diff --git a/paddlespeech/s2t/exps/wav2vec2/model.py b/paddlespeech/s2t/exps/wav2vec2/model.py index 4f6bc0c5..6a3321e4 100644 --- a/paddlespeech/s2t/exps/wav2vec2/model.py +++ b/paddlespeech/s2t/exps/wav2vec2/model.py @@ -15,6 +15,7 @@ import json import math import os +import re import time from collections import defaultdict from collections import OrderedDict @@ -62,6 +63,19 @@ class Wav2Vec2ASRTrainer(Trainer): self.avg_train_loss -= self.avg_train_loss / (batch_index + 1) self.avg_train_loss += loss / (batch_index + 1) + def before_train(self): + from_scratch = self.resume_or_scratch() + if from_scratch: + # scratch: save init model, i.e. 0 epoch + self.save(tag='init', infos=None) + else: + # resume: train next_epoch and next_iteration + self.epoch += 1 + logger.info( + f"Resume train: epoch {self.epoch }, step {self.iteration}!") + + self.maybe_batch_sampler_step() + def train_batch(self, batch_index, batch, msg): train_conf = self.config start = time.time() @@ -69,14 +83,14 @@ class Wav2Vec2ASRTrainer(Trainer): # forward utt, wav, wavs_lens, target, target_lens = batch wavs_lens_rate = wavs_lens / wav.shape[1] - target_lens_rate = target_lens / target.shape[1] + wav = wav[:, :, 0] - if hasattr(train_conf, 'speech_augment'): + if hasattr(train_conf, 'audio_augment'): wav = self.speech_augmentation(wav, wavs_lens_rate) - loss = self.model(wav, wavs_lens_rate, target, target_lens_rate) + + loss = self.model(wav, wavs_lens_rate, target, target_lens) # loss div by `batch_size * accum_grad` loss /= train_conf.accum_grad - # update self.avg_train_loss self.update_average(batch_index, float(loss)) @@ -98,11 +112,17 @@ class Wav2Vec2ASRTrainer(Trainer): # optimizer step old if (batch_index + 1) % train_conf.accum_grad == 0: - self.optimizer.step() - self.optimizer.clear_grad() - self.lr_scheduler.step() + self.model_optimizer.step() + self.model_optimizer.clear_grad() + if not train_conf.freeze_wav2vec2: + self.wav2vec2_optimizer.step() + self.wav2vec2_optimizer.clear_grad() + if self.config.model_scheduler != 'newbobscheduler': + self.model_lr_scheduler.step() + if self.config.wav2vec2_scheduler != 'newbobscheduler': + if not train_conf.freeze_wav2vec2: + self.wav2vec2_lr_scheduler.step() self.iteration += 1 - losses_np = {'loss': self.avg_train_loss * train_conf.accum_grad} iteration_time = time.time() - start for k, v in losses_np.items(): @@ -114,7 +134,10 @@ class Wav2Vec2ASRTrainer(Trainer): if (batch_index + 1) % train_conf.accum_grad == 0: if dist.get_rank() == 0 and self.visualizer: losses_np_v = losses_np.copy() - losses_np_v.update({"lr": self.lr_scheduler()}) + losses_np_v.update({ + "model_lr": self.model_lr_scheduler(), + "wav2vec2_lr": self.wav2vec2_lr_scheduler() + }) for key, val in losses_np_v.items(): self.visualizer.add_scalar( tag='train/' + key, value=val, step=self.iteration - 1) @@ -131,11 +154,10 @@ class Wav2Vec2ASRTrainer(Trainer): for i, batch in enumerate(self.valid_loader): utt, wav, wavs_lens, target, target_lens = batch wavs_lens_rate = wavs_lens / wav.shape[1] - target_lens_rate = target_lens / target.shape[1] wav = wav[:, :, 0] - loss = self.model(wav, wavs_lens_rate, target, target_lens_rate) + loss = self.model(wav, wavs_lens_rate, target, target_lens) - if paddle.isfinite(loss): + if math.isfinite(float(loss)): num_utts = batch[1].shape[0] num_seen_utts += num_utts total_loss += float(loss) * num_utts @@ -160,6 +182,106 @@ class Wav2Vec2ASRTrainer(Trainer): dist.get_rank(), total_loss / num_seen_utts)) return total_loss, num_seen_utts + @mp_tools.rank_zero_only + def save(self, tag=None, infos: dict=None): + """Save checkpoint (model parameters and optimizer states). + + Args: + tag (int or str, optional): None for step, else using tag, e.g epoch. Defaults to None. + infos (dict, optional): meta data to save. Defaults to None. + """ + + infos = infos if infos else dict() + infos.update({ + "epoch": self.epoch, + "model_lr": self.model_optimizer.get_lr(), + "wav2vec2_lr": self.wav2vec2_optimizer.get_lr() + }) + + checkpoint_path = os.path.join( + self.checkpoint_dir, + "{}".format(self.iteration if tag is None else tag)) + + model_dict = self.model.state_dict() + params_path = checkpoint_path + ".pdparams" + paddle.save(model_dict, params_path) + logger.info("Saved model to {}".format(params_path)) + + model_opt_dict = self.model_optimizer.state_dict() + wav2vec2_opt_dict = self.wav2vec2_optimizer.state_dict() + + opt_dict = {'model': model_opt_dict, 'wav2vec2': wav2vec2_opt_dict} + + optimizer_path = checkpoint_path + ".pdopt" + paddle.save(opt_dict, optimizer_path) + logger.info("Saved optimzier state to {}".format(optimizer_path)) + + scheduler_dict = {} + + if self.config.model_scheduler == 'newbobscheduler': + scheduler_dict['model'] = self.model_lr_scheduler.save() + if self.config.wav2vec2_scheduler == 'newbobscheduler': + scheduler_dict['wav2vec2'] = self.wav2vec2_lr_scheduler.save() + if scheduler_dict: + scheduler_path = checkpoint_path + ".pdlrs" + paddle.save(scheduler_dict, scheduler_path) + logger.info("Saved scheduler state to {}".format(scheduler_path)) + info_path = re.sub('.pdparams$', '.json', params_path) + infos = {} if infos is None else infos + with open(info_path, 'w') as fout: + data = json.dumps(infos) + fout.write(data) + + def resume_or_scratch(self): + """Resume from latest checkpoint at checkpoints in the output + directory or load a specified checkpoint. + + If ``args.checkpoint_path`` is not None, load the checkpoint, else + resume training. + """ + scratch = None + if self.args.resume: + # just restore ckpt + # lr will resotre from optimizer ckpt + resume_json_path = os.path.join(self.checkpoint_dir, + self.args.resume + '.json') + with open(resume_json_path, 'r') as f: + resume_json = json.load(f) + self.iteration = 0 + self.epoch = resume_json["epoch"] + + # resotre model from *.pdparams + params_path = os.path.join(self.checkpoint_dir, + "{}".format(self.epoch)) + '.pdparams' + model_dict = paddle.load(params_path) + self.model.set_state_dict(model_dict) + + # resotre optimizer from *.pdopt + optimizer_path = os.path.join(self.checkpoint_dir, + "{}".format(self.epoch)) + '.pdopt' + optimizer_dict = paddle.load(optimizer_path) + self.model_optimizer.set_state_dict(optimizer_dict['model']) + self.wav2vec2_optimizer.set_state_dict(optimizer_dict['wav2vec2']) + + # resotre lr_scheduler from *.pdlrs + scheduler_path = os.path.join(self.checkpoint_dir, + "{}".format(self.epoch)) + '.pdlrs' + if os.path.isfile(os.path.join(scheduler_path)): + scheduler_dict = paddle.load(scheduler_path) + if self.config.model_scheduler == 'newbobscheduler': + self.model_lr_scheduler.load(scheduler_dict['model']) + if self.config.wav2vec2_scheduler == 'newbobscheduler': + self.wav2vec2_lr_scheduler.load(scheduler_dict['wav2vec2']) + logger.info( + f"Restore ckpt: epoch {self.epoch }, step {self.iteration}!") + scratch = False + else: + self.iteration = 0 + self.epoch = 0 + scratch = True + logger.info("Init from scratch!") + return scratch + def do_train(self): """The training process control by step.""" # !!!IMPORTANT!!! @@ -170,7 +292,6 @@ class Wav2Vec2ASRTrainer(Trainer): # paddle.jit.save(script_model, script_model_path) self.before_train() - if not self.use_streamdata: logger.info( f"Train Total Examples: {len(self.train_loader.dataset)}") @@ -187,7 +308,9 @@ class Wav2Vec2ASRTrainer(Trainer): report("Rank", dist.get_rank()) report("epoch", self.epoch) report('step', self.iteration) - report("lr", self.lr_scheduler()) + report("model_lr", self.model_optimizer.get_lr()) + report("wav2vec2_lr", + self.wav2vec2_optimizer.get_lr()) self.train_batch(batch_index, batch, msg) self.after_train_batch() report('iter', batch_index + 1) @@ -225,15 +348,25 @@ class Wav2Vec2ASRTrainer(Trainer): cv_loss = float(cv_loss) else: cv_loss = total_loss / num_seen_utts - logger.info( 'Epoch {} Val info val_loss {}'.format(self.epoch, cv_loss)) if self.visualizer: self.visualizer.add_scalar( tag='eval/cv_loss', value=cv_loss, step=self.epoch) self.visualizer.add_scalar( - tag='eval/lr', value=self.lr_scheduler(), step=self.epoch) - + tag='eval/model_lr', + value=self.model_lr_scheduler(), + step=self.epoch) + self.visualizer.add_scalar( + tag='eval/wav2vec2_lr', + value=self.wav2vec2_lr_scheduler(), + step=self.epoch) + + if self.config.model_scheduler == 'newbobscheduler': + self.model_lr_scheduler.step(cv_loss) + if self.config.wav2vec2_scheduler == 'newbobscheduler': + if not self.config.freeze_wav2vec2: + self.wav2vec2_lr_scheduler.step(cv_loss) self.save(tag=self.epoch, infos={'val_loss': cv_loss}) self.new_epoch() @@ -268,10 +401,11 @@ class Wav2Vec2ASRTrainer(Trainer): model_conf.output_dim = self.test_loader.vocab_size model = Wav2vec2ASR.from_config(model_conf) + model_dict = paddle.load(config.wav2vec2_params_path) + model.wav2vec2.set_state_dict(model_dict) if self.parallel: model = paddle.DataParallel(model, find_unused_parameters=True) - logger.info(f"{model}") layer_tools.print_params(model, logger.info) self.model = model @@ -286,46 +420,74 @@ class Wav2Vec2ASRTrainer(Trainer): return train_config = config - optim_type = train_config.model_optim - optim_conf = train_config.model_optim_conf - scheduler_type = train_config.scheduler - scheduler_conf = train_config.scheduler_conf - - scheduler_args = { - "learning_rate": optim_conf.lr, - "verbose": False, - "warmup_steps": scheduler_conf.warmup_steps, - "gamma": scheduler_conf.lr_decay, - "d_model": model_conf.dnn_neurons, - } - lr_scheduler = LRSchedulerFactory.from_args(scheduler_type, - scheduler_args) + model_optim_type = train_config.model_optim + model_optim_conf = train_config.model_optim_conf + wav2vec2_optim_type = train_config.model_optim + wav2vec2_optim_conf = train_config.wav2vec2_optim_conf + + model_scheduler_type = train_config.model_scheduler + model_scheduler_conf = train_config.model_scheduler_conf + wav2vec2_scheduler_type = train_config.wav2vec2_scheduler + wav2vec2_scheduler_conf = train_config.wav2vec2_scheduler_conf + + model_scheduler_args = dict( + **{"learning_rate": model_optim_conf.lr, + "verbose": False}, **(dict(model_scheduler_conf))) + + wav2vec2_scheduler_args = dict( + **{"learning_rate": wav2vec2_optim_conf.lr, + "verbose": False}, **(dict(wav2vec2_scheduler_conf))) + + model_lr_scheduler = LRSchedulerFactory.from_args(model_scheduler_type, + model_scheduler_args) + wav2vec2_lr_scheduler = LRSchedulerFactory.from_args( + wav2vec2_scheduler_type, wav2vec2_scheduler_args) def optimizer_args( config, + optim_type, + optim_conf, parameters, lr_scheduler=None, ): train_config = config - optim_type = train_config.model_optim - optim_conf = train_config.model_optim_conf - scheduler_type = train_config.scheduler - scheduler_conf = train_config.scheduler_conf - return { - "grad_clip": train_config.global_grad_clip, - "learning_rate": lr_scheduler - if lr_scheduler else optim_conf.lr, - "epsilon": optim_conf.epsilon, - "rho": optim_conf.rho, - "parameters": parameters, - "beta1": 0.9 if optim_type == 'noam' else None, - "beat2": 0.98 if optim_type == 'noam' else None, - } - - optimzer_args = optimizer_args(config, model.parameters(), lr_scheduler) - optimizer = OptimizerFactory.from_args(optim_type, optimzer_args) - - self.optimizer = optimizer - self.lr_scheduler = lr_scheduler + optim_arg = dict(optim_conf) + optim_arg.update({ + "grad_clip": + train_config.global_grad_clip, + "learning_rate": + lr_scheduler if lr_scheduler else optim_conf.lr, + "parameters": + parameters + }) + return optim_arg + + model_optimizer_args = optimizer_args(config, model_optim_type, + model_optim_conf, [{ + 'params': + model._layers.enc.parameters() + }, { + 'params': + model._layers.ctc.parameters() + }] if self.parallel else [{ + 'params': + model.enc.parameters() + }, { + 'params': + model.ctc.parameters() + }], model_lr_scheduler) + wav2vec2_optimizer_args = optimizer_args( + config, wav2vec2_optim_type, wav2vec2_optim_conf, + model._layers.wav2vec2.parameters() if self.parallel else + model.wav2vec2.parameters(), wav2vec2_lr_scheduler) + model_optimizer = OptimizerFactory.from_args(model_optim_type, + model_optimizer_args) + wav2vec2_optimizer = OptimizerFactory.from_args(wav2vec2_optim_type, + wav2vec2_optimizer_args) + + self.model_optimizer = model_optimizer + self.wav2vec2_optimizer = wav2vec2_optimizer + self.model_lr_scheduler = model_lr_scheduler + self.wav2vec2_lr_scheduler = wav2vec2_lr_scheduler logger.info("Setup optimizer/lr_scheduler!") diff --git a/paddlespeech/s2t/models/wav2vec2/__init__.py b/paddlespeech/s2t/models/wav2vec2/__init__.py index e69de29b..3a12a9cf 100644 --- a/paddlespeech/s2t/models/wav2vec2/__init__.py +++ b/paddlespeech/s2t/models/wav2vec2/__init__.py @@ -0,0 +1,17 @@ +# Copyright (c) 2021 PaddlePaddle Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +from .wav2vec2_ASR import Wav2vec2ASR +from .wav2vec2_ASR import Wav2vec2Base + +__all__ = ["Wav2vec2ASR", "Wav2vec2Base"] diff --git a/paddlespeech/s2t/models/wav2vec2/modules/VanillaNN.py b/paddlespeech/s2t/models/wav2vec2/modules/VanillaNN.py index cfd8f507..82313c33 100644 --- a/paddlespeech/s2t/models/wav2vec2/modules/VanillaNN.py +++ b/paddlespeech/s2t/models/wav2vec2/modules/VanillaNN.py @@ -18,6 +18,7 @@ import paddle from paddlespeech.s2t.models.wav2vec2.modules import containers from paddlespeech.s2t.models.wav2vec2.modules import linear +from paddlespeech.s2t.models.wav2vec2.modules.normalization import BatchNorm1d class VanillaNN(containers.Sequential): @@ -39,18 +40,34 @@ class VanillaNN(containers.Sequential): paddle.shape([10, 120, 512]) """ - def __init__( - self, - input_shape, - activation=paddle.nn.LeakyReLU, - dnn_blocks=2, - dnn_neurons=512, ): - super().__init__(input_shape=input_shape) + def __init__(self, + input_shape, + dnn_blocks=2, + dnn_neurons=512, + activation=True, + normalization=False, + dropout_rate=0.0): + super().__init__(input_shape=[None, None, input_shape]) + + if not isinstance(dropout_rate, list): + dropout_rate = [dropout_rate] * dnn_blocks + else: + assert len( + dropout_rate + ) == dnn_blocks, "len(dropout_rate) must equal to dnn_blocks" for block_index in range(dnn_blocks): self.append( linear.Linear, n_neurons=dnn_neurons, - bias=True, + bias_attr=None, layer_name="linear", ) - self.append(activation(), layer_name="act") + if normalization: + self.append( + BatchNorm1d, input_size=dnn_neurons, layer_name='bn') + if activation: + self.append(paddle.nn.LeakyReLU(), layer_name="act") + self.append( + paddle.nn.Dropout(), + p=dropout_rate[block_index], + layer_name='dropout') diff --git a/paddlespeech/s2t/models/wav2vec2/modules/__init__.py b/paddlespeech/s2t/models/wav2vec2/modules/__init__.py new file mode 100644 index 00000000..97043fd7 --- /dev/null +++ b/paddlespeech/s2t/models/wav2vec2/modules/__init__.py @@ -0,0 +1,13 @@ +# Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. diff --git a/paddlespeech/s2t/models/wav2vec2/modules/containers.py b/paddlespeech/s2t/models/wav2vec2/modules/containers.py index 180d0bd3..6a6b94e9 100644 --- a/paddlespeech/s2t/models/wav2vec2/modules/containers.py +++ b/paddlespeech/s2t/models/wav2vec2/modules/containers.py @@ -141,5 +141,4 @@ class Sequential(paddle.nn.LayerDict): x = layer(x) if isinstance(x, tuple): x = x[0] - return x diff --git a/paddlespeech/s2t/models/wav2vec2/modules/linear.py b/paddlespeech/s2t/models/wav2vec2/modules/linear.py index adae4514..3ea3716c 100644 --- a/paddlespeech/s2t/models/wav2vec2/modules/linear.py +++ b/paddlespeech/s2t/models/wav2vec2/modules/linear.py @@ -53,7 +53,7 @@ class Linear(paddle.nn.Layer): n_neurons, input_shape=None, input_size=None, - bias=True, + bias_attr=None, combine_dims=False, ): super().__init__() self.combine_dims = combine_dims @@ -67,7 +67,7 @@ class Linear(paddle.nn.Layer): input_size = input_shape[2] * input_shape[3] # Weights are initialized following paddle approach - self.w = align.Linear(input_size, n_neurons, bias_attr=bias) + self.w = align.Linear(input_size, n_neurons, bias_attr=bias_attr) def forward(self, x): """Returns the linear transformation of input tensor. diff --git a/paddlespeech/s2t/models/wav2vec2/modules/modeling_wav2vec2.py b/paddlespeech/s2t/models/wav2vec2/modules/modeling_wav2vec2.py index e484fff6..5670cb53 100644 --- a/paddlespeech/s2t/models/wav2vec2/modules/modeling_wav2vec2.py +++ b/paddlespeech/s2t/models/wav2vec2/modules/modeling_wav2vec2.py @@ -1120,9 +1120,6 @@ class Wav2Vec2ConfigPure(): self.output_hidden_states = False self.use_return_dict = True - self.pad_token_id = config.pad_token_id - self.bos_token_id = config.bos_token_id - self.eos_token_id = config.eos_token_id self.hidden_size = config.hidden_size self.feat_extract_norm = config.feat_extract_norm self.feat_extract_activation = config.feat_extract_activation @@ -1145,7 +1142,6 @@ class Wav2Vec2ConfigPure(): self.layerdrop = config.layerdrop self.layer_norm_eps = config.layer_norm_eps self.initializer_range = config.initializer_range - self.vocab_size = config.vocab_size self.do_stable_layer_norm = config.do_stable_layer_norm self.use_weighted_layer_sum = config.use_weighted_layer_sum diff --git a/paddlespeech/s2t/models/wav2vec2/processing/__init__.py b/paddlespeech/s2t/models/wav2vec2/processing/__init__.py new file mode 100644 index 00000000..97043fd7 --- /dev/null +++ b/paddlespeech/s2t/models/wav2vec2/processing/__init__.py @@ -0,0 +1,13 @@ +# Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. diff --git a/paddlespeech/s2t/models/wav2vec2/processing/speech_augmentation.py b/paddlespeech/s2t/models/wav2vec2/processing/speech_augmentation.py index ac9bf45d..9224549a 100644 --- a/paddlespeech/s2t/models/wav2vec2/processing/speech_augmentation.py +++ b/paddlespeech/s2t/models/wav2vec2/processing/speech_augmentation.py @@ -639,6 +639,170 @@ class DropChunk(nn.Layer): return dropped_waveform +class SpecAugment(paddle.nn.Layer): + """An implementation of the SpecAugment algorithm. + Reference: + https://arxiv.org/abs/1904.08779 + Arguments + --------- + time_warp : bool + Whether applying time warping. + time_warp_window : int + Time warp window. + time_warp_mode : str + Interpolation mode for time warping (default "bicubic"). + freq_mask : bool + Whether applying freq mask. + freq_mask_width : int or tuple + Freq mask width range. + n_freq_mask : int + Number of freq mask. + time_mask : bool + Whether applying time mask. + time_mask_width : int or tuple + Time mask width range. + n_time_mask : int + Number of time mask. + replace_with_zero : bool + If True, replace masked value with 0, else replace masked value with mean of the input tensor. + Example + ------- + >>> aug = SpecAugment() + >>> a = paddle.rand([8, 120, 80]) + >>> a = aug(a) + >>> print(a.shape) + paddle.Size([8, 120, 80]) + """ + + def __init__( + self, + time_warp=True, + time_warp_window=5, + time_warp_mode="bicubic", + freq_mask=True, + freq_mask_width=(0, 20), + n_freq_mask=2, + time_mask=True, + time_mask_width=(0, 100), + n_time_mask=2, + replace_with_zero=True, ): + super().__init__() + assert ( + time_warp or freq_mask or time_mask + ), "at least one of time_warp, time_mask, or freq_mask should be applied" + + self.apply_time_warp = time_warp + self.time_warp_window = time_warp_window + self.time_warp_mode = time_warp_mode + + self.freq_mask = freq_mask + if isinstance(freq_mask_width, int): + freq_mask_width = (0, freq_mask_width) + self.freq_mask_width = freq_mask_width + self.n_freq_mask = n_freq_mask + + self.time_mask = time_mask + if isinstance(time_mask_width, int): + time_mask_width = (0, time_mask_width) + self.time_mask_width = time_mask_width + self.n_time_mask = n_time_mask + + self.replace_with_zero = replace_with_zero + + def forward(self, x): + """Takes in input a tensors and returns an augmented one.""" + if self.apply_time_warp: + x = self.time_warp(x) + if self.freq_mask: + x = self.mask_along_axis(x, dim=2) + if self.time_mask: + x = self.mask_along_axis(x, dim=1) + return x + + def time_warp(self, x): + """Time warping with paddle.nn.functional.interpolate""" + original_size = x.shape + window = self.time_warp_window + + # 2d interpolation requires 4D or higher dimension tensors + # x: (Batch, Time, Freq) -> (Batch, 1, Time, Freq) + if x.dim() == 3: + x = x.unsqueeze(1) + + time = x.shape[2] + if time - window <= window: + return x.view(*original_size) + + # compute center and corresponding window + c = paddle.randint(window, time - window, (1, ))[0] + w = paddle.randint(c - window, c + window, (1, ))[0] + 1 + # c = 5 + # w = 10 + left = paddle.nn.functional.interpolate( + x[:, :, :c], + (w, x.shape[3]), + mode=self.time_warp_mode, + align_corners=True, ) + right = paddle.nn.functional.interpolate( + x[:, :, c:], + (time - w, x.shape[3]), + mode=self.time_warp_mode, + align_corners=True, ) + + x[:, :, :w] = left + x[:, :, w:] = right + return x.view(*original_size) + + def mask_along_axis(self, x, dim): + """Mask along time or frequency axis. + Arguments + --------- + x : tensor + Input tensor. + dim : int + Corresponding dimension to mask. + """ + original_size = x.shape + if x.dim() == 4: + x = x.view(-1, x.shape[2], x.shape[3]) + + batch, time, fea = x.shape + + if dim == 1: + D = time + n_mask = self.n_time_mask + width_range = self.time_mask_width + else: + D = fea + n_mask = self.n_freq_mask + width_range = self.freq_mask_width + + mask_len = paddle.randint(width_range[0], width_range[1], + (batch, n_mask)).unsqueeze(2) + + mask_pos = paddle.randint(0, max(1, D - mask_len.max()), + (batch, n_mask)).unsqueeze(2) + + # compute masks + arange = paddle.arange(end=D).view(1, 1, -1) + mask = (mask_pos <= arange) * (arange < (mask_pos + mask_len)) + mask = mask.any(axis=1) + + if dim == 1: + mask = mask.unsqueeze(2) + else: + mask = mask.unsqueeze(1) + + if self.replace_with_zero: + val = 0.0 + else: + val = x.mean() + # same to x.masked_fill_(mask, val) + y = paddle.full(x.shape, val, x.dtype) + x = paddle.where(mask, y, x) + return x.view(*original_size) + + class TimeDomainSpecAugment(nn.Layer): """A time-domain approximation of the SpecAugment algorithm. This augmentation module implements three augmentations in diff --git a/paddlespeech/s2t/models/wav2vec2/wav2vec2_ASR.py b/paddlespeech/s2t/models/wav2vec2/wav2vec2_ASR.py index e1334774..eda188da 100644 --- a/paddlespeech/s2t/models/wav2vec2/wav2vec2_ASR.py +++ b/paddlespeech/s2t/models/wav2vec2/wav2vec2_ASR.py @@ -23,7 +23,9 @@ import paddle.nn.functional as F from paddlespeech.s2t.models.wav2vec2.modules.modeling_wav2vec2 import Wav2Vec2ConfigPure from paddlespeech.s2t.models.wav2vec2.modules.modeling_wav2vec2 import Wav2Vec2Model from paddlespeech.s2t.models.wav2vec2.modules.VanillaNN import VanillaNN +from paddlespeech.s2t.models.wav2vec2.processing.speech_augmentation import SpecAugment from paddlespeech.s2t.modules.ctc import CTCDecoderBase as CTC +from paddlespeech.s2t.modules.initializer import DefaultInitializerContext from paddlespeech.s2t.utils.ctc_utils import remove_duplicates_and_blank from paddlespeech.s2t.utils.utility import log_add @@ -31,44 +33,41 @@ from paddlespeech.s2t.utils.utility import log_add class Wav2vec2ASR(nn.Layer): def __init__(self, config: dict): super().__init__() + init_type = config.get("init_type", None) + with DefaultInitializerContext(init_type): + self.config = config + wav2vec2_config = Wav2Vec2ConfigPure(config) + wav2vec2 = Wav2Vec2Model(wav2vec2_config) + self.normalize_wav = config.normalize_wav + self.output_norm = config.output_norm + if hasattr(config, 'spec_augment'): + self.spec_augment = SpecAugment(**config.spec_augment) - wav2vec2_config = Wav2Vec2ConfigPure(config) - wav2vec2 = Wav2Vec2Model(wav2vec2_config) - model_dict = paddle.load(config.wav2vec2_params_path) - wav2vec2.set_state_dict(model_dict) - self.normalize_wav = config.normalize_wav - self.output_norm = config.output_norm - if config.freeze_wav2vec2: - wav2vec2.eval() - for parm in wav2vec2.parameters(): - parm.trainable = False - self.wav2vec2 = wav2vec2 - self.enc = VanillaNN( - input_shape=[None, None, wav2vec2_config.hidden_size], - activation=nn.LeakyReLU, - dnn_blocks=config.dnn_blocks, - dnn_neurons=config.dnn_neurons) - self.ctc = CTC(odim=config.output_dim, - enc_n_units=config.dnn_neurons, - blank_id=config.blank_id, - dropout_rate=config.ctc_dropout_rate, - reduction='mean') - - def forward(self, wav, wavs_lens_rate, target, target_lens_rate): + if config.freeze_wav2vec2: + wav2vec2.eval() + for parm in wav2vec2.parameters(): + parm.trainable = False + self.wav2vec2 = wav2vec2 + self.enc = VanillaNN(**config.enc) + self.ctc = CTC(**config.ctc, + odim=config.output_dim, + batch_average=False, + reduction='mean') + + def forward(self, wav, wavs_lens_rate, target, target_lens): if self.normalize_wav: - wav = F.layer_norm(wav, wav.shape[1:]) + wav = F.layer_norm(wav, wav.shape) # Extract wav2vec output out = self.wav2vec2(wav)[0] # We normalize the output if required if self.output_norm: - out = F.layer_norm(out, out.shape[1:]) - feats = out - + out = F.layer_norm(out, out.shape) + if self.train and hasattr(self.config, 'spec_augment'): + feats = self.spec_augment(out) + else: + feats = out x = self.enc(feats) x_lens = (wavs_lens_rate * x.shape[1]).round().astype(paddle.int64) - target_lens = (target_lens_rate * - target.shape[1]).round().astype(paddle.int64) - ctc_loss = self.ctc(x, x_lens, target, target_lens) return ctc_loss @@ -239,3 +238,33 @@ class Wav2vec2ASR(nn.Layer): """ hyps = self._ctc_prefix_beam_search(wav, beam_size) return hyps[0][0] + + +class Wav2vec2Base(nn.Layer): + """Wav2vec2 model""" + + def __init__(self, config: dict): + super().__init__() + wav2vec2_config = Wav2Vec2ConfigPure(config) + wav2vec2 = Wav2Vec2Model(wav2vec2_config) + self.wav2vec2 = wav2vec2 + + @classmethod + def from_config(cls, configs: dict): + """init model. + + Args: + configs (dict): config dict. + + Raises: + ValueError: raise when using not support encoder type. + + Returns: + nn.Layer: Wav2Vec2Base + """ + model = cls(configs) + return model + + def forward(self, wav): + out = self.wav2vec2(wav) + return out diff --git a/paddlespeech/s2t/training/scheduler.py b/paddlespeech/s2t/training/scheduler.py index b22f7ef8..53c756ce 100644 --- a/paddlespeech/s2t/training/scheduler.py +++ b/paddlespeech/s2t/training/scheduler.py @@ -17,6 +17,7 @@ from typing import Dict from typing import Text from typing import Union +import paddle from paddle.optimizer.lr import LRScheduler from typeguard import check_argument_types @@ -107,6 +108,125 @@ class ConstantLR(LRScheduler): return self.base_lr +@register_scheduler +class NewBobScheduler(LRScheduler): + """Scheduler with new-bob technique, used for LR annealing. + + The learning rate is annealed based on the validation performance. + In particular: if (past_loss-current_loss)/past_loss< impr_threshold: + lr=lr * annealing_factor. + + Arguments + --------- + initial_value : float + The initial hyperparameter value. + annealing_factor : float + It is annealing factor used in new_bob strategy. + improvement_threshold : float + It is the improvement rate between losses used to perform learning + annealing in new_bob strategy. + patient : int + When the annealing condition is violated patient times, + the learning rate is finally reduced. + + Example + ------- + >>> scheduler = NewBobScheduler(initial_value=1.0) + >>> scheduler(metric_value=10.0) + (1.0, 1.0) + >>> scheduler(metric_value=2.0) + (1.0, 1.0) + >>> scheduler(metric_value=2.5) + (1.0, 0.5) + """ + + def __init__( + self, + learning_rate, + last_epoch=-1, + verbose=False, + annealing_factor=0.5, + improvement_threshold=0.0025, + patient=0, ): + self.hyperparam_value = learning_rate + self.annealing_factor = annealing_factor + self.improvement_threshold = improvement_threshold + self.patient = patient + self.metric_values = [] + self.current_patient = self.patient + super().__init__(learning_rate, last_epoch, verbose) + + def step(self, metric_value=None): + """ + + ``step`` should be called after ``optimizer.step`` . It will update the learning rate in optimizer according to current ``epoch`` . + The new learning rate will take effect on next ``optimizer.step`` . + + Args: + epoch (int, None): specify current epoch. Default: None. Auto-increment from last_epoch=-1. + + Returns: + None + """ + if metric_value is None: + self.last_epoch += 1 + self.last_lr = self.hyperparam_value + else: + self.last_epoch += 1 + self.last_lr = self.get_lr(metric_value) + + if self.verbose: + print('Epoch {}: {} set learning rate to {}.'.format( + self.last_epoch, self.__class__.__name__, self.last_lr)) + + def get_lr(self, metric_value): + """Returns the current and new value for the hyperparameter. + + Arguments + --------- + metric_value : int + A number for determining whether to change the hyperparameter value. + """ + new_value = self.hyperparam_value + if len(self.metric_values) > 0: + prev_metric = self.metric_values[-1] + # Update value if improvement too small and patience is 0 + if prev_metric == 0: # Prevent division by zero + improvement = 0 + else: + improvement = (prev_metric - metric_value) / prev_metric + if improvement < self.improvement_threshold: + if self.current_patient == 0: + new_value *= self.annealing_factor + self.current_patient = self.patient + else: + self.current_patient -= 1 + + # Store relevant info + self.metric_values.append(metric_value) + self.hyperparam_value = new_value + + return new_value + + def save(self): + """Saves the current metrics on the specified path.""" + data = { + "current_epoch_index": self.last_epoch, + "hyperparam_value": self.hyperparam_value, + "metric_values": self.metric_values, + "current_patient": self.current_patient + } + return data + + def load(self, data): + """Loads the needed information.""" + data = paddle.load(data) + self.last_epoch = data["current_epoch_index"] + self.hyperparam_value = data["hyperparam_value"] + self.metric_values = data["metric_values"] + self.current_patient = data["current_patient"] + + def dynamic_import_scheduler(module): """Import Scheduler class dynamically. diff --git a/tests/unit/cli/test_cli.sh b/tests/unit/cli/test_cli.sh index 902e9ff1..3a58626d 100755 --- a/tests/unit/cli/test_cli.sh +++ b/tests/unit/cli/test_cli.sh @@ -9,6 +9,10 @@ paddlespeech cls --input ./cat.wav --topk 10 # Punctuation_restoration paddlespeech text --input 今天的天气真不错啊你下午有空吗我想约你一起去吃饭 --model ernie_linear_p3_wudao_fast +# Speech SSL +paddlespeech ssl --task asr --lang en --input ./en.wav +paddlespeech ssl --task vector --lang en --input ./en.wav + # Speech_recognition wget -c https://paddlespeech.bj.bcebos.com/PaddleAudio/zh.wav https://paddlespeech.bj.bcebos.com/PaddleAudio/en.wav paddlespeech asr --input ./zh.wav