From 506f2bfd208764af0293008f9a2795c3f3ea35c3 Mon Sep 17 00:00:00 2001 From: Hui Zhang Date: Sun, 24 Oct 2021 13:48:56 +0000 Subject: [PATCH] add lm interface --- deepspeech/models/lm/transformer.py | 7 +-- deepspeech/models/lm_interface.py | 69 +++++++++++++++++++++++++++++ 2 files changed, 73 insertions(+), 3 deletions(-) create mode 100644 deepspeech/models/lm_interface.py diff --git a/deepspeech/models/lm/transformer.py b/deepspeech/models/lm/transformer.py index ebe70e4b..9392a15e 100644 --- a/deepspeech/models/lm/transformer.py +++ b/deepspeech/models/lm/transformer.py @@ -23,9 +23,10 @@ import paddle.nn.functional as F from deepspeech.modules.mask import subsequent_mask from deepspeech.modules.encoder import TransformerEncoder from deepspeech.decoders.scorers.scorer_interface import BatchScorerInterface +from deepspeech.models.lm_interface import #LMInterface -class TransformerLM(nn.Layer, BatchScorerInterface): +class TransformerLM(nn.Layer, LMInterface, BatchScorerInterface): def __init__( self, n_vocab: int, @@ -90,7 +91,7 @@ class TransformerLM(nn.Layer, BatchScorerInterface): return ys_mask.unsqueeze(-2) & m def forward( - self, x: paddle.Tensor, xlens, t: paddle.Tensor + self, x: paddle.Tensor, t: paddle.Tensor ) -> Tuple[paddle.Tensor, paddle.Tensor, paddle.Tensor]: """Compute LM loss value from buffer sequences. @@ -110,11 +111,11 @@ class TransformerLM(nn.Layer, BatchScorerInterface): """ xm = x != 0 + xlen = xm.sum(axis=1) if self.embed_drop is not None: emb = self.embed_drop(self.embed(x)) else: emb = self.embed(x) - xlen = xm.sum(axis=1) h, _ = self.encoder(emb, xlen) y = self.decoder(h) loss = F.cross_entropy(y.view(-1, y.shape[-1]), t.view(-1), reduction="none") diff --git a/deepspeech/models/lm_interface.py b/deepspeech/models/lm_interface.py new file mode 100644 index 00000000..ed6d5d9c --- /dev/null +++ b/deepspeech/models/lm_interface.py @@ -0,0 +1,69 @@ +"""Language model interface.""" + +import argparse + +from deepspeech.decoders.scorers.scorer_interface import ScorerInterface +from deepspeech.utils.dynamic_import import dynamic_import + +class LMInterface(ScorerInterface): + """LM Interface for ESPnet model implementation.""" + + @staticmethod + def add_arguments(parser): + """Add arguments to command line argument parser.""" + return parser + + @classmethod + def build(cls, n_vocab: int, **kwargs): + """Initialize this class with python-level args. + + Args: + idim (int): The number of vocabulary. + + Returns: + LMinterface: A new instance of LMInterface. + + """ + args = argparse.Namespace(**kwargs) + return cls(n_vocab, args) + + def forward(self, x, t): + """Compute LM loss value from buffer sequences. + + Args: + x (torch.Tensor): Input ids. (batch, len) + t (torch.Tensor): Target ids. (batch, len) + + Returns: + tuple[torch.Tensor, torch.Tensor, torch.Tensor]: Tuple of + loss to backward (scalar), + negative log-likelihood of t: -log p(t) (scalar) and + the number of elements in x (scalar) + + Notes: + The last two return values are used + in perplexity: p(t)^{-n} = exp(-log p(t) / n) + + """ + raise NotImplementedError("forward method is not implemented") + + +predefined_lms = { + "transformer": "deepspeech.models.lm.transformer:TransformerLM", +} + +def dynamic_import_lm(module): + """Import LM class dynamically. + + Args: + module (str): module_name:class_name or alias in `predefined_lms` + + Returns: + type: LM class + + """ + model_class = dynamic_import(module, predefined_lms) + assert issubclass( + model_class, LMInterface + ), f"{module} does not implement LMInterface" + return model_class