mv nlp tokenizer to tiktoken.

pull/4101/head
zxcd 2 weeks ago
parent 8a1d5085f9
commit fe4860550a

@ -192,14 +192,15 @@ class WhisperExecutor(BaseExecutor):
self.resource_path = os.path.join(
DATA_HOME, self.task_resource.version, 'whisper')
# self.download_resource(resource_url, self.resource_path,
# resource_md5)
self.download_resource(resource_url, self.resource_path,
resource_md5)
else:
raise Exception("wrong type")
# load model
model_dict = paddle.load(self.ckpt_path)
dims = ModelDimensions(**model_dict["dims"])
self.dims = dims
self.model = Whisper(dims)
self.model.load_dict(model_dict)
self.model.eval()
@ -252,8 +253,10 @@ class WhisperExecutor(BaseExecutor):
logger.debug(f"audio shape: {audio.shape}")
# fbank
audio = log_mel_spectrogram(
audio, resource_path=self.resource_path, n_mels=128, padding=480000)
print(audio)
audio,
resource_path=self.resource_path,
n_mels=self.dims.n_mels,
padding=480000)
audio_len = paddle.to_tensor(audio.shape[0]).unsqueeze(axis=0)
self._inputs["audio"] = audio

@ -4,6 +4,7 @@
# Modified from OpenAI Whisper 2022 (https://github.com/openai/whisper/whisper/tokenizer.py)
import base64
import os
import string
from dataclasses import dataclass
from dataclasses import field
from functools import cached_property
@ -12,12 +13,8 @@ from typing import Dict
from typing import List
from typing import Optional
from typing import Tuple
from typing import Union
import numpy as np
import paddle
import tiktoken
# from paddlenlp.transformers import GPTTokenizer
LANGUAGES = {
"en": "english",
@ -142,7 +139,7 @@ TO_LANGUAGE_CODE = {
@dataclass
class Tokenizer:
"""A thin wrapper around `GPTTokenizer` providing quick access to special tokens"""
"""A thin wrapper around `tiktoken` providing quick access to special tokens"""
encoding: tiktoken.Encoding
num_languages: int
@ -159,7 +156,6 @@ class Tokenizer:
sot: int = self.special_tokens["<|startoftranscript|>"]
translate: int = self.special_tokens["<|translate|>"]
transcribe: int = self.special_tokens["<|transcribe|>"]
langs = tuple(LANGUAGES.keys())[:self.num_languages]
sot_sequence = [sot]
if self.language is not None:
@ -167,51 +163,22 @@ class Tokenizer:
if self.task is not None:
task_token: int = transcribe if self.task == "transcribe" else translate
sot_sequence.append(task_token)
print("sot_sequence", sot_sequence)
self.sot_sequence = tuple(sot_sequence)
def encode(self, text, **kwargs):
return self.encoding.encode(text, **kwargs)
# def decode(self,
# token_ids: Union[int, List[int], np.ndarray, paddle.Tensor],
# **kwargs):
# if len(token_ids) > 1:
# ids_list = []
# for ids in token_ids:
# if paddle.is_tensor(ids):
# ids = ids.item()
# if ids < len(self.tokenizer):
# ids_list.append(ids)
# token_ids = ids_list
# elif len(token_ids) == 1:
# token_ids = token_ids[0]
# else:
# raise ValueError(f"token_ids {token_ids} load error.")
# return self.tokenizer.decode(token_ids, **kwargs)
def decode(self, token_ids: List[int], **kwargs) -> str:
token_ids = [t for t in token_ids if t < self.timestamp_begin]
return self.encoding.decode(token_ids, **kwargs)
def decode_with_timestamps(self, tokens) -> str:
def decode_with_timestamps(self, token_ids: List[int], **kwargs) -> str:
"""
Timestamp tokens are above the special tokens' id range and are ignored by `decode()`.
Timestamp tokens are above other special tokens' id range and are ignored by `decode()`.
This method decodes given tokens with timestamps tokens annotated, e.g. "<|1.08|>".
"""
outputs = [[]]
for token in tokens:
if token >= self.timestamp_begin:
timestamp = f"<|{(token - self.timestamp_begin) * 0.02:.2f}|>"
outputs.append(timestamp)
outputs.append([])
else:
outputs[-1].append(token)
outputs = [
s if isinstance(s, str) else self.tokenizer.decode(s)
for s in outputs
]
return "".join(outputs)
return self.encoding.decode(token_ids, **kwargs)
@cached_property
def eot(self) -> int:
@ -368,41 +335,16 @@ class Tokenizer:
return words, word_tokens
# @lru_cache(maxsize=None)
# def build_tokenizer(resource_path: str, name: str="gpt2"):
# os.environ["TOKENIZERS_PARALLELISM"] = "false"
# path = os.path.join(resource_path, "assets", name)
# tokenizer = GPTTokenizer.from_pretrained(path)
# specials = [
# "<|startoftranscript|>",
# * [f"<|{lang}|>" for lang in LANGUAGES.keys()],
# "<|translate|>",
# "<|transcribe|>",
# "<|startoflm|>",
# "<|startofprev|>",
# "<|nospeech|>",
# "<|notimestamps|>",
# ]
# tokenizer.add_special_tokens(dict(additional_special_tokens=specials))
# return tokenizer
@lru_cache(maxsize=None)
def get_encoding(resource_path: str, name: str="gpt2", num_languages: int=99):
print("resource_path", resource_path)
print("name", name)
vocab_path = os.path.join(resource_path, "assets", f"{name}.tiktoken")
# vocab_path = os.path.join(resource_path, "assets", sname)
# vocab_path += ".tiktoken"
ranks = {
base64.b64decode(token): int(rank)
for token, rank in (line.split() for line in open(vocab_path) if line)
}
n_vocab = len(ranks)
special_tokens = {}
specials = [
"<|endoftext|>",
"<|startoftranscript|>",
@ -415,11 +357,9 @@ def get_encoding(resource_path: str, name: str="gpt2", num_languages: int=99):
"<|notimestamps|>",
* [f"<|{i * 0.02:.2f}|>" for i in range(1501)],
]
for token in specials:
special_tokens[token] = n_vocab
n_vocab += 1
return tiktoken.Encoding(
name=os.path.basename(vocab_path),
explicit_n_vocab=n_vocab,
@ -428,50 +368,6 @@ def get_encoding(resource_path: str, name: str="gpt2", num_languages: int=99):
special_tokens=special_tokens, )
# @lru_cache(maxsize=None)
# def get_tokenizer(
# multilingual: bool,
# resource_path: str,
# *,
# task: Optional[str]=None, # Literal["transcribe", "translate", None]
# language: Optional[str]=None, ) -> Tokenizer:
# if language is not None:
# language = language.lower()
# if language not in LANGUAGES:
# if language in TO_LANGUAGE_CODE:
# language = TO_LANGUAGE_CODE[language]
# else:
# raise ValueError(f"Unsupported language: {language}")
# if multilingual:
# tokenizer_name = "multilingual"
# task = task or "transcribe"
# language = language or "en"
# else:
# tokenizer_name = "gpt2"
# task = None
# language = None
# tokenizer = build_tokenizer(
# resource_path=resource_path, name=tokenizer_name)
# all_special_ids: List[int] = tokenizer.all_special_ids
# sot: int = all_special_ids[1]
# translate: int = all_special_ids[-6]
# transcribe: int = all_special_ids[-5]
# langs = tuple(LANGUAGES.keys())
# sot_sequence = [sot]
# if language is not None:
# sot_sequence.append(sot + 1 + langs.index(language))
# if task is not None:
# sot_sequence.append(transcribe if task == "transcribe" else translate)
# return Tokenizer(
# tokenizer=tokenizer,
# language=language,
# sot_sequence=tuple(sot_sequence))
@lru_cache(maxsize=None)
def get_tokenizer(
multilingual: bool,
@ -481,8 +377,8 @@ def get_tokenizer(
language: Optional[str]=None,
task: Optional[str]=None, # Literal["transcribe", "translate", None]
) -> Tokenizer:
if language is not None:
print(language)
language = language.lower()
if language not in LANGUAGES:
if language in TO_LANGUAGE_CODE:

@ -397,7 +397,9 @@ def detect_language(
# skip encoder forward pass if already-encoded audio features were given
if mel.shape[-2:] != (model.dims.n_audio_ctx, model.dims.n_audio_state):
mel = model.encoder(mel)
mel = model.encoder(
mel
) # TODO zhaoxi: torch return float16, while cause e-3 diff with paddle float32
# forward pass using a single token, startoftranscript
batch_size = mel.shape[0]
@ -407,6 +409,7 @@ def detect_language(
# collect detected languages; suppress all non-language tokens
mask = paddle.ones(paddle.to_tensor(logits.shape[-1]), dtype=bool)
mask[list(tokenizer.all_language_tokens)] = False
logits.contiguous()
logits[:, mask] = -np.inf
language_tokens = paddle.argmax(logits, axis=-1)
language_token_probs = F.softmax(logits, axis=-1)
@ -492,8 +495,6 @@ def transcribe(
content_frames = mel.shape[-1] - N_FRAMES
content_duration = float(content_frames * HOP_LENGTH / SAMPLE_RATE)
# import pdb
# pdb.set_trace()
if decode_options.get("language", None) in {None, "None"}:
if not model.is_multilingual:
decode_options["language"] = "en"
@ -512,9 +513,7 @@ def transcribe(
)
language: str = decode_options["language"]
print("language", language)
task: str = decode_options.get("task", "transcribe")
print("model.num_languages", model.num_languages)
tokenizer = get_tokenizer(
multilingual=model.is_multilingual,
resource_path=resource_path,
@ -652,7 +651,6 @@ def transcribe(
"prompt"] = initial_prompt_tokens + remaining_prompt
else:
decode_options["prompt"] = all_tokens[prompt_reset_since:]
result: DecodingResult = decode_with_fallback(mel_segment)
tokens = paddle.to_tensor(result.tokens)
@ -704,7 +702,6 @@ def transcribe(
consecutive = paddle.where(timestamp_tokens[:-1] & timestamp_tokens[
1:])[0]
print("consecutive", consecutive)
consecutive = paddle.add(consecutive, paddle.to_tensor(1))
if len(consecutive) > 0:
# if the output contains two consecutive timestamp tokens
@ -755,82 +752,6 @@ def transcribe(
result=result, ))
seek += segment_size
if word_timestamps:
add_word_timestamps(
segments=current_segments,
model=model,
tokenizer=tokenizer,
mel=mel_segment,
num_frames=segment_size,
prepend_punctuations=prepend_punctuations,
append_punctuations=append_punctuations,
last_speech_timestamp=last_speech_timestamp, )
if not single_timestamp_ending:
last_word_end = get_end(current_segments)
if last_word_end is not None and last_word_end > time_offset:
seek = round(last_word_end * FRAMES_PER_SECOND)
# skip silence before possible hallucinations
if hallucination_silence_threshold is not None:
threshold = hallucination_silence_threshold
if not single_timestamp_ending:
last_word_end = get_end(current_segments)
if last_word_end is not None and last_word_end > time_offset:
remaining_duration = window_end_time - last_word_end
if remaining_duration > threshold:
seek = round(last_word_end * FRAMES_PER_SECOND)
else:
seek = previous_seek + segment_size
# if first segment might be a hallucination, skip leading silence
first_segment = next_words_segment(current_segments)
if first_segment is not None and is_segment_anomaly(
first_segment):
gap = first_segment["start"] - time_offset
if gap > threshold:
seek = previous_seek + round(gap *
FRAMES_PER_SECOND)
continue
# skip silence before any possible hallucination that is surrounded
# by silence or more hallucinations
hal_last_end = last_speech_timestamp
for si in range(len(current_segments)):
segment = current_segments[si]
if not segment["words"]:
continue
if is_segment_anomaly(segment):
next_segment = next_words_segment(
current_segments[si + 1:])
if next_segment is not None:
hal_next_start = next_segment["words"][0][
"start"]
else:
hal_next_start = time_offset + segment_duration
silence_before = (
segment["start"] - hal_last_end > threshold or
segment["start"] < threshold or
segment["start"] - time_offset < 2.0)
silence_after = (
hal_next_start - segment["end"] > threshold or
is_segment_anomaly(next_segment) or
window_end_time - segment["end"] < 2.0)
if silence_before and silence_after:
seek = round(
max(time_offset + 1, segment["start"]) *
FRAMES_PER_SECOND)
if content_duration - segment[
"end"] < threshold:
seek = content_frames
current_segments[si:] = []
break
hal_last_end = segment["end"]
last_word_end = get_end(current_segments)
if last_word_end is not None:
last_speech_timestamp = last_word_end
if verbose:
for segment in current_segments:
start, end, text = segment["start"], segment[
@ -1138,6 +1059,7 @@ class SuppressBlank(LogitFilter):
def apply(self, logits: paddle.Tensor, tokens: paddle.Tensor):
if tokens.shape[1] == self.sample_begin:
logits.contiguous()
logits[:, self.tokenizer.encode(" ") + [self.tokenizer.eot
]] = -np.inf
@ -1147,6 +1069,7 @@ class SuppressTokens(LogitFilter):
self.suppress_tokens = list(suppress_tokens)
def apply(self, logits: paddle.Tensor, tokens: paddle.Tensor):
logits.contiguous()
logits[:, self.suppress_tokens] = -np.inf
@ -1162,6 +1085,7 @@ class ApplyTimestampRules(LogitFilter):
def apply(self, logits: paddle.Tensor, tokens: paddle.Tensor):
# suppress <|notimestamps|> which is handled by without_timestamps
if self.tokenizer.no_timestamps is not None:
logits.contiguous()
logits[:, self.tokenizer.no_timestamps] = -np.inf
# timestamps have to appear in pairs, except directly before EOT; mask logits accordingly
@ -1182,6 +1106,7 @@ class ApplyTimestampRules(LogitFilter):
if tokens.shape[
1] == self.sample_begin and self.max_initial_timestamp_index is not None:
last_allowed = self.tokenizer.timestamp_begin + self.max_initial_timestamp_index
logits.contiguous()
logits[:, last_allowed + 1:] = -np.inf
# if sum of probability over timestamps is above any other token, sample timestamp
@ -1218,12 +1143,14 @@ class DecodingTask:
multilingual=model.is_multilingual,
resource_path=resource_path,
language=language,
task=options.task)
task=options.task,
num_languages=model.num_languages)
self.tokenizer: Tokenizer = tokenizer
self.options: DecodingOptions = self._verify_options(options)
self.resource_path: str = resource_path
self.beam_size: int = options.beam_size or options.best_of or 1
# self.beam_size: int = options.beam_size or options.best_of or 1
self.n_group: int = options.beam_size or options.best_of or 1
self.n_ctx: int = model.dims.n_text_ctx
self.sample_len: int = options.sample_len or model.dims.n_text_ctx // 2
@ -1368,10 +1295,10 @@ class DecodingTask:
sum_logprobs: paddle.Tensor = paddle.zeros(
paddle.to_tensor(n_batch), dtype=paddle.float32)
no_speech_probs = [np.nan] * n_batch
try:
for i in range(self.sample_len):
logits = self.inference.logits(tokens, audio_features)
logits.contiguous()
if i == 0 and self.tokenizer.no_speech is not None: # save no_speech_probs
probs_at_sot = F.softmax(
@ -1407,17 +1334,8 @@ class DecodingTask:
audio_features: paddle.Tensor = self._get_audio_features(
mel) # encoder forward pass
tokens: paddle.Tensor
if batch_size > 1:
for i in range(batch_size):
tokens = paddle.concat(
x=[
paddle.to_tensor([self.initial_tokens]),
paddle.to_tensor([self.initial_tokens])
],
axis=0)
elif batch_size == 1:
tokens = paddle.to_tensor([self.initial_tokens])
tokens: Tensor = paddle.to_tensor([self.initial_tokens]).repeat(
batch_size, 1)
# detect language if requested, overwriting the language token
languages, language_probs = self._detect_language(
@ -1434,30 +1352,26 @@ class DecodingTask:
language_probs)
]
# repeat the audio & text tensors by the group size, for beam search or best-of-n sampling
audio_features = paddle.repeat_interleave(
audio_features, self.beam_size, axis=0)
tokens = paddle.repeat_interleave(tokens, self.beam_size, axis=0)
# repeat text tensors by the group size, for beam search or best-of-n sampling
tokens = tokens.repeat_interleave(self.n_group, axis=0)
# call the main sampling loop
tokens, sum_logprobs, no_speech_probs = self._main_loop(audio_features,
tokens)
# reshape the tensors to have (batch_size, beam_size) as the first two dimensions
audio_features = audio_features[::self.beam_size]
no_speech_probs = no_speech_probs[::self.beam_size]
# reshape the tensors to have (batch_size, n_group) as the first two dimensions
audio_features = audio_features[::self.n_group]
no_speech_probs = no_speech_probs[::self.n_group]
assert audio_features.shape[0] == len(no_speech_probs) == batch_size
tokens = tokens.reshape([batch_size, self.beam_size, -1])
sum_logprobs = sum_logprobs.reshape([batch_size, self.beam_size])
tokens = tokens.reshape([batch_size, self.n_group, -1])
sum_logprobs = sum_logprobs.reshape([batch_size, self.n_group])
# get the final candidates for each group, and slice between the first sampled token and EOT
tokens, sum_logprobs = self.decoder.finalize(tokens, sum_logprobs)
tokens: List[List[paddle.Tensor]] = [[
tokens: List[List[Tensor]] = [[
t[self.sample_begin:(t == tokenizer.eot).nonzero()[0, 0]] for t in s
] for s in tokens]
# select the top-ranked sample in each group
selected = self.sequence_ranker.rank(tokens, sum_logprobs)
tokens: List[List[
@ -1466,11 +1380,12 @@ class DecodingTask:
sum_logprobs: List[
float] = [lp[i] for i, lp in zip(selected, sum_logprobs)]
avg_logprobs: List[
float] = [lp / (len(t) + 1) for t, lp in zip(tokens, sum_logprobs)]
avg_logprobs: List[float] = [
lp / (len(t) + 1) for t, lp in zip(tokens, sum_logprobs)
]
fields = (texts, languages, tokens, audio_features, avg_logprobs,
no_speech_probs)
no_speech_probs, )
if len(set(map(len, fields))) != 1:
raise RuntimeError(
f"inconsistent result lengths: {list(map(len, fields))}")
@ -1504,7 +1419,7 @@ def decode(
model: Whisper
the Whisper model instance
mel: paddle.Tensor, shape = (80, 3000) or (*, 80, 3000)
mel: paddle.Tensor, shape = (80, 3000) or (*, 80, 3000) or (128, 3000) or (*, 128, 3000)
A tensor containing the Mel spectrogram(s)
options: DecodingOptions
@ -1660,9 +1575,6 @@ def mel_filters(resource_path: str, n_mels: int) -> paddle.Tensor:
mel_80=librosa.filters.mel(sr=16000, n_fft=400, n_mels=80),
)
"""
# assert n_mels == 80, f"Unsupported n_mels: {n_mels}"
# with np.load(os.path.join(resource_path, "assets", "mel_filters.npz")) as f:
# return paddle.to_tensor(f[f"mel_{n_mels}"])
assert n_mels in {80, 128}, f"Unsupported n_mels: {n_mels}"
filters_path = os.path.join(resource_path, "assets", "mel_filters.npz")
@ -1683,11 +1595,11 @@ def log_mel_spectrogram(audio: Union[str, np.ndarray, paddle.Tensor],
The path to audio or either a NumPy array or Tensor containing the audio waveform in 16 kHz
n_mels: int
The number of Mel-frequency filters, only 80 is supported
The number of Mel-frequency filters, only 80 and 128 is supported
Returns
-------
paddle.Tensor, shape = (80, n_frames)
paddle.Tensor, shape = (n_mels, n_frames)
A Tensor that contains the Mel spectrogram
"""
if not paddle.is_tensor(audio):

@ -128,6 +128,7 @@ base = [
"flatten_dict",
"pyloudnorm",
"rich",
"tiktoken",
]
server = ["pattern_singleton", "websockets"]

Loading…
Cancel
Save