|
|
@ -250,11 +250,11 @@ class BaseEncoder(nn.Layer):
|
|
|
|
r_cnn_cache = []
|
|
|
|
r_cnn_cache = []
|
|
|
|
for i, layer in enumerate(self.encoders):
|
|
|
|
for i, layer in enumerate(self.encoders):
|
|
|
|
# att_cache[i:i+1] = (1, head, cache_t1, d_k*2)
|
|
|
|
# att_cache[i:i+1] = (1, head, cache_t1, d_k*2)
|
|
|
|
# cnn_cache[i] = (B=1, hidden-dim, cache_t2)
|
|
|
|
# cnn_cache[i:i+1] = (1, B=1, hidden-dim, cache_t2)
|
|
|
|
xs, _, new_att_cache, new_cnn_cache = layer(
|
|
|
|
xs, _, new_att_cache, new_cnn_cache = layer(
|
|
|
|
xs, att_mask, pos_emb,
|
|
|
|
xs, att_mask, pos_emb,
|
|
|
|
att_cache=att_cache[i:i+1] if elayers > 0 else att_cache,
|
|
|
|
att_cache=att_cache[i:i+1] if elayers > 0 else att_cache,
|
|
|
|
cnn_cache=cnn_cache[i] if paddle.shape(cnn_cache)[0] > 0 else cnn_cache,
|
|
|
|
cnn_cache=cnn_cache[i:i+1] if paddle.shape(cnn_cache)[0] > 0 else cnn_cache,
|
|
|
|
)
|
|
|
|
)
|
|
|
|
# new_att_cache = (1, head, attention_key_size, d_k*2)
|
|
|
|
# new_att_cache = (1, head, attention_key_size, d_k*2)
|
|
|
|
# new_cnn_cache = (B=1, hidden-dim, cache_t2)
|
|
|
|
# new_cnn_cache = (B=1, hidden-dim, cache_t2)
|
|
|
|