add recog interface

pull/926/head
Hui Zhang 3 years ago
parent bb75735fac
commit f2f305cd66

@ -0,0 +1,154 @@
"""V2 backend for `asr_recog.py` using py:class:`espnet.nets.beam_search.BeamSearch`."""
import json
import paddle
# from espnet.asr.asr_utils import get_model_conf
# from espnet.asr.asr_utils import torch_load
# from espnet.asr.pytorch_backend.asr import load_trained_model
# from espnet.nets.lm_interface import dynamic_import_lm
# from espnet.nets.asr_interface import ASRInterface
from .utils import add_results_to_json
# from .batch_beam_search import BatchBeamSearch
from .beam_search import BeamSearch
from .scorer_interface import BatchScorerInterface
from .scorers.length_bonus import LengthBonus
from deepspeech.io.reader import LoadInputsAndTargets
from deepspeech.utils.log import Log
logger = Log(__name__).getlog()
def recog_v2(args):
"""Decode with custom models that implements ScorerInterface.
Args:
args (namespace): The program arguments.
See py:func:`bin.asr_recog.get_parser` for details
"""
logger.warning("experimental API for custom LMs is selected by --api v2")
if args.batchsize > 1:
raise NotImplementedError("multi-utt batch decoding is not implemented")
if args.streaming_mode is not None:
raise NotImplementedError("streaming mode is not implemented")
if args.word_rnnlm:
raise NotImplementedError("word LM is not implemented")
# set_deterministic(args)
model, train_args = load_trained_model(args.model)
# assert isinstance(model, ASRInterface)
model.eval()
load_inputs_and_targets = LoadInputsAndTargets(
mode="asr",
load_output=False,
sort_in_input_length=False,
preprocess_conf=train_args.preprocess_conf
if args.preprocess_conf is None
else args.preprocess_conf,
preprocess_args={"train": False},
)
if args.rnnlm:
lm_args = get_model_conf(args.rnnlm, args.rnnlm_conf)
# NOTE: for a compatibility with less than 0.5.0 version models
lm_model_module = getattr(lm_args, "model_module", "default")
lm_class = dynamic_import_lm(lm_model_module, lm_args.backend)
lm = lm_class(len(train_args.char_list), lm_args)
torch_load(args.rnnlm, lm)
lm.eval()
else:
lm = None
if args.ngram_model:
from .scorers.ngram import NgramFullScorer
from .scorers.ngram import NgramPartScorer
if args.ngram_scorer == "full":
ngram = NgramFullScorer(args.ngram_model, train_args.char_list)
else:
ngram = NgramPartScorer(args.ngram_model, train_args.char_list)
else:
ngram = None
scorers = model.scorers()
scorers["lm"] = lm
scorers["ngram"] = ngram
scorers["length_bonus"] = LengthBonus(len(train_args.char_list))
weights = dict(
decoder=1.0 - args.ctc_weight,
ctc=args.ctc_weight,
lm=args.lm_weight,
ngram=args.ngram_weight,
length_bonus=args.penalty,
)
beam_search = BeamSearch(
beam_size=args.beam_size,
vocab_size=len(train_args.char_list),
weights=weights,
scorers=scorers,
sos=model.sos,
eos=model.eos,
token_list=train_args.char_list,
pre_beam_score_key=None if args.ctc_weight == 1.0 else "full",
)
# TODO(karita): make all scorers batchfied
if args.batchsize == 1:
non_batch = [
k
for k, v in beam_search.full_scorers.items()
if not isinstance(v, BatchScorerInterface)
]
if len(non_batch) == 0:
beam_search.__class__ = BatchBeamSearch
logger.info("BatchBeamSearch implementation is selected.")
else:
logger.warning(
f"As non-batch scorers {non_batch} are found, "
f"fall back to non-batch implementation."
)
if args.ngpu > 1:
raise NotImplementedError("only single GPU decoding is supported")
if args.ngpu == 1:
device = "gpu:0"
else:
device = "cpu"
dtype = getattr(paddle, args.dtype)
logger.info(f"Decoding device={device}, dtype={dtype}")
model.to(device=device, dtype=dtype)
model.eval()
beam_search.to(device=device, dtype=dtype)
beam_search.eval()
# read json data
with open(args.recog_json, "rb") as f:
js = json.load(f)
# josnlines to dict, key by 'utt'
js = {item['utt']: item for item in js}
new_js = {}
with paddle.no_grad():
for idx, name in enumerate(js.keys(), 1):
logger.info("(%d/%d) decoding " + name, idx, len(js.keys()))
batch = [(name, js[name])]
feat = load_inputs_and_targets(batch)[0][0]
enc = model.encode(paddle.to_tensor(feat).to(device=device, dtype=dtype))
nbest_hyps = beam_search(
x=enc, maxlenratio=args.maxlenratio, minlenratio=args.minlenratio
)
nbest_hyps = [
h.asdict() for h in nbest_hyps[: min(len(nbest_hyps), args.nbest)]
]
new_js[name] = add_results_to_json(
js[name], nbest_hyps, train_args.char_list
)
with open(args.result_label, "wb") as f:
f.write(
json.dumps(
{"utts": new_js}, indent=4, ensure_ascii=False, sort_keys=True
).encode("utf_8")
)

@ -47,3 +47,78 @@ def end_detect(ended_hyps, i, M=3, D_end=np.log(1 * np.exp(-10))):
return True return True
else: else:
return False return False
# * ------------------ recognition related ------------------ *
def parse_hypothesis(hyp, char_list):
"""Parse hypothesis.
Args:
hyp (list[dict[str, Any]]): Recognition hypothesis.
char_list (list[str]): List of characters.
Returns:
tuple(str, str, str, float)
"""
# remove sos and get results
tokenid_as_list = list(map(int, hyp["yseq"][1:]))
token_as_list = [char_list[idx] for idx in tokenid_as_list]
score = float(hyp["score"])
# convert to string
tokenid = " ".join([str(idx) for idx in tokenid_as_list])
token = " ".join(token_as_list)
text = "".join(token_as_list).replace("<space>", " ")
return text, token, tokenid, score
def add_results_to_json(js, nbest_hyps, char_list):
"""Add N-best results to json.
Args:
js (dict[str, Any]): Groundtruth utterance dict.
nbest_hyps_sd (list[dict[str, Any]]):
List of hypothesis for multi_speakers: nutts x nspkrs.
char_list (list[str]): List of characters.
Returns:
dict[str, Any]: N-best results added utterance dict.
"""
# copy old json info
new_js = dict()
new_js["utt2spk"] = js["utt2spk"]
new_js["output"] = []
for n, hyp in enumerate(nbest_hyps, 1):
# parse hypothesis
rec_text, rec_token, rec_tokenid, score = parse_hypothesis(hyp, char_list)
# copy ground-truth
if len(js["output"]) > 0:
out_dic = dict(js["output"][0].items())
else:
# for no reference case (e.g., speech translation)
out_dic = {"name": ""}
# update name
out_dic["name"] += "[%d]" % n
# add recognition results
out_dic["rec_text"] = rec_text
out_dic["rec_token"] = rec_token
out_dic["rec_tokenid"] = rec_tokenid
out_dic["score"] = score
# add to list of N-best result dicts
new_js["output"].append(out_dic)
# show 1-best result
if n == 1:
if "text" in out_dic.keys():
logging.info("groundtruth: %s" % out_dic["text"])
logging.info("prediction : %s" % out_dic["rec_text"])
return new_js

@ -0,0 +1,148 @@
"""ASR Interface module."""
import argparse
from deepspeech.utils.dynamic_import import dynamic_import
class ASRInterface:
"""ASR Interface for ESPnet model implementation."""
@staticmethod
def add_arguments(parser):
"""Add arguments to parser."""
return parser
@classmethod
def build(cls, idim: int, odim: int, **kwargs):
"""Initialize this class with python-level args.
Args:
idim (int): The number of an input feature dim.
odim (int): The number of output vocab.
Returns:
ASRinterface: A new instance of ASRInterface.
"""
args = argparse.Namespace(**kwargs)
return cls(idim, odim, args)
def forward(self, xs, ilens, ys, olens):
"""Compute loss for training.
:param xs: batch of padded source sequences paddle.Tensor (B, Tmax, idim)
:param ilens: batch of lengths of source sequences (B), paddle.Tensor
:param ys: batch of padded target sequences paddle.Tensor (B, Lmax)
:param olens: batch of lengths of target sequences (B), paddle.Tensor
:return: loss value
:rtype: paddle.Tensor
"""
raise NotImplementedError("forward method is not implemented")
def recognize(self, x, recog_args, char_list=None, rnnlm=None):
"""Recognize x for evaluation.
:param ndarray x: input acouctic feature (B, T, D) or (T, D)
:param namespace recog_args: argment namespace contraining options
:param list char_list: list of characters
:param paddle.nn.Layer rnnlm: language model module
:return: N-best decoding results
:rtype: list
"""
raise NotImplementedError("recognize method is not implemented")
def recognize_batch(self, x, recog_args, char_list=None, rnnlm=None):
"""Beam search implementation for batch.
:param paddle.Tensor x: encoder hidden state sequences (B, Tmax, Henc)
:param namespace recog_args: argument namespace containing options
:param list char_list: list of characters
:param paddle.nn.Module rnnlm: language model module
:return: N-best decoding results
:rtype: list
"""
raise NotImplementedError("Batch decoding is not supported yet.")
def calculate_all_attentions(self, xs, ilens, ys):
"""Calculate attention.
:param list xs: list of padded input sequences [(T1, idim), (T2, idim), ...]
:param ndarray ilens: batch of lengths of input sequences (B)
:param list ys: list of character id sequence tensor [(L1), (L2), (L3), ...]
:return: attention weights (B, Lmax, Tmax)
:rtype: float ndarray
"""
raise NotImplementedError("calculate_all_attentions method is not implemented")
def calculate_all_ctc_probs(self, xs, ilens, ys):
"""Calculate CTC probability.
:param list xs_pad: list of padded input sequences [(T1, idim), (T2, idim), ...]
:param ndarray ilens: batch of lengths of input sequences (B)
:param list ys: list of character id sequence tensor [(L1), (L2), (L3), ...]
:return: CTC probabilities (B, Tmax, vocab)
:rtype: float ndarray
"""
raise NotImplementedError("calculate_all_ctc_probs method is not implemented")
@property
def attention_plot_class(self):
"""Get attention plot class."""
from espnet.asr.asr_utils import PlotAttentionReport
return PlotAttentionReport
@property
def ctc_plot_class(self):
"""Get CTC plot class."""
from espnet.asr.asr_utils import PlotCTCReport
return PlotCTCReport
def get_total_subsampling_factor(self):
"""Get total subsampling factor."""
raise NotImplementedError(
"get_total_subsampling_factor method is not implemented"
)
def encode(self, feat):
"""Encode feature in `beam_search` (optional).
Args:
x (numpy.ndarray): input feature (T, D)
Returns:
paddle.Tensor: encoded feature (T, D)
"""
raise NotImplementedError("encode method is not implemented")
def scorers(self):
"""Get scorers for `beam_search` (optional).
Returns:
dict[str, ScorerInterface]: dict of `ScorerInterface` objects
"""
raise NotImplementedError("decoders method is not implemented")
predefined_asr = {
"transformer": "deepspeech.models.u2:E2E",
"conformer": "deepspeech.models.u2:E2E",
}
def dynamic_import_asr(module, name):
"""Import ASR models dynamically.
Args:
module (str): module_name:class_name or alias in `predefined_asr`
name (str): asr name. e.g., transformer, conformer
Returns:
type: ASR class
"""
model_class = dynamic_import(module, predefined_asr.get(name, ""))
assert issubclass(
model_class, ASRInterface
), f"{module} does not implement ASRInterface"
return model_class

@ -49,13 +49,15 @@ from deepspeech.utils.tensor_utils import pad_sequence
from deepspeech.utils.tensor_utils import th_accuracy from deepspeech.utils.tensor_utils import th_accuracy
from deepspeech.utils.utility import log_add from deepspeech.utils.utility import log_add
from deepspeech.utils.utility import UpdateConfig from deepspeech.utils.utility import UpdateConfig
from deepspeech.models.asr_interface import ASRInterface
from deepspeech.decoders.scorers.ctc_prefix_score import CTCPrefixScorer
__all__ = ["U2Model", "U2InferModel"] __all__ = ["U2Model", "U2InferModel"]
logger = Log(__name__).getlog() logger = Log(__name__).getlog()
class U2BaseModel(nn.Layer): class U2BaseModel(ASRInterface, nn.Layer):
"""CTC-Attention hybrid Encoder-Decoder model""" """CTC-Attention hybrid Encoder-Decoder model"""
@classmethod @classmethod
@ -120,7 +122,7 @@ class U2BaseModel(nn.Layer):
**kwargs): **kwargs):
assert 0.0 <= ctc_weight <= 1.0, ctc_weight assert 0.0 <= ctc_weight <= 1.0, ctc_weight
super().__init__() nn.Layer.__init__(self)
# note that eos is the same as sos (equivalent ID) # note that eos is the same as sos (equivalent ID)
self.sos = vocab_size - 1 self.sos = vocab_size - 1
self.eos = vocab_size - 1 self.eos = vocab_size - 1
@ -813,7 +815,27 @@ class U2BaseModel(nn.Layer):
return res, res_tokenids return res, res_tokenids
class U2Model(U2BaseModel): class U2DecodeModel(U2BaseModel):
def scorers(self):
"""Scorers."""
return dict(decoder=self.decoder, ctc=CTCPrefixScorer(self.ctc, self.eos))
def encode(self, x):
"""Encode acoustic features.
:param ndarray x: source acoustic feature (T, D)
:return: encoder outputs
:rtype: paddle.Tensor
"""
self.eval()
x = paddle.to_tensor(x).unsqueeze(0)
ilen = x.size(1)
enc_output, _ = self._forward_encoder(x, ilen)
return enc_output.squeeze(0)
class U2Model(U2DecodeModel):
def __init__(self, configs: dict): def __init__(self, configs: dict):
vocab_size, encoder, decoder, ctc = U2Model._init_from_config(configs) vocab_size, encoder, decoder, ctc = U2Model._init_from_config(configs)

@ -15,6 +15,7 @@
from typing import List from typing import List
from typing import Optional from typing import Optional
from typing import Tuple from typing import Tuple
from typing import Any
import paddle import paddle
from paddle import nn from paddle import nn
@ -25,7 +26,9 @@ from deepspeech.modules.decoder_layer import DecoderLayer
from deepspeech.modules.embedding import PositionalEncoding from deepspeech.modules.embedding import PositionalEncoding
from deepspeech.modules.mask import make_non_pad_mask from deepspeech.modules.mask import make_non_pad_mask
from deepspeech.modules.mask import subsequent_mask from deepspeech.modules.mask import subsequent_mask
from deepspeech.modules.mask import make_xs_mask
from deepspeech.modules.positionwise_feed_forward import PositionwiseFeedForward from deepspeech.modules.positionwise_feed_forward import PositionwiseFeedForward
from deepspeech.decoders.scorers.score_interface import BatchScorerInterface
from deepspeech.utils.log import Log from deepspeech.utils.log import Log
logger = Log(__name__).getlog() logger = Log(__name__).getlog()
@ -33,7 +36,7 @@ logger = Log(__name__).getlog()
__all__ = ["TransformerDecoder"] __all__ = ["TransformerDecoder"]
class TransformerDecoder(nn.Layer): class TransformerDecoder(BatchScorerInterface, nn.Layer):
"""Base class of Transfomer decoder module. """Base class of Transfomer decoder module.
Args: Args:
vocab_size: output dim vocab_size: output dim
@ -71,7 +74,8 @@ class TransformerDecoder(nn.Layer):
concat_after: bool=False, ): concat_after: bool=False, ):
assert check_argument_types() assert check_argument_types()
super().__init__() nn.Layer.__init__(self)
self.selfattention_layer_type = 'selfattn'
attention_dim = encoder_output_size attention_dim = encoder_output_size
if input_layer == "embed": if input_layer == "embed":
@ -180,3 +184,64 @@ class TransformerDecoder(nn.Layer):
if self.use_output_layer: if self.use_output_layer:
y = paddle.log_softmax(self.output_layer(y), axis=-1) y = paddle.log_softmax(self.output_layer(y), axis=-1)
return y, new_cache return y, new_cache
# beam search API (see ScorerInterface)
def score(self, ys, state, x):
"""Score.
ys: (ylen,)
x: (xlen, n_feat)
"""
ys_mask = subsequent_mask(len(ys)).unsqueeze(0)
x_mask = make_xs_mask(x.unsqueeze(0))
if self.selfattention_layer_type != "selfattn":
# TODO(karita): implement cache
logging.warning(
f"{self.selfattention_layer_type} does not support cached decoding."
)
state = None
logp, state = self.forward_one_step(
x.unsqueeze(0), x_mask,
ys.unsqueeze(0), ys_mask,
cache=state
)
return logp.squeeze(0), state
# batch beam search API (see BatchScorerInterface)
def batch_score(
self, ys: paddle.Tensor, states: List[Any], xs: paddle.Tensor
) -> Tuple[paddle.Tensor, List[Any]]:
"""Score new token batch (required).
Args:
ys (paddle.Tensor): paddle.int64 prefix tokens (n_batch, ylen).
states (List[Any]): Scorer states for prefix tokens.
xs (paddle.Tensor):
The encoder feature that generates ys (n_batch, xlen, n_feat).
Returns:
tuple[paddle.Tensor, List[Any]]: Tuple of
batchfied scores for next token with shape of `(n_batch, n_vocab)`
and next state list for ys.
"""
# merge states
n_batch = len(ys)
n_layers = len(self.decoders)
if states[0] is None:
batch_state = None
else:
# transpose state of [batch, layer] into [layer, batch]
batch_state = [
paddle.stack([states[b][i] for b in range(n_batch)])
for i in range(n_layers)
]
# batch decoding
ys_mask = subsequent_mask(ys.size(-1)).unsqueeze(0)
xs_mask = make_xs_mask(xs)
logp, states = self.forward_one_step(xs, xs_mask, ys, ys_mask, cache=batch_state)
# transpose state of [layer, batch] into [batch, layer]
state_list = [[states[i][b] for i in range(n_layers)] for b in range(n_batch)]
return logp, state_list

@ -18,12 +18,24 @@ from deepspeech.utils.log import Log
logger = Log(__name__).getlog() logger = Log(__name__).getlog()
__all__ = [ __all__ = [
"make_pad_mask", "make_non_pad_mask", "subsequent_mask", "make_xs_mask", "make_pad_mask", "make_non_pad_mask", "subsequent_mask",
"subsequent_chunk_mask", "add_optional_chunk_mask", "mask_finished_scores", "subsequent_chunk_mask", "add_optional_chunk_mask", "mask_finished_scores",
"mask_finished_preds" "mask_finished_preds"
] ]
def make_xs_mask(xs:paddle.Tensor) -> paddle.Tensor:
"""Maks mask tensor containing indices of non-padded part.
Args:
xs (paddle.Tensor): (B, T, D), zeros for pad.
Returns:
paddle.Tensor: Mask Tensor indices of non-padded part. (B, T, D)
"""
pad_frame = paddle.zeros([1, 1, xs.shape[-1]], dtype=xs.dtype)
mask = xs != pad_frame
return mask
def make_pad_mask(lengths: paddle.Tensor) -> paddle.Tensor: def make_pad_mask(lengths: paddle.Tensor) -> paddle.Tensor:
"""Make mask tensor containing indices of padded part. """Make mask tensor containing indices of padded part.
See description of make_non_pad_mask. See description of make_non_pad_mask.
@ -31,6 +43,7 @@ def make_pad_mask(lengths: paddle.Tensor) -> paddle.Tensor:
lengths (paddle.Tensor): Batch of lengths (B,). lengths (paddle.Tensor): Batch of lengths (B,).
Returns: Returns:
paddle.Tensor: Mask tensor containing indices of padded part. paddle.Tensor: Mask tensor containing indices of padded part.
(B, T)
Examples: Examples:
>>> lengths = [5, 3, 2] >>> lengths = [5, 3, 2]
>>> make_pad_mask(lengths) >>> make_pad_mask(lengths)
@ -62,6 +75,7 @@ def make_non_pad_mask(lengths: paddle.Tensor) -> paddle.Tensor:
lengths (paddle.Tensor): Batch of lengths (B,). lengths (paddle.Tensor): Batch of lengths (B,).
Returns: Returns:
paddle.Tensor: mask tensor containing indices of padded part. paddle.Tensor: mask tensor containing indices of padded part.
(B, T)
Examples: Examples:
>>> lengths = [5, 3, 2] >>> lengths = [5, 3, 2]
>>> make_non_pad_mask(lengths) >>> make_non_pad_mask(lengths)

Loading…
Cancel
Save