add all whisper model size support, test=asr (#2677)

* add all whisper model size support

* add choices in parser.
pull/2680/head
zxcd 2 years ago committed by GitHub
parent 0b4cf2211d
commit b71f1428c7
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23

@ -157,9 +157,11 @@ Via the easy-to-use, efficient, flexible and scalable implementation, our vision
- 🧩 *Cascaded models application*: as an extension of the typical traditional audio tasks, we combine the workflows of the aforementioned tasks with other fields like Natural language processing (NLP) and Computer Vision (CV).
### Recent Update
- 👑 2022.11.18: Add [Whisper CLI and Demos](https://github.com/PaddlePaddle/PaddleSpeech/pull/2640), support multi language recognition and translation.
- 🔥 2022.11.18: Add [Wav2vec2 CLI and Demos](https://github.com/PaddlePaddle/PaddleSpeech/blob/develop/demos/speech_ssl), Support ASR and Feature Extraction.
- 🎉 2022.11.17: Add [male voice for TTS](https://github.com/PaddlePaddle/PaddleSpeech/pull/2660).
- 🔥 2022.11.07: Add [U2/U2++ C++ High Performance Streaming ASR Deployment](https://github.com/PaddlePaddle/PaddleSpeech/blob/develop/speechx/examples/u2pp_ol/wenetspeech).- 👑 2022.11.01: Add [Adversarial Loss](https://arxiv.org/pdf/1907.04448.pdf) for [Chinese English mixed TTS](./examples/zh_en_tts/tts3).
- 🔥 2022.11.07: Add [U2/U2++ C++ High Performance Streaming ASR Deployment](https://github.com/PaddlePaddle/PaddleSpeech/blob/develop/speechx/examples/u2pp_ol/wenetspeech).
- 👑 2022.11.01: Add [Adversarial Loss](https://arxiv.org/pdf/1907.04448.pdf) for [Chinese English mixed TTS](./examples/zh_en_tts/tts3).
- 🔥 2022.10.26: Add [Prosody Prediction](https://github.com/PaddlePaddle/PaddleSpeech/tree/develop/examples/other/rhy) for TTS.
- 🎉 2022.10.21: Add [SSML](https://github.com/PaddlePaddle/PaddleSpeech/discussions/2538) for TTS Chinese Text Frontend.
- 👑 2022.10.11: Add [Wav2vec2ASR-en](./examples/librispeech/asr3), wav2vec2.0 fine-tuning for ASR on LibriSpeech.

@ -164,6 +164,7 @@
### 近期更新
- 👑 2022.11.18: 新增 [Whisper CLI 和 Demos](https://github.com/PaddlePaddle/PaddleSpeech/pull/2640),支持多种语言的识别与翻译。
- 🔥 2022.11.18: 新增 [Wav2vec2 CLI 和 Demos](https://github.com/PaddlePaddle/PaddleSpeech/blob/develop/demos/speech_ssl), 支持 ASR 和 特征提取.
- 🎉 2022.11.17: TTS 新增[高质量男性音色](https://github.com/PaddlePaddle/PaddleSpeech/pull/2660)。
- 🔥 2022.11.07: 新增 [U2/U2++ 高性能流式 ASR C++ 部署](https://github.com/PaddlePaddle/PaddleSpeech/blob/develop/speechx/examples/u2pp_ol/wenetspeech)。

@ -17,3 +17,5 @@ This directory contains many speech applications in multiple scenarios.
* story talker - book reader based on OCR and TTS
* style_fs2 - multi style control for FastSpeech2 model
* text_to_speech - convert text into speech
* self supervised pretraining - speech feature extraction and speech recognition based on wav2vec2
* Wishper - speech recognize and translate based on Whisper model

@ -17,3 +17,5 @@
* 会说话的故事书 - 基于 OCR 和语音合成的会说话的故事书。
* 个性化语音合成 - 基于 FastSpeech2 模型的个性化语音合成。
* 语音合成 - 基于给定的文本生成语音音频。
* 自监督预训练模型 - 基于wav2vec2的语音特征提取和语音识别。
* Whisper - 基于Whisper模型的语音识别与翻译。

@ -25,8 +25,12 @@ Whisper model trained by OpenAI whisper https://github.com/openai/whisper
# to recognize text
paddlespeech whisper --task transcribe --input ./zh.wav
# to change model English-Only base size model
paddlespeech whisper --lang en --size base --task transcribe --input ./en.wav
# to recognize text and translate to English
paddlespeech whisper --task translate --input ./zh.wav
```
Usage:
@ -37,7 +41,9 @@ Whisper model trained by OpenAI whisper https://github.com/openai/whisper
- `input`(required): Audio file to recognize.
- `model`: Model type of asr task. Default: `whisper-large`.
- `task`: Output type. Default: `transcribe`.
- `lang`: Model language. Default: `None`. Forcibly set the recognized language, which is determined by the model itself by default.
- `lang`: Model language. Default: ``. Use `en` to choice English-only model. Now [medium,base,small,tiny] size can support English-only.
- `size`: Model size for decode. Defalut: `large`. Now can support [large,medium,base,small,tiny].
- `language`: Set decode language. Default: `None`. Forcibly set the recognized language, which is determined by the model itself by default.
- `sample_rate`: Sample rate of the model. Default: `16000`. Other sampling rates are not supported now.
- `config`: Config of asr task. Use pretrained model when it is None. Default: `None`.
- `ckpt_path`: Model checkpoint. Use pretrained model when it is None. Default: `None`.

@ -27,6 +27,9 @@ Whisper模型由OpenAI Whisper训练 https://github.com/openai/whisper
# 识别文本
paddlespeech whisper --task transcribe --input ./zh.wav
#选择只支持英文的模型,并且更换不同大小的模型
paddlespeech whisper --lang en --size base --task transcribe --input ./en.wav
# 将语音翻译成英语
paddlespeech whisper --task translate --input ./zh.wav
```
@ -38,7 +41,9 @@ Whisper模型由OpenAI Whisper训练 https://github.com/openai/whisper
- `input`(必须输入):用于识别的音频文件。
- `model`ASR 任务的模型,默认值:`whisper-large`。
- `task`:输出类别,默认值:`transcribe`。
- `lang`:模型语言,默认值:`None`,强制设定识别出的语言,默认为模型自行判定。
- `lang`: 模型语言,默认值:``,使用`en`选择只支持英文的模型,目前可选择`en`的模型有[medium,base,small,tiny]。
- `size`: 模型大小,默认值:`large`,目前支持[large,medium,base,small,tiny]。
- `language`:设定解码语言,默认值:`None`,强制设定识别出的语言,默认为模型自行判定。
- `sample_rate`:音频采样率,默认值:`16000`目前Whisper暂不支持其他采样率。
- `config`ASR 任务的参数文件,若不设置则使用预训练模型中的默认配置,默认值:`None`。
- `ckpt_path`:模型参数文件,若不设置则下载解码模型使用,默认值:`None`。

@ -1,10 +1,13 @@
#!/bin/bash
# audio download
wget -c https://paddlespeech.bj.bcebos.com/PaddleAudio/zh.wav
wget -c https://paddlespeech.bj.bcebos.com/PaddleAudio/zh.wav https://paddlespeech.bj.bcebos.com/PaddleAudio/en.wav
# to recognize text
paddlespeech whisper --task transcribe --input ./zh.wav
# to recognize text and translate to English
paddlespeech whisper --task translate --input ./zh.wav
paddlespeech whisper --task translate --input ./zh.wav
# to change model English-Only model
paddlespeech whisper --lang en --size base --task transcribe --input ./en.wav

@ -36,6 +36,8 @@ from ..utils import timer_register
from paddlespeech.s2t.models.whisper import log_mel_spectrogram
from paddlespeech.s2t.models.whisper import ModelDimensions
from paddlespeech.s2t.models.whisper import Whisper
from paddlespeech.s2t.models.whisper.tokenizer import LANGUAGES
from paddlespeech.s2t.models.whisper.tokenizer import TO_LANGUAGE_CODE
from paddlespeech.s2t.utils.utility import UpdateConfig
__all__ = ['WhisperExecutor']
@ -53,16 +55,14 @@ class WhisperExecutor(BaseExecutor):
'--model',
type=str,
default='whisper',
choices=[
tag[:tag.index('-')]
for tag in self.task_resource.pretrained_models.keys()
],
choices=['whisper'],
help='Choose model type of asr task.')
self.parser.add_argument(
'--lang',
type=str,
default='None',
help='Choose model decode language. Default is None, recognized by model.'
default='',
choices=['', 'en'],
help='Choose model language. Default is "", English-only model set [en].'
)
self.parser.add_argument(
'--task',
@ -74,8 +74,17 @@ class WhisperExecutor(BaseExecutor):
'--size',
type=str,
default='large',
choices=['large', 'medium', 'base', 'small', 'tiny'],
help='Choose model size. now only support large, large:[whisper-large-16k]'
)
self.parser.add_argument(
'--language',
type=str,
default='None',
choices=sorted(LANGUAGES.keys()) + sorted(
[k.title() for k in TO_LANGUAGE_CODE.keys()]),
help='Choose model decode language. Default is None, recognized by model.'
)
self.parser.add_argument(
"--sample_rate",
type=int,
@ -129,9 +138,10 @@ class WhisperExecutor(BaseExecutor):
def _init_from_path(self,
model_type: str='whisper',
lang: str='None',
lang: str='',
task: str='transcribe',
size: str='large',
language: str='None',
sample_rate: int=16000,
cfg_path: Optional[os.PathLike]=None,
decode_method: str='ctc_prefix_beam_search',
@ -149,7 +159,10 @@ class WhisperExecutor(BaseExecutor):
if cfg_path is None or ckpt_path is None:
sample_rate_str = '16k' if sample_rate == 16000 else '8k'
tag = model_type + '-' + size + '-' + sample_rate_str
if lang == "":
tag = model_type + '-' + size + '-' + sample_rate_str
else:
tag = model_type + '-' + size + '-' + lang + '-' + sample_rate_str
self.task_resource.set_task_model(tag, version=None)
self.res_path = self.task_resource.res_dir
@ -194,8 +207,13 @@ class WhisperExecutor(BaseExecutor):
self.task = task
#set language
if lang is not None:
self.language = lang
if language is not None:
if lang == 'en' and language != 'en':
logger.info(
"{tag} is an English-only model, set language=English .")
self.language = 'en'
else:
self.language = language
def preprocess(self, model_type: str, input: Union[str, os.PathLike]):
"""
@ -234,7 +252,6 @@ class WhisperExecutor(BaseExecutor):
audio = log_mel_spectrogram(audio)
audio_len = paddle.to_tensor(audio.shape[0])
#audio = paddle.to_tensor(audio, dtype='float32').unsqueeze(axis=0)
self._inputs["audio"] = audio
self._inputs["audio_len"] = audio_len
@ -381,6 +398,7 @@ class WhisperExecutor(BaseExecutor):
lang = parser_args.lang
task = parser_args.task
size = parser_args.size
language = parser_args.language
sample_rate = parser_args.sample_rate
config = parser_args.config
ckpt_path = parser_args.ckpt_path
@ -404,6 +422,7 @@ class WhisperExecutor(BaseExecutor):
lang=lang,
task=task,
size=size,
language=language,
sample_rate=sample_rate,
config=config,
ckpt_path=ckpt_path,
@ -431,9 +450,10 @@ class WhisperExecutor(BaseExecutor):
def __call__(self,
audio_file: os.PathLike,
model: str='whisper',
lang: str='None',
lang: str='',
task: str='transcribe',
size: str='large',
language: str='None',
sample_rate: int=16000,
config: os.PathLike=None,
ckpt_path: os.PathLike=None,
@ -447,8 +467,9 @@ class WhisperExecutor(BaseExecutor):
"""
audio_file = os.path.abspath(audio_file)
paddle.set_device(device)
self._init_from_path(model, lang, task, size, sample_rate, config,
decode_method, num_decoding_left_chunks, ckpt_path)
self._init_from_path(model, lang, task, size, language, sample_rate,
config, decode_method, num_decoding_left_chunks,
ckpt_path)
if not self._check(audio_file, sample_rate, force_yes):
sys.exit(-1)
if rtf:

@ -487,6 +487,182 @@ whisper_dynamic_pretrained_models = {
'paddlespeech/s2t/models/whisper',
},
},
"whisper-base-en-16k": {
'1.3': {
'url':
'https://paddlespeech.bj.bcebos.com/whisper/whisper_model_20221118/whisper-base-en-model.tar.gz',
'md5':
'f5bb8cdff42c7031d9e4c0ea20f7ceee',
'cfg_path':
'whisper.yaml',
'ckpt_path':
'whisper-base-en-model',
'model':
'whisper-base-en-model.pdparams',
'params':
'whisper-base-en-model.pdparams',
'resuource_data':
'https://paddlespeech.bj.bcebos.com/whisper/whisper_model_20221108/assets.tar',
'resuource_data_md5':
'37a0a8abdb3641a51194f79567a93b61',
'resuource_path':
'paddlespeech/s2t/models/whisper',
},
},
"whisper-base-16k": {
'1.3': {
'url':
'https://paddlespeech.bj.bcebos.com/whisper/whisper_model_20221118/whisper-base-model.tar.gz',
'md5':
'46f254e89a01b71586af1a46d28d7ce9',
'cfg_path':
'whisper.yaml',
'ckpt_path':
'whisper-base-model',
'model':
'whisper-base-model.pdparams',
'params':
'whisper-base-model.pdparams',
'resuource_data':
'https://paddlespeech.bj.bcebos.com/whisper/whisper_model_20221108/assets.tar',
'resuource_data_md5':
'37a0a8abdb3641a51194f79567a93b61',
'resuource_path':
'paddlespeech/s2t/models/whisper',
},
},
"whisper-medium-en-16k": {
'1.3': {
'url':
'https://paddlespeech.bj.bcebos.com/whisper/whisper_model_20221118/whisper-medium-en-model.tar.gz',
'md5':
'98228f3ba94636c2760b51e5f3d6885f',
'cfg_path':
'whisper.yaml',
'ckpt_path':
'whisper-medium-en-model',
'model':
'whisper-medium-en-model.pdparams',
'params':
'whisper-medium-en-model.pdparams',
'resuource_data':
'https://paddlespeech.bj.bcebos.com/whisper/whisper_model_20221108/assets.tar',
'resuource_data_md5':
'37a0a8abdb3641a51194f79567a93b61',
'resuource_path':
'paddlespeech/s2t/models/whisper',
},
},
"whisper-medium-16k": {
'1.3': {
'url':
'https://paddlespeech.bj.bcebos.com/whisper/whisper_model_20221118/whisper-medium-model.tar.gz',
'md5':
'51ac154b264db75492ed1cc5280baebf',
'cfg_path':
'whisper.yaml',
'ckpt_path':
'whisper-medium-model',
'model':
'whisper-medium-model.pdparams',
'params':
'whisper-medium-model.pdparams',
'resuource_data':
'https://paddlespeech.bj.bcebos.com/whisper/whisper_model_20221108/assets.tar',
'resuource_data_md5':
'37a0a8abdb3641a51194f79567a93b61',
'resuource_path':
'paddlespeech/s2t/models/whisper',
},
},
"whisper-small-en-16k": {
'1.3': {
'url':
'https://paddlespeech.bj.bcebos.com/whisper/whisper_model_20221118/whisper-small-en-model.tar.gz',
'md5':
'973b784a335580a393e13a13995b110a',
'cfg_path':
'whisper.yaml',
'ckpt_path':
'whisper-small-en-model',
'model':
'whisper-small-en-model.pdparams',
'params':
'whisper-small-en-model.pdparams',
'resuource_data':
'https://paddlespeech.bj.bcebos.com/whisper/whisper_model_20221108/assets.tar',
'resuource_data_md5':
'37a0a8abdb3641a51194f79567a93b61',
'resuource_path':
'paddlespeech/s2t/models/whisper',
},
},
"whisper-small-16k": {
'1.3': {
'url':
'https://paddlespeech.bj.bcebos.com/whisper/whisper_model_20221118/whisper-small-model.tar.gz',
'md5':
'57a7530851cc98631c6fb29c606489c6',
'cfg_path':
'whisper.yaml',
'ckpt_path':
'whisper-small-model',
'model':
'whisper-small-model.pdparams',
'params':
'whisper-small-model.pdparams',
'resuource_data':
'https://paddlespeech.bj.bcebos.com/whisper/whisper_model_20221108/assets.tar',
'resuource_data_md5':
'37a0a8abdb3641a51194f79567a93b61',
'resuource_path':
'paddlespeech/s2t/models/whisper',
},
},
"whisper-tiny-en-16k": {
'1.3': {
'url':
'https://paddlespeech.bj.bcebos.com/whisper/whisper_model_20221118/whisper-tiny-en-model.tar.gz',
'md5':
'3ef5c0777e0bd4a1a240895167b0eb0d',
'cfg_path':
'whisper.yaml',
'ckpt_path':
'whisper-tiny-en-model',
'model':
'whisper-tiny-en-model.pdparams',
'params':
'whisper-tiny-en-model.pdparams',
'resuource_data':
'https://paddlespeech.bj.bcebos.com/whisper/whisper_model_20221108/assets.tar',
'resuource_data_md5':
'37a0a8abdb3641a51194f79567a93b61',
'resuource_path':
'paddlespeech/s2t/models/whisper',
},
},
"whisper-tiny-16k": {
'1.3': {
'url':
'https://paddlespeech.bj.bcebos.com/whisper/whisper_model_20221118/whisper-tiny-model.tar.gz',
'md5':
'ddf232cd16c85120e89c870a53451e53',
'cfg_path':
'whisper.yaml',
'ckpt_path':
'whisper-tiny-model',
'model':
'whisper-tiny-model.pdparams',
'params':
'whisper-tiny-model.pdparams',
'resuource_data':
'https://paddlespeech.bj.bcebos.com/whisper/whisper_model_20221108/assets.tar',
'resuource_data_md5':
'37a0a8abdb3641a51194f79567a93b61',
'resuource_path':
'paddlespeech/s2t/models/whisper',
},
},
}
# ---------------------------------

@ -475,7 +475,8 @@ def transcribe(
if dtype == np.float32:
decode_options["fp16"] = False
if decode_options.get("language", None) is None:
if decode_options.get(
"language", 'None') or decode_options.get("language", None) is None:
if not model.is_multilingual:
decode_options["language"] = "en"
else:

Loading…
Cancel
Save