[ASR] fix Whisper cli model download path error. test=asr (#2679)

* add all whisper model size support

* add choices in parser.

* fix Whisper cli model download path error.

* fix resource download path.

* fix code style
pull/2680/head
zxcd 2 years ago committed by GitHub
parent fc02cd0540
commit 4542684694
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23

@ -61,7 +61,7 @@ Whisper model trained by OpenAI whisper https://github.com/openai/whisper
# to recognize text
text = whisper_executor(
model='whisper-large',
model='whisper',
task='transcribe',
sample_rate=16000,
config=None, # Set `config` and `ckpt_path` to None to use pretrained model.
@ -72,7 +72,7 @@ Whisper model trained by OpenAI whisper https://github.com/openai/whisper
# to recognize text and translate to English
feature = whisper_executor(
model='whisper-large',
model='whisper',
task='translate',
sample_rate=16000,
config=None, # Set `config` and `ckpt_path` to None to use pretrained model.

@ -61,7 +61,7 @@ Whisper模型由OpenAI Whisper训练 https://github.com/openai/whisper
# 识别文本
text = whisper_executor(
model='whisper-large',
model='whisper',
task='transcribe',
sample_rate=16000,
config=None, # Set `config` and `ckpt_path` to None to use pretrained model.
@ -72,7 +72,7 @@ Whisper模型由OpenAI Whisper训练 https://github.com/openai/whisper
# 将语音翻译成英语
feature = whisper_executor(
model='whisper-large',
model='whisper',
task='translate',
sample_rate=16000,
config=None, # Set `config` and `ckpt_path` to None to use pretrained model.

@ -27,6 +27,7 @@ import paddle
import soundfile
from yacs.config import CfgNode
from ...utils.env import DATA_HOME
from ..download import get_path_from_url
from ..executor import BaseExecutor
from ..log import logger
@ -187,10 +188,12 @@ class WhisperExecutor(BaseExecutor):
with UpdateConfig(self.config):
if "whisper" in model_type:
resource_url = self.task_resource.res_dict['resuource_data']
resource_md5 = self.task_resource.res_dict['resuource_data_md5']
resuource_path = self.task_resource.res_dict['resuource_path']
self.download_resource(resource_url, resuource_path,
resource_url = self.task_resource.res_dict['resource_data']
resource_md5 = self.task_resource.res_dict['resource_data_md5']
self.resource_path = os.path.join(
DATA_HOME, self.task_resource.version, 'whisper')
self.download_resource(resource_url, self.resource_path,
resource_md5)
else:
raise Exception("wrong type")
@ -249,7 +252,7 @@ class WhisperExecutor(BaseExecutor):
logger.debug(f"audio shape: {audio.shape}")
# fbank
audio = log_mel_spectrogram(audio)
audio = log_mel_spectrogram(audio, resource_path=self.resource_path)
audio_len = paddle.to_tensor(audio.shape[0])
@ -279,6 +282,7 @@ class WhisperExecutor(BaseExecutor):
verbose=cfg.verbose,
task=self.task,
language=self.language,
resource_path=self.resource_path,
temperature=temperature,
compression_ratio_threshold=cfg.compression_ratio_threshold,
logprob_threshold=cfg.logprob_threshold,

@ -468,9 +468,9 @@ whisper_dynamic_pretrained_models = {
"whisper-large-16k": {
'1.3': {
'url':
'https://paddlespeech.bj.bcebos.com/whisper/whisper_model_20221108/whisper-large-model.tar.gz',
'https://paddlespeech.bj.bcebos.com/whisper/whisper_model_20221122/whisper-large-model.tar.gz',
'md5':
'364c4d670835e5ca489045e1c29d75fe',
'cf1557af9d8ffa493fefad9cb08ae189',
'cfg_path':
'whisper.yaml',
'ckpt_path':
@ -479,20 +479,18 @@ whisper_dynamic_pretrained_models = {
'whisper-large-model.pdparams',
'params':
'whisper-large-model.pdparams',
'resuource_data':
'resource_data':
'https://paddlespeech.bj.bcebos.com/whisper/whisper_model_20221108/assets.tar',
'resuource_data_md5':
'resource_data_md5':
'37a0a8abdb3641a51194f79567a93b61',
'resuource_path':
'paddlespeech/s2t/models/whisper',
},
},
"whisper-base-en-16k": {
'1.3': {
'url':
'https://paddlespeech.bj.bcebos.com/whisper/whisper_model_20221118/whisper-base-en-model.tar.gz',
'https://paddlespeech.bj.bcebos.com/whisper/whisper_model_20221122/whisper-base-en-model.tar.gz',
'md5':
'f5bb8cdff42c7031d9e4c0ea20f7ceee',
'b156529aefde6beb7726d2ea98fd067a',
'cfg_path':
'whisper.yaml',
'ckpt_path':
@ -501,20 +499,18 @@ whisper_dynamic_pretrained_models = {
'whisper-base-en-model.pdparams',
'params':
'whisper-base-en-model.pdparams',
'resuource_data':
'resource_data':
'https://paddlespeech.bj.bcebos.com/whisper/whisper_model_20221108/assets.tar',
'resuource_data_md5':
'resource_data_md5':
'37a0a8abdb3641a51194f79567a93b61',
'resuource_path':
'paddlespeech/s2t/models/whisper',
},
},
"whisper-base-16k": {
'1.3': {
'url':
'https://paddlespeech.bj.bcebos.com/whisper/whisper_model_20221118/whisper-base-model.tar.gz',
'https://paddlespeech.bj.bcebos.com/whisper/whisper_model_20221122/whisper-base-model.tar.gz',
'md5':
'46f254e89a01b71586af1a46d28d7ce9',
'6b012a5abd583db14398c3492e47120b',
'cfg_path':
'whisper.yaml',
'ckpt_path':
@ -523,20 +519,18 @@ whisper_dynamic_pretrained_models = {
'whisper-base-model.pdparams',
'params':
'whisper-base-model.pdparams',
'resuource_data':
'resource_data':
'https://paddlespeech.bj.bcebos.com/whisper/whisper_model_20221108/assets.tar',
'resuource_data_md5':
'resource_data_md5':
'37a0a8abdb3641a51194f79567a93b61',
'resuource_path':
'paddlespeech/s2t/models/whisper',
},
},
"whisper-medium-en-16k": {
'1.3': {
'url':
'https://paddlespeech.bj.bcebos.com/whisper/whisper_model_20221118/whisper-medium-en-model.tar.gz',
'https://paddlespeech.bj.bcebos.com/whisper/whisper_model_20221122/whisper-medium-en-model.tar.gz',
'md5':
'98228f3ba94636c2760b51e5f3d6885f',
'c7f57d270bd20c7b170ba9dcf6c16f74',
'cfg_path':
'whisper.yaml',
'ckpt_path':
@ -545,20 +539,18 @@ whisper_dynamic_pretrained_models = {
'whisper-medium-en-model.pdparams',
'params':
'whisper-medium-en-model.pdparams',
'resuource_data':
'resource_data':
'https://paddlespeech.bj.bcebos.com/whisper/whisper_model_20221108/assets.tar',
'resuource_data_md5':
'resource_data_md5':
'37a0a8abdb3641a51194f79567a93b61',
'resuource_path':
'paddlespeech/s2t/models/whisper',
},
},
"whisper-medium-16k": {
'1.3': {
'url':
'https://paddlespeech.bj.bcebos.com/whisper/whisper_model_20221118/whisper-medium-model.tar.gz',
'https://paddlespeech.bj.bcebos.com/whisper/whisper_model_20221122/whisper-medium-model.tar.gz',
'md5':
'51ac154b264db75492ed1cc5280baebf',
'4c7dcd0df25f408199db4a4548336786',
'cfg_path':
'whisper.yaml',
'ckpt_path':
@ -567,20 +559,18 @@ whisper_dynamic_pretrained_models = {
'whisper-medium-model.pdparams',
'params':
'whisper-medium-model.pdparams',
'resuource_data':
'resource_data':
'https://paddlespeech.bj.bcebos.com/whisper/whisper_model_20221108/assets.tar',
'resuource_data_md5':
'resource_data_md5':
'37a0a8abdb3641a51194f79567a93b61',
'resuource_path':
'paddlespeech/s2t/models/whisper',
},
},
"whisper-small-en-16k": {
'1.3': {
'url':
'https://paddlespeech.bj.bcebos.com/whisper/whisper_model_20221118/whisper-small-en-model.tar.gz',
'https://paddlespeech.bj.bcebos.com/whisper/whisper_model_20221122/whisper-small-en-model.tar.gz',
'md5':
'973b784a335580a393e13a13995b110a',
'2b24efcb2e93f3275af7c0c7f598ff1c',
'cfg_path':
'whisper.yaml',
'ckpt_path':
@ -589,20 +579,18 @@ whisper_dynamic_pretrained_models = {
'whisper-small-en-model.pdparams',
'params':
'whisper-small-en-model.pdparams',
'resuource_data':
'resource_data':
'https://paddlespeech.bj.bcebos.com/whisper/whisper_model_20221108/assets.tar',
'resuource_data_md5':
'resource_data_md5':
'37a0a8abdb3641a51194f79567a93b61',
'resuource_path':
'paddlespeech/s2t/models/whisper',
},
},
"whisper-small-16k": {
'1.3': {
'url':
'https://paddlespeech.bj.bcebos.com/whisper/whisper_model_20221118/whisper-small-model.tar.gz',
'https://paddlespeech.bj.bcebos.com/whisper/whisper_model_20221122/whisper-small-model.tar.gz',
'md5':
'57a7530851cc98631c6fb29c606489c6',
'5a57911dd41651dd6ed78c5763912825',
'cfg_path':
'whisper.yaml',
'ckpt_path':
@ -611,20 +599,18 @@ whisper_dynamic_pretrained_models = {
'whisper-small-model.pdparams',
'params':
'whisper-small-model.pdparams',
'resuource_data':
'resource_data':
'https://paddlespeech.bj.bcebos.com/whisper/whisper_model_20221108/assets.tar',
'resuource_data_md5':
'resource_data_md5':
'37a0a8abdb3641a51194f79567a93b61',
'resuource_path':
'paddlespeech/s2t/models/whisper',
},
},
"whisper-tiny-en-16k": {
'1.3': {
'url':
'https://paddlespeech.bj.bcebos.com/whisper/whisper_model_20221118/whisper-tiny-en-model.tar.gz',
'https://paddlespeech.bj.bcebos.com/whisper/whisper_model_20221122/whisper-tiny-en-model.tar.gz',
'md5':
'3ef5c0777e0bd4a1a240895167b0eb0d',
'14969164a3f713fd58e56978c34188f6',
'cfg_path':
'whisper.yaml',
'ckpt_path':
@ -633,20 +619,18 @@ whisper_dynamic_pretrained_models = {
'whisper-tiny-en-model.pdparams',
'params':
'whisper-tiny-en-model.pdparams',
'resuource_data':
'resource_data':
'https://paddlespeech.bj.bcebos.com/whisper/whisper_model_20221108/assets.tar',
'resuource_data_md5':
'resource_data_md5':
'37a0a8abdb3641a51194f79567a93b61',
'resuource_path':
'paddlespeech/s2t/models/whisper',
},
},
"whisper-tiny-16k": {
'1.3': {
'url':
'https://paddlespeech.bj.bcebos.com/whisper/whisper_model_20221118/whisper-tiny-model.tar.gz',
'https://paddlespeech.bj.bcebos.com/whisper/whisper_model_20221122/whisper-tiny-model.tar.gz',
'md5':
'ddf232cd16c85120e89c870a53451e53',
'a5b82a1f2067a2ca400f17fabd62b81b',
'cfg_path':
'whisper.yaml',
'ckpt_path':
@ -655,12 +639,10 @@ whisper_dynamic_pretrained_models = {
'whisper-tiny-model.pdparams',
'params':
'whisper-tiny-model.pdparams',
'resuource_data':
'resource_data':
'https://paddlespeech.bj.bcebos.com/whisper/whisper_model_20221108/assets.tar',
'resuource_data_md5':
'resource_data_md5':
'37a0a8abdb3641a51194f79567a93b61',
'resuource_path':
'paddlespeech/s2t/models/whisper',
},
},
}

@ -62,7 +62,8 @@ class WhisperInfer():
temperature = [temperature]
#load audio
mel = log_mel_spectrogram(args.audio)
mel = log_mel_spectrogram(
args.audio_file, resource_path=config.resource_path)
result = transcribe(
self.model, mel, temperature=temperature, **config)

@ -298,9 +298,9 @@ class Tokenizer:
@lru_cache(maxsize=None)
def build_tokenizer(name: str="gpt2"):
def build_tokenizer(resource_path: str, name: str="gpt2"):
os.environ["TOKENIZERS_PARALLELISM"] = "false"
path = os.path.join(os.path.dirname(__file__), "assets", name)
path = os.path.join(resource_path, "assets", name)
tokenizer = GPTTokenizer.from_pretrained(path)
specials = [
@ -321,6 +321,7 @@ def build_tokenizer(name: str="gpt2"):
@lru_cache(maxsize=None)
def get_tokenizer(
multilingual: bool,
resource_path: str,
*,
task: Optional[str]=None, # Literal["transcribe", "translate", None]
language: Optional[str]=None, ) -> Tokenizer:
@ -341,7 +342,8 @@ def get_tokenizer(
task = None
language = None
tokenizer = build_tokenizer(name=tokenizer_name)
tokenizer = build_tokenizer(
resource_path=resource_path, name=tokenizer_name)
all_special_ids: List[int] = tokenizer.all_special_ids
sot: int = all_special_ids[1]
translate: int = all_special_ids[-6]

@ -1,6 +1,6 @@
# MIT License, Copyright (c) 2022 OpenAI.
# Copyright (c) 2022 PaddlePaddle Authors and . All Rights Reserved.
#
#
# Modified from OpenAI Whisper 2022 (https://github.com/openai/whisper/whisper)
import os
from dataclasses import dataclass
@ -265,7 +265,6 @@ class DecodingOptions:
task: str = "transcribe" # whether to perform X->X "transcribe" or X->English "translate"
language: Optional[
str] = None # language that the audio is in; uses detected language if None
# sampling-related options
temperature: float = 0.0
sample_len: Optional[int] = None # maximum number of tokens to sample
@ -361,10 +360,11 @@ class WhisperInference(Inference):
@paddle.no_grad()
def detect_language(model: "Whisper",
mel: paddle.Tensor,
tokenizer: Tokenizer=None
) -> Tuple[paddle.Tensor, List[dict]]:
def detect_language(
model: "Whisper",
mel: paddle.Tensor,
resource_path: str,
tokenizer: Tokenizer=None) -> Tuple[paddle.Tensor, List[dict]]:
"""
Detect the spoken language in the audio, and return them as list of strings, along with the ids
of the most probable language tokens and the probability distribution over all language tokens.
@ -378,7 +378,8 @@ def detect_language(model: "Whisper",
list of dictionaries containing the probability distribution over all languages.
"""
if tokenizer is None:
tokenizer = get_tokenizer(model.is_multilingual)
tokenizer = get_tokenizer(
model.is_multilingual, resource_path=resource_path)
if tokenizer.language is None or tokenizer.language_token not in tokenizer.sot_sequence:
raise ValueError(
"This model doesn't have language tokens so it can't perform lang id"
@ -419,6 +420,7 @@ def detect_language(model: "Whisper",
def transcribe(
model: "Whisper",
mel: paddle.Tensor,
resource_path: str,
*,
verbose: Optional[bool]=None,
temperature: Union[float, Tuple[float, ...]]=(0.0, 0.2, 0.4, 0.6, 0.8,
@ -485,7 +487,7 @@ def transcribe(
"Detecting language using up to the first 30 seconds. Use `--language` to specify the language"
)
segment = pad_or_trim(mel, N_FRAMES)
_, probs = model.detect_language(segment)
_, probs = model.detect_language(segment, resource_path)
decode_options["language"] = max(probs, key=probs.get)
if verbose is not None:
print(
@ -495,7 +497,10 @@ def transcribe(
language = decode_options["language"]
task = decode_options.get("task", "transcribe")
tokenizer = get_tokenizer(
model.is_multilingual, language=language, task=task)
model.is_multilingual,
resource_path=resource_path,
language=language,
task=task)
def decode_with_fallback(segment: paddle.Tensor) -> DecodingResult:
temperatures = [temperature] if isinstance(temperature, (
@ -513,7 +518,7 @@ def transcribe(
kwargs.pop("best_of", None)
options = DecodingOptions(**kwargs, temperature=t)
decode_result = model.decode(segment, options)
decode_result = model.decode(segment, options, resource_path)
needs_fallback = False
if compression_ratio_threshold is not None and decode_result.compression_ratio > compression_ratio_threshold:
@ -979,14 +984,21 @@ class DecodingTask:
decoder: TokenDecoder
logit_filters: List[LogitFilter]
def __init__(self, model: "Whisper", options: DecodingOptions):
def __init__(self,
model: "Whisper",
options: DecodingOptions,
resource_path: str):
self.model = model
language = options.language or "en"
tokenizer = get_tokenizer(
model.is_multilingual, language=language, task=options.task)
model.is_multilingual,
resource_path=resource_path,
language=language,
task=options.task)
self.tokenizer: Tokenizer = tokenizer
self.options: DecodingOptions = self._verify_options(options)
self.resource_path: str = resource_path
self.beam_size: int = options.beam_size or options.best_of or 1
self.n_ctx: int = model.dims.n_text_ctx
@ -1112,13 +1124,14 @@ class DecodingTask:
def _detect_language(self,
audio_features: paddle.Tensor,
tokens: paddle.Tensor):
tokens: paddle.Tensor,
resource_path: str):
languages = [self.options.language] * audio_features.shape[0]
lang_probs = None
if self.options.language is None or self.options.task == "lang_id":
lang_tokens, lang_probs = self.model.detect_language(audio_features,
self.tokenizer)
lang_tokens, lang_probs = self.model.detect_language(
audio_features, self.tokenizer, self.resource_path)
languages = [max(probs, key=probs.get) for probs in lang_probs]
if self.options.language is None:
tokens[:, self.sot_index +
@ -1185,7 +1198,8 @@ class DecodingTask:
# detect language if requested, overwriting the language token
languages, language_probs = self._detect_language(
paddle.to_tensor(audio_features), paddle.to_tensor(tokens))
paddle.to_tensor(audio_features),
paddle.to_tensor(tokens), self.resource_path)
if self.options.task == "lang_id":
return [
@ -1254,10 +1268,11 @@ class DecodingTask:
@paddle.no_grad()
def decode(model: "Whisper",
mel: paddle.Tensor,
options: DecodingOptions=DecodingOptions()
) -> Union[DecodingResult, List[DecodingResult]]:
def decode(
model: "Whisper",
mel: paddle.Tensor,
options: DecodingOptions=DecodingOptions(),
resource_path=str, ) -> Union[DecodingResult, List[DecodingResult]]:
"""
Performs decoding of 30-second audio segment(s), provided as Mel spectrogram(s).
@ -1281,7 +1296,7 @@ def decode(model: "Whisper",
if single:
mel = mel.unsqueeze(0)
result = DecodingTask(model, options).run(mel)
result = DecodingTask(model, options, resource_path).run(mel)
if single:
result = result[0]
@ -1407,7 +1422,7 @@ def hann_window(n_fft: int=N_FFT):
@lru_cache(maxsize=None)
def mel_filters(device, n_mels: int=N_MELS) -> paddle.Tensor:
def mel_filters(resource_path: str, n_mels: int=N_MELS) -> paddle.Tensor:
"""
load the mel filterbank matrix for projecting STFT into a Mel spectrogram.
Allows decoupling librosa dependency; saved using:
@ -1418,14 +1433,13 @@ def mel_filters(device, n_mels: int=N_MELS) -> paddle.Tensor:
)
"""
assert n_mels == 80, f"Unsupported n_mels: {n_mels}"
with np.load(
os.path.join(
os.path.dirname(__file__), "assets", "mel_filters.npz")) as f:
with np.load(os.path.join(resource_path, "assets", "mel_filters.npz")) as f:
return paddle.to_tensor(f[f"mel_{n_mels}"])
def log_mel_spectrogram(audio: Union[str, np.ndarray, paddle.Tensor],
n_mels: int=N_MELS):
n_mels: int=N_MELS,
resource_path: str=None):
"""
Compute the log-Mel spectrogram of
@ -1454,7 +1468,7 @@ def log_mel_spectrogram(audio: Union[str, np.ndarray, paddle.Tensor],
magnitudes = stft[:, :-1].abs()**2
filters = mel_filters(audio, n_mels)
filters = mel_filters(resource_path, n_mels)
mel_spec = filters @ magnitudes
mel_spec = paddle.to_tensor(mel_spec.numpy().tolist())

Loading…
Cancel
Save