Update wav2vec2_ASR.py

pull/3872/head
张春乔 11 months ago committed by GitHub
parent cc4904b67a
commit 0e8468d9fd
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194

@ -80,24 +80,6 @@ class Wav2vec2ASR(nn.Layer):
return ctc_loss
@paddle.no_grad()
def extract_features(self, wav):
if self.normalize_wav:
wav = F.layer_norm(wav, wav.shape[-1])
# Extract wav2vec output
out = self.wav2vec2(wav)[0]
# We normalize the output if required
if self.output_norm:
out = F.layer_norm(out, out.shape[-1])
if self.training and hasattr(self.config, 'spec_augment'):
feats = self.spec_augment(out)
else:
feats = out
return feats
@paddle.no_grad()
def decode(self,
feats: paddle.Tensor,

Loading…
Cancel
Save