|
|
|
@ -19,6 +19,9 @@ from typing import Tuple
|
|
|
|
|
import paddle
|
|
|
|
|
import paddle.nn as nn
|
|
|
|
|
import paddle.nn.functional as F
|
|
|
|
|
|
|
|
|
|
from .wavlm_paddle import WavLM
|
|
|
|
|
from .wavlm_paddle import WavLMConfig
|
|
|
|
|
from paddlespeech.s2t.models.wav2vec2.modules.VanillaNN import VanillaNN
|
|
|
|
|
from paddlespeech.s2t.models.wav2vec2.processing.speech_augmentation import SpecAugment
|
|
|
|
|
from paddlespeech.s2t.modules.ctc import CTCDecoderBase as CTC
|
|
|
|
@ -26,8 +29,6 @@ from paddlespeech.s2t.modules.initializer import DefaultInitializerContext
|
|
|
|
|
from paddlespeech.s2t.utils.ctc_utils import remove_duplicates_and_blank
|
|
|
|
|
from paddlespeech.s2t.utils.utility import log_add
|
|
|
|
|
|
|
|
|
|
from .wavlm_paddle import WavLM, WavLMConfig
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
class WavLMASR(nn.Layer):
|
|
|
|
|
def __init__(self, config: dict):
|
|
|
|
@ -56,13 +57,13 @@ class WavLMASR(nn.Layer):
|
|
|
|
|
|
|
|
|
|
def forward(self, wav, wavs_lens_rate, target, target_lens):
|
|
|
|
|
if self.normalize_wav:
|
|
|
|
|
wav = F.layer_norm(wav, wav.shape)
|
|
|
|
|
wav = F.layer_norm(wav, wav.shape[1:])
|
|
|
|
|
|
|
|
|
|
# Extract wav2vec output
|
|
|
|
|
out = self.wavlm(wav)
|
|
|
|
|
# We normalize the output if required
|
|
|
|
|
if self.output_norm:
|
|
|
|
|
out = F.layer_norm(out, out.shape)
|
|
|
|
|
out = F.layer_norm(out, out.shape[1:])
|
|
|
|
|
|
|
|
|
|
if self.training and hasattr(self.config, 'spec_augment'):
|
|
|
|
|
feats = self.spec_augment(out)
|
|
|
|
|