From 538f26006182c54ea9c4c47ade5af14a1e076419 Mon Sep 17 00:00:00 2001 From: zxcd <228587199@qq.com> Date: Mon, 1 Sep 2025 11:20:40 +0800 Subject: [PATCH] =?UTF-8?q?=E3=80=90ASR=E3=80=91whisper=20large=20v3=20(#4?= =?UTF-8?q?101)?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit * whisper large v3 * add convert.py * mv nlp tokenizer to tiktoken. * fix bug * remove convert.py * add new model file. * fix * fix version number * fix version number * fix some bug * fix bug --- paddlespeech/cli/whisper/infer.py | 18 +- paddlespeech/resource/pretrained_models.py | 182 +++++++++ paddlespeech/s2t/exps/whisper/test_wav.py | 7 +- paddlespeech/s2t/models/whisper/tokenizer.py | 307 +++++++------- paddlespeech/s2t/models/whisper/whisper.py | 395 ++++++++++++------- setup.py | 1 + 6 files changed, 635 insertions(+), 275 deletions(-) diff --git a/paddlespeech/cli/whisper/infer.py b/paddlespeech/cli/whisper/infer.py index 5649b757f..7056e7f94 100644 --- a/paddlespeech/cli/whisper/infer.py +++ b/paddlespeech/cli/whisper/infer.py @@ -74,10 +74,9 @@ class WhisperExecutor(BaseExecutor): self.parser.add_argument( '--size', type=str, - default='large', - choices=['large', 'medium', 'base', 'small', 'tiny'], - help='Choose model size. now only support large, large:[whisper-large-16k]' - ) + default='turbo', + choices=['large', 'medium', 'base', 'small', 'tiny', 'turbo'], + help='Choose model size.') self.parser.add_argument( '--language', type=str, @@ -141,7 +140,7 @@ class WhisperExecutor(BaseExecutor): model_type: str='whisper', lang: str='', task: str='transcribe', - size: str='large', + size: str='turbo', language: str='None', sample_rate: int=16000, cfg_path: Optional[os.PathLike]=None, @@ -200,6 +199,7 @@ class WhisperExecutor(BaseExecutor): # load model model_dict = paddle.load(self.ckpt_path) dims = ModelDimensions(**model_dict["dims"]) + self.dims = dims self.model = Whisper(dims) self.model.load_dict(model_dict) self.model.eval() @@ -251,8 +251,11 @@ class WhisperExecutor(BaseExecutor): logger.debug(f"audio shape: {audio.shape}") # fbank - audio = log_mel_spectrogram(audio, resource_path=self.resource_path) - + audio = log_mel_spectrogram( + audio, + resource_path=self.resource_path, + n_mels=self.dims.n_mels, + padding=480000) audio_len = paddle.to_tensor(audio.shape[0]).unsqueeze(axis=0) self._inputs["audio"] = audio @@ -275,7 +278,6 @@ class WhisperExecutor(BaseExecutor): cfg.temperature_increment_on_fallback)) else: temperature = [cfg.temperature] - self._outputs["result"] = self.model.transcribe( audio, verbose=cfg.verbose, diff --git a/paddlespeech/resource/pretrained_models.py b/paddlespeech/resource/pretrained_models.py index 9b8c4d2b5..d67d17894 100644 --- a/paddlespeech/resource/pretrained_models.py +++ b/paddlespeech/resource/pretrained_models.py @@ -617,6 +617,24 @@ whisper_dynamic_pretrained_models = { 'resource_data_md5': '37a0a8abdb3641a51194f79567a93b61', }, + '1.5': { + 'url': + 'https://paddlespeech.bj.bcebos.com/whisper/whisper_model_20250825/whisper-large-model.tar.gz', + 'md5': + '9ebbd228fa07ca4557e5da863dac2982', + 'cfg_path': + 'whisper.yaml', + 'ckpt_path': + 'whisper-large-model', + 'model': + 'whisper-large-model.pdparams', + 'params': + 'whisper-large-model.pdparams', + 'resource_data': + 'https://paddlespeech.bj.bcebos.com/whisper/whisper_model_20250825/assets.tar', + 'resource_data_md5': + 'dd61d092d362f1fdbae6ede08282e177', + }, }, "whisper-base-en-16k": { '1.3': { @@ -637,6 +655,24 @@ whisper_dynamic_pretrained_models = { 'resource_data_md5': '37a0a8abdb3641a51194f79567a93b61', }, + '1.5': { + 'url': + 'https://paddlespeech.bj.bcebos.com/whisper/whisper_model_20250825/whisper-base-en-model.tar.gz', + 'md5': + '376617a9c5f36404f50dde3708bac0c6', + 'cfg_path': + 'whisper.yaml', + 'ckpt_path': + 'whisper-base-en-model', + 'model': + 'whisper-base-en-model.pdparams', + 'params': + 'whisper-base-en-model.pdparams', + 'resource_data': + 'https://paddlespeech.bj.bcebos.com/whisper/whisper_model_20250825/assets.tar', + 'resource_data_md5': + 'dd61d092d362f1fdbae6ede08282e177', + }, }, "whisper-base-16k": { '1.3': { @@ -657,6 +693,24 @@ whisper_dynamic_pretrained_models = { 'resource_data_md5': '37a0a8abdb3641a51194f79567a93b61', }, + '1.5': { + 'url': + 'https://paddlespeech.bj.bcebos.com/whisper/whisper_model_20250825/whisper-base-model.tar.gz', + 'md5': + '61836cb29c93048621f83364d83b532b', + 'cfg_path': + 'whisper.yaml', + 'ckpt_path': + 'whisper-base-model', + 'model': + 'whisper-base-model.pdparams', + 'params': + 'whisper-base-model.pdparams', + 'resource_data': + 'https://paddlespeech.bj.bcebos.com/whisper/whisper_model_20250825/assets.tar', + 'resource_data_md5': + 'dd61d092d362f1fdbae6ede08282e177', + }, }, "whisper-medium-en-16k": { '1.3': { @@ -677,6 +731,24 @@ whisper_dynamic_pretrained_models = { 'resource_data_md5': '37a0a8abdb3641a51194f79567a93b61', }, + '1.5': { + 'url': + 'https://paddlespeech.bj.bcebos.com/whisper/whisper_model_20250825/whisper-medium-en-model.tar.gz', + 'md5': + 'ac01145c5de962f1416f3d98171be559', + 'cfg_path': + 'whisper.yaml', + 'ckpt_path': + 'whisper-medium-en-model', + 'model': + 'whisper-medium-en-model.pdparams', + 'params': + 'whisper-medium-en-model.pdparams', + 'resource_data': + 'https://paddlespeech.bj.bcebos.com/whisper/whisper_model_20250825/assets.tar', + 'resource_data_md5': + 'dd61d092d362f1fdbae6ede08282e177', + }, }, "whisper-medium-16k": { '1.3': { @@ -697,6 +769,24 @@ whisper_dynamic_pretrained_models = { 'resource_data_md5': '37a0a8abdb3641a51194f79567a93b61', }, + '1.5': { + 'url': + 'https://paddlespeech.bj.bcebos.com/whisper/whisper_model_20250825/whisper-medium-model.tar.gz', + 'md5': + '07770819961d1fe795facd3666f8db17', + 'cfg_path': + 'whisper.yaml', + 'ckpt_path': + 'whisper-medium-model', + 'model': + 'whisper-medium-model.pdparams', + 'params': + 'whisper-medium-model.pdparams', + 'resource_data': + 'https://paddlespeech.bj.bcebos.com/whisper/whisper_model_20250825/assets.tar', + 'resource_data_md5': + 'dd61d092d362f1fdbae6ede08282e177', + }, }, "whisper-small-en-16k": { '1.3': { @@ -717,6 +807,24 @@ whisper_dynamic_pretrained_models = { 'resource_data_md5': '37a0a8abdb3641a51194f79567a93b61', }, + '1.5': { + 'url': + 'https://paddlespeech.bj.bcebos.com/whisper/whisper_model_20250825/whisper-small-en-model.tar.gz', + 'md5': + '67af14156b93f49ae738a17204189e46', + 'cfg_path': + 'whisper.yaml', + 'ckpt_path': + 'whisper-small-en-model', + 'model': + 'whisper-small-en-model.pdparams', + 'params': + 'whisper-small-en-model.pdparams', + 'resource_data': + 'https://paddlespeech.bj.bcebos.com/whisper/whisper_model_20250825/assets.tar', + 'resource_data_md5': + 'dd61d092d362f1fdbae6ede08282e177', + }, }, "whisper-small-16k": { '1.3': { @@ -737,6 +845,24 @@ whisper_dynamic_pretrained_models = { 'resource_data_md5': '37a0a8abdb3641a51194f79567a93b61', }, + '1.5': { + 'url': + 'https://paddlespeech.bj.bcebos.com/whisper/whisper_model_20250825/whisper-small-model.tar.gz', + 'md5': + 'db53c4bf39a9ad46ef77e6f9a37200b6', + 'cfg_path': + 'whisper.yaml', + 'ckpt_path': + 'whisper-small-model', + 'model': + 'whisper-small-model.pdparams', + 'params': + 'whisper-small-model.pdparams', + 'resource_data': + 'https://paddlespeech.bj.bcebos.com/whisper/whisper_model_20250825/assets.tar', + 'resource_data_md5': + 'dd61d092d362f1fdbae6ede08282e177', + }, }, "whisper-tiny-en-16k": { '1.3': { @@ -757,6 +883,24 @@ whisper_dynamic_pretrained_models = { 'resource_data_md5': '37a0a8abdb3641a51194f79567a93b61', }, + '1.5': { + 'url': + 'https://paddlespeech.bj.bcebos.com/whisper/whisper_model_20250825/whisper-tiny-en-model.tar.gz', + 'md5': + 'f91f8447d8b37ed13f4327ef6565b094', + 'cfg_path': + 'whisper.yaml', + 'ckpt_path': + 'whisper-tiny-en-model', + 'model': + 'whisper-tiny-en-model.pdparams', + 'params': + 'whisper-tiny-en-model.pdparams', + 'resource_data': + 'https://paddlespeech.bj.bcebos.com/whisper/whisper_model_20250825/assets.tar', + 'resource_data_md5': + 'dd61d092d362f1fdbae6ede08282e177', + }, }, "whisper-tiny-16k": { '1.3': { @@ -777,6 +921,44 @@ whisper_dynamic_pretrained_models = { 'resource_data_md5': '37a0a8abdb3641a51194f79567a93b61', }, + '1.5': { + 'url': + 'https://paddlespeech.bj.bcebos.com/whisper/whisper_model_20250825/whisper-tiny-model.tar.gz', + 'md5': + '6f2209ac656ff12de085c824363316e2', + 'cfg_path': + 'whisper.yaml', + 'ckpt_path': + 'whisper-tiny-model', + 'model': + 'whisper-tiny-model.pdparams', + 'params': + 'whisper-tiny-model.pdparams', + 'resource_data': + 'https://paddlespeech.bj.bcebos.com/whisper/whisper_model_20250825/assets.tar', + 'resource_data_md5': + 'dd61d092d362f1fdbae6ede08282e177', + }, + }, + "whisper-turbo-16k": { + '1.5': { + 'url': + 'https://paddlespeech.bj.bcebos.com/whisper/whisper_model_20250825/whisper-turbo-model.tar.gz', + 'md5': + 'fe2dd1a1d6eb8e6d017aafc7d5f62336', + 'cfg_path': + 'whisper.yaml', + 'ckpt_path': + 'whisper-turbo-model', + 'model': + 'whisper-turbo-model.pdparams', + 'params': + 'whisper-turbo-model.pdparams', + 'resource_data': + 'https://paddlespeech.bj.bcebos.com/whisper/whisper_model_20250825/assets.tar', + 'resource_data_md5': + 'dd61d092d362f1fdbae6ede08282e177', + }, }, } diff --git a/paddlespeech/s2t/exps/whisper/test_wav.py b/paddlespeech/s2t/exps/whisper/test_wav.py index d9c32a406..f9fc11302 100644 --- a/paddlespeech/s2t/exps/whisper/test_wav.py +++ b/paddlespeech/s2t/exps/whisper/test_wav.py @@ -45,6 +45,7 @@ class WhisperInfer(): model_dict = paddle.load(self.config.model_file) config.pop("model_file") dims = ModelDimensions(**model_dict["dims"]) + self.dims = dims self.model = Whisper(dims) self.model.load_dict(model_dict) @@ -64,8 +65,10 @@ class WhisperInfer(): #load audio mel = log_mel_spectrogram( - args.audio_file, resource_path=config.resource_path) - + args.audio_file, + resource_path=config.resource_path, + n_mels=self.dims.n_mels, + padding=480000) result = transcribe( self.model, mel, temperature=temperature, **config) if args.result_file is not None: diff --git a/paddlespeech/s2t/models/whisper/tokenizer.py b/paddlespeech/s2t/models/whisper/tokenizer.py index e8b201bcc..552837e8a 100644 --- a/paddlespeech/s2t/models/whisper/tokenizer.py +++ b/paddlespeech/s2t/models/whisper/tokenizer.py @@ -2,17 +2,19 @@ # Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved. # # Modified from OpenAI Whisper 2022 (https://github.com/openai/whisper/whisper/tokenizer.py) +import base64 import os +import string from dataclasses import dataclass +from dataclasses import field +from functools import cached_property from functools import lru_cache +from typing import Dict from typing import List from typing import Optional from typing import Tuple -from typing import Union -import numpy as np -import paddle -from paddlenlp.transformers import GPTTokenizer +import tiktoken LANGUAGES = { "en": "english", @@ -35,7 +37,7 @@ LANGUAGES = { "hi": "hindi", "fi": "finnish", "vi": "vietnamese", - "iw": "hebrew", + "he": "hebrew", "uk": "ukrainian", "el": "greek", "ms": "malay", @@ -114,6 +116,7 @@ LANGUAGES = { "ba": "bashkir", "jw": "javanese", "su": "sundanese", + "yue": "cantonese", } # language code lookup by name, with a few language aliases @@ -130,134 +133,122 @@ TO_LANGUAGE_CODE = { "moldovan": "ro", "sinhalese": "si", "castilian": "es", + "mandarin": "zh", } -@dataclass(frozen=True) +@dataclass class Tokenizer: - """A thin wrapper around `GPTTokenizer` providing quick access to special tokens""" - - tokenizer: "GPTTokenizer" - language: Optional[str] - sot_sequence: Tuple[int] + """A thin wrapper around `tiktoken` providing quick access to special tokens""" + + encoding: tiktoken.Encoding + num_languages: int + language: Optional[str] = None + task: Optional[str] = None + sot_sequence: Tuple[int] = () + special_tokens: Dict[str, int] = field(default_factory=dict) + + def __post_init__(self): + for special in self.encoding.special_tokens_set: + special_token = self.encoding.encode_single_token(special) + self.special_tokens[special] = special_token + + sot: int = self.special_tokens["<|startoftranscript|>"] + translate: int = self.special_tokens["<|translate|>"] + transcribe: int = self.special_tokens["<|transcribe|>"] + langs = tuple(LANGUAGES.keys())[:self.num_languages] + sot_sequence = [sot] + if self.language is not None: + sot_sequence.append(sot + 1 + langs.index(self.language)) + if self.task is not None: + task_token: int = transcribe if self.task == "transcribe" else translate + sot_sequence.append(task_token) + + self.sot_sequence = tuple(sot_sequence) def encode(self, text, **kwargs): - return self.tokenizer.encode(text, **kwargs) - - def decode(self, - token_ids: Union[int, List[int], np.ndarray, paddle.Tensor], - **kwargs): - if len(token_ids) > 1: - ids_list = [] - for ids in token_ids: - if paddle.is_tensor(ids): - ids = ids.item() - if ids < len(self.tokenizer): - ids_list.append(ids) - token_ids = ids_list - elif len(token_ids) == 1: - token_ids = token_ids[0] - else: - raise ValueError(f"token_ids {token_ids} load error.") - - return self.tokenizer.decode(token_ids, **kwargs) - - def decode_with_timestamps(self, tokens) -> str: + return self.encoding.encode(text, **kwargs) + + def decode(self, token_ids: List[int], **kwargs) -> str: + token_ids = [t for t in token_ids if t < self.timestamp_begin] + return self.encoding.decode(token_ids, **kwargs) + + def decode_with_timestamps(self, token_ids: List[int], **kwargs) -> str: """ - Timestamp tokens are above the special tokens' id range and are ignored by `decode()`. + Timestamp tokens are above other special tokens' id range and are ignored by `decode()`. This method decodes given tokens with timestamps tokens annotated, e.g. "<|1.08|>". """ - outputs = [[]] - for token in tokens: - if token >= self.timestamp_begin: - timestamp = f"<|{(token - self.timestamp_begin) * 0.02:.2f}|>" - outputs.append(timestamp) - outputs.append([]) - else: - outputs[-1].append(token) - outputs = [ - s if isinstance(s, str) else self.tokenizer.decode(s) - for s in outputs - ] - return "".join(outputs) - - @property - @lru_cache() + return self.encoding.decode(token_ids, **kwargs) + + @cached_property def eot(self) -> int: - return self.tokenizer.eos_token_id + return self.encoding.eot_token + + @cached_property + def transcribe(self) -> int: + return self.special_tokens["<|transcribe|>"] - @property - @lru_cache() + @cached_property + def translate(self) -> int: + return self.special_tokens["<|translate|>"] + + @cached_property def sot(self) -> int: - return self._get_single_token_id("<|startoftranscript|>") + return self.special_tokens["<|startoftranscript|>"] - @property - @lru_cache() + @cached_property def sot_lm(self) -> int: - return self._get_single_token_id("<|startoflm|>") + return self.special_tokens["<|startoflm|>"] - @property - @lru_cache() + @cached_property def sot_prev(self) -> int: - return self._get_single_token_id("<|startofprev|>") + return self.special_tokens["<|startofprev|>"] - @property - @lru_cache() + @cached_property def no_speech(self) -> int: - return self._get_single_token_id("<|nospeech|>") + return self.special_tokens["<|nospeech|>"] - @property - @lru_cache() + @cached_property def no_timestamps(self) -> int: - return self._get_single_token_id("<|notimestamps|>") + return self.special_tokens["<|notimestamps|>"] - @property - @lru_cache() + @cached_property def timestamp_begin(self) -> int: - return self.tokenizer.all_special_ids[-1] + 1 + return self.special_tokens["<|0.00|>"] - @property - @lru_cache() + @cached_property def language_token(self) -> int: """Returns the token id corresponding to the value of the `language` field""" if self.language is None: raise ValueError( "This tokenizer does not have language token configured") - additional_tokens = dict( - zip( - self.tokenizer.additional_special_tokens, - self.tokenizer.additional_special_tokens_ids, )) - candidate = f"<|{self.language}|>" - if candidate in additional_tokens: - return additional_tokens[candidate] + return self.to_language_token(self.language) + + def to_language_token(self, language): + if token := self.special_tokens.get(f"<|{language}|>", None): + return token - raise KeyError(f"Language {self.language} not found in tokenizer.") + raise KeyError(f"Language {language} not found in tokenizer.") - @property - @lru_cache() + @cached_property def all_language_tokens(self) -> Tuple[int]: result = [] - for token, token_id in zip( - self.tokenizer.additional_special_tokens, - self.tokenizer.additional_special_tokens_ids, ): + for token, token_id in self.special_tokens.items(): if token.strip("<|>") in LANGUAGES: result.append(token_id) - return tuple(result) + return tuple(result)[:self.num_languages] - @property - @lru_cache() + @cached_property def all_language_codes(self) -> Tuple[str]: return tuple( - self.decode([l]).strip("<|>") for l in self.all_language_tokens) + self.decode([_l]).strip("<|>") for _l in self.all_language_tokens) - @property - @lru_cache() + @cached_property def sot_sequence_including_notimestamps(self) -> Tuple[int]: return tuple(list(self.sot_sequence) + [self.no_timestamps]) - @property - @lru_cache() + @cached_property def non_speech_tokens(self) -> Tuple[int]: """ Returns the list of tokens to suppress in order to avoid any speaker tags or non-speech @@ -269,9 +260,10 @@ class Tokenizer: keeping basic punctuations like commas, periods, question marks, exclamation points, etc. """ - symbols = list("\"#()*+/:;<=>@[\\]^_`{|}~「」『』") - symbols += "<< >> <<< >>> -- --- -( -[ (' (\" (( )) ((( ))) [[ ]] {{ }} ♪♪ ♪♪♪".split( - ) + symbols = list('"#()*+/:;<=>@[\\]^_`{|}~「」『』') + symbols += ( + "<< >> <<< >>> -- --- -( -[ (' (\" (( )) ((( ))) [[ ]] {{ }} ♪♪ ♪♪♪". + split()) # symbols that may be a single token or multiple tokens depending on the tokenizer. # In case they're multiple tokens, suppress the first token, which is safe because: @@ -281,45 +273,99 @@ class Tokenizer: assert all(0x2640 <= ord(c) <= 0x267F for c in miscellaneous) # allow hyphens "-" and single quotes "'" between words, but not at the beginning of a word - result = { - self.tokenizer.encode(" -").input_ids[0], - self.tokenizer.encode(" '").input_ids[0] - } + result = {self.encoding.encode(" -")[0], self.encoding.encode(" '")[0]} for symbol in symbols + list(miscellaneous): for tokens in [ - self.tokenizer.encode(symbol).input_ids, - self.tokenizer.encode(" " + symbol).input_ids + self.encoding.encode(symbol), + self.encoding.encode(" " + symbol), ]: if len(tokens) == 1 or symbol in miscellaneous: result.add(tokens[0]) return tuple(sorted(result)) - def _get_single_token_id(self, text) -> int: - tokens = self.tokenizer.encode(text).input_ids - assert len(tokens) == 1, f"{text} is not encoded as a single token" - return tokens[0] + def split_to_word_tokens(self, tokens: List[int]): + if self.language in {"zh", "ja", "th", "lo", "my", "yue"}: + # These languages don't typically use spaces, so it is difficult to split words + # without morpheme analysis. Here, we instead split words at any + # position where the tokens are decoded as valid unicode points + return self.split_tokens_on_unicode(tokens) + return self.split_tokens_on_spaces(tokens) -@lru_cache(maxsize=None) -def build_tokenizer(resource_path: str, name: str="gpt2"): - os.environ["TOKENIZERS_PARALLELISM"] = "false" - path = os.path.join(resource_path, "assets", name) - tokenizer = GPTTokenizer.from_pretrained(path) + def split_tokens_on_unicode(self, tokens: List[int]): + decoded_full = self.decode_with_timestamps(tokens) + replacement_char = "\ufffd" + words = [] + word_tokens = [] + current_tokens = [] + unicode_offset = 0 + + for token in tokens: + current_tokens.append(token) + decoded = self.decode_with_timestamps(current_tokens) + + if (replacement_char not in decoded or + decoded_full[unicode_offset + decoded.index( + replacement_char)] == replacement_char): + words.append(decoded) + word_tokens.append(current_tokens) + current_tokens = [] + unicode_offset += len(decoded) + + return words, word_tokens + + def split_tokens_on_spaces(self, tokens: List[int]): + subwords, subword_tokens_list = self.split_tokens_on_unicode(tokens) + words = [] + word_tokens = [] + + for subword, subword_tokens in zip(subwords, subword_tokens_list): + special = subword_tokens[0] >= self.eot + with_space = subword.startswith(" ") + punctuation = subword.strip() in string.punctuation + if special or with_space or punctuation or len(words) == 0: + words.append(subword) + word_tokens.append(subword_tokens) + else: + words[-1] = words[-1] + subword + word_tokens[-1].extend(subword_tokens) + + return words, word_tokens + + +@lru_cache(maxsize=None) +def get_encoding(resource_path: str, name: str="gpt2", num_languages: int=99): + vocab_path = os.path.join(resource_path, "assets", f"{name}.tiktoken") + ranks = { + base64.b64decode(token): int(rank) + for token, rank in (line.split() for line in open(vocab_path) if line) + } + + n_vocab = len(ranks) + special_tokens = {} specials = [ + "<|endoftext|>", "<|startoftranscript|>", - * [f"<|{lang}|>" for lang in LANGUAGES.keys()], + * [f"<|{lang}|>" for lang in list(LANGUAGES.keys())[:num_languages]], "<|translate|>", "<|transcribe|>", "<|startoflm|>", "<|startofprev|>", "<|nospeech|>", "<|notimestamps|>", + * [f"<|{i * 0.02:.2f}|>" for i in range(1501)], ] - - tokenizer.add_special_tokens(dict(additional_special_tokens=specials)) - return tokenizer + for token in specials: + special_tokens[token] = n_vocab + n_vocab += 1 + return tiktoken.Encoding( + name=os.path.basename(vocab_path), + explicit_n_vocab=n_vocab, + pat_str=r"""'s|'t|'re|'ve|'m|'ll|'d| ?\p{L}+| ?\p{N}+| ?[^\s\p{L}\p{N}]+|\s+(?!\S)|\s+""", + mergeable_ranks=ranks, + special_tokens=special_tokens, ) @lru_cache(maxsize=None) @@ -327,8 +373,11 @@ def get_tokenizer( multilingual: bool, resource_path: str, *, + num_languages: int=99, + language: Optional[str]=None, task: Optional[str]=None, # Literal["transcribe", "translate", None] - language: Optional[str]=None, ) -> Tokenizer: +) -> Tokenizer: + if language is not None: language = language.lower() if language not in LANGUAGES: @@ -338,29 +387,21 @@ def get_tokenizer( raise ValueError(f"Unsupported language: {language}") if multilingual: - tokenizer_name = "multilingual" - task = task or "transcribe" + encoding_name = "multilingual" language = language or "en" + task = task or "transcribe" else: - tokenizer_name = "gpt2" - task = None + encoding_name = "gpt2" language = None + task = None - tokenizer = build_tokenizer( - resource_path=resource_path, name=tokenizer_name) - all_special_ids: List[int] = tokenizer.all_special_ids - sot: int = all_special_ids[1] - translate: int = all_special_ids[-6] - transcribe: int = all_special_ids[-5] - - langs = tuple(LANGUAGES.keys()) - sot_sequence = [sot] - if language is not None: - sot_sequence.append(sot + 1 + langs.index(language)) - if task is not None: - sot_sequence.append(transcribe if task == "transcribe" else translate) + encoding = get_encoding( + resource_path=resource_path, + name=encoding_name, + num_languages=num_languages) return Tokenizer( - tokenizer=tokenizer, + encoding=encoding, + num_languages=num_languages, language=language, - sot_sequence=tuple(sot_sequence)) + task=task) diff --git a/paddlespeech/s2t/models/whisper/whisper.py b/paddlespeech/s2t/models/whisper/whisper.py index fdd3a6974..54aef0956 100644 --- a/paddlespeech/s2t/models/whisper/whisper.py +++ b/paddlespeech/s2t/models/whisper/whisper.py @@ -33,13 +33,18 @@ logger = Log(__name__).getlog() _MODELS = ["large"] SAMPLE_RATE = 16000 N_FFT = 400 -N_MELS = 80 HOP_LENGTH = 160 CHUNK_LENGTH = 30 N_SAMPLES = CHUNK_LENGTH * SAMPLE_RATE # 480000: number of samples in a chunk N_FRAMES = utils.exact_div( N_SAMPLES, HOP_LENGTH) # 3000: number of frames in a mel spectrogram input +N_SAMPLES_PER_TOKEN = HOP_LENGTH * 2 # the initial convolutions has stride 2 +FRAMES_PER_SECOND = utils.exact_div(SAMPLE_RATE, + HOP_LENGTH) # 10ms per audio frame +TOKENS_PER_SECOND = utils.exact_div(SAMPLE_RATE, + N_SAMPLES_PER_TOKEN) # 20ms per audio token + @dataclass class ModelDimensions: @@ -378,7 +383,9 @@ def detect_language( """ if tokenizer is None: tokenizer = get_tokenizer( - model.is_multilingual, resource_path=resource_path) + multilingual=model.is_multilingual, + resource_path=resource_path, + num_languages=model.num_languages) if tokenizer.language is None or tokenizer.language_token not in tokenizer.sot_sequence: raise ValueError( "This model doesn't have language tokens so it can't perform lang id" @@ -400,6 +407,7 @@ def detect_language( # collect detected languages; suppress all non-language tokens mask = paddle.ones(paddle.to_tensor(logits.shape[-1]), dtype=bool) mask[list(tokenizer.all_language_tokens)] = False + logits.contiguous() logits[:, mask] = -np.inf language_tokens = paddle.argmax(logits, axis=-1) language_token_probs = F.softmax(logits, axis=-1) @@ -428,6 +436,13 @@ def transcribe( logprob_threshold: Optional[float]=-1.0, no_speech_threshold: Optional[float]=0.6, condition_on_previous_text: bool=True, + initial_prompt: Optional[str]=None, + carry_initial_prompt: bool=False, + word_timestamps: bool=False, + prepend_punctuations: str="\"'“¿([{-", + append_punctuations: str="\"'.。,,!!??::”)]}、", + clip_timestamps: Union[str, List[float]]="0", + hallucination_silence_threshold: Optional[float]=None, **decode_options, ): """ Transcribe an audio file using Whisper @@ -476,8 +491,9 @@ def transcribe( if dtype == np.float32: decode_options["fp16"] = False - if decode_options.get("language") == 'None' or decode_options.get( - "language", None) is None: + content_frames = mel.shape[-1] - N_FRAMES + content_duration = float(content_frames * HOP_LENGTH / SAMPLE_RATE) + if decode_options.get("language", None) in {None, "None"}: if not model.is_multilingual: decode_options["language"] = "en" else: @@ -485,25 +501,47 @@ def transcribe( print( "Detecting language using up to the first 30 seconds. Use `--language` to specify the language" ) - segment = pad_or_trim(mel, N_FRAMES) - _, probs = model.detect_language(segment, resource_path) + mel_segment = pad_or_trim(mel, + N_FRAMES).to(model.device).astype(dtype) + _, probs = model.detect_language(mel_segment, resource_path) decode_options["language"] = max(probs, key=probs.get) if verbose is not None: print( f"Detected language: {LANGUAGES[decode_options['language']].title()}" ) - language = decode_options["language"] - task = decode_options.get("task", "transcribe") + language: str = decode_options["language"] + task: str = decode_options.get("task", "transcribe") tokenizer = get_tokenizer( - model.is_multilingual, + multilingual=model.is_multilingual, resource_path=resource_path, + num_languages=model.num_languages, language=language, - task=task) + task=task, ) + + if isinstance(clip_timestamps, str): + clip_timestamps = [ + float(ts) + for ts in (clip_timestamps.split(",") if clip_timestamps else []) + ] + seek_points: List[ + int] = [round(ts * FRAMES_PER_SECOND) for ts in clip_timestamps] + if len(seek_points) == 0: + seek_points.append(0) + if len(seek_points) % 2 == 1: + seek_points.append(content_frames) + seek_clips: List[Tuple[int, int]] = list( + zip(seek_points[::2], seek_points[1::2])) + + punctuation = "\"'“¿([{-\"'.。,,!!??::”)]}、" + + if word_timestamps and task == "translate": + warnings.warn( + "Word-level timestamps on translations may not be reliable.") def decode_with_fallback(segment: paddle.Tensor) -> DecodingResult: - temperatures = [temperature] if isinstance(temperature, ( - int, float)) else temperature + temperatures = ([temperature] if isinstance(temperature, (int, float)) + else temperature) decode_result = None for t in temperatures: @@ -517,20 +555,29 @@ def transcribe( kwargs.pop("best_of", None) options = DecodingOptions(**kwargs, temperature=t) + decode_result = model.decode(segment, options, resource_path) needs_fallback = False - if compression_ratio_threshold is not None and decode_result.compression_ratio > compression_ratio_threshold: + if (compression_ratio_threshold is not None and + decode_result.compression_ratio > + compression_ratio_threshold): needs_fallback = True # too repetitive - if logprob_threshold is not None and decode_result.avg_logprob < logprob_threshold: + if (logprob_threshold is not None and + decode_result.avg_logprob < logprob_threshold): needs_fallback = True # average log probability is too low - + if (no_speech_threshold is not None and + decode_result.no_speech_prob > no_speech_threshold and + logprob_threshold is not None and + decode_result.avg_logprob < logprob_threshold): + needs_fallback = False # silence if not needs_fallback: break return decode_result - seek = 0 + clip_idx = 0 + seek = seek_clips[clip_idx][0] input_stride = utils.exact_div( N_FRAMES, model.dims.n_audio_ctx) # mel frames per output token: 2 time_precision = (input_stride * HOP_LENGTH / @@ -539,127 +586,210 @@ def transcribe( all_segments = [] prompt_reset_since = 0 - initial_prompt = decode_options.pop("initial_prompt", None) or [] - if initial_prompt: - initial_prompt = tokenizer.encode(" " + - initial_prompt.strip()).input_ids - all_tokens.extend(initial_prompt) + remaining_prompt_length = model.dims.n_text_ctx // 2 - 1 + if initial_prompt is not None: + initial_prompt_tokens = tokenizer.encode(" " + initial_prompt.strip()) + all_tokens.extend(initial_prompt_tokens) + remaining_prompt_length -= len(initial_prompt_tokens) + else: + initial_prompt_tokens = [] - def add_segment(*, + def new_segment(*, start: float, end: float, - text_tokens: paddle.Tensor, + tokens: paddle.Tensor, result: DecodingResult): - text = tokenizer.decode( - [token for token in text_tokens if token < tokenizer.eot]) - if len(text.strip()) == 0: # skip empty text output - return - - all_segments.append({ - "id": len(all_segments), + tokens = tokens.tolist() + text_tokens = [token for token in tokens if token < tokenizer.eot] + return { "seek": seek, "start": start, "end": end, - "text": text, - "tokens": result.tokens, + "text": tokenizer.decode(text_tokens), + "tokens": tokens, "temperature": result.temperature, "avg_logprob": result.avg_logprob, "compression_ratio": result.compression_ratio, "no_speech_prob": result.no_speech_prob, - }) - if verbose: - print( - f"[{utils.format_timestamp(start)} --> {utils.format_timestamp(end)}] {text}" - ) - - # show the progress bar when verbose is False (otherwise the transcribed text will be printed) - num_frames = mel.shape[-1] - previous_seek_value = seek + } + # show the progress bar when verbose is False (if True, transcribed text will be printed) with tqdm.tqdm( - total=num_frames, unit='frames', + total=content_frames, unit="frames", disable=verbose is not False) as pbar: - while seek < num_frames: - timestamp_offset = float(seek * HOP_LENGTH / SAMPLE_RATE) - segment = pad_or_trim(mel[:, seek:], N_FRAMES) - segment_duration = segment.shape[-1] * HOP_LENGTH / SAMPLE_RATE - - decode_options["prompt"] = all_tokens[prompt_reset_since:] - result: DecodingResult = decode_with_fallback(segment) + last_speech_timestamp = 0.0 + # NOTE: This loop is obscurely flattened to make the diff readable. + # A later commit should turn this into a simpler nested loop. + # for seek_clip_start, seek_clip_end in seek_clips: + # while seek < seek_clip_end + while clip_idx < len(seek_clips): + seek_clip_start, seek_clip_end = seek_clips[clip_idx] + if seek < seek_clip_start: + seek = seek_clip_start + if seek >= seek_clip_end: + clip_idx += 1 + if clip_idx < len(seek_clips): + seek = seek_clips[clip_idx][0] + continue + time_offset = float(seek * HOP_LENGTH / SAMPLE_RATE) + window_end_time = float( + (seek + N_FRAMES) * HOP_LENGTH / SAMPLE_RATE) + segment_size = min(N_FRAMES, content_frames - seek, + seek_clip_end - seek) + mel_segment = mel[:, seek:seek + segment_size] + segment_duration = segment_size * HOP_LENGTH / SAMPLE_RATE + mel_segment = pad_or_trim(mel_segment, + N_FRAMES).to(model.device).astype(dtype) + + if carry_initial_prompt: + nignored = max(len(initial_prompt_tokens), prompt_reset_since) + remaining_prompt = all_tokens[nignored:][ + -remaining_prompt_length:] + decode_options[ + "prompt"] = initial_prompt_tokens + remaining_prompt + else: + decode_options["prompt"] = all_tokens[prompt_reset_since:] + result: DecodingResult = decode_with_fallback(mel_segment) tokens = paddle.to_tensor(result.tokens) if no_speech_threshold is not None: # no voice activity check should_skip = result.no_speech_prob > no_speech_threshold - if logprob_threshold is not None and result.avg_logprob > logprob_threshold: + if (logprob_threshold is not None and + result.avg_logprob > logprob_threshold): # don't skip if the logprob is high enough, despite the no_speech_prob should_skip = False if should_skip: - seek += segment.shape[ - -1] # fast-forward to the next segment boundary + seek += segment_size # fast-forward to the next segment boundary continue + previous_seek = seek + current_segments = [] + + # anomalous words are very long/short/improbable + def word_anomaly_score(word: dict) -> float: + probability = word.get("probability", 0.0) + duration = word["end"] - word["start"] + score = 0.0 + if probability < 0.15: + score += 1.0 + if duration < 0.133: + score += (0.133 - duration) * 15 + if duration > 2.0: + score += duration - 2.0 + return score + + def is_segment_anomaly(segment: Optional[dict]) -> bool: + if segment is None or not segment["words"]: + return False + words = [ + w for w in segment["words"] if w["word"] not in punctuation + ] + words = words[:8] + score = sum(word_anomaly_score(w) for w in words) + return score >= 3 or score + 0.01 >= len(words) + + def next_words_segment(segments: List[dict]) -> Optional[dict]: + return next((s for s in segments if s["words"]), None) + timestamp_tokens: paddle.Tensor = tokens.greater_equal( paddle.to_tensor(tokenizer.timestamp_begin)) + single_timestamp_ending = timestamp_tokens[ + -2:].tolist() == [False, True] consecutive = paddle.where(timestamp_tokens[:-1] & timestamp_tokens[ 1:])[0] - if len( - consecutive - ) > 0: # if the output contains two consecutive timestamp tokens + if consecutive.numel() != 0: consecutive = paddle.add(consecutive, paddle.to_tensor(1)) + if len(consecutive) > 0: + # if the output contains two consecutive timestamp tokens + slices = consecutive.tolist() + if single_timestamp_ending: + slices.append(len(tokens)) + last_slice = 0 - for current_slice in consecutive: + for current_slice in slices: sliced_tokens = tokens[last_slice:current_slice] - start_timestamp_position = ( + start_timestamp_pos = ( sliced_tokens[0].item() - tokenizer.timestamp_begin) - end_timestamp_position = ( + end_timestamp_pos = ( sliced_tokens[-1].item() - tokenizer.timestamp_begin) - add_segment( - start=timestamp_offset + start_timestamp_position * - time_precision, - end=timestamp_offset + end_timestamp_position * - time_precision, - text_tokens=sliced_tokens[1:-1], - result=result, ) + current_segments.append( + new_segment( + start=time_offset + start_timestamp_pos * + time_precision, + end=time_offset + end_timestamp_pos * + time_precision, + tokens=sliced_tokens, + result=result, )) last_slice = current_slice - last_timestamp_position = ( - tokens[last_slice - 1].item() - tokenizer.timestamp_begin) - seek += last_timestamp_position * input_stride - all_tokens.extend(tokens[:last_slice + 1].tolist()) + + if single_timestamp_ending: + # single timestamp at the end means no speech after the last timestamp. + seek += segment_size + else: + # otherwise, ignore the unfinished segment and seek to the last timestamp + last_timestamp_pos = (tokens[last_slice - 1].item() - + tokenizer.timestamp_begin) + seek += last_timestamp_pos * input_stride else: duration = segment_duration timestamps = tokens[timestamp_tokens.nonzero().flatten()] - if len(timestamps) > 0 and timestamps[ - -1].item() != tokenizer.timestamp_begin: + if (len(timestamps) > 0 and + timestamps[-1].item() != tokenizer.timestamp_begin): # no consecutive timestamps but it has a timestamp; use the last one. - # single timestamp at the end means no speech after the last timestamp. - last_timestamp_position = timestamps[ - -1].item() - tokenizer.timestamp_begin - duration = last_timestamp_position * time_precision + last_timestamp_pos = ( + timestamps[-1].item() - tokenizer.timestamp_begin) + duration = last_timestamp_pos * time_precision + + current_segments.append( + new_segment( + start=time_offset, + end=time_offset + duration, + tokens=tokens, + result=result, )) + seek += segment_size - add_segment( - start=timestamp_offset, - end=timestamp_offset + duration, - text_tokens=tokens, - result=result, ) - - seek += segment.shape[-1] - all_tokens.extend(tokens.tolist()) + if verbose: + for segment in current_segments: + start, end, text = segment["start"], segment[ + "end"], segment["text"] + line = f"[{utils.format_timestamp(start)} --> {utils.format_timestamp(end)}] {text}" + print(line) + + # if a segment is instantaneous or does not contain text, clear it + for i, segment in enumerate(current_segments): + if segment["start"] == segment["end"] or segment[ + "text"].strip() == "": + segment["text"] = "" + segment["tokens"] = [] + segment["words"] = [] + + all_segments.extend( + [{ + "id": i, + ** + segment + } + for i, segment in enumerate( + current_segments, start=len(all_segments))]) + all_tokens.extend([ + token + for segment in current_segments for token in segment["tokens"] + ]) if not condition_on_previous_text or result.temperature > 0.5: # do not feed the prompt tokens if a high temperature was used prompt_reset_since = len(all_tokens) # update progress bar - pbar.update(min(num_frames, seek) - previous_seek_value) - previous_seek_value = seek + pbar.update(min(content_frames, seek) - previous_seek) return dict( - text=tokenizer.decode(all_tokens[len(initial_prompt):]), + text=tokenizer.decode(all_tokens[len(initial_prompt_tokens):]), segments=all_segments, - language=language) + language=language, ) class SequenceRanker: @@ -776,11 +906,11 @@ class GreedyDecoder(TokenDecoder): next_tokens.shape[0] * next_tokens.shape[1], ]) - logprobs = F.log_softmax(logits, axis=-1, dtype=paddle.float32) + logprobs = F.log_softmax(logits, axis=-1, dtype="float32") current_logprobs = logprobs[paddle.arange(logprobs.shape[0]), next_tokens] sum_logprobs += current_logprobs * paddle.to_tensor( - (tokens[:, -1] != self.eot), dtype=paddle.float32) + (tokens[:, -1] != self.eot), dtype="float32") next_tokens[tokens[:, -1] == self.eot] = self.eot tokens = paddle.concat([tokens, next_tokens[:, None]], axis=-1) @@ -928,8 +1058,9 @@ class SuppressBlank(LogitFilter): def apply(self, logits: paddle.Tensor, tokens: paddle.Tensor): if tokens.shape[1] == self.sample_begin: - logits[:, self.tokenizer.encode(" ").input_ids + - [self.tokenizer.eot]] = -np.inf + logits.contiguous() + logits[:, self.tokenizer.encode(" ") + [self.tokenizer.eot + ]] = -np.inf class SuppressTokens(LogitFilter): @@ -937,6 +1068,7 @@ class SuppressTokens(LogitFilter): self.suppress_tokens = list(suppress_tokens) def apply(self, logits: paddle.Tensor, tokens: paddle.Tensor): + logits.contiguous() logits[:, self.suppress_tokens] = -np.inf @@ -952,6 +1084,7 @@ class ApplyTimestampRules(LogitFilter): def apply(self, logits: paddle.Tensor, tokens: paddle.Tensor): # suppress <|notimestamps|> which is handled by without_timestamps if self.tokenizer.no_timestamps is not None: + logits.contiguous() logits[:, self.tokenizer.no_timestamps] = -np.inf # timestamps have to appear in pairs, except directly before EOT; mask logits accordingly @@ -972,6 +1105,7 @@ class ApplyTimestampRules(LogitFilter): if tokens.shape[ 1] == self.sample_begin and self.max_initial_timestamp_index is not None: last_allowed = self.tokenizer.timestamp_begin + self.max_initial_timestamp_index + logits.contiguous() logits[:, last_allowed + 1:] = -np.inf # if sum of probability over timestamps is above any other token, sample timestamp @@ -1005,15 +1139,16 @@ class DecodingTask: language = options.language or "en" tokenizer = get_tokenizer( - model.is_multilingual, + multilingual=model.is_multilingual, resource_path=resource_path, language=language, - task=options.task) + task=options.task, + num_languages=model.num_languages) self.tokenizer: Tokenizer = tokenizer self.options: DecodingOptions = self._verify_options(options) self.resource_path: str = resource_path - self.beam_size: int = options.beam_size or options.best_of or 1 + self.n_group: int = options.beam_size or options.best_of or 1 self.n_ctx: int = model.dims.n_text_ctx self.sample_len: int = options.sample_len or model.dims.n_text_ctx // 2 @@ -1158,10 +1293,10 @@ class DecodingTask: sum_logprobs: paddle.Tensor = paddle.zeros( paddle.to_tensor(n_batch), dtype=paddle.float32) no_speech_probs = [np.nan] * n_batch - try: for i in range(self.sample_len): logits = self.inference.logits(tokens, audio_features) + logits.contiguous() if i == 0 and self.tokenizer.no_speech is not None: # save no_speech_probs probs_at_sot = F.softmax( @@ -1197,17 +1332,8 @@ class DecodingTask: audio_features: paddle.Tensor = self._get_audio_features( mel) # encoder forward pass - tokens: paddle.Tensor - if batch_size > 1: - for i in range(batch_size): - tokens = paddle.concat( - x=[ - paddle.to_tensor([self.initial_tokens]), - paddle.to_tensor([self.initial_tokens]) - ], - axis=0) - elif batch_size == 1: - tokens = paddle.to_tensor([self.initial_tokens]) + tokens: Tensor = paddle.to_tensor([self.initial_tokens]).repeat( + batch_size, 1) # detect language if requested, overwriting the language token languages, language_probs = self._detect_language( @@ -1224,30 +1350,26 @@ class DecodingTask: language_probs) ] - # repeat the audio & text tensors by the group size, for beam search or best-of-n sampling - - audio_features = paddle.repeat_interleave( - audio_features, self.beam_size, axis=0) - tokens = paddle.repeat_interleave(tokens, self.beam_size, axis=0) + # repeat text tensors by the group size, for beam search or best-of-n sampling + tokens = tokens.repeat_interleave(self.n_group, axis=0) # call the main sampling loop tokens, sum_logprobs, no_speech_probs = self._main_loop(audio_features, tokens) - # reshape the tensors to have (batch_size, beam_size) as the first two dimensions - audio_features = audio_features[::self.beam_size] - no_speech_probs = no_speech_probs[::self.beam_size] + # reshape the tensors to have (batch_size, n_group) as the first two dimensions + audio_features = audio_features[::self.n_group] + no_speech_probs = no_speech_probs[::self.n_group] assert audio_features.shape[0] == len(no_speech_probs) == batch_size - tokens = tokens.reshape([batch_size, self.beam_size, -1]) - sum_logprobs = sum_logprobs.reshape([batch_size, self.beam_size]) + tokens = tokens.reshape([batch_size, self.n_group, -1]) + sum_logprobs = sum_logprobs.reshape([batch_size, self.n_group]) # get the final candidates for each group, and slice between the first sampled token and EOT tokens, sum_logprobs = self.decoder.finalize(tokens, sum_logprobs) - tokens: List[List[paddle.Tensor]] = [[ + tokens: List[List[Tensor]] = [[ t[self.sample_begin:(t == tokenizer.eot).nonzero()[0, 0]] for t in s ] for s in tokens] - # select the top-ranked sample in each group selected = self.sequence_ranker.rank(tokens, sum_logprobs) tokens: List[List[ @@ -1256,11 +1378,12 @@ class DecodingTask: sum_logprobs: List[ float] = [lp[i] for i, lp in zip(selected, sum_logprobs)] - avg_logprobs: List[ - float] = [lp / (len(t) + 1) for t, lp in zip(tokens, sum_logprobs)] + avg_logprobs: List[float] = [ + lp / (len(t) + 1) for t, lp in zip(tokens, sum_logprobs) + ] fields = (texts, languages, tokens, audio_features, avg_logprobs, - no_speech_probs) + no_speech_probs, ) if len(set(map(len, fields))) != 1: raise RuntimeError( f"inconsistent result lengths: {list(map(len, fields))}") @@ -1294,7 +1417,7 @@ def decode( model: Whisper the Whisper model instance - mel: paddle.Tensor, shape = (80, 3000) or (*, 80, 3000) + mel: paddle.Tensor, shape = (80, 3000) or (*, 80, 3000) or (128, 3000) or (*, 128, 3000) A tensor containing the Mel spectrogram(s) options: DecodingOptions @@ -1350,7 +1473,11 @@ class Whisper(nn.Layer): @property def is_multilingual(self): - return self.dims.n_vocab == 51865 + return self.dims.n_vocab >= 51865 + + @property + def num_languages(self): + return self.dims.n_vocab - 51765 - int(self.is_multilingual) def install_kv_cache_hooks(self, cache: Optional[dict]=None): """ @@ -1364,7 +1491,7 @@ class Whisper(nn.Layer): cache : Dict[nn.Layer, paddle.Tensor] A dictionary object mapping the key/value projection modules to its cache hooks : List[RemovableHandle] - List of PyTorch RemovableHandle objects to stop the hooks to be called + List of Paddle RemovableHandle objects to stop the hooks to be called """ cache = {**cache} if cache is not None else {} hooks = [] @@ -1431,11 +1558,11 @@ def hann_window(n_fft: int=N_FFT): """ return paddle.to_tensor( [0.5 - 0.5 * np.cos(2 * np.pi * n / n_fft) for n in range(n_fft)], - dtype=paddle.float32) + dtype="float32") @lru_cache(maxsize=None) -def mel_filters(resource_path: str, n_mels: int=N_MELS) -> paddle.Tensor: +def mel_filters(resource_path: str, n_mels: int) -> paddle.Tensor: """ load the mel filterbank matrix for projecting STFT into a Mel spectrogram. Allows decoupling librosa dependency; saved using: @@ -1445,13 +1572,16 @@ def mel_filters(resource_path: str, n_mels: int=N_MELS) -> paddle.Tensor: mel_80=librosa.filters.mel(sr=16000, n_fft=400, n_mels=80), ) """ - assert n_mels == 80, f"Unsupported n_mels: {n_mels}" - with np.load(os.path.join(resource_path, "assets", "mel_filters.npz")) as f: + assert n_mels in {80, 128}, f"Unsupported n_mels: {n_mels}" + + filters_path = os.path.join(resource_path, "assets", "mel_filters.npz") + with np.load(filters_path, allow_pickle=False) as f: return paddle.to_tensor(f[f"mel_{n_mels}"]) def log_mel_spectrogram(audio: Union[str, np.ndarray, paddle.Tensor], - n_mels: int=N_MELS, + n_mels: int=80, + padding: int=0, resource_path: str=None): """ Compute the log-Mel spectrogram of @@ -1462,11 +1592,11 @@ def log_mel_spectrogram(audio: Union[str, np.ndarray, paddle.Tensor], The path to audio or either a NumPy array or Tensor containing the audio waveform in 16 kHz n_mels: int - The number of Mel-frequency filters, only 80 is supported + The number of Mel-frequency filters, only 80 and 128 is supported Returns ------- - paddle.Tensor, shape = (80, n_frames) + paddle.Tensor, shape = (n_mels, n_frames) A Tensor that contains the Mel spectrogram """ if not paddle.is_tensor(audio): @@ -1475,7 +1605,8 @@ def log_mel_spectrogram(audio: Union[str, np.ndarray, paddle.Tensor], audio = audio[:, 0] logger.info(f"audio shape: {audio.shape}") audio = paddle.to_tensor(audio) - + if padding > 0: + audio = F.pad(audio, (0, padding), data_format="NLC") window = hann_window(N_FFT) stft = paddle.signal.stft(audio, N_FFT, HOP_LENGTH, window=window) diff --git a/setup.py b/setup.py index 47e543ce4..581260a19 100644 --- a/setup.py +++ b/setup.py @@ -128,6 +128,7 @@ base = [ "flatten_dict", "pyloudnorm", "rich", + "tiktoken", ] server = ["pattern_singleton", "websockets"]