diff --git a/paddlespeech/s2t/models/whisper/tokenizer.py b/paddlespeech/s2t/models/whisper/tokenizer.py index 8bd85c91..1e1aea04 100644 --- a/paddlespeech/s2t/models/whisper/tokenizer.py +++ b/paddlespeech/s2t/models/whisper/tokenizer.py @@ -155,6 +155,10 @@ class Tokenizer: if ids < len(self.tokenizer): ids_list.append(ids) token_ids = ids_list + elif len(token_ids) == 1: + token_ids = token_ids[0] + else: + raise ValueError(f"token_ids {token_ids} load error.") return self.tokenizer.decode(token_ids, **kwargs) diff --git a/paddlespeech/s2t/models/whisper/whipser.py b/paddlespeech/s2t/models/whisper/whipser.py index 63cafbdb..81692f37 100644 --- a/paddlespeech/s2t/models/whisper/whipser.py +++ b/paddlespeech/s2t/models/whisper/whipser.py @@ -17,12 +17,11 @@ from typing import Union import numpy as np import paddle import paddle.nn.functional as F +import paddlespeech.s2t.modules.align as paddlespeech_nn import soundfile import tqdm from paddle import nn from paddle.distribution import Categorical - -import paddlespeech.s2t.modules.align as paddlespeech_nn from paddlespeech.s2t.models.whisper import utils from paddlespeech.s2t.models.whisper.tokenizer import get_tokenizer from paddlespeech.s2t.models.whisper.tokenizer import LANGUAGES @@ -771,8 +770,10 @@ class GreedyDecoder(TokenDecoder): if temperature == 0: next_tokens = paddle.argmax(logits, axis=-1) else: - next_tokens = Categorical(logits=logits / temperature).sample( - shape=logits.shape) + next_tokens = Categorical(logits=logits / temperature).sample([1]) + next_tokens = paddle.reshape(next_tokens, [ + next_tokens.shape[0] * next_tokens.shape[1], + ]) logprobs = F.log_softmax(logits, axis=-1, dtype=paddle.float32) current_logprobs = logprobs[paddle.arange(logprobs.shape[0]), @@ -1205,9 +1206,8 @@ class DecodingTask: DecodingResult( audio_features=features, language=language, - language_probs=probs) - for features, language, probs in zip(audio_features, languages, - language_probs) + language_probs=probs) for features, language, probs in + zip(audio_features, languages, language_probs) ] # repeat the audio & text tensors by the group size, for beam search or best-of-n sampling