fix some bug. (#2825)

pull/2828/head
zxcd 2 years ago committed by GitHub
parent faa2f86651
commit ad40dafa85
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23

@ -155,6 +155,10 @@ class Tokenizer:
if ids < len(self.tokenizer): if ids < len(self.tokenizer):
ids_list.append(ids) ids_list.append(ids)
token_ids = ids_list token_ids = ids_list
elif len(token_ids) == 1:
token_ids = token_ids[0]
else:
raise ValueError(f"token_ids {token_ids} load error.")
return self.tokenizer.decode(token_ids, **kwargs) return self.tokenizer.decode(token_ids, **kwargs)

@ -17,12 +17,11 @@ from typing import Union
import numpy as np import numpy as np
import paddle import paddle
import paddle.nn.functional as F import paddle.nn.functional as F
import paddlespeech.s2t.modules.align as paddlespeech_nn
import soundfile import soundfile
import tqdm import tqdm
from paddle import nn from paddle import nn
from paddle.distribution import Categorical from paddle.distribution import Categorical
import paddlespeech.s2t.modules.align as paddlespeech_nn
from paddlespeech.s2t.models.whisper import utils from paddlespeech.s2t.models.whisper import utils
from paddlespeech.s2t.models.whisper.tokenizer import get_tokenizer from paddlespeech.s2t.models.whisper.tokenizer import get_tokenizer
from paddlespeech.s2t.models.whisper.tokenizer import LANGUAGES from paddlespeech.s2t.models.whisper.tokenizer import LANGUAGES
@ -771,8 +770,10 @@ class GreedyDecoder(TokenDecoder):
if temperature == 0: if temperature == 0:
next_tokens = paddle.argmax(logits, axis=-1) next_tokens = paddle.argmax(logits, axis=-1)
else: else:
next_tokens = Categorical(logits=logits / temperature).sample( next_tokens = Categorical(logits=logits / temperature).sample([1])
shape=logits.shape) next_tokens = paddle.reshape(next_tokens, [
next_tokens.shape[0] * next_tokens.shape[1],
])
logprobs = F.log_softmax(logits, axis=-1, dtype=paddle.float32) logprobs = F.log_softmax(logits, axis=-1, dtype=paddle.float32)
current_logprobs = logprobs[paddle.arange(logprobs.shape[0]), current_logprobs = logprobs[paddle.arange(logprobs.shape[0]),
@ -1205,9 +1206,8 @@ class DecodingTask:
DecodingResult( DecodingResult(
audio_features=features, audio_features=features,
language=language, language=language,
language_probs=probs) language_probs=probs) for features, language, probs in
for features, language, probs in zip(audio_features, languages, zip(audio_features, languages, language_probs)
language_probs)
] ]
# repeat the audio & text tensors by the group size, for beam search or best-of-n sampling # repeat the audio & text tensors by the group size, for beam search or best-of-n sampling

Loading…
Cancel
Save