Merge pull request #1612 from Jackwaterveg/update

[ASR] Replace kaidi_fbank with paddleaudio
pull/1626/head
Hui Zhang 3 years ago committed by GitHub
commit 943d4ac1ee
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23

@ -23,7 +23,3 @@ process:
n_mask: 2
inplace: true
replace_with_zero: false

@ -14,8 +14,11 @@
# Modified from espnet(https://github.com/espnet/espnet)
import librosa
import numpy as np
import paddle
from python_speech_features import logfbank
import paddleaudio.compliance.kaldi as kaldi
def stft(x,
n_fft,
@ -309,6 +312,77 @@ class IStft():
class LogMelSpectrogramKaldi():
def __init__(
self,
fs=16000,
n_mels=80,
n_shift=160, # unit:sample, 10ms
win_length=400, # unit:sample, 25ms
energy_floor=0.0,
dither=0.1):
"""
The Kaldi implementation of LogMelSpectrogram
Args:
fs (int): sample rate of the audio
n_mels (int): number of mel filter banks
n_shift (int): number of points in a frame shift
win_length (int): number of points in a frame windows
energy_floor (float): Floor on energy in Spectrogram computation (absolute)
dither (float): Dithering constant
Returns:
LogMelSpectrogramKaldi
"""
self.fs = fs
self.n_mels = n_mels
num_point_ms = fs / 1000
self.n_frame_length = win_length / num_point_ms
self.n_frame_shift = n_shift / num_point_ms
self.energy_floor = energy_floor
self.dither = dither
def __repr__(self):
return (
"{name}(fs={fs}, n_mels={n_mels}, "
"n_frame_shift={n_frame_shift}, n_frame_length={n_frame_length}, "
"dither={dither}))".format(
name=self.__class__.__name__,
fs=self.fs,
n_mels=self.n_mels,
n_frame_shift=self.n_frame_shift,
n_frame_length=self.n_frame_length,
dither=self.dither, ))
def __call__(self, x, train):
"""
Args:
x (np.ndarray): shape (Ti,)
train (bool): True, train mode.
Raises:
ValueError: not support (Ti, C)
Returns:
np.ndarray: (T, D)
"""
dither = self.dither if train else 0.0
if x.ndim != 1:
raise ValueError("Not support x: [Time, Channel]")
waveform = paddle.to_tensor(np.expand_dims(x, 0), dtype=paddle.float32)
mat = kaldi.fbank(
waveform,
n_mels=self.n_mels,
frame_length=self.n_frame_length,
frame_shift=self.n_frame_shift,
dither=dither,
energy_floor=self.energy_floor,
sr=self.fs)
mat = np.squeeze(mat.numpy())
return mat
class LogMelSpectrogramKaldi_decay():
def __init__(
self,
fs=16000,

@ -31,6 +31,7 @@ import_alias = dict(
freq_mask="paddlespeech.s2t.transform.spec_augment:FreqMask",
spec_augment="paddlespeech.s2t.transform.spec_augment:SpecAugment",
speed_perturbation="paddlespeech.s2t.transform.perturb:SpeedPerturbation",
speed_perturbation_sox="paddlespeech.s2t.transform.perturb:SpeedPerturbationSox",
volume_perturbation="paddlespeech.s2t.transform.perturb:VolumePerturbation",
noise_injection="paddlespeech.s2t.transform.perturb:NoiseInjection",
bandpass_perturbation="paddlespeech.s2t.transform.perturb:BandpassPerturbation",

Loading…
Cancel
Save