【ASR】whisper large v3 (#4101)

* whisper large v3

* add convert.py

* mv nlp tokenizer to tiktoken.

* fix bug

* remove convert.py

* add new model file.

* fix

* fix version number

* fix version number

* fix some bug

* fix bug
pull/4115/head
zxcd 3 months ago committed by GitHub
parent 8f367b056f
commit 538f260061
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194

@ -74,10 +74,9 @@ class WhisperExecutor(BaseExecutor):
self.parser.add_argument( self.parser.add_argument(
'--size', '--size',
type=str, type=str,
default='large', default='turbo',
choices=['large', 'medium', 'base', 'small', 'tiny'], choices=['large', 'medium', 'base', 'small', 'tiny', 'turbo'],
help='Choose model size. now only support large, large:[whisper-large-16k]' help='Choose model size.')
)
self.parser.add_argument( self.parser.add_argument(
'--language', '--language',
type=str, type=str,
@ -141,7 +140,7 @@ class WhisperExecutor(BaseExecutor):
model_type: str='whisper', model_type: str='whisper',
lang: str='', lang: str='',
task: str='transcribe', task: str='transcribe',
size: str='large', size: str='turbo',
language: str='None', language: str='None',
sample_rate: int=16000, sample_rate: int=16000,
cfg_path: Optional[os.PathLike]=None, cfg_path: Optional[os.PathLike]=None,
@ -200,6 +199,7 @@ class WhisperExecutor(BaseExecutor):
# load model # load model
model_dict = paddle.load(self.ckpt_path) model_dict = paddle.load(self.ckpt_path)
dims = ModelDimensions(**model_dict["dims"]) dims = ModelDimensions(**model_dict["dims"])
self.dims = dims
self.model = Whisper(dims) self.model = Whisper(dims)
self.model.load_dict(model_dict) self.model.load_dict(model_dict)
self.model.eval() self.model.eval()
@ -251,8 +251,11 @@ class WhisperExecutor(BaseExecutor):
logger.debug(f"audio shape: {audio.shape}") logger.debug(f"audio shape: {audio.shape}")
# fbank # fbank
audio = log_mel_spectrogram(audio, resource_path=self.resource_path) audio = log_mel_spectrogram(
audio,
resource_path=self.resource_path,
n_mels=self.dims.n_mels,
padding=480000)
audio_len = paddle.to_tensor(audio.shape[0]).unsqueeze(axis=0) audio_len = paddle.to_tensor(audio.shape[0]).unsqueeze(axis=0)
self._inputs["audio"] = audio self._inputs["audio"] = audio
@ -275,7 +278,6 @@ class WhisperExecutor(BaseExecutor):
cfg.temperature_increment_on_fallback)) cfg.temperature_increment_on_fallback))
else: else:
temperature = [cfg.temperature] temperature = [cfg.temperature]
self._outputs["result"] = self.model.transcribe( self._outputs["result"] = self.model.transcribe(
audio, audio,
verbose=cfg.verbose, verbose=cfg.verbose,

@ -617,6 +617,24 @@ whisper_dynamic_pretrained_models = {
'resource_data_md5': 'resource_data_md5':
'37a0a8abdb3641a51194f79567a93b61', '37a0a8abdb3641a51194f79567a93b61',
}, },
'1.5': {
'url':
'https://paddlespeech.bj.bcebos.com/whisper/whisper_model_20250825/whisper-large-model.tar.gz',
'md5':
'9ebbd228fa07ca4557e5da863dac2982',
'cfg_path':
'whisper.yaml',
'ckpt_path':
'whisper-large-model',
'model':
'whisper-large-model.pdparams',
'params':
'whisper-large-model.pdparams',
'resource_data':
'https://paddlespeech.bj.bcebos.com/whisper/whisper_model_20250825/assets.tar',
'resource_data_md5':
'dd61d092d362f1fdbae6ede08282e177',
},
}, },
"whisper-base-en-16k": { "whisper-base-en-16k": {
'1.3': { '1.3': {
@ -637,6 +655,24 @@ whisper_dynamic_pretrained_models = {
'resource_data_md5': 'resource_data_md5':
'37a0a8abdb3641a51194f79567a93b61', '37a0a8abdb3641a51194f79567a93b61',
}, },
'1.5': {
'url':
'https://paddlespeech.bj.bcebos.com/whisper/whisper_model_20250825/whisper-base-en-model.tar.gz',
'md5':
'376617a9c5f36404f50dde3708bac0c6',
'cfg_path':
'whisper.yaml',
'ckpt_path':
'whisper-base-en-model',
'model':
'whisper-base-en-model.pdparams',
'params':
'whisper-base-en-model.pdparams',
'resource_data':
'https://paddlespeech.bj.bcebos.com/whisper/whisper_model_20250825/assets.tar',
'resource_data_md5':
'dd61d092d362f1fdbae6ede08282e177',
},
}, },
"whisper-base-16k": { "whisper-base-16k": {
'1.3': { '1.3': {
@ -657,6 +693,24 @@ whisper_dynamic_pretrained_models = {
'resource_data_md5': 'resource_data_md5':
'37a0a8abdb3641a51194f79567a93b61', '37a0a8abdb3641a51194f79567a93b61',
}, },
'1.5': {
'url':
'https://paddlespeech.bj.bcebos.com/whisper/whisper_model_20250825/whisper-base-model.tar.gz',
'md5':
'61836cb29c93048621f83364d83b532b',
'cfg_path':
'whisper.yaml',
'ckpt_path':
'whisper-base-model',
'model':
'whisper-base-model.pdparams',
'params':
'whisper-base-model.pdparams',
'resource_data':
'https://paddlespeech.bj.bcebos.com/whisper/whisper_model_20250825/assets.tar',
'resource_data_md5':
'dd61d092d362f1fdbae6ede08282e177',
},
}, },
"whisper-medium-en-16k": { "whisper-medium-en-16k": {
'1.3': { '1.3': {
@ -677,6 +731,24 @@ whisper_dynamic_pretrained_models = {
'resource_data_md5': 'resource_data_md5':
'37a0a8abdb3641a51194f79567a93b61', '37a0a8abdb3641a51194f79567a93b61',
}, },
'1.5': {
'url':
'https://paddlespeech.bj.bcebos.com/whisper/whisper_model_20250825/whisper-medium-en-model.tar.gz',
'md5':
'ac01145c5de962f1416f3d98171be559',
'cfg_path':
'whisper.yaml',
'ckpt_path':
'whisper-medium-en-model',
'model':
'whisper-medium-en-model.pdparams',
'params':
'whisper-medium-en-model.pdparams',
'resource_data':
'https://paddlespeech.bj.bcebos.com/whisper/whisper_model_20250825/assets.tar',
'resource_data_md5':
'dd61d092d362f1fdbae6ede08282e177',
},
}, },
"whisper-medium-16k": { "whisper-medium-16k": {
'1.3': { '1.3': {
@ -697,6 +769,24 @@ whisper_dynamic_pretrained_models = {
'resource_data_md5': 'resource_data_md5':
'37a0a8abdb3641a51194f79567a93b61', '37a0a8abdb3641a51194f79567a93b61',
}, },
'1.5': {
'url':
'https://paddlespeech.bj.bcebos.com/whisper/whisper_model_20250825/whisper-medium-model.tar.gz',
'md5':
'07770819961d1fe795facd3666f8db17',
'cfg_path':
'whisper.yaml',
'ckpt_path':
'whisper-medium-model',
'model':
'whisper-medium-model.pdparams',
'params':
'whisper-medium-model.pdparams',
'resource_data':
'https://paddlespeech.bj.bcebos.com/whisper/whisper_model_20250825/assets.tar',
'resource_data_md5':
'dd61d092d362f1fdbae6ede08282e177',
},
}, },
"whisper-small-en-16k": { "whisper-small-en-16k": {
'1.3': { '1.3': {
@ -717,6 +807,24 @@ whisper_dynamic_pretrained_models = {
'resource_data_md5': 'resource_data_md5':
'37a0a8abdb3641a51194f79567a93b61', '37a0a8abdb3641a51194f79567a93b61',
}, },
'1.5': {
'url':
'https://paddlespeech.bj.bcebos.com/whisper/whisper_model_20250825/whisper-small-en-model.tar.gz',
'md5':
'67af14156b93f49ae738a17204189e46',
'cfg_path':
'whisper.yaml',
'ckpt_path':
'whisper-small-en-model',
'model':
'whisper-small-en-model.pdparams',
'params':
'whisper-small-en-model.pdparams',
'resource_data':
'https://paddlespeech.bj.bcebos.com/whisper/whisper_model_20250825/assets.tar',
'resource_data_md5':
'dd61d092d362f1fdbae6ede08282e177',
},
}, },
"whisper-small-16k": { "whisper-small-16k": {
'1.3': { '1.3': {
@ -737,6 +845,24 @@ whisper_dynamic_pretrained_models = {
'resource_data_md5': 'resource_data_md5':
'37a0a8abdb3641a51194f79567a93b61', '37a0a8abdb3641a51194f79567a93b61',
}, },
'1.5': {
'url':
'https://paddlespeech.bj.bcebos.com/whisper/whisper_model_20250825/whisper-small-model.tar.gz',
'md5':
'db53c4bf39a9ad46ef77e6f9a37200b6',
'cfg_path':
'whisper.yaml',
'ckpt_path':
'whisper-small-model',
'model':
'whisper-small-model.pdparams',
'params':
'whisper-small-model.pdparams',
'resource_data':
'https://paddlespeech.bj.bcebos.com/whisper/whisper_model_20250825/assets.tar',
'resource_data_md5':
'dd61d092d362f1fdbae6ede08282e177',
},
}, },
"whisper-tiny-en-16k": { "whisper-tiny-en-16k": {
'1.3': { '1.3': {
@ -757,6 +883,24 @@ whisper_dynamic_pretrained_models = {
'resource_data_md5': 'resource_data_md5':
'37a0a8abdb3641a51194f79567a93b61', '37a0a8abdb3641a51194f79567a93b61',
}, },
'1.5': {
'url':
'https://paddlespeech.bj.bcebos.com/whisper/whisper_model_20250825/whisper-tiny-en-model.tar.gz',
'md5':
'f91f8447d8b37ed13f4327ef6565b094',
'cfg_path':
'whisper.yaml',
'ckpt_path':
'whisper-tiny-en-model',
'model':
'whisper-tiny-en-model.pdparams',
'params':
'whisper-tiny-en-model.pdparams',
'resource_data':
'https://paddlespeech.bj.bcebos.com/whisper/whisper_model_20250825/assets.tar',
'resource_data_md5':
'dd61d092d362f1fdbae6ede08282e177',
},
}, },
"whisper-tiny-16k": { "whisper-tiny-16k": {
'1.3': { '1.3': {
@ -777,6 +921,44 @@ whisper_dynamic_pretrained_models = {
'resource_data_md5': 'resource_data_md5':
'37a0a8abdb3641a51194f79567a93b61', '37a0a8abdb3641a51194f79567a93b61',
}, },
'1.5': {
'url':
'https://paddlespeech.bj.bcebos.com/whisper/whisper_model_20250825/whisper-tiny-model.tar.gz',
'md5':
'6f2209ac656ff12de085c824363316e2',
'cfg_path':
'whisper.yaml',
'ckpt_path':
'whisper-tiny-model',
'model':
'whisper-tiny-model.pdparams',
'params':
'whisper-tiny-model.pdparams',
'resource_data':
'https://paddlespeech.bj.bcebos.com/whisper/whisper_model_20250825/assets.tar',
'resource_data_md5':
'dd61d092d362f1fdbae6ede08282e177',
},
},
"whisper-turbo-16k": {
'1.5': {
'url':
'https://paddlespeech.bj.bcebos.com/whisper/whisper_model_20250825/whisper-turbo-model.tar.gz',
'md5':
'fe2dd1a1d6eb8e6d017aafc7d5f62336',
'cfg_path':
'whisper.yaml',
'ckpt_path':
'whisper-turbo-model',
'model':
'whisper-turbo-model.pdparams',
'params':
'whisper-turbo-model.pdparams',
'resource_data':
'https://paddlespeech.bj.bcebos.com/whisper/whisper_model_20250825/assets.tar',
'resource_data_md5':
'dd61d092d362f1fdbae6ede08282e177',
},
}, },
} }

@ -45,6 +45,7 @@ class WhisperInfer():
model_dict = paddle.load(self.config.model_file) model_dict = paddle.load(self.config.model_file)
config.pop("model_file") config.pop("model_file")
dims = ModelDimensions(**model_dict["dims"]) dims = ModelDimensions(**model_dict["dims"])
self.dims = dims
self.model = Whisper(dims) self.model = Whisper(dims)
self.model.load_dict(model_dict) self.model.load_dict(model_dict)
@ -64,8 +65,10 @@ class WhisperInfer():
#load audio #load audio
mel = log_mel_spectrogram( mel = log_mel_spectrogram(
args.audio_file, resource_path=config.resource_path) args.audio_file,
resource_path=config.resource_path,
n_mels=self.dims.n_mels,
padding=480000)
result = transcribe( result = transcribe(
self.model, mel, temperature=temperature, **config) self.model, mel, temperature=temperature, **config)
if args.result_file is not None: if args.result_file is not None:

@ -2,17 +2,19 @@
# Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved. # Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved.
# #
# Modified from OpenAI Whisper 2022 (https://github.com/openai/whisper/whisper/tokenizer.py) # Modified from OpenAI Whisper 2022 (https://github.com/openai/whisper/whisper/tokenizer.py)
import base64
import os import os
import string
from dataclasses import dataclass from dataclasses import dataclass
from dataclasses import field
from functools import cached_property
from functools import lru_cache from functools import lru_cache
from typing import Dict
from typing import List from typing import List
from typing import Optional from typing import Optional
from typing import Tuple from typing import Tuple
from typing import Union
import numpy as np import tiktoken
import paddle
from paddlenlp.transformers import GPTTokenizer
LANGUAGES = { LANGUAGES = {
"en": "english", "en": "english",
@ -35,7 +37,7 @@ LANGUAGES = {
"hi": "hindi", "hi": "hindi",
"fi": "finnish", "fi": "finnish",
"vi": "vietnamese", "vi": "vietnamese",
"iw": "hebrew", "he": "hebrew",
"uk": "ukrainian", "uk": "ukrainian",
"el": "greek", "el": "greek",
"ms": "malay", "ms": "malay",
@ -114,6 +116,7 @@ LANGUAGES = {
"ba": "bashkir", "ba": "bashkir",
"jw": "javanese", "jw": "javanese",
"su": "sundanese", "su": "sundanese",
"yue": "cantonese",
} }
# language code lookup by name, with a few language aliases # language code lookup by name, with a few language aliases
@ -130,134 +133,122 @@ TO_LANGUAGE_CODE = {
"moldovan": "ro", "moldovan": "ro",
"sinhalese": "si", "sinhalese": "si",
"castilian": "es", "castilian": "es",
"mandarin": "zh",
} }
@dataclass(frozen=True) @dataclass
class Tokenizer: class Tokenizer:
"""A thin wrapper around `GPTTokenizer` providing quick access to special tokens""" """A thin wrapper around `tiktoken` providing quick access to special tokens"""
encoding: tiktoken.Encoding
num_languages: int
language: Optional[str] = None
task: Optional[str] = None
sot_sequence: Tuple[int] = ()
special_tokens: Dict[str, int] = field(default_factory=dict)
def __post_init__(self):
for special in self.encoding.special_tokens_set:
special_token = self.encoding.encode_single_token(special)
self.special_tokens[special] = special_token
sot: int = self.special_tokens["<|startoftranscript|>"]
translate: int = self.special_tokens["<|translate|>"]
transcribe: int = self.special_tokens["<|transcribe|>"]
langs = tuple(LANGUAGES.keys())[:self.num_languages]
sot_sequence = [sot]
if self.language is not None:
sot_sequence.append(sot + 1 + langs.index(self.language))
if self.task is not None:
task_token: int = transcribe if self.task == "transcribe" else translate
sot_sequence.append(task_token)
tokenizer: "GPTTokenizer" self.sot_sequence = tuple(sot_sequence)
language: Optional[str]
sot_sequence: Tuple[int]
def encode(self, text, **kwargs): def encode(self, text, **kwargs):
return self.tokenizer.encode(text, **kwargs) return self.encoding.encode(text, **kwargs)
def decode(self,
token_ids: Union[int, List[int], np.ndarray, paddle.Tensor],
**kwargs):
if len(token_ids) > 1:
ids_list = []
for ids in token_ids:
if paddle.is_tensor(ids):
ids = ids.item()
if ids < len(self.tokenizer):
ids_list.append(ids)
token_ids = ids_list
elif len(token_ids) == 1:
token_ids = token_ids[0]
else:
raise ValueError(f"token_ids {token_ids} load error.")
return self.tokenizer.decode(token_ids, **kwargs) def decode(self, token_ids: List[int], **kwargs) -> str:
token_ids = [t for t in token_ids if t < self.timestamp_begin]
return self.encoding.decode(token_ids, **kwargs)
def decode_with_timestamps(self, tokens) -> str: def decode_with_timestamps(self, token_ids: List[int], **kwargs) -> str:
""" """
Timestamp tokens are above the special tokens' id range and are ignored by `decode()`. Timestamp tokens are above other special tokens' id range and are ignored by `decode()`.
This method decodes given tokens with timestamps tokens annotated, e.g. "<|1.08|>". This method decodes given tokens with timestamps tokens annotated, e.g. "<|1.08|>".
""" """
outputs = [[]] return self.encoding.decode(token_ids, **kwargs)
for token in tokens:
if token >= self.timestamp_begin:
timestamp = f"<|{(token - self.timestamp_begin) * 0.02:.2f}|>"
outputs.append(timestamp)
outputs.append([])
else:
outputs[-1].append(token)
outputs = [
s if isinstance(s, str) else self.tokenizer.decode(s)
for s in outputs
]
return "".join(outputs)
@property @cached_property
@lru_cache()
def eot(self) -> int: def eot(self) -> int:
return self.tokenizer.eos_token_id return self.encoding.eot_token
@cached_property
def transcribe(self) -> int:
return self.special_tokens["<|transcribe|>"]
@cached_property
def translate(self) -> int:
return self.special_tokens["<|translate|>"]
@property @cached_property
@lru_cache()
def sot(self) -> int: def sot(self) -> int:
return self._get_single_token_id("<|startoftranscript|>") return self.special_tokens["<|startoftranscript|>"]
@property @cached_property
@lru_cache()
def sot_lm(self) -> int: def sot_lm(self) -> int:
return self._get_single_token_id("<|startoflm|>") return self.special_tokens["<|startoflm|>"]
@property @cached_property
@lru_cache()
def sot_prev(self) -> int: def sot_prev(self) -> int:
return self._get_single_token_id("<|startofprev|>") return self.special_tokens["<|startofprev|>"]
@property @cached_property
@lru_cache()
def no_speech(self) -> int: def no_speech(self) -> int:
return self._get_single_token_id("<|nospeech|>") return self.special_tokens["<|nospeech|>"]
@property @cached_property
@lru_cache()
def no_timestamps(self) -> int: def no_timestamps(self) -> int:
return self._get_single_token_id("<|notimestamps|>") return self.special_tokens["<|notimestamps|>"]
@property @cached_property
@lru_cache()
def timestamp_begin(self) -> int: def timestamp_begin(self) -> int:
return self.tokenizer.all_special_ids[-1] + 1 return self.special_tokens["<|0.00|>"]
@property @cached_property
@lru_cache()
def language_token(self) -> int: def language_token(self) -> int:
"""Returns the token id corresponding to the value of the `language` field""" """Returns the token id corresponding to the value of the `language` field"""
if self.language is None: if self.language is None:
raise ValueError( raise ValueError(
"This tokenizer does not have language token configured") "This tokenizer does not have language token configured")
additional_tokens = dict( return self.to_language_token(self.language)
zip(
self.tokenizer.additional_special_tokens,
self.tokenizer.additional_special_tokens_ids, ))
candidate = f"<|{self.language}|>"
if candidate in additional_tokens:
return additional_tokens[candidate]
raise KeyError(f"Language {self.language} not found in tokenizer.") def to_language_token(self, language):
if token := self.special_tokens.get(f"<|{language}|>", None):
return token
@property raise KeyError(f"Language {language} not found in tokenizer.")
@lru_cache()
@cached_property
def all_language_tokens(self) -> Tuple[int]: def all_language_tokens(self) -> Tuple[int]:
result = [] result = []
for token, token_id in zip( for token, token_id in self.special_tokens.items():
self.tokenizer.additional_special_tokens,
self.tokenizer.additional_special_tokens_ids, ):
if token.strip("<|>") in LANGUAGES: if token.strip("<|>") in LANGUAGES:
result.append(token_id) result.append(token_id)
return tuple(result) return tuple(result)[:self.num_languages]
@property @cached_property
@lru_cache()
def all_language_codes(self) -> Tuple[str]: def all_language_codes(self) -> Tuple[str]:
return tuple( return tuple(
self.decode([l]).strip("<|>") for l in self.all_language_tokens) self.decode([_l]).strip("<|>") for _l in self.all_language_tokens)
@property @cached_property
@lru_cache()
def sot_sequence_including_notimestamps(self) -> Tuple[int]: def sot_sequence_including_notimestamps(self) -> Tuple[int]:
return tuple(list(self.sot_sequence) + [self.no_timestamps]) return tuple(list(self.sot_sequence) + [self.no_timestamps])
@property @cached_property
@lru_cache()
def non_speech_tokens(self) -> Tuple[int]: def non_speech_tokens(self) -> Tuple[int]:
""" """
Returns the list of tokens to suppress in order to avoid any speaker tags or non-speech Returns the list of tokens to suppress in order to avoid any speaker tags or non-speech
@ -269,9 +260,10 @@ class Tokenizer:
keeping basic punctuations like commas, periods, question marks, exclamation points, etc. keeping basic punctuations like commas, periods, question marks, exclamation points, etc.
""" """
symbols = list("\"#()*+/:;<=>@[\\]^_`{|}~「」『』") symbols = list('"#()*+/:;<=>@[\\]^_`{|}~「」『』')
symbols += "<< >> <<< >>> -- --- -( -[ (' (\" (( )) ((( ))) [[ ]] {{ }} ♪♪ ♪♪♪".split( symbols += (
) "<< >> <<< >>> -- --- -( -[ (' (\" (( )) ((( ))) [[ ]] {{ }} ♪♪ ♪♪♪".
split())
# symbols that may be a single token or multiple tokens depending on the tokenizer. # symbols that may be a single token or multiple tokens depending on the tokenizer.
# In case they're multiple tokens, suppress the first token, which is safe because: # In case they're multiple tokens, suppress the first token, which is safe because:
@ -281,45 +273,99 @@ class Tokenizer:
assert all(0x2640 <= ord(c) <= 0x267F for c in miscellaneous) assert all(0x2640 <= ord(c) <= 0x267F for c in miscellaneous)
# allow hyphens "-" and single quotes "'" between words, but not at the beginning of a word # allow hyphens "-" and single quotes "'" between words, but not at the beginning of a word
result = { result = {self.encoding.encode(" -")[0], self.encoding.encode(" '")[0]}
self.tokenizer.encode(" -").input_ids[0],
self.tokenizer.encode(" '").input_ids[0]
}
for symbol in symbols + list(miscellaneous): for symbol in symbols + list(miscellaneous):
for tokens in [ for tokens in [
self.tokenizer.encode(symbol).input_ids, self.encoding.encode(symbol),
self.tokenizer.encode(" " + symbol).input_ids self.encoding.encode(" " + symbol),
]: ]:
if len(tokens) == 1 or symbol in miscellaneous: if len(tokens) == 1 or symbol in miscellaneous:
result.add(tokens[0]) result.add(tokens[0])
return tuple(sorted(result)) return tuple(sorted(result))
def _get_single_token_id(self, text) -> int: def split_to_word_tokens(self, tokens: List[int]):
tokens = self.tokenizer.encode(text).input_ids if self.language in {"zh", "ja", "th", "lo", "my", "yue"}:
assert len(tokens) == 1, f"{text} is not encoded as a single token" # These languages don't typically use spaces, so it is difficult to split words
return tokens[0] # without morpheme analysis. Here, we instead split words at any
# position where the tokens are decoded as valid unicode points
return self.split_tokens_on_unicode(tokens)
return self.split_tokens_on_spaces(tokens)
def split_tokens_on_unicode(self, tokens: List[int]):
decoded_full = self.decode_with_timestamps(tokens)
replacement_char = "\ufffd"
words = []
word_tokens = []
current_tokens = []
unicode_offset = 0
for token in tokens:
current_tokens.append(token)
decoded = self.decode_with_timestamps(current_tokens)
if (replacement_char not in decoded or
decoded_full[unicode_offset + decoded.index(
replacement_char)] == replacement_char):
words.append(decoded)
word_tokens.append(current_tokens)
current_tokens = []
unicode_offset += len(decoded)
return words, word_tokens
def split_tokens_on_spaces(self, tokens: List[int]):
subwords, subword_tokens_list = self.split_tokens_on_unicode(tokens)
words = []
word_tokens = []
for subword, subword_tokens in zip(subwords, subword_tokens_list):
special = subword_tokens[0] >= self.eot
with_space = subword.startswith(" ")
punctuation = subword.strip() in string.punctuation
if special or with_space or punctuation or len(words) == 0:
words.append(subword)
word_tokens.append(subword_tokens)
else:
words[-1] = words[-1] + subword
word_tokens[-1].extend(subword_tokens)
return words, word_tokens
@lru_cache(maxsize=None) @lru_cache(maxsize=None)
def build_tokenizer(resource_path: str, name: str="gpt2"): def get_encoding(resource_path: str, name: str="gpt2", num_languages: int=99):
os.environ["TOKENIZERS_PARALLELISM"] = "false" vocab_path = os.path.join(resource_path, "assets", f"{name}.tiktoken")
path = os.path.join(resource_path, "assets", name) ranks = {
tokenizer = GPTTokenizer.from_pretrained(path) base64.b64decode(token): int(rank)
for token, rank in (line.split() for line in open(vocab_path) if line)
}
n_vocab = len(ranks)
special_tokens = {}
specials = [ specials = [
"<|endoftext|>",
"<|startoftranscript|>", "<|startoftranscript|>",
* [f"<|{lang}|>" for lang in LANGUAGES.keys()], * [f"<|{lang}|>" for lang in list(LANGUAGES.keys())[:num_languages]],
"<|translate|>", "<|translate|>",
"<|transcribe|>", "<|transcribe|>",
"<|startoflm|>", "<|startoflm|>",
"<|startofprev|>", "<|startofprev|>",
"<|nospeech|>", "<|nospeech|>",
"<|notimestamps|>", "<|notimestamps|>",
* [f"<|{i * 0.02:.2f}|>" for i in range(1501)],
] ]
for token in specials:
tokenizer.add_special_tokens(dict(additional_special_tokens=specials)) special_tokens[token] = n_vocab
return tokenizer n_vocab += 1
return tiktoken.Encoding(
name=os.path.basename(vocab_path),
explicit_n_vocab=n_vocab,
pat_str=r"""'s|'t|'re|'ve|'m|'ll|'d| ?\p{L}+| ?\p{N}+| ?[^\s\p{L}\p{N}]+|\s+(?!\S)|\s+""",
mergeable_ranks=ranks,
special_tokens=special_tokens, )
@lru_cache(maxsize=None) @lru_cache(maxsize=None)
@ -327,8 +373,11 @@ def get_tokenizer(
multilingual: bool, multilingual: bool,
resource_path: str, resource_path: str,
*, *,
num_languages: int=99,
language: Optional[str]=None,
task: Optional[str]=None, # Literal["transcribe", "translate", None] task: Optional[str]=None, # Literal["transcribe", "translate", None]
language: Optional[str]=None, ) -> Tokenizer: ) -> Tokenizer:
if language is not None: if language is not None:
language = language.lower() language = language.lower()
if language not in LANGUAGES: if language not in LANGUAGES:
@ -338,29 +387,21 @@ def get_tokenizer(
raise ValueError(f"Unsupported language: {language}") raise ValueError(f"Unsupported language: {language}")
if multilingual: if multilingual:
tokenizer_name = "multilingual" encoding_name = "multilingual"
task = task or "transcribe"
language = language or "en" language = language or "en"
task = task or "transcribe"
else: else:
tokenizer_name = "gpt2" encoding_name = "gpt2"
task = None
language = None language = None
task = None
tokenizer = build_tokenizer( encoding = get_encoding(
resource_path=resource_path, name=tokenizer_name) resource_path=resource_path,
all_special_ids: List[int] = tokenizer.all_special_ids name=encoding_name,
sot: int = all_special_ids[1] num_languages=num_languages)
translate: int = all_special_ids[-6]
transcribe: int = all_special_ids[-5]
langs = tuple(LANGUAGES.keys())
sot_sequence = [sot]
if language is not None:
sot_sequence.append(sot + 1 + langs.index(language))
if task is not None:
sot_sequence.append(transcribe if task == "transcribe" else translate)
return Tokenizer( return Tokenizer(
tokenizer=tokenizer, encoding=encoding,
num_languages=num_languages,
language=language, language=language,
sot_sequence=tuple(sot_sequence)) task=task)

@ -33,13 +33,18 @@ logger = Log(__name__).getlog()
_MODELS = ["large"] _MODELS = ["large"]
SAMPLE_RATE = 16000 SAMPLE_RATE = 16000
N_FFT = 400 N_FFT = 400
N_MELS = 80
HOP_LENGTH = 160 HOP_LENGTH = 160
CHUNK_LENGTH = 30 CHUNK_LENGTH = 30
N_SAMPLES = CHUNK_LENGTH * SAMPLE_RATE # 480000: number of samples in a chunk N_SAMPLES = CHUNK_LENGTH * SAMPLE_RATE # 480000: number of samples in a chunk
N_FRAMES = utils.exact_div( N_FRAMES = utils.exact_div(
N_SAMPLES, HOP_LENGTH) # 3000: number of frames in a mel spectrogram input N_SAMPLES, HOP_LENGTH) # 3000: number of frames in a mel spectrogram input
N_SAMPLES_PER_TOKEN = HOP_LENGTH * 2 # the initial convolutions has stride 2
FRAMES_PER_SECOND = utils.exact_div(SAMPLE_RATE,
HOP_LENGTH) # 10ms per audio frame
TOKENS_PER_SECOND = utils.exact_div(SAMPLE_RATE,
N_SAMPLES_PER_TOKEN) # 20ms per audio token
@dataclass @dataclass
class ModelDimensions: class ModelDimensions:
@ -378,7 +383,9 @@ def detect_language(
""" """
if tokenizer is None: if tokenizer is None:
tokenizer = get_tokenizer( tokenizer = get_tokenizer(
model.is_multilingual, resource_path=resource_path) multilingual=model.is_multilingual,
resource_path=resource_path,
num_languages=model.num_languages)
if tokenizer.language is None or tokenizer.language_token not in tokenizer.sot_sequence: if tokenizer.language is None or tokenizer.language_token not in tokenizer.sot_sequence:
raise ValueError( raise ValueError(
"This model doesn't have language tokens so it can't perform lang id" "This model doesn't have language tokens so it can't perform lang id"
@ -400,6 +407,7 @@ def detect_language(
# collect detected languages; suppress all non-language tokens # collect detected languages; suppress all non-language tokens
mask = paddle.ones(paddle.to_tensor(logits.shape[-1]), dtype=bool) mask = paddle.ones(paddle.to_tensor(logits.shape[-1]), dtype=bool)
mask[list(tokenizer.all_language_tokens)] = False mask[list(tokenizer.all_language_tokens)] = False
logits.contiguous()
logits[:, mask] = -np.inf logits[:, mask] = -np.inf
language_tokens = paddle.argmax(logits, axis=-1) language_tokens = paddle.argmax(logits, axis=-1)
language_token_probs = F.softmax(logits, axis=-1) language_token_probs = F.softmax(logits, axis=-1)
@ -428,6 +436,13 @@ def transcribe(
logprob_threshold: Optional[float]=-1.0, logprob_threshold: Optional[float]=-1.0,
no_speech_threshold: Optional[float]=0.6, no_speech_threshold: Optional[float]=0.6,
condition_on_previous_text: bool=True, condition_on_previous_text: bool=True,
initial_prompt: Optional[str]=None,
carry_initial_prompt: bool=False,
word_timestamps: bool=False,
prepend_punctuations: str="\"'“¿([{-",
append_punctuations: str="\"'.。,!?::”)]}、",
clip_timestamps: Union[str, List[float]]="0",
hallucination_silence_threshold: Optional[float]=None,
**decode_options, ): **decode_options, ):
""" """
Transcribe an audio file using Whisper Transcribe an audio file using Whisper
@ -476,8 +491,9 @@ def transcribe(
if dtype == np.float32: if dtype == np.float32:
decode_options["fp16"] = False decode_options["fp16"] = False
if decode_options.get("language") == 'None' or decode_options.get( content_frames = mel.shape[-1] - N_FRAMES
"language", None) is None: content_duration = float(content_frames * HOP_LENGTH / SAMPLE_RATE)
if decode_options.get("language", None) in {None, "None"}:
if not model.is_multilingual: if not model.is_multilingual:
decode_options["language"] = "en" decode_options["language"] = "en"
else: else:
@ -485,25 +501,47 @@ def transcribe(
print( print(
"Detecting language using up to the first 30 seconds. Use `--language` to specify the language" "Detecting language using up to the first 30 seconds. Use `--language` to specify the language"
) )
segment = pad_or_trim(mel, N_FRAMES) mel_segment = pad_or_trim(mel,
_, probs = model.detect_language(segment, resource_path) N_FRAMES).to(model.device).astype(dtype)
_, probs = model.detect_language(mel_segment, resource_path)
decode_options["language"] = max(probs, key=probs.get) decode_options["language"] = max(probs, key=probs.get)
if verbose is not None: if verbose is not None:
print( print(
f"Detected language: {LANGUAGES[decode_options['language']].title()}" f"Detected language: {LANGUAGES[decode_options['language']].title()}"
) )
language = decode_options["language"] language: str = decode_options["language"]
task = decode_options.get("task", "transcribe") task: str = decode_options.get("task", "transcribe")
tokenizer = get_tokenizer( tokenizer = get_tokenizer(
model.is_multilingual, multilingual=model.is_multilingual,
resource_path=resource_path, resource_path=resource_path,
num_languages=model.num_languages,
language=language, language=language,
task=task) task=task, )
if isinstance(clip_timestamps, str):
clip_timestamps = [
float(ts)
for ts in (clip_timestamps.split(",") if clip_timestamps else [])
]
seek_points: List[
int] = [round(ts * FRAMES_PER_SECOND) for ts in clip_timestamps]
if len(seek_points) == 0:
seek_points.append(0)
if len(seek_points) % 2 == 1:
seek_points.append(content_frames)
seek_clips: List[Tuple[int, int]] = list(
zip(seek_points[::2], seek_points[1::2]))
punctuation = "\"'“¿([{-\"'.。,!?::”)]}、"
if word_timestamps and task == "translate":
warnings.warn(
"Word-level timestamps on translations may not be reliable.")
def decode_with_fallback(segment: paddle.Tensor) -> DecodingResult: def decode_with_fallback(segment: paddle.Tensor) -> DecodingResult:
temperatures = [temperature] if isinstance(temperature, ( temperatures = ([temperature] if isinstance(temperature, (int, float))
int, float)) else temperature else temperature)
decode_result = None decode_result = None
for t in temperatures: for t in temperatures:
@ -517,20 +555,29 @@ def transcribe(
kwargs.pop("best_of", None) kwargs.pop("best_of", None)
options = DecodingOptions(**kwargs, temperature=t) options = DecodingOptions(**kwargs, temperature=t)
decode_result = model.decode(segment, options, resource_path) decode_result = model.decode(segment, options, resource_path)
needs_fallback = False needs_fallback = False
if compression_ratio_threshold is not None and decode_result.compression_ratio > compression_ratio_threshold: if (compression_ratio_threshold is not None and
decode_result.compression_ratio >
compression_ratio_threshold):
needs_fallback = True # too repetitive needs_fallback = True # too repetitive
if logprob_threshold is not None and decode_result.avg_logprob < logprob_threshold: if (logprob_threshold is not None and
decode_result.avg_logprob < logprob_threshold):
needs_fallback = True # average log probability is too low needs_fallback = True # average log probability is too low
if (no_speech_threshold is not None and
decode_result.no_speech_prob > no_speech_threshold and
logprob_threshold is not None and
decode_result.avg_logprob < logprob_threshold):
needs_fallback = False # silence
if not needs_fallback: if not needs_fallback:
break break
return decode_result return decode_result
seek = 0 clip_idx = 0
seek = seek_clips[clip_idx][0]
input_stride = utils.exact_div( input_stride = utils.exact_div(
N_FRAMES, model.dims.n_audio_ctx) # mel frames per output token: 2 N_FRAMES, model.dims.n_audio_ctx) # mel frames per output token: 2
time_precision = (input_stride * HOP_LENGTH / time_precision = (input_stride * HOP_LENGTH /
@ -539,127 +586,210 @@ def transcribe(
all_segments = [] all_segments = []
prompt_reset_since = 0 prompt_reset_since = 0
initial_prompt = decode_options.pop("initial_prompt", None) or [] remaining_prompt_length = model.dims.n_text_ctx // 2 - 1
if initial_prompt: if initial_prompt is not None:
initial_prompt = tokenizer.encode(" " + initial_prompt_tokens = tokenizer.encode(" " + initial_prompt.strip())
initial_prompt.strip()).input_ids all_tokens.extend(initial_prompt_tokens)
all_tokens.extend(initial_prompt) remaining_prompt_length -= len(initial_prompt_tokens)
else:
initial_prompt_tokens = []
def add_segment(*, def new_segment(*,
start: float, start: float,
end: float, end: float,
text_tokens: paddle.Tensor, tokens: paddle.Tensor,
result: DecodingResult): result: DecodingResult):
text = tokenizer.decode( tokens = tokens.tolist()
[token for token in text_tokens if token < tokenizer.eot]) text_tokens = [token for token in tokens if token < tokenizer.eot]
if len(text.strip()) == 0: # skip empty text output return {
return
all_segments.append({
"id": len(all_segments),
"seek": seek, "seek": seek,
"start": start, "start": start,
"end": end, "end": end,
"text": text, "text": tokenizer.decode(text_tokens),
"tokens": result.tokens, "tokens": tokens,
"temperature": result.temperature, "temperature": result.temperature,
"avg_logprob": result.avg_logprob, "avg_logprob": result.avg_logprob,
"compression_ratio": result.compression_ratio, "compression_ratio": result.compression_ratio,
"no_speech_prob": result.no_speech_prob, "no_speech_prob": result.no_speech_prob,
}) }
if verbose:
print(
f"[{utils.format_timestamp(start)} --> {utils.format_timestamp(end)}] {text}"
)
# show the progress bar when verbose is False (otherwise the transcribed text will be printed)
num_frames = mel.shape[-1]
previous_seek_value = seek
# show the progress bar when verbose is False (if True, transcribed text will be printed)
with tqdm.tqdm( with tqdm.tqdm(
total=num_frames, unit='frames', total=content_frames, unit="frames",
disable=verbose is not False) as pbar: disable=verbose is not False) as pbar:
while seek < num_frames: last_speech_timestamp = 0.0
timestamp_offset = float(seek * HOP_LENGTH / SAMPLE_RATE) # NOTE: This loop is obscurely flattened to make the diff readable.
segment = pad_or_trim(mel[:, seek:], N_FRAMES) # A later commit should turn this into a simpler nested loop.
segment_duration = segment.shape[-1] * HOP_LENGTH / SAMPLE_RATE # for seek_clip_start, seek_clip_end in seek_clips:
# while seek < seek_clip_end
while clip_idx < len(seek_clips):
seek_clip_start, seek_clip_end = seek_clips[clip_idx]
if seek < seek_clip_start:
seek = seek_clip_start
if seek >= seek_clip_end:
clip_idx += 1
if clip_idx < len(seek_clips):
seek = seek_clips[clip_idx][0]
continue
time_offset = float(seek * HOP_LENGTH / SAMPLE_RATE)
window_end_time = float(
(seek + N_FRAMES) * HOP_LENGTH / SAMPLE_RATE)
segment_size = min(N_FRAMES, content_frames - seek,
seek_clip_end - seek)
mel_segment = mel[:, seek:seek + segment_size]
segment_duration = segment_size * HOP_LENGTH / SAMPLE_RATE
mel_segment = pad_or_trim(mel_segment,
N_FRAMES).to(model.device).astype(dtype)
if carry_initial_prompt:
nignored = max(len(initial_prompt_tokens), prompt_reset_since)
remaining_prompt = all_tokens[nignored:][
-remaining_prompt_length:]
decode_options[
"prompt"] = initial_prompt_tokens + remaining_prompt
else:
decode_options["prompt"] = all_tokens[prompt_reset_since:] decode_options["prompt"] = all_tokens[prompt_reset_since:]
result: DecodingResult = decode_with_fallback(segment) result: DecodingResult = decode_with_fallback(mel_segment)
tokens = paddle.to_tensor(result.tokens) tokens = paddle.to_tensor(result.tokens)
if no_speech_threshold is not None: if no_speech_threshold is not None:
# no voice activity check # no voice activity check
should_skip = result.no_speech_prob > no_speech_threshold should_skip = result.no_speech_prob > no_speech_threshold
if logprob_threshold is not None and result.avg_logprob > logprob_threshold: if (logprob_threshold is not None and
result.avg_logprob > logprob_threshold):
# don't skip if the logprob is high enough, despite the no_speech_prob # don't skip if the logprob is high enough, despite the no_speech_prob
should_skip = False should_skip = False
if should_skip: if should_skip:
seek += segment.shape[ seek += segment_size # fast-forward to the next segment boundary
-1] # fast-forward to the next segment boundary
continue continue
previous_seek = seek
current_segments = []
# anomalous words are very long/short/improbable
def word_anomaly_score(word: dict) -> float:
probability = word.get("probability", 0.0)
duration = word["end"] - word["start"]
score = 0.0
if probability < 0.15:
score += 1.0
if duration < 0.133:
score += (0.133 - duration) * 15
if duration > 2.0:
score += duration - 2.0
return score
def is_segment_anomaly(segment: Optional[dict]) -> bool:
if segment is None or not segment["words"]:
return False
words = [
w for w in segment["words"] if w["word"] not in punctuation
]
words = words[:8]
score = sum(word_anomaly_score(w) for w in words)
return score >= 3 or score + 0.01 >= len(words)
def next_words_segment(segments: List[dict]) -> Optional[dict]:
return next((s for s in segments if s["words"]), None)
timestamp_tokens: paddle.Tensor = tokens.greater_equal( timestamp_tokens: paddle.Tensor = tokens.greater_equal(
paddle.to_tensor(tokenizer.timestamp_begin)) paddle.to_tensor(tokenizer.timestamp_begin))
single_timestamp_ending = timestamp_tokens[
-2:].tolist() == [False, True]
consecutive = paddle.where(timestamp_tokens[:-1] & timestamp_tokens[ consecutive = paddle.where(timestamp_tokens[:-1] & timestamp_tokens[
1:])[0] 1:])[0]
if len( if consecutive.numel() != 0:
consecutive
) > 0: # if the output contains two consecutive timestamp tokens
consecutive = paddle.add(consecutive, paddle.to_tensor(1)) consecutive = paddle.add(consecutive, paddle.to_tensor(1))
if len(consecutive) > 0:
# if the output contains two consecutive timestamp tokens
slices = consecutive.tolist()
if single_timestamp_ending:
slices.append(len(tokens))
last_slice = 0 last_slice = 0
for current_slice in consecutive: for current_slice in slices:
sliced_tokens = tokens[last_slice:current_slice] sliced_tokens = tokens[last_slice:current_slice]
start_timestamp_position = ( start_timestamp_pos = (
sliced_tokens[0].item() - tokenizer.timestamp_begin) sliced_tokens[0].item() - tokenizer.timestamp_begin)
end_timestamp_position = ( end_timestamp_pos = (
sliced_tokens[-1].item() - tokenizer.timestamp_begin) sliced_tokens[-1].item() - tokenizer.timestamp_begin)
add_segment( current_segments.append(
start=timestamp_offset + start_timestamp_position * new_segment(
start=time_offset + start_timestamp_pos *
time_precision, time_precision,
end=timestamp_offset + end_timestamp_position * end=time_offset + end_timestamp_pos *
time_precision, time_precision,
text_tokens=sliced_tokens[1:-1], tokens=sliced_tokens,
result=result, ) result=result, ))
last_slice = current_slice last_slice = current_slice
last_timestamp_position = (
tokens[last_slice - 1].item() - tokenizer.timestamp_begin) if single_timestamp_ending:
seek += last_timestamp_position * input_stride # single timestamp at the end means no speech after the last timestamp.
all_tokens.extend(tokens[:last_slice + 1].tolist()) seek += segment_size
else:
# otherwise, ignore the unfinished segment and seek to the last timestamp
last_timestamp_pos = (tokens[last_slice - 1].item() -
tokenizer.timestamp_begin)
seek += last_timestamp_pos * input_stride
else: else:
duration = segment_duration duration = segment_duration
timestamps = tokens[timestamp_tokens.nonzero().flatten()] timestamps = tokens[timestamp_tokens.nonzero().flatten()]
if len(timestamps) > 0 and timestamps[ if (len(timestamps) > 0 and
-1].item() != tokenizer.timestamp_begin: timestamps[-1].item() != tokenizer.timestamp_begin):
# no consecutive timestamps but it has a timestamp; use the last one. # no consecutive timestamps but it has a timestamp; use the last one.
# single timestamp at the end means no speech after the last timestamp. last_timestamp_pos = (
last_timestamp_position = timestamps[ timestamps[-1].item() - tokenizer.timestamp_begin)
-1].item() - tokenizer.timestamp_begin duration = last_timestamp_pos * time_precision
duration = last_timestamp_position * time_precision
current_segments.append(
add_segment( new_segment(
start=timestamp_offset, start=time_offset,
end=timestamp_offset + duration, end=time_offset + duration,
text_tokens=tokens, tokens=tokens,
result=result, ) result=result, ))
seek += segment_size
seek += segment.shape[-1] if verbose:
all_tokens.extend(tokens.tolist()) for segment in current_segments:
start, end, text = segment["start"], segment[
"end"], segment["text"]
line = f"[{utils.format_timestamp(start)} --> {utils.format_timestamp(end)}] {text}"
print(line)
# if a segment is instantaneous or does not contain text, clear it
for i, segment in enumerate(current_segments):
if segment["start"] == segment["end"] or segment[
"text"].strip() == "":
segment["text"] = ""
segment["tokens"] = []
segment["words"] = []
all_segments.extend(
[{
"id": i,
**
segment
}
for i, segment in enumerate(
current_segments, start=len(all_segments))])
all_tokens.extend([
token
for segment in current_segments for token in segment["tokens"]
])
if not condition_on_previous_text or result.temperature > 0.5: if not condition_on_previous_text or result.temperature > 0.5:
# do not feed the prompt tokens if a high temperature was used # do not feed the prompt tokens if a high temperature was used
prompt_reset_since = len(all_tokens) prompt_reset_since = len(all_tokens)
# update progress bar # update progress bar
pbar.update(min(num_frames, seek) - previous_seek_value) pbar.update(min(content_frames, seek) - previous_seek)
previous_seek_value = seek
return dict( return dict(
text=tokenizer.decode(all_tokens[len(initial_prompt):]), text=tokenizer.decode(all_tokens[len(initial_prompt_tokens):]),
segments=all_segments, segments=all_segments,
language=language) language=language, )
class SequenceRanker: class SequenceRanker:
@ -776,11 +906,11 @@ class GreedyDecoder(TokenDecoder):
next_tokens.shape[0] * next_tokens.shape[1], next_tokens.shape[0] * next_tokens.shape[1],
]) ])
logprobs = F.log_softmax(logits, axis=-1, dtype=paddle.float32) logprobs = F.log_softmax(logits, axis=-1, dtype="float32")
current_logprobs = logprobs[paddle.arange(logprobs.shape[0]), current_logprobs = logprobs[paddle.arange(logprobs.shape[0]),
next_tokens] next_tokens]
sum_logprobs += current_logprobs * paddle.to_tensor( sum_logprobs += current_logprobs * paddle.to_tensor(
(tokens[:, -1] != self.eot), dtype=paddle.float32) (tokens[:, -1] != self.eot), dtype="float32")
next_tokens[tokens[:, -1] == self.eot] = self.eot next_tokens[tokens[:, -1] == self.eot] = self.eot
tokens = paddle.concat([tokens, next_tokens[:, None]], axis=-1) tokens = paddle.concat([tokens, next_tokens[:, None]], axis=-1)
@ -928,8 +1058,9 @@ class SuppressBlank(LogitFilter):
def apply(self, logits: paddle.Tensor, tokens: paddle.Tensor): def apply(self, logits: paddle.Tensor, tokens: paddle.Tensor):
if tokens.shape[1] == self.sample_begin: if tokens.shape[1] == self.sample_begin:
logits[:, self.tokenizer.encode(" ").input_ids + logits.contiguous()
[self.tokenizer.eot]] = -np.inf logits[:, self.tokenizer.encode(" ") + [self.tokenizer.eot
]] = -np.inf
class SuppressTokens(LogitFilter): class SuppressTokens(LogitFilter):
@ -937,6 +1068,7 @@ class SuppressTokens(LogitFilter):
self.suppress_tokens = list(suppress_tokens) self.suppress_tokens = list(suppress_tokens)
def apply(self, logits: paddle.Tensor, tokens: paddle.Tensor): def apply(self, logits: paddle.Tensor, tokens: paddle.Tensor):
logits.contiguous()
logits[:, self.suppress_tokens] = -np.inf logits[:, self.suppress_tokens] = -np.inf
@ -952,6 +1084,7 @@ class ApplyTimestampRules(LogitFilter):
def apply(self, logits: paddle.Tensor, tokens: paddle.Tensor): def apply(self, logits: paddle.Tensor, tokens: paddle.Tensor):
# suppress <|notimestamps|> which is handled by without_timestamps # suppress <|notimestamps|> which is handled by without_timestamps
if self.tokenizer.no_timestamps is not None: if self.tokenizer.no_timestamps is not None:
logits.contiguous()
logits[:, self.tokenizer.no_timestamps] = -np.inf logits[:, self.tokenizer.no_timestamps] = -np.inf
# timestamps have to appear in pairs, except directly before EOT; mask logits accordingly # timestamps have to appear in pairs, except directly before EOT; mask logits accordingly
@ -972,6 +1105,7 @@ class ApplyTimestampRules(LogitFilter):
if tokens.shape[ if tokens.shape[
1] == self.sample_begin and self.max_initial_timestamp_index is not None: 1] == self.sample_begin and self.max_initial_timestamp_index is not None:
last_allowed = self.tokenizer.timestamp_begin + self.max_initial_timestamp_index last_allowed = self.tokenizer.timestamp_begin + self.max_initial_timestamp_index
logits.contiguous()
logits[:, last_allowed + 1:] = -np.inf logits[:, last_allowed + 1:] = -np.inf
# if sum of probability over timestamps is above any other token, sample timestamp # if sum of probability over timestamps is above any other token, sample timestamp
@ -1005,15 +1139,16 @@ class DecodingTask:
language = options.language or "en" language = options.language or "en"
tokenizer = get_tokenizer( tokenizer = get_tokenizer(
model.is_multilingual, multilingual=model.is_multilingual,
resource_path=resource_path, resource_path=resource_path,
language=language, language=language,
task=options.task) task=options.task,
num_languages=model.num_languages)
self.tokenizer: Tokenizer = tokenizer self.tokenizer: Tokenizer = tokenizer
self.options: DecodingOptions = self._verify_options(options) self.options: DecodingOptions = self._verify_options(options)
self.resource_path: str = resource_path self.resource_path: str = resource_path
self.beam_size: int = options.beam_size or options.best_of or 1 self.n_group: int = options.beam_size or options.best_of or 1
self.n_ctx: int = model.dims.n_text_ctx self.n_ctx: int = model.dims.n_text_ctx
self.sample_len: int = options.sample_len or model.dims.n_text_ctx // 2 self.sample_len: int = options.sample_len or model.dims.n_text_ctx // 2
@ -1158,10 +1293,10 @@ class DecodingTask:
sum_logprobs: paddle.Tensor = paddle.zeros( sum_logprobs: paddle.Tensor = paddle.zeros(
paddle.to_tensor(n_batch), dtype=paddle.float32) paddle.to_tensor(n_batch), dtype=paddle.float32)
no_speech_probs = [np.nan] * n_batch no_speech_probs = [np.nan] * n_batch
try: try:
for i in range(self.sample_len): for i in range(self.sample_len):
logits = self.inference.logits(tokens, audio_features) logits = self.inference.logits(tokens, audio_features)
logits.contiguous()
if i == 0 and self.tokenizer.no_speech is not None: # save no_speech_probs if i == 0 and self.tokenizer.no_speech is not None: # save no_speech_probs
probs_at_sot = F.softmax( probs_at_sot = F.softmax(
@ -1197,17 +1332,8 @@ class DecodingTask:
audio_features: paddle.Tensor = self._get_audio_features( audio_features: paddle.Tensor = self._get_audio_features(
mel) # encoder forward pass mel) # encoder forward pass
tokens: paddle.Tensor tokens: Tensor = paddle.to_tensor([self.initial_tokens]).repeat(
if batch_size > 1: batch_size, 1)
for i in range(batch_size):
tokens = paddle.concat(
x=[
paddle.to_tensor([self.initial_tokens]),
paddle.to_tensor([self.initial_tokens])
],
axis=0)
elif batch_size == 1:
tokens = paddle.to_tensor([self.initial_tokens])
# detect language if requested, overwriting the language token # detect language if requested, overwriting the language token
languages, language_probs = self._detect_language( languages, language_probs = self._detect_language(
@ -1224,30 +1350,26 @@ class DecodingTask:
language_probs) language_probs)
] ]
# repeat the audio & text tensors by the group size, for beam search or best-of-n sampling # repeat text tensors by the group size, for beam search or best-of-n sampling
tokens = tokens.repeat_interleave(self.n_group, axis=0)
audio_features = paddle.repeat_interleave(
audio_features, self.beam_size, axis=0)
tokens = paddle.repeat_interleave(tokens, self.beam_size, axis=0)
# call the main sampling loop # call the main sampling loop
tokens, sum_logprobs, no_speech_probs = self._main_loop(audio_features, tokens, sum_logprobs, no_speech_probs = self._main_loop(audio_features,
tokens) tokens)
# reshape the tensors to have (batch_size, beam_size) as the first two dimensions # reshape the tensors to have (batch_size, n_group) as the first two dimensions
audio_features = audio_features[::self.beam_size] audio_features = audio_features[::self.n_group]
no_speech_probs = no_speech_probs[::self.beam_size] no_speech_probs = no_speech_probs[::self.n_group]
assert audio_features.shape[0] == len(no_speech_probs) == batch_size assert audio_features.shape[0] == len(no_speech_probs) == batch_size
tokens = tokens.reshape([batch_size, self.beam_size, -1]) tokens = tokens.reshape([batch_size, self.n_group, -1])
sum_logprobs = sum_logprobs.reshape([batch_size, self.beam_size]) sum_logprobs = sum_logprobs.reshape([batch_size, self.n_group])
# get the final candidates for each group, and slice between the first sampled token and EOT # get the final candidates for each group, and slice between the first sampled token and EOT
tokens, sum_logprobs = self.decoder.finalize(tokens, sum_logprobs) tokens, sum_logprobs = self.decoder.finalize(tokens, sum_logprobs)
tokens: List[List[paddle.Tensor]] = [[ tokens: List[List[Tensor]] = [[
t[self.sample_begin:(t == tokenizer.eot).nonzero()[0, 0]] for t in s t[self.sample_begin:(t == tokenizer.eot).nonzero()[0, 0]] for t in s
] for s in tokens] ] for s in tokens]
# select the top-ranked sample in each group # select the top-ranked sample in each group
selected = self.sequence_ranker.rank(tokens, sum_logprobs) selected = self.sequence_ranker.rank(tokens, sum_logprobs)
tokens: List[List[ tokens: List[List[
@ -1256,11 +1378,12 @@ class DecodingTask:
sum_logprobs: List[ sum_logprobs: List[
float] = [lp[i] for i, lp in zip(selected, sum_logprobs)] float] = [lp[i] for i, lp in zip(selected, sum_logprobs)]
avg_logprobs: List[ avg_logprobs: List[float] = [
float] = [lp / (len(t) + 1) for t, lp in zip(tokens, sum_logprobs)] lp / (len(t) + 1) for t, lp in zip(tokens, sum_logprobs)
]
fields = (texts, languages, tokens, audio_features, avg_logprobs, fields = (texts, languages, tokens, audio_features, avg_logprobs,
no_speech_probs) no_speech_probs, )
if len(set(map(len, fields))) != 1: if len(set(map(len, fields))) != 1:
raise RuntimeError( raise RuntimeError(
f"inconsistent result lengths: {list(map(len, fields))}") f"inconsistent result lengths: {list(map(len, fields))}")
@ -1294,7 +1417,7 @@ def decode(
model: Whisper model: Whisper
the Whisper model instance the Whisper model instance
mel: paddle.Tensor, shape = (80, 3000) or (*, 80, 3000) mel: paddle.Tensor, shape = (80, 3000) or (*, 80, 3000) or (128, 3000) or (*, 128, 3000)
A tensor containing the Mel spectrogram(s) A tensor containing the Mel spectrogram(s)
options: DecodingOptions options: DecodingOptions
@ -1350,7 +1473,11 @@ class Whisper(nn.Layer):
@property @property
def is_multilingual(self): def is_multilingual(self):
return self.dims.n_vocab == 51865 return self.dims.n_vocab >= 51865
@property
def num_languages(self):
return self.dims.n_vocab - 51765 - int(self.is_multilingual)
def install_kv_cache_hooks(self, cache: Optional[dict]=None): def install_kv_cache_hooks(self, cache: Optional[dict]=None):
""" """
@ -1364,7 +1491,7 @@ class Whisper(nn.Layer):
cache : Dict[nn.Layer, paddle.Tensor] cache : Dict[nn.Layer, paddle.Tensor]
A dictionary object mapping the key/value projection modules to its cache A dictionary object mapping the key/value projection modules to its cache
hooks : List[RemovableHandle] hooks : List[RemovableHandle]
List of PyTorch RemovableHandle objects to stop the hooks to be called List of Paddle RemovableHandle objects to stop the hooks to be called
""" """
cache = {**cache} if cache is not None else {} cache = {**cache} if cache is not None else {}
hooks = [] hooks = []
@ -1431,11 +1558,11 @@ def hann_window(n_fft: int=N_FFT):
""" """
return paddle.to_tensor( return paddle.to_tensor(
[0.5 - 0.5 * np.cos(2 * np.pi * n / n_fft) for n in range(n_fft)], [0.5 - 0.5 * np.cos(2 * np.pi * n / n_fft) for n in range(n_fft)],
dtype=paddle.float32) dtype="float32")
@lru_cache(maxsize=None) @lru_cache(maxsize=None)
def mel_filters(resource_path: str, n_mels: int=N_MELS) -> paddle.Tensor: def mel_filters(resource_path: str, n_mels: int) -> paddle.Tensor:
""" """
load the mel filterbank matrix for projecting STFT into a Mel spectrogram. load the mel filterbank matrix for projecting STFT into a Mel spectrogram.
Allows decoupling librosa dependency; saved using: Allows decoupling librosa dependency; saved using:
@ -1445,13 +1572,16 @@ def mel_filters(resource_path: str, n_mels: int=N_MELS) -> paddle.Tensor:
mel_80=librosa.filters.mel(sr=16000, n_fft=400, n_mels=80), mel_80=librosa.filters.mel(sr=16000, n_fft=400, n_mels=80),
) )
""" """
assert n_mels == 80, f"Unsupported n_mels: {n_mels}" assert n_mels in {80, 128}, f"Unsupported n_mels: {n_mels}"
with np.load(os.path.join(resource_path, "assets", "mel_filters.npz")) as f:
filters_path = os.path.join(resource_path, "assets", "mel_filters.npz")
with np.load(filters_path, allow_pickle=False) as f:
return paddle.to_tensor(f[f"mel_{n_mels}"]) return paddle.to_tensor(f[f"mel_{n_mels}"])
def log_mel_spectrogram(audio: Union[str, np.ndarray, paddle.Tensor], def log_mel_spectrogram(audio: Union[str, np.ndarray, paddle.Tensor],
n_mels: int=N_MELS, n_mels: int=80,
padding: int=0,
resource_path: str=None): resource_path: str=None):
""" """
Compute the log-Mel spectrogram of Compute the log-Mel spectrogram of
@ -1462,11 +1592,11 @@ def log_mel_spectrogram(audio: Union[str, np.ndarray, paddle.Tensor],
The path to audio or either a NumPy array or Tensor containing the audio waveform in 16 kHz The path to audio or either a NumPy array or Tensor containing the audio waveform in 16 kHz
n_mels: int n_mels: int
The number of Mel-frequency filters, only 80 is supported The number of Mel-frequency filters, only 80 and 128 is supported
Returns Returns
------- -------
paddle.Tensor, shape = (80, n_frames) paddle.Tensor, shape = (n_mels, n_frames)
A Tensor that contains the Mel spectrogram A Tensor that contains the Mel spectrogram
""" """
if not paddle.is_tensor(audio): if not paddle.is_tensor(audio):
@ -1475,7 +1605,8 @@ def log_mel_spectrogram(audio: Union[str, np.ndarray, paddle.Tensor],
audio = audio[:, 0] audio = audio[:, 0]
logger.info(f"audio shape: {audio.shape}") logger.info(f"audio shape: {audio.shape}")
audio = paddle.to_tensor(audio) audio = paddle.to_tensor(audio)
if padding > 0:
audio = F.pad(audio, (0, padding), data_format="NLC")
window = hann_window(N_FFT) window = hann_window(N_FFT)
stft = paddle.signal.stft(audio, N_FFT, HOP_LENGTH, window=window) stft = paddle.signal.stft(audio, N_FFT, HOP_LENGTH, window=window)

@ -128,6 +128,7 @@ base = [
"flatten_dict", "flatten_dict",
"pyloudnorm", "pyloudnorm",
"rich", "rich",
"tiktoken",
] ]
server = ["pattern_singleton", "websockets"] server = ["pattern_singleton", "websockets"]

Loading…
Cancel
Save