From 65e34c535b4444c42c28f14b16a2617a73d296d1 Mon Sep 17 00:00:00 2001 From: chrisxu2016 <823254351@qq.com> Date: Thu, 15 Jun 2017 03:08:30 +0800 Subject: [PATCH] add augmentation --- data_utils/audio.py | 396 ++++++++++++++++- data_utils/augmentor/audio_database.py | 401 ++++++++++++++++++ data_utils/augmentor/augmentation.py | 15 + data_utils/augmentor/implus_response.py | 76 ++++ data_utils/augmentor/noise_speech.py | 318 ++++++++++++++ .../online_bayesian_normalization.py | 57 +++ data_utils/augmentor/resampler.py | 30 ++ data_utils/augmentor/speed_perturb.py | 53 +++ data_utils/augmentor/volume_perturb.py | 4 +- 9 files changed, 1337 insertions(+), 13 deletions(-) create mode 100755 data_utils/augmentor/audio_database.py create mode 100755 data_utils/augmentor/implus_response.py create mode 100755 data_utils/augmentor/noise_speech.py create mode 100755 data_utils/augmentor/online_bayesian_normalization.py create mode 100755 data_utils/augmentor/resampler.py create mode 100755 data_utils/augmentor/speed_perturb.py diff --git a/data_utils/audio.py b/data_utils/audio.py index 916c8ac1a..aef13c30f 100755 --- a/data_utils/audio.py +++ b/data_utils/audio.py @@ -6,6 +6,8 @@ from __future__ import print_function import numpy as np import io import soundfile +import scikits.samplerate +from scipy import signal class AudioSegment(object): @@ -62,6 +64,69 @@ class AudioSegment(object): samples, sample_rate = soundfile.read(file, dtype='float32') return cls(samples, sample_rate) + @classmethod + def slice_from_file(cls, fname, start=None, end=None): + """ + Loads a small section of an audio without having to load + the entire file into the memory which can be incredibly wasteful. + + :param fname: input audio file name + :type fname: bsaestring + :param start: start time in seconds (supported granularity is ms) + If start is negative, it wraps around from the end. If not + provided, this function reads from the very beginning. + :type start: float + :param end: start time in seconds (supported granularity is ms) + If end is negative, it wraps around from the end. If not + provided, the default behvaior is to read to the end of the + file. + :type end: float + + :return:the specified slice of input audio in the audio.AudioSegment + format. + """ + sndfile = soundfile.SoundFile(fname) + + sample_rate = sndfile.samplerate + if sndfile.channels != 1: + raise TypeError("{} has more than 1 channel.".format(fname)) + + duration = float(len(sndfile)) / sample_rate + + if start is None: + start = 0.0 + if end is None: + end = duration + + if start < 0.0: + start += duration + if end < 0.0: + end += duration + + if start < 0.0: + raise IndexError("The slice start position ({} s) is out of " + "bounds. Filename: {}".format(start, fname)) + if end < 0.0: + raise IndexError("The slice end position ({} s) is out of bounds " + "Filename: {}".format(end, fname)) + + if start > end: + raise IndexError("The slice start position ({} s) is later than " + "the slice end position ({} s)." + .format(start, end)) + + if end > duration: + raise ValueError("The slice end time ({} s) is out of " + "bounds (> {} s) Filename: {}" + .format(end, duration, fname)) + + start_frame = int(start * sample_rate) + end_frame = int(end * sample_rate) + sndfile.seek(start_frame) + data = sndfile.read(frames=end_frame - start_frame, dtype='float32') + + return cls(data, sample_rate) + @classmethod def from_bytes(cls, bytes): """Create audio segment from a byte string containing audio samples. @@ -75,6 +140,44 @@ class AudioSegment(object): io.BytesIO(bytes), dtype='float32') return cls(samples, sample_rate) + @classmethod + def make_silence(cls, duration, sample_rate): + """Creates a silent audio segment of the given duration and + sample rate. + + :param duration: length of silence in seconds + :type duration: scalar + :param sample_rate: sample rate + :type sample_rate: scalar + :returns: silence of the given duration + :rtype: AudioSegment + """ + samples = np.zeros(int(float(duration) * sample_rate)) + return cls(samples, sample_rate) + + @classmethod + def concatenate(cls, *segments): + """Concatenate an arbitrary number of audio segments together. + + :param *segments: input audio segments + :type *segments: [AudioSegment] + """ + # Perform basic sanity-checks. + N = len(segments) + if N == 0: + raise ValueError("No audio segments are given to concatenate.") + sample_rate = segments[0]._sample_rate + for segment in segments: + if sample_rate != segment._sample_rate: + raise ValueError("Can't concatenate segments with " + "different sample rates") + if type(segment) is not cls: + raise TypeError("Only audio segments of the same type " + "instance can be concatenated.") + + samples = np.concatenate([seg.samples for seg in segments]) + return cls(samples, sample_rate) + def to_wav_file(self, filepath, dtype='float32'): """Save audio segment to disk as wav file. @@ -143,23 +246,288 @@ class AudioSegment(object): new_indices = np.linspace(start=0, stop=old_length, num=new_length) self._samples = np.interp(new_indices, old_indices, self._samples) - def normalize(self, target_sample_rate): - raise NotImplementedError() + def normalize(self, target_db=-20, max_gain_db=300.0): + """Normalize audio to desired RMS value in decibels. + + Note that this is an in-place transformation. + + :param target_db: Target RMS value in decibels.This value + should be less than 0.0 as 0.0 is full-scale audio. + :type target_db: float, optional + :param max_gain_db: Max amount of gain in dB that can be applied + for normalization. This is to prevent nans when attempting + to normalize a signal consisting of all zeros. + :type max_gain_db: float, optional - def resample(self, target_sample_rate): - raise NotImplementedError() + :raises NormalizationWarning: if the required gain to normalize the + segment to the target_db value exceeds max_gain_db. + """ + gain = target_db - self.rms_db + if gain > max_gain_db: + raise ValueError( + "Unable to normalize segment to {} dB because it has an RMS " + "value of {} dB and the difference exceeds max_gain_db ({} dB)" + .format(target_db, self.rms_db, max_gain_db)) + gain = min(max_gain_db, target_db - self.rms_db) + self.apply_gain(gain) + + def normalize_online_bayesian(self, + target_db, + prior_db, + prior_samples, + startup_delay=0.0): + """ + Normalize audio using a production-compatible online/causal algorithm. + This uses an exponential likelihood and gamma prior to make + online estimates of the RMS even when there are very few samples. + + Note that this is an in-place transformation. + + :param target_db: Target RMS value in decibels + :type target_bd: scalar + :param prior_db: Prior RMS estimate in decibels + :type prior_db: scalar + :param prior_samples: Prior strength in number of samples + :type prior_samples: scalar + :param startup_delay: Default: 0.0 s. If provided, this + function will accrue statistics for the first startup_delay + seconds before applying online normalization. + :type startup_delay: scalar + """ + # Estimate total RMS online + startup_sample_idx = min(self.num_samples - 1, + int(self.sample_rate * startup_delay)) + prior_mean_squared = 10.**(prior_db / 10.) + prior_sum_of_squares = prior_mean_squared * prior_samples + cumsum_of_squares = np.cumsum(self.samples**2) + sample_count = np.arange(len(self)) + 1 + if startup_sample_idx > 0: + cumsum_of_squares[:startup_sample_idx] = \ + cumsum_of_squares[startup_sample_idx] + sample_count[:startup_sample_idx] = \ + sample_count[startup_sample_idx] + mean_squared_estimate = ((cumsum_of_squares + prior_sum_of_squares) / + (sample_count + prior_samples)) + rms_estimate_db = 10 * np.log10(mean_squared_estimate) + + # Compute required time-varying gain + gain_db = target_db - rms_estimate_db + + # Apply gain to new segment + self.apply_gain(gain_db) + + def normalize_ewma(self, + target_db, + decay_rate, + startup_delay, + rms_eps=1e-6, + max_gain_db=300.0): + startup_sample_idx = min(self.num_samples - 1, + int(self.sample_rate * startup_delay)) + mean_sq = self.samples**2 + if startup_sample_idx > 0: + mean_sq[:startup_sample_idx] = \ + np.sum(mean_sq[:startup_sample_idx]) / startup_sample_idx + idx_start = max(0, startup_sample_idx - 1) + initial_condition = mean_sq[idx_start] * decay_rate + mean_sq[idx_start:] = lfilter( + [1.0 - decay_rate], [1.0, -decay_rate], + mean_sq[idx_start:], + axis=0, + zi=[initial_condition])[0] + rms_estimate_db = 10.0 * np.log10(mean_sq + rms_eps) + gain_db = target_db - rms_estimate_db + if np.any(gain_db > max_gain_db): + warnings.warn( + "Unable to normalize segment to {} dB because it has an RMS " + "value of {} dB and the difference exceeds max_gain_db ({} dB)" + .format(target_db, self.rms_db, max_gain_db), + NormalizationWarning) + gain_db = np.minimum(gain_db, max_gain_db) + self.apply_gain(gain_db) + + def resample(self, target_sample_rate, quality='sinc_medium'): + """Resample audio and return new AudioSegment. + This resamples the audio to a new sample rate and returns a brand + new AudioSegment. The existing AudioSegment is unchanged. + + Note that this is an in-place transformation. + + :param new_sample_rate: target sample rate + :type new_sample_rate: scalar + :param quality: One of {'sinc_fastest', 'sinc_medium', 'sinc_best'}. + Sets resampling speed/quality tradeoff. + See http://www.mega-nerd.com/SRC/api_misc.html#Converters + :type quality: basestring + """ + resample_ratio = target_sample_rate / self._sample_rate + new_samples = scikits.samplerate.resample( + self._samples, r=resample_ratio, type=quality) + self._samples = new_samples + self._sample_rate = new_sample_rate def pad_silence(self, duration, sides='both'): - raise NotImplementedError() + """Pads this audio sample with a period of silence. + + Note that this is an in-place transformation. + + :param duration: length of silence in seconds to pad + :type duration: float + :param sides: + 'beginning' - adds silence in the beginning + 'end' - adds silence in the end + 'both' - adds silence in both the beginning and the end. + :type sides: basestring + """ + if duration == 0.0: + return self + cls = type(self) + silence = cls.make_silence(duration, self._sample_rate) + if sides == "beginning": + padded = cls.concatenate(silence, self) + elif sides == "end": + padded = cls.concatenate(self, silence) + elif sides == "both": + padded = cls.concatenate(silence, self, silence) + else: + raise ValueError("Unknown value for the kwarg 'sides'") + self._samples = padded._samples + self._sample_rate = padded._sample_rate def subsegment(self, start_sec=None, end_sec=None): - raise NotImplementedError() + """Return new AudioSegment containing audio between given boundaries. + + :param start_sec: Beginning of subsegment in seconds, + (beginning of segment if None). + :type start_sec: scalar + :param end_sec: End of subsegment in seconds, + (end of segment if None). + :type end_sec: scalar + + :return: New AudioSegment containing specified + subsegment. + :trype: AudioSegment + """ + # Default boundaries + if start_sec is None: + start_sec = 0.0 + if end_sec is None: + end_sec = self.duration + + # negative boundaries are relative to end of segment + if start_sec < 0.0: + start_sec = self.duration + start_sec + if end_sec < 0.0: + end_sec = self.duration + end_sec - def convolve(self, filter, allow_resample=False): - raise NotImplementedError() + start_sample = int(round(start_sec * self._sample_rate)) + end_sample = int(round(end_sec * self._sample_rate)) + samples = self._samples[start_sample:end_sample] - def convolve_and_normalize(self, filter, allow_resample=False): - raise NotImplementedError() + return type(self)(samples, sample_rate=self._sample_rate) + + def random_subsegment(self, subsegment_length, rng=None): + """ + Return a random subsegment of a specified length in seconds. + + :param subsegment_length: Subsegment length in seconds. + :type subsegment_length: scalar + :param rng: Random number generator state + :type rng: random.Random [optional] + + + :return:clip (SpeechDLSegment): New SpeechDLSegmen containing random + subsegment of original segment. + """ + if rng is None: + rng = random.Random() + + if subsegment_length > self.duration: + raise ValueError("Length of subsegment must not be greater " + "than original segment.") + start_time = rng.uniform(0.0, self.duration - subsegment_length) + return self.subsegment(start_time, start_time + subsegment_length) + + def convolve(self, ir, allow_resampling=False): + """Convolve this audio segment with the given filter. + + :param ir: impulse response + :type ir: AudioSegment + :param allow_resampling: indicates whether resampling is allowed + when the ir has a different sample rate from this signal. + :type allow_resampling: boolean + """ + if allow_resampling and self.sample_rate != ir.sample_rate: + ir = ir.resample(self.sample_rate) + + if self.sample_rate != ir.sample_rate: + raise ValueError("Impulse response sample rate ({}Hz) is " + "equal to base signal sample rate ({}Hz)." + .format(ir.sample_rate, self.sample_rate)) + + samples = signal.fftconvolve(self.samples, ir.samples, "full") + self._samples = samples + + def convolve_and_normalize(self, ir, allow_resample=False): + """Convolve and normalize the resulting audio segment so that it + has the same average power as the input signal. + + :param ir: impulse response + :type ir: AudioSegment + :param allow_resampling: indicates whether resampling is allowed + when the ir has a different sample rate from this signal. + :type allow_resampling: boolean + """ + self.convolve(ir, allow_resampling=allow_resampling) + self.normalize(target_db=self.rms_db) + + def add_noise(self, + noise, + snr_dB, + allow_downsampling=False, + max_gain_db=300.0, + rng=None): + """Adds the given noise segment at a specific signal-to-noise ratio. + If the noise segment is longer than this segment, a random subsegment + of matching length is sampled from it and used instead. + + :param noise: Noise signal to add. + :type noise: SpeechDLSegment + :param snr_dB: Signal-to-Noise Ratio, in decibels. + :type snr_dB: scalar + :param allow_downsampling: whether to allow the noise signal + to be downsampled to match the base signal sample rate. + :type allow_downsampling: boolean + :param max_gain_db: Maximum amount of gain to apply to noise + signal before adding it in. This is to prevent attempting + to apply infinite gain to a zero signal. + :type max_gain_db: scalar + :param rng: Random number generator state. + :type rng: random.Random + + Returns: + SpeechDLSegment: signal with noise added. + """ + if rng is None: + rng = random.Random() + + if allow_downsampling and noise.sample_rate > self.sample_rate: + noise = noise.resample(self.sample_rate) + + if noise.sample_rate != self.sample_rate: + raise ValueError("Noise sample rate ({}Hz) is not equal to " + "base signal sample rate ({}Hz)." + .format(noise.sample_rate, self.sample_rate)) + if noise.duration < self.duration: + raise ValueError("Noise signal ({} sec) must be at " + "least as long as base signal ({} sec)." + .format(noise.duration, self.duration)) + noise_gain_db = self.rms_db - noise.rms_db - snr_dB + noise_gain_db = min(max_gain_db, noise_gain_db) + noise_subsegment = noise.random_subsegment(self.duration, rng=rng) + output = self + self.tranform_noise(noise_subsegment, noise_gain_db) + self._samples = output._samples + self._sample_rate = output._sample_rate @property def samples(self): @@ -186,7 +554,7 @@ class AudioSegment(object): :return: Number of samples. :rtype: int """ - return self._samples.shape(0) + return self._samples.shape[0] @property def duration(self): @@ -250,3 +618,9 @@ class AudioSegment(object): else: raise TypeError("Unsupported sample type: %s." % samples.dtype) return output_samples.astype(dtype) + + def tranform_noise(self, noise_subsegment, noise_gain_db): + """ tranform noise file + """ + return type(self)(noise_subsegment._samples * (10.**( + noise_gain_db / 20.)), noise_subsegment._sample_rate) diff --git a/data_utils/augmentor/audio_database.py b/data_utils/augmentor/audio_database.py new file mode 100755 index 000000000..e41c6dd72 --- /dev/null +++ b/data_utils/augmentor/audio_database.py @@ -0,0 +1,401 @@ +from __future__ import print_function +from collections import defaultdict +import bisect +import logging +import numpy as np +import os +import random +import sys + +UNK_TAG = "" + + +def stream_audio_index(fname, UNK=UNK_TAG): + """Reads an audio index file and emits one record in the index at a time. + + :param fname: audio index path + :type fname: basestring + :param UNK: UNK token to denote that certain audios are not tagged. + :type UNK: basesring + + Yields: + idx, duration, size, relpath, tags (int, float, int, str, list(str)): + audio file id, length of the audio in seconds, size in byte, + relative path w.r.t. to the root noise directory, list of tags + """ + with open(fname) as audio_index_file: + for i, line in enumerate(audio_index_file): + tok = line.strip().split("\t") + assert len(tok) >= 4, \ + "Invalid line at line {} in file {}".format( + i + 1, audio_index_file) + idx = int(tok[0]) + duration = float(tok[1]) + # Sometimes, the duration can round down to 0.0 + assert duration >= 0.0, \ + "Invalid duration at line {} in file {}".format( + i + 1, audio_index_file) + size = int(tok[2]) + assert size > 0, \ + "Invalid size at line {} in file {}".format( + i + 1, audio_index_file) + relpath = tok[3] + if len(tok) == 4: + tags = [UNK_TAG] + else: + tags = tok[4:] + yield idx, duration, size, relpath, tags + + +def truncate_float(val, ndigits=6): + """ Truncates a floating-point value to have the desired number of + digits after the decimal point. + + :param val: input value. + :type val: float + :parma ndigits: desired number of digits. + :type ndigits: int + + :return: truncated value + :rtype: float + """ + p = 10.0**ndigits + return float(int(val * p)) / p + + +def print_audio_index(idx, duration, size, relpath, tags, file=sys.stdout): + """Prints an audio record to the index file. + + :param idx: Audio file id. + :type idx: int + :param duration: length of the audio in seconds + :type duration: float + :param size: size of the file in bytes + :type size: int + :param relpath: relative path w.r.t. to the root noise directory. + :type relpath: basestring + :parma tags: list of tags + :parma tags: list(str) + :parma file: file to which we want to write an audio record. + :type file: sys.stdout + """ + file.write("{}\t{:.6f}\t{}\t{}" + .format(idx, truncate_float(duration, ndigits=6), size, relpath)) + for tag in tags: + file.write("\t{}".format(tag)) + file.write("\n") + + +class AudioIndex(object): + """ In-memory index of audio files that do not have annotations. + This supports duration-based sampling and sampling from a target + distribution. + + Each line in the index file consists of the following fields: + (id (int), duration (float), size (int), relative path (str), + list of tags ([str])) + """ + + def __init__(self): + self.audio_dir = None + self.index_fname = None + self.tags = None + self.bin_size = 2.0 + self.clear() + + def clear(self): + """ Clears the index + + Returns: + None + """ + self.idx_to_record = {} + # The list of indices correspond to audio files whose duration is + # greater than or equal to the key. + self.duration_to_id_set = {} + self.duration_to_id_set_per_tag = defaultdict(lambda: {}) + self.duration_to_list = defaultdict(lambda: []) + self.duration_to_list_per_tag = defaultdict( + lambda: defaultdict(lambda: [])) + self.tag_to_id_set = defaultdict(lambda: set()) + self.shared_duration_bins = [] + self.id_set_complete = set() + self.id_set = set() + self.duration_bins = [] + + def has_audio(self, distr=None): + """ + :param distr: The target distribution of audio tags that we want to + match. If this is not supplied, the function simply checks that + there are some audio files. + :parma distr: dict + :return: True if there are audio files. + :rtype: boolean + """ + if distr is None: + return len(self.id_set) > 0 + else: + for tag in distr: + if tag not in self.duration_to_list_per_tag: + return False + return True + + def _load_all_records_from_disk(self, audio_dir, idx_fname, bin_size): + """Loads all audio records from the disk into memory and groups them + into chunks based on their duration and the bin_size granalarity. + + Once all the records are read, indices are built from these records + by another function so that the audio samples can be drawn efficiently. + + Updates: + self.audio_dir (path): audio root directory + self.idx_fname (path): audio database index filename + self.bin_size (float): granularity of bins + self.idx_to_record (dict): maps from the audio id to + (duration, file_size, relative_path, tags) + self.tag_to_id_set (dict): maps from the tag to + the set of id's of audios that have this tag. + self.id_set_complete (set): set of all audio id's in the index file + self.min_duration (float): minimum audio duration observed in the + index file + self.duration_bins (list): the lower bounds on the duration of + audio files falling in each bin + self.duration_to_id_set (dict): contains (k, v) where v is the set + of id's of audios whose lengths are longer than or equal to k. + (e.g. k is the duration lower bound of this bin). + self.duration_to_id_set_per_tag (dict): Something like above but + has a finer granularity mapping from the tag to + duration_to_id_set. + self.shared_duration_bins (list): list of sets where each set + contains duration lower bounds whose audio id sets are the + same. The rationale for having this is that there are a few + but extremely long audio files which lead to a lot of bins. + When the id sets do not change across various minimum duration + boundaries, we + cluster these together and make them point to the same id set + reference. + + :return: whether the records were read from the disk. The assumption is + that the audio index file on disk and the actual audio files + are constructed once and never change during training. We only + re-read when either the directory or the index file path change. + """ + if self.audio_dir == audio_dir and self.idx_fname == idx_fname and \ + self.bin_size == bin_size: + # The audio directory and/or the list of audio files + # haven't changed. No need to load the list again. + return False + + # Remember where the audio index is most recently read from. + self.audio_dir = audio_dir + self.idx_fname = idx_fname + self.bin_size = bin_size + + # Read in the idx and compute the number of bins necessary + self.clear() + rank = [] + min_duration = float('inf') + max_duration = float('-inf') + for idx, duration, file_size, relpath, tags in \ + stream_audio_index(idx_fname): + self.idx_to_record[idx] = (duration, file_size, relpath, tags) + max_duration = max(max_duration, duration) + min_duration = min(min_duration, duration) + rank.append((duration, idx)) + for tag in tags: + self.tag_to_id_set[tag].add(idx) + if len(rank) == 0: + # file is empty + raise IOError("Index file {} is empty".format(idx_fname)) + for tag in self.tag_to_id_set: + self.id_set_complete |= self.tag_to_id_set[tag] + dur = min_duration + self.min_duration = min_duration + while dur < max_duration + bin_size: + self.duration_bins.append(dur) + dur += bin_size + + # Sort in decreasing order of duration and populate + # the cumulative indices lists. + rank.sort(reverse=True) + + # These are indices for `rank` and used to keep track of whether + # there are new records to add in the current bin. + last = 0 + cur = 0 + + # The set of audios falling in the previous bin; in the case, + # where we don't find new audios for the current bin, we store + # the reference to the last set so as to conserve memory. + # This is not such a big problem if the audio duration is + # bounded by a small number like 30 seconds and the + # bin size is big enough. But, for raw freesound audios, + # some audios can be as long as a few hours! + last_audio_set = set() + + # The same but for each tag so that we can pick audios based on + # tags and also some user-specified tag distribution. + last_audio_set_per_tag = defaultdict(lambda: set()) + + # Set of lists of bins sharing the same audio sets. + shared = set() + + for i in range(len(self.duration_bins) - 1, -1, -1): + lower_bound = self.duration_bins[i] + new_audio_idxs = set() + new_audio_idxs_per_tag = defaultdict(lambda: set()) + while cur < len(rank) and rank[cur][0] >= lower_bound: + idx = rank[cur][1] + tags = self.idx_to_record[idx][3] + new_audio_idxs.add(idx) + for tag in tags: + new_audio_idxs_per_tag[tag].add(idx) + cur += 1 + # This makes certain that the same list is shared across + # different bins if no new indices are added. + if cur == last: + shared.add(lower_bound) + else: + last_audio_set = last_audio_set | new_audio_idxs + for tag in new_audio_idxs_per_tag: + last_audio_set_per_tag[tag] = \ + last_audio_set_per_tag[tag] | \ + new_audio_idxs_per_tag[tag] + if len(shared) > 0: + self.shared_duration_bins.append(shared) + shared = set([lower_bound]) + ### last_audio_set = set() should set blank + last = cur + self.duration_to_id_set[lower_bound] = last_audio_set + for tag in last_audio_set_per_tag: + self.duration_to_id_set_per_tag[lower_bound][tag] = \ + last_audio_set_per_tag[tag] + + # The last `shared` record isn't added to the `shared_duration_bins`. + self.shared_duration_bins.append(shared) + + # We make sure that the while loop above has exhausted through the + # `rank` list by checking if the `cur`rent index in `rank` equals + # the length of the array, which is the halting condition. + assert cur == len(rank) + + return True + + def _build_index_from_records(self, tag_list): + """ Uses the in-memory records read from the index file to build + an in-memory index restricted to the given tag list. + + :param tag_list: List of tags we are interested in sampling from. + :type tag_list: list(str) + + Updates: + self.id_set (set): the set of all audio id's that can be sampled. + self.duration_to_list (dict): maps from the duration lower bound + to the id's of audios longer than this duration. + self.duration_to_list_per_tag (dict): maps from the tag to + the same structure as self.duration_to_list. This is to support + sampling from a target noise distribution. + + :return: whether the index was built from scratch + """ + if self.tags == tag_list: + return False + + self.tags = tag_list + if len(tag_list) == 0: + self.id_set = self.id_set_complete + else: + self.id_set = set() + for tag in tag_list: + self.id_set |= self.tag_to_id_set[tag] + + # Next, we need to take a subset of the audio files + for shared in self.shared_duration_bins: + # All bins in `shared' have the same index lists + # so we can intersect once and set all of them to this list. + lb = list(shared)[0] + intersected = list(self.id_set & self.duration_to_id_set[lb]) + duration_to_id_set = self.duration_to_id_set_per_tag[lb] + intersected_per_tag = { + tag: self.tag_to_id_set[tag] & duration_to_id_set[tag] + for tag in duration_to_id_set + } + for bin_key in shared: + self.duration_to_list[bin_key] = intersected + for tag in intersected_per_tag: + self.duration_to_list_per_tag[tag][bin_key] = \ + intersected_per_tag[tag] + assert len(self.duration_to_list) == len(self.duration_to_id_set) + return True + + def refresh_records_from_index_file(self, + audio_dir, + idx_fname, + tag_list, + bin_size=2.0): + """ Loads the index file and populates the records + for building the internal index. + + If the audio directory or index file name has changed, the whole index + is reloaded from scratch. If only the tag_list is changed, then the + desired index is built from the complete, in-memory record. + + :param audio_dir: audio directory + :type audio_dir: basestring + :param idx_fname: audio index file name + :type idex_fname: basestring + :param tag_list: list of tags we are interested in loading; + if empty, we load all. + :type tag_list: list + :param bin_size: optional argument for controlling the granularity + of duration bins + :type bin_size: float + """ + if tag_list is None: + tag_list = [] + reloaded_records = self._load_all_records_from_disk(audio_dir, + idx_fname, bin_size) + if reloaded_records or self.tags != tag_list: + self._build_index_from_records(tag_list) + logger.info('loaded {} audio files from {}' + .format(len(self.id_set), idx_fname)) + + def sample_audio(self, duration, rng=None, distr=None): + """ Uniformly draws an audio record of at least the desired duration + + :param duration: minimum desired audio duration + :type duration: float + :param rng: random number generator + :type rng: random.Random + :param distr: target distribution of audio tags. If not provided, + :type distr: dict + all audio files are sampled uniformly at random. + + :returns: success, (duration, file_size, path) + """ + if duration < 0.0: + duration = self.min_duration + i = bisect.bisect_left(self.duration_bins, duration) + if i == len(self.duration_bins): + return False, None + bin_key = self.duration_bins[i] + if distr is None: + indices = self.duration_to_list[bin_key] + else: + # If a desired audio distribution is given, we sample from it. + if rng is None: + rng = random.Random() + nprng = np.random.RandomState(rng.getrandbits(32)) + prob_masses = distr.values() + prob_masses /= np.sum(prob_masses) + tag = nprng.choice(distr.keys(), p=prob_masses) + indices = self.duration_to_list_per_tag[tag][bin_key] + if len(indices) == 0: + return False, None + else: + if rng is None: + rng = random.Random() + # duration, file size and relative path from root + s = self.idx_to_record[rng.sample(indices, 1)[0]] + s = (s[0], s[1], os.path.join(self.audio_dir, s[2])) + return True, s diff --git a/data_utils/augmentor/augmentation.py b/data_utils/augmentor/augmentation.py index abe1a0ec8..c0a70ad18 100755 --- a/data_utils/augmentor/augmentation.py +++ b/data_utils/augmentor/augmentation.py @@ -6,6 +6,11 @@ from __future__ import print_function import json import random from data_utils.augmentor.volume_perturb import VolumePerturbAugmentor +from data_utils.augmentor.resamler import ResamplerAugmentor +from data_utils.augmentor.speed_perturb import SpeedPerturbatioAugmentor +from data_utils.augmentor.online_bayesian_normalization import OnlineBayesianNormalizationAugmentor +from data_utils.augmentor.Impulse_response import ImpulseResponseAugmentor +from data_utils.augmentor.noise_speech import NoiseSpeechAugmentor class AugmentationPipeline(object): @@ -76,5 +81,15 @@ class AugmentationPipeline(object): """Return an augmentation model by the type name, and pass in params.""" if augmentor_type == "volume": return VolumePerturbAugmentor(self._rng, **params) + if augmentor_type == "resamle": + return ResamplerAugmentor(self._rng, **params) + if augmentor_type == "speed": + return SpeedPerturbatioAugmentor(self._rng, **params) + if augmentor_type == "online_bayesian_normalization": + return OnlineBayesianNormalizationAugmentor(self._rng, **params) + if augmentor_type == "Impulse_response": + return ImpulseResponseAugmentor(self._rng, **params) + if augmentor_type == "noise_speech": + return NoiseSpeechAugmentor(self._rng, **params) else: raise ValueError("Unknown augmentor type [%s]." % augmentor_type) diff --git a/data_utils/augmentor/implus_response.py b/data_utils/augmentor/implus_response.py new file mode 100755 index 000000000..cc2053421 --- /dev/null +++ b/data_utils/augmentor/implus_response.py @@ -0,0 +1,76 @@ +""" Impulse response""" + +from __future__ import absolute_import +from __future__ import division +from __future__ import print_function + +from . import base +from . import audio_database +from data_utils.speech import SpeechSegment + + +class ImpulseResponseAugmentor(base.AugmentorBase): + """ Instantiates an impulse response model + + :param ir_dir: directory containing impulse responses + :type ir_dir: basestring + :param tags: optional parameter for specifying what + particular impulse responses to apply. + :type tags: list + :parm tag_distr: optional noise distribution + :type tag_distr: dict + """ + + def __init__(self, rng, ir_dir, index_file, tags=None, tag_distr=None): + # Define all required parameter maps here. + self.ir_dir = ir_dir + self.index_file = index_file + + self.tags = tags + self.tag_distr = tag_distr + + self.audio_index = audio_database.AudioIndex() + self.rng = rng + + def _init_data(self): + """ Preloads stuff from disk in an attempt (e.g. list of files, etc) + to make later loading faster. If the data configuration remains the + same, this function does nothing. + + """ + self.audio_index.refresh_records_from_index_file( + self.ir_dir, self.index_file, self.tags) + + def transform_audio(self, audio_segment): + """ Convolves the input audio with an impulse response. + + :param audio_segment: input audio + :type audio_segment: AudioSegemnt + """ + # This handles the cases where the data source or directories change. + self._init_data() + + read_size = 0 + tag_distr = self.tag_distr + if not self.audio_index.has_audio(tag_distr): + if tag_distr is None: + if not self.tags: + raise RuntimeError("The ir index does not have audio " + "files to sample from.") + else: + raise RuntimeError("The ir index does not have audio " + "files of the given tags to sample " + "from.") + else: + raise RuntimeError("The ir index does not have audio " + "files to match the target ir " + "distribution.") + else: + # Querying with a negative duration triggers the index to search + # from all impulse responses. + success, record = self.audio_index.sample_audio( + -1.0, rng=self.rng, distr=tag_distr) + if success is True: + _, read_size, ir_fname = record + ir_wav = SpeechSegment.from_file(ir_fname) + audio_segment.convolve(ir_wav, allow_resampling=True) diff --git a/data_utils/augmentor/noise_speech.py b/data_utils/augmentor/noise_speech.py new file mode 100755 index 000000000..8cf7c27b6 --- /dev/null +++ b/data_utils/augmentor/noise_speech.py @@ -0,0 +1,318 @@ +""" noise speech +""" +from __future__ import absolute_import +from __future__ import division +from __future__ import print_function +import logging +import numpy as np +import os +from collections import defaultdict + +from . import base +from . import audio_database +from data_utils.speech import SpeechSegment + +TURK = "turk" +USE_AUDIO_DATABASE_SOURCES = frozenset(["freesound", "chime"]) +HALF_NOISE_LENGTH_MIN_THRESHOLD = 3.0 +FIND_NOISE_MAX_ATTEMPTS = 20 + +logger = logging.getLogger(__name__) + + +def get_first_smaller(items, value): + index = bisect.bisect_left(items, value) - 1 + assert items[index] < value, \ + 'get_first_smaller failed! %d %d' % (items[index], value) + return items[index] + + +def get_first_larger(items, value): + 'Find leftmost value greater than value' + index = bisect.bisect_right(items, value) + assert index < len(items), \ + "no noise bin exists for this audio length (%f)" % value + assert items[index] > value, \ + 'get_first_larger failed! %d %d' % (items[index], value) + return items[index] + + +def _get_turk_noise_files(noise_dir, index_file): + """ Creates a map from duration => a list of noise filenames + + :param noise_dir: Directory of noise files which contains + "noise-samples-list" + :type noise_dir: basestring + :param index_file: Noise list + :type index_file: basestring + + returns:noise_files (defaultdict): A map of bins to noise files. + Each key is the duration, and the value is a list of noise + files binned to this duration. Each bin is 2 secs. + + Note: noise-samples-list should contain one line per noise (wav) file + along with its duration in milliseconds + """ + noise_files = defaultdict(list) + if not os.path.exists(index_file): + logger.error('No noise files were found at {}'.format(index_file)) + return noise_files + num_noise_files = 0 + rounded_durations = list(range(0, 65, 2)) + with open(index_file, 'r') as fl: + for line in fl: + fname = os.path.join(noise_dir, line.strip().split()[0]) + duration = float(line.strip().split()[1]) / 1000 + # bin the noise files into length bins rounded by 2 sec + bin_id = get_first_smaller(rounded_durations, duration) + noise_files[bin_id].append(fname) + num_noise_files += 1 + logger.info('Loaded {} turk noise files'.format(num_noise_files)) + return noise_files + + +class NoiseSpeechAugmentor(base.AugmentorBase): + """ Noise addition block + + :param snr_min: minimum signal-to-noise ratio + :type snr_min: float + :param snr_max: maximum signal-to-noise ratio + :type snr_max: float + :param noise_dir: root of where noise files are stored + :type noise_fir: basestring + :param index_file: index of noises of interest in noise_dir + :type index_file: basestring + :param source: select one from + - turk + - freesound + - chime + Note that this field is no longer required for the freesound + and chime + :type source: string + :param tags: optional parameter for specifying what + particular noises we want to add. See above for the available tags. + :type tags: list + :param tag_distr: optional noise distribution + :type tag_distr: dict + """ + + def __init__(self, + rng, + snr_min, + snr_max, + noise_dir, + source, + allow_downsampling=None, + index_file=None, + tags=None, + tag_distr=None): + # Define all required parameter maps here. + self.rng = rng + self.snr_min = snr_min + self.snr_max = snr_max + self.noise_dir = noise_dir + self.source = source + + self.allow_downsampling = allow_downsampling + self.index_file = index_file + self.tags = tags + self.tag_distr = tag_distr + + # When new noise sources are added, make sure to define the + # associated bookkeeping variables here. + self.turk_noise_files = [] + self.turk_noise_dir = None + self.audio_index = audio_database.AudioIndex() + + def _init_data(self): + """ Preloads stuff from disk in an attempt (e.g. list of files, etc) + to make later loading faster. If the data configuration remains the + same, this function does nothing. + + """ + noise_dir = self.noise_dir + index_file = self.index_file + source = self.source + if not index_file: + if source == TURK: + index_file = os.path.join(noise_dir, 'noise-samples-list') + logger.debug("index_file not provided; " + "defaulting to " + + index_file) + else: + if source != "": + assert source in USE_AUDIO_DATABASE_SOURCES, \ + "{} not supported by audio_database".format(source) + index_file = os.path.join(noise_dir, + "audio_index_commercial.txt") + logger.debug("index_file not provided; " + "defaulting to " + + index_file) + + if source == TURK: + if self.turk_noise_dir != noise_dir: + self.turk_noise_dir = noise_dir + self.turk_noise_files = _get_turk_noise_files(noise_dir, + index_file) + # elif source == TODO_SUPPORT_NON_AUDIO_DATABASE_BASED_SOURCES: + else: + if source != "": + assert source in USE_AUDIO_DATABASE_SOURCES, \ + "{} not supported by audio_database".format(source) + self.audio_index.refresh_records_from_index_file( + self.noise_dir, index_file, self.tags) + + def transform_audio(self, audio_segment): + """Adds walla noise + + :param audio_segment: Input audio + :type audio_segment: SpeechSegment + """ + # This handles the cases where the data source or directories change. + self._init_data + source = self.source + allow_downsampling = self.allow_downsampling + if source == TURK: + self._add_turk_noise(audio_segment, self.rng, allow_downsampling) + # elif source == TODO_SUPPORT_NON_AUDIO_DATABASE_BASED_SOURCES: + else: + self._add_noise(audio_segment, self.rng, allow_downsampling) + + def _sample_snr(self): + """ Returns a float sampled in [`self.snr_min`, `self.snr_max`] + if both `self.snr_min` and `self.snr_max` are non-zero. + """ + snr_min = self.snr_min + snr_max = self.snr_max + sampled_snr = self.rng.uniform(snr_min, snr_max) + return sampled_snr + + def _add_turk_noise(self, audio_segment, allow_downsampling): + """ Adds a turk noise to the input audio. + + :param audio_segment: input audio + :type audio_segment: audiosegment + :param allow_downsampling: indicates whether downsampling + is allowed + :type allow_downsampling: boolean + """ + read_size = 0 + if len(self.turk_noise_files) > 0: + snr = self._sample_snr(self.rng) + # Draw the noise file randomly from noise files that are + # slightly longer than the utterance + noise_bins = sorted(self.turk_noise_files.keys()) + # note some bins can be empty, so we can't just round up + # to the nearest 2-sec interval + rounded_duration = get_first_larger(noise_bins, + audio_segment.duration) + noise_fname = \ + self.rng.sample(self.turk_noise_files[rounded_duration], 1)[0] + noise = SpeechSegment.from_wav_file(noise_fname) + logger.debug('noise_fname {}'.format(noise_fname)) + logger.debug('snr {}'.format(snr)) + read_size = len(noise) * 2 + # May throw exceptions, but this is caught by + # AudioFeaturizer.get_audio_files. + audio_segment.add_noise( + noise, snr, rng=self.rng, allow_downsampling=allow_downsampling) + + def _add_noise(self, audio_segment, allow_downsampling): + """ Adds a noise indexed in audio_database.AudioIndex. + + :param audio_segment: input audio + :type audio_segment: SpeechSegment + :param allow_downsampling: indicates whether downsampling + is allowed + :type allow_downsampling: boolean + + Returns: + (SpeechSegment, int) + - sound with turk noise added + - number of bytes read from disk + """ + read_size = 0 + tag_distr = self.tag_distr + if not self.audio_index.has_audio(tag_distr): + if tag_distr is None: + if not self.tags: + raise RuntimeError("The noise index does not have audio " + "files to sample from.") + else: + raise RuntimeError("The noise index does not have audio " + "files of the given tags to sample " + "from.") + else: + raise RuntimeError("The noise index does not have audio " + "files to match the target noise " + "distribution.") + else: + # Compute audio segment related statistics + audio_duration = audio_segment.duration + + # Sample relevant augmentation parameters. + snr = self._sample_snr(self.rng) + + # Perhaps, we may not have a sufficiently long noise, so we need + # to search iteratively. + min_duration = audio_duration + 0.25 + for _ in range(FIND_NOISE_MAX_ATTEMPTS): + logger.debug("attempting to find noise of length " + "at least {}".format(min_duration)) + + success, record = \ + self.audio_index.sample_audio(min_duration, + rng=self.rng, + distr=tag_distr) + + if success is True: + noise_duration, read_size, noise_fname = record + + # Assert after logging so we know + # what caused augmentation to fail. + logger.debug("noise_fname {}".format(noise_fname)) + logger.debug("snr {}".format(snr)) + assert noise_duration >= min_duration + break + + # Decrease the desired minimum duration linearly. + # If the value becomes smaller than some threshold, + # we half the value instead. + if min_duration > HALF_NOISE_LENGTH_MIN_THRESHOLD: + min_duration -= 2.0 + else: + min_duration *= 0.5 + + if success is False: + logger.info("Failed to find a noise file") + return + + diff_duration = audio_duration + 0.25 - noise_duration + if diff_duration >= 0.0: + # Here, the noise is shorter than the audio file, so + # we pad with zeros to make sure the noise sound is applied + # with a uniformly random shift. + noise = SpeechSegment.from_file(noise_fname) + noise = noise.pad_silence(diff_duration, sides="both") + else: + # The noise clip is at least ~25 ms longer than the audio + # segment here. + diff_duration = int(noise_duration * audio_segment.sample_rate) - \ + int(audio_duration * audio_segment.sample_rate) - \ + int(0.02 * audio_segment.sample_rate) + start = float(self.rng.randint(0, diff_duration)) / \ + audio.sample_rate + finish = min(start + audio_duration + 0.2, noise_duration) + noise = SpeechSegment.slice_from_file(noise_fname, start, + finish) + + if len(noise) < len(audio_segment): + # This is to ensure that the noise clip is at least as + # long as the audio segment. + num_samples_to_pad = len(audio_segment) - len(noise) + # Padding this amount of silence on both ends ensures that + # the placement of the noise clip is uniformly random. + silence = SpeechSegment( + np.zeros(num_samples_to_pad), audio_segment.sample_rate) + noise = SpeechSegment.concatenate(silence, noise, silence) + + audio_segment.add_noise( + noise, snr, rng=self.rng, allow_downsampling=allow_downsampling) diff --git a/data_utils/augmentor/online_bayesian_normalization.py b/data_utils/augmentor/online_bayesian_normalization.py new file mode 100755 index 000000000..bc2d6c1b6 --- /dev/null +++ b/data_utils/augmentor/online_bayesian_normalization.py @@ -0,0 +1,57 @@ +""" Online bayesian normalization +""" +from __future__ import absolute_import +from __future__ import division +from __future__ import print_function + +from . import base + + +class OnlineBayesianNormalizationAugmentor(base.AugmentorBase): + """ + Instantiates an online bayesian normalization module. + :param target_db: Target RMS value in decibels + :type target_db: func[int->scalar] + :param prior_db: Prior RMS estimate in decibels + :type prior_db: func[int->scalar] + :param prior_samples: Prior strength in number of samples + :type prior_samples: func[int->scalar] + :param startup_delay: Start-up delay in seconds during + which normalization statistics is accrued. + :type starup_delay: func[int->scalar] + """ + + def __init__(self, + rng, + target_db, + prior_db, + prior_samples, + startup_delay=base.parse_parameter_from(0.0)): + + self.target_db = target_db + self.prior_db = prior_db + self.prior_samples = prior_samples + self.startup_delay = startup_delay + self.rng = rng + + def transform_audio(self, audio_segment): + """ + Normalizes the input audio using the online Bayesian approach. + + :param audio_segment: input audio + :type audio_segment: SpeechSegment + :param iteration: current iteration + :type iteration: int + :param text: audio transcription + :type text: basestring + :param rng: RNG to use for augmentation + :type rng: random.Random + + """ + read_size = 0 + target_db = self.target_db(iteration) + prior_db = self.prior_db(iteration) + prior_samples = self.prior_samples(iteration) + startup_delay = self.startup_delay(iteration) + audio.normalize_online_bayesian( + target_db, prior_db, prior_samples, startup_delay=startup_delay) diff --git a/data_utils/augmentor/resampler.py b/data_utils/augmentor/resampler.py new file mode 100755 index 000000000..1b959be56 --- /dev/null +++ b/data_utils/augmentor/resampler.py @@ -0,0 +1,30 @@ +from __future__ import absolute_import +from __future__ import division +from __future__ import print_function + +from . import base + + +class ResamplerAugmentor(base.AugmentorBase): + """ Instantiates a resampler module. + + :param new_sample_rate: New sample rate in Hz + :type new_sample_rate: func[int->scalar] + :param rng: Random generator object. + :type rng: random.Random + """ + + def __init__(self, rng, new_sample_rate): + self.new_sample_rate = new_sample_rate + self._rng = rng + + def transform_audio(self, audio_segment): + """ Resamples the input audio to the target sample rate. + + Note that this is an in-place transformation. + + :param audio: input audio + :type audio: SpeechDLSegment + """ + new_sample_rate = self.new_sample_rate + audio.resample(new_sample_rate) \ No newline at end of file diff --git a/data_utils/augmentor/speed_perturb.py b/data_utils/augmentor/speed_perturb.py new file mode 100755 index 000000000..e09be5f74 --- /dev/null +++ b/data_utils/augmentor/speed_perturb.py @@ -0,0 +1,53 @@ +"""Speed perturbation module for making ASR robust to different voice +types (high pitched, low pitched, etc) +Samples uniformly between speed_min and speed_max +""" + +from __future__ import absolute_import +from __future__ import division +from __future__ import print_function + +from . import base + + +class SpeedPerturbatioAugmentor(base.AugmentorBase): + """ + Instantiates a speed perturbation module. + + See reference paper here: + + http://www.danielpovey.com/files/2015_interspeech_augmentation.pdf + + :param speed_min: Lower bound on new rate to sample + :type speed_min: func[int->scalar] + :param speed_max: Upper bound on new rate to sample + :type speed_max: func[int->scalar] + """ + + def __init__(self, rng, speed_min, speed_max): + + if (speed_min < 0.9): + raise ValueError( + "Sampling speed below 0.9 can cause unnatural effects") + if (speed_min > 1.1): + raise ValueError( + "Sampling speed above 1.1 can cause unnatural effects") + self.speed_min = speed_min + self.speed_max = speed_max + self.rng = rng + + def transform_audio(self, audio_segment): + """ + Samples a new speed rate from the given range and + changes the speed of the given audio clip. + + Note that this is an in-place transformation. + + :param audio_segment: input audio + :type audio_segment: SpeechDLSegment + """ + read_size = 0 + speed_min = self.speed_min(iteration) + speed_max = self.speed_max(iteration) + sampled_speed = rng.uniform(speed_min, speed_max) + audio = audio.change_speed(sampled_speed) diff --git a/data_utils/augmentor/volume_perturb.py b/data_utils/augmentor/volume_perturb.py index a5a9f6cad..15055b915 100755 --- a/data_utils/augmentor/volume_perturb.py +++ b/data_utils/augmentor/volume_perturb.py @@ -3,10 +3,10 @@ from __future__ import absolute_import from __future__ import division from __future__ import print_function -from data_utils.augmentor.base import AugmentorBase +from . import base -class VolumePerturbAugmentor(AugmentorBase): +class VolumePerturbAugmentor(base.AugmentorBase): """Augmentation model for adding random volume perturbation. This is used for multi-loudness training of PCEN. See