|
|
@ -743,7 +743,7 @@ class SpecAugment(paddle.nn.Layer):
|
|
|
|
|
|
|
|
|
|
|
|
time = x.shape[2]
|
|
|
|
time = x.shape[2]
|
|
|
|
if time - window <= window:
|
|
|
|
if time - window <= window:
|
|
|
|
return x.view(*original_size)
|
|
|
|
return x.reshape([*original_size])
|
|
|
|
|
|
|
|
|
|
|
|
# compute center and corresponding window
|
|
|
|
# compute center and corresponding window
|
|
|
|
c = paddle.randint(window, time - window, (1, ))[0]
|
|
|
|
c = paddle.randint(window, time - window, (1, ))[0]
|
|
|
@ -762,7 +762,7 @@ class SpecAugment(paddle.nn.Layer):
|
|
|
|
|
|
|
|
|
|
|
|
x[:, :, :w] = left
|
|
|
|
x[:, :, :w] = left
|
|
|
|
x[:, :, w:] = right
|
|
|
|
x[:, :, w:] = right
|
|
|
|
return x.view(*original_size)
|
|
|
|
return x.reshape([*original_size])
|
|
|
|
|
|
|
|
|
|
|
|
def mask_along_axis(self, x, dim):
|
|
|
|
def mask_along_axis(self, x, dim):
|
|
|
|
"""Mask along time or frequency axis.
|
|
|
|
"""Mask along time or frequency axis.
|
|
|
@ -775,7 +775,7 @@ class SpecAugment(paddle.nn.Layer):
|
|
|
|
"""
|
|
|
|
"""
|
|
|
|
original_size = x.shape
|
|
|
|
original_size = x.shape
|
|
|
|
if x.dim() == 4:
|
|
|
|
if x.dim() == 4:
|
|
|
|
x = x.view(-1, x.shape[2], x.shape[3])
|
|
|
|
x = x.reshape([-1, x.shape[2], x.shape[3]])
|
|
|
|
|
|
|
|
|
|
|
|
batch, time, fea = x.shape
|
|
|
|
batch, time, fea = x.shape
|
|
|
|
|
|
|
|
|
|
|
@ -795,7 +795,7 @@ class SpecAugment(paddle.nn.Layer):
|
|
|
|
(batch, n_mask)).unsqueeze(2)
|
|
|
|
(batch, n_mask)).unsqueeze(2)
|
|
|
|
|
|
|
|
|
|
|
|
# compute masks
|
|
|
|
# compute masks
|
|
|
|
arange = paddle.arange(end=D).view(1, 1, -1)
|
|
|
|
arange = paddle.arange(end=D).reshape([1, 1, -1])
|
|
|
|
mask = (mask_pos <= arange) * (arange < (mask_pos + mask_len))
|
|
|
|
mask = (mask_pos <= arange) * (arange < (mask_pos + mask_len))
|
|
|
|
mask = mask.any(axis=1)
|
|
|
|
mask = mask.any(axis=1)
|
|
|
|
|
|
|
|
|
|
|
@ -811,7 +811,7 @@ class SpecAugment(paddle.nn.Layer):
|
|
|
|
# same to x.masked_fill_(mask, val)
|
|
|
|
# same to x.masked_fill_(mask, val)
|
|
|
|
y = paddle.full(x.shape, val, x.dtype)
|
|
|
|
y = paddle.full(x.shape, val, x.dtype)
|
|
|
|
x = paddle.where(mask, y, x)
|
|
|
|
x = paddle.where(mask, y, x)
|
|
|
|
return x.view(*original_size)
|
|
|
|
return x.reshape([*original_size])
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
class TimeDomainSpecAugment(nn.Layer):
|
|
|
|
class TimeDomainSpecAugment(nn.Layer):
|
|
|
|