add transformer lm and encoder score api

pull/930/head
Hui Zhang 4 years ago
parent c5f6692191
commit 4566351127

@ -0,0 +1,259 @@
# Copyright (c) 2021 PaddlePaddle Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
from typing import Any
from typing import List
from typing import Tuple
import numpy as np
import paddle
import paddle.nn as nn
import paddle.nn.functional as F
from deepspeech.modules.mask import subsequent_mask
from deepspeech.modules.encoder import TransformerEncoder
from deepspeech.decoders.scorers.scorer_interface import BatchScorerInterface
#LMInterface
class TransformerLM(nn.Layer, BatchScorerInterface):
def __init__(
self,
n_vocab: int,
pos_enc: str=None,
embed_unit: int=128,
att_unit: int=256,
head: int=2,
unit: int=1024,
layer: int=4,
dropout_rate: float=0.5,
emb_dropout_rate: float = 0.0,
att_dropout_rate: float = 0.0,
tie_weights: bool = False,):
nn.Layer.__init__(self)
if pos_enc == "sinusoidal":
pos_enc_layer_type = "abs_pos"
elif pos_enc is None:
#TODO
pos_enc_layer_type = "None"
else:
raise ValueError(f"unknown pos-enc option: {pos_enc}")
self.embed = nn.Embedding(n_vocab, embed_unit)
if emb_dropout_rate == 0.0:
self.embed_drop = None
else:
self.embed_drop = nn.Dropout(emb_dropout_rate)
self.encoder = TransformerEncoder(
input_size=embed_unit,
output_size=att_unit,
attention_heads=head,
linear_units=unit,
num_blocks=layer,
dropout_rate=dropout_rate,
attention_dropout_rate=att_dropout_rate,
input_layer="linear",
pos_enc_layer_type=pos_enc_layer_type,
concat_after=False,
static_chunk_size=1,
use_dynamic_chunk=False,
use_dynamic_left_chunk=False)
self.decoder = nn.Linear(att_unit, n_vocab)
logging.info("Tie weights set to {}".format(tie_weights))
logging.info("Dropout set to {}".format(dropout_rate))
logging.info("Emb Dropout set to {}".format(emb_dropout_rate))
logging.info("Att Dropout set to {}".format(att_dropout_rate))
if tie_weights:
assert (
att_unit == embed_unit
), "Tie Weights: True need embedding and final dimensions to match"
self.decoder.weight = self.embed.weight
def _target_mask(self, ys_in_pad):
ys_mask = ys_in_pad != 0
m = subsequent_mask(ys_mask.size(-1)).unsqueeze(0)
return ys_mask.unsqueeze(-2) & m
def forward(
self, x: paddle.Tensor, xlens, t: paddle.Tensor
) -> Tuple[paddle.Tensor, paddle.Tensor, paddle.Tensor]:
"""Compute LM loss value from buffer sequences.
Args:
x (paddle.Tensor): Input ids. (batch, len)
t (paddle.Tensor): Target ids. (batch, len)
Returns:
tuple[paddle.Tensor, paddle.Tensor, paddle.Tensor]: Tuple of
loss to backward (scalar),
negative log-likelihood of t: -log p(t) (scalar) and
the number of elements in x (scalar)
Notes:
The last two return values are used
in perplexity: p(t)^{-n} = exp(-log p(t) / n)
"""
xm = x != 0
if self.embed_drop is not None:
emb = self.embed_drop(self.embed(x))
else:
emb = self.embed(x)
xlen = xm.sum(axis=1)
h, _ = self.encoder(emb, xlen)
y = self.decoder(h)
loss = F.cross_entropy(y.view(-1, y.shape[-1]), t.view(-1), reduction="none")
mask = xm.to(dtype=loss.dtype)
logp = loss * mask.view(-1)
logp = logp.sum()
count = mask.sum()
return logp / count, logp, count
# beam search API (see ScorerInterface)
def score(self, y: paddle.Tensor, state: Any,
x: paddle.Tensor) -> Tuple[paddle.Tensor, Any]:
"""Score new token.
Args:
y (paddle.Tensor): 1D paddle.int64 prefix tokens.
state: Scorer state for prefix tokens
x (paddle.Tensor): encoder feature that generates ys.
Returns:
tuple[paddle.Tensor, Any]: Tuple of
paddle.float32 scores for next token (n_vocab)
and next state for ys
"""
y = y.unsqueeze(0)
if self.embed_drop is not None:
emb = self.embed_drop(self.embed(y))
else:
emb = self.embed(y)
h, _, cache = self.encoder.forward_one_step(
emb, self._target_mask(y), cache=state
)
h = self.decoder(h[:, -1])
logp = h.log_softmax(axis=-1).squeeze(0)
return logp, cache
# batch beam search API (see BatchScorerInterface)
def batch_score(
self, ys: paddle.Tensor, states: List[Any], xs: paddle.Tensor
) -> Tuple[paddle.Tensor, List[Any]]:
"""Score new token batch (required).
Args:
ys (paddle.Tensor): paddle.int64 prefix tokens (n_batch, ylen).
states (List[Any]): Scorer states for prefix tokens.
xs (paddle.Tensor):
The encoder feature that generates ys (n_batch, xlen, n_feat).
Returns:
tuple[paddle.Tensor, List[Any]]: Tuple of
batchfied scores for next token with shape of `(n_batch, n_vocab)`
and next state list for ys.
"""
# merge states
n_batch = len(ys)
n_layers = len(self.encoder.encoders)
if states[0] is None:
batch_state = None
else:
# transpose state of [batch, layer] into [layer, batch]
batch_state = [
paddle.stack([states[b][i] for b in range(n_batch)])
for i in range(n_layers)
]
if self.embed_drop is not None:
emb = self.embed_drop(self.embed(ys))
else:
emb = self.embed(ys)
# batch decoding
h, _, states = self.encoder.forward_one_step(
emb, self._target_mask(ys), cache=batch_state
)
h = self.decoder(h[:, -1])
logp = h.log_softmax(axi=-1)
# transpose state of [layer, batch] into [batch, layer]
state_list = [[states[i][b] for i in range(n_layers)] for b in range(n_batch)]
return logp, state_list
if __name__ == "__main__":
tlm = TransformerLM(
n_vocab=5002,
pos_enc=None,
embed_unit=128,
att_unit=512,
head=8,
unit=2048,
layer=16,
dropout_rate=0.5, )
# n_vocab: int,
# pos_enc: str=None,
# embed_unit: int=128,
# att_unit: int=256,
# head: int=2,
# unit: int=1024,
# layer: int=4,
# dropout_rate: float=0.5,
# emb_dropout_rate: float = 0.0,
# att_dropout_rate: float = 0.0,
# tie_weights: bool = False,):
paddle.set_device("cpu")
model_dict = paddle.load("transformerLM.pdparams")
tlm.set_state_dict(model_dict)
tlm.eval()
#Test the score
input2 = np.array([5])
input2 = paddle.to_tensor(input2)
state = (None, None, 0)
output, state = tlm.score(input2, state, None)
input3 = np.array([10])
input3 = paddle.to_tensor(input3)
output, state = tlm.score(input3, state, None)
input4 = np.array([0])
input4 = paddle.to_tensor(input4)
output, state = tlm.score(input4, state, None)
print("output", output)
"""
#Test the batch score
batch_size = 2
inp2 = np.array([[5], [10]])
inp2 = paddle.to_tensor(inp2)
output, states = tlm.batch_score(
inp2, [(None,None,0)] * batch_size)
inp3 = np.array([[100], [30]])
inp3 = paddle.to_tensor(inp3)
output, states = tlm.batch_score(
inp3, states)
print("output", output)
#print("cache", cache)
#np.save("output_pd.npy", output)
"""

@ -31,6 +31,7 @@ from deepspeech.modules.encoder_layer import TransformerEncoderLayer
from deepspeech.modules.mask import add_optional_chunk_mask from deepspeech.modules.mask import add_optional_chunk_mask
from deepspeech.modules.mask import make_non_pad_mask from deepspeech.modules.mask import make_non_pad_mask
from deepspeech.modules.positionwise_feed_forward import PositionwiseFeedForward from deepspeech.modules.positionwise_feed_forward import PositionwiseFeedForward
from deepspeech.modules.subsampling import Conv2dSubsampling
from deepspeech.modules.subsampling import Conv2dSubsampling4 from deepspeech.modules.subsampling import Conv2dSubsampling4
from deepspeech.modules.subsampling import Conv2dSubsampling6 from deepspeech.modules.subsampling import Conv2dSubsampling6
from deepspeech.modules.subsampling import Conv2dSubsampling8 from deepspeech.modules.subsampling import Conv2dSubsampling8
@ -370,6 +371,46 @@ class TransformerEncoder(BaseEncoder):
concat_after=concat_after) for _ in range(num_blocks) concat_after=concat_after) for _ in range(num_blocks)
]) ])
def forward_one_step(
self,
xs: paddle.Tensor,
masks: paddle.Tensor,
cache=None,
) -> Tuple[paddle.Tensor, paddle.Tensor]:
"""Encode input frame.
Args:
xs (paddle.Tensor): Input tensor. (B, T, D)
masks (paddle.Tensor): Mask tensor. (B, 1, T)
cache (List[paddle.Tensor]): List of cache tensors.
Returns:
paddle.Tensor: Output tensor.
paddle.Tensor: Mask tensor.
List[paddle.Tensor]: List of new cache tensors.
"""
if self.global_cmvn is not None:
xs = self.global_cmvn(xs)
if isinstance(self.embed, Conv2dSubsampling):
# xs, masks = self.embed(xs, masks)
#TODO(Hui Zhang): self.embed(xs, masks, offset=0), stride_slice not support bool tensor
xs, pos_emb, masks = self.embed(xs, masks.astype(xs.dtype), offset=0)
else:
xs = self.embed(xs)
#TODO(Hui Zhang): remove mask.astype, stride_slice not support bool tensor
masks = masks.astype(paddle.bool)
if cache is None:
cache = [None for _ in range(len(self.encoders))]
new_cache = []
for c, e in zip(cache, self.encoders):
xs, masks, _ = e(xs, masks, output_cache=c)
new_cache.append(xs)
if self.normalize_before:
xs = self.after_norm(xs)
return xs, masks, new_cache
class ConformerEncoder(BaseEncoder): class ConformerEncoder(BaseEncoder):
"""Conformer encoder module.""" """Conformer encoder module."""

@ -71,7 +71,7 @@ class TransformerEncoderLayer(nn.Layer):
self, self,
x: paddle.Tensor, x: paddle.Tensor,
mask: paddle.Tensor, mask: paddle.Tensor,
pos_emb: paddle.Tensor, pos_emb: Optional[paddle.Tensor]=None,
mask_pad: Optional[paddle.Tensor]=None, mask_pad: Optional[paddle.Tensor]=None,
output_cache: Optional[paddle.Tensor]=None, output_cache: Optional[paddle.Tensor]=None,
cnn_cache: Optional[paddle.Tensor]=None, cnn_cache: Optional[paddle.Tensor]=None,
@ -82,8 +82,8 @@ class TransformerEncoderLayer(nn.Layer):
mask (paddle.Tensor): Mask tensor for the input (#batch, time). mask (paddle.Tensor): Mask tensor for the input (#batch, time).
pos_emb (paddle.Tensor): just for interface compatibility pos_emb (paddle.Tensor): just for interface compatibility
to ConformerEncoderLayer to ConformerEncoderLayer
mask_pad (paddle.Tensor): does not used in transformer layer, mask_pad (paddle.Tensor): not used here, it's for interface
just for unified api with conformer. compatibility to ConformerEncoderLayer
output_cache (paddle.Tensor): Cache tensor of the output output_cache (paddle.Tensor): Cache tensor of the output
(#batch, time2, size), time2 < time in x. (#batch, time2, size), time2 < time in x.
cnn_cache (paddle.Tensor): not used here, it's for interface cnn_cache (paddle.Tensor): not used here, it's for interface

@ -82,8 +82,11 @@ class LinearNoSubsampling(BaseSubsampling):
x, pos_emb = self.pos_enc(x, offset) x, pos_emb = self.pos_enc(x, offset)
return x, pos_emb, x_mask return x, pos_emb, x_mask
class Conv2dSubsampling(BaseSubsampling):
def __init__(self, *args, **kwargs):
super().__init__(*args, **kwargs)
class Conv2dSubsampling4(BaseSubsampling): class Conv2dSubsampling4(Conv2dSubsampling):
"""Convolutional 2D subsampling (to 1/4 length).""" """Convolutional 2D subsampling (to 1/4 length)."""
def __init__(self, def __init__(self,
@ -134,7 +137,7 @@ class Conv2dSubsampling4(BaseSubsampling):
return x, pos_emb, x_mask[:, :, :-2:2][:, :, :-2:2] return x, pos_emb, x_mask[:, :, :-2:2][:, :, :-2:2]
class Conv2dSubsampling6(BaseSubsampling): class Conv2dSubsampling6(Conv2dSubsampling):
"""Convolutional 2D subsampling (to 1/6 length).""" """Convolutional 2D subsampling (to 1/6 length)."""
def __init__(self, def __init__(self,
@ -187,7 +190,7 @@ class Conv2dSubsampling6(BaseSubsampling):
return x, pos_emb, x_mask[:, :, :-2:2][:, :, :-4:3] return x, pos_emb, x_mask[:, :, :-2:2][:, :, :-4:3]
class Conv2dSubsampling8(BaseSubsampling): class Conv2dSubsampling8(Conv2dSubsampling):
"""Convolutional 2D subsampling (to 1/8 length).""" """Convolutional 2D subsampling (to 1/8 length)."""
def __init__(self, def __init__(self,

Loading…
Cancel
Save