diff --git a/paddlespeech/cli/whisper/infer.py b/paddlespeech/cli/whisper/infer.py index 5649b757f..bb174a1e6 100644 --- a/paddlespeech/cli/whisper/infer.py +++ b/paddlespeech/cli/whisper/infer.py @@ -192,8 +192,8 @@ class WhisperExecutor(BaseExecutor): self.resource_path = os.path.join( DATA_HOME, self.task_resource.version, 'whisper') - self.download_resource(resource_url, self.resource_path, - resource_md5) + # self.download_resource(resource_url, self.resource_path, + # resource_md5) else: raise Exception("wrong type") @@ -251,8 +251,9 @@ class WhisperExecutor(BaseExecutor): logger.debug(f"audio shape: {audio.shape}") # fbank - audio = log_mel_spectrogram(audio, resource_path=self.resource_path) - + audio = log_mel_spectrogram( + audio, resource_path=self.resource_path, n_mels=128, padding=480000) + print(audio) audio_len = paddle.to_tensor(audio.shape[0]).unsqueeze(axis=0) self._inputs["audio"] = audio @@ -275,7 +276,6 @@ class WhisperExecutor(BaseExecutor): cfg.temperature_increment_on_fallback)) else: temperature = [cfg.temperature] - self._outputs["result"] = self.model.transcribe( audio, verbose=cfg.verbose, diff --git a/paddlespeech/s2t/exps/whisper/test_wav.py b/paddlespeech/s2t/exps/whisper/test_wav.py index d9c32a406..27c167063 100644 --- a/paddlespeech/s2t/exps/whisper/test_wav.py +++ b/paddlespeech/s2t/exps/whisper/test_wav.py @@ -63,9 +63,13 @@ class WhisperInfer(): temperature = [temperature] #load audio - mel = log_mel_spectrogram( - args.audio_file, resource_path=config.resource_path) - + # mel = log_mel_spectrogram( + # args.audio_file, resource_path=config.resource_path, , n_mels=128) + audio = log_mel_spectrogram( + args.audio_file, + resource_path=config.resource_path, + n_mels=128, + padding=480000) result = transcribe( self.model, mel, temperature=temperature, **config) if args.result_file is not None: diff --git a/paddlespeech/s2t/models/whisper/tokenizer.py b/paddlespeech/s2t/models/whisper/tokenizer.py index e8b201bcc..7b2722d41 100644 --- a/paddlespeech/s2t/models/whisper/tokenizer.py +++ b/paddlespeech/s2t/models/whisper/tokenizer.py @@ -2,9 +2,13 @@ # Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved. # # Modified from OpenAI Whisper 2022 (https://github.com/openai/whisper/whisper/tokenizer.py) +import base64 import os from dataclasses import dataclass +from dataclasses import field +from functools import cached_property from functools import lru_cache +from typing import Dict from typing import List from typing import Optional from typing import Tuple @@ -12,7 +16,8 @@ from typing import Union import numpy as np import paddle -from paddlenlp.transformers import GPTTokenizer +import tiktoken +# from paddlenlp.transformers import GPTTokenizer LANGUAGES = { "en": "english", @@ -35,7 +40,7 @@ LANGUAGES = { "hi": "hindi", "fi": "finnish", "vi": "vietnamese", - "iw": "hebrew", + "he": "hebrew", "uk": "ukrainian", "el": "greek", "ms": "malay", @@ -114,6 +119,7 @@ LANGUAGES = { "ba": "bashkir", "jw": "javanese", "su": "sundanese", + "yue": "cantonese", } # language code lookup by name, with a few language aliases @@ -130,37 +136,63 @@ TO_LANGUAGE_CODE = { "moldovan": "ro", "sinhalese": "si", "castilian": "es", + "mandarin": "zh", } -@dataclass(frozen=True) +@dataclass class Tokenizer: """A thin wrapper around `GPTTokenizer` providing quick access to special tokens""" - tokenizer: "GPTTokenizer" - language: Optional[str] - sot_sequence: Tuple[int] + encoding: tiktoken.Encoding + num_languages: int + language: Optional[str] = None + task: Optional[str] = None + sot_sequence: Tuple[int] = () + special_tokens: Dict[str, int] = field(default_factory=dict) + + def __post_init__(self): + for special in self.encoding.special_tokens_set: + special_token = self.encoding.encode_single_token(special) + self.special_tokens[special] = special_token + + sot: int = self.special_tokens["<|startoftranscript|>"] + translate: int = self.special_tokens["<|translate|>"] + transcribe: int = self.special_tokens["<|transcribe|>"] + + langs = tuple(LANGUAGES.keys())[:self.num_languages] + sot_sequence = [sot] + if self.language is not None: + sot_sequence.append(sot + 1 + langs.index(self.language)) + if self.task is not None: + task_token: int = transcribe if self.task == "transcribe" else translate + sot_sequence.append(task_token) + print("sot_sequence", sot_sequence) + self.sot_sequence = tuple(sot_sequence) def encode(self, text, **kwargs): - return self.tokenizer.encode(text, **kwargs) - - def decode(self, - token_ids: Union[int, List[int], np.ndarray, paddle.Tensor], - **kwargs): - if len(token_ids) > 1: - ids_list = [] - for ids in token_ids: - if paddle.is_tensor(ids): - ids = ids.item() - if ids < len(self.tokenizer): - ids_list.append(ids) - token_ids = ids_list - elif len(token_ids) == 1: - token_ids = token_ids[0] - else: - raise ValueError(f"token_ids {token_ids} load error.") - - return self.tokenizer.decode(token_ids, **kwargs) + return self.encoding.encode(text, **kwargs) + + # def decode(self, + # token_ids: Union[int, List[int], np.ndarray, paddle.Tensor], + # **kwargs): + # if len(token_ids) > 1: + # ids_list = [] + # for ids in token_ids: + # if paddle.is_tensor(ids): + # ids = ids.item() + # if ids < len(self.tokenizer): + # ids_list.append(ids) + # token_ids = ids_list + # elif len(token_ids) == 1: + # token_ids = token_ids[0] + # else: + # raise ValueError(f"token_ids {token_ids} load error.") + + # return self.tokenizer.decode(token_ids, **kwargs) + def decode(self, token_ids: List[int], **kwargs) -> str: + token_ids = [t for t in token_ids if t < self.timestamp_begin] + return self.encoding.decode(token_ids, **kwargs) def decode_with_timestamps(self, tokens) -> str: """ @@ -181,83 +213,75 @@ class Tokenizer: ] return "".join(outputs) - @property - @lru_cache() + @cached_property def eot(self) -> int: - return self.tokenizer.eos_token_id + return self.encoding.eot_token - @property - @lru_cache() + @cached_property + def transcribe(self) -> int: + return self.special_tokens["<|transcribe|>"] + + @cached_property + def translate(self) -> int: + return self.special_tokens["<|translate|>"] + + @cached_property def sot(self) -> int: - return self._get_single_token_id("<|startoftranscript|>") + return self.special_tokens["<|startoftranscript|>"] - @property - @lru_cache() + @cached_property def sot_lm(self) -> int: - return self._get_single_token_id("<|startoflm|>") + return self.special_tokens["<|startoflm|>"] - @property - @lru_cache() + @cached_property def sot_prev(self) -> int: - return self._get_single_token_id("<|startofprev|>") + return self.special_tokens["<|startofprev|>"] - @property - @lru_cache() + @cached_property def no_speech(self) -> int: - return self._get_single_token_id("<|nospeech|>") + return self.special_tokens["<|nospeech|>"] - @property - @lru_cache() + @cached_property def no_timestamps(self) -> int: - return self._get_single_token_id("<|notimestamps|>") + return self.special_tokens["<|notimestamps|>"] - @property - @lru_cache() + @cached_property def timestamp_begin(self) -> int: - return self.tokenizer.all_special_ids[-1] + 1 + return self.special_tokens["<|0.00|>"] - @property - @lru_cache() + @cached_property def language_token(self) -> int: """Returns the token id corresponding to the value of the `language` field""" if self.language is None: raise ValueError( "This tokenizer does not have language token configured") - additional_tokens = dict( - zip( - self.tokenizer.additional_special_tokens, - self.tokenizer.additional_special_tokens_ids, )) - candidate = f"<|{self.language}|>" - if candidate in additional_tokens: - return additional_tokens[candidate] + return self.to_language_token(self.language) + + def to_language_token(self, language): + if token := self.special_tokens.get(f"<|{language}|>", None): + return token - raise KeyError(f"Language {self.language} not found in tokenizer.") + raise KeyError(f"Language {language} not found in tokenizer.") - @property - @lru_cache() + @cached_property def all_language_tokens(self) -> Tuple[int]: result = [] - for token, token_id in zip( - self.tokenizer.additional_special_tokens, - self.tokenizer.additional_special_tokens_ids, ): + for token, token_id in self.special_tokens.items(): if token.strip("<|>") in LANGUAGES: result.append(token_id) - return tuple(result) + return tuple(result)[:self.num_languages] - @property - @lru_cache() + @cached_property def all_language_codes(self) -> Tuple[str]: return tuple( - self.decode([l]).strip("<|>") for l in self.all_language_tokens) + self.decode([_l]).strip("<|>") for _l in self.all_language_tokens) - @property - @lru_cache() + @cached_property def sot_sequence_including_notimestamps(self) -> Tuple[int]: return tuple(list(self.sot_sequence) + [self.no_timestamps]) - @property - @lru_cache() + @cached_property def non_speech_tokens(self) -> Tuple[int]: """ Returns the list of tokens to suppress in order to avoid any speaker tags or non-speech @@ -269,9 +293,10 @@ class Tokenizer: keeping basic punctuations like commas, periods, question marks, exclamation points, etc. """ - symbols = list("\"#()*+/:;<=>@[\\]^_`{|}~「」『』") - symbols += "<< >> <<< >>> -- --- -( -[ (' (\" (( )) ((( ))) [[ ]] {{ }} ♪♪ ♪♪♪".split( - ) + symbols = list('"#()*+/:;<=>@[\\]^_`{|}~「」『』') + symbols += ( + "<< >> <<< >>> -- --- -( -[ (' (\" (( )) ((( ))) [[ ]] {{ }} ♪♪ ♪♪♪". + split()) # symbols that may be a single token or multiple tokens depending on the tokenizer. # In case they're multiple tokens, suppress the first token, which is safe because: @@ -281,45 +306,170 @@ class Tokenizer: assert all(0x2640 <= ord(c) <= 0x267F for c in miscellaneous) # allow hyphens "-" and single quotes "'" between words, but not at the beginning of a word - result = { - self.tokenizer.encode(" -").input_ids[0], - self.tokenizer.encode(" '").input_ids[0] - } + result = {self.encoding.encode(" -")[0], self.encoding.encode(" '")[0]} for symbol in symbols + list(miscellaneous): for tokens in [ - self.tokenizer.encode(symbol).input_ids, - self.tokenizer.encode(" " + symbol).input_ids + self.encoding.encode(symbol), + self.encoding.encode(" " + symbol), ]: if len(tokens) == 1 or symbol in miscellaneous: result.add(tokens[0]) return tuple(sorted(result)) - def _get_single_token_id(self, text) -> int: - tokens = self.tokenizer.encode(text).input_ids - assert len(tokens) == 1, f"{text} is not encoded as a single token" - return tokens[0] + def split_to_word_tokens(self, tokens: List[int]): + if self.language in {"zh", "ja", "th", "lo", "my", "yue"}: + # These languages don't typically use spaces, so it is difficult to split words + # without morpheme analysis. Here, we instead split words at any + # position where the tokens are decoded as valid unicode points + return self.split_tokens_on_unicode(tokens) + + return self.split_tokens_on_spaces(tokens) + + def split_tokens_on_unicode(self, tokens: List[int]): + decoded_full = self.decode_with_timestamps(tokens) + replacement_char = "\ufffd" + + words = [] + word_tokens = [] + current_tokens = [] + unicode_offset = 0 + + for token in tokens: + current_tokens.append(token) + decoded = self.decode_with_timestamps(current_tokens) + + if (replacement_char not in decoded or + decoded_full[unicode_offset + decoded.index( + replacement_char)] == replacement_char): + words.append(decoded) + word_tokens.append(current_tokens) + current_tokens = [] + unicode_offset += len(decoded) + + return words, word_tokens + + def split_tokens_on_spaces(self, tokens: List[int]): + subwords, subword_tokens_list = self.split_tokens_on_unicode(tokens) + words = [] + word_tokens = [] + + for subword, subword_tokens in zip(subwords, subword_tokens_list): + special = subword_tokens[0] >= self.eot + with_space = subword.startswith(" ") + punctuation = subword.strip() in string.punctuation + if special or with_space or punctuation or len(words) == 0: + words.append(subword) + word_tokens.append(subword_tokens) + else: + words[-1] = words[-1] + subword + word_tokens[-1].extend(subword_tokens) + + return words, word_tokens + + +# @lru_cache(maxsize=None) +# def build_tokenizer(resource_path: str, name: str="gpt2"): +# os.environ["TOKENIZERS_PARALLELISM"] = "false" +# path = os.path.join(resource_path, "assets", name) +# tokenizer = GPTTokenizer.from_pretrained(path) + +# specials = [ +# "<|startoftranscript|>", +# * [f"<|{lang}|>" for lang in LANGUAGES.keys()], +# "<|translate|>", +# "<|transcribe|>", +# "<|startoflm|>", +# "<|startofprev|>", +# "<|nospeech|>", +# "<|notimestamps|>", +# ] + +# tokenizer.add_special_tokens(dict(additional_special_tokens=specials)) +# return tokenizer @lru_cache(maxsize=None) -def build_tokenizer(resource_path: str, name: str="gpt2"): - os.environ["TOKENIZERS_PARALLELISM"] = "false" - path = os.path.join(resource_path, "assets", name) - tokenizer = GPTTokenizer.from_pretrained(path) +def get_encoding(resource_path: str, name: str="gpt2", num_languages: int=99): + print("resource_path", resource_path) + print("name", name) + vocab_path = os.path.join(resource_path, "assets", f"{name}.tiktoken") + # vocab_path = os.path.join(resource_path, "assets", sname) + # vocab_path += ".tiktoken" + ranks = { + base64.b64decode(token): int(rank) + for token, rank in (line.split() for line in open(vocab_path) if line) + } + n_vocab = len(ranks) + special_tokens = {} specials = [ + "<|endoftext|>", "<|startoftranscript|>", - * [f"<|{lang}|>" for lang in LANGUAGES.keys()], + * [f"<|{lang}|>" for lang in list(LANGUAGES.keys())[:num_languages]], "<|translate|>", "<|transcribe|>", "<|startoflm|>", "<|startofprev|>", "<|nospeech|>", "<|notimestamps|>", + * [f"<|{i * 0.02:.2f}|>" for i in range(1501)], ] - tokenizer.add_special_tokens(dict(additional_special_tokens=specials)) - return tokenizer + for token in specials: + special_tokens[token] = n_vocab + n_vocab += 1 + + return tiktoken.Encoding( + name=os.path.basename(vocab_path), + explicit_n_vocab=n_vocab, + pat_str=r"""'s|'t|'re|'ve|'m|'ll|'d| ?\p{L}+| ?\p{N}+| ?[^\s\p{L}\p{N}]+|\s+(?!\S)|\s+""", + mergeable_ranks=ranks, + special_tokens=special_tokens, ) + + +# @lru_cache(maxsize=None) +# def get_tokenizer( +# multilingual: bool, +# resource_path: str, +# *, +# task: Optional[str]=None, # Literal["transcribe", "translate", None] +# language: Optional[str]=None, ) -> Tokenizer: +# if language is not None: +# language = language.lower() +# if language not in LANGUAGES: +# if language in TO_LANGUAGE_CODE: +# language = TO_LANGUAGE_CODE[language] +# else: +# raise ValueError(f"Unsupported language: {language}") + +# if multilingual: +# tokenizer_name = "multilingual" +# task = task or "transcribe" +# language = language or "en" +# else: +# tokenizer_name = "gpt2" +# task = None +# language = None + +# tokenizer = build_tokenizer( +# resource_path=resource_path, name=tokenizer_name) +# all_special_ids: List[int] = tokenizer.all_special_ids +# sot: int = all_special_ids[1] +# translate: int = all_special_ids[-6] +# transcribe: int = all_special_ids[-5] + +# langs = tuple(LANGUAGES.keys()) +# sot_sequence = [sot] +# if language is not None: +# sot_sequence.append(sot + 1 + langs.index(language)) +# if task is not None: +# sot_sequence.append(transcribe if task == "transcribe" else translate) + +# return Tokenizer( +# tokenizer=tokenizer, +# language=language, +# sot_sequence=tuple(sot_sequence)) @lru_cache(maxsize=None) @@ -327,9 +477,12 @@ def get_tokenizer( multilingual: bool, resource_path: str, *, + num_languages: int=99, + language: Optional[str]=None, task: Optional[str]=None, # Literal["transcribe", "translate", None] - language: Optional[str]=None, ) -> Tokenizer: +) -> Tokenizer: if language is not None: + print(language) language = language.lower() if language not in LANGUAGES: if language in TO_LANGUAGE_CODE: @@ -338,29 +491,21 @@ def get_tokenizer( raise ValueError(f"Unsupported language: {language}") if multilingual: - tokenizer_name = "multilingual" - task = task or "transcribe" + encoding_name = "multilingual" language = language or "en" + task = task or "transcribe" else: - tokenizer_name = "gpt2" - task = None + encoding_name = "gpt2" language = None + task = None - tokenizer = build_tokenizer( - resource_path=resource_path, name=tokenizer_name) - all_special_ids: List[int] = tokenizer.all_special_ids - sot: int = all_special_ids[1] - translate: int = all_special_ids[-6] - transcribe: int = all_special_ids[-5] - - langs = tuple(LANGUAGES.keys()) - sot_sequence = [sot] - if language is not None: - sot_sequence.append(sot + 1 + langs.index(language)) - if task is not None: - sot_sequence.append(transcribe if task == "transcribe" else translate) + encoding = get_encoding( + resource_path=resource_path, + name=encoding_name, + num_languages=num_languages) return Tokenizer( - tokenizer=tokenizer, + encoding=encoding, + num_languages=num_languages, language=language, - sot_sequence=tuple(sot_sequence)) + task=task) diff --git a/paddlespeech/s2t/models/whisper/whisper.py b/paddlespeech/s2t/models/whisper/whisper.py index fdd3a6974..737276eeb 100644 --- a/paddlespeech/s2t/models/whisper/whisper.py +++ b/paddlespeech/s2t/models/whisper/whisper.py @@ -33,13 +33,18 @@ logger = Log(__name__).getlog() _MODELS = ["large"] SAMPLE_RATE = 16000 N_FFT = 400 -N_MELS = 80 HOP_LENGTH = 160 CHUNK_LENGTH = 30 N_SAMPLES = CHUNK_LENGTH * SAMPLE_RATE # 480000: number of samples in a chunk N_FRAMES = utils.exact_div( N_SAMPLES, HOP_LENGTH) # 3000: number of frames in a mel spectrogram input +N_SAMPLES_PER_TOKEN = HOP_LENGTH * 2 # the initial convolutions has stride 2 +FRAMES_PER_SECOND = utils.exact_div(SAMPLE_RATE, + HOP_LENGTH) # 10ms per audio frame +TOKENS_PER_SECOND = utils.exact_div(SAMPLE_RATE, + N_SAMPLES_PER_TOKEN) # 20ms per audio token + @dataclass class ModelDimensions: @@ -378,7 +383,9 @@ def detect_language( """ if tokenizer is None: tokenizer = get_tokenizer( - model.is_multilingual, resource_path=resource_path) + multilingual=model.is_multilingual, + resource_path=resource_path, + num_languages=model.num_languages) if tokenizer.language is None or tokenizer.language_token not in tokenizer.sot_sequence: raise ValueError( "This model doesn't have language tokens so it can't perform lang id" @@ -428,6 +435,13 @@ def transcribe( logprob_threshold: Optional[float]=-1.0, no_speech_threshold: Optional[float]=0.6, condition_on_previous_text: bool=True, + initial_prompt: Optional[str]=None, + carry_initial_prompt: bool=False, + word_timestamps: bool=False, + prepend_punctuations: str="\"'“¿([{-", + append_punctuations: str="\"'.。,,!!??::”)]}、", + clip_timestamps: Union[str, List[float]]="0", + hallucination_silence_threshold: Optional[float]=None, **decode_options, ): """ Transcribe an audio file using Whisper @@ -476,8 +490,11 @@ def transcribe( if dtype == np.float32: decode_options["fp16"] = False - if decode_options.get("language") == 'None' or decode_options.get( - "language", None) is None: + content_frames = mel.shape[-1] - N_FRAMES + content_duration = float(content_frames * HOP_LENGTH / SAMPLE_RATE) + # import pdb + # pdb.set_trace() + if decode_options.get("language", None) in {None, "None"}: if not model.is_multilingual: decode_options["language"] = "en" else: @@ -485,25 +502,49 @@ def transcribe( print( "Detecting language using up to the first 30 seconds. Use `--language` to specify the language" ) - segment = pad_or_trim(mel, N_FRAMES) - _, probs = model.detect_language(segment, resource_path) + mel_segment = pad_or_trim(mel, + N_FRAMES).to(model.device).astype(dtype) + _, probs = model.detect_language(mel_segment, resource_path) decode_options["language"] = max(probs, key=probs.get) if verbose is not None: print( f"Detected language: {LANGUAGES[decode_options['language']].title()}" ) - language = decode_options["language"] - task = decode_options.get("task", "transcribe") + language: str = decode_options["language"] + print("language", language) + task: str = decode_options.get("task", "transcribe") + print("model.num_languages", model.num_languages) tokenizer = get_tokenizer( - model.is_multilingual, + multilingual=model.is_multilingual, resource_path=resource_path, + num_languages=model.num_languages, language=language, - task=task) + task=task, ) + + if isinstance(clip_timestamps, str): + clip_timestamps = [ + float(ts) + for ts in (clip_timestamps.split(",") if clip_timestamps else []) + ] + seek_points: List[ + int] = [round(ts * FRAMES_PER_SECOND) for ts in clip_timestamps] + if len(seek_points) == 0: + seek_points.append(0) + if len(seek_points) % 2 == 1: + seek_points.append(content_frames) + seek_clips: List[Tuple[int, int]] = list( + zip(seek_points[::2], seek_points[1::2])) + + punctuation = "\"'“¿([{-\"'.。,,!!??::”)]}、" + + if word_timestamps and task == "translate": + warnings.warn( + "Word-level timestamps on translations may not be reliable.") def decode_with_fallback(segment: paddle.Tensor) -> DecodingResult: - temperatures = [temperature] if isinstance(temperature, ( - int, float)) else temperature + temperatures = ([temperature] if isinstance(temperature, (int, float)) + else temperature) decode_result = None for t in temperatures: @@ -517,20 +558,29 @@ def transcribe( kwargs.pop("best_of", None) options = DecodingOptions(**kwargs, temperature=t) + decode_result = model.decode(segment, options, resource_path) needs_fallback = False - if compression_ratio_threshold is not None and decode_result.compression_ratio > compression_ratio_threshold: + if (compression_ratio_threshold is not None and + decode_result.compression_ratio > + compression_ratio_threshold): needs_fallback = True # too repetitive - if logprob_threshold is not None and decode_result.avg_logprob < logprob_threshold: + if (logprob_threshold is not None and + decode_result.avg_logprob < logprob_threshold): needs_fallback = True # average log probability is too low - + if (no_speech_threshold is not None and + decode_result.no_speech_prob > no_speech_threshold and + logprob_threshold is not None and + decode_result.avg_logprob < logprob_threshold): + needs_fallback = False # silence if not needs_fallback: break return decode_result - seek = 0 + clip_idx = 0 + seek = seek_clips[clip_idx][0] input_stride = utils.exact_div( N_FRAMES, model.dims.n_audio_ctx) # mel frames per output token: 2 time_precision = (input_stride * HOP_LENGTH / @@ -539,127 +589,287 @@ def transcribe( all_segments = [] prompt_reset_since = 0 - initial_prompt = decode_options.pop("initial_prompt", None) or [] - if initial_prompt: - initial_prompt = tokenizer.encode(" " + - initial_prompt.strip()).input_ids - all_tokens.extend(initial_prompt) + remaining_prompt_length = model.dims.n_text_ctx // 2 - 1 + if initial_prompt is not None: + initial_prompt_tokens = tokenizer.encode(" " + initial_prompt.strip()) + all_tokens.extend(initial_prompt_tokens) + remaining_prompt_length -= len(initial_prompt_tokens) + else: + initial_prompt_tokens = [] - def add_segment(*, + def new_segment(*, start: float, end: float, - text_tokens: paddle.Tensor, + tokens: paddle.Tensor, result: DecodingResult): - text = tokenizer.decode( - [token for token in text_tokens if token < tokenizer.eot]) - if len(text.strip()) == 0: # skip empty text output - return - - all_segments.append({ - "id": len(all_segments), + tokens = tokens.tolist() + text_tokens = [token for token in tokens if token < tokenizer.eot] + return { "seek": seek, "start": start, "end": end, - "text": text, - "tokens": result.tokens, + "text": tokenizer.decode(text_tokens), + "tokens": tokens, "temperature": result.temperature, "avg_logprob": result.avg_logprob, "compression_ratio": result.compression_ratio, "no_speech_prob": result.no_speech_prob, - }) - if verbose: - print( - f"[{utils.format_timestamp(start)} --> {utils.format_timestamp(end)}] {text}" - ) - - # show the progress bar when verbose is False (otherwise the transcribed text will be printed) - num_frames = mel.shape[-1] - previous_seek_value = seek + } + # show the progress bar when verbose is False (if True, transcribed text will be printed) with tqdm.tqdm( - total=num_frames, unit='frames', + total=content_frames, unit="frames", disable=verbose is not False) as pbar: - while seek < num_frames: - timestamp_offset = float(seek * HOP_LENGTH / SAMPLE_RATE) - segment = pad_or_trim(mel[:, seek:], N_FRAMES) - segment_duration = segment.shape[-1] * HOP_LENGTH / SAMPLE_RATE + last_speech_timestamp = 0.0 + # NOTE: This loop is obscurely flattened to make the diff readable. + # A later commit should turn this into a simpler nested loop. + # for seek_clip_start, seek_clip_end in seek_clips: + # while seek < seek_clip_end + while clip_idx < len(seek_clips): + seek_clip_start, seek_clip_end = seek_clips[clip_idx] + if seek < seek_clip_start: + seek = seek_clip_start + if seek >= seek_clip_end: + clip_idx += 1 + if clip_idx < len(seek_clips): + seek = seek_clips[clip_idx][0] + continue + time_offset = float(seek * HOP_LENGTH / SAMPLE_RATE) + window_end_time = float( + (seek + N_FRAMES) * HOP_LENGTH / SAMPLE_RATE) + segment_size = min(N_FRAMES, content_frames - seek, + seek_clip_end - seek) + mel_segment = mel[:, seek:seek + segment_size] + segment_duration = segment_size * HOP_LENGTH / SAMPLE_RATE + mel_segment = pad_or_trim(mel_segment, + N_FRAMES).to(model.device).astype(dtype) + + if carry_initial_prompt: + nignored = max(len(initial_prompt_tokens), prompt_reset_since) + remaining_prompt = all_tokens[nignored:][ + -remaining_prompt_length:] + decode_options[ + "prompt"] = initial_prompt_tokens + remaining_prompt + else: + decode_options["prompt"] = all_tokens[prompt_reset_since:] - decode_options["prompt"] = all_tokens[prompt_reset_since:] - result: DecodingResult = decode_with_fallback(segment) + result: DecodingResult = decode_with_fallback(mel_segment) tokens = paddle.to_tensor(result.tokens) if no_speech_threshold is not None: # no voice activity check should_skip = result.no_speech_prob > no_speech_threshold - if logprob_threshold is not None and result.avg_logprob > logprob_threshold: + if (logprob_threshold is not None and + result.avg_logprob > logprob_threshold): # don't skip if the logprob is high enough, despite the no_speech_prob should_skip = False if should_skip: - seek += segment.shape[ - -1] # fast-forward to the next segment boundary + seek += segment_size # fast-forward to the next segment boundary continue + previous_seek = seek + current_segments = [] + + # anomalous words are very long/short/improbable + def word_anomaly_score(word: dict) -> float: + probability = word.get("probability", 0.0) + duration = word["end"] - word["start"] + score = 0.0 + if probability < 0.15: + score += 1.0 + if duration < 0.133: + score += (0.133 - duration) * 15 + if duration > 2.0: + score += duration - 2.0 + return score + + def is_segment_anomaly(segment: Optional[dict]) -> bool: + if segment is None or not segment["words"]: + return False + words = [ + w for w in segment["words"] if w["word"] not in punctuation + ] + words = words[:8] + score = sum(word_anomaly_score(w) for w in words) + return score >= 3 or score + 0.01 >= len(words) + + def next_words_segment(segments: List[dict]) -> Optional[dict]: + return next((s for s in segments if s["words"]), None) + timestamp_tokens: paddle.Tensor = tokens.greater_equal( paddle.to_tensor(tokenizer.timestamp_begin)) + single_timestamp_ending = timestamp_tokens[ + -2:].tolist() == [False, True] consecutive = paddle.where(timestamp_tokens[:-1] & timestamp_tokens[ 1:])[0] - if len( - consecutive - ) > 0: # if the output contains two consecutive timestamp tokens - consecutive = paddle.add(consecutive, paddle.to_tensor(1)) + print("consecutive", consecutive) + consecutive = paddle.add(consecutive, paddle.to_tensor(1)) + if len(consecutive) > 0: + # if the output contains two consecutive timestamp tokens + slices = consecutive.tolist() + if single_timestamp_ending: + slices.append(len(tokens)) + last_slice = 0 - for current_slice in consecutive: + for current_slice in slices: sliced_tokens = tokens[last_slice:current_slice] - start_timestamp_position = ( + start_timestamp_pos = ( sliced_tokens[0].item() - tokenizer.timestamp_begin) - end_timestamp_position = ( + end_timestamp_pos = ( sliced_tokens[-1].item() - tokenizer.timestamp_begin) - add_segment( - start=timestamp_offset + start_timestamp_position * - time_precision, - end=timestamp_offset + end_timestamp_position * - time_precision, - text_tokens=sliced_tokens[1:-1], - result=result, ) + current_segments.append( + new_segment( + start=time_offset + start_timestamp_pos * + time_precision, + end=time_offset + end_timestamp_pos * + time_precision, + tokens=sliced_tokens, + result=result, )) last_slice = current_slice - last_timestamp_position = ( - tokens[last_slice - 1].item() - tokenizer.timestamp_begin) - seek += last_timestamp_position * input_stride - all_tokens.extend(tokens[:last_slice + 1].tolist()) + + if single_timestamp_ending: + # single timestamp at the end means no speech after the last timestamp. + seek += segment_size + else: + # otherwise, ignore the unfinished segment and seek to the last timestamp + last_timestamp_pos = (tokens[last_slice - 1].item() - + tokenizer.timestamp_begin) + seek += last_timestamp_pos * input_stride else: duration = segment_duration timestamps = tokens[timestamp_tokens.nonzero().flatten()] - if len(timestamps) > 0 and timestamps[ - -1].item() != tokenizer.timestamp_begin: + if (len(timestamps) > 0 and + timestamps[-1].item() != tokenizer.timestamp_begin): # no consecutive timestamps but it has a timestamp; use the last one. - # single timestamp at the end means no speech after the last timestamp. - last_timestamp_position = timestamps[ - -1].item() - tokenizer.timestamp_begin - duration = last_timestamp_position * time_precision - - add_segment( - start=timestamp_offset, - end=timestamp_offset + duration, - text_tokens=tokens, - result=result, ) + last_timestamp_pos = ( + timestamps[-1].item() - tokenizer.timestamp_begin) + duration = last_timestamp_pos * time_precision + + current_segments.append( + new_segment( + start=time_offset, + end=time_offset + duration, + tokens=tokens, + result=result, )) + seek += segment_size + + if word_timestamps: + add_word_timestamps( + segments=current_segments, + model=model, + tokenizer=tokenizer, + mel=mel_segment, + num_frames=segment_size, + prepend_punctuations=prepend_punctuations, + append_punctuations=append_punctuations, + last_speech_timestamp=last_speech_timestamp, ) + + if not single_timestamp_ending: + last_word_end = get_end(current_segments) + if last_word_end is not None and last_word_end > time_offset: + seek = round(last_word_end * FRAMES_PER_SECOND) + + # skip silence before possible hallucinations + if hallucination_silence_threshold is not None: + threshold = hallucination_silence_threshold + if not single_timestamp_ending: + last_word_end = get_end(current_segments) + if last_word_end is not None and last_word_end > time_offset: + remaining_duration = window_end_time - last_word_end + if remaining_duration > threshold: + seek = round(last_word_end * FRAMES_PER_SECOND) + else: + seek = previous_seek + segment_size + + # if first segment might be a hallucination, skip leading silence + first_segment = next_words_segment(current_segments) + if first_segment is not None and is_segment_anomaly( + first_segment): + gap = first_segment["start"] - time_offset + if gap > threshold: + seek = previous_seek + round(gap * + FRAMES_PER_SECOND) + continue + + # skip silence before any possible hallucination that is surrounded + # by silence or more hallucinations + hal_last_end = last_speech_timestamp + for si in range(len(current_segments)): + segment = current_segments[si] + if not segment["words"]: + continue + if is_segment_anomaly(segment): + next_segment = next_words_segment( + current_segments[si + 1:]) + if next_segment is not None: + hal_next_start = next_segment["words"][0][ + "start"] + else: + hal_next_start = time_offset + segment_duration + silence_before = ( + segment["start"] - hal_last_end > threshold or + segment["start"] < threshold or + segment["start"] - time_offset < 2.0) + silence_after = ( + hal_next_start - segment["end"] > threshold or + is_segment_anomaly(next_segment) or + window_end_time - segment["end"] < 2.0) + if silence_before and silence_after: + seek = round( + max(time_offset + 1, segment["start"]) * + FRAMES_PER_SECOND) + if content_duration - segment[ + "end"] < threshold: + seek = content_frames + current_segments[si:] = [] + break + hal_last_end = segment["end"] + + last_word_end = get_end(current_segments) + if last_word_end is not None: + last_speech_timestamp = last_word_end - seek += segment.shape[-1] - all_tokens.extend(tokens.tolist()) + if verbose: + for segment in current_segments: + start, end, text = segment["start"], segment[ + "end"], segment["text"] + line = f"[{utils.format_timestamp(start)} --> {utils.format_timestamp(end)}] {text}" + print(line) + + # if a segment is instantaneous or does not contain text, clear it + for i, segment in enumerate(current_segments): + if segment["start"] == segment["end"] or segment[ + "text"].strip() == "": + segment["text"] = "" + segment["tokens"] = [] + segment["words"] = [] + + all_segments.extend( + [{ + "id": i, + ** + segment + } + for i, segment in enumerate( + current_segments, start=len(all_segments))]) + all_tokens.extend([ + token + for segment in current_segments for token in segment["tokens"] + ]) if not condition_on_previous_text or result.temperature > 0.5: # do not feed the prompt tokens if a high temperature was used prompt_reset_since = len(all_tokens) # update progress bar - pbar.update(min(num_frames, seek) - previous_seek_value) - previous_seek_value = seek + pbar.update(min(content_frames, seek) - previous_seek) return dict( - text=tokenizer.decode(all_tokens[len(initial_prompt):]), + text=tokenizer.decode(all_tokens[len(initial_prompt_tokens):]), segments=all_segments, - language=language) + language=language, ) class SequenceRanker: @@ -776,11 +986,11 @@ class GreedyDecoder(TokenDecoder): next_tokens.shape[0] * next_tokens.shape[1], ]) - logprobs = F.log_softmax(logits, axis=-1, dtype=paddle.float32) + logprobs = F.log_softmax(logits, axis=-1, dtype="float32") current_logprobs = logprobs[paddle.arange(logprobs.shape[0]), next_tokens] sum_logprobs += current_logprobs * paddle.to_tensor( - (tokens[:, -1] != self.eot), dtype=paddle.float32) + (tokens[:, -1] != self.eot), dtype="float32") next_tokens[tokens[:, -1] == self.eot] = self.eot tokens = paddle.concat([tokens, next_tokens[:, None]], axis=-1) @@ -928,8 +1138,8 @@ class SuppressBlank(LogitFilter): def apply(self, logits: paddle.Tensor, tokens: paddle.Tensor): if tokens.shape[1] == self.sample_begin: - logits[:, self.tokenizer.encode(" ").input_ids + - [self.tokenizer.eot]] = -np.inf + logits[:, self.tokenizer.encode(" ") + [self.tokenizer.eot + ]] = -np.inf class SuppressTokens(LogitFilter): @@ -1005,7 +1215,7 @@ class DecodingTask: language = options.language or "en" tokenizer = get_tokenizer( - model.is_multilingual, + multilingual=model.is_multilingual, resource_path=resource_path, language=language, task=options.task) @@ -1346,11 +1556,16 @@ class Whisper(nn.Layer): @property def device(self): + # return str(paddle.device.get_device()).split(":")[0] return paddle.device.get_device() @property def is_multilingual(self): - return self.dims.n_vocab == 51865 + return self.dims.n_vocab >= 51865 + + @property + def num_languages(self): + return self.dims.n_vocab - 51765 - int(self.is_multilingual) def install_kv_cache_hooks(self, cache: Optional[dict]=None): """ @@ -1364,7 +1579,7 @@ class Whisper(nn.Layer): cache : Dict[nn.Layer, paddle.Tensor] A dictionary object mapping the key/value projection modules to its cache hooks : List[RemovableHandle] - List of PyTorch RemovableHandle objects to stop the hooks to be called + List of Paddle RemovableHandle objects to stop the hooks to be called """ cache = {**cache} if cache is not None else {} hooks = [] @@ -1435,7 +1650,7 @@ def hann_window(n_fft: int=N_FFT): @lru_cache(maxsize=None) -def mel_filters(resource_path: str, n_mels: int=N_MELS) -> paddle.Tensor: +def mel_filters(resource_path: str, n_mels: int) -> paddle.Tensor: """ load the mel filterbank matrix for projecting STFT into a Mel spectrogram. Allows decoupling librosa dependency; saved using: @@ -1445,13 +1660,19 @@ def mel_filters(resource_path: str, n_mels: int=N_MELS) -> paddle.Tensor: mel_80=librosa.filters.mel(sr=16000, n_fft=400, n_mels=80), ) """ - assert n_mels == 80, f"Unsupported n_mels: {n_mels}" - with np.load(os.path.join(resource_path, "assets", "mel_filters.npz")) as f: + # assert n_mels == 80, f"Unsupported n_mels: {n_mels}" + # with np.load(os.path.join(resource_path, "assets", "mel_filters.npz")) as f: + # return paddle.to_tensor(f[f"mel_{n_mels}"]) + assert n_mels in {80, 128}, f"Unsupported n_mels: {n_mels}" + + filters_path = os.path.join(resource_path, "assets", "mel_filters.npz") + with np.load(filters_path, allow_pickle=False) as f: return paddle.to_tensor(f[f"mel_{n_mels}"]) def log_mel_spectrogram(audio: Union[str, np.ndarray, paddle.Tensor], - n_mels: int=N_MELS, + n_mels: int=80, + padding: int=0, resource_path: str=None): """ Compute the log-Mel spectrogram of @@ -1475,7 +1696,8 @@ def log_mel_spectrogram(audio: Union[str, np.ndarray, paddle.Tensor], audio = audio[:, 0] logger.info(f"audio shape: {audio.shape}") audio = paddle.to_tensor(audio) - + if padding > 0: + audio = F.pad(audio, (0, padding), data_format="NLC") window = hann_window(N_FFT) stft = paddle.signal.stft(audio, N_FFT, HOP_LENGTH, window=window)