|
|
@ -33,7 +33,7 @@ from paddlespeech.s2t.io.speechbrain import data_pipeline
|
|
|
|
from paddlespeech.s2t.io.speechbrain import dataio
|
|
|
|
from paddlespeech.s2t.io.speechbrain import dataio
|
|
|
|
from paddlespeech.s2t.io.speechbrain import dataset
|
|
|
|
from paddlespeech.s2t.io.speechbrain import dataset
|
|
|
|
from paddlespeech.s2t.io.speechbrain.dataloader import make_dataloader
|
|
|
|
from paddlespeech.s2t.io.speechbrain.dataloader import make_dataloader
|
|
|
|
from paddlespeech.s2t.models.wavlm.processing.speech_augmentation import TimeDomainSpecAugment
|
|
|
|
from paddlespeech.s2t.models.wav2vec2.processing.speech_augmentation import TimeDomainSpecAugment
|
|
|
|
from paddlespeech.s2t.models.wavlm.wavlm_asr import WavLMASR
|
|
|
|
from paddlespeech.s2t.models.wavlm.wavlm_asr import WavLMASR
|
|
|
|
from paddlespeech.s2t.training.optimizer import OptimizerFactory
|
|
|
|
from paddlespeech.s2t.training.optimizer import OptimizerFactory
|
|
|
|
from paddlespeech.s2t.training.reporter import ObsScope
|
|
|
|
from paddlespeech.s2t.training.reporter import ObsScope
|
|
|
@ -428,8 +428,7 @@ class WavLMASRTrainer(Trainer):
|
|
|
|
report("epoch", self.epoch)
|
|
|
|
report("epoch", self.epoch)
|
|
|
|
report('step', self.iteration)
|
|
|
|
report('step', self.iteration)
|
|
|
|
report("model_lr", self.model_optimizer.get_lr())
|
|
|
|
report("model_lr", self.model_optimizer.get_lr())
|
|
|
|
report("wavlm_lr",
|
|
|
|
report("wavlm_lr", self.wavlm_optimizer.get_lr())
|
|
|
|
self.wavlm_optimizer.get_lr())
|
|
|
|
|
|
|
|
self.train_batch(batch_index, batch, msg)
|
|
|
|
self.train_batch(batch_index, batch, msg)
|
|
|
|
self.after_train_batch()
|
|
|
|
self.after_train_batch()
|
|
|
|
report('iter', batch_index + 1)
|
|
|
|
report('iter', batch_index + 1)
|
|
|
@ -680,8 +679,7 @@ class WavLMASRTrainer(Trainer):
|
|
|
|
logger.info("optim_model:{},{}", model_optim_type, model_optim_conf)
|
|
|
|
logger.info("optim_model:{},{}", model_optim_type, model_optim_conf)
|
|
|
|
wavlm_optim_type = train_config.wavlm_optim
|
|
|
|
wavlm_optim_type = train_config.wavlm_optim
|
|
|
|
wavlm_optim_conf = train_config.wavlm_optim_conf
|
|
|
|
wavlm_optim_conf = train_config.wavlm_optim_conf
|
|
|
|
logger.info("optim_model:{},{}", wavlm_optim_type,
|
|
|
|
logger.info("optim_model:{},{}", wavlm_optim_type, wavlm_optim_conf)
|
|
|
|
wavlm_optim_conf)
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
model_scheduler_type = train_config.model_scheduler
|
|
|
|
model_scheduler_type = train_config.model_scheduler
|
|
|
|
model_scheduler_conf = train_config.model_scheduler_conf
|
|
|
|
model_scheduler_conf = train_config.model_scheduler_conf
|
|
|
@ -698,8 +696,8 @@ class WavLMASRTrainer(Trainer):
|
|
|
|
|
|
|
|
|
|
|
|
model_lr_scheduler = LRSchedulerFactory.from_args(model_scheduler_type,
|
|
|
|
model_lr_scheduler = LRSchedulerFactory.from_args(model_scheduler_type,
|
|
|
|
model_scheduler_args)
|
|
|
|
model_scheduler_args)
|
|
|
|
wavlm_lr_scheduler = LRSchedulerFactory.from_args(
|
|
|
|
wavlm_lr_scheduler = LRSchedulerFactory.from_args(wavlm_scheduler_type,
|
|
|
|
wavlm_scheduler_type, wavlm_scheduler_args)
|
|
|
|
wavlm_scheduler_args)
|
|
|
|
|
|
|
|
|
|
|
|
def optimizer_args(
|
|
|
|
def optimizer_args(
|
|
|
|
config,
|
|
|
|
config,
|
|
|
@ -716,19 +714,26 @@ class WavLMASRTrainer(Trainer):
|
|
|
|
})
|
|
|
|
})
|
|
|
|
return optim_arg
|
|
|
|
return optim_arg
|
|
|
|
|
|
|
|
|
|
|
|
model_optimizer_args = optimizer_args(
|
|
|
|
model_optimizer_args = optimizer_args(config, model_optim_type,
|
|
|
|
config, model_optim_type,
|
|
|
|
model_optim_conf, [{
|
|
|
|
model_optim_conf,
|
|
|
|
'params':
|
|
|
|
[{'params': model._layers.enc.parameters()}, {'params': model._layers.ctc.parameters()}] if self.parallel else [{'params': model.enc.parameters()}, {'params': model.ctc.parameters()}],
|
|
|
|
model._layers.enc.parameters()
|
|
|
|
model_lr_scheduler
|
|
|
|
}, {
|
|
|
|
)
|
|
|
|
'params':
|
|
|
|
|
|
|
|
model._layers.ctc.parameters()
|
|
|
|
|
|
|
|
}] if self.parallel else [{
|
|
|
|
|
|
|
|
'params':
|
|
|
|
|
|
|
|
model.enc.parameters()
|
|
|
|
|
|
|
|
}, {
|
|
|
|
|
|
|
|
'params':
|
|
|
|
|
|
|
|
model.ctc.parameters()
|
|
|
|
|
|
|
|
}], model_lr_scheduler)
|
|
|
|
# [{'params': model._layers.ctc.parameters()}] if self.parallel else [{'params': model.ctc.parameters()}], model_lr_scheduler)
|
|
|
|
# [{'params': model._layers.ctc.parameters()}] if self.parallel else [{'params': model.ctc.parameters()}], model_lr_scheduler)
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
wavlm_optimizer_args = optimizer_args(
|
|
|
|
wavlm_optimizer_args = optimizer_args(
|
|
|
|
config, wavlm_optim_type, wavlm_optim_conf,
|
|
|
|
config, wavlm_optim_type, wavlm_optim_conf,
|
|
|
|
model._layers.wavlm.parameters() if self.parallel else
|
|
|
|
model._layers.wavlm.parameters()
|
|
|
|
model.wavlm.parameters(), wavlm_lr_scheduler)
|
|
|
|
if self.parallel else model.wavlm.parameters(), wavlm_lr_scheduler)
|
|
|
|
|
|
|
|
|
|
|
|
model_optimizer = OptimizerFactory.from_args(model_optim_type,
|
|
|
|
model_optimizer = OptimizerFactory.from_args(model_optim_type,
|
|
|
|
model_optimizer_args)
|
|
|
|
model_optimizer_args)
|
|
|
|