|
|
@ -175,9 +175,7 @@ class BaseEncoder(nn.Layer):
|
|
|
|
decoding_chunk_size, self.static_chunk_size,
|
|
|
|
decoding_chunk_size, self.static_chunk_size,
|
|
|
|
num_decoding_left_chunks)
|
|
|
|
num_decoding_left_chunks)
|
|
|
|
for layer in self.encoders:
|
|
|
|
for layer in self.encoders:
|
|
|
|
xs, chunk_masks, _, _ = layer(xs, chunk_masks, pos_emb, mask_pad,
|
|
|
|
xs, chunk_masks, _, _ = layer(xs, chunk_masks, pos_emb, mask_pad)
|
|
|
|
paddle.zeros([0, 0, 0, 0]),
|
|
|
|
|
|
|
|
paddle.zeros([0, 0, 0, 0]))
|
|
|
|
|
|
|
|
if self.normalize_before:
|
|
|
|
if self.normalize_before:
|
|
|
|
xs = self.after_norm(xs)
|
|
|
|
xs = self.after_norm(xs)
|
|
|
|
# Here we assume the mask is not changed in encoder layers, so just
|
|
|
|
# Here we assume the mask is not changed in encoder layers, so just
|
|
|
@ -190,9 +188,9 @@ class BaseEncoder(nn.Layer):
|
|
|
|
xs: paddle.Tensor,
|
|
|
|
xs: paddle.Tensor,
|
|
|
|
offset: int,
|
|
|
|
offset: int,
|
|
|
|
required_cache_size: int,
|
|
|
|
required_cache_size: int,
|
|
|
|
att_cache: paddle.Tensor, # paddle.zeros([0,0,0,0])
|
|
|
|
att_cache: paddle.Tensor=paddle.zeros([0, 0, 0, 0]),
|
|
|
|
cnn_cache: paddle.Tensor, # paddle.zeros([0,0,0,0]),
|
|
|
|
cnn_cache: paddle.Tensor=paddle.zeros([0, 0, 0, 0]),
|
|
|
|
att_mask: paddle.Tensor, # paddle.ones([0,0,0], dtype=paddle.bool)
|
|
|
|
att_mask: paddle.Tensor=paddle.ones([0, 0, 0], dtype=paddle.bool)
|
|
|
|
) -> Tuple[paddle.Tensor, paddle.Tensor, paddle.Tensor]:
|
|
|
|
) -> Tuple[paddle.Tensor, paddle.Tensor, paddle.Tensor]:
|
|
|
|
""" Forward just one chunk
|
|
|
|
""" Forward just one chunk
|
|
|
|
Args:
|
|
|
|
Args:
|
|
|
@ -255,7 +253,6 @@ class BaseEncoder(nn.Layer):
|
|
|
|
xs,
|
|
|
|
xs,
|
|
|
|
att_mask,
|
|
|
|
att_mask,
|
|
|
|
pos_emb,
|
|
|
|
pos_emb,
|
|
|
|
mask_pad=paddle.ones([0, 0, 0], dtype=paddle.bool),
|
|
|
|
|
|
|
|
att_cache=att_cache[i:i + 1] if elayers > 0 else att_cache,
|
|
|
|
att_cache=att_cache[i:i + 1] if elayers > 0 else att_cache,
|
|
|
|
cnn_cache=cnn_cache[i:i + 1]
|
|
|
|
cnn_cache=cnn_cache[i:i + 1]
|
|
|
|
if paddle.shape(cnn_cache)[0] > 0 else cnn_cache, )
|
|
|
|
if paddle.shape(cnn_cache)[0] > 0 else cnn_cache, )
|
|
|
@ -328,8 +325,7 @@ class BaseEncoder(nn.Layer):
|
|
|
|
chunk_xs = xs[:, cur:end, :]
|
|
|
|
chunk_xs = xs[:, cur:end, :]
|
|
|
|
|
|
|
|
|
|
|
|
(y, att_cache, cnn_cache) = self.forward_chunk(
|
|
|
|
(y, att_cache, cnn_cache) = self.forward_chunk(
|
|
|
|
chunk_xs, offset, required_cache_size, att_cache, cnn_cache,
|
|
|
|
chunk_xs, offset, required_cache_size, att_cache, cnn_cache)
|
|
|
|
paddle.ones([0, 0, 0], dtype=paddle.bool))
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
outputs.append(y)
|
|
|
|
outputs.append(y)
|
|
|
|
offset += y.shape[1]
|
|
|
|
offset += y.shape[1]
|
|
|
|