|
|
@ -200,7 +200,7 @@ class BaseEncoder(nn.Layer):
|
|
|
|
offset: int,
|
|
|
|
offset: int,
|
|
|
|
required_cache_size: int,
|
|
|
|
required_cache_size: int,
|
|
|
|
att_cache: paddle.Tensor=paddle.zeros([0, 0, 0, 0]),
|
|
|
|
att_cache: paddle.Tensor=paddle.zeros([0, 0, 0, 0]),
|
|
|
|
cnn_cache: paddle.Tensor=paddle.zeros([0, 0, 0, 0]),
|
|
|
|
# cnn_cache: paddle.Tensor=paddle.zeros([0, 0, 0, 0]),
|
|
|
|
att_mask: paddle.Tensor=paddle.ones([0, 0, 0], dtype=paddle.bool)
|
|
|
|
att_mask: paddle.Tensor=paddle.ones([0, 0, 0], dtype=paddle.bool)
|
|
|
|
) -> Tuple[paddle.Tensor, paddle.Tensor, paddle.Tensor]:
|
|
|
|
) -> Tuple[paddle.Tensor, paddle.Tensor, paddle.Tensor]:
|
|
|
|
""" Forward just one chunk
|
|
|
|
""" Forward just one chunk
|
|
|
@ -252,7 +252,7 @@ class BaseEncoder(nn.Layer):
|
|
|
|
next_cache_start = max(attention_key_size - required_cache_size, 0)
|
|
|
|
next_cache_start = max(attention_key_size - required_cache_size, 0)
|
|
|
|
|
|
|
|
|
|
|
|
r_att_cache = []
|
|
|
|
r_att_cache = []
|
|
|
|
r_cnn_cache = []
|
|
|
|
# r_cnn_cache = []
|
|
|
|
for i, layer in enumerate(self.encoders):
|
|
|
|
for i, layer in enumerate(self.encoders):
|
|
|
|
# att_cache[i:i+1] = (1, head, cache_t1, d_k*2)
|
|
|
|
# att_cache[i:i+1] = (1, head, cache_t1, d_k*2)
|
|
|
|
# cnn_cache[i:i+1] = (1, B=1, hidden-dim, cache_t2)
|
|
|
|
# cnn_cache[i:i+1] = (1, B=1, hidden-dim, cache_t2)
|
|
|
@ -262,25 +262,27 @@ class BaseEncoder(nn.Layer):
|
|
|
|
# raw code as below:
|
|
|
|
# raw code as below:
|
|
|
|
# att_cache=att_cache[i:i+1] if elayers > 0 else att_cache,
|
|
|
|
# att_cache=att_cache[i:i+1] if elayers > 0 else att_cache,
|
|
|
|
# cnn_cache=cnn_cache[i:i+1] if cnn_cache.shape[0] > 0 else cnn_cache,
|
|
|
|
# cnn_cache=cnn_cache[i:i+1] if cnn_cache.shape[0] > 0 else cnn_cache,
|
|
|
|
xs, _, new_att_cache, new_cnn_cache = layer(
|
|
|
|
xs, _, new_att_cache = layer(
|
|
|
|
xs,
|
|
|
|
xs,
|
|
|
|
att_mask,
|
|
|
|
att_mask,
|
|
|
|
pos_emb,
|
|
|
|
pos_emb,
|
|
|
|
att_cache=att_cache[i:i + 1],
|
|
|
|
att_cache=att_cache[i:i + 1],
|
|
|
|
cnn_cache=cnn_cache[i:i + 1], )
|
|
|
|
# cnn_cache=cnn_cache[i:i + 1],
|
|
|
|
|
|
|
|
)
|
|
|
|
# new_att_cache = (1, head, attention_key_size, d_k*2)
|
|
|
|
# new_att_cache = (1, head, attention_key_size, d_k*2)
|
|
|
|
# new_cnn_cache = (B=1, hidden-dim, cache_t2)
|
|
|
|
# new_cnn_cache = (B=1, hidden-dim, cache_t2)
|
|
|
|
r_att_cache.append(new_att_cache[:, :, next_cache_start:, :])
|
|
|
|
r_att_cache.append(new_att_cache[:, :, next_cache_start:, :])
|
|
|
|
r_cnn_cache.append(new_cnn_cache) # add elayer dim
|
|
|
|
# r_cnn_cache.append(new_cnn_cache) # add elayer dim
|
|
|
|
|
|
|
|
|
|
|
|
if self.normalize_before:
|
|
|
|
if self.normalize_before:
|
|
|
|
xs = self.after_norm(xs)
|
|
|
|
xs = self.after_norm(xs)
|
|
|
|
|
|
|
|
|
|
|
|
# r_att_cache (elayers, head, T, d_k*2)
|
|
|
|
# r_att_cache (elayers, head, T, d_k*2)
|
|
|
|
# r_cnn_cache (elayers, B=1, hidden-dim, cache_t2)
|
|
|
|
# r_cnn_cache (elayers, B=1, hidden-dim, cache_t2)
|
|
|
|
|
|
|
|
# breakpoint()
|
|
|
|
r_att_cache = paddle.concat(r_att_cache, axis=0)
|
|
|
|
r_att_cache = paddle.concat(r_att_cache, axis=0)
|
|
|
|
r_cnn_cache = paddle.stack(r_cnn_cache, axis=0)
|
|
|
|
# r_cnn_cache = paddle.stack(r_cnn_cache, axis=0)
|
|
|
|
return xs, r_att_cache, r_cnn_cache
|
|
|
|
return xs, r_att_cache#, r_cnn_cache
|
|
|
|
|
|
|
|
|
|
|
|
def forward_chunk_by_chunk(
|
|
|
|
def forward_chunk_by_chunk(
|
|
|
|
self,
|
|
|
|
self,
|
|
|
|