From 13f708739ba956aa3c63b91e529827bc73d3e160 Mon Sep 17 00:00:00 2001 From: Xinghai Sun Date: Wed, 21 Jun 2017 20:52:30 +0800 Subject: [PATCH 1/2] Improve audio featurizer and add shift augmentor. 1. Improve audio featurizer. 2. Add shift augmentor. 3. Update default argument to be the current best seggestion. 4. Add checkpoints with pass id. --- README.md | 4 +- data_utils/audio.py | 157 ++++++++++++--------- data_utils/augmentor/augmentation.py | 3 + data_utils/augmentor/volume_perturb.py | 2 +- data_utils/data.py | 7 +- data_utils/featurizer/audio_featurizer.py | 42 +++++- data_utils/featurizer/speech_featurizer.py | 24 +++- infer.py | 2 +- setup.sh | 3 + train.py | 19 ++- 10 files changed, 180 insertions(+), 83 deletions(-) diff --git a/README.md b/README.md index 0cdb203d..2912ff31 100644 --- a/README.md +++ b/README.md @@ -51,13 +51,13 @@ python compute_mean_std.py --help For GPU Training: ``` -CUDA_VISIBLE_DEVICES=0,1,2,3 python train.py --trainer_count 4 +CUDA_VISIBLE_DEVICES=0,1,2,3,4,5,6,7 python train.py ``` For CPU Training: ``` -python train.py --trainer_count 8 --use_gpu False +python train.py --use_gpu False ``` More help for arguments: diff --git a/data_utils/audio.py b/data_utils/audio.py index 5d02feb6..1faeb48a 100644 --- a/data_utils/audio.py +++ b/data_utils/audio.py @@ -66,6 +66,54 @@ class AudioSegment(object): samples, sample_rate = soundfile.read(file, dtype='float32') return cls(samples, sample_rate) + @classmethod + def slice_from_file(cls, file, start=None, end=None): + """Loads a small section of an audio without having to load + the entire file into the memory which can be incredibly wasteful. + + :param file: Input audio filepath or file object. + :type file: basestring|file + :param start: Start time in seconds. If start is negative, it wraps + around from the end. If not provided, this function + reads from the very beginning. + :type start: float + :param end: End time in seconds. If end is negative, it wraps around + from the end. If not provided, the default behvaior is + to read to the end of the file. + :type end: float + :return: AudioSegment instance of the specified slice of the input + audio file. + :rtype: AudioSegment + :raise ValueError: If start or end is incorrectly set, e.g. out of + bounds in time. + """ + sndfile = soundfile.SoundFile(file) + sample_rate = sndfile.samplerate + duration = float(len(sndfile)) / sample_rate + start = 0. if start is None else start + end = 0. if end is None else end + if start < 0.0: + start += duration + if end < 0.0: + end += duration + if start < 0.0: + raise ValueError("The slice start position (%f s) is out of " + "bounds." % start) + if end < 0.0: + raise ValueError("The slice end position (%f s) is out of bounds." % + end) + if start > end: + raise ValueError("The slice start position (%f s) is later than " + "the slice end position (%f s)." % (start, end)) + if end > duration: + raise ValueError("The slice end position (%f s) is out of bounds " + "(> %f s)" % (end, duration)) + start_frame = int(start * sample_rate) + end_frame = int(end * sample_rate) + sndfile.seek(start_frame) + data = sndfile.read(frames=end_frame - start_frame, dtype='float32') + return cls(data, sample_rate) + @classmethod def from_bytes(cls, bytes): """Create audio segment from a byte string containing audio samples. @@ -105,6 +153,20 @@ class AudioSegment(object): samples = np.concatenate([seg.samples for seg in segments]) return cls(samples, sample_rate) + @classmethod + def make_silence(cls, duration, sample_rate): + """Creates a silent audio segment of the given duration and sample rate. + + :param duration: Length of silence in seconds. + :type duration: float + :param sample_rate: Sample rate. + :type sample_rate: float + :return: Silent AudioSegment instance of the given duration. + :rtype: AudioSegment + """ + samples = np.zeros(int(duration * sample_rate)) + return cls(samples, sample_rate) + def to_wav_file(self, filepath, dtype='float32'): """Save audio segment to disk as wav file. @@ -130,68 +192,6 @@ class AudioSegment(object): format='WAV', subtype=subtype_map[dtype]) - @classmethod - def slice_from_file(cls, file, start=None, end=None): - """Loads a small section of an audio without having to load - the entire file into the memory which can be incredibly wasteful. - - :param file: Input audio filepath or file object. - :type file: basestring|file - :param start: Start time in seconds. If start is negative, it wraps - around from the end. If not provided, this function - reads from the very beginning. - :type start: float - :param end: End time in seconds. If end is negative, it wraps around - from the end. If not provided, the default behvaior is - to read to the end of the file. - :type end: float - :return: AudioSegment instance of the specified slice of the input - audio file. - :rtype: AudioSegment - :raise ValueError: If start or end is incorrectly set, e.g. out of - bounds in time. - """ - sndfile = soundfile.SoundFile(file) - sample_rate = sndfile.samplerate - duration = float(len(sndfile)) / sample_rate - start = 0. if start is None else start - end = 0. if end is None else end - if start < 0.0: - start += duration - if end < 0.0: - end += duration - if start < 0.0: - raise ValueError("The slice start position (%f s) is out of " - "bounds." % start) - if end < 0.0: - raise ValueError("The slice end position (%f s) is out of bounds." % - end) - if start > end: - raise ValueError("The slice start position (%f s) is later than " - "the slice end position (%f s)." % (start, end)) - if end > duration: - raise ValueError("The slice end position (%f s) is out of bounds " - "(> %f s)" % (end, duration)) - start_frame = int(start * sample_rate) - end_frame = int(end * sample_rate) - sndfile.seek(start_frame) - data = sndfile.read(frames=end_frame - start_frame, dtype='float32') - return cls(data, sample_rate) - - @classmethod - def make_silence(cls, duration, sample_rate): - """Creates a silent audio segment of the given duration and sample rate. - - :param duration: Length of silence in seconds. - :type duration: float - :param sample_rate: Sample rate. - :type sample_rate: float - :return: Silent AudioSegment instance of the given duration. - :rtype: AudioSegment - """ - samples = np.zeros(int(duration * sample_rate)) - return cls(samples, sample_rate) - def superimpose(self, other): """Add samples from another segment to those of this segment (sample-wise addition, not segment concatenation). @@ -225,7 +225,7 @@ class AudioSegment(object): samples = self._convert_samples_from_float32(self._samples, dtype) return samples.tostring() - def apply_gain(self, gain): + def gain_db(self, gain): """Apply gain in decibels to samples. Note that this is an in-place transformation. @@ -278,7 +278,7 @@ class AudioSegment(object): "Unable to normalize segment to %f dB because the " "the probable gain have exceeds max_gain_db (%f dB)" % (target_db, max_gain_db)) - self.apply_gain(min(max_gain_db, target_db - self.rms_db)) + self.gain_db(min(max_gain_db, target_db - self.rms_db)) def normalize_online_bayesian(self, target_db, @@ -319,7 +319,7 @@ class AudioSegment(object): rms_estimate_db = 10 * np.log10(mean_squared_estimate) # Compute required time-varying gain. gain_db = target_db - rms_estimate_db - self.apply_gain(gain_db) + self.gain_db(gain_db) def resample(self, target_sample_rate, quality='sinc_medium'): """Resample the audio to a target sample rate. @@ -366,6 +366,31 @@ class AudioSegment(object): raise ValueError("Unknown value for the sides %s" % sides) self._samples = padded._samples + def shift(self, shift_ms): + """Shift the audio in time. If `shift_ms` is positive, shift with time + advance; if negative, shift with time delay. Silence are padded to + keep the duration unchanged. + + Note that this is an in-place transformation. + + :param shift_ms: Shift time in millseconds. If positive, shift with + time advance; if negative; shift with time delay. + :type shift_ms: float + :raises ValueError: If shift_ms is longer than audio duration. + """ + if shift_ms / 1000.0 > self.duration: + raise ValueError("Absolute value of shift_ms should be smaller " + "than audio duration.") + shift_samples = int(shift_ms * self._sample_rate / 1000) + if shift_samples > 0: + # time advance + self._samples[:-shift_samples] = self._samples[shift_samples:] + self._samples[-shift_samples:] = 0 + elif shift_samples < 0: + # time delay + self._samples[-shift_samples:] = self._samples[:shift_samples] + self._samples[:-shift_samples] = 0 + def subsegment(self, start_sec=None, end_sec=None): """Cut the AudioSegment between given boundaries. @@ -505,7 +530,7 @@ class AudioSegment(object): noise_gain_db = min(self.rms_db - noise.rms_db - snr_dB, max_gain_db) noise_new = copy.deepcopy(noise) noise_new.random_subsegment(self.duration, rng=rng) - noise_new.apply_gain(noise_gain_db) + noise_new.gain_db(noise_gain_db) self.superimpose(noise_new) @property diff --git a/data_utils/augmentor/augmentation.py b/data_utils/augmentor/augmentation.py index abe1a0ec..0d60bbdb 100644 --- a/data_utils/augmentor/augmentation.py +++ b/data_utils/augmentor/augmentation.py @@ -6,6 +6,7 @@ from __future__ import print_function import json import random from data_utils.augmentor.volume_perturb import VolumePerturbAugmentor +from data_utils.augmentor.shift_perturb import ShiftPerturbAugmentor class AugmentationPipeline(object): @@ -76,5 +77,7 @@ class AugmentationPipeline(object): """Return an augmentation model by the type name, and pass in params.""" if augmentor_type == "volume": return VolumePerturbAugmentor(self._rng, **params) + elif augmentor_type == "shift": + return ShiftPerturbAugmentor(self._rng, **params) else: raise ValueError("Unknown augmentor type [%s]." % augmentor_type) diff --git a/data_utils/augmentor/volume_perturb.py b/data_utils/augmentor/volume_perturb.py index a5a9f6ca..62631fb0 100644 --- a/data_utils/augmentor/volume_perturb.py +++ b/data_utils/augmentor/volume_perturb.py @@ -36,5 +36,5 @@ class VolumePerturbAugmentor(AugmentorBase): :param audio_segment: Audio segment to add effects to. :type audio_segment: AudioSegmenet|SpeechSegment """ - gain = self._rng.uniform(min_gain_dBFS, max_gain_dBFS) + gain = self._rng.uniform(self._min_gain_dBFS, self._max_gain_dBFS) audio_segment.apply_gain(gain) diff --git a/data_utils/data.py b/data_utils/data.py index 44af7ffa..d01ca8cc 100644 --- a/data_utils/data.py +++ b/data_utils/data.py @@ -45,6 +45,9 @@ class DataGenerator(object): :types max_freq: None|float :param specgram_type: Specgram feature type. Options: 'linear'. :type specgram_type: str + :param use_dB_normalization: Whether to normalize the audio to -20 dB + before extracting the features. + :type use_dB_normalization: bool :param num_threads: Number of CPU threads for processing data. :type num_threads: int :param random_seed: Random seed. @@ -61,6 +64,7 @@ class DataGenerator(object): window_ms=20.0, max_freq=None, specgram_type='linear', + use_dB_normalization=True, num_threads=multiprocessing.cpu_count(), random_seed=0): self._max_duration = max_duration @@ -73,7 +77,8 @@ class DataGenerator(object): specgram_type=specgram_type, stride_ms=stride_ms, window_ms=window_ms, - max_freq=max_freq) + max_freq=max_freq, + use_dB_normalization=use_dB_normalization) self._num_threads = num_threads self._rng = random.Random(random_seed) self._epoch = 0 diff --git a/data_utils/featurizer/audio_featurizer.py b/data_utils/featurizer/audio_featurizer.py index 9f9d4e50..4b4d02c6 100644 --- a/data_utils/featurizer/audio_featurizer.py +++ b/data_utils/featurizer/audio_featurizer.py @@ -24,26 +24,64 @@ class AudioFeaturizer(object): corresponding to frequencies between [0, max_freq] are returned. :types max_freq: None|float + :param target_sample_rate: Audio are resampled (if upsampling or + downsampling is allowed) to this before + extracting spectrogram features. + :type target_sample_rate: float + :param use_dB_normalization: Whether to normalize the audio to a certain + decibels before extracting the features. + :type use_dB_normalization: bool + :param target_dB: Target audio decibels for normalization. + :type target_dB: float """ def __init__(self, specgram_type='linear', stride_ms=10.0, window_ms=20.0, - max_freq=None): + max_freq=None, + target_sample_rate=16000, + use_dB_normalization=True, + target_dB=-20): self._specgram_type = specgram_type self._stride_ms = stride_ms self._window_ms = window_ms self._max_freq = max_freq + self._target_sample_rate = target_sample_rate + self._use_dB_normalization = use_dB_normalization + self._target_dB = target_dB - def featurize(self, audio_segment): + def featurize(self, + audio_segment, + allow_downsampling=True, + allow_upsamplling=True): """Extract audio features from AudioSegment or SpeechSegment. :param audio_segment: Audio/speech segment to extract features from. :type audio_segment: AudioSegment|SpeechSegment + :param allow_downsampling: Whether to allow audio downsampling before + featurizing. + :type allow_downsampling: bool + :param allow_upsampling: Whether to allow audio upsampling before + featurizing. + :type allow_upsampling: bool :return: Spectrogram audio feature in 2darray. :rtype: ndarray + :raises ValueError: If audio sample rate is not supported. """ + # upsampling or downsampling + if ((audio_segment.sample_rate > self._target_sample_rate and + allow_downsampling) or + (audio_segment.sample_rate < self._target_sample_rate and + allow_upsampling)): + audio_segment.resample(self._target_sample_rate) + if audio_segment.sample_rate != self._target_sample_rate: + raise ValueError("Audio sample rate is not supported. " + "Turn allow_downsampling or allow up_sampling on.") + # decibel normalization + if self._use_dB_normalization: + audio_segment.normalize(target_db=self._target_dB) + # extract spectrogram return self._compute_specgram(audio_segment.samples, audio_segment.sample_rate) diff --git a/data_utils/featurizer/speech_featurizer.py b/data_utils/featurizer/speech_featurizer.py index 77020455..26283892 100644 --- a/data_utils/featurizer/speech_featurizer.py +++ b/data_utils/featurizer/speech_featurizer.py @@ -29,6 +29,15 @@ class SpeechFeaturizer(object): corresponding to frequencies between [0, max_freq] are returned. :types max_freq: None|float + :param target_sample_rate: Speech are resampled (if upsampling or + downsampling is allowed) to this before + extracting spectrogram features. + :type target_sample_rate: float + :param use_dB_normalization: Whether to normalize the audio to a certain + decibels before extracting the features. + :type use_dB_normalization: bool + :param target_dB: Target audio decibels for normalization. + :type target_dB: float """ def __init__(self, @@ -36,9 +45,18 @@ class SpeechFeaturizer(object): specgram_type='linear', stride_ms=10.0, window_ms=20.0, - max_freq=None): - self._audio_featurizer = AudioFeaturizer(specgram_type, stride_ms, - window_ms, max_freq) + max_freq=None, + target_sample_rate=16000, + use_dB_normalization=True, + target_dB=-20): + self._audio_featurizer = AudioFeaturizer( + specgram_type=specgram_type, + stride_ms=stride_ms, + window_ms=window_ms, + max_freq=max_freq, + target_sample_rate=target_sample_rate, + use_dB_normalization=use_dB_normalization, + target_dB=target_dB) self._text_featurizer = TextFeaturizer(vocab_filepath) def featurize(self, speech_segment): diff --git a/infer.py b/infer.py index 71518133..9037a108 100644 --- a/infer.py +++ b/infer.py @@ -56,7 +56,7 @@ parser.add_argument( help="Manifest path for decoding. (default: %(default)s)") parser.add_argument( "--model_filepath", - default='./params.tar.gz', + default='checkpoints/params.latest.tar.gz', type=str, help="Model filepath. (default: %(default)s)") parser.add_argument( diff --git a/setup.sh b/setup.sh index 1ae2a5ee..cdec34ff 100644 --- a/setup.sh +++ b/setup.sh @@ -27,4 +27,7 @@ if [ $? != 0 ]; then exit 1 fi +# prepare ./checkpoints +mkdir checkpoints + echo "Install all dependencies successfully." diff --git a/train.py b/train.py index fc23ec72..3a2d0cad 100644 --- a/train.py +++ b/train.py @@ -17,10 +17,10 @@ import utils parser = argparse.ArgumentParser(description=__doc__) parser.add_argument( - "--batch_size", default=32, type=int, help="Minibatch size.") + "--batch_size", default=256, type=int, help="Minibatch size.") parser.add_argument( "--num_passes", - default=20, + default=200, type=int, help="Training pass number. (default: %(default)s)") parser.add_argument( @@ -55,7 +55,7 @@ parser.add_argument( help="Use sortagrad or not. (default: %(default)s)") parser.add_argument( "--max_duration", - default=100.0, + default=27.0, type=float, help="Audios with duration larger than this will be discarded. " "(default: %(default)s)") @@ -67,13 +67,13 @@ parser.add_argument( "(default: %(default)s)") parser.add_argument( "--shuffle_method", - default='instance_shuffle', + default='batch_shuffle_clipped', type=str, help="Shuffle method: 'instance_shuffle', 'batch_shuffle', " "'batch_shuffle_batch'. (default: %(default)s)") parser.add_argument( "--trainer_count", - default=4, + default=8, type=int, help="Trainer number. (default: %(default)s)") parser.add_argument( @@ -110,7 +110,9 @@ parser.add_argument( "the existing model of this path. (default: %(default)s)") parser.add_argument( "--augmentation_config", - default='{}', + default='[{"type": "shift", ' + '"params": {"min_shift_ms": -5, "max_shift_ms": 5},' + '"prob": 1.0}]', type=str, help="Augmentation configuration in json-format. " "(default: %(default)s)") @@ -189,7 +191,7 @@ def train(): print("\nPass: %d, Batch: %d, TrainCost: %f" % ( event.pass_id, event.batch_id + 1, cost_sum / cost_counter)) cost_sum, cost_counter = 0.0, 0 - with gzip.open("params.tar.gz", 'w') as f: + with gzip.open("checkpoints/params.latest.tar.gz", 'w') as f: parameters.to_tar(f) else: sys.stdout.write('.') @@ -202,6 +204,9 @@ def train(): reader=test_batch_reader, feeding=test_generator.feeding) print("\n------- Time: %d sec, Pass: %d, ValidationCost: %s" % (time.time() - start_time, event.pass_id, result.cost)) + with gzip.open("checkpoints/params.pass-%d.tar.gz" % event.pass_id, + 'w') as f: + parameters.to_tar(f) # run train trainer.train( From cdd52ac2706929ea993038aedce3080eb2de8af8 Mon Sep 17 00:00:00 2001 From: Xinghai Sun Date: Mon, 26 Jun 2017 14:17:22 +0800 Subject: [PATCH 2/2] Fix a missing abs bug for DS2 AudioSegment. --- data_utils/audio.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/data_utils/audio.py b/data_utils/audio.py index 1faeb48a..d55fae1e 100644 --- a/data_utils/audio.py +++ b/data_utils/audio.py @@ -378,7 +378,7 @@ class AudioSegment(object): :type shift_ms: float :raises ValueError: If shift_ms is longer than audio duration. """ - if shift_ms / 1000.0 > self.duration: + if abs(shift_ms) / 1000.0 > self.duration: raise ValueError("Absolute value of shift_ms should be smaller " "than audio duration.") shift_samples = int(shift_ms * self._sample_rate / 1000)