From 6f44ac92c8ff50e47fda3d1d19b366b2b4ee205a Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?=E5=BC=A0=E6=98=A5=E4=B9=94?= <83450930+Liyulingyue@users.noreply.github.com> Date: Thu, 14 Nov 2024 19:44:32 +0800 Subject: [PATCH] fix the shape error in layer_norm (#3884) --- paddlespeech/s2t/models/hubert/hubert_ASR.py | 4 ++-- paddlespeech/s2t/models/wav2vec2/wav2vec2_ASR.py | 4 ++-- paddlespeech/s2t/models/wavlm/wavlm_asr.py | 9 +++++---- 3 files changed, 9 insertions(+), 8 deletions(-) diff --git a/paddlespeech/s2t/models/hubert/hubert_ASR.py b/paddlespeech/s2t/models/hubert/hubert_ASR.py index df3475897..4a0dc2aa6 100644 --- a/paddlespeech/s2t/models/hubert/hubert_ASR.py +++ b/paddlespeech/s2t/models/hubert/hubert_ASR.py @@ -84,13 +84,13 @@ class HubertASR(nn.Layer): def forward(self, wav, wavs_lens_rate, target, target_lens): if self.normalize_wav: - wav = F.layer_norm(wav, wav.shape) + wav = F.layer_norm(wav, wav.shape[1:]) # Extract wav2vec output out = self.hubert.extract_features(wav)[0] # We normalize the output if required if self.output_norm: - out = F.layer_norm(out, out.shape) + out = F.layer_norm(out, out.shape[1:]) if self.training and hasattr(self.config, 'spec_augment'): feats = self.spec_augment(out) diff --git a/paddlespeech/s2t/models/wav2vec2/wav2vec2_ASR.py b/paddlespeech/s2t/models/wav2vec2/wav2vec2_ASR.py index 64195defc..179f7038d 100755 --- a/paddlespeech/s2t/models/wav2vec2/wav2vec2_ASR.py +++ b/paddlespeech/s2t/models/wav2vec2/wav2vec2_ASR.py @@ -59,13 +59,13 @@ class Wav2vec2ASR(nn.Layer): def forward(self, wav, wavs_lens_rate, target, target_lens): if self.normalize_wav: - wav = F.layer_norm(wav, wav.shape) + wav = F.layer_norm(wav, wav.shape[1:]) # Extract wav2vec output out = self.wav2vec2(wav)[0] # We normalize the output if required if self.output_norm: - out = F.layer_norm(out, out.shape) + out = F.layer_norm(out, out.shape[1:]) if self.training and hasattr(self.config, 'spec_augment'): feats = self.spec_augment(out) diff --git a/paddlespeech/s2t/models/wavlm/wavlm_asr.py b/paddlespeech/s2t/models/wavlm/wavlm_asr.py index 5764890d2..53dd498d5 100644 --- a/paddlespeech/s2t/models/wavlm/wavlm_asr.py +++ b/paddlespeech/s2t/models/wavlm/wavlm_asr.py @@ -19,6 +19,9 @@ from typing import Tuple import paddle import paddle.nn as nn import paddle.nn.functional as F + +from .wavlm_paddle import WavLM +from .wavlm_paddle import WavLMConfig from paddlespeech.s2t.models.wav2vec2.modules.VanillaNN import VanillaNN from paddlespeech.s2t.models.wav2vec2.processing.speech_augmentation import SpecAugment from paddlespeech.s2t.modules.ctc import CTCDecoderBase as CTC @@ -26,8 +29,6 @@ from paddlespeech.s2t.modules.initializer import DefaultInitializerContext from paddlespeech.s2t.utils.ctc_utils import remove_duplicates_and_blank from paddlespeech.s2t.utils.utility import log_add -from .wavlm_paddle import WavLM, WavLMConfig - class WavLMASR(nn.Layer): def __init__(self, config: dict): @@ -56,13 +57,13 @@ class WavLMASR(nn.Layer): def forward(self, wav, wavs_lens_rate, target, target_lens): if self.normalize_wav: - wav = F.layer_norm(wav, wav.shape) + wav = F.layer_norm(wav, wav.shape[1:]) # Extract wav2vec output out = self.wavlm(wav) # We normalize the output if required if self.output_norm: - out = F.layer_norm(out, out.shape) + out = F.layer_norm(out, out.shape[1:]) if self.training and hasattr(self.config, 'spec_augment'): feats = self.spec_augment(out)