|
|
|
@ -33,13 +33,18 @@ logger = Log(__name__).getlog()
|
|
|
|
|
_MODELS = ["large"]
|
|
|
|
|
SAMPLE_RATE = 16000
|
|
|
|
|
N_FFT = 400
|
|
|
|
|
N_MELS = 80
|
|
|
|
|
HOP_LENGTH = 160
|
|
|
|
|
CHUNK_LENGTH = 30
|
|
|
|
|
N_SAMPLES = CHUNK_LENGTH * SAMPLE_RATE # 480000: number of samples in a chunk
|
|
|
|
|
N_FRAMES = utils.exact_div(
|
|
|
|
|
N_SAMPLES, HOP_LENGTH) # 3000: number of frames in a mel spectrogram input
|
|
|
|
|
|
|
|
|
|
N_SAMPLES_PER_TOKEN = HOP_LENGTH * 2 # the initial convolutions has stride 2
|
|
|
|
|
FRAMES_PER_SECOND = utils.exact_div(SAMPLE_RATE,
|
|
|
|
|
HOP_LENGTH) # 10ms per audio frame
|
|
|
|
|
TOKENS_PER_SECOND = utils.exact_div(SAMPLE_RATE,
|
|
|
|
|
N_SAMPLES_PER_TOKEN) # 20ms per audio token
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
@dataclass
|
|
|
|
|
class ModelDimensions:
|
|
|
|
@ -378,7 +383,9 @@ def detect_language(
|
|
|
|
|
"""
|
|
|
|
|
if tokenizer is None:
|
|
|
|
|
tokenizer = get_tokenizer(
|
|
|
|
|
model.is_multilingual, resource_path=resource_path)
|
|
|
|
|
multilingual=model.is_multilingual,
|
|
|
|
|
resource_path=resource_path,
|
|
|
|
|
num_languages=model.num_languages)
|
|
|
|
|
if tokenizer.language is None or tokenizer.language_token not in tokenizer.sot_sequence:
|
|
|
|
|
raise ValueError(
|
|
|
|
|
"This model doesn't have language tokens so it can't perform lang id"
|
|
|
|
@ -428,6 +435,13 @@ def transcribe(
|
|
|
|
|
logprob_threshold: Optional[float]=-1.0,
|
|
|
|
|
no_speech_threshold: Optional[float]=0.6,
|
|
|
|
|
condition_on_previous_text: bool=True,
|
|
|
|
|
initial_prompt: Optional[str]=None,
|
|
|
|
|
carry_initial_prompt: bool=False,
|
|
|
|
|
word_timestamps: bool=False,
|
|
|
|
|
prepend_punctuations: str="\"'“¿([{-",
|
|
|
|
|
append_punctuations: str="\"'.。,,!!??::”)]}、",
|
|
|
|
|
clip_timestamps: Union[str, List[float]]="0",
|
|
|
|
|
hallucination_silence_threshold: Optional[float]=None,
|
|
|
|
|
**decode_options, ):
|
|
|
|
|
"""
|
|
|
|
|
Transcribe an audio file using Whisper
|
|
|
|
@ -476,8 +490,11 @@ def transcribe(
|
|
|
|
|
if dtype == np.float32:
|
|
|
|
|
decode_options["fp16"] = False
|
|
|
|
|
|
|
|
|
|
if decode_options.get("language") == 'None' or decode_options.get(
|
|
|
|
|
"language", None) is None:
|
|
|
|
|
content_frames = mel.shape[-1] - N_FRAMES
|
|
|
|
|
content_duration = float(content_frames * HOP_LENGTH / SAMPLE_RATE)
|
|
|
|
|
# import pdb
|
|
|
|
|
# pdb.set_trace()
|
|
|
|
|
if decode_options.get("language", None) in {None, "None"}:
|
|
|
|
|
if not model.is_multilingual:
|
|
|
|
|
decode_options["language"] = "en"
|
|
|
|
|
else:
|
|
|
|
@ -485,25 +502,49 @@ def transcribe(
|
|
|
|
|
print(
|
|
|
|
|
"Detecting language using up to the first 30 seconds. Use `--language` to specify the language"
|
|
|
|
|
)
|
|
|
|
|
segment = pad_or_trim(mel, N_FRAMES)
|
|
|
|
|
_, probs = model.detect_language(segment, resource_path)
|
|
|
|
|
mel_segment = pad_or_trim(mel,
|
|
|
|
|
N_FRAMES).to(model.device).astype(dtype)
|
|
|
|
|
_, probs = model.detect_language(mel_segment, resource_path)
|
|
|
|
|
decode_options["language"] = max(probs, key=probs.get)
|
|
|
|
|
if verbose is not None:
|
|
|
|
|
print(
|
|
|
|
|
f"Detected language: {LANGUAGES[decode_options['language']].title()}"
|
|
|
|
|
)
|
|
|
|
|
|
|
|
|
|
language = decode_options["language"]
|
|
|
|
|
task = decode_options.get("task", "transcribe")
|
|
|
|
|
language: str = decode_options["language"]
|
|
|
|
|
print("language", language)
|
|
|
|
|
task: str = decode_options.get("task", "transcribe")
|
|
|
|
|
print("model.num_languages", model.num_languages)
|
|
|
|
|
tokenizer = get_tokenizer(
|
|
|
|
|
model.is_multilingual,
|
|
|
|
|
multilingual=model.is_multilingual,
|
|
|
|
|
resource_path=resource_path,
|
|
|
|
|
num_languages=model.num_languages,
|
|
|
|
|
language=language,
|
|
|
|
|
task=task)
|
|
|
|
|
task=task, )
|
|
|
|
|
|
|
|
|
|
if isinstance(clip_timestamps, str):
|
|
|
|
|
clip_timestamps = [
|
|
|
|
|
float(ts)
|
|
|
|
|
for ts in (clip_timestamps.split(",") if clip_timestamps else [])
|
|
|
|
|
]
|
|
|
|
|
seek_points: List[
|
|
|
|
|
int] = [round(ts * FRAMES_PER_SECOND) for ts in clip_timestamps]
|
|
|
|
|
if len(seek_points) == 0:
|
|
|
|
|
seek_points.append(0)
|
|
|
|
|
if len(seek_points) % 2 == 1:
|
|
|
|
|
seek_points.append(content_frames)
|
|
|
|
|
seek_clips: List[Tuple[int, int]] = list(
|
|
|
|
|
zip(seek_points[::2], seek_points[1::2]))
|
|
|
|
|
|
|
|
|
|
punctuation = "\"'“¿([{-\"'.。,,!!??::”)]}、"
|
|
|
|
|
|
|
|
|
|
if word_timestamps and task == "translate":
|
|
|
|
|
warnings.warn(
|
|
|
|
|
"Word-level timestamps on translations may not be reliable.")
|
|
|
|
|
|
|
|
|
|
def decode_with_fallback(segment: paddle.Tensor) -> DecodingResult:
|
|
|
|
|
temperatures = [temperature] if isinstance(temperature, (
|
|
|
|
|
int, float)) else temperature
|
|
|
|
|
temperatures = ([temperature] if isinstance(temperature, (int, float))
|
|
|
|
|
else temperature)
|
|
|
|
|
decode_result = None
|
|
|
|
|
|
|
|
|
|
for t in temperatures:
|
|
|
|
@ -517,20 +558,29 @@ def transcribe(
|
|
|
|
|
kwargs.pop("best_of", None)
|
|
|
|
|
|
|
|
|
|
options = DecodingOptions(**kwargs, temperature=t)
|
|
|
|
|
|
|
|
|
|
decode_result = model.decode(segment, options, resource_path)
|
|
|
|
|
|
|
|
|
|
needs_fallback = False
|
|
|
|
|
if compression_ratio_threshold is not None and decode_result.compression_ratio > compression_ratio_threshold:
|
|
|
|
|
if (compression_ratio_threshold is not None and
|
|
|
|
|
decode_result.compression_ratio >
|
|
|
|
|
compression_ratio_threshold):
|
|
|
|
|
needs_fallback = True # too repetitive
|
|
|
|
|
if logprob_threshold is not None and decode_result.avg_logprob < logprob_threshold:
|
|
|
|
|
if (logprob_threshold is not None and
|
|
|
|
|
decode_result.avg_logprob < logprob_threshold):
|
|
|
|
|
needs_fallback = True # average log probability is too low
|
|
|
|
|
|
|
|
|
|
if (no_speech_threshold is not None and
|
|
|
|
|
decode_result.no_speech_prob > no_speech_threshold and
|
|
|
|
|
logprob_threshold is not None and
|
|
|
|
|
decode_result.avg_logprob < logprob_threshold):
|
|
|
|
|
needs_fallback = False # silence
|
|
|
|
|
if not needs_fallback:
|
|
|
|
|
break
|
|
|
|
|
|
|
|
|
|
return decode_result
|
|
|
|
|
|
|
|
|
|
seek = 0
|
|
|
|
|
clip_idx = 0
|
|
|
|
|
seek = seek_clips[clip_idx][0]
|
|
|
|
|
input_stride = utils.exact_div(
|
|
|
|
|
N_FRAMES, model.dims.n_audio_ctx) # mel frames per output token: 2
|
|
|
|
|
time_precision = (input_stride * HOP_LENGTH /
|
|
|
|
@ -539,127 +589,287 @@ def transcribe(
|
|
|
|
|
all_segments = []
|
|
|
|
|
prompt_reset_since = 0
|
|
|
|
|
|
|
|
|
|
initial_prompt = decode_options.pop("initial_prompt", None) or []
|
|
|
|
|
if initial_prompt:
|
|
|
|
|
initial_prompt = tokenizer.encode(" " +
|
|
|
|
|
initial_prompt.strip()).input_ids
|
|
|
|
|
all_tokens.extend(initial_prompt)
|
|
|
|
|
remaining_prompt_length = model.dims.n_text_ctx // 2 - 1
|
|
|
|
|
if initial_prompt is not None:
|
|
|
|
|
initial_prompt_tokens = tokenizer.encode(" " + initial_prompt.strip())
|
|
|
|
|
all_tokens.extend(initial_prompt_tokens)
|
|
|
|
|
remaining_prompt_length -= len(initial_prompt_tokens)
|
|
|
|
|
else:
|
|
|
|
|
initial_prompt_tokens = []
|
|
|
|
|
|
|
|
|
|
def add_segment(*,
|
|
|
|
|
def new_segment(*,
|
|
|
|
|
start: float,
|
|
|
|
|
end: float,
|
|
|
|
|
text_tokens: paddle.Tensor,
|
|
|
|
|
tokens: paddle.Tensor,
|
|
|
|
|
result: DecodingResult):
|
|
|
|
|
text = tokenizer.decode(
|
|
|
|
|
[token for token in text_tokens if token < tokenizer.eot])
|
|
|
|
|
if len(text.strip()) == 0: # skip empty text output
|
|
|
|
|
return
|
|
|
|
|
|
|
|
|
|
all_segments.append({
|
|
|
|
|
"id": len(all_segments),
|
|
|
|
|
tokens = tokens.tolist()
|
|
|
|
|
text_tokens = [token for token in tokens if token < tokenizer.eot]
|
|
|
|
|
return {
|
|
|
|
|
"seek": seek,
|
|
|
|
|
"start": start,
|
|
|
|
|
"end": end,
|
|
|
|
|
"text": text,
|
|
|
|
|
"tokens": result.tokens,
|
|
|
|
|
"text": tokenizer.decode(text_tokens),
|
|
|
|
|
"tokens": tokens,
|
|
|
|
|
"temperature": result.temperature,
|
|
|
|
|
"avg_logprob": result.avg_logprob,
|
|
|
|
|
"compression_ratio": result.compression_ratio,
|
|
|
|
|
"no_speech_prob": result.no_speech_prob,
|
|
|
|
|
})
|
|
|
|
|
if verbose:
|
|
|
|
|
print(
|
|
|
|
|
f"[{utils.format_timestamp(start)} --> {utils.format_timestamp(end)}] {text}"
|
|
|
|
|
)
|
|
|
|
|
|
|
|
|
|
# show the progress bar when verbose is False (otherwise the transcribed text will be printed)
|
|
|
|
|
num_frames = mel.shape[-1]
|
|
|
|
|
previous_seek_value = seek
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
# show the progress bar when verbose is False (if True, transcribed text will be printed)
|
|
|
|
|
with tqdm.tqdm(
|
|
|
|
|
total=num_frames, unit='frames',
|
|
|
|
|
total=content_frames, unit="frames",
|
|
|
|
|
disable=verbose is not False) as pbar:
|
|
|
|
|
while seek < num_frames:
|
|
|
|
|
timestamp_offset = float(seek * HOP_LENGTH / SAMPLE_RATE)
|
|
|
|
|
segment = pad_or_trim(mel[:, seek:], N_FRAMES)
|
|
|
|
|
segment_duration = segment.shape[-1] * HOP_LENGTH / SAMPLE_RATE
|
|
|
|
|
last_speech_timestamp = 0.0
|
|
|
|
|
# NOTE: This loop is obscurely flattened to make the diff readable.
|
|
|
|
|
# A later commit should turn this into a simpler nested loop.
|
|
|
|
|
# for seek_clip_start, seek_clip_end in seek_clips:
|
|
|
|
|
# while seek < seek_clip_end
|
|
|
|
|
while clip_idx < len(seek_clips):
|
|
|
|
|
seek_clip_start, seek_clip_end = seek_clips[clip_idx]
|
|
|
|
|
if seek < seek_clip_start:
|
|
|
|
|
seek = seek_clip_start
|
|
|
|
|
if seek >= seek_clip_end:
|
|
|
|
|
clip_idx += 1
|
|
|
|
|
if clip_idx < len(seek_clips):
|
|
|
|
|
seek = seek_clips[clip_idx][0]
|
|
|
|
|
continue
|
|
|
|
|
time_offset = float(seek * HOP_LENGTH / SAMPLE_RATE)
|
|
|
|
|
window_end_time = float(
|
|
|
|
|
(seek + N_FRAMES) * HOP_LENGTH / SAMPLE_RATE)
|
|
|
|
|
segment_size = min(N_FRAMES, content_frames - seek,
|
|
|
|
|
seek_clip_end - seek)
|
|
|
|
|
mel_segment = mel[:, seek:seek + segment_size]
|
|
|
|
|
segment_duration = segment_size * HOP_LENGTH / SAMPLE_RATE
|
|
|
|
|
mel_segment = pad_or_trim(mel_segment,
|
|
|
|
|
N_FRAMES).to(model.device).astype(dtype)
|
|
|
|
|
|
|
|
|
|
if carry_initial_prompt:
|
|
|
|
|
nignored = max(len(initial_prompt_tokens), prompt_reset_since)
|
|
|
|
|
remaining_prompt = all_tokens[nignored:][
|
|
|
|
|
-remaining_prompt_length:]
|
|
|
|
|
decode_options[
|
|
|
|
|
"prompt"] = initial_prompt_tokens + remaining_prompt
|
|
|
|
|
else:
|
|
|
|
|
decode_options["prompt"] = all_tokens[prompt_reset_since:]
|
|
|
|
|
|
|
|
|
|
decode_options["prompt"] = all_tokens[prompt_reset_since:]
|
|
|
|
|
result: DecodingResult = decode_with_fallback(segment)
|
|
|
|
|
result: DecodingResult = decode_with_fallback(mel_segment)
|
|
|
|
|
tokens = paddle.to_tensor(result.tokens)
|
|
|
|
|
|
|
|
|
|
if no_speech_threshold is not None:
|
|
|
|
|
# no voice activity check
|
|
|
|
|
should_skip = result.no_speech_prob > no_speech_threshold
|
|
|
|
|
if logprob_threshold is not None and result.avg_logprob > logprob_threshold:
|
|
|
|
|
if (logprob_threshold is not None and
|
|
|
|
|
result.avg_logprob > logprob_threshold):
|
|
|
|
|
# don't skip if the logprob is high enough, despite the no_speech_prob
|
|
|
|
|
should_skip = False
|
|
|
|
|
|
|
|
|
|
if should_skip:
|
|
|
|
|
seek += segment.shape[
|
|
|
|
|
-1] # fast-forward to the next segment boundary
|
|
|
|
|
seek += segment_size # fast-forward to the next segment boundary
|
|
|
|
|
continue
|
|
|
|
|
|
|
|
|
|
previous_seek = seek
|
|
|
|
|
current_segments = []
|
|
|
|
|
|
|
|
|
|
# anomalous words are very long/short/improbable
|
|
|
|
|
def word_anomaly_score(word: dict) -> float:
|
|
|
|
|
probability = word.get("probability", 0.0)
|
|
|
|
|
duration = word["end"] - word["start"]
|
|
|
|
|
score = 0.0
|
|
|
|
|
if probability < 0.15:
|
|
|
|
|
score += 1.0
|
|
|
|
|
if duration < 0.133:
|
|
|
|
|
score += (0.133 - duration) * 15
|
|
|
|
|
if duration > 2.0:
|
|
|
|
|
score += duration - 2.0
|
|
|
|
|
return score
|
|
|
|
|
|
|
|
|
|
def is_segment_anomaly(segment: Optional[dict]) -> bool:
|
|
|
|
|
if segment is None or not segment["words"]:
|
|
|
|
|
return False
|
|
|
|
|
words = [
|
|
|
|
|
w for w in segment["words"] if w["word"] not in punctuation
|
|
|
|
|
]
|
|
|
|
|
words = words[:8]
|
|
|
|
|
score = sum(word_anomaly_score(w) for w in words)
|
|
|
|
|
return score >= 3 or score + 0.01 >= len(words)
|
|
|
|
|
|
|
|
|
|
def next_words_segment(segments: List[dict]) -> Optional[dict]:
|
|
|
|
|
return next((s for s in segments if s["words"]), None)
|
|
|
|
|
|
|
|
|
|
timestamp_tokens: paddle.Tensor = tokens.greater_equal(
|
|
|
|
|
paddle.to_tensor(tokenizer.timestamp_begin))
|
|
|
|
|
single_timestamp_ending = timestamp_tokens[
|
|
|
|
|
-2:].tolist() == [False, True]
|
|
|
|
|
|
|
|
|
|
consecutive = paddle.where(timestamp_tokens[:-1] & timestamp_tokens[
|
|
|
|
|
1:])[0]
|
|
|
|
|
if len(
|
|
|
|
|
consecutive
|
|
|
|
|
) > 0: # if the output contains two consecutive timestamp tokens
|
|
|
|
|
consecutive = paddle.add(consecutive, paddle.to_tensor(1))
|
|
|
|
|
print("consecutive", consecutive)
|
|
|
|
|
consecutive = paddle.add(consecutive, paddle.to_tensor(1))
|
|
|
|
|
if len(consecutive) > 0:
|
|
|
|
|
# if the output contains two consecutive timestamp tokens
|
|
|
|
|
slices = consecutive.tolist()
|
|
|
|
|
if single_timestamp_ending:
|
|
|
|
|
slices.append(len(tokens))
|
|
|
|
|
|
|
|
|
|
last_slice = 0
|
|
|
|
|
for current_slice in consecutive:
|
|
|
|
|
for current_slice in slices:
|
|
|
|
|
sliced_tokens = tokens[last_slice:current_slice]
|
|
|
|
|
start_timestamp_position = (
|
|
|
|
|
start_timestamp_pos = (
|
|
|
|
|
sliced_tokens[0].item() - tokenizer.timestamp_begin)
|
|
|
|
|
end_timestamp_position = (
|
|
|
|
|
end_timestamp_pos = (
|
|
|
|
|
sliced_tokens[-1].item() - tokenizer.timestamp_begin)
|
|
|
|
|
add_segment(
|
|
|
|
|
start=timestamp_offset + start_timestamp_position *
|
|
|
|
|
time_precision,
|
|
|
|
|
end=timestamp_offset + end_timestamp_position *
|
|
|
|
|
time_precision,
|
|
|
|
|
text_tokens=sliced_tokens[1:-1],
|
|
|
|
|
result=result, )
|
|
|
|
|
current_segments.append(
|
|
|
|
|
new_segment(
|
|
|
|
|
start=time_offset + start_timestamp_pos *
|
|
|
|
|
time_precision,
|
|
|
|
|
end=time_offset + end_timestamp_pos *
|
|
|
|
|
time_precision,
|
|
|
|
|
tokens=sliced_tokens,
|
|
|
|
|
result=result, ))
|
|
|
|
|
last_slice = current_slice
|
|
|
|
|
last_timestamp_position = (
|
|
|
|
|
tokens[last_slice - 1].item() - tokenizer.timestamp_begin)
|
|
|
|
|
seek += last_timestamp_position * input_stride
|
|
|
|
|
all_tokens.extend(tokens[:last_slice + 1].tolist())
|
|
|
|
|
|
|
|
|
|
if single_timestamp_ending:
|
|
|
|
|
# single timestamp at the end means no speech after the last timestamp.
|
|
|
|
|
seek += segment_size
|
|
|
|
|
else:
|
|
|
|
|
# otherwise, ignore the unfinished segment and seek to the last timestamp
|
|
|
|
|
last_timestamp_pos = (tokens[last_slice - 1].item() -
|
|
|
|
|
tokenizer.timestamp_begin)
|
|
|
|
|
seek += last_timestamp_pos * input_stride
|
|
|
|
|
else:
|
|
|
|
|
duration = segment_duration
|
|
|
|
|
timestamps = tokens[timestamp_tokens.nonzero().flatten()]
|
|
|
|
|
if len(timestamps) > 0 and timestamps[
|
|
|
|
|
-1].item() != tokenizer.timestamp_begin:
|
|
|
|
|
if (len(timestamps) > 0 and
|
|
|
|
|
timestamps[-1].item() != tokenizer.timestamp_begin):
|
|
|
|
|
# no consecutive timestamps but it has a timestamp; use the last one.
|
|
|
|
|
# single timestamp at the end means no speech after the last timestamp.
|
|
|
|
|
last_timestamp_position = timestamps[
|
|
|
|
|
-1].item() - tokenizer.timestamp_begin
|
|
|
|
|
duration = last_timestamp_position * time_precision
|
|
|
|
|
|
|
|
|
|
add_segment(
|
|
|
|
|
start=timestamp_offset,
|
|
|
|
|
end=timestamp_offset + duration,
|
|
|
|
|
text_tokens=tokens,
|
|
|
|
|
result=result, )
|
|
|
|
|
last_timestamp_pos = (
|
|
|
|
|
timestamps[-1].item() - tokenizer.timestamp_begin)
|
|
|
|
|
duration = last_timestamp_pos * time_precision
|
|
|
|
|
|
|
|
|
|
current_segments.append(
|
|
|
|
|
new_segment(
|
|
|
|
|
start=time_offset,
|
|
|
|
|
end=time_offset + duration,
|
|
|
|
|
tokens=tokens,
|
|
|
|
|
result=result, ))
|
|
|
|
|
seek += segment_size
|
|
|
|
|
|
|
|
|
|
if word_timestamps:
|
|
|
|
|
add_word_timestamps(
|
|
|
|
|
segments=current_segments,
|
|
|
|
|
model=model,
|
|
|
|
|
tokenizer=tokenizer,
|
|
|
|
|
mel=mel_segment,
|
|
|
|
|
num_frames=segment_size,
|
|
|
|
|
prepend_punctuations=prepend_punctuations,
|
|
|
|
|
append_punctuations=append_punctuations,
|
|
|
|
|
last_speech_timestamp=last_speech_timestamp, )
|
|
|
|
|
|
|
|
|
|
if not single_timestamp_ending:
|
|
|
|
|
last_word_end = get_end(current_segments)
|
|
|
|
|
if last_word_end is not None and last_word_end > time_offset:
|
|
|
|
|
seek = round(last_word_end * FRAMES_PER_SECOND)
|
|
|
|
|
|
|
|
|
|
# skip silence before possible hallucinations
|
|
|
|
|
if hallucination_silence_threshold is not None:
|
|
|
|
|
threshold = hallucination_silence_threshold
|
|
|
|
|
if not single_timestamp_ending:
|
|
|
|
|
last_word_end = get_end(current_segments)
|
|
|
|
|
if last_word_end is not None and last_word_end > time_offset:
|
|
|
|
|
remaining_duration = window_end_time - last_word_end
|
|
|
|
|
if remaining_duration > threshold:
|
|
|
|
|
seek = round(last_word_end * FRAMES_PER_SECOND)
|
|
|
|
|
else:
|
|
|
|
|
seek = previous_seek + segment_size
|
|
|
|
|
|
|
|
|
|
# if first segment might be a hallucination, skip leading silence
|
|
|
|
|
first_segment = next_words_segment(current_segments)
|
|
|
|
|
if first_segment is not None and is_segment_anomaly(
|
|
|
|
|
first_segment):
|
|
|
|
|
gap = first_segment["start"] - time_offset
|
|
|
|
|
if gap > threshold:
|
|
|
|
|
seek = previous_seek + round(gap *
|
|
|
|
|
FRAMES_PER_SECOND)
|
|
|
|
|
continue
|
|
|
|
|
|
|
|
|
|
# skip silence before any possible hallucination that is surrounded
|
|
|
|
|
# by silence or more hallucinations
|
|
|
|
|
hal_last_end = last_speech_timestamp
|
|
|
|
|
for si in range(len(current_segments)):
|
|
|
|
|
segment = current_segments[si]
|
|
|
|
|
if not segment["words"]:
|
|
|
|
|
continue
|
|
|
|
|
if is_segment_anomaly(segment):
|
|
|
|
|
next_segment = next_words_segment(
|
|
|
|
|
current_segments[si + 1:])
|
|
|
|
|
if next_segment is not None:
|
|
|
|
|
hal_next_start = next_segment["words"][0][
|
|
|
|
|
"start"]
|
|
|
|
|
else:
|
|
|
|
|
hal_next_start = time_offset + segment_duration
|
|
|
|
|
silence_before = (
|
|
|
|
|
segment["start"] - hal_last_end > threshold or
|
|
|
|
|
segment["start"] < threshold or
|
|
|
|
|
segment["start"] - time_offset < 2.0)
|
|
|
|
|
silence_after = (
|
|
|
|
|
hal_next_start - segment["end"] > threshold or
|
|
|
|
|
is_segment_anomaly(next_segment) or
|
|
|
|
|
window_end_time - segment["end"] < 2.0)
|
|
|
|
|
if silence_before and silence_after:
|
|
|
|
|
seek = round(
|
|
|
|
|
max(time_offset + 1, segment["start"]) *
|
|
|
|
|
FRAMES_PER_SECOND)
|
|
|
|
|
if content_duration - segment[
|
|
|
|
|
"end"] < threshold:
|
|
|
|
|
seek = content_frames
|
|
|
|
|
current_segments[si:] = []
|
|
|
|
|
break
|
|
|
|
|
hal_last_end = segment["end"]
|
|
|
|
|
|
|
|
|
|
last_word_end = get_end(current_segments)
|
|
|
|
|
if last_word_end is not None:
|
|
|
|
|
last_speech_timestamp = last_word_end
|
|
|
|
|
|
|
|
|
|
seek += segment.shape[-1]
|
|
|
|
|
all_tokens.extend(tokens.tolist())
|
|
|
|
|
if verbose:
|
|
|
|
|
for segment in current_segments:
|
|
|
|
|
start, end, text = segment["start"], segment[
|
|
|
|
|
"end"], segment["text"]
|
|
|
|
|
line = f"[{utils.format_timestamp(start)} --> {utils.format_timestamp(end)}] {text}"
|
|
|
|
|
print(line)
|
|
|
|
|
|
|
|
|
|
# if a segment is instantaneous or does not contain text, clear it
|
|
|
|
|
for i, segment in enumerate(current_segments):
|
|
|
|
|
if segment["start"] == segment["end"] or segment[
|
|
|
|
|
"text"].strip() == "":
|
|
|
|
|
segment["text"] = ""
|
|
|
|
|
segment["tokens"] = []
|
|
|
|
|
segment["words"] = []
|
|
|
|
|
|
|
|
|
|
all_segments.extend(
|
|
|
|
|
[{
|
|
|
|
|
"id": i,
|
|
|
|
|
**
|
|
|
|
|
segment
|
|
|
|
|
}
|
|
|
|
|
for i, segment in enumerate(
|
|
|
|
|
current_segments, start=len(all_segments))])
|
|
|
|
|
all_tokens.extend([
|
|
|
|
|
token
|
|
|
|
|
for segment in current_segments for token in segment["tokens"]
|
|
|
|
|
])
|
|
|
|
|
|
|
|
|
|
if not condition_on_previous_text or result.temperature > 0.5:
|
|
|
|
|
# do not feed the prompt tokens if a high temperature was used
|
|
|
|
|
prompt_reset_since = len(all_tokens)
|
|
|
|
|
|
|
|
|
|
# update progress bar
|
|
|
|
|
pbar.update(min(num_frames, seek) - previous_seek_value)
|
|
|
|
|
previous_seek_value = seek
|
|
|
|
|
pbar.update(min(content_frames, seek) - previous_seek)
|
|
|
|
|
|
|
|
|
|
return dict(
|
|
|
|
|
text=tokenizer.decode(all_tokens[len(initial_prompt):]),
|
|
|
|
|
text=tokenizer.decode(all_tokens[len(initial_prompt_tokens):]),
|
|
|
|
|
segments=all_segments,
|
|
|
|
|
language=language)
|
|
|
|
|
language=language, )
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
class SequenceRanker:
|
|
|
|
@ -776,11 +986,11 @@ class GreedyDecoder(TokenDecoder):
|
|
|
|
|
next_tokens.shape[0] * next_tokens.shape[1],
|
|
|
|
|
])
|
|
|
|
|
|
|
|
|
|
logprobs = F.log_softmax(logits, axis=-1, dtype=paddle.float32)
|
|
|
|
|
logprobs = F.log_softmax(logits, axis=-1, dtype="float32")
|
|
|
|
|
current_logprobs = logprobs[paddle.arange(logprobs.shape[0]),
|
|
|
|
|
next_tokens]
|
|
|
|
|
sum_logprobs += current_logprobs * paddle.to_tensor(
|
|
|
|
|
(tokens[:, -1] != self.eot), dtype=paddle.float32)
|
|
|
|
|
(tokens[:, -1] != self.eot), dtype="float32")
|
|
|
|
|
|
|
|
|
|
next_tokens[tokens[:, -1] == self.eot] = self.eot
|
|
|
|
|
tokens = paddle.concat([tokens, next_tokens[:, None]], axis=-1)
|
|
|
|
@ -928,8 +1138,8 @@ class SuppressBlank(LogitFilter):
|
|
|
|
|
|
|
|
|
|
def apply(self, logits: paddle.Tensor, tokens: paddle.Tensor):
|
|
|
|
|
if tokens.shape[1] == self.sample_begin:
|
|
|
|
|
logits[:, self.tokenizer.encode(" ").input_ids +
|
|
|
|
|
[self.tokenizer.eot]] = -np.inf
|
|
|
|
|
logits[:, self.tokenizer.encode(" ") + [self.tokenizer.eot
|
|
|
|
|
]] = -np.inf
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
class SuppressTokens(LogitFilter):
|
|
|
|
@ -1005,7 +1215,7 @@ class DecodingTask:
|
|
|
|
|
|
|
|
|
|
language = options.language or "en"
|
|
|
|
|
tokenizer = get_tokenizer(
|
|
|
|
|
model.is_multilingual,
|
|
|
|
|
multilingual=model.is_multilingual,
|
|
|
|
|
resource_path=resource_path,
|
|
|
|
|
language=language,
|
|
|
|
|
task=options.task)
|
|
|
|
@ -1346,11 +1556,16 @@ class Whisper(nn.Layer):
|
|
|
|
|
|
|
|
|
|
@property
|
|
|
|
|
def device(self):
|
|
|
|
|
# return str(paddle.device.get_device()).split(":")[0]
|
|
|
|
|
return paddle.device.get_device()
|
|
|
|
|
|
|
|
|
|
@property
|
|
|
|
|
def is_multilingual(self):
|
|
|
|
|
return self.dims.n_vocab == 51865
|
|
|
|
|
return self.dims.n_vocab >= 51865
|
|
|
|
|
|
|
|
|
|
@property
|
|
|
|
|
def num_languages(self):
|
|
|
|
|
return self.dims.n_vocab - 51765 - int(self.is_multilingual)
|
|
|
|
|
|
|
|
|
|
def install_kv_cache_hooks(self, cache: Optional[dict]=None):
|
|
|
|
|
"""
|
|
|
|
@ -1364,7 +1579,7 @@ class Whisper(nn.Layer):
|
|
|
|
|
cache : Dict[nn.Layer, paddle.Tensor]
|
|
|
|
|
A dictionary object mapping the key/value projection modules to its cache
|
|
|
|
|
hooks : List[RemovableHandle]
|
|
|
|
|
List of PyTorch RemovableHandle objects to stop the hooks to be called
|
|
|
|
|
List of Paddle RemovableHandle objects to stop the hooks to be called
|
|
|
|
|
"""
|
|
|
|
|
cache = {**cache} if cache is not None else {}
|
|
|
|
|
hooks = []
|
|
|
|
@ -1435,7 +1650,7 @@ def hann_window(n_fft: int=N_FFT):
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
@lru_cache(maxsize=None)
|
|
|
|
|
def mel_filters(resource_path: str, n_mels: int=N_MELS) -> paddle.Tensor:
|
|
|
|
|
def mel_filters(resource_path: str, n_mels: int) -> paddle.Tensor:
|
|
|
|
|
"""
|
|
|
|
|
load the mel filterbank matrix for projecting STFT into a Mel spectrogram.
|
|
|
|
|
Allows decoupling librosa dependency; saved using:
|
|
|
|
@ -1445,13 +1660,19 @@ def mel_filters(resource_path: str, n_mels: int=N_MELS) -> paddle.Tensor:
|
|
|
|
|
mel_80=librosa.filters.mel(sr=16000, n_fft=400, n_mels=80),
|
|
|
|
|
)
|
|
|
|
|
"""
|
|
|
|
|
assert n_mels == 80, f"Unsupported n_mels: {n_mels}"
|
|
|
|
|
with np.load(os.path.join(resource_path, "assets", "mel_filters.npz")) as f:
|
|
|
|
|
# assert n_mels == 80, f"Unsupported n_mels: {n_mels}"
|
|
|
|
|
# with np.load(os.path.join(resource_path, "assets", "mel_filters.npz")) as f:
|
|
|
|
|
# return paddle.to_tensor(f[f"mel_{n_mels}"])
|
|
|
|
|
assert n_mels in {80, 128}, f"Unsupported n_mels: {n_mels}"
|
|
|
|
|
|
|
|
|
|
filters_path = os.path.join(resource_path, "assets", "mel_filters.npz")
|
|
|
|
|
with np.load(filters_path, allow_pickle=False) as f:
|
|
|
|
|
return paddle.to_tensor(f[f"mel_{n_mels}"])
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
def log_mel_spectrogram(audio: Union[str, np.ndarray, paddle.Tensor],
|
|
|
|
|
n_mels: int=N_MELS,
|
|
|
|
|
n_mels: int=80,
|
|
|
|
|
padding: int=0,
|
|
|
|
|
resource_path: str=None):
|
|
|
|
|
"""
|
|
|
|
|
Compute the log-Mel spectrogram of
|
|
|
|
@ -1475,7 +1696,8 @@ def log_mel_spectrogram(audio: Union[str, np.ndarray, paddle.Tensor],
|
|
|
|
|
audio = audio[:, 0]
|
|
|
|
|
logger.info(f"audio shape: {audio.shape}")
|
|
|
|
|
audio = paddle.to_tensor(audio)
|
|
|
|
|
|
|
|
|
|
if padding > 0:
|
|
|
|
|
audio = F.pad(audio, (0, padding), data_format="NLC")
|
|
|
|
|
window = hann_window(N_FFT)
|
|
|
|
|
stft = paddle.signal.stft(audio, N_FFT, HOP_LENGTH, window=window)
|
|
|
|
|
|
|
|
|
|