Update wav2vec2_ASR.py

pull/3872/head
张春乔 11 months ago committed by GitHub
parent fca3b80178
commit 9bc11ce671
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194

@ -80,6 +80,24 @@ class Wav2vec2ASR(nn.Layer):
return ctc_loss
@paddle.no_grad()
def extract_features(self, wav):
if self.normalize_wav:
wav = F.layer_norm(wav, wav.shape[-1])
# Extract wav2vec output
out = self.wav2vec2(wav)[0]
# We normalize the output if required
if self.output_norm:
out = F.layer_norm(out, out.shape[-1])
if self.training and hasattr(self.config, 'spec_augment'):
feats = self.spec_augment(out)
else:
feats = out
return feats
@paddle.no_grad()
def decode(self,
feats: paddle.Tensor,

Loading…
Cancel
Save