You can not select more than 25 topics Topics must start with a letter or number, can include dashes ('-') and can be up to 35 characters long.
PaddleSpeech/paddlespeech/s2t/transform/spectrogram.py

306 lines
8.7 KiB

# Copyright (c) 2021 PaddlePaddle Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import librosa
import numpy as np
def stft(x,
n_fft,
n_shift,
win_length=None,
window="hann",
center=True,
pad_mode="reflect"):
# x: [Time, Channel]
if x.ndim == 1:
single_channel = True
# x: [Time] -> [Time, Channel]
x = x[:, None]
else:
single_channel = False
x = x.astype(np.float32)
# FIXME(kamo): librosa.stft can't use multi-channel?
# x: [Time, Channel, Freq]
x = np.stack(
[
librosa.stft(
x[:, ch],
n_fft=n_fft,
hop_length=n_shift,
win_length=win_length,
window=window,
center=center,
pad_mode=pad_mode, ).T for ch in range(x.shape[1])
],
axis=1, )
if single_channel:
# x: [Time, Channel, Freq] -> [Time, Freq]
x = x[:, 0]
return x
def istft(x, n_shift, win_length=None, window="hann", center=True):
# x: [Time, Channel, Freq]
if x.ndim == 2:
single_channel = True
# x: [Time, Freq] -> [Time, Channel, Freq]
x = x[:, None, :]
else:
single_channel = False
# x: [Time, Channel]
x = np.stack(
[
librosa.istft(
x[:, ch].T, # [Time, Freq] -> [Freq, Time]
hop_length=n_shift,
win_length=win_length,
window=window,
center=center, ) for ch in range(x.shape[1])
],
axis=1, )
if single_channel:
# x: [Time, Channel] -> [Time]
x = x[:, 0]
return x
def stft2logmelspectrogram(x_stft,
fs,
n_mels,
n_fft,
fmin=None,
fmax=None,
eps=1e-10):
# x_stft: (Time, Channel, Freq) or (Time, Freq)
fmin = 0 if fmin is None else fmin
fmax = fs / 2 if fmax is None else fmax
# spc: (Time, Channel, Freq) or (Time, Freq)
spc = np.abs(x_stft)
# mel_basis: (Mel_freq, Freq)
mel_basis = librosa.filters.mel(fs, n_fft, n_mels, fmin, fmax)
# lmspc: (Time, Channel, Mel_freq) or (Time, Mel_freq)
lmspc = np.log10(np.maximum(eps, np.dot(spc, mel_basis.T)))
return lmspc
def spectrogram(x, n_fft, n_shift, win_length=None, window="hann"):
# x: (Time, Channel) -> spc: (Time, Channel, Freq)
spc = np.abs(stft(x, n_fft, n_shift, win_length, window=window))
return spc
def logmelspectrogram(
x,
fs,
n_mels,
n_fft,
n_shift,
win_length=None,
window="hann",
fmin=None,
fmax=None,
eps=1e-10,
pad_mode="reflect", ):
# stft: (Time, Channel, Freq) or (Time, Freq)
x_stft = stft(
x,
n_fft=n_fft,
n_shift=n_shift,
win_length=win_length,
window=window,
pad_mode=pad_mode, )
return stft2logmelspectrogram(
x_stft,
fs=fs,
n_mels=n_mels,
n_fft=n_fft,
fmin=fmin,
fmax=fmax,
eps=eps)
class Spectrogram():
def __init__(self, n_fft, n_shift, win_length=None, window="hann"):
self.n_fft = n_fft
self.n_shift = n_shift
self.win_length = win_length
self.window = window
def __repr__(self):
return ("{name}(n_fft={n_fft}, n_shift={n_shift}, "
"win_length={win_length}, window={window})".format(
name=self.__class__.__name__,
n_fft=self.n_fft,
n_shift=self.n_shift,
win_length=self.win_length,
window=self.window, ))
def __call__(self, x):
return spectrogram(
x,
n_fft=self.n_fft,
n_shift=self.n_shift,
win_length=self.win_length,
window=self.window, )
class LogMelSpectrogram():
def __init__(
self,
fs,
n_mels,
n_fft,
n_shift,
win_length=None,
window="hann",
fmin=None,
fmax=None,
eps=1e-10, ):
self.fs = fs
self.n_mels = n_mels
self.n_fft = n_fft
self.n_shift = n_shift
self.win_length = win_length
self.window = window
self.fmin = fmin
self.fmax = fmax
self.eps = eps
def __repr__(self):
return ("{name}(fs={fs}, n_mels={n_mels}, n_fft={n_fft}, "
"n_shift={n_shift}, win_length={win_length}, window={window}, "
"fmin={fmin}, fmax={fmax}, eps={eps}))".format(
name=self.__class__.__name__,
fs=self.fs,
n_mels=self.n_mels,
n_fft=self.n_fft,
n_shift=self.n_shift,
win_length=self.win_length,
window=self.window,
fmin=self.fmin,
fmax=self.fmax,
eps=self.eps, ))
def __call__(self, x):
return logmelspectrogram(
x,
fs=self.fs,
n_mels=self.n_mels,
n_fft=self.n_fft,
n_shift=self.n_shift,
win_length=self.win_length,
window=self.window, )
class Stft2LogMelSpectrogram():
def __init__(self, fs, n_mels, n_fft, fmin=None, fmax=None, eps=1e-10):
self.fs = fs
self.n_mels = n_mels
self.n_fft = n_fft
self.fmin = fmin
self.fmax = fmax
self.eps = eps
def __repr__(self):
return ("{name}(fs={fs}, n_mels={n_mels}, n_fft={n_fft}, "
"fmin={fmin}, fmax={fmax}, eps={eps}))".format(
name=self.__class__.__name__,
fs=self.fs,
n_mels=self.n_mels,
n_fft=self.n_fft,
fmin=self.fmin,
fmax=self.fmax,
eps=self.eps, ))
def __call__(self, x):
return stft2logmelspectrogram(
x,
fs=self.fs,
n_mels=self.n_mels,
n_fft=self.n_fft,
fmin=self.fmin,
fmax=self.fmax, )
class Stft():
def __init__(
self,
n_fft,
n_shift,
win_length=None,
window="hann",
center=True,
pad_mode="reflect", ):
self.n_fft = n_fft
self.n_shift = n_shift
self.win_length = win_length
self.window = window
self.center = center
self.pad_mode = pad_mode
def __repr__(self):
return ("{name}(n_fft={n_fft}, n_shift={n_shift}, "
"win_length={win_length}, window={window},"
"center={center}, pad_mode={pad_mode})".format(
name=self.__class__.__name__,
n_fft=self.n_fft,
n_shift=self.n_shift,
win_length=self.win_length,
window=self.window,
center=self.center,
pad_mode=self.pad_mode, ))
def __call__(self, x):
return stft(
x,
self.n_fft,
self.n_shift,
win_length=self.win_length,
window=self.window,
center=self.center,
pad_mode=self.pad_mode, )
class IStft():
def __init__(self, n_shift, win_length=None, window="hann", center=True):
self.n_shift = n_shift
self.win_length = win_length
self.window = window
self.center = center
def __repr__(self):
return ("{name}(n_shift={n_shift}, "
"win_length={win_length}, window={window},"
"center={center})".format(
name=self.__class__.__name__,
n_shift=self.n_shift,
win_length=self.win_length,
window=self.window,
center=self.center, ))
def __call__(self, x):
return istft(
x,
self.n_shift,
win_length=self.win_length,
window=self.window,
center=self.center, )