|
|
@ -69,8 +69,7 @@ def make_non_pad_mask(lengths: paddle.Tensor) -> paddle.Tensor:
|
|
|
|
[1, 1, 1, 0, 0],
|
|
|
|
[1, 1, 1, 0, 0],
|
|
|
|
[1, 1, 0, 0, 0]]
|
|
|
|
[1, 1, 0, 0, 0]]
|
|
|
|
"""
|
|
|
|
"""
|
|
|
|
#TODO(Hui Zhang): return ~make_pad_mask(lengths), not support ~
|
|
|
|
return ~make_pad_mask(lengths)
|
|
|
|
return make_pad_mask(lengths).logical_not()
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
def subsequent_mask(size: int) -> paddle.Tensor:
|
|
|
|
def subsequent_mask(size: int) -> paddle.Tensor:
|
|
|
@ -92,12 +91,7 @@ def subsequent_mask(size: int) -> paddle.Tensor:
|
|
|
|
[1, 1, 1]]
|
|
|
|
[1, 1, 1]]
|
|
|
|
"""
|
|
|
|
"""
|
|
|
|
ret = paddle.ones([size, size], dtype=paddle.bool)
|
|
|
|
ret = paddle.ones([size, size], dtype=paddle.bool)
|
|
|
|
#TODO(Hui Zhang): tril not support bool
|
|
|
|
return paddle.tril(ret)
|
|
|
|
#return paddle.tril(ret)
|
|
|
|
|
|
|
|
ret = ret.astype(paddle.float)
|
|
|
|
|
|
|
|
ret = paddle.tril(ret)
|
|
|
|
|
|
|
|
ret = ret.astype(paddle.bool)
|
|
|
|
|
|
|
|
return ret
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
def subsequent_chunk_mask(
|
|
|
|
def subsequent_chunk_mask(
|
|
|
@ -186,15 +180,13 @@ def add_optional_chunk_mask(xs: paddle.Tensor,
|
|
|
|
chunk_masks = subsequent_chunk_mask(xs.shape[1], chunk_size,
|
|
|
|
chunk_masks = subsequent_chunk_mask(xs.shape[1], chunk_size,
|
|
|
|
num_left_chunks) # (L, L)
|
|
|
|
num_left_chunks) # (L, L)
|
|
|
|
chunk_masks = chunk_masks.unsqueeze(0) # (1, L, L)
|
|
|
|
chunk_masks = chunk_masks.unsqueeze(0) # (1, L, L)
|
|
|
|
# chunk_masks = masks & chunk_masks # (B, L, L)
|
|
|
|
chunk_masks = masks & chunk_masks # (B, L, L)
|
|
|
|
chunk_masks = masks.logical_and(chunk_masks) # (B, L, L)
|
|
|
|
|
|
|
|
elif static_chunk_size > 0:
|
|
|
|
elif static_chunk_size > 0:
|
|
|
|
num_left_chunks = num_decoding_left_chunks
|
|
|
|
num_left_chunks = num_decoding_left_chunks
|
|
|
|
chunk_masks = subsequent_chunk_mask(xs.shape[1], static_chunk_size,
|
|
|
|
chunk_masks = subsequent_chunk_mask(xs.shape[1], static_chunk_size,
|
|
|
|
num_left_chunks) # (L, L)
|
|
|
|
num_left_chunks) # (L, L)
|
|
|
|
chunk_masks = chunk_masks.unsqueeze(0) # (1, L, L)
|
|
|
|
chunk_masks = chunk_masks.unsqueeze(0) # (1, L, L)
|
|
|
|
# chunk_masks = masks & chunk_masks # (B, L, L)
|
|
|
|
chunk_masks = masks & chunk_masks # (B, L, L)
|
|
|
|
chunk_masks = masks.logical_and(chunk_masks) # (B, L, L)
|
|
|
|
|
|
|
|
else:
|
|
|
|
else:
|
|
|
|
chunk_masks = masks
|
|
|
|
chunk_masks = masks
|
|
|
|
return chunk_masks
|
|
|
|
return chunk_masks
|
|
|
|