fix tb logger

pull/578/head
Hui Zhang 4 years ago
parent 16b8b9821c
commit 77e5641a91

@ -88,10 +88,6 @@ class U2Trainer(Trainer):
losses_np['ctc_loss'] = float(ctc_loss) losses_np['ctc_loss'] = float(ctc_loss)
if (batch_index + 1) % train_conf.accum_grad == 0: if (batch_index + 1) % train_conf.accum_grad == 0:
if dist.get_rank() == 0 and self.visualizer:
losses_np_v = losses_np.copy()
losses_np_v.update({"lr": self.lr_scheduler()})
self.visualizer.add_scalars("step", losses_np_v, self.iteration)
self.optimizer.step() self.optimizer.step()
self.optimizer.clear_grad() self.optimizer.clear_grad()
self.lr_scheduler.step() self.lr_scheduler.step()
@ -107,6 +103,12 @@ class U2Trainer(Trainer):
for k, v in losses_np.items()) for k, v in losses_np.items())
logger.info(msg) logger.info(msg)
if dist.get_rank() == 0 and self.visualizer:
losses_np_v = losses_np.copy()
losses_np_v.update({"lr": self.lr_scheduler()})
self.visualizer.add_scalars("step", losses_np_v,
self.iteration - 1)
def train(self): def train(self):
"""The training process control by step.""" """The training process control by step."""
# !!!IMPORTANT!!! # !!!IMPORTANT!!!

@ -46,8 +46,8 @@ def _load_latest_checkpoint(checkpoint_dir: str) -> int:
return iteration return iteration
def _save_checkpoint(checkpoint_dir: str, iteration: int): def _save_record(checkpoint_dir: str, iteration: int):
"""Save the iteration number of the latest model to be checkpointed. """Save the iteration number of the latest model to be checkpoint record.
Args: Args:
checkpoint_dir (str): the directory where checkpoint is saved. checkpoint_dir (str): the directory where checkpoint is saved.
iteration (int): the latest iteration number. iteration (int): the latest iteration number.
@ -149,4 +149,4 @@ def save_parameters(checkpoint_dir: str,
fout.write(data) fout.write(data)
if isinstance(tag_or_iteration, int): if isinstance(tag_or_iteration, int):
_save_checkpoint(checkpoint_dir, tag_or_iteration) _save_record(checkpoint_dir, tag_or_iteration)

@ -21,6 +21,8 @@ __all__ = [
def summary(layer: nn.Layer, print_func=print): def summary(layer: nn.Layer, print_func=print):
if print_func is None:
return
num_params = num_elements = 0 num_params = num_elements = 0
for name, param in layer.state_dict().items(): for name, param in layer.state_dict().items():
if print_func: if print_func:
@ -32,15 +34,6 @@ def summary(layer: nn.Layer, print_func=print):
print_func(f"Total parameters: {num_params}, {num_elements} elements.") print_func(f"Total parameters: {num_params}, {num_elements} elements.")
def gradient_norm(layer: nn.Layer):
grad_norm_dict = {}
for name, param in layer.state_dict().items():
if param.trainable:
grad = param.gradient() # return numpy.ndarray
grad_norm_dict[name] = np.linalg.norm(grad) / grad.size
return grad_norm_dict
def print_grads(model, print_func=print): def print_grads(model, print_func=print):
if print_func is None: if print_func is None:
return return
@ -64,6 +57,15 @@ def print_params(model, print_func=print):
print_func(f"Total parameters: {num_params}, {total} elements.") print_func(f"Total parameters: {num_params}, {total} elements.")
def gradient_norm(layer: nn.Layer):
grad_norm_dict = {}
for name, param in layer.state_dict().items():
if param.trainable:
grad = param.gradient() # return numpy.ndarray
grad_norm_dict[name] = np.linalg.norm(grad) / grad.size
return grad_norm_dict
def recursively_remove_weight_norm(layer: nn.Layer): def recursively_remove_weight_norm(layer: nn.Layer):
for layer in layer.sublayers(): for layer in layer.sublayers():
try: try:

Loading…
Cancel
Save