|
|
@ -56,8 +56,8 @@ def print_params(model, print_func=print):
|
|
|
|
if print_func:
|
|
|
|
if print_func:
|
|
|
|
print_func(msg)
|
|
|
|
print_func(msg)
|
|
|
|
if print_func:
|
|
|
|
if print_func:
|
|
|
|
total = total / 1024**3
|
|
|
|
total = total / 1024**2
|
|
|
|
print_func(f"Total parameters: {num_params}, {total}G elements.")
|
|
|
|
print_func(f"Total parameters: {num_params}, {total:.4f}M elements.")
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
def gradient_norm(layer: nn.Layer):
|
|
|
|
def gradient_norm(layer: nn.Layer):
|
|
|
|