|
|
@ -41,7 +41,8 @@ class SpecAugmentor(AugmentorBase):
|
|
|
|
W=40,
|
|
|
|
W=40,
|
|
|
|
adaptive_number_ratio=0,
|
|
|
|
adaptive_number_ratio=0,
|
|
|
|
adaptive_size_ratio=0,
|
|
|
|
adaptive_size_ratio=0,
|
|
|
|
max_n_time_masks=20):
|
|
|
|
max_n_time_masks=20,
|
|
|
|
|
|
|
|
replace_with_zero=True):
|
|
|
|
"""SpecAugment class.
|
|
|
|
"""SpecAugment class.
|
|
|
|
Args:
|
|
|
|
Args:
|
|
|
|
rng (random.Random): random generator object.
|
|
|
|
rng (random.Random): random generator object.
|
|
|
@ -54,9 +55,11 @@ class SpecAugmentor(AugmentorBase):
|
|
|
|
adaptive_number_ratio (float): adaptive multiplicity ratio for time masking
|
|
|
|
adaptive_number_ratio (float): adaptive multiplicity ratio for time masking
|
|
|
|
adaptive_size_ratio (float): adaptive size ratio for time masking
|
|
|
|
adaptive_size_ratio (float): adaptive size ratio for time masking
|
|
|
|
max_n_time_masks (int): maximum number of time masking
|
|
|
|
max_n_time_masks (int): maximum number of time masking
|
|
|
|
|
|
|
|
replace_with_zero (bool): pad zero on mask if true else use mean
|
|
|
|
"""
|
|
|
|
"""
|
|
|
|
super().__init__()
|
|
|
|
super().__init__()
|
|
|
|
self._rng = rng
|
|
|
|
self._rng = rng
|
|
|
|
|
|
|
|
self.replace_with_zero = replace_with_zero
|
|
|
|
|
|
|
|
|
|
|
|
self.W = W
|
|
|
|
self.W = W
|
|
|
|
self.F = F
|
|
|
|
self.F = F
|
|
|
@ -124,15 +127,18 @@ class SpecAugmentor(AugmentorBase):
|
|
|
|
return f"specaug: F-{F}, T-{T}, F-n-{n_freq_masks}, T-n-{n_time_masks}"
|
|
|
|
return f"specaug: F-{F}, T-{T}, F-n-{n_freq_masks}, T-n-{n_time_masks}"
|
|
|
|
|
|
|
|
|
|
|
|
def time_warp(xs, W=40):
|
|
|
|
def time_warp(xs, W=40):
|
|
|
|
raise NotImplementedError
|
|
|
|
return xs
|
|
|
|
|
|
|
|
|
|
|
|
def mask_freq(self, xs, replace_with_zero=False):
|
|
|
|
def mask_freq(self, xs, replace_with_zero=False):
|
|
|
|
n_bins = xs.shape[0]
|
|
|
|
n_bins = xs.shape[0]
|
|
|
|
for i in range(0, self.n_freq_masks):
|
|
|
|
for i in range(0, self.n_freq_masks):
|
|
|
|
f = int(self._rng.uniform(low=0, high=self.F))
|
|
|
|
f = int(self._rng.uniform(low=0, high=self.F))
|
|
|
|
f_0 = int(self._rng.uniform(low=0, high=n_bins - f))
|
|
|
|
f_0 = int(self._rng.uniform(low=0, high=n_bins - f))
|
|
|
|
xs[f_0:f_0 + f, :] = 0
|
|
|
|
|
|
|
|
assert f_0 <= f_0 + f
|
|
|
|
assert f_0 <= f_0 + f
|
|
|
|
|
|
|
|
if self.replace_with_zero:
|
|
|
|
|
|
|
|
xs[f_0:f_0 + f, :] = 0
|
|
|
|
|
|
|
|
else:
|
|
|
|
|
|
|
|
xs[f_0:f_0 + f, :] = xs.mean()
|
|
|
|
self._freq_mask = (f_0, f_0 + f)
|
|
|
|
self._freq_mask = (f_0, f_0 + f)
|
|
|
|
return xs
|
|
|
|
return xs
|
|
|
|
|
|
|
|
|
|
|
@ -154,14 +160,17 @@ class SpecAugmentor(AugmentorBase):
|
|
|
|
t = int(self._rng.uniform(low=0, high=T))
|
|
|
|
t = int(self._rng.uniform(low=0, high=T))
|
|
|
|
t = min(t, int(n_frames * self.p))
|
|
|
|
t = min(t, int(n_frames * self.p))
|
|
|
|
t_0 = int(self._rng.uniform(low=0, high=n_frames - t))
|
|
|
|
t_0 = int(self._rng.uniform(low=0, high=n_frames - t))
|
|
|
|
xs[:, t_0:t_0 + t] = 0
|
|
|
|
|
|
|
|
assert t_0 <= t_0 + t
|
|
|
|
assert t_0 <= t_0 + t
|
|
|
|
|
|
|
|
if self.replace_with_zero:
|
|
|
|
|
|
|
|
xs[:, t_0:t_0 + t] = 0
|
|
|
|
|
|
|
|
else:
|
|
|
|
|
|
|
|
xs[:, t_0:t_0 + t] = xs.mean()
|
|
|
|
self._time_mask = (t_0, t_0 + t)
|
|
|
|
self._time_mask = (t_0, t_0 + t)
|
|
|
|
return xs
|
|
|
|
return xs
|
|
|
|
|
|
|
|
|
|
|
|
def __call__(self, x, train=True):
|
|
|
|
def __call__(self, x, train=True):
|
|
|
|
if not train:
|
|
|
|
if not train:
|
|
|
|
return
|
|
|
|
return x
|
|
|
|
return self.transform_feature(x)
|
|
|
|
return self.transform_feature(x)
|
|
|
|
|
|
|
|
|
|
|
|
def transform_feature(self, xs: np.ndarray):
|
|
|
|
def transform_feature(self, xs: np.ndarray):
|
|
|
@ -171,7 +180,7 @@ class SpecAugmentor(AugmentorBase):
|
|
|
|
Returns:
|
|
|
|
Returns:
|
|
|
|
xs (FloatTensor): `[F, T]`
|
|
|
|
xs (FloatTensor): `[F, T]`
|
|
|
|
"""
|
|
|
|
"""
|
|
|
|
# xs = self.time_warp(xs)
|
|
|
|
xs = self.time_warp(xs)
|
|
|
|
xs = self.mask_freq(xs)
|
|
|
|
xs = self.mask_freq(xs)
|
|
|
|
xs = self.mask_time(xs)
|
|
|
|
xs = self.mask_time(xs)
|
|
|
|
return xs
|
|
|
|
return xs
|
|
|
|