From e4ecfb22fd9e3053b73c7272d087d96e3de377c2 Mon Sep 17 00:00:00 2001 From: Hui Zhang Date: Mon, 25 Oct 2021 03:25:49 +0000 Subject: [PATCH] format code --- deepspeech/__init__.py | 3 - deepspeech/decoders/recog.py | 25 ++- deepspeech/decoders/recog_bin.py | 2 - deepspeech/models/asr_interface.py | 2 +- deepspeech/models/lm/__init__.py | 13 ++ deepspeech/models/lm/transformer.py | 64 ++++---- deepspeech/models/lm_interface.py | 21 ++- deepspeech/models/st_interface.py | 32 +++- deepspeech/models/u2_st/__init__.py | 15 +- deepspeech/modules/ctc.py | 3 +- deepspeech/modules/embedding.py | 5 +- deepspeech/modules/encoder.py | 10 +- deepspeech/modules/subsampling.py | 6 +- deepspeech/training/extensions/plot.py | 147 ++++++++---------- deepspeech/training/extensions/visualizer.py | 2 +- .../triggers/compare_value_trigger.py | 7 +- deepspeech/training/triggers/limit_trigger.py | 2 +- deepspeech/training/triggers/time_trigger.py | 2 +- deepspeech/training/triggers/utils.py | 13 ++ deepspeech/utils/asr_utils.py | 4 +- deepspeech/utils/bleu_score.py | 9 +- deepspeech/utils/error_rate.py | 18 ++- setup.py | 3 +- utils/json2trn.py | 1 - 24 files changed, 229 insertions(+), 180 deletions(-) diff --git a/deepspeech/__init__.py b/deepspeech/__init__.py index ed209f3d..5762e635 100644 --- a/deepspeech/__init__.py +++ b/deepspeech/__init__.py @@ -355,7 +355,6 @@ if not hasattr(paddle.Tensor, 'tolist'): "register user tolist to paddle.Tensor, remove this when fixed!") setattr(paddle.Tensor, 'tolist', tolist) - ########### hack paddle.nn ############# from paddle.nn import Layer from typing import Optional @@ -506,5 +505,3 @@ if not hasattr(paddle.nn, 'LayerDict'): logger.debug( "register user LayerDict to paddle.nn, remove this when fixed!") setattr(paddle.nn, 'LayerDict', LayerDict) - - diff --git a/deepspeech/decoders/recog.py b/deepspeech/decoders/recog.py index c8df65d6..6dea6b70 100644 --- a/deepspeech/decoders/recog.py +++ b/deepspeech/decoders/recog.py @@ -12,12 +12,8 @@ # See the License for the specific language governing permissions and # limitations under the License. """V2 backend for `asr_recog.py` using py:class:`decoders.beam_search.BeamSearch`.""" -import json -from pathlib import Path - import jsonlines import paddle -import yaml from yacs.config import CfgNode from .beam_search import BatchBeamSearch @@ -79,8 +75,7 @@ def recog_v2(args): sort_in_input_length=False, preprocess_conf=confs.collator.augmentation_config if args.preprocess_conf is None else args.preprocess_conf, - preprocess_args={"train": False}, - ) + preprocess_args={"train": False}, ) if args.rnnlm: lm_args = get_model_conf(args.rnnlm, args.rnnlm_conf) @@ -113,8 +108,7 @@ def recog_v2(args): ctc=args.ctc_weight, lm=args.lm_weight, ngram=args.ngram_weight, - length_bonus=args.penalty, - ) + length_bonus=args.penalty, ) beam_search = BeamSearch( beam_size=args.beam_size, vocab_size=len(char_list), @@ -123,8 +117,7 @@ def recog_v2(args): sos=model.sos, eos=model.eos, token_list=char_list, - pre_beam_score_key=None if args.ctc_weight == 1.0 else "full", - ) + pre_beam_score_key=None if args.ctc_weight == 1.0 else "full", ) # TODO(karita): make all scorers batchfied if args.batchsize == 1: @@ -171,9 +164,10 @@ def recog_v2(args): logger.info(f'feat: {feat.shape}') enc = model.encode(paddle.to_tensor(feat).to(dtype)) logger.info(f'eout: {enc.shape}') - nbest_hyps = beam_search(x=enc, - maxlenratio=args.maxlenratio, - minlenratio=args.minlenratio) + nbest_hyps = beam_search( + x=enc, + maxlenratio=args.maxlenratio, + minlenratio=args.minlenratio) nbest_hyps = [ h.asdict() for h in nbest_hyps[:min(len(nbest_hyps), args.nbest)] @@ -183,9 +177,8 @@ def recog_v2(args): item = new_js[name]['output'][0] # 1-best ref = item['text'] - rec_text = item['rec_text'].replace('▁', - ' ').replace('', - '').strip() + rec_text = item['rec_text'].replace('▁', ' ').replace( + '', '').strip() rec_tokenid = list(map(int, item['rec_tokenid'].split())) f.write({ "utt": name, diff --git a/deepspeech/decoders/recog_bin.py b/deepspeech/decoders/recog_bin.py index 567dfecd..7c866648 100644 --- a/deepspeech/decoders/recog_bin.py +++ b/deepspeech/decoders/recog_bin.py @@ -21,8 +21,6 @@ from distutils.util import strtobool import configargparse import numpy as np -from .recog import recog_v2 - def get_parser(): """Get default arguments.""" diff --git a/deepspeech/models/asr_interface.py b/deepspeech/models/asr_interface.py index b6f5c664..d86daa0b 100644 --- a/deepspeech/models/asr_interface.py +++ b/deepspeech/models/asr_interface.py @@ -110,7 +110,7 @@ class ASRInterface: @property def ctc_plot_class(self): """Get CTC plot class.""" - from deepspeech.training.extensions.plot import PlotCTCReport + from deepspeech.training.extensions.plot import PlotCTCReport return PlotCTCReport diff --git a/deepspeech/models/lm/__init__.py b/deepspeech/models/lm/__init__.py index e69de29b..185a92b8 100644 --- a/deepspeech/models/lm/__init__.py +++ b/deepspeech/models/lm/__init__.py @@ -0,0 +1,13 @@ +# Copyright (c) 2021 PaddlePaddle Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. diff --git a/deepspeech/models/lm/transformer.py b/deepspeech/models/lm/transformer.py index dcae4ea0..467c4ab9 100644 --- a/deepspeech/models/lm/transformer.py +++ b/deepspeech/models/lm/transformer.py @@ -20,11 +20,11 @@ import paddle import paddle.nn as nn import paddle.nn.functional as F -from deepspeech.modules.mask import subsequent_mask -from deepspeech.modules.encoder import TransformerEncoder from deepspeech.decoders.scorers.scorer_interface import BatchScorerInterface -from deepspeech.models.lm_interface import -#LMInterface +from deepspeech.models.lm_interface import LMInterface +from deepspeech.modules.encoder import TransformerEncoder +from deepspeech.modules.mask import subsequent_mask + class TransformerLM(nn.Layer, LMInterface, BatchScorerInterface): def __init__( @@ -36,10 +36,10 @@ class TransformerLM(nn.Layer, LMInterface, BatchScorerInterface): head: int=2, unit: int=1024, layer: int=4, - dropout_rate: float=0.5, - emb_dropout_rate: float = 0.0, - att_dropout_rate: float = 0.0, - tie_weights: bool = False,): + dropout_rate: float=0.5, + emb_dropout_rate: float=0.0, + att_dropout_rate: float=0.0, + tie_weights: bool=False, ): nn.Layer.__init__(self) if pos_enc == "sinusoidal": @@ -89,9 +89,8 @@ class TransformerLM(nn.Layer, LMInterface, BatchScorerInterface): m = subsequent_mask(ys_mask.size(-1)).unsqueeze(0) return ys_mask.unsqueeze(-2) & m - def forward( - self, x: paddle.Tensor, t: paddle.Tensor - ) -> Tuple[paddle.Tensor, paddle.Tensor, paddle.Tensor]: + def forward(self, x: paddle.Tensor, t: paddle.Tensor + ) -> Tuple[paddle.Tensor, paddle.Tensor, paddle.Tensor]: """Compute LM loss value from buffer sequences. Args: @@ -117,7 +116,8 @@ class TransformerLM(nn.Layer, LMInterface, BatchScorerInterface): emb = self.embed(x) h, _ = self.encoder(emb, xlen) y = self.decoder(h) - loss = F.cross_entropy(y.view(-1, y.shape[-1]), t.view(-1), reduction="none") + loss = F.cross_entropy( + y.view(-1, y.shape[-1]), t.view(-1), reduction="none") mask = xm.to(dtype=loss.dtype) logp = loss * mask.view(-1) logp = logp.sum() @@ -148,16 +148,16 @@ class TransformerLM(nn.Layer, LMInterface, BatchScorerInterface): emb = self.embed(y) h, _, cache = self.encoder.forward_one_step( - emb, self._target_mask(y), cache=state - ) + emb, self._target_mask(y), cache=state) h = self.decoder(h[:, -1]) logp = h.log_softmax(axis=-1).squeeze(0) return logp, cache # batch beam search API (see BatchScorerInterface) - def batch_score( - self, ys: paddle.Tensor, states: List[Any], xs: paddle.Tensor - ) -> Tuple[paddle.Tensor, List[Any]]: + def batch_score(self, + ys: paddle.Tensor, + states: List[Any], + xs: paddle.Tensor) -> Tuple[paddle.Tensor, List[Any]]: """Score new token batch (required). Args: @@ -191,13 +191,13 @@ class TransformerLM(nn.Layer, LMInterface, BatchScorerInterface): # batch decoding h, _, states = self.encoder.forward_one_step( - emb, self._target_mask(ys), cache=batch_state - ) + emb, self._target_mask(ys), cache=batch_state) h = self.decoder(h[:, -1]) logp = h.log_softmax(axi=-1) # transpose state of [layer, batch] into [batch, layer] - state_list = [[states[i][b] for i in range(n_layers)] for b in range(n_batch)] + state_list = [[states[i][b] for i in range(n_layers)] + for b in range(n_batch)] return logp, state_list @@ -212,17 +212,17 @@ if __name__ == "__main__": layer=16, dropout_rate=0.5, ) - # n_vocab: int, - # pos_enc: str=None, - # embed_unit: int=128, - # att_unit: int=256, - # head: int=2, - # unit: int=1024, - # layer: int=4, - # dropout_rate: float=0.5, - # emb_dropout_rate: float = 0.0, - # att_dropout_rate: float = 0.0, - # tie_weights: bool = False,): + # n_vocab: int, + # pos_enc: str=None, + # embed_unit: int=128, + # att_unit: int=256, + # head: int=2, + # unit: int=1024, + # layer: int=4, + # dropout_rate: float=0.5, + # emb_dropout_rate: float = 0.0, + # att_dropout_rate: float = 0.0, + # tie_weights: bool = False,): paddle.set_device("cpu") model_dict = paddle.load("transformerLM.pdparams") tlm.set_state_dict(model_dict) @@ -256,4 +256,4 @@ if __name__ == "__main__": print("output", output) #print("cache", cache) #np.save("output_pd.npy", output) - """ \ No newline at end of file + """ diff --git a/deepspeech/models/lm_interface.py b/deepspeech/models/lm_interface.py index 66466bcd..e2987282 100644 --- a/deepspeech/models/lm_interface.py +++ b/deepspeech/models/lm_interface.py @@ -1,10 +1,23 @@ +# Copyright (c) 2021 PaddlePaddle Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. """Language model interface.""" - import argparse from deepspeech.decoders.scorers.scorer_interface import ScorerInterface from deepspeech.utils.dynamic_import import dynamic_import + class LMInterface(ScorerInterface): """LM Interface model implementation.""" @@ -52,6 +65,7 @@ predefined_lms = { "transformer": "deepspeech.models.lm.transformer:TransformerLM", } + def dynamic_import_lm(module): """Import LM class dynamically. @@ -63,7 +77,6 @@ def dynamic_import_lm(module): """ model_class = dynamic_import(module, predefined_lms) - assert issubclass( - model_class, LMInterface - ), f"{module} does not implement LMInterface" + assert issubclass(model_class, + LMInterface), f"{module} does not implement LMInterface" return model_class diff --git a/deepspeech/models/st_interface.py b/deepspeech/models/st_interface.py index 7bbcbf99..05939f9a 100644 --- a/deepspeech/models/st_interface.py +++ b/deepspeech/models/st_interface.py @@ -1,9 +1,20 @@ +# Copyright (c) 2021 PaddlePaddle Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. """ST Interface module.""" - - -import argparse -from deepspeech.utils.dynamic_import import dynamic_import from .asr_interface import ASRInterface +from deepspeech.utils.dynamic_import import dynamic_import + class STInterface(ASRInterface): """ST Interface model implementation. @@ -13,7 +24,12 @@ class STInterface(ASRInterface): """ - def translate(self, x, trans_args, char_list=None, rnnlm=None, ensemble_models=[]): + def translate(self, + x, + trans_args, + char_list=None, + rnnlm=None, + ensemble_models=[]): """Recognize x for evaluation. :param ndarray x: input acouctic feature (B, T, D) or (T, D) @@ -42,6 +58,7 @@ predefined_st = { "transformer": "deepspeech.models.u2_st:U2STModel", } + def dynamic_import_st(module): """Import ST models dynamically. @@ -53,7 +70,6 @@ def dynamic_import_st(module): """ model_class = dynamic_import(module, predefined_st) - assert issubclass( - model_class, STInterface - ), f"{module} does not implement STInterface" + assert issubclass(model_class, + STInterface), f"{module} does not implement STInterface" return model_class diff --git a/deepspeech/models/u2_st/__init__.py b/deepspeech/models/u2_st/__init__.py index f2b2f747..6b10b083 100644 --- a/deepspeech/models/u2_st/__init__.py +++ b/deepspeech/models/u2_st/__init__.py @@ -1,2 +1,15 @@ +# Copyright (c) 2021 PaddlePaddle Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +from .u2_st import U2STInferModel from .u2_st import U2STModel -from .u2_st import U2STInferModel \ No newline at end of file diff --git a/deepspeech/modules/ctc.py b/deepspeech/modules/ctc.py index e0c8006d..df6848db 100644 --- a/deepspeech/modules/ctc.py +++ b/deepspeech/modules/ctc.py @@ -11,9 +11,10 @@ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. +from typing import Union + import paddle from paddle import nn -from typing import Union from paddle.nn import functional as F from typeguard import check_argument_types diff --git a/deepspeech/modules/embedding.py b/deepspeech/modules/embedding.py index 7e8a2a85..2df877b0 100644 --- a/deepspeech/modules/embedding.py +++ b/deepspeech/modules/embedding.py @@ -22,7 +22,10 @@ from deepspeech.utils.log import Log logger = Log(__name__).getlog() -__all__ = ["NoPositionalEncoding", "PositionalEncoding", "RelPositionalEncoding"] +__all__ = [ + "NoPositionalEncoding", "PositionalEncoding", "RelPositionalEncoding" +] + class NoPositionalEncoding(nn.Layer): def __init__(self, diff --git a/deepspeech/modules/encoder.py b/deepspeech/modules/encoder.py index 0f8f1075..f2c26988 100644 --- a/deepspeech/modules/encoder.py +++ b/deepspeech/modules/encoder.py @@ -24,9 +24,9 @@ from deepspeech.modules.activation import get_activation from deepspeech.modules.attention import MultiHeadedAttention from deepspeech.modules.attention import RelPositionMultiHeadedAttention from deepspeech.modules.conformer_convolution import ConvolutionModule +from deepspeech.modules.embedding import NoPositionalEncoding from deepspeech.modules.embedding import PositionalEncoding from deepspeech.modules.embedding import RelPositionalEncoding -from deepspeech.modules.embedding import NoPositionalEncoding from deepspeech.modules.encoder_layer import ConformerEncoderLayer from deepspeech.modules.encoder_layer import TransformerEncoderLayer from deepspeech.modules.mask import add_optional_chunk_mask @@ -103,7 +103,7 @@ class BaseEncoder(nn.Layer): pos_enc_class = PositionalEncoding elif pos_enc_layer_type == "rel_pos": pos_enc_class = RelPositionalEncoding - elif pos_enc_layer_type is "no_pos": + elif pos_enc_layer_type == "no_pos": pos_enc_class = NoPositionalEncoding else: raise ValueError("unknown pos_enc_layer: " + pos_enc_layer_type) @@ -378,8 +378,7 @@ class TransformerEncoder(BaseEncoder): self, xs: paddle.Tensor, masks: paddle.Tensor, - cache=None, - ) -> Tuple[paddle.Tensor, paddle.Tensor]: + cache=None, ) -> Tuple[paddle.Tensor, paddle.Tensor]: """Encode input frame. Args: @@ -397,7 +396,8 @@ class TransformerEncoder(BaseEncoder): if isinstance(self.embed, Conv2dSubsampling): #TODO(Hui Zhang): self.embed(xs, masks, offset=0), stride_slice not support bool tensor - xs, pos_emb, masks = self.embed(xs, masks.astype(xs.dtype), offset=0) + xs, pos_emb, masks = self.embed( + xs, masks.astype(xs.dtype), offset=0) else: xs = self.embed(xs) #TODO(Hui Zhang): remove mask.astype, stride_slice not support bool tensor diff --git a/deepspeech/modules/subsampling.py b/deepspeech/modules/subsampling.py index 694f9f6f..13e2c8ef 100644 --- a/deepspeech/modules/subsampling.py +++ b/deepspeech/modules/subsampling.py @@ -60,8 +60,8 @@ class LinearNoSubsampling(BaseSubsampling): self.out = nn.Sequential( nn.Linear(idim, odim), nn.LayerNorm(odim, epsilon=1e-12), - nn.Dropout(dropout_rate), - nn.ReLU(),) + nn.Dropout(dropout_rate), + nn.ReLU(), ) self.right_context = 0 self.subsampling_rate = 1 @@ -83,10 +83,12 @@ class LinearNoSubsampling(BaseSubsampling): x, pos_emb = self.pos_enc(x, offset) return x, pos_emb, x_mask + class Conv2dSubsampling(BaseSubsampling): def __init__(self, *args, **kwargs): super().__init__(*args, **kwargs) + class Conv2dSubsampling4(Conv2dSubsampling): """Convolutional 2D subsampling (to 1/4 length).""" diff --git a/deepspeech/training/extensions/plot.py b/deepspeech/training/extensions/plot.py index 7b1ee74a..6fbb4d4d 100644 --- a/deepspeech/training/extensions/plot.py +++ b/deepspeech/training/extensions/plot.py @@ -1,15 +1,22 @@ - -import argparse +# Copyright (c) 2021 PaddlePaddle Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. import copy -import json import os -import shutil -import tempfile -import numpy as np +import numpy as np from . import extension -from ..updaters.trainer import Trainer class PlotAttentionReport(extension.Extension): @@ -37,20 +44,19 @@ class PlotAttentionReport(extension.Extension): """ def __init__( - self, - att_vis_fn, - data, - outdir, - converter, - transform, - device, - reverse=False, - ikey="input", - iaxis=0, - okey="output", - oaxis=0, - subsampling_factor=1, - ): + self, + att_vis_fn, + data, + outdir, + converter, + transform, + device, + reverse=False, + ikey="input", + iaxis=0, + okey="output", + oaxis=0, + subsampling_factor=1, ): self.att_vis_fn = att_vis_fn self.data = copy.deepcopy(data) self.data_dict = {k: v for k, v in copy.deepcopy(data)} @@ -77,44 +83,30 @@ class PlotAttentionReport(extension.Extension): for i in range(num_encs): for idx, att_w in enumerate(att_ws[i]): filename = "%s/%s.ep.{.updater.epoch}.att%d.png" % ( - self.outdir, - uttid_list[idx], - i + 1, - ) + self.outdir, uttid_list[idx], i + 1, ) att_w = self.trim_attention_weight(uttid_list[idx], att_w) np_filename = "%s/%s.ep.{.updater.epoch}.att%d.npy" % ( - self.outdir, - uttid_list[idx], - i + 1, - ) + self.outdir, uttid_list[idx], i + 1, ) np.save(np_filename.format(trainer), att_w) - self._plot_and_save_attention(att_w, filename.format(trainer)) + self._plot_and_save_attention(att_w, + filename.format(trainer)) # han for idx, att_w in enumerate(att_ws[num_encs]): filename = "%s/%s.ep.{.updater.epoch}.han.png" % ( - self.outdir, - uttid_list[idx], - ) + self.outdir, uttid_list[idx], ) att_w = self.trim_attention_weight(uttid_list[idx], att_w) np_filename = "%s/%s.ep.{.updater.epoch}.han.npy" % ( - self.outdir, - uttid_list[idx], - ) + self.outdir, uttid_list[idx], ) np.save(np_filename.format(trainer), att_w) self._plot_and_save_attention( - att_w, filename.format(trainer), han_mode=True - ) + att_w, filename.format(trainer), han_mode=True) else: for idx, att_w in enumerate(att_ws): - filename = "%s/%s.ep.{.updater.epoch}.png" % ( - self.outdir, - uttid_list[idx], - ) + filename = "%s/%s.ep.{.updater.epoch}.png" % (self.outdir, + uttid_list[idx], ) att_w = self.trim_attention_weight(uttid_list[idx], att_w) np_filename = "%s/%s.ep.{.updater.epoch}.npy" % ( - self.outdir, - uttid_list[idx], - ) + self.outdir, uttid_list[idx], ) np.save(np_filename.format(trainer), att_w) self._plot_and_save_attention(att_w, filename.format(trainer)) @@ -131,8 +123,7 @@ class PlotAttentionReport(extension.Extension): logger.add_figure( "%s_att%d" % (uttid_list[idx], i + 1), plot.gcf(), - step, - ) + step, ) # han for idx, att_w in enumerate(att_ws[num_encs]): att_w = self.trim_attention_weight(uttid_list[idx], att_w) @@ -140,8 +131,7 @@ class PlotAttentionReport(extension.Extension): logger.add_figure( "%s_han" % (uttid_list[idx]), plot.gcf(), - step, - ) + step, ) else: for idx, att_w in enumerate(att_ws): att_w = self.trim_attention_weight(uttid_list[idx], att_w) @@ -286,20 +276,19 @@ class PlotCTCReport(extension.Extension): """ def __init__( - self, - ctc_vis_fn, - data, - outdir, - converter, - transform, - device, - reverse=False, - ikey="input", - iaxis=0, - okey="output", - oaxis=0, - subsampling_factor=1, - ): + self, + ctc_vis_fn, + data, + outdir, + converter, + transform, + device, + reverse=False, + ikey="input", + iaxis=0, + okey="output", + oaxis=0, + subsampling_factor=1, ): self.ctc_vis_fn = ctc_vis_fn self.data = copy.deepcopy(data) self.data_dict = {k: v for k, v in copy.deepcopy(data)} @@ -325,29 +314,19 @@ class PlotCTCReport(extension.Extension): for i in range(num_encs): for idx, ctc_prob in enumerate(ctc_probs[i]): filename = "%s/%s.ep.{.updater.epoch}.ctc%d.png" % ( - self.outdir, - uttid_list[idx], - i + 1, - ) + self.outdir, uttid_list[idx], i + 1, ) ctc_prob = self.trim_ctc_prob(uttid_list[idx], ctc_prob) np_filename = "%s/%s.ep.{.updater.epoch}.ctc%d.npy" % ( - self.outdir, - uttid_list[idx], - i + 1, - ) + self.outdir, uttid_list[idx], i + 1, ) np.save(np_filename.format(trainer), ctc_prob) self._plot_and_save_ctc(ctc_prob, filename.format(trainer)) else: for idx, ctc_prob in enumerate(ctc_probs): - filename = "%s/%s.ep.{.updater.epoch}.png" % ( - self.outdir, - uttid_list[idx], - ) + filename = "%s/%s.ep.{.updater.epoch}.png" % (self.outdir, + uttid_list[idx], ) ctc_prob = self.trim_ctc_prob(uttid_list[idx], ctc_prob) np_filename = "%s/%s.ep.{.updater.epoch}.npy" % ( - self.outdir, - uttid_list[idx], - ) + self.outdir, uttid_list[idx], ) np.save(np_filename.format(trainer), ctc_prob) self._plot_and_save_ctc(ctc_prob, filename.format(trainer)) @@ -363,8 +342,7 @@ class PlotCTCReport(extension.Extension): logger.add_figure( "%s_ctc%d" % (uttid_list[idx], i + 1), plot.gcf(), - step, - ) + step, ) else: for idx, ctc_prob in enumerate(ctc_probs): ctc_prob = self.trim_ctc_prob(uttid_list[idx], ctc_prob) @@ -420,8 +398,11 @@ class PlotCTCReport(extension.Extension): for idx in set(topk_ids.reshape(-1).tolist()): if idx == 0: plt.plot( - times_probs, ctc_prob[:, 0], ":", label="", color="grey" - ) + times_probs, + ctc_prob[:, 0], + ":", + label="", + color="grey") else: plt.plot(times_probs, ctc_prob[:, idx]) plt.xlabel(u"Input [frame]", fontsize=12) @@ -434,4 +415,4 @@ class PlotCTCReport(extension.Extension): def _plot_and_save_ctc(self, ctc_prob, filename): plt = self.draw_ctc_plot(ctc_prob) plt.savefig(filename) - plt.close() \ No newline at end of file + plt.close() diff --git a/deepspeech/training/extensions/visualizer.py b/deepspeech/training/extensions/visualizer.py index dcec2a76..e5f456ca 100644 --- a/deepspeech/training/extensions/visualizer.py +++ b/deepspeech/training/extensions/visualizer.py @@ -36,4 +36,4 @@ class VisualDL(extension.Extension): self.writer.add_scalar(k, v, step=trainer.updater.state.iteration) def finalize(self, trainer): - self.writer.close() \ No newline at end of file + self.writer.close() diff --git a/deepspeech/training/triggers/compare_value_trigger.py b/deepspeech/training/triggers/compare_value_trigger.py index 75f83b42..efb928e2 100644 --- a/deepspeech/training/triggers/compare_value_trigger.py +++ b/deepspeech/training/triggers/compare_value_trigger.py @@ -11,9 +11,9 @@ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. - -from .utils import get_trigger from ..reporter import DictSummary +from .utils import get_trigger + class CompareValueTrigger(): """Trigger invoked when key value getting bigger or lower than before. @@ -24,6 +24,7 @@ class CompareValueTrigger(): trigger (tuple(int, str)) : Trigger that decide the comparison interval. """ + def __init__(self, key, compare_fn, trigger=(1, "epoch")): self._key = key self._best_value = None @@ -57,4 +58,4 @@ class CompareValueTrigger(): return False def _init_summary(self): - self._summary = DictSummary() \ No newline at end of file + self._summary = DictSummary() diff --git a/deepspeech/training/triggers/limit_trigger.py b/deepspeech/training/triggers/limit_trigger.py index b34b537a..ecd527ac 100644 --- a/deepspeech/training/triggers/limit_trigger.py +++ b/deepspeech/training/triggers/limit_trigger.py @@ -28,4 +28,4 @@ class LimitTrigger(): state = trainer.updater.state index = getattr(state, self.unit) fire = index >= self.limit - return fire \ No newline at end of file + return fire diff --git a/deepspeech/training/triggers/time_trigger.py b/deepspeech/training/triggers/time_trigger.py index 88f5d2f1..e31179a9 100644 --- a/deepspeech/training/triggers/time_trigger.py +++ b/deepspeech/training/triggers/time_trigger.py @@ -38,4 +38,4 @@ class TimeTrigger(): return state_dict def set_state_dict(self, state_dict): - self._next_time = state_dict['next_time'] \ No newline at end of file + self._next_time = state_dict['next_time'] diff --git a/deepspeech/training/triggers/utils.py b/deepspeech/training/triggers/utils.py index 42a697a9..1a7c4292 100644 --- a/deepspeech/training/triggers/utils.py +++ b/deepspeech/training/triggers/utils.py @@ -1,3 +1,16 @@ +# Copyright (c) 2021 PaddlePaddle Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. from .interval_trigger import IntervalTrigger diff --git a/deepspeech/utils/asr_utils.py b/deepspeech/utils/asr_utils.py index 06cf6487..6f86e56f 100644 --- a/deepspeech/utils/asr_utils.py +++ b/deepspeech/utils/asr_utils.py @@ -12,6 +12,7 @@ # See the License for the specific language governing permissions and # limitations under the License. import json + import numpy as np __all__ = ["label_smoothing_dist"] @@ -33,8 +34,7 @@ def label_smoothing_dist(odim, lsm_type, transcript=None, blank=0): if lsm_type == "unigram": assert transcript is not None, ( - "transcript is required for %s label smoothing" % lsm_type - ) + "transcript is required for %s label smoothing" % lsm_type) labelcount = np.zeros(odim) for k, v in trans_json.items(): ids = np.array([int(n) for n in v["output"][0]["tokenid"].split()]) diff --git a/deepspeech/utils/bleu_score.py b/deepspeech/utils/bleu_score.py index 93749ddd..ea32fcf9 100644 --- a/deepspeech/utils/bleu_score.py +++ b/deepspeech/utils/bleu_score.py @@ -14,9 +14,9 @@ """This module provides functions to calculate bleu score in different level. e.g. wer for word-level, cer for char-level. """ -import sacrebleu import nltk import numpy as np +import sacrebleu __all__ = ['bleu', 'char_bleu', "ErrorCalculator"] @@ -106,11 +106,14 @@ class ErrorCalculator(): # NOTE: padding index (-1) in y_true is used to pad y_hat # because y_hats is not padded with -1 seq_hat = [self.char_list[int(idx)] for idx in y_hat[:ymax]] - seq_true = [self.char_list[int(idx)] for idx in y_true if int(idx) != -1] + seq_true = [ + self.char_list[int(idx)] for idx in y_true if int(idx) != -1 + ] seq_hat_text = "".join(seq_hat).replace(self.space, " ") seq_hat_text = seq_hat_text.replace(self.pad, "") seq_true_text = "".join(seq_true).replace(self.space, " ") seqs_hat.append(seq_hat_text) seqs_true.append(seq_true_text) - bleu = nltk.bleu_score.corpus_bleu([[ref] for ref in seqs_true], seqs_hat) + bleu = nltk.bleu_score.corpus_bleu([[ref] for ref in seqs_true], + seqs_hat) return bleu * 100 diff --git a/deepspeech/utils/error_rate.py b/deepspeech/utils/error_rate.py index 0ad62b6b..548376aa 100644 --- a/deepspeech/utils/error_rate.py +++ b/deepspeech/utils/error_rate.py @@ -14,11 +14,10 @@ """This module provides functions to calculate error rate in different level. e.g. wer for word-level, cer for char-level. """ +from itertools import groupby + import editdistance import numpy as np -import logging -import sys -from itertools import groupby __all__ = ['word_errors', 'char_errors', 'wer', 'cer', "ErrorCalculator"] @@ -225,9 +224,12 @@ class ErrorCalculator(): :return: """ - def __init__( - self, char_list, sym_space, sym_blank, report_cer=False, report_wer=False - ): + def __init__(self, + char_list, + sym_space, + sym_blank, + report_cer=False, + report_wer=False): """Construct an ErrorCalculator object.""" super().__init__() @@ -317,7 +319,9 @@ class ErrorCalculator(): ymax = eos_true[0] if len(eos_true) > 0 else len(y_true) # NOTE: padding index (-1) in y_true is used to pad y_hat seq_hat = [self.char_list[int(idx)] for idx in y_hat[:ymax]] - seq_true = [self.char_list[int(idx)] for idx in y_true if int(idx) != -1] + seq_true = [ + self.char_list[int(idx)] for idx in y_true if int(idx) != -1 + ] seq_hat_text = "".join(seq_hat).replace(self.space, " ") seq_hat_text = seq_hat_text.replace(self.blank, "") seq_true_text = "".join(seq_true).replace(self.space, " ") diff --git a/setup.py b/setup.py index bd982129..be17e0a4 100644 --- a/setup.py +++ b/setup.py @@ -15,7 +15,6 @@ import contextlib import inspect import io import os -import re import subprocess as sp import sys from pathlib import Path @@ -84,7 +83,7 @@ def _post_install(install_lib_dir): tools_extrs_dir = HERE / 'tools/extras' with pushd(tools_extrs_dir): print(os.getcwd()) - check_call(f"./install_autolog.sh") + check_call("./install_autolog.sh") print("autolog install.") # ctcdecoder diff --git a/utils/json2trn.py b/utils/json2trn.py index 873fde4f..4adfa491 100755 --- a/utils/json2trn.py +++ b/utils/json2trn.py @@ -4,7 +4,6 @@ # 2018 Xuankai Chang (Shanghai Jiao Tong University) # Apache 2.0 (http://www.apache.org/licenses/LICENSE-2.0) import argparse -import json import logging import sys