|
|
|
@ -18,40 +18,11 @@ from deepspeech.utils.log import Log
|
|
|
|
|
logger = Log(__name__).getlog()
|
|
|
|
|
|
|
|
|
|
__all__ = [
|
|
|
|
|
'sequence_mask', "make_pad_mask", "make_non_pad_mask", "subsequent_mask",
|
|
|
|
|
"make_pad_mask", "make_non_pad_mask", "subsequent_mask",
|
|
|
|
|
"subsequent_chunk_mask", "add_optional_chunk_mask", "mask_finished_scores",
|
|
|
|
|
"mask_finished_preds"
|
|
|
|
|
]
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
def sequence_mask(x_len, max_len=None, dtype='float32'):
|
|
|
|
|
"""batch sequence mask.
|
|
|
|
|
|
|
|
|
|
Args:
|
|
|
|
|
x_len ([paddle.Tensor]): xs lenght, [B]
|
|
|
|
|
max_len ([type], optional): max sequence length. Defaults to None.
|
|
|
|
|
dtype (str, optional): mask data type. Defaults to 'float32'.
|
|
|
|
|
|
|
|
|
|
Returns:
|
|
|
|
|
paddle.Tensor: [B, Tmax]
|
|
|
|
|
|
|
|
|
|
Examples:
|
|
|
|
|
>>> sequence_mask([2, 4])
|
|
|
|
|
[[1., 1., 0., 0.],
|
|
|
|
|
[1., 1., 1., 1.]]
|
|
|
|
|
"""
|
|
|
|
|
# (TODO: Hui Zhang): jit not support Tenosr.dim() and Tensor.ndim
|
|
|
|
|
# assert x_len.dim() == 1, (x_len.dim(), x_len)
|
|
|
|
|
max_len = max_len or x_len.max()
|
|
|
|
|
x_len = paddle.unsqueeze(x_len, -1)
|
|
|
|
|
row_vector = paddle.arange(max_len)
|
|
|
|
|
# TODO(Hui Zhang): fix this bug
|
|
|
|
|
#mask = row_vector < x_len
|
|
|
|
|
mask = row_vector > x_len # a bug, broadcast 的时候出错了
|
|
|
|
|
mask = paddle.cast(mask, dtype)
|
|
|
|
|
return mask
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
def make_pad_mask(lengths: paddle.Tensor) -> paddle.Tensor:
|
|
|
|
|
"""Make mask tensor containing indices of padded part.
|
|
|
|
|
See description of make_non_pad_mask.
|
|
|
|
@ -66,7 +37,8 @@ def make_pad_mask(lengths: paddle.Tensor) -> paddle.Tensor:
|
|
|
|
|
[0, 0, 0, 1, 1],
|
|
|
|
|
[0, 0, 1, 1, 1]]
|
|
|
|
|
"""
|
|
|
|
|
assert lengths.dim() == 1
|
|
|
|
|
# (TODO: Hui Zhang): jit not support Tenosr.dim() and Tensor.ndim
|
|
|
|
|
# assert lengths.dim() == 1
|
|
|
|
|
batch_size = int(lengths.shape[0])
|
|
|
|
|
max_len = int(lengths.max())
|
|
|
|
|
seq_range = paddle.arange(0, max_len, dtype=paddle.int64)
|
|
|
|
|