|
|
|
@ -80,6 +80,24 @@ class Wav2vec2ASR(nn.Layer):
|
|
|
|
|
|
|
|
|
|
return ctc_loss
|
|
|
|
|
|
|
|
|
|
@paddle.no_grad()
|
|
|
|
|
def extract_features(self, wav):
|
|
|
|
|
if self.normalize_wav:
|
|
|
|
|
wav = F.layer_norm(wav, wav.shape[-1])
|
|
|
|
|
|
|
|
|
|
# Extract wav2vec output
|
|
|
|
|
out = self.wav2vec2(wav)[0]
|
|
|
|
|
# We normalize the output if required
|
|
|
|
|
if self.output_norm:
|
|
|
|
|
out = F.layer_norm(out, out.shape[-1])
|
|
|
|
|
|
|
|
|
|
if self.training and hasattr(self.config, 'spec_augment'):
|
|
|
|
|
feats = self.spec_augment(out)
|
|
|
|
|
else:
|
|
|
|
|
feats = out
|
|
|
|
|
|
|
|
|
|
return feats
|
|
|
|
|
|
|
|
|
|
@paddle.no_grad()
|
|
|
|
|
def decode(self,
|
|
|
|
|
feats: paddle.Tensor,
|
|
|
|
|