[Fix] import TimeDomainSpecAugment (#3919)

pull/3923/head
megemini 3 weeks ago committed by GitHub
parent 5b3612f273
commit 890c87ea93
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194

@ -33,7 +33,7 @@ from paddlespeech.s2t.io.speechbrain import data_pipeline
from paddlespeech.s2t.io.speechbrain import dataio
from paddlespeech.s2t.io.speechbrain import dataset
from paddlespeech.s2t.io.speechbrain.dataloader import make_dataloader
from paddlespeech.s2t.models.wavlm.processing.speech_augmentation import TimeDomainSpecAugment
from paddlespeech.s2t.models.wav2vec2.processing.speech_augmentation import TimeDomainSpecAugment
from paddlespeech.s2t.models.wavlm.wavlm_asr import WavLMASR
from paddlespeech.s2t.training.optimizer import OptimizerFactory
from paddlespeech.s2t.training.reporter import ObsScope
@ -211,7 +211,7 @@ class WavLMASRTrainer(Trainer):
loss.backward()
layer_tools.print_grads(self.model, print_func=None)
# NOTE: the code below asserted that the backward() is problematic, and as more steps are accumulated, the output from wavlm alone will be the same for all frames
# optimizer step old
if (batch_index + 1) % train_conf.accum_grad == 0:
@ -428,8 +428,7 @@ class WavLMASRTrainer(Trainer):
report("epoch", self.epoch)
report('step', self.iteration)
report("model_lr", self.model_optimizer.get_lr())
report("wavlm_lr",
self.wavlm_optimizer.get_lr())
report("wavlm_lr", self.wavlm_optimizer.get_lr())
self.train_batch(batch_index, batch, msg)
self.after_train_batch()
report('iter', batch_index + 1)
@ -680,8 +679,7 @@ class WavLMASRTrainer(Trainer):
logger.info("optim_model:{},{}", model_optim_type, model_optim_conf)
wavlm_optim_type = train_config.wavlm_optim
wavlm_optim_conf = train_config.wavlm_optim_conf
logger.info("optim_model:{},{}", wavlm_optim_type,
wavlm_optim_conf)
logger.info("optim_model:{},{}", wavlm_optim_type, wavlm_optim_conf)
model_scheduler_type = train_config.model_scheduler
model_scheduler_conf = train_config.model_scheduler_conf
@ -698,8 +696,8 @@ class WavLMASRTrainer(Trainer):
model_lr_scheduler = LRSchedulerFactory.from_args(model_scheduler_type,
model_scheduler_args)
wavlm_lr_scheduler = LRSchedulerFactory.from_args(
wavlm_scheduler_type, wavlm_scheduler_args)
wavlm_lr_scheduler = LRSchedulerFactory.from_args(wavlm_scheduler_type,
wavlm_scheduler_args)
def optimizer_args(
config,
@ -716,24 +714,31 @@ class WavLMASRTrainer(Trainer):
})
return optim_arg
model_optimizer_args = optimizer_args(
config, model_optim_type,
model_optim_conf,
[{'params': model._layers.enc.parameters()}, {'params': model._layers.ctc.parameters()}] if self.parallel else [{'params': model.enc.parameters()}, {'params': model.ctc.parameters()}],
model_lr_scheduler
)
# [{'params': model._layers.ctc.parameters()}] if self.parallel else [{'params': model.ctc.parameters()}], model_lr_scheduler)
model_optimizer_args = optimizer_args(config, model_optim_type,
model_optim_conf, [{
'params':
model._layers.enc.parameters()
}, {
'params':
model._layers.ctc.parameters()
}] if self.parallel else [{
'params':
model.enc.parameters()
}, {
'params':
model.ctc.parameters()
}], model_lr_scheduler)
# [{'params': model._layers.ctc.parameters()}] if self.parallel else [{'params': model.ctc.parameters()}], model_lr_scheduler)
wavlm_optimizer_args = optimizer_args(
config, wavlm_optim_type, wavlm_optim_conf,
model._layers.wavlm.parameters() if self.parallel else
model.wavlm.parameters(), wavlm_lr_scheduler)
model._layers.wavlm.parameters()
if self.parallel else model.wavlm.parameters(), wavlm_lr_scheduler)
model_optimizer = OptimizerFactory.from_args(model_optim_type,
model_optimizer_args)
wavlm_optimizer = OptimizerFactory.from_args(wavlm_optim_type,
wavlm_optimizer_args)
wavlm_optimizer_args)
self.model_optimizer = model_optimizer
self.wavlm_optimizer = wavlm_optimizer

Loading…
Cancel
Save