diff --git a/paddlespeech/s2t/models/wav2vec2/wav2vec2_ASR.py b/paddlespeech/s2t/models/wav2vec2/wav2vec2_ASR.py index 3fb91cb86..b19d967e2 100755 --- a/paddlespeech/s2t/models/wav2vec2/wav2vec2_ASR.py +++ b/paddlespeech/s2t/models/wav2vec2/wav2vec2_ASR.py @@ -84,18 +84,15 @@ class Wav2vec2ASR(nn.Layer): def extract_features(self, wav): if self.normalize_wav: wav = F.layer_norm(wav, wav.shape[-1]) - # Extract wav2vec output out = self.wav2vec2(wav)[0] # We normalize the output if required if self.output_norm: out = F.layer_norm(out, out.shape[-1]) - if self.training and hasattr(self.config, 'spec_augment'): feats = self.spec_augment(out) else: feats = out - return feats @paddle.no_grad()