support replace with mean by aug

pull/776/head
Hui Zhang 3 years ago
parent 86d08f994b
commit 50f10f37ae

@ -352,45 +352,6 @@ if not hasattr(paddle.Tensor, 'tolist'):
"register user tolist to paddle.Tensor, remove this when fixed!")
setattr(paddle.Tensor, 'tolist', tolist)
########### hcak paddle.nn.functional #############
def glu(x: paddle.Tensor, axis=-1) -> paddle.Tensor:
"""The gated linear unit (GLU) activation."""
a, b = x.split(2, axis=axis)
act_b = F.sigmoid(b)
return a * act_b
if not hasattr(paddle.nn.functional, 'glu'):
logger.warn(
"register user glu to paddle.nn.functional, remove this when fixed!")
setattr(paddle.nn.functional, 'glu', glu)
# def softplus(x):
# """Softplus function."""
# if hasattr(paddle.nn.functional, 'softplus'):
# #return paddle.nn.functional.softplus(x.float()).type_as(x)
# return paddle.nn.functional.softplus(x)
# else:
# raise NotImplementedError
# def gelu_accurate(x):
# """Gaussian Error Linear Units (GELU) activation."""
# # [reference] https://github.com/pytorch/fairseq/blob/e75cff5f2c1d62f12dc911e0bf420025eb1a4e33/fairseq/modules/gelu.py
# if not hasattr(gelu_accurate, "_a"):
# gelu_accurate._a = math.sqrt(2 / math.pi)
# return 0.5 * x * (1 + paddle.tanh(gelu_accurate._a *
# (x + 0.044715 * paddle.pow(x, 3))))
# def gelu(x):
# """Gaussian Error Linear Units (GELU) activation."""
# if hasattr(nn.functional, 'gelu'):
# #return nn.functional.gelu(x.float()).type_as(x)
# return nn.functional.gelu(x)
# else:
# return x * 0.5 * (1.0 + paddle.erf(x / math.sqrt(2.0)))
########### hcak paddle.nn #############
class GLU(nn.Layer):
@ -401,7 +362,7 @@ class GLU(nn.Layer):
self.dim = dim
def forward(self, xs):
return glu(xs, dim=self.dim)
return F.glu(xs, dim=self.dim)
if not hasattr(paddle.nn, 'GLU'):

@ -32,7 +32,7 @@ class ImpulseResponseAugmentor(AugmentorBase):
def __call__(self, x, uttid=None, train=True):
if not train:
return
return x
self.transform_audio(x)
return x

@ -38,7 +38,7 @@ class NoisePerturbAugmentor(AugmentorBase):
def __call__(self, x, uttid=None, train=True):
if not train:
return
return x
self.transform_audio(x)
return x

@ -46,7 +46,7 @@ class OnlineBayesianNormalizationAugmentor(AugmentorBase):
def __call__(self, x, uttid=None, train=True):
if not train:
return
return x
self.transform_audio(x)
return x

@ -33,7 +33,7 @@ class ResampleAugmentor(AugmentorBase):
def __call__(self, x, uttid=None, train=True):
if not train:
return
return x
self.transform_audio(x)
return x

@ -33,7 +33,7 @@ class ShiftPerturbAugmentor(AugmentorBase):
def __call__(self, x, uttid=None, train=True):
if not train:
return
return x
self.transform_audio(x)
return x

@ -41,7 +41,8 @@ class SpecAugmentor(AugmentorBase):
W=40,
adaptive_number_ratio=0,
adaptive_size_ratio=0,
max_n_time_masks=20):
max_n_time_masks=20,
replace_with_zero=True):
"""SpecAugment class.
Args:
rng (random.Random): random generator object.
@ -54,9 +55,11 @@ class SpecAugmentor(AugmentorBase):
adaptive_number_ratio (float): adaptive multiplicity ratio for time masking
adaptive_size_ratio (float): adaptive size ratio for time masking
max_n_time_masks (int): maximum number of time masking
replace_with_zero (bool): pad zero on mask if true else use mean
"""
super().__init__()
self._rng = rng
self.replace_with_zero = replace_with_zero
self.W = W
self.F = F
@ -124,15 +127,18 @@ class SpecAugmentor(AugmentorBase):
return f"specaug: F-{F}, T-{T}, F-n-{n_freq_masks}, T-n-{n_time_masks}"
def time_warp(xs, W=40):
raise NotImplementedError
return xs
def mask_freq(self, xs, replace_with_zero=False):
n_bins = xs.shape[0]
for i in range(0, self.n_freq_masks):
f = int(self._rng.uniform(low=0, high=self.F))
f_0 = int(self._rng.uniform(low=0, high=n_bins - f))
xs[f_0:f_0 + f, :] = 0
assert f_0 <= f_0 + f
if self.replace_with_zero:
xs[f_0:f_0 + f, :] = 0
else:
xs[f_0:f_0 + f, :] = xs.mean()
self._freq_mask = (f_0, f_0 + f)
return xs
@ -154,14 +160,17 @@ class SpecAugmentor(AugmentorBase):
t = int(self._rng.uniform(low=0, high=T))
t = min(t, int(n_frames * self.p))
t_0 = int(self._rng.uniform(low=0, high=n_frames - t))
xs[:, t_0:t_0 + t] = 0
assert t_0 <= t_0 + t
if self.replace_with_zero:
xs[:, t_0:t_0 + t] = 0
else:
xs[:, t_0:t_0 + t] = xs.mean()
self._time_mask = (t_0, t_0 + t)
return xs
def __call__(self, x, train=True):
if not train:
return
return x
return self.transform_feature(x)
def transform_feature(self, xs: np.ndarray):
@ -171,7 +180,7 @@ class SpecAugmentor(AugmentorBase):
Returns:
xs (FloatTensor): `[F, T]`
"""
# xs = self.time_warp(xs)
xs = self.time_warp(xs)
xs = self.mask_freq(xs)
xs = self.mask_time(xs)
return xs

@ -81,7 +81,7 @@ class SpeedPerturbAugmentor(AugmentorBase):
def __call__(self, x, uttid=None, train=True):
if not train:
return
return x
self.transform_audio(x)
return x

@ -39,7 +39,7 @@ class VolumePerturbAugmentor(AugmentorBase):
def __call__(self, x, uttid=None, train=True):
if not train:
return
return x
self.transform_audio(x)
return x

@ -27,7 +27,8 @@
"W": 80,
"adaptive_number_ratio": 0,
"adaptive_size_ratio": 0,
"max_n_time_masks": 20
"max_n_time_masks": 20,
"replace_with_zero": true
},
"prob": 1.0
}

@ -27,7 +27,8 @@
"W": 80,
"adaptive_number_ratio": 0,
"adaptive_size_ratio": 0,
"max_n_time_masks": 20
"max_n_time_masks": 20,
"replace_with_zero": true
},
"prob": 1.0
}

@ -1,10 +0,0 @@
[
{
"type": "shift",
"params": {
"min_shift_ms": -5,
"max_shift_ms": 5
},
"prob": 1.0
}
]

@ -60,7 +60,8 @@
"W": 80,
"adaptive_number_ratio": 0,
"adaptive_size_ratio": 0,
"max_n_time_masks": 20
"max_n_time_masks": 20,
"replace_with_zero": true
},
"prob": 0.0
}

@ -27,7 +27,8 @@
"W": 80,
"adaptive_number_ratio": 0,
"adaptive_size_ratio": 0,
"max_n_time_masks": 20
"max_n_time_masks": 20,
"replace_with_zero": true
},
"prob": 1.0
}

@ -27,7 +27,8 @@
"W": 80,
"adaptive_number_ratio": 0,
"adaptive_size_ratio": 0,
"max_n_time_masks": 20
"max_n_time_masks": 20,
"replace_with_zero": true
},
"prob": 1.0
}

@ -27,7 +27,8 @@
"W": 80,
"adaptive_number_ratio": 0,
"adaptive_size_ratio": 0,
"max_n_time_masks": 20
"max_n_time_masks": 20,
"replace_with_zero": true
},
"prob": 1.0
}

@ -10,7 +10,8 @@
"W": 80,
"adaptive_number_ratio": 0,
"adaptive_size_ratio": 0,
"max_n_time_masks": 20
"max_n_time_masks": 20,
"replace_with_zero": true
},
"prob": 1.0
}

@ -27,7 +27,8 @@
"W": 80,
"adaptive_number_ratio": 0,
"adaptive_size_ratio": 0,
"max_n_time_masks": 20
"max_n_time_masks": 20,
"replace_with_zero": true
},
"prob": 1.0
}

@ -1,4 +1,13 @@
[
{
"type": "speed",
"params": {
"min_speed_rate": 0.9,
"max_speed_rate": 1.1,
"num_rates": 3
},
"prob": 1.0
},
{
"type": "shift",
"params": {
@ -6,5 +15,21 @@
"max_shift_ms": 5
},
"prob": 1.0
},
{
"type": "specaug",
"params": {
"F": 10,
"T": 50,
"n_freq_masks": 2,
"n_time_masks": 2,
"p": 1.0,
"W": 80,
"adaptive_number_ratio": 0,
"adaptive_size_ratio": 0,
"max_n_time_masks": 20,
"replace_with_zero": true
},
"prob": 1.0
}
]

@ -27,7 +27,8 @@
"W": 80,
"adaptive_number_ratio": 0,
"adaptive_size_ratio": 0,
"max_n_time_masks": 20
"max_n_time_masks": 20,
"replace_with_zero": true
},
"prob": 1.0
}

Loading…
Cancel
Save