diff --git a/demos/whisper/README.md b/demos/whisper/README.md new file mode 100644 index 000000000..d332fca13 --- /dev/null +++ b/demos/whisper/README.md @@ -0,0 +1,89 @@ +([简体中文](./README_cn.md)|English) + +## Introduction +Whisper is a general-purpose speech recognition model. It is trained on a large dataset of diverse audio and is also a multi-task model that can perform multilingual speech recognition as well as speech translation and language identification. + +Whisper model trained by OpenAI whisper https://github.com/openai/whisper + +## Usage + ### 1. Installation + see [installation](https://github.com/PaddlePaddle/PaddleSpeech/blob/develop/docs/source/install.md). + + You can choose one way from easy, meduim and hard to install paddlespeech. + + ### 2. Prepare Input File + The input of this demo should be a WAV file(`.wav`), and the sample rate must be the same as the model. + + Here are sample files for this demo that can be downloaded: + ```bash + wget -c https://paddlespeech.bj.bcebos.com/PaddleAudio/zh.wav + ``` + + ### 3. Usage + - Command Line(Recommended) + ```bash + # to recognize text + paddlespeech whisper --task transcribe --input ./zh.wav + + # to recognize text and translate to English + paddlespeech whisper --task translate --input ./zh.wav + ``` + + Usage: + ```bash + paddlespeech whisper --help + ``` + Arguments: + - `input`(required): Audio file to recognize. + - `model`: Model type of asr task. Default: `whisper-large`. + - `task`: Output type. Default: `transcribe`. + - `lang`: Model language. Default: `None`. Forcibly set the recognized language, which is determined by the model itself by default. + - `sample_rate`: Sample rate of the model. Default: `16000`. Other sampling rates are not supported now. + - `config`: Config of asr task. Use pretrained model when it is None. Default: `None`. + - `ckpt_path`: Model checkpoint. Use pretrained model when it is None. Default: `None`. + - `yes`: No additional parameters required. Once set this parameter, it means accepting the request of the program by default, which includes transforming the audio sample rate. Default: `False`. + - `device`: Choose device to execute model inference. Default: default device of paddlepaddle in current environment. + - `verbose`: Show the log information. + + + - Python API + ```python + import paddle + from paddlespeech.cli.whisper import WhisperExecutor + + whisper_executor = WhisperExecutor() + + # to recognize text + text = whisper_executor( + model='whisper-large', + task='transcribe', + sample_rate=16000, + config=None, # Set `config` and `ckpt_path` to None to use pretrained model. + ckpt_path=None, + audio_file='./zh.wav', + device=paddle.get_device()) + print('ASR Result: \n{}'.format(text)) + + # to recognize text and translate to English + feature = ssl_executor( + model='whisper-large', + task='translate', + sample_rate=16000, + config=None, # Set `config` and `ckpt_path` to None to use pretrained model. + ckpt_path=None, + audio_file='./zh.wav', + device=paddle.get_device()) + print('Representation: \n{}'.format(feature)) + ``` + + Output: + ```bash + Transcribe Result: + Detected language: Chinese + [00:00.000 --> 00:05.000] 我认为跑步最重要的就是给我带来了身体健康 + {'text': '我认为跑步最重要的就是给我带来了身体健康', 'segments': [{'id': 0, 'seek': 0, 'start': 0.0, 'end': 5.0, 'text': '我认为跑步最重要的就是给我带来了身体健康', 'tokens': [50364, 1654, 7422, 97, 13992, 32585, 31429, 8661, 24928, 1546, 5620, 49076, 4845, 99, 34912, 19847, 29485, 44201, 6346, 115, 50614], 'temperature': 0.0, 'avg_logprob': -0.23577967557040128, 'compression_ratio': 0.28169014084507044, 'no_speech_prob': 0.028302080929279327}], 'language': 'zh'} + + Translate Result: + Detected language: Chinese + [00:00.000 --> 00:05.000] I think the most important thing about running is that it brings me good health. + {'text': ' I think the most important thing about running is that it brings me good health.', 'segments': [{'id': 0, 'seek': 0, 'start': 0.0, 'end': 5.0, 'text': ' I think the most important thing about running is that it brings me good health.', 'tokens': [50364, 286, 519, 264, 881, 1021, 551, 466, 2614, 307, 300, 309, 5607, 385, 665, 1585, 13, 50614], 'temperature': 0.0, 'avg_logprob': -0.47945233395225123, 'compression_ratio': 1.095890410958904, 'no_speech_prob': 0.028302080929279327}], 'language': 'zh'} \ No newline at end of file diff --git a/demos/whisper/README_cn.md b/demos/whisper/README_cn.md new file mode 100644 index 000000000..e559ba041 --- /dev/null +++ b/demos/whisper/README_cn.md @@ -0,0 +1,91 @@ +(简体中文|[English](./README.md)) + +# Whisper模型 +## 介绍 +Whisper是一种通用的语音识别模型。它是在多种音频的大数据集上训练的,也是一个多任务模型,可以执行多语言语音识别以及语音翻译和语言识别。 + +Whisper模型由OpenAI Whisper训练 https://github.com/openai/whisper + +## 使用方法 +### 1. 安装 + 请看[安装文档](https://github.com/PaddlePaddle/PaddleSpeech/blob/develop/docs/source/install_cn.md)。 + + 你可以从 easy,medium,hard 三中方式中选择一种方式安装。 + +### 2. 准备输入 + 这个 demo 的输入应该是一个 WAV 文件(`.wav`),并且采样率必须与模型的采样率相同。 + + 可以下载此 demo 的示例音频: + ```bash + wget -c https://paddlespeech.bj.bcebos.com/PaddleAudio/zh.wav + ``` + +### 3. 使用方法 + - 命令行 (推荐使用) + ```bash + + # 识别文本 + paddlespeech whisper --task transcribe --input ./zh.wav + + # 将语音翻译成英语 + paddlespeech whisper --task translate --input ./zh.wav + ``` + 使用方法: + ```bash + paddlespeech whisper --help + ``` + 参数: + - `input`(必须输入):用于识别的音频文件。 + - `model`:ASR 任务的模型,默认值:`whisper-large`。 + - `task`:输出类别,默认值:`transcribe`。 + - `lang`:模型语言,默认值:`None`,强制设定识别出的语言,默认为模型自行判定。 + - `sample_rate`:音频采样率,默认值:`16000`,目前Whisper暂不支持其他采样率。 + - `config`:ASR 任务的参数文件,若不设置则使用预训练模型中的默认配置,默认值:`None`。 + - `ckpt_path`:模型参数文件,若不设置则下载解码模型使用,默认值:`None`。 + - `yes`;不需要设置额外的参数,一旦设置了该参数,说明你默认同意程序的所有请求,其中包括自动转换输入音频的采样率。默认值:`False`。 + - `device`:执行预测的设备,默认值:当前系统下 paddlepaddle 的默认 device。 + - `verbose`: 如果使用,显示 logger 信息。 + + +- Python API + ```python + import paddle + from paddlespeech.cli.whisper import WhisperExecutor + + whisper_executor = WhisperExecutor() + + # 识别文本 + text = whisper_executor( + model='whisper-large', + task='transcribe', + sample_rate=16000, + config=None, # Set `config` and `ckpt_path` to None to use pretrained model. + ckpt_path=None, + audio_file='./zh.wav', + device=paddle.get_device()) + print('ASR Result: \n{}'.format(text)) + + # 将语音翻译成英语 + feature = ssl_executor( + model='whisper-large', + task='translate', + sample_rate=16000, + config=None, # Set `config` and `ckpt_path` to None to use pretrained model. + ckpt_path=None, + audio_file='./zh.wav', + device=paddle.get_device()) + print('Representation: \n{}'.format(feature)) + ``` + + + 输出: + ```bash + Transcribe Result: + Detected language: Chinese + [00:00.000 --> 00:05.000] 我认为跑步最重要的就是给我带来了身体健康 + {'text': '我认为跑步最重要的就是给我带来了身体健康', 'segments': [{'id': 0, 'seek': 0, 'start': 0.0, 'end': 5.0, 'text': '我认为跑步最重要的就是给我带来了身体健康', 'tokens': [50364, 1654, 7422, 97, 13992, 32585, 31429, 8661, 24928, 1546, 5620, 49076, 4845, 99, 34912, 19847, 29485, 44201, 6346, 115, 50614], 'temperature': 0.0, 'avg_logprob': -0.23577967557040128, 'compression_ratio': 0.28169014084507044, 'no_speech_prob': 0.028302080929279327}], 'language': 'zh'} + + Translate Result: + Detected language: Chinese + [00:00.000 --> 00:05.000] I think the most important thing about running is that it brings me good health. + {'text': ' I think the most important thing about running is that it brings me good health.', 'segments': [{'id': 0, 'seek': 0, 'start': 0.0, 'end': 5.0, 'text': ' I think the most important thing about running is that it brings me good health.', 'tokens': [50364, 286, 519, 264, 881, 1021, 551, 466, 2614, 307, 300, 309, 5607, 385, 665, 1585, 13, 50614], 'temperature': 0.0, 'avg_logprob': -0.47945233395225123, 'compression_ratio': 1.095890410958904, 'no_speech_prob': 0.028302080929279327}], 'language': 'zh'} \ No newline at end of file diff --git a/demos/whisper/run.sh b/demos/whisper/run.sh new file mode 100644 index 000000000..095743bb0 --- /dev/null +++ b/demos/whisper/run.sh @@ -0,0 +1,10 @@ +#!/bin/bash + +# audio download +wget -c https://paddlespeech.bj.bcebos.com/PaddleAudio/zh.wav + +# to recognize text +paddlespeech whisper --task transcribe --input ./zh.wav + +# to recognize text and transcribe to English +paddlespeech whisper --task translate --input ./zh.wav \ No newline at end of file diff --git a/paddlespeech/cli/base_commands.py b/paddlespeech/cli/base_commands.py index 7210091a9..7551b6c02 100644 --- a/paddlespeech/cli/base_commands.py +++ b/paddlespeech/cli/base_commands.py @@ -83,7 +83,8 @@ model_name_format = { 'st': 'Model-Source language-Target language', 'text': 'Model-Task-Language', 'tts': 'Model-Language', - 'vector': 'Model-Sample Rate' + 'vector': 'Model-Sample Rate', + 'whisper': 'Model-Language-Sample Rate' } @@ -94,7 +95,9 @@ class StatsCommand: def __init__(self): self.parser = argparse.ArgumentParser( prog='paddlespeech.stats', add_help=True) - self.task_choices = ['asr', 'cls', 'st', 'text', 'tts', 'vector', 'kws'] + self.task_choices = [ + 'asr', 'cls', 'st', 'text', 'tts', 'vector', 'kws', 'whisper' + ] self.parser.add_argument( '--task', type=str, @@ -141,6 +144,10 @@ _commands = { 'tts': ['Text to Speech infer command.', 'TTSExecutor'], 'vector': ['Speech to vector embedding infer command.', 'VectorExecutor'], 'kws': ['Keyword Spotting infer command.', 'KWSExecutor'], + 'whisper': [ + 'Whisper model for speech to text or translate speech to English.', + 'WhisperExecutor' + ] } for com, info in _commands.items(): diff --git a/paddlespeech/cli/whisper/__init__.py b/paddlespeech/cli/whisper/__init__.py new file mode 100644 index 000000000..3bafc10d2 --- /dev/null +++ b/paddlespeech/cli/whisper/__init__.py @@ -0,0 +1,14 @@ +# Copyright (c) 2021 PaddlePaddle Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +from .infer import WhisperExecutor diff --git a/paddlespeech/cli/whisper/infer.py b/paddlespeech/cli/whisper/infer.py new file mode 100644 index 000000000..ada888bfc --- /dev/null +++ b/paddlespeech/cli/whisper/infer.py @@ -0,0 +1,473 @@ +# Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +import argparse +import io +import os +import sys +import time +from collections import OrderedDict +from typing import List +from typing import Optional +from typing import Union + +import librosa +import numpy as np +import paddle +import soundfile +from yacs.config import CfgNode + +from ..download import get_path_from_url +from ..executor import BaseExecutor +from ..log import logger +from ..utils import CLI_TIMER +from ..utils import stats_wrapper +from ..utils import timer_register +from paddlespeech.s2t.models.whisper import log_mel_spectrogram +from paddlespeech.s2t.models.whisper import ModelDimensions +from paddlespeech.s2t.models.whisper import Whisper +from paddlespeech.s2t.utils.utility import UpdateConfig + +__all__ = ['WhisperExecutor'] + + +@timer_register +class WhisperExecutor(BaseExecutor): + def __init__(self): + super().__init__('whisper') + self.parser = argparse.ArgumentParser( + prog='paddlespeech.whisper', add_help=True) + self.parser.add_argument( + '--input', type=str, default=None, help='Audio file to recognize.') + self.parser.add_argument( + '--model', + type=str, + default='whisper', + choices=[ + tag[:tag.index('-')] + for tag in self.task_resource.pretrained_models.keys() + ], + help='Choose model type of asr task.') + self.parser.add_argument( + '--lang', + type=str, + default='None', + help='Choose model decode language. Default is None, recognized by model.' + ) + self.parser.add_argument( + '--task', + type=str, + default='transcribe', + choices=["transcribe", "translate"], + help='Choose task tpye for transcribe or translate.') + self.parser.add_argument( + '--size', + type=str, + default='large', + help='Choose model size. now only support large, large:[whisper-large-16k]' + ) + self.parser.add_argument( + "--sample_rate", + type=int, + default=16000, + choices=[16000], + help='Choose the audio sample rate of the model. only support 16000') + self.parser.add_argument( + '--config', + type=str, + default=None, + help='Config of asr task. Use deault config when it is None.') + self.parser.add_argument( + '--decode_method', + type=str, + default='ctc_prefix_beam_search', + choices=['ctc_greedy_search', 'ctc_prefix_beam_search'], + help='only support transformer and conformer model') + self.parser.add_argument( + '--ckpt_path', + type=str, + default=None, + help='Checkpoint file of model.') + self.parser.add_argument( + '--yes', + '-y', + action="store_true", + default=False, + help='No additional parameters required. \ + Once set this parameter, it means accepting the request of the program by default, \ + which includes transforming the audio sample rate') + self.parser.add_argument( + '--rtf', + action="store_true", + default=False, + help='Show Real-time Factor(RTF).') + self.parser.add_argument( + '--device', + type=str, + default=paddle.get_device(), + help='Choose device to execute model inference.') + self.parser.add_argument( + '-d', + '--job_dump_result', + action='store_true', + help='Save job result into file.') + self.parser.add_argument( + '-v', + '--verbose', + action='store_true', + help='Increase logger verbosity of current task.') + + def _init_from_path(self, + model_type: str='whisper', + lang: str='None', + task: str='transcribe', + size: str='large', + sample_rate: int=16000, + cfg_path: Optional[os.PathLike]=None, + decode_method: str='ctc_prefix_beam_search', + num_decoding_left_chunks: int=-1, + ckpt_path: Optional[os.PathLike]=None): + """ + Init model and other resources from a specific path. + """ + logger.debug("start to init the model") + # default max_len: unit:second + self.max_len = 50 + if hasattr(self, 'model'): + logger.debug('Model had been initialized.') + return + + if cfg_path is None or ckpt_path is None: + sample_rate_str = '16k' if sample_rate == 16000 else '8k' + tag = model_type + '-' + size + '-' + sample_rate_str + self.task_resource.set_task_model(tag, version=None) + self.res_path = self.task_resource.res_dir + + self.cfg_path = os.path.join( + self.res_path, self.task_resource.res_dict['cfg_path']) + self.ckpt_path = os.path.join( + self.res_path, + self.task_resource.res_dict['ckpt_path'] + ".pdparams") + logger.debug(self.res_path) + + else: + self.cfg_path = os.path.abspath(cfg_path) + self.ckpt_path = os.path.abspath(ckpt_path + ".pdparams") + self.res_path = os.path.dirname( + os.path.dirname(os.path.abspath(self.cfg_path))) + logger.debug(self.cfg_path) + logger.debug(self.ckpt_path) + + #Init body. + self.config = CfgNode(new_allowed=True) + self.config.merge_from_file(self.cfg_path) + + with UpdateConfig(self.config): + if "whisper" in model_type: + resource_url = self.task_resource.res_dict['resuource_data'] + resource_md5 = self.task_resource.res_dict['resuource_data_md5'] + resuource_path = self.task_resource.res_dict['resuource_path'] + self.download_resource(resource_url, resuource_path, + resource_md5) + else: + raise Exception("wrong type") + #model_name = model_type + #model_class = self.task_resource.get_model_class(model_name) + #model_conf = self.config + #model = model_class.from_config(model_conf) + #self.model = model + #self.model.eval() + + # load model + model_dict = paddle.load(self.ckpt_path) + dims = ModelDimensions(**model_dict["dims"]) + self.model = Whisper(dims) + self.model.load_dict(model_dict) + + #set task + if task is not None: + self.task = task + + #set language + if lang is not None: + self.language = lang + + def preprocess(self, model_type: str, input: Union[str, os.PathLike]): + """ + Input preprocess and return paddle.Tensor stored in self.input. + Input content can be a text(tts), a file(asr, cls) or a streaming(not supported yet). + """ + + audio_file = input + if isinstance(audio_file, (str, os.PathLike)): + logger.debug("Preprocess audio_file:" + audio_file) + elif isinstance(audio_file, io.BytesIO): + audio_file.seek(0) + + # Get the object for feature extraction + # whisper hard-coded audio hyperparameters, params in paddlespeech/s2t/models/whisper/whisper.py + logger.debug("read the audio file") + audio, audio_sample_rate = soundfile.read( + audio_file, dtype="float32", always_2d=True) + if self.change_format: + if audio.shape[1] >= 2: + audio = audio.mean(axis=1, dtype=np.int16) + else: + audio = audio[:, 0] + # pcm16 -> pcm 32 + audio = self._pcm16to32(audio) + audio = librosa.resample( + audio, orig_sr=audio_sample_rate, target_sr=self.sample_rate) + audio_sample_rate = self.sample_rate + # pcm32 -> pcm 16 + audio = self._pcm32to16(audio) + else: + audio = audio[:, 0] + + logger.debug(f"audio shape: {audio.shape}") + # fbank + audio = log_mel_spectrogram(audio) + + audio_len = paddle.to_tensor(audio.shape[0]) + #audio = paddle.to_tensor(audio, dtype='float32').unsqueeze(axis=0) + + self._inputs["audio"] = audio + self._inputs["audio_len"] = audio_len + logger.debug(f"audio feat shape: {audio.shape}") + + logger.debug("audio feat process success") + + @paddle.no_grad() + def infer(self, model_type: str): + """ + Model inference and result stored in self.output. + """ + logger.debug("start to infer the model to get the output") + cfg = self.config + audio = self._inputs["audio"] + if cfg.temperature_increment_on_fallback is not None: + temperature = tuple( + np.arange(cfg.temperature, 1.0 + 1e-6, + cfg.temperature_increment_on_fallback)) + else: + temperature = [cfg.temperature] + + self._outputs["result"] = self.model.transcribe( + audio, + verbose=cfg.verbose, + task=self.task, + language=self.language, + temperature=temperature, + compression_ratio_threshold=cfg.compression_ratio_threshold, + logprob_threshold=cfg.logprob_threshold, + best_of=cfg.best_of, + beam_size=cfg.beam_size, + patience=cfg.patience, + length_penalty=cfg.length_penalty, + initial_prompt=cfg.initial_prompt, + condition_on_previous_text=cfg.condition_on_previous_text, + no_speech_threshold=cfg.no_speech_threshold) + + def postprocess(self) -> Union[str, os.PathLike]: + """ + Output postprocess and return human-readable results such as texts and audio files. + """ + return self._outputs["result"] + + def download_resource(self, url, lm_dir, md5sum): + download_path = get_path_from_url( + url=url, + root_dir=lm_dir, + md5sum=md5sum, + decompress=True, ) + + def _pcm16to32(self, audio): + assert (audio.dtype == np.int16) + audio = audio.astype("float32") + bits = np.iinfo(np.int16).bits + audio = audio / (2**(bits - 1)) + return audio + + def _pcm32to16(self, audio): + assert (audio.dtype == np.float32) + bits = np.iinfo(np.int16).bits + audio = audio * (2**(bits - 1)) + audio = np.round(audio).astype("int16") + return audio + + def _check(self, audio_file: str, sample_rate: int, force_yes: bool=False): + self.sample_rate = sample_rate + if self.sample_rate != 16000 and self.sample_rate != 8000: + logger.error( + "invalid sample rate, please input --sr 8000 or --sr 16000") + return False + + if isinstance(audio_file, (str, os.PathLike)): + if not os.path.isfile(audio_file): + logger.error("Please input the right audio file path") + return False + elif isinstance(audio_file, io.BytesIO): + audio_file.seek(0) + + logger.debug("checking the audio file format......") + try: + audio, audio_sample_rate = soundfile.read( + audio_file, dtype="int16", always_2d=True) + audio_duration = audio.shape[0] / audio_sample_rate + if audio_duration > self.max_len: + logger.error( + f"Please input audio file less then {self.max_len} seconds.\n" + ) + return False + except Exception as e: + logger.exception(e) + logger.error( + f"can not open the audio file, please check the audio file({audio_file}) format is 'wav'. \n \ + you can try to use sox to change the file format.\n \ + For example: \n \ + sample rate: 16k \n \ + sox input_audio.xx --rate 16k --bits 16 --channels 1 output_audio.wav \n \ + sample rate: 8k \n \ + sox input_audio.xx --rate 8k --bits 16 --channels 1 output_audio.wav \n \ + ") + return False + logger.debug("The sample rate is %d" % audio_sample_rate) + if audio_sample_rate != self.sample_rate: + logger.warning("The sample rate of the input file is not {}.\n \ + The program will resample the wav file to {}.\n \ + If the result does not meet your expectations,\n \ + Please input the 16k 16 bit 1 channel wav file. \ + ".format(self.sample_rate, self.sample_rate)) + if force_yes is False: + while (True): + logger.debug( + "Whether to change the sample rate and the channel. Y: change the sample. N: exit the prgream." + ) + content = input("Input(Y/N):") + if content.strip() == "Y" or content.strip( + ) == "y" or content.strip() == "yes" or content.strip( + ) == "Yes": + logger.debug( + "change the sampele rate, channel to 16k and 1 channel" + ) + break + elif content.strip() == "N" or content.strip( + ) == "n" or content.strip() == "no" or content.strip( + ) == "No": + logger.debug("Exit the program") + return False + else: + logger.warning("Not regular input, please input again") + + self.change_format = True + else: + logger.debug("The audio file format is right") + self.change_format = False + + return True + + def execute(self, argv: List[str]) -> bool: + """ + Command line entry. + """ + parser_args = self.parser.parse_args(argv) + + model = parser_args.model + lang = parser_args.lang + task = parser_args.task + size = parser_args.size + sample_rate = parser_args.sample_rate + config = parser_args.config + ckpt_path = parser_args.ckpt_path + decode_method = parser_args.decode_method + force_yes = parser_args.yes + rtf = parser_args.rtf + device = parser_args.device + + if not parser_args.verbose: + self.disable_task_loggers() + + task_source = self.get_input_source(parser_args.input) + task_results = OrderedDict() + has_exceptions = False + + for id_, input_ in task_source.items(): + try: + res = self( + audio_file=input_, + model=model, + lang=lang, + task=task, + size=size, + sample_rate=sample_rate, + config=config, + ckpt_path=ckpt_path, + decode_method=decode_method, + force_yes=force_yes, + rtf=rtf, + device=device) + task_results[id_] = res + except Exception as e: + has_exceptions = True + task_results[id_] = f'{e.__class__.__name__}: {e}' + + if rtf: + self.show_rtf(CLI_TIMER[self.__class__.__name__]) + + self.process_task_results(parser_args.input, task_results, + parser_args.job_dump_result) + + if has_exceptions: + return False + else: + return True + + @stats_wrapper + def __call__(self, + audio_file: os.PathLike, + model: str='whisper', + lang: str='None', + task: str='transcribe', + size: str='large', + sample_rate: int=16000, + config: os.PathLike=None, + ckpt_path: os.PathLike=None, + decode_method: str='attention_rescoring', + num_decoding_left_chunks: int=-1, + force_yes: bool=False, + rtf: bool=False, + device=paddle.get_device()): + """ + Python API to call an executor. + """ + audio_file = os.path.abspath(audio_file) + paddle.set_device(device) + self._init_from_path(model, lang, task, size, sample_rate, config, + decode_method, num_decoding_left_chunks, ckpt_path) + if not self._check(audio_file, sample_rate, force_yes): + sys.exit(-1) + if rtf: + k = self.__class__.__name__ + CLI_TIMER[k]['start'].append(time.time()) + + self.preprocess(model, audio_file) + self.infer(model) + res = self.postprocess() # Retrieve result of asr. + + if rtf: + CLI_TIMER[k]['end'].append(time.time()) + audio, audio_sample_rate = soundfile.read( + audio_file, dtype="int16", always_2d=True) + CLI_TIMER[k]['extra'].append(audio.shape[0] / audio_sample_rate) + + return res diff --git a/paddlespeech/resource/model_alias.py b/paddlespeech/resource/model_alias.py index 8e9ecc4ba..ce7fa662f 100644 --- a/paddlespeech/resource/model_alias.py +++ b/paddlespeech/resource/model_alias.py @@ -29,6 +29,11 @@ model_alias = { "transformer": ["paddlespeech.s2t.models.u2:U2Model"], "wenetspeech": ["paddlespeech.s2t.models.u2:U2Model"], + # --------------------------------- + # ------------ Whisper ------------ + # --------------------------------- + "whisper": ["paddlespeech.s2t.models.whisper:Whisper"], + # --------------------------------- # -------------- CLS -------------- # --------------------------------- diff --git a/paddlespeech/resource/pretrained_models.py b/paddlespeech/resource/pretrained_models.py index df50a6a9d..b83d66f26 100644 --- a/paddlespeech/resource/pretrained_models.py +++ b/paddlespeech/resource/pretrained_models.py @@ -25,6 +25,7 @@ __all__ = [ 'tts_static_pretrained_models', 'tts_onnx_pretrained_models', 'vector_dynamic_pretrained_models', + 'whisper_dynamic_pretrained_models', ] # The tags for pretrained_models should be "{model_name}[_{dataset}][-{lang}][-...]". @@ -424,6 +425,31 @@ asr_onnx_pretrained_models = { }, } +whisper_dynamic_pretrained_models = { + "whisper-large-16k": { + '1.3': { + 'url': + 'https://paddlespeech.bj.bcebos.com/whisper/whisper_model_20221108/whisper-large-model.tar.gz', + 'md5': + '364c4d670835e5ca489045e1c29d75fe', + 'cfg_path': + 'whisper.yaml', + 'ckpt_path': + 'whisper-large-model', + 'model': + 'whisper-large-model.pdparams', + 'params': + 'whisper-large-model.pdparams', + 'resuource_data': + 'https://paddlespeech.bj.bcebos.com/whisper/whisper_model_20221108/assets.tar', + 'resuource_data_md5': + '37a0a8abdb3641a51194f79567a93b61', + 'resuource_path': + 'paddlespeech/s2t/models/whisper', + }, + }, +} + # --------------------------------- # -------------- CLS -------------- # --------------------------------- diff --git a/paddlespeech/resource/resource.py b/paddlespeech/resource/resource.py index 8e9914b2e..d3d89f4de 100644 --- a/paddlespeech/resource/resource.py +++ b/paddlespeech/resource/resource.py @@ -22,7 +22,7 @@ from ..utils.dynamic_import import dynamic_import from ..utils.env import MODEL_HOME from .model_alias import model_alias -task_supported = ['asr', 'cls', 'st', 'text', 'tts', 'vector', 'kws'] +task_supported = ['asr', 'cls', 'st', 'text', 'tts', 'vector', 'kws', 'whisper'] model_format_supported = ['dynamic', 'static', 'onnx'] inference_mode_supported = ['online', 'offline'] diff --git a/paddlespeech/s2t/exps/whisper/test_wav.py b/paddlespeech/s2t/exps/whisper/test_wav.py index 95f7ac051..63945b9eb 100644 --- a/paddlespeech/s2t/exps/whisper/test_wav.py +++ b/paddlespeech/s2t/exps/whisper/test_wav.py @@ -21,6 +21,7 @@ import paddle import soundfile from yacs.config import CfgNode +from paddlespeech.s2t.models.whisper import log_mel_spectrogram from paddlespeech.s2t.models.whisper import ModelDimensions from paddlespeech.s2t.models.whisper import transcribe from paddlespeech.s2t.models.whisper import Whisper @@ -60,12 +61,14 @@ class WhisperInfer(): else: temperature = [temperature] + #load audio + mel = log_mel_spectrogram(args.audio) + result = transcribe( - self.model, args.audio_file, temperature=temperature, **config) + self.model, mel, temperature=temperature, **config) if args.result_file is not None: with open(args.result_file, 'w') as f: f.write(str(result)) - print("result", result) return result diff --git a/paddlespeech/s2t/models/whisper/__init__.py b/paddlespeech/s2t/models/whisper/__init__.py index 98ab23610..1c8adba56 100644 --- a/paddlespeech/s2t/models/whisper/__init__.py +++ b/paddlespeech/s2t/models/whisper/__init__.py @@ -2,11 +2,4 @@ # Copyright (c) 2022 PaddlePaddle Authors and . All Rights Reserved. # # Modified from OpenAI Whisper 2022 (https://github.com/openai/whisper/whisper/__init__.py) -from paddlespeech.s2t.models.whisper.whipser import decode -from paddlespeech.s2t.models.whisper.whipser import DecodingOptions -from paddlespeech.s2t.models.whisper.whipser import DecodingResult -from paddlespeech.s2t.models.whisper.whipser import detect_language -from paddlespeech.s2t.models.whisper.whipser import log_mel_spectrogram -from paddlespeech.s2t.models.whisper.whipser import ModelDimensions -from paddlespeech.s2t.models.whisper.whipser import transcribe -from paddlespeech.s2t.models.whisper.whipser import Whisper +from .whipser import * diff --git a/paddlespeech/s2t/models/whisper/tokenizer.py b/paddlespeech/s2t/models/whisper/tokenizer.py index 39c07073f..1c58c94c7 100644 --- a/paddlespeech/s2t/models/whisper/tokenizer.py +++ b/paddlespeech/s2t/models/whisper/tokenizer.py @@ -13,7 +13,6 @@ from typing import Union import numpy as np import paddle from paddlenlp.transformers import GPTTokenizer -#from transformers import GPT2TokenizerFast LANGUAGES = { "en": "english", diff --git a/paddlespeech/s2t/models/whisper/whipser.py b/paddlespeech/s2t/models/whisper/whipser.py index 3b2eed8d6..5a495f634 100644 --- a/paddlespeech/s2t/models/whisper/whipser.py +++ b/paddlespeech/s2t/models/whisper/whipser.py @@ -1,7 +1,7 @@ # MIT License, Copyright (c) 2022 OpenAI. # Copyright (c) 2022 PaddlePaddle Authors and . All Rights Reserved. # -# Modified from OpenAI Whisper 2022 (https://github.com/openai/whisper/whisper/__init__.py) +# Modified from OpenAI Whisper 2022 (https://github.com/openai/whisper/whisper) import os from dataclasses import dataclass from dataclasses import field @@ -418,7 +418,7 @@ def detect_language(model: "Whisper", def transcribe( model: "Whisper", - audio: Union[str, np.ndarray, paddle.Tensor], + mel: paddle.Tensor, *, verbose: Optional[bool]=None, temperature: Union[float, Tuple[float, ...]]=(0.0, 0.2, 0.4, 0.6, 0.8, @@ -436,8 +436,8 @@ def transcribe( model: Whisper The Whisper model instance - audio: Union[str, np.ndarray, torch.Tensor] - The path to the audio file to open, or the audio waveform + mel: torch.Tensor + The audio feature verbose: bool Whether to display the text being decoded to the console. If True, displays all the details, @@ -475,8 +475,6 @@ def transcribe( if dtype == np.float32: decode_options["fp16"] = False - mel = log_mel_spectrogram(audio) - if decode_options.get("language", None) is None: if not model.is_multilingual: decode_options["language"] = "en" @@ -1193,9 +1191,8 @@ class DecodingTask: DecodingResult( audio_features=features, language=language, - language_probs=probs) - for features, language, probs in zip(audio_features, languages, - language_probs) + language_probs=probs) for features, language, probs in + zip(audio_features, languages, language_probs) ] # repeat the audio & text tensors by the group size, for beam search or best-of-n sampling diff --git a/tests/unit/cli/test_cli.sh b/tests/unit/cli/test_cli.sh index d571aa78f..644008d52 100755 --- a/tests/unit/cli/test_cli.sh +++ b/tests/unit/cli/test_cli.sh @@ -93,5 +93,11 @@ paddlespeech stats --task text paddlespeech stats --task vector paddlespeech stats --task st +# whisper text recognize +paddlespeech whisper --task transcribe --input ./zh.wav + +# whisper recognize text and translate to English +paddlespeech whisper --task translate --input ./zh.wav + echo -e "\033[32mTest success !!!\033[0m"