pull/4101/head
zxcd 2 weeks ago
parent fe4860550a
commit 86f2bd1f6a

@ -45,6 +45,7 @@ class WhisperInfer():
model_dict = paddle.load(self.config.model_file)
config.pop("model_file")
dims = ModelDimensions(**model_dict["dims"])
self.dims = dims
self.model = Whisper(dims)
self.model.load_dict(model_dict)
@ -63,12 +64,10 @@ class WhisperInfer():
temperature = [temperature]
#load audio
# mel = log_mel_spectrogram(
# args.audio_file, resource_path=config.resource_path, , n_mels=128)
audio = log_mel_spectrogram(
mel = log_mel_spectrogram(
args.audio_file,
resource_path=config.resource_path,
n_mels=128,
n_mels=self.dims.n_mels,
padding=480000)
result = transcribe(
self.model, mel, temperature=temperature, **config)

@ -397,9 +397,7 @@ def detect_language(
# skip encoder forward pass if already-encoded audio features were given
if mel.shape[-2:] != (model.dims.n_audio_ctx, model.dims.n_audio_state):
mel = model.encoder(
mel
) # TODO zhaoxi: torch return float16, while cause e-3 diff with paddle float32
mel = model.encoder(mel)
# forward pass using a single token, startoftranscript
batch_size = mel.shape[0]
@ -1149,7 +1147,6 @@ class DecodingTask:
self.options: DecodingOptions = self._verify_options(options)
self.resource_path: str = resource_path
# self.beam_size: int = options.beam_size or options.best_of or 1
self.n_group: int = options.beam_size or options.best_of or 1
self.n_ctx: int = model.dims.n_text_ctx
self.sample_len: int = options.sample_len or model.dims.n_text_ctx // 2

Loading…
Cancel
Save