You can not select more than 25 topics Topics must start with a letter or number, can include dashes ('-') and can be up to 35 characters long.
PaddleSpeech/paddlespeech/vector/io/signal_processing.py

218 lines
6.8 KiB

# Copyright (c) 2021 PaddlePaddle Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import numpy as np
import paddle
# TODO: Complete type-hint and doc string.
def blackman_window(win_len, dtype=np.float32):
arcs = np.pi * np.arange(win_len) / float(win_len)
win = np.asarray(
[0.42 - 0.5 * np.cos(2 * arc) + 0.08 * np.cos(4 * arc) for arc in arcs],
dtype=dtype)
return paddle.to_tensor(win)
def compute_amplitude(waveforms, lengths=None, amp_type="avg", scale="linear"):
if len(waveforms.shape) == 1:
waveforms = waveforms.unsqueeze(0)
assert amp_type in ["avg", "peak"]
assert scale in ["linear", "dB"]
if amp_type == "avg":
if lengths is None:
out = paddle.mean(paddle.abs(waveforms), axis=1, keepdim=True)
else:
wav_sum = paddle.sum(paddle.abs(waveforms), axis=1, keepdim=True)
out = wav_sum / lengths
elif amp_type == "peak":
out = paddle.max(paddle.abs(waveforms), axis=1, keepdim=True)
else:
raise NotImplementedError
if scale == "linear":
return out
elif scale == "dB":
return paddle.clip(20 * paddle.log10(out), min=-80)
else:
raise NotImplementedError
def dB_to_amplitude(SNR):
return 10**(SNR / 20)
def convolve1d(
waveform,
kernel,
padding=0,
pad_type="constant",
stride=1,
groups=1, ):
if len(waveform.shape) != 3:
raise ValueError("Convolve1D expects a 3-dimensional tensor")
# Padding can be a tuple (left_pad, right_pad) or an int
if isinstance(padding, list):
waveform = paddle.nn.functional.pad(
x=waveform,
pad=padding,
mode=pad_type,
data_format='NLC', )
# Move time dimension last, which pad and fft and conv expect.
# (N, L, C) -> (N, C, L)
waveform = waveform.transpose([0, 2, 1])
kernel = kernel.transpose([0, 2, 1])
convolved = paddle.nn.functional.conv1d(
x=waveform,
weight=kernel,
stride=stride,
groups=groups,
padding=padding if not isinstance(padding, list) else 0, )
# Return time dimension to the second dimension.
return convolved.transpose([0, 2, 1])
def notch_filter(notch_freq, filter_width=101, notch_width=0.05):
# Check inputs
assert 0 < notch_freq <= 1
assert filter_width % 2 != 0
pad = filter_width // 2
inputs = paddle.arange(filter_width, dtype='float32') - pad
# Avoid frequencies that are too low
notch_freq += notch_width
# Define sinc function, avoiding division by zero
def sinc(x):
def _sinc(x):
return paddle.sin(x) / x
# The zero is at the middle index
res = paddle.concat(
[_sinc(x[:pad]), paddle.ones([1]), _sinc(x[pad + 1:])])
return res
# Compute a low-pass filter with cutoff frequency notch_freq.
hlpf = sinc(3 * (notch_freq - notch_width) * inputs)
# import torch
# hlpf *= paddle.to_tensor(torch.blackman_window(filter_width).detach().numpy())
hlpf *= blackman_window(filter_width)
hlpf /= paddle.sum(hlpf)
# Compute a high-pass filter with cutoff frequency notch_freq.
hhpf = sinc(3 * (notch_freq + notch_width) * inputs)
# hhpf *= paddle.to_tensor(torch.blackman_window(filter_width).detach().numpy())
hhpf *= blackman_window(filter_width)
hhpf /= -paddle.sum(hhpf)
hhpf[pad] += 1
# Adding filters creates notch filter
return (hlpf + hhpf).reshape([1, -1, 1])
def reverberate(waveforms,
rir_waveform,
sample_rate,
impulse_duration=0.3,
rescale_amp="avg"):
orig_shape = waveforms.shape
if len(waveforms.shape) > 3 or len(rir_waveform.shape) > 3:
raise NotImplementedError
# if inputs are mono tensors we reshape to 1, samples
if len(waveforms.shape) == 1:
waveforms = waveforms.unsqueeze(0).unsqueeze(-1)
elif len(waveforms.shape) == 2:
waveforms = waveforms.unsqueeze(-1)
if len(rir_waveform.shape) == 1: # convolve1d expects a 3d tensor !
rir_waveform = rir_waveform.unsqueeze(0).unsqueeze(-1)
elif len(rir_waveform.shape) == 2:
rir_waveform = rir_waveform.unsqueeze(-1)
# Compute the average amplitude of the clean
orig_amplitude = compute_amplitude(waveforms, waveforms.shape[1],
rescale_amp)
# Compute index of the direct signal, so we can preserve alignment
impulse_index_start = rir_waveform.abs().argmax(axis=1).item()
impulse_index_end = min(
impulse_index_start + int(sample_rate * impulse_duration),
rir_waveform.shape[1])
rir_waveform = rir_waveform[:, impulse_index_start:impulse_index_end, :]
rir_waveform = rir_waveform / paddle.norm(rir_waveform, p=2)
rir_waveform = paddle.flip(rir_waveform, [1])
waveforms = convolve1d(
waveform=waveforms,
kernel=rir_waveform,
padding=[rir_waveform.shape[1] - 1, 0], )
# Rescale to the peak amplitude of the clean waveform
waveforms = rescale(waveforms, waveforms.shape[1], orig_amplitude,
rescale_amp)
if len(orig_shape) == 1:
waveforms = waveforms.squeeze(0).squeeze(-1)
if len(orig_shape) == 2:
waveforms = waveforms.squeeze(-1)
return waveforms
def rescale(waveforms, lengths, target_lvl, amp_type="avg", scale="linear"):
assert amp_type in ["peak", "avg"]
assert scale in ["linear", "dB"]
batch_added = False
if len(waveforms.shape) == 1:
batch_added = True
waveforms = waveforms.unsqueeze(0)
waveforms = normalize(waveforms, lengths, amp_type)
if scale == "linear":
out = target_lvl * waveforms
elif scale == "dB":
out = dB_to_amplitude(target_lvl) * waveforms
else:
raise NotImplementedError("Invalid scale, choose between dB and linear")
if batch_added:
out = out.squeeze(0)
return out
def normalize(waveforms, lengths=None, amp_type="avg", eps=1e-14):
assert amp_type in ["avg", "peak"]
batch_added = False
if len(waveforms.shape) == 1:
batch_added = True
waveforms = waveforms.unsqueeze(0)
den = compute_amplitude(waveforms, lengths, amp_type) + eps
if batch_added:
waveforms = waveforms.squeeze(0)
return waveforms / den