diff --git a/deepspeech/__init__.py b/deepspeech/__init__.py index ed209f3d..5762e635 100644 --- a/deepspeech/__init__.py +++ b/deepspeech/__init__.py @@ -355,7 +355,6 @@ if not hasattr(paddle.Tensor, 'tolist'): "register user tolist to paddle.Tensor, remove this when fixed!") setattr(paddle.Tensor, 'tolist', tolist) - ########### hack paddle.nn ############# from paddle.nn import Layer from typing import Optional @@ -506,5 +505,3 @@ if not hasattr(paddle.nn, 'LayerDict'): logger.debug( "register user LayerDict to paddle.nn, remove this when fixed!") setattr(paddle.nn, 'LayerDict', LayerDict) - - diff --git a/deepspeech/decoders/recog.py b/deepspeech/decoders/recog.py index 6868bc00..dae3cd42 100644 --- a/deepspeech/decoders/recog.py +++ b/deepspeech/decoders/recog.py @@ -12,12 +12,8 @@ # See the License for the specific language governing permissions and # limitations under the License. """V2 backend for `asr_recog.py` using py:class:`decoders.beam_search.BeamSearch`.""" -import json -from pathlib import Path - import jsonlines import paddle -import yaml from yacs.config import CfgNode from .beam_search import BatchBeamSearch diff --git a/deepspeech/decoders/recog_bin.py b/deepspeech/decoders/recog_bin.py index 567dfecd..7c866648 100644 --- a/deepspeech/decoders/recog_bin.py +++ b/deepspeech/decoders/recog_bin.py @@ -21,8 +21,6 @@ from distutils.util import strtobool import configargparse import numpy as np -from .recog import recog_v2 - def get_parser(): """Get default arguments.""" diff --git a/deepspeech/models/asr_interface.py b/deepspeech/models/asr_interface.py index 7dac81b4..d86daa0b 100644 --- a/deepspeech/models/asr_interface.py +++ b/deepspeech/models/asr_interface.py @@ -18,7 +18,7 @@ from deepspeech.utils.dynamic_import import dynamic_import class ASRInterface: - """ASR Interface for ESPnet model implementation.""" + """ASR Interface model implementation.""" @staticmethod def add_arguments(parser): @@ -103,14 +103,14 @@ class ASRInterface: @property def attention_plot_class(self): """Get attention plot class.""" - from espnet.asr.asr_utils import PlotAttentionReport + from deepspeech.training.extensions.plot import PlotAttentionReport return PlotAttentionReport @property def ctc_plot_class(self): """Get CTC plot class.""" - from espnet.asr.asr_utils import PlotCTCReport + from deepspeech.training.extensions.plot import PlotCTCReport return PlotCTCReport diff --git a/deepspeech/models/lm/__init__.py b/deepspeech/models/lm/__init__.py index e69de29b..185a92b8 100644 --- a/deepspeech/models/lm/__init__.py +++ b/deepspeech/models/lm/__init__.py @@ -0,0 +1,13 @@ +# Copyright (c) 2021 PaddlePaddle Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. diff --git a/deepspeech/models/lm_interface.py b/deepspeech/models/lm_interface.py index ed6d5d9c..e2987282 100644 --- a/deepspeech/models/lm_interface.py +++ b/deepspeech/models/lm_interface.py @@ -1,12 +1,25 @@ +# Copyright (c) 2021 PaddlePaddle Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. """Language model interface.""" - import argparse from deepspeech.decoders.scorers.scorer_interface import ScorerInterface from deepspeech.utils.dynamic_import import dynamic_import + class LMInterface(ScorerInterface): - """LM Interface for ESPnet model implementation.""" + """LM Interface model implementation.""" @staticmethod def add_arguments(parser): @@ -52,6 +65,7 @@ predefined_lms = { "transformer": "deepspeech.models.lm.transformer:TransformerLM", } + def dynamic_import_lm(module): """Import LM class dynamically. @@ -63,7 +77,6 @@ def dynamic_import_lm(module): """ model_class = dynamic_import(module, predefined_lms) - assert issubclass( - model_class, LMInterface - ), f"{module} does not implement LMInterface" + assert issubclass(model_class, + LMInterface), f"{module} does not implement LMInterface" return model_class diff --git a/deepspeech/models/st_interface.py b/deepspeech/models/st_interface.py new file mode 100644 index 00000000..05939f9a --- /dev/null +++ b/deepspeech/models/st_interface.py @@ -0,0 +1,75 @@ +# Copyright (c) 2021 PaddlePaddle Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +"""ST Interface module.""" +from .asr_interface import ASRInterface +from deepspeech.utils.dynamic_import import dynamic_import + + +class STInterface(ASRInterface): + """ST Interface model implementation. + + NOTE: This class is inherited from ASRInterface to enable joint translation + and recognition when performing multi-task learning with the ASR task. + + """ + + def translate(self, + x, + trans_args, + char_list=None, + rnnlm=None, + ensemble_models=[]): + """Recognize x for evaluation. + + :param ndarray x: input acouctic feature (B, T, D) or (T, D) + :param namespace trans_args: argment namespace contraining options + :param list char_list: list of characters + :param paddle.nn.Layer rnnlm: language model module + :return: N-best decoding results + :rtype: list + """ + raise NotImplementedError("translate method is not implemented") + + def translate_batch(self, x, trans_args, char_list=None, rnnlm=None): + """Beam search implementation for batch. + + :param paddle.Tensor x: encoder hidden state sequences (B, Tmax, Henc) + :param namespace trans_args: argument namespace containing options + :param list char_list: list of characters + :param paddle.nn.Layer rnnlm: language model module + :return: N-best decoding results + :rtype: list + """ + raise NotImplementedError("Batch decoding is not supported yet.") + + +predefined_st = { + "transformer": "deepspeech.models.u2_st:U2STModel", +} + + +def dynamic_import_st(module): + """Import ST models dynamically. + + Args: + module (str): module_name:class_name or alias in `predefined_st` + + Returns: + type: ST class + + """ + model_class = dynamic_import(module, predefined_st) + assert issubclass(model_class, + STInterface), f"{module} does not implement STInterface" + return model_class diff --git a/deepspeech/models/u2_st/__init__.py b/deepspeech/models/u2_st/__init__.py new file mode 100644 index 00000000..6b10b083 --- /dev/null +++ b/deepspeech/models/u2_st/__init__.py @@ -0,0 +1,15 @@ +# Copyright (c) 2021 PaddlePaddle Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +from .u2_st import U2STInferModel +from .u2_st import U2STModel diff --git a/deepspeech/models/u2_st.py b/deepspeech/models/u2_st/u2_st.py similarity index 100% rename from deepspeech/models/u2_st.py rename to deepspeech/models/u2_st/u2_st.py diff --git a/deepspeech/modules/ctc.py b/deepspeech/modules/ctc.py index e0c8006d..df6848db 100644 --- a/deepspeech/modules/ctc.py +++ b/deepspeech/modules/ctc.py @@ -11,9 +11,10 @@ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. +from typing import Union + import paddle from paddle import nn -from typing import Union from paddle.nn import functional as F from typeguard import check_argument_types diff --git a/deepspeech/modules/embedding.py b/deepspeech/modules/embedding.py index 7e8a2a85..2df877b0 100644 --- a/deepspeech/modules/embedding.py +++ b/deepspeech/modules/embedding.py @@ -22,7 +22,10 @@ from deepspeech.utils.log import Log logger = Log(__name__).getlog() -__all__ = ["NoPositionalEncoding", "PositionalEncoding", "RelPositionalEncoding"] +__all__ = [ + "NoPositionalEncoding", "PositionalEncoding", "RelPositionalEncoding" +] + class NoPositionalEncoding(nn.Layer): def __init__(self, diff --git a/deepspeech/modules/encoder.py b/deepspeech/modules/encoder.py index 6288e2ee..79411771 100644 --- a/deepspeech/modules/encoder.py +++ b/deepspeech/modules/encoder.py @@ -103,7 +103,7 @@ class BaseEncoder(nn.Layer): pos_enc_class = PositionalEncoding elif pos_enc_layer_type == "rel_pos": pos_enc_class = RelPositionalEncoding - elif pos_enc_layer_type is "no_pos": + elif pos_enc_layer_type == "no_pos": pos_enc_class = NoPositionalEncoding else: raise ValueError("unknown pos_enc_layer: " + pos_enc_layer_type) diff --git a/deepspeech/modules/subsampling.py b/deepspeech/modules/subsampling.py index 694f9f6f..13e2c8ef 100644 --- a/deepspeech/modules/subsampling.py +++ b/deepspeech/modules/subsampling.py @@ -60,8 +60,8 @@ class LinearNoSubsampling(BaseSubsampling): self.out = nn.Sequential( nn.Linear(idim, odim), nn.LayerNorm(odim, epsilon=1e-12), - nn.Dropout(dropout_rate), - nn.ReLU(),) + nn.Dropout(dropout_rate), + nn.ReLU(), ) self.right_context = 0 self.subsampling_rate = 1 @@ -83,10 +83,12 @@ class LinearNoSubsampling(BaseSubsampling): x, pos_emb = self.pos_enc(x, offset) return x, pos_emb, x_mask + class Conv2dSubsampling(BaseSubsampling): def __init__(self, *args, **kwargs): super().__init__(*args, **kwargs) + class Conv2dSubsampling4(Conv2dSubsampling): """Convolutional 2D subsampling (to 1/4 length).""" diff --git a/deepspeech/training/extensions/plot.py b/deepspeech/training/extensions/plot.py new file mode 100644 index 00000000..6fbb4d4d --- /dev/null +++ b/deepspeech/training/extensions/plot.py @@ -0,0 +1,418 @@ +# Copyright (c) 2021 PaddlePaddle Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +import copy +import os + +import numpy as np + +from . import extension + + +class PlotAttentionReport(extension.Extension): + """Plot attention reporter. + + Args: + att_vis_fn (espnet.nets.*_backend.e2e_asr.E2E.calculate_all_attentions): + Function of attention visualization. + data (list[tuple(str, dict[str, list[Any]])]): List json utt key items. + outdir (str): Directory to save figures. + converter (espnet.asr.*_backend.asr.CustomConverter): + Function to convert data. + device (int | torch.device): Device. + reverse (bool): If True, input and output length are reversed. + ikey (str): Key to access input + (for ASR/ST ikey="input", for MT ikey="output".) + iaxis (int): Dimension to access input + (for ASR/ST iaxis=0, for MT iaxis=1.) + okey (str): Key to access output + (for ASR/ST okey="input", MT okay="output".) + oaxis (int): Dimension to access output + (for ASR/ST oaxis=0, for MT oaxis=0.) + subsampling_factor (int): subsampling factor in encoder + + """ + + def __init__( + self, + att_vis_fn, + data, + outdir, + converter, + transform, + device, + reverse=False, + ikey="input", + iaxis=0, + okey="output", + oaxis=0, + subsampling_factor=1, ): + self.att_vis_fn = att_vis_fn + self.data = copy.deepcopy(data) + self.data_dict = {k: v for k, v in copy.deepcopy(data)} + # key is utterance ID + self.outdir = outdir + self.converter = converter + self.transform = transform + self.device = device + self.reverse = reverse + self.ikey = ikey + self.iaxis = iaxis + self.okey = okey + self.oaxis = oaxis + self.factor = subsampling_factor + if not os.path.exists(self.outdir): + os.makedirs(self.outdir) + + def __call__(self, trainer): + """Plot and save image file of att_ws matrix.""" + att_ws, uttid_list = self.get_attention_weights() + if isinstance(att_ws, list): # multi-encoder case + num_encs = len(att_ws) - 1 + # atts + for i in range(num_encs): + for idx, att_w in enumerate(att_ws[i]): + filename = "%s/%s.ep.{.updater.epoch}.att%d.png" % ( + self.outdir, uttid_list[idx], i + 1, ) + att_w = self.trim_attention_weight(uttid_list[idx], att_w) + np_filename = "%s/%s.ep.{.updater.epoch}.att%d.npy" % ( + self.outdir, uttid_list[idx], i + 1, ) + np.save(np_filename.format(trainer), att_w) + self._plot_and_save_attention(att_w, + filename.format(trainer)) + # han + for idx, att_w in enumerate(att_ws[num_encs]): + filename = "%s/%s.ep.{.updater.epoch}.han.png" % ( + self.outdir, uttid_list[idx], ) + att_w = self.trim_attention_weight(uttid_list[idx], att_w) + np_filename = "%s/%s.ep.{.updater.epoch}.han.npy" % ( + self.outdir, uttid_list[idx], ) + np.save(np_filename.format(trainer), att_w) + self._plot_and_save_attention( + att_w, filename.format(trainer), han_mode=True) + else: + for idx, att_w in enumerate(att_ws): + filename = "%s/%s.ep.{.updater.epoch}.png" % (self.outdir, + uttid_list[idx], ) + att_w = self.trim_attention_weight(uttid_list[idx], att_w) + np_filename = "%s/%s.ep.{.updater.epoch}.npy" % ( + self.outdir, uttid_list[idx], ) + np.save(np_filename.format(trainer), att_w) + self._plot_and_save_attention(att_w, filename.format(trainer)) + + def log_attentions(self, logger, step): + """Add image files of att_ws matrix to the tensorboard.""" + att_ws, uttid_list = self.get_attention_weights() + if isinstance(att_ws, list): # multi-encoder case + num_encs = len(att_ws) - 1 + # atts + for i in range(num_encs): + for idx, att_w in enumerate(att_ws[i]): + att_w = self.trim_attention_weight(uttid_list[idx], att_w) + plot = self.draw_attention_plot(att_w) + logger.add_figure( + "%s_att%d" % (uttid_list[idx], i + 1), + plot.gcf(), + step, ) + # han + for idx, att_w in enumerate(att_ws[num_encs]): + att_w = self.trim_attention_weight(uttid_list[idx], att_w) + plot = self.draw_han_plot(att_w) + logger.add_figure( + "%s_han" % (uttid_list[idx]), + plot.gcf(), + step, ) + else: + for idx, att_w in enumerate(att_ws): + att_w = self.trim_attention_weight(uttid_list[idx], att_w) + plot = self.draw_attention_plot(att_w) + logger.add_figure("%s" % (uttid_list[idx]), plot.gcf(), step) + + def get_attention_weights(self): + """Return attention weights. + + Returns: + numpy.ndarray: attention weights. float. Its shape would be + differ from backend. + * pytorch-> 1) multi-head case => (B, H, Lmax, Tmax), 2) + other case => (B, Lmax, Tmax). + * chainer-> (B, Lmax, Tmax) + + """ + return_batch, uttid_list = self.transform(self.data, return_uttid=True) + batch = self.converter([return_batch], self.device) + if isinstance(batch, tuple): + att_ws = self.att_vis_fn(*batch) + else: + att_ws = self.att_vis_fn(**batch) + return att_ws, uttid_list + + def trim_attention_weight(self, uttid, att_w): + """Transform attention matrix with regard to self.reverse.""" + if self.reverse: + enc_key, enc_axis = self.okey, self.oaxis + dec_key, dec_axis = self.ikey, self.iaxis + else: + enc_key, enc_axis = self.ikey, self.iaxis + dec_key, dec_axis = self.okey, self.oaxis + dec_len = int(self.data_dict[uttid][dec_key][dec_axis]["shape"][0]) + enc_len = int(self.data_dict[uttid][enc_key][enc_axis]["shape"][0]) + if self.factor > 1: + enc_len //= self.factor + if len(att_w.shape) == 3: + att_w = att_w[:, :dec_len, :enc_len] + else: + att_w = att_w[:dec_len, :enc_len] + return att_w + + def draw_attention_plot(self, att_w): + """Plot the att_w matrix. + + Returns: + matplotlib.pyplot: pyplot object with attention matrix image. + + """ + import matplotlib + + matplotlib.use("Agg") + import matplotlib.pyplot as plt + + plt.clf() + att_w = att_w.astype(np.float32) + if len(att_w.shape) == 3: + for h, aw in enumerate(att_w, 1): + plt.subplot(1, len(att_w), h) + plt.imshow(aw, aspect="auto") + plt.xlabel("Encoder Index") + plt.ylabel("Decoder Index") + else: + plt.imshow(att_w, aspect="auto") + plt.xlabel("Encoder Index") + plt.ylabel("Decoder Index") + plt.tight_layout() + return plt + + def draw_han_plot(self, att_w): + """Plot the att_w matrix for hierarchical attention. + + Returns: + matplotlib.pyplot: pyplot object with attention matrix image. + + """ + import matplotlib + + matplotlib.use("Agg") + import matplotlib.pyplot as plt + + plt.clf() + if len(att_w.shape) == 3: + for h, aw in enumerate(att_w, 1): + legends = [] + plt.subplot(1, len(att_w), h) + for i in range(aw.shape[1]): + plt.plot(aw[:, i]) + legends.append("Att{}".format(i)) + plt.ylim([0, 1.0]) + plt.xlim([0, aw.shape[0]]) + plt.grid(True) + plt.ylabel("Attention Weight") + plt.xlabel("Decoder Index") + plt.legend(legends) + else: + legends = [] + for i in range(att_w.shape[1]): + plt.plot(att_w[:, i]) + legends.append("Att{}".format(i)) + plt.ylim([0, 1.0]) + plt.xlim([0, att_w.shape[0]]) + plt.grid(True) + plt.ylabel("Attention Weight") + plt.xlabel("Decoder Index") + plt.legend(legends) + plt.tight_layout() + return plt + + def _plot_and_save_attention(self, att_w, filename, han_mode=False): + if han_mode: + plt = self.draw_han_plot(att_w) + else: + plt = self.draw_attention_plot(att_w) + plt.savefig(filename) + plt.close() + + +class PlotCTCReport(extension.Extension): + """Plot CTC reporter. + + Args: + ctc_vis_fn (espnet.nets.*_backend.e2e_asr.E2E.calculate_all_ctc_probs): + Function of CTC visualization. + data (list[tuple(str, dict[str, list[Any]])]): List json utt key items. + outdir (str): Directory to save figures. + converter (espnet.asr.*_backend.asr.CustomConverter): + Function to convert data. + device (int | torch.device): Device. + reverse (bool): If True, input and output length are reversed. + ikey (str): Key to access input + (for ASR/ST ikey="input", for MT ikey="output".) + iaxis (int): Dimension to access input + (for ASR/ST iaxis=0, for MT iaxis=1.) + okey (str): Key to access output + (for ASR/ST okey="input", MT okay="output".) + oaxis (int): Dimension to access output + (for ASR/ST oaxis=0, for MT oaxis=0.) + subsampling_factor (int): subsampling factor in encoder + + """ + + def __init__( + self, + ctc_vis_fn, + data, + outdir, + converter, + transform, + device, + reverse=False, + ikey="input", + iaxis=0, + okey="output", + oaxis=0, + subsampling_factor=1, ): + self.ctc_vis_fn = ctc_vis_fn + self.data = copy.deepcopy(data) + self.data_dict = {k: v for k, v in copy.deepcopy(data)} + # key is utterance ID + self.outdir = outdir + self.converter = converter + self.transform = transform + self.device = device + self.reverse = reverse + self.ikey = ikey + self.iaxis = iaxis + self.okey = okey + self.oaxis = oaxis + self.factor = subsampling_factor + if not os.path.exists(self.outdir): + os.makedirs(self.outdir) + + def __call__(self, trainer): + """Plot and save image file of ctc prob.""" + ctc_probs, uttid_list = self.get_ctc_probs() + if isinstance(ctc_probs, list): # multi-encoder case + num_encs = len(ctc_probs) - 1 + for i in range(num_encs): + for idx, ctc_prob in enumerate(ctc_probs[i]): + filename = "%s/%s.ep.{.updater.epoch}.ctc%d.png" % ( + self.outdir, uttid_list[idx], i + 1, ) + ctc_prob = self.trim_ctc_prob(uttid_list[idx], ctc_prob) + np_filename = "%s/%s.ep.{.updater.epoch}.ctc%d.npy" % ( + self.outdir, uttid_list[idx], i + 1, ) + np.save(np_filename.format(trainer), ctc_prob) + self._plot_and_save_ctc(ctc_prob, filename.format(trainer)) + else: + for idx, ctc_prob in enumerate(ctc_probs): + filename = "%s/%s.ep.{.updater.epoch}.png" % (self.outdir, + uttid_list[idx], ) + ctc_prob = self.trim_ctc_prob(uttid_list[idx], ctc_prob) + np_filename = "%s/%s.ep.{.updater.epoch}.npy" % ( + self.outdir, uttid_list[idx], ) + np.save(np_filename.format(trainer), ctc_prob) + self._plot_and_save_ctc(ctc_prob, filename.format(trainer)) + + def log_ctc_probs(self, logger, step): + """Add image files of ctc probs to the tensorboard.""" + ctc_probs, uttid_list = self.get_ctc_probs() + if isinstance(ctc_probs, list): # multi-encoder case + num_encs = len(ctc_probs) - 1 + for i in range(num_encs): + for idx, ctc_prob in enumerate(ctc_probs[i]): + ctc_prob = self.trim_ctc_prob(uttid_list[idx], ctc_prob) + plot = self.draw_ctc_plot(ctc_prob) + logger.add_figure( + "%s_ctc%d" % (uttid_list[idx], i + 1), + plot.gcf(), + step, ) + else: + for idx, ctc_prob in enumerate(ctc_probs): + ctc_prob = self.trim_ctc_prob(uttid_list[idx], ctc_prob) + plot = self.draw_ctc_plot(ctc_prob) + logger.add_figure("%s" % (uttid_list[idx]), plot.gcf(), step) + + def get_ctc_probs(self): + """Return CTC probs. + + Returns: + numpy.ndarray: CTC probs. float. Its shape would be + differ from backend. (B, Tmax, vocab). + + """ + return_batch, uttid_list = self.transform(self.data, return_uttid=True) + batch = self.converter([return_batch], self.device) + if isinstance(batch, tuple): + probs = self.ctc_vis_fn(*batch) + else: + probs = self.ctc_vis_fn(**batch) + return probs, uttid_list + + def trim_ctc_prob(self, uttid, prob): + """Trim CTC posteriors accoding to input lengths.""" + enc_len = int(self.data_dict[uttid][self.ikey][self.iaxis]["shape"][0]) + if self.factor > 1: + enc_len //= self.factor + prob = prob[:enc_len] + return prob + + def draw_ctc_plot(self, ctc_prob): + """Plot the ctc_prob matrix. + + Returns: + matplotlib.pyplot: pyplot object with CTC prob matrix image. + + """ + import matplotlib + + matplotlib.use("Agg") + import matplotlib.pyplot as plt + + ctc_prob = ctc_prob.astype(np.float32) + + plt.clf() + topk_ids = np.argsort(ctc_prob, axis=1) + n_frames, vocab = ctc_prob.shape + times_probs = np.arange(n_frames) + + plt.figure(figsize=(20, 8)) + + # NOTE: index 0 is reserved for blank + for idx in set(topk_ids.reshape(-1).tolist()): + if idx == 0: + plt.plot( + times_probs, + ctc_prob[:, 0], + ":", + label="", + color="grey") + else: + plt.plot(times_probs, ctc_prob[:, idx]) + plt.xlabel(u"Input [frame]", fontsize=12) + plt.ylabel("Posteriors", fontsize=12) + plt.xticks(list(range(0, int(n_frames) + 1, 10))) + plt.yticks(list(range(0, 2, 1))) + plt.tight_layout() + return plt + + def _plot_and_save_ctc(self, ctc_prob, filename): + plt = self.draw_ctc_plot(ctc_prob) + plt.savefig(filename) + plt.close() diff --git a/deepspeech/training/triggers/__init__.py b/deepspeech/training/triggers/__init__.py index 1a7c4292..185a92b8 100644 --- a/deepspeech/training/triggers/__init__.py +++ b/deepspeech/training/triggers/__init__.py @@ -11,18 +11,3 @@ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. -from .interval_trigger import IntervalTrigger - - -def never_fail_trigger(trainer): - return False - - -def get_trigger(trigger): - if trigger is None: - return never_fail_trigger - if callable(trigger): - return trigger - else: - trigger = IntervalTrigger(*trigger) - return trigger diff --git a/deepspeech/training/triggers/compare_value_trigger.py b/deepspeech/training/triggers/compare_value_trigger.py new file mode 100644 index 00000000..efb928e2 --- /dev/null +++ b/deepspeech/training/triggers/compare_value_trigger.py @@ -0,0 +1,61 @@ +# Copyright (c) 2021 PaddlePaddle Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +from ..reporter import DictSummary +from .utils import get_trigger + + +class CompareValueTrigger(): + """Trigger invoked when key value getting bigger or lower than before. + + Args: + key (str) : Key of value. + compare_fn ((float, float) -> bool) : Function to compare the values. + trigger (tuple(int, str)) : Trigger that decide the comparison interval. + + """ + + def __init__(self, key, compare_fn, trigger=(1, "epoch")): + self._key = key + self._best_value = None + self._interval_trigger = get_trigger(trigger) + self._init_summary() + self._compare_fn = compare_fn + + def __call__(self, trainer): + """Get value related to the key and compare with current value.""" + observation = trainer.observation + summary = self._summary + key = self._key + if key in observation: + summary.add({key: observation[key]}) + + if not self._interval_trigger(trainer): + return False + + stats = summary.compute_mean() + value = float(stats[key]) # copy to CPU + self._init_summary() + + if self._best_value is None: + # initialize best value + self._best_value = value + return False + elif self._compare_fn(self._best_value, value): + return True + else: + self._best_value = value + return False + + def _init_summary(self): + self._summary = DictSummary() diff --git a/deepspeech/training/triggers/time_trigger.py b/deepspeech/training/triggers/time_trigger.py index ea8fe562..e31179a9 100644 --- a/deepspeech/training/triggers/time_trigger.py +++ b/deepspeech/training/triggers/time_trigger.py @@ -30,3 +30,12 @@ class TimeTrigger(): return True else: return False + + def state_dict(self): + state_dict = { + "next_time": self._next_time, + } + return state_dict + + def set_state_dict(self, state_dict): + self._next_time = state_dict['next_time'] diff --git a/deepspeech/training/triggers/utils.py b/deepspeech/training/triggers/utils.py new file mode 100644 index 00000000..1a7c4292 --- /dev/null +++ b/deepspeech/training/triggers/utils.py @@ -0,0 +1,28 @@ +# Copyright (c) 2021 PaddlePaddle Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +from .interval_trigger import IntervalTrigger + + +def never_fail_trigger(trainer): + return False + + +def get_trigger(trigger): + if trigger is None: + return never_fail_trigger + if callable(trigger): + return trigger + else: + trigger = IntervalTrigger(*trigger) + return trigger diff --git a/deepspeech/utils/asr_utils.py b/deepspeech/utils/asr_utils.py index 06cf6487..6f86e56f 100644 --- a/deepspeech/utils/asr_utils.py +++ b/deepspeech/utils/asr_utils.py @@ -12,6 +12,7 @@ # See the License for the specific language governing permissions and # limitations under the License. import json + import numpy as np __all__ = ["label_smoothing_dist"] @@ -33,8 +34,7 @@ def label_smoothing_dist(odim, lsm_type, transcript=None, blank=0): if lsm_type == "unigram": assert transcript is not None, ( - "transcript is required for %s label smoothing" % lsm_type - ) + "transcript is required for %s label smoothing" % lsm_type) labelcount = np.zeros(odim) for k, v in trans_json.items(): ids = np.array([int(n) for n in v["output"][0]["tokenid"].split()]) diff --git a/deepspeech/utils/bleu_score.py b/deepspeech/utils/bleu_score.py index 93749ddd..ea32fcf9 100644 --- a/deepspeech/utils/bleu_score.py +++ b/deepspeech/utils/bleu_score.py @@ -14,9 +14,9 @@ """This module provides functions to calculate bleu score in different level. e.g. wer for word-level, cer for char-level. """ -import sacrebleu import nltk import numpy as np +import sacrebleu __all__ = ['bleu', 'char_bleu', "ErrorCalculator"] @@ -106,11 +106,14 @@ class ErrorCalculator(): # NOTE: padding index (-1) in y_true is used to pad y_hat # because y_hats is not padded with -1 seq_hat = [self.char_list[int(idx)] for idx in y_hat[:ymax]] - seq_true = [self.char_list[int(idx)] for idx in y_true if int(idx) != -1] + seq_true = [ + self.char_list[int(idx)] for idx in y_true if int(idx) != -1 + ] seq_hat_text = "".join(seq_hat).replace(self.space, " ") seq_hat_text = seq_hat_text.replace(self.pad, "") seq_true_text = "".join(seq_true).replace(self.space, " ") seqs_hat.append(seq_hat_text) seqs_true.append(seq_true_text) - bleu = nltk.bleu_score.corpus_bleu([[ref] for ref in seqs_true], seqs_hat) + bleu = nltk.bleu_score.corpus_bleu([[ref] for ref in seqs_true], + seqs_hat) return bleu * 100 diff --git a/deepspeech/utils/error_rate.py b/deepspeech/utils/error_rate.py index 0ad62b6b..548376aa 100644 --- a/deepspeech/utils/error_rate.py +++ b/deepspeech/utils/error_rate.py @@ -14,11 +14,10 @@ """This module provides functions to calculate error rate in different level. e.g. wer for word-level, cer for char-level. """ +from itertools import groupby + import editdistance import numpy as np -import logging -import sys -from itertools import groupby __all__ = ['word_errors', 'char_errors', 'wer', 'cer', "ErrorCalculator"] @@ -225,9 +224,12 @@ class ErrorCalculator(): :return: """ - def __init__( - self, char_list, sym_space, sym_blank, report_cer=False, report_wer=False - ): + def __init__(self, + char_list, + sym_space, + sym_blank, + report_cer=False, + report_wer=False): """Construct an ErrorCalculator object.""" super().__init__() @@ -317,7 +319,9 @@ class ErrorCalculator(): ymax = eos_true[0] if len(eos_true) > 0 else len(y_true) # NOTE: padding index (-1) in y_true is used to pad y_hat seq_hat = [self.char_list[int(idx)] for idx in y_hat[:ymax]] - seq_true = [self.char_list[int(idx)] for idx in y_true if int(idx) != -1] + seq_true = [ + self.char_list[int(idx)] for idx in y_true if int(idx) != -1 + ] seq_hat_text = "".join(seq_hat).replace(self.space, " ") seq_hat_text = seq_hat_text.replace(self.blank, "") seq_true_text = "".join(seq_true).replace(self.space, " ") diff --git a/setup.py b/setup.py index bd982129..be17e0a4 100644 --- a/setup.py +++ b/setup.py @@ -15,7 +15,6 @@ import contextlib import inspect import io import os -import re import subprocess as sp import sys from pathlib import Path @@ -84,7 +83,7 @@ def _post_install(install_lib_dir): tools_extrs_dir = HERE / 'tools/extras' with pushd(tools_extrs_dir): print(os.getcwd()) - check_call(f"./install_autolog.sh") + check_call("./install_autolog.sh") print("autolog install.") # ctcdecoder diff --git a/utils/json2trn.py b/utils/json2trn.py index 873fde4f..4adfa491 100755 --- a/utils/json2trn.py +++ b/utils/json2trn.py @@ -4,7 +4,6 @@ # 2018 Xuankai Chang (Shanghai Jiao Tong University) # Apache 2.0 (http://www.apache.org/licenses/LICENSE-2.0) import argparse -import json import logging import sys