mv scaler.unscale_ blow grad_clip.

pull/3167/head
zxcd 2 years ago
parent 7399d560e7
commit a1e5f27003

@ -82,7 +82,6 @@ class U2Trainer(Trainer):
with context(): with context():
if scaler: if scaler:
scaler.scale(loss).backward() scaler.scale(loss).backward()
scaler.unscale_(self.optimizer)
else: else:
loss.backward() loss.backward()
layer_tools.print_grads(self.model, print_func=None) layer_tools.print_grads(self.model, print_func=None)
@ -91,6 +90,8 @@ class U2Trainer(Trainer):
if (batch_index + 1) % train_conf.accum_grad == 0: if (batch_index + 1) % train_conf.accum_grad == 0:
# do global grad clip # do global grad clip
if train_conf.global_grad_clip != 0: if train_conf.global_grad_clip != 0:
if scaler:
scaler.unscale_(self.optimizer)
# need paddlepaddle==develop or paddlepaddle>=2.5 # need paddlepaddle==develop or paddlepaddle>=2.5
clip_grad_norm_(self.model.parameters(), clip_grad_norm_(self.model.parameters(),
train_conf.global_grad_clip) train_conf.global_grad_clip)

Loading…
Cancel
Save