fix the shape error in layer_norm (#3884)

pull/3890/head
张春乔 10 months ago committed by GitHub
parent 4fdb0647f4
commit 6f44ac92c8
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194

@ -84,13 +84,13 @@ class HubertASR(nn.Layer):
def forward(self, wav, wavs_lens_rate, target, target_lens):
if self.normalize_wav:
wav = F.layer_norm(wav, wav.shape)
wav = F.layer_norm(wav, wav.shape[1:])
# Extract wav2vec output
out = self.hubert.extract_features(wav)[0]
# We normalize the output if required
if self.output_norm:
out = F.layer_norm(out, out.shape)
out = F.layer_norm(out, out.shape[1:])
if self.training and hasattr(self.config, 'spec_augment'):
feats = self.spec_augment(out)

@ -59,13 +59,13 @@ class Wav2vec2ASR(nn.Layer):
def forward(self, wav, wavs_lens_rate, target, target_lens):
if self.normalize_wav:
wav = F.layer_norm(wav, wav.shape)
wav = F.layer_norm(wav, wav.shape[1:])
# Extract wav2vec output
out = self.wav2vec2(wav)[0]
# We normalize the output if required
if self.output_norm:
out = F.layer_norm(out, out.shape)
out = F.layer_norm(out, out.shape[1:])
if self.training and hasattr(self.config, 'spec_augment'):
feats = self.spec_augment(out)

@ -19,6 +19,9 @@ from typing import Tuple
import paddle
import paddle.nn as nn
import paddle.nn.functional as F
from .wavlm_paddle import WavLM
from .wavlm_paddle import WavLMConfig
from paddlespeech.s2t.models.wav2vec2.modules.VanillaNN import VanillaNN
from paddlespeech.s2t.models.wav2vec2.processing.speech_augmentation import SpecAugment
from paddlespeech.s2t.modules.ctc import CTCDecoderBase as CTC
@ -26,8 +29,6 @@ from paddlespeech.s2t.modules.initializer import DefaultInitializerContext
from paddlespeech.s2t.utils.ctc_utils import remove_duplicates_and_blank
from paddlespeech.s2t.utils.utility import log_add
from .wavlm_paddle import WavLM, WavLMConfig
class WavLMASR(nn.Layer):
def __init__(self, config: dict):
@ -56,13 +57,13 @@ class WavLMASR(nn.Layer):
def forward(self, wav, wavs_lens_rate, target, target_lens):
if self.normalize_wav:
wav = F.layer_norm(wav, wav.shape)
wav = F.layer_norm(wav, wav.shape[1:])
# Extract wav2vec output
out = self.wavlm(wav)
# We normalize the output if required
if self.output_norm:
out = F.layer_norm(out, out.shape)
out = F.layer_norm(out, out.shape[1:])
if self.training and hasattr(self.config, 'spec_augment'):
feats = self.spec_augment(out)

Loading…
Cancel
Save