[s2t] add whisper asr large model (#2640)

* add whisper asr large model decoding, test=asr

* fix code style.

* fix json code style.

* remove resource and fix code style.

* fix yapf

* add cli and demos, fix some code style.

* fix some problem by comment.

* fix yapf
pull/2670/head
zxcd 2 years ago committed by GitHub
parent dc9d3baf51
commit b1d3f59bcb
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23

@ -0,0 +1,89 @@
([简体中文](./README_cn.md)|English)
## Introduction
Whisper is a general-purpose speech recognition model. It is trained on a large dataset of diverse audio and is also a multi-task model that can perform multilingual speech recognition as well as speech translation and language identification.
Whisper model trained by OpenAI whisper https://github.com/openai/whisper
## Usage
### 1. Installation
see [installation](https://github.com/PaddlePaddle/PaddleSpeech/blob/develop/docs/source/install.md).
You can choose one way from easy, meduim and hard to install paddlespeech.
### 2. Prepare Input File
The input of this demo should be a WAV file(`.wav`), and the sample rate must be the same as the model.
Here are sample files for this demo that can be downloaded:
```bash
wget -c https://paddlespeech.bj.bcebos.com/PaddleAudio/zh.wav
```
### 3. Usage
- Command Line(Recommended)
```bash
# to recognize text
paddlespeech whisper --task transcribe --input ./zh.wav
# to recognize text and translate to English
paddlespeech whisper --task translate --input ./zh.wav
```
Usage:
```bash
paddlespeech whisper --help
```
Arguments:
- `input`(required): Audio file to recognize.
- `model`: Model type of asr task. Default: `whisper-large`.
- `task`: Output type. Default: `transcribe`.
- `lang`: Model language. Default: `None`. Forcibly set the recognized language, which is determined by the model itself by default.
- `sample_rate`: Sample rate of the model. Default: `16000`. Other sampling rates are not supported now.
- `config`: Config of asr task. Use pretrained model when it is None. Default: `None`.
- `ckpt_path`: Model checkpoint. Use pretrained model when it is None. Default: `None`.
- `yes`: No additional parameters required. Once set this parameter, it means accepting the request of the program by default, which includes transforming the audio sample rate. Default: `False`.
- `device`: Choose device to execute model inference. Default: default device of paddlepaddle in current environment.
- `verbose`: Show the log information.
- Python API
```python
import paddle
from paddlespeech.cli.whisper import WhisperExecutor
whisper_executor = WhisperExecutor()
# to recognize text
text = whisper_executor(
model='whisper-large',
task='transcribe',
sample_rate=16000,
config=None, # Set `config` and `ckpt_path` to None to use pretrained model.
ckpt_path=None,
audio_file='./zh.wav',
device=paddle.get_device())
print('ASR Result: \n{}'.format(text))
# to recognize text and translate to English
feature = whisper_executor(
model='whisper-large',
task='translate',
sample_rate=16000,
config=None, # Set `config` and `ckpt_path` to None to use pretrained model.
ckpt_path=None,
audio_file='./zh.wav',
device=paddle.get_device())
print('Representation: \n{}'.format(feature))
```
Output:
```bash
Transcribe Result:
Detected language: Chinese
[00:00.000 --> 00:05.000] 我认为跑步最重要的就是给我带来了身体健康
{'text': '我认为跑步最重要的就是给我带来了身体健康', 'segments': [{'id': 0, 'seek': 0, 'start': 0.0, 'end': 5.0, 'text': '我认为跑步最重要的就是给我带来了身体健康', 'tokens': [50364, 1654, 7422, 97, 13992, 32585, 31429, 8661, 24928, 1546, 5620, 49076, 4845, 99, 34912, 19847, 29485, 44201, 6346, 115, 50614], 'temperature': 0.0, 'avg_logprob': -0.23577967557040128, 'compression_ratio': 0.28169014084507044, 'no_speech_prob': 0.028302080929279327}], 'language': 'zh'}
Translate Result:
Detected language: Chinese
[00:00.000 --> 00:05.000] I think the most important thing about running is that it brings me good health.
{'text': ' I think the most important thing about running is that it brings me good health.', 'segments': [{'id': 0, 'seek': 0, 'start': 0.0, 'end': 5.0, 'text': ' I think the most important thing about running is that it brings me good health.', 'tokens': [50364, 286, 519, 264, 881, 1021, 551, 466, 2614, 307, 300, 309, 5607, 385, 665, 1585, 13, 50614], 'temperature': 0.0, 'avg_logprob': -0.47945233395225123, 'compression_ratio': 1.095890410958904, 'no_speech_prob': 0.028302080929279327}], 'language': 'zh'}

@ -0,0 +1,91 @@
(简体中文|[English](./README.md))
# Whisper模型
## 介绍
Whisper是一种通用的语音识别模型。它是在多种音频的大数据集上训练的也是一个多任务模型可以执行多语言语音识别以及语音翻译和语言识别。
Whisper模型由OpenAI Whisper训练 https://github.com/openai/whisper
## 使用方法
### 1. 安装
请看[安装文档](https://github.com/PaddlePaddle/PaddleSpeech/blob/develop/docs/source/install_cn.md)。
你可以从 easymediumhard 三中方式中选择一种方式安装。
### 2. 准备输入
这个 demo 的输入应该是一个 WAV 文件(`.wav`),并且采样率必须与模型的采样率相同。
可以下载此 demo 的示例音频:
```bash
wget -c https://paddlespeech.bj.bcebos.com/PaddleAudio/zh.wav
```
### 3. 使用方法
- 命令行 (推荐使用)
```bash
# 识别文本
paddlespeech whisper --task transcribe --input ./zh.wav
# 将语音翻译成英语
paddlespeech whisper --task translate --input ./zh.wav
```
使用方法:
```bash
paddlespeech whisper --help
```
参数:
- `input`(必须输入):用于识别的音频文件。
- `model`ASR 任务的模型,默认值:`whisper-large`。
- `task`:输出类别,默认值:`transcribe`。
- `lang`:模型语言,默认值:`None`,强制设定识别出的语言,默认为模型自行判定。
- `sample_rate`:音频采样率,默认值:`16000`目前Whisper暂不支持其他采样率。
- `config`ASR 任务的参数文件,若不设置则使用预训练模型中的默认配置,默认值:`None`。
- `ckpt_path`:模型参数文件,若不设置则下载解码模型使用,默认值:`None`。
- `yes`;不需要设置额外的参数,一旦设置了该参数,说明你默认同意程序的所有请求,其中包括自动转换输入音频的采样率。默认值:`False`。
- `device`:执行预测的设备,默认值:当前系统下 paddlepaddle 的默认 device。
- `verbose`: 如果使用,显示 logger 信息。
- Python API
```python
import paddle
from paddlespeech.cli.whisper import WhisperExecutor
whisper_executor = WhisperExecutor()
# 识别文本
text = whisper_executor(
model='whisper-large',
task='transcribe',
sample_rate=16000,
config=None, # Set `config` and `ckpt_path` to None to use pretrained model.
ckpt_path=None,
audio_file='./zh.wav',
device=paddle.get_device())
print('ASR Result: \n{}'.format(text))
# 将语音翻译成英语
feature = whisper_executor(
model='whisper-large',
task='translate',
sample_rate=16000,
config=None, # Set `config` and `ckpt_path` to None to use pretrained model.
ckpt_path=None,
audio_file='./zh.wav',
device=paddle.get_device())
print('Representation: \n{}'.format(feature))
```
输出:
```bash
Transcribe Result:
Detected language: Chinese
[00:00.000 --> 00:05.000] 我认为跑步最重要的就是给我带来了身体健康
{'text': '我认为跑步最重要的就是给我带来了身体健康', 'segments': [{'id': 0, 'seek': 0, 'start': 0.0, 'end': 5.0, 'text': '我认为跑步最重要的就是给我带来了身体健康', 'tokens': [50364, 1654, 7422, 97, 13992, 32585, 31429, 8661, 24928, 1546, 5620, 49076, 4845, 99, 34912, 19847, 29485, 44201, 6346, 115, 50614], 'temperature': 0.0, 'avg_logprob': -0.23577967557040128, 'compression_ratio': 0.28169014084507044, 'no_speech_prob': 0.028302080929279327}], 'language': 'zh'}
Translate Result:
Detected language: Chinese
[00:00.000 --> 00:05.000] I think the most important thing about running is that it brings me good health.
{'text': ' I think the most important thing about running is that it brings me good health.', 'segments': [{'id': 0, 'seek': 0, 'start': 0.0, 'end': 5.0, 'text': ' I think the most important thing about running is that it brings me good health.', 'tokens': [50364, 286, 519, 264, 881, 1021, 551, 466, 2614, 307, 300, 309, 5607, 385, 665, 1585, 13, 50614], 'temperature': 0.0, 'avg_logprob': -0.47945233395225123, 'compression_ratio': 1.095890410958904, 'no_speech_prob': 0.028302080929279327}], 'language': 'zh'}

@ -0,0 +1,10 @@
#!/bin/bash
# audio download
wget -c https://paddlespeech.bj.bcebos.com/PaddleAudio/zh.wav
# to recognize text
paddlespeech whisper --task transcribe --input ./zh.wav
# to recognize text and translate to English
paddlespeech whisper --task translate --input ./zh.wav

@ -83,7 +83,8 @@ model_name_format = {
'st': 'Model-Source language-Target language', 'st': 'Model-Source language-Target language',
'text': 'Model-Task-Language', 'text': 'Model-Task-Language',
'tts': 'Model-Language', 'tts': 'Model-Language',
'vector': 'Model-Sample Rate' 'vector': 'Model-Sample Rate',
'whisper': 'Model-Language-Sample Rate'
} }
@ -94,7 +95,9 @@ class StatsCommand:
def __init__(self): def __init__(self):
self.parser = argparse.ArgumentParser( self.parser = argparse.ArgumentParser(
prog='paddlespeech.stats', add_help=True) prog='paddlespeech.stats', add_help=True)
self.task_choices = ['asr', 'cls', 'st', 'text', 'tts', 'vector', 'kws'] self.task_choices = [
'asr', 'cls', 'st', 'text', 'tts', 'vector', 'kws', 'whisper'
]
self.parser.add_argument( self.parser.add_argument(
'--task', '--task',
type=str, type=str,
@ -141,6 +144,10 @@ _commands = {
'tts': ['Text to Speech infer command.', 'TTSExecutor'], 'tts': ['Text to Speech infer command.', 'TTSExecutor'],
'vector': ['Speech to vector embedding infer command.', 'VectorExecutor'], 'vector': ['Speech to vector embedding infer command.', 'VectorExecutor'],
'kws': ['Keyword Spotting infer command.', 'KWSExecutor'], 'kws': ['Keyword Spotting infer command.', 'KWSExecutor'],
'whisper': [
'Whisper model for speech to text or translate speech to English.',
'WhisperExecutor'
]
} }
for com, info in _commands.items(): for com, info in _commands.items():

@ -0,0 +1,14 @@
# Copyright (c) 2021 PaddlePaddle Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
from .infer import WhisperExecutor

@ -0,0 +1,468 @@
# Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import argparse
import io
import os
import sys
import time
from collections import OrderedDict
from typing import List
from typing import Optional
from typing import Union
import librosa
import numpy as np
import paddle
import soundfile
from yacs.config import CfgNode
from ..download import get_path_from_url
from ..executor import BaseExecutor
from ..log import logger
from ..utils import CLI_TIMER
from ..utils import stats_wrapper
from ..utils import timer_register
from paddlespeech.s2t.models.whisper import log_mel_spectrogram
from paddlespeech.s2t.models.whisper import ModelDimensions
from paddlespeech.s2t.models.whisper import Whisper
from paddlespeech.s2t.utils.utility import UpdateConfig
__all__ = ['WhisperExecutor']
@timer_register
class WhisperExecutor(BaseExecutor):
def __init__(self):
super().__init__('whisper')
self.parser = argparse.ArgumentParser(
prog='paddlespeech.whisper', add_help=True)
self.parser.add_argument(
'--input', type=str, default=None, help='Audio file to recognize.')
self.parser.add_argument(
'--model',
type=str,
default='whisper',
choices=[
tag[:tag.index('-')]
for tag in self.task_resource.pretrained_models.keys()
],
help='Choose model type of asr task.')
self.parser.add_argument(
'--lang',
type=str,
default='None',
help='Choose model decode language. Default is None, recognized by model.'
)
self.parser.add_argument(
'--task',
type=str,
default='transcribe',
choices=["transcribe", "translate"],
help='Choose task tpye for transcribe or translate.')
self.parser.add_argument(
'--size',
type=str,
default='large',
help='Choose model size. now only support large, large:[whisper-large-16k]'
)
self.parser.add_argument(
"--sample_rate",
type=int,
default=16000,
choices=[16000],
help='Choose the audio sample rate of the model. only support 16000')
self.parser.add_argument(
'--config',
type=str,
default=None,
help='Config of asr task. Use deault config when it is None.')
self.parser.add_argument(
'--decode_method',
type=str,
default='ctc_prefix_beam_search',
choices=['ctc_greedy_search', 'ctc_prefix_beam_search'],
help='only support transformer and conformer model')
self.parser.add_argument(
'--ckpt_path',
type=str,
default=None,
help='Checkpoint file of model.')
self.parser.add_argument(
'--yes',
'-y',
action="store_true",
default=False,
help='No additional parameters required. \
Once set this parameter, it means accepting the request of the program by default, \
which includes transforming the audio sample rate')
self.parser.add_argument(
'--rtf',
action="store_true",
default=False,
help='Show Real-time Factor(RTF).')
self.parser.add_argument(
'--device',
type=str,
default=paddle.get_device(),
help='Choose device to execute model inference.')
self.parser.add_argument(
'-d',
'--job_dump_result',
action='store_true',
help='Save job result into file.')
self.parser.add_argument(
'-v',
'--verbose',
action='store_true',
help='Increase logger verbosity of current task.')
def _init_from_path(self,
model_type: str='whisper',
lang: str='None',
task: str='transcribe',
size: str='large',
sample_rate: int=16000,
cfg_path: Optional[os.PathLike]=None,
decode_method: str='ctc_prefix_beam_search',
num_decoding_left_chunks: int=-1,
ckpt_path: Optional[os.PathLike]=None):
"""
Init model and other resources from a specific path.
"""
logger.debug("start to init the model")
# default max_len: unit:second
self.max_len = 50
if hasattr(self, 'model'):
logger.debug('Model had been initialized.')
return
if cfg_path is None or ckpt_path is None:
sample_rate_str = '16k' if sample_rate == 16000 else '8k'
tag = model_type + '-' + size + '-' + sample_rate_str
self.task_resource.set_task_model(tag, version=None)
self.res_path = self.task_resource.res_dir
self.cfg_path = os.path.join(
self.res_path, self.task_resource.res_dict['cfg_path'])
self.ckpt_path = os.path.join(
self.res_path,
self.task_resource.res_dict['ckpt_path'] + ".pdparams")
logger.debug(self.res_path)
else:
self.cfg_path = os.path.abspath(cfg_path)
self.ckpt_path = os.path.abspath(ckpt_path + ".pdparams")
self.res_path = os.path.dirname(
os.path.dirname(os.path.abspath(self.cfg_path)))
logger.debug(self.cfg_path)
logger.debug(self.ckpt_path)
#Init body.
self.config = CfgNode(new_allowed=True)
self.config.merge_from_file(self.cfg_path)
with UpdateConfig(self.config):
if "whisper" in model_type:
resource_url = self.task_resource.res_dict['resuource_data']
resource_md5 = self.task_resource.res_dict['resuource_data_md5']
resuource_path = self.task_resource.res_dict['resuource_path']
self.download_resource(resource_url, resuource_path,
resource_md5)
else:
raise Exception("wrong type")
# load model
model_dict = paddle.load(self.ckpt_path)
dims = ModelDimensions(**model_dict["dims"])
self.model = Whisper(dims)
self.model.load_dict(model_dict)
self.model.eval()
#set task
if task is not None:
self.task = task
#set language
if lang is not None:
self.language = lang
def preprocess(self, model_type: str, input: Union[str, os.PathLike]):
"""
Input preprocess and return paddle.Tensor stored in self.input.
Input content can be a text(tts), a file(asr, cls) or a streaming(not supported yet).
"""
audio_file = input
if isinstance(audio_file, (str, os.PathLike)):
logger.debug("Preprocess audio_file:" + audio_file)
elif isinstance(audio_file, io.BytesIO):
audio_file.seek(0)
# Get the object for feature extraction
# whisper hard-coded audio hyperparameters, params in paddlespeech/s2t/models/whisper/whisper.py
logger.debug("read the audio file")
audio, audio_sample_rate = soundfile.read(
audio_file, dtype="float32", always_2d=True)
if self.change_format:
if audio.shape[1] >= 2:
audio = audio.mean(axis=1, dtype=np.int16)
else:
audio = audio[:, 0]
# pcm16 -> pcm 32
audio = self._pcm16to32(audio)
audio = librosa.resample(
audio, orig_sr=audio_sample_rate, target_sr=self.sample_rate)
audio_sample_rate = self.sample_rate
# pcm32 -> pcm 16
audio = self._pcm32to16(audio)
else:
audio = audio[:, 0]
logger.debug(f"audio shape: {audio.shape}")
# fbank
audio = log_mel_spectrogram(audio)
audio_len = paddle.to_tensor(audio.shape[0])
#audio = paddle.to_tensor(audio, dtype='float32').unsqueeze(axis=0)
self._inputs["audio"] = audio
self._inputs["audio_len"] = audio_len
logger.debug(f"audio feat shape: {audio.shape}")
logger.debug("audio feat process success")
@paddle.no_grad()
def infer(self, model_type: str):
"""
Model inference and result stored in self.output.
"""
logger.debug("start to infer the model to get the output")
cfg = self.config
audio = self._inputs["audio"]
if cfg.temperature_increment_on_fallback is not None:
temperature = tuple(
np.arange(cfg.temperature, 1.0 + 1e-6,
cfg.temperature_increment_on_fallback))
else:
temperature = [cfg.temperature]
self._outputs["result"] = self.model.transcribe(
audio,
verbose=cfg.verbose,
task=self.task,
language=self.language,
temperature=temperature,
compression_ratio_threshold=cfg.compression_ratio_threshold,
logprob_threshold=cfg.logprob_threshold,
best_of=cfg.best_of,
beam_size=cfg.beam_size,
patience=cfg.patience,
length_penalty=cfg.length_penalty,
initial_prompt=cfg.initial_prompt,
condition_on_previous_text=cfg.condition_on_previous_text,
no_speech_threshold=cfg.no_speech_threshold)
def postprocess(self) -> Union[str, os.PathLike]:
"""
Output postprocess and return human-readable results such as texts and audio files.
"""
return self._outputs["result"]
def download_resource(self, url, lm_dir, md5sum):
download_path = get_path_from_url(
url=url,
root_dir=lm_dir,
md5sum=md5sum,
decompress=True, )
def _pcm16to32(self, audio):
assert (audio.dtype == np.int16)
audio = audio.astype("float32")
bits = np.iinfo(np.int16).bits
audio = audio / (2**(bits - 1))
return audio
def _pcm32to16(self, audio):
assert (audio.dtype == np.float32)
bits = np.iinfo(np.int16).bits
audio = audio * (2**(bits - 1))
audio = np.round(audio).astype("int16")
return audio
def _check(self, audio_file: str, sample_rate: int, force_yes: bool=False):
self.sample_rate = sample_rate
if self.sample_rate != 16000 and self.sample_rate != 8000:
logger.error(
"invalid sample rate, please input --sr 8000 or --sr 16000")
return False
if isinstance(audio_file, (str, os.PathLike)):
if not os.path.isfile(audio_file):
logger.error("Please input the right audio file path")
return False
elif isinstance(audio_file, io.BytesIO):
audio_file.seek(0)
logger.debug("checking the audio file format......")
try:
audio, audio_sample_rate = soundfile.read(
audio_file, dtype="int16", always_2d=True)
audio_duration = audio.shape[0] / audio_sample_rate
if audio_duration > self.max_len:
logger.error(
f"Please input audio file less then {self.max_len} seconds.\n"
)
return False
except Exception as e:
logger.exception(e)
logger.error(
f"can not open the audio file, please check the audio file({audio_file}) format is 'wav'. \n \
you can try to use sox to change the file format.\n \
For example: \n \
sample rate: 16k \n \
sox input_audio.xx --rate 16k --bits 16 --channels 1 output_audio.wav \n \
sample rate: 8k \n \
sox input_audio.xx --rate 8k --bits 16 --channels 1 output_audio.wav \n \
")
return False
logger.debug("The sample rate is %d" % audio_sample_rate)
if audio_sample_rate != self.sample_rate:
logger.warning("The sample rate of the input file is not {}.\n \
The program will resample the wav file to {}.\n \
If the result does not meet your expectations\n \
Please input the 16k 16 bit 1 channel wav file. \
".format(self.sample_rate, self.sample_rate))
if force_yes is False:
while (True):
logger.debug(
"Whether to change the sample rate and the channel. Y: change the sample. N: exit the prgream."
)
content = input("Input(Y/N):")
if content.strip() == "Y" or content.strip(
) == "y" or content.strip() == "yes" or content.strip(
) == "Yes":
logger.debug(
"change the sampele rate, channel to 16k and 1 channel"
)
break
elif content.strip() == "N" or content.strip(
) == "n" or content.strip() == "no" or content.strip(
) == "No":
logger.debug("Exit the program")
return False
else:
logger.warning("Not regular input, please input again")
self.change_format = True
else:
logger.debug("The audio file format is right")
self.change_format = False
return True
def execute(self, argv: List[str]) -> bool:
"""
Command line entry.
"""
parser_args = self.parser.parse_args(argv)
model = parser_args.model
lang = parser_args.lang
task = parser_args.task
size = parser_args.size
sample_rate = parser_args.sample_rate
config = parser_args.config
ckpt_path = parser_args.ckpt_path
decode_method = parser_args.decode_method
force_yes = parser_args.yes
rtf = parser_args.rtf
device = parser_args.device
if not parser_args.verbose:
self.disable_task_loggers()
task_source = self.get_input_source(parser_args.input)
task_results = OrderedDict()
has_exceptions = False
for id_, input_ in task_source.items():
try:
res = self(
audio_file=input_,
model=model,
lang=lang,
task=task,
size=size,
sample_rate=sample_rate,
config=config,
ckpt_path=ckpt_path,
decode_method=decode_method,
force_yes=force_yes,
rtf=rtf,
device=device)
task_results[id_] = res
except Exception as e:
has_exceptions = True
task_results[id_] = f'{e.__class__.__name__}: {e}'
if rtf:
self.show_rtf(CLI_TIMER[self.__class__.__name__])
self.process_task_results(parser_args.input, task_results,
parser_args.job_dump_result)
if has_exceptions:
return False
else:
return True
@stats_wrapper
def __call__(self,
audio_file: os.PathLike,
model: str='whisper',
lang: str='None',
task: str='transcribe',
size: str='large',
sample_rate: int=16000,
config: os.PathLike=None,
ckpt_path: os.PathLike=None,
decode_method: str='attention_rescoring',
num_decoding_left_chunks: int=-1,
force_yes: bool=False,
rtf: bool=False,
device=paddle.get_device()):
"""
Python API to call an executor.
"""
audio_file = os.path.abspath(audio_file)
paddle.set_device(device)
self._init_from_path(model, lang, task, size, sample_rate, config,
decode_method, num_decoding_left_chunks, ckpt_path)
if not self._check(audio_file, sample_rate, force_yes):
sys.exit(-1)
if rtf:
k = self.__class__.__name__
CLI_TIMER[k]['start'].append(time.time())
self.preprocess(model, audio_file)
self.infer(model)
res = self.postprocess() # Retrieve result of asr.
if rtf:
CLI_TIMER[k]['end'].append(time.time())
audio, audio_sample_rate = soundfile.read(
audio_file, dtype="int16", always_2d=True)
CLI_TIMER[k]['extra'].append(audio.shape[0] / audio_sample_rate)
return res

@ -29,6 +29,11 @@ model_alias = {
"transformer": ["paddlespeech.s2t.models.u2:U2Model"], "transformer": ["paddlespeech.s2t.models.u2:U2Model"],
"wenetspeech": ["paddlespeech.s2t.models.u2:U2Model"], "wenetspeech": ["paddlespeech.s2t.models.u2:U2Model"],
# ---------------------------------
# ------------ Whisper ------------
# ---------------------------------
"whisper": ["paddlespeech.s2t.models.whisper:Whisper"],
# --------------------------------- # ---------------------------------
# -------------- CLS -------------- # -------------- CLS --------------
# --------------------------------- # ---------------------------------

@ -25,6 +25,7 @@ __all__ = [
'tts_static_pretrained_models', 'tts_static_pretrained_models',
'tts_onnx_pretrained_models', 'tts_onnx_pretrained_models',
'vector_dynamic_pretrained_models', 'vector_dynamic_pretrained_models',
'whisper_dynamic_pretrained_models',
] ]
# The tags for pretrained_models should be "{model_name}[_{dataset}][-{lang}][-...]". # The tags for pretrained_models should be "{model_name}[_{dataset}][-{lang}][-...]".
@ -424,6 +425,31 @@ asr_onnx_pretrained_models = {
}, },
} }
whisper_dynamic_pretrained_models = {
"whisper-large-16k": {
'1.3': {
'url':
'https://paddlespeech.bj.bcebos.com/whisper/whisper_model_20221108/whisper-large-model.tar.gz',
'md5':
'364c4d670835e5ca489045e1c29d75fe',
'cfg_path':
'whisper.yaml',
'ckpt_path':
'whisper-large-model',
'model':
'whisper-large-model.pdparams',
'params':
'whisper-large-model.pdparams',
'resuource_data':
'https://paddlespeech.bj.bcebos.com/whisper/whisper_model_20221108/assets.tar',
'resuource_data_md5':
'37a0a8abdb3641a51194f79567a93b61',
'resuource_path':
'paddlespeech/s2t/models/whisper',
},
},
}
# --------------------------------- # ---------------------------------
# -------------- CLS -------------- # -------------- CLS --------------
# --------------------------------- # ---------------------------------

@ -22,7 +22,7 @@ from ..utils.dynamic_import import dynamic_import
from ..utils.env import MODEL_HOME from ..utils.env import MODEL_HOME
from .model_alias import model_alias from .model_alias import model_alias
task_supported = ['asr', 'cls', 'st', 'text', 'tts', 'vector', 'kws'] task_supported = ['asr', 'cls', 'st', 'text', 'tts', 'vector', 'kws', 'whisper']
model_format_supported = ['dynamic', 'static', 'onnx'] model_format_supported = ['dynamic', 'static', 'onnx']
inference_mode_supported = ['online', 'offline'] inference_mode_supported = ['online', 'offline']

@ -0,0 +1,122 @@
# Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.∏
# See the License for the specific language governing permissions and
# limitations under the License.
# Modified from Whisper (https://github.com/openai/whisper/whisper/)
import os.path
import sys
import distutils
import numpy as np
import paddle
import soundfile
from yacs.config import CfgNode
from paddlespeech.s2t.models.whisper import log_mel_spectrogram
from paddlespeech.s2t.models.whisper import ModelDimensions
from paddlespeech.s2t.models.whisper import transcribe
from paddlespeech.s2t.models.whisper import Whisper
from paddlespeech.s2t.training.cli import default_argument_parser
from paddlespeech.s2t.utils.log import Log
logger = Log(__name__).getlog()
class WhisperInfer():
def __init__(self, config, args):
self.args = args
self.config = config
self.audio_file = args.audio_file
paddle.set_device('gpu' if self.args.ngpu > 0 else 'cpu')
config.pop("ngpu")
#load_model
model_dict = paddle.load(self.config.model_file)
config.pop("model_file")
dims = ModelDimensions(**model_dict["dims"])
self.model = Whisper(dims)
self.model.load_dict(model_dict)
def run(self):
check(args.audio_file)
with paddle.no_grad():
temperature = config.pop("temperature")
temperature_increment_on_fallback = config.pop(
"temperature_increment_on_fallback")
if temperature_increment_on_fallback is not None:
temperature = tuple(
np.arange(temperature, 1.0 + 1e-6,
temperature_increment_on_fallback))
else:
temperature = [temperature]
#load audio
mel = log_mel_spectrogram(args.audio)
result = transcribe(
self.model, mel, temperature=temperature, **config)
if args.result_file is not None:
with open(args.result_file, 'w') as f:
f.write(str(result))
return result
def check(audio_file: str):
if not os.path.isfile(audio_file):
print("Please input the right audio file path")
sys.exit(-1)
logger.info("checking the audio file format......")
try:
_, sample_rate = soundfile.read(audio_file)
except Exception as e:
logger.error(str(e))
logger.error(
"can not open the wav file, please check the audio file format")
sys.exit(-1)
logger.info("The sample rate is %d" % sample_rate)
assert (sample_rate == 16000)
logger.info("The audio file format is right")
def main(config, args):
WhisperInfer(config, args).run()
if __name__ == "__main__":
parser = default_argument_parser()
# save asr result to
parser.add_argument(
"--result_file", type=str, help="path of save the asr result")
parser.add_argument(
"--audio_file", type=str, help="path of the input audio file")
parser.add_argument(
"--debug",
type=distutils.util.strtobool,
default=False,
help="for debug.")
args = parser.parse_args()
config = CfgNode(new_allowed=True)
if args.config:
config.merge_from_file(args.config)
if args.decode_cfg:
decode_confs = CfgNode(new_allowed=True)
decode_confs.merge_from_file(args.decode_cfg)
config.decode = decode_confs
if args.opts:
config.merge_from_list(args.opts)
config.freeze()
main(config, args)

@ -0,0 +1,12 @@
# MIT License, Copyright (c) 2022 OpenAI.
# Copyright (c) 2022 PaddlePaddle Authors and . All Rights Reserved.
#
# Modified from OpenAI Whisper 2022 (https://github.com/openai/whisper/whisper/__init__.py)
from paddlespeech.s2t.models.whisper.whipser import decode
from paddlespeech.s2t.models.whisper.whipser import DecodingOptions
from paddlespeech.s2t.models.whisper.whipser import DecodingResult
from paddlespeech.s2t.models.whisper.whipser import detect_language
from paddlespeech.s2t.models.whisper.whipser import log_mel_spectrogram
from paddlespeech.s2t.models.whisper.whipser import ModelDimensions
from paddlespeech.s2t.models.whisper.whipser import transcribe
from paddlespeech.s2t.models.whisper.whipser import Whisper

@ -0,0 +1,360 @@
# MIT License, Copyright (c) 2022 OpenAI.
# Copyright (c) 2022 PaddlePaddle Authors and . All Rights Reserved.
#
# Modified from OpenAI Whisper 2022 (https://github.com/openai/whisper/whisper/tokenizer.py)
import os
from dataclasses import dataclass
from functools import lru_cache
from typing import List
from typing import Optional
from typing import Tuple
from typing import Union
import numpy as np
import paddle
from paddlenlp.transformers import GPTTokenizer
LANGUAGES = {
"en": "english",
"zh": "chinese",
"de": "german",
"es": "spanish",
"ru": "russian",
"ko": "korean",
"fr": "french",
"ja": "japanese",
"pt": "portuguese",
"tr": "turkish",
"pl": "polish",
"ca": "catalan",
"nl": "dutch",
"ar": "arabic",
"sv": "swedish",
"it": "italian",
"id": "indonesian",
"hi": "hindi",
"fi": "finnish",
"vi": "vietnamese",
"iw": "hebrew",
"uk": "ukrainian",
"el": "greek",
"ms": "malay",
"cs": "czech",
"ro": "romanian",
"da": "danish",
"hu": "hungarian",
"ta": "tamil",
"no": "norwegian",
"th": "thai",
"ur": "urdu",
"hr": "croatian",
"bg": "bulgarian",
"lt": "lithuanian",
"la": "latin",
"mi": "maori",
"ml": "malayalam",
"cy": "welsh",
"sk": "slovak",
"te": "telugu",
"fa": "persian",
"lv": "latvian",
"bn": "bengali",
"sr": "serbian",
"az": "azerbaijani",
"sl": "slovenian",
"kn": "kannada",
"et": "estonian",
"mk": "macedonian",
"br": "breton",
"eu": "basque",
"is": "icelandic",
"hy": "armenian",
"ne": "nepali",
"mn": "mongolian",
"bs": "bosnian",
"kk": "kazakh",
"sq": "albanian",
"sw": "swahili",
"gl": "galician",
"mr": "marathi",
"pa": "punjabi",
"si": "sinhala",
"km": "khmer",
"sn": "shona",
"yo": "yoruba",
"so": "somali",
"af": "afrikaans",
"oc": "occitan",
"ka": "georgian",
"be": "belarusian",
"tg": "tajik",
"sd": "sindhi",
"gu": "gujarati",
"am": "amharic",
"yi": "yiddish",
"lo": "lao",
"uz": "uzbek",
"fo": "faroese",
"ht": "haitian creole",
"ps": "pashto",
"tk": "turkmen",
"nn": "nynorsk",
"mt": "maltese",
"sa": "sanskrit",
"lb": "luxembourgish",
"my": "myanmar",
"bo": "tibetan",
"tl": "tagalog",
"mg": "malagasy",
"as": "assamese",
"tt": "tatar",
"haw": "hawaiian",
"ln": "lingala",
"ha": "hausa",
"ba": "bashkir",
"jw": "javanese",
"su": "sundanese",
}
# language code lookup by name, with a few language aliases
TO_LANGUAGE_CODE = {
**{language: code for code, language in LANGUAGES.items()},
"burmese": "my",
"valencian": "ca",
"flemish": "nl",
"haitian": "ht",
"letzeburgesch": "lb",
"pushto": "ps",
"panjabi": "pa",
"moldavian": "ro",
"moldovan": "ro",
"sinhalese": "si",
"castilian": "es",
}
@dataclass(frozen=True)
class Tokenizer:
"""A thin wrapper around `GPTTokenizer` providing quick access to special tokens"""
tokenizer: "GPTTokenizer"
language: Optional[str]
sot_sequence: Tuple[int]
def encode(self, text, **kwargs):
return self.tokenizer.encode(text, **kwargs)
def decode(self,
token_ids: Union[int, List[int], np.ndarray, paddle.Tensor],
**kwargs):
if len(token_ids) > 1:
ids_list = []
for ids in token_ids:
if paddle.is_tensor(ids):
ids = ids.item()
if ids < len(self.tokenizer):
ids_list.append(ids)
token_ids = ids_list
return self.tokenizer.decode(token_ids, **kwargs)
def decode_with_timestamps(self, tokens) -> str:
"""
Timestamp tokens are above the special tokens' id range and are ignored by `decode()`.
This method decodes given tokens with timestamps tokens annotated, e.g. "<|1.08|>".
"""
outputs = [[]]
for token in tokens:
if token >= self.timestamp_begin:
timestamp = f"<|{(token - self.timestamp_begin) * 0.02:.2f}|>"
outputs.append(timestamp)
outputs.append([])
else:
outputs[-1].append(token)
outputs = [
s if isinstance(s, str) else self.tokenizer.decode(s)
for s in outputs
]
return "".join(outputs)
@property
@lru_cache()
def eot(self) -> int:
return self.tokenizer.eos_token_id
@property
@lru_cache()
def sot(self) -> int:
return self._get_single_token_id("<|startoftranscript|>")
@property
@lru_cache()
def sot_lm(self) -> int:
return self._get_single_token_id("<|startoflm|>")
@property
@lru_cache()
def sot_prev(self) -> int:
return self._get_single_token_id("<|startofprev|>")
@property
@lru_cache()
def no_speech(self) -> int:
return self._get_single_token_id("<|nospeech|>")
@property
@lru_cache()
def no_timestamps(self) -> int:
return self._get_single_token_id("<|notimestamps|>")
@property
@lru_cache()
def timestamp_begin(self) -> int:
return self.tokenizer.all_special_ids[-1] + 1
@property
@lru_cache()
def language_token(self) -> int:
"""Returns the token id corresponding to the value of the `language` field"""
if self.language is None:
raise ValueError(
"This tokenizer does not have language token configured")
additional_tokens = dict(
zip(
self.tokenizer.additional_special_tokens,
self.tokenizer.additional_special_tokens_ids, ))
candidate = f"<|{self.language}|>"
if candidate in additional_tokens:
return additional_tokens[candidate]
raise KeyError(f"Language {self.language} not found in tokenizer.")
@property
@lru_cache()
def all_language_tokens(self) -> Tuple[int]:
result = []
for token, token_id in zip(
self.tokenizer.additional_special_tokens,
self.tokenizer.additional_special_tokens_ids, ):
if token.strip("<|>") in LANGUAGES:
result.append(token_id)
return tuple(result)
@property
@lru_cache()
def all_language_codes(self) -> Tuple[str]:
return tuple(
self.decode([l]).strip("<|>") for l in self.all_language_tokens)
@property
@lru_cache()
def sot_sequence_including_notimestamps(self) -> Tuple[int]:
return tuple(list(self.sot_sequence) + [self.no_timestamps])
@property
@lru_cache()
def non_speech_tokens(self) -> Tuple[int]:
"""
Returns the list of tokens to suppress in order to avoid any speaker tags or non-speech
annotations, to prevent sampling texts that are not actually spoken in the audio, e.g.
-
- ( SPEAKING FOREIGN LANGUAGE )
- [DAVID] Hey there,
keeping basic punctuations like commas, periods, question marks, exclamation points, etc.
"""
symbols = list("\"#()*+/:;<=>@[\\]^_`{|}~「」『』")
symbols += "<< >> <<< >>> -- --- -( -[ (' (\" (( )) ((( ))) [[ ]] {{ }} ♪♪ ♪♪♪".split(
)
# symbols that may be a single token or multiple tokens depending on the tokenizer.
# In case they're multiple tokens, suppress the first token, which is safe because:
# These are between U+2640 and U+267F miscellaneous symbols that are okay to suppress
# in generations, and in the 3-byte UTF-8 representation they share the first two bytes.
miscellaneous = set("♩♪♫♬♭♮♯")
assert all(0x2640 <= ord(c) <= 0x267F for c in miscellaneous)
# allow hyphens "-" and single quotes "'" between words, but not at the beginning of a word
result = {
self.tokenizer.encode(" -").input_ids[0],
self.tokenizer.encode(" '").input_ids[0]
}
for symbol in symbols + list(miscellaneous):
for tokens in [
self.tokenizer.encode(symbol).input_ids,
self.tokenizer.encode(" " + symbol).input_ids
]:
if len(tokens) == 1 or symbol in miscellaneous:
result.add(tokens[0])
return tuple(sorted(result))
def _get_single_token_id(self, text) -> int:
tokens = self.tokenizer.encode(text).input_ids
assert len(tokens) == 1, f"{text} is not encoded as a single token"
return tokens[0]
@lru_cache(maxsize=None)
def build_tokenizer(name: str="gpt2"):
os.environ["TOKENIZERS_PARALLELISM"] = "false"
path = os.path.join(os.path.dirname(__file__), "assets", name)
tokenizer = GPTTokenizer.from_pretrained(path)
specials = [
"<|startoftranscript|>",
* [f"<|{lang}|>" for lang in LANGUAGES.keys()],
"<|translate|>",
"<|transcribe|>",
"<|startoflm|>",
"<|startofprev|>",
"<|nospeech|>",
"<|notimestamps|>",
]
tokenizer.add_special_tokens(dict(additional_special_tokens=specials))
return tokenizer
@lru_cache(maxsize=None)
def get_tokenizer(
multilingual: bool,
*,
task: Optional[str]=None, # Literal["transcribe", "translate", None]
language: Optional[str]=None, ) -> Tokenizer:
if language is not None:
language = language.lower()
if language not in LANGUAGES:
if language in TO_LANGUAGE_CODE:
language = TO_LANGUAGE_CODE[language]
else:
raise ValueError(f"Unsupported language: {language}")
if multilingual:
tokenizer_name = "multilingual"
task = task or "transcribe"
language = language or "en"
else:
tokenizer_name = "gpt2"
task = None
language = None
tokenizer = build_tokenizer(name=tokenizer_name)
all_special_ids: List[int] = tokenizer.all_special_ids
sot: int = all_special_ids[1]
translate: int = all_special_ids[-6]
transcribe: int = all_special_ids[-5]
langs = tuple(LANGUAGES.keys())
sot_sequence = [sot]
if language is not None:
sot_sequence.append(sot + 1 + langs.index(language))
if task is not None:
sot_sequence.append(transcribe if task == "transcribe" else translate)
return Tokenizer(
tokenizer=tokenizer,
language=language,
sot_sequence=tuple(sot_sequence))

@ -0,0 +1,92 @@
# MIT License, Copyright (c) 2022 OpenAI.
# Copyright (c) 2022 PaddlePaddle Authors and . All Rights Reserved.
#
# Modified from OpenAI Whisper 2022 (https://github.com/openai/whisper/whisper/utils.py)
import zlib
from typing import Iterator
from typing import TextIO
def exact_div(x, y):
assert x % y == 0
return x // y
def str2bool(string):
str2val = {"True": True, "False": False}
if string in str2val:
return str2val[string]
else:
raise ValueError(f"Expected one of {set(str2val.keys())}, got {string}")
def optional_int(string):
return None if string == "None" else int(string)
def optional_float(string):
return None if string == "None" else float(string)
def compression_ratio(text) -> float:
return len(text) / len(zlib.compress(text.encode("utf-8")))
def format_timestamp(seconds: float,
always_include_hours: bool=False,
decimal_marker: str='.'):
assert seconds >= 0, "non-negative timestamp expected"
milliseconds = round(seconds * 1000.0)
hours = milliseconds // 3_600_000
milliseconds -= hours * 3_600_000
minutes = milliseconds // 60_000
milliseconds -= minutes * 60_000
seconds = milliseconds // 1_000
milliseconds -= seconds * 1_000
hours_marker = f"{hours:02d}:" if always_include_hours or hours > 0 else ""
return f"{hours_marker}{minutes:02d}:{seconds:02d}{decimal_marker}{milliseconds:03d}"
def write_txt(transcript: Iterator[dict], file: TextIO):
for segment in transcript:
print(segment['text'].strip(), file=file, flush=True)
def write_vtt(transcript: Iterator[dict], file: TextIO):
print("WEBVTT\n", file=file)
for segment in transcript:
print(
f"{format_timestamp(segment['start'])} --> {format_timestamp(segment['end'])}\n"
f"{segment['text'].strip().replace('-->', '->')}\n",
file=file,
flush=True, )
def write_srt(transcript: Iterator[dict], file: TextIO):
"""
Write a transcript to a file in SRT format.
Example usage:
from pathlib import Path
from whisper.utils import write_srt
result = transcribe(model, audio_path, temperature=temperature, **args)
# save SRT
audio_basename = Path(audio_path).stem
with open(Path(output_dir) / (audio_basename + ".srt"), "w", encoding="utf-8") as srt:
write_srt(result["segments"], file=srt)
"""
for i, segment in enumerate(transcript, start=1):
# write srt lines
print(
f"{i}\n"
f"{format_timestamp(segment['start'], always_include_hours=True, decimal_marker=',')} --> "
f"{format_timestamp(segment['end'], always_include_hours=True, decimal_marker=',')}\n"
f"{segment['text'].strip().replace('-->', '->')}\n",
file=file,
flush=True, )

File diff suppressed because it is too large Load Diff

@ -0,0 +1,21 @@
MIT License
Copyright (c) 2022 OpenAI
Permission is hereby granted, free of charge, to any person obtaining a copy
of this software and associated documentation files (the "Software"), to deal
in the Software without restriction, including without limitation the rights
to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
copies of the Software, and to permit persons to whom the Software is
furnished to do so, subject to the following conditions:
The above copyright notice and this permission notice shall be included in all
copies or substantial portions of the Software.
THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE
SOFTWARE.

@ -94,5 +94,11 @@ paddlespeech stats --task text
paddlespeech stats --task vector paddlespeech stats --task vector
paddlespeech stats --task st paddlespeech stats --task st
# whisper text recognize
paddlespeech whisper --task transcribe --input ./zh.wav
# whisper recognize text and translate to English
paddlespeech whisper --task translate --input ./zh.wav
echo -e "\033[32mTest success !!!\033[0m" echo -e "\033[32mTest success !!!\033[0m"

Loading…
Cancel
Save