|
|
|
@ -129,7 +129,7 @@ def _compute_mask_indices(
|
|
|
|
|
[sequence_length for _ in range(batch_size)])
|
|
|
|
|
|
|
|
|
|
# SpecAugment mask to fill
|
|
|
|
|
spec_aug_mask = np.zeros((batch_size, sequence_length), dtype=bool)
|
|
|
|
|
spec_aug_mask = np.zeros((batch_size, sequence_length), dtype=np.bool_)
|
|
|
|
|
spec_aug_mask_idxs = []
|
|
|
|
|
|
|
|
|
|
max_num_masked_span = compute_num_masked_span(sequence_length)
|
|
|
|
@ -207,9 +207,9 @@ def _sample_negative_indices(features_shape: Tuple,
|
|
|
|
|
sampled_negative_indices = np.zeros(
|
|
|
|
|
shape=(batch_size, sequence_length, num_negatives), dtype=np.int32)
|
|
|
|
|
|
|
|
|
|
mask_time_indices = (mask_time_indices.astype(bool)
|
|
|
|
|
mask_time_indices = (mask_time_indices.astype(np.bool_)
|
|
|
|
|
if mask_time_indices is not None else
|
|
|
|
|
np.ones(features_shape, dtype=bool))
|
|
|
|
|
np.ones(features_shape, dtype=np.bool_))
|
|
|
|
|
|
|
|
|
|
for batch_idx in range(batch_size):
|
|
|
|
|
high = mask_time_indices[batch_idx].sum() - 1
|
|
|
|
|