diff --git a/paddlespeech/s2t/models/wav2vec2/wav2vec2_ASR.py b/paddlespeech/s2t/models/wav2vec2/wav2vec2_ASR.py index 64195defc..3fb91cb86 100755 --- a/paddlespeech/s2t/models/wav2vec2/wav2vec2_ASR.py +++ b/paddlespeech/s2t/models/wav2vec2/wav2vec2_ASR.py @@ -80,6 +80,24 @@ class Wav2vec2ASR(nn.Layer): return ctc_loss + @paddle.no_grad() + def extract_features(self, wav): + if self.normalize_wav: + wav = F.layer_norm(wav, wav.shape[-1]) + + # Extract wav2vec output + out = self.wav2vec2(wav)[0] + # We normalize the output if required + if self.output_norm: + out = F.layer_norm(out, out.shape[-1]) + + if self.training and hasattr(self.config, 'spec_augment'): + feats = self.spec_augment(out) + else: + feats = out + + return feats + @paddle.no_grad() def decode(self, feats: paddle.Tensor,