diff --git a/paddlespeech/s2t/exps/whisper/test_wav.py b/paddlespeech/s2t/exps/whisper/test_wav.py index 27c167063..f9fc11302 100644 --- a/paddlespeech/s2t/exps/whisper/test_wav.py +++ b/paddlespeech/s2t/exps/whisper/test_wav.py @@ -45,6 +45,7 @@ class WhisperInfer(): model_dict = paddle.load(self.config.model_file) config.pop("model_file") dims = ModelDimensions(**model_dict["dims"]) + self.dims = dims self.model = Whisper(dims) self.model.load_dict(model_dict) @@ -63,12 +64,10 @@ class WhisperInfer(): temperature = [temperature] #load audio - # mel = log_mel_spectrogram( - # args.audio_file, resource_path=config.resource_path, , n_mels=128) - audio = log_mel_spectrogram( + mel = log_mel_spectrogram( args.audio_file, resource_path=config.resource_path, - n_mels=128, + n_mels=self.dims.n_mels, padding=480000) result = transcribe( self.model, mel, temperature=temperature, **config) diff --git a/paddlespeech/s2t/models/whisper/whisper.py b/paddlespeech/s2t/models/whisper/whisper.py index 504b7637d..4203d9021 100644 --- a/paddlespeech/s2t/models/whisper/whisper.py +++ b/paddlespeech/s2t/models/whisper/whisper.py @@ -397,9 +397,7 @@ def detect_language( # skip encoder forward pass if already-encoded audio features were given if mel.shape[-2:] != (model.dims.n_audio_ctx, model.dims.n_audio_state): - mel = model.encoder( - mel - ) # TODO zhaoxi: torch return float16, while cause e-3 diff with paddle float32 + mel = model.encoder(mel) # forward pass using a single token, startoftranscript batch_size = mel.shape[0] @@ -1149,7 +1147,6 @@ class DecodingTask: self.options: DecodingOptions = self._verify_options(options) self.resource_path: str = resource_path - # self.beam_size: int = options.beam_size or options.best_of or 1 self.n_group: int = options.beam_size or options.best_of or 1 self.n_ctx: int = model.dims.n_text_ctx self.sample_len: int = options.sample_len or model.dims.n_text_ctx // 2