You can not select more than 25 topics
Topics must start with a letter or number, can include dashes ('-') and can be up to 35 characters long.
910 lines
33 KiB
910 lines
33 KiB
# Copyright (c) 2021 PaddlePaddle Authors. All Rights Reserved.
|
|
#
|
|
# Licensed under the Apache License, Version 2.0 (the "License");
|
|
# you may not use this file except in compliance with the License.
|
|
# You may obtain a copy of the License at
|
|
#
|
|
# http://www.apache.org/licenses/LICENSE-2.0
|
|
#
|
|
# Unless required by applicable law or agreed to in writing, software
|
|
# distributed under the License is distributed on an "AS IS" BASIS,
|
|
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
|
# See the License for the specific language governing permissions and
|
|
# limitations under the License.
|
|
# this is modified from SpeechBrain
|
|
# https://github.com/speechbrain/speechbrain/blob/085be635c07f16d42cd1295045bc46c407f1e15b/speechbrain/lobes/augment.py
|
|
import math
|
|
import os
|
|
from typing import List
|
|
|
|
import numpy as np
|
|
import paddle
|
|
import paddle.nn as nn
|
|
import paddle.nn.functional as F
|
|
|
|
from paddlespeech.s2t.utils.log import Log
|
|
from paddlespeech.vector.io.dataset import CSVDataset
|
|
from paddlespeech.vector.io.signal_processing import compute_amplitude
|
|
from paddlespeech.vector.io.signal_processing import convolve1d
|
|
from paddlespeech.vector.io.signal_processing import dB_to_amplitude
|
|
from paddlespeech.vector.io.signal_processing import notch_filter
|
|
from paddlespeech.vector.io.signal_processing import reverberate
|
|
|
|
logger = Log(__name__).getlog()
|
|
|
|
|
|
# TODO: Complete type-hint and doc string.
|
|
class DropFreq(nn.Layer):
|
|
def __init__(
|
|
self,
|
|
drop_freq_low=1e-14,
|
|
drop_freq_high=1,
|
|
drop_count_low=1,
|
|
drop_count_high=2,
|
|
drop_width=0.05,
|
|
drop_prob=1, ):
|
|
super(DropFreq, self).__init__()
|
|
self.drop_freq_low = drop_freq_low
|
|
self.drop_freq_high = drop_freq_high
|
|
self.drop_count_low = drop_count_low
|
|
self.drop_count_high = drop_count_high
|
|
self.drop_width = drop_width
|
|
self.drop_prob = drop_prob
|
|
|
|
def forward(self, waveforms):
|
|
# Don't drop (return early) 1-`drop_prob` portion of the batches
|
|
dropped_waveform = waveforms.clone()
|
|
if paddle.rand([1]) > self.drop_prob:
|
|
return dropped_waveform
|
|
|
|
# Add channels dimension
|
|
if len(waveforms.shape) == 2:
|
|
dropped_waveform = dropped_waveform.unsqueeze(-1)
|
|
|
|
# Pick number of frequencies to drop
|
|
drop_count = paddle.randint(
|
|
low=self.drop_count_low, high=self.drop_count_high + 1, shape=[1])
|
|
|
|
# Pick a frequency to drop
|
|
drop_range = self.drop_freq_high - self.drop_freq_low
|
|
drop_frequency = (
|
|
paddle.rand([drop_count]) * drop_range + self.drop_freq_low)
|
|
|
|
# Filter parameters
|
|
filter_length = 101
|
|
pad = filter_length // 2
|
|
|
|
# Start with delta function
|
|
drop_filter = paddle.zeros([1, filter_length, 1])
|
|
drop_filter[0, pad, 0] = 1
|
|
|
|
# Subtract each frequency
|
|
for frequency in drop_frequency:
|
|
notch_kernel = notch_filter(frequency, filter_length,
|
|
self.drop_width)
|
|
drop_filter = convolve1d(drop_filter, notch_kernel, pad)
|
|
|
|
# Apply filter
|
|
dropped_waveform = convolve1d(dropped_waveform, drop_filter, pad)
|
|
|
|
# Remove channels dimension if added
|
|
return dropped_waveform.squeeze(-1)
|
|
|
|
|
|
class DropChunk(nn.Layer):
|
|
def __init__(
|
|
self,
|
|
drop_length_low=100,
|
|
drop_length_high=1000,
|
|
drop_count_low=1,
|
|
drop_count_high=10,
|
|
drop_start=0,
|
|
drop_end=None,
|
|
drop_prob=1,
|
|
noise_factor=0.0, ):
|
|
super(DropChunk, self).__init__()
|
|
self.drop_length_low = drop_length_low
|
|
self.drop_length_high = drop_length_high
|
|
self.drop_count_low = drop_count_low
|
|
self.drop_count_high = drop_count_high
|
|
self.drop_start = drop_start
|
|
self.drop_end = drop_end
|
|
self.drop_prob = drop_prob
|
|
self.noise_factor = noise_factor
|
|
|
|
# Validate low < high
|
|
if drop_length_low > drop_length_high:
|
|
raise ValueError("Low limit must not be more than high limit")
|
|
if drop_count_low > drop_count_high:
|
|
raise ValueError("Low limit must not be more than high limit")
|
|
|
|
# Make sure the length doesn't exceed end - start
|
|
if drop_end is not None and drop_end >= 0:
|
|
if drop_start > drop_end:
|
|
raise ValueError("Low limit must not be more than high limit")
|
|
|
|
drop_range = drop_end - drop_start
|
|
self.drop_length_low = min(drop_length_low, drop_range)
|
|
self.drop_length_high = min(drop_length_high, drop_range)
|
|
|
|
def forward(self, waveforms, lengths):
|
|
# Reading input list
|
|
lengths = (lengths * waveforms.shape[1]).astype('int64')
|
|
batch_size = waveforms.shape[0]
|
|
dropped_waveform = waveforms.clone()
|
|
|
|
# Don't drop (return early) 1-`drop_prob` portion of the batches
|
|
if paddle.rand([1]) > self.drop_prob:
|
|
return dropped_waveform
|
|
|
|
# Store original amplitude for computing white noise amplitude
|
|
clean_amplitude = compute_amplitude(waveforms, lengths.unsqueeze(1))
|
|
|
|
# Pick a number of times to drop
|
|
drop_times = paddle.randint(
|
|
low=self.drop_count_low,
|
|
high=self.drop_count_high + 1,
|
|
shape=[batch_size], )
|
|
|
|
# Iterate batch to set mask
|
|
for i in range(batch_size):
|
|
if drop_times[i] == 0:
|
|
continue
|
|
|
|
# Pick lengths
|
|
length = paddle.randint(
|
|
low=self.drop_length_low,
|
|
high=self.drop_length_high + 1,
|
|
shape=[drop_times[i]], )
|
|
|
|
# Compute range of starting locations
|
|
start_min = self.drop_start
|
|
if start_min < 0:
|
|
start_min += lengths[i]
|
|
start_max = self.drop_end
|
|
if start_max is None:
|
|
start_max = lengths[i]
|
|
if start_max < 0:
|
|
start_max += lengths[i]
|
|
start_max = max(0, start_max - length.max())
|
|
|
|
# Pick starting locations
|
|
start = paddle.randint(
|
|
low=start_min,
|
|
high=start_max + 1,
|
|
shape=[drop_times[i]], )
|
|
|
|
end = start + length
|
|
|
|
# Update waveform
|
|
if not self.noise_factor:
|
|
for j in range(drop_times[i]):
|
|
if start[j] < end[j]:
|
|
dropped_waveform[i, start[j]:end[j]] = 0.0
|
|
else:
|
|
# Uniform distribution of -2 to +2 * avg amplitude should
|
|
# preserve the average for normalization
|
|
noise_max = 2 * clean_amplitude[i] * self.noise_factor
|
|
for j in range(drop_times[i]):
|
|
# zero-center the noise distribution
|
|
noise_vec = paddle.rand([length[j]], dtype='float32')
|
|
|
|
noise_vec = 2 * noise_max * noise_vec - noise_max
|
|
dropped_waveform[i, int(start[j]):int(end[j])] = noise_vec
|
|
|
|
return dropped_waveform
|
|
|
|
|
|
class Resample(nn.Layer):
|
|
def __init__(
|
|
self,
|
|
orig_freq=16000,
|
|
new_freq=16000,
|
|
lowpass_filter_width=6, ):
|
|
super(Resample, self).__init__()
|
|
self.orig_freq = orig_freq
|
|
self.new_freq = new_freq
|
|
self.lowpass_filter_width = lowpass_filter_width
|
|
|
|
# Compute rate for striding
|
|
self._compute_strides()
|
|
assert self.orig_freq % self.conv_stride == 0
|
|
assert self.new_freq % self.conv_transpose_stride == 0
|
|
|
|
def _compute_strides(self):
|
|
# Compute new unit based on ratio of in/out frequencies
|
|
base_freq = math.gcd(self.orig_freq, self.new_freq)
|
|
input_samples_in_unit = self.orig_freq // base_freq
|
|
self.output_samples = self.new_freq // base_freq
|
|
|
|
# Store the appropriate stride based on the new units
|
|
self.conv_stride = input_samples_in_unit
|
|
self.conv_transpose_stride = self.output_samples
|
|
|
|
def forward(self, waveforms):
|
|
if not hasattr(self, "first_indices"):
|
|
self._indices_and_weights(waveforms)
|
|
|
|
# Don't do anything if the frequencies are the same
|
|
if self.orig_freq == self.new_freq:
|
|
return waveforms
|
|
|
|
unsqueezed = False
|
|
if len(waveforms.shape) == 2:
|
|
waveforms = waveforms.unsqueeze(1)
|
|
unsqueezed = True
|
|
elif len(waveforms.shape) == 3:
|
|
waveforms = waveforms.transpose([0, 2, 1])
|
|
else:
|
|
raise ValueError("Input must be 2 or 3 dimensions")
|
|
|
|
# Do resampling
|
|
resampled_waveform = self._perform_resample(waveforms)
|
|
|
|
if unsqueezed:
|
|
resampled_waveform = resampled_waveform.squeeze(1)
|
|
else:
|
|
resampled_waveform = resampled_waveform.transpose([0, 2, 1])
|
|
|
|
return resampled_waveform
|
|
|
|
def _perform_resample(self, waveforms):
|
|
# Compute output size and initialize
|
|
batch_size, num_channels, wave_len = waveforms.shape
|
|
window_size = self.weights.shape[1]
|
|
tot_output_samp = self._output_samples(wave_len)
|
|
resampled_waveform = paddle.zeros((batch_size, num_channels,
|
|
tot_output_samp))
|
|
|
|
# eye size: (num_channels, num_channels, 1)
|
|
eye = paddle.eye(num_channels).unsqueeze(2)
|
|
|
|
# Iterate over the phases in the polyphase filter
|
|
for i in range(self.first_indices.shape[0]):
|
|
wave_to_conv = waveforms
|
|
first_index = int(self.first_indices[i].item())
|
|
if first_index >= 0:
|
|
# trim the signal as the filter will not be applied
|
|
# before the first_index
|
|
wave_to_conv = wave_to_conv[:, :, first_index:]
|
|
|
|
# pad the right of the signal to allow partial convolutions
|
|
# meaning compute values for partial windows (e.g. end of the
|
|
# window is outside the signal length)
|
|
max_index = (tot_output_samp - 1) // self.output_samples
|
|
end_index = max_index * self.conv_stride + window_size
|
|
current_wave_len = wave_len - first_index
|
|
right_padding = max(0, end_index + 1 - current_wave_len)
|
|
left_padding = max(0, -first_index)
|
|
wave_to_conv = paddle.nn.functional.pad(
|
|
wave_to_conv, [left_padding, right_padding], data_format='NCL')
|
|
conv_wave = paddle.nn.functional.conv1d(
|
|
x=wave_to_conv,
|
|
# weight=self.weights[i].repeat(num_channels, 1, 1),
|
|
weight=self.weights[i].expand((num_channels, 1, -1)),
|
|
stride=self.conv_stride,
|
|
groups=num_channels, )
|
|
|
|
# we want conv_wave[:, i] to be at
|
|
# output[:, i + n*conv_transpose_stride]
|
|
dilated_conv_wave = paddle.nn.functional.conv1d_transpose(
|
|
conv_wave, eye, stride=self.conv_transpose_stride)
|
|
|
|
# pad dilated_conv_wave so it reaches the output length if needed.
|
|
left_padding = i
|
|
previous_padding = left_padding + dilated_conv_wave.shape[-1]
|
|
right_padding = max(0, tot_output_samp - previous_padding)
|
|
dilated_conv_wave = paddle.nn.functional.pad(
|
|
dilated_conv_wave, [left_padding, right_padding],
|
|
data_format='NCL')
|
|
dilated_conv_wave = dilated_conv_wave[:, :, :tot_output_samp]
|
|
|
|
resampled_waveform += dilated_conv_wave
|
|
|
|
return resampled_waveform
|
|
|
|
def _output_samples(self, input_num_samp):
|
|
samp_in = int(self.orig_freq)
|
|
samp_out = int(self.new_freq)
|
|
|
|
tick_freq = abs(samp_in * samp_out) // math.gcd(samp_in, samp_out)
|
|
ticks_per_input_period = tick_freq // samp_in
|
|
|
|
# work out the number of ticks in the time interval
|
|
# [ 0, input_num_samp/samp_in ).
|
|
interval_length = input_num_samp * ticks_per_input_period
|
|
if interval_length <= 0:
|
|
return 0
|
|
ticks_per_output_period = tick_freq // samp_out
|
|
|
|
# Get the last output-sample in the closed interval,
|
|
# i.e. replacing [ ) with [ ]. Note: integer division rounds down.
|
|
# See http://en.wikipedia.org/wiki/Interval_(mathematics) for an
|
|
# explanation of the notation.
|
|
last_output_samp = interval_length // ticks_per_output_period
|
|
|
|
# We need the last output-sample in the open interval, so if it
|
|
# takes us to the end of the interval exactly, subtract one.
|
|
if last_output_samp * ticks_per_output_period == interval_length:
|
|
last_output_samp -= 1
|
|
|
|
# First output-sample index is zero, so the number of output samples
|
|
# is the last output-sample plus one.
|
|
num_output_samp = last_output_samp + 1
|
|
|
|
return num_output_samp
|
|
|
|
def _indices_and_weights(self, waveforms):
|
|
# Lowpass filter frequency depends on smaller of two frequencies
|
|
min_freq = min(self.orig_freq, self.new_freq)
|
|
lowpass_cutoff = 0.99 * 0.5 * min_freq
|
|
|
|
assert lowpass_cutoff * 2 <= min_freq
|
|
window_width = self.lowpass_filter_width / (2.0 * lowpass_cutoff)
|
|
|
|
assert lowpass_cutoff < min(self.orig_freq, self.new_freq) / 2
|
|
output_t = paddle.arange(start=0.0, end=self.output_samples)
|
|
output_t /= self.new_freq
|
|
min_t = output_t - window_width
|
|
max_t = output_t + window_width
|
|
|
|
min_input_index = paddle.ceil(min_t * self.orig_freq)
|
|
max_input_index = paddle.floor(max_t * self.orig_freq)
|
|
num_indices = max_input_index - min_input_index + 1
|
|
|
|
max_weight_width = num_indices.max()
|
|
j = paddle.arange(max_weight_width, dtype='float32')
|
|
input_index = min_input_index.unsqueeze(1) + j.unsqueeze(0)
|
|
delta_t = (input_index / self.orig_freq) - output_t.unsqueeze(1)
|
|
|
|
weights = paddle.zeros_like(delta_t)
|
|
inside_window_indices = delta_t.abs().less_than(
|
|
paddle.to_tensor(window_width))
|
|
|
|
# raised-cosine (Hanning) window with width `window_width`
|
|
weights[inside_window_indices] = 0.5 * (1 + paddle.cos(
|
|
2 * math.pi * lowpass_cutoff / self.lowpass_filter_width *
|
|
delta_t.masked_select(inside_window_indices)))
|
|
|
|
t_eq_zero_indices = delta_t.equal(paddle.zeros_like(delta_t))
|
|
t_not_eq_zero_indices = delta_t.not_equal(paddle.zeros_like(delta_t))
|
|
|
|
# sinc filter function
|
|
weights = paddle.where(
|
|
t_not_eq_zero_indices,
|
|
weights * paddle.sin(2 * math.pi * lowpass_cutoff * delta_t) /
|
|
(math.pi * delta_t), weights)
|
|
|
|
# limit of the function at t = 0
|
|
weights = paddle.where(t_eq_zero_indices, weights * 2 * lowpass_cutoff,
|
|
weights)
|
|
|
|
# size (output_samples, max_weight_width)
|
|
weights /= self.orig_freq
|
|
|
|
self.first_indices = min_input_index
|
|
self.weights = weights
|
|
|
|
|
|
class SpeedPerturb(nn.Layer):
|
|
def __init__(
|
|
self,
|
|
orig_freq,
|
|
speeds=[90, 100, 110],
|
|
perturb_prob=1.0, ):
|
|
super(SpeedPerturb, self).__init__()
|
|
self.orig_freq = orig_freq
|
|
self.speeds = speeds
|
|
self.perturb_prob = perturb_prob
|
|
|
|
# Initialize index of perturbation
|
|
self.samp_index = 0
|
|
|
|
# Initialize resamplers
|
|
self.resamplers = []
|
|
for speed in self.speeds:
|
|
config = {
|
|
"orig_freq": self.orig_freq,
|
|
"new_freq": self.orig_freq * speed // 100,
|
|
}
|
|
self.resamplers.append(Resample(**config))
|
|
|
|
def forward(self, waveform):
|
|
# Don't perturb (return early) 1-`perturb_prob` portion of the batches
|
|
if paddle.rand([1]) > self.perturb_prob:
|
|
return waveform.clone()
|
|
|
|
# Perform a random perturbation
|
|
self.samp_index = paddle.randint(len(self.speeds), shape=[1]).item()
|
|
perturbed_waveform = self.resamplers[self.samp_index](waveform)
|
|
|
|
return perturbed_waveform
|
|
|
|
|
|
class AddNoise(nn.Layer):
|
|
def __init__(
|
|
self,
|
|
noise_dataset=None, # None for white noise
|
|
num_workers=0,
|
|
snr_low=0,
|
|
snr_high=0,
|
|
mix_prob=1.0,
|
|
start_index=None,
|
|
normalize=False, ):
|
|
super(AddNoise, self).__init__()
|
|
|
|
self.num_workers = num_workers
|
|
self.snr_low = snr_low
|
|
self.snr_high = snr_high
|
|
self.mix_prob = mix_prob
|
|
self.start_index = start_index
|
|
self.normalize = normalize
|
|
self.noise_dataset = noise_dataset
|
|
self.noise_dataloader = None
|
|
|
|
def forward(self, waveforms, lengths=None):
|
|
if lengths is None:
|
|
lengths = paddle.ones([len(waveforms)])
|
|
|
|
# Copy clean waveform to initialize noisy waveform
|
|
noisy_waveform = waveforms.clone()
|
|
lengths = (lengths * waveforms.shape[1]).astype('int64').unsqueeze(1)
|
|
|
|
# Don't add noise (return early) 1-`mix_prob` portion of the batches
|
|
if paddle.rand([1]) > self.mix_prob:
|
|
return noisy_waveform
|
|
|
|
# Compute the average amplitude of the clean waveforms
|
|
clean_amplitude = compute_amplitude(waveforms, lengths)
|
|
|
|
# Pick an SNR and use it to compute the mixture amplitude factors
|
|
SNR = paddle.rand((len(waveforms), 1))
|
|
SNR = SNR * (self.snr_high - self.snr_low) + self.snr_low
|
|
noise_amplitude_factor = 1 / (dB_to_amplitude(SNR) + 1)
|
|
new_noise_amplitude = noise_amplitude_factor * clean_amplitude
|
|
|
|
# Scale clean signal appropriately
|
|
noisy_waveform *= 1 - noise_amplitude_factor
|
|
|
|
# Loop through clean samples and create mixture
|
|
if self.noise_dataset is None:
|
|
white_noise = paddle.normal(shape=waveforms.shape)
|
|
noisy_waveform += new_noise_amplitude * white_noise
|
|
else:
|
|
tensor_length = waveforms.shape[1]
|
|
noise_waveform, noise_length = self._load_noise(
|
|
lengths,
|
|
tensor_length, )
|
|
|
|
# Rescale and add
|
|
noise_amplitude = compute_amplitude(noise_waveform, noise_length)
|
|
noise_waveform *= new_noise_amplitude / (noise_amplitude + 1e-14)
|
|
noisy_waveform += noise_waveform
|
|
|
|
# Normalizing to prevent clipping
|
|
if self.normalize:
|
|
abs_max, _ = paddle.max(
|
|
paddle.abs(noisy_waveform), axis=1, keepdim=True)
|
|
noisy_waveform = noisy_waveform / abs_max.clip(min=1.0)
|
|
|
|
return noisy_waveform
|
|
|
|
def _load_noise(self, lengths, max_length):
|
|
"""
|
|
Load a batch of noises
|
|
|
|
args
|
|
lengths(Paddle.Tensor): Num samples of waveforms with shape (N, 1).
|
|
max_length(int): Width of a batch.
|
|
"""
|
|
lengths = lengths.squeeze(1)
|
|
batch_size = len(lengths)
|
|
|
|
# Load a noise batch
|
|
if self.noise_dataloader is None:
|
|
|
|
def noise_collate_fn(batch):
|
|
def pad(x, target_length, mode='constant', **kwargs):
|
|
x = np.asarray(x)
|
|
w = target_length - x.shape[0]
|
|
assert w >= 0, f'Target length {target_length} is less than origin length {x.shape[0]}'
|
|
return np.pad(x, [0, w], mode=mode, **kwargs)
|
|
|
|
ids = [item['utt_id'] for item in batch]
|
|
lengths = np.asarray([item['feat'].shape[0] for item in batch])
|
|
waveforms = list(
|
|
map(lambda x: pad(x, max(max_length, lengths.max().item())),
|
|
[item['feat'] for item in batch]))
|
|
waveforms = np.stack(waveforms)
|
|
return {'ids': ids, 'feats': waveforms, 'lengths': lengths}
|
|
|
|
# Create noise data loader.
|
|
self.noise_dataloader = paddle.io.DataLoader(
|
|
self.noise_dataset,
|
|
batch_size=batch_size,
|
|
shuffle=True,
|
|
num_workers=self.num_workers,
|
|
collate_fn=noise_collate_fn,
|
|
return_list=True, )
|
|
self.noise_data = iter(self.noise_dataloader)
|
|
|
|
noise_batch, noise_len = self._load_noise_batch_of_size(batch_size)
|
|
|
|
# Select a random starting location in the waveform
|
|
start_index = self.start_index
|
|
if self.start_index is None:
|
|
start_index = 0
|
|
max_chop = (noise_len - lengths).min().clip(min=1)
|
|
start_index = paddle.randint(high=max_chop, shape=[1])
|
|
|
|
# Truncate noise_batch to max_length
|
|
noise_batch = noise_batch[:, start_index:start_index + max_length]
|
|
noise_len = (noise_len - start_index).clip(max=max_length).unsqueeze(1)
|
|
return noise_batch, noise_len
|
|
|
|
def _load_noise_batch_of_size(self, batch_size):
|
|
"""Concatenate noise batches, then chop to correct size"""
|
|
noise_batch, noise_lens = self._load_noise_batch()
|
|
|
|
# Expand
|
|
while len(noise_batch) < batch_size:
|
|
noise_batch = paddle.concat((noise_batch, noise_batch))
|
|
noise_lens = paddle.concat((noise_lens, noise_lens))
|
|
|
|
# Contract
|
|
if len(noise_batch) > batch_size:
|
|
noise_batch = noise_batch[:batch_size]
|
|
noise_lens = noise_lens[:batch_size]
|
|
|
|
return noise_batch, noise_lens
|
|
|
|
def _load_noise_batch(self):
|
|
"""Load a batch of noises, restarting iteration if necessary."""
|
|
try:
|
|
batch = next(self.noise_data)
|
|
except StopIteration:
|
|
self.noise_data = iter(self.noise_dataloader)
|
|
batch = next(self.noise_data)
|
|
|
|
noises, lens = batch['feats'], batch['lengths']
|
|
return noises, lens
|
|
|
|
|
|
class AddReverb(nn.Layer):
|
|
def __init__(
|
|
self,
|
|
rir_dataset,
|
|
reverb_prob=1.0,
|
|
rir_scale_factor=1.0,
|
|
num_workers=0, ):
|
|
super(AddReverb, self).__init__()
|
|
self.rir_dataset = rir_dataset
|
|
self.reverb_prob = reverb_prob
|
|
self.rir_scale_factor = rir_scale_factor
|
|
|
|
# Create rir data loader.
|
|
def rir_collate_fn(batch):
|
|
def pad(x, target_length, mode='constant', **kwargs):
|
|
x = np.asarray(x)
|
|
w = target_length - x.shape[0]
|
|
assert w >= 0, f'Target length {target_length} is less than origin length {x.shape[0]}'
|
|
return np.pad(x, [0, w], mode=mode, **kwargs)
|
|
|
|
ids = [item['utt_id'] for item in batch]
|
|
lengths = np.asarray([item['feat'].shape[0] for item in batch])
|
|
waveforms = list(
|
|
map(lambda x: pad(x, lengths.max().item()),
|
|
[item['feat'] for item in batch]))
|
|
waveforms = np.stack(waveforms)
|
|
return {'ids': ids, 'feats': waveforms, 'lengths': lengths}
|
|
|
|
self.rir_dataloader = paddle.io.DataLoader(
|
|
self.rir_dataset,
|
|
collate_fn=rir_collate_fn,
|
|
num_workers=num_workers,
|
|
shuffle=True,
|
|
return_list=True, )
|
|
|
|
self.rir_data = iter(self.rir_dataloader)
|
|
|
|
def forward(self, waveforms, lengths=None):
|
|
"""
|
|
Arguments
|
|
---------
|
|
waveforms : tensor
|
|
Shape should be `[batch, time]` or `[batch, time, channels]`.
|
|
lengths : tensor
|
|
Shape should be a single dimension, `[batch]`.
|
|
|
|
Returns
|
|
-------
|
|
Tensor of shape `[batch, time]` or `[batch, time, channels]`.
|
|
"""
|
|
|
|
if lengths is None:
|
|
lengths = paddle.ones([len(waveforms)])
|
|
|
|
# Don't add reverb (return early) 1-`reverb_prob` portion of the time
|
|
if paddle.rand([1]) > self.reverb_prob:
|
|
return waveforms.clone()
|
|
|
|
# Add channels dimension if necessary
|
|
channel_added = False
|
|
if len(waveforms.shape) == 2:
|
|
waveforms = waveforms.unsqueeze(-1)
|
|
channel_added = True
|
|
|
|
# Load and prepare RIR
|
|
rir_waveform = self._load_rir()
|
|
|
|
# Compress or dilate RIR
|
|
if self.rir_scale_factor != 1:
|
|
rir_waveform = F.interpolate(
|
|
rir_waveform.transpose([0, 2, 1]),
|
|
scale_factor=self.rir_scale_factor,
|
|
mode="linear",
|
|
align_corners=False,
|
|
data_format='NCW', )
|
|
# (N, C, L) -> (N, L, C)
|
|
rir_waveform = rir_waveform.transpose([0, 2, 1])
|
|
|
|
rev_waveform = reverberate(
|
|
waveforms,
|
|
rir_waveform,
|
|
self.rir_dataset.sample_rate,
|
|
rescale_amp="avg")
|
|
|
|
# Remove channels dimension if added
|
|
if channel_added:
|
|
return rev_waveform.squeeze(-1)
|
|
|
|
return rev_waveform
|
|
|
|
def _load_rir(self):
|
|
try:
|
|
batch = next(self.rir_data)
|
|
except StopIteration:
|
|
self.rir_data = iter(self.rir_dataloader)
|
|
batch = next(self.rir_data)
|
|
|
|
rir_waveform = batch['feats']
|
|
|
|
# Make sure RIR has correct channels
|
|
if len(rir_waveform.shape) == 2:
|
|
rir_waveform = rir_waveform.unsqueeze(-1)
|
|
|
|
return rir_waveform
|
|
|
|
|
|
class AddBabble(nn.Layer):
|
|
def __init__(
|
|
self,
|
|
speaker_count=3,
|
|
snr_low=0,
|
|
snr_high=0,
|
|
mix_prob=1, ):
|
|
super(AddBabble, self).__init__()
|
|
self.speaker_count = speaker_count
|
|
self.snr_low = snr_low
|
|
self.snr_high = snr_high
|
|
self.mix_prob = mix_prob
|
|
|
|
def forward(self, waveforms, lengths=None):
|
|
if lengths is None:
|
|
lengths = paddle.ones([len(waveforms)])
|
|
|
|
babbled_waveform = waveforms.clone()
|
|
lengths = (lengths * waveforms.shape[1]).unsqueeze(1)
|
|
batch_size = len(waveforms)
|
|
|
|
# Don't mix (return early) 1-`mix_prob` portion of the batches
|
|
if paddle.rand([1]) > self.mix_prob:
|
|
return babbled_waveform
|
|
|
|
# Pick an SNR and use it to compute the mixture amplitude factors
|
|
clean_amplitude = compute_amplitude(waveforms, lengths)
|
|
SNR = paddle.rand((batch_size, 1))
|
|
SNR = SNR * (self.snr_high - self.snr_low) + self.snr_low
|
|
noise_amplitude_factor = 1 / (dB_to_amplitude(SNR) + 1)
|
|
new_noise_amplitude = noise_amplitude_factor * clean_amplitude
|
|
|
|
# Scale clean signal appropriately
|
|
babbled_waveform *= 1 - noise_amplitude_factor
|
|
|
|
# For each speaker in the mixture, roll and add
|
|
babble_waveform = waveforms.roll((1, ), axis=0)
|
|
babble_len = lengths.roll((1, ), axis=0)
|
|
for i in range(1, self.speaker_count):
|
|
babble_waveform += waveforms.roll((1 + i, ), axis=0)
|
|
babble_len = paddle.concat(
|
|
[babble_len, babble_len.roll((1, ), axis=0)], axis=-1).max(
|
|
axis=-1, keepdim=True)
|
|
|
|
# Rescale and add to mixture
|
|
babble_amplitude = compute_amplitude(babble_waveform, babble_len)
|
|
babble_waveform *= new_noise_amplitude / (babble_amplitude + 1e-14)
|
|
babbled_waveform += babble_waveform
|
|
|
|
return babbled_waveform
|
|
|
|
|
|
class TimeDomainSpecAugment(nn.Layer):
|
|
def __init__(
|
|
self,
|
|
perturb_prob=1.0,
|
|
drop_freq_prob=1.0,
|
|
drop_chunk_prob=1.0,
|
|
speeds=[95, 100, 105],
|
|
sample_rate=16000,
|
|
drop_freq_count_low=0,
|
|
drop_freq_count_high=3,
|
|
drop_chunk_count_low=0,
|
|
drop_chunk_count_high=5,
|
|
drop_chunk_length_low=1000,
|
|
drop_chunk_length_high=2000,
|
|
drop_chunk_noise_factor=0, ):
|
|
super(TimeDomainSpecAugment, self).__init__()
|
|
self.speed_perturb = SpeedPerturb(
|
|
perturb_prob=perturb_prob,
|
|
orig_freq=sample_rate,
|
|
speeds=speeds, )
|
|
self.drop_freq = DropFreq(
|
|
drop_prob=drop_freq_prob,
|
|
drop_count_low=drop_freq_count_low,
|
|
drop_count_high=drop_freq_count_high, )
|
|
self.drop_chunk = DropChunk(
|
|
drop_prob=drop_chunk_prob,
|
|
drop_count_low=drop_chunk_count_low,
|
|
drop_count_high=drop_chunk_count_high,
|
|
drop_length_low=drop_chunk_length_low,
|
|
drop_length_high=drop_chunk_length_high,
|
|
noise_factor=drop_chunk_noise_factor, )
|
|
|
|
def forward(self, waveforms, lengths=None):
|
|
if lengths is None:
|
|
lengths = paddle.ones([len(waveforms)])
|
|
|
|
with paddle.no_grad():
|
|
# Augmentation
|
|
waveforms = self.speed_perturb(waveforms)
|
|
waveforms = self.drop_freq(waveforms)
|
|
waveforms = self.drop_chunk(waveforms, lengths)
|
|
|
|
return waveforms
|
|
|
|
|
|
class EnvCorrupt(nn.Layer):
|
|
def __init__(
|
|
self,
|
|
reverb_prob=1.0,
|
|
babble_prob=1.0,
|
|
noise_prob=1.0,
|
|
rir_dataset=None,
|
|
noise_dataset=None,
|
|
num_workers=0,
|
|
babble_speaker_count=0,
|
|
babble_snr_low=0,
|
|
babble_snr_high=0,
|
|
noise_snr_low=0,
|
|
noise_snr_high=0,
|
|
rir_scale_factor=1.0, ):
|
|
super(EnvCorrupt, self).__init__()
|
|
|
|
# Initialize corrupters
|
|
if rir_dataset is not None and reverb_prob > 0.0:
|
|
self.add_reverb = AddReverb(
|
|
rir_dataset=rir_dataset,
|
|
num_workers=num_workers,
|
|
reverb_prob=reverb_prob,
|
|
rir_scale_factor=rir_scale_factor, )
|
|
|
|
if babble_speaker_count > 0 and babble_prob > 0.0:
|
|
self.add_babble = AddBabble(
|
|
speaker_count=babble_speaker_count,
|
|
snr_low=babble_snr_low,
|
|
snr_high=babble_snr_high,
|
|
mix_prob=babble_prob, )
|
|
|
|
if noise_dataset is not None and noise_prob > 0.0:
|
|
self.add_noise = AddNoise(
|
|
noise_dataset=noise_dataset,
|
|
num_workers=num_workers,
|
|
snr_low=noise_snr_low,
|
|
snr_high=noise_snr_high,
|
|
mix_prob=noise_prob, )
|
|
|
|
def forward(self, waveforms, lengths=None):
|
|
if lengths is None:
|
|
lengths = paddle.ones([len(waveforms)])
|
|
|
|
# Augmentation
|
|
with paddle.no_grad():
|
|
if hasattr(self, "add_reverb"):
|
|
try:
|
|
waveforms = self.add_reverb(waveforms, lengths)
|
|
except Exception:
|
|
pass
|
|
if hasattr(self, "add_babble"):
|
|
waveforms = self.add_babble(waveforms, lengths)
|
|
if hasattr(self, "add_noise"):
|
|
waveforms = self.add_noise(waveforms, lengths)
|
|
|
|
return waveforms
|
|
|
|
|
|
def build_augment_pipeline(target_dir=None) -> List[paddle.nn.Layer]:
|
|
"""build augment pipeline
|
|
Note: this pipeline cannot be used in the paddle.DataLoader
|
|
|
|
Returns:
|
|
List[paddle.nn.Layer]: all augment process
|
|
"""
|
|
logger.info("start to build the augment pipeline")
|
|
noise_dataset = CSVDataset(csv_path=os.path.join(target_dir,
|
|
"rir_noise/csv/noise.csv"))
|
|
rir_dataset = CSVDataset(csv_path=os.path.join(target_dir,
|
|
"rir_noise/csv/rir.csv"))
|
|
|
|
wavedrop = TimeDomainSpecAugment(
|
|
sample_rate=16000,
|
|
speeds=[100], )
|
|
speed_perturb = TimeDomainSpecAugment(
|
|
sample_rate=16000,
|
|
speeds=[95, 100, 105], )
|
|
add_noise = EnvCorrupt(
|
|
noise_dataset=noise_dataset,
|
|
reverb_prob=0.0,
|
|
noise_prob=1.0,
|
|
noise_snr_low=0,
|
|
noise_snr_high=15,
|
|
rir_scale_factor=1.0, )
|
|
add_rev = EnvCorrupt(
|
|
rir_dataset=rir_dataset,
|
|
reverb_prob=1.0,
|
|
noise_prob=0.0,
|
|
rir_scale_factor=1.0, )
|
|
add_rev_noise = EnvCorrupt(
|
|
noise_dataset=noise_dataset,
|
|
rir_dataset=rir_dataset,
|
|
reverb_prob=1.0,
|
|
noise_prob=1.0,
|
|
noise_snr_low=0,
|
|
noise_snr_high=15,
|
|
rir_scale_factor=1.0, )
|
|
|
|
return [wavedrop, speed_perturb, add_noise, add_rev, add_rev_noise]
|
|
|
|
|
|
def waveform_augment(waveforms: paddle.Tensor,
|
|
augment_pipeline: List[paddle.nn.Layer]) -> paddle.Tensor:
|
|
"""process the augment pipeline and return all the waveforms
|
|
|
|
Args:
|
|
waveforms (paddle.Tensor): original batch waveform
|
|
augment_pipeline (List[paddle.nn.Layer]): agument pipeline process
|
|
|
|
Returns:
|
|
paddle.Tensor: all the audio waveform including the original waveform and augmented waveform
|
|
"""
|
|
# stage 0: store the original waveforms
|
|
waveforms_aug_list = [waveforms]
|
|
|
|
# augment the original batch waveform
|
|
for aug in augment_pipeline:
|
|
# stage 1: augment the data
|
|
waveforms_aug = aug(waveforms) # (N, L)
|
|
if waveforms_aug.shape[1] >= waveforms.shape[1]:
|
|
# Trunc
|
|
waveforms_aug = waveforms_aug[:, :waveforms.shape[1]]
|
|
else:
|
|
# Pad
|
|
lengths_to_pad = waveforms.shape[1] - waveforms_aug.shape[1]
|
|
waveforms_aug = F.pad(
|
|
waveforms_aug.unsqueeze(-1), [0, lengths_to_pad],
|
|
data_format='NLC').squeeze(-1)
|
|
# stage 2: append the augmented waveform into the list
|
|
waveforms_aug_list.append(waveforms_aug)
|
|
|
|
# get the all the waveforms
|
|
return paddle.concat(waveforms_aug_list, axis=0)
|