From 4566351127ea67179b9c4dba8461a900e3332e39 Mon Sep 17 00:00:00 2001 From: Hui Zhang Date: Sun, 24 Oct 2021 13:21:16 +0000 Subject: [PATCH] add transformer lm and encoder score api --- deepspeech/models/lm/__init__.py | 0 deepspeech/models/lm/transformer.py | 259 ++++++++++++++++++++++++++++ deepspeech/modules/encoder.py | 41 +++++ deepspeech/modules/encoder_layer.py | 6 +- deepspeech/modules/subsampling.py | 9 +- 5 files changed, 309 insertions(+), 6 deletions(-) create mode 100644 deepspeech/models/lm/__init__.py create mode 100644 deepspeech/models/lm/transformer.py diff --git a/deepspeech/models/lm/__init__.py b/deepspeech/models/lm/__init__.py new file mode 100644 index 000000000..e69de29bb diff --git a/deepspeech/models/lm/transformer.py b/deepspeech/models/lm/transformer.py new file mode 100644 index 000000000..ebe70e4b2 --- /dev/null +++ b/deepspeech/models/lm/transformer.py @@ -0,0 +1,259 @@ +# Copyright (c) 2021 PaddlePaddle Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +from typing import Any +from typing import List +from typing import Tuple + +import numpy as np +import paddle +import paddle.nn as nn +import paddle.nn.functional as F + +from deepspeech.modules.mask import subsequent_mask +from deepspeech.modules.encoder import TransformerEncoder +from deepspeech.decoders.scorers.scorer_interface import BatchScorerInterface +#LMInterface + +class TransformerLM(nn.Layer, BatchScorerInterface): + def __init__( + self, + n_vocab: int, + pos_enc: str=None, + embed_unit: int=128, + att_unit: int=256, + head: int=2, + unit: int=1024, + layer: int=4, + dropout_rate: float=0.5, + emb_dropout_rate: float = 0.0, + att_dropout_rate: float = 0.0, + tie_weights: bool = False,): + nn.Layer.__init__(self) + + if pos_enc == "sinusoidal": + pos_enc_layer_type = "abs_pos" + elif pos_enc is None: + #TODO + pos_enc_layer_type = "None" + else: + raise ValueError(f"unknown pos-enc option: {pos_enc}") + + self.embed = nn.Embedding(n_vocab, embed_unit) + + if emb_dropout_rate == 0.0: + self.embed_drop = None + else: + self.embed_drop = nn.Dropout(emb_dropout_rate) + + self.encoder = TransformerEncoder( + input_size=embed_unit, + output_size=att_unit, + attention_heads=head, + linear_units=unit, + num_blocks=layer, + dropout_rate=dropout_rate, + attention_dropout_rate=att_dropout_rate, + input_layer="linear", + pos_enc_layer_type=pos_enc_layer_type, + concat_after=False, + static_chunk_size=1, + use_dynamic_chunk=False, + use_dynamic_left_chunk=False) + + self.decoder = nn.Linear(att_unit, n_vocab) + + logging.info("Tie weights set to {}".format(tie_weights)) + logging.info("Dropout set to {}".format(dropout_rate)) + logging.info("Emb Dropout set to {}".format(emb_dropout_rate)) + logging.info("Att Dropout set to {}".format(att_dropout_rate)) + + if tie_weights: + assert ( + att_unit == embed_unit + ), "Tie Weights: True need embedding and final dimensions to match" + self.decoder.weight = self.embed.weight + + def _target_mask(self, ys_in_pad): + ys_mask = ys_in_pad != 0 + m = subsequent_mask(ys_mask.size(-1)).unsqueeze(0) + return ys_mask.unsqueeze(-2) & m + + def forward( + self, x: paddle.Tensor, xlens, t: paddle.Tensor + ) -> Tuple[paddle.Tensor, paddle.Tensor, paddle.Tensor]: + """Compute LM loss value from buffer sequences. + + Args: + x (paddle.Tensor): Input ids. (batch, len) + t (paddle.Tensor): Target ids. (batch, len) + + Returns: + tuple[paddle.Tensor, paddle.Tensor, paddle.Tensor]: Tuple of + loss to backward (scalar), + negative log-likelihood of t: -log p(t) (scalar) and + the number of elements in x (scalar) + + Notes: + The last two return values are used + in perplexity: p(t)^{-n} = exp(-log p(t) / n) + + """ + xm = x != 0 + if self.embed_drop is not None: + emb = self.embed_drop(self.embed(x)) + else: + emb = self.embed(x) + xlen = xm.sum(axis=1) + h, _ = self.encoder(emb, xlen) + y = self.decoder(h) + loss = F.cross_entropy(y.view(-1, y.shape[-1]), t.view(-1), reduction="none") + mask = xm.to(dtype=loss.dtype) + logp = loss * mask.view(-1) + logp = logp.sum() + count = mask.sum() + return logp / count, logp, count + + # beam search API (see ScorerInterface) + def score(self, y: paddle.Tensor, state: Any, + x: paddle.Tensor) -> Tuple[paddle.Tensor, Any]: + """Score new token. + + Args: + y (paddle.Tensor): 1D paddle.int64 prefix tokens. + state: Scorer state for prefix tokens + x (paddle.Tensor): encoder feature that generates ys. + + Returns: + tuple[paddle.Tensor, Any]: Tuple of + paddle.float32 scores for next token (n_vocab) + and next state for ys + + """ + y = y.unsqueeze(0) + + if self.embed_drop is not None: + emb = self.embed_drop(self.embed(y)) + else: + emb = self.embed(y) + + h, _, cache = self.encoder.forward_one_step( + emb, self._target_mask(y), cache=state + ) + h = self.decoder(h[:, -1]) + logp = h.log_softmax(axis=-1).squeeze(0) + return logp, cache + + # batch beam search API (see BatchScorerInterface) + def batch_score( + self, ys: paddle.Tensor, states: List[Any], xs: paddle.Tensor + ) -> Tuple[paddle.Tensor, List[Any]]: + """Score new token batch (required). + + Args: + ys (paddle.Tensor): paddle.int64 prefix tokens (n_batch, ylen). + states (List[Any]): Scorer states for prefix tokens. + xs (paddle.Tensor): + The encoder feature that generates ys (n_batch, xlen, n_feat). + + Returns: + tuple[paddle.Tensor, List[Any]]: Tuple of + batchfied scores for next token with shape of `(n_batch, n_vocab)` + and next state list for ys. + + """ + # merge states + n_batch = len(ys) + n_layers = len(self.encoder.encoders) + if states[0] is None: + batch_state = None + else: + # transpose state of [batch, layer] into [layer, batch] + batch_state = [ + paddle.stack([states[b][i] for b in range(n_batch)]) + for i in range(n_layers) + ] + + if self.embed_drop is not None: + emb = self.embed_drop(self.embed(ys)) + else: + emb = self.embed(ys) + + # batch decoding + h, _, states = self.encoder.forward_one_step( + emb, self._target_mask(ys), cache=batch_state + ) + h = self.decoder(h[:, -1]) + logp = h.log_softmax(axi=-1) + + # transpose state of [layer, batch] into [batch, layer] + state_list = [[states[i][b] for i in range(n_layers)] for b in range(n_batch)] + return logp, state_list + + +if __name__ == "__main__": + tlm = TransformerLM( + n_vocab=5002, + pos_enc=None, + embed_unit=128, + att_unit=512, + head=8, + unit=2048, + layer=16, + dropout_rate=0.5, ) + + # n_vocab: int, + # pos_enc: str=None, + # embed_unit: int=128, + # att_unit: int=256, + # head: int=2, + # unit: int=1024, + # layer: int=4, + # dropout_rate: float=0.5, + # emb_dropout_rate: float = 0.0, + # att_dropout_rate: float = 0.0, + # tie_weights: bool = False,): + paddle.set_device("cpu") + model_dict = paddle.load("transformerLM.pdparams") + tlm.set_state_dict(model_dict) + + tlm.eval() + #Test the score + input2 = np.array([5]) + input2 = paddle.to_tensor(input2) + state = (None, None, 0) + output, state = tlm.score(input2, state, None) + + input3 = np.array([10]) + input3 = paddle.to_tensor(input3) + output, state = tlm.score(input3, state, None) + + input4 = np.array([0]) + input4 = paddle.to_tensor(input4) + output, state = tlm.score(input4, state, None) + print("output", output) + """ + #Test the batch score + batch_size = 2 + inp2 = np.array([[5], [10]]) + inp2 = paddle.to_tensor(inp2) + output, states = tlm.batch_score( + inp2, [(None,None,0)] * batch_size) + inp3 = np.array([[100], [30]]) + inp3 = paddle.to_tensor(inp3) + output, states = tlm.batch_score( + inp3, states) + print("output", output) + #print("cache", cache) + #np.save("output_pd.npy", output) + """ \ No newline at end of file diff --git a/deepspeech/modules/encoder.py b/deepspeech/modules/encoder.py index 6ffb6465c..6de1ef4dd 100644 --- a/deepspeech/modules/encoder.py +++ b/deepspeech/modules/encoder.py @@ -31,6 +31,7 @@ from deepspeech.modules.encoder_layer import TransformerEncoderLayer from deepspeech.modules.mask import add_optional_chunk_mask from deepspeech.modules.mask import make_non_pad_mask from deepspeech.modules.positionwise_feed_forward import PositionwiseFeedForward +from deepspeech.modules.subsampling import Conv2dSubsampling from deepspeech.modules.subsampling import Conv2dSubsampling4 from deepspeech.modules.subsampling import Conv2dSubsampling6 from deepspeech.modules.subsampling import Conv2dSubsampling8 @@ -370,6 +371,46 @@ class TransformerEncoder(BaseEncoder): concat_after=concat_after) for _ in range(num_blocks) ]) + def forward_one_step( + self, + xs: paddle.Tensor, + masks: paddle.Tensor, + cache=None, + ) -> Tuple[paddle.Tensor, paddle.Tensor]: + """Encode input frame. + + Args: + xs (paddle.Tensor): Input tensor. (B, T, D) + masks (paddle.Tensor): Mask tensor. (B, 1, T) + cache (List[paddle.Tensor]): List of cache tensors. + + Returns: + paddle.Tensor: Output tensor. + paddle.Tensor: Mask tensor. + List[paddle.Tensor]: List of new cache tensors. + """ + if self.global_cmvn is not None: + xs = self.global_cmvn(xs) + + if isinstance(self.embed, Conv2dSubsampling): + # xs, masks = self.embed(xs, masks) + #TODO(Hui Zhang): self.embed(xs, masks, offset=0), stride_slice not support bool tensor + xs, pos_emb, masks = self.embed(xs, masks.astype(xs.dtype), offset=0) + else: + xs = self.embed(xs) + #TODO(Hui Zhang): remove mask.astype, stride_slice not support bool tensor + masks = masks.astype(paddle.bool) + + if cache is None: + cache = [None for _ in range(len(self.encoders))] + new_cache = [] + for c, e in zip(cache, self.encoders): + xs, masks, _ = e(xs, masks, output_cache=c) + new_cache.append(xs) + if self.normalize_before: + xs = self.after_norm(xs) + return xs, masks, new_cache + class ConformerEncoder(BaseEncoder): """Conformer encoder module.""" diff --git a/deepspeech/modules/encoder_layer.py b/deepspeech/modules/encoder_layer.py index 1db556ca5..6f49cfc86 100644 --- a/deepspeech/modules/encoder_layer.py +++ b/deepspeech/modules/encoder_layer.py @@ -71,7 +71,7 @@ class TransformerEncoderLayer(nn.Layer): self, x: paddle.Tensor, mask: paddle.Tensor, - pos_emb: paddle.Tensor, + pos_emb: Optional[paddle.Tensor]=None, mask_pad: Optional[paddle.Tensor]=None, output_cache: Optional[paddle.Tensor]=None, cnn_cache: Optional[paddle.Tensor]=None, @@ -82,8 +82,8 @@ class TransformerEncoderLayer(nn.Layer): mask (paddle.Tensor): Mask tensor for the input (#batch, time). pos_emb (paddle.Tensor): just for interface compatibility to ConformerEncoderLayer - mask_pad (paddle.Tensor): does not used in transformer layer, - just for unified api with conformer. + mask_pad (paddle.Tensor): not used here, it's for interface + compatibility to ConformerEncoderLayer output_cache (paddle.Tensor): Cache tensor of the output (#batch, time2, size), time2 < time in x. cnn_cache (paddle.Tensor): not used here, it's for interface diff --git a/deepspeech/modules/subsampling.py b/deepspeech/modules/subsampling.py index 3bed62f3c..f804907fb 100644 --- a/deepspeech/modules/subsampling.py +++ b/deepspeech/modules/subsampling.py @@ -82,8 +82,11 @@ class LinearNoSubsampling(BaseSubsampling): x, pos_emb = self.pos_enc(x, offset) return x, pos_emb, x_mask +class Conv2dSubsampling(BaseSubsampling): + def __init__(self, *args, **kwargs): + super().__init__(*args, **kwargs) -class Conv2dSubsampling4(BaseSubsampling): +class Conv2dSubsampling4(Conv2dSubsampling): """Convolutional 2D subsampling (to 1/4 length).""" def __init__(self, @@ -134,7 +137,7 @@ class Conv2dSubsampling4(BaseSubsampling): return x, pos_emb, x_mask[:, :, :-2:2][:, :, :-2:2] -class Conv2dSubsampling6(BaseSubsampling): +class Conv2dSubsampling6(Conv2dSubsampling): """Convolutional 2D subsampling (to 1/6 length).""" def __init__(self, @@ -187,7 +190,7 @@ class Conv2dSubsampling6(BaseSubsampling): return x, pos_emb, x_mask[:, :, :-2:2][:, :, :-4:3] -class Conv2dSubsampling8(BaseSubsampling): +class Conv2dSubsampling8(Conv2dSubsampling): """Convolutional 2D subsampling (to 1/8 length).""" def __init__(self,