|
|
|
@ -82,7 +82,6 @@ class U2Trainer(Trainer):
|
|
|
|
|
with context():
|
|
|
|
|
if scaler:
|
|
|
|
|
scaler.scale(loss).backward()
|
|
|
|
|
scaler.unscale_(self.optimizer)
|
|
|
|
|
else:
|
|
|
|
|
loss.backward()
|
|
|
|
|
layer_tools.print_grads(self.model, print_func=None)
|
|
|
|
@ -91,6 +90,8 @@ class U2Trainer(Trainer):
|
|
|
|
|
if (batch_index + 1) % train_conf.accum_grad == 0:
|
|
|
|
|
# do global grad clip
|
|
|
|
|
if train_conf.global_grad_clip != 0:
|
|
|
|
|
if scaler:
|
|
|
|
|
scaler.unscale_(self.optimizer)
|
|
|
|
|
# need paddlepaddle==develop or paddlepaddle>=2.5
|
|
|
|
|
clip_grad_norm_(self.model.parameters(),
|
|
|
|
|
train_conf.global_grad_clip)
|
|
|
|
|