whisper large v3

pull/4101/head
zxcd 2 months ago
parent f032b3811a
commit 4c83b49a8d

@ -192,8 +192,8 @@ class WhisperExecutor(BaseExecutor):
self.resource_path = os.path.join(
DATA_HOME, self.task_resource.version, 'whisper')
self.download_resource(resource_url, self.resource_path,
resource_md5)
# self.download_resource(resource_url, self.resource_path,
# resource_md5)
else:
raise Exception("wrong type")
@ -251,8 +251,9 @@ class WhisperExecutor(BaseExecutor):
logger.debug(f"audio shape: {audio.shape}")
# fbank
audio = log_mel_spectrogram(audio, resource_path=self.resource_path)
audio = log_mel_spectrogram(
audio, resource_path=self.resource_path, n_mels=128, padding=480000)
print(audio)
audio_len = paddle.to_tensor(audio.shape[0]).unsqueeze(axis=0)
self._inputs["audio"] = audio
@ -275,7 +276,6 @@ class WhisperExecutor(BaseExecutor):
cfg.temperature_increment_on_fallback))
else:
temperature = [cfg.temperature]
self._outputs["result"] = self.model.transcribe(
audio,
verbose=cfg.verbose,

@ -63,9 +63,13 @@ class WhisperInfer():
temperature = [temperature]
#load audio
mel = log_mel_spectrogram(
args.audio_file, resource_path=config.resource_path)
# mel = log_mel_spectrogram(
# args.audio_file, resource_path=config.resource_path, , n_mels=128)
audio = log_mel_spectrogram(
args.audio_file,
resource_path=config.resource_path,
n_mels=128,
padding=480000)
result = transcribe(
self.model, mel, temperature=temperature, **config)
if args.result_file is not None:

@ -2,9 +2,13 @@
# Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved.
#
# Modified from OpenAI Whisper 2022 (https://github.com/openai/whisper/whisper/tokenizer.py)
import base64
import os
from dataclasses import dataclass
from dataclasses import field
from functools import cached_property
from functools import lru_cache
from typing import Dict
from typing import List
from typing import Optional
from typing import Tuple
@ -12,7 +16,8 @@ from typing import Union
import numpy as np
import paddle
from paddlenlp.transformers import GPTTokenizer
import tiktoken
# from paddlenlp.transformers import GPTTokenizer
LANGUAGES = {
"en": "english",
@ -35,7 +40,7 @@ LANGUAGES = {
"hi": "hindi",
"fi": "finnish",
"vi": "vietnamese",
"iw": "hebrew",
"he": "hebrew",
"uk": "ukrainian",
"el": "greek",
"ms": "malay",
@ -114,6 +119,7 @@ LANGUAGES = {
"ba": "bashkir",
"jw": "javanese",
"su": "sundanese",
"yue": "cantonese",
}
# language code lookup by name, with a few language aliases
@ -130,37 +136,63 @@ TO_LANGUAGE_CODE = {
"moldovan": "ro",
"sinhalese": "si",
"castilian": "es",
"mandarin": "zh",
}
@dataclass(frozen=True)
@dataclass
class Tokenizer:
"""A thin wrapper around `GPTTokenizer` providing quick access to special tokens"""
tokenizer: "GPTTokenizer"
language: Optional[str]
sot_sequence: Tuple[int]
encoding: tiktoken.Encoding
num_languages: int
language: Optional[str] = None
task: Optional[str] = None
sot_sequence: Tuple[int] = ()
special_tokens: Dict[str, int] = field(default_factory=dict)
def __post_init__(self):
for special in self.encoding.special_tokens_set:
special_token = self.encoding.encode_single_token(special)
self.special_tokens[special] = special_token
sot: int = self.special_tokens["<|startoftranscript|>"]
translate: int = self.special_tokens["<|translate|>"]
transcribe: int = self.special_tokens["<|transcribe|>"]
langs = tuple(LANGUAGES.keys())[:self.num_languages]
sot_sequence = [sot]
if self.language is not None:
sot_sequence.append(sot + 1 + langs.index(self.language))
if self.task is not None:
task_token: int = transcribe if self.task == "transcribe" else translate
sot_sequence.append(task_token)
print("sot_sequence", sot_sequence)
self.sot_sequence = tuple(sot_sequence)
def encode(self, text, **kwargs):
return self.tokenizer.encode(text, **kwargs)
def decode(self,
token_ids: Union[int, List[int], np.ndarray, paddle.Tensor],
**kwargs):
if len(token_ids) > 1:
ids_list = []
for ids in token_ids:
if paddle.is_tensor(ids):
ids = ids.item()
if ids < len(self.tokenizer):
ids_list.append(ids)
token_ids = ids_list
elif len(token_ids) == 1:
token_ids = token_ids[0]
else:
raise ValueError(f"token_ids {token_ids} load error.")
return self.tokenizer.decode(token_ids, **kwargs)
return self.encoding.encode(text, **kwargs)
# def decode(self,
# token_ids: Union[int, List[int], np.ndarray, paddle.Tensor],
# **kwargs):
# if len(token_ids) > 1:
# ids_list = []
# for ids in token_ids:
# if paddle.is_tensor(ids):
# ids = ids.item()
# if ids < len(self.tokenizer):
# ids_list.append(ids)
# token_ids = ids_list
# elif len(token_ids) == 1:
# token_ids = token_ids[0]
# else:
# raise ValueError(f"token_ids {token_ids} load error.")
# return self.tokenizer.decode(token_ids, **kwargs)
def decode(self, token_ids: List[int], **kwargs) -> str:
token_ids = [t for t in token_ids if t < self.timestamp_begin]
return self.encoding.decode(token_ids, **kwargs)
def decode_with_timestamps(self, tokens) -> str:
"""
@ -181,83 +213,75 @@ class Tokenizer:
]
return "".join(outputs)
@property
@lru_cache()
@cached_property
def eot(self) -> int:
return self.tokenizer.eos_token_id
return self.encoding.eot_token
@property
@lru_cache()
@cached_property
def transcribe(self) -> int:
return self.special_tokens["<|transcribe|>"]
@cached_property
def translate(self) -> int:
return self.special_tokens["<|translate|>"]
@cached_property
def sot(self) -> int:
return self._get_single_token_id("<|startoftranscript|>")
return self.special_tokens["<|startoftranscript|>"]
@property
@lru_cache()
@cached_property
def sot_lm(self) -> int:
return self._get_single_token_id("<|startoflm|>")
return self.special_tokens["<|startoflm|>"]
@property
@lru_cache()
@cached_property
def sot_prev(self) -> int:
return self._get_single_token_id("<|startofprev|>")
return self.special_tokens["<|startofprev|>"]
@property
@lru_cache()
@cached_property
def no_speech(self) -> int:
return self._get_single_token_id("<|nospeech|>")
return self.special_tokens["<|nospeech|>"]
@property
@lru_cache()
@cached_property
def no_timestamps(self) -> int:
return self._get_single_token_id("<|notimestamps|>")
return self.special_tokens["<|notimestamps|>"]
@property
@lru_cache()
@cached_property
def timestamp_begin(self) -> int:
return self.tokenizer.all_special_ids[-1] + 1
return self.special_tokens["<|0.00|>"]
@property
@lru_cache()
@cached_property
def language_token(self) -> int:
"""Returns the token id corresponding to the value of the `language` field"""
if self.language is None:
raise ValueError(
"This tokenizer does not have language token configured")
additional_tokens = dict(
zip(
self.tokenizer.additional_special_tokens,
self.tokenizer.additional_special_tokens_ids, ))
candidate = f"<|{self.language}|>"
if candidate in additional_tokens:
return additional_tokens[candidate]
return self.to_language_token(self.language)
def to_language_token(self, language):
if token := self.special_tokens.get(f"<|{language}|>", None):
return token
raise KeyError(f"Language {self.language} not found in tokenizer.")
raise KeyError(f"Language {language} not found in tokenizer.")
@property
@lru_cache()
@cached_property
def all_language_tokens(self) -> Tuple[int]:
result = []
for token, token_id in zip(
self.tokenizer.additional_special_tokens,
self.tokenizer.additional_special_tokens_ids, ):
for token, token_id in self.special_tokens.items():
if token.strip("<|>") in LANGUAGES:
result.append(token_id)
return tuple(result)
return tuple(result)[:self.num_languages]
@property
@lru_cache()
@cached_property
def all_language_codes(self) -> Tuple[str]:
return tuple(
self.decode([l]).strip("<|>") for l in self.all_language_tokens)
self.decode([_l]).strip("<|>") for _l in self.all_language_tokens)
@property
@lru_cache()
@cached_property
def sot_sequence_including_notimestamps(self) -> Tuple[int]:
return tuple(list(self.sot_sequence) + [self.no_timestamps])
@property
@lru_cache()
@cached_property
def non_speech_tokens(self) -> Tuple[int]:
"""
Returns the list of tokens to suppress in order to avoid any speaker tags or non-speech
@ -269,9 +293,10 @@ class Tokenizer:
keeping basic punctuations like commas, periods, question marks, exclamation points, etc.
"""
symbols = list("\"#()*+/:;<=>@[\\]^_`{|}~「」『』")
symbols += "<< >> <<< >>> -- --- -( -[ (' (\" (( )) ((( ))) [[ ]] {{ }} ♪♪ ♪♪♪".split(
)
symbols = list('"#()*+/:;<=>@[\\]^_`{|}~「」『』')
symbols += (
"<< >> <<< >>> -- --- -( -[ (' (\" (( )) ((( ))) [[ ]] {{ }} ♪♪ ♪♪♪".
split())
# symbols that may be a single token or multiple tokens depending on the tokenizer.
# In case they're multiple tokens, suppress the first token, which is safe because:
@ -281,45 +306,170 @@ class Tokenizer:
assert all(0x2640 <= ord(c) <= 0x267F for c in miscellaneous)
# allow hyphens "-" and single quotes "'" between words, but not at the beginning of a word
result = {
self.tokenizer.encode(" -").input_ids[0],
self.tokenizer.encode(" '").input_ids[0]
}
result = {self.encoding.encode(" -")[0], self.encoding.encode(" '")[0]}
for symbol in symbols + list(miscellaneous):
for tokens in [
self.tokenizer.encode(symbol).input_ids,
self.tokenizer.encode(" " + symbol).input_ids
self.encoding.encode(symbol),
self.encoding.encode(" " + symbol),
]:
if len(tokens) == 1 or symbol in miscellaneous:
result.add(tokens[0])
return tuple(sorted(result))
def _get_single_token_id(self, text) -> int:
tokens = self.tokenizer.encode(text).input_ids
assert len(tokens) == 1, f"{text} is not encoded as a single token"
return tokens[0]
def split_to_word_tokens(self, tokens: List[int]):
if self.language in {"zh", "ja", "th", "lo", "my", "yue"}:
# These languages don't typically use spaces, so it is difficult to split words
# without morpheme analysis. Here, we instead split words at any
# position where the tokens are decoded as valid unicode points
return self.split_tokens_on_unicode(tokens)
return self.split_tokens_on_spaces(tokens)
def split_tokens_on_unicode(self, tokens: List[int]):
decoded_full = self.decode_with_timestamps(tokens)
replacement_char = "\ufffd"
words = []
word_tokens = []
current_tokens = []
unicode_offset = 0
for token in tokens:
current_tokens.append(token)
decoded = self.decode_with_timestamps(current_tokens)
if (replacement_char not in decoded or
decoded_full[unicode_offset + decoded.index(
replacement_char)] == replacement_char):
words.append(decoded)
word_tokens.append(current_tokens)
current_tokens = []
unicode_offset += len(decoded)
return words, word_tokens
def split_tokens_on_spaces(self, tokens: List[int]):
subwords, subword_tokens_list = self.split_tokens_on_unicode(tokens)
words = []
word_tokens = []
for subword, subword_tokens in zip(subwords, subword_tokens_list):
special = subword_tokens[0] >= self.eot
with_space = subword.startswith(" ")
punctuation = subword.strip() in string.punctuation
if special or with_space or punctuation or len(words) == 0:
words.append(subword)
word_tokens.append(subword_tokens)
else:
words[-1] = words[-1] + subword
word_tokens[-1].extend(subword_tokens)
return words, word_tokens
# @lru_cache(maxsize=None)
# def build_tokenizer(resource_path: str, name: str="gpt2"):
# os.environ["TOKENIZERS_PARALLELISM"] = "false"
# path = os.path.join(resource_path, "assets", name)
# tokenizer = GPTTokenizer.from_pretrained(path)
# specials = [
# "<|startoftranscript|>",
# * [f"<|{lang}|>" for lang in LANGUAGES.keys()],
# "<|translate|>",
# "<|transcribe|>",
# "<|startoflm|>",
# "<|startofprev|>",
# "<|nospeech|>",
# "<|notimestamps|>",
# ]
# tokenizer.add_special_tokens(dict(additional_special_tokens=specials))
# return tokenizer
@lru_cache(maxsize=None)
def build_tokenizer(resource_path: str, name: str="gpt2"):
os.environ["TOKENIZERS_PARALLELISM"] = "false"
path = os.path.join(resource_path, "assets", name)
tokenizer = GPTTokenizer.from_pretrained(path)
def get_encoding(resource_path: str, name: str="gpt2", num_languages: int=99):
print("resource_path", resource_path)
print("name", name)
vocab_path = os.path.join(resource_path, "assets", f"{name}.tiktoken")
# vocab_path = os.path.join(resource_path, "assets", sname)
# vocab_path += ".tiktoken"
ranks = {
base64.b64decode(token): int(rank)
for token, rank in (line.split() for line in open(vocab_path) if line)
}
n_vocab = len(ranks)
special_tokens = {}
specials = [
"<|endoftext|>",
"<|startoftranscript|>",
* [f"<|{lang}|>" for lang in LANGUAGES.keys()],
* [f"<|{lang}|>" for lang in list(LANGUAGES.keys())[:num_languages]],
"<|translate|>",
"<|transcribe|>",
"<|startoflm|>",
"<|startofprev|>",
"<|nospeech|>",
"<|notimestamps|>",
* [f"<|{i * 0.02:.2f}|>" for i in range(1501)],
]
tokenizer.add_special_tokens(dict(additional_special_tokens=specials))
return tokenizer
for token in specials:
special_tokens[token] = n_vocab
n_vocab += 1
return tiktoken.Encoding(
name=os.path.basename(vocab_path),
explicit_n_vocab=n_vocab,
pat_str=r"""'s|'t|'re|'ve|'m|'ll|'d| ?\p{L}+| ?\p{N}+| ?[^\s\p{L}\p{N}]+|\s+(?!\S)|\s+""",
mergeable_ranks=ranks,
special_tokens=special_tokens, )
# @lru_cache(maxsize=None)
# def get_tokenizer(
# multilingual: bool,
# resource_path: str,
# *,
# task: Optional[str]=None, # Literal["transcribe", "translate", None]
# language: Optional[str]=None, ) -> Tokenizer:
# if language is not None:
# language = language.lower()
# if language not in LANGUAGES:
# if language in TO_LANGUAGE_CODE:
# language = TO_LANGUAGE_CODE[language]
# else:
# raise ValueError(f"Unsupported language: {language}")
# if multilingual:
# tokenizer_name = "multilingual"
# task = task or "transcribe"
# language = language or "en"
# else:
# tokenizer_name = "gpt2"
# task = None
# language = None
# tokenizer = build_tokenizer(
# resource_path=resource_path, name=tokenizer_name)
# all_special_ids: List[int] = tokenizer.all_special_ids
# sot: int = all_special_ids[1]
# translate: int = all_special_ids[-6]
# transcribe: int = all_special_ids[-5]
# langs = tuple(LANGUAGES.keys())
# sot_sequence = [sot]
# if language is not None:
# sot_sequence.append(sot + 1 + langs.index(language))
# if task is not None:
# sot_sequence.append(transcribe if task == "transcribe" else translate)
# return Tokenizer(
# tokenizer=tokenizer,
# language=language,
# sot_sequence=tuple(sot_sequence))
@lru_cache(maxsize=None)
@ -327,9 +477,12 @@ def get_tokenizer(
multilingual: bool,
resource_path: str,
*,
num_languages: int=99,
language: Optional[str]=None,
task: Optional[str]=None, # Literal["transcribe", "translate", None]
language: Optional[str]=None, ) -> Tokenizer:
) -> Tokenizer:
if language is not None:
print(language)
language = language.lower()
if language not in LANGUAGES:
if language in TO_LANGUAGE_CODE:
@ -338,29 +491,21 @@ def get_tokenizer(
raise ValueError(f"Unsupported language: {language}")
if multilingual:
tokenizer_name = "multilingual"
task = task or "transcribe"
encoding_name = "multilingual"
language = language or "en"
task = task or "transcribe"
else:
tokenizer_name = "gpt2"
task = None
encoding_name = "gpt2"
language = None
task = None
tokenizer = build_tokenizer(
resource_path=resource_path, name=tokenizer_name)
all_special_ids: List[int] = tokenizer.all_special_ids
sot: int = all_special_ids[1]
translate: int = all_special_ids[-6]
transcribe: int = all_special_ids[-5]
langs = tuple(LANGUAGES.keys())
sot_sequence = [sot]
if language is not None:
sot_sequence.append(sot + 1 + langs.index(language))
if task is not None:
sot_sequence.append(transcribe if task == "transcribe" else translate)
encoding = get_encoding(
resource_path=resource_path,
name=encoding_name,
num_languages=num_languages)
return Tokenizer(
tokenizer=tokenizer,
encoding=encoding,
num_languages=num_languages,
language=language,
sot_sequence=tuple(sot_sequence))
task=task)

@ -33,13 +33,18 @@ logger = Log(__name__).getlog()
_MODELS = ["large"]
SAMPLE_RATE = 16000
N_FFT = 400
N_MELS = 80
HOP_LENGTH = 160
CHUNK_LENGTH = 30
N_SAMPLES = CHUNK_LENGTH * SAMPLE_RATE # 480000: number of samples in a chunk
N_FRAMES = utils.exact_div(
N_SAMPLES, HOP_LENGTH) # 3000: number of frames in a mel spectrogram input
N_SAMPLES_PER_TOKEN = HOP_LENGTH * 2 # the initial convolutions has stride 2
FRAMES_PER_SECOND = utils.exact_div(SAMPLE_RATE,
HOP_LENGTH) # 10ms per audio frame
TOKENS_PER_SECOND = utils.exact_div(SAMPLE_RATE,
N_SAMPLES_PER_TOKEN) # 20ms per audio token
@dataclass
class ModelDimensions:
@ -378,7 +383,9 @@ def detect_language(
"""
if tokenizer is None:
tokenizer = get_tokenizer(
model.is_multilingual, resource_path=resource_path)
multilingual=model.is_multilingual,
resource_path=resource_path,
num_languages=model.num_languages)
if tokenizer.language is None or tokenizer.language_token not in tokenizer.sot_sequence:
raise ValueError(
"This model doesn't have language tokens so it can't perform lang id"
@ -428,6 +435,13 @@ def transcribe(
logprob_threshold: Optional[float]=-1.0,
no_speech_threshold: Optional[float]=0.6,
condition_on_previous_text: bool=True,
initial_prompt: Optional[str]=None,
carry_initial_prompt: bool=False,
word_timestamps: bool=False,
prepend_punctuations: str="\"'“¿([{-",
append_punctuations: str="\"'.。,!?::”)]}、",
clip_timestamps: Union[str, List[float]]="0",
hallucination_silence_threshold: Optional[float]=None,
**decode_options, ):
"""
Transcribe an audio file using Whisper
@ -476,8 +490,11 @@ def transcribe(
if dtype == np.float32:
decode_options["fp16"] = False
if decode_options.get("language") == 'None' or decode_options.get(
"language", None) is None:
content_frames = mel.shape[-1] - N_FRAMES
content_duration = float(content_frames * HOP_LENGTH / SAMPLE_RATE)
# import pdb
# pdb.set_trace()
if decode_options.get("language", None) in {None, "None"}:
if not model.is_multilingual:
decode_options["language"] = "en"
else:
@ -485,25 +502,49 @@ def transcribe(
print(
"Detecting language using up to the first 30 seconds. Use `--language` to specify the language"
)
segment = pad_or_trim(mel, N_FRAMES)
_, probs = model.detect_language(segment, resource_path)
mel_segment = pad_or_trim(mel,
N_FRAMES).to(model.device).astype(dtype)
_, probs = model.detect_language(mel_segment, resource_path)
decode_options["language"] = max(probs, key=probs.get)
if verbose is not None:
print(
f"Detected language: {LANGUAGES[decode_options['language']].title()}"
)
language = decode_options["language"]
task = decode_options.get("task", "transcribe")
language: str = decode_options["language"]
print("language", language)
task: str = decode_options.get("task", "transcribe")
print("model.num_languages", model.num_languages)
tokenizer = get_tokenizer(
model.is_multilingual,
multilingual=model.is_multilingual,
resource_path=resource_path,
num_languages=model.num_languages,
language=language,
task=task)
task=task, )
if isinstance(clip_timestamps, str):
clip_timestamps = [
float(ts)
for ts in (clip_timestamps.split(",") if clip_timestamps else [])
]
seek_points: List[
int] = [round(ts * FRAMES_PER_SECOND) for ts in clip_timestamps]
if len(seek_points) == 0:
seek_points.append(0)
if len(seek_points) % 2 == 1:
seek_points.append(content_frames)
seek_clips: List[Tuple[int, int]] = list(
zip(seek_points[::2], seek_points[1::2]))
punctuation = "\"'“¿([{-\"'.。,!?::”)]}、"
if word_timestamps and task == "translate":
warnings.warn(
"Word-level timestamps on translations may not be reliable.")
def decode_with_fallback(segment: paddle.Tensor) -> DecodingResult:
temperatures = [temperature] if isinstance(temperature, (
int, float)) else temperature
temperatures = ([temperature] if isinstance(temperature, (int, float))
else temperature)
decode_result = None
for t in temperatures:
@ -517,20 +558,29 @@ def transcribe(
kwargs.pop("best_of", None)
options = DecodingOptions(**kwargs, temperature=t)
decode_result = model.decode(segment, options, resource_path)
needs_fallback = False
if compression_ratio_threshold is not None and decode_result.compression_ratio > compression_ratio_threshold:
if (compression_ratio_threshold is not None and
decode_result.compression_ratio >
compression_ratio_threshold):
needs_fallback = True # too repetitive
if logprob_threshold is not None and decode_result.avg_logprob < logprob_threshold:
if (logprob_threshold is not None and
decode_result.avg_logprob < logprob_threshold):
needs_fallback = True # average log probability is too low
if (no_speech_threshold is not None and
decode_result.no_speech_prob > no_speech_threshold and
logprob_threshold is not None and
decode_result.avg_logprob < logprob_threshold):
needs_fallback = False # silence
if not needs_fallback:
break
return decode_result
seek = 0
clip_idx = 0
seek = seek_clips[clip_idx][0]
input_stride = utils.exact_div(
N_FRAMES, model.dims.n_audio_ctx) # mel frames per output token: 2
time_precision = (input_stride * HOP_LENGTH /
@ -539,127 +589,287 @@ def transcribe(
all_segments = []
prompt_reset_since = 0
initial_prompt = decode_options.pop("initial_prompt", None) or []
if initial_prompt:
initial_prompt = tokenizer.encode(" " +
initial_prompt.strip()).input_ids
all_tokens.extend(initial_prompt)
remaining_prompt_length = model.dims.n_text_ctx // 2 - 1
if initial_prompt is not None:
initial_prompt_tokens = tokenizer.encode(" " + initial_prompt.strip())
all_tokens.extend(initial_prompt_tokens)
remaining_prompt_length -= len(initial_prompt_tokens)
else:
initial_prompt_tokens = []
def add_segment(*,
def new_segment(*,
start: float,
end: float,
text_tokens: paddle.Tensor,
tokens: paddle.Tensor,
result: DecodingResult):
text = tokenizer.decode(
[token for token in text_tokens if token < tokenizer.eot])
if len(text.strip()) == 0: # skip empty text output
return
all_segments.append({
"id": len(all_segments),
tokens = tokens.tolist()
text_tokens = [token for token in tokens if token < tokenizer.eot]
return {
"seek": seek,
"start": start,
"end": end,
"text": text,
"tokens": result.tokens,
"text": tokenizer.decode(text_tokens),
"tokens": tokens,
"temperature": result.temperature,
"avg_logprob": result.avg_logprob,
"compression_ratio": result.compression_ratio,
"no_speech_prob": result.no_speech_prob,
})
if verbose:
print(
f"[{utils.format_timestamp(start)} --> {utils.format_timestamp(end)}] {text}"
)
# show the progress bar when verbose is False (otherwise the transcribed text will be printed)
num_frames = mel.shape[-1]
previous_seek_value = seek
}
# show the progress bar when verbose is False (if True, transcribed text will be printed)
with tqdm.tqdm(
total=num_frames, unit='frames',
total=content_frames, unit="frames",
disable=verbose is not False) as pbar:
while seek < num_frames:
timestamp_offset = float(seek * HOP_LENGTH / SAMPLE_RATE)
segment = pad_or_trim(mel[:, seek:], N_FRAMES)
segment_duration = segment.shape[-1] * HOP_LENGTH / SAMPLE_RATE
last_speech_timestamp = 0.0
# NOTE: This loop is obscurely flattened to make the diff readable.
# A later commit should turn this into a simpler nested loop.
# for seek_clip_start, seek_clip_end in seek_clips:
# while seek < seek_clip_end
while clip_idx < len(seek_clips):
seek_clip_start, seek_clip_end = seek_clips[clip_idx]
if seek < seek_clip_start:
seek = seek_clip_start
if seek >= seek_clip_end:
clip_idx += 1
if clip_idx < len(seek_clips):
seek = seek_clips[clip_idx][0]
continue
time_offset = float(seek * HOP_LENGTH / SAMPLE_RATE)
window_end_time = float(
(seek + N_FRAMES) * HOP_LENGTH / SAMPLE_RATE)
segment_size = min(N_FRAMES, content_frames - seek,
seek_clip_end - seek)
mel_segment = mel[:, seek:seek + segment_size]
segment_duration = segment_size * HOP_LENGTH / SAMPLE_RATE
mel_segment = pad_or_trim(mel_segment,
N_FRAMES).to(model.device).astype(dtype)
if carry_initial_prompt:
nignored = max(len(initial_prompt_tokens), prompt_reset_since)
remaining_prompt = all_tokens[nignored:][
-remaining_prompt_length:]
decode_options[
"prompt"] = initial_prompt_tokens + remaining_prompt
else:
decode_options["prompt"] = all_tokens[prompt_reset_since:]
decode_options["prompt"] = all_tokens[prompt_reset_since:]
result: DecodingResult = decode_with_fallback(segment)
result: DecodingResult = decode_with_fallback(mel_segment)
tokens = paddle.to_tensor(result.tokens)
if no_speech_threshold is not None:
# no voice activity check
should_skip = result.no_speech_prob > no_speech_threshold
if logprob_threshold is not None and result.avg_logprob > logprob_threshold:
if (logprob_threshold is not None and
result.avg_logprob > logprob_threshold):
# don't skip if the logprob is high enough, despite the no_speech_prob
should_skip = False
if should_skip:
seek += segment.shape[
-1] # fast-forward to the next segment boundary
seek += segment_size # fast-forward to the next segment boundary
continue
previous_seek = seek
current_segments = []
# anomalous words are very long/short/improbable
def word_anomaly_score(word: dict) -> float:
probability = word.get("probability", 0.0)
duration = word["end"] - word["start"]
score = 0.0
if probability < 0.15:
score += 1.0
if duration < 0.133:
score += (0.133 - duration) * 15
if duration > 2.0:
score += duration - 2.0
return score
def is_segment_anomaly(segment: Optional[dict]) -> bool:
if segment is None or not segment["words"]:
return False
words = [
w for w in segment["words"] if w["word"] not in punctuation
]
words = words[:8]
score = sum(word_anomaly_score(w) for w in words)
return score >= 3 or score + 0.01 >= len(words)
def next_words_segment(segments: List[dict]) -> Optional[dict]:
return next((s for s in segments if s["words"]), None)
timestamp_tokens: paddle.Tensor = tokens.greater_equal(
paddle.to_tensor(tokenizer.timestamp_begin))
single_timestamp_ending = timestamp_tokens[
-2:].tolist() == [False, True]
consecutive = paddle.where(timestamp_tokens[:-1] & timestamp_tokens[
1:])[0]
if len(
consecutive
) > 0: # if the output contains two consecutive timestamp tokens
consecutive = paddle.add(consecutive, paddle.to_tensor(1))
print("consecutive", consecutive)
consecutive = paddle.add(consecutive, paddle.to_tensor(1))
if len(consecutive) > 0:
# if the output contains two consecutive timestamp tokens
slices = consecutive.tolist()
if single_timestamp_ending:
slices.append(len(tokens))
last_slice = 0
for current_slice in consecutive:
for current_slice in slices:
sliced_tokens = tokens[last_slice:current_slice]
start_timestamp_position = (
start_timestamp_pos = (
sliced_tokens[0].item() - tokenizer.timestamp_begin)
end_timestamp_position = (
end_timestamp_pos = (
sliced_tokens[-1].item() - tokenizer.timestamp_begin)
add_segment(
start=timestamp_offset + start_timestamp_position *
time_precision,
end=timestamp_offset + end_timestamp_position *
time_precision,
text_tokens=sliced_tokens[1:-1],
result=result, )
current_segments.append(
new_segment(
start=time_offset + start_timestamp_pos *
time_precision,
end=time_offset + end_timestamp_pos *
time_precision,
tokens=sliced_tokens,
result=result, ))
last_slice = current_slice
last_timestamp_position = (
tokens[last_slice - 1].item() - tokenizer.timestamp_begin)
seek += last_timestamp_position * input_stride
all_tokens.extend(tokens[:last_slice + 1].tolist())
if single_timestamp_ending:
# single timestamp at the end means no speech after the last timestamp.
seek += segment_size
else:
# otherwise, ignore the unfinished segment and seek to the last timestamp
last_timestamp_pos = (tokens[last_slice - 1].item() -
tokenizer.timestamp_begin)
seek += last_timestamp_pos * input_stride
else:
duration = segment_duration
timestamps = tokens[timestamp_tokens.nonzero().flatten()]
if len(timestamps) > 0 and timestamps[
-1].item() != tokenizer.timestamp_begin:
if (len(timestamps) > 0 and
timestamps[-1].item() != tokenizer.timestamp_begin):
# no consecutive timestamps but it has a timestamp; use the last one.
# single timestamp at the end means no speech after the last timestamp.
last_timestamp_position = timestamps[
-1].item() - tokenizer.timestamp_begin
duration = last_timestamp_position * time_precision
add_segment(
start=timestamp_offset,
end=timestamp_offset + duration,
text_tokens=tokens,
result=result, )
last_timestamp_pos = (
timestamps[-1].item() - tokenizer.timestamp_begin)
duration = last_timestamp_pos * time_precision
current_segments.append(
new_segment(
start=time_offset,
end=time_offset + duration,
tokens=tokens,
result=result, ))
seek += segment_size
if word_timestamps:
add_word_timestamps(
segments=current_segments,
model=model,
tokenizer=tokenizer,
mel=mel_segment,
num_frames=segment_size,
prepend_punctuations=prepend_punctuations,
append_punctuations=append_punctuations,
last_speech_timestamp=last_speech_timestamp, )
if not single_timestamp_ending:
last_word_end = get_end(current_segments)
if last_word_end is not None and last_word_end > time_offset:
seek = round(last_word_end * FRAMES_PER_SECOND)
# skip silence before possible hallucinations
if hallucination_silence_threshold is not None:
threshold = hallucination_silence_threshold
if not single_timestamp_ending:
last_word_end = get_end(current_segments)
if last_word_end is not None and last_word_end > time_offset:
remaining_duration = window_end_time - last_word_end
if remaining_duration > threshold:
seek = round(last_word_end * FRAMES_PER_SECOND)
else:
seek = previous_seek + segment_size
# if first segment might be a hallucination, skip leading silence
first_segment = next_words_segment(current_segments)
if first_segment is not None and is_segment_anomaly(
first_segment):
gap = first_segment["start"] - time_offset
if gap > threshold:
seek = previous_seek + round(gap *
FRAMES_PER_SECOND)
continue
# skip silence before any possible hallucination that is surrounded
# by silence or more hallucinations
hal_last_end = last_speech_timestamp
for si in range(len(current_segments)):
segment = current_segments[si]
if not segment["words"]:
continue
if is_segment_anomaly(segment):
next_segment = next_words_segment(
current_segments[si + 1:])
if next_segment is not None:
hal_next_start = next_segment["words"][0][
"start"]
else:
hal_next_start = time_offset + segment_duration
silence_before = (
segment["start"] - hal_last_end > threshold or
segment["start"] < threshold or
segment["start"] - time_offset < 2.0)
silence_after = (
hal_next_start - segment["end"] > threshold or
is_segment_anomaly(next_segment) or
window_end_time - segment["end"] < 2.0)
if silence_before and silence_after:
seek = round(
max(time_offset + 1, segment["start"]) *
FRAMES_PER_SECOND)
if content_duration - segment[
"end"] < threshold:
seek = content_frames
current_segments[si:] = []
break
hal_last_end = segment["end"]
last_word_end = get_end(current_segments)
if last_word_end is not None:
last_speech_timestamp = last_word_end
seek += segment.shape[-1]
all_tokens.extend(tokens.tolist())
if verbose:
for segment in current_segments:
start, end, text = segment["start"], segment[
"end"], segment["text"]
line = f"[{utils.format_timestamp(start)} --> {utils.format_timestamp(end)}] {text}"
print(line)
# if a segment is instantaneous or does not contain text, clear it
for i, segment in enumerate(current_segments):
if segment["start"] == segment["end"] or segment[
"text"].strip() == "":
segment["text"] = ""
segment["tokens"] = []
segment["words"] = []
all_segments.extend(
[{
"id": i,
**
segment
}
for i, segment in enumerate(
current_segments, start=len(all_segments))])
all_tokens.extend([
token
for segment in current_segments for token in segment["tokens"]
])
if not condition_on_previous_text or result.temperature > 0.5:
# do not feed the prompt tokens if a high temperature was used
prompt_reset_since = len(all_tokens)
# update progress bar
pbar.update(min(num_frames, seek) - previous_seek_value)
previous_seek_value = seek
pbar.update(min(content_frames, seek) - previous_seek)
return dict(
text=tokenizer.decode(all_tokens[len(initial_prompt):]),
text=tokenizer.decode(all_tokens[len(initial_prompt_tokens):]),
segments=all_segments,
language=language)
language=language, )
class SequenceRanker:
@ -776,11 +986,11 @@ class GreedyDecoder(TokenDecoder):
next_tokens.shape[0] * next_tokens.shape[1],
])
logprobs = F.log_softmax(logits, axis=-1, dtype=paddle.float32)
logprobs = F.log_softmax(logits, axis=-1, dtype="float32")
current_logprobs = logprobs[paddle.arange(logprobs.shape[0]),
next_tokens]
sum_logprobs += current_logprobs * paddle.to_tensor(
(tokens[:, -1] != self.eot), dtype=paddle.float32)
(tokens[:, -1] != self.eot), dtype="float32")
next_tokens[tokens[:, -1] == self.eot] = self.eot
tokens = paddle.concat([tokens, next_tokens[:, None]], axis=-1)
@ -928,8 +1138,8 @@ class SuppressBlank(LogitFilter):
def apply(self, logits: paddle.Tensor, tokens: paddle.Tensor):
if tokens.shape[1] == self.sample_begin:
logits[:, self.tokenizer.encode(" ").input_ids +
[self.tokenizer.eot]] = -np.inf
logits[:, self.tokenizer.encode(" ") + [self.tokenizer.eot
]] = -np.inf
class SuppressTokens(LogitFilter):
@ -1005,7 +1215,7 @@ class DecodingTask:
language = options.language or "en"
tokenizer = get_tokenizer(
model.is_multilingual,
multilingual=model.is_multilingual,
resource_path=resource_path,
language=language,
task=options.task)
@ -1346,11 +1556,16 @@ class Whisper(nn.Layer):
@property
def device(self):
# return str(paddle.device.get_device()).split(":")[0]
return paddle.device.get_device()
@property
def is_multilingual(self):
return self.dims.n_vocab == 51865
return self.dims.n_vocab >= 51865
@property
def num_languages(self):
return self.dims.n_vocab - 51765 - int(self.is_multilingual)
def install_kv_cache_hooks(self, cache: Optional[dict]=None):
"""
@ -1364,7 +1579,7 @@ class Whisper(nn.Layer):
cache : Dict[nn.Layer, paddle.Tensor]
A dictionary object mapping the key/value projection modules to its cache
hooks : List[RemovableHandle]
List of PyTorch RemovableHandle objects to stop the hooks to be called
List of Paddle RemovableHandle objects to stop the hooks to be called
"""
cache = {**cache} if cache is not None else {}
hooks = []
@ -1435,7 +1650,7 @@ def hann_window(n_fft: int=N_FFT):
@lru_cache(maxsize=None)
def mel_filters(resource_path: str, n_mels: int=N_MELS) -> paddle.Tensor:
def mel_filters(resource_path: str, n_mels: int) -> paddle.Tensor:
"""
load the mel filterbank matrix for projecting STFT into a Mel spectrogram.
Allows decoupling librosa dependency; saved using:
@ -1445,13 +1660,19 @@ def mel_filters(resource_path: str, n_mels: int=N_MELS) -> paddle.Tensor:
mel_80=librosa.filters.mel(sr=16000, n_fft=400, n_mels=80),
)
"""
assert n_mels == 80, f"Unsupported n_mels: {n_mels}"
with np.load(os.path.join(resource_path, "assets", "mel_filters.npz")) as f:
# assert n_mels == 80, f"Unsupported n_mels: {n_mels}"
# with np.load(os.path.join(resource_path, "assets", "mel_filters.npz")) as f:
# return paddle.to_tensor(f[f"mel_{n_mels}"])
assert n_mels in {80, 128}, f"Unsupported n_mels: {n_mels}"
filters_path = os.path.join(resource_path, "assets", "mel_filters.npz")
with np.load(filters_path, allow_pickle=False) as f:
return paddle.to_tensor(f[f"mel_{n_mels}"])
def log_mel_spectrogram(audio: Union[str, np.ndarray, paddle.Tensor],
n_mels: int=N_MELS,
n_mels: int=80,
padding: int=0,
resource_path: str=None):
"""
Compute the log-Mel spectrogram of
@ -1475,7 +1696,8 @@ def log_mel_spectrogram(audio: Union[str, np.ndarray, paddle.Tensor],
audio = audio[:, 0]
logger.info(f"audio shape: {audio.shape}")
audio = paddle.to_tensor(audio)
if padding > 0:
audio = F.pad(audio, (0, padding), data_format="NLC")
window = hann_window(N_FFT)
stft = paddle.signal.stft(audio, N_FFT, HOP_LENGTH, window=window)

Loading…
Cancel
Save