【ASR】whisper large v3 (#4101)

* whisper large v3

* add convert.py

* mv nlp tokenizer to tiktoken.

* fix bug

* remove convert.py

* add new model file.

* fix

* fix version number

* fix version number

* fix some bug

* fix bug
pull/4115/head
zxcd 4 days ago committed by GitHub
parent 8f367b056f
commit 538f260061
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194

@ -74,10 +74,9 @@ class WhisperExecutor(BaseExecutor):
self.parser.add_argument(
'--size',
type=str,
default='large',
choices=['large', 'medium', 'base', 'small', 'tiny'],
help='Choose model size. now only support large, large:[whisper-large-16k]'
)
default='turbo',
choices=['large', 'medium', 'base', 'small', 'tiny', 'turbo'],
help='Choose model size.')
self.parser.add_argument(
'--language',
type=str,
@ -141,7 +140,7 @@ class WhisperExecutor(BaseExecutor):
model_type: str='whisper',
lang: str='',
task: str='transcribe',
size: str='large',
size: str='turbo',
language: str='None',
sample_rate: int=16000,
cfg_path: Optional[os.PathLike]=None,
@ -200,6 +199,7 @@ class WhisperExecutor(BaseExecutor):
# load model
model_dict = paddle.load(self.ckpt_path)
dims = ModelDimensions(**model_dict["dims"])
self.dims = dims
self.model = Whisper(dims)
self.model.load_dict(model_dict)
self.model.eval()
@ -251,8 +251,11 @@ class WhisperExecutor(BaseExecutor):
logger.debug(f"audio shape: {audio.shape}")
# fbank
audio = log_mel_spectrogram(audio, resource_path=self.resource_path)
audio = log_mel_spectrogram(
audio,
resource_path=self.resource_path,
n_mels=self.dims.n_mels,
padding=480000)
audio_len = paddle.to_tensor(audio.shape[0]).unsqueeze(axis=0)
self._inputs["audio"] = audio
@ -275,7 +278,6 @@ class WhisperExecutor(BaseExecutor):
cfg.temperature_increment_on_fallback))
else:
temperature = [cfg.temperature]
self._outputs["result"] = self.model.transcribe(
audio,
verbose=cfg.verbose,

@ -617,6 +617,24 @@ whisper_dynamic_pretrained_models = {
'resource_data_md5':
'37a0a8abdb3641a51194f79567a93b61',
},
'1.5': {
'url':
'https://paddlespeech.bj.bcebos.com/whisper/whisper_model_20250825/whisper-large-model.tar.gz',
'md5':
'9ebbd228fa07ca4557e5da863dac2982',
'cfg_path':
'whisper.yaml',
'ckpt_path':
'whisper-large-model',
'model':
'whisper-large-model.pdparams',
'params':
'whisper-large-model.pdparams',
'resource_data':
'https://paddlespeech.bj.bcebos.com/whisper/whisper_model_20250825/assets.tar',
'resource_data_md5':
'dd61d092d362f1fdbae6ede08282e177',
},
},
"whisper-base-en-16k": {
'1.3': {
@ -637,6 +655,24 @@ whisper_dynamic_pretrained_models = {
'resource_data_md5':
'37a0a8abdb3641a51194f79567a93b61',
},
'1.5': {
'url':
'https://paddlespeech.bj.bcebos.com/whisper/whisper_model_20250825/whisper-base-en-model.tar.gz',
'md5':
'376617a9c5f36404f50dde3708bac0c6',
'cfg_path':
'whisper.yaml',
'ckpt_path':
'whisper-base-en-model',
'model':
'whisper-base-en-model.pdparams',
'params':
'whisper-base-en-model.pdparams',
'resource_data':
'https://paddlespeech.bj.bcebos.com/whisper/whisper_model_20250825/assets.tar',
'resource_data_md5':
'dd61d092d362f1fdbae6ede08282e177',
},
},
"whisper-base-16k": {
'1.3': {
@ -657,6 +693,24 @@ whisper_dynamic_pretrained_models = {
'resource_data_md5':
'37a0a8abdb3641a51194f79567a93b61',
},
'1.5': {
'url':
'https://paddlespeech.bj.bcebos.com/whisper/whisper_model_20250825/whisper-base-model.tar.gz',
'md5':
'61836cb29c93048621f83364d83b532b',
'cfg_path':
'whisper.yaml',
'ckpt_path':
'whisper-base-model',
'model':
'whisper-base-model.pdparams',
'params':
'whisper-base-model.pdparams',
'resource_data':
'https://paddlespeech.bj.bcebos.com/whisper/whisper_model_20250825/assets.tar',
'resource_data_md5':
'dd61d092d362f1fdbae6ede08282e177',
},
},
"whisper-medium-en-16k": {
'1.3': {
@ -677,6 +731,24 @@ whisper_dynamic_pretrained_models = {
'resource_data_md5':
'37a0a8abdb3641a51194f79567a93b61',
},
'1.5': {
'url':
'https://paddlespeech.bj.bcebos.com/whisper/whisper_model_20250825/whisper-medium-en-model.tar.gz',
'md5':
'ac01145c5de962f1416f3d98171be559',
'cfg_path':
'whisper.yaml',
'ckpt_path':
'whisper-medium-en-model',
'model':
'whisper-medium-en-model.pdparams',
'params':
'whisper-medium-en-model.pdparams',
'resource_data':
'https://paddlespeech.bj.bcebos.com/whisper/whisper_model_20250825/assets.tar',
'resource_data_md5':
'dd61d092d362f1fdbae6ede08282e177',
},
},
"whisper-medium-16k": {
'1.3': {
@ -697,6 +769,24 @@ whisper_dynamic_pretrained_models = {
'resource_data_md5':
'37a0a8abdb3641a51194f79567a93b61',
},
'1.5': {
'url':
'https://paddlespeech.bj.bcebos.com/whisper/whisper_model_20250825/whisper-medium-model.tar.gz',
'md5':
'07770819961d1fe795facd3666f8db17',
'cfg_path':
'whisper.yaml',
'ckpt_path':
'whisper-medium-model',
'model':
'whisper-medium-model.pdparams',
'params':
'whisper-medium-model.pdparams',
'resource_data':
'https://paddlespeech.bj.bcebos.com/whisper/whisper_model_20250825/assets.tar',
'resource_data_md5':
'dd61d092d362f1fdbae6ede08282e177',
},
},
"whisper-small-en-16k": {
'1.3': {
@ -717,6 +807,24 @@ whisper_dynamic_pretrained_models = {
'resource_data_md5':
'37a0a8abdb3641a51194f79567a93b61',
},
'1.5': {
'url':
'https://paddlespeech.bj.bcebos.com/whisper/whisper_model_20250825/whisper-small-en-model.tar.gz',
'md5':
'67af14156b93f49ae738a17204189e46',
'cfg_path':
'whisper.yaml',
'ckpt_path':
'whisper-small-en-model',
'model':
'whisper-small-en-model.pdparams',
'params':
'whisper-small-en-model.pdparams',
'resource_data':
'https://paddlespeech.bj.bcebos.com/whisper/whisper_model_20250825/assets.tar',
'resource_data_md5':
'dd61d092d362f1fdbae6ede08282e177',
},
},
"whisper-small-16k": {
'1.3': {
@ -737,6 +845,24 @@ whisper_dynamic_pretrained_models = {
'resource_data_md5':
'37a0a8abdb3641a51194f79567a93b61',
},
'1.5': {
'url':
'https://paddlespeech.bj.bcebos.com/whisper/whisper_model_20250825/whisper-small-model.tar.gz',
'md5':
'db53c4bf39a9ad46ef77e6f9a37200b6',
'cfg_path':
'whisper.yaml',
'ckpt_path':
'whisper-small-model',
'model':
'whisper-small-model.pdparams',
'params':
'whisper-small-model.pdparams',
'resource_data':
'https://paddlespeech.bj.bcebos.com/whisper/whisper_model_20250825/assets.tar',
'resource_data_md5':
'dd61d092d362f1fdbae6ede08282e177',
},
},
"whisper-tiny-en-16k": {
'1.3': {
@ -757,6 +883,24 @@ whisper_dynamic_pretrained_models = {
'resource_data_md5':
'37a0a8abdb3641a51194f79567a93b61',
},
'1.5': {
'url':
'https://paddlespeech.bj.bcebos.com/whisper/whisper_model_20250825/whisper-tiny-en-model.tar.gz',
'md5':
'f91f8447d8b37ed13f4327ef6565b094',
'cfg_path':
'whisper.yaml',
'ckpt_path':
'whisper-tiny-en-model',
'model':
'whisper-tiny-en-model.pdparams',
'params':
'whisper-tiny-en-model.pdparams',
'resource_data':
'https://paddlespeech.bj.bcebos.com/whisper/whisper_model_20250825/assets.tar',
'resource_data_md5':
'dd61d092d362f1fdbae6ede08282e177',
},
},
"whisper-tiny-16k": {
'1.3': {
@ -777,6 +921,44 @@ whisper_dynamic_pretrained_models = {
'resource_data_md5':
'37a0a8abdb3641a51194f79567a93b61',
},
'1.5': {
'url':
'https://paddlespeech.bj.bcebos.com/whisper/whisper_model_20250825/whisper-tiny-model.tar.gz',
'md5':
'6f2209ac656ff12de085c824363316e2',
'cfg_path':
'whisper.yaml',
'ckpt_path':
'whisper-tiny-model',
'model':
'whisper-tiny-model.pdparams',
'params':
'whisper-tiny-model.pdparams',
'resource_data':
'https://paddlespeech.bj.bcebos.com/whisper/whisper_model_20250825/assets.tar',
'resource_data_md5':
'dd61d092d362f1fdbae6ede08282e177',
},
},
"whisper-turbo-16k": {
'1.5': {
'url':
'https://paddlespeech.bj.bcebos.com/whisper/whisper_model_20250825/whisper-turbo-model.tar.gz',
'md5':
'fe2dd1a1d6eb8e6d017aafc7d5f62336',
'cfg_path':
'whisper.yaml',
'ckpt_path':
'whisper-turbo-model',
'model':
'whisper-turbo-model.pdparams',
'params':
'whisper-turbo-model.pdparams',
'resource_data':
'https://paddlespeech.bj.bcebos.com/whisper/whisper_model_20250825/assets.tar',
'resource_data_md5':
'dd61d092d362f1fdbae6ede08282e177',
},
},
}

@ -45,6 +45,7 @@ class WhisperInfer():
model_dict = paddle.load(self.config.model_file)
config.pop("model_file")
dims = ModelDimensions(**model_dict["dims"])
self.dims = dims
self.model = Whisper(dims)
self.model.load_dict(model_dict)
@ -64,8 +65,10 @@ class WhisperInfer():
#load audio
mel = log_mel_spectrogram(
args.audio_file, resource_path=config.resource_path)
args.audio_file,
resource_path=config.resource_path,
n_mels=self.dims.n_mels,
padding=480000)
result = transcribe(
self.model, mel, temperature=temperature, **config)
if args.result_file is not None:

@ -2,17 +2,19 @@
# Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved.
#
# Modified from OpenAI Whisper 2022 (https://github.com/openai/whisper/whisper/tokenizer.py)
import base64
import os
import string
from dataclasses import dataclass
from dataclasses import field
from functools import cached_property
from functools import lru_cache
from typing import Dict
from typing import List
from typing import Optional
from typing import Tuple
from typing import Union
import numpy as np
import paddle
from paddlenlp.transformers import GPTTokenizer
import tiktoken
LANGUAGES = {
"en": "english",
@ -35,7 +37,7 @@ LANGUAGES = {
"hi": "hindi",
"fi": "finnish",
"vi": "vietnamese",
"iw": "hebrew",
"he": "hebrew",
"uk": "ukrainian",
"el": "greek",
"ms": "malay",
@ -114,6 +116,7 @@ LANGUAGES = {
"ba": "bashkir",
"jw": "javanese",
"su": "sundanese",
"yue": "cantonese",
}
# language code lookup by name, with a few language aliases
@ -130,134 +133,122 @@ TO_LANGUAGE_CODE = {
"moldovan": "ro",
"sinhalese": "si",
"castilian": "es",
"mandarin": "zh",
}
@dataclass(frozen=True)
@dataclass
class Tokenizer:
"""A thin wrapper around `GPTTokenizer` providing quick access to special tokens"""
tokenizer: "GPTTokenizer"
language: Optional[str]
sot_sequence: Tuple[int]
"""A thin wrapper around `tiktoken` providing quick access to special tokens"""
encoding: tiktoken.Encoding
num_languages: int
language: Optional[str] = None
task: Optional[str] = None
sot_sequence: Tuple[int] = ()
special_tokens: Dict[str, int] = field(default_factory=dict)
def __post_init__(self):
for special in self.encoding.special_tokens_set:
special_token = self.encoding.encode_single_token(special)
self.special_tokens[special] = special_token
sot: int = self.special_tokens["<|startoftranscript|>"]
translate: int = self.special_tokens["<|translate|>"]
transcribe: int = self.special_tokens["<|transcribe|>"]
langs = tuple(LANGUAGES.keys())[:self.num_languages]
sot_sequence = [sot]
if self.language is not None:
sot_sequence.append(sot + 1 + langs.index(self.language))
if self.task is not None:
task_token: int = transcribe if self.task == "transcribe" else translate
sot_sequence.append(task_token)
self.sot_sequence = tuple(sot_sequence)
def encode(self, text, **kwargs):
return self.tokenizer.encode(text, **kwargs)
def decode(self,
token_ids: Union[int, List[int], np.ndarray, paddle.Tensor],
**kwargs):
if len(token_ids) > 1:
ids_list = []
for ids in token_ids:
if paddle.is_tensor(ids):
ids = ids.item()
if ids < len(self.tokenizer):
ids_list.append(ids)
token_ids = ids_list
elif len(token_ids) == 1:
token_ids = token_ids[0]
else:
raise ValueError(f"token_ids {token_ids} load error.")
return self.tokenizer.decode(token_ids, **kwargs)
def decode_with_timestamps(self, tokens) -> str:
return self.encoding.encode(text, **kwargs)
def decode(self, token_ids: List[int], **kwargs) -> str:
token_ids = [t for t in token_ids if t < self.timestamp_begin]
return self.encoding.decode(token_ids, **kwargs)
def decode_with_timestamps(self, token_ids: List[int], **kwargs) -> str:
"""
Timestamp tokens are above the special tokens' id range and are ignored by `decode()`.
Timestamp tokens are above other special tokens' id range and are ignored by `decode()`.
This method decodes given tokens with timestamps tokens annotated, e.g. "<|1.08|>".
"""
outputs = [[]]
for token in tokens:
if token >= self.timestamp_begin:
timestamp = f"<|{(token - self.timestamp_begin) * 0.02:.2f}|>"
outputs.append(timestamp)
outputs.append([])
else:
outputs[-1].append(token)
outputs = [
s if isinstance(s, str) else self.tokenizer.decode(s)
for s in outputs
]
return "".join(outputs)
@property
@lru_cache()
return self.encoding.decode(token_ids, **kwargs)
@cached_property
def eot(self) -> int:
return self.tokenizer.eos_token_id
return self.encoding.eot_token
@cached_property
def transcribe(self) -> int:
return self.special_tokens["<|transcribe|>"]
@property
@lru_cache()
@cached_property
def translate(self) -> int:
return self.special_tokens["<|translate|>"]
@cached_property
def sot(self) -> int:
return self._get_single_token_id("<|startoftranscript|>")
return self.special_tokens["<|startoftranscript|>"]
@property
@lru_cache()
@cached_property
def sot_lm(self) -> int:
return self._get_single_token_id("<|startoflm|>")
return self.special_tokens["<|startoflm|>"]
@property
@lru_cache()
@cached_property
def sot_prev(self) -> int:
return self._get_single_token_id("<|startofprev|>")
return self.special_tokens["<|startofprev|>"]
@property
@lru_cache()
@cached_property
def no_speech(self) -> int:
return self._get_single_token_id("<|nospeech|>")
return self.special_tokens["<|nospeech|>"]
@property
@lru_cache()
@cached_property
def no_timestamps(self) -> int:
return self._get_single_token_id("<|notimestamps|>")
return self.special_tokens["<|notimestamps|>"]
@property
@lru_cache()
@cached_property
def timestamp_begin(self) -> int:
return self.tokenizer.all_special_ids[-1] + 1
return self.special_tokens["<|0.00|>"]
@property
@lru_cache()
@cached_property
def language_token(self) -> int:
"""Returns the token id corresponding to the value of the `language` field"""
if self.language is None:
raise ValueError(
"This tokenizer does not have language token configured")
additional_tokens = dict(
zip(
self.tokenizer.additional_special_tokens,
self.tokenizer.additional_special_tokens_ids, ))
candidate = f"<|{self.language}|>"
if candidate in additional_tokens:
return additional_tokens[candidate]
return self.to_language_token(self.language)
def to_language_token(self, language):
if token := self.special_tokens.get(f"<|{language}|>", None):
return token
raise KeyError(f"Language {self.language} not found in tokenizer.")
raise KeyError(f"Language {language} not found in tokenizer.")
@property
@lru_cache()
@cached_property
def all_language_tokens(self) -> Tuple[int]:
result = []
for token, token_id in zip(
self.tokenizer.additional_special_tokens,
self.tokenizer.additional_special_tokens_ids, ):
for token, token_id in self.special_tokens.items():
if token.strip("<|>") in LANGUAGES:
result.append(token_id)
return tuple(result)
return tuple(result)[:self.num_languages]
@property
@lru_cache()
@cached_property
def all_language_codes(self) -> Tuple[str]:
return tuple(
self.decode([l]).strip("<|>") for l in self.all_language_tokens)
self.decode([_l]).strip("<|>") for _l in self.all_language_tokens)
@property
@lru_cache()
@cached_property
def sot_sequence_including_notimestamps(self) -> Tuple[int]:
return tuple(list(self.sot_sequence) + [self.no_timestamps])
@property
@lru_cache()
@cached_property
def non_speech_tokens(self) -> Tuple[int]:
"""
Returns the list of tokens to suppress in order to avoid any speaker tags or non-speech
@ -269,9 +260,10 @@ class Tokenizer:
keeping basic punctuations like commas, periods, question marks, exclamation points, etc.
"""
symbols = list("\"#()*+/:;<=>@[\\]^_`{|}~「」『』")
symbols += "<< >> <<< >>> -- --- -( -[ (' (\" (( )) ((( ))) [[ ]] {{ }} ♪♪ ♪♪♪".split(
)
symbols = list('"#()*+/:;<=>@[\\]^_`{|}~「」『』')
symbols += (
"<< >> <<< >>> -- --- -( -[ (' (\" (( )) ((( ))) [[ ]] {{ }} ♪♪ ♪♪♪".
split())
# symbols that may be a single token or multiple tokens depending on the tokenizer.
# In case they're multiple tokens, suppress the first token, which is safe because:
@ -281,45 +273,99 @@ class Tokenizer:
assert all(0x2640 <= ord(c) <= 0x267F for c in miscellaneous)
# allow hyphens "-" and single quotes "'" between words, but not at the beginning of a word
result = {
self.tokenizer.encode(" -").input_ids[0],
self.tokenizer.encode(" '").input_ids[0]
}
result = {self.encoding.encode(" -")[0], self.encoding.encode(" '")[0]}
for symbol in symbols + list(miscellaneous):
for tokens in [
self.tokenizer.encode(symbol).input_ids,
self.tokenizer.encode(" " + symbol).input_ids
self.encoding.encode(symbol),
self.encoding.encode(" " + symbol),
]:
if len(tokens) == 1 or symbol in miscellaneous:
result.add(tokens[0])
return tuple(sorted(result))
def _get_single_token_id(self, text) -> int:
tokens = self.tokenizer.encode(text).input_ids
assert len(tokens) == 1, f"{text} is not encoded as a single token"
return tokens[0]
def split_to_word_tokens(self, tokens: List[int]):
if self.language in {"zh", "ja", "th", "lo", "my", "yue"}:
# These languages don't typically use spaces, so it is difficult to split words
# without morpheme analysis. Here, we instead split words at any
# position where the tokens are decoded as valid unicode points
return self.split_tokens_on_unicode(tokens)
return self.split_tokens_on_spaces(tokens)
@lru_cache(maxsize=None)
def build_tokenizer(resource_path: str, name: str="gpt2"):
os.environ["TOKENIZERS_PARALLELISM"] = "false"
path = os.path.join(resource_path, "assets", name)
tokenizer = GPTTokenizer.from_pretrained(path)
def split_tokens_on_unicode(self, tokens: List[int]):
decoded_full = self.decode_with_timestamps(tokens)
replacement_char = "\ufffd"
words = []
word_tokens = []
current_tokens = []
unicode_offset = 0
for token in tokens:
current_tokens.append(token)
decoded = self.decode_with_timestamps(current_tokens)
if (replacement_char not in decoded or
decoded_full[unicode_offset + decoded.index(
replacement_char)] == replacement_char):
words.append(decoded)
word_tokens.append(current_tokens)
current_tokens = []
unicode_offset += len(decoded)
return words, word_tokens
def split_tokens_on_spaces(self, tokens: List[int]):
subwords, subword_tokens_list = self.split_tokens_on_unicode(tokens)
words = []
word_tokens = []
for subword, subword_tokens in zip(subwords, subword_tokens_list):
special = subword_tokens[0] >= self.eot
with_space = subword.startswith(" ")
punctuation = subword.strip() in string.punctuation
if special or with_space or punctuation or len(words) == 0:
words.append(subword)
word_tokens.append(subword_tokens)
else:
words[-1] = words[-1] + subword
word_tokens[-1].extend(subword_tokens)
return words, word_tokens
@lru_cache(maxsize=None)
def get_encoding(resource_path: str, name: str="gpt2", num_languages: int=99):
vocab_path = os.path.join(resource_path, "assets", f"{name}.tiktoken")
ranks = {
base64.b64decode(token): int(rank)
for token, rank in (line.split() for line in open(vocab_path) if line)
}
n_vocab = len(ranks)
special_tokens = {}
specials = [
"<|endoftext|>",
"<|startoftranscript|>",
* [f"<|{lang}|>" for lang in LANGUAGES.keys()],
* [f"<|{lang}|>" for lang in list(LANGUAGES.keys())[:num_languages]],
"<|translate|>",
"<|transcribe|>",
"<|startoflm|>",
"<|startofprev|>",
"<|nospeech|>",
"<|notimestamps|>",
* [f"<|{i * 0.02:.2f}|>" for i in range(1501)],
]
tokenizer.add_special_tokens(dict(additional_special_tokens=specials))
return tokenizer
for token in specials:
special_tokens[token] = n_vocab
n_vocab += 1
return tiktoken.Encoding(
name=os.path.basename(vocab_path),
explicit_n_vocab=n_vocab,
pat_str=r"""'s|'t|'re|'ve|'m|'ll|'d| ?\p{L}+| ?\p{N}+| ?[^\s\p{L}\p{N}]+|\s+(?!\S)|\s+""",
mergeable_ranks=ranks,
special_tokens=special_tokens, )
@lru_cache(maxsize=None)
@ -327,8 +373,11 @@ def get_tokenizer(
multilingual: bool,
resource_path: str,
*,
num_languages: int=99,
language: Optional[str]=None,
task: Optional[str]=None, # Literal["transcribe", "translate", None]
language: Optional[str]=None, ) -> Tokenizer:
) -> Tokenizer:
if language is not None:
language = language.lower()
if language not in LANGUAGES:
@ -338,29 +387,21 @@ def get_tokenizer(
raise ValueError(f"Unsupported language: {language}")
if multilingual:
tokenizer_name = "multilingual"
task = task or "transcribe"
encoding_name = "multilingual"
language = language or "en"
task = task or "transcribe"
else:
tokenizer_name = "gpt2"
task = None
encoding_name = "gpt2"
language = None
task = None
tokenizer = build_tokenizer(
resource_path=resource_path, name=tokenizer_name)
all_special_ids: List[int] = tokenizer.all_special_ids
sot: int = all_special_ids[1]
translate: int = all_special_ids[-6]
transcribe: int = all_special_ids[-5]
langs = tuple(LANGUAGES.keys())
sot_sequence = [sot]
if language is not None:
sot_sequence.append(sot + 1 + langs.index(language))
if task is not None:
sot_sequence.append(transcribe if task == "transcribe" else translate)
encoding = get_encoding(
resource_path=resource_path,
name=encoding_name,
num_languages=num_languages)
return Tokenizer(
tokenizer=tokenizer,
encoding=encoding,
num_languages=num_languages,
language=language,
sot_sequence=tuple(sot_sequence))
task=task)

@ -33,13 +33,18 @@ logger = Log(__name__).getlog()
_MODELS = ["large"]
SAMPLE_RATE = 16000
N_FFT = 400
N_MELS = 80
HOP_LENGTH = 160
CHUNK_LENGTH = 30
N_SAMPLES = CHUNK_LENGTH * SAMPLE_RATE # 480000: number of samples in a chunk
N_FRAMES = utils.exact_div(
N_SAMPLES, HOP_LENGTH) # 3000: number of frames in a mel spectrogram input
N_SAMPLES_PER_TOKEN = HOP_LENGTH * 2 # the initial convolutions has stride 2
FRAMES_PER_SECOND = utils.exact_div(SAMPLE_RATE,
HOP_LENGTH) # 10ms per audio frame
TOKENS_PER_SECOND = utils.exact_div(SAMPLE_RATE,
N_SAMPLES_PER_TOKEN) # 20ms per audio token
@dataclass
class ModelDimensions:
@ -378,7 +383,9 @@ def detect_language(
"""
if tokenizer is None:
tokenizer = get_tokenizer(
model.is_multilingual, resource_path=resource_path)
multilingual=model.is_multilingual,
resource_path=resource_path,
num_languages=model.num_languages)
if tokenizer.language is None or tokenizer.language_token not in tokenizer.sot_sequence:
raise ValueError(
"This model doesn't have language tokens so it can't perform lang id"
@ -400,6 +407,7 @@ def detect_language(
# collect detected languages; suppress all non-language tokens
mask = paddle.ones(paddle.to_tensor(logits.shape[-1]), dtype=bool)
mask[list(tokenizer.all_language_tokens)] = False
logits.contiguous()
logits[:, mask] = -np.inf
language_tokens = paddle.argmax(logits, axis=-1)
language_token_probs = F.softmax(logits, axis=-1)
@ -428,6 +436,13 @@ def transcribe(
logprob_threshold: Optional[float]=-1.0,
no_speech_threshold: Optional[float]=0.6,
condition_on_previous_text: bool=True,
initial_prompt: Optional[str]=None,
carry_initial_prompt: bool=False,
word_timestamps: bool=False,
prepend_punctuations: str="\"'“¿([{-",
append_punctuations: str="\"'.。,!?::”)]}、",
clip_timestamps: Union[str, List[float]]="0",
hallucination_silence_threshold: Optional[float]=None,
**decode_options, ):
"""
Transcribe an audio file using Whisper
@ -476,8 +491,9 @@ def transcribe(
if dtype == np.float32:
decode_options["fp16"] = False
if decode_options.get("language") == 'None' or decode_options.get(
"language", None) is None:
content_frames = mel.shape[-1] - N_FRAMES
content_duration = float(content_frames * HOP_LENGTH / SAMPLE_RATE)
if decode_options.get("language", None) in {None, "None"}:
if not model.is_multilingual:
decode_options["language"] = "en"
else:
@ -485,25 +501,47 @@ def transcribe(
print(
"Detecting language using up to the first 30 seconds. Use `--language` to specify the language"
)
segment = pad_or_trim(mel, N_FRAMES)
_, probs = model.detect_language(segment, resource_path)
mel_segment = pad_or_trim(mel,
N_FRAMES).to(model.device).astype(dtype)
_, probs = model.detect_language(mel_segment, resource_path)
decode_options["language"] = max(probs, key=probs.get)
if verbose is not None:
print(
f"Detected language: {LANGUAGES[decode_options['language']].title()}"
)
language = decode_options["language"]
task = decode_options.get("task", "transcribe")
language: str = decode_options["language"]
task: str = decode_options.get("task", "transcribe")
tokenizer = get_tokenizer(
model.is_multilingual,
multilingual=model.is_multilingual,
resource_path=resource_path,
num_languages=model.num_languages,
language=language,
task=task)
task=task, )
if isinstance(clip_timestamps, str):
clip_timestamps = [
float(ts)
for ts in (clip_timestamps.split(",") if clip_timestamps else [])
]
seek_points: List[
int] = [round(ts * FRAMES_PER_SECOND) for ts in clip_timestamps]
if len(seek_points) == 0:
seek_points.append(0)
if len(seek_points) % 2 == 1:
seek_points.append(content_frames)
seek_clips: List[Tuple[int, int]] = list(
zip(seek_points[::2], seek_points[1::2]))
punctuation = "\"'“¿([{-\"'.。,!?::”)]}、"
if word_timestamps and task == "translate":
warnings.warn(
"Word-level timestamps on translations may not be reliable.")
def decode_with_fallback(segment: paddle.Tensor) -> DecodingResult:
temperatures = [temperature] if isinstance(temperature, (
int, float)) else temperature
temperatures = ([temperature] if isinstance(temperature, (int, float))
else temperature)
decode_result = None
for t in temperatures:
@ -517,20 +555,29 @@ def transcribe(
kwargs.pop("best_of", None)
options = DecodingOptions(**kwargs, temperature=t)
decode_result = model.decode(segment, options, resource_path)
needs_fallback = False
if compression_ratio_threshold is not None and decode_result.compression_ratio > compression_ratio_threshold:
if (compression_ratio_threshold is not None and
decode_result.compression_ratio >
compression_ratio_threshold):
needs_fallback = True # too repetitive
if logprob_threshold is not None and decode_result.avg_logprob < logprob_threshold:
if (logprob_threshold is not None and
decode_result.avg_logprob < logprob_threshold):
needs_fallback = True # average log probability is too low
if (no_speech_threshold is not None and
decode_result.no_speech_prob > no_speech_threshold and
logprob_threshold is not None and
decode_result.avg_logprob < logprob_threshold):
needs_fallback = False # silence
if not needs_fallback:
break
return decode_result
seek = 0
clip_idx = 0
seek = seek_clips[clip_idx][0]
input_stride = utils.exact_div(
N_FRAMES, model.dims.n_audio_ctx) # mel frames per output token: 2
time_precision = (input_stride * HOP_LENGTH /
@ -539,127 +586,210 @@ def transcribe(
all_segments = []
prompt_reset_since = 0
initial_prompt = decode_options.pop("initial_prompt", None) or []
if initial_prompt:
initial_prompt = tokenizer.encode(" " +
initial_prompt.strip()).input_ids
all_tokens.extend(initial_prompt)
remaining_prompt_length = model.dims.n_text_ctx // 2 - 1
if initial_prompt is not None:
initial_prompt_tokens = tokenizer.encode(" " + initial_prompt.strip())
all_tokens.extend(initial_prompt_tokens)
remaining_prompt_length -= len(initial_prompt_tokens)
else:
initial_prompt_tokens = []
def add_segment(*,
def new_segment(*,
start: float,
end: float,
text_tokens: paddle.Tensor,
tokens: paddle.Tensor,
result: DecodingResult):
text = tokenizer.decode(
[token for token in text_tokens if token < tokenizer.eot])
if len(text.strip()) == 0: # skip empty text output
return
all_segments.append({
"id": len(all_segments),
tokens = tokens.tolist()
text_tokens = [token for token in tokens if token < tokenizer.eot]
return {
"seek": seek,
"start": start,
"end": end,
"text": text,
"tokens": result.tokens,
"text": tokenizer.decode(text_tokens),
"tokens": tokens,
"temperature": result.temperature,
"avg_logprob": result.avg_logprob,
"compression_ratio": result.compression_ratio,
"no_speech_prob": result.no_speech_prob,
})
if verbose:
print(
f"[{utils.format_timestamp(start)} --> {utils.format_timestamp(end)}] {text}"
)
# show the progress bar when verbose is False (otherwise the transcribed text will be printed)
num_frames = mel.shape[-1]
previous_seek_value = seek
}
# show the progress bar when verbose is False (if True, transcribed text will be printed)
with tqdm.tqdm(
total=num_frames, unit='frames',
total=content_frames, unit="frames",
disable=verbose is not False) as pbar:
while seek < num_frames:
timestamp_offset = float(seek * HOP_LENGTH / SAMPLE_RATE)
segment = pad_or_trim(mel[:, seek:], N_FRAMES)
segment_duration = segment.shape[-1] * HOP_LENGTH / SAMPLE_RATE
decode_options["prompt"] = all_tokens[prompt_reset_since:]
result: DecodingResult = decode_with_fallback(segment)
last_speech_timestamp = 0.0
# NOTE: This loop is obscurely flattened to make the diff readable.
# A later commit should turn this into a simpler nested loop.
# for seek_clip_start, seek_clip_end in seek_clips:
# while seek < seek_clip_end
while clip_idx < len(seek_clips):
seek_clip_start, seek_clip_end = seek_clips[clip_idx]
if seek < seek_clip_start:
seek = seek_clip_start
if seek >= seek_clip_end:
clip_idx += 1
if clip_idx < len(seek_clips):
seek = seek_clips[clip_idx][0]
continue
time_offset = float(seek * HOP_LENGTH / SAMPLE_RATE)
window_end_time = float(
(seek + N_FRAMES) * HOP_LENGTH / SAMPLE_RATE)
segment_size = min(N_FRAMES, content_frames - seek,
seek_clip_end - seek)
mel_segment = mel[:, seek:seek + segment_size]
segment_duration = segment_size * HOP_LENGTH / SAMPLE_RATE
mel_segment = pad_or_trim(mel_segment,
N_FRAMES).to(model.device).astype(dtype)
if carry_initial_prompt:
nignored = max(len(initial_prompt_tokens), prompt_reset_since)
remaining_prompt = all_tokens[nignored:][
-remaining_prompt_length:]
decode_options[
"prompt"] = initial_prompt_tokens + remaining_prompt
else:
decode_options["prompt"] = all_tokens[prompt_reset_since:]
result: DecodingResult = decode_with_fallback(mel_segment)
tokens = paddle.to_tensor(result.tokens)
if no_speech_threshold is not None:
# no voice activity check
should_skip = result.no_speech_prob > no_speech_threshold
if logprob_threshold is not None and result.avg_logprob > logprob_threshold:
if (logprob_threshold is not None and
result.avg_logprob > logprob_threshold):
# don't skip if the logprob is high enough, despite the no_speech_prob
should_skip = False
if should_skip:
seek += segment.shape[
-1] # fast-forward to the next segment boundary
seek += segment_size # fast-forward to the next segment boundary
continue
previous_seek = seek
current_segments = []
# anomalous words are very long/short/improbable
def word_anomaly_score(word: dict) -> float:
probability = word.get("probability", 0.0)
duration = word["end"] - word["start"]
score = 0.0
if probability < 0.15:
score += 1.0
if duration < 0.133:
score += (0.133 - duration) * 15
if duration > 2.0:
score += duration - 2.0
return score
def is_segment_anomaly(segment: Optional[dict]) -> bool:
if segment is None or not segment["words"]:
return False
words = [
w for w in segment["words"] if w["word"] not in punctuation
]
words = words[:8]
score = sum(word_anomaly_score(w) for w in words)
return score >= 3 or score + 0.01 >= len(words)
def next_words_segment(segments: List[dict]) -> Optional[dict]:
return next((s for s in segments if s["words"]), None)
timestamp_tokens: paddle.Tensor = tokens.greater_equal(
paddle.to_tensor(tokenizer.timestamp_begin))
single_timestamp_ending = timestamp_tokens[
-2:].tolist() == [False, True]
consecutive = paddle.where(timestamp_tokens[:-1] & timestamp_tokens[
1:])[0]
if len(
consecutive
) > 0: # if the output contains two consecutive timestamp tokens
if consecutive.numel() != 0:
consecutive = paddle.add(consecutive, paddle.to_tensor(1))
if len(consecutive) > 0:
# if the output contains two consecutive timestamp tokens
slices = consecutive.tolist()
if single_timestamp_ending:
slices.append(len(tokens))
last_slice = 0
for current_slice in consecutive:
for current_slice in slices:
sliced_tokens = tokens[last_slice:current_slice]
start_timestamp_position = (
start_timestamp_pos = (
sliced_tokens[0].item() - tokenizer.timestamp_begin)
end_timestamp_position = (
end_timestamp_pos = (
sliced_tokens[-1].item() - tokenizer.timestamp_begin)
add_segment(
start=timestamp_offset + start_timestamp_position *
time_precision,
end=timestamp_offset + end_timestamp_position *
time_precision,
text_tokens=sliced_tokens[1:-1],
result=result, )
current_segments.append(
new_segment(
start=time_offset + start_timestamp_pos *
time_precision,
end=time_offset + end_timestamp_pos *
time_precision,
tokens=sliced_tokens,
result=result, ))
last_slice = current_slice
last_timestamp_position = (
tokens[last_slice - 1].item() - tokenizer.timestamp_begin)
seek += last_timestamp_position * input_stride
all_tokens.extend(tokens[:last_slice + 1].tolist())
if single_timestamp_ending:
# single timestamp at the end means no speech after the last timestamp.
seek += segment_size
else:
# otherwise, ignore the unfinished segment and seek to the last timestamp
last_timestamp_pos = (tokens[last_slice - 1].item() -
tokenizer.timestamp_begin)
seek += last_timestamp_pos * input_stride
else:
duration = segment_duration
timestamps = tokens[timestamp_tokens.nonzero().flatten()]
if len(timestamps) > 0 and timestamps[
-1].item() != tokenizer.timestamp_begin:
if (len(timestamps) > 0 and
timestamps[-1].item() != tokenizer.timestamp_begin):
# no consecutive timestamps but it has a timestamp; use the last one.
# single timestamp at the end means no speech after the last timestamp.
last_timestamp_position = timestamps[
-1].item() - tokenizer.timestamp_begin
duration = last_timestamp_position * time_precision
last_timestamp_pos = (
timestamps[-1].item() - tokenizer.timestamp_begin)
duration = last_timestamp_pos * time_precision
current_segments.append(
new_segment(
start=time_offset,
end=time_offset + duration,
tokens=tokens,
result=result, ))
seek += segment_size
add_segment(
start=timestamp_offset,
end=timestamp_offset + duration,
text_tokens=tokens,
result=result, )
seek += segment.shape[-1]
all_tokens.extend(tokens.tolist())
if verbose:
for segment in current_segments:
start, end, text = segment["start"], segment[
"end"], segment["text"]
line = f"[{utils.format_timestamp(start)} --> {utils.format_timestamp(end)}] {text}"
print(line)
# if a segment is instantaneous or does not contain text, clear it
for i, segment in enumerate(current_segments):
if segment["start"] == segment["end"] or segment[
"text"].strip() == "":
segment["text"] = ""
segment["tokens"] = []
segment["words"] = []
all_segments.extend(
[{
"id": i,
**
segment
}
for i, segment in enumerate(
current_segments, start=len(all_segments))])
all_tokens.extend([
token
for segment in current_segments for token in segment["tokens"]
])
if not condition_on_previous_text or result.temperature > 0.5:
# do not feed the prompt tokens if a high temperature was used
prompt_reset_since = len(all_tokens)
# update progress bar
pbar.update(min(num_frames, seek) - previous_seek_value)
previous_seek_value = seek
pbar.update(min(content_frames, seek) - previous_seek)
return dict(
text=tokenizer.decode(all_tokens[len(initial_prompt):]),
text=tokenizer.decode(all_tokens[len(initial_prompt_tokens):]),
segments=all_segments,
language=language)
language=language, )
class SequenceRanker:
@ -776,11 +906,11 @@ class GreedyDecoder(TokenDecoder):
next_tokens.shape[0] * next_tokens.shape[1],
])
logprobs = F.log_softmax(logits, axis=-1, dtype=paddle.float32)
logprobs = F.log_softmax(logits, axis=-1, dtype="float32")
current_logprobs = logprobs[paddle.arange(logprobs.shape[0]),
next_tokens]
sum_logprobs += current_logprobs * paddle.to_tensor(
(tokens[:, -1] != self.eot), dtype=paddle.float32)
(tokens[:, -1] != self.eot), dtype="float32")
next_tokens[tokens[:, -1] == self.eot] = self.eot
tokens = paddle.concat([tokens, next_tokens[:, None]], axis=-1)
@ -928,8 +1058,9 @@ class SuppressBlank(LogitFilter):
def apply(self, logits: paddle.Tensor, tokens: paddle.Tensor):
if tokens.shape[1] == self.sample_begin:
logits[:, self.tokenizer.encode(" ").input_ids +
[self.tokenizer.eot]] = -np.inf
logits.contiguous()
logits[:, self.tokenizer.encode(" ") + [self.tokenizer.eot
]] = -np.inf
class SuppressTokens(LogitFilter):
@ -937,6 +1068,7 @@ class SuppressTokens(LogitFilter):
self.suppress_tokens = list(suppress_tokens)
def apply(self, logits: paddle.Tensor, tokens: paddle.Tensor):
logits.contiguous()
logits[:, self.suppress_tokens] = -np.inf
@ -952,6 +1084,7 @@ class ApplyTimestampRules(LogitFilter):
def apply(self, logits: paddle.Tensor, tokens: paddle.Tensor):
# suppress <|notimestamps|> which is handled by without_timestamps
if self.tokenizer.no_timestamps is not None:
logits.contiguous()
logits[:, self.tokenizer.no_timestamps] = -np.inf
# timestamps have to appear in pairs, except directly before EOT; mask logits accordingly
@ -972,6 +1105,7 @@ class ApplyTimestampRules(LogitFilter):
if tokens.shape[
1] == self.sample_begin and self.max_initial_timestamp_index is not None:
last_allowed = self.tokenizer.timestamp_begin + self.max_initial_timestamp_index
logits.contiguous()
logits[:, last_allowed + 1:] = -np.inf
# if sum of probability over timestamps is above any other token, sample timestamp
@ -1005,15 +1139,16 @@ class DecodingTask:
language = options.language or "en"
tokenizer = get_tokenizer(
model.is_multilingual,
multilingual=model.is_multilingual,
resource_path=resource_path,
language=language,
task=options.task)
task=options.task,
num_languages=model.num_languages)
self.tokenizer: Tokenizer = tokenizer
self.options: DecodingOptions = self._verify_options(options)
self.resource_path: str = resource_path
self.beam_size: int = options.beam_size or options.best_of or 1
self.n_group: int = options.beam_size or options.best_of or 1
self.n_ctx: int = model.dims.n_text_ctx
self.sample_len: int = options.sample_len or model.dims.n_text_ctx // 2
@ -1158,10 +1293,10 @@ class DecodingTask:
sum_logprobs: paddle.Tensor = paddle.zeros(
paddle.to_tensor(n_batch), dtype=paddle.float32)
no_speech_probs = [np.nan] * n_batch
try:
for i in range(self.sample_len):
logits = self.inference.logits(tokens, audio_features)
logits.contiguous()
if i == 0 and self.tokenizer.no_speech is not None: # save no_speech_probs
probs_at_sot = F.softmax(
@ -1197,17 +1332,8 @@ class DecodingTask:
audio_features: paddle.Tensor = self._get_audio_features(
mel) # encoder forward pass
tokens: paddle.Tensor
if batch_size > 1:
for i in range(batch_size):
tokens = paddle.concat(
x=[
paddle.to_tensor([self.initial_tokens]),
paddle.to_tensor([self.initial_tokens])
],
axis=0)
elif batch_size == 1:
tokens = paddle.to_tensor([self.initial_tokens])
tokens: Tensor = paddle.to_tensor([self.initial_tokens]).repeat(
batch_size, 1)
# detect language if requested, overwriting the language token
languages, language_probs = self._detect_language(
@ -1224,30 +1350,26 @@ class DecodingTask:
language_probs)
]
# repeat the audio & text tensors by the group size, for beam search or best-of-n sampling
audio_features = paddle.repeat_interleave(
audio_features, self.beam_size, axis=0)
tokens = paddle.repeat_interleave(tokens, self.beam_size, axis=0)
# repeat text tensors by the group size, for beam search or best-of-n sampling
tokens = tokens.repeat_interleave(self.n_group, axis=0)
# call the main sampling loop
tokens, sum_logprobs, no_speech_probs = self._main_loop(audio_features,
tokens)
# reshape the tensors to have (batch_size, beam_size) as the first two dimensions
audio_features = audio_features[::self.beam_size]
no_speech_probs = no_speech_probs[::self.beam_size]
# reshape the tensors to have (batch_size, n_group) as the first two dimensions
audio_features = audio_features[::self.n_group]
no_speech_probs = no_speech_probs[::self.n_group]
assert audio_features.shape[0] == len(no_speech_probs) == batch_size
tokens = tokens.reshape([batch_size, self.beam_size, -1])
sum_logprobs = sum_logprobs.reshape([batch_size, self.beam_size])
tokens = tokens.reshape([batch_size, self.n_group, -1])
sum_logprobs = sum_logprobs.reshape([batch_size, self.n_group])
# get the final candidates for each group, and slice between the first sampled token and EOT
tokens, sum_logprobs = self.decoder.finalize(tokens, sum_logprobs)
tokens: List[List[paddle.Tensor]] = [[
tokens: List[List[Tensor]] = [[
t[self.sample_begin:(t == tokenizer.eot).nonzero()[0, 0]] for t in s
] for s in tokens]
# select the top-ranked sample in each group
selected = self.sequence_ranker.rank(tokens, sum_logprobs)
tokens: List[List[
@ -1256,11 +1378,12 @@ class DecodingTask:
sum_logprobs: List[
float] = [lp[i] for i, lp in zip(selected, sum_logprobs)]
avg_logprobs: List[
float] = [lp / (len(t) + 1) for t, lp in zip(tokens, sum_logprobs)]
avg_logprobs: List[float] = [
lp / (len(t) + 1) for t, lp in zip(tokens, sum_logprobs)
]
fields = (texts, languages, tokens, audio_features, avg_logprobs,
no_speech_probs)
no_speech_probs, )
if len(set(map(len, fields))) != 1:
raise RuntimeError(
f"inconsistent result lengths: {list(map(len, fields))}")
@ -1294,7 +1417,7 @@ def decode(
model: Whisper
the Whisper model instance
mel: paddle.Tensor, shape = (80, 3000) or (*, 80, 3000)
mel: paddle.Tensor, shape = (80, 3000) or (*, 80, 3000) or (128, 3000) or (*, 128, 3000)
A tensor containing the Mel spectrogram(s)
options: DecodingOptions
@ -1350,7 +1473,11 @@ class Whisper(nn.Layer):
@property
def is_multilingual(self):
return self.dims.n_vocab == 51865
return self.dims.n_vocab >= 51865
@property
def num_languages(self):
return self.dims.n_vocab - 51765 - int(self.is_multilingual)
def install_kv_cache_hooks(self, cache: Optional[dict]=None):
"""
@ -1364,7 +1491,7 @@ class Whisper(nn.Layer):
cache : Dict[nn.Layer, paddle.Tensor]
A dictionary object mapping the key/value projection modules to its cache
hooks : List[RemovableHandle]
List of PyTorch RemovableHandle objects to stop the hooks to be called
List of Paddle RemovableHandle objects to stop the hooks to be called
"""
cache = {**cache} if cache is not None else {}
hooks = []
@ -1431,11 +1558,11 @@ def hann_window(n_fft: int=N_FFT):
"""
return paddle.to_tensor(
[0.5 - 0.5 * np.cos(2 * np.pi * n / n_fft) for n in range(n_fft)],
dtype=paddle.float32)
dtype="float32")
@lru_cache(maxsize=None)
def mel_filters(resource_path: str, n_mels: int=N_MELS) -> paddle.Tensor:
def mel_filters(resource_path: str, n_mels: int) -> paddle.Tensor:
"""
load the mel filterbank matrix for projecting STFT into a Mel spectrogram.
Allows decoupling librosa dependency; saved using:
@ -1445,13 +1572,16 @@ def mel_filters(resource_path: str, n_mels: int=N_MELS) -> paddle.Tensor:
mel_80=librosa.filters.mel(sr=16000, n_fft=400, n_mels=80),
)
"""
assert n_mels == 80, f"Unsupported n_mels: {n_mels}"
with np.load(os.path.join(resource_path, "assets", "mel_filters.npz")) as f:
assert n_mels in {80, 128}, f"Unsupported n_mels: {n_mels}"
filters_path = os.path.join(resource_path, "assets", "mel_filters.npz")
with np.load(filters_path, allow_pickle=False) as f:
return paddle.to_tensor(f[f"mel_{n_mels}"])
def log_mel_spectrogram(audio: Union[str, np.ndarray, paddle.Tensor],
n_mels: int=N_MELS,
n_mels: int=80,
padding: int=0,
resource_path: str=None):
"""
Compute the log-Mel spectrogram of
@ -1462,11 +1592,11 @@ def log_mel_spectrogram(audio: Union[str, np.ndarray, paddle.Tensor],
The path to audio or either a NumPy array or Tensor containing the audio waveform in 16 kHz
n_mels: int
The number of Mel-frequency filters, only 80 is supported
The number of Mel-frequency filters, only 80 and 128 is supported
Returns
-------
paddle.Tensor, shape = (80, n_frames)
paddle.Tensor, shape = (n_mels, n_frames)
A Tensor that contains the Mel spectrogram
"""
if not paddle.is_tensor(audio):
@ -1475,7 +1605,8 @@ def log_mel_spectrogram(audio: Union[str, np.ndarray, paddle.Tensor],
audio = audio[:, 0]
logger.info(f"audio shape: {audio.shape}")
audio = paddle.to_tensor(audio)
if padding > 0:
audio = F.pad(audio, (0, padding), data_format="NLC")
window = hann_window(N_FFT)
stft = paddle.signal.stft(audio, N_FFT, HOP_LENGTH, window=window)

@ -128,6 +128,7 @@ base = [
"flatten_dict",
"pyloudnorm",
"rich",
"tiktoken",
]
server = ["pattern_singleton", "websockets"]

Loading…
Cancel
Save