|
|
@ -106,7 +106,8 @@ class U2Trainer(Trainer):
|
|
|
|
# Within this context, gradients will be accumulated on module
|
|
|
|
# Within this context, gradients will be accumulated on module
|
|
|
|
# variables, which will later be synchronized.
|
|
|
|
# variables, which will later be synchronized.
|
|
|
|
# When using cpu w/o DDP, model does not have `no_sync`
|
|
|
|
# When using cpu w/o DDP, model does not have `no_sync`
|
|
|
|
context = self.model.no_sync if self.parallel else nullcontext
|
|
|
|
context = self.model.no_sync if (hasattr(self.model, "no_sync") and
|
|
|
|
|
|
|
|
self.parallel) else nullcontext
|
|
|
|
else:
|
|
|
|
else:
|
|
|
|
# Used for single gpu training and DDP gradient synchronization
|
|
|
|
# Used for single gpu training and DDP gradient synchronization
|
|
|
|
# processes.
|
|
|
|
# processes.
|
|
|
|