(D,T) to (T, D); time warp

pull/777/head
Hui Zhang 3 years ago
parent d9a3864072
commit 782f6be42d

@ -12,7 +12,11 @@
# See the License for the specific language governing permissions and # See the License for the specific language governing permissions and
# limitations under the License. # limitations under the License.
"""Contains the volume perturb augmentation model.""" """Contains the volume perturb augmentation model."""
import random
import numpy as np import numpy as np
from PIL import Image
from PIL.Image import BICUBIC
from deepspeech.frontend.augmentor.base import AugmentorBase from deepspeech.frontend.augmentor.base import AugmentorBase
from deepspeech.utils.log import Log from deepspeech.utils.log import Log
@ -42,7 +46,8 @@ class SpecAugmentor(AugmentorBase):
adaptive_number_ratio=0, adaptive_number_ratio=0,
adaptive_size_ratio=0, adaptive_size_ratio=0,
max_n_time_masks=20, max_n_time_masks=20,
replace_with_zero=True): replace_with_zero=True,
warp_mode='PIL'):
"""SpecAugment class. """SpecAugment class.
Args: Args:
rng (random.Random): random generator object. rng (random.Random): random generator object.
@ -56,11 +61,15 @@ class SpecAugmentor(AugmentorBase):
adaptive_size_ratio (float): adaptive size ratio for time masking adaptive_size_ratio (float): adaptive size ratio for time masking
max_n_time_masks (int): maximum number of time masking max_n_time_masks (int): maximum number of time masking
replace_with_zero (bool): pad zero on mask if true else use mean replace_with_zero (bool): pad zero on mask if true else use mean
warp_mode (str): "PIL" (default, fast, not differentiable)
or "sparse_image_warp" (slow, differentiable)
""" """
super().__init__() super().__init__()
self._rng = rng self._rng = rng
self.inplace = True
self.replace_with_zero = replace_with_zero self.replace_with_zero = replace_with_zero
self.mode = warp_mode
self.W = W self.W = W
self.F = F self.F = F
self.T = T self.T = T
@ -126,24 +135,80 @@ class SpecAugmentor(AugmentorBase):
def __repr__(self): def __repr__(self):
return f"specaug: F-{F}, T-{T}, F-n-{n_freq_masks}, T-n-{n_time_masks}" return f"specaug: F-{F}, T-{T}, F-n-{n_freq_masks}, T-n-{n_time_masks}"
def time_warp(xs, W=40): def time_warp(self, x, mode='PIL'):
return xs """time warp for spec augment
move random center frame by the random width ~ uniform(-window, window)
Args:
x (np.ndarray): spectrogram (time, freq)
mode (str): PIL or sparse_image_warp
Raises:
NotImplementedError: [description]
NotImplementedError: [description]
Returns:
np.ndarray: time warped spectrogram (time, freq)
"""
window = max_time_warp = self.W
if mode == "PIL":
t = x.shape[0]
if t - window <= window:
return x
# NOTE: randrange(a, b) emits a, a + 1, ..., b - 1
center = random.randrange(window, t - window)
warped = random.randrange(center - window, center +
window) + 1 # 1 ... t - 1
left = Image.fromarray(x[:center]).resize((x.shape[1], warped),
BICUBIC)
right = Image.fromarray(x[center:]).resize((x.shape[1], t - warped),
BICUBIC)
if self.inplace:
x[:warped] = left
x[warped:] = right
return x
return np.concatenate((left, right), 0)
elif mode == "sparse_image_warp":
raise NotImplementedError('sparse_image_warp')
else:
raise NotImplementedError(
"unknown resize mode: " + mode +
", choose one from (PIL, sparse_image_warp).")
def mask_freq(self, x, replace_with_zero=False):
"""freq mask
def mask_freq(self, xs, replace_with_zero=False): Args:
n_bins = xs.shape[0] x (np.ndarray): spectrogram (time, freq)
replace_with_zero (bool, optional): Defaults to False.
Returns:
np.ndarray: freq mask spectrogram (time, freq)
"""
n_bins = x.shape[1]
for i in range(0, self.n_freq_masks): for i in range(0, self.n_freq_masks):
f = int(self._rng.uniform(low=0, high=self.F)) f = int(self._rng.uniform(low=0, high=self.F))
f_0 = int(self._rng.uniform(low=0, high=n_bins - f)) f_0 = int(self._rng.uniform(low=0, high=n_bins - f))
assert f_0 <= f_0 + f assert f_0 <= f_0 + f
if self.replace_with_zero: if replace_with_zero:
xs[f_0:f_0 + f, :] = 0 x[:, f_0:f_0 + f] = 0
else: else:
xs[f_0:f_0 + f, :] = xs.mean() x[:, f_0:f_0 + f] = x.mean()
self._freq_mask = (f_0, f_0 + f) self._freq_mask = (f_0, f_0 + f)
return xs return x
def mask_time(self, x, replace_with_zero=False):
"""time mask
def mask_time(self, xs, replace_with_zero=False): Args:
n_frames = xs.shape[1] x (np.ndarray): spectrogram (time, freq)
replace_with_zero (bool, optional): Defaults to False.
Returns:
np.ndarray: time mask spectrogram (time, freq)
"""
n_frames = x.shape[0]
if self.adaptive_number_ratio > 0: if self.adaptive_number_ratio > 0:
n_masks = int(n_frames * self.adaptive_number_ratio) n_masks = int(n_frames * self.adaptive_number_ratio)
@ -161,26 +226,26 @@ class SpecAugmentor(AugmentorBase):
t = min(t, int(n_frames * self.p)) t = min(t, int(n_frames * self.p))
t_0 = int(self._rng.uniform(low=0, high=n_frames - t)) t_0 = int(self._rng.uniform(low=0, high=n_frames - t))
assert t_0 <= t_0 + t assert t_0 <= t_0 + t
if self.replace_with_zero: if replace_with_zero:
xs[:, t_0:t_0 + t] = 0 x[t_0:t_0 + t, :] = 0
else: else:
xs[:, t_0:t_0 + t] = xs.mean() x[t_0:t_0 + t, :] = x.mean()
self._time_mask = (t_0, t_0 + t) self._time_mask = (t_0, t_0 + t)
return xs return x
def __call__(self, x, train=True): def __call__(self, x, train=True):
if not train: if not train:
return x return x
return self.transform_feature(x) return self.transform_feature(x)
def transform_feature(self, xs: np.ndarray): def transform_feature(self, x: np.ndarray):
""" """
Args: Args:
xs (FloatTensor): `[F, T]` x (np.ndarray): `[T, F]`
Returns: Returns:
xs (FloatTensor): `[F, T]` x (np.ndarray): `[T, F]`
""" """
xs = self.time_warp(xs) x = self.time_warp(x, self.mode)
xs = self.mask_freq(xs) x = self.mask_freq(x, self.replace_with_zero)
xs = self.mask_time(xs) x = self.mask_time(x, self.replace_with_zero)
return xs return x

@ -167,32 +167,6 @@ class AudioFeaturizer(object):
raise ValueError("Unknown specgram_type %s. " raise ValueError("Unknown specgram_type %s. "
"Supported values: linear." % self._specgram_type) "Supported values: linear." % self._specgram_type)
def _compute_linear_specgram(self,
samples,
sample_rate,
stride_ms=10.0,
window_ms=20.0,
max_freq=None,
eps=1e-14):
"""Compute the linear spectrogram from FFT energy."""
if max_freq is None:
max_freq = sample_rate / 2
if max_freq > sample_rate / 2:
raise ValueError("max_freq must not be greater than half of "
"sample rate.")
if stride_ms > window_ms:
raise ValueError("Stride size must not be greater than "
"window size.")
stride_size = int(0.001 * sample_rate * stride_ms)
window_size = int(0.001 * sample_rate * window_ms)
specgram, freqs = self._specgram_real(
samples,
window_size=window_size,
stride_size=stride_size,
sample_rate=sample_rate)
ind = np.where(freqs <= max_freq)[0][-1] + 1
return np.log(specgram[:ind, :] + eps)
def _specgram_real(self, samples, window_size, stride_size, sample_rate): def _specgram_real(self, samples, window_size, stride_size, sample_rate):
"""Compute the spectrogram for samples from a real signal.""" """Compute the spectrogram for samples from a real signal."""
# extract strided windows # extract strided windows
@ -217,26 +191,65 @@ class AudioFeaturizer(object):
freqs = float(sample_rate) / window_size * np.arange(fft.shape[0]) freqs = float(sample_rate) / window_size * np.arange(fft.shape[0])
return fft, freqs return fft, freqs
def _compute_linear_specgram(self,
samples,
sample_rate,
stride_ms=10.0,
window_ms=20.0,
max_freq=None,
eps=1e-14):
"""Compute the linear spectrogram from FFT energy.
Args:
samples ([type]): [description]
sample_rate ([type]): [description]
stride_ms (float, optional): [description]. Defaults to 10.0.
window_ms (float, optional): [description]. Defaults to 20.0.
max_freq ([type], optional): [description]. Defaults to None.
eps ([type], optional): [description]. Defaults to 1e-14.
Raises:
ValueError: [description]
ValueError: [description]
Returns:
np.ndarray: log spectrogram, (time, freq)
"""
if max_freq is None:
max_freq = sample_rate / 2
if max_freq > sample_rate / 2:
raise ValueError("max_freq must not be greater than half of "
"sample rate.")
if stride_ms > window_ms:
raise ValueError("Stride size must not be greater than "
"window size.")
stride_size = int(0.001 * sample_rate * stride_ms)
window_size = int(0.001 * sample_rate * window_ms)
specgram, freqs = self._specgram_real(
samples,
window_size=window_size,
stride_size=stride_size,
sample_rate=sample_rate)
ind = np.where(freqs <= max_freq)[0][-1] + 1
# (freq, time)
spec = np.log(specgram[:ind, :] + eps)
return np.transpose(spec)
def _concat_delta_delta(self, feat): def _concat_delta_delta(self, feat):
"""append delat, delta-delta feature. """append delat, delta-delta feature.
Args: Args:
feat (np.ndarray): (D, T) feat (np.ndarray): (T, D)
Returns: Returns:
np.ndarray: feat with delta-delta, (3*D, T) np.ndarray: feat with delta-delta, (T, 3*D)
""" """
feat = np.transpose(feat)
# Deltas # Deltas
d_feat = delta(feat, 2) d_feat = delta(feat, 2)
# Deltas-Deltas # Deltas-Deltas
dd_feat = delta(feat, 2) dd_feat = delta(feat, 2)
# transpose
feat = np.transpose(feat)
d_feat = np.transpose(d_feat)
dd_feat = np.transpose(dd_feat)
# concat above three features # concat above three features
concat_feat = np.concatenate((feat, d_feat, dd_feat)) concat_feat = np.concatenate((feat, d_feat, dd_feat), axis=1)
return concat_feat return concat_feat
def _compute_mfcc(self, def _compute_mfcc(self,
@ -292,7 +305,6 @@ class AudioFeaturizer(object):
ceplifter=22, ceplifter=22,
useEnergy=True, useEnergy=True,
winfunc='povey') winfunc='povey')
mfcc_feat = np.transpose(mfcc_feat)
if delta_delta: if delta_delta:
mfcc_feat = self._concat_delta_delta(mfcc_feat) mfcc_feat = self._concat_delta_delta(mfcc_feat)
return mfcc_feat return mfcc_feat
@ -346,8 +358,6 @@ class AudioFeaturizer(object):
remove_dc_offset=True, remove_dc_offset=True,
preemph=0.97, preemph=0.97,
wintype='povey') wintype='povey')
fbank_feat = np.transpose(fbank_feat)
if delta_delta: if delta_delta:
fbank_feat = self._concat_delta_delta(fbank_feat) fbank_feat = self._concat_delta_delta(fbank_feat)
return fbank_feat return fbank_feat

@ -40,21 +40,21 @@ class CollateFunc(object):
number = 0 number = 0
for item in batch: for item in batch:
audioseg = AudioSegment.from_file(item['feat']) audioseg = AudioSegment.from_file(item['feat'])
feat = self.feature_func(audioseg) #(D, T) feat = self.feature_func(audioseg) #(T, D)
sums = np.sum(feat, axis=1) sums = np.sum(feat, axis=0)
if mean_stat is None: if mean_stat is None:
mean_stat = sums mean_stat = sums
else: else:
mean_stat += sums mean_stat += sums
square_sums = np.sum(np.square(feat), axis=1) square_sums = np.sum(np.square(feat), axis=0)
if var_stat is None: if var_stat is None:
var_stat = square_sums var_stat = square_sums
else: else:
var_stat += square_sums var_stat += square_sums
number += feat.shape[1] number += feat.shape[0]
return number, mean_stat, var_stat return number, mean_stat, var_stat
@ -120,7 +120,7 @@ class FeatureNormalizer(object):
"""Normalize features to be of zero mean and unit stddev. """Normalize features to be of zero mean and unit stddev.
:param features: Input features to be normalized. :param features: Input features to be normalized.
:type features: ndarray, shape (D, T) :type features: ndarray, shape (T, D)
:param eps: added to stddev to provide numerical stablibity. :param eps: added to stddev to provide numerical stablibity.
:type eps: float :type eps: float
:return: Normalized features. :return: Normalized features.
@ -131,8 +131,8 @@ class FeatureNormalizer(object):
def _read_mean_std_from_file(self, filepath, eps=1e-20): def _read_mean_std_from_file(self, filepath, eps=1e-20):
"""Load mean and std from file.""" """Load mean and std from file."""
mean, istd = load_cmvn(filepath, filetype='json') mean, istd = load_cmvn(filepath, filetype='json')
self._mean = np.expand_dims(mean, axis=-1) self._mean = np.expand_dims(mean, axis=0)
self._istd = np.expand_dims(istd, axis=-1) self._istd = np.expand_dims(istd, axis=0)
def write_to_file(self, filepath): def write_to_file(self, filepath):
"""Write the mean and stddev to the file. """Write the mean and stddev to the file.

@ -242,7 +242,6 @@ class SpeechCollator():
# specgram augment # specgram augment
specgram = self._augmentation_pipeline.transform_feature(specgram) specgram = self._augmentation_pipeline.transform_feature(specgram)
specgram = specgram.transpose([1, 0])
return specgram, transcript_part return specgram, transcript_part
def __call__(self, batch): def __call__(self, batch):
@ -250,7 +249,7 @@ class SpeechCollator():
Args: Args:
batch ([List]): batch is (audio, text) batch ([List]): batch is (audio, text)
audio (np.ndarray) shape (D, T) audio (np.ndarray) shape (T, D)
text (List[int] or str): shape (U,) text (List[int] or str): shape (U,)
Returns: Returns:

@ -217,6 +217,34 @@ class SpeechCollator():
return self._local_data.tar2object[tarpath].extractfile( return self._local_data.tar2object[tarpath].extractfile(
self._local_data.tar2info[tarpath][filename]) self._local_data.tar2info[tarpath][filename])
@property
def manifest(self):
return self._manifest
@property
def vocab_size(self):
return self._speech_featurizer.vocab_size
@property
def vocab_list(self):
return self._speech_featurizer.vocab_list
@property
def vocab_dict(self):
return self._speech_featurizer.vocab_dict
@property
def text_feature(self):
return self._speech_featurizer.text_feature
@property
def feature_size(self):
return self._speech_featurizer.feature_size
@property
def stride_ms(self):
return self._speech_featurizer.stride_ms
def process_utterance(self, audio_file, translation): def process_utterance(self, audio_file, translation):
"""Load, augment, featurize and normalize for speech data. """Load, augment, featurize and normalize for speech data.
@ -244,7 +272,6 @@ class SpeechCollator():
# specgram augment # specgram augment
specgram = self._augmentation_pipeline.transform_feature(specgram) specgram = self._augmentation_pipeline.transform_feature(specgram)
specgram = specgram.transpose([1, 0])
return specgram, translation_part return specgram, translation_part
def __call__(self, batch): def __call__(self, batch):
@ -252,7 +279,7 @@ class SpeechCollator():
Args: Args:
batch ([List]): batch is (audio, text) batch ([List]): batch is (audio, text)
audio (np.ndarray) shape (D, T) audio (np.ndarray) shape (T, D)
text (List[int] or str): shape (U,) text (List[int] or str): shape (U,)
Returns: Returns:
@ -296,34 +323,6 @@ class SpeechCollator():
text_lens = np.array(text_lens).astype(np.int64) text_lens = np.array(text_lens).astype(np.int64)
return utts, padded_audios, audio_lens, padded_texts, text_lens return utts, padded_audios, audio_lens, padded_texts, text_lens
@property
def manifest(self):
return self._manifest
@property
def vocab_size(self):
return self._speech_featurizer.vocab_size
@property
def vocab_list(self):
return self._speech_featurizer.vocab_list
@property
def vocab_dict(self):
return self._speech_featurizer.vocab_dict
@property
def text_feature(self):
return self._speech_featurizer.text_feature
@property
def feature_size(self):
return self._speech_featurizer.feature_size
@property
def stride_ms(self):
return self._speech_featurizer.stride_ms
class TripletSpeechCollator(SpeechCollator): class TripletSpeechCollator(SpeechCollator):
def process_utterance(self, audio_file, translation, transcript): def process_utterance(self, audio_file, translation, transcript):
@ -355,7 +354,6 @@ class TripletSpeechCollator(SpeechCollator):
# specgram augment # specgram augment
specgram = self._augmentation_pipeline.transform_feature(specgram) specgram = self._augmentation_pipeline.transform_feature(specgram)
specgram = specgram.transpose([1, 0])
return specgram, translation_part, transcript_part return specgram, translation_part, transcript_part
def __call__(self, batch): def __call__(self, batch):
@ -363,7 +361,7 @@ class TripletSpeechCollator(SpeechCollator):
Args: Args:
batch ([List]): batch is (audio, text) batch ([List]): batch is (audio, text)
audio (np.ndarray) shape (D, T) audio (np.ndarray) shape (T, D)
text (List[int] or str): shape (U,) text (List[int] or str): shape (U,)
Returns: Returns:
@ -524,49 +522,19 @@ class KaldiPrePorocessedCollator(SpeechCollator):
:rtype: tuple of (2darray, list) :rtype: tuple of (2darray, list)
""" """
specgram = kaldiio.load_mat(audio_file) specgram = kaldiio.load_mat(audio_file)
specgram = specgram.transpose([1, 0])
assert specgram.shape[ assert specgram.shape[
0] == self._feat_dim, 'expect feat dim {}, but got {}'.format( 1] == self._feat_dim, 'expect feat dim {}, but got {}'.format(
self._feat_dim, specgram.shape[0]) self._feat_dim, specgram.shape[1])
# specgram augment # specgram augment
specgram = self._augmentation_pipeline.transform_feature(specgram) specgram = self._augmentation_pipeline.transform_feature(specgram)
specgram = specgram.transpose([1, 0])
if self._keep_transcription_text: if self._keep_transcription_text:
return specgram, translation return specgram, translation
else: else:
text_ids = self._text_featurizer.featurize(translation) text_ids = self._text_featurizer.featurize(translation)
return specgram, text_ids return specgram, text_ids
@property
def manifest(self):
return self._manifest
@property
def vocab_size(self):
return self._text_featurizer.vocab_size
@property
def vocab_list(self):
return self._text_featurizer.vocab_list
@property
def vocab_dict(self):
return self._text_featurizer.vocab_dict
@property
def text_feature(self):
return self._text_featurizer
@property
def feature_size(self):
return self._feat_dim
@property
def stride_ms(self):
return self._stride_ms
class TripletKaldiPrePorocessedCollator(KaldiPrePorocessedCollator): class TripletKaldiPrePorocessedCollator(KaldiPrePorocessedCollator):
def process_utterance(self, audio_file, translation, transcript): def process_utterance(self, audio_file, translation, transcript):
@ -583,15 +551,13 @@ class TripletKaldiPrePorocessedCollator(KaldiPrePorocessedCollator):
:rtype: tuple of (2darray, (list, list)) :rtype: tuple of (2darray, (list, list))
""" """
specgram = kaldiio.load_mat(audio_file) specgram = kaldiio.load_mat(audio_file)
specgram = specgram.transpose([1, 0])
assert specgram.shape[ assert specgram.shape[
0] == self._feat_dim, 'expect feat dim {}, but got {}'.format( 1] == self._feat_dim, 'expect feat dim {}, but got {}'.format(
self._feat_dim, specgram.shape[0]) self._feat_dim, specgram.shape[1])
# specgram augment # specgram augment
specgram = self._augmentation_pipeline.transform_feature(specgram) specgram = self._augmentation_pipeline.transform_feature(specgram)
specgram = specgram.transpose([1, 0])
if self._keep_transcription_text: if self._keep_transcription_text:
return specgram, translation, transcript return specgram, translation, transcript
else: else:
@ -604,7 +570,7 @@ class TripletKaldiPrePorocessedCollator(KaldiPrePorocessedCollator):
Args: Args:
batch ([List]): batch is (audio, text) batch ([List]): batch is (audio, text)
audio (np.ndarray) shape (D, T) audio (np.ndarray) shape (T, D)
translation (List[int] or str): shape (U,) translation (List[int] or str): shape (U,)
transcription (List[int] or str): shape (V,) transcription (List[int] or str): shape (V,)

@ -1,6 +1,7 @@
coverage coverage
gpustat gpustat
kaldiio kaldiio
Pillow
pre-commit pre-commit
pybind11 pybind11
resampy==0.2.2 resampy==0.2.2

Loading…
Cancel
Save