diff --git a/deepspeech/frontend/augmentor/augmentation.py b/deepspeech/frontend/augmentor/augmentation.py index cc0564da..a61ca37b 100644 --- a/deepspeech/frontend/augmentor/augmentation.py +++ b/deepspeech/frontend/augmentor/augmentation.py @@ -13,18 +13,27 @@ # limitations under the License. """Contains the data augmentation pipeline.""" import json +from collections.abc import Sequence +from inspect import signature import numpy as np -from deepspeech.frontend.augmentor.impulse_response import ImpulseResponseAugmentor -from deepspeech.frontend.augmentor.noise_perturb import NoisePerturbAugmentor -from deepspeech.frontend.augmentor.online_bayesian_normalization import \ - OnlineBayesianNormalizationAugmentor -from deepspeech.frontend.augmentor.resample import ResampleAugmentor -from deepspeech.frontend.augmentor.shift_perturb import ShiftPerturbAugmentor -from deepspeech.frontend.augmentor.spec_augment import SpecAugmentor -from deepspeech.frontend.augmentor.speed_perturb import SpeedPerturbAugmentor -from deepspeech.frontend.augmentor.volume_perturb import VolumePerturbAugmentor +from deepspeech.utils.dynamic_import import dynamic_import +from deepspeech.utils.log import Log + +__all__ = ["AugmentationPipeline"] + +logger = Log(__name__).getlog() + +import_alias = dict( + volume="deepspeech.frontend.augmentor.impulse_response:VolumePerturbAugmentor", + shift="deepspeech.frontend.augmentor.shift_perturb:ShiftPerturbAugmentor", + speed="deepspeech.frontend.augmentor.speed_perturb:SpeedPerturbAugmentor", + resample="deepspeech.frontend.augmentor.resample:ResampleAugmentor", + bayesian_normal="deepspeech.frontend.augmentor.online_bayesian_normalization:OnlineBayesianNormalizationAugmentor", + noise="deepspeech.frontend.augmentor.noise_perturb:NoisePerturbAugmentor", + impulse="deepspeech.frontend.augmentor.impulse_response:ImpulseResponseAugmentor", + specaug="deepspeech.frontend.augmentor.spec_augment:SpecAugmentor", ) class AugmentationPipeline(): @@ -78,20 +87,74 @@ class AugmentationPipeline(): augmentor to take effect. If "prob" is zero, the augmentor does not take effect. - :param augmentation_config: Augmentation configuration in json string. - :type augmentation_config: str - :param random_seed: Random seed. - :type random_seed: int - :raises ValueError: If the augmentation json config is in incorrect format". + Params: + augmentation_config(str): Augmentation configuration in json string. + random_seed(int): Random seed. + train(bool): whether is train mode. + + Raises: + ValueError: If the augmentation json config is in incorrect format". """ - def __init__(self, augmentation_config: str, random_seed=0): + def __init__(self, augmentation_config: str, random_seed: int=0): self._rng = np.random.RandomState(random_seed) self._spec_types = ('specaug') - self._augmentors, self._rates = self._parse_pipeline_from( - augmentation_config, 'audio') + + if augmentation_config is None: + self.conf = {} + else: + self.conf = json.loads(augmentation_config) + + self._augmentors, self._rates = self._parse_pipeline_from('all') + self._audio_augmentors, self._audio_rates = self._parse_pipeline_from( + 'audio') self._spec_augmentors, self._spec_rates = self._parse_pipeline_from( - augmentation_config, 'feature') + 'feature') + + def __call__(self, xs, uttid_list=None, **kwargs): + if not isinstance(xs, Sequence): + is_batch = False + xs = [xs] + else: + is_batch = True + + if isinstance(uttid_list, str): + uttid_list = [uttid_list for _ in range(len(xs))] + + if self.conf.get("mode", "sequential") == "sequential": + for idx, (func, rate) in enumerate( + zip(self._augmentors, self._rates), 0): + if self._rng.uniform(0., 1.) >= rate: + continue + + # Derive only the args which the func has + try: + param = signature(func).parameters + except ValueError: + # Some function, e.g. built-in function, are failed + param = {} + _kwargs = {k: v for k, v in kwargs.items() if k in param} + + try: + if uttid_list is not None and "uttid" in param: + xs = [ + func(x, u, **_kwargs) + for x, u in zip(xs, uttid_list) + ] + else: + xs = [func(x, **_kwargs) for x in xs] + except Exception: + logger.fatal("Catch a exception from {}th func: {}".format( + idx, func)) + raise + else: + raise NotImplementedError( + "Not supporting mode={}".format(self.conf["mode"])) + + if is_batch: + return xs + else: + return xs[0] def transform_audio(self, audio_segment): """Run the pre-processing pipeline for data augmentation. @@ -101,7 +164,9 @@ class AugmentationPipeline(): :param audio_segment: Audio segment to process. :type audio_segment: AudioSegmenet|SpeechSegment """ - for augmentor, rate in zip(self._augmentors, self._rates): + if not self._train: + return + for augmentor, rate in zip(self._audio_augmentors, self._audio_rates): if self._rng.uniform(0., 1.) < rate: augmentor.transform_audio(audio_segment) @@ -111,57 +176,44 @@ class AugmentationPipeline(): Args: spec_segment (np.ndarray): audio feature, (D, T). """ + if not self._train: + return for augmentor, rate in zip(self._spec_augmentors, self._spec_rates): if self._rng.uniform(0., 1.) < rate: spec_segment = augmentor.transform_feature(spec_segment) return spec_segment - def _parse_pipeline_from(self, config_json, aug_type='audio'): + def _parse_pipeline_from(self, aug_type='all'): """Parse the config json to build a augmentation pipelien.""" - assert aug_type in ('audio', 'feature'), aug_type - try: - configs = json.loads(config_json) - audio_confs = [] - feature_confs = [] - for config in configs: - if config["type"] in self._spec_types: - feature_confs.append(config) - else: - audio_confs.append(config) - - if aug_type == 'audio': - aug_confs = audio_confs - elif aug_type == 'feature': - aug_confs = feature_confs - - augmentors = [ - self._get_augmentor(config["type"], config["params"]) - for config in aug_confs - ] - rates = [config["prob"] for config in aug_confs] - - except Exception as e: - raise ValueError("Failed to parse the augmentation config json: " - "%s" % str(e)) + assert aug_type in ('audio', 'feature', 'all'), aug_type + audio_confs = [] + feature_confs = [] + all_confs = [] + for config in self.conf: + all_confs.append(config) + if config["type"] in self._spec_types: + feature_confs.append(config) + else: + audio_confs.append(config) + + if aug_type == 'audio': + aug_confs = audio_confs + elif aug_type == 'feature': + aug_confs = feature_confs + else: + aug_confs = all_confs + + augmentors = [ + self._get_augmentor(config["type"], config["params"]) + for config in aug_confs + ] + rates = [config["prob"] for config in aug_confs] return augmentors, rates def _get_augmentor(self, augmentor_type, params): """Return an augmentation model by the type name, and pass in params.""" - if augmentor_type == "volume": - return VolumePerturbAugmentor(self._rng, **params) - elif augmentor_type == "shift": - return ShiftPerturbAugmentor(self._rng, **params) - elif augmentor_type == "speed": - return SpeedPerturbAugmentor(self._rng, **params) - elif augmentor_type == "resample": - return ResampleAugmentor(self._rng, **params) - elif augmentor_type == "bayesian_normal": - return OnlineBayesianNormalizationAugmentor(self._rng, **params) - elif augmentor_type == "noise": - return NoisePerturbAugmentor(self._rng, **params) - elif augmentor_type == "impulse": - return ImpulseResponseAugmentor(self._rng, **params) - elif augmentor_type == "specaug": - return SpecAugmentor(self._rng, **params) - else: + class_obj = dynamic_import(augmentor_type, import_alias) + try: + obj = class_obj(self._rng, **params) + except Exception: raise ValueError("Unknown augmentor type [%s]." % augmentor_type) diff --git a/deepspeech/frontend/augmentor/base.py b/deepspeech/frontend/augmentor/base.py index e6f5c1e9..87cb4ef7 100644 --- a/deepspeech/frontend/augmentor/base.py +++ b/deepspeech/frontend/augmentor/base.py @@ -28,6 +28,10 @@ class AugmentorBase(): def __init__(self): pass + @abstractmethod + def __call__(self, xs): + raise NotImplementedError + @abstractmethod def transform_audio(self, audio_segment): """Adds various effects to the input audio segment. Such effects diff --git a/deepspeech/frontend/augmentor/impulse_response.py b/deepspeech/frontend/augmentor/impulse_response.py index fbd617b4..01421fc6 100644 --- a/deepspeech/frontend/augmentor/impulse_response.py +++ b/deepspeech/frontend/augmentor/impulse_response.py @@ -30,6 +30,11 @@ class ImpulseResponseAugmentor(AugmentorBase): self._rng = rng self._impulse_manifest = read_manifest(impulse_manifest_path) + def __call__(self, x, uttid=None, train=True): + if not train: + return + self.transform_audio(x) + def transform_audio(self, audio_segment): """Add impulse response effect. diff --git a/deepspeech/frontend/augmentor/noise_perturb.py b/deepspeech/frontend/augmentor/noise_perturb.py index b3c07f5c..11f5ed10 100644 --- a/deepspeech/frontend/augmentor/noise_perturb.py +++ b/deepspeech/frontend/augmentor/noise_perturb.py @@ -36,6 +36,11 @@ class NoisePerturbAugmentor(AugmentorBase): self._rng = rng self._noise_manifest = read_manifest(manifest_path=noise_manifest_path) + def __call__(self, x, uttid=None, train=True): + if not train: + return + self.transform_audio(x) + def transform_audio(self, audio_segment): """Add background noise audio. diff --git a/deepspeech/frontend/augmentor/online_bayesian_normalization.py b/deepspeech/frontend/augmentor/online_bayesian_normalization.py index 5af3b9b0..dc32a180 100644 --- a/deepspeech/frontend/augmentor/online_bayesian_normalization.py +++ b/deepspeech/frontend/augmentor/online_bayesian_normalization.py @@ -44,6 +44,11 @@ class OnlineBayesianNormalizationAugmentor(AugmentorBase): self._rng = rng self._startup_delay = startup_delay + def __call__(self, x, uttid=None, train=True): + if not train: + return + self.transform_audio(x) + def transform_audio(self, audio_segment): """Normalizes the input audio using the online Bayesian approach. diff --git a/deepspeech/frontend/augmentor/resample.py b/deepspeech/frontend/augmentor/resample.py index 9afce635..a862b184 100644 --- a/deepspeech/frontend/augmentor/resample.py +++ b/deepspeech/frontend/augmentor/resample.py @@ -31,6 +31,11 @@ class ResampleAugmentor(AugmentorBase): self._new_sample_rate = new_sample_rate self._rng = rng + def __call__(self, x, uttid=None, train=True): + if not train: + return + self.transform_audio(x) + def transform_audio(self, audio_segment): """Resamples the input audio to a target sample rate. diff --git a/deepspeech/frontend/augmentor/shift_perturb.py b/deepspeech/frontend/augmentor/shift_perturb.py index 9cc3fe2d..6c78c528 100644 --- a/deepspeech/frontend/augmentor/shift_perturb.py +++ b/deepspeech/frontend/augmentor/shift_perturb.py @@ -31,6 +31,11 @@ class ShiftPerturbAugmentor(AugmentorBase): self._max_shift_ms = max_shift_ms self._rng = rng + def __call__(self, x, uttid=None, train=True): + if not train: + return + self.transform_audio(x) + def transform_audio(self, audio_segment): """Shift audio. diff --git a/deepspeech/frontend/augmentor/spec_augment.py b/deepspeech/frontend/augmentor/spec_augment.py index 1c2e09fc..94d23bf4 100644 --- a/deepspeech/frontend/augmentor/spec_augment.py +++ b/deepspeech/frontend/augmentor/spec_augment.py @@ -157,6 +157,11 @@ class SpecAugmentor(AugmentorBase): self._time_mask = (t_0, t_0 + t) return xs + def __call__(self, x, train=True): + if not train: + return + self.transform_audio(x) + def transform_feature(self, xs: np.ndarray): """ Args: diff --git a/deepspeech/frontend/augmentor/speed_perturb.py b/deepspeech/frontend/augmentor/speed_perturb.py index d0977c13..838c5cc2 100644 --- a/deepspeech/frontend/augmentor/speed_perturb.py +++ b/deepspeech/frontend/augmentor/speed_perturb.py @@ -79,6 +79,11 @@ class SpeedPerturbAugmentor(AugmentorBase): self._rates = np.linspace( self._min_rate, self._max_rate, self._num_rates, endpoint=True) + def __call__(self, x, uttid=None, train=True): + if not train: + return + self.transform_audio(x) + def transform_audio(self, audio_segment): """Sample a new speed rate from the given range and changes the speed of the given audio clip. diff --git a/deepspeech/frontend/augmentor/volume_perturb.py b/deepspeech/frontend/augmentor/volume_perturb.py index 0d76e7a0..ffae1693 100644 --- a/deepspeech/frontend/augmentor/volume_perturb.py +++ b/deepspeech/frontend/augmentor/volume_perturb.py @@ -37,6 +37,11 @@ class VolumePerturbAugmentor(AugmentorBase): self._max_gain_dBFS = max_gain_dBFS self._rng = rng + def __call__(self, x, uttid=None, train=True): + if not train: + return + self.transform_audio(x) + def transform_audio(self, audio_segment): """Change audio loadness. diff --git a/deepspeech/io/dataset.py b/deepspeech/io/dataset.py index c5b6e737..e2db9340 100644 --- a/deepspeech/io/dataset.py +++ b/deepspeech/io/dataset.py @@ -16,6 +16,7 @@ from typing import Optional from paddle.io import Dataset from yacs.config import CfgNode +from deepspeech.frontend.utility import read_manifest from deepspeech.utils.log import Log __all__ = ["ManifestDataset", "TripletManifestDataset", "TransformDataset"] diff --git a/requirements.txt b/requirements.txt index 692f3499..af2600e0 100644 --- a/requirements.txt +++ b/requirements.txt @@ -1,5 +1,6 @@ coverage gpustat +kaldiio pre-commit pybind11 resampy==0.2.2 @@ -13,4 +14,3 @@ tensorboardX textgrid typeguard yacs -kaldiio