You can not select more than 25 topics
Topics must start with a letter or number, can include dashes ('-') and can be up to 35 characters long.
218 lines
6.8 KiB
218 lines
6.8 KiB
3 years ago
|
# Copyright (c) 2021 PaddlePaddle Authors. All Rights Reserved.
|
||
|
#
|
||
|
# Licensed under the Apache License, Version 2.0 (the "License");
|
||
|
# you may not use this file except in compliance with the License.
|
||
|
# You may obtain a copy of the License at
|
||
|
#
|
||
|
# http://www.apache.org/licenses/LICENSE-2.0
|
||
|
#
|
||
|
# Unless required by applicable law or agreed to in writing, software
|
||
|
# distributed under the License is distributed on an "AS IS" BASIS,
|
||
|
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||
|
# See the License for the specific language governing permissions and
|
||
|
# limitations under the License.
|
||
|
import numpy as np
|
||
|
import paddle
|
||
|
|
||
|
# TODO: Complete type-hint and doc string.
|
||
|
|
||
|
|
||
|
def blackman_window(win_len, dtype=np.float32):
|
||
|
arcs = np.pi * np.arange(win_len) / float(win_len)
|
||
|
win = np.asarray(
|
||
|
[0.42 - 0.5 * np.cos(2 * arc) + 0.08 * np.cos(4 * arc) for arc in arcs],
|
||
|
dtype=dtype)
|
||
|
return paddle.to_tensor(win)
|
||
|
|
||
|
|
||
|
def compute_amplitude(waveforms, lengths=None, amp_type="avg", scale="linear"):
|
||
|
if len(waveforms.shape) == 1:
|
||
|
waveforms = waveforms.unsqueeze(0)
|
||
|
|
||
|
assert amp_type in ["avg", "peak"]
|
||
|
assert scale in ["linear", "dB"]
|
||
|
|
||
|
if amp_type == "avg":
|
||
|
if lengths is None:
|
||
|
out = paddle.mean(paddle.abs(waveforms), axis=1, keepdim=True)
|
||
|
else:
|
||
|
wav_sum = paddle.sum(paddle.abs(waveforms), axis=1, keepdim=True)
|
||
|
out = wav_sum / lengths
|
||
|
elif amp_type == "peak":
|
||
|
out = paddle.max(paddle.abs(waveforms), axis=1, keepdim=True)
|
||
|
else:
|
||
|
raise NotImplementedError
|
||
|
|
||
|
if scale == "linear":
|
||
|
return out
|
||
|
elif scale == "dB":
|
||
|
return paddle.clip(20 * paddle.log10(out), min=-80)
|
||
|
else:
|
||
|
raise NotImplementedError
|
||
|
|
||
|
|
||
|
def dB_to_amplitude(SNR):
|
||
|
return 10**(SNR / 20)
|
||
|
|
||
|
|
||
|
def convolve1d(
|
||
|
waveform,
|
||
|
kernel,
|
||
|
padding=0,
|
||
|
pad_type="constant",
|
||
|
stride=1,
|
||
|
groups=1, ):
|
||
|
if len(waveform.shape) != 3:
|
||
|
raise ValueError("Convolve1D expects a 3-dimensional tensor")
|
||
|
|
||
|
# Padding can be a tuple (left_pad, right_pad) or an int
|
||
|
if isinstance(padding, list):
|
||
|
waveform = paddle.nn.functional.pad(
|
||
|
x=waveform,
|
||
|
pad=padding,
|
||
|
mode=pad_type,
|
||
|
data_format='NLC', )
|
||
|
|
||
|
# Move time dimension last, which pad and fft and conv expect.
|
||
|
# (N, L, C) -> (N, C, L)
|
||
|
waveform = waveform.transpose([0, 2, 1])
|
||
|
kernel = kernel.transpose([0, 2, 1])
|
||
|
|
||
|
convolved = paddle.nn.functional.conv1d(
|
||
|
x=waveform,
|
||
|
weight=kernel,
|
||
|
stride=stride,
|
||
|
groups=groups,
|
||
|
padding=padding if not isinstance(padding, list) else 0, )
|
||
|
|
||
|
# Return time dimension to the second dimension.
|
||
|
return convolved.transpose([0, 2, 1])
|
||
|
|
||
|
|
||
|
def notch_filter(notch_freq, filter_width=101, notch_width=0.05):
|
||
|
# Check inputs
|
||
|
assert 0 < notch_freq <= 1
|
||
|
assert filter_width % 2 != 0
|
||
|
pad = filter_width // 2
|
||
|
inputs = paddle.arange(filter_width, dtype='float32') - pad
|
||
|
|
||
|
# Avoid frequencies that are too low
|
||
|
notch_freq += notch_width
|
||
|
|
||
|
# Define sinc function, avoiding division by zero
|
||
|
def sinc(x):
|
||
|
def _sinc(x):
|
||
|
return paddle.sin(x) / x
|
||
|
|
||
|
# The zero is at the middle index
|
||
|
res = paddle.concat(
|
||
|
[_sinc(x[:pad]), paddle.ones([1]), _sinc(x[pad + 1:])])
|
||
|
return res
|
||
|
|
||
|
# Compute a low-pass filter with cutoff frequency notch_freq.
|
||
|
hlpf = sinc(3 * (notch_freq - notch_width) * inputs)
|
||
|
# import torch
|
||
|
# hlpf *= paddle.to_tensor(torch.blackman_window(filter_width).detach().numpy())
|
||
|
hlpf *= blackman_window(filter_width)
|
||
|
hlpf /= paddle.sum(hlpf)
|
||
|
|
||
|
# Compute a high-pass filter with cutoff frequency notch_freq.
|
||
|
hhpf = sinc(3 * (notch_freq + notch_width) * inputs)
|
||
|
# hhpf *= paddle.to_tensor(torch.blackman_window(filter_width).detach().numpy())
|
||
|
hhpf *= blackman_window(filter_width)
|
||
|
hhpf /= -paddle.sum(hhpf)
|
||
|
hhpf[pad] += 1
|
||
|
|
||
|
# Adding filters creates notch filter
|
||
|
return (hlpf + hhpf).reshape([1, -1, 1])
|
||
|
|
||
|
|
||
|
def reverberate(waveforms,
|
||
|
rir_waveform,
|
||
|
sample_rate,
|
||
|
impulse_duration=0.3,
|
||
|
rescale_amp="avg"):
|
||
|
orig_shape = waveforms.shape
|
||
|
|
||
|
if len(waveforms.shape) > 3 or len(rir_waveform.shape) > 3:
|
||
|
raise NotImplementedError
|
||
|
|
||
|
# if inputs are mono tensors we reshape to 1, samples
|
||
|
if len(waveforms.shape) == 1:
|
||
|
waveforms = waveforms.unsqueeze(0).unsqueeze(-1)
|
||
|
elif len(waveforms.shape) == 2:
|
||
|
waveforms = waveforms.unsqueeze(-1)
|
||
|
|
||
|
if len(rir_waveform.shape) == 1: # convolve1d expects a 3d tensor !
|
||
|
rir_waveform = rir_waveform.unsqueeze(0).unsqueeze(-1)
|
||
|
elif len(rir_waveform.shape) == 2:
|
||
|
rir_waveform = rir_waveform.unsqueeze(-1)
|
||
|
|
||
|
# Compute the average amplitude of the clean
|
||
|
orig_amplitude = compute_amplitude(waveforms, waveforms.shape[1],
|
||
|
rescale_amp)
|
||
|
|
||
|
# Compute index of the direct signal, so we can preserve alignment
|
||
|
impulse_index_start = rir_waveform.abs().argmax(axis=1).item()
|
||
|
impulse_index_end = min(
|
||
|
impulse_index_start + int(sample_rate * impulse_duration),
|
||
|
rir_waveform.shape[1])
|
||
|
rir_waveform = rir_waveform[:, impulse_index_start:impulse_index_end, :]
|
||
|
rir_waveform = rir_waveform / paddle.norm(rir_waveform, p=2)
|
||
|
rir_waveform = paddle.flip(rir_waveform, [1])
|
||
|
|
||
|
waveforms = convolve1d(
|
||
|
waveform=waveforms,
|
||
|
kernel=rir_waveform,
|
||
|
padding=[rir_waveform.shape[1] - 1, 0], )
|
||
|
|
||
|
# Rescale to the peak amplitude of the clean waveform
|
||
|
waveforms = rescale(waveforms, waveforms.shape[1], orig_amplitude,
|
||
|
rescale_amp)
|
||
|
|
||
|
if len(orig_shape) == 1:
|
||
|
waveforms = waveforms.squeeze(0).squeeze(-1)
|
||
|
if len(orig_shape) == 2:
|
||
|
waveforms = waveforms.squeeze(-1)
|
||
|
|
||
|
return waveforms
|
||
|
|
||
|
|
||
|
def rescale(waveforms, lengths, target_lvl, amp_type="avg", scale="linear"):
|
||
|
assert amp_type in ["peak", "avg"]
|
||
|
assert scale in ["linear", "dB"]
|
||
|
|
||
|
batch_added = False
|
||
|
if len(waveforms.shape) == 1:
|
||
|
batch_added = True
|
||
|
waveforms = waveforms.unsqueeze(0)
|
||
|
|
||
|
waveforms = normalize(waveforms, lengths, amp_type)
|
||
|
|
||
|
if scale == "linear":
|
||
|
out = target_lvl * waveforms
|
||
|
elif scale == "dB":
|
||
|
out = dB_to_amplitude(target_lvl) * waveforms
|
||
|
|
||
|
else:
|
||
|
raise NotImplementedError("Invalid scale, choose between dB and linear")
|
||
|
|
||
|
if batch_added:
|
||
|
out = out.squeeze(0)
|
||
|
|
||
|
return out
|
||
|
|
||
|
|
||
|
def normalize(waveforms, lengths=None, amp_type="avg", eps=1e-14):
|
||
|
assert amp_type in ["avg", "peak"]
|
||
|
|
||
|
batch_added = False
|
||
|
if len(waveforms.shape) == 1:
|
||
|
batch_added = True
|
||
|
waveforms = waveforms.unsqueeze(0)
|
||
|
|
||
|
den = compute_amplitude(waveforms, lengths, amp_type) + eps
|
||
|
if batch_added:
|
||
|
waveforms = waveforms.squeeze(0)
|
||
|
return waveforms / den
|