|
|
|
@ -466,7 +466,7 @@ def phones_masking(xs_pad: paddle.Tensor,
|
|
|
|
|
for s, e in zip(masked_start, masked_end):
|
|
|
|
|
masked_pos[idx, s:e] = 1
|
|
|
|
|
non_eos_mask = paddle.reshape(src_mask, paddle.shape(xs_pad)[:2])
|
|
|
|
|
masked_pos = masked_pos * non_eos_mask
|
|
|
|
|
masked_pos = masked_pos * non_eos_mask.astype(masked_pos.dtype)
|
|
|
|
|
masked_pos = paddle.cast(masked_pos, 'bool')
|
|
|
|
|
|
|
|
|
|
return masked_pos
|
|
|
|
@ -550,10 +550,11 @@ def phones_text_masking(xs_pad: paddle.Tensor,
|
|
|
|
|
for s, e in zip(masked_start, masked_end):
|
|
|
|
|
masked_pos[idx, s:e] = 1
|
|
|
|
|
non_eos_mask = paddle.reshape(src_mask, shape=paddle.shape(xs_pad)[:2])
|
|
|
|
|
masked_pos = masked_pos * non_eos_mask
|
|
|
|
|
masked_pos = masked_pos * non_eos_mask.astype(masked_pos.dtype)
|
|
|
|
|
non_eos_text_mask = paddle.reshape(
|
|
|
|
|
text_mask, shape=paddle.shape(text_pad)[:2])
|
|
|
|
|
text_masked_pos = text_masked_pos * non_eos_text_mask
|
|
|
|
|
text_masked_pos = text_masked_pos * non_eos_text_mask.astype(
|
|
|
|
|
text_masked_pos.dtype)
|
|
|
|
|
masked_pos = paddle.cast(masked_pos, 'bool')
|
|
|
|
|
text_masked_pos = paddle.cast(text_masked_pos, 'bool')
|
|
|
|
|
|
|
|
|
|