@ -33,7 +33,7 @@ from paddlespeech.s2t.io.speechbrain import data_pipeline
from paddlespeech . s2t . io . speechbrain import dataio
from paddlespeech . s2t . io . speechbrain import dataset
from paddlespeech . s2t . io . speechbrain . dataloader import make_dataloader
from paddlespeech . s2t . models . wav lm . processing . speech_augmentation import TimeDomainSpecAugment
from paddlespeech . s2t . models . wav 2vec2 . processing . speech_augmentation import TimeDomainSpecAugment
from paddlespeech . s2t . models . wavlm . wavlm_asr import WavLMASR
from paddlespeech . s2t . training . optimizer import OptimizerFactory
from paddlespeech . s2t . training . reporter import ObsScope
@ -211,7 +211,7 @@ class WavLMASRTrainer(Trainer):
loss . backward ( )
layer_tools . print_grads ( self . model , print_func = None )
# NOTE: the code below asserted that the backward() is problematic, and as more steps are accumulated, the output from wavlm alone will be the same for all frames
# optimizer step old
if ( batch_index + 1 ) % train_conf . accum_grad == 0 :
@ -428,8 +428,7 @@ class WavLMASRTrainer(Trainer):
report ( " epoch " , self . epoch )
report ( ' step ' , self . iteration )
report ( " model_lr " , self . model_optimizer . get_lr ( ) )
report ( " wavlm_lr " ,
self . wavlm_optimizer . get_lr ( ) )
report ( " wavlm_lr " , self . wavlm_optimizer . get_lr ( ) )
self . train_batch ( batch_index , batch , msg )
self . after_train_batch ( )
report ( ' iter ' , batch_index + 1 )
@ -680,8 +679,7 @@ class WavLMASRTrainer(Trainer):
logger . info ( " optim_model: {} , {} " , model_optim_type , model_optim_conf )
wavlm_optim_type = train_config . wavlm_optim
wavlm_optim_conf = train_config . wavlm_optim_conf
logger . info ( " optim_model: {} , {} " , wavlm_optim_type ,
wavlm_optim_conf )
logger . info ( " optim_model: {} , {} " , wavlm_optim_type , wavlm_optim_conf )
model_scheduler_type = train_config . model_scheduler
model_scheduler_conf = train_config . model_scheduler_conf
@ -698,8 +696,8 @@ class WavLMASRTrainer(Trainer):
model_lr_scheduler = LRSchedulerFactory . from_args ( model_scheduler_type ,
model_scheduler_args )
wavlm_lr_scheduler = LRSchedulerFactory . from_args (
wavlm_scheduler_type , wavlm_scheduler_args )
wavlm_lr_scheduler = LRSchedulerFactory . from_args ( wavlm_scheduler_type ,
wavlm_scheduler_args )
def optimizer_args (
config ,
@ -716,24 +714,31 @@ class WavLMASRTrainer(Trainer):
} )
return optim_arg
model_optimizer_args = optimizer_args (
config , model_optim_type ,
model_optim_conf ,
[ { ' params ' : model . _layers . enc . parameters ( ) } , { ' params ' : model . _layers . ctc . parameters ( ) } ] if self . parallel else [ { ' params ' : model . enc . parameters ( ) } , { ' params ' : model . ctc . parameters ( ) } ] ,
model_lr_scheduler
)
# [{'params': model._layers.ctc.parameters()}] if self.parallel else [{'params': model.ctc.parameters()}], model_lr_scheduler)
model_optimizer_args = optimizer_args ( config , model_optim_type ,
model_optim_conf , [ {
' params ' :
model . _layers . enc . parameters ( )
} , {
' params ' :
model . _layers . ctc . parameters ( )
} ] if self . parallel else [ {
' params ' :
model . enc . parameters ( )
} , {
' params ' :
model . ctc . parameters ( )
} ] , model_lr_scheduler )
# [{'params': model._layers.ctc.parameters()}] if self.parallel else [{'params': model.ctc.parameters()}], model_lr_scheduler)
wavlm_optimizer_args = optimizer_args (
config , wavlm_optim_type , wavlm_optim_conf ,
model . _layers . wavlm . parameters ( ) if self . parallel else
model . wavlm . parameters ( ) , wavlm_lr_scheduler )
model . _layers . wavlm . parameters ( )
if self . parallel else model . wavlm . parameters ( ) , wavlm_lr_scheduler )
model_optimizer = OptimizerFactory . from_args ( model_optim_type ,
model_optimizer_args )
wavlm_optimizer = OptimizerFactory . from_args ( wavlm_optim_type ,
wavlm_optimizer_args )
wavlm_optimizer_args )
self . model_optimizer = model_optimizer
self . wavlm_optimizer = wavlm_optimizer