add layer tools

pull/538/head
Hui Zhang 5 years ago
parent b46743f446
commit ac6a4da2e0

@ -31,12 +31,8 @@ from deepspeech.training import Trainer
from deepspeech.training.gradclip import MyClipGradByGlobalNorm
from deepspeech.utils import mp_tools
from deepspeech.utils.error_rate import char_errors
from deepspeech.utils.error_rate import word_errors
from deepspeech.utils.error_rate import cer
from deepspeech.utils.error_rate import wer
from deepspeech.utils.utility import print_grads
from deepspeech.utils.utility import print_params
from deepspeech.utils import layer_tools
from deepspeech.utils import error_rate
from deepspeech.io.collator import SpeechCollator
from deepspeech.io.sampler import SortagradDistributedBatchSampler
@ -59,7 +55,7 @@ class DeepSpeech2Trainer(Trainer):
self.model.train()
loss = self.model(*batch_data)
loss.backward()
print_grads(self.model, logger=None)
layer_tools.print_grads(self.model, print_func=None)
self.optimizer.step()
self.optimizer.clear_grad()
@ -127,7 +123,7 @@ class DeepSpeech2Trainer(Trainer):
if self.parallel:
model = paddle.DataParallel(model)
print_params(model, self.logger)
layer_tools.print_params(model, self.logger.info)
grad_clip = MyClipGradByGlobalNorm(config.training.global_grad_clip)
lr_scheduler = paddle.optimizer.lr.ExponentialDecay(
@ -237,8 +233,8 @@ class DeepSpeech2Tester(DeepSpeech2Trainer):
def compute_metrics(self, audio, texts, audio_len, texts_len):
cfg = self.config.decoding
errors_sum, len_refs, num_ins = 0.0, 0, 0
errors_func = char_errors if cfg.error_rate_type == 'cer' else word_errors
error_rate_func = cer if cfg.error_rate_type == 'cer' else wer
errors_func = error_rate.char_errors if cfg.error_rate_type == 'cer' else error_rate.word_errors
error_rate_func = error_rate.cer if cfg.error_rate_type == 'cer' else error_rate.wer
vocab_list = self.test_loader.dataset.vocab_list

@ -382,19 +382,45 @@ class DeepSpeech2Model(nn.Layer):
"""Build a model from a pretrained model.
Parameters
----------
model: nn.Layer
Asr Model.
checkpoint_path: Path or str
The path of pretrained model checkpoint, without extension name.
Returns
-------
Model
DeepSpeech2Model
The model build from pretrined result.
"""
checkpoint.load_parameters(self, checkpoint_path=checkpoint_path)
return
return self
@classmethod
def from_pretrained(cls, dataset, config, checkpoint_path):
"""Build a DeepSpeech2Model model from a pretrained model.
Parameters
----------
dataset: paddle.io.Dataset
config: yacs.config.CfgNode
model configs
checkpoint_path: Path or str
the path of pretrained model checkpoint, without extension name
Returns
-------
DeepSpeech2Model
The model built from pretrained result.
"""
model = cls(feat_size=dataset.feature_size,
dict_size=dataset.vocab_size,
num_conv_layers=config.model.num_conv_layers,
num_rnn_layers=config.model.num_rnn_layers,
rnn_size=config.model.rnn_layer_size,
use_gru=config.model.use_gru,
share_rnn_weights=config.model.share_rnn_weights)
model.from_pretrained(checkpoint_path)
layer_tools.summary(model)
return model
class DeepSpeech2InferModel(DeepSpeech2Model):

@ -0,0 +1,78 @@
# Copyright (c) 2021 PaddlePaddle Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import numpy as np
from paddle import nn
__all__ = [
"summary", "gradient_norm", "freeze", "unfreeze", "print_grads",
"print_params"
]
def summary(layer: nn.Layer, print_func=print):
num_params = num_elements = 0
print_func("layer summary:")
for name, param in layer.state_dict().items():
print_func("{}|{}|{}".format(name, param.shape, np.prod(param.shape)))
num_elements += np.prod(param.shape)
num_params += 1
print_func("layer has {} parameters, {} elements.".format(num_params,
num_elements))
def gradient_norm(layer: nn.Layer):
grad_norm_dict = {}
for name, param in layer.state_dict().items():
if param.trainable:
grad = param.gradient()
grad_norm_dict[name] = np.linalg.norm(grad) / grad.size
return grad_norm_dict
def recursively_remove_weight_norm(layer: nn.Layer):
for layer in layer.sublayers():
try:
nn.utils.remove_weight_norm(layer)
except:
# ther is not weight norm hoom in this layer
pass
def freeze(layer: nn.Layer):
for param in layer.parameters():
param.trainable = False
def unfreeze(layer: nn.Layer):
for param in layer.parameters():
param.trainable = True
def print_grads(model, print_func=print):
for n, p in model.named_parameters():
msg = f"param grad: {n}: shape: {p.shape} grad: {p.grad}"
if print_func:
print_func(msg)
def print_params(model, print_func=print):
total = 0.0
for n, p in model.named_parameters():
msg = f"param: {n}: shape: {p.shape} stop_grad: {p.stop_gradient}"
total += np.prod(p.shape)
if print_func:
print_func(msg)
if print_func:
print_func(f"Total parameters: {total}!")

@ -16,7 +16,7 @@
import numpy as np
import distutils.util
__all__ = ['print_arguments', 'add_arguments', 'print_grads', 'print_params']
__all__ = ['print_arguments', 'add_arguments']
def print_arguments(args):
@ -57,22 +57,4 @@ def add_arguments(argname, type, default, help, argparser, **kwargs):
default=default,
type=type,
help=help + ' Default: %(default)s.',
**kwargs)
def print_grads(model, logger=None):
for n, p in model.named_parameters():
msg = f"param grad: {n}: shape: {p.shape} grad: {p.grad}"
if logger:
logger.info(msg)
def print_params(model, logger=None):
total = 0.0
for n, p in model.named_parameters():
msg = f"param: {n}: shape: {p.shape} stop_grad: {p.stop_gradient}"
total += np.prod(p.shape)
if logger:
logger.info(msg)
if logger:
logger.info(f"Total parameters: {total}!")
**kwargs)
Loading…
Cancel
Save