diff --git a/deepspeech/__init__.py b/deepspeech/__init__.py index 3e96c40f1..fa5ef0439 100644 --- a/deepspeech/__init__.py +++ b/deepspeech/__init__.py @@ -49,10 +49,18 @@ if not hasattr(paddle, 'softmax'): logger.warn("register user softmax to paddle, remove this when fixed!") setattr(paddle, 'softmax', paddle.nn.functional.softmax) +if not hasattr(paddle, 'log_softmax'): + logger.warn("register user log_softmax to paddle, remove this when fixed!") + setattr(paddle, 'log_softmax', paddle.nn.functional.log_softmax) + if not hasattr(paddle, 'sigmoid'): logger.warn("register user sigmoid to paddle, remove this when fixed!") setattr(paddle, 'sigmoid', paddle.nn.functional.sigmoid) +if not hasattr(paddle, 'log_sigmoid'): + logger.warn("register user log_sigmoid to paddle, remove this when fixed!") + setattr(paddle, 'log_sigmoid', paddle.nn.functional.log_sigmoid) + if not hasattr(paddle, 'relu'): logger.warn("register user relu to paddle, remove this when fixed!") setattr(paddle, 'relu', paddle.nn.functional.relu) diff --git a/deepspeech/modules/decoder.py b/deepspeech/modules/decoder.py new file mode 100644 index 000000000..4dd989d58 --- /dev/null +++ b/deepspeech/modules/decoder.py @@ -0,0 +1,178 @@ +# Copyright (c) 2021 PaddlePaddle Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +"""Decoder definition.""" +from typing import Tuple, List, Optional +from typeguard import check_argument_types +import logging + +import paddle +from paddle import nn +from paddle.nn import functional as F +from paddle.nn import initializer as I + +from deepspeech.modules.attention import MultiHeadedAttention +from deepspeech.modules.decoder_layer import DecoderLayer +from deepspeech.modules.embedding import PositionalEncoding +from deepspeech.modules.positionwise_feed_forward import PositionwiseFeedForward +from deepspeech.modules.mask import subsequent_mask +from deepspeech.modules.mask import make_pad_mask + +logger = logging.getLogger(__name__) + +__all__ = ["TransformerDecoder"] + + +class TransformerDecoder(nn.Module): + """Base class of Transfomer decoder module. + Args: + vocab_size: output dim + encoder_output_size: dimension of attention + attention_heads: the number of heads of multi head attention + linear_units: the hidden units number of position-wise feedforward + num_blocks: the number of decoder blocks + dropout_rate: dropout rate + self_attention_dropout_rate: dropout rate for attention + input_layer: input layer type, `embed` + use_output_layer: whether to use output layer + pos_enc_class: PositionalEncoding module + normalize_before: + True: use layer_norm before each sub-block of a layer. + False: use layer_norm after each sub-block of a layer. + concat_after: whether to concat attention layer's input and output + True: x -> x + linear(concat(x, att(x))) + False: x -> x + att(x) + """ + + def __init__( + self, + vocab_size: int, + encoder_output_size: int, + attention_heads: int=4, + linear_units: int=2048, + num_blocks: int=6, + dropout_rate: float=0.1, + positional_dropout_rate: float=0.1, + self_attention_dropout_rate: float=0.0, + src_attention_dropout_rate: float=0.0, + input_layer: str="embed", + use_output_layer: bool=True, + normalize_before: bool=True, + concat_after: bool=False, ): + + assert check_argument_types() + super().__init__() + attention_dim = encoder_output_size + + if input_layer == "embed": + self.embed = nn.Sequential( + nn.Embedding(vocab_size, attention_dim), + PositionalEncoding(attention_dim, positional_dropout_rate), ) + else: + raise ValueError(f"only 'embed' is supported: {input_layer}") + + self.normalize_before = normalize_before + self.after_norm = nn.LayerNorm(attention_dim, epsilon=1e-12) + self.use_output_layer = use_output_layer + self.output_layer = nn.Linear(attention_dim, vocab_size) + + self.decoders = nn.ModuleList([ + DecoderLayer( + size=attention_dim, + self_attn=MultiHeadedAttention(attention_heads, attention_dim, + self_attention_dropout_rate), + src_attn=MultiHeadedAttention(attention_heads, attention_dim, + src_attention_dropout_rate), + feed_forward=PositionwiseFeedForward( + attention_dim, linear_units, dropout_rate), + dropout_rate=dropout_rate, + normalize_before=normalize_before, + concat_after=concat_after, ) for _ in range(num_blocks) + ]) + + def forward( + self, + memory: paddle.Tensor, + memory_mask: paddle.Tensor, + ys_in_pad: paddle.Tensor, + ys_in_lens: paddle.Tensor, ) -> Tuple[paddle.Tensor, paddle.Tensor]: + """Forward decoder. + Args: + memory: encoded memory, float32 (batch, maxlen_in, feat) + memory_mask: encoder memory mask, (batch, 1, maxlen_in) + ys_in_pad: padded input token ids, int64 (batch, maxlen_out) + ys_in_lens: input lengths of this batch (batch) + Returns: + (tuple): tuple containing: + x: decoded token score before softmax (batch, maxlen_out, vocab_size) + if use_output_layer is True, + olens: (batch, ) + """ + tgt = ys_in_pad + # tgt_mask: (B, 1, L) + tgt_mask = (make_non_pad_mask(ys_in_lens).unsqueeze(1)) + # m: (1, L, L) + m = subsequent_mask(tgt_mask.size(-1)).unsqueeze(0) + # tgt_mask: (B, L, L) + tgt_mask = tgt_mask & m + + x, _ = self.embed(tgt) + for layer in self.decoders: + x, tgt_mask, memory, memory_mask = layer(x, tgt_mask, memory, + memory_mask) + if self.normalize_before: + x = self.after_norm(x) + if self.use_output_layer: + x = self.output_layer(x) + + olens = tgt_mask.sum(1) + return x, olens + + def forward_one_step( + self, + memory: paddle.Tensor, + memory_mask: paddle.Tensor, + tgt: paddle.Tensor, + tgt_mask: paddle.Tensor, + cache: Optional[List[paddle.Tensor]]=None, + ) -> Tuple[paddle.Tensor, List[paddle.Tensor]]: + """Forward one step. + This is only used for decoding. + Args: + memory: encoded memory, float32 (batch, maxlen_in, feat) + memory_mask: encoded memory mask, (batch, 1, maxlen_in) + tgt: input token ids, int64 (batch, maxlen_out) + tgt_mask: input token mask, (batch, maxlen_out) + dtype=paddle.bool + cache: cached output list of (batch, max_time_out-1, size) + Returns: + y, cache: NN output value and cache per `self.decoders`. + y.shape` is (batch, maxlen_out, token) + """ + x, _ = self.embed(tgt) + new_cache = [] + for i, decoder in enumerate(self.decoders): + if cache is None: + c = None + else: + c = cache[i] + x, tgt_mask, memory, memory_mask = decoder( + x, tgt_mask, memory, memory_mask, cache=c) + new_cache.append(x) + if self.normalize_before: + y = self.after_norm(x[:, -1]) + else: + y = x[:, -1] + if self.use_output_layer: + y = paddle.log_softmax(self.output_layer(y), dim=-1) + return y, new_cache diff --git a/deepspeech/modules/decoder_layer.py b/deepspeech/modules/decoder_layer.py new file mode 100644 index 000000000..8e5ae1ac1 --- /dev/null +++ b/deepspeech/modules/decoder_layer.py @@ -0,0 +1,146 @@ +# Copyright (c) 2021 PaddlePaddle Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +"""Decoder self-attention layer definition.""" +from typing import Optional, Tuple +import logging + +import paddle +from paddle import nn +from paddle.nn import functional as F +from paddle.nn import initializer as I + +logger = logging.getLogger(__name__) + +__all__ = ["DecoderLayer"] + + +class DecoderLayer(nn.Module): + """Single decoder layer module. + Args: + size (int): Input dimension. + self_attn (nn.Module): Self-attention module instance. + `MultiHeadedAttention` instance can be used as the argument. + src_attn (nn.Module): Self-attention module instance. + `MultiHeadedAttention` instance can be used as the argument. + feed_forward (nn.Module): Feed-forward module instance. + `PositionwiseFeedForward` instance can be used as the argument. + dropout_rate (float): Dropout rate. + normalize_before (bool): + True: use layer_norm before each sub-block. + False: to use layer_norm after each sub-block. + concat_after (bool): Whether to concat attention layer's input + and output. + True: x -> x + linear(concat(x, att(x))) + False: x -> x + att(x) + """ + + def __init__( + self, + size: int, + self_attn: nn.Module, + src_attn: nn.Module, + feed_forward: nn.Module, + dropout_rate: float, + normalize_before: bool=True, + concat_after: bool=False, ): + """Construct an DecoderLayer object.""" + super().__init__() + self.size = size + self.self_attn = self_attn + self.src_attn = src_attn + self.feed_forward = feed_forward + self.norm1 = nn.LayerNorm(size, epsilon=1e-12) + self.norm2 = nn.LayerNorm(size, epsilon=1e-12) + self.norm3 = nn.LayerNorm(size, epsilon=1e-12) + self.dropout = nn.Dropout(dropout_rate) + self.normalize_before = normalize_before + self.concat_after = concat_after + self.concat_linear1 = nn.Linear(size + size, size) + self.concat_linear2 = nn.Linear(size + size, size) + + def forward( + self, + tgt: paddle.Tensor, + tgt_mask: paddle.Tensor, + memory: paddle.Tensor, + memory_mask: paddle.Tensor, + cache: Optional[paddle.Tensor]=None + ) -> Tuple[paddle.Tensor, paddle.Tensor, paddle.Tensor, paddle.Tensor]: + """Compute decoded features. + Args: + tgt (paddle.Tensor): Input tensor (#batch, maxlen_out, size). + tgt_mask (paddle.Tensor): Mask for input tensor + (#batch, maxlen_out). + memory (paddle.Tensor): Encoded memory + (#batch, maxlen_in, size). + memory_mask (paddle.Tensor): Encoded memory mask + (#batch, maxlen_in). + cache (paddle.Tensor): cached tensors. + (#batch, maxlen_out - 1, size). + Returns: + paddle.Tensor: Output tensor (#batch, maxlen_out, size). + paddle.Tensor: Mask for output tensor (#batch, maxlen_out). + paddle.Tensor: Encoded memory (#batch, maxlen_in, size). + paddle.Tensor: Encoded memory mask (#batch, maxlen_in). + """ + residual = tgt + if self.normalize_before: + tgt = self.norm1(tgt) + + if cache is None: + tgt_q = tgt + tgt_q_mask = tgt_mask + else: + # compute only the last frame query keeping dim: max_time_out -> 1 + assert cache.shape == ( + tgt.shape[0], tgt.shape[1] - 1, self.size, + ), "{cache.shape} == {(tgt.shape[0], tgt.shape[1] - 1, self.size)}" + tgt_q = tgt[:, -1:, :] + residual = residual[:, -1:, :] + tgt_q_mask = tgt_mask[:, -1:, :] + + if self.concat_after: + tgt_concat = paddle.cat( + (tgt_q, self.self_attn(tgt_q, tgt, tgt, tgt_q_mask)), dim=-1) + x = residual + self.concat_linear1(tgt_concat) + else: + x = residual + self.dropout( + self.self_attn(tgt_q, tgt, tgt, tgt_q_mask)) + if not self.normalize_before: + x = self.norm1(x) + + residual = x + if self.normalize_before: + x = self.norm2(x) + if self.concat_after: + x_concat = paddle.cat( + (x, self.src_attn(x, memory, memory, memory_mask)), dim=-1) + x = residual + self.concat_linear2(x_concat) + else: + x = residual + self.dropout( + self.src_attn(x, memory, memory, memory_mask)) + if not self.normalize_before: + x = self.norm2(x) + + residual = x + if self.normalize_before: + x = self.norm3(x) + x = residual + self.dropout(self.feed_forward(x)) + if not self.normalize_before: + x = self.norm3(x) + + if cache is not None: + x = paddle.cat([cache, x], dim=1) + + return x, tgt_mask, memory, memory_mask diff --git a/deepspeech/modules/mask.py b/deepspeech/modules/mask.py index 007b18e0a..4351a7cb8 100644 --- a/deepspeech/modules/mask.py +++ b/deepspeech/modules/mask.py @@ -50,8 +50,7 @@ def sequence_mask(x_len, max_len=None, dtype='float32'): return mask -def subsequent_mask( - size: int, ) -> paddle.Tensor: +def subsequent_mask(size: int) -> paddle.Tensor: """Create mask for subsequent steps (size, size). This mask is used only in decoder which works in an auto-regressive mode. This means the current step could only do attention with its left steps.