|
|
@ -76,9 +76,9 @@ class TransformerEncoderLayer(nn.Layer):
|
|
|
|
x: paddle.Tensor,
|
|
|
|
x: paddle.Tensor,
|
|
|
|
mask: paddle.Tensor,
|
|
|
|
mask: paddle.Tensor,
|
|
|
|
pos_emb: paddle.Tensor,
|
|
|
|
pos_emb: paddle.Tensor,
|
|
|
|
mask_pad: paddle.Tensor=paddle.ones([0, 0, 0], dtype=paddle.bool),
|
|
|
|
mask_pad: paddle.Tensor,
|
|
|
|
att_cache: paddle.Tensor=paddle.zeros([0, 0, 0, 0]),
|
|
|
|
att_cache: paddle.Tensor,
|
|
|
|
cnn_cache: paddle.Tensor=paddle.zeros([0, 0, 0, 0]),
|
|
|
|
cnn_cache: paddle.Tensor,
|
|
|
|
) -> Tuple[paddle.Tensor, paddle.Tensor, paddle.Tensor, paddle.Tensor]:
|
|
|
|
) -> Tuple[paddle.Tensor, paddle.Tensor, paddle.Tensor, paddle.Tensor]:
|
|
|
|
"""Compute encoded features.
|
|
|
|
"""Compute encoded features.
|
|
|
|
Args:
|
|
|
|
Args:
|
|
|
@ -105,7 +105,7 @@ class TransformerEncoderLayer(nn.Layer):
|
|
|
|
if self.normalize_before:
|
|
|
|
if self.normalize_before:
|
|
|
|
x = self.norm1(x)
|
|
|
|
x = self.norm1(x)
|
|
|
|
|
|
|
|
|
|
|
|
x_att, new_att_cache = self.self_attn(x, x, x, mask, cache=att_cache)
|
|
|
|
x_att, new_att_cache = self.self_attn(x, x, x, mask, paddle.empty([0]), cache=att_cache)
|
|
|
|
|
|
|
|
|
|
|
|
if self.concat_after:
|
|
|
|
if self.concat_after:
|
|
|
|
x_concat = paddle.concat((x, x_att), axis=-1)
|
|
|
|
x_concat = paddle.concat((x, x_att), axis=-1)
|
|
|
@ -193,9 +193,9 @@ class ConformerEncoderLayer(nn.Layer):
|
|
|
|
x: paddle.Tensor,
|
|
|
|
x: paddle.Tensor,
|
|
|
|
mask: paddle.Tensor,
|
|
|
|
mask: paddle.Tensor,
|
|
|
|
pos_emb: paddle.Tensor,
|
|
|
|
pos_emb: paddle.Tensor,
|
|
|
|
mask_pad: paddle.Tensor=paddle.ones([0, 0, 0], dtype=paddle.bool),
|
|
|
|
mask_pad: paddle.Tensor,
|
|
|
|
att_cache: paddle.Tensor=paddle.zeros([0, 0, 0, 0]),
|
|
|
|
att_cache: paddle.Tensor,
|
|
|
|
cnn_cache: paddle.Tensor=paddle.zeros([0, 0, 0, 0]),
|
|
|
|
cnn_cache: paddle.Tensor,
|
|
|
|
) -> Tuple[paddle.Tensor, paddle.Tensor, paddle.Tensor, paddle.Tensor]:
|
|
|
|
) -> Tuple[paddle.Tensor, paddle.Tensor, paddle.Tensor, paddle.Tensor]:
|
|
|
|
"""Compute encoded features.
|
|
|
|
"""Compute encoded features.
|
|
|
|
Args:
|
|
|
|
Args:
|
|
|
|