[Fix] import TimeDomainSpecAugment (#3919)

pull/3923/head
megemini 3 weeks ago committed by GitHub
parent 5b3612f273
commit 890c87ea93
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194

@ -33,7 +33,7 @@ from paddlespeech.s2t.io.speechbrain import data_pipeline
from paddlespeech.s2t.io.speechbrain import dataio from paddlespeech.s2t.io.speechbrain import dataio
from paddlespeech.s2t.io.speechbrain import dataset from paddlespeech.s2t.io.speechbrain import dataset
from paddlespeech.s2t.io.speechbrain.dataloader import make_dataloader from paddlespeech.s2t.io.speechbrain.dataloader import make_dataloader
from paddlespeech.s2t.models.wavlm.processing.speech_augmentation import TimeDomainSpecAugment from paddlespeech.s2t.models.wav2vec2.processing.speech_augmentation import TimeDomainSpecAugment
from paddlespeech.s2t.models.wavlm.wavlm_asr import WavLMASR from paddlespeech.s2t.models.wavlm.wavlm_asr import WavLMASR
from paddlespeech.s2t.training.optimizer import OptimizerFactory from paddlespeech.s2t.training.optimizer import OptimizerFactory
from paddlespeech.s2t.training.reporter import ObsScope from paddlespeech.s2t.training.reporter import ObsScope
@ -428,8 +428,7 @@ class WavLMASRTrainer(Trainer):
report("epoch", self.epoch) report("epoch", self.epoch)
report('step', self.iteration) report('step', self.iteration)
report("model_lr", self.model_optimizer.get_lr()) report("model_lr", self.model_optimizer.get_lr())
report("wavlm_lr", report("wavlm_lr", self.wavlm_optimizer.get_lr())
self.wavlm_optimizer.get_lr())
self.train_batch(batch_index, batch, msg) self.train_batch(batch_index, batch, msg)
self.after_train_batch() self.after_train_batch()
report('iter', batch_index + 1) report('iter', batch_index + 1)
@ -680,8 +679,7 @@ class WavLMASRTrainer(Trainer):
logger.info("optim_model:{},{}", model_optim_type, model_optim_conf) logger.info("optim_model:{},{}", model_optim_type, model_optim_conf)
wavlm_optim_type = train_config.wavlm_optim wavlm_optim_type = train_config.wavlm_optim
wavlm_optim_conf = train_config.wavlm_optim_conf wavlm_optim_conf = train_config.wavlm_optim_conf
logger.info("optim_model:{},{}", wavlm_optim_type, logger.info("optim_model:{},{}", wavlm_optim_type, wavlm_optim_conf)
wavlm_optim_conf)
model_scheduler_type = train_config.model_scheduler model_scheduler_type = train_config.model_scheduler
model_scheduler_conf = train_config.model_scheduler_conf model_scheduler_conf = train_config.model_scheduler_conf
@ -698,8 +696,8 @@ class WavLMASRTrainer(Trainer):
model_lr_scheduler = LRSchedulerFactory.from_args(model_scheduler_type, model_lr_scheduler = LRSchedulerFactory.from_args(model_scheduler_type,
model_scheduler_args) model_scheduler_args)
wavlm_lr_scheduler = LRSchedulerFactory.from_args( wavlm_lr_scheduler = LRSchedulerFactory.from_args(wavlm_scheduler_type,
wavlm_scheduler_type, wavlm_scheduler_args) wavlm_scheduler_args)
def optimizer_args( def optimizer_args(
config, config,
@ -716,24 +714,31 @@ class WavLMASRTrainer(Trainer):
}) })
return optim_arg return optim_arg
model_optimizer_args = optimizer_args( model_optimizer_args = optimizer_args(config, model_optim_type,
config, model_optim_type, model_optim_conf, [{
model_optim_conf, 'params':
[{'params': model._layers.enc.parameters()}, {'params': model._layers.ctc.parameters()}] if self.parallel else [{'params': model.enc.parameters()}, {'params': model.ctc.parameters()}], model._layers.enc.parameters()
model_lr_scheduler }, {
) 'params':
# [{'params': model._layers.ctc.parameters()}] if self.parallel else [{'params': model.ctc.parameters()}], model_lr_scheduler) model._layers.ctc.parameters()
}] if self.parallel else [{
'params':
model.enc.parameters()
}, {
'params':
model.ctc.parameters()
}], model_lr_scheduler)
# [{'params': model._layers.ctc.parameters()}] if self.parallel else [{'params': model.ctc.parameters()}], model_lr_scheduler)
wavlm_optimizer_args = optimizer_args( wavlm_optimizer_args = optimizer_args(
config, wavlm_optim_type, wavlm_optim_conf, config, wavlm_optim_type, wavlm_optim_conf,
model._layers.wavlm.parameters() if self.parallel else model._layers.wavlm.parameters()
model.wavlm.parameters(), wavlm_lr_scheduler) if self.parallel else model.wavlm.parameters(), wavlm_lr_scheduler)
model_optimizer = OptimizerFactory.from_args(model_optim_type, model_optimizer = OptimizerFactory.from_args(model_optim_type,
model_optimizer_args) model_optimizer_args)
wavlm_optimizer = OptimizerFactory.from_args(wavlm_optim_type, wavlm_optimizer = OptimizerFactory.from_args(wavlm_optim_type,
wavlm_optimizer_args) wavlm_optimizer_args)
self.model_optimizer = model_optimizer self.model_optimizer = model_optimizer
self.wavlm_optimizer = wavlm_optimizer self.wavlm_optimizer = wavlm_optimizer

Loading…
Cancel
Save