|
|
|
@ -76,9 +76,9 @@ class TransformerEncoderLayer(nn.Layer):
|
|
|
|
|
x: paddle.Tensor,
|
|
|
|
|
mask: paddle.Tensor,
|
|
|
|
|
pos_emb: paddle.Tensor,
|
|
|
|
|
mask_pad: paddle.Tensor= paddle.ones([0,0,0], dtype=paddle.bool),
|
|
|
|
|
att_cache: paddle.Tensor=paddle.zeros([0,0,0,0]),
|
|
|
|
|
cnn_cache: paddle.Tensor=paddle.zeros([0,0,0,0]),
|
|
|
|
|
mask_pad: paddle.Tensor=paddle.ones([0, 0, 0], dtype=paddle.bool),
|
|
|
|
|
att_cache: paddle.Tensor=paddle.zeros([0, 0, 0, 0]),
|
|
|
|
|
cnn_cache: paddle.Tensor=paddle.zeros([0, 0, 0, 0]),
|
|
|
|
|
) -> Tuple[paddle.Tensor, paddle.Tensor, paddle.Tensor, paddle.Tensor]:
|
|
|
|
|
"""Compute encoded features.
|
|
|
|
|
Args:
|
|
|
|
@ -105,9 +105,7 @@ class TransformerEncoderLayer(nn.Layer):
|
|
|
|
|
if self.normalize_before:
|
|
|
|
|
x = self.norm1(x)
|
|
|
|
|
|
|
|
|
|
x_att, new_att_cache = self.self_attn(
|
|
|
|
|
x, x, x, mask, cache=att_cache
|
|
|
|
|
)
|
|
|
|
|
x_att, new_att_cache = self.self_attn(x, x, x, mask, cache=att_cache)
|
|
|
|
|
|
|
|
|
|
if self.concat_after:
|
|
|
|
|
x_concat = paddle.concat((x, x_att), axis=-1)
|
|
|
|
@ -124,7 +122,7 @@ class TransformerEncoderLayer(nn.Layer):
|
|
|
|
|
if not self.normalize_before:
|
|
|
|
|
x = self.norm2(x)
|
|
|
|
|
|
|
|
|
|
fake_cnn_cache = paddle.zeros([0,0,0], dtype=x.dtype)
|
|
|
|
|
fake_cnn_cache = paddle.zeros([0, 0, 0], dtype=x.dtype)
|
|
|
|
|
return x, mask, new_att_cache, fake_cnn_cache
|
|
|
|
|
|
|
|
|
|
|
|
|
|
@ -195,9 +193,9 @@ class ConformerEncoderLayer(nn.Layer):
|
|
|
|
|
x: paddle.Tensor,
|
|
|
|
|
mask: paddle.Tensor,
|
|
|
|
|
pos_emb: paddle.Tensor,
|
|
|
|
|
mask_pad: paddle.Tensor= paddle.ones([0,0,0], dtype=paddle.bool),
|
|
|
|
|
att_cache: paddle.Tensor=paddle.zeros([0,0,0,0]),
|
|
|
|
|
cnn_cache: paddle.Tensor=paddle.zeros([0,0,0,0]),
|
|
|
|
|
mask_pad: paddle.Tensor=paddle.ones([0, 0, 0], dtype=paddle.bool),
|
|
|
|
|
att_cache: paddle.Tensor=paddle.zeros([0, 0, 0, 0]),
|
|
|
|
|
cnn_cache: paddle.Tensor=paddle.zeros([0, 0, 0, 0]),
|
|
|
|
|
) -> Tuple[paddle.Tensor, paddle.Tensor, paddle.Tensor, paddle.Tensor]:
|
|
|
|
|
"""Compute encoded features.
|
|
|
|
|
Args:
|
|
|
|
@ -211,7 +209,8 @@ class ConformerEncoderLayer(nn.Layer):
|
|
|
|
|
att_cache (paddle.Tensor): Cache tensor of the KEY & VALUE
|
|
|
|
|
(#batch=1, head, cache_t1, d_k * 2), head * d_k == size.
|
|
|
|
|
cnn_cache (paddle.Tensor): Convolution cache in conformer layer
|
|
|
|
|
(#batch=1, size, cache_t2)
|
|
|
|
|
(1, #batch=1, size, cache_t2). First dim will not be used, just
|
|
|
|
|
for dy2st.
|
|
|
|
|
Returns:
|
|
|
|
|
paddle.Tensor: Output tensor (#batch, time, size).
|
|
|
|
|
paddle.Tensor: Mask tensor (#batch, time, time).
|
|
|
|
@ -219,6 +218,8 @@ class ConformerEncoderLayer(nn.Layer):
|
|
|
|
|
(#batch=1, head, cache_t1 + time, d_k * 2).
|
|
|
|
|
paddle.Tensor: cnn_cahce tensor (#batch, size, cache_t2).
|
|
|
|
|
"""
|
|
|
|
|
# (1, #batch=1, size, cache_t2) -> (#batch=1, size, cache_t2)
|
|
|
|
|
cnn_cache = paddle.squeeze(cnn_cache, axis=0)
|
|
|
|
|
|
|
|
|
|
# whether to use macaron style FFN
|
|
|
|
|
if self.feed_forward_macaron is not None:
|
|
|
|
@ -249,8 +250,7 @@ class ConformerEncoderLayer(nn.Layer):
|
|
|
|
|
|
|
|
|
|
# convolution module
|
|
|
|
|
# Fake new cnn cache here, and then change it in conv_module
|
|
|
|
|
new_cnn_cache = paddle.zeros([0,0,0], dtype=x.dtype)
|
|
|
|
|
cnn_cache = paddle.squeeze(cnn_cache, axis=0)
|
|
|
|
|
new_cnn_cache = paddle.zeros([0, 0, 0], dtype=x.dtype)
|
|
|
|
|
if self.conv_module is not None:
|
|
|
|
|
residual = x
|
|
|
|
|
if self.normalize_before:
|
|
|
|
@ -275,4 +275,4 @@ class ConformerEncoderLayer(nn.Layer):
|
|
|
|
|
if self.conv_module is not None:
|
|
|
|
|
x = self.norm_final(x)
|
|
|
|
|
|
|
|
|
|
return x, mask, new_att_cache, new_cnn_cache
|
|
|
|
|
return x, mask, new_att_cache, new_cnn_cache
|
|
|
|
|