From 890c87ea93f3146666c6825306ceb8e21b18d099 Mon Sep 17 00:00:00 2001 From: megemini Date: Mon, 2 Dec 2024 11:08:28 +0800 Subject: [PATCH] [Fix] import TimeDomainSpecAugment (#3919) --- paddlespeech/s2t/exps/wavlm/model.py | 43 ++++++++++++++++------------ 1 file changed, 24 insertions(+), 19 deletions(-) diff --git a/paddlespeech/s2t/exps/wavlm/model.py b/paddlespeech/s2t/exps/wavlm/model.py index 6ed2c5d8..606867ea 100644 --- a/paddlespeech/s2t/exps/wavlm/model.py +++ b/paddlespeech/s2t/exps/wavlm/model.py @@ -33,7 +33,7 @@ from paddlespeech.s2t.io.speechbrain import data_pipeline from paddlespeech.s2t.io.speechbrain import dataio from paddlespeech.s2t.io.speechbrain import dataset from paddlespeech.s2t.io.speechbrain.dataloader import make_dataloader -from paddlespeech.s2t.models.wavlm.processing.speech_augmentation import TimeDomainSpecAugment +from paddlespeech.s2t.models.wav2vec2.processing.speech_augmentation import TimeDomainSpecAugment from paddlespeech.s2t.models.wavlm.wavlm_asr import WavLMASR from paddlespeech.s2t.training.optimizer import OptimizerFactory from paddlespeech.s2t.training.reporter import ObsScope @@ -211,7 +211,7 @@ class WavLMASRTrainer(Trainer): loss.backward() layer_tools.print_grads(self.model, print_func=None) - + # NOTE: the code below asserted that the backward() is problematic, and as more steps are accumulated, the output from wavlm alone will be the same for all frames # optimizer step old if (batch_index + 1) % train_conf.accum_grad == 0: @@ -428,8 +428,7 @@ class WavLMASRTrainer(Trainer): report("epoch", self.epoch) report('step', self.iteration) report("model_lr", self.model_optimizer.get_lr()) - report("wavlm_lr", - self.wavlm_optimizer.get_lr()) + report("wavlm_lr", self.wavlm_optimizer.get_lr()) self.train_batch(batch_index, batch, msg) self.after_train_batch() report('iter', batch_index + 1) @@ -680,8 +679,7 @@ class WavLMASRTrainer(Trainer): logger.info("optim_model:{},{}", model_optim_type, model_optim_conf) wavlm_optim_type = train_config.wavlm_optim wavlm_optim_conf = train_config.wavlm_optim_conf - logger.info("optim_model:{},{}", wavlm_optim_type, - wavlm_optim_conf) + logger.info("optim_model:{},{}", wavlm_optim_type, wavlm_optim_conf) model_scheduler_type = train_config.model_scheduler model_scheduler_conf = train_config.model_scheduler_conf @@ -698,8 +696,8 @@ class WavLMASRTrainer(Trainer): model_lr_scheduler = LRSchedulerFactory.from_args(model_scheduler_type, model_scheduler_args) - wavlm_lr_scheduler = LRSchedulerFactory.from_args( - wavlm_scheduler_type, wavlm_scheduler_args) + wavlm_lr_scheduler = LRSchedulerFactory.from_args(wavlm_scheduler_type, + wavlm_scheduler_args) def optimizer_args( config, @@ -716,24 +714,31 @@ class WavLMASRTrainer(Trainer): }) return optim_arg - model_optimizer_args = optimizer_args( - config, model_optim_type, - model_optim_conf, - [{'params': model._layers.enc.parameters()}, {'params': model._layers.ctc.parameters()}] if self.parallel else [{'params': model.enc.parameters()}, {'params': model.ctc.parameters()}], - model_lr_scheduler - ) - # [{'params': model._layers.ctc.parameters()}] if self.parallel else [{'params': model.ctc.parameters()}], model_lr_scheduler) - + model_optimizer_args = optimizer_args(config, model_optim_type, + model_optim_conf, [{ + 'params': + model._layers.enc.parameters() + }, { + 'params': + model._layers.ctc.parameters() + }] if self.parallel else [{ + 'params': + model.enc.parameters() + }, { + 'params': + model.ctc.parameters() + }], model_lr_scheduler) + # [{'params': model._layers.ctc.parameters()}] if self.parallel else [{'params': model.ctc.parameters()}], model_lr_scheduler) wavlm_optimizer_args = optimizer_args( config, wavlm_optim_type, wavlm_optim_conf, - model._layers.wavlm.parameters() if self.parallel else - model.wavlm.parameters(), wavlm_lr_scheduler) + model._layers.wavlm.parameters() + if self.parallel else model.wavlm.parameters(), wavlm_lr_scheduler) model_optimizer = OptimizerFactory.from_args(model_optim_type, model_optimizer_args) wavlm_optimizer = OptimizerFactory.from_args(wavlm_optim_type, - wavlm_optimizer_args) + wavlm_optimizer_args) self.model_optimizer = model_optimizer self.wavlm_optimizer = wavlm_optimizer