Add the None for pos_enc and feed_one_step

pull/925/head
huangyuxin 4 years ago
parent 3ee6aed57d
commit ef959bb49d

@ -22,6 +22,8 @@ import paddle.nn.functional as F
from deepspeech.modules.encoder import TransformerEncoder
#LMInterface, BatchScorerInterface
class TransformerLM(nn.Layer):
def __init__(
@ -39,7 +41,7 @@ class TransformerLM(nn.Layer):
pos_enc_layer_type = "abs_pos"
elif pos_enc is None:
#TODO
raise ValueError(f"unknown pos-enc option: {pos_enc}")
pos_enc_layer_type = "None"
else:
raise ValueError(f"unknown pos-enc option: {pos_enc}")
@ -66,8 +68,15 @@ class TransformerLM(nn.Layer):
model_dict = paddle.load("transformerLM.pdparams")
self.set_state_dict(model_dict)
def _target_len(self, ys_in_pad):
ys_len_tmp = paddle.where(
paddle.to_tensor(ys_in_pad != 0),
paddle.ones_like(ys_in_pad), paddle.zeros_like(ys_in_pad))
ys_len = paddle.sum(ys_len_tmp, axis=-1)
return ys_len
def forward(self, input: paddle.Tensor,
hidden: None) -> Tuple[paddle.Tensor, None]:
x_len: paddle.Tensor) -> Tuple[paddle.Tensor, None]:
x = self.embed(input)
x_len = self._target_len(input)
@ -75,61 +84,46 @@ class TransformerLM(nn.Layer):
y = self.decoder(h)
return y, None
def score(
self,
y: paddle.Tensor,
subsampling_cache,
state: Any,
offset: int, ) -> Tuple[paddle.Tensor, Any]:
def score(self, y: paddle.Tensor, state: Any,
x: paddle.Tensor) -> Tuple[paddle.Tensor, Any]:
# y, the chunk input
y = y.unsqueeze(0)
subsampling_cache = subsampling_cache
conformer_cnn_cache = None
elayers_output_cache = state
#subsampling_cache, elayers_output_cache, conformer_cnn_cache, offset = state
required_cache_size = -1
y = self.embed(y)
h, r_subsampling_cache, r_elayers_output_cache, r_conformer_cnn_cache = self.encoder.forward_chunk(
y, offset, required_cache_size, subsampling_cache,
elayers_output_cache, conformer_cnn_cache)
h, state = self.encoder.forward_one_step(y, required_cache_size, state)
h = self.decoder(h[:, -1])
logp = F.log_softmax(h).squeeze(0)
return h, r_subsampling_cache, r_elayers_output_cache
return h, state
def batch_score(
self,
ys: paddle.Tensor,
subsampling_caches: List[Any],
encoder_states: List[Any],
offset: int, ) -> Tuple[paddle.Tensor, List[Any]]:
states: List[Any], ) -> Tuple[paddle.Tensor, List[Any]]:
#ys, the batch chunk input
n_batch = ys.shape[0]
n_layers = len(self.encoder.encoders)
hs = []
new_subsampling_states = []
new_encoder_states = []
new_states = []
for i in range(n_batch):
y = ys[i:i + 1, :]
subsampling_cache = subsampling_caches[i]
elayers_output_cache = encoder_states[i]
conformer_cnn_cache = None
state = states[i]
required_cache_size = -1
y = self.embed(y)
h, r_subsampling_cache, r_elayers_output_cache, r_conformer_cnn_cache = self.encoder.forward_chunk(
y, offset, required_cache_size, subsampling_cache,
elayers_output_cache, conformer_cnn_cache)
h, state = self.encoder.forward_one_step(y, required_cache_size,
state)
h = self.decoder(h[:, -1])
hs.append(h)
new_subsampling_states.append(r_subsampling_cache)
new_encoder_states.append(r_elayers_output_cache)
new_states.append(state)
hs = paddle.concat(hs, axis=0)
hs = F.log_softmax(hs)
return hs, new_subsampling_states, new_encoder_states
return hs, new_states
if __name__ == "__main__":
tlm = TransformerLM(
vocab_size=5002,
pos_enc='sinusoidal',
pos_enc=None,
embed_unit=128,
att_unit=512,
head=8,
@ -139,34 +133,33 @@ if __name__ == "__main__":
paddle.set_device("cpu")
tlm.eval()
"""
#Test the score
input2 = np.array([5])
input2 = paddle.to_tensor(input2)
output, sub_cache, cache =tlm.score(input2, None, None, 0)
state = (None, None, 0)
output, state = tlm.score(input2, state, None)
input3 = np.array([10])
input3 = paddle.to_tensor(input3)
output, sub_cache, cache = tlm.score(input3, sub_cache, cache, 1)
output, state = tlm.score(input3, state, None)
input4 = np.array([7])
input4 = np.array([0])
input4 = paddle.to_tensor(input4)
output, sub_cache, cache = tlm.score(input4, sub_cache, cache, 2)
output, state = tlm.score(input4, state, None)
print("output", output)
"""
#Test the batch score
batch_size = 2
offset = 0
inp2 = np.array([[5], [10]])
inp2 = paddle.to_tensor(inp2)
output, subsampling_caches, encoder_caches = tlm.batch_score(
inp2, [None] * batch_size, [None] * batch_size, offset)
output, states = tlm.batch_score(
inp2, [(None,None,0)] * batch_size)
offset += 1
inp3 = np.array([[100], [30]])
inp3 = paddle.to_tensor(inp3)
output, subsampling_caches, encoder_caches = tlm.batch_score(
inp3, subsampling_caches, encoder_caches, offset)
output, states = tlm.batch_score(
inp3, states)
print("output", output)
#print("cache", cache)
#np.save("output_pd.npy", output)
"""

@ -25,6 +25,22 @@ logger = Log(__name__).getlog()
__all__ = ["PositionalEncoding", "RelPositionalEncoding"]
class NoPositionalEncoding(nn.Layer):
def __init__(self,
d_model: int,
dropout_rate: float,
max_len: int=5000,
reverse: bool=False):
super().__init__()
def forward(self, x: paddle.Tensor,
offset: int=0) -> Tuple[paddle.Tensor, paddle.Tensor]:
return x, None
def position_encoding(self, offset: int, size: int) -> paddle.Tensor:
return None
class PositionalEncoding(nn.Layer):
def __init__(self,
d_model: int,

@ -24,6 +24,7 @@ from deepspeech.modules.activation import get_activation
from deepspeech.modules.attention import MultiHeadedAttention
from deepspeech.modules.attention import RelPositionMultiHeadedAttention
from deepspeech.modules.conformer_convolution import ConvolutionModule
from deepspeech.modules.embedding import NoPositionalEncoding
from deepspeech.modules.embedding import PositionalEncoding
from deepspeech.modules.embedding import RelPositionalEncoding
from deepspeech.modules.encoder_layer import ConformerEncoderLayer
@ -101,6 +102,8 @@ class BaseEncoder(nn.Layer):
pos_enc_class = PositionalEncoding
elif pos_enc_layer_type == "rel_pos":
pos_enc_class = RelPositionalEncoding
elif pos_enc_layer_type == "None":
pos_enc_class = NoPositionalEncoding
else:
raise ValueError("unknown pos_enc_layer: " + pos_enc_layer_type)
@ -155,11 +158,11 @@ class BaseEncoder(nn.Layer):
encoder output tensor, lens and mask
"""
masks = make_non_pad_mask(xs_lens).unsqueeze(1) # (B, 1, L)
if self.global_cmvn is not None:
xs = self.global_cmvn(xs)
#TODO(Hui Zhang): self.embed(xs, masks, offset=0), stride_slice not support bool tensor
xs, pos_emb, masks = self.embed(xs, masks.astype(xs.dtype), offset=0)
#print("xs", xs)
#TODO(Hui Zhang): remove mask.astype, stride_slice not support bool tensor
masks = masks.astype(paddle.bool)
#TODO(Hui Zhang): mask_pad = ~masks
@ -168,8 +171,15 @@ class BaseEncoder(nn.Layer):
xs, masks, self.use_dynamic_chunk, self.use_dynamic_left_chunk,
decoding_chunk_size, self.static_chunk_size,
num_decoding_left_chunks)
#print ("chunk_masks", chunk_masks)
i = 0
for layer in self.encoders:
if i == 3:
xs, chunk_masks, _ = layer(
xs, chunk_masks, pos_emb, mask_pad, is_print=True)
else:
xs, chunk_masks, _ = layer(xs, chunk_masks, pos_emb, mask_pad)
i += 1
if self.normalize_before:
xs = self.after_norm(xs)
# Here we assume the mask is not changed in encoder layers, so just
@ -248,6 +258,8 @@ class BaseEncoder(nn.Layer):
i]
cnn_cache = None if conformer_cnn_cache is None else conformer_cnn_cache[
i]
#print ("i", i)
#print ("xs", xs)
xs, _, new_cnn_cache = layer(
xs,
masks,
@ -370,6 +382,80 @@ class TransformerEncoder(BaseEncoder):
concat_after=concat_after) for _ in range(num_blocks)
])
def forward_one_step(
self,
xs: paddle.Tensor,
required_cache_size: int,
state=(None, None, 0),
) -> Tuple[paddle.Tensor, paddle.Tensor, List[paddle.Tensor], List[
paddle.Tensor]]:
""" Forward just one chunk
Args:
xs (paddle.Tensor): chunk input, [B=1, T, D]
offset (int): current offset in encoder output time stamp
required_cache_size (int): cache size required for next chunk
compuation
>=0: actual cache size
<0: means all history cache is required
subsampling_cache (Optional[paddle.Tensor]): subsampling cache
elayers_output_cache (Optional[List[paddle.Tensor]]):
transformer/conformer encoder layers output cache
conformer_cnn_cache (Optional[List[paddle.Tensor]]): conformer
cnn cache
Returns:
paddle.Tensor: output of current input xs
paddle.Tensor: subsampling cache required for next chunk computation
List[paddle.Tensor]: encoder layers output cache required for next
chunk computation
List[paddle.Tensor]: conformer cnn cache
"""
assert xs.shape[0] == 1 # batch size must be one
# tmp_masks is just for interface compatibility
# TODO(Hui Zhang): stride_slice not support bool tensor
# tmp_masks = paddle.ones([1, xs.size(1)], dtype=paddle.bool)
subsampling_cache, elayers_output_cache, offset = state
tmp_masks = paddle.ones([1, xs.shape[1]], dtype=paddle.int32)
tmp_masks = tmp_masks.unsqueeze(1) #[B=1, C=1, T]
if self.global_cmvn is not None:
xs = self.global_cmvn(xs)
xs, pos_emb, _ = self.embed(
xs, tmp_masks, offset=offset) #xs=(B, T, D), pos_emb=(B=1, T, D)
if subsampling_cache is not None:
cache_size = subsampling_cache.shape[1] #T
xs = paddle.cat((subsampling_cache, xs), dim=1)
else:
cache_size = 0
# only used when using `RelPositionMultiHeadedAttention`
pos_emb = self.embed.position_encoding(
offset=offset - cache_size, size=xs.shape[1])
if required_cache_size < 0:
next_cache_start = 0
elif required_cache_size == 0:
next_cache_start = xs.shape[1]
else:
next_cache_start = xs.shape[1] - required_cache_size
r_subsampling_cache = xs[:, next_cache_start:, :]
# Real mask for transformer/conformer layers
masks = paddle.ones([1, xs.shape[1]], dtype=paddle.bool)
masks = masks.unsqueeze(1) #[B=1, L'=1, T]
r_elayers_output_cache = []
for i, layer in enumerate(self.encoders):
attn_cache = None if elayers_output_cache is None else elayers_output_cache[
i]
xs, _, _ = layer(
xs, masks, pos_emb, output_cache=attn_cache, cnn_cache=None)
r_elayers_output_cache.append(xs[:, next_cache_start:, :])
if self.normalize_before:
xs = self.after_norm(xs)
new_state = (r_subsampling_cache, r_elayers_output_cache, offset + 1)
return (xs[:, cache_size:, :], new_state)
class ConformerEncoder(BaseEncoder):
"""Conformer encoder module."""

Loading…
Cancel
Save