diff --git a/deepspeech/exps/deepspeech2/model.py b/deepspeech/exps/deepspeech2/model.py index c414ff130..3db144341 100644 --- a/deepspeech/exps/deepspeech2/model.py +++ b/deepspeech/exps/deepspeech2/model.py @@ -31,12 +31,8 @@ from deepspeech.training import Trainer from deepspeech.training.gradclip import MyClipGradByGlobalNorm from deepspeech.utils import mp_tools -from deepspeech.utils.error_rate import char_errors -from deepspeech.utils.error_rate import word_errors -from deepspeech.utils.error_rate import cer -from deepspeech.utils.error_rate import wer -from deepspeech.utils.utility import print_grads -from deepspeech.utils.utility import print_params +from deepspeech.utils import layer_tools +from deepspeech.utils import error_rate from deepspeech.io.collator import SpeechCollator from deepspeech.io.sampler import SortagradDistributedBatchSampler @@ -59,7 +55,7 @@ class DeepSpeech2Trainer(Trainer): self.model.train() loss = self.model(*batch_data) loss.backward() - print_grads(self.model, logger=None) + layer_tools.print_grads(self.model, print_func=None) self.optimizer.step() self.optimizer.clear_grad() @@ -127,7 +123,7 @@ class DeepSpeech2Trainer(Trainer): if self.parallel: model = paddle.DataParallel(model) - print_params(model, self.logger) + layer_tools.print_params(model, self.logger.info) grad_clip = MyClipGradByGlobalNorm(config.training.global_grad_clip) lr_scheduler = paddle.optimizer.lr.ExponentialDecay( @@ -237,8 +233,8 @@ class DeepSpeech2Tester(DeepSpeech2Trainer): def compute_metrics(self, audio, texts, audio_len, texts_len): cfg = self.config.decoding errors_sum, len_refs, num_ins = 0.0, 0, 0 - errors_func = char_errors if cfg.error_rate_type == 'cer' else word_errors - error_rate_func = cer if cfg.error_rate_type == 'cer' else wer + errors_func = error_rate.char_errors if cfg.error_rate_type == 'cer' else error_rate.word_errors + error_rate_func = error_rate.cer if cfg.error_rate_type == 'cer' else error_rate.wer vocab_list = self.test_loader.dataset.vocab_list diff --git a/deepspeech/models/deepspeech2.py b/deepspeech/models/deepspeech2.py index 374d15ef4..1173c852d 100644 --- a/deepspeech/models/deepspeech2.py +++ b/deepspeech/models/deepspeech2.py @@ -382,19 +382,45 @@ class DeepSpeech2Model(nn.Layer): """Build a model from a pretrained model. Parameters ---------- - model: nn.Layer - Asr Model. - checkpoint_path: Path or str The path of pretrained model checkpoint, without extension name. Returns ------- - Model + DeepSpeech2Model The model build from pretrined result. """ checkpoint.load_parameters(self, checkpoint_path=checkpoint_path) - return + return self + + @classmethod + def from_pretrained(cls, dataset, config, checkpoint_path): + """Build a DeepSpeech2Model model from a pretrained model. + Parameters + ---------- + dataset: paddle.io.Dataset + + config: yacs.config.CfgNode + model configs + + checkpoint_path: Path or str + the path of pretrained model checkpoint, without extension name + + Returns + ------- + DeepSpeech2Model + The model built from pretrained result. + """ + model = cls(feat_size=dataset.feature_size, + dict_size=dataset.vocab_size, + num_conv_layers=config.model.num_conv_layers, + num_rnn_layers=config.model.num_rnn_layers, + rnn_size=config.model.rnn_layer_size, + use_gru=config.model.use_gru, + share_rnn_weights=config.model.share_rnn_weights) + model.from_pretrained(checkpoint_path) + layer_tools.summary(model) + return model class DeepSpeech2InferModel(DeepSpeech2Model): diff --git a/deepspeech/utils/layer_tools.py b/deepspeech/utils/layer_tools.py new file mode 100644 index 000000000..46a354761 --- /dev/null +++ b/deepspeech/utils/layer_tools.py @@ -0,0 +1,78 @@ +# Copyright (c) 2021 PaddlePaddle Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import numpy as np +from paddle import nn + +__all__ = [ + "summary", "gradient_norm", "freeze", "unfreeze", "print_grads", + "print_params" +] + + +def summary(layer: nn.Layer, print_func=print): + num_params = num_elements = 0 + print_func("layer summary:") + for name, param in layer.state_dict().items(): + print_func("{}|{}|{}".format(name, param.shape, np.prod(param.shape))) + num_elements += np.prod(param.shape) + num_params += 1 + print_func("layer has {} parameters, {} elements.".format(num_params, + num_elements)) + + +def gradient_norm(layer: nn.Layer): + grad_norm_dict = {} + for name, param in layer.state_dict().items(): + if param.trainable: + grad = param.gradient() + grad_norm_dict[name] = np.linalg.norm(grad) / grad.size + return grad_norm_dict + + +def recursively_remove_weight_norm(layer: nn.Layer): + for layer in layer.sublayers(): + try: + nn.utils.remove_weight_norm(layer) + except: + # ther is not weight norm hoom in this layer + pass + + +def freeze(layer: nn.Layer): + for param in layer.parameters(): + param.trainable = False + + +def unfreeze(layer: nn.Layer): + for param in layer.parameters(): + param.trainable = True + + +def print_grads(model, print_func=print): + for n, p in model.named_parameters(): + msg = f"param grad: {n}: shape: {p.shape} grad: {p.grad}" + if print_func: + print_func(msg) + + +def print_params(model, print_func=print): + total = 0.0 + for n, p in model.named_parameters(): + msg = f"param: {n}: shape: {p.shape} stop_grad: {p.stop_gradient}" + total += np.prod(p.shape) + if print_func: + print_func(msg) + if print_func: + print_func(f"Total parameters: {total}!") diff --git a/deepspeech/utils/utility.py b/deepspeech/utils/utility.py index 7892f9150..72a45e29a 100644 --- a/deepspeech/utils/utility.py +++ b/deepspeech/utils/utility.py @@ -16,7 +16,7 @@ import numpy as np import distutils.util -__all__ = ['print_arguments', 'add_arguments', 'print_grads', 'print_params'] +__all__ = ['print_arguments', 'add_arguments'] def print_arguments(args): @@ -57,22 +57,4 @@ def add_arguments(argname, type, default, help, argparser, **kwargs): default=default, type=type, help=help + ' Default: %(default)s.', - **kwargs) - - -def print_grads(model, logger=None): - for n, p in model.named_parameters(): - msg = f"param grad: {n}: shape: {p.shape} grad: {p.grad}" - if logger: - logger.info(msg) - - -def print_params(model, logger=None): - total = 0.0 - for n, p in model.named_parameters(): - msg = f"param: {n}: shape: {p.shape} stop_grad: {p.stop_gradient}" - total += np.prod(p.shape) - if logger: - logger.info(msg) - if logger: - logger.info(f"Total parameters: {total}!") + **kwargs) \ No newline at end of file