From 9bc11ce671bc6c04a3d35fc191960c4f90975923 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?=E5=BC=A0=E6=98=A5=E4=B9=94?= <83450930+Liyulingyue@users.noreply.github.com> Date: Wed, 6 Nov 2024 00:18:54 +0800 Subject: [PATCH] Update wav2vec2_ASR.py --- .../s2t/models/wav2vec2/wav2vec2_ASR.py | 18 ++++++++++++++++++ 1 file changed, 18 insertions(+) diff --git a/paddlespeech/s2t/models/wav2vec2/wav2vec2_ASR.py b/paddlespeech/s2t/models/wav2vec2/wav2vec2_ASR.py index 64195defc..3fb91cb86 100755 --- a/paddlespeech/s2t/models/wav2vec2/wav2vec2_ASR.py +++ b/paddlespeech/s2t/models/wav2vec2/wav2vec2_ASR.py @@ -80,6 +80,24 @@ class Wav2vec2ASR(nn.Layer): return ctc_loss + @paddle.no_grad() + def extract_features(self, wav): + if self.normalize_wav: + wav = F.layer_norm(wav, wav.shape[-1]) + + # Extract wav2vec output + out = self.wav2vec2(wav)[0] + # We normalize the output if required + if self.output_norm: + out = F.layer_norm(out, out.shape[-1]) + + if self.training and hasattr(self.config, 'spec_augment'): + feats = self.spec_augment(out) + else: + feats = out + + return feats + @paddle.no_grad() def decode(self, feats: paddle.Tensor,