From 82044aecc52367f9f6cd91323fbcd7c8414d14b6 Mon Sep 17 00:00:00 2001 From: Hui Zhang Date: Fri, 22 Oct 2021 06:28:44 +0000 Subject: [PATCH 1/5] u2 kaldi wer4p0 --- deepspeech/__init__.py | 14 ++++++-- examples/librispeech/s2/README.md | 46 ++++++--------------------- examples/librispeech/s2/local/test.sh | 21 ++++++++---- examples/librispeech/s2/run.sh | 4 ++- 4 files changed, 38 insertions(+), 47 deletions(-) diff --git a/deepspeech/__init__.py b/deepspeech/__init__.py index 493f10a6f..f0cd1bae2 100644 --- a/deepspeech/__init__.py +++ b/deepspeech/__init__.py @@ -362,11 +362,19 @@ def ctc_loss(logits, label_lengths, blank=0, reduction='mean', - norm_by_times=True): + norm_by_times=False, + norm_by_batchsize=True, + norm_by_total_logits_len=False): #logger.info("my ctc loss with norm by times") ## https://github.com/PaddlePaddle/Paddle/blob/f5ca2db2cc/paddle/fluid/operators/warpctc_op.h#L403 - loss_out = paddle.fluid.layers.warpctc(logits, labels, blank, norm_by_times, - input_lengths, label_lengths) + loss_out = paddle.fluid.layers.warpctc( + logits, + labels, + blank, + norm_by_times, + input_lengths, + label_lengths, + norm_by_batchsize, ) loss_out = paddle.fluid.layers.squeeze(loss_out, [-1]) assert reduction in ['mean', 'sum', 'none'] diff --git a/examples/librispeech/s2/README.md b/examples/librispeech/s2/README.md index e4022f014..1f7c69194 100644 --- a/examples/librispeech/s2/README.md +++ b/examples/librispeech/s2/README.md @@ -1,41 +1,13 @@ # LibriSpeech -## Data -| Data Subset | Duration in Seconds | -| data/manifest.train | 0.83s ~ 29.735s | -| data/manifest.dev | 1.065 ~ 35.155s | -| data/manifest.test-clean | 1.285s ~ 34.955s | +| Model | Params | Config | Augmentation| Loss | +| --- | --- | --- | --- | +| transformer | 32.52 M | conf/transformer.yaml | spec_aug | 6.3197922706604 | -## Conformer -| Model | Params | Config | Augmentation| Test set | Decode method | Loss | WER | -| --- | --- | --- | --- | --- | --- | --- | --- | -| conformer | 47.63 M | conf/conformer.yaml | spec_aug + shift | test-clean | attention | - | - | -| conformer | 47.63 M | conf/conformer.yaml | spec_aug + shift | test-clean | ctc_greedy_search | | | -| conformer | 47.63 M | conf/conformer.yaml | spec_aug + shift | test-clean | ctc_prefix_beam_search | | | -| conformer | 47.63 M | conf/conformer.yaml | spec_aug + shift | test-clean | attention_rescoring | | | -### Test w/o length filter -| Model | Params | Config | Augmentation| Test set | Decode method | Loss | WER | -| --- | --- | --- | --- | --- | --- | --- | --- | -| conformer | 47.63 M | conf/conformer.yaml | spec_aug + shift | test-clean-all | attention | | | - - -## Chunk Conformer - -| Model | Params | Config | Augmentation| Test set | Decode method | Chunk Size & Left Chunks | Loss | WER | -| --- | --- | --- | --- | --- | --- | --- | --- | --- | -| conformer | 47.63 M | conf/chunk_conformer.yaml | spec_aug + shift | test-clean | attention | 16, -1 | | | -| conformer | 47.63 M | conf/chunk_conformer.yaml | spec_aug + shift | test-clean | ctc_greedy_search | 16, -1 | | | -| conformer | 47.63 M | conf/chunk_conformer.yaml | spec_aug + shift | test-clean | ctc_prefix_beam_search | 16, -1 | | - | -| conformer | 47.63 M | conf/chunk_conformer.yaml | spec_aug + shift | test-clean | attention_rescoring | 16, -1 | | - | - - -## Transformer -| Model | Params | Config | Augmentation| Test set | Decode method | Loss | WER | -| --- | --- | --- | --- | --- | --- | --- | --- | -| transformer | 32.52 M | conf/transformer.yaml | spec_aug + shift | test-clean | attention | | | - -### Test w/o length filter -| Model | Params | Config | Augmentation| Test set | Decode method | Loss | WER | -| --- | --- | --- | --- | --- | --- | --- | --- | -| transformer | 32.52 M | conf/transformer.yaml | spec_aug + shift | test-clean-all | attention | | | +| Test Set | Decode Method | #Snt | #Wrd | Corr | Sub | Del | Ins | Err | S.Err | +| --- | --- | --- | --- | --- | --- | --- | --- | --- | --- | +| test-clean | attention | 2620 | 52576 | 96.4 | 2.5 | 1.1 | 0.4 | 4.0 | 34.7 | +| test-clean | ctc_greedy_search | 2620 | 52576 | 95.9 | 3.7 | 0.4 | 0.5 | 4.6 | 48.0 | +| test-clean | ctc_prefix_beamsearch | 2620 | 52576 | 95.9 | 3.7 | 0.4 | 0.5 | 4.6 | 47.6 | +| test-clean | attention_rescore | 2620 | 52576 | 96.8 | 2.9 | 0.3 | 0.4 | 3.7 | 38.0 | diff --git a/examples/librispeech/s2/local/test.sh b/examples/librispeech/s2/local/test.sh index 5eeb2d612..379a3787e 100755 --- a/examples/librispeech/s2/local/test.sh +++ b/examples/librispeech/s2/local/test.sh @@ -6,7 +6,7 @@ expdir=exp datadir=data nj=32 -lmtag= +lmtag='nolm' recog_set="test-clean test-other dev-clean dev-other" recog_set="test-clean" @@ -29,11 +29,18 @@ config_path=$1 dict=$2 ckpt_prefix=$3 + +ckpt_dir=$(dirname `dirname ${ckpt_prefix}`) +echo "ckpt dir: ${ckpt_dir}" + +ckpt_tag=$(basename ${ckpt_prefix}) +echo "ckpt tag: ${ckpt_tag}" + chunk_mode=false if [[ ${config_path} =~ ^.*chunk_.*yaml$ ]];then chunk_mode=true fi -echo "chunk mode ${chunk_mode}" +echo "chunk mode: ${chunk_mode}" # download language model @@ -46,11 +53,13 @@ pids=() # initialize pids for dmethd in attention ctc_greedy_search ctc_prefix_beam_search attention_rescoring; do ( + echo "decode method: ${dmethd}" for rtask in ${recog_set}; do ( - decode_dir=decode_${rtask}_${dmethd}_$(basename ${config_path%.*})_${lmtag} + echo "dataset: ${rtask}" + decode_dir=${ckpt_dir}/decode/decode_${rtask/-/_}_${dmethd}_$(basename ${config_path%.*})_${lmtag}_${ckpt_tag} feat_recog_dir=${datadir} - mkdir -p ${expdir}/${decode_dir} + mkdir -p ${decode_dir} mkdir -p ${feat_recog_dir} # split data @@ -61,7 +70,7 @@ for dmethd in attention ctc_greedy_search ctc_prefix_beam_search attention_resco # set batchsize 0 to disable batch decoding batch_size=1 - ${decode_cmd} JOB=1:${nj} ${expdir}/${decode_dir}/log/decode.JOB.log \ + ${decode_cmd} JOB=1:${nj} ${decode_dir}/log/decode.JOB.log \ python3 -u ${BIN_DIR}/test.py \ --model-name u2_kaldi \ --run-mode test \ @@ -69,7 +78,7 @@ for dmethd in attention ctc_greedy_search ctc_prefix_beam_search attention_resco --dict-path ${dict} \ --config ${config_path} \ --checkpoint_path ${ckpt_prefix} \ - --result-file ${expdir}/${decode_dir}/data.JOB.json \ + --result-file ${decode_dir}/data.JOB.json \ --opts decoding.decoding_method ${dmethd} \ --opts decoding.batch_size ${batch_size} \ --opts data.test_manifest ${feat_recog_dir}/split${nj}/JOB/manifest.${rtask} diff --git a/examples/librispeech/s2/run.sh b/examples/librispeech/s2/run.sh index 46b6ac1b4..1ffe3e5c5 100755 --- a/examples/librispeech/s2/run.sh +++ b/examples/librispeech/s2/run.sh @@ -1,4 +1,5 @@ #!/bin/bash + set -e . ./path.sh || exit 1; @@ -7,8 +8,9 @@ set -e stage=0 stop_stage=100 conf_path=conf/transformer.yaml -dict_path=data/train_960_unigram5000_units.txt +dict_path=data/bpe_unigram_5000_units.txt avg_num=10 + source ${MAIN_ROOT}/utils/parse_options.sh || exit 1; avg_ckpt=avg_${avg_num} From 1d81d577b1f48ca294338a9591d9dc71e6c7724e Mon Sep 17 00:00:00 2001 From: Hui Zhang Date: Fri, 22 Oct 2021 08:34:59 +0000 Subject: [PATCH 2/5] add recog interface --- deepspeech/decoders/recog.py | 154 +++++++++++++++++++++++++++++ deepspeech/decoders/utils.py | 75 ++++++++++++++ deepspeech/models/asr_interface.py | 148 +++++++++++++++++++++++++++ deepspeech/models/u2/u2.py | 28 +++++- deepspeech/modules/decoder.py | 69 ++++++++++++- deepspeech/modules/mask.py | 16 ++- 6 files changed, 484 insertions(+), 6 deletions(-) create mode 100644 deepspeech/decoders/recog.py create mode 100644 deepspeech/models/asr_interface.py diff --git a/deepspeech/decoders/recog.py b/deepspeech/decoders/recog.py new file mode 100644 index 000000000..399c5c54f --- /dev/null +++ b/deepspeech/decoders/recog.py @@ -0,0 +1,154 @@ +"""V2 backend for `asr_recog.py` using py:class:`espnet.nets.beam_search.BeamSearch`.""" + +import json +import paddle + +# from espnet.asr.asr_utils import get_model_conf +# from espnet.asr.asr_utils import torch_load +# from espnet.asr.pytorch_backend.asr import load_trained_model +# from espnet.nets.lm_interface import dynamic_import_lm + +# from espnet.nets.asr_interface import ASRInterface + +from .utils import add_results_to_json +# from .batch_beam_search import BatchBeamSearch +from .beam_search import BeamSearch +from .scorer_interface import BatchScorerInterface +from .scorers.length_bonus import LengthBonus + +from deepspeech.io.reader import LoadInputsAndTargets +from deepspeech.utils.log import Log +logger = Log(__name__).getlog() + + +def recog_v2(args): + """Decode with custom models that implements ScorerInterface. + + Args: + args (namespace): The program arguments. + See py:func:`bin.asr_recog.get_parser` for details + + """ + logger.warning("experimental API for custom LMs is selected by --api v2") + if args.batchsize > 1: + raise NotImplementedError("multi-utt batch decoding is not implemented") + if args.streaming_mode is not None: + raise NotImplementedError("streaming mode is not implemented") + if args.word_rnnlm: + raise NotImplementedError("word LM is not implemented") + + # set_deterministic(args) + model, train_args = load_trained_model(args.model) + # assert isinstance(model, ASRInterface) + model.eval() + load_inputs_and_targets = LoadInputsAndTargets( + mode="asr", + load_output=False, + sort_in_input_length=False, + preprocess_conf=train_args.preprocess_conf + if args.preprocess_conf is None + else args.preprocess_conf, + preprocess_args={"train": False}, + ) + + if args.rnnlm: + lm_args = get_model_conf(args.rnnlm, args.rnnlm_conf) + # NOTE: for a compatibility with less than 0.5.0 version models + lm_model_module = getattr(lm_args, "model_module", "default") + lm_class = dynamic_import_lm(lm_model_module, lm_args.backend) + lm = lm_class(len(train_args.char_list), lm_args) + torch_load(args.rnnlm, lm) + lm.eval() + else: + lm = None + + if args.ngram_model: + from .scorers.ngram import NgramFullScorer + from .scorers.ngram import NgramPartScorer + + if args.ngram_scorer == "full": + ngram = NgramFullScorer(args.ngram_model, train_args.char_list) + else: + ngram = NgramPartScorer(args.ngram_model, train_args.char_list) + else: + ngram = None + + scorers = model.scorers() + scorers["lm"] = lm + scorers["ngram"] = ngram + scorers["length_bonus"] = LengthBonus(len(train_args.char_list)) + weights = dict( + decoder=1.0 - args.ctc_weight, + ctc=args.ctc_weight, + lm=args.lm_weight, + ngram=args.ngram_weight, + length_bonus=args.penalty, + ) + beam_search = BeamSearch( + beam_size=args.beam_size, + vocab_size=len(train_args.char_list), + weights=weights, + scorers=scorers, + sos=model.sos, + eos=model.eos, + token_list=train_args.char_list, + pre_beam_score_key=None if args.ctc_weight == 1.0 else "full", + ) + # TODO(karita): make all scorers batchfied + if args.batchsize == 1: + non_batch = [ + k + for k, v in beam_search.full_scorers.items() + if not isinstance(v, BatchScorerInterface) + ] + if len(non_batch) == 0: + beam_search.__class__ = BatchBeamSearch + logger.info("BatchBeamSearch implementation is selected.") + else: + logger.warning( + f"As non-batch scorers {non_batch} are found, " + f"fall back to non-batch implementation." + ) + + if args.ngpu > 1: + raise NotImplementedError("only single GPU decoding is supported") + if args.ngpu == 1: + device = "gpu:0" + else: + device = "cpu" + dtype = getattr(paddle, args.dtype) + logger.info(f"Decoding device={device}, dtype={dtype}") + model.to(device=device, dtype=dtype) + model.eval() + beam_search.to(device=device, dtype=dtype) + beam_search.eval() + + # read json data + with open(args.recog_json, "rb") as f: + js = json.load(f) + # josnlines to dict, key by 'utt' + js = {item['utt']: item for item in js} + + new_js = {} + with paddle.no_grad(): + for idx, name in enumerate(js.keys(), 1): + logger.info("(%d/%d) decoding " + name, idx, len(js.keys())) + batch = [(name, js[name])] + feat = load_inputs_and_targets(batch)[0][0] + enc = model.encode(paddle.to_tensor(feat).to(device=device, dtype=dtype)) + nbest_hyps = beam_search( + x=enc, maxlenratio=args.maxlenratio, minlenratio=args.minlenratio + ) + nbest_hyps = [ + h.asdict() for h in nbest_hyps[: min(len(nbest_hyps), args.nbest)] + ] + new_js[name] = add_results_to_json( + js[name], nbest_hyps, train_args.char_list + ) + + with open(args.result_label, "wb") as f: + f.write( + json.dumps( + {"utts": new_js}, indent=4, ensure_ascii=False, sort_keys=True + ).encode("utf_8") + ) diff --git a/deepspeech/decoders/utils.py b/deepspeech/decoders/utils.py index 92f65814d..0281a78bb 100644 --- a/deepspeech/decoders/utils.py +++ b/deepspeech/decoders/utils.py @@ -47,3 +47,78 @@ def end_detect(ended_hyps, i, M=3, D_end=np.log(1 * np.exp(-10))): return True else: return False + + +# * ------------------ recognition related ------------------ * +def parse_hypothesis(hyp, char_list): + """Parse hypothesis. + + Args: + hyp (list[dict[str, Any]]): Recognition hypothesis. + char_list (list[str]): List of characters. + + Returns: + tuple(str, str, str, float) + + """ + # remove sos and get results + tokenid_as_list = list(map(int, hyp["yseq"][1:])) + token_as_list = [char_list[idx] for idx in tokenid_as_list] + score = float(hyp["score"]) + + # convert to string + tokenid = " ".join([str(idx) for idx in tokenid_as_list]) + token = " ".join(token_as_list) + text = "".join(token_as_list).replace("", " ") + + return text, token, tokenid, score + + +def add_results_to_json(js, nbest_hyps, char_list): + """Add N-best results to json. + + Args: + js (dict[str, Any]): Groundtruth utterance dict. + nbest_hyps_sd (list[dict[str, Any]]): + List of hypothesis for multi_speakers: nutts x nspkrs. + char_list (list[str]): List of characters. + + Returns: + dict[str, Any]: N-best results added utterance dict. + + """ + # copy old json info + new_js = dict() + new_js["utt2spk"] = js["utt2spk"] + new_js["output"] = [] + + for n, hyp in enumerate(nbest_hyps, 1): + # parse hypothesis + rec_text, rec_token, rec_tokenid, score = parse_hypothesis(hyp, char_list) + + # copy ground-truth + if len(js["output"]) > 0: + out_dic = dict(js["output"][0].items()) + else: + # for no reference case (e.g., speech translation) + out_dic = {"name": ""} + + # update name + out_dic["name"] += "[%d]" % n + + # add recognition results + out_dic["rec_text"] = rec_text + out_dic["rec_token"] = rec_token + out_dic["rec_tokenid"] = rec_tokenid + out_dic["score"] = score + + # add to list of N-best result dicts + new_js["output"].append(out_dic) + + # show 1-best result + if n == 1: + if "text" in out_dic.keys(): + logging.info("groundtruth: %s" % out_dic["text"]) + logging.info("prediction : %s" % out_dic["rec_text"]) + + return new_js \ No newline at end of file diff --git a/deepspeech/models/asr_interface.py b/deepspeech/models/asr_interface.py new file mode 100644 index 000000000..eb820fc05 --- /dev/null +++ b/deepspeech/models/asr_interface.py @@ -0,0 +1,148 @@ +"""ASR Interface module.""" +import argparse + +from deepspeech.utils.dynamic_import import dynamic_import + + +class ASRInterface: + """ASR Interface for ESPnet model implementation.""" + + @staticmethod + def add_arguments(parser): + """Add arguments to parser.""" + return parser + + @classmethod + def build(cls, idim: int, odim: int, **kwargs): + """Initialize this class with python-level args. + + Args: + idim (int): The number of an input feature dim. + odim (int): The number of output vocab. + + Returns: + ASRinterface: A new instance of ASRInterface. + + """ + args = argparse.Namespace(**kwargs) + return cls(idim, odim, args) + + def forward(self, xs, ilens, ys, olens): + """Compute loss for training. + + :param xs: batch of padded source sequences paddle.Tensor (B, Tmax, idim) + :param ilens: batch of lengths of source sequences (B), paddle.Tensor + :param ys: batch of padded target sequences paddle.Tensor (B, Lmax) + :param olens: batch of lengths of target sequences (B), paddle.Tensor + :return: loss value + :rtype: paddle.Tensor + """ + raise NotImplementedError("forward method is not implemented") + + def recognize(self, x, recog_args, char_list=None, rnnlm=None): + """Recognize x for evaluation. + + :param ndarray x: input acouctic feature (B, T, D) or (T, D) + :param namespace recog_args: argment namespace contraining options + :param list char_list: list of characters + :param paddle.nn.Layer rnnlm: language model module + :return: N-best decoding results + :rtype: list + """ + raise NotImplementedError("recognize method is not implemented") + + def recognize_batch(self, x, recog_args, char_list=None, rnnlm=None): + """Beam search implementation for batch. + + :param paddle.Tensor x: encoder hidden state sequences (B, Tmax, Henc) + :param namespace recog_args: argument namespace containing options + :param list char_list: list of characters + :param paddle.nn.Module rnnlm: language model module + :return: N-best decoding results + :rtype: list + """ + raise NotImplementedError("Batch decoding is not supported yet.") + + def calculate_all_attentions(self, xs, ilens, ys): + """Calculate attention. + + :param list xs: list of padded input sequences [(T1, idim), (T2, idim), ...] + :param ndarray ilens: batch of lengths of input sequences (B) + :param list ys: list of character id sequence tensor [(L1), (L2), (L3), ...] + :return: attention weights (B, Lmax, Tmax) + :rtype: float ndarray + """ + raise NotImplementedError("calculate_all_attentions method is not implemented") + + def calculate_all_ctc_probs(self, xs, ilens, ys): + """Calculate CTC probability. + + :param list xs_pad: list of padded input sequences [(T1, idim), (T2, idim), ...] + :param ndarray ilens: batch of lengths of input sequences (B) + :param list ys: list of character id sequence tensor [(L1), (L2), (L3), ...] + :return: CTC probabilities (B, Tmax, vocab) + :rtype: float ndarray + """ + raise NotImplementedError("calculate_all_ctc_probs method is not implemented") + + @property + def attention_plot_class(self): + """Get attention plot class.""" + from espnet.asr.asr_utils import PlotAttentionReport + + return PlotAttentionReport + + @property + def ctc_plot_class(self): + """Get CTC plot class.""" + from espnet.asr.asr_utils import PlotCTCReport + + return PlotCTCReport + + def get_total_subsampling_factor(self): + """Get total subsampling factor.""" + raise NotImplementedError( + "get_total_subsampling_factor method is not implemented" + ) + + def encode(self, feat): + """Encode feature in `beam_search` (optional). + + Args: + x (numpy.ndarray): input feature (T, D) + Returns: + paddle.Tensor: encoded feature (T, D) + """ + raise NotImplementedError("encode method is not implemented") + + def scorers(self): + """Get scorers for `beam_search` (optional). + + Returns: + dict[str, ScorerInterface]: dict of `ScorerInterface` objects + + """ + raise NotImplementedError("decoders method is not implemented") + + +predefined_asr = { + "transformer": "deepspeech.models.u2:E2E", + "conformer": "deepspeech.models.u2:E2E", +} + +def dynamic_import_asr(module, name): + """Import ASR models dynamically. + + Args: + module (str): module_name:class_name or alias in `predefined_asr` + name (str): asr name. e.g., transformer, conformer + + Returns: + type: ASR class + + """ + model_class = dynamic_import(module, predefined_asr.get(name, "")) + assert issubclass( + model_class, ASRInterface + ), f"{module} does not implement ASRInterface" + return model_class diff --git a/deepspeech/models/u2/u2.py b/deepspeech/models/u2/u2.py index fd63fa9c5..8915cbd7d 100644 --- a/deepspeech/models/u2/u2.py +++ b/deepspeech/models/u2/u2.py @@ -49,13 +49,15 @@ from deepspeech.utils.tensor_utils import pad_sequence from deepspeech.utils.tensor_utils import th_accuracy from deepspeech.utils.utility import log_add from deepspeech.utils.utility import UpdateConfig +from deepspeech.models.asr_interface import ASRInterface +from deepspeech.decoders.scorers.ctc_prefix_score import CTCPrefixScorer __all__ = ["U2Model", "U2InferModel"] logger = Log(__name__).getlog() -class U2BaseModel(nn.Layer): +class U2BaseModel(ASRInterface, nn.Layer): """CTC-Attention hybrid Encoder-Decoder model""" @classmethod @@ -120,7 +122,7 @@ class U2BaseModel(nn.Layer): **kwargs): assert 0.0 <= ctc_weight <= 1.0, ctc_weight - super().__init__() + nn.Layer.__init__(self) # note that eos is the same as sos (equivalent ID) self.sos = vocab_size - 1 self.eos = vocab_size - 1 @@ -813,7 +815,27 @@ class U2BaseModel(nn.Layer): return res, res_tokenids -class U2Model(U2BaseModel): +class U2DecodeModel(U2BaseModel): + + def scorers(self): + """Scorers.""" + return dict(decoder=self.decoder, ctc=CTCPrefixScorer(self.ctc, self.eos)) + + def encode(self, x): + """Encode acoustic features. + + :param ndarray x: source acoustic feature (T, D) + :return: encoder outputs + :rtype: paddle.Tensor + """ + self.eval() + x = paddle.to_tensor(x).unsqueeze(0) + ilen = x.size(1) + enc_output, _ = self._forward_encoder(x, ilen) + return enc_output.squeeze(0) + + +class U2Model(U2DecodeModel): def __init__(self, configs: dict): vocab_size, encoder, decoder, ctc = U2Model._init_from_config(configs) diff --git a/deepspeech/modules/decoder.py b/deepspeech/modules/decoder.py index 1ae3ce371..154b7390f 100644 --- a/deepspeech/modules/decoder.py +++ b/deepspeech/modules/decoder.py @@ -15,6 +15,7 @@ from typing import List from typing import Optional from typing import Tuple +from typing import Any import paddle from paddle import nn @@ -25,7 +26,9 @@ from deepspeech.modules.decoder_layer import DecoderLayer from deepspeech.modules.embedding import PositionalEncoding from deepspeech.modules.mask import make_non_pad_mask from deepspeech.modules.mask import subsequent_mask +from deepspeech.modules.mask import make_xs_mask from deepspeech.modules.positionwise_feed_forward import PositionwiseFeedForward +from deepspeech.decoders.scorers.score_interface import BatchScorerInterface from deepspeech.utils.log import Log logger = Log(__name__).getlog() @@ -33,7 +36,7 @@ logger = Log(__name__).getlog() __all__ = ["TransformerDecoder"] -class TransformerDecoder(nn.Layer): +class TransformerDecoder(BatchScorerInterface, nn.Layer): """Base class of Transfomer decoder module. Args: vocab_size: output dim @@ -71,7 +74,8 @@ class TransformerDecoder(nn.Layer): concat_after: bool=False, ): assert check_argument_types() - super().__init__() + nn.Layer.__init__(self) + self.selfattention_layer_type = 'selfattn' attention_dim = encoder_output_size if input_layer == "embed": @@ -180,3 +184,64 @@ class TransformerDecoder(nn.Layer): if self.use_output_layer: y = paddle.log_softmax(self.output_layer(y), axis=-1) return y, new_cache + + # beam search API (see ScorerInterface) + def score(self, ys, state, x): + """Score. + ys: (ylen,) + x: (xlen, n_feat) + """ + ys_mask = subsequent_mask(len(ys)).unsqueeze(0) + x_mask = make_xs_mask(x.unsqueeze(0)) + if self.selfattention_layer_type != "selfattn": + # TODO(karita): implement cache + logging.warning( + f"{self.selfattention_layer_type} does not support cached decoding." + ) + state = None + logp, state = self.forward_one_step( + x.unsqueeze(0), x_mask, + ys.unsqueeze(0), ys_mask, + cache=state + ) + return logp.squeeze(0), state + + # batch beam search API (see BatchScorerInterface) + def batch_score( + self, ys: paddle.Tensor, states: List[Any], xs: paddle.Tensor + ) -> Tuple[paddle.Tensor, List[Any]]: + """Score new token batch (required). + + Args: + ys (paddle.Tensor): paddle.int64 prefix tokens (n_batch, ylen). + states (List[Any]): Scorer states for prefix tokens. + xs (paddle.Tensor): + The encoder feature that generates ys (n_batch, xlen, n_feat). + + Returns: + tuple[paddle.Tensor, List[Any]]: Tuple of + batchfied scores for next token with shape of `(n_batch, n_vocab)` + and next state list for ys. + + """ + # merge states + n_batch = len(ys) + n_layers = len(self.decoders) + if states[0] is None: + batch_state = None + else: + # transpose state of [batch, layer] into [layer, batch] + batch_state = [ + paddle.stack([states[b][i] for b in range(n_batch)]) + for i in range(n_layers) + ] + + # batch decoding + ys_mask = subsequent_mask(ys.size(-1)).unsqueeze(0) + + xs_mask = make_xs_mask(xs) + logp, states = self.forward_one_step(xs, xs_mask, ys, ys_mask, cache=batch_state) + + # transpose state of [layer, batch] into [batch, layer] + state_list = [[states[i][b] for i in range(n_layers)] for b in range(n_batch)] + return logp, state_list diff --git a/deepspeech/modules/mask.py b/deepspeech/modules/mask.py index 00f228a2b..cffa10a7b 100644 --- a/deepspeech/modules/mask.py +++ b/deepspeech/modules/mask.py @@ -18,12 +18,24 @@ from deepspeech.utils.log import Log logger = Log(__name__).getlog() __all__ = [ - "make_pad_mask", "make_non_pad_mask", "subsequent_mask", + "make_xs_mask", "make_pad_mask", "make_non_pad_mask", "subsequent_mask", "subsequent_chunk_mask", "add_optional_chunk_mask", "mask_finished_scores", "mask_finished_preds" ] +def make_xs_mask(xs:paddle.Tensor) -> paddle.Tensor: + """Maks mask tensor containing indices of non-padded part. + Args: + xs (paddle.Tensor): (B, T, D), zeros for pad. + Returns: + paddle.Tensor: Mask Tensor indices of non-padded part. (B, T, D) + """ + pad_frame = paddle.zeros([1, 1, xs.shape[-1]], dtype=xs.dtype) + mask = xs != pad_frame + return mask + + def make_pad_mask(lengths: paddle.Tensor) -> paddle.Tensor: """Make mask tensor containing indices of padded part. See description of make_non_pad_mask. @@ -31,6 +43,7 @@ def make_pad_mask(lengths: paddle.Tensor) -> paddle.Tensor: lengths (paddle.Tensor): Batch of lengths (B,). Returns: paddle.Tensor: Mask tensor containing indices of padded part. + (B, T) Examples: >>> lengths = [5, 3, 2] >>> make_pad_mask(lengths) @@ -62,6 +75,7 @@ def make_non_pad_mask(lengths: paddle.Tensor) -> paddle.Tensor: lengths (paddle.Tensor): Batch of lengths (B,). Returns: paddle.Tensor: mask tensor containing indices of padded part. + (B, T) Examples: >>> lengths = [5, 3, 2] >>> make_non_pad_mask(lengths) From 9c62ad69d71a6d4c23155879b85a3492d40c5a1f Mon Sep 17 00:00:00 2001 From: Hui Zhang Date: Fri, 22 Oct 2021 10:18:46 +0000 Subject: [PATCH 3/5] decoder with ctc prefix score --- deepspeech/__init__.py | 20 +- deepspeech/decoders/beam_search.py | 22 +- deepspeech/decoders/recog.py | 111 +++-- deepspeech/decoders/scorers/ctc.py | 4 +- .../decoders/scorers/ctc_prefix_score.py | 4 +- ...score_interface.py => scorer_interface.py} | 0 deepspeech/decoders/utils.py | 10 +- deepspeech/exps/u2_kaldi/bin/recog.py | 379 ++++++++++++++++++ deepspeech/exps/u2_kaldi/model.py | 2 - .../frontend/featurizer/text_featurizer.py | 2 +- deepspeech/models/u2/u2.py | 2 +- deepspeech/modules/decoder.py | 11 +- deepspeech/modules/mask.py | 7 +- deepspeech/training/cli.py | 6 +- deepspeech/training/trainer.py | 20 +- examples/librispeech/s2/README.md | 1 + .../librispeech/s2/conf/decode/decode.yaml | 7 + .../s2/conf/decode/decode_all.yaml | 7 + .../s2/conf/decode/decode_wo_lm.yaml | 7 + examples/librispeech/s2/local/recog.sh | 103 +++++ examples/librispeech/s2/local/test.sh | 2 +- requirements.txt | 1 + 22 files changed, 647 insertions(+), 81 deletions(-) rename deepspeech/decoders/scorers/{score_interface.py => scorer_interface.py} (100%) create mode 100644 deepspeech/exps/u2_kaldi/bin/recog.py create mode 100644 examples/librispeech/s2/conf/decode/decode.yaml create mode 100644 examples/librispeech/s2/conf/decode/decode_all.yaml create mode 100644 examples/librispeech/s2/conf/decode/decode_wo_lm.yaml create mode 100755 examples/librispeech/s2/local/recog.sh diff --git a/deepspeech/__init__.py b/deepspeech/__init__.py index 5f9ba007e..ca5b4ba5c 100644 --- a/deepspeech/__init__.py +++ b/deepspeech/__init__.py @@ -233,7 +233,7 @@ def is_broadcastable(shp1, shp2): def masked_fill(xs: paddle.Tensor, mask: paddle.Tensor, value: Union[float, int]): - assert is_broadcastable(xs.shape, mask.shape) is True + assert is_broadcastable(xs.shape, mask.shape) is True, (xs.shape, mask.shape) bshape = paddle.broadcast_shape(xs.shape, mask.shape) mask = mask.broadcast_to(bshape) trues = paddle.ones_like(xs) * value @@ -312,18 +312,18 @@ if not hasattr(paddle.Tensor, 'type_as'): def to(x: paddle.Tensor, *args, **kwargs) -> paddle.Tensor: - assert len(args) == 1 - if isinstance(args[0], str): # dtype - return x.astype(args[0]) - elif isinstance(args[0], paddle.Tensor): #Tensor - return x.astype(args[0].dtype) - else: # Device - return x + assert len(args) == 1 + if isinstance(args[0], str): # dtype + return x.astype(args[0]) + elif isinstance(args[0], paddle.Tensor): # Tensor + return x.astype(args[0].dtype) + else: # Device + return x if not hasattr(paddle.Tensor, 'to'): - logger.debug("register user to to paddle.Tensor, remove this when fixed!") - setattr(paddle.Tensor, 'to', to) + logger.debug("register user to to paddle.Tensor, remove this when fixed!") + setattr(paddle.Tensor, 'to', to) def func_float(x: paddle.Tensor) -> paddle.Tensor: diff --git a/deepspeech/decoders/beam_search.py b/deepspeech/decoders/beam_search.py index 5bf0d3e2f..afb8aefa5 100644 --- a/deepspeech/decoders/beam_search.py +++ b/deepspeech/decoders/beam_search.py @@ -1,7 +1,6 @@ """Beam search module.""" from itertools import chain -import logger from typing import Any from typing import Dict from typing import List @@ -141,7 +140,7 @@ class BeamSearch(paddle.nn.Layer): ] @staticmethod - def append_token(xs: paddle.Tensor, x: int) -> paddle.Tensor: + def append_token(xs: paddle.Tensor, x: Union[int, paddle.Tensor]) -> paddle.Tensor: """Append new token to prefix tokens. Args: @@ -152,8 +151,8 @@ class BeamSearch(paddle.nn.Layer): paddle.Tensor: (T+1,), New tensor contains: xs + [x] with xs.dtype and xs.device """ - x = paddle.to_tensor([x], dtype=xs.dtype, place=xs.place) - return paddle.cat((xs, x)) + x = paddle.to_tensor([x], dtype=xs.dtype) if isinstance(x, int) else x + return paddle.concat((xs, x)) def score_full( self, hyp: Hypothesis, x: paddle.Tensor @@ -306,7 +305,7 @@ class BeamSearch(paddle.nn.Layer): part_ids = paddle.arange(self.n_vocab) # no pre-beam for hyp in running_hyps: # scoring - weighted_scores = paddle.zeros(self.n_vocab, dtype=x.dtype) + weighted_scores = paddle.zeros([self.n_vocab], dtype=x.dtype) scores, states = self.score_full(hyp, x) for k in self.full_scorers: weighted_scores += self.weights[k] * scores[k] @@ -410,15 +409,20 @@ class BeamSearch(paddle.nn.Layer): best = nbest_hyps[0] for k, v in best.scores.items(): logger.info( - f"{v:6.2f} * {self.weights[k]:3} = {v * self.weights[k]:6.2f} for {k}" + f"{float(v):6.2f} * {self.weights[k]:3} = {float(v) * self.weights[k]:6.2f} for {k}" ) - logger.info(f"total log probability: {best.score:.2f}") - logger.info(f"normalized log probability: {best.score / len(best.yseq):.2f}") + logger.info(f"total log probability: {float(best.score):.2f}") + logger.info(f"normalized log probability: {float(best.score) / len(best.yseq):.2f}") logger.info(f"total number of ended hypotheses: {len(nbest_hyps)}") if self.token_list is not None: + # logger.info( + # "best hypo: " + # + "".join([self.token_list[x] for x in best.yseq[1:-1]]) + # + "\n" + # ) logger.info( "best hypo: " - + "".join([self.token_list[x] for x in best.yseq[1:-1]]) + + "".join([self.token_list[x] for x in best.yseq[1:]]) + "\n" ) return nbest_hyps diff --git a/deepspeech/decoders/recog.py b/deepspeech/decoders/recog.py index 399c5c54f..867569aa0 100644 --- a/deepspeech/decoders/recog.py +++ b/deepspeech/decoders/recog.py @@ -2,18 +2,22 @@ import json import paddle +import yaml +from yacs.config import CfgNode +from pathlib import Path +import jsonlines # from espnet.asr.asr_utils import get_model_conf # from espnet.asr.asr_utils import torch_load # from espnet.asr.pytorch_backend.asr import load_trained_model # from espnet.nets.lm_interface import dynamic_import_lm -# from espnet.nets.asr_interface import ASRInterface +from deepspeech.models.asr_interface import ASRInterface from .utils import add_results_to_json # from .batch_beam_search import BatchBeamSearch from .beam_search import BeamSearch -from .scorer_interface import BatchScorerInterface +from .scorers.scorer_interface import BatchScorerInterface from .scorers.length_bonus import LengthBonus from deepspeech.io.reader import LoadInputsAndTargets @@ -21,6 +25,14 @@ from deepspeech.utils.log import Log logger = Log(__name__).getlog() +from deepspeech.utils.dynamic_import import dynamic_import +from deepspeech.utils.utility import print_arguments + +model_test_alias = { + "u2": "deepspeech.exps.u2.model:U2Tester", + "u2_kaldi": "deepspeech.exps.u2_kaldi.model:U2Tester", +} + def recog_v2(args): """Decode with custom models that implements ScorerInterface. @@ -36,16 +48,31 @@ def recog_v2(args): raise NotImplementedError("streaming mode is not implemented") if args.word_rnnlm: raise NotImplementedError("word LM is not implemented") - + args.nprocs = args.ngpu # set_deterministic(args) - model, train_args = load_trained_model(args.model) - # assert isinstance(model, ASRInterface) - model.eval() + + #model, train_args = load_trained_model(args.model) + model_path = Path(args.model) + ckpt_dir = model_path.parent.parent + + confs = CfgNode() + confs.set_new_allowed(True) + confs.merge_from_file(args.model_conf) + + class_obj = dynamic_import(args.model_name, model_test_alias) + exp = class_obj(confs, args) + with exp.eval(): + exp.setup() + exp.restore() + char_list = exp.args.char_list + + model = exp.model + assert isinstance(model, ASRInterface) load_inputs_and_targets = LoadInputsAndTargets( mode="asr", load_output=False, sort_in_input_length=False, - preprocess_conf=train_args.preprocess_conf + preprocess_conf=confs.collator.augmentation_config if args.preprocess_conf is None else args.preprocess_conf, preprocess_args={"train": False}, @@ -56,7 +83,7 @@ def recog_v2(args): # NOTE: for a compatibility with less than 0.5.0 version models lm_model_module = getattr(lm_args, "model_module", "default") lm_class = dynamic_import_lm(lm_model_module, lm_args.backend) - lm = lm_class(len(train_args.char_list), lm_args) + lm = lm_class(len(char_list), lm_args) torch_load(args.rnnlm, lm) lm.eval() else: @@ -67,16 +94,16 @@ def recog_v2(args): from .scorers.ngram import NgramPartScorer if args.ngram_scorer == "full": - ngram = NgramFullScorer(args.ngram_model, train_args.char_list) + ngram = NgramFullScorer(args.ngram_model, char_list) else: - ngram = NgramPartScorer(args.ngram_model, train_args.char_list) + ngram = NgramPartScorer(args.ngram_model, char_list) else: ngram = None scorers = model.scorers() scorers["lm"] = lm scorers["ngram"] = ngram - scorers["length_bonus"] = LengthBonus(len(train_args.char_list)) + scorers["length_bonus"] = LengthBonus(len(char_list)) weights = dict( decoder=1.0 - args.ctc_weight, ctc=args.ctc_weight, @@ -86,14 +113,15 @@ def recog_v2(args): ) beam_search = BeamSearch( beam_size=args.beam_size, - vocab_size=len(train_args.char_list), + vocab_size=len(char_list), weights=weights, scorers=scorers, sos=model.sos, eos=model.eos, - token_list=train_args.char_list, + token_list=char_list, pre_beam_score_key=None if args.ctc_weight == 1.0 else "full", ) + # TODO(karita): make all scorers batchfied if args.batchsize == 1: non_batch = [ @@ -116,6 +144,7 @@ def recog_v2(args): device = "gpu:0" else: device = "cpu" + paddle.set_device(device) dtype = getattr(paddle, args.dtype) logger.info(f"Decoding device={device}, dtype={dtype}") model.to(device=device, dtype=dtype) @@ -124,31 +153,41 @@ def recog_v2(args): beam_search.eval() # read json data - with open(args.recog_json, "rb") as f: - js = json.load(f) + js = [] + with jsonlines.open(args.recog_json, "r") as reader: + for item in reader: + js.append(item) # josnlines to dict, key by 'utt' js = {item['utt']: item for item in js} new_js = {} with paddle.no_grad(): - for idx, name in enumerate(js.keys(), 1): - logger.info("(%d/%d) decoding " + name, idx, len(js.keys())) - batch = [(name, js[name])] - feat = load_inputs_and_targets(batch)[0][0] - enc = model.encode(paddle.to_tensor(feat).to(device=device, dtype=dtype)) - nbest_hyps = beam_search( - x=enc, maxlenratio=args.maxlenratio, minlenratio=args.minlenratio - ) - nbest_hyps = [ - h.asdict() for h in nbest_hyps[: min(len(nbest_hyps), args.nbest)] - ] - new_js[name] = add_results_to_json( - js[name], nbest_hyps, train_args.char_list - ) - - with open(args.result_label, "wb") as f: - f.write( - json.dumps( - {"utts": new_js}, indent=4, ensure_ascii=False, sort_keys=True - ).encode("utf_8") - ) + with jsonlines.open(args.result_label, "w") as f: + for idx, name in enumerate(js.keys(), 1): + logger.info(f"({idx}/{len(js.keys())}) decoding " + name) + batch = [(name, js[name])] + feat = load_inputs_and_targets(batch)[0][0] + logger.info(f'feat: {feat.shape}') + enc = model.encode(paddle.to_tensor(feat).to(dtype)) + logger.info(f'eouts: {enc.shape}') + nbest_hyps = beam_search( + x=enc, maxlenratio=args.maxlenratio, minlenratio=args.minlenratio + ) + nbest_hyps = [ + h.asdict() for h in nbest_hyps[: min(len(nbest_hyps), args.nbest)] + ] + new_js[name] = add_results_to_json( + js[name], nbest_hyps, char_list + ) + + item = new_js[name]['output'][0] # 1-best + utt = name + ref = item['text'] + rec_text = item['rec_text'].replace('▁', ' ').replace('', '').strip() + rec_tokenid = map(int, item['rec_tokenid'].split()) + f.write({ + "utt": utt, + "refs": [ref], + "hyps": [rec_text], + "hyps_tokenid": [rec_tokenid], + }) \ No newline at end of file diff --git a/deepspeech/decoders/scorers/ctc.py b/deepspeech/decoders/scorers/ctc.py index 36b4bfd36..4871d6e12 100644 --- a/deepspeech/decoders/scorers/ctc.py +++ b/deepspeech/decoders/scorers/ctc.py @@ -15,8 +15,8 @@ import numpy as np import paddle -from .ctc_prefix_score import CTCPrefixScorer -from .ctc_prefix_score import CTCPrefixScorerPD +from .ctc_prefix_score import CTCPrefixScore +from .ctc_prefix_score import CTCPrefixScorePD from .scorer_interface import BatchPartialScorerInterface diff --git a/deepspeech/decoders/scorers/ctc_prefix_score.py b/deepspeech/decoders/scorers/ctc_prefix_score.py index 5f568c811..c85d546d3 100644 --- a/deepspeech/decoders/scorers/ctc_prefix_score.py +++ b/deepspeech/decoders/scorers/ctc_prefix_score.py @@ -6,7 +6,7 @@ import paddle import six -class CTCPrefixScorerPD(): +class CTCPrefixScorePD(): """Batch processing of CTCPrefixScore which is based on Algorithm 2 in WATANABE et al. @@ -267,7 +267,7 @@ class CTCPrefixScorerPD(): return (r_prev_new, s_prev, f_min_prev, f_max_prev) -class CTCPrefixScorer(): +class CTCPrefixScore(): """Compute CTC label sequence scores which is based on Algorithm 2 in WATANABE et al. diff --git a/deepspeech/decoders/scorers/score_interface.py b/deepspeech/decoders/scorers/scorer_interface.py similarity index 100% rename from deepspeech/decoders/scorers/score_interface.py rename to deepspeech/decoders/scorers/scorer_interface.py diff --git a/deepspeech/decoders/utils.py b/deepspeech/decoders/utils.py index 0281a78bb..f59b55d9b 100644 --- a/deepspeech/decoders/utils.py +++ b/deepspeech/decoders/utils.py @@ -12,7 +12,11 @@ # See the License for the specific language governing permissions and # limitations under the License. -__all__ = ["end_detect"] +import numpy as np +from deepspeech.utils.log import Log +logger = Log(__name__).getlog() + +__all__ = ["end_detect", "parse_hypothesis", "add_results_to_json"] def end_detect(ended_hyps, i, M=3, D_end=np.log(1 * np.exp(-10))): @@ -118,7 +122,7 @@ def add_results_to_json(js, nbest_hyps, char_list): # show 1-best result if n == 1: if "text" in out_dic.keys(): - logging.info("groundtruth: %s" % out_dic["text"]) - logging.info("prediction : %s" % out_dic["rec_text"]) + logger.info("groundtruth: %s" % out_dic["text"]) + logger.info("prediction : %s" % out_dic["rec_text"]) return new_js \ No newline at end of file diff --git a/deepspeech/exps/u2_kaldi/bin/recog.py b/deepspeech/exps/u2_kaldi/bin/recog.py new file mode 100644 index 000000000..60f4e58ac --- /dev/null +++ b/deepspeech/exps/u2_kaldi/bin/recog.py @@ -0,0 +1,379 @@ + +"""End-to-end speech recognition model decoding script.""" + +import configargparse +import logging +import os +import random +import sys + +import numpy as np + +from distutils.util import strtobool +from deepspeech.training.cli import default_argument_parser + +# NOTE: you need this func to generate our sphinx doc + +def get_parser(): + """Get default arguments.""" + parser = configargparse.ArgumentParser( + description="Transcribe text from speech using " + "a speech recognition model on one CPU or GPU", + config_file_parser_class=configargparse.YAMLConfigFileParser, + formatter_class=configargparse.ArgumentDefaultsHelpFormatter, + ) + parser.add( + '--model-name', + type=str, + default='u2_kaldi', + help='model name, e.g: deepspeech2, u2, u2_kaldi, u2_st') + # general configuration + parser.add("--config", is_config_file=True, help="Config file path") + parser.add( + "--config2", + is_config_file=True, + help="Second config file path that overwrites the settings in `--config`", + ) + parser.add( + "--config3", + is_config_file=True, + help="Third config file path that overwrites the settings " + "in `--config` and `--config2`", + ) + + parser.add_argument("--ngpu", type=int, default=0, help="Number of GPUs") + parser.add_argument( + "--dtype", + choices=("float16", "float32", "float64"), + default="float32", + help="Float precision (only available in --api v2)", + ) + parser.add_argument("--debugmode", type=int, default=1, help="Debugmode") + parser.add_argument("--seed", type=int, default=1, help="Random seed") + parser.add_argument("--verbose", "-V", type=int, default=2, help="Verbose option") + parser.add_argument( + "--batchsize", + type=int, + default=1, + help="Batch size for beam search (0: means no batch processing)", + ) + parser.add_argument( + "--preprocess-conf", + type=str, + default=None, + help="The configuration file for the pre-processing", + ) + parser.add_argument( + "--api", + default="v2", + choices=["v2"], + help="Beam search APIs " + "v2: Experimental API. It supports any models that implements ScorerInterface.", + ) + # task related + parser.add_argument( + "--recog-json", type=str, help="Filename of recognition data (json)" + ) + parser.add_argument( + "--result-label", + type=str, + required=True, + help="Filename of result label data (json)", + ) + # model (parameter) related + parser.add_argument( + "--model", type=str, required=True, help="Model file parameters to read" + ) + parser.add_argument( + "--model-conf", type=str, default=None, help="Model config file" + ) + parser.add_argument( + "--num-spkrs", + type=int, + default=1, + choices=[1, 2], + help="Number of speakers in the speech", + ) + parser.add_argument( + "--num-encs", default=1, type=int, help="Number of encoders in the model." + ) + # search related + parser.add_argument("--nbest", type=int, default=1, help="Output N-best hypotheses") + parser.add_argument("--beam-size", type=int, default=1, help="Beam size") + parser.add_argument("--penalty", type=float, default=0.0, help="Incertion penalty") + parser.add_argument( + "--maxlenratio", + type=float, + default=0.0, + help="""Input length ratio to obtain max output length. + If maxlenratio=0.0 (default), it uses a end-detect function + to automatically find maximum hypothesis lengths. + If maxlenratio<0.0, its absolute value is interpreted + as a constant max output length""", + ) + parser.add_argument( + "--minlenratio", + type=float, + default=0.0, + help="Input length ratio to obtain min output length", + ) + parser.add_argument( + "--ctc-weight", type=float, default=0.0, help="CTC weight in joint decoding" + ) + parser.add_argument( + "--weights-ctc-dec", + type=float, + action="append", + help="ctc weight assigned to each encoder during decoding." + "[in multi-encoder mode only]", + ) + parser.add_argument( + "--ctc-window-margin", + type=int, + default=0, + help="""Use CTC window with margin parameter to accelerate + CTC/attention decoding especially on GPU. Smaller magin + makes decoding faster, but may increase search errors. + If margin=0 (default), this function is disabled""", + ) + # transducer related + parser.add_argument( + "--search-type", + type=str, + default="default", + choices=["default", "nsc", "tsd", "alsd", "maes"], + help="""Type of beam search implementation to use during inference. + Can be either: default beam search ("default"), + N-Step Constrained beam search ("nsc"), Time-Synchronous Decoding ("tsd"), + Alignment-Length Synchronous Decoding ("alsd") or + modified Adaptive Expansion Search ("maes").""", + ) + parser.add_argument( + "--nstep", + type=int, + default=1, + help="""Number of expansion steps allowed in NSC beam search or mAES + (nstep > 0 for NSC and nstep > 1 for mAES).""", + ) + parser.add_argument( + "--prefix-alpha", + type=int, + default=2, + help="Length prefix difference allowed in NSC beam search or mAES.", + ) + parser.add_argument( + "--max-sym-exp", + type=int, + default=2, + help="Number of symbol expansions allowed in TSD.", + ) + parser.add_argument( + "--u-max", + type=int, + default=400, + help="Length prefix difference allowed in ALSD.", + ) + parser.add_argument( + "--expansion-gamma", + type=float, + default=2.3, + help="Allowed logp difference for prune-by-value method in mAES.", + ) + parser.add_argument( + "--expansion-beta", + type=int, + default=2, + help="""Number of additional candidates for expanded hypotheses + selection in mAES.""", + ) + parser.add_argument( + "--score-norm", + type=strtobool, + nargs="?", + default=True, + help="Normalize final hypotheses' score by length", + ) + parser.add_argument( + "--softmax-temperature", + type=float, + default=1.0, + help="Penalization term for softmax function.", + ) + # rnnlm related + parser.add_argument( + "--rnnlm", type=str, default=None, help="RNNLM model file to read" + ) + parser.add_argument( + "--rnnlm-conf", type=str, default=None, help="RNNLM model config file to read" + ) + parser.add_argument( + "--word-rnnlm", type=str, default=None, help="Word RNNLM model file to read" + ) + parser.add_argument( + "--word-rnnlm-conf", + type=str, + default=None, + help="Word RNNLM model config file to read", + ) + parser.add_argument("--word-dict", type=str, default=None, help="Word list to read") + parser.add_argument("--lm-weight", type=float, default=0.1, help="RNNLM weight") + # ngram related + parser.add_argument( + "--ngram-model", type=str, default=None, help="ngram model file to read" + ) + parser.add_argument("--ngram-weight", type=float, default=0.1, help="ngram weight") + parser.add_argument( + "--ngram-scorer", + type=str, + default="part", + choices=("full", "part"), + help="""if the ngram is set as a part scorer, similar with CTC scorer, + ngram scorer only scores topK hypethesis. + if the ngram is set as full scorer, ngram scorer scores all hypthesis + the decoding speed of part scorer is musch faster than full one""", + ) + # streaming related + parser.add_argument( + "--streaming-mode", + type=str, + default=None, + choices=["window", "segment"], + help="""Use streaming recognizer for inference. + `--batchsize` must be set to 0 to enable this mode""", + ) + parser.add_argument("--streaming-window", type=int, default=10, help="Window size") + parser.add_argument( + "--streaming-min-blank-dur", + type=int, + default=10, + help="Minimum blank duration threshold", + ) + parser.add_argument( + "--streaming-onset-margin", type=int, default=1, help="Onset margin" + ) + parser.add_argument( + "--streaming-offset-margin", type=int, default=1, help="Offset margin" + ) + # non-autoregressive related + # Mask CTC related. See https://arxiv.org/abs/2005.08700 for the detail. + parser.add_argument( + "--maskctc-n-iterations", + type=int, + default=10, + help="Number of decoding iterations." + "For Mask CTC, set 0 to predict 1 mask/iter.", + ) + parser.add_argument( + "--maskctc-probability-threshold", + type=float, + default=0.999, + help="Threshold probability for CTC output", + ) + # quantize model related + parser.add_argument( + "--quantize-config", + nargs="*", + help="Quantize config list. E.g.: --quantize-config=[Linear,LSTM,GRU]", + ) + parser.add_argument( + "--quantize-dtype", type=str, default="qint8", help="Dtype dynamic quantize" + ) + parser.add_argument( + "--quantize-asr-model", + type=bool, + default=False, + help="Quantize asr model", + ) + parser.add_argument( + "--quantize-lm-model", + type=bool, + default=False, + help="Quantize lm model", + ) + return parser + + +def main(args): + """Run the main decoding function.""" + parser = get_parser() + parser.add_argument( + "--output", metavar="CKPT_DIR", help="path to save checkpoint.") + parser.add_argument( + "--checkpoint_path", type=str, help="path to load checkpoint") + parser.add_argument( + "--dict-path", type=str, help="path to load checkpoint") + # parser = default_argument_parser(parser) + args = parser.parse_args(args) + + if args.ngpu == 0 and args.dtype == "float16": + raise ValueError(f"--dtype {args.dtype} does not support the CPU backend.") + + # logging info + if args.verbose == 1: + logging.basicConfig( + level=logging.INFO, + format="%(asctime)s (%(module)s:%(lineno)d) %(levelname)s: %(message)s", + ) + elif args.verbose == 2: + logging.basicConfig( + level=logging.DEBUG, + format="%(asctime)s (%(module)s:%(lineno)d) %(levelname)s: %(message)s", + ) + else: + logging.basicConfig( + level=logging.WARN, + format="%(asctime)s (%(module)s:%(lineno)d) %(levelname)s: %(message)s", + ) + logging.warning("Skip DEBUG/INFO messages") + logging.info(args) + + # check CUDA_VISIBLE_DEVICES + if args.ngpu > 0: + cvd = os.environ.get("CUDA_VISIBLE_DEVICES") + if cvd is None: + logging.warning("CUDA_VISIBLE_DEVICES is not set.") + elif args.ngpu != len(cvd.split(",")): + logging.error("#gpus is not matched with CUDA_VISIBLE_DEVICES.") + sys.exit(1) + + # TODO(mn5k): support of multiple GPUs + if args.ngpu > 1: + logging.error("The program only supports ngpu=1.") + sys.exit(1) + + # display PYTHONPATH + logging.info("python path = " + os.environ.get("PYTHONPATH", "(None)")) + + # seed setting + random.seed(args.seed) + np.random.seed(args.seed) + logging.info("set random seed = %d" % args.seed) + + # validate rnn options + if args.rnnlm is not None and args.word_rnnlm is not None: + logging.error( + "It seems that both --rnnlm and --word-rnnlm are specified. " + "Please use either option." + ) + sys.exit(1) + + # recog + if args.num_spkrs == 1: + if args.num_encs == 1: + # Experimental API that supports custom LMs + if args.api == "v2": + from deepspeech.decoders.recog import recog_v2 + recog_v2(args) + else: + raise ValueError("Only support --api v2") + else: + if args.api == "v2": + raise NotImplementedError( + f"--num-encs {args.num_encs} > 1 is not supported in --api v2" + ) + elif args.num_spkrs == 2: + raise ValueError("asr_mix not supported.") + + +if __name__ == "__main__": + main(sys.argv[1:]) diff --git a/deepspeech/exps/u2_kaldi/model.py b/deepspeech/exps/u2_kaldi/model.py index 18e29b28f..3aadca856 100644 --- a/deepspeech/exps/u2_kaldi/model.py +++ b/deepspeech/exps/u2_kaldi/model.py @@ -317,11 +317,9 @@ class U2Trainer(Trainer): with UpdateConfig(model_conf): model_conf.input_dim = self.train_loader.feat_dim model_conf.output_dim = self.train_loader.vocab_size - model = U2Model.from_config(model_conf) if self.parallel: model = paddle.DataParallel(model) - logger.info(f"{model}") layer_tools.print_params(model, logger.info) # lr diff --git a/deepspeech/frontend/featurizer/text_featurizer.py b/deepspeech/frontend/featurizer/text_featurizer.py index 34220432b..18ff411b0 100644 --- a/deepspeech/frontend/featurizer/text_featurizer.py +++ b/deepspeech/frontend/featurizer/text_featurizer.py @@ -207,7 +207,7 @@ class TextFeaturizer(): """Load vocabulary from file.""" vocab_list = load_dict(vocab_filepath, maskctc) assert vocab_list is not None - logger.info(f"Vocab: {pformat(vocab_list)}") + logger.debug(f"Vocab: {pformat(vocab_list)}") id2token = dict( [(idx, token) for (idx, token) in enumerate(vocab_list)]) diff --git a/deepspeech/models/u2/u2.py b/deepspeech/models/u2/u2.py index 8915cbd7d..fa517906f 100644 --- a/deepspeech/models/u2/u2.py +++ b/deepspeech/models/u2/u2.py @@ -50,7 +50,7 @@ from deepspeech.utils.tensor_utils import th_accuracy from deepspeech.utils.utility import log_add from deepspeech.utils.utility import UpdateConfig from deepspeech.models.asr_interface import ASRInterface -from deepspeech.decoders.scorers.ctc_prefix_score import CTCPrefixScorer +from deepspeech.decoders.scorers.ctc import CTCPrefixScorer __all__ = ["U2Model", "U2InferModel"] diff --git a/deepspeech/modules/decoder.py b/deepspeech/modules/decoder.py index 154b7390f..ee3572ea9 100644 --- a/deepspeech/modules/decoder.py +++ b/deepspeech/modules/decoder.py @@ -28,7 +28,7 @@ from deepspeech.modules.mask import make_non_pad_mask from deepspeech.modules.mask import subsequent_mask from deepspeech.modules.mask import make_xs_mask from deepspeech.modules.positionwise_feed_forward import PositionwiseFeedForward -from deepspeech.decoders.scorers.score_interface import BatchScorerInterface +from deepspeech.decoders.scorers.scorer_interface import BatchScorerInterface from deepspeech.utils.log import Log logger = Log(__name__).getlog() @@ -191,8 +191,8 @@ class TransformerDecoder(BatchScorerInterface, nn.Layer): ys: (ylen,) x: (xlen, n_feat) """ - ys_mask = subsequent_mask(len(ys)).unsqueeze(0) - x_mask = make_xs_mask(x.unsqueeze(0)) + ys_mask = subsequent_mask(len(ys)).unsqueeze(0) # (B,L,L) + x_mask = make_xs_mask(x.unsqueeze(0)).unsqueeze(1) # (B,1,T) if self.selfattention_layer_type != "selfattn": # TODO(karita): implement cache logging.warning( @@ -237,9 +237,8 @@ class TransformerDecoder(BatchScorerInterface, nn.Layer): ] # batch decoding - ys_mask = subsequent_mask(ys.size(-1)).unsqueeze(0) - - xs_mask = make_xs_mask(xs) + ys_mask = subsequent_mask(ys.size(-1)).unsqueeze(0) # (B,L,L) + xs_mask = make_xs_mask(xs).unsqueeze(1) # (B,1,T) logp, states = self.forward_one_step(xs, xs_mask, ys, ys_mask, cache=batch_state) # transpose state of [layer, batch] into [batch, layer] diff --git a/deepspeech/modules/mask.py b/deepspeech/modules/mask.py index cffa10a7b..7ae418c03 100644 --- a/deepspeech/modules/mask.py +++ b/deepspeech/modules/mask.py @@ -24,15 +24,16 @@ __all__ = [ ] -def make_xs_mask(xs:paddle.Tensor) -> paddle.Tensor: +def make_xs_mask(xs:paddle.Tensor, pad_value=0.0) -> paddle.Tensor: """Maks mask tensor containing indices of non-padded part. Args: xs (paddle.Tensor): (B, T, D), zeros for pad. Returns: - paddle.Tensor: Mask Tensor indices of non-padded part. (B, T, D) + paddle.Tensor: Mask Tensor indices of non-padded part. (B, T) """ - pad_frame = paddle.zeros([1, 1, xs.shape[-1]], dtype=xs.dtype) + pad_frame = paddle.full([1, 1, xs.shape[-1]], pad_value, dtype=xs.dtype) mask = xs != pad_frame + mask = mask.all(axis=-1) return mask diff --git a/deepspeech/training/cli.py b/deepspeech/training/cli.py index ca9652e3c..db7076d3d 100644 --- a/deepspeech/training/cli.py +++ b/deepspeech/training/cli.py @@ -35,7 +35,7 @@ class LoadFromFile(argparse.Action): parser.parse_args(f.read().split(), namespace) -def default_argument_parser(): +def default_argument_parser(parser=None): r"""A simple yet genral argument parser for experiments with parakeet. This is used in examples with parakeet. And it is intended to be used by @@ -62,7 +62,9 @@ def default_argument_parser(): argparse.ArgumentParser the parser """ - parser = argparse.ArgumentParser() + if parser is None: + parser = argparse.ArgumentParser() + parser.register('action', 'extend', ExtendAction) parser.add_argument( '--conf', type=open, action=LoadFromFile, help="config file.") diff --git a/deepspeech/training/trainer.py b/deepspeech/training/trainer.py index 6c18fa369..b2ee7a1b8 100644 --- a/deepspeech/training/trainer.py +++ b/deepspeech/training/trainer.py @@ -126,7 +126,7 @@ class Trainer(): logger.info(f"Set seed {args.seed}") # profiler and benchmark options - if self.args.benchmark_batch_size: + if hasattr(self.args, "benchmark_batch_size") and self.args.benchmark_batch_size: with UpdateConfig(self.config): self.config.collator.batch_size = self.args.benchmark_batch_size self.config.training.log_interval = 1 @@ -326,12 +326,25 @@ class Trainer(): finally: self.destory() + def restore(self): + """Resume from latest checkpoint at checkpoints in the output + directory or load a specified checkpoint. + + If ``args.checkpoint_path`` is not None, load the checkpoint, else + resume training. + """ + assert self.args.checkpoint_path + infos = self.checkpoint.load_latest_parameters( + self.model, + checkpoint_path=self.args.checkpoint_path) + return infos + def run_test(self): """Do Test/Decode""" try: with Timer("Test/Decode Done: {}"): with self.eval(): - self.resume_or_scratch() + self.restore() self.test() except KeyboardInterrupt: exit(-1) @@ -341,6 +354,7 @@ class Trainer(): try: with Timer("Export Done: {}"): with self.eval(): + self.restore() self.export() except KeyboardInterrupt: exit(-1) @@ -350,7 +364,7 @@ class Trainer(): try: with Timer("Align Done: {}"): with self.eval(): - self.resume_or_scratch() + self.restore() self.align() except KeyboardInterrupt: sys.exit(-1) diff --git a/examples/librispeech/s2/README.md b/examples/librispeech/s2/README.md index 1f7c69194..6e41cc370 100644 --- a/examples/librispeech/s2/README.md +++ b/examples/librispeech/s2/README.md @@ -11,3 +11,4 @@ | test-clean | ctc_greedy_search | 2620 | 52576 | 95.9 | 3.7 | 0.4 | 0.5 | 4.6 | 48.0 | | test-clean | ctc_prefix_beamsearch | 2620 | 52576 | 95.9 | 3.7 | 0.4 | 0.5 | 4.6 | 47.6 | | test-clean | attention_rescore | 2620 | 52576 | 96.8 | 2.9 | 0.3 | 0.4 | 3.7 | 38.0 | +| test-clean | join_ctc_w/o_lm | 2620 | 52576 | 97.2 | 2.6 | 0.3 | 0.4 | 3.2 | 34.9 | diff --git a/examples/librispeech/s2/conf/decode/decode.yaml b/examples/librispeech/s2/conf/decode/decode.yaml new file mode 100644 index 000000000..867bf6118 --- /dev/null +++ b/examples/librispeech/s2/conf/decode/decode.yaml @@ -0,0 +1,7 @@ +batchsize: 0 +beam-size: 60 +ctc-weight: 0.4 +lm-weight: 0.0 +maxlenratio: 0.0 +minlenratio: 0.0 +penalty: 0.0 diff --git a/examples/librispeech/s2/conf/decode/decode_all.yaml b/examples/librispeech/s2/conf/decode/decode_all.yaml new file mode 100644 index 000000000..87d5f6d19 --- /dev/null +++ b/examples/librispeech/s2/conf/decode/decode_all.yaml @@ -0,0 +1,7 @@ +batchsize: 0 +beam-size: 60 +ctc-weight: 0.4 +lm-weight: 0.6 +maxlenratio: 0.0 +minlenratio: 0.0 +penalty: 0.0 \ No newline at end of file diff --git a/examples/librispeech/s2/conf/decode/decode_wo_lm.yaml b/examples/librispeech/s2/conf/decode/decode_wo_lm.yaml new file mode 100644 index 000000000..a82dba23b --- /dev/null +++ b/examples/librispeech/s2/conf/decode/decode_wo_lm.yaml @@ -0,0 +1,7 @@ +batchsize: 0 +beam-size: 60 +ctc-weight: 0.4 +lm-weight: 0.0 +maxlenratio: 0.0 +minlenratio: 0.0 +penalty: 0.0 \ No newline at end of file diff --git a/examples/librispeech/s2/local/recog.sh b/examples/librispeech/s2/local/recog.sh new file mode 100755 index 000000000..0393535b4 --- /dev/null +++ b/examples/librispeech/s2/local/recog.sh @@ -0,0 +1,103 @@ +#!/bin/bash + +set -e + +expdir=exp +datadir=data +nj=32 + +decode_config=conf/decode/decode.yaml +lang_model=rnnlm.model.best +lmexpdir=exp/train_rnnlm_pytorch_lm_transformer_cosine_batchsize32_lr1e-4_layer16_unigram5000_ngpu4/ + +lmtag='nolm' + +recog_set="test-clean test-other dev-clean dev-other" +recog_set="test-clean" + +# bpemode (unigram or bpe) +nbpe=5000 +bpemode=unigram +bpeprefix="data/bpe_${bpemode}_${nbpe}" +bpemodel=${bpeprefix}.model + +if [ $# != 3 ];then + echo "usage: ${0} config_path dict_path ckpt_path_prefix" + exit -1 +fi + +ngpu=$(echo $CUDA_VISIBLE_DEVICES | awk -F "," '{print NF}') +echo "using $ngpu gpus..." + +config_path=$1 +dict=$2 +ckpt_prefix=$3 + +ckpt_dir=$(dirname `dirname ${ckpt_prefix}`) +echo "ckpt dir: ${ckpt_dir}" + +ckpt_tag=$(basename ${ckpt_prefix}) +echo "ckpt tag: ${ckpt_tag}" + +chunk_mode=false +if [[ ${config_path} =~ ^.*chunk_.*yaml$ ]];then + chunk_mode=true +fi +echo "chunk mode: ${chunk_mode}" +echo "decode conf: ${decode_config}" + +# download language model +#bash local/download_lm_en.sh +#if [ $? -ne 0 ]; then +# exit 1 +#fi + + +pids=() # initialize pids + +for dmethd in join_ctc; do +( + echo "${dmethd} decoding" + for rtask in ${recog_set}; do + ( + echo "${rtask} dataset" + decode_dir=${ckpt_dir}/decode/decode_${rtask/-/_}_${dmethd}_$(basename ${config_path%.*})_${lmtag}_${ckpt_tag} + feat_recog_dir=${datadir} + mkdir -p ${decode_dir} + mkdir -p ${feat_recog_dir} + + # split data + split_json.sh manifest.${rtask} ${nj} + + #### use CPU for decoding + ngpu=0 + + # set batchsize 0 to disable batch decoding + ${decode_cmd} JOB=1:${nj} ${decode_dir}/log/decode.JOB.log \ + python3 -u ${BIN_DIR}/recog.py \ + --api v2 \ + --config ${decode_config} \ + --ngpu ${ngpu} \ + --batchsize 0 \ + --checkpoint_path ${ckpt_prefix} \ + --dict-path ${dict} \ + --recog-json ${feat_recog_dir}/split${nj}/JOB/manifest.${rtask} \ + --result-label ${decode_dir}/data.JOB.json \ + --model-conf ${config_path} \ + --model ${ckpt_prefix}.pdparams + + #--rnnlm ${lmexpdir}/${lang_model} \ + + score_sclite.sh --bpe ${nbpe} --bpemodel ${bpemodel} --wer false ${decode_dir} ${dict} + + ) & + pids+=($!) # store background pids + i=0; for pid in "${pids[@]}"; do wait ${pid} || ((++i)); done + [ ${i} -gt 0 ] && echo "$0: ${i} background jobs are failed." || true + done +) +done + +echo "Finished" + +exit 0 diff --git a/examples/librispeech/s2/local/test.sh b/examples/librispeech/s2/local/test.sh index a6b6cfa9d..006a13c5f 100755 --- a/examples/librispeech/s2/local/test.sh +++ b/examples/librispeech/s2/local/test.sh @@ -83,7 +83,7 @@ for dmethd in attention ctc_greedy_search ctc_prefix_beam_search attention_resco --opts decoding.batch_size ${batch_size} \ --opts data.test_manifest ${feat_recog_dir}/split${nj}/JOB/manifest.${rtask} - score_sclite.sh --bpe ${nbpe} --bpemodel ${bpemodel} --wer false ${expdir}/${decode_dir} ${dict} + score_sclite.sh --bpe ${nbpe} --bpemodel ${bpemodel} --wer false ${decode_dir} ${dict} ) & pids+=($!) # store background pids diff --git a/requirements.txt b/requirements.txt index 2a3f06514..42fd645a5 100644 --- a/requirements.txt +++ b/requirements.txt @@ -40,3 +40,4 @@ pyworld jieba phkit yq +ConfigArgParse \ No newline at end of file From 1b75ca1eead918c9b6ab25426f2eea1a39a1c53f Mon Sep 17 00:00:00 2001 From: Hui Zhang Date: Sat, 23 Oct 2021 15:31:43 +0000 Subject: [PATCH 4/5] avg model dump val loss mean --- utils/avg_model.py | 11 ++++++----- 1 file changed, 6 insertions(+), 5 deletions(-) diff --git a/utils/avg_model.py b/utils/avg_model.py index 1fc00cb65..7c05ec789 100755 --- a/utils/avg_model.py +++ b/utils/avg_model.py @@ -25,8 +25,8 @@ def main(args): paddle.set_device('cpu') val_scores = [] - beat_val_scores = [] - selected_epochs = [] + beat_val_scores = None + selected_epochs = None jsons = glob.glob(f'{args.ckpt_dir}/[!train]*.json') jsons = sorted(jsons, key=os.path.getmtime, reverse=True) @@ -80,9 +80,10 @@ def main(args): data = json.dumps({ "mode": 'val_best' if args.val_best else 'latest', "avg_ckpt": args.dst_model, - "ckpt": path_list, - "epoch": selected_epochs.tolist(), - "val_loss": beat_val_scores.tolist(), + "val_loss_mean": np.mean(beat_val_scores), + "ckpts": path_list, + "epochs": selected_epochs.tolist(), + "val_losses": beat_val_scores.tolist(), }) f.write(data + "\n") From 2101e20d9dc77dc5907b91a0eaa4cdc0198c599d Mon Sep 17 00:00:00 2001 From: Hui Zhang Date: Sat, 23 Oct 2021 15:59:53 +0000 Subject: [PATCH 5/5] recog into decoders, format code --- deepspeech/__init__.py | 38 +- deepspeech/decoders/beam_search/__init__.py | 17 + .../decoders/beam_search/batch_beam_search.py | 17 + .../decoders/{ => beam_search}/beam_search.py | 236 ++++++----- deepspeech/decoders/recog.py | 134 +++--- deepspeech/decoders/recog_bin.py | 376 +++++++++++++++++ deepspeech/decoders/scorers/ngram.py | 5 +- deepspeech/decoders/utils.py | 7 +- deepspeech/exps/__init__.py | 49 +++ deepspeech/exps/u2_kaldi/bin/recog.py | 388 +----------------- deepspeech/exps/u2_kaldi/model.py | 5 +- .../frontend/featurizer/text_featurizer.py | 2 +- deepspeech/models/asr_interface.py | 39 +- deepspeech/models/u2/u2.py | 8 +- deepspeech/modules/decoder.py | 32 +- deepspeech/modules/mask.py | 2 +- deepspeech/training/cli.py | 2 +- deepspeech/training/trainer.py | 6 +- examples/librispeech/README.md | 6 +- examples/librispeech/s2/README.md | 14 +- .../librispeech/s2/conf/decode/decode.yaml | 2 +- examples/librispeech/s2/local/recog.sh | 24 +- examples/librispeech/s2/local/test.sh | 17 +- requirements.txt | 42 +- setup.py | 23 +- 25 files changed, 808 insertions(+), 683 deletions(-) create mode 100644 deepspeech/decoders/beam_search/__init__.py create mode 100644 deepspeech/decoders/beam_search/batch_beam_search.py rename deepspeech/decoders/{ => beam_search}/beam_search.py (74%) create mode 100644 deepspeech/decoders/recog_bin.py diff --git a/deepspeech/__init__.py b/deepspeech/__init__.py index ca5b4ba5c..da3b1acb1 100644 --- a/deepspeech/__init__.py +++ b/deepspeech/__init__.py @@ -233,7 +233,8 @@ def is_broadcastable(shp1, shp2): def masked_fill(xs: paddle.Tensor, mask: paddle.Tensor, value: Union[float, int]): - assert is_broadcastable(xs.shape, mask.shape) is True, (xs.shape, mask.shape) + assert is_broadcastable(xs.shape, mask.shape) is True, (xs.shape, + mask.shape) bshape = paddle.broadcast_shape(xs.shape, mask.shape) mask = mask.broadcast_to(bshape) trues = paddle.ones_like(xs) * value @@ -312,18 +313,18 @@ if not hasattr(paddle.Tensor, 'type_as'): def to(x: paddle.Tensor, *args, **kwargs) -> paddle.Tensor: - assert len(args) == 1 - if isinstance(args[0], str): # dtype - return x.astype(args[0]) - elif isinstance(args[0], paddle.Tensor): # Tensor - return x.astype(args[0].dtype) - else: # Device - return x + assert len(args) == 1 + if isinstance(args[0], str): # dtype + return x.astype(args[0]) + elif isinstance(args[0], paddle.Tensor): # Tensor + return x.astype(args[0].dtype) + else: # Device + return x if not hasattr(paddle.Tensor, 'to'): - logger.debug("register user to to paddle.Tensor, remove this when fixed!") - setattr(paddle.Tensor, 'to', to) + logger.debug("register user to to paddle.Tensor, remove this when fixed!") + setattr(paddle.Tensor, 'to', to) def func_float(x: paddle.Tensor) -> paddle.Tensor: @@ -355,7 +356,6 @@ if not hasattr(paddle.Tensor, 'tolist'): setattr(paddle.Tensor, 'tolist', tolist) - ########### hcak paddle.nn.functional ############# # hack loss def ctc_loss(logits, @@ -384,7 +384,6 @@ logger.debug( ) F.ctc_loss = ctc_loss - ########### hcak paddle.nn ############# from paddle.nn import Layer from typing import Optional @@ -394,6 +393,7 @@ from typing import Tuple from typing import Iterator from collections import OrderedDict, abc as container_abcs + class LayerDict(paddle.nn.Layer): r"""Holds submodules in a dictionary. @@ -438,7 +438,7 @@ class LayerDict(paddle.nn.Layer): return x """ - def __init__(self, modules: Optional[Mapping[str, Layer]] = None) -> None: + def __init__(self, modules: Optional[Mapping[str, Layer]]=None) -> None: super(LayerDict, self).__init__() if modules is not None: self.update(modules) @@ -505,10 +505,11 @@ class LayerDict(paddle.nn.Layer): """ if not isinstance(modules, container_abcs.Iterable): raise TypeError("LayerDict.update should be called with an " - "iterable of key/value pairs, but got " + - type(modules).__name__) + "iterable of key/value pairs, but got " + type( + modules).__name__) - if isinstance(modules, (OrderedDict, LayerDict, container_abcs.Mapping)): + if isinstance(modules, + (OrderedDict, LayerDict, container_abcs.Mapping)): for key, module in modules.items(): self[key] = module else: @@ -520,14 +521,15 @@ class LayerDict(paddle.nn.Layer): type(m).__name__) if not len(m) == 2: raise ValueError("LayerDict update sequence element " - "#" + str(j) + " has length " + str(len(m)) + - "; 2 is required") + "#" + str(j) + " has length " + str( + len(m)) + "; 2 is required") # modules can be Mapping (what it's typed at), or a list: [(name1, module1), (name2, module2)] # that's too cumbersome to type correctly with overloads, so we add an ignore here self[m[0]] = m[1] # type: ignore[assignment] # remove forward alltogether to fallback on Module's _forward_unimplemented + if not hasattr(paddle.nn, 'LayerDict'): logger.debug( "register user LayerDict to paddle.nn, remove this when fixed!") diff --git a/deepspeech/decoders/beam_search/__init__.py b/deepspeech/decoders/beam_search/__init__.py new file mode 100644 index 000000000..79a1e9d30 --- /dev/null +++ b/deepspeech/decoders/beam_search/__init__.py @@ -0,0 +1,17 @@ +# Copyright (c) 2021 PaddlePaddle Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +from .batch_beam_search import BatchBeamSearch +from .beam_search import beam_search +from .beam_search import BeamSearch +from .beam_search import Hypothesis diff --git a/deepspeech/decoders/beam_search/batch_beam_search.py b/deepspeech/decoders/beam_search/batch_beam_search.py new file mode 100644 index 000000000..3fc1c435f --- /dev/null +++ b/deepspeech/decoders/beam_search/batch_beam_search.py @@ -0,0 +1,17 @@ +# Copyright (c) 2021 PaddlePaddle Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + + +class BatchBeamSearch(): + pass diff --git a/deepspeech/decoders/beam_search.py b/deepspeech/decoders/beam_search/beam_search.py similarity index 74% rename from deepspeech/decoders/beam_search.py rename to deepspeech/decoders/beam_search/beam_search.py index afb8aefa5..8fd8f9b8f 100644 --- a/deepspeech/decoders/beam_search.py +++ b/deepspeech/decoders/beam_search/beam_search.py @@ -1,5 +1,17 @@ +# Copyright (c) 2021 PaddlePaddle Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. """Beam search module.""" - from itertools import chain from typing import Any from typing import Dict @@ -10,18 +22,18 @@ from typing import Union import paddle -from .utils import end_detect -from .scorers.scorer_interface import PartialScorerInterface -from .scorers.scorer_interface import ScorerInterface - +from ..scorers.scorer_interface import PartialScorerInterface +from ..scorers.scorer_interface import ScorerInterface +from ..utils import end_detect from deepspeech.utils.log import Log logger = Log(__name__).getlog() + class Hypothesis(NamedTuple): """Hypothesis data type.""" - yseq: paddle.Tensor # (T,) + yseq: paddle.Tensor # (T,) score: Union[float, paddle.Tensor] = 0 scores: Dict[str, Union[float, paddle.Tensor]] = dict() states: Dict[str, Any] = dict() @@ -31,25 +43,24 @@ class Hypothesis(NamedTuple): return self._replace( yseq=self.yseq.tolist(), score=float(self.score), - scores={k: float(v) for k, v in self.scores.items()}, - )._asdict() + scores={k: float(v) + for k, v in self.scores.items()}, )._asdict() class BeamSearch(paddle.nn.Layer): """Beam search implementation.""" def __init__( - self, - scorers: Dict[str, ScorerInterface], - weights: Dict[str, float], - beam_size: int, - vocab_size: int, - sos: int, - eos: int, - token_list: List[str] = None, - pre_beam_ratio: float = 1.5, - pre_beam_score_key: str = None, - ): + self, + scorers: Dict[str, ScorerInterface], + weights: Dict[str, float], + beam_size: int, + vocab_size: int, + sos: int, + eos: int, + token_list: List[str]=None, + pre_beam_ratio: float=1.5, + pre_beam_score_key: str=None, ): """Initialize beam search. Args: @@ -71,12 +82,12 @@ class BeamSearch(paddle.nn.Layer): super().__init__() # set scorers self.weights = weights - self.scorers = dict() # all = full + partial - self.full_scorers = dict() # full tokens - self.part_scorers = dict() # partial tokens + self.scorers = dict() # all = full + partial + self.full_scorers = dict() # full tokens + self.part_scorers = dict() # partial tokens # this module dict is required for recursive cast # `self.to(device, dtype)` in `recog.py` - self.nn_dict = paddle.nn.LayerDict() # nn.Layer + self.nn_dict = paddle.nn.LayerDict() # nn.Layer for k, v in scorers.items(): w = weights.get(k, 0) if w == 0 or v is None: @@ -100,20 +111,16 @@ class BeamSearch(paddle.nn.Layer): self.pre_beam_size = int(pre_beam_ratio * beam_size) self.beam_size = beam_size self.n_vocab = vocab_size - if ( - pre_beam_score_key is not None - and pre_beam_score_key != "full" - and pre_beam_score_key not in self.full_scorers - ): - raise KeyError(f"{pre_beam_score_key} is not found in {self.full_scorers}") + if (pre_beam_score_key is not None and pre_beam_score_key != "full" and + pre_beam_score_key not in self.full_scorers): + raise KeyError( + f"{pre_beam_score_key} is not found in {self.full_scorers}") # selected `key` scorer to do pre beam search self.pre_beam_score_key = pre_beam_score_key # do_pre_beam when need, valid and has part_scorers - self.do_pre_beam = ( - self.pre_beam_score_key is not None - and self.pre_beam_size < self.n_vocab - and len(self.part_scorers) > 0 - ) + self.do_pre_beam = (self.pre_beam_score_key is not None and + self.pre_beam_size < self.n_vocab and + len(self.part_scorers) > 0) def init_hyp(self, x: paddle.Tensor) -> List[Hypothesis]: """Get an initial hypothesis data. @@ -135,12 +142,12 @@ class BeamSearch(paddle.nn.Layer): yseq=paddle.to_tensor([self.sos], place=x.place), score=0.0, scores=init_scores, - states=init_states, - ) + states=init_states, ) ] @staticmethod - def append_token(xs: paddle.Tensor, x: Union[int, paddle.Tensor]) -> paddle.Tensor: + def append_token(xs: paddle.Tensor, + x: Union[int, paddle.Tensor]) -> paddle.Tensor: """Append new token to prefix tokens. Args: @@ -154,9 +161,8 @@ class BeamSearch(paddle.nn.Layer): x = paddle.to_tensor([x], dtype=xs.dtype) if isinstance(x, int) else x return paddle.concat((xs, x)) - def score_full( - self, hyp: Hypothesis, x: paddle.Tensor - ) -> Tuple[Dict[str, paddle.Tensor], Dict[str, Any]]: + def score_full(self, hyp: Hypothesis, x: paddle.Tensor + ) -> Tuple[Dict[str, paddle.Tensor], Dict[str, Any]]: """Score new hypothesis by `self.full_scorers`. Args: @@ -178,9 +184,11 @@ class BeamSearch(paddle.nn.Layer): scores[k], states[k] = d.score(hyp.yseq, hyp.states[k], x) return scores, states - def score_partial( - self, hyp: Hypothesis, ids: paddle.Tensor, x: paddle.Tensor - ) -> Tuple[Dict[str, paddle.Tensor], Dict[str, Any]]: + def score_partial(self, + hyp: Hypothesis, + ids: paddle.Tensor, + x: paddle.Tensor + ) -> Tuple[Dict[str, paddle.Tensor], Dict[str, Any]]: """Score new hypothesis by `self.part_scorers`. Args: @@ -201,12 +209,12 @@ class BeamSearch(paddle.nn.Layer): states = dict() for k, d in self.part_scorers.items(): # scores[k] shape (len(ids),) - scores[k], states[k] = d.score_partial(hyp.yseq, ids, hyp.states[k], x) + scores[k], states[k] = d.score_partial(hyp.yseq, ids, hyp.states[k], + x) return scores, states - def beam( - self, weighted_scores: paddle.Tensor, ids: paddle.Tensor - ) -> Tuple[paddle.Tensor, paddle.Tensor]: + def beam(self, weighted_scores: paddle.Tensor, + ids: paddle.Tensor) -> Tuple[paddle.Tensor, paddle.Tensor]: """Compute topk full token ids and partial token ids. Args: @@ -223,7 +231,8 @@ class BeamSearch(paddle.nn.Layer): """ # no pre beam performed, `ids` equal to `weighted_scores` if weighted_scores.size(0) == ids.size(0): - top_ids = weighted_scores.topk(self.beam_size)[1] # index in n_vocab + top_ids = weighted_scores.topk( + self.beam_size)[1] # index in n_vocab return top_ids, top_ids # mask pruned in pre-beam not to select in topk @@ -231,18 +240,18 @@ class BeamSearch(paddle.nn.Layer): weighted_scores[:] = -float("inf") weighted_scores[ids] = tmp # top_ids no equal to local_ids, since ids shape not same - top_ids = weighted_scores.topk(self.beam_size)[1] # index in n_vocab - local_ids = weighted_scores[ids].topk(self.beam_size)[1] # index in len(ids) + top_ids = weighted_scores.topk(self.beam_size)[1] # index in n_vocab + local_ids = weighted_scores[ids].topk( + self.beam_size)[1] # index in len(ids) return top_ids, local_ids @staticmethod def merge_scores( - prev_scores: Dict[str, float], - next_full_scores: Dict[str, paddle.Tensor], - full_idx: int, - next_part_scores: Dict[str, paddle.Tensor], - part_idx: int, - ) -> Dict[str, paddle.Tensor]: + prev_scores: Dict[str, float], + next_full_scores: Dict[str, paddle.Tensor], + full_idx: int, + next_part_scores: Dict[str, paddle.Tensor], + part_idx: int, ) -> Dict[str, paddle.Tensor]: """Merge scores for new hypothesis. Args: @@ -288,9 +297,8 @@ class BeamSearch(paddle.nn.Layer): new_states[k] = d.select_state(part_states[k], part_idx) return new_states - def search( - self, running_hyps: List[Hypothesis], x: paddle.Tensor - ) -> List[Hypothesis]: + def search(self, running_hyps: List[Hypothesis], + x: paddle.Tensor) -> List[Hypothesis]: """Search new tokens for running hypotheses and encoded speech x. Args: @@ -311,11 +319,9 @@ class BeamSearch(paddle.nn.Layer): weighted_scores += self.weights[k] * scores[k] # partial scoring if self.do_pre_beam: - pre_beam_scores = ( - weighted_scores - if self.pre_beam_score_key == "full" - else scores[self.pre_beam_score_key] - ) + pre_beam_scores = (weighted_scores + if self.pre_beam_score_key == "full" else + scores[self.pre_beam_score_key]) part_ids = paddle.topk(pre_beam_scores, self.pre_beam_size)[1] part_scores, part_states = self.score_partial(hyp, part_ids, x) for k in self.part_scorers: @@ -331,22 +337,21 @@ class BeamSearch(paddle.nn.Layer): Hypothesis( score=weighted_scores[j], yseq=self.append_token(hyp.yseq, j), - scores=self.merge_scores( - hyp.scores, scores, j, part_scores, part_j - ), + scores=self.merge_scores(hyp.scores, scores, j, + part_scores, part_j), states=self.merge_states(states, part_states, part_j), - ) - ) + )) # sort and prune 2 x beam -> beam - best_hyps = sorted(best_hyps, key=lambda x: x.score, reverse=True)[ - : min(len(best_hyps), self.beam_size) - ] + best_hyps = sorted( + best_hyps, key=lambda x: x.score, + reverse=True)[:min(len(best_hyps), self.beam_size)] return best_hyps - def forward( - self, x: paddle.Tensor, maxlenratio: float = 0.0, minlenratio: float = 0.0 - ) -> List[Hypothesis]: + def forward(self, + x: paddle.Tensor, + maxlenratio: float=0.0, + minlenratio: float=0.0) -> List[Hypothesis]: """Perform beam search. Args: @@ -381,9 +386,11 @@ class BeamSearch(paddle.nn.Layer): logger.debug("position " + str(i)) best = self.search(running_hyps, x) # post process of one iteration - running_hyps = self.post_process(i, maxlen, maxlenratio, best, ended_hyps) + running_hyps = self.post_process(i, maxlen, maxlenratio, best, + ended_hyps) # end detection - if maxlenratio == 0.0 and end_detect([h.asdict() for h in ended_hyps], i): + if maxlenratio == 0.0 and end_detect( + [h.asdict() for h in ended_hyps], i): logger.info(f"end detected at {i}") break if len(running_hyps) == 0: @@ -395,15 +402,10 @@ class BeamSearch(paddle.nn.Layer): nbest_hyps = sorted(ended_hyps, key=lambda x: x.score, reverse=True) # check the number of hypotheses reaching to eos if len(nbest_hyps) == 0: - logger.warning( - "there is no N-best results, perform recognition " - "again with smaller minlenratio." - ) - return ( - [] - if minlenratio < 0.1 - else self.forward(x, maxlenratio, max(0.0, minlenratio - 0.1)) - ) + logger.warning("there is no N-best results, perform recognition " + "again with smaller minlenratio.") + return ([] if minlenratio < 0.1 else + self.forward(x, maxlenratio, max(0.0, minlenratio - 0.1))) # report the best result best = nbest_hyps[0] @@ -412,7 +414,9 @@ class BeamSearch(paddle.nn.Layer): f"{float(v):6.2f} * {self.weights[k]:3} = {float(v) * self.weights[k]:6.2f} for {k}" ) logger.info(f"total log probability: {float(best.score):.2f}") - logger.info(f"normalized log probability: {float(best.score) / len(best.yseq):.2f}") + logger.info( + f"normalized log probability: {float(best.score) / len(best.yseq):.2f}" + ) logger.info(f"total number of ended hypotheses: {len(nbest_hyps)}") if self.token_list is not None: # logger.info( @@ -420,21 +424,17 @@ class BeamSearch(paddle.nn.Layer): # + "".join([self.token_list[x] for x in best.yseq[1:-1]]) # + "\n" # ) - logger.info( - "best hypo: " - + "".join([self.token_list[x] for x in best.yseq[1:]]) - + "\n" - ) + logger.info("best hypo: " + "".join( + [self.token_list[x] for x in best.yseq[1:]]) + "\n") return nbest_hyps def post_process( - self, - i: int, - maxlen: int, - maxlenratio: float, - running_hyps: List[Hypothesis], - ended_hyps: List[Hypothesis], - ) -> List[Hypothesis]: + self, + i: int, + maxlen: int, + maxlenratio: float, + running_hyps: List[Hypothesis], + ended_hyps: List[Hypothesis], ) -> List[Hypothesis]: """Perform post-processing of beam search iterations. Args: @@ -450,10 +450,8 @@ class BeamSearch(paddle.nn.Layer): """ logger.debug(f"the number of running hypotheses: {len(running_hyps)}") if self.token_list is not None: - logger.debug( - "best hypo: " - + "".join([self.token_list[x] for x in running_hyps[0].yseq[1:]]) - ) + logger.debug("best hypo: " + "".join( + [self.token_list[x] for x in running_hyps[0].yseq[1:]])) # add eos in the final loop to avoid that there are no ended hyps if i == maxlen - 1: logger.info("adding in the last position in the loop") @@ -468,7 +466,8 @@ class BeamSearch(paddle.nn.Layer): for hyp in running_hyps: if hyp.yseq[-1] == self.eos: # e.g., Word LM needs to add final score - for k, d in chain(self.full_scorers.items(), self.part_scorers.items()): + for k, d in chain(self.full_scorers.items(), + self.part_scorers.items()): s = d.final_score(hyp.states[k]) hyp.scores[k] += s hyp = hyp._replace(score=hyp.score + self.weights[k] * s) @@ -479,19 +478,18 @@ class BeamSearch(paddle.nn.Layer): def beam_search( - x: paddle.Tensor, - sos: int, - eos: int, - beam_size: int, - vocab_size: int, - scorers: Dict[str, ScorerInterface], - weights: Dict[str, float], - token_list: List[str] = None, - maxlenratio: float = 0.0, - minlenratio: float = 0.0, - pre_beam_ratio: float = 1.5, - pre_beam_score_key: str = "full", -) -> list: + x: paddle.Tensor, + sos: int, + eos: int, + beam_size: int, + vocab_size: int, + scorers: Dict[str, ScorerInterface], + weights: Dict[str, float], + token_list: List[str]=None, + maxlenratio: float=0.0, + minlenratio: float=0.0, + pre_beam_ratio: float=1.5, + pre_beam_score_key: str="full", ) -> list: """Perform beam search with scorers. Args: @@ -527,6 +525,6 @@ def beam_search( pre_beam_score_key=pre_beam_score_key, sos=sos, eos=eos, - token_list=token_list, - ).forward(x=x, maxlenratio=maxlenratio, minlenratio=minlenratio) + token_list=token_list, ).forward( + x=x, maxlenratio=maxlenratio, minlenratio=minlenratio) return [h.asdict() for h in ret] diff --git a/deepspeech/decoders/recog.py b/deepspeech/decoders/recog.py index 867569aa0..c8df65d68 100644 --- a/deepspeech/decoders/recog.py +++ b/deepspeech/decoders/recog.py @@ -1,37 +1,57 @@ -"""V2 backend for `asr_recog.py` using py:class:`espnet.nets.beam_search.BeamSearch`.""" - +# Copyright (c) 2021 PaddlePaddle Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +"""V2 backend for `asr_recog.py` using py:class:`decoders.beam_search.BeamSearch`.""" import json +from pathlib import Path + +import jsonlines import paddle import yaml from yacs.config import CfgNode -from pathlib import Path -import jsonlines -# from espnet.asr.asr_utils import get_model_conf -# from espnet.asr.asr_utils import torch_load -# from espnet.asr.pytorch_backend.asr import load_trained_model -# from espnet.nets.lm_interface import dynamic_import_lm - -from deepspeech.models.asr_interface import ASRInterface - -from .utils import add_results_to_json -# from .batch_beam_search import BatchBeamSearch +from .beam_search import BatchBeamSearch from .beam_search import BeamSearch -from .scorers.scorer_interface import BatchScorerInterface from .scorers.length_bonus import LengthBonus - +from .scorers.scorer_interface import BatchScorerInterface +from .utils import add_results_to_json +from deepspeech.exps import dynamic_import_tester from deepspeech.io.reader import LoadInputsAndTargets +from deepspeech.models.asr_interface import ASRInterface from deepspeech.utils.log import Log +# from espnet.asr.asr_utils import get_model_conf +# from espnet.asr.asr_utils import torch_load +# from espnet.nets.lm_interface import dynamic_import_lm + logger = Log(__name__).getlog() +# NOTE: you need this func to generate our sphinx doc -from deepspeech.utils.dynamic_import import dynamic_import -from deepspeech.utils.utility import print_arguments -model_test_alias = { - "u2": "deepspeech.exps.u2.model:U2Tester", - "u2_kaldi": "deepspeech.exps.u2_kaldi.model:U2Tester", -} +def load_trained_model(args): + args.nprocs = args.ngpu + confs = CfgNode() + confs.set_new_allowed(True) + confs.merge_from_file(args.model_conf) + class_obj = dynamic_import_tester(args.model_name) + exp = class_obj(confs, args) + with exp.eval(): + exp.setup() + exp.restore() + char_list = exp.args.char_list + model = exp.model + return model, char_list, exp, confs + def recog_v2(args): """Decode with custom models that implements ScorerInterface. @@ -48,33 +68,17 @@ def recog_v2(args): raise NotImplementedError("streaming mode is not implemented") if args.word_rnnlm: raise NotImplementedError("word LM is not implemented") - args.nprocs = args.ngpu - # set_deterministic(args) - - #model, train_args = load_trained_model(args.model) - model_path = Path(args.model) - ckpt_dir = model_path.parent.parent - - confs = CfgNode() - confs.set_new_allowed(True) - confs.merge_from_file(args.model_conf) - - class_obj = dynamic_import(args.model_name, model_test_alias) - exp = class_obj(confs, args) - with exp.eval(): - exp.setup() - exp.restore() - char_list = exp.args.char_list - model = exp.model + # set_deterministic(args) + model, char_list, exp, confs = load_trained_model(args) assert isinstance(model, ASRInterface) + load_inputs_and_targets = LoadInputsAndTargets( mode="asr", load_output=False, sort_in_input_length=False, preprocess_conf=confs.collator.augmentation_config - if args.preprocess_conf is None - else args.preprocess_conf, + if args.preprocess_conf is None else args.preprocess_conf, preprocess_args={"train": False}, ) @@ -100,7 +104,7 @@ def recog_v2(args): else: ngram = None - scorers = model.scorers() + scorers = model.scorers() # decoder scorers["lm"] = lm scorers["ngram"] = ngram scorers["length_bonus"] = LengthBonus(len(char_list)) @@ -125,18 +129,15 @@ def recog_v2(args): # TODO(karita): make all scorers batchfied if args.batchsize == 1: non_batch = [ - k - for k, v in beam_search.full_scorers.items() + k for k, v in beam_search.full_scorers.items() if not isinstance(v, BatchScorerInterface) ] if len(non_batch) == 0: beam_search.__class__ = BatchBeamSearch logger.info("BatchBeamSearch implementation is selected.") else: - logger.warning( - f"As non-batch scorers {non_batch} are found, " - f"fall back to non-batch implementation." - ) + logger.warning(f"As non-batch scorers {non_batch} are found, " + f"fall back to non-batch implementation.") if args.ngpu > 1: raise NotImplementedError("only single GPU decoding is supported") @@ -157,7 +158,7 @@ def recog_v2(args): with jsonlines.open(args.recog_json, "r") as reader: for item in reader: js.append(item) - # josnlines to dict, key by 'utt' + # jsonlines to dict, key by 'utt', value by jsonline js = {item['utt']: item for item in js} new_js = {} @@ -169,25 +170,26 @@ def recog_v2(args): feat = load_inputs_and_targets(batch)[0][0] logger.info(f'feat: {feat.shape}') enc = model.encode(paddle.to_tensor(feat).to(dtype)) - logger.info(f'eouts: {enc.shape}') - nbest_hyps = beam_search( - x=enc, maxlenratio=args.maxlenratio, minlenratio=args.minlenratio - ) + logger.info(f'eout: {enc.shape}') + nbest_hyps = beam_search(x=enc, + maxlenratio=args.maxlenratio, + minlenratio=args.minlenratio) nbest_hyps = [ - h.asdict() for h in nbest_hyps[: min(len(nbest_hyps), args.nbest)] + h.asdict() + for h in nbest_hyps[:min(len(nbest_hyps), args.nbest)] ] - new_js[name] = add_results_to_json( - js[name], nbest_hyps, char_list - ) + new_js[name] = add_results_to_json(js[name], nbest_hyps, + char_list) - item = new_js[name]['output'][0] # 1-best - utt = name + item = new_js[name]['output'][0] # 1-best ref = item['text'] - rec_text = item['rec_text'].replace('▁', ' ').replace('', '').strip() - rec_tokenid = map(int, item['rec_tokenid'].split()) + rec_text = item['rec_text'].replace('▁', + ' ').replace('', + '').strip() + rec_tokenid = list(map(int, item['rec_tokenid'].split())) f.write({ - "utt": utt, - "refs": [ref], - "hyps": [rec_text], - "hyps_tokenid": [rec_tokenid], - }) \ No newline at end of file + "utt": name, + "refs": [ref], + "hyps": [rec_text], + "hyps_tokenid": [rec_tokenid], + }) diff --git a/deepspeech/decoders/recog_bin.py b/deepspeech/decoders/recog_bin.py new file mode 100644 index 000000000..567dfecde --- /dev/null +++ b/deepspeech/decoders/recog_bin.py @@ -0,0 +1,376 @@ +# Copyright (c) 2021 PaddlePaddle Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +"""End-to-end speech recognition model decoding script.""" +import logging +import os +import random +import sys +from distutils.util import strtobool + +import configargparse +import numpy as np + +from .recog import recog_v2 + + +def get_parser(): + """Get default arguments.""" + parser = configargparse.ArgumentParser( + description="Transcribe text from speech using " + "a speech recognition model on one CPU or GPU", + config_file_parser_class=configargparse.YAMLConfigFileParser, + formatter_class=configargparse.ArgumentDefaultsHelpFormatter, ) + parser.add( + '--model-name', + type=str, + default='u2_kaldi', + help='model name, e.g: deepspeech2, u2, u2_kaldi, u2_st') + # general configuration + parser.add("--config", is_config_file=True, help="Config file path") + parser.add( + "--config2", + is_config_file=True, + help="Second config file path that overwrites the settings in `--config`", + ) + parser.add( + "--config3", + is_config_file=True, + help="Third config file path that overwrites the settings " + "in `--config` and `--config2`", ) + + parser.add_argument("--ngpu", type=int, default=0, help="Number of GPUs") + parser.add_argument( + "--dtype", + choices=("float16", "float32", "float64"), + default="float32", + help="Float precision (only available in --api v2)", ) + parser.add_argument("--debugmode", type=int, default=1, help="Debugmode") + parser.add_argument("--seed", type=int, default=1, help="Random seed") + parser.add_argument( + "--verbose", "-V", type=int, default=2, help="Verbose option") + parser.add_argument( + "--batchsize", + type=int, + default=1, + help="Batch size for beam search (0: means no batch processing)", ) + parser.add_argument( + "--preprocess-conf", + type=str, + default=None, + help="The configuration file for the pre-processing", ) + parser.add_argument( + "--api", + default="v2", + choices=["v2"], + help="Beam search APIs " + "v2: Experimental API. It supports any models that implements ScorerInterface.", + ) + # task related + parser.add_argument( + "--recog-json", type=str, help="Filename of recognition data (json)") + parser.add_argument( + "--result-label", + type=str, + required=True, + help="Filename of result label data (json)", ) + # model (parameter) related + parser.add_argument( + "--model", + type=str, + required=True, + help="Model file parameters to read") + parser.add_argument( + "--model-conf", type=str, default=None, help="Model config file") + parser.add_argument( + "--num-spkrs", + type=int, + default=1, + choices=[1, 2], + help="Number of speakers in the speech", ) + parser.add_argument( + "--num-encs", + default=1, + type=int, + help="Number of encoders in the model.") + # search related + parser.add_argument( + "--nbest", type=int, default=1, help="Output N-best hypotheses") + parser.add_argument("--beam-size", type=int, default=1, help="Beam size") + parser.add_argument( + "--penalty", type=float, default=0.0, help="Incertion penalty") + parser.add_argument( + "--maxlenratio", + type=float, + default=0.0, + help="""Input length ratio to obtain max output length. + If maxlenratio=0.0 (default), it uses a end-detect function + to automatically find maximum hypothesis lengths. + If maxlenratio<0.0, its absolute value is interpreted + as a constant max output length""", ) + parser.add_argument( + "--minlenratio", + type=float, + default=0.0, + help="Input length ratio to obtain min output length", ) + parser.add_argument( + "--ctc-weight", + type=float, + default=0.0, + help="CTC weight in joint decoding") + parser.add_argument( + "--weights-ctc-dec", + type=float, + action="append", + help="ctc weight assigned to each encoder during decoding." + "[in multi-encoder mode only]", ) + parser.add_argument( + "--ctc-window-margin", + type=int, + default=0, + help="""Use CTC window with margin parameter to accelerate + CTC/attention decoding especially on GPU. Smaller magin + makes decoding faster, but may increase search errors. + If margin=0 (default), this function is disabled""", ) + # transducer related + parser.add_argument( + "--search-type", + type=str, + default="default", + choices=["default", "nsc", "tsd", "alsd", "maes"], + help="""Type of beam search implementation to use during inference. + Can be either: default beam search ("default"), + N-Step Constrained beam search ("nsc"), Time-Synchronous Decoding ("tsd"), + Alignment-Length Synchronous Decoding ("alsd") or + modified Adaptive Expansion Search ("maes").""", ) + parser.add_argument( + "--nstep", + type=int, + default=1, + help="""Number of expansion steps allowed in NSC beam search or mAES + (nstep > 0 for NSC and nstep > 1 for mAES).""", ) + parser.add_argument( + "--prefix-alpha", + type=int, + default=2, + help="Length prefix difference allowed in NSC beam search or mAES.", ) + parser.add_argument( + "--max-sym-exp", + type=int, + default=2, + help="Number of symbol expansions allowed in TSD.", ) + parser.add_argument( + "--u-max", + type=int, + default=400, + help="Length prefix difference allowed in ALSD.", ) + parser.add_argument( + "--expansion-gamma", + type=float, + default=2.3, + help="Allowed logp difference for prune-by-value method in mAES.", ) + parser.add_argument( + "--expansion-beta", + type=int, + default=2, + help="""Number of additional candidates for expanded hypotheses + selection in mAES.""", ) + parser.add_argument( + "--score-norm", + type=strtobool, + nargs="?", + default=True, + help="Normalize final hypotheses' score by length", ) + parser.add_argument( + "--softmax-temperature", + type=float, + default=1.0, + help="Penalization term for softmax function.", ) + # rnnlm related + parser.add_argument( + "--rnnlm", type=str, default=None, help="RNNLM model file to read") + parser.add_argument( + "--rnnlm-conf", + type=str, + default=None, + help="RNNLM model config file to read") + parser.add_argument( + "--word-rnnlm", + type=str, + default=None, + help="Word RNNLM model file to read") + parser.add_argument( + "--word-rnnlm-conf", + type=str, + default=None, + help="Word RNNLM model config file to read", ) + parser.add_argument( + "--word-dict", type=str, default=None, help="Word list to read") + parser.add_argument( + "--lm-weight", type=float, default=0.1, help="RNNLM weight") + # ngram related + parser.add_argument( + "--ngram-model", + type=str, + default=None, + help="ngram model file to read") + parser.add_argument( + "--ngram-weight", type=float, default=0.1, help="ngram weight") + parser.add_argument( + "--ngram-scorer", + type=str, + default="part", + choices=("full", "part"), + help="""if the ngram is set as a part scorer, similar with CTC scorer, + ngram scorer only scores topK hypethesis. + if the ngram is set as full scorer, ngram scorer scores all hypthesis + the decoding speed of part scorer is musch faster than full one""", + ) + # streaming related + parser.add_argument( + "--streaming-mode", + type=str, + default=None, + choices=["window", "segment"], + help="""Use streaming recognizer for inference. + `--batchsize` must be set to 0 to enable this mode""", ) + parser.add_argument( + "--streaming-window", type=int, default=10, help="Window size") + parser.add_argument( + "--streaming-min-blank-dur", + type=int, + default=10, + help="Minimum blank duration threshold", ) + parser.add_argument( + "--streaming-onset-margin", type=int, default=1, help="Onset margin") + parser.add_argument( + "--streaming-offset-margin", type=int, default=1, help="Offset margin") + # non-autoregressive related + # Mask CTC related. See https://arxiv.org/abs/2005.08700 for the detail. + parser.add_argument( + "--maskctc-n-iterations", + type=int, + default=10, + help="Number of decoding iterations." + "For Mask CTC, set 0 to predict 1 mask/iter.", ) + parser.add_argument( + "--maskctc-probability-threshold", + type=float, + default=0.999, + help="Threshold probability for CTC output", ) + # quantize model related + parser.add_argument( + "--quantize-config", + nargs="*", + help="Quantize config list. E.g.: --quantize-config=[Linear,LSTM,GRU]", + ) + parser.add_argument( + "--quantize-dtype", + type=str, + default="qint8", + help="Dtype dynamic quantize") + parser.add_argument( + "--quantize-asr-model", + type=bool, + default=False, + help="Quantize asr model", ) + parser.add_argument( + "--quantize-lm-model", + type=bool, + default=False, + help="Quantize lm model", ) + return parser + + +def main(args): + """Run the main decoding function.""" + parser = get_parser() + parser.add_argument( + "--output", metavar="CKPT_DIR", help="path to save checkpoint.") + parser.add_argument( + "--checkpoint_path", type=str, help="path to load checkpoint") + parser.add_argument("--dict-path", type=str, help="path to load checkpoint") + args = parser.parse_args(args) + + if args.ngpu == 0 and args.dtype == "float16": + raise ValueError( + f"--dtype {args.dtype} does not support the CPU backend.") + + # logging info + if args.verbose == 1: + logging.basicConfig( + level=logging.INFO, + format="%(asctime)s (%(module)s:%(lineno)d) %(levelname)s: %(message)s", + ) + elif args.verbose == 2: + logging.basicConfig( + level=logging.DEBUG, + format="%(asctime)s (%(module)s:%(lineno)d) %(levelname)s: %(message)s", + ) + else: + logging.basicConfig( + level=logging.WARN, + format="%(asctime)s (%(module)s:%(lineno)d) %(levelname)s: %(message)s", + ) + logging.warning("Skip DEBUG/INFO messages") + logging.info(args) + + # check CUDA_VISIBLE_DEVICES + if args.ngpu > 0: + cvd = os.environ.get("CUDA_VISIBLE_DEVICES") + if cvd is None: + logging.warning("CUDA_VISIBLE_DEVICES is not set.") + elif args.ngpu != len(cvd.split(",")): + logging.error("#gpus is not matched with CUDA_VISIBLE_DEVICES.") + sys.exit(1) + + # TODO(mn5k): support of multiple GPUs + if args.ngpu > 1: + logging.error("The program only supports ngpu=1.") + sys.exit(1) + + # display PYTHONPATH + logging.info("python path = " + os.environ.get("PYTHONPATH", "(None)")) + + # seed setting + random.seed(args.seed) + np.random.seed(args.seed) + logging.info("set random seed = %d" % args.seed) + + # validate rnn options + if args.rnnlm is not None and args.word_rnnlm is not None: + logging.error( + "It seems that both --rnnlm and --word-rnnlm are specified. " + "Please use either option.") + sys.exit(1) + + # recog + if args.num_spkrs == 1: + if args.num_encs == 1: + # Experimental API that supports custom LMs + if args.api == "v2": + from deepspeech.decoders.recog import recog_v2 + recog_v2(args) + else: + raise ValueError("Only support --api v2") + else: + if args.api == "v2": + raise NotImplementedError( + f"--num-encs {args.num_encs} > 1 is not supported in --api v2" + ) + elif args.num_spkrs == 2: + raise ValueError("asr_mix not supported.") + + +if __name__ == "__main__": + main(sys.argv[1:]) diff --git a/deepspeech/decoders/scorers/ngram.py b/deepspeech/decoders/scorers/ngram.py index 050a8c81f..a34d82483 100644 --- a/deepspeech/decoders/scorers/ngram.py +++ b/deepspeech/decoders/scorers/ngram.py @@ -85,8 +85,9 @@ class NgramFullScorer(Ngrambase, BatchScorerInterface): and next state list for ys. """ - return self.score_partial_( - y, paddle.to_tensor(range(self.charlen)), state, x) + return self.score_partial_(y, + paddle.to_tensor(range(self.charlen)), state, + x) class NgramPartScorer(Ngrambase, PartialScorerInterface): diff --git a/deepspeech/decoders/utils.py b/deepspeech/decoders/utils.py index f59b55d9b..3ed9c5da5 100644 --- a/deepspeech/decoders/utils.py +++ b/deepspeech/decoders/utils.py @@ -11,8 +11,8 @@ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. - import numpy as np + from deepspeech.utils.log import Log logger = Log(__name__).getlog() @@ -98,7 +98,8 @@ def add_results_to_json(js, nbest_hyps, char_list): for n, hyp in enumerate(nbest_hyps, 1): # parse hypothesis - rec_text, rec_token, rec_tokenid, score = parse_hypothesis(hyp, char_list) + rec_text, rec_token, rec_tokenid, score = parse_hypothesis(hyp, + char_list) # copy ground-truth if len(js["output"]) > 0: @@ -125,4 +126,4 @@ def add_results_to_json(js, nbest_hyps, char_list): logger.info("groundtruth: %s" % out_dic["text"]) logger.info("prediction : %s" % out_dic["rec_text"]) - return new_js \ No newline at end of file + return new_js diff --git a/deepspeech/exps/__init__.py b/deepspeech/exps/__init__.py index 185a92b8d..299530146 100644 --- a/deepspeech/exps/__init__.py +++ b/deepspeech/exps/__init__.py @@ -11,3 +11,52 @@ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. +from deepspeech.training.trainer import Trainer +from deepspeech.utils.dynamic_import import dynamic_import + +model_trainer_alias = { + "ds2": "deepspeech.exp.deepspeech2.model:DeepSpeech2Trainer", + "u2": "deepspeech.exps.u2.model:U2Trainer", + "u2_kaldi": "deepspeech.exps.u2_kaldi.model:U2Trainer", + "u2_st": "deepspeech.exps.u2_st.model:U2STTrainer", +} + + +def dynamic_import_trainer(module): + """Import Trainer dynamically. + + Args: + module (str): trainer name. e.g., ds2, u2, u2_kaldi + + Returns: + type: Trainer class + + """ + model_class = dynamic_import(module, model_trainer_alias) + assert issubclass(model_class, + Trainer), f"{module} does not implement Trainer" + return model_class + + +model_tester_alias = { + "ds2": "deepspeech.exp.deepspeech2.model:DeepSpeech2Tester", + "u2": "deepspeech.exps.u2.model:U2Tester", + "u2_kaldi": "deepspeech.exps.u2_kaldi.model:U2Tester", + "u2_st": "deepspeech.exps.u2_st.model:U2STTester", +} + + +def dynamic_import_tester(module): + """Import Tester dynamically. + + Args: + module (str): tester name. e.g., ds2, u2, u2_kaldi + + Returns: + type: Tester class + + """ + model_class = dynamic_import(module, model_tester_alias) + assert issubclass(model_class, + Trainer), f"{module} does not implement Tester" + return model_class diff --git a/deepspeech/exps/u2_kaldi/bin/recog.py b/deepspeech/exps/u2_kaldi/bin/recog.py index 60f4e58ac..e94a1ab18 100644 --- a/deepspeech/exps/u2_kaldi/bin/recog.py +++ b/deepspeech/exps/u2_kaldi/bin/recog.py @@ -1,379 +1,19 @@ - -"""End-to-end speech recognition model decoding script.""" - -import configargparse -import logging -import os -import random +# Copyright (c) 2021 PaddlePaddle Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. import sys -import numpy as np - -from distutils.util import strtobool -from deepspeech.training.cli import default_argument_parser - -# NOTE: you need this func to generate our sphinx doc - -def get_parser(): - """Get default arguments.""" - parser = configargparse.ArgumentParser( - description="Transcribe text from speech using " - "a speech recognition model on one CPU or GPU", - config_file_parser_class=configargparse.YAMLConfigFileParser, - formatter_class=configargparse.ArgumentDefaultsHelpFormatter, - ) - parser.add( - '--model-name', - type=str, - default='u2_kaldi', - help='model name, e.g: deepspeech2, u2, u2_kaldi, u2_st') - # general configuration - parser.add("--config", is_config_file=True, help="Config file path") - parser.add( - "--config2", - is_config_file=True, - help="Second config file path that overwrites the settings in `--config`", - ) - parser.add( - "--config3", - is_config_file=True, - help="Third config file path that overwrites the settings " - "in `--config` and `--config2`", - ) - - parser.add_argument("--ngpu", type=int, default=0, help="Number of GPUs") - parser.add_argument( - "--dtype", - choices=("float16", "float32", "float64"), - default="float32", - help="Float precision (only available in --api v2)", - ) - parser.add_argument("--debugmode", type=int, default=1, help="Debugmode") - parser.add_argument("--seed", type=int, default=1, help="Random seed") - parser.add_argument("--verbose", "-V", type=int, default=2, help="Verbose option") - parser.add_argument( - "--batchsize", - type=int, - default=1, - help="Batch size for beam search (0: means no batch processing)", - ) - parser.add_argument( - "--preprocess-conf", - type=str, - default=None, - help="The configuration file for the pre-processing", - ) - parser.add_argument( - "--api", - default="v2", - choices=["v2"], - help="Beam search APIs " - "v2: Experimental API. It supports any models that implements ScorerInterface.", - ) - # task related - parser.add_argument( - "--recog-json", type=str, help="Filename of recognition data (json)" - ) - parser.add_argument( - "--result-label", - type=str, - required=True, - help="Filename of result label data (json)", - ) - # model (parameter) related - parser.add_argument( - "--model", type=str, required=True, help="Model file parameters to read" - ) - parser.add_argument( - "--model-conf", type=str, default=None, help="Model config file" - ) - parser.add_argument( - "--num-spkrs", - type=int, - default=1, - choices=[1, 2], - help="Number of speakers in the speech", - ) - parser.add_argument( - "--num-encs", default=1, type=int, help="Number of encoders in the model." - ) - # search related - parser.add_argument("--nbest", type=int, default=1, help="Output N-best hypotheses") - parser.add_argument("--beam-size", type=int, default=1, help="Beam size") - parser.add_argument("--penalty", type=float, default=0.0, help="Incertion penalty") - parser.add_argument( - "--maxlenratio", - type=float, - default=0.0, - help="""Input length ratio to obtain max output length. - If maxlenratio=0.0 (default), it uses a end-detect function - to automatically find maximum hypothesis lengths. - If maxlenratio<0.0, its absolute value is interpreted - as a constant max output length""", - ) - parser.add_argument( - "--minlenratio", - type=float, - default=0.0, - help="Input length ratio to obtain min output length", - ) - parser.add_argument( - "--ctc-weight", type=float, default=0.0, help="CTC weight in joint decoding" - ) - parser.add_argument( - "--weights-ctc-dec", - type=float, - action="append", - help="ctc weight assigned to each encoder during decoding." - "[in multi-encoder mode only]", - ) - parser.add_argument( - "--ctc-window-margin", - type=int, - default=0, - help="""Use CTC window with margin parameter to accelerate - CTC/attention decoding especially on GPU. Smaller magin - makes decoding faster, but may increase search errors. - If margin=0 (default), this function is disabled""", - ) - # transducer related - parser.add_argument( - "--search-type", - type=str, - default="default", - choices=["default", "nsc", "tsd", "alsd", "maes"], - help="""Type of beam search implementation to use during inference. - Can be either: default beam search ("default"), - N-Step Constrained beam search ("nsc"), Time-Synchronous Decoding ("tsd"), - Alignment-Length Synchronous Decoding ("alsd") or - modified Adaptive Expansion Search ("maes").""", - ) - parser.add_argument( - "--nstep", - type=int, - default=1, - help="""Number of expansion steps allowed in NSC beam search or mAES - (nstep > 0 for NSC and nstep > 1 for mAES).""", - ) - parser.add_argument( - "--prefix-alpha", - type=int, - default=2, - help="Length prefix difference allowed in NSC beam search or mAES.", - ) - parser.add_argument( - "--max-sym-exp", - type=int, - default=2, - help="Number of symbol expansions allowed in TSD.", - ) - parser.add_argument( - "--u-max", - type=int, - default=400, - help="Length prefix difference allowed in ALSD.", - ) - parser.add_argument( - "--expansion-gamma", - type=float, - default=2.3, - help="Allowed logp difference for prune-by-value method in mAES.", - ) - parser.add_argument( - "--expansion-beta", - type=int, - default=2, - help="""Number of additional candidates for expanded hypotheses - selection in mAES.""", - ) - parser.add_argument( - "--score-norm", - type=strtobool, - nargs="?", - default=True, - help="Normalize final hypotheses' score by length", - ) - parser.add_argument( - "--softmax-temperature", - type=float, - default=1.0, - help="Penalization term for softmax function.", - ) - # rnnlm related - parser.add_argument( - "--rnnlm", type=str, default=None, help="RNNLM model file to read" - ) - parser.add_argument( - "--rnnlm-conf", type=str, default=None, help="RNNLM model config file to read" - ) - parser.add_argument( - "--word-rnnlm", type=str, default=None, help="Word RNNLM model file to read" - ) - parser.add_argument( - "--word-rnnlm-conf", - type=str, - default=None, - help="Word RNNLM model config file to read", - ) - parser.add_argument("--word-dict", type=str, default=None, help="Word list to read") - parser.add_argument("--lm-weight", type=float, default=0.1, help="RNNLM weight") - # ngram related - parser.add_argument( - "--ngram-model", type=str, default=None, help="ngram model file to read" - ) - parser.add_argument("--ngram-weight", type=float, default=0.1, help="ngram weight") - parser.add_argument( - "--ngram-scorer", - type=str, - default="part", - choices=("full", "part"), - help="""if the ngram is set as a part scorer, similar with CTC scorer, - ngram scorer only scores topK hypethesis. - if the ngram is set as full scorer, ngram scorer scores all hypthesis - the decoding speed of part scorer is musch faster than full one""", - ) - # streaming related - parser.add_argument( - "--streaming-mode", - type=str, - default=None, - choices=["window", "segment"], - help="""Use streaming recognizer for inference. - `--batchsize` must be set to 0 to enable this mode""", - ) - parser.add_argument("--streaming-window", type=int, default=10, help="Window size") - parser.add_argument( - "--streaming-min-blank-dur", - type=int, - default=10, - help="Minimum blank duration threshold", - ) - parser.add_argument( - "--streaming-onset-margin", type=int, default=1, help="Onset margin" - ) - parser.add_argument( - "--streaming-offset-margin", type=int, default=1, help="Offset margin" - ) - # non-autoregressive related - # Mask CTC related. See https://arxiv.org/abs/2005.08700 for the detail. - parser.add_argument( - "--maskctc-n-iterations", - type=int, - default=10, - help="Number of decoding iterations." - "For Mask CTC, set 0 to predict 1 mask/iter.", - ) - parser.add_argument( - "--maskctc-probability-threshold", - type=float, - default=0.999, - help="Threshold probability for CTC output", - ) - # quantize model related - parser.add_argument( - "--quantize-config", - nargs="*", - help="Quantize config list. E.g.: --quantize-config=[Linear,LSTM,GRU]", - ) - parser.add_argument( - "--quantize-dtype", type=str, default="qint8", help="Dtype dynamic quantize" - ) - parser.add_argument( - "--quantize-asr-model", - type=bool, - default=False, - help="Quantize asr model", - ) - parser.add_argument( - "--quantize-lm-model", - type=bool, - default=False, - help="Quantize lm model", - ) - return parser - - -def main(args): - """Run the main decoding function.""" - parser = get_parser() - parser.add_argument( - "--output", metavar="CKPT_DIR", help="path to save checkpoint.") - parser.add_argument( - "--checkpoint_path", type=str, help="path to load checkpoint") - parser.add_argument( - "--dict-path", type=str, help="path to load checkpoint") - # parser = default_argument_parser(parser) - args = parser.parse_args(args) - - if args.ngpu == 0 and args.dtype == "float16": - raise ValueError(f"--dtype {args.dtype} does not support the CPU backend.") - - # logging info - if args.verbose == 1: - logging.basicConfig( - level=logging.INFO, - format="%(asctime)s (%(module)s:%(lineno)d) %(levelname)s: %(message)s", - ) - elif args.verbose == 2: - logging.basicConfig( - level=logging.DEBUG, - format="%(asctime)s (%(module)s:%(lineno)d) %(levelname)s: %(message)s", - ) - else: - logging.basicConfig( - level=logging.WARN, - format="%(asctime)s (%(module)s:%(lineno)d) %(levelname)s: %(message)s", - ) - logging.warning("Skip DEBUG/INFO messages") - logging.info(args) - - # check CUDA_VISIBLE_DEVICES - if args.ngpu > 0: - cvd = os.environ.get("CUDA_VISIBLE_DEVICES") - if cvd is None: - logging.warning("CUDA_VISIBLE_DEVICES is not set.") - elif args.ngpu != len(cvd.split(",")): - logging.error("#gpus is not matched with CUDA_VISIBLE_DEVICES.") - sys.exit(1) - - # TODO(mn5k): support of multiple GPUs - if args.ngpu > 1: - logging.error("The program only supports ngpu=1.") - sys.exit(1) - - # display PYTHONPATH - logging.info("python path = " + os.environ.get("PYTHONPATH", "(None)")) - - # seed setting - random.seed(args.seed) - np.random.seed(args.seed) - logging.info("set random seed = %d" % args.seed) - - # validate rnn options - if args.rnnlm is not None and args.word_rnnlm is not None: - logging.error( - "It seems that both --rnnlm and --word-rnnlm are specified. " - "Please use either option." - ) - sys.exit(1) - - # recog - if args.num_spkrs == 1: - if args.num_encs == 1: - # Experimental API that supports custom LMs - if args.api == "v2": - from deepspeech.decoders.recog import recog_v2 - recog_v2(args) - else: - raise ValueError("Only support --api v2") - else: - if args.api == "v2": - raise NotImplementedError( - f"--num-encs {args.num_encs} > 1 is not supported in --api v2" - ) - elif args.num_spkrs == 2: - raise ValueError("asr_mix not supported.") - +from deepspeech.decoders.recog_bin import main if __name__ == "__main__": main(sys.argv[1:]) diff --git a/deepspeech/exps/u2_kaldi/model.py b/deepspeech/exps/u2_kaldi/model.py index 3aadca856..f86243269 100644 --- a/deepspeech/exps/u2_kaldi/model.py +++ b/deepspeech/exps/u2_kaldi/model.py @@ -434,8 +434,9 @@ class U2Tester(U2Trainer): simulate_streaming=cfg.simulate_streaming) decode_time = time.time() - start_time - for i, (utt, target, result, rec_tids) in enumerate(zip( - utts, target_transcripts, result_transcripts, result_tokenids)): + for i, (utt, target, result, rec_tids) in enumerate( + zip(utts, target_transcripts, result_transcripts, + result_tokenids)): errors, len_ref = errors_func(target, result) errors_sum += errors len_refs += len_ref diff --git a/deepspeech/frontend/featurizer/text_featurizer.py b/deepspeech/frontend/featurizer/text_featurizer.py index 18ff411b0..a6834ebc6 100644 --- a/deepspeech/frontend/featurizer/text_featurizer.py +++ b/deepspeech/frontend/featurizer/text_featurizer.py @@ -140,7 +140,7 @@ class TextFeaturizer(): Returns: str: text string. """ - tokens = [t.replace(SPACE, " ") for t in tokens ] + tokens = [t.replace(SPACE, " ") for t in tokens] return "".join(tokens) def word_tokenize(self, text): diff --git a/deepspeech/models/asr_interface.py b/deepspeech/models/asr_interface.py index eb820fc05..7dac81b4f 100644 --- a/deepspeech/models/asr_interface.py +++ b/deepspeech/models/asr_interface.py @@ -1,3 +1,16 @@ +# Copyright (c) 2021 PaddlePaddle Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. """ASR Interface module.""" import argparse @@ -72,7 +85,8 @@ class ASRInterface: :return: attention weights (B, Lmax, Tmax) :rtype: float ndarray """ - raise NotImplementedError("calculate_all_attentions method is not implemented") + raise NotImplementedError( + "calculate_all_attentions method is not implemented") def calculate_all_ctc_probs(self, xs, ilens, ys): """Calculate CTC probability. @@ -83,7 +97,8 @@ class ASRInterface: :return: CTC probabilities (B, Tmax, vocab) :rtype: float ndarray """ - raise NotImplementedError("calculate_all_ctc_probs method is not implemented") + raise NotImplementedError( + "calculate_all_ctc_probs method is not implemented") @property def attention_plot_class(self): @@ -102,8 +117,7 @@ class ASRInterface: def get_total_subsampling_factor(self): """Get total subsampling factor.""" raise NotImplementedError( - "get_total_subsampling_factor method is not implemented" - ) + "get_total_subsampling_factor method is not implemented") def encode(self, feat): """Encode feature in `beam_search` (optional). @@ -126,23 +140,22 @@ class ASRInterface: predefined_asr = { - "transformer": "deepspeech.models.u2:E2E", - "conformer": "deepspeech.models.u2:E2E", + "transformer": "deepspeech.models.u2:U2Model", + "conformer": "deepspeech.models.u2:U2Model", } -def dynamic_import_asr(module, name): + +def dynamic_import_asr(module): """Import ASR models dynamically. Args: - module (str): module_name:class_name or alias in `predefined_asr` - name (str): asr name. e.g., transformer, conformer + module (str): asr name. e.g., transformer, conformer Returns: type: ASR class """ - model_class = dynamic_import(module, predefined_asr.get(name, "")) - assert issubclass( - model_class, ASRInterface - ), f"{module} does not implement ASRInterface" + model_class = dynamic_import(module, predefined_asr) + assert issubclass(model_class, + ASRInterface), f"{module} does not implement ASRInterface" return model_class diff --git a/deepspeech/models/u2/u2.py b/deepspeech/models/u2/u2.py index fa517906f..6cd3b7751 100644 --- a/deepspeech/models/u2/u2.py +++ b/deepspeech/models/u2/u2.py @@ -28,8 +28,10 @@ from paddle import jit from paddle import nn from yacs.config import CfgNode +from deepspeech.decoders.scorers.ctc import CTCPrefixScorer from deepspeech.frontend.utility import IGNORE_ID from deepspeech.frontend.utility import load_cmvn +from deepspeech.models.asr_interface import ASRInterface from deepspeech.modules.cmvn import GlobalCMVN from deepspeech.modules.ctc import CTCDecoder from deepspeech.modules.decoder import TransformerDecoder @@ -49,8 +51,6 @@ from deepspeech.utils.tensor_utils import pad_sequence from deepspeech.utils.tensor_utils import th_accuracy from deepspeech.utils.utility import log_add from deepspeech.utils.utility import UpdateConfig -from deepspeech.models.asr_interface import ASRInterface -from deepspeech.decoders.scorers.ctc import CTCPrefixScorer __all__ = ["U2Model", "U2InferModel"] @@ -816,10 +816,10 @@ class U2BaseModel(ASRInterface, nn.Layer): class U2DecodeModel(U2BaseModel): - def scorers(self): """Scorers.""" - return dict(decoder=self.decoder, ctc=CTCPrefixScorer(self.ctc, self.eos)) + return dict( + decoder=self.decoder, ctc=CTCPrefixScorer(self.ctc, self.eos)) def encode(self, x): """Encode acoustic features. diff --git a/deepspeech/modules/decoder.py b/deepspeech/modules/decoder.py index ee3572ea9..735f06dc6 100644 --- a/deepspeech/modules/decoder.py +++ b/deepspeech/modules/decoder.py @@ -12,23 +12,23 @@ # See the License for the specific language governing permissions and # limitations under the License. """Decoder definition.""" +from typing import Any from typing import List from typing import Optional from typing import Tuple -from typing import Any import paddle from paddle import nn from typeguard import check_argument_types +from deepspeech.decoders.scorers.scorer_interface import BatchScorerInterface from deepspeech.modules.attention import MultiHeadedAttention from deepspeech.modules.decoder_layer import DecoderLayer from deepspeech.modules.embedding import PositionalEncoding from deepspeech.modules.mask import make_non_pad_mask -from deepspeech.modules.mask import subsequent_mask from deepspeech.modules.mask import make_xs_mask +from deepspeech.modules.mask import subsequent_mask from deepspeech.modules.positionwise_feed_forward import PositionwiseFeedForward -from deepspeech.decoders.scorers.scorer_interface import BatchScorerInterface from deepspeech.utils.log import Log logger = Log(__name__).getlog() @@ -191,8 +191,8 @@ class TransformerDecoder(BatchScorerInterface, nn.Layer): ys: (ylen,) x: (xlen, n_feat) """ - ys_mask = subsequent_mask(len(ys)).unsqueeze(0) # (B,L,L) - x_mask = make_xs_mask(x.unsqueeze(0)).unsqueeze(1) # (B,1,T) + ys_mask = subsequent_mask(len(ys)).unsqueeze(0) # (B,L,L) + x_mask = make_xs_mask(x.unsqueeze(0)).unsqueeze(1) # (B,1,T) if self.selfattention_layer_type != "selfattn": # TODO(karita): implement cache logging.warning( @@ -200,16 +200,14 @@ class TransformerDecoder(BatchScorerInterface, nn.Layer): ) state = None logp, state = self.forward_one_step( - x.unsqueeze(0), x_mask, - ys.unsqueeze(0), ys_mask, - cache=state - ) + x.unsqueeze(0), x_mask, ys.unsqueeze(0), ys_mask, cache=state) return logp.squeeze(0), state # batch beam search API (see BatchScorerInterface) - def batch_score( - self, ys: paddle.Tensor, states: List[Any], xs: paddle.Tensor - ) -> Tuple[paddle.Tensor, List[Any]]: + def batch_score(self, + ys: paddle.Tensor, + states: List[Any], + xs: paddle.Tensor) -> Tuple[paddle.Tensor, List[Any]]: """Score new token batch (required). Args: @@ -237,10 +235,12 @@ class TransformerDecoder(BatchScorerInterface, nn.Layer): ] # batch decoding - ys_mask = subsequent_mask(ys.size(-1)).unsqueeze(0) # (B,L,L) - xs_mask = make_xs_mask(xs).unsqueeze(1) # (B,1,T) - logp, states = self.forward_one_step(xs, xs_mask, ys, ys_mask, cache=batch_state) + ys_mask = subsequent_mask(ys.size(-1)).unsqueeze(0) # (B,L,L) + xs_mask = make_xs_mask(xs).unsqueeze(1) # (B,1,T) + logp, states = self.forward_one_step( + xs, xs_mask, ys, ys_mask, cache=batch_state) # transpose state of [layer, batch] into [batch, layer] - state_list = [[states[i][b] for i in range(n_layers)] for b in range(n_batch)] + state_list = [[states[i][b] for i in range(n_layers)] + for b in range(n_batch)] return logp, state_list diff --git a/deepspeech/modules/mask.py b/deepspeech/modules/mask.py index 7ae418c03..52f8e4bca 100644 --- a/deepspeech/modules/mask.py +++ b/deepspeech/modules/mask.py @@ -24,7 +24,7 @@ __all__ = [ ] -def make_xs_mask(xs:paddle.Tensor, pad_value=0.0) -> paddle.Tensor: +def make_xs_mask(xs: paddle.Tensor, pad_value=0.0) -> paddle.Tensor: """Maks mask tensor containing indices of non-padded part. Args: xs (paddle.Tensor): (B, T, D), zeros for pad. diff --git a/deepspeech/training/cli.py b/deepspeech/training/cli.py index db7076d3d..14a34cb75 100644 --- a/deepspeech/training/cli.py +++ b/deepspeech/training/cli.py @@ -64,7 +64,7 @@ def default_argument_parser(parser=None): """ if parser is None: parser = argparse.ArgumentParser() - + parser.register('action', 'extend', ExtendAction) parser.add_argument( '--conf', type=open, action=LoadFromFile, help="config file.") diff --git a/deepspeech/training/trainer.py b/deepspeech/training/trainer.py index b2ee7a1b8..2c2389203 100644 --- a/deepspeech/training/trainer.py +++ b/deepspeech/training/trainer.py @@ -126,7 +126,8 @@ class Trainer(): logger.info(f"Set seed {args.seed}") # profiler and benchmark options - if hasattr(self.args, "benchmark_batch_size") and self.args.benchmark_batch_size: + if hasattr(self.args, + "benchmark_batch_size") and self.args.benchmark_batch_size: with UpdateConfig(self.config): self.config.collator.batch_size = self.args.benchmark_batch_size self.config.training.log_interval = 1 @@ -335,8 +336,7 @@ class Trainer(): """ assert self.args.checkpoint_path infos = self.checkpoint.load_latest_parameters( - self.model, - checkpoint_path=self.args.checkpoint_path) + self.model, checkpoint_path=self.args.checkpoint_path) return infos def run_test(self): diff --git a/examples/librispeech/README.md b/examples/librispeech/README.md index 724590952..5943cf1d7 100644 --- a/examples/librispeech/README.md +++ b/examples/librispeech/README.md @@ -1,8 +1,8 @@ # ASR -* s0 is for deepspeech2 offline -* s1 is for transformer/conformer/U2 -* s2 is for transformer/conformer/U2 w/ kaldi feat, need install Kaldi +* s0 is for deepspeech2 offline +* s1 is for transformer/conformer/U2 +* s2 is for transformer/conformer/U2 w/ kaldi feat, need install Kaldi ## Data | Data Subset | Duration in Seconds | diff --git a/examples/librispeech/s2/README.md b/examples/librispeech/s2/README.md index 6e41cc370..d5df37d84 100644 --- a/examples/librispeech/s2/README.md +++ b/examples/librispeech/s2/README.md @@ -1,14 +1,14 @@ # LibriSpeech | Model | Params | Config | Augmentation| Loss | -| --- | --- | --- | --- | +| --- | --- | --- | --- | | transformer | 32.52 M | conf/transformer.yaml | spec_aug | 6.3197922706604 | -| Test Set | Decode Method | #Snt | #Wrd | Corr | Sub | Del | Ins | Err | S.Err | +| Test Set | Decode Method | #Snt | #Wrd | Corr | Sub | Del | Ins | Err | S.Err | | --- | --- | --- | --- | --- | --- | --- | --- | --- | --- | -| test-clean | attention | 2620 | 52576 | 96.4 | 2.5 | 1.1 | 0.4 | 4.0 | 34.7 | -| test-clean | ctc_greedy_search | 2620 | 52576 | 95.9 | 3.7 | 0.4 | 0.5 | 4.6 | 48.0 | -| test-clean | ctc_prefix_beamsearch | 2620 | 52576 | 95.9 | 3.7 | 0.4 | 0.5 | 4.6 | 47.6 | -| test-clean | attention_rescore | 2620 | 52576 | 96.8 | 2.9 | 0.3 | 0.4 | 3.7 | 38.0 | -| test-clean | join_ctc_w/o_lm | 2620 | 52576 | 97.2 | 2.6 | 0.3 | 0.4 | 3.2 | 34.9 | +| test-clean | attention | 2620 | 52576 | 96.4 | 2.5 | 1.1 | 0.4 | 4.0 | 34.7 | +| test-clean | ctc_greedy_search | 2620 | 52576 | 95.9 | 3.7 | 0.4 | 0.5 | 4.6 | 48.0 | +| test-clean | ctc_prefix_beamsearch | 2620 | 52576 | 95.9 | 3.7 | 0.4 | 0.5 | 4.6 | 47.6 | +| test-clean | attention_rescore | 2620 | 52576 | 96.8 | 2.9 | 0.3 | 0.4 | 3.7 | 38.0 | +| test-clean | join_ctc_w/o_lm | 2620 | 52576 | 97.2 | 2.6 | 0.3 | 0.4 | 3.2 | 34.9 | diff --git a/examples/librispeech/s2/conf/decode/decode.yaml b/examples/librispeech/s2/conf/decode/decode.yaml index 867bf6118..4c702db56 100644 --- a/examples/librispeech/s2/conf/decode/decode.yaml +++ b/examples/librispeech/s2/conf/decode/decode.yaml @@ -1,6 +1,6 @@ batchsize: 0 beam-size: 60 -ctc-weight: 0.4 +ctc-weight: 0.0 lm-weight: 0.0 maxlenratio: 0.0 minlenratio: 0.0 diff --git a/examples/librispeech/s2/local/recog.sh b/examples/librispeech/s2/local/recog.sh index 0393535b4..df3846c02 100755 --- a/examples/librispeech/s2/local/recog.sh +++ b/examples/librispeech/s2/local/recog.sh @@ -5,11 +5,14 @@ set -e expdir=exp datadir=data nj=32 +tag= +# decode config decode_config=conf/decode/decode.yaml + +# lm params lang_model=rnnlm.model.best lmexpdir=exp/train_rnnlm_pytorch_lm_transformer_cosine_batchsize32_lr1e-4_layer16_unigram5000_ngpu4/ - lmtag='nolm' recog_set="test-clean test-other dev-clean dev-other" @@ -21,18 +24,21 @@ bpemode=unigram bpeprefix="data/bpe_${bpemode}_${nbpe}" bpemodel=${bpeprefix}.model -if [ $# != 3 ];then - echo "usage: ${0} config_path dict_path ckpt_path_prefix" - exit -1 +# bin params +config_path=conf/transformer.yaml +dict=data/bpe_unigram_5000_units.txt +ckpt_prefix= + +source ${MAIN_ROOT}/utils/parse_options.sh || exit 1; + +if [ -z ${ckpt_prefix} ]; then + echo "usage: $0 --ckpt_prefix ckpt_prefix" + exit 1 fi ngpu=$(echo $CUDA_VISIBLE_DEVICES | awk -F "," '{print NF}') echo "using $ngpu gpus..." -config_path=$1 -dict=$2 -ckpt_prefix=$3 - ckpt_dir=$(dirname `dirname ${ckpt_prefix}`) echo "ckpt dir: ${ckpt_dir}" @@ -61,7 +67,7 @@ for dmethd in join_ctc; do for rtask in ${recog_set}; do ( echo "${rtask} dataset" - decode_dir=${ckpt_dir}/decode/decode_${rtask/-/_}_${dmethd}_$(basename ${config_path%.*})_${lmtag}_${ckpt_tag} + decode_dir=${ckpt_dir}/decode/decode_${rtask/-/_}_${dmethd}_$(basename ${config_path%.*})_${lmtag}_${ckpt_tag}_${tag} feat_recog_dir=${datadir} mkdir -p ${decode_dir} mkdir -p ${feat_recog_dir} diff --git a/examples/librispeech/s2/local/test.sh b/examples/librispeech/s2/local/test.sh index 006a13c5f..5f662d292 100755 --- a/examples/librispeech/s2/local/test.sh +++ b/examples/librispeech/s2/local/test.sh @@ -17,19 +17,20 @@ bpemode=unigram bpeprefix="data/bpe_${bpemode}_${nbpe}" bpemodel=${bpeprefix}.model -if [ $# != 3 ];then - echo "usage: ${0} config_path dict_path ckpt_path_prefix" - exit -1 +config_path=conf/transformer.yaml +dict=data/bpe_unigram_5000_units.txt +ckpt_prefix= + +source ${MAIN_ROOT}/utils/parse_options.sh || exit 1; + +if [ -z ${ckpt_prefix} ]; then + echo "usage: $0 --ckpt_prefix ckpt_prefix" + exit 1 fi ngpu=$(echo $CUDA_VISIBLE_DEVICES | awk -F "," '{print NF}') echo "using $ngpu gpus..." -config_path=$1 -dict=$2 -ckpt_prefix=$3 - - ckpt_dir=$(dirname `dirname ${ckpt_prefix}`) echo "ckpt dir: ${ckpt_dir}" diff --git a/requirements.txt b/requirements.txt index 42fd645a5..a7310a024 100644 --- a/requirements.txt +++ b/requirements.txt @@ -1,43 +1,43 @@ +ConfigArgParse coverage editdistance +g2p_en +g2pM gpustat +h5py +inflect +jieba jsonlines kaldiio +librosa +llvmlite loguru +matplotlib +nltk +numba +numpy==1.20.0 +pandas +phkit Pillow +praatio~=4.1 pre-commit pybind11 +pypinyin +pyworld resampy==0.2.2 sacrebleu scipy==1.2.1 sentencepiece snakeviz +soundfile~=0.10 sox tensorboardX textgrid +timer tqdm typeguard -visualdl==2.2.0 -yacs -numpy==1.20.0 -numba -nltk -inflect -librosa unidecode -llvmlite -matplotlib -pandas -soundfile~=0.10 -g2p_en -pypinyin +visualdl==2.2.0 webrtcvad -g2pM -praatio~=4.1 -h5py -timer -pyworld -jieba -phkit +yacs yq -ConfigArgParse \ No newline at end of file diff --git a/setup.py b/setup.py index 07b6aac04..bd982129c 100644 --- a/setup.py +++ b/setup.py @@ -11,20 +11,20 @@ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. +import contextlib +import inspect import io import os import re +import subprocess as sp import sys from pathlib import Path -import contextlib -import inspect +from setuptools import Command from setuptools import find_packages from setuptools import setup -from setuptools import Command from setuptools.command.develop import develop from setuptools.command.install import install -import subprocess as sp HERE = Path(os.path.abspath(os.path.dirname(__file__))) @@ -40,16 +40,18 @@ def pushd(new_dir): def read(*names, **kwargs): - with io.open(os.path.join(os.path.dirname(__file__), *names), - encoding=kwargs.get("encoding", "utf8")) as fp: + with io.open( + os.path.join(os.path.dirname(__file__), *names), + encoding=kwargs.get("encoding", "utf8")) as fp: return fp.read() def check_call(cmd: str, shell=False, executable=None): try: - sp.check_call(cmd.split(), - shell=shell, - executable="/bin/bash" if shell else executable) + sp.check_call( + cmd.split(), + shell=shell, + executable="/bin/bash" if shell else executable) except sp.CalledProcessError as e: print( f"{__file__}:{inspect.currentframe().f_lineno}: CMD: {cmd}, Error:", @@ -189,7 +191,6 @@ setup_info = dict( 'Programming Language :: Python :: 3.6', 'Programming Language :: Python :: 3.7', 'Programming Language :: Python :: 3.8', - ], -) + ], ) setup(**setup_info)