fix scaler save and load.

pull/3167/head
zxcd 1 year ago
parent 2f4414a5f8
commit 7399d560e7

@ -82,6 +82,7 @@ class U2Trainer(Trainer):
with context():
if scaler:
scaler.scale(loss).backward()
scaler.unscale_(self.optimizer)
else:
loss.backward()
layer_tools.print_grads(self.model, print_func=None)

@ -11,6 +11,7 @@
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import os
import sys
import time
from collections import OrderedDict
@ -189,8 +190,12 @@ class Trainer():
"step": self.iteration,
"epoch": self.epoch,
"lr": self.optimizer.get_lr(),
"scaler": self.scaler.state_dict()
})
if self.scaler:
scaler_path = os.path.join(self.checkpoint_dir,
"{}".format(self.epoch)) + '.scaler'
paddle.save(self.scaler.state_dict(), scaler_path)
self.checkpoint.save_parameters(self.checkpoint_dir, self.iteration
if tag is None else tag, self.model,
self.optimizer, infos)
@ -213,8 +218,13 @@ class Trainer():
# lr will resotre from optimizer ckpt
self.iteration = infos["step"]
self.epoch = infos["epoch"]
self.scaler = paddle.amp.GradScaler()
self.scaler.load_state_dict(infos["scaler"])
scaler_path = os.path.join(self.checkpoint_dir,
"{}".format(self.epoch)) + '.scaler'
if os.path.exists(scaler_path):
scaler_state_dict = paddle.load(scaler_path)
self.scaler.load_state_dict(scaler_state_dict)
scratch = False
logger.info(
f"Restore ckpt: epoch {self.epoch }, step {self.iteration}!")

Loading…
Cancel
Save