|
|
|
@ -33,7 +33,7 @@ def summary(layer: nn.Layer, print_func=print):
|
|
|
|
|
if print_func:
|
|
|
|
|
num_elements = num_elements / 1024**2
|
|
|
|
|
print_func(
|
|
|
|
|
f"Total parameters: {num_params}, {num_elements:.4f}M elements.")
|
|
|
|
|
f"Total parameters: {num_params}, {num_elements:.2f} M elements.")
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
def print_grads(model, print_func=print):
|
|
|
|
@ -57,7 +57,7 @@ def print_params(model, print_func=print):
|
|
|
|
|
print_func(msg)
|
|
|
|
|
if print_func:
|
|
|
|
|
total = total / 1024**2
|
|
|
|
|
print_func(f"Total parameters: {num_params}, {total:.4f}M elements.")
|
|
|
|
|
print_func(f"Total parameters: {num_params}, {total:.2f} M elements.")
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
def gradient_norm(layer: nn.Layer):
|
|
|
|
|