fix layer tools

pull/578/head
Hui Zhang 4 years ago
parent b5339633e3
commit ec3481477d

@ -56,8 +56,8 @@ def print_params(model, print_func=print):
if print_func: if print_func:
print_func(msg) print_func(msg)
if print_func: if print_func:
total = total / 1024**3 total = total / 1024**2
print_func(f"Total parameters: {num_params}, {total}G elements.") print_func(f"Total parameters: {num_params}, {total:.4f}M elements.")
def gradient_norm(layer: nn.Layer): def gradient_norm(layer: nn.Layer):

Loading…
Cancel
Save