diff --git a/deepspeech/__init__.py b/deepspeech/__init__.py index 5f9ba007..ca5b4ba5 100644 --- a/deepspeech/__init__.py +++ b/deepspeech/__init__.py @@ -233,7 +233,7 @@ def is_broadcastable(shp1, shp2): def masked_fill(xs: paddle.Tensor, mask: paddle.Tensor, value: Union[float, int]): - assert is_broadcastable(xs.shape, mask.shape) is True + assert is_broadcastable(xs.shape, mask.shape) is True, (xs.shape, mask.shape) bshape = paddle.broadcast_shape(xs.shape, mask.shape) mask = mask.broadcast_to(bshape) trues = paddle.ones_like(xs) * value @@ -312,18 +312,18 @@ if not hasattr(paddle.Tensor, 'type_as'): def to(x: paddle.Tensor, *args, **kwargs) -> paddle.Tensor: - assert len(args) == 1 - if isinstance(args[0], str): # dtype - return x.astype(args[0]) - elif isinstance(args[0], paddle.Tensor): #Tensor - return x.astype(args[0].dtype) - else: # Device - return x + assert len(args) == 1 + if isinstance(args[0], str): # dtype + return x.astype(args[0]) + elif isinstance(args[0], paddle.Tensor): # Tensor + return x.astype(args[0].dtype) + else: # Device + return x if not hasattr(paddle.Tensor, 'to'): - logger.debug("register user to to paddle.Tensor, remove this when fixed!") - setattr(paddle.Tensor, 'to', to) + logger.debug("register user to to paddle.Tensor, remove this when fixed!") + setattr(paddle.Tensor, 'to', to) def func_float(x: paddle.Tensor) -> paddle.Tensor: diff --git a/deepspeech/decoders/beam_search.py b/deepspeech/decoders/beam_search.py index 5bf0d3e2..afb8aefa 100644 --- a/deepspeech/decoders/beam_search.py +++ b/deepspeech/decoders/beam_search.py @@ -1,7 +1,6 @@ """Beam search module.""" from itertools import chain -import logger from typing import Any from typing import Dict from typing import List @@ -141,7 +140,7 @@ class BeamSearch(paddle.nn.Layer): ] @staticmethod - def append_token(xs: paddle.Tensor, x: int) -> paddle.Tensor: + def append_token(xs: paddle.Tensor, x: Union[int, paddle.Tensor]) -> paddle.Tensor: """Append new token to prefix tokens. Args: @@ -152,8 +151,8 @@ class BeamSearch(paddle.nn.Layer): paddle.Tensor: (T+1,), New tensor contains: xs + [x] with xs.dtype and xs.device """ - x = paddle.to_tensor([x], dtype=xs.dtype, place=xs.place) - return paddle.cat((xs, x)) + x = paddle.to_tensor([x], dtype=xs.dtype) if isinstance(x, int) else x + return paddle.concat((xs, x)) def score_full( self, hyp: Hypothesis, x: paddle.Tensor @@ -306,7 +305,7 @@ class BeamSearch(paddle.nn.Layer): part_ids = paddle.arange(self.n_vocab) # no pre-beam for hyp in running_hyps: # scoring - weighted_scores = paddle.zeros(self.n_vocab, dtype=x.dtype) + weighted_scores = paddle.zeros([self.n_vocab], dtype=x.dtype) scores, states = self.score_full(hyp, x) for k in self.full_scorers: weighted_scores += self.weights[k] * scores[k] @@ -410,15 +409,20 @@ class BeamSearch(paddle.nn.Layer): best = nbest_hyps[0] for k, v in best.scores.items(): logger.info( - f"{v:6.2f} * {self.weights[k]:3} = {v * self.weights[k]:6.2f} for {k}" + f"{float(v):6.2f} * {self.weights[k]:3} = {float(v) * self.weights[k]:6.2f} for {k}" ) - logger.info(f"total log probability: {best.score:.2f}") - logger.info(f"normalized log probability: {best.score / len(best.yseq):.2f}") + logger.info(f"total log probability: {float(best.score):.2f}") + logger.info(f"normalized log probability: {float(best.score) / len(best.yseq):.2f}") logger.info(f"total number of ended hypotheses: {len(nbest_hyps)}") if self.token_list is not None: + # logger.info( + # "best hypo: " + # + "".join([self.token_list[x] for x in best.yseq[1:-1]]) + # + "\n" + # ) logger.info( "best hypo: " - + "".join([self.token_list[x] for x in best.yseq[1:-1]]) + + "".join([self.token_list[x] for x in best.yseq[1:]]) + "\n" ) return nbest_hyps diff --git a/deepspeech/decoders/recog.py b/deepspeech/decoders/recog.py index 399c5c54..867569aa 100644 --- a/deepspeech/decoders/recog.py +++ b/deepspeech/decoders/recog.py @@ -2,18 +2,22 @@ import json import paddle +import yaml +from yacs.config import CfgNode +from pathlib import Path +import jsonlines # from espnet.asr.asr_utils import get_model_conf # from espnet.asr.asr_utils import torch_load # from espnet.asr.pytorch_backend.asr import load_trained_model # from espnet.nets.lm_interface import dynamic_import_lm -# from espnet.nets.asr_interface import ASRInterface +from deepspeech.models.asr_interface import ASRInterface from .utils import add_results_to_json # from .batch_beam_search import BatchBeamSearch from .beam_search import BeamSearch -from .scorer_interface import BatchScorerInterface +from .scorers.scorer_interface import BatchScorerInterface from .scorers.length_bonus import LengthBonus from deepspeech.io.reader import LoadInputsAndTargets @@ -21,6 +25,14 @@ from deepspeech.utils.log import Log logger = Log(__name__).getlog() +from deepspeech.utils.dynamic_import import dynamic_import +from deepspeech.utils.utility import print_arguments + +model_test_alias = { + "u2": "deepspeech.exps.u2.model:U2Tester", + "u2_kaldi": "deepspeech.exps.u2_kaldi.model:U2Tester", +} + def recog_v2(args): """Decode with custom models that implements ScorerInterface. @@ -36,16 +48,31 @@ def recog_v2(args): raise NotImplementedError("streaming mode is not implemented") if args.word_rnnlm: raise NotImplementedError("word LM is not implemented") - + args.nprocs = args.ngpu # set_deterministic(args) - model, train_args = load_trained_model(args.model) - # assert isinstance(model, ASRInterface) - model.eval() + + #model, train_args = load_trained_model(args.model) + model_path = Path(args.model) + ckpt_dir = model_path.parent.parent + + confs = CfgNode() + confs.set_new_allowed(True) + confs.merge_from_file(args.model_conf) + + class_obj = dynamic_import(args.model_name, model_test_alias) + exp = class_obj(confs, args) + with exp.eval(): + exp.setup() + exp.restore() + char_list = exp.args.char_list + + model = exp.model + assert isinstance(model, ASRInterface) load_inputs_and_targets = LoadInputsAndTargets( mode="asr", load_output=False, sort_in_input_length=False, - preprocess_conf=train_args.preprocess_conf + preprocess_conf=confs.collator.augmentation_config if args.preprocess_conf is None else args.preprocess_conf, preprocess_args={"train": False}, @@ -56,7 +83,7 @@ def recog_v2(args): # NOTE: for a compatibility with less than 0.5.0 version models lm_model_module = getattr(lm_args, "model_module", "default") lm_class = dynamic_import_lm(lm_model_module, lm_args.backend) - lm = lm_class(len(train_args.char_list), lm_args) + lm = lm_class(len(char_list), lm_args) torch_load(args.rnnlm, lm) lm.eval() else: @@ -67,16 +94,16 @@ def recog_v2(args): from .scorers.ngram import NgramPartScorer if args.ngram_scorer == "full": - ngram = NgramFullScorer(args.ngram_model, train_args.char_list) + ngram = NgramFullScorer(args.ngram_model, char_list) else: - ngram = NgramPartScorer(args.ngram_model, train_args.char_list) + ngram = NgramPartScorer(args.ngram_model, char_list) else: ngram = None scorers = model.scorers() scorers["lm"] = lm scorers["ngram"] = ngram - scorers["length_bonus"] = LengthBonus(len(train_args.char_list)) + scorers["length_bonus"] = LengthBonus(len(char_list)) weights = dict( decoder=1.0 - args.ctc_weight, ctc=args.ctc_weight, @@ -86,14 +113,15 @@ def recog_v2(args): ) beam_search = BeamSearch( beam_size=args.beam_size, - vocab_size=len(train_args.char_list), + vocab_size=len(char_list), weights=weights, scorers=scorers, sos=model.sos, eos=model.eos, - token_list=train_args.char_list, + token_list=char_list, pre_beam_score_key=None if args.ctc_weight == 1.0 else "full", ) + # TODO(karita): make all scorers batchfied if args.batchsize == 1: non_batch = [ @@ -116,6 +144,7 @@ def recog_v2(args): device = "gpu:0" else: device = "cpu" + paddle.set_device(device) dtype = getattr(paddle, args.dtype) logger.info(f"Decoding device={device}, dtype={dtype}") model.to(device=device, dtype=dtype) @@ -124,31 +153,41 @@ def recog_v2(args): beam_search.eval() # read json data - with open(args.recog_json, "rb") as f: - js = json.load(f) + js = [] + with jsonlines.open(args.recog_json, "r") as reader: + for item in reader: + js.append(item) # josnlines to dict, key by 'utt' js = {item['utt']: item for item in js} new_js = {} with paddle.no_grad(): - for idx, name in enumerate(js.keys(), 1): - logger.info("(%d/%d) decoding " + name, idx, len(js.keys())) - batch = [(name, js[name])] - feat = load_inputs_and_targets(batch)[0][0] - enc = model.encode(paddle.to_tensor(feat).to(device=device, dtype=dtype)) - nbest_hyps = beam_search( - x=enc, maxlenratio=args.maxlenratio, minlenratio=args.minlenratio - ) - nbest_hyps = [ - h.asdict() for h in nbest_hyps[: min(len(nbest_hyps), args.nbest)] - ] - new_js[name] = add_results_to_json( - js[name], nbest_hyps, train_args.char_list - ) - - with open(args.result_label, "wb") as f: - f.write( - json.dumps( - {"utts": new_js}, indent=4, ensure_ascii=False, sort_keys=True - ).encode("utf_8") - ) + with jsonlines.open(args.result_label, "w") as f: + for idx, name in enumerate(js.keys(), 1): + logger.info(f"({idx}/{len(js.keys())}) decoding " + name) + batch = [(name, js[name])] + feat = load_inputs_and_targets(batch)[0][0] + logger.info(f'feat: {feat.shape}') + enc = model.encode(paddle.to_tensor(feat).to(dtype)) + logger.info(f'eouts: {enc.shape}') + nbest_hyps = beam_search( + x=enc, maxlenratio=args.maxlenratio, minlenratio=args.minlenratio + ) + nbest_hyps = [ + h.asdict() for h in nbest_hyps[: min(len(nbest_hyps), args.nbest)] + ] + new_js[name] = add_results_to_json( + js[name], nbest_hyps, char_list + ) + + item = new_js[name]['output'][0] # 1-best + utt = name + ref = item['text'] + rec_text = item['rec_text'].replace('▁', ' ').replace('', '').strip() + rec_tokenid = map(int, item['rec_tokenid'].split()) + f.write({ + "utt": utt, + "refs": [ref], + "hyps": [rec_text], + "hyps_tokenid": [rec_tokenid], + }) \ No newline at end of file diff --git a/deepspeech/decoders/scorers/ctc.py b/deepspeech/decoders/scorers/ctc.py index 36b4bfd3..4871d6e1 100644 --- a/deepspeech/decoders/scorers/ctc.py +++ b/deepspeech/decoders/scorers/ctc.py @@ -15,8 +15,8 @@ import numpy as np import paddle -from .ctc_prefix_score import CTCPrefixScorer -from .ctc_prefix_score import CTCPrefixScorerPD +from .ctc_prefix_score import CTCPrefixScore +from .ctc_prefix_score import CTCPrefixScorePD from .scorer_interface import BatchPartialScorerInterface diff --git a/deepspeech/decoders/scorers/ctc_prefix_score.py b/deepspeech/decoders/scorers/ctc_prefix_score.py index 5f568c81..c85d546d 100644 --- a/deepspeech/decoders/scorers/ctc_prefix_score.py +++ b/deepspeech/decoders/scorers/ctc_prefix_score.py @@ -6,7 +6,7 @@ import paddle import six -class CTCPrefixScorerPD(): +class CTCPrefixScorePD(): """Batch processing of CTCPrefixScore which is based on Algorithm 2 in WATANABE et al. @@ -267,7 +267,7 @@ class CTCPrefixScorerPD(): return (r_prev_new, s_prev, f_min_prev, f_max_prev) -class CTCPrefixScorer(): +class CTCPrefixScore(): """Compute CTC label sequence scores which is based on Algorithm 2 in WATANABE et al. diff --git a/deepspeech/decoders/scorers/score_interface.py b/deepspeech/decoders/scorers/scorer_interface.py similarity index 100% rename from deepspeech/decoders/scorers/score_interface.py rename to deepspeech/decoders/scorers/scorer_interface.py diff --git a/deepspeech/decoders/utils.py b/deepspeech/decoders/utils.py index 0281a78b..f59b55d9 100644 --- a/deepspeech/decoders/utils.py +++ b/deepspeech/decoders/utils.py @@ -12,7 +12,11 @@ # See the License for the specific language governing permissions and # limitations under the License. -__all__ = ["end_detect"] +import numpy as np +from deepspeech.utils.log import Log +logger = Log(__name__).getlog() + +__all__ = ["end_detect", "parse_hypothesis", "add_results_to_json"] def end_detect(ended_hyps, i, M=3, D_end=np.log(1 * np.exp(-10))): @@ -118,7 +122,7 @@ def add_results_to_json(js, nbest_hyps, char_list): # show 1-best result if n == 1: if "text" in out_dic.keys(): - logging.info("groundtruth: %s" % out_dic["text"]) - logging.info("prediction : %s" % out_dic["rec_text"]) + logger.info("groundtruth: %s" % out_dic["text"]) + logger.info("prediction : %s" % out_dic["rec_text"]) return new_js \ No newline at end of file diff --git a/deepspeech/exps/u2_kaldi/bin/recog.py b/deepspeech/exps/u2_kaldi/bin/recog.py new file mode 100644 index 00000000..60f4e58a --- /dev/null +++ b/deepspeech/exps/u2_kaldi/bin/recog.py @@ -0,0 +1,379 @@ + +"""End-to-end speech recognition model decoding script.""" + +import configargparse +import logging +import os +import random +import sys + +import numpy as np + +from distutils.util import strtobool +from deepspeech.training.cli import default_argument_parser + +# NOTE: you need this func to generate our sphinx doc + +def get_parser(): + """Get default arguments.""" + parser = configargparse.ArgumentParser( + description="Transcribe text from speech using " + "a speech recognition model on one CPU or GPU", + config_file_parser_class=configargparse.YAMLConfigFileParser, + formatter_class=configargparse.ArgumentDefaultsHelpFormatter, + ) + parser.add( + '--model-name', + type=str, + default='u2_kaldi', + help='model name, e.g: deepspeech2, u2, u2_kaldi, u2_st') + # general configuration + parser.add("--config", is_config_file=True, help="Config file path") + parser.add( + "--config2", + is_config_file=True, + help="Second config file path that overwrites the settings in `--config`", + ) + parser.add( + "--config3", + is_config_file=True, + help="Third config file path that overwrites the settings " + "in `--config` and `--config2`", + ) + + parser.add_argument("--ngpu", type=int, default=0, help="Number of GPUs") + parser.add_argument( + "--dtype", + choices=("float16", "float32", "float64"), + default="float32", + help="Float precision (only available in --api v2)", + ) + parser.add_argument("--debugmode", type=int, default=1, help="Debugmode") + parser.add_argument("--seed", type=int, default=1, help="Random seed") + parser.add_argument("--verbose", "-V", type=int, default=2, help="Verbose option") + parser.add_argument( + "--batchsize", + type=int, + default=1, + help="Batch size for beam search (0: means no batch processing)", + ) + parser.add_argument( + "--preprocess-conf", + type=str, + default=None, + help="The configuration file for the pre-processing", + ) + parser.add_argument( + "--api", + default="v2", + choices=["v2"], + help="Beam search APIs " + "v2: Experimental API. It supports any models that implements ScorerInterface.", + ) + # task related + parser.add_argument( + "--recog-json", type=str, help="Filename of recognition data (json)" + ) + parser.add_argument( + "--result-label", + type=str, + required=True, + help="Filename of result label data (json)", + ) + # model (parameter) related + parser.add_argument( + "--model", type=str, required=True, help="Model file parameters to read" + ) + parser.add_argument( + "--model-conf", type=str, default=None, help="Model config file" + ) + parser.add_argument( + "--num-spkrs", + type=int, + default=1, + choices=[1, 2], + help="Number of speakers in the speech", + ) + parser.add_argument( + "--num-encs", default=1, type=int, help="Number of encoders in the model." + ) + # search related + parser.add_argument("--nbest", type=int, default=1, help="Output N-best hypotheses") + parser.add_argument("--beam-size", type=int, default=1, help="Beam size") + parser.add_argument("--penalty", type=float, default=0.0, help="Incertion penalty") + parser.add_argument( + "--maxlenratio", + type=float, + default=0.0, + help="""Input length ratio to obtain max output length. + If maxlenratio=0.0 (default), it uses a end-detect function + to automatically find maximum hypothesis lengths. + If maxlenratio<0.0, its absolute value is interpreted + as a constant max output length""", + ) + parser.add_argument( + "--minlenratio", + type=float, + default=0.0, + help="Input length ratio to obtain min output length", + ) + parser.add_argument( + "--ctc-weight", type=float, default=0.0, help="CTC weight in joint decoding" + ) + parser.add_argument( + "--weights-ctc-dec", + type=float, + action="append", + help="ctc weight assigned to each encoder during decoding." + "[in multi-encoder mode only]", + ) + parser.add_argument( + "--ctc-window-margin", + type=int, + default=0, + help="""Use CTC window with margin parameter to accelerate + CTC/attention decoding especially on GPU. Smaller magin + makes decoding faster, but may increase search errors. + If margin=0 (default), this function is disabled""", + ) + # transducer related + parser.add_argument( + "--search-type", + type=str, + default="default", + choices=["default", "nsc", "tsd", "alsd", "maes"], + help="""Type of beam search implementation to use during inference. + Can be either: default beam search ("default"), + N-Step Constrained beam search ("nsc"), Time-Synchronous Decoding ("tsd"), + Alignment-Length Synchronous Decoding ("alsd") or + modified Adaptive Expansion Search ("maes").""", + ) + parser.add_argument( + "--nstep", + type=int, + default=1, + help="""Number of expansion steps allowed in NSC beam search or mAES + (nstep > 0 for NSC and nstep > 1 for mAES).""", + ) + parser.add_argument( + "--prefix-alpha", + type=int, + default=2, + help="Length prefix difference allowed in NSC beam search or mAES.", + ) + parser.add_argument( + "--max-sym-exp", + type=int, + default=2, + help="Number of symbol expansions allowed in TSD.", + ) + parser.add_argument( + "--u-max", + type=int, + default=400, + help="Length prefix difference allowed in ALSD.", + ) + parser.add_argument( + "--expansion-gamma", + type=float, + default=2.3, + help="Allowed logp difference for prune-by-value method in mAES.", + ) + parser.add_argument( + "--expansion-beta", + type=int, + default=2, + help="""Number of additional candidates for expanded hypotheses + selection in mAES.""", + ) + parser.add_argument( + "--score-norm", + type=strtobool, + nargs="?", + default=True, + help="Normalize final hypotheses' score by length", + ) + parser.add_argument( + "--softmax-temperature", + type=float, + default=1.0, + help="Penalization term for softmax function.", + ) + # rnnlm related + parser.add_argument( + "--rnnlm", type=str, default=None, help="RNNLM model file to read" + ) + parser.add_argument( + "--rnnlm-conf", type=str, default=None, help="RNNLM model config file to read" + ) + parser.add_argument( + "--word-rnnlm", type=str, default=None, help="Word RNNLM model file to read" + ) + parser.add_argument( + "--word-rnnlm-conf", + type=str, + default=None, + help="Word RNNLM model config file to read", + ) + parser.add_argument("--word-dict", type=str, default=None, help="Word list to read") + parser.add_argument("--lm-weight", type=float, default=0.1, help="RNNLM weight") + # ngram related + parser.add_argument( + "--ngram-model", type=str, default=None, help="ngram model file to read" + ) + parser.add_argument("--ngram-weight", type=float, default=0.1, help="ngram weight") + parser.add_argument( + "--ngram-scorer", + type=str, + default="part", + choices=("full", "part"), + help="""if the ngram is set as a part scorer, similar with CTC scorer, + ngram scorer only scores topK hypethesis. + if the ngram is set as full scorer, ngram scorer scores all hypthesis + the decoding speed of part scorer is musch faster than full one""", + ) + # streaming related + parser.add_argument( + "--streaming-mode", + type=str, + default=None, + choices=["window", "segment"], + help="""Use streaming recognizer for inference. + `--batchsize` must be set to 0 to enable this mode""", + ) + parser.add_argument("--streaming-window", type=int, default=10, help="Window size") + parser.add_argument( + "--streaming-min-blank-dur", + type=int, + default=10, + help="Minimum blank duration threshold", + ) + parser.add_argument( + "--streaming-onset-margin", type=int, default=1, help="Onset margin" + ) + parser.add_argument( + "--streaming-offset-margin", type=int, default=1, help="Offset margin" + ) + # non-autoregressive related + # Mask CTC related. See https://arxiv.org/abs/2005.08700 for the detail. + parser.add_argument( + "--maskctc-n-iterations", + type=int, + default=10, + help="Number of decoding iterations." + "For Mask CTC, set 0 to predict 1 mask/iter.", + ) + parser.add_argument( + "--maskctc-probability-threshold", + type=float, + default=0.999, + help="Threshold probability for CTC output", + ) + # quantize model related + parser.add_argument( + "--quantize-config", + nargs="*", + help="Quantize config list. E.g.: --quantize-config=[Linear,LSTM,GRU]", + ) + parser.add_argument( + "--quantize-dtype", type=str, default="qint8", help="Dtype dynamic quantize" + ) + parser.add_argument( + "--quantize-asr-model", + type=bool, + default=False, + help="Quantize asr model", + ) + parser.add_argument( + "--quantize-lm-model", + type=bool, + default=False, + help="Quantize lm model", + ) + return parser + + +def main(args): + """Run the main decoding function.""" + parser = get_parser() + parser.add_argument( + "--output", metavar="CKPT_DIR", help="path to save checkpoint.") + parser.add_argument( + "--checkpoint_path", type=str, help="path to load checkpoint") + parser.add_argument( + "--dict-path", type=str, help="path to load checkpoint") + # parser = default_argument_parser(parser) + args = parser.parse_args(args) + + if args.ngpu == 0 and args.dtype == "float16": + raise ValueError(f"--dtype {args.dtype} does not support the CPU backend.") + + # logging info + if args.verbose == 1: + logging.basicConfig( + level=logging.INFO, + format="%(asctime)s (%(module)s:%(lineno)d) %(levelname)s: %(message)s", + ) + elif args.verbose == 2: + logging.basicConfig( + level=logging.DEBUG, + format="%(asctime)s (%(module)s:%(lineno)d) %(levelname)s: %(message)s", + ) + else: + logging.basicConfig( + level=logging.WARN, + format="%(asctime)s (%(module)s:%(lineno)d) %(levelname)s: %(message)s", + ) + logging.warning("Skip DEBUG/INFO messages") + logging.info(args) + + # check CUDA_VISIBLE_DEVICES + if args.ngpu > 0: + cvd = os.environ.get("CUDA_VISIBLE_DEVICES") + if cvd is None: + logging.warning("CUDA_VISIBLE_DEVICES is not set.") + elif args.ngpu != len(cvd.split(",")): + logging.error("#gpus is not matched with CUDA_VISIBLE_DEVICES.") + sys.exit(1) + + # TODO(mn5k): support of multiple GPUs + if args.ngpu > 1: + logging.error("The program only supports ngpu=1.") + sys.exit(1) + + # display PYTHONPATH + logging.info("python path = " + os.environ.get("PYTHONPATH", "(None)")) + + # seed setting + random.seed(args.seed) + np.random.seed(args.seed) + logging.info("set random seed = %d" % args.seed) + + # validate rnn options + if args.rnnlm is not None and args.word_rnnlm is not None: + logging.error( + "It seems that both --rnnlm and --word-rnnlm are specified. " + "Please use either option." + ) + sys.exit(1) + + # recog + if args.num_spkrs == 1: + if args.num_encs == 1: + # Experimental API that supports custom LMs + if args.api == "v2": + from deepspeech.decoders.recog import recog_v2 + recog_v2(args) + else: + raise ValueError("Only support --api v2") + else: + if args.api == "v2": + raise NotImplementedError( + f"--num-encs {args.num_encs} > 1 is not supported in --api v2" + ) + elif args.num_spkrs == 2: + raise ValueError("asr_mix not supported.") + + +if __name__ == "__main__": + main(sys.argv[1:]) diff --git a/deepspeech/exps/u2_kaldi/model.py b/deepspeech/exps/u2_kaldi/model.py index 18e29b28..3aadca85 100644 --- a/deepspeech/exps/u2_kaldi/model.py +++ b/deepspeech/exps/u2_kaldi/model.py @@ -317,11 +317,9 @@ class U2Trainer(Trainer): with UpdateConfig(model_conf): model_conf.input_dim = self.train_loader.feat_dim model_conf.output_dim = self.train_loader.vocab_size - model = U2Model.from_config(model_conf) if self.parallel: model = paddle.DataParallel(model) - logger.info(f"{model}") layer_tools.print_params(model, logger.info) # lr diff --git a/deepspeech/frontend/featurizer/text_featurizer.py b/deepspeech/frontend/featurizer/text_featurizer.py index 34220432..18ff411b 100644 --- a/deepspeech/frontend/featurizer/text_featurizer.py +++ b/deepspeech/frontend/featurizer/text_featurizer.py @@ -207,7 +207,7 @@ class TextFeaturizer(): """Load vocabulary from file.""" vocab_list = load_dict(vocab_filepath, maskctc) assert vocab_list is not None - logger.info(f"Vocab: {pformat(vocab_list)}") + logger.debug(f"Vocab: {pformat(vocab_list)}") id2token = dict( [(idx, token) for (idx, token) in enumerate(vocab_list)]) diff --git a/deepspeech/models/u2/u2.py b/deepspeech/models/u2/u2.py index 8915cbd7..fa517906 100644 --- a/deepspeech/models/u2/u2.py +++ b/deepspeech/models/u2/u2.py @@ -50,7 +50,7 @@ from deepspeech.utils.tensor_utils import th_accuracy from deepspeech.utils.utility import log_add from deepspeech.utils.utility import UpdateConfig from deepspeech.models.asr_interface import ASRInterface -from deepspeech.decoders.scorers.ctc_prefix_score import CTCPrefixScorer +from deepspeech.decoders.scorers.ctc import CTCPrefixScorer __all__ = ["U2Model", "U2InferModel"] diff --git a/deepspeech/modules/decoder.py b/deepspeech/modules/decoder.py index 154b7390..ee3572ea 100644 --- a/deepspeech/modules/decoder.py +++ b/deepspeech/modules/decoder.py @@ -28,7 +28,7 @@ from deepspeech.modules.mask import make_non_pad_mask from deepspeech.modules.mask import subsequent_mask from deepspeech.modules.mask import make_xs_mask from deepspeech.modules.positionwise_feed_forward import PositionwiseFeedForward -from deepspeech.decoders.scorers.score_interface import BatchScorerInterface +from deepspeech.decoders.scorers.scorer_interface import BatchScorerInterface from deepspeech.utils.log import Log logger = Log(__name__).getlog() @@ -191,8 +191,8 @@ class TransformerDecoder(BatchScorerInterface, nn.Layer): ys: (ylen,) x: (xlen, n_feat) """ - ys_mask = subsequent_mask(len(ys)).unsqueeze(0) - x_mask = make_xs_mask(x.unsqueeze(0)) + ys_mask = subsequent_mask(len(ys)).unsqueeze(0) # (B,L,L) + x_mask = make_xs_mask(x.unsqueeze(0)).unsqueeze(1) # (B,1,T) if self.selfattention_layer_type != "selfattn": # TODO(karita): implement cache logging.warning( @@ -237,9 +237,8 @@ class TransformerDecoder(BatchScorerInterface, nn.Layer): ] # batch decoding - ys_mask = subsequent_mask(ys.size(-1)).unsqueeze(0) - - xs_mask = make_xs_mask(xs) + ys_mask = subsequent_mask(ys.size(-1)).unsqueeze(0) # (B,L,L) + xs_mask = make_xs_mask(xs).unsqueeze(1) # (B,1,T) logp, states = self.forward_one_step(xs, xs_mask, ys, ys_mask, cache=batch_state) # transpose state of [layer, batch] into [batch, layer] diff --git a/deepspeech/modules/mask.py b/deepspeech/modules/mask.py index cffa10a7..7ae418c0 100644 --- a/deepspeech/modules/mask.py +++ b/deepspeech/modules/mask.py @@ -24,15 +24,16 @@ __all__ = [ ] -def make_xs_mask(xs:paddle.Tensor) -> paddle.Tensor: +def make_xs_mask(xs:paddle.Tensor, pad_value=0.0) -> paddle.Tensor: """Maks mask tensor containing indices of non-padded part. Args: xs (paddle.Tensor): (B, T, D), zeros for pad. Returns: - paddle.Tensor: Mask Tensor indices of non-padded part. (B, T, D) + paddle.Tensor: Mask Tensor indices of non-padded part. (B, T) """ - pad_frame = paddle.zeros([1, 1, xs.shape[-1]], dtype=xs.dtype) + pad_frame = paddle.full([1, 1, xs.shape[-1]], pad_value, dtype=xs.dtype) mask = xs != pad_frame + mask = mask.all(axis=-1) return mask diff --git a/deepspeech/training/cli.py b/deepspeech/training/cli.py index ca9652e3..db7076d3 100644 --- a/deepspeech/training/cli.py +++ b/deepspeech/training/cli.py @@ -35,7 +35,7 @@ class LoadFromFile(argparse.Action): parser.parse_args(f.read().split(), namespace) -def default_argument_parser(): +def default_argument_parser(parser=None): r"""A simple yet genral argument parser for experiments with parakeet. This is used in examples with parakeet. And it is intended to be used by @@ -62,7 +62,9 @@ def default_argument_parser(): argparse.ArgumentParser the parser """ - parser = argparse.ArgumentParser() + if parser is None: + parser = argparse.ArgumentParser() + parser.register('action', 'extend', ExtendAction) parser.add_argument( '--conf', type=open, action=LoadFromFile, help="config file.") diff --git a/deepspeech/training/trainer.py b/deepspeech/training/trainer.py index 6c18fa36..b2ee7a1b 100644 --- a/deepspeech/training/trainer.py +++ b/deepspeech/training/trainer.py @@ -126,7 +126,7 @@ class Trainer(): logger.info(f"Set seed {args.seed}") # profiler and benchmark options - if self.args.benchmark_batch_size: + if hasattr(self.args, "benchmark_batch_size") and self.args.benchmark_batch_size: with UpdateConfig(self.config): self.config.collator.batch_size = self.args.benchmark_batch_size self.config.training.log_interval = 1 @@ -326,12 +326,25 @@ class Trainer(): finally: self.destory() + def restore(self): + """Resume from latest checkpoint at checkpoints in the output + directory or load a specified checkpoint. + + If ``args.checkpoint_path`` is not None, load the checkpoint, else + resume training. + """ + assert self.args.checkpoint_path + infos = self.checkpoint.load_latest_parameters( + self.model, + checkpoint_path=self.args.checkpoint_path) + return infos + def run_test(self): """Do Test/Decode""" try: with Timer("Test/Decode Done: {}"): with self.eval(): - self.resume_or_scratch() + self.restore() self.test() except KeyboardInterrupt: exit(-1) @@ -341,6 +354,7 @@ class Trainer(): try: with Timer("Export Done: {}"): with self.eval(): + self.restore() self.export() except KeyboardInterrupt: exit(-1) @@ -350,7 +364,7 @@ class Trainer(): try: with Timer("Align Done: {}"): with self.eval(): - self.resume_or_scratch() + self.restore() self.align() except KeyboardInterrupt: sys.exit(-1) diff --git a/examples/librispeech/s2/README.md b/examples/librispeech/s2/README.md index 1f7c6919..6e41cc37 100644 --- a/examples/librispeech/s2/README.md +++ b/examples/librispeech/s2/README.md @@ -11,3 +11,4 @@ | test-clean | ctc_greedy_search | 2620 | 52576 | 95.9 | 3.7 | 0.4 | 0.5 | 4.6 | 48.0 | | test-clean | ctc_prefix_beamsearch | 2620 | 52576 | 95.9 | 3.7 | 0.4 | 0.5 | 4.6 | 47.6 | | test-clean | attention_rescore | 2620 | 52576 | 96.8 | 2.9 | 0.3 | 0.4 | 3.7 | 38.0 | +| test-clean | join_ctc_w/o_lm | 2620 | 52576 | 97.2 | 2.6 | 0.3 | 0.4 | 3.2 | 34.9 | diff --git a/examples/librispeech/s2/conf/decode/decode.yaml b/examples/librispeech/s2/conf/decode/decode.yaml new file mode 100644 index 00000000..867bf611 --- /dev/null +++ b/examples/librispeech/s2/conf/decode/decode.yaml @@ -0,0 +1,7 @@ +batchsize: 0 +beam-size: 60 +ctc-weight: 0.4 +lm-weight: 0.0 +maxlenratio: 0.0 +minlenratio: 0.0 +penalty: 0.0 diff --git a/examples/librispeech/s2/conf/decode/decode_all.yaml b/examples/librispeech/s2/conf/decode/decode_all.yaml new file mode 100644 index 00000000..87d5f6d1 --- /dev/null +++ b/examples/librispeech/s2/conf/decode/decode_all.yaml @@ -0,0 +1,7 @@ +batchsize: 0 +beam-size: 60 +ctc-weight: 0.4 +lm-weight: 0.6 +maxlenratio: 0.0 +minlenratio: 0.0 +penalty: 0.0 \ No newline at end of file diff --git a/examples/librispeech/s2/conf/decode/decode_wo_lm.yaml b/examples/librispeech/s2/conf/decode/decode_wo_lm.yaml new file mode 100644 index 00000000..a82dba23 --- /dev/null +++ b/examples/librispeech/s2/conf/decode/decode_wo_lm.yaml @@ -0,0 +1,7 @@ +batchsize: 0 +beam-size: 60 +ctc-weight: 0.4 +lm-weight: 0.0 +maxlenratio: 0.0 +minlenratio: 0.0 +penalty: 0.0 \ No newline at end of file diff --git a/examples/librispeech/s2/local/recog.sh b/examples/librispeech/s2/local/recog.sh new file mode 100755 index 00000000..0393535b --- /dev/null +++ b/examples/librispeech/s2/local/recog.sh @@ -0,0 +1,103 @@ +#!/bin/bash + +set -e + +expdir=exp +datadir=data +nj=32 + +decode_config=conf/decode/decode.yaml +lang_model=rnnlm.model.best +lmexpdir=exp/train_rnnlm_pytorch_lm_transformer_cosine_batchsize32_lr1e-4_layer16_unigram5000_ngpu4/ + +lmtag='nolm' + +recog_set="test-clean test-other dev-clean dev-other" +recog_set="test-clean" + +# bpemode (unigram or bpe) +nbpe=5000 +bpemode=unigram +bpeprefix="data/bpe_${bpemode}_${nbpe}" +bpemodel=${bpeprefix}.model + +if [ $# != 3 ];then + echo "usage: ${0} config_path dict_path ckpt_path_prefix" + exit -1 +fi + +ngpu=$(echo $CUDA_VISIBLE_DEVICES | awk -F "," '{print NF}') +echo "using $ngpu gpus..." + +config_path=$1 +dict=$2 +ckpt_prefix=$3 + +ckpt_dir=$(dirname `dirname ${ckpt_prefix}`) +echo "ckpt dir: ${ckpt_dir}" + +ckpt_tag=$(basename ${ckpt_prefix}) +echo "ckpt tag: ${ckpt_tag}" + +chunk_mode=false +if [[ ${config_path} =~ ^.*chunk_.*yaml$ ]];then + chunk_mode=true +fi +echo "chunk mode: ${chunk_mode}" +echo "decode conf: ${decode_config}" + +# download language model +#bash local/download_lm_en.sh +#if [ $? -ne 0 ]; then +# exit 1 +#fi + + +pids=() # initialize pids + +for dmethd in join_ctc; do +( + echo "${dmethd} decoding" + for rtask in ${recog_set}; do + ( + echo "${rtask} dataset" + decode_dir=${ckpt_dir}/decode/decode_${rtask/-/_}_${dmethd}_$(basename ${config_path%.*})_${lmtag}_${ckpt_tag} + feat_recog_dir=${datadir} + mkdir -p ${decode_dir} + mkdir -p ${feat_recog_dir} + + # split data + split_json.sh manifest.${rtask} ${nj} + + #### use CPU for decoding + ngpu=0 + + # set batchsize 0 to disable batch decoding + ${decode_cmd} JOB=1:${nj} ${decode_dir}/log/decode.JOB.log \ + python3 -u ${BIN_DIR}/recog.py \ + --api v2 \ + --config ${decode_config} \ + --ngpu ${ngpu} \ + --batchsize 0 \ + --checkpoint_path ${ckpt_prefix} \ + --dict-path ${dict} \ + --recog-json ${feat_recog_dir}/split${nj}/JOB/manifest.${rtask} \ + --result-label ${decode_dir}/data.JOB.json \ + --model-conf ${config_path} \ + --model ${ckpt_prefix}.pdparams + + #--rnnlm ${lmexpdir}/${lang_model} \ + + score_sclite.sh --bpe ${nbpe} --bpemodel ${bpemodel} --wer false ${decode_dir} ${dict} + + ) & + pids+=($!) # store background pids + i=0; for pid in "${pids[@]}"; do wait ${pid} || ((++i)); done + [ ${i} -gt 0 ] && echo "$0: ${i} background jobs are failed." || true + done +) +done + +echo "Finished" + +exit 0 diff --git a/examples/librispeech/s2/local/test.sh b/examples/librispeech/s2/local/test.sh index a6b6cfa9..006a13c5 100755 --- a/examples/librispeech/s2/local/test.sh +++ b/examples/librispeech/s2/local/test.sh @@ -83,7 +83,7 @@ for dmethd in attention ctc_greedy_search ctc_prefix_beam_search attention_resco --opts decoding.batch_size ${batch_size} \ --opts data.test_manifest ${feat_recog_dir}/split${nj}/JOB/manifest.${rtask} - score_sclite.sh --bpe ${nbpe} --bpemodel ${bpemodel} --wer false ${expdir}/${decode_dir} ${dict} + score_sclite.sh --bpe ${nbpe} --bpemodel ${bpemodel} --wer false ${decode_dir} ${dict} ) & pids+=($!) # store background pids diff --git a/requirements.txt b/requirements.txt index 2a3f0651..42fd645a 100644 --- a/requirements.txt +++ b/requirements.txt @@ -40,3 +40,4 @@ pyworld jieba phkit yq +ConfigArgParse \ No newline at end of file