diff --git a/README.md b/README.md
index 0a8566940..32e1c23d8 100644
--- a/README.md
+++ b/README.md
@@ -157,11 +157,14 @@ Via the easy-to-use, efficient, flexible and scalable implementation, our vision
- 🧩 *Cascaded models application*: as an extension of the typical traditional audio tasks, we combine the workflows of the aforementioned tasks with other fields like Natural language processing (NLP) and Computer Vision (CV).
### Recent Update
-- 🔥 2022.11.07: [U2/U2++ C++ High Performance Streaming Asr Deployment](https://github.com/PaddlePaddle/PaddleSpeech/blob/develop/speechx/examples/u2pp_ol/wenetspeech).
+- 👑 2022.11.18: Add [Whisper CLI and Demos](https://github.com/PaddlePaddle/PaddleSpeech/pull/2640), support multi language recognition and translation.
+- 🔥 2022.11.18: Add [Wav2vec2 CLI and Demos](https://github.com/PaddlePaddle/PaddleSpeech/blob/develop/demos/speech_ssl), Support ASR and Feature Extraction.
+- 🎉 2022.11.17: Add [male voice for TTS](https://github.com/PaddlePaddle/PaddleSpeech/pull/2660).
+- 🔥 2022.11.07: Add [U2/U2++ C++ High Performance Streaming ASR Deployment](https://github.com/PaddlePaddle/PaddleSpeech/blob/develop/speechx/examples/u2pp_ol/wenetspeech).
- 👑 2022.11.01: Add [Adversarial Loss](https://arxiv.org/pdf/1907.04448.pdf) for [Chinese English mixed TTS](./examples/zh_en_tts/tts3).
- 🔥 2022.10.26: Add [Prosody Prediction](https://github.com/PaddlePaddle/PaddleSpeech/tree/develop/examples/other/rhy) for TTS.
- 🎉 2022.10.21: Add [SSML](https://github.com/PaddlePaddle/PaddleSpeech/discussions/2538) for TTS Chinese Text Frontend.
-- 👑 2022.10.11: Add [Wav2vec2ASR](./examples/librispeech/asr3), wav2vec2.0 fine-tuning for ASR on LibriSpeech.
+- 👑 2022.10.11: Add [Wav2vec2ASR-en](./examples/librispeech/asr3), wav2vec2.0 fine-tuning for ASR on LibriSpeech.
- 🔥 2022.09.26: Add Voice Cloning, TTS finetune, and [ERNIE-SAT](https://arxiv.org/abs/2211.03545) in [PaddleSpeech Web Demo](./demos/speech_web).
- ⚡ 2022.09.09: Add AISHELL-3 Voice Cloning [example](./examples/aishell3/vc2) with ECAPA-TDNN speaker encoder.
- ⚡ 2022.08.25: Release TTS [finetune](./examples/other/tts_finetune/tts3) example.
@@ -978,6 +981,7 @@ You are warmly welcome to submit questions in [discussions](https://github.com/P
- Many thanks to [jerryuhoo](https://github.com/jerryuhoo)/[VTuberTalk](https://github.com/jerryuhoo/VTuberTalk) for developing a GUI tool based on PaddleSpeech TTS and code for making datasets from videos based on PaddleSpeech ASR.
- Many thanks to [vpegasus](https://github.com/vpegasus)/[xuesebot](https://github.com/vpegasus/xuesebot) for developing a rasa chatbot,which is able to speak and listen thanks to PaddleSpeech.
- Many thanks to [chenkui164](https://github.com/chenkui164)/[FastASR](https://github.com/chenkui164/FastASR) for the C++ inference implementation of PaddleSpeech ASR.
+- Many thanks to [heyudage](https://github.com/heyudage)/[VoiceTyping](https://github.com/heyudage/VoiceTyping) for the real-time voice typing tool implementation of PaddleSpeech ASR streaming services.
Besides, PaddleSpeech depends on a lot of open source repositories. See [references](./docs/source/reference.md) for more information.
diff --git a/README_cn.md b/README_cn.md
index 9f33a4cb8..427d59caf 100644
--- a/README_cn.md
+++ b/README_cn.md
@@ -164,10 +164,14 @@
### 近期更新
+- 👑 2022.11.18: 新增 [Whisper CLI 和 Demos](https://github.com/PaddlePaddle/PaddleSpeech/pull/2640),支持多种语言的识别与翻译。
+- 🔥 2022.11.18: 新增 [Wav2vec2 CLI 和 Demos](https://github.com/PaddlePaddle/PaddleSpeech/blob/develop/demos/speech_ssl), 支持 ASR 和 特征提取.
+- 🎉 2022.11.17: TTS 新增[高质量男性音色](https://github.com/PaddlePaddle/PaddleSpeech/pull/2660)。
+- 🔥 2022.11.07: 新增 [U2/U2++ 高性能流式 ASR C++ 部署](https://github.com/PaddlePaddle/PaddleSpeech/blob/develop/speechx/examples/u2pp_ol/wenetspeech)。
- 👑 2022.11.01: [中英文混合 TTS](./examples/zh_en_tts/tts3) 新增 [Adversarial Loss](https://arxiv.org/pdf/1907.04448.pdf) 模块。
- 🔥 2022.10.26: TTS 新增[韵律预测](https://github.com/PaddlePaddle/PaddleSpeech/tree/develop/examples/other/rhy)功能。
- 🎉 2022.10.21: TTS 中文文本前端新增 [SSML](https://github.com/PaddlePaddle/PaddleSpeech/discussions/2538) 功能。
-- 👑 2022.10.11: 新增 [Wav2vec2ASR](./examples/librispeech/asr3), 在 LibriSpeech 上针对 ASR 任务对 wav2vec2.0 的 finetuning。
+- 👑 2022.10.11: 新增 [Wav2vec2ASR-en](./examples/librispeech/asr3), 在 LibriSpeech 上针对 ASR 任务对 wav2vec2.0 的 finetuning。
- 🔥 2022.09.26: 新增 Voice Cloning, TTS finetune 和 [ERNIE-SAT](https://arxiv.org/abs/2211.03545) 到 [PaddleSpeech 网页应用](./demos/speech_web)。
- ⚡ 2022.09.09: 新增基于 ECAPA-TDNN 声纹模型的 AISHELL-3 Voice Cloning [示例](./examples/aishell3/vc2)。
- ⚡ 2022.08.25: 发布 TTS [finetune](./examples/other/tts_finetune/tts3) 示例。
@@ -983,6 +987,7 @@ PaddleSpeech 的 **语音合成** 主要包含三个模块:文本前端、声
- 非常感谢 [vpegasus](https://github.com/vpegasus)/[xuesebot](https://github.com/vpegasus/xuesebot) 基于 PaddleSpeech 的 ASR 与 TTS 设计的可听、说对话机器人。
- 非常感谢 [chenkui164](https://github.com/chenkui164)/[FastASR](https://github.com/chenkui164/FastASR) 对 PaddleSpeech 的 ASR 进行 C++ 推理实现。
+- 非常感谢 [heyudage](https://github.com/heyudage)/[VoiceTyping](https://github.com/heyudage/VoiceTyping) 基于 PaddleSpeech 的 ASR 流式服务实现的实时语音输入法工具。
此外,PaddleSpeech 依赖于许多开源存储库。有关更多信息,请参阅 [references](./docs/source/reference.md)。
diff --git a/demos/README.md b/demos/README.md
index 72b70b237..a41967864 100644
--- a/demos/README.md
+++ b/demos/README.md
@@ -17,3 +17,5 @@ This directory contains many speech applications in multiple scenarios.
* story talker - book reader based on OCR and TTS
* style_fs2 - multi style control for FastSpeech2 model
* text_to_speech - convert text into speech
+* self supervised pretraining - speech feature extraction and speech recognition based on wav2vec2
+* Wishper - speech recognize and translate based on Whisper model
diff --git a/demos/README_cn.md b/demos/README_cn.md
index 04fc1fa7d..ffb028f0e 100644
--- a/demos/README_cn.md
+++ b/demos/README_cn.md
@@ -17,3 +17,5 @@
* 会说话的故事书 - 基于 OCR 和语音合成的会说话的故事书。
* 个性化语音合成 - 基于 FastSpeech2 模型的个性化语音合成。
* 语音合成 - 基于给定的文本生成语音音频。
+* 自监督预训练模型 - 基于wav2vec2的语音特征提取和语音识别。
+* Whisper - 基于Whisper模型的语音识别与翻译。
diff --git a/demos/speech_ssl/README.md b/demos/speech_ssl/README.md
new file mode 100644
index 000000000..b98a7cc61
--- /dev/null
+++ b/demos/speech_ssl/README.md
@@ -0,0 +1,102 @@
+([简体中文](./README_cn.md)|English)
+# Speech SSL (Self-Supervised Learning)
+
+## Introduction
+Speech SSL, or Self-Supervised Learning, refers to a training method on the large-scale unlabeled speech dataset. The model trained in this way can produce a good acoustic representation, and can be applied to other downstream speech tasks by fine-tuning on labeled datasets.
+
+This demo is an implementation to recognize text or produce the acoustic representation from a specific audio file by speech ssl models. It can be done by a single command or a few lines in python using `PaddleSpeech`.
+
+## Usage
+### 1. Installation
+see [installation](https://github.com/PaddlePaddle/PaddleSpeech/blob/develop/docs/source/install.md).
+
+You can choose one way from easy, meduim and hard to install paddlespeech.
+
+### 2. Prepare Input File
+The input of this demo should be a WAV file(`.wav`), and the sample rate must be the same as the model.
+
+Here are sample files for this demo that can be downloaded:
+```bash
+wget -c https://paddlespeech.bj.bcebos.com/PaddleAudio/en.wav
+```
+
+### 3. Usage
+- Command Line(Recommended)
+ ```bash
+ # to recognize text
+ paddlespeech ssl --task asr --lang en --input ./en.wav
+
+ # to get acoustic representation
+ paddlespeech ssl --task vector --lang en --input ./en.wav
+ ```
+
+ Usage:
+ ```bash
+ paddlespeech ssl --help
+ ```
+ Arguments:
+ - `input`(required): Audio file to recognize.
+ - `model`: Model type of asr task. Default: `wav2vec2ASR_librispeech`.
+ - `task`: Output type. Default: `asr`.
+ - `lang`: Model language. Default: `en`.
+ - `sample_rate`: Sample rate of the model. Default: `16000`.
+ - `config`: Config of asr task. Use pretrained model when it is None. Default: `None`.
+ - `ckpt_path`: Model checkpoint. Use pretrained model when it is None. Default: `None`.
+ - `yes`: No additional parameters required. Once set this parameter, it means accepting the request of the program by default, which includes transforming the audio sample rate. Default: `False`.
+ - `device`: Choose device to execute model inference. Default: default device of paddlepaddle in current environment.
+ - `verbose`: Show the log information.
+
+
+- Python API
+ ```python
+ import paddle
+ from paddlespeech.cli.ssl import SSLExecutor
+
+ ssl_executor = SSLExecutor()
+
+ # to recognize text
+ text = ssl_executor(
+ model='wav2vec2ASR_librispeech',
+ task='asr',
+ lang='en',
+ sample_rate=16000,
+ config=None, # Set `config` and `ckpt_path` to None to use pretrained model.
+ ckpt_path=None,
+ audio_file='./en.wav',
+ device=paddle.get_device())
+ print('ASR Result: \n{}'.format(text))
+
+ # to get acoustic representation
+ feature = ssl_executor(
+ model='wav2vec2',
+ task='vector',
+ lang='en',
+ sample_rate=16000,
+ config=None, # Set `config` and `ckpt_path` to None to use pretrained model.
+ ckpt_path=None,
+ audio_file='./en.wav',
+ device=paddle.get_device())
+ print('Representation: \n{}'.format(feature))
+ ```
+
+ Output:
+ ```bash
+ ASR Result:
+ i knocked at the door on the ancient side of the building
+
+ Representation:
+ Tensor(shape=[1, 164, 1024], dtype=float32, place=Place(gpu:0), stop_gradient=True,
+ [[[ 0.02351918, -0.12980647, 0.17868176, ..., 0.10118122,
+ -0.04614586, 0.17853957],
+ [ 0.02361383, -0.12978461, 0.17870593, ..., 0.10103855,
+ -0.04638699, 0.17855372],
+ [ 0.02345137, -0.12982975, 0.17883906, ..., 0.10104341,
+ -0.04643029, 0.17856732],
+ ...,
+ [ 0.02313030, -0.12918393, 0.17845058, ..., 0.10073373,
+ -0.04701405, 0.17862988],
+ [ 0.02176583, -0.12929161, 0.17797582, ..., 0.10097728,
+ -0.04687393, 0.17864393],
+ [ 0.05269200, 0.01297141, -0.23336855, ..., -0.11257174,
+ -0.17227529, 0.20338398]]])
+ ```
diff --git a/demos/speech_ssl/README_cn.md b/demos/speech_ssl/README_cn.md
new file mode 100644
index 000000000..65961ce90
--- /dev/null
+++ b/demos/speech_ssl/README_cn.md
@@ -0,0 +1,103 @@
+(简体中文|[English](./README.md))
+
+# 语音自监督学习
+## 介绍
+语音自监督学习,指的是在大规模无标记的语音数据集上的训练方法。用这种方法训练出来的模型可以产生很好的声学表征。并且可以通过在有标签的数据集上进行微调,应用于其他下游的语音任务。
+
+这个 demo 是通过语音自监督模型将一个特定的音频文件识别成文本或产生声学表征,它可以通过使用 `PaddleSpeech` 的单个命令或 python 中的几行代码来实现。
+
+## 使用方法
+### 1. 安装
+请看[安装文档](https://github.com/PaddlePaddle/PaddleSpeech/blob/develop/docs/source/install_cn.md)。
+
+你可以从 easy,medium,hard 三中方式中选择一种方式安装。
+
+### 2. 准备输入
+这个 demo 的输入应该是一个 WAV 文件(`.wav`),并且采样率必须与模型的采样率相同。
+
+可以下载此 demo 的示例音频:
+```bash
+wget -c https://paddlespeech.bj.bcebos.com/PaddleAudio/en.wav
+```
+### 3. 使用方法
+- 命令行 (推荐使用)
+ ```bash
+
+ # 识别文本
+ paddlespeech ssl --task asr --lang en --input ./en.wav
+
+ # 产生声学表征
+ paddlespeech ssl --task vector --lang en --input ./en.wav
+ ```
+
+ 使用方法:
+ ```bash
+ paddlespeech asr --help
+ ```
+ 参数:
+ - `input`(必须输入):用于识别的音频文件。
+ - `model`:ASR 任务的模型,默认值:`wav2vec2ASR_librispeech`。
+ - `task`:输出类别,默认值:`asr`。
+ - `lang`:模型语言,默认值:`en`。
+ - `sample_rate`:音频采样率,默认值:`16000`。
+ - `config`:ASR 任务的参数文件,若不设置则使用预训练模型中的默认配置,默认值:`None`。
+ - `ckpt_path`:模型参数文件,若不设置则下载预训练模型使用,默认值:`None`。
+ - `yes`;不需要设置额外的参数,一旦设置了该参数,说明你默认同意程序的所有请求,其中包括自动转换输入音频的采样率。默认值:`False`。
+ - `device`:执行预测的设备,默认值:当前系统下 paddlepaddle 的默认 device。
+ - `verbose`: 如果使用,显示 logger 信息。
+
+
+- Python API
+ ```python
+ import paddle
+ from paddlespeech.cli.ssl import SSLExecutor
+
+ ssl_executor = SSLExecutor()
+
+ # 识别文本
+ text = ssl_executor(
+ model='wav2vec2ASR_librispeech',
+ task='asr',
+ lang='en',
+ sample_rate=16000,
+ config=None, # Set `config` and `ckpt_path` to None to use pretrained model.
+ ckpt_path=None,
+ audio_file='./en.wav',
+ device=paddle.get_device())
+ print('ASR Result: \n{}'.format(text))
+
+ # 得到声学表征
+ feature = ssl_executor(
+ model='wav2vec2',
+ task='vector',
+ lang='en',
+ sample_rate=16000,
+ config=None, # Set `config` and `ckpt_path` to None to use pretrained model.
+ ckpt_path=None,
+ audio_file='./en.wav',
+ device=paddle.get_device())
+ print('Representation: \n{}'.format(feature))
+ ```
+
+
+ 输出:
+ ```bash
+ ASR Result:
+ i knocked at the door on the ancient side of the building
+
+ Representation:
+ Tensor(shape=[1, 164, 1024], dtype=float32, place=Place(gpu:0), stop_gradient=True,
+ [[[ 0.02351918, -0.12980647, 0.17868176, ..., 0.10118122,
+ -0.04614586, 0.17853957],
+ [ 0.02361383, -0.12978461, 0.17870593, ..., 0.10103855,
+ -0.04638699, 0.17855372],
+ [ 0.02345137, -0.12982975, 0.17883906, ..., 0.10104341,
+ -0.04643029, 0.17856732],
+ ...,
+ [ 0.02313030, -0.12918393, 0.17845058, ..., 0.10073373,
+ -0.04701405, 0.17862988],
+ [ 0.02176583, -0.12929161, 0.17797582, ..., 0.10097728,
+ -0.04687393, 0.17864393],
+ [ 0.05269200, 0.01297141, -0.23336855, ..., -0.11257174,
+ -0.17227529, 0.20338398]]])
+ ```
diff --git a/demos/speech_ssl/run.sh b/demos/speech_ssl/run.sh
new file mode 100644
index 000000000..ca94bc5cc
--- /dev/null
+++ b/demos/speech_ssl/run.sh
@@ -0,0 +1,10 @@
+#!/bin/bash
+
+# audio download
+wget -c https://paddlespeech.bj.bcebos.com/PaddleAudio/en.wav
+
+# to recognize text
+paddlespeech ssl --task asr --lang en --input ./en.wav
+
+# to get acoustic representation
+paddlespeech ssl --task vector --lang en --input ./en.wav
diff --git a/demos/whisper/README.md b/demos/whisper/README.md
new file mode 100644
index 000000000..9b12554e6
--- /dev/null
+++ b/demos/whisper/README.md
@@ -0,0 +1,95 @@
+([简体中文](./README_cn.md)|English)
+
+## Introduction
+Whisper is a general-purpose speech recognition model. It is trained on a large dataset of diverse audio and is also a multi-task model that can perform multilingual speech recognition as well as speech translation and language identification.
+
+Whisper model trained by OpenAI whisper https://github.com/openai/whisper
+
+## Usage
+ ### 1. Installation
+ see [installation](https://github.com/PaddlePaddle/PaddleSpeech/blob/develop/docs/source/install.md).
+
+ You can choose one way from easy, meduim and hard to install paddlespeech.
+
+ ### 2. Prepare Input File
+ The input of this demo should be a WAV file(`.wav`), and the sample rate must be the same as the model.
+
+ Here are sample files for this demo that can be downloaded:
+ ```bash
+ wget -c https://paddlespeech.bj.bcebos.com/PaddleAudio/zh.wav
+ ```
+
+ ### 3. Usage
+ - Command Line(Recommended)
+ ```bash
+ # to recognize text
+ paddlespeech whisper --task transcribe --input ./zh.wav
+
+ # to change model English-Only base size model
+ paddlespeech whisper --lang en --size base --task transcribe --input ./en.wav
+
+ # to recognize text and translate to English
+ paddlespeech whisper --task translate --input ./zh.wav
+
+ ```
+
+ Usage:
+ ```bash
+ paddlespeech whisper --help
+ ```
+ Arguments:
+ - `input`(required): Audio file to recognize.
+ - `model`: Model type of asr task. Default: `whisper-large`.
+ - `task`: Output type. Default: `transcribe`.
+ - `lang`: Model language. Default: ``. Use `en` to choice English-only model. Now [medium,base,small,tiny] size can support English-only.
+ - `size`: Model size for decode. Defalut: `large`. Now can support [large,medium,base,small,tiny].
+ - `language`: Set decode language. Default: `None`. Forcibly set the recognized language, which is determined by the model itself by default.
+ - `sample_rate`: Sample rate of the model. Default: `16000`. Other sampling rates are not supported now.
+ - `config`: Config of asr task. Use pretrained model when it is None. Default: `None`.
+ - `ckpt_path`: Model checkpoint. Use pretrained model when it is None. Default: `None`.
+ - `yes`: No additional parameters required. Once set this parameter, it means accepting the request of the program by default, which includes transforming the audio sample rate. Default: `False`.
+ - `device`: Choose device to execute model inference. Default: default device of paddlepaddle in current environment.
+ - `verbose`: Show the log information.
+
+
+ - Python API
+ ```python
+ import paddle
+ from paddlespeech.cli.whisper import WhisperExecutor
+
+ whisper_executor = WhisperExecutor()
+
+ # to recognize text
+ text = whisper_executor(
+ model='whisper',
+ task='transcribe',
+ sample_rate=16000,
+ config=None, # Set `config` and `ckpt_path` to None to use pretrained model.
+ ckpt_path=None,
+ audio_file='./zh.wav',
+ device=paddle.get_device())
+ print('ASR Result: \n{}'.format(text))
+
+ # to recognize text and translate to English
+ feature = whisper_executor(
+ model='whisper',
+ task='translate',
+ sample_rate=16000,
+ config=None, # Set `config` and `ckpt_path` to None to use pretrained model.
+ ckpt_path=None,
+ audio_file='./zh.wav',
+ device=paddle.get_device())
+ print('Representation: \n{}'.format(feature))
+ ```
+
+ Output:
+ ```bash
+ Transcribe Result:
+ Detected language: Chinese
+ [00:00.000 --> 00:05.000] 我认为跑步最重要的就是给我带来了身体健康
+ {'text': '我认为跑步最重要的就是给我带来了身体健康', 'segments': [{'id': 0, 'seek': 0, 'start': 0.0, 'end': 5.0, 'text': '我认为跑步最重要的就是给我带来了身体健康', 'tokens': [50364, 1654, 7422, 97, 13992, 32585, 31429, 8661, 24928, 1546, 5620, 49076, 4845, 99, 34912, 19847, 29485, 44201, 6346, 115, 50614], 'temperature': 0.0, 'avg_logprob': -0.23577967557040128, 'compression_ratio': 0.28169014084507044, 'no_speech_prob': 0.028302080929279327}], 'language': 'zh'}
+
+ Translate Result:
+ Detected language: Chinese
+ [00:00.000 --> 00:05.000] I think the most important thing about running is that it brings me good health.
+ {'text': ' I think the most important thing about running is that it brings me good health.', 'segments': [{'id': 0, 'seek': 0, 'start': 0.0, 'end': 5.0, 'text': ' I think the most important thing about running is that it brings me good health.', 'tokens': [50364, 286, 519, 264, 881, 1021, 551, 466, 2614, 307, 300, 309, 5607, 385, 665, 1585, 13, 50614], 'temperature': 0.0, 'avg_logprob': -0.47945233395225123, 'compression_ratio': 1.095890410958904, 'no_speech_prob': 0.028302080929279327}], 'language': 'zh'}
diff --git a/demos/whisper/README_cn.md b/demos/whisper/README_cn.md
new file mode 100644
index 000000000..6f7c35f04
--- /dev/null
+++ b/demos/whisper/README_cn.md
@@ -0,0 +1,96 @@
+(简体中文|[English](./README.md))
+
+# Whisper模型
+## 介绍
+Whisper是一种通用的语音识别模型。它是在多种音频的大数据集上训练的,也是一个多任务模型,可以执行多语言语音识别以及语音翻译和语言识别。
+
+Whisper模型由OpenAI Whisper训练 https://github.com/openai/whisper
+
+## 使用方法
+### 1. 安装
+ 请看[安装文档](https://github.com/PaddlePaddle/PaddleSpeech/blob/develop/docs/source/install_cn.md)。
+
+ 你可以从 easy,medium,hard 三中方式中选择一种方式安装。
+
+### 2. 准备输入
+ 这个 demo 的输入应该是一个 WAV 文件(`.wav`),并且采样率必须与模型的采样率相同。
+
+ 可以下载此 demo 的示例音频:
+ ```bash
+ wget -c https://paddlespeech.bj.bcebos.com/PaddleAudio/zh.wav
+ ```
+
+### 3. 使用方法
+ - 命令行 (推荐使用)
+ ```bash
+
+ # 识别文本
+ paddlespeech whisper --task transcribe --input ./zh.wav
+
+ #选择只支持英文的模型,并且更换不同大小的模型
+ paddlespeech whisper --lang en --size base --task transcribe --input ./en.wav
+
+ # 将语音翻译成英语
+ paddlespeech whisper --task translate --input ./zh.wav
+ ```
+ 使用方法:
+ ```bash
+ paddlespeech whisper --help
+ ```
+ 参数:
+ - `input`(必须输入):用于识别的音频文件。
+ - `model`:ASR 任务的模型,默认值:`whisper-large`。
+ - `task`:输出类别,默认值:`transcribe`。
+ - `lang`: 模型语言,默认值:``,使用`en`选择只支持英文的模型,目前可选择`en`的模型有[medium,base,small,tiny]。
+ - `size`: 模型大小,默认值:`large`,目前支持[large,medium,base,small,tiny]。
+ - `language`:设定解码语言,默认值:`None`,强制设定识别出的语言,默认为模型自行判定。
+ - `sample_rate`:音频采样率,默认值:`16000`,目前Whisper暂不支持其他采样率。
+ - `config`:ASR 任务的参数文件,若不设置则使用预训练模型中的默认配置,默认值:`None`。
+ - `ckpt_path`:模型参数文件,若不设置则下载解码模型使用,默认值:`None`。
+ - `yes`;不需要设置额外的参数,一旦设置了该参数,说明你默认同意程序的所有请求,其中包括自动转换输入音频的采样率。默认值:`False`。
+ - `device`:执行预测的设备,默认值:当前系统下 paddlepaddle 的默认 device。
+ - `verbose`: 如果使用,显示 logger 信息。
+
+
+- Python API
+ ```python
+ import paddle
+ from paddlespeech.cli.whisper import WhisperExecutor
+
+ whisper_executor = WhisperExecutor()
+
+ # 识别文本
+ text = whisper_executor(
+ model='whisper',
+ task='transcribe',
+ sample_rate=16000,
+ config=None, # Set `config` and `ckpt_path` to None to use pretrained model.
+ ckpt_path=None,
+ audio_file='./zh.wav',
+ device=paddle.get_device())
+ print('ASR Result: \n{}'.format(text))
+
+ # 将语音翻译成英语
+ feature = whisper_executor(
+ model='whisper',
+ task='translate',
+ sample_rate=16000,
+ config=None, # Set `config` and `ckpt_path` to None to use pretrained model.
+ ckpt_path=None,
+ audio_file='./zh.wav',
+ device=paddle.get_device())
+ print('Representation: \n{}'.format(feature))
+ ```
+
+
+ 输出:
+ ```bash
+ Transcribe Result:
+ Detected language: Chinese
+ [00:00.000 --> 00:05.000] 我认为跑步最重要的就是给我带来了身体健康
+ {'text': '我认为跑步最重要的就是给我带来了身体健康', 'segments': [{'id': 0, 'seek': 0, 'start': 0.0, 'end': 5.0, 'text': '我认为跑步最重要的就是给我带来了身体健康', 'tokens': [50364, 1654, 7422, 97, 13992, 32585, 31429, 8661, 24928, 1546, 5620, 49076, 4845, 99, 34912, 19847, 29485, 44201, 6346, 115, 50614], 'temperature': 0.0, 'avg_logprob': -0.23577967557040128, 'compression_ratio': 0.28169014084507044, 'no_speech_prob': 0.028302080929279327}], 'language': 'zh'}
+
+ Translate Result:
+ Detected language: Chinese
+ [00:00.000 --> 00:05.000] I think the most important thing about running is that it brings me good health.
+ {'text': ' I think the most important thing about running is that it brings me good health.', 'segments': [{'id': 0, 'seek': 0, 'start': 0.0, 'end': 5.0, 'text': ' I think the most important thing about running is that it brings me good health.', 'tokens': [50364, 286, 519, 264, 881, 1021, 551, 466, 2614, 307, 300, 309, 5607, 385, 665, 1585, 13, 50614], 'temperature': 0.0, 'avg_logprob': -0.47945233395225123, 'compression_ratio': 1.095890410958904, 'no_speech_prob': 0.028302080929279327}], 'language': 'zh'}
diff --git a/demos/whisper/run.sh b/demos/whisper/run.sh
new file mode 100644
index 000000000..b9595735f
--- /dev/null
+++ b/demos/whisper/run.sh
@@ -0,0 +1,13 @@
+#!/bin/bash
+
+# audio download
+wget -c https://paddlespeech.bj.bcebos.com/PaddleAudio/zh.wav https://paddlespeech.bj.bcebos.com/PaddleAudio/en.wav
+
+# to recognize text
+paddlespeech whisper --task transcribe --input ./zh.wav
+
+# to recognize text and translate to English
+paddlespeech whisper --task translate --input ./zh.wav
+
+# to change model English-Only model
+paddlespeech whisper --lang en --size base --task transcribe --input ./en.wav
\ No newline at end of file
diff --git a/docs/source/install.md b/docs/source/install.md
index 20d7895df..aa3d311c7 100644
--- a/docs/source/install.md
+++ b/docs/source/install.md
@@ -12,8 +12,8 @@ There are 3 ways to use `PaddleSpeech`. According to the degree of difficulty, t
- Python >= 3.7
- PaddlePaddle latest version (please refer to the [Installation Guide](https://www.paddlepaddle.org.cn/documentation/docs/en/beginners_guide/index_en.html))
- C++ compilation environment
-- Hip: For Linux and Mac, do not use command `sh` instead of command `bash` in installation document.
-- Hip: We recommand you to install `paddlepaddle` from https://mirror.baidu.com/pypi/simple and install `paddlespeech` from https://pypi.tuna.tsinghua.edu.cn/simple.
+- Tip: For Linux and Mac, do not use command `sh` instead of command `bash` in installation document.
+- Tip: We recommand you to install `paddlepaddle` from https://mirror.baidu.com/pypi/simple and install `paddlespeech` from https://pypi.tuna.tsinghua.edu.cn/simple.
## Easy: Get the Basic Function (Support Linux, Mac, and Windows)
- If you are newer to `PaddleSpeech` and want to experience it easily without your machine. We recommend you to use [AI Studio](https://aistudio.baidu.com/aistudio/index) to experience it. There is a step-by-step [tutorial](https://aistudio.baidu.com/aistudio/education/group/info/25130) for `PaddleSpeech`, and you can use the basic function of `PaddleSpeech` with a free machine.
diff --git a/docs/source/released_model.md b/docs/source/released_model.md
index 79e8f4f46..9f0c2bea6 100644
--- a/docs/source/released_model.md
+++ b/docs/source/released_model.md
@@ -40,36 +40,37 @@ Language Model | Training Data | Token-based | Size | Descriptions
## Text-to-Speech Models
### Acoustic Models
-Model Type | Dataset| Example Link | Pretrained Models|Static/ONNX Models|Size (static)
+Model Type | Dataset| Example Link | Pretrained Models|Static / ONNX / Paddle-Lite Models|Size (static)
:-------------:| :------------:| :-----: | :-----:| :-----:| :-----:
Tacotron2|LJSpeech|[tacotron2-ljspeech](https://github.com/PaddlePaddle/PaddleSpeech/tree/develop/examples/ljspeech/tts0)|[tacotron2_ljspeech_ckpt_0.2.0.zip](https://paddlespeech.bj.bcebos.com/Parakeet/released_models/tacotron2/tacotron2_ljspeech_ckpt_0.2.0.zip)|||
Tacotron2|CSMSC|[tacotron2-csmsc](https://github.com/PaddlePaddle/PaddleSpeech/tree/develop/examples/csmsc/tts0)|[tacotron2_csmsc_ckpt_0.2.0.zip](https://paddlespeech.bj.bcebos.com/Parakeet/released_models/tacotron2/tacotron2_csmsc_ckpt_0.2.0.zip)|[tacotron2_csmsc_static_0.2.0.zip](https://paddlespeech.bj.bcebos.com/Parakeet/released_models/tacotron2/tacotron2_csmsc_static_0.2.0.zip)|103MB|
TransformerTTS| LJSpeech| [transformer-ljspeech](https://github.com/PaddlePaddle/PaddleSpeech/tree/develop/examples/ljspeech/tts1)|[transformer_tts_ljspeech_ckpt_0.4.zip](https://paddlespeech.bj.bcebos.com/Parakeet/released_models/transformer_tts/transformer_tts_ljspeech_ckpt_0.4.zip)|||
-SpeedySpeech| CSMSC | [speedyspeech-csmsc](https://github.com/PaddlePaddle/PaddleSpeech/tree/develop/examples/csmsc/tts2)|[speedyspeech_csmsc_ckpt_0.2.0.zip](https://paddlespeech.bj.bcebos.com/Parakeet/released_models/speedyspeech/speedyspeech_csmsc_ckpt_0.2.0.zip)|[speedyspeech_csmsc_static_0.2.0.zip](https://paddlespeech.bj.bcebos.com/Parakeet/released_models/speedyspeech/speedyspeech_csmsc_static_0.2.0.zip) [speedyspeech_csmsc_onnx_0.2.0.zip](https://paddlespeech.bj.bcebos.com/Parakeet/released_models/speedyspeech/speedyspeech_csmsc_onnx_0.2.0.zip)|13MB|
-FastSpeech2| CSMSC |[fastspeech2-csmsc](https://github.com/PaddlePaddle/PaddleSpeech/tree/develop/examples/csmsc/tts3)|[fastspeech2_nosil_baker_ckpt_0.4.zip](https://paddlespeech.bj.bcebos.com/Parakeet/released_models/fastspeech2/fastspeech2_nosil_baker_ckpt_0.4.zip)|[fastspeech2_csmsc_static_0.2.0.zip](https://paddlespeech.bj.bcebos.com/Parakeet/released_models/fastspeech2/fastspeech2_csmsc_static_0.2.0.zip) [fastspeech2_csmsc_onnx_0.2.0.zip](https://paddlespeech.bj.bcebos.com/Parakeet/released_models/fastspeech2/fastspeech2_csmsc_onnx_0.2.0.zip)|157MB|
+SpeedySpeech| CSMSC | [speedyspeech-csmsc](https://github.com/PaddlePaddle/PaddleSpeech/tree/develop/examples/csmsc/tts2)|[speedyspeech_csmsc_ckpt_0.2.0.zip](https://paddlespeech.bj.bcebos.com/Parakeet/released_models/speedyspeech/speedyspeech_csmsc_ckpt_0.2.0.zip)|[speedyspeech_csmsc_static_0.2.0.zip](https://paddlespeech.bj.bcebos.com/Parakeet/released_models/speedyspeech/speedyspeech_csmsc_static_0.2.0.zip) [speedyspeech_csmsc_onnx_0.2.0.zip](https://paddlespeech.bj.bcebos.com/Parakeet/released_models/speedyspeech/speedyspeech_csmsc_onnx_0.2.0.zip) [speedyspeech_csmsc_pdlite_1.3.0.zip](https://paddlespeech.bj.bcebos.com/Parakeet/released_models/speedyspeech/speedyspeech_csmsc_pdlite_1.3.0.zip)|13MB|
+FastSpeech2| CSMSC |[fastspeech2-csmsc](https://github.com/PaddlePaddle/PaddleSpeech/tree/develop/examples/csmsc/tts3)|[fastspeech2_nosil_baker_ckpt_0.4.zip](https://paddlespeech.bj.bcebos.com/Parakeet/released_models/fastspeech2/fastspeech2_nosil_baker_ckpt_0.4.zip)|[fastspeech2_csmsc_static_0.2.0.zip](https://paddlespeech.bj.bcebos.com/Parakeet/released_models/fastspeech2/fastspeech2_csmsc_static_0.2.0.zip) [fastspeech2_csmsc_onnx_0.2.0.zip](https://paddlespeech.bj.bcebos.com/Parakeet/released_models/fastspeech2/fastspeech2_csmsc_onnx_0.2.0.zip) [fastspeech2_csmsc_pdlite_1.3.0.zip](https://paddlespeech.bj.bcebos.com/Parakeet/released_models/fastspeech2/fastspeech2_csmsc_pdlite_1.3.0.zip)|157MB|
FastSpeech2-Conformer| CSMSC |[fastspeech2-csmsc](https://github.com/PaddlePaddle/PaddleSpeech/tree/develop/examples/csmsc/tts3)|[fastspeech2_conformer_baker_ckpt_0.5.zip](https://paddlespeech.bj.bcebos.com/Parakeet/released_models/fastspeech2/fastspeech2_conformer_baker_ckpt_0.5.zip)|||
-FastSpeech2-CNNDecoder| CSMSC| [fastspeech2-csmsc](https://github.com/PaddlePaddle/PaddleSpeech/tree/develop/examples/csmsc/tts3)| [fastspeech2_cnndecoder_csmsc_ckpt_1.0.0.zip](https://paddlespeech.bj.bcebos.com/Parakeet/released_models/fastspeech2/fastspeech2_cnndecoder_csmsc_ckpt_1.0.0.zip) | [fastspeech2_cnndecoder_csmsc_static_1.0.0.zip](https://paddlespeech.bj.bcebos.com/Parakeet/released_models/fastspeech2/fastspeech2_cnndecoder_csmsc_static_1.0.0.zip) [fastspeech2_cnndecoder_csmsc_streaming_static_1.0.0.zip](https://paddlespeech.bj.bcebos.com/Parakeet/released_models/fastspeech2/fastspeech2_cnndecoder_csmsc_streaming_static_1.0.0.zip) [fastspeech2_cnndecoder_csmsc_onnx_1.0.0.zip](https://paddlespeech.bj.bcebos.com/Parakeet/released_models/fastspeech2/fastspeech2_cnndecoder_csmsc_onnx_1.0.0.zip) [fastspeech2_cnndecoder_csmsc_streaming_onnx_1.0.0.zip](https://paddlespeech.bj.bcebos.com/Parakeet/released_models/fastspeech2/fastspeech2_cnndecoder_csmsc_streaming_onnx_1.0.0.zip) | 84MB|
-FastSpeech2| AISHELL-3 |[fastspeech2-aishell3](https://github.com/PaddlePaddle/PaddleSpeech/tree/develop/examples/aishell3/tts3)|[fastspeech2_aishell3_ckpt_1.1.0.zip](https://paddlespeech.bj.bcebos.com/Parakeet/released_models/fastspeech2/fastspeech2_aishell3_ckpt_1.1.0.zip)|[fastspeech2_aishell3_static_1.1.0.zip](https://paddlespeech.bj.bcebos.com/Parakeet/released_models/fastspeech2/fastspeech2_aishell3_static_1.1.0.zip) [fastspeech2_aishell3_onnx_1.1.0.zip](https://paddlespeech.bj.bcebos.com/Parakeet/released_models/fastspeech2/fastspeech2_aishell3_onnx_1.1.0.zip)|147MB|
-FastSpeech2| LJSpeech |[fastspeech2-ljspeech](https://github.com/PaddlePaddle/PaddleSpeech/tree/develop/examples/ljspeech/tts3)|[fastspeech2_nosil_ljspeech_ckpt_0.5.zip](https://paddlespeech.bj.bcebos.com/Parakeet/released_models/fastspeech2/fastspeech2_nosil_ljspeech_ckpt_0.5.zip)|[fastspeech2_ljspeech_static_1.1.0.zip](https://paddlespeech.bj.bcebos.com/Parakeet/released_models/fastspeech2/fastspeech2_ljspeech_static_1.1.0.zip) [fastspeech2_ljspeech_onnx_1.1.0.zip](https://paddlespeech.bj.bcebos.com/Parakeet/released_models/fastspeech2/fastspeech2_ljspeech_onnx_1.1.0.zip)|145MB|
-FastSpeech2| VCTK |[fastspeech2-vctk](https://github.com/PaddlePaddle/PaddleSpeech/tree/develop/examples/vctk/tts3)|[fastspeech2_vctk_ckpt_1.2.0.zip](https://paddlespeech.bj.bcebos.com/Parakeet/released_models/fastspeech2/fastspeech2_vctk_ckpt_1.2.0.zip)|[fastspeech2_vctk_static_1.1.0.zip](https://paddlespeech.bj.bcebos.com/Parakeet/released_models/fastspeech2/fastspeech2_vctk_static_1.1.0.zip) [fastspeech2_vctk_onnx_1.1.0.zip](https://paddlespeech.bj.bcebos.com/Parakeet/released_models/fastspeech2/fastspeech2_vctk_onnx_1.1.0.zip) | 145MB|
+FastSpeech2-CNNDecoder| CSMSC| [fastspeech2-csmsc](https://github.com/PaddlePaddle/PaddleSpeech/tree/develop/examples/csmsc/tts3)| [fastspeech2_cnndecoder_csmsc_ckpt_1.0.0.zip](https://paddlespeech.bj.bcebos.com/Parakeet/released_models/fastspeech2/fastspeech2_cnndecoder_csmsc_ckpt_1.0.0.zip) | [fastspeech2_cnndecoder_csmsc_static_1.0.0.zip](https://paddlespeech.bj.bcebos.com/Parakeet/released_models/fastspeech2/fastspeech2_cnndecoder_csmsc_static_1.0.0.zip) [fastspeech2_cnndecoder_csmsc_streaming_static_1.0.0.zip](https://paddlespeech.bj.bcebos.com/Parakeet/released_models/fastspeech2/fastspeech2_cnndecoder_csmsc_streaming_static_1.0.0.zip) [fastspeech2_cnndecoder_csmsc_onnx_1.0.0.zip](https://paddlespeech.bj.bcebos.com/Parakeet/released_models/fastspeech2/fastspeech2_cnndecoder_csmsc_onnx_1.0.0.zip) [fastspeech2_cnndecoder_csmsc_streaming_onnx_1.0.0.zip](https://paddlespeech.bj.bcebos.com/Parakeet/released_models/fastspeech2/fastspeech2_cnndecoder_csmsc_streaming_onnx_1.0.0.zip) [fastspeech2_cnndecoder_csmsc_pdlite_1.3.0.zip](https://paddlespeech.bj.bcebos.com/Parakeet/released_models/fastspeech2/fastspeech2_cnndecoder_csmsc_pdlite_1.3.0.zip) [fastspeech2_cnndecoder_csmsc_streaming_pdlite_1.3.0.zip](https://paddlespeech.bj.bcebos.com/Parakeet/released_models/fastspeech2/fastspeech2_cnndecoder_csmsc_streaming_pdlite_1.3.0.zip)| 84MB|
+FastSpeech2| AISHELL-3 |[fastspeech2-aishell3](https://github.com/PaddlePaddle/PaddleSpeech/tree/develop/examples/aishell3/tts3)|[fastspeech2_aishell3_ckpt_1.1.0.zip](https://paddlespeech.bj.bcebos.com/Parakeet/released_models/fastspeech2/fastspeech2_aishell3_ckpt_1.1.0.zip)|[fastspeech2_aishell3_static_1.1.0.zip](https://paddlespeech.bj.bcebos.com/Parakeet/released_models/fastspeech2/fastspeech2_aishell3_static_1.1.0.zip) [fastspeech2_aishell3_onnx_1.1.0.zip](https://paddlespeech.bj.bcebos.com/Parakeet/released_models/fastspeech2/fastspeech2_aishell3_onnx_1.1.0.zip) [fastspeech2_aishell3_pdlite_1.3.0.zip](https://paddlespeech.bj.bcebos.com/Parakeet/released_models/fastspeech2/fastspeech2_aishell3_pdlite_1.3.0.zip) |147MB|
+FastSpeech2| LJSpeech |[fastspeech2-ljspeech](https://github.com/PaddlePaddle/PaddleSpeech/tree/develop/examples/ljspeech/tts3)|[fastspeech2_nosil_ljspeech_ckpt_0.5.zip](https://paddlespeech.bj.bcebos.com/Parakeet/released_models/fastspeech2/fastspeech2_nosil_ljspeech_ckpt_0.5.zip)|[fastspeech2_ljspeech_static_1.1.0.zip](https://paddlespeech.bj.bcebos.com/Parakeet/released_models/fastspeech2/fastspeech2_ljspeech_static_1.1.0.zip) [fastspeech2_ljspeech_onnx_1.1.0.zip](https://paddlespeech.bj.bcebos.com/Parakeet/released_models/fastspeech2/fastspeech2_ljspeech_onnx_1.1.0.zip) [fastspeech2_ljspeech_pdlite_1.3.0.zip](https://paddlespeech.bj.bcebos.com/Parakeet/released_models/fastspeech2/fastspeech2_ljspeech_pdlite_1.3.0.zip)|145MB|
+FastSpeech2| VCTK |[fastspeech2-vctk](https://github.com/PaddlePaddle/PaddleSpeech/tree/develop/examples/vctk/tts3)|[fastspeech2_vctk_ckpt_1.2.0.zip](https://paddlespeech.bj.bcebos.com/Parakeet/released_models/fastspeech2/fastspeech2_vctk_ckpt_1.2.0.zip)|[fastspeech2_vctk_static_1.1.0.zip](https://paddlespeech.bj.bcebos.com/Parakeet/released_models/fastspeech2/fastspeech2_vctk_static_1.1.0.zip) [fastspeech2_vctk_onnx_1.1.0.zip](https://paddlespeech.bj.bcebos.com/Parakeet/released_models/fastspeech2/fastspeech2_vctk_onnx_1.1.0.zip) [fastspeech2_vctk_pdlite_1.3.0.zip](https://paddlespeech.bj.bcebos.com/Parakeet/released_models/fastspeech2/fastspeech2_vctk_pdlite_1.3.0.zip)| 145MB|
FastSpeech2| ZH_EN |[fastspeech2-zh_en](https://github.com/PaddlePaddle/PaddleSpeech/tree/develop/examples/zh_en_tts/tts3)|[fastspeech2_mix_ckpt_1.2.0.zip](https://paddlespeech.bj.bcebos.com/t2s/chinse_english_mixed/models/fastspeech2_mix_ckpt_1.2.0.zip)|[fastspeech2_mix_static_0.2.0.zip](https://paddlespeech.bj.bcebos.com/t2s/chinse_english_mixed/models/fastspeech2_mix_static_0.2.0.zip) [fastspeech2_mix_onnx_0.2.0.zip](https://paddlespeech.bj.bcebos.com/t2s/chinse_english_mixed/models/fastspeech2_mix_onnx_0.2.0.zip) | 145MB|
-
+FastSpeech2| Male ||[fastspeech2_male_ckpt_1.3.0.zip](https://paddlespeech.bj.bcebos.com/Parakeet/released_models/fastspeech2/fastspeech2_male_ckpt_1.3.0.zip)| | |
### Vocoders
-Model Type | Dataset| Example Link | Pretrained Models| Static/ONNX Models|Size (static)
+Model Type | Dataset| Example Link | Pretrained Models| Static / ONNX / Paddle-Lite Models|Size (static)
:-----:| :-----:| :-----: | :-----:| :-----:| :-----:
WaveFlow| LJSpeech |[waveflow-ljspeech](https://github.com/PaddlePaddle/PaddleSpeech/tree/develop/examples/ljspeech/voc0)|[waveflow_ljspeech_ckpt_0.3.zip](https://paddlespeech.bj.bcebos.com/Parakeet/released_models/waveflow/waveflow_ljspeech_ckpt_0.3.zip)|||
-Parallel WaveGAN| CSMSC |[PWGAN-csmsc](https://github.com/PaddlePaddle/PaddleSpeech/tree/develop/examples/csmsc/voc1)|[pwg_baker_ckpt_0.4.zip](https://paddlespeech.bj.bcebos.com/Parakeet/released_models/pwgan/pwg_baker_ckpt_0.4.zip)|[pwg_baker_static_0.4.zip](https://paddlespeech.bj.bcebos.com/Parakeet/released_models/pwgan/pwg_baker_static_0.4.zip) [pwgan_csmsc_onnx_0.2.0.zip](https://paddlespeech.bj.bcebos.com/Parakeet/released_models/pwgan/pwgan_csmsc_onnx_0.2.0.zip)|4.8MB|
-Parallel WaveGAN| LJSpeech |[PWGAN-ljspeech](https://github.com/PaddlePaddle/PaddleSpeech/tree/develop/examples/ljspeech/voc1)|[pwg_ljspeech_ckpt_0.5.zip](https://paddlespeech.bj.bcebos.com/Parakeet/released_models/pwgan/pwg_ljspeech_ckpt_0.5.zip)|[pwgan_ljspeech_static_1.1.0.zip](https://paddlespeech.bj.bcebos.com/Parakeet/released_models/pwgan/pwgan_ljspeech_static_1.1.0.zip) [pwgan_ljspeech_onnx_1.1.0.zip](https://paddlespeech.bj.bcebos.com/Parakeet/released_models/pwgan/pwgan_ljspeech_onnx_1.1.0.zip)|4.8MB|
-Parallel WaveGAN| AISHELL-3 |[PWGAN-aishell3](https://github.com/PaddlePaddle/PaddleSpeech/tree/develop/examples/aishell3/voc1)|[pwg_aishell3_ckpt_0.5.zip](https://paddlespeech.bj.bcebos.com/Parakeet/released_models/pwgan/pwg_aishell3_ckpt_0.5.zip)| [pwgan_aishell3_static_1.1.0.zip](https://paddlespeech.bj.bcebos.com/Parakeet/released_models/pwgan/pwgan_aishell3_static_1.1.0.zip) [pwgan_aishell3_onnx_1.1.0.zip](https://paddlespeech.bj.bcebos.com/Parakeet/released_models/pwgan/pwgan_aishell3_onnx_1.1.0.zip)|4.8MB|
-Parallel WaveGAN| VCTK |[PWGAN-vctk](https://github.com/PaddlePaddle/PaddleSpeech/tree/develop/examples/vctk/voc1)|[pwg_vctk_ckpt_0.5.zip](https://paddlespeech.bj.bcebos.com/Parakeet/released_models/pwgan/pwg_vctk_ckpt_0.5.zip)|[pwgan_vctk_static_1.1.0.zip](https://paddlespeech.bj.bcebos.com/Parakeet/released_models/pwgan/pwgan_vctk_static_1.1.0.zip) [pwgan_vctk_onnx_1.1.0.zip](https://paddlespeech.bj.bcebos.com/Parakeet/released_models/pwgan/pwgan_vctk_onnx_1.1.0.zip)|4.8MB|
-|Multi Band MelGAN | CSMSC |[MB MelGAN-csmsc](https://github.com/PaddlePaddle/PaddleSpeech/tree/develop/examples/csmsc/voc3) | [mb_melgan_csmsc_ckpt_0.1.1.zip](https://paddlespeech.bj.bcebos.com/Parakeet/released_models/mb_melgan/mb_melgan_csmsc_ckpt_0.1.1.zip)
[mb_melgan_baker_finetune_ckpt_0.5.zip](https://paddlespeech.bj.bcebos.com/Parakeet/released_models/mb_melgan/mb_melgan_baker_finetune_ckpt_0.5.zip)|[mb_melgan_csmsc_static_0.1.1.zip](https://paddlespeech.bj.bcebos.com/Parakeet/released_models/mb_melgan/mb_melgan_csmsc_static_0.1.1.zip) [mb_melgan_csmsc_onnx_0.2.0.zip](https://paddlespeech.bj.bcebos.com/Parakeet/released_models/mb_melgan/mb_melgan_csmsc_onnx_0.2.0.zip)|7.6MB|
+Parallel WaveGAN| CSMSC |[PWGAN-csmsc](https://github.com/PaddlePaddle/PaddleSpeech/tree/develop/examples/csmsc/voc1)|[pwg_baker_ckpt_0.4.zip](https://paddlespeech.bj.bcebos.com/Parakeet/released_models/pwgan/pwg_baker_ckpt_0.4.zip)|[pwg_baker_static_0.4.zip](https://paddlespeech.bj.bcebos.com/Parakeet/released_models/pwgan/pwg_baker_static_0.4.zip) [pwgan_csmsc_onnx_0.2.0.zip](https://paddlespeech.bj.bcebos.com/Parakeet/released_models/pwgan/pwgan_csmsc_onnx_0.2.0.zip) [pwgan_csmsc_pdlite_1.3.0.zip](https://paddlespeech.bj.bcebos.com/Parakeet/released_models/pwgan/pwgan_csmsc_pdlite_1.3.0.zip)|4.8MB|
+Parallel WaveGAN| LJSpeech |[PWGAN-ljspeech](https://github.com/PaddlePaddle/PaddleSpeech/tree/develop/examples/ljspeech/voc1)|[pwg_ljspeech_ckpt_0.5.zip](https://paddlespeech.bj.bcebos.com/Parakeet/released_models/pwgan/pwg_ljspeech_ckpt_0.5.zip)|[pwgan_ljspeech_static_1.1.0.zip](https://paddlespeech.bj.bcebos.com/Parakeet/released_models/pwgan/pwgan_ljspeech_static_1.1.0.zip) [pwgan_ljspeech_onnx_1.1.0.zip](https://paddlespeech.bj.bcebos.com/Parakeet/released_models/pwgan/pwgan_ljspeech_onnx_1.1.0.zip) [pwgan_ljspeech_pdlite_1.3.0.zip](https://paddlespeech.bj.bcebos.com/Parakeet/released_models/pwgan/pwgan_ljspeech_pdlite_1.3.0.zip)|4.8MB|
+Parallel WaveGAN| AISHELL-3 |[PWGAN-aishell3](https://github.com/PaddlePaddle/PaddleSpeech/tree/develop/examples/aishell3/voc1)|[pwg_aishell3_ckpt_0.5.zip](https://paddlespeech.bj.bcebos.com/Parakeet/released_models/pwgan/pwg_aishell3_ckpt_0.5.zip)| [pwgan_aishell3_static_1.1.0.zip](https://paddlespeech.bj.bcebos.com/Parakeet/released_models/pwgan/pwgan_aishell3_static_1.1.0.zip) [pwgan_aishell3_onnx_1.1.0.zip](https://paddlespeech.bj.bcebos.com/Parakeet/released_models/pwgan/pwgan_aishell3_onnx_1.1.0.zip) [pwgan_aishell3_pdlite_1.3.0.zip](https://paddlespeech.bj.bcebos.com/Parakeet/released_models/pwgan/pwgan_aishell3_pdlite_1.3.0.zip)|4.8MB|
+Parallel WaveGAN| VCTK |[PWGAN-vctk](https://github.com/PaddlePaddle/PaddleSpeech/tree/develop/examples/vctk/voc1)|[pwg_vctk_ckpt_0.5.zip](https://paddlespeech.bj.bcebos.com/Parakeet/released_models/pwgan/pwg_vctk_ckpt_0.5.zip)|[pwgan_vctk_static_1.1.0.zip](https://paddlespeech.bj.bcebos.com/Parakeet/released_models/pwgan/pwgan_vctk_static_1.1.0.zip) [pwgan_vctk_onnx_1.1.0.zip](https://paddlespeech.bj.bcebos.com/Parakeet/released_models/pwgan/pwgan_vctk_onnx_1.1.0.zip) [pwgan_vctk_pdlite_1.3.0.zip](https://paddlespeech.bj.bcebos.com/Parakeet/released_models/pwgan/pwgan_vctk_pdlite_1.3.0.zip)|4.8MB|
+|Multi Band MelGAN | CSMSC |[MB MelGAN-csmsc](https://github.com/PaddlePaddle/PaddleSpeech/tree/develop/examples/csmsc/voc3) | [mb_melgan_csmsc_ckpt_0.1.1.zip](https://paddlespeech.bj.bcebos.com/Parakeet/released_models/mb_melgan/mb_melgan_csmsc_ckpt_0.1.1.zip)
[mb_melgan_baker_finetune_ckpt_0.5.zip](https://paddlespeech.bj.bcebos.com/Parakeet/released_models/mb_melgan/mb_melgan_baker_finetune_ckpt_0.5.zip)|[mb_melgan_csmsc_static_0.1.1.zip](https://paddlespeech.bj.bcebos.com/Parakeet/released_models/mb_melgan/mb_melgan_csmsc_static_0.1.1.zip) [mb_melgan_csmsc_onnx_0.2.0.zip](https://paddlespeech.bj.bcebos.com/Parakeet/released_models/mb_melgan/mb_melgan_csmsc_onnx_0.2.0.zip) [mb_melgan_csmsc_pdlite_1.3.0.zip](https://paddlespeech.bj.bcebos.com/Parakeet/released_models/mb_melgan/mb_melgan_csmsc_pdlite_1.3.0.zip)|7.6MB|
Style MelGAN | CSMSC |[Style MelGAN-csmsc](https://github.com/PaddlePaddle/PaddleSpeech/tree/develop/examples/csmsc/voc4)|[style_melgan_csmsc_ckpt_0.1.1.zip](https://paddlespeech.bj.bcebos.com/Parakeet/released_models/style_melgan/style_melgan_csmsc_ckpt_0.1.1.zip)| | |
-HiFiGAN | CSMSC |[HiFiGAN-csmsc](https://github.com/PaddlePaddle/PaddleSpeech/tree/develop/examples/csmsc/voc5)|[hifigan_csmsc_ckpt_0.1.1.zip](https://paddlespeech.bj.bcebos.com/Parakeet/released_models/hifigan/hifigan_csmsc_ckpt_0.1.1.zip)|[hifigan_csmsc_static_0.1.1.zip](https://paddlespeech.bj.bcebos.com/Parakeet/released_models/hifigan/hifigan_csmsc_static_0.1.1.zip) [hifigan_csmsc_onnx_0.2.0.zip](https://paddlespeech.bj.bcebos.com/Parakeet/released_models/hifigan/hifigan_csmsc_onnx_0.2.0.zip)|46MB|
-HiFiGAN | LJSpeech |[HiFiGAN-ljspeech](https://github.com/PaddlePaddle/PaddleSpeech/tree/develop/examples/ljspeech/voc5)|[hifigan_ljspeech_ckpt_0.2.0.zip](https://paddlespeech.bj.bcebos.com/Parakeet/released_models/hifigan/hifigan_ljspeech_ckpt_0.2.0.zip)|[hifigan_ljspeech_static_1.1.0.zip](https://paddlespeech.bj.bcebos.com/Parakeet/released_models/hifigan/hifigan_ljspeech_static_1.1.0.zip) [hifigan_ljspeech_onnx_1.1.0.zip](https://paddlespeech.bj.bcebos.com/Parakeet/released_models/hifigan/hifigan_ljspeech_onnx_1.1.0.zip) |49MB|
-HiFiGAN | AISHELL-3 |[HiFiGAN-aishell3](https://github.com/PaddlePaddle/PaddleSpeech/tree/develop/examples/aishell3/voc5)|[hifigan_aishell3_ckpt_0.2.0.zip](https://paddlespeech.bj.bcebos.com/Parakeet/released_models/hifigan/hifigan_aishell3_ckpt_0.2.0.zip)|[hifigan_aishell3_static_1.1.0.zip](https://paddlespeech.bj.bcebos.com/Parakeet/released_models/hifigan/hifigan_aishell3_static_1.1.0.zip) [hifigan_aishell3_onnx_1.1.0.zip](https://paddlespeech.bj.bcebos.com/Parakeet/released_models/hifigan/hifigan_aishell3_onnx_1.1.0.zip)|46MB|
-HiFiGAN | VCTK |[HiFiGAN-vctk](https://github.com/PaddlePaddle/PaddleSpeech/tree/develop/examples/vctk/voc5)|[hifigan_vctk_ckpt_0.2.0.zip](https://paddlespeech.bj.bcebos.com/Parakeet/released_models/hifigan/hifigan_vctk_ckpt_0.2.0.zip)|[hifigan_vctk_static_1.1.0.zip](https://paddlespeech.bj.bcebos.com/Parakeet/released_models/hifigan/hifigan_vctk_static_1.1.0.zip) [hifigan_vctk_onnx_1.1.0.zip](https://paddlespeech.bj.bcebos.com/Parakeet/released_models/hifigan/hifigan_vctk_onnx_1.1.0.zip)|46MB|
+HiFiGAN | CSMSC |[HiFiGAN-csmsc](https://github.com/PaddlePaddle/PaddleSpeech/tree/develop/examples/csmsc/voc5)|[hifigan_csmsc_ckpt_0.1.1.zip](https://paddlespeech.bj.bcebos.com/Parakeet/released_models/hifigan/hifigan_csmsc_ckpt_0.1.1.zip)|[hifigan_csmsc_static_0.1.1.zip](https://paddlespeech.bj.bcebos.com/Parakeet/released_models/hifigan/hifigan_csmsc_static_0.1.1.zip) [hifigan_csmsc_onnx_0.2.0.zip](https://paddlespeech.bj.bcebos.com/Parakeet/released_models/hifigan/hifigan_csmsc_onnx_0.2.0.zip) [hifigan_csmsc_pdlite_1.3.0.zip](https://paddlespeech.bj.bcebos.com/Parakeet/released_models/hifigan/hifigan_csmsc_pdlite_1.3.0.zip)|46MB|
+HiFiGAN | LJSpeech |[HiFiGAN-ljspeech](https://github.com/PaddlePaddle/PaddleSpeech/tree/develop/examples/ljspeech/voc5)|[hifigan_ljspeech_ckpt_0.2.0.zip](https://paddlespeech.bj.bcebos.com/Parakeet/released_models/hifigan/hifigan_ljspeech_ckpt_0.2.0.zip)|[hifigan_ljspeech_static_1.1.0.zip](https://paddlespeech.bj.bcebos.com/Parakeet/released_models/hifigan/hifigan_ljspeech_static_1.1.0.zip) [hifigan_ljspeech_onnx_1.1.0.zip](https://paddlespeech.bj.bcebos.com/Parakeet/released_models/hifigan/hifigan_ljspeech_onnx_1.1.0.zip) [hifigan_ljspeech_pdlite_1.3.0.zip](https://paddlespeech.bj.bcebos.com/Parakeet/released_models/hifigan/hifigan_ljspeech_pdlite_1.3.0.zip) |49MB|
+HiFiGAN | AISHELL-3 |[HiFiGAN-aishell3](https://github.com/PaddlePaddle/PaddleSpeech/tree/develop/examples/aishell3/voc5)|[hifigan_aishell3_ckpt_0.2.0.zip](https://paddlespeech.bj.bcebos.com/Parakeet/released_models/hifigan/hifigan_aishell3_ckpt_0.2.0.zip)|[hifigan_aishell3_static_1.1.0.zip](https://paddlespeech.bj.bcebos.com/Parakeet/released_models/hifigan/hifigan_aishell3_static_1.1.0.zip) [hifigan_aishell3_onnx_1.1.0.zip](https://paddlespeech.bj.bcebos.com/Parakeet/released_models/hifigan/hifigan_aishell3_onnx_1.1.0.zip) [hifigan_aishell3_pdlite_1.3.0.zip](https://paddlespeech.bj.bcebos.com/Parakeet/released_models/hifigan/hifigan_aishell3_pdlite_1.3.0.zip)|46MB|
+HiFiGAN | VCTK |[HiFiGAN-vctk](https://github.com/PaddlePaddle/PaddleSpeech/tree/develop/examples/vctk/voc5)|[hifigan_vctk_ckpt_0.2.0.zip](https://paddlespeech.bj.bcebos.com/Parakeet/released_models/hifigan/hifigan_vctk_ckpt_0.2.0.zip)|[hifigan_vctk_static_1.1.0.zip](https://paddlespeech.bj.bcebos.com/Parakeet/released_models/hifigan/hifigan_vctk_static_1.1.0.zip) [hifigan_vctk_onnx_1.1.0.zip](https://paddlespeech.bj.bcebos.com/Parakeet/released_models/hifigan/hifigan_vctk_onnx_1.1.0.zip) [hifigan_vctk_pdlite_1.3.0.zip](https://paddlespeech.bj.bcebos.com/Parakeet/released_models/hifigan/hifigan_vctk_pdlite_1.3.0.zip)|46MB|
WaveRNN | CSMSC |[WaveRNN-csmsc](https://github.com/PaddlePaddle/PaddleSpeech/tree/develop/examples/csmsc/voc6)|[wavernn_csmsc_ckpt_0.2.0.zip](https://paddlespeech.bj.bcebos.com/Parakeet/released_models/wavernn/wavernn_csmsc_ckpt_0.2.0.zip)|[wavernn_csmsc_static_0.2.0.zip](https://paddlespeech.bj.bcebos.com/Parakeet/released_models/wavernn/wavernn_csmsc_static_0.2.0.zip)|18MB|
+Parallel WaveGAN| Male ||[pwg_male_ckpt_1.3.0.zip](https://paddlespeech.bj.bcebos.com/Parakeet/released_models/pwgan/pwg_male_ckpt_1.3.0.zip)|||
### Voice Cloning
diff --git a/examples/aishell3/tts3/README.md b/examples/aishell3/tts3/README.md
index 3e1dee2fb..49801c4c3 100644
--- a/examples/aishell3/tts3/README.md
+++ b/examples/aishell3/tts3/README.md
@@ -226,6 +226,9 @@ The static model can be downloaded here:
The ONNX model can be downloaded here:
- [fastspeech2_aishell3_onnx_1.1.0.zip](https://paddlespeech.bj.bcebos.com/Parakeet/released_models/fastspeech2/fastspeech2_aishell3_onnx_1.1.0.zip)
+The Paddle-Lite model can be downloaded here:
+- [fastspeech2_aishell3_pdlite_1.3.0.zip](https://paddlespeech.bj.bcebos.com/Parakeet/released_models/fastspeech2/fastspeech2_aishell3_pdlite_1.3.0.zip)
+
FastSpeech2 checkpoint contains files listed below.
```text
diff --git a/examples/aishell3/tts3/local/lite_predict.sh b/examples/aishell3/tts3/local/lite_predict.sh
new file mode 100755
index 000000000..e77e8b6c2
--- /dev/null
+++ b/examples/aishell3/tts3/local/lite_predict.sh
@@ -0,0 +1,32 @@
+#!/bin/bash
+
+train_output_path=$1
+
+stage=0
+stop_stage=0
+
+# pwgan
+if [ ${stage} -le 0 ] && [ ${stop_stage} -ge 0 ]; then
+ python3 ${BIN_DIR}/../lite_predict.py \
+ --inference_dir=${train_output_path}/pdlite \
+ --am=fastspeech2_aishell3 \
+ --voc=pwgan_aishell3 \
+ --text=${BIN_DIR}/../sentences.txt \
+ --output_dir=${train_output_path}/lite_infer_out \
+ --phones_dict=dump/phone_id_map.txt \
+ --speaker_dict=dump/speaker_id_map.txt \
+ --spk_id=0
+fi
+
+# hifigan
+if [ ${stage} -le 1 ] && [ ${stop_stage} -ge 1 ]; then
+ python3 ${BIN_DIR}/../lite_predict.py \
+ --inference_dir=${train_output_path}/pdlite \
+ --am=fastspeech2_aishell3 \
+ --voc=hifigan_aishell3 \
+ --text=${BIN_DIR}/../sentences.txt \
+ --output_dir=${train_output_path}/lite_infer_out \
+ --phones_dict=dump/phone_id_map.txt \
+ --speaker_dict=dump/speaker_id_map.txt \
+ --spk_id=0
+fi
diff --git a/examples/aishell3/tts3/run.sh b/examples/aishell3/tts3/run.sh
index 90b342125..b5da076b2 100755
--- a/examples/aishell3/tts3/run.sh
+++ b/examples/aishell3/tts3/run.sh
@@ -60,11 +60,11 @@ if [ ${stage} -le 6 ] && [ ${stop_stage} -ge 6 ]; then
fi
if [ ${stage} -le 7 ] && [ ${stop_stage} -ge 7 ]; then
- # This model is not supported, because 3 ops are not supported on 'arm'. These unsupported ops are: 'round, set_value, share_data'.
- # This model is not supported, because 4 ops are not supported on 'x86'. These unsupported ops are: 'matmul_v2, round, set_value, share_data'.
- # ./local/export2lite.sh ${train_output_path} inference pdlite fastspeech2_aishell3 x86
- # x86 ok, arm Segmentation fault
- # ./local/export2lite.sh ${train_output_path} inference pdlite pwgan_aishell3 x86
- # x86 ok, arm ok
- ./local/export2lite.sh ${train_output_path} inference pdlite hifigan_aishell3 x86
+ ./local/export2lite.sh ${train_output_path} inference pdlite fastspeech2_aishell3 x86
+ ./local/export2lite.sh ${train_output_path} inference pdlite pwgan_aishell3 x86
+ # ./local/export2lite.sh ${train_output_path} inference pdlite hifigan_aishell3 x86
+fi
+
+if [ ${stage} -le 8 ] && [ ${stop_stage} -ge 8 ]; then
+ CUDA_VISIBLE_DEVICES=${gpus} ./local/lite_predict.sh ${train_output_path} || exit -1
fi
diff --git a/examples/aishell3/voc1/README.md b/examples/aishell3/voc1/README.md
index bc25f43cf..467653cbe 100644
--- a/examples/aishell3/voc1/README.md
+++ b/examples/aishell3/voc1/README.md
@@ -139,6 +139,9 @@ The static model can be downloaded here:
The ONNX model can be downloaded here:
- [pwgan_aishell3_onnx_1.1.0.zip](https://paddlespeech.bj.bcebos.com/Parakeet/released_models/pwgan/pwgan_aishell3_onnx_1.1.0.zip)
+The Paddle-Lite model can be downloaded here:
+- [pwgan_aishell3_pdlite_1.3.0.zip](https://paddlespeech.bj.bcebos.com/Parakeet/released_models/pwgan/pwgan_aishell3_pdlite_1.3.0.zip)
+
Model | Step | eval/generator_loss | eval/log_stft_magnitude_loss:| eval/spectral_convergence_loss
:-------------:| :------------:| :-----: | :-----: | :--------:
default| 1(gpu) x 400000|1.968762|0.759008|0.218524
diff --git a/examples/aishell3/voc5/README.md b/examples/aishell3/voc5/README.md
index 7f99a52e3..7f62ed0d0 100644
--- a/examples/aishell3/voc5/README.md
+++ b/examples/aishell3/voc5/README.md
@@ -122,6 +122,9 @@ The static model can be downloaded here:
The ONNX model can be downloaded here:
- [hifigan_aishell3_onnx_1.1.0.zip](https://paddlespeech.bj.bcebos.com/Parakeet/released_models/hifigan/hifigan_aishell3_onnx_1.1.0.zip)
+The Paddle-Lite model can be downloaded here:
+- [hifigan_aishell3_pdlite_1.3.0.zip](https://paddlespeech.bj.bcebos.com/Parakeet/released_models/hifigan/hifigan_aishell3_pdlite_1.3.0.zip)
+
Model | Step | eval/generator_loss | eval/mel_loss| eval/feature_matching_loss
:-------------:| :------------:| :-----: | :-----: | :--------:
default| 1(gpu) x 2500000|24.060|0.1068|7.499
diff --git a/examples/csmsc/tts2/README.md b/examples/csmsc/tts2/README.md
index f45561719..ec88959d1 100644
--- a/examples/csmsc/tts2/README.md
+++ b/examples/csmsc/tts2/README.md
@@ -230,6 +230,9 @@ The static model can be downloaded here:
The ONNX model can be downloaded here:
- [speedyspeech_csmsc_onnx_0.2.0.zip](https://paddlespeech.bj.bcebos.com/Parakeet/released_models/speedyspeech/speedyspeech_csmsc_onnx_0.2.0.zip)
+The Paddle-Lite model can be downloaded here:
+- [speedyspeech_csmsc_pdlite_1.3.0.zip](https://paddlespeech.bj.bcebos.com/Parakeet/released_models/speedyspeech/speedyspeech_csmsc_pdlite_1.3.0.zip)
+
Model | Step | eval/loss | eval/l1_loss | eval/duration_loss | eval/ssim_loss
:-------------:| :------------:| :-----: | :-----: | :--------:|:--------:
diff --git a/examples/csmsc/tts2/local/lite_predict.sh b/examples/csmsc/tts2/local/lite_predict.sh
new file mode 100755
index 000000000..d0c6c0584
--- /dev/null
+++ b/examples/csmsc/tts2/local/lite_predict.sh
@@ -0,0 +1,43 @@
+#!/bin/bash
+
+train_output_path=$1
+
+stage=0
+stop_stage=0
+
+# pwgan
+if [ ${stage} -le 0 ] && [ ${stop_stage} -ge 0 ]; then
+ python3 ${BIN_DIR}/../lite_predict.py \
+ --inference_dir=${train_output_path}/pdlite \
+ --am=speedyspeech_csmsc \
+ --voc=pwgan_csmsc \
+ --text=${BIN_DIR}/../sentences.txt \
+ --output_dir=${train_output_path}/lite_infer_out \
+ --phones_dict=dump/phone_id_map.txt \
+ --tones_dict=dump/tone_id_map.txt
+fi
+
+# for more GAN Vocoders
+# multi band melgan
+if [ ${stage} -le 1 ] && [ ${stop_stage} -ge 1 ]; then
+ python3 ${BIN_DIR}/../lite_predict.py \
+ --inference_dir=${train_output_path}/pdlite \
+ --am=speedyspeech_csmsc \
+ --voc=mb_melgan_csmsc \
+ --text=${BIN_DIR}/../sentences.txt \
+ --output_dir=${train_output_path}/lite_infer_out \
+ --phones_dict=dump/phone_id_map.txt \
+ --tones_dict=dump/tone_id_map.txt
+fi
+
+# hifigan
+if [ ${stage} -le 2 ] && [ ${stop_stage} -ge 2 ]; then
+ python3 ${BIN_DIR}/../lite_predict.py \
+ --inference_dir=${train_output_path}/pdlite \
+ --am=speedyspeech_csmsc \
+ --voc=hifigan_csmsc \
+ --text=${BIN_DIR}/../sentences.txt \
+ --output_dir=${train_output_path}/lite_infer_out \
+ --phones_dict=dump/phone_id_map.txt \
+ --tones_dict=dump/tone_id_map.txt
+fi
diff --git a/examples/csmsc/tts2/run.sh b/examples/csmsc/tts2/run.sh
index 75fdb2109..1b608992f 100755
--- a/examples/csmsc/tts2/run.sh
+++ b/examples/csmsc/tts2/run.sh
@@ -63,13 +63,12 @@ fi
# must run after stage 3 (which stage generated static models)
if [ ${stage} -le 7 ] && [ ${stop_stage} -ge 7 ]; then
- # This model is not supported, because 3 ops are not supported on 'arm'. These unsupported ops are: 'round, set_value, share_data'.
- # This model is not supported, because 4 ops are not supported on 'x86'. These unsupported ops are: 'matmul_v2, round, set_value, share_data'.
./local/export2lite.sh ${train_output_path} inference pdlite speedyspeech_csmsc x86
- # x86 ok, arm Segmentation fault
- # ./local/export2lite.sh ${train_output_path} inference pdlite pwgan_csmsc x86
- # x86 ok, arm Segmentation fault
+ ./local/export2lite.sh ${train_output_path} inference pdlite pwgan_csmsc x86
# ./local/export2lite.sh ${train_output_path} inference pdlite mb_melgan_csmsc x86
- # x86 ok, arm ok
# ./local/export2lite.sh ${train_output_path} inference pdlite hifigan_csmsc x86
fi
+
+if [ ${stage} -le 8 ] && [ ${stop_stage} -ge 8 ]; then
+ CUDA_VISIBLE_DEVICES=${gpus} ./local/lite_predict.sh ${train_output_path} || exit -1
+fi
diff --git a/examples/csmsc/tts3/README.md b/examples/csmsc/tts3/README.md
index 371034e77..39926259d 100644
--- a/examples/csmsc/tts3/README.md
+++ b/examples/csmsc/tts3/README.md
@@ -238,6 +238,12 @@ The ONNX model can be downloaded here:
- [fastspeech2_cnndecoder_csmsc_onnx_1.0.0.zip](https://paddlespeech.bj.bcebos.com/Parakeet/released_models/fastspeech2/fastspeech2_cnndecoder_csmsc_onnx_1.0.0.zip)
- [fastspeech2_cnndecoder_csmsc_streaming_onnx_1.0.0.zip](https://paddlespeech.bj.bcebos.com/Parakeet/released_models/fastspeech2/fastspeech2_cnndecoder_csmsc_streaming_onnx_1.0.0.zip)
+The Paddle-Lite model can be downloaded here:
+> please compile develop version of Paddle-Lite to export and run TTS models, cause TTS models are supported by https://github.com/PaddlePaddle/Paddle-Lite/pull/9587 and https://github.com/PaddlePaddle/Paddle-Lite/pull/9706
+- [fastspeech2_csmsc_pdlite_1.3.0.zip](https://paddlespeech.bj.bcebos.com/Parakeet/released_models/fastspeech2/fastspeech2_csmsc_pdlite_1.3.0.zip)
+- [fastspeech2_cnndecoder_csmsc_pdlite_1.3.0.zip](https://paddlespeech.bj.bcebos.com/Parakeet/released_models/fastspeech2/fastspeech2_cnndecoder_csmsc_pdlite_1.3.0.zip)
+- [fastspeech2_cnndecoder_csmsc_streaming_pdlite_1.3.0.zip](https://paddlespeech.bj.bcebos.com/Parakeet/released_models/fastspeech2/fastspeech2_cnndecoder_csmsc_streaming_pdlite_1.3.0.zip)
+
Model | Step | eval/loss | eval/l1_loss | eval/duration_loss | eval/pitch_loss| eval/energy_loss
:-------------:| :------------:| :-----: | :-----: | :--------: |:--------:|:---------:
default| 2(gpu) x 76000|1.0991|0.59132|0.035815|0.31915|0.15287|
diff --git a/examples/csmsc/tts3/local/export2lite.sh b/examples/csmsc/tts3/local/export2lite.sh
index f99905cfe..c2687ec73 100755
--- a/examples/csmsc/tts3/local/export2lite.sh
+++ b/examples/csmsc/tts3/local/export2lite.sh
@@ -7,12 +7,12 @@ valid_targets=$5
model_name=${model%_*}
echo model_name: ${model_name}
-
+suffix=${valid_targets%,*}
mkdir -p ${train_output_path}/${output_dir}
paddle_lite_opt \
--model_file ${train_output_path}/${model_dir}/${model}.pdmodel \
--param_file ${train_output_path}/${model_dir}/${model}.pdiparams \
- --optimize_out ${train_output_path}/${output_dir}/${model}_${valid_targets} \
+ --optimize_out ${train_output_path}/${output_dir}/${model}_${suffix} \
--valid_targets ${valid_targets}
diff --git a/examples/csmsc/tts3/local/lite_predict.sh b/examples/csmsc/tts3/local/lite_predict.sh
new file mode 100755
index 000000000..1ed2f108d
--- /dev/null
+++ b/examples/csmsc/tts3/local/lite_predict.sh
@@ -0,0 +1,40 @@
+#!/bin/bash
+
+train_output_path=$1
+
+stage=0
+stop_stage=0
+
+# pwgan
+if [ ${stage} -le 0 ] && [ ${stop_stage} -ge 0 ]; then
+ python3 ${BIN_DIR}/../lite_predict.py \
+ --inference_dir=${train_output_path}/pdlite \
+ --am=fastspeech2_csmsc \
+ --voc=pwgan_csmsc \
+ --text=${BIN_DIR}/../sentences.txt \
+ --output_dir=${train_output_path}/lite_infer_out \
+ --phones_dict=dump/phone_id_map.txt
+fi
+
+# for more GAN Vocoders
+# multi band melgan
+if [ ${stage} -le 1 ] && [ ${stop_stage} -ge 1 ]; then
+ python3 ${BIN_DIR}/../lite_predict.py \
+ --inference_dir=${train_output_path}/pdlite \
+ --am=fastspeech2_csmsc \
+ --voc=mb_melgan_csmsc \
+ --text=${BIN_DIR}/../sentences.txt \
+ --output_dir=${train_output_path}/lite_infer_out \
+ --phones_dict=dump/phone_id_map.txt
+fi
+
+# hifigan
+if [ ${stage} -le 2 ] && [ ${stop_stage} -ge 2 ]; then
+ python3 ${BIN_DIR}/../lite_predict.py \
+ --inference_dir=${train_output_path}/pdlite \
+ --am=fastspeech2_csmsc \
+ --voc=hifigan_csmsc \
+ --text=${BIN_DIR}/../sentences.txt \
+ --output_dir=${train_output_path}/lite_infer_out \
+ --phones_dict=dump/phone_id_map.txt
+fi
diff --git a/examples/csmsc/tts3/local/lite_predict_streaming.sh b/examples/csmsc/tts3/local/lite_predict_streaming.sh
new file mode 100755
index 000000000..4570cb4eb
--- /dev/null
+++ b/examples/csmsc/tts3/local/lite_predict_streaming.sh
@@ -0,0 +1,47 @@
+#!/bin/bash
+
+train_output_path=$1
+
+stage=0
+stop_stage=0
+
+# pwgan
+if [ ${stage} -le 0 ] && [ ${stop_stage} -ge 0 ]; then
+ python3 ${BIN_DIR}/../lite_predict_streaming.py \
+ --inference_dir=${train_output_path}/pdlite_streaming \
+ --am=fastspeech2_csmsc \
+ --am_stat=dump/train/speech_stats.npy \
+ --voc=pwgan_csmsc \
+ --text=${BIN_DIR}/../sentences.txt \
+ --output_dir=${train_output_path}/lite_infer_out_streaming \
+ --phones_dict=dump/phone_id_map.txt \
+ --am_streaming=True
+fi
+
+# for more GAN Vocoders
+# multi band melgan
+if [ ${stage} -le 1 ] && [ ${stop_stage} -ge 1 ]; then
+ python3 ${BIN_DIR}/../lite_predict_streaming.py \
+ --inference_dir=${train_output_path}/pdlite_streaming \
+ --am=fastspeech2_csmsc \
+ --am_stat=dump/train/speech_stats.npy \
+ --voc=mb_melgan_csmsc \
+ --text=${BIN_DIR}/../sentences.txt \
+ --output_dir=${train_output_path}/lite_infer_out_streaming \
+ --phones_dict=dump/phone_id_map.txt \
+ --am_streaming=True
+fi
+
+# hifigan
+if [ ${stage} -le 2 ] && [ ${stop_stage} -ge 2 ]; then
+ python3 ${BIN_DIR}/../lite_predict_streaming.py \
+ --inference_dir=${train_output_path}/pdlite_streaming \
+ --am=fastspeech2_csmsc \
+ --am_stat=dump/train/speech_stats.npy \
+ --voc=hifigan_csmsc \
+ --text=${BIN_DIR}/../sentences.txt \
+ --output_dir=${train_output_path}/lite_infer_out_streaming \
+ --phones_dict=dump/phone_id_map.txt \
+ --am_streaming=True
+fi
+
diff --git a/examples/csmsc/tts3/run.sh b/examples/csmsc/tts3/run.sh
index 8d646ecc3..14308af4e 100755
--- a/examples/csmsc/tts3/run.sh
+++ b/examples/csmsc/tts3/run.sh
@@ -64,13 +64,15 @@ fi
# must run after stage 3 (which stage generated static models)
if [ ${stage} -le 7 ] && [ ${stop_stage} -ge 7 ]; then
- # This model is not supported, because 3 ops are not supported on 'arm'. These unsupported ops are: 'round, set_value, share_data'.
- # This model is not supported, because 4 ops are not supported on 'x86'. These unsupported ops are: 'matmul_v2, round, set_value, share_data'.
+ # NOTE by yuantian 2022.11.21: please compile develop version of Paddle-Lite to export and run TTS models,
+ # cause TTS models are supported by https://github.com/PaddlePaddle/Paddle-Lite/pull/9587
+ # and https://github.com/PaddlePaddle/Paddle-Lite/pull/9706
./local/export2lite.sh ${train_output_path} inference pdlite fastspeech2_csmsc x86
- # x86 ok, arm Segmentation fault
- # ./local/export2lite.sh ${train_output_path} inference pdlite pwgan_csmsc x86
- # x86 ok, arm Segmentation fault
+ ./local/export2lite.sh ${train_output_path} inference pdlite pwgan_csmsc x86
# ./local/export2lite.sh ${train_output_path} inference pdlite mb_melgan_csmsc x86
- # x86 ok, arm ok
# ./local/export2lite.sh ${train_output_path} inference pdlite hifigan_csmsc x86
fi
+
+if [ ${stage} -le 8 ] && [ ${stop_stage} -ge 8 ]; then
+ CUDA_VISIBLE_DEVICES=${gpus} ./local/lite_predict.sh ${train_output_path} || exit -1
+fi
diff --git a/examples/csmsc/tts3/run_cnndecoder.sh b/examples/csmsc/tts3/run_cnndecoder.sh
index 645d1af09..8cc9c5da2 100755
--- a/examples/csmsc/tts3/run_cnndecoder.sh
+++ b/examples/csmsc/tts3/run_cnndecoder.sh
@@ -98,32 +98,27 @@ fi
# must run after stage 3 (which stage generated static models)
if [ ${stage} -le 11 ] && [ ${stop_stage} -ge 11 ]; then
- # This model is not supported, because 3 ops are not supported on 'arm'. These unsupported ops are: 'round, set_value, share_data'.
- # This model is not supported, because 4 ops are not supported on 'x86'. These unsupported ops are: 'matmul_v2, round, set_value, share_data'.
./local/export2lite.sh ${train_output_path} inference pdlite fastspeech2_csmsc x86
- # x86 ok, arm Segmentation fault
- # ./local/export2lite.sh ${train_output_path} inference pdlite pwgan_csmsc x86
- # x86 ok, arm Segmentation fault
+ ./local/export2lite.sh ${train_output_path} inference pdlite pwgan_csmsc x86
# ./local/export2lite.sh ${train_output_path} inference pdlite mb_melgan_csmsc x86
- # x86 ok, arm ok
# ./local/export2lite.sh ${train_output_path} inference pdlite hifigan_csmsc x86
fi
-# must run after stage 5 (which stage generated static models)
if [ ${stage} -le 12 ] && [ ${stop_stage} -ge 12 ]; then
+ CUDA_VISIBLE_DEVICES=${gpus} ./local/lite_predict.sh ${train_output_path} || exit -1
+fi
+
+# must run after stage 5 (which stage generated static models)
+if [ ${stage} -le 13 ] && [ ${stop_stage} -ge 13 ]; then
# streaming acoustic model
- # This model is not supported, because 3 ops are not supported on 'arm'. These unsupported ops are: 'round, set_value, share_data'.
- # This model is not supported, because 4 ops are not supported on 'x86'. These unsupported ops are: 'matmul_v2, round, set_value, share_data'.
- # ./local/export2lite.sh ${train_output_path} inference pdlite fastspeech2_csmsc x86
./local/export2lite.sh ${train_output_path} inference_streaming pdlite_streaming fastspeech2_csmsc_am_encoder_infer x86
- # x86 ok, arm Segmentation fault
./local/export2lite.sh ${train_output_path} inference_streaming pdlite_streaming fastspeech2_csmsc_am_decoder x86
- # x86 ok, arm Segmentation fault
./local/export2lite.sh ${train_output_path} inference_streaming pdlite_streaming fastspeech2_csmsc_am_postnet x86
- # x86 ok, arm Segmentation fault
- # ./local/export2lite.sh ${train_output_path} inference_streaming pdlite_streaming pwgan_csmsc x86
- # x86 ok, arm Segmentation fault
+ ./local/export2lite.sh ${train_output_path} inference_streaming pdlite_streaming pwgan_csmsc x86
# ./local/export2lite.sh ${train_output_path} inference_streaming pdlite_streaming mb_melgan_csmsc x86
- # x86 ok, arm ok
# ./local/export2lite.sh ${train_output_path} inference_streaming pdlite_streaming hifigan_csmsc x86
fi
+
+if [ ${stage} -le 14 ] && [ ${stop_stage} -ge 14 ]; then
+ CUDA_VISIBLE_DEVICES=${gpus} ./local/lite_predict_streaming.sh ${train_output_path} || exit -1
+fi
diff --git a/examples/csmsc/voc1/README.md b/examples/csmsc/voc1/README.md
index 4646a0345..252c2b920 100644
--- a/examples/csmsc/voc1/README.md
+++ b/examples/csmsc/voc1/README.md
@@ -136,6 +136,9 @@ The static model can be downloaded here:
The ONNX model can be downloaded here:
- [pwgan_csmsc_onnx_0.2.0.zip](https://paddlespeech.bj.bcebos.com/Parakeet/released_models/pwgan/pwgan_csmsc_onnx_0.2.0.zip)
+The Paddle-Lite model can be downloaded here:
+- [pwgan_csmsc_pdlite_1.3.0.zip](https://paddlespeech.bj.bcebos.com/Parakeet/released_models/pwgan/pwgan_csmsc_pdlite_1.3.0.zip)
+
Model | Step | eval/generator_loss | eval/log_stft_magnitude_loss| eval/spectral_convergence_loss
:-------------:| :------------:| :-----: | :-----: | :--------:
default| 1(gpu) x 400000|1.948763|0.670098|0.248882
diff --git a/examples/csmsc/voc3/README.md b/examples/csmsc/voc3/README.md
index 09fb8836c..f2a1eef7f 100644
--- a/examples/csmsc/voc3/README.md
+++ b/examples/csmsc/voc3/README.md
@@ -164,6 +164,9 @@ The static model can be downloaded here:
The ONNX model can be downloaded here:
- [mb_melgan_csmsc_onnx_0.2.0.zip](https://paddlespeech.bj.bcebos.com/Parakeet/released_models/mb_melgan/mb_melgan_csmsc_onnx_0.2.0.zip)
+The Paddle-Lite model can be downloaded here:
+- [mb_melgan_csmsc_pdlite_1.3.0.zip](https://paddlespeech.bj.bcebos.com/Parakeet/released_models/mb_melgan/mb_melgan_csmsc_pdlite_1.3.0.zip)
+
Model | Step | eval/generator_loss | eval/log_stft_magnitude_loss|eval/spectral_convergence_loss |eval/sub_log_stft_magnitude_loss|eval/sub_spectral_convergence_loss
:-------------:| :------------:| :-----: | :-----: | :--------:| :--------:| :--------:
default| 1(gpu) x 1000000| 2.4851|0.71778 |0.2761 |0.66334 |0.2777|
diff --git a/examples/csmsc/voc5/README.md b/examples/csmsc/voc5/README.md
index ef552fd30..3347c6473 100644
--- a/examples/csmsc/voc5/README.md
+++ b/examples/csmsc/voc5/README.md
@@ -121,6 +121,9 @@ The static model can be downloaded here:
The ONNX model can be downloaded here:
- [hifigan_csmsc_onnx_0.2.0.zip](https://paddlespeech.bj.bcebos.com/Parakeet/released_models/hifigan/hifigan_csmsc_onnx_0.2.0.zip)
+The Paddle-Lite model can be downloaded here:
+- [hifigan_csmsc_pdlite_1.3.0.zip](https://paddlespeech.bj.bcebos.com/Parakeet/released_models/hifigan/hifigan_csmsc_pdlite_1.3.0.zip)
+
Model | Step | eval/generator_loss | eval/mel_loss| eval/feature_matching_loss
:-------------:| :------------:| :-----: | :-----: | :--------:
default| 1(gpu) x 2500000|24.927|0.1262|7.554
diff --git a/examples/librispeech/asr3/RESULTS.md b/examples/librispeech/asr3/RESULTS.md
index 1c5626d9e..27a87e137 100644
--- a/examples/librispeech/asr3/RESULTS.md
+++ b/examples/librispeech/asr3/RESULTS.md
@@ -1,8 +1,8 @@
# LibriSpeech
## Wav2VecASR
-train: Epoch 1, 1*V100-32G, batchsize:10
+train: Epoch 1, 1*V100-32G, batchsize: 6
| Model | Params | Config | Augmentation| Test set | Decode method | WER |
| --- | --- | --- | --- | --- | --- | --- |
-| wav2vec2ASR | 302.86 M | conf/wav2vec2ASR.yaml | spec_aug | test-clean | greedy search | 0.018887 |
+| wav2vec2ASR | 302.86 M | conf/wav2vec2ASR.yaml | spec_aug | test-clean | greedy search | 0.018906 |
diff --git a/examples/librispeech/asr3/conf/preprocess.yaml b/examples/librispeech/asr3/conf/preprocess.yaml
index 4a908a83b..724782ed6 100644
--- a/examples/librispeech/asr3/conf/preprocess.yaml
+++ b/examples/librispeech/asr3/conf/preprocess.yaml
@@ -1,4 +1,3 @@
process:
# use raw audio
- type: wav_process
- dither: 0.0
diff --git a/examples/librispeech/asr3/conf/wav2vec2ASR.yaml b/examples/librispeech/asr3/conf/wav2vec2ASR.yaml
index c45bd692a..1ce2d94db 100644
--- a/examples/librispeech/asr3/conf/wav2vec2ASR.yaml
+++ b/examples/librispeech/asr3/conf/wav2vec2ASR.yaml
@@ -4,16 +4,21 @@
freeze_wav2vec2: True
normalize_wav: True
output_norm: True
-dnn_blocks: 2
-dnn_neurons: 1024
-blank_id: 0
-ctc_dropout_rate: 0.0
+init_type: 'kaiming_uniform' # !Warning: need to convergence
+enc:
+ input_shape: 1024
+ dnn_blocks: 2
+ dnn_neurons: 1024
+ activation: True
+ctc:
+ enc_n_units: 1024
+ blank_id: 0
+ dropout_rate: 0.0
wav2vec2_params_path: "exp/wav2vec2/wav2vec2-large-960h-lv60-self.pdparams"
############################################
# Wav2Vec2.0 #
############################################
-vocab_size: 32
hidden_size: 1024
num_hidden_layers: 24
num_attention_heads: 16
@@ -54,9 +59,6 @@ diversity_loss_weight: 0.1
ctc_loss_reduction: "sum"
ctc_zero_infinity: False
use_weighted_layer_sum: False
-pad_token_id: 0
-bos_token_id: 1
-eos_token_id: 2
add_adapter: False
adapter_kernel_size: 3
adapter_stride: 2
@@ -78,7 +80,7 @@ unit_type: 'char'
mean_std_filepath: ""
preprocess_config: conf/preprocess.yaml
sortagrad: -1 # Feed samples from shortest to longest ; -1: enabled for all epochs 0: disabled other: enabled for 'other' epochs
-batch_size: 10 # Different batch_size may cause large differences in results
+batch_size: 6 # Different batch_size may cause large differences in results
maxlen_in: 51200000000 # if input length > maxlen-in batchsize is automatically reduced
maxlen_out: 1500000 # if output length > maxlen-out batchsize is automatically reduced
minibatches: 0 # for debug
@@ -106,17 +108,26 @@ audio_augment: # for raw audio
###########################################
n_epoch: 1
accum_grad: 1
-global_grad_clip: 3.0
+global_grad_clip: 5.0
model_optim: adadelta
model_optim_conf:
lr: 0.9
epsilon: 1.0e-6
rho: 0.95
-scheduler: constantlr
-scheduler_conf:
+model_scheduler: constantlr
+model_scheduler_conf:
+ warmup_steps: 25000
+ lr_decay: 1.0
+wav2vec2_optim: adadelta
+wav2vec2_optim_conf:
+ lr: 0.9
+ epsilon: 1.0e-6
+ rho: 0.95
+wav2vec2_scheduler: constantlr
+wav2vec2_scheduler_conf:
warmup_steps: 25000
lr_decay: 1.0
log_interval: 1
checkpoint:
kbest_n: 50
- latest_n: 5
+ latest_n: 5
\ No newline at end of file
diff --git a/examples/librispeech/asr3/local/train.sh b/examples/librispeech/asr3/local/train.sh
index 6913ed17e..24776fd17 100644
--- a/examples/librispeech/asr3/local/train.sh
+++ b/examples/librispeech/asr3/local/train.sh
@@ -10,7 +10,8 @@ echo "using $ngpu gpus..."
config_path=$1
ckpt_name=$2
-ips=$3
+resume=$3
+ips=$4
if [ ! $ips ];then
ips_config=
@@ -21,7 +22,7 @@ fi
mkdir -p exp
# seed may break model convergence
-seed=1998
+seed=1988
if [ ${seed} != 0 ]; then
export FLAGS_cudnn_deterministic=True
fi
@@ -34,13 +35,15 @@ python3 -u ${BIN_DIR}/train.py \
--ngpu ${ngpu} \
--config ${config_path} \
--output exp/${ckpt_name} \
---seed ${seed}
+--seed ${seed} \
+--resume ${resume}
else
python3 -m paddle.distributed.launch --gpus=${CUDA_VISIBLE_DEVICES} ${ips_config} ${BIN_DIR}/train.py \
--ngpu ${ngpu} \
--config ${config_path} \
--output exp/${ckpt_name} \
---seed ${seed}
+--seed ${seed} \
+--resume ${resume}
fi
if [ ${seed} != 0 ]; then
diff --git a/examples/librispeech/asr3/run.sh b/examples/librispeech/asr3/run.sh
index 3b1abb11b..05ad505c7 100644
--- a/examples/librispeech/asr3/run.sh
+++ b/examples/librispeech/asr3/run.sh
@@ -11,7 +11,7 @@ conf_path=conf/wav2vec2ASR.yaml
ips= #xx.xx.xx.xx,xx.xx.xx.xx
decode_conf_path=conf/tuning/decode.yaml
avg_num=1
-dict_path=data/lang_char/vocab.txt
+resume= # xx e.g. 30
. ${MAIN_ROOT}/utils/parse_options.sh || exit 1;
@@ -28,7 +28,7 @@ fi
if [ ${stage} -le 1 ] && [ ${stop_stage} -ge 1 ]; then
# train model, all `ckpt` under `exp` dir
- CUDA_VISIBLE_DEVICES=${gpus} ./local/train.sh ${conf_path} ${ckpt} ${ips}
+ CUDA_VISIBLE_DEVICES=${gpus} ./local/train.sh ${conf_path} ${ckpt} ${resume} ${ips}
fi
if [ ${stage} -le 2 ] && [ ${stop_stage} -ge 2 ]; then
@@ -38,10 +38,10 @@ fi
if [ ${stage} -le 3 ] && [ ${stop_stage} -ge 3 ]; then
# greedy search decoder
- CUDA_VISIBLE_DEVICES=${gpus} ./local/test.sh ${conf_path} ${decode_conf_path} exp/${ckpt}/checkpoints/${avg_ckpt} || exit -1
+ CUDA_VISIBLE_DEVICES=0 ./local/test.sh ${conf_path} ${decode_conf_path} exp/${ckpt}/checkpoints/${avg_ckpt} || exit -1
fi
if [ ${stage} -le 4 ] && [ ${stop_stage} -ge 4 ]; then
# test a single .wav file
- CUDA_VISIBLE_DEVICES=${gpus} ./local/test_wav.sh ${conf_path} ${decode_conf_path} exp/${ckpt}/checkpoints/${avg_ckpt} ${audio_file} || exit -1
+ CUDA_VISIBLE_DEVICES=0 ./local/test_wav.sh ${conf_path} ${decode_conf_path} exp/${ckpt}/checkpoints/${avg_ckpt} ${audio_file} || exit -1
fi
diff --git a/examples/ljspeech/tts3/README.md b/examples/ljspeech/tts3/README.md
index d786c1571..23b433d4e 100644
--- a/examples/ljspeech/tts3/README.md
+++ b/examples/ljspeech/tts3/README.md
@@ -221,6 +221,9 @@ The static model can be downloaded here:
The ONNX model can be downloaded here:
- [fastspeech2_ljspeech_onnx_1.1.0.zip](https://paddlespeech.bj.bcebos.com/Parakeet/released_models/fastspeech2/fastspeech2_ljspeech_onnx_1.1.0.zip)
+The Paddle-Lite model can be downloaded here:
+- [fastspeech2_ljspeech_pdlite_1.3.0.zip](https://paddlespeech.bj.bcebos.com/Parakeet/released_models/fastspeech2/fastspeech2_ljspeech_pdlite_1.3.0.zip)
+
Model | Step | eval/loss | eval/l1_loss | eval/duration_loss | eval/pitch_loss| eval/energy_loss
:-------------:| :------------:| :-----: | :-----: | :--------: |:--------:|:---------:
diff --git a/examples/ljspeech/tts3/local/lite_predict.sh b/examples/ljspeech/tts3/local/lite_predict.sh
new file mode 100755
index 000000000..75db6a0ea
--- /dev/null
+++ b/examples/ljspeech/tts3/local/lite_predict.sh
@@ -0,0 +1,30 @@
+#!/bin/bash
+
+train_output_path=$1
+
+stage=0
+stop_stage=0
+
+# pwgan
+if [ ${stage} -le 0 ] && [ ${stop_stage} -ge 0 ]; then
+ python3 ${BIN_DIR}/../lite_predict.py \
+ --inference_dir=${train_output_path}/pdlite \
+ --am=fastspeech2_ljspeech \
+ --voc=pwgan_ljspeech \
+ --text=${BIN_DIR}/../sentences_en.txt \
+ --output_dir=${train_output_path}/lite_infer_out \
+ --phones_dict=dump/phone_id_map.txt \
+ --lang=en
+fi
+
+# hifigan
+if [ ${stage} -le 1 ] && [ ${stop_stage} -ge 1 ]; then
+ python3 ${BIN_DIR}/../lite_predict.py \
+ --inference_dir=${train_output_path}/pdlite \
+ --am=fastspeech2_ljspeech \
+ --voc=hifigan_ljspeech \
+ --text=${BIN_DIR}/../sentences_en.txt \
+ --output_dir=${train_output_path}/lite_infer_out \
+ --phones_dict=dump/phone_id_map.txt \
+ --lang=en
+fi
diff --git a/examples/ljspeech/tts3/run.sh b/examples/ljspeech/tts3/run.sh
index 7ab591862..aacd4cc03 100755
--- a/examples/ljspeech/tts3/run.sh
+++ b/examples/ljspeech/tts3/run.sh
@@ -62,11 +62,11 @@ fi
# must run after stage 3 (which stage generated static models)
if [ ${stage} -le 7 ] && [ ${stop_stage} -ge 7 ]; then
- # This model is not supported, because 3 ops are not supported on 'arm'. These unsupported ops are: 'round, set_value, share_data'.
- # This model is not supported, because 4 ops are not supported on 'x86'. These unsupported ops are: 'matmul_v2, round, set_value, share_data'.
./local/export2lite.sh ${train_output_path} inference pdlite fastspeech2_ljspeech x86
- # x86 ok, arm Segmentation fault
- # ./local/export2lite.sh ${train_output_path} inference pdlite pwgan_ljspeech x86
- # x86 ok, arm ok
+ ./local/export2lite.sh ${train_output_path} inference pdlite pwgan_ljspeech x86
# ./local/export2lite.sh ${train_output_path} inference pdlite hifigan_ljspeech x86
+fi
+
+if [ ${stage} -le 8 ] && [ ${stop_stage} -ge 8 ]; then
+ CUDA_VISIBLE_DEVICES=${gpus} ./local/lite_predict.sh ${train_output_path} || exit -1
fi
\ No newline at end of file
diff --git a/examples/ljspeech/voc1/README.md b/examples/ljspeech/voc1/README.md
index ad6cd2982..a7ac2af41 100644
--- a/examples/ljspeech/voc1/README.md
+++ b/examples/ljspeech/voc1/README.md
@@ -136,6 +136,9 @@ The static model can be downloaded here:
The ONNX model can be downloaded here:
- [pwgan_ljspeech_onnx_1.1.0.zip](https://paddlespeech.bj.bcebos.com/Parakeet/released_models/pwgan/pwgan_ljspeech_onnx_1.1.0.zip)
+The Paddle-Lite model can be downloaded here:
+- [pwgan_ljspeech_pdlite_1.3.0.zip](https://paddlespeech.bj.bcebos.com/Parakeet/released_models/pwgan/pwgan_ljspeech_pdlite_1.3.0.zip)
+
Parallel WaveGAN checkpoint contains files listed below.
diff --git a/examples/ljspeech/voc5/README.md b/examples/ljspeech/voc5/README.md
index eaa51e507..65fa53267 100644
--- a/examples/ljspeech/voc5/README.md
+++ b/examples/ljspeech/voc5/README.md
@@ -121,6 +121,8 @@ The static model can be downloaded here:
The ONNX model can be downloaded here:
- [hifigan_ljspeech_onnx_1.1.0.zip](https://paddlespeech.bj.bcebos.com/Parakeet/released_models/hifigan/hifigan_ljspeech_onnx_1.1.0.zip)
+The Paddle-Lite model can be downloaded here:
+- [hifigan_ljspeech_pdlite_1.3.0.zip](https://paddlespeech.bj.bcebos.com/Parakeet/released_models/hifigan/hifigan_ljspeech_pdlite_1.3.0.zip)
Model | Step | eval/generator_loss | eval/mel_loss| eval/feature_matching_loss
:-------------:| :------------:| :-----: | :-----: | :--------:
diff --git a/examples/vctk/tts3/README.md b/examples/vctk/tts3/README.md
index 2a2f27fd4..0bf2037f5 100644
--- a/examples/vctk/tts3/README.md
+++ b/examples/vctk/tts3/README.md
@@ -224,6 +224,9 @@ The static model can be downloaded here:
The ONNX model can be downloaded here:
- [fastspeech2_vctk_onnx_1.1.0.zip](https://paddlespeech.bj.bcebos.com/Parakeet/released_models/fastspeech2/fastspeech2_vctk_onnx_1.1.0.zip)
+The Paddle-Lite model can be downloaded here:
+- [fastspeech2_vctk_pdlite_1.3.0.zip](https://paddlespeech.bj.bcebos.com/Parakeet/released_models/fastspeech2/fastspeech2_vctk_pdlite_1.3.0.zip)
+
FastSpeech2 checkpoint contains files listed below.
```text
fastspeech2_vctk_ckpt_1.2.0
diff --git a/examples/vctk/tts3/local/lite_predict.sh b/examples/vctk/tts3/local/lite_predict.sh
new file mode 100755
index 000000000..eb608535b
--- /dev/null
+++ b/examples/vctk/tts3/local/lite_predict.sh
@@ -0,0 +1,34 @@
+#!/bin/bash
+
+train_output_path=$1
+
+stage=0
+stop_stage=0
+
+# pwgan
+if [ ${stage} -le 0 ] && [ ${stop_stage} -ge 0 ]; then
+ python3 ${BIN_DIR}/../lite_predict.py \
+ --inference_dir=${train_output_path}/pdlite \
+ --am=fastspeech2_vctk \
+ --voc=pwgan_vctk \
+ --text=${BIN_DIR}/../sentences_en.txt \
+ --output_dir=${train_output_path}/lite_infer_out \
+ --phones_dict=dump/phone_id_map.txt \
+ --speaker_dict=dump/speaker_id_map.txt \
+ --spk_id=0 \
+ --lang=en
+fi
+
+# hifigan
+if [ ${stage} -le 1 ] && [ ${stop_stage} -ge 1 ]; then
+ python3 ${BIN_DIR}/../lite_predict.py \
+ --inference_dir=${train_output_path}/pdlite \
+ --am=fastspeech2_vctk \
+ --voc=hifigan_vctk \
+ --text=${BIN_DIR}/../sentences_en.txt \
+ --output_dir=${train_output_path}/lite_infer_out \
+ --phones_dict=dump/phone_id_map.txt \
+ --speaker_dict=dump/speaker_id_map.txt \
+ --spk_id=0 \
+ --lang=en
+fi
diff --git a/examples/vctk/tts3/run.sh b/examples/vctk/tts3/run.sh
index 16f1eae18..a112b94b7 100755
--- a/examples/vctk/tts3/run.sh
+++ b/examples/vctk/tts3/run.sh
@@ -61,11 +61,11 @@ fi
# must run after stage 3 (which stage generated static models)
if [ ${stage} -le 7 ] && [ ${stop_stage} -ge 7 ]; then
- # This model is not supported, because 3 ops are not supported on 'arm'. These unsupported ops are: 'round, set_value, share_data'.
- # This model is not supported, because 4 ops are not supported on 'x86'. These unsupported ops are: 'matmul_v2, round, set_value, share_data'.
./local/export2lite.sh ${train_output_path} inference pdlite fastspeech2_vctk x86
- # x86 ok, arm Segmentation fault
- # ./local/export2lite.sh ${train_output_path} inference pdlite pwgan_vctk x86
- # x86 ok, arm ok
+ ./local/export2lite.sh ${train_output_path} inference pdlite pwgan_vctk x86
# ./local/export2lite.sh ${train_output_path} inference pdlite hifigan_vctk x86
fi
+
+if [ ${stage} -le 8 ] && [ ${stop_stage} -ge 8 ]; then
+ CUDA_VISIBLE_DEVICES=${gpus} ./local/lite_predict.sh ${train_output_path} || exit -1
+fi
\ No newline at end of file
diff --git a/examples/vctk/voc1/README.md b/examples/vctk/voc1/README.md
index 2d80e7563..761f9bddb 100644
--- a/examples/vctk/voc1/README.md
+++ b/examples/vctk/voc1/README.md
@@ -141,6 +141,9 @@ The static model can be downloaded here:
The ONNX model can be downloaded here:
- [pwgan_vctk_onnx_1.1.0.zip](https://paddlespeech.bj.bcebos.com/Parakeet/released_models/pwgan/pwgan_vctk_onnx_1.1.0.zip)
+The Paddle-Lite model can be downloaded here:
+- [pwgan_vctk_pdlite_1.3.0.zip](https://paddlespeech.bj.bcebos.com/Parakeet/released_models/pwgan/pwgan_vctk_pdlite_1.3.0.zip)
+
Parallel WaveGAN checkpoint contains files listed below.
diff --git a/examples/vctk/voc5/README.md b/examples/vctk/voc5/README.md
index e937679b5..5a104f56f 100644
--- a/examples/vctk/voc5/README.md
+++ b/examples/vctk/voc5/README.md
@@ -127,6 +127,9 @@ The static model can be downloaded here:
The ONNX model can be downloaded here:
- [hifigan_vctk_onnx_1.1.0.zip](https://paddlespeech.bj.bcebos.com/Parakeet/released_models/hifigan/hifigan_vctk_onnx_1.1.0.zip)
+The Paddle-Lite model can be downloaded here:
+- [hifigan_vctk_pdlite_1.3.0.zip](https://paddlespeech.bj.bcebos.com/Parakeet/released_models/hifigan/hifigan_vctk_pdlite_1.3.0.zip)
+
Model | Step | eval/generator_loss | eval/mel_loss| eval/feature_matching_loss
:-------------:| :------------:| :-----: | :-----: | :--------:
diff --git a/paddlespeech/audio/transform/spectrogram.py b/paddlespeech/audio/transform/spectrogram.py
index cba60cfdb..84812a2cf 100644
--- a/paddlespeech/audio/transform/spectrogram.py
+++ b/paddlespeech/audio/transform/spectrogram.py
@@ -383,7 +383,7 @@ class LogMelSpectrogramKaldi():
class WavProcess():
- def __init__(self, dither=0.0):
+ def __init__(self):
"""
Args:
dither (float): Dithering constant
@@ -391,9 +391,7 @@ class WavProcess():
Returns:
"""
- self.dither = dither
-
- def __call__(self, x, train):
+ def __call__(self, x):
"""
Args:
x (np.ndarray): shape (Ti,)
@@ -405,10 +403,10 @@ class WavProcess():
Returns:
np.ndarray: (T, D)
"""
- dither = self.dither if train else 0.0
if x.ndim != 1:
raise ValueError("Not support x: [Time, Channel]")
- waveform = np.expand_dims(x, -1)
+ waveform = x.astype("float32") / 32768.0
+ waveform = np.expand_dims(waveform, -1)
return waveform
diff --git a/paddlespeech/cli/base_commands.py b/paddlespeech/cli/base_commands.py
index 7210091a9..767d0df78 100644
--- a/paddlespeech/cli/base_commands.py
+++ b/paddlespeech/cli/base_commands.py
@@ -83,7 +83,9 @@ model_name_format = {
'st': 'Model-Source language-Target language',
'text': 'Model-Task-Language',
'tts': 'Model-Language',
- 'vector': 'Model-Sample Rate'
+ 'vector': 'Model-Sample Rate',
+ 'ssl': 'Model-Language-Sample Rate',
+ 'whisper': 'Model-Language-Sample Rate'
}
@@ -94,7 +96,9 @@ class StatsCommand:
def __init__(self):
self.parser = argparse.ArgumentParser(
prog='paddlespeech.stats', add_help=True)
- self.task_choices = ['asr', 'cls', 'st', 'text', 'tts', 'vector', 'kws']
+ self.task_choices = [
+ 'asr', 'cls', 'st', 'text', 'tts', 'vector', 'kws', 'ssl', 'whisper'
+ ]
self.parser.add_argument(
'--task',
type=str,
@@ -141,6 +145,12 @@ _commands = {
'tts': ['Text to Speech infer command.', 'TTSExecutor'],
'vector': ['Speech to vector embedding infer command.', 'VectorExecutor'],
'kws': ['Keyword Spotting infer command.', 'KWSExecutor'],
+ 'ssl':
+ ['Self-Supervised Learning Pretrained model infer command.', 'SSLExecutor'],
+ 'whisper': [
+ 'Whisper model for speech to text or translate speech to English.',
+ 'WhisperExecutor'
+ ]
}
for com, info in _commands.items():
diff --git a/paddlespeech/cli/ssl/__init__.py b/paddlespeech/cli/ssl/__init__.py
new file mode 100644
index 000000000..2e53128ea
--- /dev/null
+++ b/paddlespeech/cli/ssl/__init__.py
@@ -0,0 +1,14 @@
+# Copyright (c) 2021 PaddlePaddle Authors. All Rights Reserved.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+from .infer import SSLExecutor
diff --git a/paddlespeech/cli/ssl/infer.py b/paddlespeech/cli/ssl/infer.py
new file mode 100644
index 000000000..154c25f53
--- /dev/null
+++ b/paddlespeech/cli/ssl/infer.py
@@ -0,0 +1,449 @@
+# Copyright (c) 2021 PaddlePaddle Authors. All Rights Reserved.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+import argparse
+import io
+import os
+import sys
+import time
+from collections import OrderedDict
+from typing import List
+from typing import Optional
+from typing import Union
+
+import librosa
+import numpy as np
+import paddle
+import soundfile
+from yacs.config import CfgNode
+
+from ..executor import BaseExecutor
+from ..log import logger
+from ..utils import CLI_TIMER
+from ..utils import stats_wrapper
+from ..utils import timer_register
+from paddlespeech.audio.transform.transformation import Transformation
+from paddlespeech.s2t.frontend.featurizer.text_featurizer import TextFeaturizer
+from paddlespeech.s2t.utils.utility import UpdateConfig
+
+__all__ = ['SSLExecutor']
+
+
+@timer_register
+class SSLExecutor(BaseExecutor):
+ def __init__(self):
+ super().__init__('ssl')
+ self.parser = argparse.ArgumentParser(
+ prog='paddlespeech.ssl', add_help=True)
+ self.parser.add_argument(
+ '--input', type=str, default=None, help='Audio file to recognize.')
+ self.parser.add_argument(
+ '--model',
+ type=str,
+ default='wav2vec2ASR_librispeech',
+ choices=[
+ tag[:tag.index('-')]
+ for tag in self.task_resource.pretrained_models.keys()
+ ],
+ help='Choose model type of asr task.')
+ self.parser.add_argument(
+ '--task',
+ type=str,
+ default='asr',
+ choices=['asr', 'vector'],
+ help='Choose output type for ssl task')
+ self.parser.add_argument(
+ '--lang',
+ type=str,
+ default='en',
+ help='Choose model language. zh or en, zh:[wav2vec2ASR_aishell1-zh-16k], en:[wav2vec2ASR_librispeech-en-16k]'
+ )
+ self.parser.add_argument(
+ "--sample_rate",
+ type=int,
+ default=16000,
+ choices=[8000, 16000],
+ help='Choose the audio sample rate of the model. 8000 or 16000')
+ self.parser.add_argument(
+ '--config',
+ type=str,
+ default=None,
+ help='Config of asr task. Use deault config when it is None.')
+ self.parser.add_argument(
+ '--decode_method',
+ type=str,
+ default='ctc_greedy_search',
+ choices=[
+ 'ctc_greedy_search',
+ 'ctc_prefix_beam_search',
+ ],
+ help='only support asr task')
+ self.parser.add_argument(
+ '--ckpt_path',
+ type=str,
+ default=None,
+ help='Checkpoint file of model.')
+ self.parser.add_argument(
+ '--yes',
+ '-y',
+ action="store_true",
+ default=False,
+ help='No additional parameters required. \
+ Once set this parameter, it means accepting the request of the program by default, \
+ which includes transforming the audio sample rate')
+ self.parser.add_argument(
+ '--rtf',
+ action="store_true",
+ default=False,
+ help='Show Real-time Factor(RTF).')
+ self.parser.add_argument(
+ '--device',
+ type=str,
+ default=paddle.get_device(),
+ help='Choose device to execute model inference.')
+ self.parser.add_argument(
+ '-d',
+ '--job_dump_result',
+ action='store_true',
+ help='Save job result into file.')
+ self.parser.add_argument(
+ '-v',
+ '--verbose',
+ action='store_true',
+ help='Increase logger verbosity of current task.')
+
+ def _init_from_path(self,
+ model_type: str='wav2vec2ASR_librispeech',
+ task: str='asr',
+ lang: str='en',
+ sample_rate: int=16000,
+ cfg_path: Optional[os.PathLike]=None,
+ decode_method: str='ctc_greedy_search',
+ ckpt_path: Optional[os.PathLike]=None):
+ """
+ Init model and other resources from a specific path.
+ """
+ logger.debug("start to init the model")
+ # default max_len: unit:second
+ self.max_len = 50
+ if hasattr(self, 'model'):
+ logger.debug('Model had been initialized.')
+ return
+ if cfg_path is None or ckpt_path is None:
+ sample_rate_str = '16k' if sample_rate == 16000 else '8k'
+ if task == 'asr':
+ tag = model_type + '-' + lang + '-' + sample_rate_str
+ else:
+ tag = 'wav2vec2' + '-' + lang + '-' + sample_rate_str
+ self.task_resource.set_task_model(tag, version=None)
+ self.res_path = self.task_resource.res_dir
+
+ self.cfg_path = os.path.join(
+ self.res_path, self.task_resource.res_dict['cfg_path'])
+ self.ckpt_path = os.path.join(
+ self.res_path,
+ self.task_resource.res_dict['ckpt_path'] + ".pdparams")
+ logger.debug(self.res_path)
+ else:
+ self.cfg_path = os.path.abspath(cfg_path)
+ self.ckpt_path = os.path.abspath(ckpt_path + ".pdparams")
+ self.res_path = os.path.dirname(
+ os.path.dirname(os.path.abspath(self.cfg_path)))
+ logger.debug(self.cfg_path)
+ logger.debug(self.ckpt_path)
+
+ #Init body.
+ self.config = CfgNode(new_allowed=True)
+ self.config.merge_from_file(self.cfg_path)
+ if task == 'asr':
+ with UpdateConfig(self.config):
+ self.text_feature = TextFeaturizer(
+ unit_type=self.config.unit_type,
+ vocab=self.config.vocab_filepath)
+ self.config.decode.decoding_method = decode_method
+ model_name = model_type[:model_type.rindex(
+ '_')] # model_type: {model_name}_{dataset}
+ else:
+ model_name = 'wav2vec2'
+ model_class = self.task_resource.get_model_class(model_name)
+
+ model_conf = self.config
+ model = model_class.from_config(model_conf)
+ self.model = model
+ self.model.eval()
+
+ # load model
+ model_dict = paddle.load(self.ckpt_path)
+ if task == 'asr':
+ self.model.set_state_dict(model_dict)
+ else:
+ self.model.wav2vec2.set_state_dict(model_dict)
+
+ def preprocess(self, model_type: str, input: Union[str, os.PathLike]):
+ """
+ Input preprocess and return paddle.Tensor stored in self.input.
+ Input content can be a text(tts), a file(asr, cls) or a streaming(not supported yet).
+ """
+
+ audio_file = input
+ if isinstance(audio_file, (str, os.PathLike)):
+ logger.debug("Preprocess audio_file:" + audio_file)
+ elif isinstance(audio_file, io.BytesIO):
+ audio_file.seek(0)
+
+ # Get the object for feature extraction
+ logger.debug("get the preprocess conf")
+ preprocess_conf = self.config.preprocess_config
+ preprocess_args = {"train": False}
+ preprocessing = Transformation(preprocess_conf)
+ logger.debug("read the audio file")
+ audio, audio_sample_rate = soundfile.read(
+ audio_file, dtype="int16", always_2d=True)
+ if self.change_format:
+ if audio.shape[1] >= 2:
+ audio = audio.mean(axis=1, dtype=np.int16)
+ else:
+ audio = audio[:, 0]
+ # pcm16 -> pcm 32
+ audio = self._pcm16to32(audio)
+ audio = librosa.resample(
+ audio, orig_sr=audio_sample_rate, target_sr=self.sample_rate)
+ audio_sample_rate = self.sample_rate
+ # pcm32 -> pcm 16
+ audio = self._pcm32to16(audio)
+ else:
+ audio = audio[:, 0]
+
+ logger.debug(f"audio shape: {audio.shape}")
+ # fbank
+ audio = preprocessing(audio, **preprocess_args)
+
+ audio_len = paddle.to_tensor(audio.shape[0])
+ audio = paddle.to_tensor(audio, dtype='float32').unsqueeze(axis=0)
+
+ self._inputs["audio"] = audio
+ self._inputs["audio_len"] = audio_len
+ logger.debug(f"audio feat shape: {audio.shape}")
+
+ logger.debug("audio feat process success")
+
+ @paddle.no_grad()
+ def infer(self, model_type: str, task: str):
+ """
+ Model inference and result stored in self.output.
+ """
+ logger.debug("start to infer the model to get the output")
+ audio = self._inputs["audio"]
+ if task == 'asr':
+ cfg = self.config.decode
+ logger.debug(
+ f"we will use the wav2vec2ASR like model : {model_type}")
+ try:
+ result_transcripts = self.model.decode(
+ audio,
+ text_feature=self.text_feature,
+ decoding_method=cfg.decoding_method,
+ beam_size=cfg.beam_size)
+ self._outputs["result"] = result_transcripts[0][0]
+ except Exception as e:
+ logger.exception(e)
+ else:
+ logger.debug(
+ "we will use the wav2vec2 like model to extract audio feature")
+ try:
+ out_feature = self.model(audio[:, :, 0])
+ self._outputs["result"] = out_feature[0]
+ except Exception as e:
+ logger.exception(e)
+
+ def postprocess(self) -> Union[str, os.PathLike]:
+ """
+ Output postprocess and return human-readable results such as texts and audio files.
+ """
+ return self._outputs["result"]
+
+ def _pcm16to32(self, audio):
+ assert (audio.dtype == np.int16)
+ audio = audio.astype("float32")
+ bits = np.iinfo(np.int16).bits
+ audio = audio / (2**(bits - 1))
+ return audio
+
+ def _pcm32to16(self, audio):
+ assert (audio.dtype == np.float32)
+ bits = np.iinfo(np.int16).bits
+ audio = audio * (2**(bits - 1))
+ audio = np.round(audio).astype("int16")
+ return audio
+
+ def _check(self, audio_file: str, sample_rate: int, force_yes: bool=False):
+ self.sample_rate = sample_rate
+ if self.sample_rate != 16000 and self.sample_rate != 8000:
+ logger.error(
+ "invalid sample rate, please input --sr 8000 or --sr 16000")
+ return False
+
+ if isinstance(audio_file, (str, os.PathLike)):
+ if not os.path.isfile(audio_file):
+ logger.error("Please input the right audio file path")
+ return False
+ elif isinstance(audio_file, io.BytesIO):
+ audio_file.seek(0)
+
+ logger.debug("checking the audio file format......")
+ try:
+ audio, audio_sample_rate = soundfile.read(
+ audio_file, dtype="int16", always_2d=True)
+ audio_duration = audio.shape[0] / audio_sample_rate
+ if audio_duration > self.max_len:
+ logger.error(
+ f"Please input audio file less then {self.max_len} seconds.\n"
+ )
+ return False
+ except Exception as e:
+ logger.exception(e)
+ logger.error(
+ f"can not open the audio file, please check the audio file({audio_file}) format is 'wav'. \n \
+ you can try to use sox to change the file format.\n \
+ For example: \n \
+ sample rate: 16k \n \
+ sox input_audio.xx --rate 16k --bits 16 --channels 1 output_audio.wav \n \
+ sample rate: 8k \n \
+ sox input_audio.xx --rate 8k --bits 16 --channels 1 output_audio.wav \n \
+ ")
+ return False
+ logger.debug("The sample rate is %d" % audio_sample_rate)
+ if audio_sample_rate != self.sample_rate:
+ logger.warning("The sample rate of the input file is not {}.\n \
+ The program will resample the wav file to {}.\n \
+ If the result does not meet your expectations,\n \
+ Please input the 16k 16 bit 1 channel wav file. \
+ ".format(self.sample_rate, self.sample_rate))
+ if force_yes is False:
+ while (True):
+ logger.debug(
+ "Whether to change the sample rate and the channel. Y: change the sample. N: exit the prgream."
+ )
+ content = input("Input(Y/N):")
+ if content.strip() == "Y" or content.strip(
+ ) == "y" or content.strip() == "yes" or content.strip(
+ ) == "Yes":
+ logger.debug(
+ "change the sampele rate, channel to 16k and 1 channel"
+ )
+ break
+ elif content.strip() == "N" or content.strip(
+ ) == "n" or content.strip() == "no" or content.strip(
+ ) == "No":
+ logger.debug("Exit the program")
+ return False
+ else:
+ logger.warning("Not regular input, please input again")
+
+ self.change_format = True
+ else:
+ logger.debug("The audio file format is right")
+ self.change_format = False
+
+ return True
+
+ def execute(self, argv: List[str]) -> bool:
+ """
+ Command line entry.
+ """
+ parser_args = self.parser.parse_args(argv)
+
+ model = parser_args.model
+ task = parser_args.task
+ lang = parser_args.lang
+ sample_rate = parser_args.sample_rate
+ config = parser_args.config
+ ckpt_path = parser_args.ckpt_path
+ decode_method = parser_args.decode_method
+ force_yes = parser_args.yes
+ rtf = parser_args.rtf
+ device = parser_args.device
+
+ if not parser_args.verbose:
+ self.disable_task_loggers()
+
+ task_source = self.get_input_source(parser_args.input)
+ task_results = OrderedDict()
+ has_exceptions = False
+
+ for id_, input_ in task_source.items():
+ try:
+ res = self(
+ audio_file=input_,
+ model=model,
+ task=task,
+ lang=lang,
+ sample_rate=sample_rate,
+ config=config,
+ ckpt_path=ckpt_path,
+ decode_method=decode_method,
+ force_yes=force_yes,
+ rtf=rtf,
+ device=device)
+ task_results[id_] = res
+
+ except Exception as e:
+ has_exceptions = True
+ task_results[id_] = f'{e.__class__.__name__}: {e}'
+
+ if rtf:
+ self.show_rtf(CLI_TIMER[self.__class__.__name__])
+ self.process_task_results(parser_args.input, task_results,
+ parser_args.job_dump_result)
+ if has_exceptions:
+ return False
+ else:
+ return True
+
+ @stats_wrapper
+ def __call__(self,
+ audio_file: os.PathLike,
+ model: str='wav2vec2ASR_librispeech',
+ task: str='asr',
+ lang: str='en',
+ sample_rate: int=16000,
+ config: os.PathLike=None,
+ ckpt_path: os.PathLike=None,
+ decode_method: str='ctc_greedy_search',
+ force_yes: bool=False,
+ rtf: bool=False,
+ device=paddle.get_device()):
+ """
+ Python API to call an executor.
+ """
+
+ audio_file = os.path.abspath(audio_file)
+ paddle.set_device(device)
+ self._init_from_path(model, task, lang, sample_rate, config,
+ decode_method, ckpt_path)
+ if not self._check(audio_file, sample_rate, force_yes):
+ sys.exit(-1)
+ if rtf:
+ k = self.__class__.__name__
+ CLI_TIMER[k]['start'].append(time.time())
+ self.preprocess(model, audio_file)
+ self.infer(model, task)
+ res = self.postprocess() # Retrieve result of asr.
+
+ if rtf:
+ CLI_TIMER[k]['end'].append(time.time())
+ audio, audio_sample_rate = soundfile.read(
+ audio_file, dtype="int16", always_2d=True)
+ CLI_TIMER[k]['extra'].append(audio.shape[0] / audio_sample_rate)
+
+ return res
diff --git a/paddlespeech/cli/tts/infer.py b/paddlespeech/cli/tts/infer.py
index 3eb597156..707518c05 100644
--- a/paddlespeech/cli/tts/infer.py
+++ b/paddlespeech/cli/tts/infer.py
@@ -67,6 +67,7 @@ class TTSExecutor(BaseExecutor):
'fastspeech2_mix',
'tacotron2_csmsc',
'tacotron2_ljspeech',
+ 'fastspeech2_male',
],
help='Choose acoustic model type of tts task.')
self.parser.add_argument(
@@ -122,6 +123,7 @@ class TTSExecutor(BaseExecutor):
'hifigan_aishell3',
'hifigan_vctk',
'wavernn_csmsc',
+ 'pwgan_male',
],
help='Choose vocoder type of tts task.')
diff --git a/paddlespeech/cli/whisper/__init__.py b/paddlespeech/cli/whisper/__init__.py
new file mode 100644
index 000000000..3bafc10d2
--- /dev/null
+++ b/paddlespeech/cli/whisper/__init__.py
@@ -0,0 +1,14 @@
+# Copyright (c) 2021 PaddlePaddle Authors. All Rights Reserved.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+from .infer import WhisperExecutor
diff --git a/paddlespeech/cli/whisper/infer.py b/paddlespeech/cli/whisper/infer.py
new file mode 100644
index 000000000..c016b453a
--- /dev/null
+++ b/paddlespeech/cli/whisper/infer.py
@@ -0,0 +1,493 @@
+# Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+import argparse
+import io
+import os
+import sys
+import time
+from collections import OrderedDict
+from typing import List
+from typing import Optional
+from typing import Union
+
+import librosa
+import numpy as np
+import paddle
+import soundfile
+from yacs.config import CfgNode
+
+from ...utils.env import DATA_HOME
+from ..download import get_path_from_url
+from ..executor import BaseExecutor
+from ..log import logger
+from ..utils import CLI_TIMER
+from ..utils import stats_wrapper
+from ..utils import timer_register
+from paddlespeech.s2t.models.whisper import log_mel_spectrogram
+from paddlespeech.s2t.models.whisper import ModelDimensions
+from paddlespeech.s2t.models.whisper import Whisper
+from paddlespeech.s2t.models.whisper.tokenizer import LANGUAGES
+from paddlespeech.s2t.models.whisper.tokenizer import TO_LANGUAGE_CODE
+from paddlespeech.s2t.utils.utility import UpdateConfig
+
+__all__ = ['WhisperExecutor']
+
+
+@timer_register
+class WhisperExecutor(BaseExecutor):
+ def __init__(self):
+ super().__init__('whisper')
+ self.parser = argparse.ArgumentParser(
+ prog='paddlespeech.whisper', add_help=True)
+ self.parser.add_argument(
+ '--input', type=str, default=None, help='Audio file to recognize.')
+ self.parser.add_argument(
+ '--model',
+ type=str,
+ default='whisper',
+ choices=['whisper'],
+ help='Choose model type of asr task.')
+ self.parser.add_argument(
+ '--lang',
+ type=str,
+ default='',
+ choices=['', 'en'],
+ help='Choose model language. Default is "", English-only model set [en].'
+ )
+ self.parser.add_argument(
+ '--task',
+ type=str,
+ default='transcribe',
+ choices=["transcribe", "translate"],
+ help='Choose task tpye for transcribe or translate.')
+ self.parser.add_argument(
+ '--size',
+ type=str,
+ default='large',
+ choices=['large', 'medium', 'base', 'small', 'tiny'],
+ help='Choose model size. now only support large, large:[whisper-large-16k]'
+ )
+ self.parser.add_argument(
+ '--language',
+ type=str,
+ default='None',
+ choices=sorted(LANGUAGES.keys()) + sorted(
+ [k.title() for k in TO_LANGUAGE_CODE.keys()]),
+ help='Choose model decode language. Default is None, recognized by model.'
+ )
+ self.parser.add_argument(
+ "--sample_rate",
+ type=int,
+ default=16000,
+ choices=[16000],
+ help='Choose the audio sample rate of the model. only support 16000')
+ self.parser.add_argument(
+ '--config',
+ type=str,
+ default=None,
+ help='Config of asr task. Use deault config when it is None.')
+ self.parser.add_argument(
+ '--decode_method',
+ type=str,
+ default='ctc_prefix_beam_search',
+ choices=['ctc_greedy_search', 'ctc_prefix_beam_search'],
+ help='only support transformer and conformer model')
+ self.parser.add_argument(
+ '--ckpt_path',
+ type=str,
+ default=None,
+ help='Checkpoint file of model.')
+ self.parser.add_argument(
+ '--yes',
+ '-y',
+ action="store_true",
+ default=False,
+ help='No additional parameters required. \
+ Once set this parameter, it means accepting the request of the program by default, \
+ which includes transforming the audio sample rate')
+ self.parser.add_argument(
+ '--rtf',
+ action="store_true",
+ default=False,
+ help='Show Real-time Factor(RTF).')
+ self.parser.add_argument(
+ '--device',
+ type=str,
+ default=paddle.get_device(),
+ help='Choose device to execute model inference.')
+ self.parser.add_argument(
+ '-d',
+ '--job_dump_result',
+ action='store_true',
+ help='Save job result into file.')
+ self.parser.add_argument(
+ '-v',
+ '--verbose',
+ action='store_true',
+ help='Increase logger verbosity of current task.')
+
+ def _init_from_path(self,
+ model_type: str='whisper',
+ lang: str='',
+ task: str='transcribe',
+ size: str='large',
+ language: str='None',
+ sample_rate: int=16000,
+ cfg_path: Optional[os.PathLike]=None,
+ decode_method: str='ctc_prefix_beam_search',
+ num_decoding_left_chunks: int=-1,
+ ckpt_path: Optional[os.PathLike]=None):
+ """
+ Init model and other resources from a specific path.
+ """
+ logger.debug("start to init the model")
+ # default max_len: unit:second
+ self.max_len = 50
+ if hasattr(self, 'model'):
+ logger.debug('Model had been initialized.')
+ return
+
+ if cfg_path is None or ckpt_path is None:
+ sample_rate_str = '16k' if sample_rate == 16000 else '8k'
+ if lang == "":
+ tag = model_type + '-' + size + '-' + sample_rate_str
+ else:
+ tag = model_type + '-' + size + '-' + lang + '-' + sample_rate_str
+ self.task_resource.set_task_model(tag, version=None)
+ self.res_path = self.task_resource.res_dir
+
+ self.cfg_path = os.path.join(
+ self.res_path, self.task_resource.res_dict['cfg_path'])
+ self.ckpt_path = os.path.join(
+ self.res_path,
+ self.task_resource.res_dict['ckpt_path'] + ".pdparams")
+ logger.debug(self.res_path)
+
+ else:
+ self.cfg_path = os.path.abspath(cfg_path)
+ self.ckpt_path = os.path.abspath(ckpt_path + ".pdparams")
+ self.res_path = os.path.dirname(
+ os.path.dirname(os.path.abspath(self.cfg_path)))
+ logger.debug(self.cfg_path)
+ logger.debug(self.ckpt_path)
+
+ #Init body.
+ self.config = CfgNode(new_allowed=True)
+ self.config.merge_from_file(self.cfg_path)
+
+ with UpdateConfig(self.config):
+ if "whisper" in model_type:
+ resource_url = self.task_resource.res_dict['resource_data']
+ resource_md5 = self.task_resource.res_dict['resource_data_md5']
+
+ self.resource_path = os.path.join(
+ DATA_HOME, self.task_resource.version, 'whisper')
+ self.download_resource(resource_url, self.resource_path,
+ resource_md5)
+ else:
+ raise Exception("wrong type")
+
+ # load model
+ model_dict = paddle.load(self.ckpt_path)
+ dims = ModelDimensions(**model_dict["dims"])
+ self.model = Whisper(dims)
+ self.model.load_dict(model_dict)
+ self.model.eval()
+
+ #set task
+ if task is not None:
+ self.task = task
+
+ #set language
+ if language is not None:
+ if lang == 'en' and language != 'en':
+ logger.info(
+ "{tag} is an English-only model, set language=English .")
+ self.language = 'en'
+ else:
+ self.language = language
+
+ def preprocess(self, model_type: str, input: Union[str, os.PathLike]):
+ """
+ Input preprocess and return paddle.Tensor stored in self.input.
+ Input content can be a text(tts), a file(asr, cls) or a streaming(not supported yet).
+ """
+
+ audio_file = input
+ if isinstance(audio_file, (str, os.PathLike)):
+ logger.debug("Preprocess audio_file:" + audio_file)
+ elif isinstance(audio_file, io.BytesIO):
+ audio_file.seek(0)
+
+ # Get the object for feature extraction
+ # whisper hard-coded audio hyperparameters, params in paddlespeech/s2t/models/whisper/whisper.py
+ logger.debug("read the audio file")
+ audio, audio_sample_rate = soundfile.read(
+ audio_file, dtype="float32", always_2d=True)
+ if self.change_format:
+ if audio.shape[1] >= 2:
+ audio = audio.mean(axis=1, dtype=np.int16)
+ else:
+ audio = audio[:, 0]
+ # pcm16 -> pcm 32
+ audio = self._pcm16to32(audio)
+ audio = librosa.resample(
+ audio, orig_sr=audio_sample_rate, target_sr=self.sample_rate)
+ audio_sample_rate = self.sample_rate
+ # pcm32 -> pcm 16
+ audio = self._pcm32to16(audio)
+ else:
+ audio = audio[:, 0]
+
+ logger.debug(f"audio shape: {audio.shape}")
+ # fbank
+ audio = log_mel_spectrogram(audio, resource_path=self.resource_path)
+
+ audio_len = paddle.to_tensor(audio.shape[0])
+
+ self._inputs["audio"] = audio
+ self._inputs["audio_len"] = audio_len
+ logger.debug(f"audio feat shape: {audio.shape}")
+
+ logger.debug("audio feat process success")
+
+ @paddle.no_grad()
+ def infer(self, model_type: str):
+ """
+ Model inference and result stored in self.output.
+ """
+ logger.debug("start to infer the model to get the output")
+ cfg = self.config
+ audio = self._inputs["audio"]
+ if cfg.temperature_increment_on_fallback is not None:
+ temperature = tuple(
+ np.arange(cfg.temperature, 1.0 + 1e-6,
+ cfg.temperature_increment_on_fallback))
+ else:
+ temperature = [cfg.temperature]
+
+ self._outputs["result"] = self.model.transcribe(
+ audio,
+ verbose=cfg.verbose,
+ task=self.task,
+ language=self.language,
+ resource_path=self.resource_path,
+ temperature=temperature,
+ compression_ratio_threshold=cfg.compression_ratio_threshold,
+ logprob_threshold=cfg.logprob_threshold,
+ best_of=cfg.best_of,
+ beam_size=cfg.beam_size,
+ patience=cfg.patience,
+ length_penalty=cfg.length_penalty,
+ initial_prompt=cfg.initial_prompt,
+ condition_on_previous_text=cfg.condition_on_previous_text,
+ no_speech_threshold=cfg.no_speech_threshold)
+
+ def postprocess(self) -> Union[str, os.PathLike]:
+ """
+ Output postprocess and return human-readable results such as texts and audio files.
+ """
+ return self._outputs["result"]
+
+ def download_resource(self, url, lm_dir, md5sum):
+ download_path = get_path_from_url(
+ url=url,
+ root_dir=lm_dir,
+ md5sum=md5sum,
+ decompress=True, )
+
+ def _pcm16to32(self, audio):
+ assert (audio.dtype == np.int16)
+ audio = audio.astype("float32")
+ bits = np.iinfo(np.int16).bits
+ audio = audio / (2**(bits - 1))
+ return audio
+
+ def _pcm32to16(self, audio):
+ assert (audio.dtype == np.float32)
+ bits = np.iinfo(np.int16).bits
+ audio = audio * (2**(bits - 1))
+ audio = np.round(audio).astype("int16")
+ return audio
+
+ def _check(self, audio_file: str, sample_rate: int, force_yes: bool=False):
+ self.sample_rate = sample_rate
+ if self.sample_rate != 16000 and self.sample_rate != 8000:
+ logger.error(
+ "invalid sample rate, please input --sr 8000 or --sr 16000")
+ return False
+
+ if isinstance(audio_file, (str, os.PathLike)):
+ if not os.path.isfile(audio_file):
+ logger.error("Please input the right audio file path")
+ return False
+ elif isinstance(audio_file, io.BytesIO):
+ audio_file.seek(0)
+
+ logger.debug("checking the audio file format......")
+ try:
+ audio, audio_sample_rate = soundfile.read(
+ audio_file, dtype="int16", always_2d=True)
+ audio_duration = audio.shape[0] / audio_sample_rate
+ if audio_duration > self.max_len:
+ logger.error(
+ f"Please input audio file less then {self.max_len} seconds.\n"
+ )
+ return False
+ except Exception as e:
+ logger.exception(e)
+ logger.error(
+ f"can not open the audio file, please check the audio file({audio_file}) format is 'wav'. \n \
+ you can try to use sox to change the file format.\n \
+ For example: \n \
+ sample rate: 16k \n \
+ sox input_audio.xx --rate 16k --bits 16 --channels 1 output_audio.wav \n \
+ sample rate: 8k \n \
+ sox input_audio.xx --rate 8k --bits 16 --channels 1 output_audio.wav \n \
+ ")
+ return False
+ logger.debug("The sample rate is %d" % audio_sample_rate)
+ if audio_sample_rate != self.sample_rate:
+ logger.warning("The sample rate of the input file is not {}.\n \
+ The program will resample the wav file to {}.\n \
+ If the result does not meet your expectations,\n \
+ Please input the 16k 16 bit 1 channel wav file. \
+ ".format(self.sample_rate, self.sample_rate))
+ if force_yes is False:
+ while (True):
+ logger.debug(
+ "Whether to change the sample rate and the channel. Y: change the sample. N: exit the prgream."
+ )
+ content = input("Input(Y/N):")
+ if content.strip() == "Y" or content.strip(
+ ) == "y" or content.strip() == "yes" or content.strip(
+ ) == "Yes":
+ logger.debug(
+ "change the sampele rate, channel to 16k and 1 channel"
+ )
+ break
+ elif content.strip() == "N" or content.strip(
+ ) == "n" or content.strip() == "no" or content.strip(
+ ) == "No":
+ logger.debug("Exit the program")
+ return False
+ else:
+ logger.warning("Not regular input, please input again")
+
+ self.change_format = True
+ else:
+ logger.debug("The audio file format is right")
+ self.change_format = False
+
+ return True
+
+ def execute(self, argv: List[str]) -> bool:
+ """
+ Command line entry.
+ """
+ parser_args = self.parser.parse_args(argv)
+
+ model = parser_args.model
+ lang = parser_args.lang
+ task = parser_args.task
+ size = parser_args.size
+ language = parser_args.language
+ sample_rate = parser_args.sample_rate
+ config = parser_args.config
+ ckpt_path = parser_args.ckpt_path
+ decode_method = parser_args.decode_method
+ force_yes = parser_args.yes
+ rtf = parser_args.rtf
+ device = parser_args.device
+
+ if not parser_args.verbose:
+ self.disable_task_loggers()
+
+ task_source = self.get_input_source(parser_args.input)
+ task_results = OrderedDict()
+ has_exceptions = False
+
+ for id_, input_ in task_source.items():
+ try:
+ res = self(
+ audio_file=input_,
+ model=model,
+ lang=lang,
+ task=task,
+ size=size,
+ language=language,
+ sample_rate=sample_rate,
+ config=config,
+ ckpt_path=ckpt_path,
+ decode_method=decode_method,
+ force_yes=force_yes,
+ rtf=rtf,
+ device=device)
+ task_results[id_] = res
+ except Exception as e:
+ has_exceptions = True
+ task_results[id_] = f'{e.__class__.__name__}: {e}'
+
+ if rtf:
+ self.show_rtf(CLI_TIMER[self.__class__.__name__])
+
+ self.process_task_results(parser_args.input, task_results,
+ parser_args.job_dump_result)
+
+ if has_exceptions:
+ return False
+ else:
+ return True
+
+ @stats_wrapper
+ def __call__(self,
+ audio_file: os.PathLike,
+ model: str='whisper',
+ lang: str='',
+ task: str='transcribe',
+ size: str='large',
+ language: str='None',
+ sample_rate: int=16000,
+ config: os.PathLike=None,
+ ckpt_path: os.PathLike=None,
+ decode_method: str='attention_rescoring',
+ num_decoding_left_chunks: int=-1,
+ force_yes: bool=False,
+ rtf: bool=False,
+ device=paddle.get_device()):
+ """
+ Python API to call an executor.
+ """
+ audio_file = os.path.abspath(audio_file)
+ paddle.set_device(device)
+ self._init_from_path(model, lang, task, size, language, sample_rate,
+ config, decode_method, num_decoding_left_chunks,
+ ckpt_path)
+ if not self._check(audio_file, sample_rate, force_yes):
+ sys.exit(-1)
+ if rtf:
+ k = self.__class__.__name__
+ CLI_TIMER[k]['start'].append(time.time())
+
+ self.preprocess(model, audio_file)
+ self.infer(model)
+ res = self.postprocess() # Retrieve result of asr.
+
+ if rtf:
+ CLI_TIMER[k]['end'].append(time.time())
+ audio, audio_sample_rate = soundfile.read(
+ audio_file, dtype="int16", always_2d=True)
+ CLI_TIMER[k]['extra'].append(audio.shape[0] / audio_sample_rate)
+
+ return res
diff --git a/paddlespeech/resource/model_alias.py b/paddlespeech/resource/model_alias.py
index 8e9ecc4ba..ab0b1828c 100644
--- a/paddlespeech/resource/model_alias.py
+++ b/paddlespeech/resource/model_alias.py
@@ -18,6 +18,12 @@ __all__ = [
# Records of model name to import class
model_alias = {
+ # ---------------------------------
+ # -------------- SSL --------------
+ # ---------------------------------
+ "wav2vec2ASR": ["paddlespeech.s2t.models.wav2vec2:Wav2vec2ASR"],
+ "wav2vec2": ["paddlespeech.s2t.models.wav2vec2:Wav2vec2Base"],
+
# ---------------------------------
# -------------- ASR --------------
# ---------------------------------
@@ -29,6 +35,11 @@ model_alias = {
"transformer": ["paddlespeech.s2t.models.u2:U2Model"],
"wenetspeech": ["paddlespeech.s2t.models.u2:U2Model"],
+ # ---------------------------------
+ # ------------ Whisper ------------
+ # ---------------------------------
+ "whisper": ["paddlespeech.s2t.models.whisper:Whisper"],
+
# ---------------------------------
# -------------- CLS --------------
# ---------------------------------
diff --git a/paddlespeech/resource/pretrained_models.py b/paddlespeech/resource/pretrained_models.py
index df50a6a9d..067246749 100644
--- a/paddlespeech/resource/pretrained_models.py
+++ b/paddlespeech/resource/pretrained_models.py
@@ -25,6 +25,8 @@ __all__ = [
'tts_static_pretrained_models',
'tts_onnx_pretrained_models',
'vector_dynamic_pretrained_models',
+ 'ssl_dynamic_pretrained_models',
+ 'whisper_dynamic_pretrained_models',
]
# The tags for pretrained_models should be "{model_name}[_{dataset}][-{lang}][-...]".
@@ -32,6 +34,44 @@ __all__ = [
# Command line and python api use "{model_name}[_{dataset}]" as --model, usage:
# "paddlespeech asr --model conformer_wenetspeech --lang zh --sr 16000 --input ./input.wav"
+# ---------------------------------
+# -------------- SSL --------------
+# ---------------------------------
+ssl_dynamic_pretrained_models = {
+ "wav2vec2-en-16k": {
+ '1.3': {
+ 'url':
+ 'https://paddlespeech.bj.bcebos.com/s2t/librispeech/asr3/wav2vec2-large-960h-lv60-self_ckpt_1.3.0.model.tar.gz',
+ 'md5':
+ 'acc46900680e341e500437aa59193518',
+ 'cfg_path':
+ 'model.yaml',
+ 'ckpt_path':
+ 'wav2vec2-large-960h-lv60-self',
+ 'model':
+ 'wav2vec2-large-960h-lv60-self.pdparams',
+ 'params':
+ 'wav2vec2-large-960h-lv60-self.pdparams',
+ },
+ },
+ "wav2vec2ASR_librispeech-en-16k": {
+ '1.3': {
+ 'url':
+ 'https://paddlespeech.bj.bcebos.com/s2t/librispeech/asr3/wav2vec2ASR-large-960h-librispeech_ckpt_1.3.1.model.tar.gz',
+ 'md5':
+ 'cbe28d6c78f3dd2e189968402381f454',
+ 'cfg_path':
+ 'model.yaml',
+ 'ckpt_path':
+ 'exp/wav2vec2ASR/checkpoints/avg_1',
+ 'model':
+ 'exp/wav2vec2ASR/checkpoints/avg_1.pdparams',
+ 'params':
+ 'exp/wav2vec2ASR/checkpoints/avg_1.pdparams',
+ },
+ },
+}
+
# ---------------------------------
# -------------- ASR --------------
# ---------------------------------
@@ -424,6 +464,189 @@ asr_onnx_pretrained_models = {
},
}
+whisper_dynamic_pretrained_models = {
+ "whisper-large-16k": {
+ '1.3': {
+ 'url':
+ 'https://paddlespeech.bj.bcebos.com/whisper/whisper_model_20221122/whisper-large-model.tar.gz',
+ 'md5':
+ 'cf1557af9d8ffa493fefad9cb08ae189',
+ 'cfg_path':
+ 'whisper.yaml',
+ 'ckpt_path':
+ 'whisper-large-model',
+ 'model':
+ 'whisper-large-model.pdparams',
+ 'params':
+ 'whisper-large-model.pdparams',
+ 'resource_data':
+ 'https://paddlespeech.bj.bcebos.com/whisper/whisper_model_20221108/assets.tar',
+ 'resource_data_md5':
+ '37a0a8abdb3641a51194f79567a93b61',
+ },
+ },
+ "whisper-base-en-16k": {
+ '1.3': {
+ 'url':
+ 'https://paddlespeech.bj.bcebos.com/whisper/whisper_model_20221122/whisper-base-en-model.tar.gz',
+ 'md5':
+ 'b156529aefde6beb7726d2ea98fd067a',
+ 'cfg_path':
+ 'whisper.yaml',
+ 'ckpt_path':
+ 'whisper-base-en-model',
+ 'model':
+ 'whisper-base-en-model.pdparams',
+ 'params':
+ 'whisper-base-en-model.pdparams',
+ 'resource_data':
+ 'https://paddlespeech.bj.bcebos.com/whisper/whisper_model_20221108/assets.tar',
+ 'resource_data_md5':
+ '37a0a8abdb3641a51194f79567a93b61',
+ },
+ },
+ "whisper-base-16k": {
+ '1.3': {
+ 'url':
+ 'https://paddlespeech.bj.bcebos.com/whisper/whisper_model_20221122/whisper-base-model.tar.gz',
+ 'md5':
+ '6b012a5abd583db14398c3492e47120b',
+ 'cfg_path':
+ 'whisper.yaml',
+ 'ckpt_path':
+ 'whisper-base-model',
+ 'model':
+ 'whisper-base-model.pdparams',
+ 'params':
+ 'whisper-base-model.pdparams',
+ 'resource_data':
+ 'https://paddlespeech.bj.bcebos.com/whisper/whisper_model_20221108/assets.tar',
+ 'resource_data_md5':
+ '37a0a8abdb3641a51194f79567a93b61',
+ },
+ },
+ "whisper-medium-en-16k": {
+ '1.3': {
+ 'url':
+ 'https://paddlespeech.bj.bcebos.com/whisper/whisper_model_20221122/whisper-medium-en-model.tar.gz',
+ 'md5':
+ 'c7f57d270bd20c7b170ba9dcf6c16f74',
+ 'cfg_path':
+ 'whisper.yaml',
+ 'ckpt_path':
+ 'whisper-medium-en-model',
+ 'model':
+ 'whisper-medium-en-model.pdparams',
+ 'params':
+ 'whisper-medium-en-model.pdparams',
+ 'resource_data':
+ 'https://paddlespeech.bj.bcebos.com/whisper/whisper_model_20221108/assets.tar',
+ 'resource_data_md5':
+ '37a0a8abdb3641a51194f79567a93b61',
+ },
+ },
+ "whisper-medium-16k": {
+ '1.3': {
+ 'url':
+ 'https://paddlespeech.bj.bcebos.com/whisper/whisper_model_20221122/whisper-medium-model.tar.gz',
+ 'md5':
+ '4c7dcd0df25f408199db4a4548336786',
+ 'cfg_path':
+ 'whisper.yaml',
+ 'ckpt_path':
+ 'whisper-medium-model',
+ 'model':
+ 'whisper-medium-model.pdparams',
+ 'params':
+ 'whisper-medium-model.pdparams',
+ 'resource_data':
+ 'https://paddlespeech.bj.bcebos.com/whisper/whisper_model_20221108/assets.tar',
+ 'resource_data_md5':
+ '37a0a8abdb3641a51194f79567a93b61',
+ },
+ },
+ "whisper-small-en-16k": {
+ '1.3': {
+ 'url':
+ 'https://paddlespeech.bj.bcebos.com/whisper/whisper_model_20221122/whisper-small-en-model.tar.gz',
+ 'md5':
+ '2b24efcb2e93f3275af7c0c7f598ff1c',
+ 'cfg_path':
+ 'whisper.yaml',
+ 'ckpt_path':
+ 'whisper-small-en-model',
+ 'model':
+ 'whisper-small-en-model.pdparams',
+ 'params':
+ 'whisper-small-en-model.pdparams',
+ 'resource_data':
+ 'https://paddlespeech.bj.bcebos.com/whisper/whisper_model_20221108/assets.tar',
+ 'resource_data_md5':
+ '37a0a8abdb3641a51194f79567a93b61',
+ },
+ },
+ "whisper-small-16k": {
+ '1.3': {
+ 'url':
+ 'https://paddlespeech.bj.bcebos.com/whisper/whisper_model_20221122/whisper-small-model.tar.gz',
+ 'md5':
+ '5a57911dd41651dd6ed78c5763912825',
+ 'cfg_path':
+ 'whisper.yaml',
+ 'ckpt_path':
+ 'whisper-small-model',
+ 'model':
+ 'whisper-small-model.pdparams',
+ 'params':
+ 'whisper-small-model.pdparams',
+ 'resource_data':
+ 'https://paddlespeech.bj.bcebos.com/whisper/whisper_model_20221108/assets.tar',
+ 'resource_data_md5':
+ '37a0a8abdb3641a51194f79567a93b61',
+ },
+ },
+ "whisper-tiny-en-16k": {
+ '1.3': {
+ 'url':
+ 'https://paddlespeech.bj.bcebos.com/whisper/whisper_model_20221122/whisper-tiny-en-model.tar.gz',
+ 'md5':
+ '14969164a3f713fd58e56978c34188f6',
+ 'cfg_path':
+ 'whisper.yaml',
+ 'ckpt_path':
+ 'whisper-tiny-en-model',
+ 'model':
+ 'whisper-tiny-en-model.pdparams',
+ 'params':
+ 'whisper-tiny-en-model.pdparams',
+ 'resource_data':
+ 'https://paddlespeech.bj.bcebos.com/whisper/whisper_model_20221108/assets.tar',
+ 'resource_data_md5':
+ '37a0a8abdb3641a51194f79567a93b61',
+ },
+ },
+ "whisper-tiny-16k": {
+ '1.3': {
+ 'url':
+ 'https://paddlespeech.bj.bcebos.com/whisper/whisper_model_20221122/whisper-tiny-model.tar.gz',
+ 'md5':
+ 'a5b82a1f2067a2ca400f17fabd62b81b',
+ 'cfg_path':
+ 'whisper.yaml',
+ 'ckpt_path':
+ 'whisper-tiny-model',
+ 'model':
+ 'whisper-tiny-model.pdparams',
+ 'params':
+ 'whisper-tiny-model.pdparams',
+ 'resource_data':
+ 'https://paddlespeech.bj.bcebos.com/whisper/whisper_model_20221108/assets.tar',
+ 'resource_data_md5':
+ '37a0a8abdb3641a51194f79567a93b61',
+ },
+ },
+}
+
# ---------------------------------
# -------------- CLS --------------
# ---------------------------------
@@ -723,6 +946,22 @@ tts_dynamic_pretrained_models = {
'speaker_id_map.txt',
},
},
+ "fastspeech2_male-zh": {
+ '1.0': {
+ 'url':
+ 'https://paddlespeech.bj.bcebos.com/Parakeet/released_models/fastspeech2/fastspeech2_male_ckpt_1.3.0.zip',
+ 'md5':
+ 'a4b1a2f667b878ec8f67375357b04282',
+ 'config':
+ 'default.yaml',
+ 'ckpt':
+ 'snapshot_iter_76000.pdz',
+ 'speech_stats':
+ 'speech_stats.npy',
+ 'phones_dict':
+ 'phone_id_map.txt',
+ },
+ },
# tacotron2
"tacotron2_csmsc-zh": {
'1.0': {
@@ -813,6 +1052,20 @@ tts_dynamic_pretrained_models = {
'feats_stats.npy',
},
},
+ "pwgan_male-zh": {
+ '1.0': {
+ 'url':
+ 'https://paddlespeech.bj.bcebos.com/Parakeet/released_models/pwgan/pwg_male_ckpt_1.3.0.zip',
+ 'md5':
+ 'c98cdb889c809973f8cc764437311132',
+ 'config':
+ 'default.yaml',
+ 'ckpt':
+ 'snapshot_iter_200000.pdz',
+ 'speech_stats':
+ 'feats_stats.npy',
+ },
+ },
# mb_melgan
"mb_melgan_csmsc-zh": {
'1.0': {
diff --git a/paddlespeech/resource/resource.py b/paddlespeech/resource/resource.py
index 8e9914b2e..4eb0e32d7 100644
--- a/paddlespeech/resource/resource.py
+++ b/paddlespeech/resource/resource.py
@@ -22,7 +22,9 @@ from ..utils.dynamic_import import dynamic_import
from ..utils.env import MODEL_HOME
from .model_alias import model_alias
-task_supported = ['asr', 'cls', 'st', 'text', 'tts', 'vector', 'kws']
+task_supported = [
+ 'asr', 'cls', 'st', 'text', 'tts', 'vector', 'kws', 'ssl', 'whisper'
+]
model_format_supported = ['dynamic', 'static', 'onnx']
inference_mode_supported = ['online', 'offline']
@@ -108,7 +110,6 @@ class CommonTaskResource:
"""
assert model_name in model_alias, 'No model classes found for "{}"'.format(
model_name)
-
ret = []
for import_path in model_alias[model_name]:
ret.append(dynamic_import(import_path))
diff --git a/paddlespeech/s2t/exps/wav2vec2/__init__.py b/paddlespeech/s2t/exps/wav2vec2/__init__.py
new file mode 100644
index 000000000..97043fd7b
--- /dev/null
+++ b/paddlespeech/s2t/exps/wav2vec2/__init__.py
@@ -0,0 +1,13 @@
+# Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
diff --git a/paddlespeech/s2t/exps/wav2vec2/bin/__init__.py b/paddlespeech/s2t/exps/wav2vec2/bin/__init__.py
index 185a92b8d..97043fd7b 100644
--- a/paddlespeech/s2t/exps/wav2vec2/bin/__init__.py
+++ b/paddlespeech/s2t/exps/wav2vec2/bin/__init__.py
@@ -1,4 +1,4 @@
-# Copyright (c) 2021 PaddlePaddle Authors. All Rights Reserved.
+# Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
diff --git a/paddlespeech/s2t/exps/wav2vec2/bin/train.py b/paddlespeech/s2t/exps/wav2vec2/bin/train.py
index 3ae3a9e73..29e7ef552 100644
--- a/paddlespeech/s2t/exps/wav2vec2/bin/train.py
+++ b/paddlespeech/s2t/exps/wav2vec2/bin/train.py
@@ -34,9 +34,10 @@ def main(config, args):
if __name__ == "__main__":
parser = default_argument_parser()
+ parser.add_argument(
+ '--resume', type=str, default="", nargs="?", help='resume ckpt path.')
args = parser.parse_args()
print_arguments(args, globals())
-
# https://yaml.org/type/float.html
config = CfgNode(new_allowed=True)
if args.config:
diff --git a/paddlespeech/s2t/exps/wav2vec2/model.py b/paddlespeech/s2t/exps/wav2vec2/model.py
index 4f6bc0c5b..6a3321e46 100644
--- a/paddlespeech/s2t/exps/wav2vec2/model.py
+++ b/paddlespeech/s2t/exps/wav2vec2/model.py
@@ -15,6 +15,7 @@
import json
import math
import os
+import re
import time
from collections import defaultdict
from collections import OrderedDict
@@ -62,6 +63,19 @@ class Wav2Vec2ASRTrainer(Trainer):
self.avg_train_loss -= self.avg_train_loss / (batch_index + 1)
self.avg_train_loss += loss / (batch_index + 1)
+ def before_train(self):
+ from_scratch = self.resume_or_scratch()
+ if from_scratch:
+ # scratch: save init model, i.e. 0 epoch
+ self.save(tag='init', infos=None)
+ else:
+ # resume: train next_epoch and next_iteration
+ self.epoch += 1
+ logger.info(
+ f"Resume train: epoch {self.epoch }, step {self.iteration}!")
+
+ self.maybe_batch_sampler_step()
+
def train_batch(self, batch_index, batch, msg):
train_conf = self.config
start = time.time()
@@ -69,14 +83,14 @@ class Wav2Vec2ASRTrainer(Trainer):
# forward
utt, wav, wavs_lens, target, target_lens = batch
wavs_lens_rate = wavs_lens / wav.shape[1]
- target_lens_rate = target_lens / target.shape[1]
+
wav = wav[:, :, 0]
- if hasattr(train_conf, 'speech_augment'):
+ if hasattr(train_conf, 'audio_augment'):
wav = self.speech_augmentation(wav, wavs_lens_rate)
- loss = self.model(wav, wavs_lens_rate, target, target_lens_rate)
+
+ loss = self.model(wav, wavs_lens_rate, target, target_lens)
# loss div by `batch_size * accum_grad`
loss /= train_conf.accum_grad
-
# update self.avg_train_loss
self.update_average(batch_index, float(loss))
@@ -98,11 +112,17 @@ class Wav2Vec2ASRTrainer(Trainer):
# optimizer step old
if (batch_index + 1) % train_conf.accum_grad == 0:
- self.optimizer.step()
- self.optimizer.clear_grad()
- self.lr_scheduler.step()
+ self.model_optimizer.step()
+ self.model_optimizer.clear_grad()
+ if not train_conf.freeze_wav2vec2:
+ self.wav2vec2_optimizer.step()
+ self.wav2vec2_optimizer.clear_grad()
+ if self.config.model_scheduler != 'newbobscheduler':
+ self.model_lr_scheduler.step()
+ if self.config.wav2vec2_scheduler != 'newbobscheduler':
+ if not train_conf.freeze_wav2vec2:
+ self.wav2vec2_lr_scheduler.step()
self.iteration += 1
-
losses_np = {'loss': self.avg_train_loss * train_conf.accum_grad}
iteration_time = time.time() - start
for k, v in losses_np.items():
@@ -114,7 +134,10 @@ class Wav2Vec2ASRTrainer(Trainer):
if (batch_index + 1) % train_conf.accum_grad == 0:
if dist.get_rank() == 0 and self.visualizer:
losses_np_v = losses_np.copy()
- losses_np_v.update({"lr": self.lr_scheduler()})
+ losses_np_v.update({
+ "model_lr": self.model_lr_scheduler(),
+ "wav2vec2_lr": self.wav2vec2_lr_scheduler()
+ })
for key, val in losses_np_v.items():
self.visualizer.add_scalar(
tag='train/' + key, value=val, step=self.iteration - 1)
@@ -131,11 +154,10 @@ class Wav2Vec2ASRTrainer(Trainer):
for i, batch in enumerate(self.valid_loader):
utt, wav, wavs_lens, target, target_lens = batch
wavs_lens_rate = wavs_lens / wav.shape[1]
- target_lens_rate = target_lens / target.shape[1]
wav = wav[:, :, 0]
- loss = self.model(wav, wavs_lens_rate, target, target_lens_rate)
+ loss = self.model(wav, wavs_lens_rate, target, target_lens)
- if paddle.isfinite(loss):
+ if math.isfinite(float(loss)):
num_utts = batch[1].shape[0]
num_seen_utts += num_utts
total_loss += float(loss) * num_utts
@@ -160,6 +182,106 @@ class Wav2Vec2ASRTrainer(Trainer):
dist.get_rank(), total_loss / num_seen_utts))
return total_loss, num_seen_utts
+ @mp_tools.rank_zero_only
+ def save(self, tag=None, infos: dict=None):
+ """Save checkpoint (model parameters and optimizer states).
+
+ Args:
+ tag (int or str, optional): None for step, else using tag, e.g epoch. Defaults to None.
+ infos (dict, optional): meta data to save. Defaults to None.
+ """
+
+ infos = infos if infos else dict()
+ infos.update({
+ "epoch": self.epoch,
+ "model_lr": self.model_optimizer.get_lr(),
+ "wav2vec2_lr": self.wav2vec2_optimizer.get_lr()
+ })
+
+ checkpoint_path = os.path.join(
+ self.checkpoint_dir,
+ "{}".format(self.iteration if tag is None else tag))
+
+ model_dict = self.model.state_dict()
+ params_path = checkpoint_path + ".pdparams"
+ paddle.save(model_dict, params_path)
+ logger.info("Saved model to {}".format(params_path))
+
+ model_opt_dict = self.model_optimizer.state_dict()
+ wav2vec2_opt_dict = self.wav2vec2_optimizer.state_dict()
+
+ opt_dict = {'model': model_opt_dict, 'wav2vec2': wav2vec2_opt_dict}
+
+ optimizer_path = checkpoint_path + ".pdopt"
+ paddle.save(opt_dict, optimizer_path)
+ logger.info("Saved optimzier state to {}".format(optimizer_path))
+
+ scheduler_dict = {}
+
+ if self.config.model_scheduler == 'newbobscheduler':
+ scheduler_dict['model'] = self.model_lr_scheduler.save()
+ if self.config.wav2vec2_scheduler == 'newbobscheduler':
+ scheduler_dict['wav2vec2'] = self.wav2vec2_lr_scheduler.save()
+ if scheduler_dict:
+ scheduler_path = checkpoint_path + ".pdlrs"
+ paddle.save(scheduler_dict, scheduler_path)
+ logger.info("Saved scheduler state to {}".format(scheduler_path))
+ info_path = re.sub('.pdparams$', '.json', params_path)
+ infos = {} if infos is None else infos
+ with open(info_path, 'w') as fout:
+ data = json.dumps(infos)
+ fout.write(data)
+
+ def resume_or_scratch(self):
+ """Resume from latest checkpoint at checkpoints in the output
+ directory or load a specified checkpoint.
+
+ If ``args.checkpoint_path`` is not None, load the checkpoint, else
+ resume training.
+ """
+ scratch = None
+ if self.args.resume:
+ # just restore ckpt
+ # lr will resotre from optimizer ckpt
+ resume_json_path = os.path.join(self.checkpoint_dir,
+ self.args.resume + '.json')
+ with open(resume_json_path, 'r') as f:
+ resume_json = json.load(f)
+ self.iteration = 0
+ self.epoch = resume_json["epoch"]
+
+ # resotre model from *.pdparams
+ params_path = os.path.join(self.checkpoint_dir,
+ "{}".format(self.epoch)) + '.pdparams'
+ model_dict = paddle.load(params_path)
+ self.model.set_state_dict(model_dict)
+
+ # resotre optimizer from *.pdopt
+ optimizer_path = os.path.join(self.checkpoint_dir,
+ "{}".format(self.epoch)) + '.pdopt'
+ optimizer_dict = paddle.load(optimizer_path)
+ self.model_optimizer.set_state_dict(optimizer_dict['model'])
+ self.wav2vec2_optimizer.set_state_dict(optimizer_dict['wav2vec2'])
+
+ # resotre lr_scheduler from *.pdlrs
+ scheduler_path = os.path.join(self.checkpoint_dir,
+ "{}".format(self.epoch)) + '.pdlrs'
+ if os.path.isfile(os.path.join(scheduler_path)):
+ scheduler_dict = paddle.load(scheduler_path)
+ if self.config.model_scheduler == 'newbobscheduler':
+ self.model_lr_scheduler.load(scheduler_dict['model'])
+ if self.config.wav2vec2_scheduler == 'newbobscheduler':
+ self.wav2vec2_lr_scheduler.load(scheduler_dict['wav2vec2'])
+ logger.info(
+ f"Restore ckpt: epoch {self.epoch }, step {self.iteration}!")
+ scratch = False
+ else:
+ self.iteration = 0
+ self.epoch = 0
+ scratch = True
+ logger.info("Init from scratch!")
+ return scratch
+
def do_train(self):
"""The training process control by step."""
# !!!IMPORTANT!!!
@@ -170,7 +292,6 @@ class Wav2Vec2ASRTrainer(Trainer):
# paddle.jit.save(script_model, script_model_path)
self.before_train()
-
if not self.use_streamdata:
logger.info(
f"Train Total Examples: {len(self.train_loader.dataset)}")
@@ -187,7 +308,9 @@ class Wav2Vec2ASRTrainer(Trainer):
report("Rank", dist.get_rank())
report("epoch", self.epoch)
report('step', self.iteration)
- report("lr", self.lr_scheduler())
+ report("model_lr", self.model_optimizer.get_lr())
+ report("wav2vec2_lr",
+ self.wav2vec2_optimizer.get_lr())
self.train_batch(batch_index, batch, msg)
self.after_train_batch()
report('iter', batch_index + 1)
@@ -225,15 +348,25 @@ class Wav2Vec2ASRTrainer(Trainer):
cv_loss = float(cv_loss)
else:
cv_loss = total_loss / num_seen_utts
-
logger.info(
'Epoch {} Val info val_loss {}'.format(self.epoch, cv_loss))
if self.visualizer:
self.visualizer.add_scalar(
tag='eval/cv_loss', value=cv_loss, step=self.epoch)
self.visualizer.add_scalar(
- tag='eval/lr', value=self.lr_scheduler(), step=self.epoch)
-
+ tag='eval/model_lr',
+ value=self.model_lr_scheduler(),
+ step=self.epoch)
+ self.visualizer.add_scalar(
+ tag='eval/wav2vec2_lr',
+ value=self.wav2vec2_lr_scheduler(),
+ step=self.epoch)
+
+ if self.config.model_scheduler == 'newbobscheduler':
+ self.model_lr_scheduler.step(cv_loss)
+ if self.config.wav2vec2_scheduler == 'newbobscheduler':
+ if not self.config.freeze_wav2vec2:
+ self.wav2vec2_lr_scheduler.step(cv_loss)
self.save(tag=self.epoch, infos={'val_loss': cv_loss})
self.new_epoch()
@@ -268,10 +401,11 @@ class Wav2Vec2ASRTrainer(Trainer):
model_conf.output_dim = self.test_loader.vocab_size
model = Wav2vec2ASR.from_config(model_conf)
+ model_dict = paddle.load(config.wav2vec2_params_path)
+ model.wav2vec2.set_state_dict(model_dict)
if self.parallel:
model = paddle.DataParallel(model, find_unused_parameters=True)
-
logger.info(f"{model}")
layer_tools.print_params(model, logger.info)
self.model = model
@@ -286,46 +420,74 @@ class Wav2Vec2ASRTrainer(Trainer):
return
train_config = config
- optim_type = train_config.model_optim
- optim_conf = train_config.model_optim_conf
- scheduler_type = train_config.scheduler
- scheduler_conf = train_config.scheduler_conf
-
- scheduler_args = {
- "learning_rate": optim_conf.lr,
- "verbose": False,
- "warmup_steps": scheduler_conf.warmup_steps,
- "gamma": scheduler_conf.lr_decay,
- "d_model": model_conf.dnn_neurons,
- }
- lr_scheduler = LRSchedulerFactory.from_args(scheduler_type,
- scheduler_args)
+ model_optim_type = train_config.model_optim
+ model_optim_conf = train_config.model_optim_conf
+ wav2vec2_optim_type = train_config.model_optim
+ wav2vec2_optim_conf = train_config.wav2vec2_optim_conf
+
+ model_scheduler_type = train_config.model_scheduler
+ model_scheduler_conf = train_config.model_scheduler_conf
+ wav2vec2_scheduler_type = train_config.wav2vec2_scheduler
+ wav2vec2_scheduler_conf = train_config.wav2vec2_scheduler_conf
+
+ model_scheduler_args = dict(
+ **{"learning_rate": model_optim_conf.lr,
+ "verbose": False}, **(dict(model_scheduler_conf)))
+
+ wav2vec2_scheduler_args = dict(
+ **{"learning_rate": wav2vec2_optim_conf.lr,
+ "verbose": False}, **(dict(wav2vec2_scheduler_conf)))
+
+ model_lr_scheduler = LRSchedulerFactory.from_args(model_scheduler_type,
+ model_scheduler_args)
+ wav2vec2_lr_scheduler = LRSchedulerFactory.from_args(
+ wav2vec2_scheduler_type, wav2vec2_scheduler_args)
def optimizer_args(
config,
+ optim_type,
+ optim_conf,
parameters,
lr_scheduler=None, ):
train_config = config
- optim_type = train_config.model_optim
- optim_conf = train_config.model_optim_conf
- scheduler_type = train_config.scheduler
- scheduler_conf = train_config.scheduler_conf
- return {
- "grad_clip": train_config.global_grad_clip,
- "learning_rate": lr_scheduler
- if lr_scheduler else optim_conf.lr,
- "epsilon": optim_conf.epsilon,
- "rho": optim_conf.rho,
- "parameters": parameters,
- "beta1": 0.9 if optim_type == 'noam' else None,
- "beat2": 0.98 if optim_type == 'noam' else None,
- }
-
- optimzer_args = optimizer_args(config, model.parameters(), lr_scheduler)
- optimizer = OptimizerFactory.from_args(optim_type, optimzer_args)
-
- self.optimizer = optimizer
- self.lr_scheduler = lr_scheduler
+ optim_arg = dict(optim_conf)
+ optim_arg.update({
+ "grad_clip":
+ train_config.global_grad_clip,
+ "learning_rate":
+ lr_scheduler if lr_scheduler else optim_conf.lr,
+ "parameters":
+ parameters
+ })
+ return optim_arg
+
+ model_optimizer_args = optimizer_args(config, model_optim_type,
+ model_optim_conf, [{
+ 'params':
+ model._layers.enc.parameters()
+ }, {
+ 'params':
+ model._layers.ctc.parameters()
+ }] if self.parallel else [{
+ 'params':
+ model.enc.parameters()
+ }, {
+ 'params':
+ model.ctc.parameters()
+ }], model_lr_scheduler)
+ wav2vec2_optimizer_args = optimizer_args(
+ config, wav2vec2_optim_type, wav2vec2_optim_conf,
+ model._layers.wav2vec2.parameters() if self.parallel else
+ model.wav2vec2.parameters(), wav2vec2_lr_scheduler)
+ model_optimizer = OptimizerFactory.from_args(model_optim_type,
+ model_optimizer_args)
+ wav2vec2_optimizer = OptimizerFactory.from_args(wav2vec2_optim_type,
+ wav2vec2_optimizer_args)
+
+ self.model_optimizer = model_optimizer
+ self.wav2vec2_optimizer = wav2vec2_optimizer
+ self.model_lr_scheduler = model_lr_scheduler
+ self.wav2vec2_lr_scheduler = wav2vec2_lr_scheduler
logger.info("Setup optimizer/lr_scheduler!")
diff --git a/paddlespeech/s2t/exps/whisper/test_wav.py b/paddlespeech/s2t/exps/whisper/test_wav.py
new file mode 100644
index 000000000..e04eec4f2
--- /dev/null
+++ b/paddlespeech/s2t/exps/whisper/test_wav.py
@@ -0,0 +1,123 @@
+# Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.∏
+# See the License for the specific language governing permissions and
+# limitations under the License.
+# Modified from Whisper (https://github.com/openai/whisper/whisper/)
+import os.path
+import sys
+
+import distutils
+import numpy as np
+import paddle
+import soundfile
+from yacs.config import CfgNode
+
+from paddlespeech.s2t.models.whisper import log_mel_spectrogram
+from paddlespeech.s2t.models.whisper import ModelDimensions
+from paddlespeech.s2t.models.whisper import transcribe
+from paddlespeech.s2t.models.whisper import Whisper
+from paddlespeech.s2t.training.cli import default_argument_parser
+from paddlespeech.s2t.utils.log import Log
+
+logger = Log(__name__).getlog()
+
+
+class WhisperInfer():
+ def __init__(self, config, args):
+ self.args = args
+ self.config = config
+ self.audio_file = args.audio_file
+
+ paddle.set_device('gpu' if self.args.ngpu > 0 else 'cpu')
+ config.pop("ngpu")
+
+ #load_model
+ model_dict = paddle.load(self.config.model_file)
+ config.pop("model_file")
+ dims = ModelDimensions(**model_dict["dims"])
+ self.model = Whisper(dims)
+ self.model.load_dict(model_dict)
+
+ def run(self):
+ check(args.audio_file)
+
+ with paddle.no_grad():
+ temperature = config.pop("temperature")
+ temperature_increment_on_fallback = config.pop(
+ "temperature_increment_on_fallback")
+ if temperature_increment_on_fallback is not None:
+ temperature = tuple(
+ np.arange(temperature, 1.0 + 1e-6,
+ temperature_increment_on_fallback))
+ else:
+ temperature = [temperature]
+
+ #load audio
+ mel = log_mel_spectrogram(
+ args.audio_file, resource_path=config.resource_path)
+
+ result = transcribe(
+ self.model, mel, temperature=temperature, **config)
+ if args.result_file is not None:
+ with open(args.result_file, 'w') as f:
+ f.write(str(result))
+ return result
+
+
+def check(audio_file: str):
+ if not os.path.isfile(audio_file):
+ print("Please input the right audio file path")
+ sys.exit(-1)
+
+ logger.info("checking the audio file format......")
+ try:
+ _, sample_rate = soundfile.read(audio_file)
+ except Exception as e:
+ logger.error(str(e))
+ logger.error(
+ "can not open the wav file, please check the audio file format")
+ sys.exit(-1)
+ logger.info("The sample rate is %d" % sample_rate)
+ assert (sample_rate == 16000)
+ logger.info("The audio file format is right")
+
+
+def main(config, args):
+ WhisperInfer(config, args).run()
+
+
+if __name__ == "__main__":
+ parser = default_argument_parser()
+ # save asr result to
+ parser.add_argument(
+ "--result_file", type=str, help="path of save the asr result")
+ parser.add_argument(
+ "--audio_file", type=str, help="path of the input audio file")
+ parser.add_argument(
+ "--debug",
+ type=distutils.util.strtobool,
+ default=False,
+ help="for debug.")
+ args = parser.parse_args()
+
+ config = CfgNode(new_allowed=True)
+
+ if args.config:
+ config.merge_from_file(args.config)
+ if args.decode_cfg:
+ decode_confs = CfgNode(new_allowed=True)
+ decode_confs.merge_from_file(args.decode_cfg)
+ config.decode = decode_confs
+ if args.opts:
+ config.merge_from_list(args.opts)
+ config.freeze()
+ main(config, args)
diff --git a/paddlespeech/s2t/models/wav2vec2/__init__.py b/paddlespeech/s2t/models/wav2vec2/__init__.py
index e69de29bb..3a12a9cf3 100644
--- a/paddlespeech/s2t/models/wav2vec2/__init__.py
+++ b/paddlespeech/s2t/models/wav2vec2/__init__.py
@@ -0,0 +1,17 @@
+# Copyright (c) 2021 PaddlePaddle Authors. All Rights Reserved.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+from .wav2vec2_ASR import Wav2vec2ASR
+from .wav2vec2_ASR import Wav2vec2Base
+
+__all__ = ["Wav2vec2ASR", "Wav2vec2Base"]
diff --git a/paddlespeech/s2t/models/wav2vec2/modules/VanillaNN.py b/paddlespeech/s2t/models/wav2vec2/modules/VanillaNN.py
index cfd8f507e..9c88796bb 100644
--- a/paddlespeech/s2t/models/wav2vec2/modules/VanillaNN.py
+++ b/paddlespeech/s2t/models/wav2vec2/modules/VanillaNN.py
@@ -18,6 +18,7 @@ import paddle
from paddlespeech.s2t.models.wav2vec2.modules import containers
from paddlespeech.s2t.models.wav2vec2.modules import linear
+from paddlespeech.s2t.models.wav2vec2.modules.normalization import BatchNorm1d
class VanillaNN(containers.Sequential):
@@ -39,18 +40,33 @@ class VanillaNN(containers.Sequential):
paddle.shape([10, 120, 512])
"""
- def __init__(
- self,
- input_shape,
- activation=paddle.nn.LeakyReLU,
- dnn_blocks=2,
- dnn_neurons=512, ):
- super().__init__(input_shape=input_shape)
+ def __init__(self,
+ input_shape,
+ dnn_blocks=2,
+ dnn_neurons=512,
+ activation=True,
+ normalization=False,
+ dropout_rate=0.5):
+ super().__init__(input_shape=[None, None, input_shape])
+
+ if not isinstance(dropout_rate, list):
+ dropout_rate = [dropout_rate] * dnn_blocks
+ else:
+ assert len(
+ dropout_rate
+ ) == dnn_blocks, "len(dropout_rate) must equal to dnn_blocks"
for block_index in range(dnn_blocks):
self.append(
linear.Linear,
n_neurons=dnn_neurons,
- bias=True,
+ bias_attr=None,
layer_name="linear", )
- self.append(activation(), layer_name="act")
+ if normalization:
+ self.append(
+ BatchNorm1d, input_size=dnn_neurons, layer_name='bn')
+ if activation:
+ self.append(paddle.nn.LeakyReLU(), layer_name="act")
+ self.append(
+ paddle.nn.Dropout(p=dropout_rate[block_index]),
+ layer_name='dropout')
diff --git a/paddlespeech/s2t/models/wav2vec2/modules/__init__.py b/paddlespeech/s2t/models/wav2vec2/modules/__init__.py
new file mode 100644
index 000000000..97043fd7b
--- /dev/null
+++ b/paddlespeech/s2t/models/wav2vec2/modules/__init__.py
@@ -0,0 +1,13 @@
+# Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
diff --git a/paddlespeech/s2t/models/wav2vec2/modules/containers.py b/paddlespeech/s2t/models/wav2vec2/modules/containers.py
index 180d0bd32..6a6b94e95 100644
--- a/paddlespeech/s2t/models/wav2vec2/modules/containers.py
+++ b/paddlespeech/s2t/models/wav2vec2/modules/containers.py
@@ -141,5 +141,4 @@ class Sequential(paddle.nn.LayerDict):
x = layer(x)
if isinstance(x, tuple):
x = x[0]
-
return x
diff --git a/paddlespeech/s2t/models/wav2vec2/modules/linear.py b/paddlespeech/s2t/models/wav2vec2/modules/linear.py
index adae4514a..3ea3716c4 100644
--- a/paddlespeech/s2t/models/wav2vec2/modules/linear.py
+++ b/paddlespeech/s2t/models/wav2vec2/modules/linear.py
@@ -53,7 +53,7 @@ class Linear(paddle.nn.Layer):
n_neurons,
input_shape=None,
input_size=None,
- bias=True,
+ bias_attr=None,
combine_dims=False, ):
super().__init__()
self.combine_dims = combine_dims
@@ -67,7 +67,7 @@ class Linear(paddle.nn.Layer):
input_size = input_shape[2] * input_shape[3]
# Weights are initialized following paddle approach
- self.w = align.Linear(input_size, n_neurons, bias_attr=bias)
+ self.w = align.Linear(input_size, n_neurons, bias_attr=bias_attr)
def forward(self, x):
"""Returns the linear transformation of input tensor.
diff --git a/paddlespeech/s2t/models/wav2vec2/modules/modeling_wav2vec2.py b/paddlespeech/s2t/models/wav2vec2/modules/modeling_wav2vec2.py
index e484fff68..5670cb531 100644
--- a/paddlespeech/s2t/models/wav2vec2/modules/modeling_wav2vec2.py
+++ b/paddlespeech/s2t/models/wav2vec2/modules/modeling_wav2vec2.py
@@ -1120,9 +1120,6 @@ class Wav2Vec2ConfigPure():
self.output_hidden_states = False
self.use_return_dict = True
- self.pad_token_id = config.pad_token_id
- self.bos_token_id = config.bos_token_id
- self.eos_token_id = config.eos_token_id
self.hidden_size = config.hidden_size
self.feat_extract_norm = config.feat_extract_norm
self.feat_extract_activation = config.feat_extract_activation
@@ -1145,7 +1142,6 @@ class Wav2Vec2ConfigPure():
self.layerdrop = config.layerdrop
self.layer_norm_eps = config.layer_norm_eps
self.initializer_range = config.initializer_range
- self.vocab_size = config.vocab_size
self.do_stable_layer_norm = config.do_stable_layer_norm
self.use_weighted_layer_sum = config.use_weighted_layer_sum
diff --git a/paddlespeech/s2t/models/wav2vec2/modules/normalization.py b/paddlespeech/s2t/models/wav2vec2/modules/normalization.py
new file mode 100644
index 000000000..912981058
--- /dev/null
+++ b/paddlespeech/s2t/models/wav2vec2/modules/normalization.py
@@ -0,0 +1,97 @@
+# Authors
+# * Mirco Ravanelli 2020
+# * Guillermo Cámbara 2021
+# * Sarthak Yadav 2022
+# Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+# Modified from speechbrain(https://github.com/speechbrain/speechbrain/blob/develop/speechbrain/nnet/normalization.py)
+import paddle.nn as nn
+
+from paddlespeech.s2t.modules.align import BatchNorm1D
+
+
+class BatchNorm1d(nn.Layer):
+ """Applies 1d batch normalization to the input tensor.
+ Arguments
+ ---------
+ input_shape : tuple
+ The expected shape of the input. Alternatively, use ``input_size``.
+ input_size : int
+ The expected size of the input. Alternatively, use ``input_shape``.
+ eps : float
+ This value is added to std deviation estimation to improve the numerical
+ stability.
+ momentum : float
+ It is a value used for the running_mean and running_var computation.
+ affine : bool
+ When set to True, the affine parameters are learned.
+ track_running_stats : bool
+ When set to True, this module tracks the running mean and variance,
+ and when set to False, this module does not track such statistics.
+ combine_batch_time : bool
+ When true, it combines batch an time axis.
+ Example
+ -------
+ >>> input = paddle.randn([100, 10])
+ >>> norm = BatchNorm1d(input_shape=input.shape)
+ >>> output = norm(input)
+ >>> output.shape
+ Paddle.Shape([100, 10])
+ """
+
+ def __init__(
+ self,
+ input_shape=None,
+ input_size=None,
+ eps=1e-05,
+ momentum=0.9,
+ combine_batch_time=False,
+ skip_transpose=False, ):
+ super().__init__()
+ self.combine_batch_time = combine_batch_time
+ self.skip_transpose = skip_transpose
+
+ if input_size is None and skip_transpose:
+ input_size = input_shape[1]
+ elif input_size is None:
+ input_size = input_shape[-1]
+
+ self.norm = BatchNorm1D(input_size, momentum=momentum, epsilon=eps)
+
+ def forward(self, x):
+ """Returns the normalized input tensor.
+ Arguments
+ ---------
+ x : paddle.Tensor (batch, time, [channels])
+ input to normalize. 2d or 3d tensors are expected in input
+ 4d tensors can be used when combine_dims=True.
+ """
+ shape_or = x.shape
+ if self.combine_batch_time:
+ if x.ndim == 3:
+ x = x.reshape(shape_or[0] * shape_or[1], shape_or[2])
+ else:
+ x = x.reshape(shape_or[0] * shape_or[1], shape_or[3],
+ shape_or[2])
+
+ elif not self.skip_transpose:
+ x = x.transpose([0, 2, 1])
+
+ x_n = self.norm(x)
+ if self.combine_batch_time:
+ x_n = x_n.reshape(shape_or)
+ elif not self.skip_transpose:
+ x_n = x_n.transpose([0, 2, 1])
+
+ return x_n
diff --git a/paddlespeech/s2t/models/wav2vec2/processing/__init__.py b/paddlespeech/s2t/models/wav2vec2/processing/__init__.py
new file mode 100644
index 000000000..97043fd7b
--- /dev/null
+++ b/paddlespeech/s2t/models/wav2vec2/processing/__init__.py
@@ -0,0 +1,13 @@
+# Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
diff --git a/paddlespeech/s2t/models/wav2vec2/processing/speech_augmentation.py b/paddlespeech/s2t/models/wav2vec2/processing/speech_augmentation.py
index ac9bf45db..9224549a4 100644
--- a/paddlespeech/s2t/models/wav2vec2/processing/speech_augmentation.py
+++ b/paddlespeech/s2t/models/wav2vec2/processing/speech_augmentation.py
@@ -639,6 +639,170 @@ class DropChunk(nn.Layer):
return dropped_waveform
+class SpecAugment(paddle.nn.Layer):
+ """An implementation of the SpecAugment algorithm.
+ Reference:
+ https://arxiv.org/abs/1904.08779
+ Arguments
+ ---------
+ time_warp : bool
+ Whether applying time warping.
+ time_warp_window : int
+ Time warp window.
+ time_warp_mode : str
+ Interpolation mode for time warping (default "bicubic").
+ freq_mask : bool
+ Whether applying freq mask.
+ freq_mask_width : int or tuple
+ Freq mask width range.
+ n_freq_mask : int
+ Number of freq mask.
+ time_mask : bool
+ Whether applying time mask.
+ time_mask_width : int or tuple
+ Time mask width range.
+ n_time_mask : int
+ Number of time mask.
+ replace_with_zero : bool
+ If True, replace masked value with 0, else replace masked value with mean of the input tensor.
+ Example
+ -------
+ >>> aug = SpecAugment()
+ >>> a = paddle.rand([8, 120, 80])
+ >>> a = aug(a)
+ >>> print(a.shape)
+ paddle.Size([8, 120, 80])
+ """
+
+ def __init__(
+ self,
+ time_warp=True,
+ time_warp_window=5,
+ time_warp_mode="bicubic",
+ freq_mask=True,
+ freq_mask_width=(0, 20),
+ n_freq_mask=2,
+ time_mask=True,
+ time_mask_width=(0, 100),
+ n_time_mask=2,
+ replace_with_zero=True, ):
+ super().__init__()
+ assert (
+ time_warp or freq_mask or time_mask
+ ), "at least one of time_warp, time_mask, or freq_mask should be applied"
+
+ self.apply_time_warp = time_warp
+ self.time_warp_window = time_warp_window
+ self.time_warp_mode = time_warp_mode
+
+ self.freq_mask = freq_mask
+ if isinstance(freq_mask_width, int):
+ freq_mask_width = (0, freq_mask_width)
+ self.freq_mask_width = freq_mask_width
+ self.n_freq_mask = n_freq_mask
+
+ self.time_mask = time_mask
+ if isinstance(time_mask_width, int):
+ time_mask_width = (0, time_mask_width)
+ self.time_mask_width = time_mask_width
+ self.n_time_mask = n_time_mask
+
+ self.replace_with_zero = replace_with_zero
+
+ def forward(self, x):
+ """Takes in input a tensors and returns an augmented one."""
+ if self.apply_time_warp:
+ x = self.time_warp(x)
+ if self.freq_mask:
+ x = self.mask_along_axis(x, dim=2)
+ if self.time_mask:
+ x = self.mask_along_axis(x, dim=1)
+ return x
+
+ def time_warp(self, x):
+ """Time warping with paddle.nn.functional.interpolate"""
+ original_size = x.shape
+ window = self.time_warp_window
+
+ # 2d interpolation requires 4D or higher dimension tensors
+ # x: (Batch, Time, Freq) -> (Batch, 1, Time, Freq)
+ if x.dim() == 3:
+ x = x.unsqueeze(1)
+
+ time = x.shape[2]
+ if time - window <= window:
+ return x.view(*original_size)
+
+ # compute center and corresponding window
+ c = paddle.randint(window, time - window, (1, ))[0]
+ w = paddle.randint(c - window, c + window, (1, ))[0] + 1
+ # c = 5
+ # w = 10
+ left = paddle.nn.functional.interpolate(
+ x[:, :, :c],
+ (w, x.shape[3]),
+ mode=self.time_warp_mode,
+ align_corners=True, )
+ right = paddle.nn.functional.interpolate(
+ x[:, :, c:],
+ (time - w, x.shape[3]),
+ mode=self.time_warp_mode,
+ align_corners=True, )
+
+ x[:, :, :w] = left
+ x[:, :, w:] = right
+ return x.view(*original_size)
+
+ def mask_along_axis(self, x, dim):
+ """Mask along time or frequency axis.
+ Arguments
+ ---------
+ x : tensor
+ Input tensor.
+ dim : int
+ Corresponding dimension to mask.
+ """
+ original_size = x.shape
+ if x.dim() == 4:
+ x = x.view(-1, x.shape[2], x.shape[3])
+
+ batch, time, fea = x.shape
+
+ if dim == 1:
+ D = time
+ n_mask = self.n_time_mask
+ width_range = self.time_mask_width
+ else:
+ D = fea
+ n_mask = self.n_freq_mask
+ width_range = self.freq_mask_width
+
+ mask_len = paddle.randint(width_range[0], width_range[1],
+ (batch, n_mask)).unsqueeze(2)
+
+ mask_pos = paddle.randint(0, max(1, D - mask_len.max()),
+ (batch, n_mask)).unsqueeze(2)
+
+ # compute masks
+ arange = paddle.arange(end=D).view(1, 1, -1)
+ mask = (mask_pos <= arange) * (arange < (mask_pos + mask_len))
+ mask = mask.any(axis=1)
+
+ if dim == 1:
+ mask = mask.unsqueeze(2)
+ else:
+ mask = mask.unsqueeze(1)
+
+ if self.replace_with_zero:
+ val = 0.0
+ else:
+ val = x.mean()
+ # same to x.masked_fill_(mask, val)
+ y = paddle.full(x.shape, val, x.dtype)
+ x = paddle.where(mask, y, x)
+ return x.view(*original_size)
+
+
class TimeDomainSpecAugment(nn.Layer):
"""A time-domain approximation of the SpecAugment algorithm.
This augmentation module implements three augmentations in
diff --git a/paddlespeech/s2t/models/wav2vec2/wav2vec2_ASR.py b/paddlespeech/s2t/models/wav2vec2/wav2vec2_ASR.py
index e13347740..eda188da5 100644
--- a/paddlespeech/s2t/models/wav2vec2/wav2vec2_ASR.py
+++ b/paddlespeech/s2t/models/wav2vec2/wav2vec2_ASR.py
@@ -23,7 +23,9 @@ import paddle.nn.functional as F
from paddlespeech.s2t.models.wav2vec2.modules.modeling_wav2vec2 import Wav2Vec2ConfigPure
from paddlespeech.s2t.models.wav2vec2.modules.modeling_wav2vec2 import Wav2Vec2Model
from paddlespeech.s2t.models.wav2vec2.modules.VanillaNN import VanillaNN
+from paddlespeech.s2t.models.wav2vec2.processing.speech_augmentation import SpecAugment
from paddlespeech.s2t.modules.ctc import CTCDecoderBase as CTC
+from paddlespeech.s2t.modules.initializer import DefaultInitializerContext
from paddlespeech.s2t.utils.ctc_utils import remove_duplicates_and_blank
from paddlespeech.s2t.utils.utility import log_add
@@ -31,44 +33,41 @@ from paddlespeech.s2t.utils.utility import log_add
class Wav2vec2ASR(nn.Layer):
def __init__(self, config: dict):
super().__init__()
+ init_type = config.get("init_type", None)
+ with DefaultInitializerContext(init_type):
+ self.config = config
+ wav2vec2_config = Wav2Vec2ConfigPure(config)
+ wav2vec2 = Wav2Vec2Model(wav2vec2_config)
+ self.normalize_wav = config.normalize_wav
+ self.output_norm = config.output_norm
+ if hasattr(config, 'spec_augment'):
+ self.spec_augment = SpecAugment(**config.spec_augment)
- wav2vec2_config = Wav2Vec2ConfigPure(config)
- wav2vec2 = Wav2Vec2Model(wav2vec2_config)
- model_dict = paddle.load(config.wav2vec2_params_path)
- wav2vec2.set_state_dict(model_dict)
- self.normalize_wav = config.normalize_wav
- self.output_norm = config.output_norm
- if config.freeze_wav2vec2:
- wav2vec2.eval()
- for parm in wav2vec2.parameters():
- parm.trainable = False
- self.wav2vec2 = wav2vec2
- self.enc = VanillaNN(
- input_shape=[None, None, wav2vec2_config.hidden_size],
- activation=nn.LeakyReLU,
- dnn_blocks=config.dnn_blocks,
- dnn_neurons=config.dnn_neurons)
- self.ctc = CTC(odim=config.output_dim,
- enc_n_units=config.dnn_neurons,
- blank_id=config.blank_id,
- dropout_rate=config.ctc_dropout_rate,
- reduction='mean')
-
- def forward(self, wav, wavs_lens_rate, target, target_lens_rate):
+ if config.freeze_wav2vec2:
+ wav2vec2.eval()
+ for parm in wav2vec2.parameters():
+ parm.trainable = False
+ self.wav2vec2 = wav2vec2
+ self.enc = VanillaNN(**config.enc)
+ self.ctc = CTC(**config.ctc,
+ odim=config.output_dim,
+ batch_average=False,
+ reduction='mean')
+
+ def forward(self, wav, wavs_lens_rate, target, target_lens):
if self.normalize_wav:
- wav = F.layer_norm(wav, wav.shape[1:])
+ wav = F.layer_norm(wav, wav.shape)
# Extract wav2vec output
out = self.wav2vec2(wav)[0]
# We normalize the output if required
if self.output_norm:
- out = F.layer_norm(out, out.shape[1:])
- feats = out
-
+ out = F.layer_norm(out, out.shape)
+ if self.train and hasattr(self.config, 'spec_augment'):
+ feats = self.spec_augment(out)
+ else:
+ feats = out
x = self.enc(feats)
x_lens = (wavs_lens_rate * x.shape[1]).round().astype(paddle.int64)
- target_lens = (target_lens_rate *
- target.shape[1]).round().astype(paddle.int64)
-
ctc_loss = self.ctc(x, x_lens, target, target_lens)
return ctc_loss
@@ -239,3 +238,33 @@ class Wav2vec2ASR(nn.Layer):
"""
hyps = self._ctc_prefix_beam_search(wav, beam_size)
return hyps[0][0]
+
+
+class Wav2vec2Base(nn.Layer):
+ """Wav2vec2 model"""
+
+ def __init__(self, config: dict):
+ super().__init__()
+ wav2vec2_config = Wav2Vec2ConfigPure(config)
+ wav2vec2 = Wav2Vec2Model(wav2vec2_config)
+ self.wav2vec2 = wav2vec2
+
+ @classmethod
+ def from_config(cls, configs: dict):
+ """init model.
+
+ Args:
+ configs (dict): config dict.
+
+ Raises:
+ ValueError: raise when using not support encoder type.
+
+ Returns:
+ nn.Layer: Wav2Vec2Base
+ """
+ model = cls(configs)
+ return model
+
+ def forward(self, wav):
+ out = self.wav2vec2(wav)
+ return out
diff --git a/paddlespeech/s2t/models/whisper/__init__.py b/paddlespeech/s2t/models/whisper/__init__.py
new file mode 100644
index 000000000..98ab23610
--- /dev/null
+++ b/paddlespeech/s2t/models/whisper/__init__.py
@@ -0,0 +1,12 @@
+# MIT License, Copyright (c) 2022 OpenAI.
+# Copyright (c) 2022 PaddlePaddle Authors and . All Rights Reserved.
+#
+# Modified from OpenAI Whisper 2022 (https://github.com/openai/whisper/whisper/__init__.py)
+from paddlespeech.s2t.models.whisper.whipser import decode
+from paddlespeech.s2t.models.whisper.whipser import DecodingOptions
+from paddlespeech.s2t.models.whisper.whipser import DecodingResult
+from paddlespeech.s2t.models.whisper.whipser import detect_language
+from paddlespeech.s2t.models.whisper.whipser import log_mel_spectrogram
+from paddlespeech.s2t.models.whisper.whipser import ModelDimensions
+from paddlespeech.s2t.models.whisper.whipser import transcribe
+from paddlespeech.s2t.models.whisper.whipser import Whisper
diff --git a/paddlespeech/s2t/models/whisper/tokenizer.py b/paddlespeech/s2t/models/whisper/tokenizer.py
new file mode 100644
index 000000000..8bd85c914
--- /dev/null
+++ b/paddlespeech/s2t/models/whisper/tokenizer.py
@@ -0,0 +1,362 @@
+# MIT License, Copyright (c) 2022 OpenAI.
+# Copyright (c) 2022 PaddlePaddle Authors and . All Rights Reserved.
+#
+# Modified from OpenAI Whisper 2022 (https://github.com/openai/whisper/whisper/tokenizer.py)
+import os
+from dataclasses import dataclass
+from functools import lru_cache
+from typing import List
+from typing import Optional
+from typing import Tuple
+from typing import Union
+
+import numpy as np
+import paddle
+from paddlenlp.transformers import GPTTokenizer
+
+LANGUAGES = {
+ "en": "english",
+ "zh": "chinese",
+ "de": "german",
+ "es": "spanish",
+ "ru": "russian",
+ "ko": "korean",
+ "fr": "french",
+ "ja": "japanese",
+ "pt": "portuguese",
+ "tr": "turkish",
+ "pl": "polish",
+ "ca": "catalan",
+ "nl": "dutch",
+ "ar": "arabic",
+ "sv": "swedish",
+ "it": "italian",
+ "id": "indonesian",
+ "hi": "hindi",
+ "fi": "finnish",
+ "vi": "vietnamese",
+ "iw": "hebrew",
+ "uk": "ukrainian",
+ "el": "greek",
+ "ms": "malay",
+ "cs": "czech",
+ "ro": "romanian",
+ "da": "danish",
+ "hu": "hungarian",
+ "ta": "tamil",
+ "no": "norwegian",
+ "th": "thai",
+ "ur": "urdu",
+ "hr": "croatian",
+ "bg": "bulgarian",
+ "lt": "lithuanian",
+ "la": "latin",
+ "mi": "maori",
+ "ml": "malayalam",
+ "cy": "welsh",
+ "sk": "slovak",
+ "te": "telugu",
+ "fa": "persian",
+ "lv": "latvian",
+ "bn": "bengali",
+ "sr": "serbian",
+ "az": "azerbaijani",
+ "sl": "slovenian",
+ "kn": "kannada",
+ "et": "estonian",
+ "mk": "macedonian",
+ "br": "breton",
+ "eu": "basque",
+ "is": "icelandic",
+ "hy": "armenian",
+ "ne": "nepali",
+ "mn": "mongolian",
+ "bs": "bosnian",
+ "kk": "kazakh",
+ "sq": "albanian",
+ "sw": "swahili",
+ "gl": "galician",
+ "mr": "marathi",
+ "pa": "punjabi",
+ "si": "sinhala",
+ "km": "khmer",
+ "sn": "shona",
+ "yo": "yoruba",
+ "so": "somali",
+ "af": "afrikaans",
+ "oc": "occitan",
+ "ka": "georgian",
+ "be": "belarusian",
+ "tg": "tajik",
+ "sd": "sindhi",
+ "gu": "gujarati",
+ "am": "amharic",
+ "yi": "yiddish",
+ "lo": "lao",
+ "uz": "uzbek",
+ "fo": "faroese",
+ "ht": "haitian creole",
+ "ps": "pashto",
+ "tk": "turkmen",
+ "nn": "nynorsk",
+ "mt": "maltese",
+ "sa": "sanskrit",
+ "lb": "luxembourgish",
+ "my": "myanmar",
+ "bo": "tibetan",
+ "tl": "tagalog",
+ "mg": "malagasy",
+ "as": "assamese",
+ "tt": "tatar",
+ "haw": "hawaiian",
+ "ln": "lingala",
+ "ha": "hausa",
+ "ba": "bashkir",
+ "jw": "javanese",
+ "su": "sundanese",
+}
+
+# language code lookup by name, with a few language aliases
+TO_LANGUAGE_CODE = {
+ **{language: code for code, language in LANGUAGES.items()},
+ "burmese": "my",
+ "valencian": "ca",
+ "flemish": "nl",
+ "haitian": "ht",
+ "letzeburgesch": "lb",
+ "pushto": "ps",
+ "panjabi": "pa",
+ "moldavian": "ro",
+ "moldovan": "ro",
+ "sinhalese": "si",
+ "castilian": "es",
+}
+
+
+@dataclass(frozen=True)
+class Tokenizer:
+ """A thin wrapper around `GPTTokenizer` providing quick access to special tokens"""
+
+ tokenizer: "GPTTokenizer"
+ language: Optional[str]
+ sot_sequence: Tuple[int]
+
+ def encode(self, text, **kwargs):
+ return self.tokenizer.encode(text, **kwargs)
+
+ def decode(self,
+ token_ids: Union[int, List[int], np.ndarray, paddle.Tensor],
+ **kwargs):
+ if len(token_ids) > 1:
+ ids_list = []
+ for ids in token_ids:
+ if paddle.is_tensor(ids):
+ ids = ids.item()
+ if ids < len(self.tokenizer):
+ ids_list.append(ids)
+ token_ids = ids_list
+
+ return self.tokenizer.decode(token_ids, **kwargs)
+
+ def decode_with_timestamps(self, tokens) -> str:
+ """
+ Timestamp tokens are above the special tokens' id range and are ignored by `decode()`.
+ This method decodes given tokens with timestamps tokens annotated, e.g. "<|1.08|>".
+ """
+ outputs = [[]]
+ for token in tokens:
+ if token >= self.timestamp_begin:
+ timestamp = f"<|{(token - self.timestamp_begin) * 0.02:.2f}|>"
+ outputs.append(timestamp)
+ outputs.append([])
+ else:
+ outputs[-1].append(token)
+ outputs = [
+ s if isinstance(s, str) else self.tokenizer.decode(s)
+ for s in outputs
+ ]
+ return "".join(outputs)
+
+ @property
+ @lru_cache()
+ def eot(self) -> int:
+ return self.tokenizer.eos_token_id
+
+ @property
+ @lru_cache()
+ def sot(self) -> int:
+ return self._get_single_token_id("<|startoftranscript|>")
+
+ @property
+ @lru_cache()
+ def sot_lm(self) -> int:
+ return self._get_single_token_id("<|startoflm|>")
+
+ @property
+ @lru_cache()
+ def sot_prev(self) -> int:
+ return self._get_single_token_id("<|startofprev|>")
+
+ @property
+ @lru_cache()
+ def no_speech(self) -> int:
+ return self._get_single_token_id("<|nospeech|>")
+
+ @property
+ @lru_cache()
+ def no_timestamps(self) -> int:
+ return self._get_single_token_id("<|notimestamps|>")
+
+ @property
+ @lru_cache()
+ def timestamp_begin(self) -> int:
+ return self.tokenizer.all_special_ids[-1] + 1
+
+ @property
+ @lru_cache()
+ def language_token(self) -> int:
+ """Returns the token id corresponding to the value of the `language` field"""
+ if self.language is None:
+ raise ValueError(
+ "This tokenizer does not have language token configured")
+
+ additional_tokens = dict(
+ zip(
+ self.tokenizer.additional_special_tokens,
+ self.tokenizer.additional_special_tokens_ids, ))
+ candidate = f"<|{self.language}|>"
+ if candidate in additional_tokens:
+ return additional_tokens[candidate]
+
+ raise KeyError(f"Language {self.language} not found in tokenizer.")
+
+ @property
+ @lru_cache()
+ def all_language_tokens(self) -> Tuple[int]:
+ result = []
+ for token, token_id in zip(
+ self.tokenizer.additional_special_tokens,
+ self.tokenizer.additional_special_tokens_ids, ):
+ if token.strip("<|>") in LANGUAGES:
+ result.append(token_id)
+ return tuple(result)
+
+ @property
+ @lru_cache()
+ def all_language_codes(self) -> Tuple[str]:
+ return tuple(
+ self.decode([l]).strip("<|>") for l in self.all_language_tokens)
+
+ @property
+ @lru_cache()
+ def sot_sequence_including_notimestamps(self) -> Tuple[int]:
+ return tuple(list(self.sot_sequence) + [self.no_timestamps])
+
+ @property
+ @lru_cache()
+ def non_speech_tokens(self) -> Tuple[int]:
+ """
+ Returns the list of tokens to suppress in order to avoid any speaker tags or non-speech
+ annotations, to prevent sampling texts that are not actually spoken in the audio, e.g.
+
+ - ♪♪♪
+ - ( SPEAKING FOREIGN LANGUAGE )
+ - [DAVID] Hey there,
+
+ keeping basic punctuations like commas, periods, question marks, exclamation points, etc.
+ """
+ symbols = list("\"#()*+/:;<=>@[\\]^_`{|}~「」『』")
+ symbols += "<< >> <<< >>> -- --- -( -[ (' (\" (( )) ((( ))) [[ ]] {{ }} ♪♪ ♪♪♪".split(
+ )
+
+ # symbols that may be a single token or multiple tokens depending on the tokenizer.
+ # In case they're multiple tokens, suppress the first token, which is safe because:
+ # These are between U+2640 and U+267F miscellaneous symbols that are okay to suppress
+ # in generations, and in the 3-byte UTF-8 representation they share the first two bytes.
+ miscellaneous = set("♩♪♫♬♭♮♯")
+ assert all(0x2640 <= ord(c) <= 0x267F for c in miscellaneous)
+
+ # allow hyphens "-" and single quotes "'" between words, but not at the beginning of a word
+ result = {
+ self.tokenizer.encode(" -").input_ids[0],
+ self.tokenizer.encode(" '").input_ids[0]
+ }
+ for symbol in symbols + list(miscellaneous):
+ for tokens in [
+ self.tokenizer.encode(symbol).input_ids,
+ self.tokenizer.encode(" " + symbol).input_ids
+ ]:
+ if len(tokens) == 1 or symbol in miscellaneous:
+ result.add(tokens[0])
+
+ return tuple(sorted(result))
+
+ def _get_single_token_id(self, text) -> int:
+ tokens = self.tokenizer.encode(text).input_ids
+ assert len(tokens) == 1, f"{text} is not encoded as a single token"
+ return tokens[0]
+
+
+@lru_cache(maxsize=None)
+def build_tokenizer(resource_path: str, name: str="gpt2"):
+ os.environ["TOKENIZERS_PARALLELISM"] = "false"
+ path = os.path.join(resource_path, "assets", name)
+ tokenizer = GPTTokenizer.from_pretrained(path)
+
+ specials = [
+ "<|startoftranscript|>",
+ * [f"<|{lang}|>" for lang in LANGUAGES.keys()],
+ "<|translate|>",
+ "<|transcribe|>",
+ "<|startoflm|>",
+ "<|startofprev|>",
+ "<|nospeech|>",
+ "<|notimestamps|>",
+ ]
+
+ tokenizer.add_special_tokens(dict(additional_special_tokens=specials))
+ return tokenizer
+
+
+@lru_cache(maxsize=None)
+def get_tokenizer(
+ multilingual: bool,
+ resource_path: str,
+ *,
+ task: Optional[str]=None, # Literal["transcribe", "translate", None]
+ language: Optional[str]=None, ) -> Tokenizer:
+ if language is not None:
+ language = language.lower()
+ if language not in LANGUAGES:
+ if language in TO_LANGUAGE_CODE:
+ language = TO_LANGUAGE_CODE[language]
+ else:
+ raise ValueError(f"Unsupported language: {language}")
+
+ if multilingual:
+ tokenizer_name = "multilingual"
+ task = task or "transcribe"
+ language = language or "en"
+ else:
+ tokenizer_name = "gpt2"
+ task = None
+ language = None
+
+ tokenizer = build_tokenizer(
+ resource_path=resource_path, name=tokenizer_name)
+ all_special_ids: List[int] = tokenizer.all_special_ids
+ sot: int = all_special_ids[1]
+ translate: int = all_special_ids[-6]
+ transcribe: int = all_special_ids[-5]
+
+ langs = tuple(LANGUAGES.keys())
+ sot_sequence = [sot]
+ if language is not None:
+ sot_sequence.append(sot + 1 + langs.index(language))
+ if task is not None:
+ sot_sequence.append(transcribe if task == "transcribe" else translate)
+
+ return Tokenizer(
+ tokenizer=tokenizer,
+ language=language,
+ sot_sequence=tuple(sot_sequence))
diff --git a/paddlespeech/s2t/models/whisper/utils.py b/paddlespeech/s2t/models/whisper/utils.py
new file mode 100644
index 000000000..d067af7d2
--- /dev/null
+++ b/paddlespeech/s2t/models/whisper/utils.py
@@ -0,0 +1,92 @@
+# MIT License, Copyright (c) 2022 OpenAI.
+# Copyright (c) 2022 PaddlePaddle Authors and . All Rights Reserved.
+#
+# Modified from OpenAI Whisper 2022 (https://github.com/openai/whisper/whisper/utils.py)
+import zlib
+from typing import Iterator
+from typing import TextIO
+
+
+def exact_div(x, y):
+ assert x % y == 0
+ return x // y
+
+
+def str2bool(string):
+ str2val = {"True": True, "False": False}
+ if string in str2val:
+ return str2val[string]
+ else:
+ raise ValueError(f"Expected one of {set(str2val.keys())}, got {string}")
+
+
+def optional_int(string):
+ return None if string == "None" else int(string)
+
+
+def optional_float(string):
+ return None if string == "None" else float(string)
+
+
+def compression_ratio(text) -> float:
+ return len(text) / len(zlib.compress(text.encode("utf-8")))
+
+
+def format_timestamp(seconds: float,
+ always_include_hours: bool=False,
+ decimal_marker: str='.'):
+ assert seconds >= 0, "non-negative timestamp expected"
+ milliseconds = round(seconds * 1000.0)
+
+ hours = milliseconds // 3_600_000
+ milliseconds -= hours * 3_600_000
+
+ minutes = milliseconds // 60_000
+ milliseconds -= minutes * 60_000
+
+ seconds = milliseconds // 1_000
+ milliseconds -= seconds * 1_000
+
+ hours_marker = f"{hours:02d}:" if always_include_hours or hours > 0 else ""
+ return f"{hours_marker}{minutes:02d}:{seconds:02d}{decimal_marker}{milliseconds:03d}"
+
+
+def write_txt(transcript: Iterator[dict], file: TextIO):
+ for segment in transcript:
+ print(segment['text'].strip(), file=file, flush=True)
+
+
+def write_vtt(transcript: Iterator[dict], file: TextIO):
+ print("WEBVTT\n", file=file)
+ for segment in transcript:
+ print(
+ f"{format_timestamp(segment['start'])} --> {format_timestamp(segment['end'])}\n"
+ f"{segment['text'].strip().replace('-->', '->')}\n",
+ file=file,
+ flush=True, )
+
+
+def write_srt(transcript: Iterator[dict], file: TextIO):
+ """
+ Write a transcript to a file in SRT format.
+
+ Example usage:
+ from pathlib import Path
+ from whisper.utils import write_srt
+
+ result = transcribe(model, audio_path, temperature=temperature, **args)
+
+ # save SRT
+ audio_basename = Path(audio_path).stem
+ with open(Path(output_dir) / (audio_basename + ".srt"), "w", encoding="utf-8") as srt:
+ write_srt(result["segments"], file=srt)
+ """
+ for i, segment in enumerate(transcript, start=1):
+ # write srt lines
+ print(
+ f"{i}\n"
+ f"{format_timestamp(segment['start'], always_include_hours=True, decimal_marker=',')} --> "
+ f"{format_timestamp(segment['end'], always_include_hours=True, decimal_marker=',')}\n"
+ f"{segment['text'].strip().replace('-->', '->')}\n",
+ file=file,
+ flush=True, )
diff --git a/paddlespeech/s2t/models/whisper/whipser.py b/paddlespeech/s2t/models/whisper/whipser.py
new file mode 100644
index 000000000..ba9983338
--- /dev/null
+++ b/paddlespeech/s2t/models/whisper/whipser.py
@@ -0,0 +1,1478 @@
+# MIT License, Copyright (c) 2022 OpenAI.
+# Copyright (c) 2022 PaddlePaddle Authors and . All Rights Reserved.
+#
+# Modified from OpenAI Whisper 2022 (https://github.com/openai/whisper/whisper)
+import os
+from dataclasses import dataclass
+from dataclasses import field
+from functools import lru_cache
+from typing import Dict
+from typing import Iterable
+from typing import List
+from typing import Optional
+from typing import Sequence
+from typing import Tuple
+from typing import Union
+
+import numpy as np
+import paddle
+import paddle.fluid as fluid
+import paddle.nn.functional as F
+import soundfile
+import tqdm
+from paddle import nn
+from paddle.distribution import Categorical
+
+import paddlespeech.s2t.modules.align as paddlespeech_nn
+from paddlespeech.s2t.models.whisper import utils
+from paddlespeech.s2t.models.whisper.tokenizer import get_tokenizer
+from paddlespeech.s2t.models.whisper.tokenizer import LANGUAGES
+from paddlespeech.s2t.models.whisper.tokenizer import Tokenizer
+from paddlespeech.s2t.utils.log import Log
+logger = Log(__name__).getlog()
+
+_MODELS = ["large"]
+SAMPLE_RATE = 16000
+N_FFT = 400
+N_MELS = 80
+HOP_LENGTH = 160
+CHUNK_LENGTH = 30
+N_SAMPLES = CHUNK_LENGTH * SAMPLE_RATE # 480000: number of samples in a chunk
+N_FRAMES = utils.exact_div(
+ N_SAMPLES, HOP_LENGTH) # 3000: number of frames in a mel spectrogram input
+
+
+@dataclass
+class ModelDimensions:
+ n_mels: int
+ n_audio_ctx: int
+ n_audio_state: int
+ n_audio_head: int
+ n_audio_layer: int
+ n_vocab: int
+ n_text_ctx: int
+ n_text_state: int
+ n_text_head: int
+ n_text_layer: int
+
+
+class LayerNorm(paddlespeech_nn.LayerNorm):
+ def forward(self, x: paddle.Tensor) -> paddle.Tensor:
+ return super().forward(x)
+
+
+class Linear(paddlespeech_nn.Linear):
+ def forward(self, x: paddle.Tensor) -> paddle.Tensor:
+ return F.linear(x, self.weight, None
+ if self.bias is None else self.bias)
+
+
+class Conv1d(paddlespeech_nn.Conv1D):
+ def forward(self, x: paddle.Tensor) -> paddle.Tensor:
+ return super().forward(x)
+
+
+class MultiHeadAttention(nn.Layer):
+ def __init__(self, n_state: int, n_head: int):
+ super().__init__()
+ self.n_head = n_head
+ self.query = Linear(n_state, n_state, bias_attr=True)
+ self.key = Linear(n_state, n_state, bias_attr=False)
+ self.value = Linear(n_state, n_state, bias_attr=True)
+ self.out = Linear(n_state, n_state, bias_attr=True)
+
+ def forward(
+ self,
+ x: paddle.Tensor,
+ xa: Optional[paddle.Tensor]=None,
+ mask: Optional[paddle.Tensor]=None,
+ kv_cache: Optional[dict]=None, ):
+ q = self.query(x)
+
+ if kv_cache is None or xa is None or self.key not in kv_cache:
+ # hooks, if installed (i.e. kv_cache is not None), will prepend the cached kv tensors;
+ # otherwise, perform key/value projections for self- or cross-attention as usual.
+ k = self.key(x if xa is None else xa)
+ v = self.value(x if xa is None else xa)
+ else:
+ # for cross-attention, calculate keys and values once and reuse in subsequent calls.
+ k = kv_cache[self.key]
+ v = kv_cache[self.value]
+
+ wv = self.qkv_attention(q, k, v, mask)
+ return self.out(wv)
+
+ def qkv_attention(self,
+ q: paddle.Tensor,
+ k: paddle.Tensor,
+ v: paddle.Tensor,
+ mask: Optional[paddle.Tensor]=None):
+ n_batch, n_ctx, n_state = q.shape
+ scale = (n_state // self.n_head)**-0.25
+ q = paddle.transpose(
+ q.view(*q.shape[:2], self.n_head, -1), (0, 2, 1, 3)) * scale
+ k = paddle.transpose(
+ k.view(*k.shape[:2], self.n_head, -1), (0, 2, 3, 1)) * scale
+ v = paddle.transpose(
+ v.view(*v.shape[:2], self.n_head, -1), (0, 2, 1, 3))
+
+ qk = q @ k
+ if mask is not None:
+ qk = qk + mask[:n_ctx, :n_ctx]
+
+ w = F.softmax(qk.float(), axis=-1).to(q.dtype)
+ return paddle.transpose((w @ v), (0, 2, 1, 3)).flatten(start_axis=2)
+
+
+class ResidualAttentionBlock(nn.Layer):
+ def __init__(self, n_state: int, n_head: int, cross_attention: bool=False):
+ super().__init__()
+
+ self.attn = MultiHeadAttention(n_state, n_head)
+ self.attn_ln = LayerNorm(n_state)
+
+ self.cross_attn = MultiHeadAttention(
+ n_state, n_head) if cross_attention else None
+ self.cross_attn_ln = LayerNorm(n_state) if cross_attention else None
+
+ n_mlp = n_state * 4
+ self.mlp = nn.Sequential(
+ Linear(n_state, n_mlp, bias_attr=True),
+ nn.GELU(), Linear(n_mlp, n_state, bias_attr=True))
+ self.mlp_ln = LayerNorm(n_state)
+
+ def forward(
+ self,
+ x: paddle.Tensor,
+ xa: Optional[paddle.Tensor]=None,
+ mask: Optional[paddle.Tensor]=None,
+ kv_cache: Optional[dict]=None, ):
+ x = x + self.attn(self.attn_ln(x), mask=mask, kv_cache=kv_cache)
+ if self.cross_attn:
+ x = x + self.cross_attn(
+ self.cross_attn_ln(x), xa, kv_cache=kv_cache)
+ x = x + self.mlp(self.mlp_ln(x))
+ return x
+
+
+def sinusoids(length, channels, max_timescale=10000):
+ """Returns sinusoids for positional embedding"""
+ assert channels % 2 == 0
+ log_timescale_increment = np.log(max_timescale) / (channels // 2 - 1)
+ inv_timescales = paddle.exp(-log_timescale_increment * paddle.arange(
+ channels // 2, dtype=paddle.float32))
+ scaled_time = paddle.arange(
+ length,
+ dtype=paddle.float32)[:, np.newaxis] * inv_timescales[np.newaxis, :]
+ return paddle.to_tensor(
+ paddle.concat(
+ [paddle.sin(scaled_time), paddle.cos(scaled_time)], axis=1))
+
+
+class AudioEncoder(nn.Layer):
+ def __init__(self,
+ n_mels: int,
+ n_ctx: int,
+ n_state: int,
+ n_head: int,
+ n_layer: int):
+ super().__init__()
+ self.conv1 = Conv1d(
+ n_mels, n_state, kernel_size=3, stride=1, padding=1, bias_attr=True)
+ self.conv2 = Conv1d(
+ n_state,
+ n_state,
+ kernel_size=3,
+ stride=2,
+ padding=1,
+ bias_attr=True)
+ self.register_buffer("positional_embedding", sinusoids(n_ctx, n_state))
+
+ self.blocks: Iterable[ResidualAttentionBlock] = nn.LayerList(
+ [ResidualAttentionBlock(n_state, n_head) for _ in range(n_layer)])
+ self.ln_post = LayerNorm(n_state)
+
+ def forward(self, x: paddle.Tensor):
+ """
+ x : paddle.Tensor, shape = (batch_size, n_mels, n_ctx)
+ the mel spectrogram of the audio
+ """
+ x = F.gelu(self.conv1(x))
+ x = F.gelu(self.conv2(x))
+ x = paddle.transpose(x, (0, 2, 1))
+
+ assert x.shape[
+ 1:] == self.positional_embedding.shape, "incorrect audio shape"
+ x = (x + self.positional_embedding)
+
+ for block in self.blocks:
+ x = block(x)
+
+ x = self.ln_post(x)
+ return x
+
+
+class TextDecoder(nn.Layer):
+ def __init__(self,
+ n_vocab: int,
+ n_ctx: int,
+ n_state: int,
+ n_head: int,
+ n_layer: int):
+ super().__init__()
+
+ self.token_embedding = nn.Embedding(n_vocab, n_state)
+ self.positional_embedding = paddle.create_parameter(
+ shape=[n_ctx, n_state], dtype='float32')
+
+ self.blocks: Iterable[ResidualAttentionBlock] = nn.LayerList([
+ ResidualAttentionBlock(n_state, n_head, cross_attention=True)
+ for _ in range(n_layer)
+ ])
+ self.ln = LayerNorm(n_state)
+
+ mask = fluid.layers.fill_constant(
+ shape=[n_ctx, n_state], value=-np.inf, dtype='float32')
+ mask = paddle.triu(mask, diagonal=1)
+ self.register_buffer("mask", mask, persistable=False)
+
+ def forward(self,
+ x: paddle.Tensor,
+ xa: paddle.Tensor,
+ kv_cache: Optional[dict]=None):
+ """
+ x : paddle.LongTensor, shape = (batch_size, <= n_ctx)
+ the text tokens
+ xa : paddle.Tensor, shape = (batch_size, n_mels, n_audio_ctx)
+ the encoded audio features to be attended on
+ """
+ offset = next(iter(kv_cache.values())).shape[1] if kv_cache else 0
+ x = self.token_embedding(x) + self.positional_embedding[offset:offset +
+ x.shape[-1]]
+ x = x.to(xa.dtype)
+
+ for block in self.blocks:
+ x = block(x, xa, mask=self.mask, kv_cache=kv_cache)
+
+ x = self.ln(x)
+ logits = (x @ paddle.transpose(self.token_embedding.weight, (1, 0)))
+
+ return logits
+
+
+@dataclass(frozen=True)
+class DecodingOptions:
+ task: str = "transcribe" # whether to perform X->X "transcribe" or X->English "translate"
+ language: Optional[
+ str] = None # language that the audio is in; uses detected language if None
+ # sampling-related options
+ temperature: float = 0.0
+ sample_len: Optional[int] = None # maximum number of tokens to sample
+ best_of: Optional[
+ int] = None # number of independent samples to collect, when t > 0
+ beam_size: Optional[
+ int] = None # number of beams in beam search, when t == 0
+ patience: Optional[
+ float] = None # patience in beam search (https://arxiv.org/abs/2204.05424)
+
+ # options for ranking generations (either beams or best-of-N samples)
+ length_penalty: Optional[
+ float] = None # "alpha" in Google NMT, None defaults to length norm
+
+ # prompt, prefix, and token suppression
+ prompt: Optional[Union[str, List[
+ int]]] = None # text or tokens for the previous context
+ prefix: Optional[Union[str, List[
+ int]]] = None # text or tokens to prefix the current context
+ suppress_blank: bool = True # this will suppress blank outputs
+
+ # list of tokens ids (or comma-separated token ids) to suppress
+ # "-1" will suppress a set of symbols as defined in `tokenizer.non_speech_tokens()`
+ suppress_tokens: Optional[Union[str, Iterable[int]]] = "-1"
+
+ # timestamp sampling options
+ without_timestamps: bool = False # use <|notimestamps|> to sample text tokens only
+ max_initial_timestamp: Optional[
+ float] = 1.0 # the initial timestamp cannot be later than this
+
+ # implementation details
+ fp16: bool = False # use fp16 for most of the calculation
+
+
+@dataclass(frozen=True)
+class DecodingResult:
+ audio_features: paddle.Tensor
+ language: str
+ language_probs: Optional[Dict[str, float]] = None
+ tokens: List[int] = field(default_factory=list)
+ text: str = ""
+ avg_logprob: float = np.nan
+ no_speech_prob: float = np.nan
+ temperature: float = np.nan
+ compression_ratio: float = np.nan
+
+
+class Inference:
+ def logits(self, tokens: paddle.Tensor,
+ audio_features: paddle.Tensor) -> paddle.Tensor:
+ """Perform a forward pass on the decoder and return per-token logits"""
+ raise NotImplementedError
+
+ def rearrange_kv_cache(self, source_indices) -> None:
+ """Update the key-value cache according to the updated beams"""
+ raise NotImplementedError
+
+ def cleanup_caching(self) -> None:
+ """Clean up any resources or hooks after decoding is finished"""
+ pass
+
+
+class WhisperInference(Inference):
+ def __init__(self, model: "Whisper", initial_token_length: int):
+ self.model: "Whisper" = model
+ self.initial_token_length = initial_token_length
+ self.kv_cache = {}
+ self.hooks = []
+
+ def logits(self, tokens: paddle.Tensor,
+ audio_features: paddle.Tensor) -> paddle.Tensor:
+ if not self.kv_cache:
+ self.kv_cache, self.hooks = self.model.install_kv_cache_hooks()
+
+ if tokens.shape[-1] > self.initial_token_length:
+ # only need to use the last token except in the first forward pass
+ tokens = tokens[:, -1:]
+
+ return self.model.decoder(
+ tokens, audio_features, kv_cache=self.kv_cache)
+
+ def cleanup_caching(self):
+ for hook in self.hooks:
+ hook.remove()
+
+ self.kv_cache = {}
+ self.hooks = []
+
+ def rearrange_kv_cache(self, source_indices):
+ for module, tensor in self.kv_cache.items():
+ # update the key/value cache to contain the selected sequences
+ self.kv_cache[module] = tensor[source_indices].detach()
+
+
+@paddle.no_grad()
+def detect_language(
+ model: "Whisper",
+ mel: paddle.Tensor,
+ resource_path: str,
+ tokenizer: Tokenizer=None) -> Tuple[paddle.Tensor, List[dict]]:
+ """
+ Detect the spoken language in the audio, and return them as list of strings, along with the ids
+ of the most probable language tokens and the probability distribution over all language tokens.
+ This is performed outside the main decode loop in order to not interfere with kv-caching.
+
+ Returns
+ -------
+ language_tokens : Tensor, shape = (batch_size,)
+ ids of the most probable language tokens, which appears after the startoftranscript token.
+ language_probs : List[Dict[str, float]], length = batch_size
+ list of dictionaries containing the probability distribution over all languages.
+ """
+ if tokenizer is None:
+ tokenizer = get_tokenizer(
+ model.is_multilingual, resource_path=resource_path)
+ if tokenizer.language is None or tokenizer.language_token not in tokenizer.sot_sequence:
+ raise ValueError(
+ "This model doesn't have language tokens so it can't perform lang id"
+ )
+
+ single = mel.ndim == 2
+ if single:
+ mel = mel.unsqueeze(0)
+
+ # skip encoder forward pass if already-encoded audio features were given
+ if mel.shape[-2:] != (model.dims.n_audio_ctx, model.dims.n_audio_state):
+ mel = model.encoder(mel)
+
+ # forward pass using a single token, startoftranscript
+ batch_size = mel.shape[0]
+ x = paddle.to_tensor([[tokenizer.sot]] * batch_size) # [batch_size, 1]
+ logits = model.logits(x, mel)[:, 0]
+
+ # collect detected languages; suppress all non-language tokens
+ mask = paddle.ones(paddle.to_tensor(logits.shape[-1]), dtype=bool)
+ mask[list(tokenizer.all_language_tokens)] = False
+ logits[:, mask] = -np.inf
+ language_tokens = paddle.argmax(logits, axis=-1)
+ language_token_probs = F.softmax(logits, axis=-1)
+ language_probs = [{
+ c: language_token_probs[i, j].tolist()
+ for j, c in zip(tokenizer.all_language_tokens,
+ tokenizer.all_language_codes)
+ } for i in range(batch_size)]
+
+ if single:
+ language_tokens = language_tokens[0]
+ language_probs = language_probs[0]
+
+ return language_tokens, language_probs
+
+
+def transcribe(
+ model: "Whisper",
+ mel: paddle.Tensor,
+ resource_path: str,
+ *,
+ verbose: Optional[bool]=None,
+ temperature: Union[float, Tuple[float, ...]]=(0.0, 0.2, 0.4, 0.6, 0.8,
+ 1.0),
+ compression_ratio_threshold: Optional[float]=2.4,
+ logprob_threshold: Optional[float]=-1.0,
+ no_speech_threshold: Optional[float]=0.6,
+ condition_on_previous_text: bool=True,
+ **decode_options, ):
+ """
+ Transcribe an audio file using Whisper
+
+ Parameters
+ ----------
+ model: Whisper
+ The Whisper model instance
+
+ mel: paddle.Tensor
+ The audio feature
+
+ verbose: bool
+ Whether to display the text being decoded to the console. If True, displays all the details,
+ If False, displays minimal details. If None, does not display anything
+
+ temperature: Union[float, Tuple[float, ...]]
+ Temperature for sampling. It can be a tuple of temperatures, which will be successfully used
+ upon failures according to either `compression_ratio_threshold` or `logprob_threshold`.
+
+ compression_ratio_threshold: float
+ If the gzip compression ratio is above this value, treat as failed
+
+ logprob_threshold: float
+ If the average log probability over sampled tokens is below this value, treat as failed
+
+ no_speech_threshold: float
+ If the no_speech probability is higher than this value AND the average log probability
+ over sampled tokens is below `logprob_threshold`, consider the segment as silent
+
+ condition_on_previous_text: bool
+ if True, the previous output of the model is provided as a prompt for the next window;
+ disabling may make the text inconsistent across windows, but the model becomes less prone to
+ getting stuck in a failure loop, such as repetition looping or timestamps going out of sync.
+
+ decode_options: dict
+ Keyword arguments to construct `DecodingOptions` instances
+
+ Returns
+ -------
+ A dictionary containing the resulting text ("text") and segment-level details ("segments"), and
+ the spoken language ("language"), which is detected when `decode_options["language"]` is None.
+ """
+ dtype = np.float32 #paddle only support float32
+
+ if dtype == np.float32:
+ decode_options["fp16"] = False
+
+ if decode_options.get(
+ "language", 'None') or decode_options.get("language", None) is None:
+ if not model.is_multilingual:
+ decode_options["language"] = "en"
+ else:
+ if verbose:
+ print(
+ "Detecting language using up to the first 30 seconds. Use `--language` to specify the language"
+ )
+ segment = pad_or_trim(mel, N_FRAMES)
+ _, probs = model.detect_language(segment, resource_path)
+ decode_options["language"] = max(probs, key=probs.get)
+ if verbose is not None:
+ print(
+ f"Detected language: {LANGUAGES[decode_options['language']].title()}"
+ )
+
+ language = decode_options["language"]
+ task = decode_options.get("task", "transcribe")
+ tokenizer = get_tokenizer(
+ model.is_multilingual,
+ resource_path=resource_path,
+ language=language,
+ task=task)
+
+ def decode_with_fallback(segment: paddle.Tensor) -> DecodingResult:
+ temperatures = [temperature] if isinstance(temperature, (
+ int, float)) else temperature
+ decode_result = None
+
+ for t in temperatures:
+ kwargs = {**decode_options}
+ if t > 0:
+ # disable beam_size and patience when t > 0
+ kwargs.pop("beam_size", None)
+ kwargs.pop("patience", None)
+ else:
+ # disable best_of when t == 0
+ kwargs.pop("best_of", None)
+
+ options = DecodingOptions(**kwargs, temperature=t)
+ decode_result = model.decode(segment, options, resource_path)
+
+ needs_fallback = False
+ if compression_ratio_threshold is not None and decode_result.compression_ratio > compression_ratio_threshold:
+ needs_fallback = True # too repetitive
+ if logprob_threshold is not None and decode_result.avg_logprob < logprob_threshold:
+ needs_fallback = True # average log probability is too low
+
+ if not needs_fallback:
+ break
+
+ return decode_result
+
+ seek = 0
+ input_stride = utils.exact_div(
+ N_FRAMES, model.dims.n_audio_ctx) # mel frames per output token: 2
+ time_precision = (input_stride * HOP_LENGTH /
+ SAMPLE_RATE) # time per output token: 0.02 (seconds)
+ all_tokens = []
+ all_segments = []
+ prompt_reset_since = 0
+
+ initial_prompt = decode_options.pop("initial_prompt", None) or []
+ if initial_prompt:
+ initial_prompt = tokenizer.encode(" " +
+ initial_prompt.strip()).input_ids
+ all_tokens.extend(initial_prompt)
+
+ def add_segment(*,
+ start: float,
+ end: float,
+ text_tokens: paddle.Tensor,
+ result: DecodingResult):
+ text = tokenizer.decode(
+ [token for token in text_tokens if token < tokenizer.eot])
+ if len(text.strip()) == 0: # skip empty text output
+ return
+
+ all_segments.append({
+ "id": len(all_segments),
+ "seek": seek,
+ "start": start,
+ "end": end,
+ "text": text,
+ "tokens": result.tokens,
+ "temperature": result.temperature,
+ "avg_logprob": result.avg_logprob,
+ "compression_ratio": result.compression_ratio,
+ "no_speech_prob": result.no_speech_prob,
+ })
+ if verbose:
+ print(
+ f"[{utils.format_timestamp(start)} --> {utils.format_timestamp(end)}] {text}"
+ )
+
+ # show the progress bar when verbose is False (otherwise the transcribed text will be printed)
+ num_frames = mel.shape[-1]
+ previous_seek_value = seek
+
+ with tqdm.tqdm(
+ total=num_frames, unit='frames',
+ disable=verbose is not False) as pbar:
+ while seek < num_frames:
+ timestamp_offset = float(seek * HOP_LENGTH / SAMPLE_RATE)
+ segment = pad_or_trim(mel[:, seek:], N_FRAMES)
+ segment_duration = segment.shape[-1] * HOP_LENGTH / SAMPLE_RATE
+
+ decode_options["prompt"] = all_tokens[prompt_reset_since:]
+ result: DecodingResult = decode_with_fallback(segment)
+ tokens = paddle.to_tensor(result.tokens)
+
+ if no_speech_threshold is not None:
+ # no voice activity check
+ should_skip = result.no_speech_prob > no_speech_threshold
+ if logprob_threshold is not None and result.avg_logprob > logprob_threshold:
+ # don't skip if the logprob is high enough, despite the no_speech_prob
+ should_skip = False
+
+ if should_skip:
+ seek += segment.shape[
+ -1] # fast-forward to the next segment boundary
+ continue
+
+ timestamp_tokens: paddle.Tensor = tokens.greater_equal(
+ paddle.to_tensor(tokenizer.timestamp_begin))
+
+ consecutive = paddle.where(timestamp_tokens[:-1] & timestamp_tokens[
+ 1:])[0]
+ if len(
+ consecutive
+ ) > 0: # if the output contains two consecutive timestamp tokens
+ consecutive = paddle.add(consecutive, paddle.to_tensor(1))
+ last_slice = 0
+ for current_slice in consecutive:
+ sliced_tokens = tokens[last_slice:current_slice]
+ start_timestamp_position = (
+ sliced_tokens[0].item() - tokenizer.timestamp_begin)
+ end_timestamp_position = (
+ sliced_tokens[-1].item() - tokenizer.timestamp_begin)
+ add_segment(
+ start=timestamp_offset + start_timestamp_position *
+ time_precision,
+ end=timestamp_offset + end_timestamp_position *
+ time_precision,
+ text_tokens=sliced_tokens[1:-1],
+ result=result, )
+ last_slice = current_slice
+ last_timestamp_position = (
+ tokens[last_slice - 1].item() - tokenizer.timestamp_begin)
+ seek += last_timestamp_position * input_stride
+ all_tokens.extend(tokens[:last_slice + 1].tolist())
+ else:
+ duration = segment_duration
+ timestamps = tokens[timestamp_tokens.nonzero().flatten()]
+ if len(timestamps) > 0 and timestamps[
+ -1].item() != tokenizer.timestamp_begin:
+ # no consecutive timestamps but it has a timestamp; use the last one.
+ # single timestamp at the end means no speech after the last timestamp.
+ last_timestamp_position = timestamps[
+ -1].item() - tokenizer.timestamp_begin
+ duration = last_timestamp_position * time_precision
+
+ add_segment(
+ start=timestamp_offset,
+ end=timestamp_offset + duration,
+ text_tokens=tokens,
+ result=result, )
+
+ seek += segment.shape[-1]
+ all_tokens.extend(tokens.tolist())
+
+ if not condition_on_previous_text or result.temperature > 0.5:
+ # do not feed the prompt tokens if a high temperature was used
+ prompt_reset_since = len(all_tokens)
+
+ # update progress bar
+ pbar.update(min(num_frames, seek) - previous_seek_value)
+ previous_seek_value = seek
+
+ return dict(
+ text=tokenizer.decode(all_tokens[len(initial_prompt):]),
+ segments=all_segments,
+ language=language)
+
+
+class SequenceRanker:
+ def rank(self,
+ tokens: List[List[paddle.Tensor]],
+ sum_logprobs: List[List[float]]) -> List[int]:
+ """
+ Given a list of groups of samples and their cumulative log probabilities,
+ return the indices of the samples in each group to select as the final result
+ """
+ raise NotImplementedError
+
+
+class MaximumLikelihoodRanker(SequenceRanker):
+ """
+ Select the sample with the highest log probabilities, penalized using either
+ a simple length normalization or Google NMT paper's length penalty
+ """
+
+ def __init__(self, length_penalty: Optional[float]):
+ self.length_penalty = length_penalty
+
+ def rank(self,
+ tokens: List[List[paddle.Tensor]],
+ sum_logprobs: List[List[float]]):
+ def scores(logprobs, lengths):
+ result = []
+ for logprob, length in zip(logprobs, lengths):
+ if self.length_penalty is None:
+ penalty = length
+ else:
+ # from the Google NMT paper
+ penalty = ((5 + length) / 6)**self.length_penalty
+ result.append(logprob / penalty)
+ return result
+
+ # get the sequence with the highest score
+ lengths = [[len(t) for t in s] for s in tokens]
+ return [np.argmax(scores(p, l)) for p, l in zip(sum_logprobs, lengths)]
+
+
+class TokenDecoder:
+ def reset(self):
+ """Initialize any stateful variables for decoding a new sequence"""
+
+ def update(self,
+ tokens: paddle.Tensor,
+ logits: paddle.Tensor,
+ sum_logprobs: paddle.Tensor) -> Tuple[paddle.Tensor, bool]:
+ """Specify how to select the next token, based on the current trace and logits
+
+ Parameters
+ ----------
+ tokens : Tensor, shape = (n_batch, current_sequence_length)
+ all tokens in the context so far, including the prefix and sot_sequence tokens
+
+ logits : Tensor, shape = (n_batch, vocab_size)
+ per-token logits of the probability distribution at the current step
+
+ sum_logprobs : Tensor, shape = (n_batch)
+ cumulative log probabilities for each sequence
+
+ Returns
+ -------
+ tokens : Tensor, shape = (n_batch, current_sequence_length + 1)
+ the tokens, appended with the selected next token
+
+ completed : bool
+ True if all sequences has reached the end of text
+
+ """
+ raise NotImplementedError
+
+ def finalize(
+ self, tokens: paddle.Tensor, sum_logprobs: paddle.Tensor
+ ) -> Tuple[Sequence[Sequence[paddle.Tensor]], List[List[float]]]:
+ """Finalize search and return the final candidate sequences
+
+ Parameters
+ ----------
+ tokens : Tensor, shape = (batch_size, beam_size, current_sequence_length)
+ all tokens in the context so far, including the prefix and sot_sequence
+
+ sum_logprobs : Tensor, shape = (batch_size, beam_size)
+ cumulative log probabilities for each sequence
+
+ Returns
+ -------
+ tokens : Sequence[Sequence[Tensor]], length = batch_size
+ sequence of Tensors containing candidate token sequences, for each audio input
+
+ sum_logprobs : List[List[float]], length = batch_size
+ sequence of cumulative log probabilities corresponding to the above
+
+ """
+ raise NotImplementedError
+
+
+class GreedyDecoder(TokenDecoder):
+ def __init__(self, temperature: float, eot: int):
+ self.temperature = temperature
+ self.eot = eot
+
+ def update(self,
+ tokens: paddle.Tensor,
+ logits: paddle.Tensor,
+ sum_logprobs: paddle.Tensor) -> Tuple[paddle.Tensor, bool]:
+ temperature = self.temperature
+ if temperature == 0:
+ next_tokens = paddle.argmax(logits, axis=-1)
+ else:
+ next_tokens = Categorical(logits=logits / temperature).sample(
+ shape=logits.shape)
+
+ logprobs = F.log_softmax(logits, axis=-1, dtype=paddle.float32)
+ current_logprobs = logprobs[paddle.arange(logprobs.shape[0]),
+ next_tokens]
+ sum_logprobs += current_logprobs * paddle.to_tensor(
+ (tokens[:, -1] != self.eot), dtype=paddle.float32)
+
+ next_tokens[tokens[:, -1] == self.eot] = self.eot
+ tokens = paddle.concat([tokens, next_tokens[:, None]], axis=-1)
+
+ completed = paddle.all((tokens[:, -1] == self.eot))
+ return tokens, completed
+
+ def finalize(self, tokens: paddle.Tensor, sum_logprobs: paddle.Tensor):
+ # make sure each sequence has at least one EOT token at the end
+ tokens = F.pad(tokens, (0, 1), value=self.eot, data_format="NCL")
+ return tokens, sum_logprobs.tolist()
+
+
+class BeamSearchDecoder(TokenDecoder):
+ def __init__(self,
+ beam_size: int,
+ eot: int,
+ inference: Inference,
+ patience: Optional[float]=None):
+ self.beam_size = beam_size
+ self.eot = eot
+ self.inference = inference
+ self.patience = patience or 1.0
+ self.max_candidates: int = round(beam_size * self.patience)
+ self.finished_sequences = None
+
+ assert self.max_candidates > 0, f"Invalid beam size ({beam_size}) or patience ({patience})"
+
+ def reset(self):
+ self.finished_sequences = None
+
+ def update(self,
+ tokens: paddle.Tensor,
+ logits: paddle.Tensor,
+ sum_logprobs: paddle.Tensor) -> Tuple[paddle.Tensor, bool]:
+ if tokens.shape[0] % self.beam_size != 0:
+ raise ValueError(f"{tokens.shape}[0] % {self.beam_size} != 0")
+
+ batch_size = tokens.shape[0] // self.beam_size
+ if self.finished_sequences is None: # for the first update
+ self.finished_sequences = [{} for _ in range(batch_size)]
+
+ logprobs = F.log_softmax(logits, axis=-1, dtype=paddle.float32)
+ next_tokens, source_indices, finished_sequences = [], [], []
+ for i in range(batch_size):
+ scores, sources, finished = {}, {}, {}
+
+ # STEP 1: calculate the cumulative log probabilities for possible candidates
+ for j in range(self.beam_size):
+ idx = i * self.beam_size + j
+ prefix = tokens[idx].tolist()
+ logprob, token = paddle.topk(
+ logprobs[idx], k=self.beam_size + 1)
+ for logprob, token in zip(logprob, token):
+ new_logprob = (sum_logprobs[idx] + logprob).tolist()[0]
+ sequence = tuple(prefix + [token.tolist()[0]])
+ scores[sequence] = new_logprob
+ sources[sequence] = idx
+
+ # STEP 2: rank the candidates and keep the top beam_size sequences for each audio
+ saved = 0
+ for sequence in sorted(scores, key=scores.get, reverse=True):
+ if sequence[-1] == self.eot:
+ finished[sequence] = scores[sequence]
+ else:
+ sum_logprobs[len(next_tokens)] = scores[sequence]
+ next_tokens.append(sequence)
+ source_indices.append(sources[sequence])
+
+ saved += 1
+ if saved == self.beam_size:
+ break
+
+ finished_sequences.append(finished)
+
+ tokens = paddle.to_tensor(next_tokens)
+ self.inference.rearrange_kv_cache(source_indices)
+
+ # add newly finished sequences to self.finished_sequences
+ assert len(self.finished_sequences) == len(finished_sequences)
+ for previously_finished, newly_finished in zip(self.finished_sequences,
+ finished_sequences):
+ for seq in sorted(
+ newly_finished, key=newly_finished.get, reverse=True):
+ if len(previously_finished) >= self.max_candidates:
+ break # the candidate list is full
+ previously_finished[seq] = newly_finished[seq]
+
+ # mark as completed if all audio has enough number of samples
+ completed = all(
+ len(sequences) >= self.max_candidates
+ for sequences in self.finished_sequences)
+ return tokens, completed
+
+ def finalize(self,
+ preceding_tokens: paddle.Tensor,
+ sum_logprobs: paddle.Tensor):
+ # collect all finished sequences, including patience, and add unfinished ones if not enough
+ sum_logprobs = sum_logprobs.cpu()
+ for i, sequences in enumerate(self.finished_sequences):
+ if len(sequences
+ ) < self.beam_size: # when not enough sequences are finished
+ for j in list(np.argsort(sum_logprobs[i]))[::-1]:
+ sequence = preceding_tokens[i, j].tolist() + [self.eot]
+ sequences[tuple(sequence)] = sum_logprobs[i][j].item()
+ if len(sequences) >= self.beam_size:
+ break
+
+ tokens: List[List[paddle.Tensor]] = [
+ [paddle.to_tensor(seq) for seq in sequences.keys()]
+ for sequences in self.finished_sequences
+ ]
+ sum_logprobs: List[List[float]] = [
+ list(sequences.values()) for sequences in self.finished_sequences
+ ]
+ return tokens, sum_logprobs
+
+
+class LogitFilter:
+ def apply(self, logits: paddle.Tensor, tokens: paddle.Tensor) -> None:
+ """Apply any filtering or masking to logits in-place
+
+ Parameters
+ ----------
+ logits : Tensor, shape = (n_batch, vocab_size)
+ per-token logits of the probability distribution at the current step
+
+ tokens : Tensor, shape = (n_batch, current_sequence_length)
+ all tokens in the context so far, including the prefix and sot_sequence tokens
+
+ """
+ raise NotImplementedError
+
+
+class SuppressBlank(LogitFilter):
+ def __init__(self, tokenizer: Tokenizer, sample_begin: int):
+ self.tokenizer = tokenizer
+ self.sample_begin = sample_begin
+
+ def apply(self, logits: paddle.Tensor, tokens: paddle.Tensor):
+ if tokens.shape[1] == self.sample_begin:
+ logits[:, self.tokenizer.encode(" ").input_ids +
+ [self.tokenizer.eot]] = -np.inf
+
+
+class SuppressTokens(LogitFilter):
+ def __init__(self, suppress_tokens: Sequence[int]):
+ self.suppress_tokens = list(suppress_tokens)
+
+ def apply(self, logits: paddle.Tensor, tokens: paddle.Tensor):
+ logits[:, self.suppress_tokens] = -np.inf
+
+
+class ApplyTimestampRules(LogitFilter):
+ def __init__(self,
+ tokenizer: Tokenizer,
+ sample_begin: int,
+ max_initial_timestamp_index: Optional[int]):
+ self.tokenizer = tokenizer
+ self.sample_begin = sample_begin
+ self.max_initial_timestamp_index = max_initial_timestamp_index
+
+ def apply(self, logits: paddle.Tensor, tokens: paddle.Tensor):
+ # suppress <|notimestamps|> which is handled by without_timestamps
+ if self.tokenizer.no_timestamps is not None:
+ logits[:, self.tokenizer.no_timestamps] = -np.inf
+
+ # timestamps have to appear in pairs, except directly before EOT; mask logits accordingly
+ for k in range(tokens.shape[0]):
+ seq = [t for t in tokens[k, self.sample_begin:].tolist()]
+ last_was_timestamp = len(seq) >= 1 and seq[
+ -1] >= self.tokenizer.timestamp_begin
+ penultimate_was_timestamp = len(seq) < 2 or seq[
+ -2] >= self.tokenizer.timestamp_begin
+
+ if last_was_timestamp:
+ if penultimate_was_timestamp: # has to be non-timestamp
+ logits[k, self.tokenizer.timestamp_begin:] = -np.inf
+ else: # cannot be normal text tokens
+ logits[k, :self.tokenizer.eot] = -np.inf
+
+ # apply the `max_initial_timestamp` option
+ if tokens.shape[
+ 1] == self.sample_begin and self.max_initial_timestamp_index is not None:
+ last_allowed = self.tokenizer.timestamp_begin + self.max_initial_timestamp_index
+ logits[:, last_allowed + 1:] = -np.inf
+
+ # if sum of probability over timestamps is above any other token, sample timestamp
+ logprobs = F.log_softmax(logits, axis=-1, dtype=paddle.float32)
+ for k in range(tokens.shape[0]):
+ timestamp_logprob = paddle.logsumexp(
+ logprobs[k, self.tokenizer.timestamp_begin:], axis=-1)
+ max_text_token_logprob = paddle.max(
+ logprobs[k, :self.tokenizer.timestamp_begin])
+ if timestamp_logprob > max_text_token_logprob:
+ logits[k, :self.tokenizer.timestamp_begin] = -np.inf
+
+
+class DecodingTask:
+ inference: Inference
+ sequence_ranker: SequenceRanker
+ decoder: TokenDecoder
+ logit_filters: List[LogitFilter]
+
+ def __init__(self,
+ model: "Whisper",
+ options: DecodingOptions,
+ resource_path: str):
+ self.model = model
+
+ language = options.language or "en"
+ tokenizer = get_tokenizer(
+ model.is_multilingual,
+ resource_path=resource_path,
+ language=language,
+ task=options.task)
+ self.tokenizer: Tokenizer = tokenizer
+ self.options: DecodingOptions = self._verify_options(options)
+ self.resource_path: str = resource_path
+
+ self.beam_size: int = options.beam_size or options.best_of or 1
+ self.n_ctx: int = model.dims.n_text_ctx
+ self.sample_len: int = options.sample_len or model.dims.n_text_ctx // 2
+
+ self.sot_sequence: Tuple[int] = tokenizer.sot_sequence
+ if self.options.without_timestamps:
+ self.sot_sequence = tokenizer.sot_sequence_including_notimestamps
+
+ self.initial_tokens: Tuple[int] = self._get_initial_tokens()
+ self.sample_begin: int = len(self.initial_tokens)
+ self.sot_index: int = self.initial_tokens.index(tokenizer.sot)
+
+ # inference: implements the forward pass through the decoder, including kv caching
+ self.inference = WhisperInference(model, len(self.initial_tokens))
+
+ # sequence ranker: implements how to rank a group of sampled sequences
+ self.sequence_ranker = MaximumLikelihoodRanker(options.length_penalty)
+
+ # decoder: implements how to select the next tokens, given the autoregressive distribution
+ if options.beam_size is not None:
+ self.decoder = BeamSearchDecoder(options.beam_size, tokenizer.eot,
+ self.inference, options.patience)
+ else:
+ self.decoder = GreedyDecoder(options.temperature, tokenizer.eot)
+
+ # logit filters: applies various rules to suppress or penalize certain tokens
+ self.logit_filters = []
+ if self.options.suppress_blank:
+ self.logit_filters.append(
+ SuppressBlank(self.tokenizer, self.sample_begin))
+ if self.options.suppress_tokens:
+ self.logit_filters.append(
+ SuppressTokens(self._get_suppress_tokens()))
+ if not options.without_timestamps:
+ precision = CHUNK_LENGTH / model.dims.n_audio_ctx # usually 0.02 seconds
+ max_initial_timestamp_index = None
+ if options.max_initial_timestamp:
+ max_initial_timestamp_index = round(
+ self.options.max_initial_timestamp / precision)
+ self.logit_filters.append(
+ ApplyTimestampRules(tokenizer, self.sample_begin,
+ max_initial_timestamp_index))
+
+ def _verify_options(self, options: DecodingOptions) -> DecodingOptions:
+ if options.beam_size is not None and options.best_of is not None:
+ raise ValueError("beam_size and best_of can't be given together")
+ if options.temperature == 0:
+ if options.best_of is not None:
+ raise ValueError(
+ "best_of with greedy sampling (T=0) is not compatible")
+ if options.patience is not None and options.beam_size is None:
+ raise ValueError("patience requires beam_size to be given")
+ if options.length_penalty is not None and not (
+ 0 <= options.length_penalty <= 1):
+ raise ValueError(
+ "length_penalty (alpha) should be a value between 0 and 1")
+
+ return options
+
+ def _get_initial_tokens(self) -> Tuple[int]:
+ tokens = list(self.sot_sequence)
+ prefix = self.options.prefix
+ prompt = self.options.prompt
+
+ if prefix:
+ prefix_tokens = (
+ self.tokenizer.encode(" " + prefix.strip().input_ids)
+ if isinstance(prefix, str) else prefix)
+ if self.sample_len is not None:
+ max_prefix_len = self.n_ctx // 2 - self.sample_len
+ prefix_tokens = prefix_tokens[-max_prefix_len:]
+ tokens = tokens + prefix_tokens
+
+ if prompt:
+ prompt_tokens = (
+ self.tokenizer.encode(" " + prompt.strip().input_ids)
+ if isinstance(prompt, str) else prompt)
+ tokens = [self.tokenizer.sot_prev] + prompt_tokens[-(self.n_ctx // 2
+ - 1):] + tokens
+
+ return tuple(tokens)
+
+ def _get_suppress_tokens(self) -> Tuple[int]:
+ suppress_tokens = self.options.suppress_tokens
+
+ if isinstance(suppress_tokens, str):
+ suppress_tokens = [int(t) for t in suppress_tokens.split(",")]
+
+ if -1 in suppress_tokens:
+ suppress_tokens = [t for t in suppress_tokens if t >= 0]
+ suppress_tokens.extend(self.tokenizer.non_speech_tokens)
+ elif suppress_tokens is None or len(suppress_tokens) == 0:
+ suppress_tokens = [] # interpret empty string as an empty list
+ else:
+ assert isinstance(suppress_tokens,
+ list), "suppress_tokens must be a list"
+
+ suppress_tokens.extend([
+ self.tokenizer.sot, self.tokenizer.sot_prev, self.tokenizer.sot_lm
+ ])
+ if self.tokenizer.no_speech is not None:
+ # no-speech probability is collected separately
+ suppress_tokens.append(self.tokenizer.no_speech)
+
+ return tuple(sorted(set(suppress_tokens)))
+
+ def _get_audio_features(self, mel: paddle.Tensor):
+ #if self.options.fp16:
+ # mel = mel.half()
+
+ if mel.shape[-2:] == (self.model.dims.n_audio_ctx,
+ self.model.dims.n_audio_state):
+ # encoded audio features are given; skip audio encoding
+ audio_features = mel
+ else:
+ audio_features = self.model.encoder(mel)
+
+ #if audio_features.dtype != (np.float16 if self.options.fp16 else np.float32):
+ # return TypeError(f"audio_features has an incorrect dtype: {audio_features.dtype}")
+
+ return audio_features
+
+ def _detect_language(self,
+ audio_features: paddle.Tensor,
+ tokens: paddle.Tensor,
+ resource_path: str):
+ languages = [self.options.language] * audio_features.shape[0]
+ lang_probs = None
+
+ if self.options.language is None or self.options.task == "lang_id":
+ lang_tokens, lang_probs = self.model.detect_language(
+ audio_features, self.tokenizer, self.resource_path)
+ languages = [max(probs, key=probs.get) for probs in lang_probs]
+ if self.options.language is None:
+ tokens[:, self.sot_index +
+ 1] = lang_tokens # write language tokens
+
+ return languages, lang_probs
+
+ def _main_loop(self, audio_features: paddle.Tensor, tokens: paddle.Tensor):
+ assert audio_features.shape[0] == tokens.shape[0]
+ n_batch = tokens.shape[0]
+ sum_logprobs: paddle.Tensor = paddle.zeros(
+ paddle.to_tensor(n_batch), dtype=paddle.float32)
+ no_speech_probs = [np.nan] * n_batch
+
+ try:
+ for i in range(self.sample_len):
+ logits = self.inference.logits(tokens, audio_features)
+
+ if i == 0 and self.tokenizer.no_speech is not None: # save no_speech_probs
+ probs_at_sot = F.softmax(
+ logits[:, self.sot_index],
+ axis=-1,
+ dtype=paddle.float32)
+ no_speech_probs = probs_at_sot[:, self.tokenizer.
+ no_speech].tolist()
+
+ # now we need to consider the logits at the last token only
+ logits = logits[:, -1]
+
+ # apply the logit filters, e.g. for suppressing or applying penalty to
+ for logit_filter in self.logit_filters:
+ logit_filter.apply(logits, tokens)
+
+ # expand the tokens tensor with the selected next tokens
+ tokens, completed = self.decoder.update(tokens, logits,
+ sum_logprobs)
+ if completed or tokens.shape[-1] > self.n_ctx:
+ break
+ finally:
+ self.inference.cleanup_caching()
+
+ return tokens, sum_logprobs, no_speech_probs
+
+ @paddle.no_grad()
+ def run(self, mel: paddle.Tensor) -> List[DecodingResult]:
+ self.decoder.reset()
+ tokenizer: Tokenizer = self.tokenizer
+ batch_size: int = mel.shape[0]
+
+ audio_features: paddle.Tensor = self._get_audio_features(
+ mel) # encoder forward pass
+
+ tokens: paddle.Tensor
+ if batch_size > 1:
+ for i in range(batch_size):
+ tokens = paddle.concat(
+ x=[
+ paddle.to_tensor([self.initial_tokens]),
+ paddle.to_tensor([self.initial_tokens])
+ ],
+ axis=0)
+ elif batch_size == 1:
+ tokens = paddle.to_tensor([self.initial_tokens])
+
+ # detect language if requested, overwriting the language token
+ languages, language_probs = self._detect_language(
+ paddle.to_tensor(audio_features),
+ paddle.to_tensor(tokens), self.resource_path)
+
+ if self.options.task == "lang_id":
+ return [
+ DecodingResult(
+ audio_features=features,
+ language=language,
+ language_probs=probs)
+ for features, language, probs in zip(audio_features, languages,
+ language_probs)
+ ]
+
+ # repeat the audio & text tensors by the group size, for beam search or best-of-n sampling
+
+ audio_features = paddle.repeat_interleave(
+ audio_features, self.beam_size, axis=0)
+ tokens = paddle.repeat_interleave(tokens, self.beam_size, axis=0)
+
+ # call the main sampling loop
+ tokens, sum_logprobs, no_speech_probs = self._main_loop(audio_features,
+ tokens)
+
+ # reshape the tensors to have (batch_size, beam_size) as the first two dimensions
+ audio_features = audio_features[::self.beam_size]
+ no_speech_probs = no_speech_probs[::self.beam_size]
+ assert audio_features.shape[0] == len(no_speech_probs) == batch_size
+
+ tokens = tokens.reshape([batch_size, self.beam_size, -1])
+ sum_logprobs = sum_logprobs.reshape([batch_size, self.beam_size])
+
+ # get the final candidates for each group, and slice between the first sampled token and EOT
+ tokens, sum_logprobs = self.decoder.finalize(tokens, sum_logprobs)
+ tokens: List[List[paddle.Tensor]] = [[
+ t[self.sample_begin:(t == tokenizer.eot).nonzero()[0, 0]] for t in s
+ ] for s in tokens]
+
+ # select the top-ranked sample in each group
+ selected = self.sequence_ranker.rank(tokens, sum_logprobs)
+ tokens: List[List[
+ int]] = [t[i].tolist() for i, t in zip(selected, tokens)]
+ texts: List[str] = [tokenizer.decode(t).strip() for t in tokens]
+
+ sum_logprobs: List[
+ float] = [lp[i] for i, lp in zip(selected, sum_logprobs)]
+ avg_logprobs: List[
+ float] = [lp / (len(t) + 1) for t, lp in zip(tokens, sum_logprobs)]
+
+ fields = (texts, languages, tokens, audio_features, avg_logprobs,
+ no_speech_probs)
+ if len(set(map(len, fields))) != 1:
+ raise RuntimeError(
+ f"inconsistent result lengths: {list(map(len, fields))}")
+
+ return [
+ DecodingResult(
+ audio_features=features,
+ language=language,
+ tokens=tokens,
+ text=text,
+ avg_logprob=avg_logprob,
+ no_speech_prob=no_speech_prob,
+ temperature=self.options.temperature,
+ compression_ratio=utils.compression_ratio(text), )
+ for text, language, tokens, features, avg_logprob, no_speech_prob in
+ zip(*fields)
+ ]
+
+
+@paddle.no_grad()
+def decode(
+ model: "Whisper",
+ mel: paddle.Tensor,
+ options: DecodingOptions=DecodingOptions(),
+ resource_path=str, ) -> Union[DecodingResult, List[DecodingResult]]:
+ """
+ Performs decoding of 30-second audio segment(s), provided as Mel spectrogram(s).
+
+ Parameters
+ ----------
+ model: Whisper
+ the Whisper model instance
+
+ mel: paddle.Tensor, shape = (80, 3000) or (*, 80, 3000)
+ A tensor containing the Mel spectrogram(s)
+
+ options: DecodingOptions
+ A dataclass that contains all necessary options for decoding 30-second segments
+
+ Returns
+ -------
+ result: Union[DecodingResult, List[DecodingResult]]
+ The result(s) of decoding contained in `DecodingResult` dataclass instance(s)
+ """
+ single = mel.ndim == 2
+ if single:
+ mel = mel.unsqueeze(0)
+
+ result = DecodingTask(model, options, resource_path).run(mel)
+
+ if single:
+ result = result[0]
+
+ return result
+
+
+class Whisper(nn.Layer):
+ def __init__(self, dims: ModelDimensions):
+ super().__init__()
+ self.dims = dims
+ self.encoder = AudioEncoder(
+ self.dims.n_mels,
+ self.dims.n_audio_ctx,
+ self.dims.n_audio_state,
+ self.dims.n_audio_head,
+ self.dims.n_audio_layer, )
+ self.decoder = TextDecoder(
+ self.dims.n_vocab,
+ self.dims.n_text_ctx,
+ self.dims.n_text_state,
+ self.dims.n_text_head,
+ self.dims.n_text_layer, )
+
+ def embed_audio(self, mel: paddle.Tensor):
+ return self.encoder.forward(mel)
+
+ def logits(self, tokens: paddle.Tensor, audio_features: paddle.Tensor):
+ return self.decoder.forward(tokens, audio_features)
+
+ def forward(self, mel: paddle.Tensor,
+ tokens: paddle.Tensor) -> Dict[str, paddle.Tensor]:
+ return self.decoder(tokens, self.encoder(mel))
+
+ @property
+ def device(self):
+ return paddle.device.get_device()
+
+ @property
+ def is_multilingual(self):
+ return self.dims.n_vocab == 51865
+
+ def install_kv_cache_hooks(self, cache: Optional[dict]=None):
+ """
+ The `MultiHeadAttention` module optionally accepts `kv_cache` which stores the key and value
+ tensors calculated for the previous positions. This method returns a dictionary that stores
+ all caches, and the necessary hooks for the key and value projection modules that save the
+ intermediate tensors to be reused during later calculations.
+
+ Returns
+ -------
+ cache : Dict[nn.Layer, paddle.Tensor]
+ A dictionary object mapping the key/value projection modules to its cache
+ hooks : List[RemovableHandle]
+ List of PyTorch RemovableHandle objects to stop the hooks to be called
+ """
+ cache = {**cache} if cache is not None else {}
+ hooks = []
+
+ def save_to_cache(module, _, output):
+ if module not in cache or output.shape[
+ 1] > self.decoder.positional_embedding.shape[0]:
+ cache[
+ module] = output # save as-is, for the first token or cross attention
+ else:
+ cache[module] = paddle.concat(
+ [cache[module], output], axis=1).detach()
+ return cache[module]
+
+ def install_hooks(layer: nn.Layer):
+ if isinstance(layer, MultiHeadAttention):
+ hooks.append(
+ layer.key.register_forward_post_hook(save_to_cache))
+ hooks.append(
+ layer.value.register_forward_post_hook(save_to_cache))
+
+ self.decoder.apply(install_hooks)
+ return cache, hooks
+
+ detect_language = detect_language
+ transcribe = transcribe
+ decode = decode
+
+
+def pad_or_trim(array, length: int=N_SAMPLES, *, axis: int=-1):
+ """
+ Pad or trim the audio array to N_SAMPLES, as expected by the encoder.
+ """
+ if paddle.is_tensor(array):
+ if array.shape[axis] > length:
+ array = array.index_select(axis=axis, index=paddle.arange(length))
+
+ if array.shape[axis] < length:
+ pad_widths = [(0, 0)] * array.ndim
+ pad_widths[axis] = (0, length - array.shape[axis])
+ array = paddle.transpose(array, (1, 0))
+ array = F.pad(
+ array, [pad for sizes in pad_widths[::-1] for pad in sizes],
+ data_format='NLC')
+ array = paddle.transpose(array, (1, 0))
+ else:
+ if array.shape[axis] > length:
+ array = array.take(indices=range(length), axis=axis)
+
+ if array.shape[axis] < length:
+ pad_widths = [(0, 0)] * array.ndim
+ pad_widths[axis] = (0, length - array.shape[axis])
+ array = paddle.transpose(array, (1, 0))
+ array = np.pad(array, pad_widths)
+ array = paddle.transpose(array, (1, 0))
+
+ return array
+
+
+def hann_window(n_fft: int=N_FFT):
+ """
+ hanning window
+ n_fft: The number of frequency components of the discrete Fourier transform.
+ """
+ return paddle.to_tensor(
+ [0.5 - 0.5 * np.cos(2 * np.pi * n / n_fft) for n in range(n_fft)],
+ dtype=paddle.float32)
+
+
+@lru_cache(maxsize=None)
+def mel_filters(resource_path: str, n_mels: int=N_MELS) -> paddle.Tensor:
+ """
+ load the mel filterbank matrix for projecting STFT into a Mel spectrogram.
+ Allows decoupling librosa dependency; saved using:
+
+ np.savez_compressed(
+ "mel_filters.npz",
+ mel_80=librosa.filters.mel(sr=16000, n_fft=400, n_mels=80),
+ )
+ """
+ assert n_mels == 80, f"Unsupported n_mels: {n_mels}"
+ with np.load(os.path.join(resource_path, "assets", "mel_filters.npz")) as f:
+ return paddle.to_tensor(f[f"mel_{n_mels}"])
+
+
+def log_mel_spectrogram(audio: Union[str, np.ndarray, paddle.Tensor],
+ n_mels: int=N_MELS,
+ resource_path: str=None):
+ """
+ Compute the log-Mel spectrogram of
+
+ Parameters
+ ----------
+ audio: Union[str, np.ndarray, paddle.Tensor], shape = (*)
+ The path to audio or either a NumPy array or Tensor containing the audio waveform in 16 kHz
+
+ n_mels: int
+ The number of Mel-frequency filters, only 80 is supported
+
+ Returns
+ -------
+ paddle.Tensor, shape = (80, n_frames)
+ A Tensor that contains the Mel spectrogram
+ """
+ if not paddle.is_tensor(audio):
+ if isinstance(audio, str):
+ audio, _ = soundfile.read(audio, dtype="float32", always_2d=True)
+ audio = audio[:, 0]
+ logger.info(f"audio shape: {audio.shape}")
+ audio = paddle.to_tensor(audio)
+
+ window = hann_window(N_FFT)
+ stft = paddle.signal.stft(audio, N_FFT, HOP_LENGTH, window=window)
+
+ magnitudes = stft[:, :-1].abs()**2
+
+ filters = mel_filters(resource_path, n_mels)
+ mel_spec = filters @ magnitudes
+ mel_spec = paddle.to_tensor(mel_spec.numpy().tolist())
+
+ log_spec = paddle.clip(mel_spec, min=1e-10).log10()
+ log_spec = paddle.maximum(log_spec, log_spec.max() - 8.0)
+ log_spec = (log_spec + 4.0) / 4.0
+ return log_spec
diff --git a/paddlespeech/s2t/models/whisper/whisper_LICENSE b/paddlespeech/s2t/models/whisper/whisper_LICENSE
new file mode 100644
index 000000000..49e465e19
--- /dev/null
+++ b/paddlespeech/s2t/models/whisper/whisper_LICENSE
@@ -0,0 +1,21 @@
+MIT License
+
+Copyright (c) 2022 OpenAI
+
+Permission is hereby granted, free of charge, to any person obtaining a copy
+of this software and associated documentation files (the "Software"), to deal
+in the Software without restriction, including without limitation the rights
+to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
+copies of the Software, and to permit persons to whom the Software is
+furnished to do so, subject to the following conditions:
+
+The above copyright notice and this permission notice shall be included in all
+copies or substantial portions of the Software.
+
+THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
+IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
+FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
+AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
+LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
+OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE
+SOFTWARE.
\ No newline at end of file
diff --git a/paddlespeech/s2t/training/scheduler.py b/paddlespeech/s2t/training/scheduler.py
index b22f7ef85..53c756ce3 100644
--- a/paddlespeech/s2t/training/scheduler.py
+++ b/paddlespeech/s2t/training/scheduler.py
@@ -17,6 +17,7 @@ from typing import Dict
from typing import Text
from typing import Union
+import paddle
from paddle.optimizer.lr import LRScheduler
from typeguard import check_argument_types
@@ -107,6 +108,125 @@ class ConstantLR(LRScheduler):
return self.base_lr
+@register_scheduler
+class NewBobScheduler(LRScheduler):
+ """Scheduler with new-bob technique, used for LR annealing.
+
+ The learning rate is annealed based on the validation performance.
+ In particular: if (past_loss-current_loss)/past_loss< impr_threshold:
+ lr=lr * annealing_factor.
+
+ Arguments
+ ---------
+ initial_value : float
+ The initial hyperparameter value.
+ annealing_factor : float
+ It is annealing factor used in new_bob strategy.
+ improvement_threshold : float
+ It is the improvement rate between losses used to perform learning
+ annealing in new_bob strategy.
+ patient : int
+ When the annealing condition is violated patient times,
+ the learning rate is finally reduced.
+
+ Example
+ -------
+ >>> scheduler = NewBobScheduler(initial_value=1.0)
+ >>> scheduler(metric_value=10.0)
+ (1.0, 1.0)
+ >>> scheduler(metric_value=2.0)
+ (1.0, 1.0)
+ >>> scheduler(metric_value=2.5)
+ (1.0, 0.5)
+ """
+
+ def __init__(
+ self,
+ learning_rate,
+ last_epoch=-1,
+ verbose=False,
+ annealing_factor=0.5,
+ improvement_threshold=0.0025,
+ patient=0, ):
+ self.hyperparam_value = learning_rate
+ self.annealing_factor = annealing_factor
+ self.improvement_threshold = improvement_threshold
+ self.patient = patient
+ self.metric_values = []
+ self.current_patient = self.patient
+ super().__init__(learning_rate, last_epoch, verbose)
+
+ def step(self, metric_value=None):
+ """
+
+ ``step`` should be called after ``optimizer.step`` . It will update the learning rate in optimizer according to current ``epoch`` .
+ The new learning rate will take effect on next ``optimizer.step`` .
+
+ Args:
+ epoch (int, None): specify current epoch. Default: None. Auto-increment from last_epoch=-1.
+
+ Returns:
+ None
+ """
+ if metric_value is None:
+ self.last_epoch += 1
+ self.last_lr = self.hyperparam_value
+ else:
+ self.last_epoch += 1
+ self.last_lr = self.get_lr(metric_value)
+
+ if self.verbose:
+ print('Epoch {}: {} set learning rate to {}.'.format(
+ self.last_epoch, self.__class__.__name__, self.last_lr))
+
+ def get_lr(self, metric_value):
+ """Returns the current and new value for the hyperparameter.
+
+ Arguments
+ ---------
+ metric_value : int
+ A number for determining whether to change the hyperparameter value.
+ """
+ new_value = self.hyperparam_value
+ if len(self.metric_values) > 0:
+ prev_metric = self.metric_values[-1]
+ # Update value if improvement too small and patience is 0
+ if prev_metric == 0: # Prevent division by zero
+ improvement = 0
+ else:
+ improvement = (prev_metric - metric_value) / prev_metric
+ if improvement < self.improvement_threshold:
+ if self.current_patient == 0:
+ new_value *= self.annealing_factor
+ self.current_patient = self.patient
+ else:
+ self.current_patient -= 1
+
+ # Store relevant info
+ self.metric_values.append(metric_value)
+ self.hyperparam_value = new_value
+
+ return new_value
+
+ def save(self):
+ """Saves the current metrics on the specified path."""
+ data = {
+ "current_epoch_index": self.last_epoch,
+ "hyperparam_value": self.hyperparam_value,
+ "metric_values": self.metric_values,
+ "current_patient": self.current_patient
+ }
+ return data
+
+ def load(self, data):
+ """Loads the needed information."""
+ data = paddle.load(data)
+ self.last_epoch = data["current_epoch_index"]
+ self.hyperparam_value = data["hyperparam_value"]
+ self.metric_values = data["metric_values"]
+ self.current_patient = data["current_patient"]
+
+
def dynamic_import_scheduler(module):
"""Import Scheduler class dynamically.
diff --git a/paddlespeech/t2s/exps/inference.py b/paddlespeech/t2s/exps/inference.py
index 5840c0699..e0ae20bb1 100644
--- a/paddlespeech/t2s/exps/inference.py
+++ b/paddlespeech/t2s/exps/inference.py
@@ -145,7 +145,7 @@ def main():
# warmup
for utt_id, sentence in sentences[:3]:
with timer() as t:
- am_output_data = get_am_output(
+ mel = get_am_output(
input=sentence,
am_predictor=am_predictor,
am=args.am,
@@ -154,12 +154,11 @@ def main():
merge_sentences=merge_sentences,
speaker_dict=args.speaker_dict,
spk_id=args.spk_id, )
- wav = get_voc_output(
- voc_predictor=voc_predictor, input=am_output_data)
+ wav = get_voc_output(voc_predictor=voc_predictor, input=mel)
speed = wav.size / t.elapse
rtf = fs / speed
print(
- f"{utt_id}, mel: {am_output_data.shape}, wave: {wav.shape}, time: {t.elapse}s, Hz: {speed}, RTF: {rtf}."
+ f"{utt_id}, mel: {mel.shape}, wave: {wav.shape}, time: {t.elapse}s, Hz: {speed}, RTF: {rtf}."
)
print("warm up done!")
@@ -168,7 +167,7 @@ def main():
T = 0
for utt_id, sentence in sentences:
with timer() as t:
- am_output_data = get_am_output(
+ mel = get_am_output(
input=sentence,
am_predictor=am_predictor,
am=args.am,
@@ -177,8 +176,7 @@ def main():
merge_sentences=merge_sentences,
speaker_dict=args.speaker_dict,
spk_id=args.spk_id, )
- wav = get_voc_output(
- voc_predictor=voc_predictor, input=am_output_data)
+ wav = get_voc_output(voc_predictor=voc_predictor, input=mel)
N += wav.size
T += t.elapse
@@ -187,7 +185,7 @@ def main():
sf.write(output_dir / (utt_id + ".wav"), wav, samplerate=fs)
print(
- f"{utt_id}, mel: {am_output_data.shape}, wave: {wav.shape}, time: {t.elapse}s, Hz: {speed}, RTF: {rtf}."
+ f"{utt_id}, mel: {mel.shape}, wave: {wav.shape}, time: {t.elapse}s, Hz: {speed}, RTF: {rtf}."
)
print(f"{utt_id} done!")
diff --git a/paddlespeech/t2s/exps/lite_predict.py b/paddlespeech/t2s/exps/lite_predict.py
new file mode 100644
index 000000000..bd0c732b1
--- /dev/null
+++ b/paddlespeech/t2s/exps/lite_predict.py
@@ -0,0 +1,168 @@
+# Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+import argparse
+from pathlib import Path
+
+import soundfile as sf
+from timer import timer
+
+from paddlespeech.t2s.exps.syn_utils import get_frontend
+from paddlespeech.t2s.exps.syn_utils import get_lite_am_output
+from paddlespeech.t2s.exps.syn_utils import get_lite_predictor
+from paddlespeech.t2s.exps.syn_utils import get_lite_voc_output
+from paddlespeech.t2s.exps.syn_utils import get_sentences
+
+
+def parse_args():
+ parser = argparse.ArgumentParser(
+ description="Paddle Infernce with acoustic model & vocoder.")
+ # acoustic model
+ parser.add_argument(
+ '--am',
+ type=str,
+ default='fastspeech2_csmsc',
+ choices=[
+ 'speedyspeech_csmsc',
+ 'fastspeech2_csmsc',
+ 'fastspeech2_aishell3',
+ 'fastspeech2_ljspeech',
+ 'fastspeech2_vctk',
+ 'fastspeech2_mix',
+ ],
+ help='Choose acoustic model type of tts task.')
+ parser.add_argument(
+ "--phones_dict", type=str, default=None, help="phone vocabulary file.")
+ parser.add_argument(
+ "--tones_dict", type=str, default=None, help="tone vocabulary file.")
+ parser.add_argument(
+ "--speaker_dict", type=str, default=None, help="speaker id map file.")
+ parser.add_argument(
+ '--spk_id',
+ type=int,
+ default=0,
+ help='spk id for multi speaker acoustic model')
+ # voc
+ parser.add_argument(
+ '--voc',
+ type=str,
+ default='pwgan_csmsc',
+ choices=[
+ 'pwgan_csmsc',
+ 'pwgan_aishell3',
+ 'pwgan_ljspeech',
+ 'pwgan_vctk',
+ 'mb_melgan_csmsc',
+ 'hifigan_csmsc',
+ 'hifigan_aishell3',
+ 'hifigan_ljspeech',
+ 'hifigan_vctk',
+ ],
+ help='Choose vocoder type of tts task.')
+ # other
+ parser.add_argument(
+ '--lang',
+ type=str,
+ default='zh',
+ help='Choose model language. zh or en or mix')
+ parser.add_argument(
+ "--text",
+ type=str,
+ help="text to synthesize, a 'utt_id sentence' pair per line")
+ parser.add_argument(
+ "--inference_dir", type=str, help="dir to save inference models")
+ parser.add_argument("--output_dir", type=str, help="output dir")
+
+ args, _ = parser.parse_known_args()
+ return args
+
+
+# only inference for models trained with csmsc now
+def main():
+ args = parse_args()
+
+ # frontend
+ frontend = get_frontend(
+ lang=args.lang,
+ phones_dict=args.phones_dict,
+ tones_dict=args.tones_dict)
+
+ # am_predictor
+ am_predictor = get_lite_predictor(
+ model_dir=args.inference_dir, model_file=args.am + "_x86.nb")
+ # model: {model_name}_{dataset}
+ am_dataset = args.am[args.am.rindex('_') + 1:]
+
+ # voc_predictor
+ voc_predictor = get_lite_predictor(
+ model_dir=args.inference_dir, model_file=args.voc + "_x86.nb")
+
+ output_dir = Path(args.output_dir)
+ output_dir.mkdir(parents=True, exist_ok=True)
+
+ sentences = get_sentences(text_file=args.text, lang=args.lang)
+
+ merge_sentences = True
+ fs = 24000 if am_dataset != 'ljspeech' else 22050
+ # warmup
+ for utt_id, sentence in sentences[:3]:
+ with timer() as t:
+ mel = get_lite_am_output(
+ input=sentence,
+ am_predictor=am_predictor,
+ am=args.am,
+ frontend=frontend,
+ lang=args.lang,
+ merge_sentences=merge_sentences,
+ speaker_dict=args.speaker_dict,
+ spk_id=args.spk_id, )
+ wav = get_lite_voc_output(voc_predictor=voc_predictor, input=mel)
+ speed = wav.size / t.elapse
+ rtf = fs / speed
+ print(
+ f"{utt_id}, mel: {mel.shape}, wave: {wav.shape}, time: {t.elapse}s, Hz: {speed}, RTF: {rtf}."
+ )
+
+ print("warm up done!")
+
+ N = 0
+ T = 0
+ for utt_id, sentence in sentences:
+ with timer() as t:
+ mel = get_lite_am_output(
+ input=sentence,
+ am_predictor=am_predictor,
+ am=args.am,
+ frontend=frontend,
+ lang=args.lang,
+ merge_sentences=merge_sentences,
+ speaker_dict=args.speaker_dict,
+ spk_id=args.spk_id, )
+ wav = get_lite_voc_output(voc_predictor=voc_predictor, input=mel)
+
+ N += wav.size
+ T += t.elapse
+ speed = wav.size / t.elapse
+ rtf = fs / speed
+
+ sf.write(output_dir / (utt_id + ".wav"), wav, samplerate=fs)
+ print(
+ f"{utt_id}, mel: {mel.shape}, wave: {wav.shape}, time: {t.elapse}s, Hz: {speed}, RTF: {rtf}."
+ )
+
+ print(f"{utt_id} done!")
+ print(f"generation speed: {N / T}Hz, RTF: {fs / (N / T) }")
+
+
+if __name__ == "__main__":
+ main()
diff --git a/paddlespeech/t2s/exps/lite_predict_streaming.py b/paddlespeech/t2s/exps/lite_predict_streaming.py
new file mode 100644
index 000000000..37b600512
--- /dev/null
+++ b/paddlespeech/t2s/exps/lite_predict_streaming.py
@@ -0,0 +1,230 @@
+# Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+import argparse
+from pathlib import Path
+
+import numpy as np
+import soundfile as sf
+from timer import timer
+
+from paddlespeech.t2s.exps.syn_utils import denorm
+from paddlespeech.t2s.exps.syn_utils import get_chunks
+from paddlespeech.t2s.exps.syn_utils import get_frontend
+from paddlespeech.t2s.exps.syn_utils import get_lite_am_sublayer_output
+from paddlespeech.t2s.exps.syn_utils import get_lite_predictor
+from paddlespeech.t2s.exps.syn_utils import get_lite_streaming_am_output
+from paddlespeech.t2s.exps.syn_utils import get_lite_voc_output
+from paddlespeech.t2s.exps.syn_utils import get_sentences
+from paddlespeech.t2s.exps.syn_utils import run_frontend
+from paddlespeech.t2s.utils import str2bool
+
+
+def parse_args():
+ parser = argparse.ArgumentParser(
+ description="Paddle Infernce with acoustic model & vocoder.")
+ # acoustic model
+ parser.add_argument(
+ '--am',
+ type=str,
+ default='fastspeech2_csmsc',
+ choices=['fastspeech2_csmsc'],
+ help='Choose acoustic model type of tts task.')
+ parser.add_argument(
+ "--am_stat",
+ type=str,
+ default=None,
+ help="mean and standard deviation used to normalize spectrogram when training acoustic model."
+ )
+ parser.add_argument(
+ "--phones_dict", type=str, default=None, help="phone vocabulary file.")
+ parser.add_argument(
+ "--tones_dict", type=str, default=None, help="tone vocabulary file.")
+ parser.add_argument(
+ "--speaker_dict", type=str, default=None, help="speaker id map file.")
+ parser.add_argument(
+ '--spk_id',
+ type=int,
+ default=0,
+ help='spk id for multi speaker acoustic model')
+ # voc
+ parser.add_argument(
+ '--voc',
+ type=str,
+ default='pwgan_csmsc',
+ choices=['pwgan_csmsc', 'mb_melgan_csmsc', 'hifigan_csmsc'],
+ help='Choose vocoder type of tts task.')
+ # other
+ parser.add_argument(
+ '--lang',
+ type=str,
+ default='zh',
+ help='Choose model language. zh or en')
+ parser.add_argument(
+ "--text",
+ type=str,
+ help="text to synthesize, a 'utt_id sentence' pair per line")
+ parser.add_argument(
+ "--inference_dir", type=str, help="dir to save inference models")
+ parser.add_argument("--output_dir", type=str, help="output dir")
+ # inference
+
+ # streaming related
+ parser.add_argument(
+ "--am_streaming",
+ type=str2bool,
+ default=False,
+ help="whether use streaming acoustic model")
+ parser.add_argument(
+ "--block_size", type=int, default=42, help="block size of am streaming")
+ parser.add_argument(
+ "--pad_size", type=int, default=12, help="pad size of am streaming")
+
+ args, _ = parser.parse_known_args()
+ return args
+
+
+# only inference for models trained with csmsc now
+def main():
+ args = parse_args()
+
+ # frontend
+ frontend = get_frontend(
+ lang=args.lang,
+ phones_dict=args.phones_dict,
+ tones_dict=args.tones_dict)
+
+ # am_predictor
+ am_encoder_infer_predictor = get_lite_predictor(
+ model_dir=args.inference_dir,
+ model_file=args.am + "_am_encoder_infer" + "_x86.nb")
+ am_decoder_predictor = get_lite_predictor(
+ model_dir=args.inference_dir,
+ model_file=args.am + "_am_decoder" + "_x86.nb")
+ am_postnet_predictor = get_lite_predictor(
+ model_dir=args.inference_dir,
+ model_file=args.am + "_am_postnet" + "_x86.nb")
+ am_mu, am_std = np.load(args.am_stat)
+ # model: {model_name}_{dataset}
+ am_dataset = args.am[args.am.rindex('_') + 1:]
+
+ # voc_predictor
+ voc_predictor = get_lite_predictor(
+ model_dir=args.inference_dir, model_file=args.voc + "_x86.nb")
+
+ output_dir = Path(args.output_dir)
+ output_dir.mkdir(parents=True, exist_ok=True)
+
+ sentences = get_sentences(text_file=args.text, lang=args.lang)
+
+ merge_sentences = True
+
+ fs = 24000 if am_dataset != 'ljspeech' else 22050
+ # warmup
+ for utt_id, sentence in sentences[:3]:
+ with timer() as t:
+ normalized_mel = get_lite_streaming_am_output(
+ input=sentence,
+ am_encoder_infer_predictor=am_encoder_infer_predictor,
+ am_decoder_predictor=am_decoder_predictor,
+ am_postnet_predictor=am_postnet_predictor,
+ frontend=frontend,
+ lang=args.lang,
+ merge_sentences=merge_sentences, )
+ mel = denorm(normalized_mel, am_mu, am_std)
+ wav = get_lite_voc_output(voc_predictor=voc_predictor, input=mel)
+ speed = wav.size / t.elapse
+ rtf = fs / speed
+ print(
+ f"{utt_id}, mel: {mel.shape}, wave: {wav.shape}, time: {t.elapse}s, Hz: {speed}, RTF: {rtf}."
+ )
+
+ print("warm up done!")
+
+ N = 0
+ T = 0
+ block_size = args.block_size
+ pad_size = args.pad_size
+ get_tone_ids = False
+ for utt_id, sentence in sentences:
+ with timer() as t:
+ # frontend
+ frontend_dict = run_frontend(
+ frontend=frontend,
+ text=sentence,
+ merge_sentences=merge_sentences,
+ get_tone_ids=get_tone_ids,
+ lang=args.lang)
+ phone_ids = frontend_dict['phone_ids']
+ phones = phone_ids[0].numpy()
+ # acoustic model
+ orig_hs = get_lite_am_sublayer_output(
+ am_encoder_infer_predictor, input=phones)
+
+ if args.am_streaming:
+ hss = get_chunks(orig_hs, block_size, pad_size)
+ chunk_num = len(hss)
+ mel_list = []
+ for i, hs in enumerate(hss):
+ am_decoder_output = get_lite_am_sublayer_output(
+ am_decoder_predictor, input=hs)
+ am_postnet_output = get_lite_am_sublayer_output(
+ am_postnet_predictor,
+ input=np.transpose(am_decoder_output, (0, 2, 1)))
+ am_output_data = am_decoder_output + np.transpose(
+ am_postnet_output, (0, 2, 1))
+ normalized_mel = am_output_data[0]
+
+ sub_mel = denorm(normalized_mel, am_mu, am_std)
+ # clip output part of pad
+ if i == 0:
+ sub_mel = sub_mel[:-pad_size]
+ elif i == chunk_num - 1:
+ # 最后一块的右侧一定没有 pad 够
+ sub_mel = sub_mel[pad_size:]
+ else:
+ # 倒数几块的右侧也可能没有 pad 够
+ sub_mel = sub_mel[pad_size:(block_size + pad_size) -
+ sub_mel.shape[0]]
+ mel_list.append(sub_mel)
+ mel = np.concatenate(mel_list, axis=0)
+
+ else:
+ am_decoder_output = get_lite_am_sublayer_output(
+ am_decoder_predictor, input=orig_hs)
+ am_postnet_output = get_lite_am_sublayer_output(
+ am_postnet_predictor,
+ input=np.transpose(am_decoder_output, (0, 2, 1)))
+ am_output_data = am_decoder_output + np.transpose(
+ am_postnet_output, (0, 2, 1))
+ normalized_mel = am_output_data[0]
+ mel = denorm(normalized_mel, am_mu, am_std)
+ # vocoder
+ wav = get_lite_voc_output(voc_predictor=voc_predictor, input=mel)
+
+ N += wav.size
+ T += t.elapse
+ speed = wav.size / t.elapse
+ rtf = fs / speed
+
+ sf.write(output_dir / (utt_id + ".wav"), wav, samplerate=24000)
+ print(
+ f"{utt_id}, mel: {mel.shape}, wave: {wav.shape}, time: {t.elapse}s, Hz: {speed}, RTF: {rtf}."
+ )
+
+ print(f"{utt_id} done!")
+ print(f"generation speed: {N / T}Hz, RTF: {fs / (N / T) }")
+
+
+if __name__ == "__main__":
+ main()
diff --git a/paddlespeech/t2s/exps/syn_utils.py b/paddlespeech/t2s/exps/syn_utils.py
index 41663891e..cea125291 100644
--- a/paddlespeech/t2s/exps/syn_utils.py
+++ b/paddlespeech/t2s/exps/syn_utils.py
@@ -26,6 +26,8 @@ import paddle
from paddle import inference
from paddle import jit
from paddle.static import InputSpec
+from paddlelite.lite import create_paddle_predictor
+from paddlelite.lite import MobileConfig
from yacs.config import CfgNode
from paddlespeech.t2s.datasets.data_table import DataTable
@@ -510,3 +512,105 @@ def get_sess(model_path: Optional[os.PathLike],
sess = ort.InferenceSession(
model_path, providers=providers, sess_options=sess_options)
return sess
+
+
+# Paddle-Lite
+def get_lite_predictor(model_dir: Optional[os.PathLike]=None,
+ model_file: Optional[os.PathLike]=None,
+ cpu_threads: int=1):
+ config = MobileConfig()
+ config.set_model_from_file(str(Path(model_dir) / model_file))
+ predictor = create_paddle_predictor(config)
+ return predictor
+
+
+def get_lite_am_output(
+ input: str,
+ am_predictor,
+ am: str,
+ frontend: object,
+ lang: str='zh',
+ merge_sentences: bool=True,
+ speaker_dict: Optional[os.PathLike]=None,
+ spk_id: int=0, ):
+ am_name = am[:am.rindex('_')]
+ am_dataset = am[am.rindex('_') + 1:]
+ get_spk_id = False
+ get_tone_ids = False
+ if am_name == 'speedyspeech':
+ get_tone_ids = True
+ if am_dataset in {"aishell3", "vctk", "mix"} and speaker_dict:
+ get_spk_id = True
+ spk_id = np.array([spk_id])
+
+ frontend_dict = run_frontend(
+ frontend=frontend,
+ text=input,
+ merge_sentences=merge_sentences,
+ get_tone_ids=get_tone_ids,
+ lang=lang)
+
+ if get_tone_ids:
+ tone_ids = frontend_dict['tone_ids']
+ tones = tone_ids[0].numpy()
+ tones_handle = am_predictor.get_input(1)
+ tones_handle.from_numpy(tones)
+
+ if get_spk_id:
+ spk_id_handle = am_predictor.get_input(1)
+ spk_id_handle.from_numpy(spk_id)
+ phone_ids = frontend_dict['phone_ids']
+ phones = phone_ids[0].numpy()
+ phones_handle = am_predictor.get_input(0)
+ phones_handle.from_numpy(phones)
+ am_predictor.run()
+ am_output_handle = am_predictor.get_output(0)
+ am_output_data = am_output_handle.numpy()
+ return am_output_data
+
+
+def get_lite_voc_output(voc_predictor, input):
+ mel_handle = voc_predictor.get_input(0)
+ mel_handle.from_numpy(input)
+ voc_predictor.run()
+ voc_output_handle = voc_predictor.get_output(0)
+ wav = voc_output_handle.numpy()
+ return wav
+
+
+def get_lite_am_sublayer_output(am_sublayer_predictor, input):
+ input_handle = am_sublayer_predictor.get_input(0)
+ input_handle.from_numpy(input)
+
+ am_sublayer_predictor.run()
+ am_sublayer_handle = am_sublayer_predictor.get_output(0)
+ am_sublayer_output = am_sublayer_handle.numpy()
+ return am_sublayer_output
+
+
+def get_lite_streaming_am_output(input: str,
+ am_encoder_infer_predictor,
+ am_decoder_predictor,
+ am_postnet_predictor,
+ frontend,
+ lang: str='zh',
+ merge_sentences: bool=True):
+ get_tone_ids = False
+ frontend_dict = run_frontend(
+ frontend=frontend,
+ text=input,
+ merge_sentences=merge_sentences,
+ get_tone_ids=get_tone_ids,
+ lang=lang)
+ phone_ids = frontend_dict['phone_ids']
+ phones = phone_ids[0].numpy()
+ am_encoder_infer_output = get_lite_am_sublayer_output(
+ am_encoder_infer_predictor, input=phones)
+ am_decoder_output = get_lite_am_sublayer_output(
+ am_decoder_predictor, input=am_encoder_infer_output)
+ am_postnet_output = get_lite_am_sublayer_output(
+ am_postnet_predictor, input=np.transpose(am_decoder_output, (0, 2, 1)))
+ am_output_data = am_decoder_output + np.transpose(am_postnet_output,
+ (0, 2, 1))
+ normalized_mel = am_output_data[0]
+ return normalized_mel
diff --git a/paddlespeech/t2s/frontend/g2pw/onnx_api.py b/paddlespeech/t2s/frontend/g2pw/onnx_api.py
index 4e6fad4e5..47c26a610 100644
--- a/paddlespeech/t2s/frontend/g2pw/onnx_api.py
+++ b/paddlespeech/t2s/frontend/g2pw/onnx_api.py
@@ -100,7 +100,7 @@ class G2PWOnnxConverter:
]
self.non_polyphonic = {
'一', '不', '和', '咋', '嗲', '剖', '差', '攢', '倒', '難', '奔', '勁', '拗',
- '肖', '瘙', '誒', '泊'
+ '肖', '瘙', '誒', '泊', '听'
}
self.non_monophonic = {'似', '攢'}
self.monophonic_chars = [
diff --git a/paddlespeech/t2s/frontend/zh_normalization/constants.py b/paddlespeech/t2s/frontend/zh_normalization/constants.py
index 5d2b0b34e..6423ad74a 100644
--- a/paddlespeech/t2s/frontend/zh_normalization/constants.py
+++ b/paddlespeech/t2s/frontend/zh_normalization/constants.py
@@ -19,7 +19,7 @@ from pypinyin.constants import SUPPORT_UCS4
# 全角半角转换
# 英文字符全角 -> 半角映射表 (num: 52)
F2H_ASCII_LETTERS = {
- chr(ord(char) + 65248): char
+ ord(char) + 65248: ord(char)
for char in string.ascii_letters
}
@@ -27,12 +27,12 @@ F2H_ASCII_LETTERS = {
H2F_ASCII_LETTERS = {value: key for key, value in F2H_ASCII_LETTERS.items()}
# 数字字符全角 -> 半角映射表 (num: 10)
-F2H_DIGITS = {chr(ord(char) + 65248): char for char in string.digits}
+F2H_DIGITS = {ord(char) + 65248: ord(char) for char in string.digits}
# 数字字符半角 -> 全角映射表
H2F_DIGITS = {value: key for key, value in F2H_DIGITS.items()}
# 标点符号全角 -> 半角映射表 (num: 32)
-F2H_PUNCTUATIONS = {chr(ord(char) + 65248): char for char in string.punctuation}
+F2H_PUNCTUATIONS = {ord(char) + 65248: ord(char) for char in string.punctuation}
# 标点符号半角 -> 全角映射表
H2F_PUNCTUATIONS = {value: key for key, value in F2H_PUNCTUATIONS.items()}
diff --git a/paddlespeech/t2s/frontend/zh_normalization/text_normlization.py b/paddlespeech/t2s/frontend/zh_normalization/text_normlization.py
index 8f8e3b07d..1250e96ca 100644
--- a/paddlespeech/t2s/frontend/zh_normalization/text_normlization.py
+++ b/paddlespeech/t2s/frontend/zh_normalization/text_normlization.py
@@ -65,7 +65,7 @@ class TextNormalizer():
if lang == "zh":
text = text.replace(" ", "")
# 过滤掉特殊字符
- text = re.sub(r'[《》【】<=>{}()()#&@“”^_|…\\]', '', text)
+ text = re.sub(r'[——《》【】<=>{}()()#&@“”^_|…\\]', '', text)
text = self.SENTENCE_SPLITOR.sub(r'\1\n', text)
text = text.strip()
sentences = [sentence.strip() for sentence in re.split(r'\n+', text)]
@@ -74,7 +74,44 @@ class TextNormalizer():
def _post_replace(self, sentence: str) -> str:
sentence = sentence.replace('/', '每')
sentence = sentence.replace('~', '至')
-
+ sentence = sentence.replace('~', '至')
+ sentence = sentence.replace('①', '一')
+ sentence = sentence.replace('②', '二')
+ sentence = sentence.replace('③', '三')
+ sentence = sentence.replace('④', '四')
+ sentence = sentence.replace('⑤', '五')
+ sentence = sentence.replace('⑥', '六')
+ sentence = sentence.replace('⑦', '七')
+ sentence = sentence.replace('⑧', '八')
+ sentence = sentence.replace('⑨', '九')
+ sentence = sentence.replace('⑩', '十')
+ sentence = sentence.replace('α', '阿尔法')
+ sentence = sentence.replace('β', '贝塔')
+ sentence = sentence.replace('γ', '伽玛').replace('Γ', '伽玛')
+ sentence = sentence.replace('δ', '德尔塔').replace('Δ', '德尔塔')
+ sentence = sentence.replace('ε', '艾普西龙')
+ sentence = sentence.replace('ζ', '捷塔')
+ sentence = sentence.replace('η', '依塔')
+ sentence = sentence.replace('θ', '西塔').replace('Θ', '西塔')
+ sentence = sentence.replace('ι', '艾欧塔')
+ sentence = sentence.replace('κ', '喀帕')
+ sentence = sentence.replace('λ', '拉姆达').replace('Λ', '拉姆达')
+ sentence = sentence.replace('μ', '缪')
+ sentence = sentence.replace('ν', '拗')
+ sentence = sentence.replace('ξ', '克西').replace('Ξ', '克西')
+ sentence = sentence.replace('ο', '欧米克伦')
+ sentence = sentence.replace('π', '派').replace('Π', '派')
+ sentence = sentence.replace('ρ', '肉')
+ sentence = sentence.replace('ς', '西格玛').replace('Σ', '西格玛').replace(
+ 'σ', '西格玛')
+ sentence = sentence.replace('τ', '套')
+ sentence = sentence.replace('υ', '宇普西龙')
+ sentence = sentence.replace('φ', '服艾').replace('Φ', '服艾')
+ sentence = sentence.replace('χ', '器')
+ sentence = sentence.replace('ψ', '普赛').replace('Ψ', '普赛')
+ sentence = sentence.replace('ω', '欧米伽').replace('Ω', '欧米伽')
+ # re filter special characters, have one more character "-" than line 68
+ sentence = re.sub(r'[-——《》【】<=>{}()()#&@“”^_|…\\]', '', sentence)
return sentence
def normalize_sentence(self, sentence: str) -> str:
@@ -113,6 +150,5 @@ class TextNormalizer():
def normalize(self, text: str) -> List[str]:
sentences = self._split(text)
-
sentences = [self.normalize_sentence(sent) for sent in sentences]
return sentences
diff --git a/setup.py b/setup.py
index 35668bddb..7fb4c70be 100644
--- a/setup.py
+++ b/setup.py
@@ -75,6 +75,7 @@ base = [
"braceexpand",
"pyyaml",
"pybind11",
+ "paddlelite",
"paddleslim==2.3.4",
]
diff --git a/speechx/examples/ds2_ol/onnx/README.md b/speechx/examples/ds2_ol/onnx/README.md
index e6ab953c8..b98b74b6f 100644
--- a/speechx/examples/ds2_ol/onnx/README.md
+++ b/speechx/examples/ds2_ol/onnx/README.md
@@ -1,11 +1,8 @@
-# DeepSpeech2 to ONNX model
+# Convert DeepSpeech2 model to ONNX format
-1. convert deepspeech2 model to ONNX, using Paddle2ONNX.
-2. check paddleinference and onnxruntime output equal.
-3. optimize onnx model
-4. check paddleinference and optimized onnxruntime output equal.
-5. quantize onnx model
-4. check paddleinference and optimized onnxruntime output equal.
+> We recommend using U2/U2++ model instead of DS2, please see [here](../../u2pp_ol/wenetspeech/).
+
+This example demonstrate converting ds2 model to ONNX fromat.
Please make sure [Paddle2ONNX](https://github.com/PaddlePaddle/Paddle2ONNX) and [onnx-simplifier](https://github.com/zh794390558/onnx-simplifier/tree/dyn_time_shape) version is correct.
@@ -25,18 +22,24 @@ onnxoptimizer 0.2.7
onnxruntime 1.11.0
```
+
## Using
```
bash run.sh --stage 0 --stop_stage 5
```
+1. convert deepspeech2 model to ONNX, using Paddle2ONNX.
+2. check paddleinference and onnxruntime output equal.
+3. optimize onnx model
+4. check paddleinference and optimized onnxruntime output equal.
+5. quantize onnx model
+6. check paddleinference and optimized onnxruntime output equal.
+
For more details please see `run.sh`.
## Outputs
-The optimized onnx model is `exp/model.opt.onnx`, quanted model is `$exp/model.optset11.quant.onnx`.
-
-To show the graph, please using `local/netron.sh`.
+The optimized onnx model is `exp/model.opt.onnx`, quanted model is `exp/model.optset11.quant.onnx`.
## [Results](https://github.com/PaddlePaddle/PaddleSpeech/wiki/ASR-Benchmark#streaming-asr)
diff --git a/speechx/examples/u2pp_ol/wenetspeech/README.md b/speechx/examples/u2pp_ol/wenetspeech/README.md
index b90b8e201..6ca8f6dd8 100644
--- a/speechx/examples/u2pp_ol/wenetspeech/README.md
+++ b/speechx/examples/u2pp_ol/wenetspeech/README.md
@@ -1,27 +1,77 @@
-# u2/u2pp Streaming ASR
+# U2/U2++ Streaming ASR
+
+A C++ deployment example for `PaddleSpeech/examples/wenetspeech/asr1` recipe. The model is static model from `export`, how to export model please see [here](../../../../examples/wenetspeech/asr1/). If you want using exported model, `run.sh` will download it, for the model link please see `run.sh`.
+
+This example will demonstrate how to using the u2/u2++ model to recognize `wav` and compute `CER`. We using AISHELL-1 as test data.
## Testing with Aishell Test Data
-### Download wav and model
+### Source `path.sh` first
+
+```bash
+source path.sh
+```
+
+All bins are under `echo $SPEECHX_BUILD` dir.
+
+### Download dataset and model
```
./run.sh --stop_stage 0
```
-### compute feature
+### process `cmvn` and compute feature
-```
+```bash
./run.sh --stage 1 --stop_stage 1
```
-### decoding using feature
+If you only want to convert `cmvn` file format, can using this cmd:
+
+```bash
+./local/feat.sh --stage 1 --stop_stage 1
+```
+
+### Decoding using `feature` input
```
./run.sh --stage 2 --stop_stage 2
```
-### decoding using wav
+### Decoding using `wav` input
```
./run.sh --stage 3 --stop_stage 3
```
+
+This stage using `u2_recognizer_main` to recognize wav file.
+
+The input is `scp` file which look like this:
+```text
+# head data/split1/1/aishell_test.scp
+BAC009S0764W0121 /workspace/PaddleSpeech/speechx/examples/u2pp_ol/wenetspeech/data/test/S0764/BAC009S0764W0121.wav
+BAC009S0764W0122 /workspace/PaddleSpeech/speechx/examples/u2pp_ol/wenetspeech/data/test/S0764/BAC009S0764W0122.wav
+...
+BAC009S0764W0125 /workspace/PaddleSpeech/speechx/examples/u2pp_ol/wenetspeech/data/test/S0764/BAC009S0764W0125.wav
+```
+
+If you want to recognize one wav, you can make `scp` file like this:
+```text
+key path/to/wav/file
+```
+
+Then specify `--wav_rspecifier=` param for `u2_recognizer_main` bin. For other flags meaning, please see `help`:
+```bash
+u2_recognizer_main --help
+```
+
+The exmaple using `u2_recgonize_main` bin please see `local/recognizer.sh`.
+
+### Decoding with `wav` using quant model
+
+`local/recognizer_quant.sh` is same to `local/recognizer.sh`, but using quanted model.
+
+
+## Results
+
+Please see [here](./RESULTS.md).
diff --git a/speechx/examples/u2pp_ol/wenetspeech/run.sh b/speechx/examples/u2pp_ol/wenetspeech/run.sh
index 870c5deeb..711d68083 100755
--- a/speechx/examples/u2pp_ol/wenetspeech/run.sh
+++ b/speechx/examples/u2pp_ol/wenetspeech/run.sh
@@ -72,13 +72,16 @@ fi
if [ ${stage} -le 1 ] && [ ${stop_stage} -ge 1 ]; then
+ # process cmvn and compute fbank feat
./local/feat.sh
fi
if [ ${stage} -le 2 ] && [ ${stop_stage} -ge 2 ]; then
+ # decode with fbank feat input
./local/decode.sh
fi
if [ ${stage} -le 3 ] && [ ${stop_stage} -ge 3 ]; then
+ # decode with wav input
./loca/recognizer.sh
fi
diff --git a/tests/unit/cli/test_cli.sh b/tests/unit/cli/test_cli.sh
index d571aa78f..3a58626d2 100755
--- a/tests/unit/cli/test_cli.sh
+++ b/tests/unit/cli/test_cli.sh
@@ -9,6 +9,10 @@ paddlespeech cls --input ./cat.wav --topk 10
# Punctuation_restoration
paddlespeech text --input 今天的天气真不错啊你下午有空吗我想约你一起去吃饭 --model ernie_linear_p3_wudao_fast
+# Speech SSL
+paddlespeech ssl --task asr --lang en --input ./en.wav
+paddlespeech ssl --task vector --lang en --input ./en.wav
+
# Speech_recognition
wget -c https://paddlespeech.bj.bcebos.com/PaddleAudio/zh.wav https://paddlespeech.bj.bcebos.com/PaddleAudio/en.wav
paddlespeech asr --input ./zh.wav
@@ -54,6 +58,7 @@ paddlespeech tts --am fastspeech2_vctk --voc hifigan_vctk --input "Life was like
paddlespeech tts --am tacotron2_csmsc --input "你好,欢迎使用百度飞桨深度学习框架!"
paddlespeech tts --am tacotron2_csmsc --voc wavernn_csmsc --input "你好,欢迎使用百度飞桨深度学习框架!"
paddlespeech tts --am tacotron2_ljspeech --voc pwgan_ljspeech --lang en --input "Life was like a box of chocolates, you never know what you're gonna get."
+paddlespeech tts --am fastspeech2_male --voc pwgan_male --input "你好,欢迎使用百度飞桨深度学习框架!"
# mix tts
# The `am` must be `fastspeech2_mix`!
# The `lang` must be `mix`!
@@ -93,5 +98,11 @@ paddlespeech stats --task text
paddlespeech stats --task vector
paddlespeech stats --task st
+# whisper text recognize
+paddlespeech whisper --task transcribe --input ./zh.wav
+
+# whisper recognize text and translate to English
+paddlespeech whisper --task translate --input ./zh.wav
+
echo -e "\033[32mTest success !!!\033[0m"