refactor augmentation interface

pull/756/head
Hui Zhang 3 years ago
parent 5ae639196c
commit 8939994d75

@ -13,18 +13,27 @@
# limitations under the License. # limitations under the License.
"""Contains the data augmentation pipeline.""" """Contains the data augmentation pipeline."""
import json import json
from collections.abc import Sequence
from inspect import signature
import numpy as np import numpy as np
from deepspeech.frontend.augmentor.impulse_response import ImpulseResponseAugmentor from deepspeech.utils.dynamic_import import dynamic_import
from deepspeech.frontend.augmentor.noise_perturb import NoisePerturbAugmentor from deepspeech.utils.log import Log
from deepspeech.frontend.augmentor.online_bayesian_normalization import \
OnlineBayesianNormalizationAugmentor __all__ = ["AugmentationPipeline"]
from deepspeech.frontend.augmentor.resample import ResampleAugmentor
from deepspeech.frontend.augmentor.shift_perturb import ShiftPerturbAugmentor logger = Log(__name__).getlog()
from deepspeech.frontend.augmentor.spec_augment import SpecAugmentor
from deepspeech.frontend.augmentor.speed_perturb import SpeedPerturbAugmentor import_alias = dict(
from deepspeech.frontend.augmentor.volume_perturb import VolumePerturbAugmentor volume="deepspeech.frontend.augmentor.impulse_response:VolumePerturbAugmentor",
shift="deepspeech.frontend.augmentor.shift_perturb:ShiftPerturbAugmentor",
speed="deepspeech.frontend.augmentor.speed_perturb:SpeedPerturbAugmentor",
resample="deepspeech.frontend.augmentor.resample:ResampleAugmentor",
bayesian_normal="deepspeech.frontend.augmentor.online_bayesian_normalization:OnlineBayesianNormalizationAugmentor",
noise="deepspeech.frontend.augmentor.noise_perturb:NoisePerturbAugmentor",
impulse="deepspeech.frontend.augmentor.impulse_response:ImpulseResponseAugmentor",
specaug="deepspeech.frontend.augmentor.spec_augment:SpecAugmentor", )
class AugmentationPipeline(): class AugmentationPipeline():
@ -78,20 +87,74 @@ class AugmentationPipeline():
augmentor to take effect. If "prob" is zero, the augmentor does not take augmentor to take effect. If "prob" is zero, the augmentor does not take
effect. effect.
:param augmentation_config: Augmentation configuration in json string. Params:
:type augmentation_config: str augmentation_config(str): Augmentation configuration in json string.
:param random_seed: Random seed. random_seed(int): Random seed.
:type random_seed: int train(bool): whether is train mode.
:raises ValueError: If the augmentation json config is in incorrect format".
Raises:
ValueError: If the augmentation json config is in incorrect format".
""" """
def __init__(self, augmentation_config: str, random_seed=0): def __init__(self, augmentation_config: str, random_seed: int=0):
self._rng = np.random.RandomState(random_seed) self._rng = np.random.RandomState(random_seed)
self._spec_types = ('specaug') self._spec_types = ('specaug')
self._augmentors, self._rates = self._parse_pipeline_from(
augmentation_config, 'audio') if augmentation_config is None:
self.conf = {}
else:
self.conf = json.loads(augmentation_config)
self._augmentors, self._rates = self._parse_pipeline_from('all')
self._audio_augmentors, self._audio_rates = self._parse_pipeline_from(
'audio')
self._spec_augmentors, self._spec_rates = self._parse_pipeline_from( self._spec_augmentors, self._spec_rates = self._parse_pipeline_from(
augmentation_config, 'feature') 'feature')
def __call__(self, xs, uttid_list=None, **kwargs):
if not isinstance(xs, Sequence):
is_batch = False
xs = [xs]
else:
is_batch = True
if isinstance(uttid_list, str):
uttid_list = [uttid_list for _ in range(len(xs))]
if self.conf.get("mode", "sequential") == "sequential":
for idx, (func, rate) in enumerate(
zip(self._augmentors, self._rates), 0):
if self._rng.uniform(0., 1.) >= rate:
continue
# Derive only the args which the func has
try:
param = signature(func).parameters
except ValueError:
# Some function, e.g. built-in function, are failed
param = {}
_kwargs = {k: v for k, v in kwargs.items() if k in param}
try:
if uttid_list is not None and "uttid" in param:
xs = [
func(x, u, **_kwargs)
for x, u in zip(xs, uttid_list)
]
else:
xs = [func(x, **_kwargs) for x in xs]
except Exception:
logger.fatal("Catch a exception from {}th func: {}".format(
idx, func))
raise
else:
raise NotImplementedError(
"Not supporting mode={}".format(self.conf["mode"]))
if is_batch:
return xs
else:
return xs[0]
def transform_audio(self, audio_segment): def transform_audio(self, audio_segment):
"""Run the pre-processing pipeline for data augmentation. """Run the pre-processing pipeline for data augmentation.
@ -101,7 +164,9 @@ class AugmentationPipeline():
:param audio_segment: Audio segment to process. :param audio_segment: Audio segment to process.
:type audio_segment: AudioSegmenet|SpeechSegment :type audio_segment: AudioSegmenet|SpeechSegment
""" """
for augmentor, rate in zip(self._augmentors, self._rates): if not self._train:
return
for augmentor, rate in zip(self._audio_augmentors, self._audio_rates):
if self._rng.uniform(0., 1.) < rate: if self._rng.uniform(0., 1.) < rate:
augmentor.transform_audio(audio_segment) augmentor.transform_audio(audio_segment)
@ -111,57 +176,44 @@ class AugmentationPipeline():
Args: Args:
spec_segment (np.ndarray): audio feature, (D, T). spec_segment (np.ndarray): audio feature, (D, T).
""" """
if not self._train:
return
for augmentor, rate in zip(self._spec_augmentors, self._spec_rates): for augmentor, rate in zip(self._spec_augmentors, self._spec_rates):
if self._rng.uniform(0., 1.) < rate: if self._rng.uniform(0., 1.) < rate:
spec_segment = augmentor.transform_feature(spec_segment) spec_segment = augmentor.transform_feature(spec_segment)
return spec_segment return spec_segment
def _parse_pipeline_from(self, config_json, aug_type='audio'): def _parse_pipeline_from(self, aug_type='all'):
"""Parse the config json to build a augmentation pipelien.""" """Parse the config json to build a augmentation pipelien."""
assert aug_type in ('audio', 'feature'), aug_type assert aug_type in ('audio', 'feature', 'all'), aug_type
try: audio_confs = []
configs = json.loads(config_json) feature_confs = []
audio_confs = [] all_confs = []
feature_confs = [] for config in self.conf:
for config in configs: all_confs.append(config)
if config["type"] in self._spec_types: if config["type"] in self._spec_types:
feature_confs.append(config) feature_confs.append(config)
else: else:
audio_confs.append(config) audio_confs.append(config)
if aug_type == 'audio': if aug_type == 'audio':
aug_confs = audio_confs aug_confs = audio_confs
elif aug_type == 'feature': elif aug_type == 'feature':
aug_confs = feature_confs aug_confs = feature_confs
else:
augmentors = [ aug_confs = all_confs
self._get_augmentor(config["type"], config["params"])
for config in aug_confs augmentors = [
] self._get_augmentor(config["type"], config["params"])
rates = [config["prob"] for config in aug_confs] for config in aug_confs
]
except Exception as e: rates = [config["prob"] for config in aug_confs]
raise ValueError("Failed to parse the augmentation config json: "
"%s" % str(e))
return augmentors, rates return augmentors, rates
def _get_augmentor(self, augmentor_type, params): def _get_augmentor(self, augmentor_type, params):
"""Return an augmentation model by the type name, and pass in params.""" """Return an augmentation model by the type name, and pass in params."""
if augmentor_type == "volume": class_obj = dynamic_import(augmentor_type, import_alias)
return VolumePerturbAugmentor(self._rng, **params) try:
elif augmentor_type == "shift": obj = class_obj(self._rng, **params)
return ShiftPerturbAugmentor(self._rng, **params) except Exception:
elif augmentor_type == "speed":
return SpeedPerturbAugmentor(self._rng, **params)
elif augmentor_type == "resample":
return ResampleAugmentor(self._rng, **params)
elif augmentor_type == "bayesian_normal":
return OnlineBayesianNormalizationAugmentor(self._rng, **params)
elif augmentor_type == "noise":
return NoisePerturbAugmentor(self._rng, **params)
elif augmentor_type == "impulse":
return ImpulseResponseAugmentor(self._rng, **params)
elif augmentor_type == "specaug":
return SpecAugmentor(self._rng, **params)
else:
raise ValueError("Unknown augmentor type [%s]." % augmentor_type) raise ValueError("Unknown augmentor type [%s]." % augmentor_type)

@ -28,6 +28,10 @@ class AugmentorBase():
def __init__(self): def __init__(self):
pass pass
@abstractmethod
def __call__(self, xs):
raise NotImplementedError
@abstractmethod @abstractmethod
def transform_audio(self, audio_segment): def transform_audio(self, audio_segment):
"""Adds various effects to the input audio segment. Such effects """Adds various effects to the input audio segment. Such effects

@ -30,6 +30,11 @@ class ImpulseResponseAugmentor(AugmentorBase):
self._rng = rng self._rng = rng
self._impulse_manifest = read_manifest(impulse_manifest_path) self._impulse_manifest = read_manifest(impulse_manifest_path)
def __call__(self, x, uttid=None, train=True):
if not train:
return
self.transform_audio(x)
def transform_audio(self, audio_segment): def transform_audio(self, audio_segment):
"""Add impulse response effect. """Add impulse response effect.

@ -36,6 +36,11 @@ class NoisePerturbAugmentor(AugmentorBase):
self._rng = rng self._rng = rng
self._noise_manifest = read_manifest(manifest_path=noise_manifest_path) self._noise_manifest = read_manifest(manifest_path=noise_manifest_path)
def __call__(self, x, uttid=None, train=True):
if not train:
return
self.transform_audio(x)
def transform_audio(self, audio_segment): def transform_audio(self, audio_segment):
"""Add background noise audio. """Add background noise audio.

@ -44,6 +44,11 @@ class OnlineBayesianNormalizationAugmentor(AugmentorBase):
self._rng = rng self._rng = rng
self._startup_delay = startup_delay self._startup_delay = startup_delay
def __call__(self, x, uttid=None, train=True):
if not train:
return
self.transform_audio(x)
def transform_audio(self, audio_segment): def transform_audio(self, audio_segment):
"""Normalizes the input audio using the online Bayesian approach. """Normalizes the input audio using the online Bayesian approach.

@ -31,6 +31,11 @@ class ResampleAugmentor(AugmentorBase):
self._new_sample_rate = new_sample_rate self._new_sample_rate = new_sample_rate
self._rng = rng self._rng = rng
def __call__(self, x, uttid=None, train=True):
if not train:
return
self.transform_audio(x)
def transform_audio(self, audio_segment): def transform_audio(self, audio_segment):
"""Resamples the input audio to a target sample rate. """Resamples the input audio to a target sample rate.

@ -31,6 +31,11 @@ class ShiftPerturbAugmentor(AugmentorBase):
self._max_shift_ms = max_shift_ms self._max_shift_ms = max_shift_ms
self._rng = rng self._rng = rng
def __call__(self, x, uttid=None, train=True):
if not train:
return
self.transform_audio(x)
def transform_audio(self, audio_segment): def transform_audio(self, audio_segment):
"""Shift audio. """Shift audio.

@ -157,6 +157,11 @@ class SpecAugmentor(AugmentorBase):
self._time_mask = (t_0, t_0 + t) self._time_mask = (t_0, t_0 + t)
return xs return xs
def __call__(self, x, train=True):
if not train:
return
self.transform_audio(x)
def transform_feature(self, xs: np.ndarray): def transform_feature(self, xs: np.ndarray):
""" """
Args: Args:

@ -79,6 +79,11 @@ class SpeedPerturbAugmentor(AugmentorBase):
self._rates = np.linspace( self._rates = np.linspace(
self._min_rate, self._max_rate, self._num_rates, endpoint=True) self._min_rate, self._max_rate, self._num_rates, endpoint=True)
def __call__(self, x, uttid=None, train=True):
if not train:
return
self.transform_audio(x)
def transform_audio(self, audio_segment): def transform_audio(self, audio_segment):
"""Sample a new speed rate from the given range and """Sample a new speed rate from the given range and
changes the speed of the given audio clip. changes the speed of the given audio clip.

@ -37,6 +37,11 @@ class VolumePerturbAugmentor(AugmentorBase):
self._max_gain_dBFS = max_gain_dBFS self._max_gain_dBFS = max_gain_dBFS
self._rng = rng self._rng = rng
def __call__(self, x, uttid=None, train=True):
if not train:
return
self.transform_audio(x)
def transform_audio(self, audio_segment): def transform_audio(self, audio_segment):
"""Change audio loadness. """Change audio loadness.

@ -16,6 +16,7 @@ from typing import Optional
from paddle.io import Dataset from paddle.io import Dataset
from yacs.config import CfgNode from yacs.config import CfgNode
from deepspeech.frontend.utility import read_manifest
from deepspeech.utils.log import Log from deepspeech.utils.log import Log
__all__ = ["ManifestDataset", "TripletManifestDataset", "TransformDataset"] __all__ = ["ManifestDataset", "TripletManifestDataset", "TransformDataset"]

@ -1,5 +1,6 @@
coverage coverage
gpustat gpustat
kaldiio
pre-commit pre-commit
pybind11 pybind11
resampy==0.2.2 resampy==0.2.2
@ -13,4 +14,3 @@ tensorboardX
textgrid textgrid
typeguard typeguard
yacs yacs
kaldiio

Loading…
Cancel
Save