From 115a06bb3739715d75cdadc3b6bc813acd328c99 Mon Sep 17 00:00:00 2001 From: chrisxu2016 <823254351@qq.com> Date: Tue, 20 Jun 2017 16:24:03 +0800 Subject: [PATCH] add augmentor class --- data_utils/audio.py | 2 +- data_utils/augmentor/augmentation.py | 9 ++++ .../online_bayesian_normalization.py | 50 +++++++++++++++++++ data_utils/augmentor/resample.py | 30 +++++++++++ data_utils/augmentor/speed_perturb.py | 43 ++++++++++++++++ data_utils/augmentor/volume_perturb.py | 2 +- 6 files changed, 134 insertions(+), 2 deletions(-) mode change 100644 => 100755 data_utils/audio.py mode change 100644 => 100755 data_utils/augmentor/augmentation.py create mode 100755 data_utils/augmentor/online_bayesian_normalization.py create mode 100755 data_utils/augmentor/resample.py create mode 100755 data_utils/augmentor/speed_perturb.py mode change 100644 => 100755 data_utils/augmentor/volume_perturb.py diff --git a/data_utils/audio.py b/data_utils/audio.py old mode 100644 new mode 100755 index 5d02feb6..03e2d5e4 --- a/data_utils/audio.py +++ b/data_utils/audio.py @@ -308,7 +308,7 @@ class AudioSegment(object): prior_mean_squared = 10.**(prior_db / 10.) prior_sum_of_squares = prior_mean_squared * prior_samples cumsum_of_squares = np.cumsum(self.samples**2) - sample_count = np.arange(len(self.num_samples)) + 1 + sample_count = np.arange(self.num_samples) + 1 if startup_sample_idx > 0: cumsum_of_squares[:startup_sample_idx] = \ cumsum_of_squares[startup_sample_idx] diff --git a/data_utils/augmentor/augmentation.py b/data_utils/augmentor/augmentation.py old mode 100644 new mode 100755 index abe1a0ec..bfe7075e --- a/data_utils/augmentor/augmentation.py +++ b/data_utils/augmentor/augmentation.py @@ -6,6 +6,9 @@ from __future__ import print_function import json import random from data_utils.augmentor.volume_perturb import VolumePerturbAugmentor +from data_utils.augmentor.speed_perturb import SpeedPerturbAugmentor +from data_utils.augmentor.resample import ResampleAugmentor +from data_utils.augmentor.online_bayesian_normalization import OnlineBayesianNormalizationAugmentor class AugmentationPipeline(object): @@ -76,5 +79,11 @@ class AugmentationPipeline(object): """Return an augmentation model by the type name, and pass in params.""" if augmentor_type == "volume": return VolumePerturbAugmentor(self._rng, **params) + if augmentor_type == "speed": + return SpeedPerturbAugmentor(self._rng, **params) + if augmentor_type == "resample": + return ResampleAugmentor(self._rng, **params) + if augmentor_type == "baysian_normal": + return OnlineBayesianNormalizationAugmentor(self._rng, **params) else: raise ValueError("Unknown augmentor type [%s]." % augmentor_type) diff --git a/data_utils/augmentor/online_bayesian_normalization.py b/data_utils/augmentor/online_bayesian_normalization.py new file mode 100755 index 00000000..bb999912 --- /dev/null +++ b/data_utils/augmentor/online_bayesian_normalization.py @@ -0,0 +1,50 @@ +"""Contain the online bayesian normalization augmentation model.""" +from __future__ import absolute_import +from __future__ import division +from __future__ import print_function + +from data_utils.augmentor.base import AugmentorBase + + +class OnlineBayesianNormalizationAugmentor(AugmentorBase): + """Augmentation model for adding online bayesian normalization. + + :param rng: Random generator object. + :type rng: random.Random + :param target_db: Target RMS value in decibels. + :type target_db: float + :param prior_db: Prior RMS estimate in decibels. + :type prior_db: float + :param prior_samples: Prior strength in number of samples. + :type prior_samples: int + :param startup_delay: Default 0.0s. If provided, this function will + accrue statistics for the first startup_delay + seconds before applying online normalization. + :type starup_delay: float. + """ + + def __init__(self, + rng, + target_db, + prior_db, + prior_samples, + startup_delay=0.0): + self._target_db = target_db + self._prior_db = prior_db + self._prior_samples = prior_samples + self._startup_delay = startup_delay + self._rng = rng + self._startup_delay=startup_delay + + def transform_audio(self, audio_segment): + """Normalizes the input audio using the online Bayesian approach. + + Note that this is an in-place transformation. + + :param audio_segment: Audio segment to add effects to. + :type audio_segment: AudioSegment|SpeechSegment + """ + audio_segment.normalize_online_bayesian(self._target_db, + self._prior_db, + self._prior_samples, + self._startup_delay) diff --git a/data_utils/augmentor/resample.py b/data_utils/augmentor/resample.py new file mode 100755 index 00000000..88ef7ed0 --- /dev/null +++ b/data_utils/augmentor/resample.py @@ -0,0 +1,30 @@ +"""Contain the resample augmentation model.""" +from __future__ import absolute_import +from __future__ import division +from __future__ import print_function + +from data_utils.augmentor.base import AugmentorBase + + +class ResampleAugmentor(AugmentorBase): + """Augmentation model for resampling. + + :param rng: Random generator object. + :type rng: random.Random + :param new_sample_rate: New sample rate in Hz + :type new_sample_rate: int + """ + + def __init__(self, rng, new_sample_rate): + self._new_sample_rate = new_sample_rate + self._rng = rng + + def transform_audio(self, audio_segment): + """Resamples the input audio to a target sample rate. + + Note that this is an in-place transformation. + + :param audio: Audio segment to add effects to. + :type audio: AudioSegment|SpeechSegment + """ + audio_segment.resample(self._new_sample_rate) \ No newline at end of file diff --git a/data_utils/augmentor/speed_perturb.py b/data_utils/augmentor/speed_perturb.py new file mode 100755 index 00000000..67de344c --- /dev/null +++ b/data_utils/augmentor/speed_perturb.py @@ -0,0 +1,43 @@ +"""Contain the speech perturbation augmentation model.""" +from __future__ import absolute_import +from __future__ import division +from __future__ import print_function + +from data_utils.augmentor.base import AugmentorBase + + +class SpeedPerturbAugmentor(AugmentorBase): + """Augmentation model for adding speed perturbation. + + See reference paper here: + http://www.danielpovey.com/files/2015_interspeech_augmentation.pdf + + :param rng: Random generator object. + :type rng: random.Random + :param min_speed_rate: Lower bound of new speed rate to sample. + :type min_speed_rate: float + :param max_speed_rate: Upper bound of new speed rate to sample. + :type max_speed_rate: float + """ + + def __init__(self, rng, min_speed_rate, max_speed_rate): + + if (min_speed_rate < 0.5): + raise ValueError("Sampling speed below 0.9 can cause unnatural effects") + if (max_speed_rate > 1.5): + raise ValueError("Sampling speed above 1.1 can cause unnatural effects") + self._min_speed_rate = min_speed_rate + self._max_speed_rate = max_speed_rate + self._rng = rng + + def transform_audio(self, audio_segment): + """Sample a new speed rate from the given range and + changes the speed of the given audio clip. + + Note that this is an in-place transformation. + + :param audio_segment: Audio segment to add effects to. + :type audio_segment: AudioSegment|SpeechSegment + """ + sampled_speed = self._rng.uniform(self._min_speed_rate, self._max_speed_rate) + audio_segment.change_speed(sampled_speed) diff --git a/data_utils/augmentor/volume_perturb.py b/data_utils/augmentor/volume_perturb.py old mode 100644 new mode 100755 index a5a9f6ca..62631fb0 --- a/data_utils/augmentor/volume_perturb.py +++ b/data_utils/augmentor/volume_perturb.py @@ -36,5 +36,5 @@ class VolumePerturbAugmentor(AugmentorBase): :param audio_segment: Audio segment to add effects to. :type audio_segment: AudioSegmenet|SpeechSegment """ - gain = self._rng.uniform(min_gain_dBFS, max_gain_dBFS) + gain = self._rng.uniform(self._min_gain_dBFS, self._max_gain_dBFS) audio_segment.apply_gain(gain)