@ -23,6 +23,7 @@ import jsonlines
import numpy as np
import paddle
from paddle import distributed as dist
from paddle . nn . utils import clip_grad_norm_
from paddlespeech . s2t . frontend . featurizer import TextFeaturizer
from paddlespeech . s2t . io . dataloader import DataLoaderFactory
@ -47,14 +48,16 @@ class U2Trainer(Trainer):
def __init__ ( self , config , args ) :
super ( ) . __init__ ( config , args )
def train_batch ( self , batch_index , batch_data , msg) :
def train_batch ( self , batch_index , batch_data , scaler, msg) :
train_conf = self . config
start = time . time ( )
# forward
utt , audio , audio_len , text , text_len = batch_data
loss , attention_loss , ctc_loss = self . model ( audio , audio_len , text ,
text_len )
with paddle . amp . auto_cast (
level = self . amp_level , enable = True if scaler else False ) :
loss , attention_loss , ctc_loss = self . model ( audio , audio_len , text ,
text_len )
# loss div by `batch_size * accum_grad`
loss / = train_conf . accum_grad
@ -77,12 +80,26 @@ class U2Trainer(Trainer):
# processes.
context = nullcontext
with context ( ) :
loss . backward ( )
if scaler :
scaler . scale ( loss ) . backward ( )
else :
loss . backward ( )
layer_tools . print_grads ( self . model , print_func = None )
# optimizer step
if ( batch_index + 1 ) % train_conf . accum_grad == 0 :
self . optimizer . step ( )
# do global grad clip
if train_conf . global_grad_clip != 0 :
if scaler :
scaler . unscale_ ( self . optimizer )
# need paddlepaddle==develop or paddlepaddle>=2.5
clip_grad_norm_ ( self . model . parameters ( ) ,
train_conf . global_grad_clip )
if scaler :
scaler . step ( self . optimizer )
scaler . update ( )
else :
self . optimizer . step ( )
self . optimizer . clear_grad ( )
self . lr_scheduler . step ( )
self . iteration + = 1
@ -173,7 +190,8 @@ class U2Trainer(Trainer):
report ( " epoch " , self . epoch )
report ( ' step ' , self . iteration )
report ( " lr " , self . lr_scheduler ( ) )
self . train_batch ( batch_index , batch , msg )
self . train_batch ( batch_index , batch , self . scaler ,
msg )
self . after_train_batch ( )
report ( ' iter ' , batch_index + 1 )
if not self . use_streamdata :
@ -253,6 +271,19 @@ class U2Trainer(Trainer):
model_conf . output_dim = self . test_loader . vocab_size
model = U2Model . from_config ( model_conf )
# For Mixed Precision Training
self . use_amp = self . config . get ( " use_amp " , True )
self . amp_level = self . config . get ( " amp_level " , " O1 " )
if self . train and self . use_amp :
self . scaler = paddle . amp . GradScaler (
init_loss_scaling = self . config . get (
" scale_loss " , 32768.0 ) ) #amp default num 32768.0
#Set amp_level
if self . amp_level == ' O2 ' :
model = paddle . amp . decorate ( models = model , level = self . amp_level )
else :
self . scaler = None
if self . parallel :
model = paddle . DataParallel ( model )
@ -290,7 +321,6 @@ class U2Trainer(Trainer):
scheduler_type = train_config . scheduler
scheduler_conf = train_config . scheduler_conf
return {
" grad_clip " : train_config . global_grad_clip ,
" weight_decay " : optim_conf . weight_decay ,
" learning_rate " : lr_scheduler
if lr_scheduler else optim_conf . lr ,