parent
bb75735fac
commit
f2f305cd66
@ -0,0 +1,154 @@
|
||||
"""V2 backend for `asr_recog.py` using py:class:`espnet.nets.beam_search.BeamSearch`."""
|
||||
|
||||
import json
|
||||
import paddle
|
||||
|
||||
# from espnet.asr.asr_utils import get_model_conf
|
||||
# from espnet.asr.asr_utils import torch_load
|
||||
# from espnet.asr.pytorch_backend.asr import load_trained_model
|
||||
# from espnet.nets.lm_interface import dynamic_import_lm
|
||||
|
||||
# from espnet.nets.asr_interface import ASRInterface
|
||||
|
||||
from .utils import add_results_to_json
|
||||
# from .batch_beam_search import BatchBeamSearch
|
||||
from .beam_search import BeamSearch
|
||||
from .scorer_interface import BatchScorerInterface
|
||||
from .scorers.length_bonus import LengthBonus
|
||||
|
||||
from deepspeech.io.reader import LoadInputsAndTargets
|
||||
from deepspeech.utils.log import Log
|
||||
logger = Log(__name__).getlog()
|
||||
|
||||
|
||||
def recog_v2(args):
|
||||
"""Decode with custom models that implements ScorerInterface.
|
||||
|
||||
Args:
|
||||
args (namespace): The program arguments.
|
||||
See py:func:`bin.asr_recog.get_parser` for details
|
||||
|
||||
"""
|
||||
logger.warning("experimental API for custom LMs is selected by --api v2")
|
||||
if args.batchsize > 1:
|
||||
raise NotImplementedError("multi-utt batch decoding is not implemented")
|
||||
if args.streaming_mode is not None:
|
||||
raise NotImplementedError("streaming mode is not implemented")
|
||||
if args.word_rnnlm:
|
||||
raise NotImplementedError("word LM is not implemented")
|
||||
|
||||
# set_deterministic(args)
|
||||
model, train_args = load_trained_model(args.model)
|
||||
# assert isinstance(model, ASRInterface)
|
||||
model.eval()
|
||||
load_inputs_and_targets = LoadInputsAndTargets(
|
||||
mode="asr",
|
||||
load_output=False,
|
||||
sort_in_input_length=False,
|
||||
preprocess_conf=train_args.preprocess_conf
|
||||
if args.preprocess_conf is None
|
||||
else args.preprocess_conf,
|
||||
preprocess_args={"train": False},
|
||||
)
|
||||
|
||||
if args.rnnlm:
|
||||
lm_args = get_model_conf(args.rnnlm, args.rnnlm_conf)
|
||||
# NOTE: for a compatibility with less than 0.5.0 version models
|
||||
lm_model_module = getattr(lm_args, "model_module", "default")
|
||||
lm_class = dynamic_import_lm(lm_model_module, lm_args.backend)
|
||||
lm = lm_class(len(train_args.char_list), lm_args)
|
||||
torch_load(args.rnnlm, lm)
|
||||
lm.eval()
|
||||
else:
|
||||
lm = None
|
||||
|
||||
if args.ngram_model:
|
||||
from .scorers.ngram import NgramFullScorer
|
||||
from .scorers.ngram import NgramPartScorer
|
||||
|
||||
if args.ngram_scorer == "full":
|
||||
ngram = NgramFullScorer(args.ngram_model, train_args.char_list)
|
||||
else:
|
||||
ngram = NgramPartScorer(args.ngram_model, train_args.char_list)
|
||||
else:
|
||||
ngram = None
|
||||
|
||||
scorers = model.scorers()
|
||||
scorers["lm"] = lm
|
||||
scorers["ngram"] = ngram
|
||||
scorers["length_bonus"] = LengthBonus(len(train_args.char_list))
|
||||
weights = dict(
|
||||
decoder=1.0 - args.ctc_weight,
|
||||
ctc=args.ctc_weight,
|
||||
lm=args.lm_weight,
|
||||
ngram=args.ngram_weight,
|
||||
length_bonus=args.penalty,
|
||||
)
|
||||
beam_search = BeamSearch(
|
||||
beam_size=args.beam_size,
|
||||
vocab_size=len(train_args.char_list),
|
||||
weights=weights,
|
||||
scorers=scorers,
|
||||
sos=model.sos,
|
||||
eos=model.eos,
|
||||
token_list=train_args.char_list,
|
||||
pre_beam_score_key=None if args.ctc_weight == 1.0 else "full",
|
||||
)
|
||||
# TODO(karita): make all scorers batchfied
|
||||
if args.batchsize == 1:
|
||||
non_batch = [
|
||||
k
|
||||
for k, v in beam_search.full_scorers.items()
|
||||
if not isinstance(v, BatchScorerInterface)
|
||||
]
|
||||
if len(non_batch) == 0:
|
||||
beam_search.__class__ = BatchBeamSearch
|
||||
logger.info("BatchBeamSearch implementation is selected.")
|
||||
else:
|
||||
logger.warning(
|
||||
f"As non-batch scorers {non_batch} are found, "
|
||||
f"fall back to non-batch implementation."
|
||||
)
|
||||
|
||||
if args.ngpu > 1:
|
||||
raise NotImplementedError("only single GPU decoding is supported")
|
||||
if args.ngpu == 1:
|
||||
device = "gpu:0"
|
||||
else:
|
||||
device = "cpu"
|
||||
dtype = getattr(paddle, args.dtype)
|
||||
logger.info(f"Decoding device={device}, dtype={dtype}")
|
||||
model.to(device=device, dtype=dtype)
|
||||
model.eval()
|
||||
beam_search.to(device=device, dtype=dtype)
|
||||
beam_search.eval()
|
||||
|
||||
# read json data
|
||||
with open(args.recog_json, "rb") as f:
|
||||
js = json.load(f)
|
||||
# josnlines to dict, key by 'utt'
|
||||
js = {item['utt']: item for item in js}
|
||||
|
||||
new_js = {}
|
||||
with paddle.no_grad():
|
||||
for idx, name in enumerate(js.keys(), 1):
|
||||
logger.info("(%d/%d) decoding " + name, idx, len(js.keys()))
|
||||
batch = [(name, js[name])]
|
||||
feat = load_inputs_and_targets(batch)[0][0]
|
||||
enc = model.encode(paddle.to_tensor(feat).to(device=device, dtype=dtype))
|
||||
nbest_hyps = beam_search(
|
||||
x=enc, maxlenratio=args.maxlenratio, minlenratio=args.minlenratio
|
||||
)
|
||||
nbest_hyps = [
|
||||
h.asdict() for h in nbest_hyps[: min(len(nbest_hyps), args.nbest)]
|
||||
]
|
||||
new_js[name] = add_results_to_json(
|
||||
js[name], nbest_hyps, train_args.char_list
|
||||
)
|
||||
|
||||
with open(args.result_label, "wb") as f:
|
||||
f.write(
|
||||
json.dumps(
|
||||
{"utts": new_js}, indent=4, ensure_ascii=False, sort_keys=True
|
||||
).encode("utf_8")
|
||||
)
|
@ -0,0 +1,148 @@
|
||||
"""ASR Interface module."""
|
||||
import argparse
|
||||
|
||||
from deepspeech.utils.dynamic_import import dynamic_import
|
||||
|
||||
|
||||
class ASRInterface:
|
||||
"""ASR Interface for ESPnet model implementation."""
|
||||
|
||||
@staticmethod
|
||||
def add_arguments(parser):
|
||||
"""Add arguments to parser."""
|
||||
return parser
|
||||
|
||||
@classmethod
|
||||
def build(cls, idim: int, odim: int, **kwargs):
|
||||
"""Initialize this class with python-level args.
|
||||
|
||||
Args:
|
||||
idim (int): The number of an input feature dim.
|
||||
odim (int): The number of output vocab.
|
||||
|
||||
Returns:
|
||||
ASRinterface: A new instance of ASRInterface.
|
||||
|
||||
"""
|
||||
args = argparse.Namespace(**kwargs)
|
||||
return cls(idim, odim, args)
|
||||
|
||||
def forward(self, xs, ilens, ys, olens):
|
||||
"""Compute loss for training.
|
||||
|
||||
:param xs: batch of padded source sequences paddle.Tensor (B, Tmax, idim)
|
||||
:param ilens: batch of lengths of source sequences (B), paddle.Tensor
|
||||
:param ys: batch of padded target sequences paddle.Tensor (B, Lmax)
|
||||
:param olens: batch of lengths of target sequences (B), paddle.Tensor
|
||||
:return: loss value
|
||||
:rtype: paddle.Tensor
|
||||
"""
|
||||
raise NotImplementedError("forward method is not implemented")
|
||||
|
||||
def recognize(self, x, recog_args, char_list=None, rnnlm=None):
|
||||
"""Recognize x for evaluation.
|
||||
|
||||
:param ndarray x: input acouctic feature (B, T, D) or (T, D)
|
||||
:param namespace recog_args: argment namespace contraining options
|
||||
:param list char_list: list of characters
|
||||
:param paddle.nn.Layer rnnlm: language model module
|
||||
:return: N-best decoding results
|
||||
:rtype: list
|
||||
"""
|
||||
raise NotImplementedError("recognize method is not implemented")
|
||||
|
||||
def recognize_batch(self, x, recog_args, char_list=None, rnnlm=None):
|
||||
"""Beam search implementation for batch.
|
||||
|
||||
:param paddle.Tensor x: encoder hidden state sequences (B, Tmax, Henc)
|
||||
:param namespace recog_args: argument namespace containing options
|
||||
:param list char_list: list of characters
|
||||
:param paddle.nn.Module rnnlm: language model module
|
||||
:return: N-best decoding results
|
||||
:rtype: list
|
||||
"""
|
||||
raise NotImplementedError("Batch decoding is not supported yet.")
|
||||
|
||||
def calculate_all_attentions(self, xs, ilens, ys):
|
||||
"""Calculate attention.
|
||||
|
||||
:param list xs: list of padded input sequences [(T1, idim), (T2, idim), ...]
|
||||
:param ndarray ilens: batch of lengths of input sequences (B)
|
||||
:param list ys: list of character id sequence tensor [(L1), (L2), (L3), ...]
|
||||
:return: attention weights (B, Lmax, Tmax)
|
||||
:rtype: float ndarray
|
||||
"""
|
||||
raise NotImplementedError("calculate_all_attentions method is not implemented")
|
||||
|
||||
def calculate_all_ctc_probs(self, xs, ilens, ys):
|
||||
"""Calculate CTC probability.
|
||||
|
||||
:param list xs_pad: list of padded input sequences [(T1, idim), (T2, idim), ...]
|
||||
:param ndarray ilens: batch of lengths of input sequences (B)
|
||||
:param list ys: list of character id sequence tensor [(L1), (L2), (L3), ...]
|
||||
:return: CTC probabilities (B, Tmax, vocab)
|
||||
:rtype: float ndarray
|
||||
"""
|
||||
raise NotImplementedError("calculate_all_ctc_probs method is not implemented")
|
||||
|
||||
@property
|
||||
def attention_plot_class(self):
|
||||
"""Get attention plot class."""
|
||||
from espnet.asr.asr_utils import PlotAttentionReport
|
||||
|
||||
return PlotAttentionReport
|
||||
|
||||
@property
|
||||
def ctc_plot_class(self):
|
||||
"""Get CTC plot class."""
|
||||
from espnet.asr.asr_utils import PlotCTCReport
|
||||
|
||||
return PlotCTCReport
|
||||
|
||||
def get_total_subsampling_factor(self):
|
||||
"""Get total subsampling factor."""
|
||||
raise NotImplementedError(
|
||||
"get_total_subsampling_factor method is not implemented"
|
||||
)
|
||||
|
||||
def encode(self, feat):
|
||||
"""Encode feature in `beam_search` (optional).
|
||||
|
||||
Args:
|
||||
x (numpy.ndarray): input feature (T, D)
|
||||
Returns:
|
||||
paddle.Tensor: encoded feature (T, D)
|
||||
"""
|
||||
raise NotImplementedError("encode method is not implemented")
|
||||
|
||||
def scorers(self):
|
||||
"""Get scorers for `beam_search` (optional).
|
||||
|
||||
Returns:
|
||||
dict[str, ScorerInterface]: dict of `ScorerInterface` objects
|
||||
|
||||
"""
|
||||
raise NotImplementedError("decoders method is not implemented")
|
||||
|
||||
|
||||
predefined_asr = {
|
||||
"transformer": "deepspeech.models.u2:E2E",
|
||||
"conformer": "deepspeech.models.u2:E2E",
|
||||
}
|
||||
|
||||
def dynamic_import_asr(module, name):
|
||||
"""Import ASR models dynamically.
|
||||
|
||||
Args:
|
||||
module (str): module_name:class_name or alias in `predefined_asr`
|
||||
name (str): asr name. e.g., transformer, conformer
|
||||
|
||||
Returns:
|
||||
type: ASR class
|
||||
|
||||
"""
|
||||
model_class = dynamic_import(module, predefined_asr.get(name, ""))
|
||||
assert issubclass(
|
||||
model_class, ASRInterface
|
||||
), f"{module} does not implement ASRInterface"
|
||||
return model_class
|
Loading…
Reference in new issue