[ASR] fix Whisper cli model download path error. test=asr (#2679)

* add all whisper model size support

* add choices in parser.

* fix Whisper cli model download path error.

* fix resource download path.

* fix code style
pull/2680/head
zxcd 2 years ago committed by GitHub
parent fc02cd0540
commit 4542684694
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23

@ -61,7 +61,7 @@ Whisper model trained by OpenAI whisper https://github.com/openai/whisper
# to recognize text # to recognize text
text = whisper_executor( text = whisper_executor(
model='whisper-large', model='whisper',
task='transcribe', task='transcribe',
sample_rate=16000, sample_rate=16000,
config=None, # Set `config` and `ckpt_path` to None to use pretrained model. config=None, # Set `config` and `ckpt_path` to None to use pretrained model.
@ -72,7 +72,7 @@ Whisper model trained by OpenAI whisper https://github.com/openai/whisper
# to recognize text and translate to English # to recognize text and translate to English
feature = whisper_executor( feature = whisper_executor(
model='whisper-large', model='whisper',
task='translate', task='translate',
sample_rate=16000, sample_rate=16000,
config=None, # Set `config` and `ckpt_path` to None to use pretrained model. config=None, # Set `config` and `ckpt_path` to None to use pretrained model.

@ -61,7 +61,7 @@ Whisper模型由OpenAI Whisper训练 https://github.com/openai/whisper
# 识别文本 # 识别文本
text = whisper_executor( text = whisper_executor(
model='whisper-large', model='whisper',
task='transcribe', task='transcribe',
sample_rate=16000, sample_rate=16000,
config=None, # Set `config` and `ckpt_path` to None to use pretrained model. config=None, # Set `config` and `ckpt_path` to None to use pretrained model.
@ -72,7 +72,7 @@ Whisper模型由OpenAI Whisper训练 https://github.com/openai/whisper
# 将语音翻译成英语 # 将语音翻译成英语
feature = whisper_executor( feature = whisper_executor(
model='whisper-large', model='whisper',
task='translate', task='translate',
sample_rate=16000, sample_rate=16000,
config=None, # Set `config` and `ckpt_path` to None to use pretrained model. config=None, # Set `config` and `ckpt_path` to None to use pretrained model.

@ -27,6 +27,7 @@ import paddle
import soundfile import soundfile
from yacs.config import CfgNode from yacs.config import CfgNode
from ...utils.env import DATA_HOME
from ..download import get_path_from_url from ..download import get_path_from_url
from ..executor import BaseExecutor from ..executor import BaseExecutor
from ..log import logger from ..log import logger
@ -187,10 +188,12 @@ class WhisperExecutor(BaseExecutor):
with UpdateConfig(self.config): with UpdateConfig(self.config):
if "whisper" in model_type: if "whisper" in model_type:
resource_url = self.task_resource.res_dict['resuource_data'] resource_url = self.task_resource.res_dict['resource_data']
resource_md5 = self.task_resource.res_dict['resuource_data_md5'] resource_md5 = self.task_resource.res_dict['resource_data_md5']
resuource_path = self.task_resource.res_dict['resuource_path']
self.download_resource(resource_url, resuource_path, self.resource_path = os.path.join(
DATA_HOME, self.task_resource.version, 'whisper')
self.download_resource(resource_url, self.resource_path,
resource_md5) resource_md5)
else: else:
raise Exception("wrong type") raise Exception("wrong type")
@ -249,7 +252,7 @@ class WhisperExecutor(BaseExecutor):
logger.debug(f"audio shape: {audio.shape}") logger.debug(f"audio shape: {audio.shape}")
# fbank # fbank
audio = log_mel_spectrogram(audio) audio = log_mel_spectrogram(audio, resource_path=self.resource_path)
audio_len = paddle.to_tensor(audio.shape[0]) audio_len = paddle.to_tensor(audio.shape[0])
@ -279,6 +282,7 @@ class WhisperExecutor(BaseExecutor):
verbose=cfg.verbose, verbose=cfg.verbose,
task=self.task, task=self.task,
language=self.language, language=self.language,
resource_path=self.resource_path,
temperature=temperature, temperature=temperature,
compression_ratio_threshold=cfg.compression_ratio_threshold, compression_ratio_threshold=cfg.compression_ratio_threshold,
logprob_threshold=cfg.logprob_threshold, logprob_threshold=cfg.logprob_threshold,

@ -468,9 +468,9 @@ whisper_dynamic_pretrained_models = {
"whisper-large-16k": { "whisper-large-16k": {
'1.3': { '1.3': {
'url': 'url':
'https://paddlespeech.bj.bcebos.com/whisper/whisper_model_20221108/whisper-large-model.tar.gz', 'https://paddlespeech.bj.bcebos.com/whisper/whisper_model_20221122/whisper-large-model.tar.gz',
'md5': 'md5':
'364c4d670835e5ca489045e1c29d75fe', 'cf1557af9d8ffa493fefad9cb08ae189',
'cfg_path': 'cfg_path':
'whisper.yaml', 'whisper.yaml',
'ckpt_path': 'ckpt_path':
@ -479,20 +479,18 @@ whisper_dynamic_pretrained_models = {
'whisper-large-model.pdparams', 'whisper-large-model.pdparams',
'params': 'params':
'whisper-large-model.pdparams', 'whisper-large-model.pdparams',
'resuource_data': 'resource_data':
'https://paddlespeech.bj.bcebos.com/whisper/whisper_model_20221108/assets.tar', 'https://paddlespeech.bj.bcebos.com/whisper/whisper_model_20221108/assets.tar',
'resuource_data_md5': 'resource_data_md5':
'37a0a8abdb3641a51194f79567a93b61', '37a0a8abdb3641a51194f79567a93b61',
'resuource_path':
'paddlespeech/s2t/models/whisper',
}, },
}, },
"whisper-base-en-16k": { "whisper-base-en-16k": {
'1.3': { '1.3': {
'url': 'url':
'https://paddlespeech.bj.bcebos.com/whisper/whisper_model_20221118/whisper-base-en-model.tar.gz', 'https://paddlespeech.bj.bcebos.com/whisper/whisper_model_20221122/whisper-base-en-model.tar.gz',
'md5': 'md5':
'f5bb8cdff42c7031d9e4c0ea20f7ceee', 'b156529aefde6beb7726d2ea98fd067a',
'cfg_path': 'cfg_path':
'whisper.yaml', 'whisper.yaml',
'ckpt_path': 'ckpt_path':
@ -501,20 +499,18 @@ whisper_dynamic_pretrained_models = {
'whisper-base-en-model.pdparams', 'whisper-base-en-model.pdparams',
'params': 'params':
'whisper-base-en-model.pdparams', 'whisper-base-en-model.pdparams',
'resuource_data': 'resource_data':
'https://paddlespeech.bj.bcebos.com/whisper/whisper_model_20221108/assets.tar', 'https://paddlespeech.bj.bcebos.com/whisper/whisper_model_20221108/assets.tar',
'resuource_data_md5': 'resource_data_md5':
'37a0a8abdb3641a51194f79567a93b61', '37a0a8abdb3641a51194f79567a93b61',
'resuource_path':
'paddlespeech/s2t/models/whisper',
}, },
}, },
"whisper-base-16k": { "whisper-base-16k": {
'1.3': { '1.3': {
'url': 'url':
'https://paddlespeech.bj.bcebos.com/whisper/whisper_model_20221118/whisper-base-model.tar.gz', 'https://paddlespeech.bj.bcebos.com/whisper/whisper_model_20221122/whisper-base-model.tar.gz',
'md5': 'md5':
'46f254e89a01b71586af1a46d28d7ce9', '6b012a5abd583db14398c3492e47120b',
'cfg_path': 'cfg_path':
'whisper.yaml', 'whisper.yaml',
'ckpt_path': 'ckpt_path':
@ -523,20 +519,18 @@ whisper_dynamic_pretrained_models = {
'whisper-base-model.pdparams', 'whisper-base-model.pdparams',
'params': 'params':
'whisper-base-model.pdparams', 'whisper-base-model.pdparams',
'resuource_data': 'resource_data':
'https://paddlespeech.bj.bcebos.com/whisper/whisper_model_20221108/assets.tar', 'https://paddlespeech.bj.bcebos.com/whisper/whisper_model_20221108/assets.tar',
'resuource_data_md5': 'resource_data_md5':
'37a0a8abdb3641a51194f79567a93b61', '37a0a8abdb3641a51194f79567a93b61',
'resuource_path':
'paddlespeech/s2t/models/whisper',
}, },
}, },
"whisper-medium-en-16k": { "whisper-medium-en-16k": {
'1.3': { '1.3': {
'url': 'url':
'https://paddlespeech.bj.bcebos.com/whisper/whisper_model_20221118/whisper-medium-en-model.tar.gz', 'https://paddlespeech.bj.bcebos.com/whisper/whisper_model_20221122/whisper-medium-en-model.tar.gz',
'md5': 'md5':
'98228f3ba94636c2760b51e5f3d6885f', 'c7f57d270bd20c7b170ba9dcf6c16f74',
'cfg_path': 'cfg_path':
'whisper.yaml', 'whisper.yaml',
'ckpt_path': 'ckpt_path':
@ -545,20 +539,18 @@ whisper_dynamic_pretrained_models = {
'whisper-medium-en-model.pdparams', 'whisper-medium-en-model.pdparams',
'params': 'params':
'whisper-medium-en-model.pdparams', 'whisper-medium-en-model.pdparams',
'resuource_data': 'resource_data':
'https://paddlespeech.bj.bcebos.com/whisper/whisper_model_20221108/assets.tar', 'https://paddlespeech.bj.bcebos.com/whisper/whisper_model_20221108/assets.tar',
'resuource_data_md5': 'resource_data_md5':
'37a0a8abdb3641a51194f79567a93b61', '37a0a8abdb3641a51194f79567a93b61',
'resuource_path':
'paddlespeech/s2t/models/whisper',
}, },
}, },
"whisper-medium-16k": { "whisper-medium-16k": {
'1.3': { '1.3': {
'url': 'url':
'https://paddlespeech.bj.bcebos.com/whisper/whisper_model_20221118/whisper-medium-model.tar.gz', 'https://paddlespeech.bj.bcebos.com/whisper/whisper_model_20221122/whisper-medium-model.tar.gz',
'md5': 'md5':
'51ac154b264db75492ed1cc5280baebf', '4c7dcd0df25f408199db4a4548336786',
'cfg_path': 'cfg_path':
'whisper.yaml', 'whisper.yaml',
'ckpt_path': 'ckpt_path':
@ -567,20 +559,18 @@ whisper_dynamic_pretrained_models = {
'whisper-medium-model.pdparams', 'whisper-medium-model.pdparams',
'params': 'params':
'whisper-medium-model.pdparams', 'whisper-medium-model.pdparams',
'resuource_data': 'resource_data':
'https://paddlespeech.bj.bcebos.com/whisper/whisper_model_20221108/assets.tar', 'https://paddlespeech.bj.bcebos.com/whisper/whisper_model_20221108/assets.tar',
'resuource_data_md5': 'resource_data_md5':
'37a0a8abdb3641a51194f79567a93b61', '37a0a8abdb3641a51194f79567a93b61',
'resuource_path':
'paddlespeech/s2t/models/whisper',
}, },
}, },
"whisper-small-en-16k": { "whisper-small-en-16k": {
'1.3': { '1.3': {
'url': 'url':
'https://paddlespeech.bj.bcebos.com/whisper/whisper_model_20221118/whisper-small-en-model.tar.gz', 'https://paddlespeech.bj.bcebos.com/whisper/whisper_model_20221122/whisper-small-en-model.tar.gz',
'md5': 'md5':
'973b784a335580a393e13a13995b110a', '2b24efcb2e93f3275af7c0c7f598ff1c',
'cfg_path': 'cfg_path':
'whisper.yaml', 'whisper.yaml',
'ckpt_path': 'ckpt_path':
@ -589,20 +579,18 @@ whisper_dynamic_pretrained_models = {
'whisper-small-en-model.pdparams', 'whisper-small-en-model.pdparams',
'params': 'params':
'whisper-small-en-model.pdparams', 'whisper-small-en-model.pdparams',
'resuource_data': 'resource_data':
'https://paddlespeech.bj.bcebos.com/whisper/whisper_model_20221108/assets.tar', 'https://paddlespeech.bj.bcebos.com/whisper/whisper_model_20221108/assets.tar',
'resuource_data_md5': 'resource_data_md5':
'37a0a8abdb3641a51194f79567a93b61', '37a0a8abdb3641a51194f79567a93b61',
'resuource_path':
'paddlespeech/s2t/models/whisper',
}, },
}, },
"whisper-small-16k": { "whisper-small-16k": {
'1.3': { '1.3': {
'url': 'url':
'https://paddlespeech.bj.bcebos.com/whisper/whisper_model_20221118/whisper-small-model.tar.gz', 'https://paddlespeech.bj.bcebos.com/whisper/whisper_model_20221122/whisper-small-model.tar.gz',
'md5': 'md5':
'57a7530851cc98631c6fb29c606489c6', '5a57911dd41651dd6ed78c5763912825',
'cfg_path': 'cfg_path':
'whisper.yaml', 'whisper.yaml',
'ckpt_path': 'ckpt_path':
@ -611,20 +599,18 @@ whisper_dynamic_pretrained_models = {
'whisper-small-model.pdparams', 'whisper-small-model.pdparams',
'params': 'params':
'whisper-small-model.pdparams', 'whisper-small-model.pdparams',
'resuource_data': 'resource_data':
'https://paddlespeech.bj.bcebos.com/whisper/whisper_model_20221108/assets.tar', 'https://paddlespeech.bj.bcebos.com/whisper/whisper_model_20221108/assets.tar',
'resuource_data_md5': 'resource_data_md5':
'37a0a8abdb3641a51194f79567a93b61', '37a0a8abdb3641a51194f79567a93b61',
'resuource_path':
'paddlespeech/s2t/models/whisper',
}, },
}, },
"whisper-tiny-en-16k": { "whisper-tiny-en-16k": {
'1.3': { '1.3': {
'url': 'url':
'https://paddlespeech.bj.bcebos.com/whisper/whisper_model_20221118/whisper-tiny-en-model.tar.gz', 'https://paddlespeech.bj.bcebos.com/whisper/whisper_model_20221122/whisper-tiny-en-model.tar.gz',
'md5': 'md5':
'3ef5c0777e0bd4a1a240895167b0eb0d', '14969164a3f713fd58e56978c34188f6',
'cfg_path': 'cfg_path':
'whisper.yaml', 'whisper.yaml',
'ckpt_path': 'ckpt_path':
@ -633,20 +619,18 @@ whisper_dynamic_pretrained_models = {
'whisper-tiny-en-model.pdparams', 'whisper-tiny-en-model.pdparams',
'params': 'params':
'whisper-tiny-en-model.pdparams', 'whisper-tiny-en-model.pdparams',
'resuource_data': 'resource_data':
'https://paddlespeech.bj.bcebos.com/whisper/whisper_model_20221108/assets.tar', 'https://paddlespeech.bj.bcebos.com/whisper/whisper_model_20221108/assets.tar',
'resuource_data_md5': 'resource_data_md5':
'37a0a8abdb3641a51194f79567a93b61', '37a0a8abdb3641a51194f79567a93b61',
'resuource_path':
'paddlespeech/s2t/models/whisper',
}, },
}, },
"whisper-tiny-16k": { "whisper-tiny-16k": {
'1.3': { '1.3': {
'url': 'url':
'https://paddlespeech.bj.bcebos.com/whisper/whisper_model_20221118/whisper-tiny-model.tar.gz', 'https://paddlespeech.bj.bcebos.com/whisper/whisper_model_20221122/whisper-tiny-model.tar.gz',
'md5': 'md5':
'ddf232cd16c85120e89c870a53451e53', 'a5b82a1f2067a2ca400f17fabd62b81b',
'cfg_path': 'cfg_path':
'whisper.yaml', 'whisper.yaml',
'ckpt_path': 'ckpt_path':
@ -655,12 +639,10 @@ whisper_dynamic_pretrained_models = {
'whisper-tiny-model.pdparams', 'whisper-tiny-model.pdparams',
'params': 'params':
'whisper-tiny-model.pdparams', 'whisper-tiny-model.pdparams',
'resuource_data': 'resource_data':
'https://paddlespeech.bj.bcebos.com/whisper/whisper_model_20221108/assets.tar', 'https://paddlespeech.bj.bcebos.com/whisper/whisper_model_20221108/assets.tar',
'resuource_data_md5': 'resource_data_md5':
'37a0a8abdb3641a51194f79567a93b61', '37a0a8abdb3641a51194f79567a93b61',
'resuource_path':
'paddlespeech/s2t/models/whisper',
}, },
}, },
} }

@ -62,7 +62,8 @@ class WhisperInfer():
temperature = [temperature] temperature = [temperature]
#load audio #load audio
mel = log_mel_spectrogram(args.audio) mel = log_mel_spectrogram(
args.audio_file, resource_path=config.resource_path)
result = transcribe( result = transcribe(
self.model, mel, temperature=temperature, **config) self.model, mel, temperature=temperature, **config)

@ -298,9 +298,9 @@ class Tokenizer:
@lru_cache(maxsize=None) @lru_cache(maxsize=None)
def build_tokenizer(name: str="gpt2"): def build_tokenizer(resource_path: str, name: str="gpt2"):
os.environ["TOKENIZERS_PARALLELISM"] = "false" os.environ["TOKENIZERS_PARALLELISM"] = "false"
path = os.path.join(os.path.dirname(__file__), "assets", name) path = os.path.join(resource_path, "assets", name)
tokenizer = GPTTokenizer.from_pretrained(path) tokenizer = GPTTokenizer.from_pretrained(path)
specials = [ specials = [
@ -321,6 +321,7 @@ def build_tokenizer(name: str="gpt2"):
@lru_cache(maxsize=None) @lru_cache(maxsize=None)
def get_tokenizer( def get_tokenizer(
multilingual: bool, multilingual: bool,
resource_path: str,
*, *,
task: Optional[str]=None, # Literal["transcribe", "translate", None] task: Optional[str]=None, # Literal["transcribe", "translate", None]
language: Optional[str]=None, ) -> Tokenizer: language: Optional[str]=None, ) -> Tokenizer:
@ -341,7 +342,8 @@ def get_tokenizer(
task = None task = None
language = None language = None
tokenizer = build_tokenizer(name=tokenizer_name) tokenizer = build_tokenizer(
resource_path=resource_path, name=tokenizer_name)
all_special_ids: List[int] = tokenizer.all_special_ids all_special_ids: List[int] = tokenizer.all_special_ids
sot: int = all_special_ids[1] sot: int = all_special_ids[1]
translate: int = all_special_ids[-6] translate: int = all_special_ids[-6]

@ -1,6 +1,6 @@
# MIT License, Copyright (c) 2022 OpenAI. # MIT License, Copyright (c) 2022 OpenAI.
# Copyright (c) 2022 PaddlePaddle Authors and . All Rights Reserved. # Copyright (c) 2022 PaddlePaddle Authors and . All Rights Reserved.
# #
# Modified from OpenAI Whisper 2022 (https://github.com/openai/whisper/whisper) # Modified from OpenAI Whisper 2022 (https://github.com/openai/whisper/whisper)
import os import os
from dataclasses import dataclass from dataclasses import dataclass
@ -265,7 +265,6 @@ class DecodingOptions:
task: str = "transcribe" # whether to perform X->X "transcribe" or X->English "translate" task: str = "transcribe" # whether to perform X->X "transcribe" or X->English "translate"
language: Optional[ language: Optional[
str] = None # language that the audio is in; uses detected language if None str] = None # language that the audio is in; uses detected language if None
# sampling-related options # sampling-related options
temperature: float = 0.0 temperature: float = 0.0
sample_len: Optional[int] = None # maximum number of tokens to sample sample_len: Optional[int] = None # maximum number of tokens to sample
@ -361,10 +360,11 @@ class WhisperInference(Inference):
@paddle.no_grad() @paddle.no_grad()
def detect_language(model: "Whisper", def detect_language(
mel: paddle.Tensor, model: "Whisper",
tokenizer: Tokenizer=None mel: paddle.Tensor,
) -> Tuple[paddle.Tensor, List[dict]]: resource_path: str,
tokenizer: Tokenizer=None) -> Tuple[paddle.Tensor, List[dict]]:
""" """
Detect the spoken language in the audio, and return them as list of strings, along with the ids Detect the spoken language in the audio, and return them as list of strings, along with the ids
of the most probable language tokens and the probability distribution over all language tokens. of the most probable language tokens and the probability distribution over all language tokens.
@ -378,7 +378,8 @@ def detect_language(model: "Whisper",
list of dictionaries containing the probability distribution over all languages. list of dictionaries containing the probability distribution over all languages.
""" """
if tokenizer is None: if tokenizer is None:
tokenizer = get_tokenizer(model.is_multilingual) tokenizer = get_tokenizer(
model.is_multilingual, resource_path=resource_path)
if tokenizer.language is None or tokenizer.language_token not in tokenizer.sot_sequence: if tokenizer.language is None or tokenizer.language_token not in tokenizer.sot_sequence:
raise ValueError( raise ValueError(
"This model doesn't have language tokens so it can't perform lang id" "This model doesn't have language tokens so it can't perform lang id"
@ -419,6 +420,7 @@ def detect_language(model: "Whisper",
def transcribe( def transcribe(
model: "Whisper", model: "Whisper",
mel: paddle.Tensor, mel: paddle.Tensor,
resource_path: str,
*, *,
verbose: Optional[bool]=None, verbose: Optional[bool]=None,
temperature: Union[float, Tuple[float, ...]]=(0.0, 0.2, 0.4, 0.6, 0.8, temperature: Union[float, Tuple[float, ...]]=(0.0, 0.2, 0.4, 0.6, 0.8,
@ -485,7 +487,7 @@ def transcribe(
"Detecting language using up to the first 30 seconds. Use `--language` to specify the language" "Detecting language using up to the first 30 seconds. Use `--language` to specify the language"
) )
segment = pad_or_trim(mel, N_FRAMES) segment = pad_or_trim(mel, N_FRAMES)
_, probs = model.detect_language(segment) _, probs = model.detect_language(segment, resource_path)
decode_options["language"] = max(probs, key=probs.get) decode_options["language"] = max(probs, key=probs.get)
if verbose is not None: if verbose is not None:
print( print(
@ -495,7 +497,10 @@ def transcribe(
language = decode_options["language"] language = decode_options["language"]
task = decode_options.get("task", "transcribe") task = decode_options.get("task", "transcribe")
tokenizer = get_tokenizer( tokenizer = get_tokenizer(
model.is_multilingual, language=language, task=task) model.is_multilingual,
resource_path=resource_path,
language=language,
task=task)
def decode_with_fallback(segment: paddle.Tensor) -> DecodingResult: def decode_with_fallback(segment: paddle.Tensor) -> DecodingResult:
temperatures = [temperature] if isinstance(temperature, ( temperatures = [temperature] if isinstance(temperature, (
@ -513,7 +518,7 @@ def transcribe(
kwargs.pop("best_of", None) kwargs.pop("best_of", None)
options = DecodingOptions(**kwargs, temperature=t) options = DecodingOptions(**kwargs, temperature=t)
decode_result = model.decode(segment, options) decode_result = model.decode(segment, options, resource_path)
needs_fallback = False needs_fallback = False
if compression_ratio_threshold is not None and decode_result.compression_ratio > compression_ratio_threshold: if compression_ratio_threshold is not None and decode_result.compression_ratio > compression_ratio_threshold:
@ -979,14 +984,21 @@ class DecodingTask:
decoder: TokenDecoder decoder: TokenDecoder
logit_filters: List[LogitFilter] logit_filters: List[LogitFilter]
def __init__(self, model: "Whisper", options: DecodingOptions): def __init__(self,
model: "Whisper",
options: DecodingOptions,
resource_path: str):
self.model = model self.model = model
language = options.language or "en" language = options.language or "en"
tokenizer = get_tokenizer( tokenizer = get_tokenizer(
model.is_multilingual, language=language, task=options.task) model.is_multilingual,
resource_path=resource_path,
language=language,
task=options.task)
self.tokenizer: Tokenizer = tokenizer self.tokenizer: Tokenizer = tokenizer
self.options: DecodingOptions = self._verify_options(options) self.options: DecodingOptions = self._verify_options(options)
self.resource_path: str = resource_path
self.beam_size: int = options.beam_size or options.best_of or 1 self.beam_size: int = options.beam_size or options.best_of or 1
self.n_ctx: int = model.dims.n_text_ctx self.n_ctx: int = model.dims.n_text_ctx
@ -1112,13 +1124,14 @@ class DecodingTask:
def _detect_language(self, def _detect_language(self,
audio_features: paddle.Tensor, audio_features: paddle.Tensor,
tokens: paddle.Tensor): tokens: paddle.Tensor,
resource_path: str):
languages = [self.options.language] * audio_features.shape[0] languages = [self.options.language] * audio_features.shape[0]
lang_probs = None lang_probs = None
if self.options.language is None or self.options.task == "lang_id": if self.options.language is None or self.options.task == "lang_id":
lang_tokens, lang_probs = self.model.detect_language(audio_features, lang_tokens, lang_probs = self.model.detect_language(
self.tokenizer) audio_features, self.tokenizer, self.resource_path)
languages = [max(probs, key=probs.get) for probs in lang_probs] languages = [max(probs, key=probs.get) for probs in lang_probs]
if self.options.language is None: if self.options.language is None:
tokens[:, self.sot_index + tokens[:, self.sot_index +
@ -1185,7 +1198,8 @@ class DecodingTask:
# detect language if requested, overwriting the language token # detect language if requested, overwriting the language token
languages, language_probs = self._detect_language( languages, language_probs = self._detect_language(
paddle.to_tensor(audio_features), paddle.to_tensor(tokens)) paddle.to_tensor(audio_features),
paddle.to_tensor(tokens), self.resource_path)
if self.options.task == "lang_id": if self.options.task == "lang_id":
return [ return [
@ -1254,10 +1268,11 @@ class DecodingTask:
@paddle.no_grad() @paddle.no_grad()
def decode(model: "Whisper", def decode(
mel: paddle.Tensor, model: "Whisper",
options: DecodingOptions=DecodingOptions() mel: paddle.Tensor,
) -> Union[DecodingResult, List[DecodingResult]]: options: DecodingOptions=DecodingOptions(),
resource_path=str, ) -> Union[DecodingResult, List[DecodingResult]]:
""" """
Performs decoding of 30-second audio segment(s), provided as Mel spectrogram(s). Performs decoding of 30-second audio segment(s), provided as Mel spectrogram(s).
@ -1281,7 +1296,7 @@ def decode(model: "Whisper",
if single: if single:
mel = mel.unsqueeze(0) mel = mel.unsqueeze(0)
result = DecodingTask(model, options).run(mel) result = DecodingTask(model, options, resource_path).run(mel)
if single: if single:
result = result[0] result = result[0]
@ -1407,7 +1422,7 @@ def hann_window(n_fft: int=N_FFT):
@lru_cache(maxsize=None) @lru_cache(maxsize=None)
def mel_filters(device, n_mels: int=N_MELS) -> paddle.Tensor: def mel_filters(resource_path: str, n_mels: int=N_MELS) -> paddle.Tensor:
""" """
load the mel filterbank matrix for projecting STFT into a Mel spectrogram. load the mel filterbank matrix for projecting STFT into a Mel spectrogram.
Allows decoupling librosa dependency; saved using: Allows decoupling librosa dependency; saved using:
@ -1418,14 +1433,13 @@ def mel_filters(device, n_mels: int=N_MELS) -> paddle.Tensor:
) )
""" """
assert n_mels == 80, f"Unsupported n_mels: {n_mels}" assert n_mels == 80, f"Unsupported n_mels: {n_mels}"
with np.load( with np.load(os.path.join(resource_path, "assets", "mel_filters.npz")) as f:
os.path.join(
os.path.dirname(__file__), "assets", "mel_filters.npz")) as f:
return paddle.to_tensor(f[f"mel_{n_mels}"]) return paddle.to_tensor(f[f"mel_{n_mels}"])
def log_mel_spectrogram(audio: Union[str, np.ndarray, paddle.Tensor], def log_mel_spectrogram(audio: Union[str, np.ndarray, paddle.Tensor],
n_mels: int=N_MELS): n_mels: int=N_MELS,
resource_path: str=None):
""" """
Compute the log-Mel spectrogram of Compute the log-Mel spectrogram of
@ -1454,7 +1468,7 @@ def log_mel_spectrogram(audio: Union[str, np.ndarray, paddle.Tensor],
magnitudes = stft[:, :-1].abs()**2 magnitudes = stft[:, :-1].abs()**2
filters = mel_filters(audio, n_mels) filters = mel_filters(resource_path, n_mels)
mel_spec = filters @ magnitudes mel_spec = filters @ magnitudes
mel_spec = paddle.to_tensor(mel_spec.numpy().tolist()) mel_spec = paddle.to_tensor(mel_spec.numpy().tolist())

Loading…
Cancel
Save