You can not select more than 25 topics
Topics must start with a letter or number, can include dashes ('-') and can be up to 35 characters long.
1478 lines
56 KiB
1478 lines
56 KiB
# MIT License, Copyright (c) 2022 OpenAI.
|
|
# Copyright (c) 2022 PaddlePaddle Authors and . All Rights Reserved.
|
|
#
|
|
# Modified from OpenAI Whisper 2022 (https://github.com/openai/whisper/whisper)
|
|
import os
|
|
from dataclasses import dataclass
|
|
from dataclasses import field
|
|
from functools import lru_cache
|
|
from typing import Dict
|
|
from typing import Iterable
|
|
from typing import List
|
|
from typing import Optional
|
|
from typing import Sequence
|
|
from typing import Tuple
|
|
from typing import Union
|
|
|
|
import numpy as np
|
|
import paddle
|
|
import paddle.nn.functional as F
|
|
import paddlespeech.s2t.modules.align as paddlespeech_nn
|
|
import soundfile
|
|
import tqdm
|
|
from paddle import nn
|
|
from paddle.distribution import Categorical
|
|
from paddlespeech.s2t.models.whisper import utils
|
|
from paddlespeech.s2t.models.whisper.tokenizer import get_tokenizer
|
|
from paddlespeech.s2t.models.whisper.tokenizer import LANGUAGES
|
|
from paddlespeech.s2t.models.whisper.tokenizer import Tokenizer
|
|
from paddlespeech.s2t.utils.log import Log
|
|
logger = Log(__name__).getlog()
|
|
|
|
_MODELS = ["large"]
|
|
SAMPLE_RATE = 16000
|
|
N_FFT = 400
|
|
N_MELS = 80
|
|
HOP_LENGTH = 160
|
|
CHUNK_LENGTH = 30
|
|
N_SAMPLES = CHUNK_LENGTH * SAMPLE_RATE # 480000: number of samples in a chunk
|
|
N_FRAMES = utils.exact_div(
|
|
N_SAMPLES, HOP_LENGTH) # 3000: number of frames in a mel spectrogram input
|
|
|
|
|
|
@dataclass
|
|
class ModelDimensions:
|
|
n_mels: int
|
|
n_audio_ctx: int
|
|
n_audio_state: int
|
|
n_audio_head: int
|
|
n_audio_layer: int
|
|
n_vocab: int
|
|
n_text_ctx: int
|
|
n_text_state: int
|
|
n_text_head: int
|
|
n_text_layer: int
|
|
|
|
|
|
class LayerNorm(paddlespeech_nn.LayerNorm):
|
|
def forward(self, x: paddle.Tensor) -> paddle.Tensor:
|
|
return super().forward(x)
|
|
|
|
|
|
class Linear(paddlespeech_nn.Linear):
|
|
def forward(self, x: paddle.Tensor) -> paddle.Tensor:
|
|
return F.linear(x, self.weight, None
|
|
if self.bias is None else self.bias)
|
|
|
|
|
|
class Conv1d(paddlespeech_nn.Conv1D):
|
|
def forward(self, x: paddle.Tensor) -> paddle.Tensor:
|
|
return super().forward(x)
|
|
|
|
|
|
class MultiHeadAttention(nn.Layer):
|
|
def __init__(self, n_state: int, n_head: int):
|
|
super().__init__()
|
|
self.n_head = n_head
|
|
self.query = Linear(n_state, n_state, bias_attr=True)
|
|
self.key = Linear(n_state, n_state, bias_attr=False)
|
|
self.value = Linear(n_state, n_state, bias_attr=True)
|
|
self.out = Linear(n_state, n_state, bias_attr=True)
|
|
|
|
def forward(
|
|
self,
|
|
x: paddle.Tensor,
|
|
xa: Optional[paddle.Tensor]=None,
|
|
mask: Optional[paddle.Tensor]=None,
|
|
kv_cache: Optional[dict]=None, ):
|
|
q = self.query(x)
|
|
|
|
if kv_cache is None or xa is None or self.key not in kv_cache:
|
|
# hooks, if installed (i.e. kv_cache is not None), will prepend the cached kv tensors;
|
|
# otherwise, perform key/value projections for self- or cross-attention as usual.
|
|
k = self.key(x if xa is None else xa)
|
|
v = self.value(x if xa is None else xa)
|
|
else:
|
|
# for cross-attention, calculate keys and values once and reuse in subsequent calls.
|
|
k = kv_cache[self.key]
|
|
v = kv_cache[self.value]
|
|
|
|
wv = self.qkv_attention(q, k, v, mask)
|
|
return self.out(wv)
|
|
|
|
def qkv_attention(self,
|
|
q: paddle.Tensor,
|
|
k: paddle.Tensor,
|
|
v: paddle.Tensor,
|
|
mask: Optional[paddle.Tensor]=None):
|
|
n_batch, n_ctx, n_state = q.shape
|
|
scale = (n_state // self.n_head)**-0.25
|
|
q = paddle.transpose(
|
|
q.view(*q.shape[:2], self.n_head, -1), (0, 2, 1, 3)) * scale
|
|
k = paddle.transpose(
|
|
k.view(*k.shape[:2], self.n_head, -1), (0, 2, 3, 1)) * scale
|
|
v = paddle.transpose(
|
|
v.view(*v.shape[:2], self.n_head, -1), (0, 2, 1, 3))
|
|
|
|
qk = q @ k
|
|
if mask is not None:
|
|
qk = qk + mask[:n_ctx, :n_ctx]
|
|
|
|
w = F.softmax(qk.float(), axis=-1).to(q.dtype)
|
|
return paddle.transpose((w @ v), (0, 2, 1, 3)).flatten(start_axis=2)
|
|
|
|
|
|
class ResidualAttentionBlock(nn.Layer):
|
|
def __init__(self, n_state: int, n_head: int, cross_attention: bool=False):
|
|
super().__init__()
|
|
|
|
self.attn = MultiHeadAttention(n_state, n_head)
|
|
self.attn_ln = LayerNorm(n_state)
|
|
|
|
self.cross_attn = MultiHeadAttention(
|
|
n_state, n_head) if cross_attention else None
|
|
self.cross_attn_ln = LayerNorm(n_state) if cross_attention else None
|
|
|
|
n_mlp = n_state * 4
|
|
self.mlp = nn.Sequential(
|
|
Linear(n_state, n_mlp, bias_attr=True),
|
|
nn.GELU(), Linear(n_mlp, n_state, bias_attr=True))
|
|
self.mlp_ln = LayerNorm(n_state)
|
|
|
|
def forward(
|
|
self,
|
|
x: paddle.Tensor,
|
|
xa: Optional[paddle.Tensor]=None,
|
|
mask: Optional[paddle.Tensor]=None,
|
|
kv_cache: Optional[dict]=None, ):
|
|
x = x + self.attn(self.attn_ln(x), mask=mask, kv_cache=kv_cache)
|
|
if self.cross_attn:
|
|
x = x + self.cross_attn(
|
|
self.cross_attn_ln(x), xa, kv_cache=kv_cache)
|
|
x = x + self.mlp(self.mlp_ln(x))
|
|
return x
|
|
|
|
|
|
def sinusoids(length, channels, max_timescale=10000):
|
|
"""Returns sinusoids for positional embedding"""
|
|
assert channels % 2 == 0
|
|
log_timescale_increment = np.log(max_timescale) / (channels // 2 - 1)
|
|
inv_timescales = paddle.exp(-log_timescale_increment * paddle.arange(
|
|
channels // 2, dtype=paddle.float32))
|
|
scaled_time = paddle.arange(
|
|
length,
|
|
dtype=paddle.float32)[:, np.newaxis] * inv_timescales[np.newaxis, :]
|
|
return paddle.to_tensor(
|
|
paddle.concat(
|
|
[paddle.sin(scaled_time), paddle.cos(scaled_time)], axis=1))
|
|
|
|
|
|
class AudioEncoder(nn.Layer):
|
|
def __init__(self,
|
|
n_mels: int,
|
|
n_ctx: int,
|
|
n_state: int,
|
|
n_head: int,
|
|
n_layer: int):
|
|
super().__init__()
|
|
self.conv1 = Conv1d(
|
|
n_mels, n_state, kernel_size=3, stride=1, padding=1, bias_attr=True)
|
|
self.conv2 = Conv1d(
|
|
n_state,
|
|
n_state,
|
|
kernel_size=3,
|
|
stride=2,
|
|
padding=1,
|
|
bias_attr=True)
|
|
self.register_buffer("positional_embedding", sinusoids(n_ctx, n_state))
|
|
|
|
self.blocks: Iterable[ResidualAttentionBlock] = nn.LayerList(
|
|
[ResidualAttentionBlock(n_state, n_head) for _ in range(n_layer)])
|
|
self.ln_post = LayerNorm(n_state)
|
|
|
|
def forward(self, x: paddle.Tensor):
|
|
"""
|
|
x : paddle.Tensor, shape = (batch_size, n_mels, n_ctx)
|
|
the mel spectrogram of the audio
|
|
"""
|
|
x = F.gelu(self.conv1(x))
|
|
x = F.gelu(self.conv2(x))
|
|
x = paddle.transpose(x, (0, 2, 1))
|
|
|
|
assert x.shape[
|
|
1:] == self.positional_embedding.shape, "incorrect audio shape"
|
|
x = (x + self.positional_embedding)
|
|
|
|
for block in self.blocks:
|
|
x = block(x)
|
|
|
|
x = self.ln_post(x)
|
|
return x
|
|
|
|
|
|
class TextDecoder(nn.Layer):
|
|
def __init__(self,
|
|
n_vocab: int,
|
|
n_ctx: int,
|
|
n_state: int,
|
|
n_head: int,
|
|
n_layer: int):
|
|
super().__init__()
|
|
|
|
self.token_embedding = nn.Embedding(n_vocab, n_state)
|
|
self.positional_embedding = paddle.create_parameter(
|
|
shape=[n_ctx, n_state], dtype='float32')
|
|
|
|
self.blocks: Iterable[ResidualAttentionBlock] = nn.LayerList([
|
|
ResidualAttentionBlock(n_state, n_head, cross_attention=True)
|
|
for _ in range(n_layer)
|
|
])
|
|
self.ln = LayerNorm(n_state)
|
|
|
|
mask = paddle.full(
|
|
shape=[n_ctx, n_state], fill_value=-np.inf, dtype='float32')
|
|
mask = paddle.triu(mask, diagonal=1)
|
|
self.register_buffer("mask", mask, persistable=False)
|
|
|
|
def forward(self,
|
|
x: paddle.Tensor,
|
|
xa: paddle.Tensor,
|
|
kv_cache: Optional[dict]=None):
|
|
"""
|
|
x : paddle.LongTensor, shape = (batch_size, <= n_ctx)
|
|
the text tokens
|
|
xa : paddle.Tensor, shape = (batch_size, n_mels, n_audio_ctx)
|
|
the encoded audio features to be attended on
|
|
"""
|
|
offset = next(iter(kv_cache.values())).shape[1] if kv_cache else 0
|
|
x = self.token_embedding(x) + self.positional_embedding[offset:offset +
|
|
x.shape[-1]]
|
|
x = x.to(xa.dtype)
|
|
|
|
for block in self.blocks:
|
|
x = block(x, xa, mask=self.mask, kv_cache=kv_cache)
|
|
|
|
x = self.ln(x)
|
|
logits = (x @ paddle.transpose(self.token_embedding.weight, (1, 0)))
|
|
|
|
return logits
|
|
|
|
|
|
@dataclass(frozen=True)
|
|
class DecodingOptions:
|
|
task: str = "transcribe" # whether to perform X->X "transcribe" or X->English "translate"
|
|
language: Optional[
|
|
str] = None # language that the audio is in; uses detected language if None
|
|
# sampling-related options
|
|
temperature: float = 0.0
|
|
sample_len: Optional[int] = None # maximum number of tokens to sample
|
|
best_of: Optional[
|
|
int] = None # number of independent samples to collect, when t > 0
|
|
beam_size: Optional[
|
|
int] = None # number of beams in beam search, when t == 0
|
|
patience: Optional[
|
|
float] = None # patience in beam search (https://arxiv.org/abs/2204.05424)
|
|
|
|
# options for ranking generations (either beams or best-of-N samples)
|
|
length_penalty: Optional[
|
|
float] = None # "alpha" in Google NMT, None defaults to length norm
|
|
|
|
# prompt, prefix, and token suppression
|
|
prompt: Optional[Union[str, List[
|
|
int]]] = None # text or tokens for the previous context
|
|
prefix: Optional[Union[str, List[
|
|
int]]] = None # text or tokens to prefix the current context
|
|
suppress_blank: bool = True # this will suppress blank outputs
|
|
|
|
# list of tokens ids (or comma-separated token ids) to suppress
|
|
# "-1" will suppress a set of symbols as defined in `tokenizer.non_speech_tokens()`
|
|
suppress_tokens: Optional[Union[str, Iterable[int]]] = "-1"
|
|
|
|
# timestamp sampling options
|
|
without_timestamps: bool = False # use <|notimestamps|> to sample text tokens only
|
|
max_initial_timestamp: Optional[
|
|
float] = 1.0 # the initial timestamp cannot be later than this
|
|
|
|
# implementation details
|
|
fp16: bool = False # use fp16 for most of the calculation
|
|
|
|
|
|
@dataclass(frozen=True)
|
|
class DecodingResult:
|
|
audio_features: paddle.Tensor
|
|
language: str
|
|
language_probs: Optional[Dict[str, float]] = None
|
|
tokens: List[int] = field(default_factory=list)
|
|
text: str = ""
|
|
avg_logprob: float = np.nan
|
|
no_speech_prob: float = np.nan
|
|
temperature: float = np.nan
|
|
compression_ratio: float = np.nan
|
|
|
|
|
|
class Inference:
|
|
def logits(self, tokens: paddle.Tensor,
|
|
audio_features: paddle.Tensor) -> paddle.Tensor:
|
|
"""Perform a forward pass on the decoder and return per-token logits"""
|
|
raise NotImplementedError
|
|
|
|
def rearrange_kv_cache(self, source_indices) -> None:
|
|
"""Update the key-value cache according to the updated beams"""
|
|
raise NotImplementedError
|
|
|
|
def cleanup_caching(self) -> None:
|
|
"""Clean up any resources or hooks after decoding is finished"""
|
|
pass
|
|
|
|
|
|
class WhisperInference(Inference):
|
|
def __init__(self, model: "Whisper", initial_token_length: int):
|
|
self.model: "Whisper" = model
|
|
self.initial_token_length = initial_token_length
|
|
self.kv_cache = {}
|
|
self.hooks = []
|
|
|
|
def logits(self, tokens: paddle.Tensor,
|
|
audio_features: paddle.Tensor) -> paddle.Tensor:
|
|
if not self.kv_cache:
|
|
self.kv_cache, self.hooks = self.model.install_kv_cache_hooks()
|
|
|
|
if tokens.shape[-1] > self.initial_token_length:
|
|
# only need to use the last token except in the first forward pass
|
|
tokens = tokens[:, -1:]
|
|
|
|
return self.model.decoder(
|
|
tokens, audio_features, kv_cache=self.kv_cache)
|
|
|
|
def cleanup_caching(self):
|
|
for hook in self.hooks:
|
|
hook.remove()
|
|
|
|
self.kv_cache = {}
|
|
self.hooks = []
|
|
|
|
def rearrange_kv_cache(self, source_indices):
|
|
for module, tensor in self.kv_cache.items():
|
|
# update the key/value cache to contain the selected sequences
|
|
self.kv_cache[module] = tensor[source_indices].detach()
|
|
|
|
|
|
@paddle.no_grad()
|
|
def detect_language(
|
|
model: "Whisper",
|
|
mel: paddle.Tensor,
|
|
resource_path: str,
|
|
tokenizer: Tokenizer=None) -> Tuple[paddle.Tensor, List[dict]]:
|
|
"""
|
|
Detect the spoken language in the audio, and return them as list of strings, along with the ids
|
|
of the most probable language tokens and the probability distribution over all language tokens.
|
|
This is performed outside the main decode loop in order to not interfere with kv-caching.
|
|
|
|
Returns
|
|
-------
|
|
language_tokens : Tensor, shape = (batch_size,)
|
|
ids of the most probable language tokens, which appears after the startoftranscript token.
|
|
language_probs : List[Dict[str, float]], length = batch_size
|
|
list of dictionaries containing the probability distribution over all languages.
|
|
"""
|
|
if tokenizer is None:
|
|
tokenizer = get_tokenizer(
|
|
model.is_multilingual, resource_path=resource_path)
|
|
if tokenizer.language is None or tokenizer.language_token not in tokenizer.sot_sequence:
|
|
raise ValueError(
|
|
"This model doesn't have language tokens so it can't perform lang id"
|
|
)
|
|
|
|
single = mel.ndim == 2
|
|
if single:
|
|
mel = mel.unsqueeze(0)
|
|
|
|
# skip encoder forward pass if already-encoded audio features were given
|
|
if mel.shape[-2:] != (model.dims.n_audio_ctx, model.dims.n_audio_state):
|
|
mel = model.encoder(mel)
|
|
|
|
# forward pass using a single token, startoftranscript
|
|
batch_size = mel.shape[0]
|
|
x = paddle.to_tensor([[tokenizer.sot]] * batch_size) # [batch_size, 1]
|
|
logits = model.logits(x, mel)[:, 0]
|
|
|
|
# collect detected languages; suppress all non-language tokens
|
|
mask = paddle.ones(paddle.to_tensor(logits.shape[-1]), dtype=bool)
|
|
mask[list(tokenizer.all_language_tokens)] = False
|
|
logits[:, mask] = -np.inf
|
|
language_tokens = paddle.argmax(logits, axis=-1)
|
|
language_token_probs = F.softmax(logits, axis=-1)
|
|
language_probs = [{
|
|
c: language_token_probs[i, j].tolist()
|
|
for j, c in zip(tokenizer.all_language_tokens,
|
|
tokenizer.all_language_codes)
|
|
} for i in range(batch_size)]
|
|
|
|
if single:
|
|
language_tokens = language_tokens[0]
|
|
language_probs = language_probs[0]
|
|
|
|
return language_tokens, language_probs
|
|
|
|
|
|
def transcribe(
|
|
model: "Whisper",
|
|
mel: paddle.Tensor,
|
|
resource_path: str,
|
|
*,
|
|
verbose: Optional[bool]=None,
|
|
temperature: Union[float, Tuple[float, ...]]=(0.0, 0.2, 0.4, 0.6, 0.8,
|
|
1.0),
|
|
compression_ratio_threshold: Optional[float]=2.4,
|
|
logprob_threshold: Optional[float]=-1.0,
|
|
no_speech_threshold: Optional[float]=0.6,
|
|
condition_on_previous_text: bool=True,
|
|
**decode_options, ):
|
|
"""
|
|
Transcribe an audio file using Whisper
|
|
|
|
Parameters
|
|
----------
|
|
model: Whisper
|
|
The Whisper model instance
|
|
|
|
mel: paddle.Tensor
|
|
The audio feature
|
|
|
|
verbose: bool
|
|
Whether to display the text being decoded to the console. If True, displays all the details,
|
|
If False, displays minimal details. If None, does not display anything
|
|
|
|
temperature: Union[float, Tuple[float, ...]]
|
|
Temperature for sampling. It can be a tuple of temperatures, which will be successfully used
|
|
upon failures according to either `compression_ratio_threshold` or `logprob_threshold`.
|
|
|
|
compression_ratio_threshold: float
|
|
If the gzip compression ratio is above this value, treat as failed
|
|
|
|
logprob_threshold: float
|
|
If the average log probability over sampled tokens is below this value, treat as failed
|
|
|
|
no_speech_threshold: float
|
|
If the no_speech probability is higher than this value AND the average log probability
|
|
over sampled tokens is below `logprob_threshold`, consider the segment as silent
|
|
|
|
condition_on_previous_text: bool
|
|
if True, the previous output of the model is provided as a prompt for the next window;
|
|
disabling may make the text inconsistent across windows, but the model becomes less prone to
|
|
getting stuck in a failure loop, such as repetition looping or timestamps going out of sync.
|
|
|
|
decode_options: dict
|
|
Keyword arguments to construct `DecodingOptions` instances
|
|
|
|
Returns
|
|
-------
|
|
A dictionary containing the resulting text ("text") and segment-level details ("segments"), and
|
|
the spoken language ("language"), which is detected when `decode_options["language"]` is None.
|
|
"""
|
|
dtype = np.float32 #paddle only support float32
|
|
|
|
if dtype == np.float32:
|
|
decode_options["fp16"] = False
|
|
|
|
if decode_options.get(
|
|
"language") == 'None' or decode_options.get("language", None) is None:
|
|
if not model.is_multilingual:
|
|
decode_options["language"] = "en"
|
|
else:
|
|
if verbose:
|
|
print(
|
|
"Detecting language using up to the first 30 seconds. Use `--language` to specify the language"
|
|
)
|
|
segment = pad_or_trim(mel, N_FRAMES)
|
|
_, probs = model.detect_language(segment, resource_path)
|
|
decode_options["language"] = max(probs, key=probs.get)
|
|
if verbose is not None:
|
|
print(
|
|
f"Detected language: {LANGUAGES[decode_options['language']].title()}"
|
|
)
|
|
|
|
language = decode_options["language"]
|
|
task = decode_options.get("task", "transcribe")
|
|
tokenizer = get_tokenizer(
|
|
model.is_multilingual,
|
|
resource_path=resource_path,
|
|
language=language,
|
|
task=task)
|
|
|
|
def decode_with_fallback(segment: paddle.Tensor) -> DecodingResult:
|
|
temperatures = [temperature] if isinstance(temperature, (
|
|
int, float)) else temperature
|
|
decode_result = None
|
|
|
|
for t in temperatures:
|
|
kwargs = {**decode_options}
|
|
if t > 0:
|
|
# disable beam_size and patience when t > 0
|
|
kwargs.pop("beam_size", None)
|
|
kwargs.pop("patience", None)
|
|
else:
|
|
# disable best_of when t == 0
|
|
kwargs.pop("best_of", None)
|
|
|
|
options = DecodingOptions(**kwargs, temperature=t)
|
|
decode_result = model.decode(segment, options, resource_path)
|
|
|
|
needs_fallback = False
|
|
if compression_ratio_threshold is not None and decode_result.compression_ratio > compression_ratio_threshold:
|
|
needs_fallback = True # too repetitive
|
|
if logprob_threshold is not None and decode_result.avg_logprob < logprob_threshold:
|
|
needs_fallback = True # average log probability is too low
|
|
|
|
if not needs_fallback:
|
|
break
|
|
|
|
return decode_result
|
|
|
|
seek = 0
|
|
input_stride = utils.exact_div(
|
|
N_FRAMES, model.dims.n_audio_ctx) # mel frames per output token: 2
|
|
time_precision = (input_stride * HOP_LENGTH /
|
|
SAMPLE_RATE) # time per output token: 0.02 (seconds)
|
|
all_tokens = []
|
|
all_segments = []
|
|
prompt_reset_since = 0
|
|
|
|
initial_prompt = decode_options.pop("initial_prompt", None) or []
|
|
if initial_prompt:
|
|
initial_prompt = tokenizer.encode(" " +
|
|
initial_prompt.strip()).input_ids
|
|
all_tokens.extend(initial_prompt)
|
|
|
|
def add_segment(*,
|
|
start: float,
|
|
end: float,
|
|
text_tokens: paddle.Tensor,
|
|
result: DecodingResult):
|
|
text = tokenizer.decode(
|
|
[token for token in text_tokens if token < tokenizer.eot])
|
|
if len(text.strip()) == 0: # skip empty text output
|
|
return
|
|
|
|
all_segments.append({
|
|
"id": len(all_segments),
|
|
"seek": seek,
|
|
"start": start,
|
|
"end": end,
|
|
"text": text,
|
|
"tokens": result.tokens,
|
|
"temperature": result.temperature,
|
|
"avg_logprob": result.avg_logprob,
|
|
"compression_ratio": result.compression_ratio,
|
|
"no_speech_prob": result.no_speech_prob,
|
|
})
|
|
if verbose:
|
|
print(
|
|
f"[{utils.format_timestamp(start)} --> {utils.format_timestamp(end)}] {text}"
|
|
)
|
|
|
|
# show the progress bar when verbose is False (otherwise the transcribed text will be printed)
|
|
num_frames = mel.shape[-1]
|
|
previous_seek_value = seek
|
|
|
|
with tqdm.tqdm(
|
|
total=num_frames, unit='frames',
|
|
disable=verbose is not False) as pbar:
|
|
while seek < num_frames:
|
|
timestamp_offset = float(seek * HOP_LENGTH / SAMPLE_RATE)
|
|
segment = pad_or_trim(mel[:, seek:], N_FRAMES)
|
|
segment_duration = segment.shape[-1] * HOP_LENGTH / SAMPLE_RATE
|
|
|
|
decode_options["prompt"] = all_tokens[prompt_reset_since:]
|
|
result: DecodingResult = decode_with_fallback(segment)
|
|
tokens = paddle.to_tensor(result.tokens)
|
|
|
|
if no_speech_threshold is not None:
|
|
# no voice activity check
|
|
should_skip = result.no_speech_prob > no_speech_threshold
|
|
if logprob_threshold is not None and result.avg_logprob > logprob_threshold:
|
|
# don't skip if the logprob is high enough, despite the no_speech_prob
|
|
should_skip = False
|
|
|
|
if should_skip:
|
|
seek += segment.shape[
|
|
-1] # fast-forward to the next segment boundary
|
|
continue
|
|
|
|
timestamp_tokens: paddle.Tensor = tokens.greater_equal(
|
|
paddle.to_tensor(tokenizer.timestamp_begin))
|
|
|
|
consecutive = paddle.where(timestamp_tokens[:-1] & timestamp_tokens[
|
|
1:])[0]
|
|
if len(
|
|
consecutive
|
|
) > 0: # if the output contains two consecutive timestamp tokens
|
|
consecutive = paddle.add(consecutive, paddle.to_tensor(1))
|
|
last_slice = 0
|
|
for current_slice in consecutive:
|
|
sliced_tokens = tokens[last_slice:current_slice]
|
|
start_timestamp_position = (
|
|
sliced_tokens[0].item() - tokenizer.timestamp_begin)
|
|
end_timestamp_position = (
|
|
sliced_tokens[-1].item() - tokenizer.timestamp_begin)
|
|
add_segment(
|
|
start=timestamp_offset + start_timestamp_position *
|
|
time_precision,
|
|
end=timestamp_offset + end_timestamp_position *
|
|
time_precision,
|
|
text_tokens=sliced_tokens[1:-1],
|
|
result=result, )
|
|
last_slice = current_slice
|
|
last_timestamp_position = (
|
|
tokens[last_slice - 1].item() - tokenizer.timestamp_begin)
|
|
seek += last_timestamp_position * input_stride
|
|
all_tokens.extend(tokens[:last_slice + 1].tolist())
|
|
else:
|
|
duration = segment_duration
|
|
timestamps = tokens[timestamp_tokens.nonzero().flatten()]
|
|
if len(timestamps) > 0 and timestamps[
|
|
-1].item() != tokenizer.timestamp_begin:
|
|
# no consecutive timestamps but it has a timestamp; use the last one.
|
|
# single timestamp at the end means no speech after the last timestamp.
|
|
last_timestamp_position = timestamps[
|
|
-1].item() - tokenizer.timestamp_begin
|
|
duration = last_timestamp_position * time_precision
|
|
|
|
add_segment(
|
|
start=timestamp_offset,
|
|
end=timestamp_offset + duration,
|
|
text_tokens=tokens,
|
|
result=result, )
|
|
|
|
seek += segment.shape[-1]
|
|
all_tokens.extend(tokens.tolist())
|
|
|
|
if not condition_on_previous_text or result.temperature > 0.5:
|
|
# do not feed the prompt tokens if a high temperature was used
|
|
prompt_reset_since = len(all_tokens)
|
|
|
|
# update progress bar
|
|
pbar.update(min(num_frames, seek) - previous_seek_value)
|
|
previous_seek_value = seek
|
|
|
|
return dict(
|
|
text=tokenizer.decode(all_tokens[len(initial_prompt):]),
|
|
segments=all_segments,
|
|
language=language)
|
|
|
|
|
|
class SequenceRanker:
|
|
def rank(self,
|
|
tokens: List[List[paddle.Tensor]],
|
|
sum_logprobs: List[List[float]]) -> List[int]:
|
|
"""
|
|
Given a list of groups of samples and their cumulative log probabilities,
|
|
return the indices of the samples in each group to select as the final result
|
|
"""
|
|
raise NotImplementedError
|
|
|
|
|
|
class MaximumLikelihoodRanker(SequenceRanker):
|
|
"""
|
|
Select the sample with the highest log probabilities, penalized using either
|
|
a simple length normalization or Google NMT paper's length penalty
|
|
"""
|
|
|
|
def __init__(self, length_penalty: Optional[float]):
|
|
self.length_penalty = length_penalty
|
|
|
|
def rank(self,
|
|
tokens: List[List[paddle.Tensor]],
|
|
sum_logprobs: List[List[float]]):
|
|
def scores(logprobs, lengths):
|
|
result = []
|
|
for logprob, length in zip(logprobs, lengths):
|
|
if self.length_penalty is None:
|
|
penalty = length
|
|
else:
|
|
# from the Google NMT paper
|
|
penalty = ((5 + length) / 6)**self.length_penalty
|
|
result.append(logprob / penalty)
|
|
return result
|
|
|
|
# get the sequence with the highest score
|
|
lengths = [[len(t) for t in s] for s in tokens]
|
|
return [np.argmax(scores(p, l)) for p, l in zip(sum_logprobs, lengths)]
|
|
|
|
|
|
class TokenDecoder:
|
|
def reset(self):
|
|
"""Initialize any stateful variables for decoding a new sequence"""
|
|
|
|
def update(self,
|
|
tokens: paddle.Tensor,
|
|
logits: paddle.Tensor,
|
|
sum_logprobs: paddle.Tensor) -> Tuple[paddle.Tensor, bool]:
|
|
"""Specify how to select the next token, based on the current trace and logits
|
|
|
|
Parameters
|
|
----------
|
|
tokens : Tensor, shape = (n_batch, current_sequence_length)
|
|
all tokens in the context so far, including the prefix and sot_sequence tokens
|
|
|
|
logits : Tensor, shape = (n_batch, vocab_size)
|
|
per-token logits of the probability distribution at the current step
|
|
|
|
sum_logprobs : Tensor, shape = (n_batch)
|
|
cumulative log probabilities for each sequence
|
|
|
|
Returns
|
|
-------
|
|
tokens : Tensor, shape = (n_batch, current_sequence_length + 1)
|
|
the tokens, appended with the selected next token
|
|
|
|
completed : bool
|
|
True if all sequences has reached the end of text
|
|
|
|
"""
|
|
raise NotImplementedError
|
|
|
|
def finalize(
|
|
self, tokens: paddle.Tensor, sum_logprobs: paddle.Tensor
|
|
) -> Tuple[Sequence[Sequence[paddle.Tensor]], List[List[float]]]:
|
|
"""Finalize search and return the final candidate sequences
|
|
|
|
Parameters
|
|
----------
|
|
tokens : Tensor, shape = (batch_size, beam_size, current_sequence_length)
|
|
all tokens in the context so far, including the prefix and sot_sequence
|
|
|
|
sum_logprobs : Tensor, shape = (batch_size, beam_size)
|
|
cumulative log probabilities for each sequence
|
|
|
|
Returns
|
|
-------
|
|
tokens : Sequence[Sequence[Tensor]], length = batch_size
|
|
sequence of Tensors containing candidate token sequences, for each audio input
|
|
|
|
sum_logprobs : List[List[float]], length = batch_size
|
|
sequence of cumulative log probabilities corresponding to the above
|
|
|
|
"""
|
|
raise NotImplementedError
|
|
|
|
|
|
class GreedyDecoder(TokenDecoder):
|
|
def __init__(self, temperature: float, eot: int):
|
|
self.temperature = temperature
|
|
self.eot = eot
|
|
|
|
def update(self,
|
|
tokens: paddle.Tensor,
|
|
logits: paddle.Tensor,
|
|
sum_logprobs: paddle.Tensor) -> Tuple[paddle.Tensor, bool]:
|
|
temperature = self.temperature
|
|
if temperature == 0:
|
|
next_tokens = paddle.argmax(logits, axis=-1)
|
|
else:
|
|
next_tokens = Categorical(logits=logits / temperature).sample([1])
|
|
next_tokens = paddle.reshape(next_tokens, [
|
|
next_tokens.shape[0] * next_tokens.shape[1],
|
|
])
|
|
|
|
logprobs = F.log_softmax(logits, axis=-1, dtype=paddle.float32)
|
|
current_logprobs = logprobs[paddle.arange(logprobs.shape[0]),
|
|
next_tokens]
|
|
sum_logprobs += current_logprobs * paddle.to_tensor(
|
|
(tokens[:, -1] != self.eot), dtype=paddle.float32)
|
|
|
|
next_tokens[tokens[:, -1] == self.eot] = self.eot
|
|
tokens = paddle.concat([tokens, next_tokens[:, None]], axis=-1)
|
|
|
|
completed = paddle.all((tokens[:, -1] == self.eot))
|
|
return tokens, completed
|
|
|
|
def finalize(self, tokens: paddle.Tensor, sum_logprobs: paddle.Tensor):
|
|
# make sure each sequence has at least one EOT token at the end
|
|
tokens = F.pad(tokens, (0, 1), value=self.eot, data_format="NCL")
|
|
return tokens, sum_logprobs.tolist()
|
|
|
|
|
|
class BeamSearchDecoder(TokenDecoder):
|
|
def __init__(self,
|
|
beam_size: int,
|
|
eot: int,
|
|
inference: Inference,
|
|
patience: Optional[float]=None):
|
|
self.beam_size = beam_size
|
|
self.eot = eot
|
|
self.inference = inference
|
|
self.patience = patience or 1.0
|
|
self.max_candidates: int = round(beam_size * self.patience)
|
|
self.finished_sequences = None
|
|
|
|
assert self.max_candidates > 0, f"Invalid beam size ({beam_size}) or patience ({patience})"
|
|
|
|
def reset(self):
|
|
self.finished_sequences = None
|
|
|
|
def update(self,
|
|
tokens: paddle.Tensor,
|
|
logits: paddle.Tensor,
|
|
sum_logprobs: paddle.Tensor) -> Tuple[paddle.Tensor, bool]:
|
|
if tokens.shape[0] % self.beam_size != 0:
|
|
raise ValueError(f"{tokens.shape}[0] % {self.beam_size} != 0")
|
|
|
|
batch_size = tokens.shape[0] // self.beam_size
|
|
if self.finished_sequences is None: # for the first update
|
|
self.finished_sequences = [{} for _ in range(batch_size)]
|
|
|
|
logprobs = F.log_softmax(logits, axis=-1, dtype=paddle.float32)
|
|
next_tokens, source_indices, finished_sequences = [], [], []
|
|
for i in range(batch_size):
|
|
scores, sources, finished = {}, {}, {}
|
|
|
|
# STEP 1: calculate the cumulative log probabilities for possible candidates
|
|
for j in range(self.beam_size):
|
|
idx = i * self.beam_size + j
|
|
prefix = tokens[idx].tolist()
|
|
logprob, token = paddle.topk(
|
|
logprobs[idx], k=self.beam_size + 1)
|
|
for logprob, token in zip(logprob, token):
|
|
new_logprob = (sum_logprobs[idx] + logprob).tolist()[0]
|
|
sequence = tuple(prefix + [token.tolist()[0]])
|
|
scores[sequence] = new_logprob
|
|
sources[sequence] = idx
|
|
|
|
# STEP 2: rank the candidates and keep the top beam_size sequences for each audio
|
|
saved = 0
|
|
for sequence in sorted(scores, key=scores.get, reverse=True):
|
|
if sequence[-1] == self.eot:
|
|
finished[sequence] = scores[sequence]
|
|
else:
|
|
sum_logprobs[len(next_tokens)] = scores[sequence]
|
|
next_tokens.append(sequence)
|
|
source_indices.append(sources[sequence])
|
|
|
|
saved += 1
|
|
if saved == self.beam_size:
|
|
break
|
|
|
|
finished_sequences.append(finished)
|
|
|
|
tokens = paddle.to_tensor(next_tokens)
|
|
self.inference.rearrange_kv_cache(source_indices)
|
|
|
|
# add newly finished sequences to self.finished_sequences
|
|
assert len(self.finished_sequences) == len(finished_sequences)
|
|
for previously_finished, newly_finished in zip(self.finished_sequences,
|
|
finished_sequences):
|
|
for seq in sorted(
|
|
newly_finished, key=newly_finished.get, reverse=True):
|
|
if len(previously_finished) >= self.max_candidates:
|
|
break # the candidate list is full
|
|
previously_finished[seq] = newly_finished[seq]
|
|
|
|
# mark as completed if all audio has enough number of samples
|
|
completed = all(
|
|
len(sequences) >= self.max_candidates
|
|
for sequences in self.finished_sequences)
|
|
return tokens, completed
|
|
|
|
def finalize(self,
|
|
preceding_tokens: paddle.Tensor,
|
|
sum_logprobs: paddle.Tensor):
|
|
# collect all finished sequences, including patience, and add unfinished ones if not enough
|
|
sum_logprobs = sum_logprobs.cpu()
|
|
for i, sequences in enumerate(self.finished_sequences):
|
|
if len(sequences
|
|
) < self.beam_size: # when not enough sequences are finished
|
|
for j in list(np.argsort(sum_logprobs[i]))[::-1]:
|
|
sequence = preceding_tokens[i, j].tolist() + [self.eot]
|
|
sequences[tuple(sequence)] = sum_logprobs[i][j].item()
|
|
if len(sequences) >= self.beam_size:
|
|
break
|
|
|
|
tokens: List[List[paddle.Tensor]] = [
|
|
[paddle.to_tensor(seq) for seq in sequences.keys()]
|
|
for sequences in self.finished_sequences
|
|
]
|
|
sum_logprobs: List[List[float]] = [
|
|
list(sequences.values()) for sequences in self.finished_sequences
|
|
]
|
|
return tokens, sum_logprobs
|
|
|
|
|
|
class LogitFilter:
|
|
def apply(self, logits: paddle.Tensor, tokens: paddle.Tensor) -> None:
|
|
"""Apply any filtering or masking to logits in-place
|
|
|
|
Parameters
|
|
----------
|
|
logits : Tensor, shape = (n_batch, vocab_size)
|
|
per-token logits of the probability distribution at the current step
|
|
|
|
tokens : Tensor, shape = (n_batch, current_sequence_length)
|
|
all tokens in the context so far, including the prefix and sot_sequence tokens
|
|
|
|
"""
|
|
raise NotImplementedError
|
|
|
|
|
|
class SuppressBlank(LogitFilter):
|
|
def __init__(self, tokenizer: Tokenizer, sample_begin: int):
|
|
self.tokenizer = tokenizer
|
|
self.sample_begin = sample_begin
|
|
|
|
def apply(self, logits: paddle.Tensor, tokens: paddle.Tensor):
|
|
if tokens.shape[1] == self.sample_begin:
|
|
logits[:, self.tokenizer.encode(" ").input_ids +
|
|
[self.tokenizer.eot]] = -np.inf
|
|
|
|
|
|
class SuppressTokens(LogitFilter):
|
|
def __init__(self, suppress_tokens: Sequence[int]):
|
|
self.suppress_tokens = list(suppress_tokens)
|
|
|
|
def apply(self, logits: paddle.Tensor, tokens: paddle.Tensor):
|
|
logits[:, self.suppress_tokens] = -np.inf
|
|
|
|
|
|
class ApplyTimestampRules(LogitFilter):
|
|
def __init__(self,
|
|
tokenizer: Tokenizer,
|
|
sample_begin: int,
|
|
max_initial_timestamp_index: Optional[int]):
|
|
self.tokenizer = tokenizer
|
|
self.sample_begin = sample_begin
|
|
self.max_initial_timestamp_index = max_initial_timestamp_index
|
|
|
|
def apply(self, logits: paddle.Tensor, tokens: paddle.Tensor):
|
|
# suppress <|notimestamps|> which is handled by without_timestamps
|
|
if self.tokenizer.no_timestamps is not None:
|
|
logits[:, self.tokenizer.no_timestamps] = -np.inf
|
|
|
|
# timestamps have to appear in pairs, except directly before EOT; mask logits accordingly
|
|
for k in range(tokens.shape[0]):
|
|
seq = [t for t in tokens[k, self.sample_begin:].tolist()]
|
|
last_was_timestamp = len(seq) >= 1 and seq[
|
|
-1] >= self.tokenizer.timestamp_begin
|
|
penultimate_was_timestamp = len(seq) < 2 or seq[
|
|
-2] >= self.tokenizer.timestamp_begin
|
|
|
|
if last_was_timestamp:
|
|
if penultimate_was_timestamp: # has to be non-timestamp
|
|
logits[k, self.tokenizer.timestamp_begin:] = -np.inf
|
|
else: # cannot be normal text tokens
|
|
logits[k, :self.tokenizer.eot] = -np.inf
|
|
|
|
# apply the `max_initial_timestamp` option
|
|
if tokens.shape[
|
|
1] == self.sample_begin and self.max_initial_timestamp_index is not None:
|
|
last_allowed = self.tokenizer.timestamp_begin + self.max_initial_timestamp_index
|
|
logits[:, last_allowed + 1:] = -np.inf
|
|
|
|
# if sum of probability over timestamps is above any other token, sample timestamp
|
|
logprobs = F.log_softmax(logits, axis=-1, dtype=paddle.float32)
|
|
for k in range(tokens.shape[0]):
|
|
timestamp_logprob = paddle.logsumexp(
|
|
logprobs[k, self.tokenizer.timestamp_begin:], axis=-1)
|
|
max_text_token_logprob = paddle.max(
|
|
logprobs[k, :self.tokenizer.timestamp_begin])
|
|
if timestamp_logprob > max_text_token_logprob:
|
|
logits[k, :self.tokenizer.timestamp_begin] = -np.inf
|
|
|
|
|
|
class DecodingTask:
|
|
inference: Inference
|
|
sequence_ranker: SequenceRanker
|
|
decoder: TokenDecoder
|
|
logit_filters: List[LogitFilter]
|
|
|
|
def __init__(self,
|
|
model: "Whisper",
|
|
options: DecodingOptions,
|
|
resource_path: str):
|
|
self.model = model
|
|
|
|
language = options.language or "en"
|
|
tokenizer = get_tokenizer(
|
|
model.is_multilingual,
|
|
resource_path=resource_path,
|
|
language=language,
|
|
task=options.task)
|
|
self.tokenizer: Tokenizer = tokenizer
|
|
self.options: DecodingOptions = self._verify_options(options)
|
|
self.resource_path: str = resource_path
|
|
|
|
self.beam_size: int = options.beam_size or options.best_of or 1
|
|
self.n_ctx: int = model.dims.n_text_ctx
|
|
self.sample_len: int = options.sample_len or model.dims.n_text_ctx // 2
|
|
|
|
self.sot_sequence: Tuple[int] = tokenizer.sot_sequence
|
|
if self.options.without_timestamps:
|
|
self.sot_sequence = tokenizer.sot_sequence_including_notimestamps
|
|
|
|
self.initial_tokens: Tuple[int] = self._get_initial_tokens()
|
|
self.sample_begin: int = len(self.initial_tokens)
|
|
self.sot_index: int = self.initial_tokens.index(tokenizer.sot)
|
|
|
|
# inference: implements the forward pass through the decoder, including kv caching
|
|
self.inference = WhisperInference(model, len(self.initial_tokens))
|
|
|
|
# sequence ranker: implements how to rank a group of sampled sequences
|
|
self.sequence_ranker = MaximumLikelihoodRanker(options.length_penalty)
|
|
|
|
# decoder: implements how to select the next tokens, given the autoregressive distribution
|
|
if options.beam_size is not None:
|
|
self.decoder = BeamSearchDecoder(options.beam_size, tokenizer.eot,
|
|
self.inference, options.patience)
|
|
else:
|
|
self.decoder = GreedyDecoder(options.temperature, tokenizer.eot)
|
|
|
|
# logit filters: applies various rules to suppress or penalize certain tokens
|
|
self.logit_filters = []
|
|
if self.options.suppress_blank:
|
|
self.logit_filters.append(
|
|
SuppressBlank(self.tokenizer, self.sample_begin))
|
|
if self.options.suppress_tokens:
|
|
self.logit_filters.append(
|
|
SuppressTokens(self._get_suppress_tokens()))
|
|
if not options.without_timestamps:
|
|
precision = CHUNK_LENGTH / model.dims.n_audio_ctx # usually 0.02 seconds
|
|
max_initial_timestamp_index = None
|
|
if options.max_initial_timestamp:
|
|
max_initial_timestamp_index = round(
|
|
self.options.max_initial_timestamp / precision)
|
|
self.logit_filters.append(
|
|
ApplyTimestampRules(tokenizer, self.sample_begin,
|
|
max_initial_timestamp_index))
|
|
|
|
def _verify_options(self, options: DecodingOptions) -> DecodingOptions:
|
|
if options.beam_size is not None and options.best_of is not None:
|
|
raise ValueError("beam_size and best_of can't be given together")
|
|
if options.temperature == 0:
|
|
if options.best_of is not None:
|
|
raise ValueError(
|
|
"best_of with greedy sampling (T=0) is not compatible")
|
|
if options.patience is not None and options.beam_size is None:
|
|
raise ValueError("patience requires beam_size to be given")
|
|
if options.length_penalty is not None and not (
|
|
0 <= options.length_penalty <= 1):
|
|
raise ValueError(
|
|
"length_penalty (alpha) should be a value between 0 and 1")
|
|
|
|
return options
|
|
|
|
def _get_initial_tokens(self) -> Tuple[int]:
|
|
tokens = list(self.sot_sequence)
|
|
prefix = self.options.prefix
|
|
prompt = self.options.prompt
|
|
|
|
if prefix:
|
|
prefix_tokens = (
|
|
self.tokenizer.encode(" " + prefix.strip().input_ids)
|
|
if isinstance(prefix, str) else prefix)
|
|
if self.sample_len is not None:
|
|
max_prefix_len = self.n_ctx // 2 - self.sample_len
|
|
prefix_tokens = prefix_tokens[-max_prefix_len:]
|
|
tokens = tokens + prefix_tokens
|
|
|
|
if prompt:
|
|
prompt_tokens = (
|
|
self.tokenizer.encode(" " + prompt.strip().input_ids)
|
|
if isinstance(prompt, str) else prompt)
|
|
tokens = [self.tokenizer.sot_prev] + prompt_tokens[-(self.n_ctx // 2
|
|
- 1):] + tokens
|
|
|
|
return tuple(tokens)
|
|
|
|
def _get_suppress_tokens(self) -> Tuple[int]:
|
|
suppress_tokens = self.options.suppress_tokens
|
|
|
|
if isinstance(suppress_tokens, str):
|
|
suppress_tokens = [int(t) for t in suppress_tokens.split(",")]
|
|
|
|
if -1 in suppress_tokens:
|
|
suppress_tokens = [t for t in suppress_tokens if t >= 0]
|
|
suppress_tokens.extend(self.tokenizer.non_speech_tokens)
|
|
elif suppress_tokens is None or len(suppress_tokens) == 0:
|
|
suppress_tokens = [] # interpret empty string as an empty list
|
|
else:
|
|
assert isinstance(suppress_tokens,
|
|
list), "suppress_tokens must be a list"
|
|
|
|
suppress_tokens.extend([
|
|
self.tokenizer.sot, self.tokenizer.sot_prev, self.tokenizer.sot_lm
|
|
])
|
|
if self.tokenizer.no_speech is not None:
|
|
# no-speech probability is collected separately
|
|
suppress_tokens.append(self.tokenizer.no_speech)
|
|
|
|
return tuple(sorted(set(suppress_tokens)))
|
|
|
|
def _get_audio_features(self, mel: paddle.Tensor):
|
|
#if self.options.fp16:
|
|
# mel = mel.half()
|
|
|
|
if mel.shape[-2:] == (self.model.dims.n_audio_ctx,
|
|
self.model.dims.n_audio_state):
|
|
# encoded audio features are given; skip audio encoding
|
|
audio_features = mel
|
|
else:
|
|
audio_features = self.model.encoder(mel)
|
|
|
|
#if audio_features.dtype != (np.float16 if self.options.fp16 else np.float32):
|
|
# return TypeError(f"audio_features has an incorrect dtype: {audio_features.dtype}")
|
|
|
|
return audio_features
|
|
|
|
def _detect_language(self,
|
|
audio_features: paddle.Tensor,
|
|
tokens: paddle.Tensor,
|
|
resource_path: str):
|
|
languages = [self.options.language] * audio_features.shape[0]
|
|
lang_probs = None
|
|
|
|
if self.options.language is None or self.options.task == "lang_id":
|
|
lang_tokens, lang_probs = self.model.detect_language(
|
|
audio_features, self.tokenizer, self.resource_path)
|
|
languages = [max(probs, key=probs.get) for probs in lang_probs]
|
|
if self.options.language is None:
|
|
tokens[:, self.sot_index +
|
|
1] = lang_tokens # write language tokens
|
|
|
|
return languages, lang_probs
|
|
|
|
def _main_loop(self, audio_features: paddle.Tensor, tokens: paddle.Tensor):
|
|
assert audio_features.shape[0] == tokens.shape[0]
|
|
n_batch = tokens.shape[0]
|
|
sum_logprobs: paddle.Tensor = paddle.zeros(
|
|
paddle.to_tensor(n_batch), dtype=paddle.float32)
|
|
no_speech_probs = [np.nan] * n_batch
|
|
|
|
try:
|
|
for i in range(self.sample_len):
|
|
logits = self.inference.logits(tokens, audio_features)
|
|
|
|
if i == 0 and self.tokenizer.no_speech is not None: # save no_speech_probs
|
|
probs_at_sot = F.softmax(
|
|
logits[:, self.sot_index],
|
|
axis=-1,
|
|
dtype=paddle.float32)
|
|
no_speech_probs = probs_at_sot[:, self.tokenizer.
|
|
no_speech].tolist()
|
|
|
|
# now we need to consider the logits at the last token only
|
|
logits = logits[:, -1]
|
|
|
|
# apply the logit filters, e.g. for suppressing or applying penalty to
|
|
for logit_filter in self.logit_filters:
|
|
logit_filter.apply(logits, tokens)
|
|
|
|
# expand the tokens tensor with the selected next tokens
|
|
tokens, completed = self.decoder.update(tokens, logits,
|
|
sum_logprobs)
|
|
if completed or tokens.shape[-1] > self.n_ctx:
|
|
break
|
|
finally:
|
|
self.inference.cleanup_caching()
|
|
|
|
return tokens, sum_logprobs, no_speech_probs
|
|
|
|
@paddle.no_grad()
|
|
def run(self, mel: paddle.Tensor) -> List[DecodingResult]:
|
|
self.decoder.reset()
|
|
tokenizer: Tokenizer = self.tokenizer
|
|
batch_size: int = mel.shape[0]
|
|
|
|
audio_features: paddle.Tensor = self._get_audio_features(
|
|
mel) # encoder forward pass
|
|
|
|
tokens: paddle.Tensor
|
|
if batch_size > 1:
|
|
for i in range(batch_size):
|
|
tokens = paddle.concat(
|
|
x=[
|
|
paddle.to_tensor([self.initial_tokens]),
|
|
paddle.to_tensor([self.initial_tokens])
|
|
],
|
|
axis=0)
|
|
elif batch_size == 1:
|
|
tokens = paddle.to_tensor([self.initial_tokens])
|
|
|
|
# detect language if requested, overwriting the language token
|
|
languages, language_probs = self._detect_language(
|
|
paddle.to_tensor(audio_features),
|
|
paddle.to_tensor(tokens), self.resource_path)
|
|
|
|
if self.options.task == "lang_id":
|
|
return [
|
|
DecodingResult(
|
|
audio_features=features,
|
|
language=language,
|
|
language_probs=probs) for features, language, probs in
|
|
zip(audio_features, languages, language_probs)
|
|
]
|
|
|
|
# repeat the audio & text tensors by the group size, for beam search or best-of-n sampling
|
|
|
|
audio_features = paddle.repeat_interleave(
|
|
audio_features, self.beam_size, axis=0)
|
|
tokens = paddle.repeat_interleave(tokens, self.beam_size, axis=0)
|
|
|
|
# call the main sampling loop
|
|
tokens, sum_logprobs, no_speech_probs = self._main_loop(audio_features,
|
|
tokens)
|
|
|
|
# reshape the tensors to have (batch_size, beam_size) as the first two dimensions
|
|
audio_features = audio_features[::self.beam_size]
|
|
no_speech_probs = no_speech_probs[::self.beam_size]
|
|
assert audio_features.shape[0] == len(no_speech_probs) == batch_size
|
|
|
|
tokens = tokens.reshape([batch_size, self.beam_size, -1])
|
|
sum_logprobs = sum_logprobs.reshape([batch_size, self.beam_size])
|
|
|
|
# get the final candidates for each group, and slice between the first sampled token and EOT
|
|
tokens, sum_logprobs = self.decoder.finalize(tokens, sum_logprobs)
|
|
tokens: List[List[paddle.Tensor]] = [[
|
|
t[self.sample_begin:(t == tokenizer.eot).nonzero()[0, 0]] for t in s
|
|
] for s in tokens]
|
|
|
|
# select the top-ranked sample in each group
|
|
selected = self.sequence_ranker.rank(tokens, sum_logprobs)
|
|
tokens: List[List[
|
|
int]] = [t[i].tolist() for i, t in zip(selected, tokens)]
|
|
texts: List[str] = [tokenizer.decode(t).strip() for t in tokens]
|
|
|
|
sum_logprobs: List[
|
|
float] = [lp[i] for i, lp in zip(selected, sum_logprobs)]
|
|
avg_logprobs: List[
|
|
float] = [lp / (len(t) + 1) for t, lp in zip(tokens, sum_logprobs)]
|
|
|
|
fields = (texts, languages, tokens, audio_features, avg_logprobs,
|
|
no_speech_probs)
|
|
if len(set(map(len, fields))) != 1:
|
|
raise RuntimeError(
|
|
f"inconsistent result lengths: {list(map(len, fields))}")
|
|
|
|
return [
|
|
DecodingResult(
|
|
audio_features=features,
|
|
language=language,
|
|
tokens=tokens,
|
|
text=text,
|
|
avg_logprob=avg_logprob,
|
|
no_speech_prob=no_speech_prob,
|
|
temperature=self.options.temperature,
|
|
compression_ratio=utils.compression_ratio(text), )
|
|
for text, language, tokens, features, avg_logprob, no_speech_prob in
|
|
zip(*fields)
|
|
]
|
|
|
|
|
|
@paddle.no_grad()
|
|
def decode(
|
|
model: "Whisper",
|
|
mel: paddle.Tensor,
|
|
options: DecodingOptions=DecodingOptions(),
|
|
resource_path=str, ) -> Union[DecodingResult, List[DecodingResult]]:
|
|
"""
|
|
Performs decoding of 30-second audio segment(s), provided as Mel spectrogram(s).
|
|
|
|
Parameters
|
|
----------
|
|
model: Whisper
|
|
the Whisper model instance
|
|
|
|
mel: paddle.Tensor, shape = (80, 3000) or (*, 80, 3000)
|
|
A tensor containing the Mel spectrogram(s)
|
|
|
|
options: DecodingOptions
|
|
A dataclass that contains all necessary options for decoding 30-second segments
|
|
|
|
Returns
|
|
-------
|
|
result: Union[DecodingResult, List[DecodingResult]]
|
|
The result(s) of decoding contained in `DecodingResult` dataclass instance(s)
|
|
"""
|
|
single = mel.ndim == 2
|
|
if single:
|
|
mel = mel.unsqueeze(0)
|
|
|
|
result = DecodingTask(model, options, resource_path).run(mel)
|
|
|
|
if single:
|
|
result = result[0]
|
|
|
|
return result
|
|
|
|
|
|
class Whisper(nn.Layer):
|
|
def __init__(self, dims: ModelDimensions):
|
|
super().__init__()
|
|
self.dims = dims
|
|
self.encoder = AudioEncoder(
|
|
self.dims.n_mels,
|
|
self.dims.n_audio_ctx,
|
|
self.dims.n_audio_state,
|
|
self.dims.n_audio_head,
|
|
self.dims.n_audio_layer, )
|
|
self.decoder = TextDecoder(
|
|
self.dims.n_vocab,
|
|
self.dims.n_text_ctx,
|
|
self.dims.n_text_state,
|
|
self.dims.n_text_head,
|
|
self.dims.n_text_layer, )
|
|
|
|
def embed_audio(self, mel: paddle.Tensor):
|
|
return self.encoder.forward(mel)
|
|
|
|
def logits(self, tokens: paddle.Tensor, audio_features: paddle.Tensor):
|
|
return self.decoder.forward(tokens, audio_features)
|
|
|
|
def forward(self, mel: paddle.Tensor,
|
|
tokens: paddle.Tensor) -> Dict[str, paddle.Tensor]:
|
|
return self.decoder(tokens, self.encoder(mel))
|
|
|
|
@property
|
|
def device(self):
|
|
return paddle.device.get_device()
|
|
|
|
@property
|
|
def is_multilingual(self):
|
|
return self.dims.n_vocab == 51865
|
|
|
|
def install_kv_cache_hooks(self, cache: Optional[dict]=None):
|
|
"""
|
|
The `MultiHeadAttention` module optionally accepts `kv_cache` which stores the key and value
|
|
tensors calculated for the previous positions. This method returns a dictionary that stores
|
|
all caches, and the necessary hooks for the key and value projection modules that save the
|
|
intermediate tensors to be reused during later calculations.
|
|
|
|
Returns
|
|
-------
|
|
cache : Dict[nn.Layer, paddle.Tensor]
|
|
A dictionary object mapping the key/value projection modules to its cache
|
|
hooks : List[RemovableHandle]
|
|
List of PyTorch RemovableHandle objects to stop the hooks to be called
|
|
"""
|
|
cache = {**cache} if cache is not None else {}
|
|
hooks = []
|
|
|
|
def save_to_cache(module, _, output):
|
|
if module not in cache or output.shape[
|
|
1] > self.decoder.positional_embedding.shape[0]:
|
|
cache[
|
|
module] = output # save as-is, for the first token or cross attention
|
|
else:
|
|
cache[module] = paddle.concat(
|
|
[cache[module], output], axis=1).detach()
|
|
return cache[module]
|
|
|
|
def install_hooks(layer: nn.Layer):
|
|
if isinstance(layer, MultiHeadAttention):
|
|
hooks.append(
|
|
layer.key.register_forward_post_hook(save_to_cache))
|
|
hooks.append(
|
|
layer.value.register_forward_post_hook(save_to_cache))
|
|
|
|
self.decoder.apply(install_hooks)
|
|
return cache, hooks
|
|
|
|
detect_language = detect_language
|
|
transcribe = transcribe
|
|
decode = decode
|
|
|
|
|
|
def pad_or_trim(array, length: int=N_SAMPLES, *, axis: int=-1):
|
|
"""
|
|
Pad or trim the audio array to N_SAMPLES, as expected by the encoder.
|
|
"""
|
|
if paddle.is_tensor(array):
|
|
if array.shape[axis] > length:
|
|
array = array.index_select(axis=axis, index=paddle.arange(length))
|
|
|
|
if array.shape[axis] < length:
|
|
pad_widths = [(0, 0)] * array.ndim
|
|
pad_widths[axis] = (0, length - array.shape[axis])
|
|
array = paddle.transpose(array, (1, 0))
|
|
array = F.pad(
|
|
array, [pad for sizes in pad_widths[::-1] for pad in sizes],
|
|
data_format='NLC')
|
|
array = paddle.transpose(array, (1, 0))
|
|
else:
|
|
if array.shape[axis] > length:
|
|
array = array.take(indices=range(length), axis=axis)
|
|
|
|
if array.shape[axis] < length:
|
|
pad_widths = [(0, 0)] * array.ndim
|
|
pad_widths[axis] = (0, length - array.shape[axis])
|
|
array = paddle.transpose(array, (1, 0))
|
|
array = np.pad(array, pad_widths)
|
|
array = paddle.transpose(array, (1, 0))
|
|
|
|
return array
|
|
|
|
|
|
def hann_window(n_fft: int=N_FFT):
|
|
"""
|
|
hanning window
|
|
n_fft: The number of frequency components of the discrete Fourier transform.
|
|
"""
|
|
return paddle.to_tensor(
|
|
[0.5 - 0.5 * np.cos(2 * np.pi * n / n_fft) for n in range(n_fft)],
|
|
dtype=paddle.float32)
|
|
|
|
|
|
@lru_cache(maxsize=None)
|
|
def mel_filters(resource_path: str, n_mels: int=N_MELS) -> paddle.Tensor:
|
|
"""
|
|
load the mel filterbank matrix for projecting STFT into a Mel spectrogram.
|
|
Allows decoupling librosa dependency; saved using:
|
|
|
|
np.savez_compressed(
|
|
"mel_filters.npz",
|
|
mel_80=librosa.filters.mel(sr=16000, n_fft=400, n_mels=80),
|
|
)
|
|
"""
|
|
assert n_mels == 80, f"Unsupported n_mels: {n_mels}"
|
|
with np.load(os.path.join(resource_path, "assets", "mel_filters.npz")) as f:
|
|
return paddle.to_tensor(f[f"mel_{n_mels}"])
|
|
|
|
|
|
def log_mel_spectrogram(audio: Union[str, np.ndarray, paddle.Tensor],
|
|
n_mels: int=N_MELS,
|
|
resource_path: str=None):
|
|
"""
|
|
Compute the log-Mel spectrogram of
|
|
|
|
Parameters
|
|
----------
|
|
audio: Union[str, np.ndarray, paddle.Tensor], shape = (*)
|
|
The path to audio or either a NumPy array or Tensor containing the audio waveform in 16 kHz
|
|
|
|
n_mels: int
|
|
The number of Mel-frequency filters, only 80 is supported
|
|
|
|
Returns
|
|
-------
|
|
paddle.Tensor, shape = (80, n_frames)
|
|
A Tensor that contains the Mel spectrogram
|
|
"""
|
|
if not paddle.is_tensor(audio):
|
|
if isinstance(audio, str):
|
|
audio, _ = soundfile.read(audio, dtype="float32", always_2d=True)
|
|
audio = audio[:, 0]
|
|
logger.info(f"audio shape: {audio.shape}")
|
|
audio = paddle.to_tensor(audio)
|
|
|
|
window = hann_window(N_FFT)
|
|
stft = paddle.signal.stft(audio, N_FFT, HOP_LENGTH, window=window)
|
|
|
|
magnitudes = stft[:, :-1].abs()**2
|
|
|
|
filters = mel_filters(resource_path, n_mels)
|
|
mel_spec = filters @ magnitudes
|
|
mel_spec = paddle.to_tensor(mel_spec.numpy().tolist())
|
|
|
|
log_spec = paddle.clip(mel_spec, min=1e-10).log10()
|
|
log_spec = paddle.maximum(log_spec, log_spec.max() - 8.0)
|
|
log_spec = (log_spec + 4.0) / 4.0
|
|
return log_spec
|