@ -56,8 +56,8 @@ def print_params(model, print_func=print):
if print_func:
print_func(msg)
total = total / 1024**3
print_func(f"Total parameters: {num_params}, {total}G elements.")
total = total / 1024**2
print_func(f"Total parameters: {num_params}, {total:.4f}M elements.")
def gradient_norm(layer: nn.Layer):