【audio】remove paddleaudio from paddlespeech (#3986)
* remove paddleaudio from paddlespeech * use scikit-learn instead sklearn * add pathos * remove utils * add kaldiio * remove useless printpull/3994/head
parent
f3a5df2049
commit
0479cce8ff
@ -0,0 +1,20 @@
|
||||
# Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
from .soundfile_backend import depth_convert
|
||||
from .soundfile_backend import load
|
||||
from .soundfile_backend import normalize
|
||||
from .soundfile_backend import resample
|
||||
from .soundfile_backend import soundfile_load
|
||||
from .soundfile_backend import soundfile_save
|
||||
from .soundfile_backend import to_mono
|
@ -0,0 +1,53 @@
|
||||
# Token from https://github.com/pytorch/audio/blob/main/torchaudio/backend/common.py with modification.
|
||||
|
||||
|
||||
class AudioInfo:
|
||||
"""return of info function.
|
||||
|
||||
This class is used by :ref:`"sox_io" backend<sox_io_backend>` and
|
||||
:ref:`"soundfile" backend with the new interface<soundfile_backend>`.
|
||||
|
||||
:ivar int sample_rate: Sample rate
|
||||
:ivar int num_frames: The number of frames
|
||||
:ivar int num_channels: The number of channels
|
||||
:ivar int bits_per_sample: The number of bits per sample. This is 0 for lossy formats,
|
||||
or when it cannot be accurately inferred.
|
||||
:ivar str encoding: Audio encoding
|
||||
The values encoding can take are one of the following:
|
||||
|
||||
* ``PCM_S``: Signed integer linear PCM
|
||||
* ``PCM_U``: Unsigned integer linear PCM
|
||||
* ``PCM_F``: Floating point linear PCM
|
||||
* ``FLAC``: Flac, Free Lossless Audio Codec
|
||||
* ``ULAW``: Mu-law
|
||||
* ``ALAW``: A-law
|
||||
* ``MP3`` : MP3, MPEG-1 Audio Layer III
|
||||
* ``VORBIS``: OGG Vorbis
|
||||
* ``AMR_WB``: Adaptive Multi-Rate
|
||||
* ``AMR_NB``: Adaptive Multi-Rate Wideband
|
||||
* ``OPUS``: Opus
|
||||
* ``HTK``: Single channel 16-bit PCM
|
||||
* ``UNKNOWN`` : None of above
|
||||
"""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
sample_rate: int,
|
||||
num_frames: int,
|
||||
num_channels: int,
|
||||
bits_per_sample: int,
|
||||
encoding: str, ):
|
||||
self.sample_rate = sample_rate
|
||||
self.num_frames = num_frames
|
||||
self.num_channels = num_channels
|
||||
self.bits_per_sample = bits_per_sample
|
||||
self.encoding = encoding
|
||||
|
||||
def __str__(self):
|
||||
return (f"AudioMetaData("
|
||||
f"sample_rate={self.sample_rate}, "
|
||||
f"num_frames={self.num_frames}, "
|
||||
f"num_channels={self.num_channels}, "
|
||||
f"bits_per_sample={self.bits_per_sample}, "
|
||||
f"encoding={self.encoding}"
|
||||
f")")
|
@ -0,0 +1,677 @@
|
||||
# Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
import os
|
||||
import warnings
|
||||
from typing import Optional
|
||||
from typing import Tuple
|
||||
|
||||
import numpy as np
|
||||
import paddle
|
||||
import resampy
|
||||
import soundfile
|
||||
from scipy.io import wavfile
|
||||
|
||||
from ..utils import depth_convert
|
||||
from ..utils import ParameterError
|
||||
from .common import AudioInfo
|
||||
|
||||
__all__ = [
|
||||
'resample',
|
||||
'to_mono',
|
||||
'normalize',
|
||||
'save',
|
||||
'soundfile_save',
|
||||
'load',
|
||||
'soundfile_load',
|
||||
'info',
|
||||
]
|
||||
NORMALMIZE_TYPES = ['linear', 'gaussian']
|
||||
MERGE_TYPES = ['ch0', 'ch1', 'random', 'average']
|
||||
RESAMPLE_MODES = ['kaiser_best', 'kaiser_fast']
|
||||
EPS = 1e-8
|
||||
|
||||
|
||||
def resample(y: np.ndarray,
|
||||
src_sr: int,
|
||||
target_sr: int,
|
||||
mode: str='kaiser_fast') -> np.ndarray:
|
||||
"""Audio resampling.
|
||||
|
||||
Args:
|
||||
y (np.ndarray): Input waveform array in 1D or 2D.
|
||||
src_sr (int): Source sample rate.
|
||||
target_sr (int): Target sample rate.
|
||||
mode (str, optional): The resampling filter to use. Defaults to 'kaiser_fast'.
|
||||
|
||||
Returns:
|
||||
np.ndarray: `y` resampled to `target_sr`
|
||||
"""
|
||||
|
||||
if mode == 'kaiser_best':
|
||||
warnings.warn(
|
||||
f'Using resampy in kaiser_best to {src_sr}=>{target_sr}. This function is pretty slow, \
|
||||
we recommend the mode kaiser_fast in large scale audio training')
|
||||
|
||||
if not isinstance(y, np.ndarray):
|
||||
raise ParameterError(
|
||||
'Only support numpy np.ndarray, but received y in {type(y)}')
|
||||
|
||||
if mode not in RESAMPLE_MODES:
|
||||
raise ParameterError(f'resample mode must in {RESAMPLE_MODES}')
|
||||
|
||||
return resampy.resample(y, src_sr, target_sr, filter=mode)
|
||||
|
||||
|
||||
def to_mono(y: np.ndarray, merge_type: str='average') -> np.ndarray:
|
||||
"""Convert sterior audio to mono.
|
||||
|
||||
Args:
|
||||
y (np.ndarray): Input waveform array in 1D or 2D.
|
||||
merge_type (str, optional): Merge type to generate mono waveform. Defaults to 'average'.
|
||||
|
||||
Returns:
|
||||
np.ndarray: `y` with mono channel.
|
||||
"""
|
||||
|
||||
if merge_type not in MERGE_TYPES:
|
||||
raise ParameterError(
|
||||
f'Unsupported merge type {merge_type}, available types are {MERGE_TYPES}'
|
||||
)
|
||||
if y.ndim > 2:
|
||||
raise ParameterError(
|
||||
f'Unsupported audio array, y.ndim > 2, the shape is {y.shape}')
|
||||
if y.ndim == 1: # nothing to merge
|
||||
return y
|
||||
|
||||
if merge_type == 'ch0':
|
||||
return y[0]
|
||||
if merge_type == 'ch1':
|
||||
return y[1]
|
||||
if merge_type == 'random':
|
||||
return y[np.random.randint(0, 2)]
|
||||
|
||||
# need to do averaging according to dtype
|
||||
|
||||
if y.dtype == 'float32':
|
||||
y_out = (y[0] + y[1]) * 0.5
|
||||
elif y.dtype == 'int16':
|
||||
y_out = y.astype('int32')
|
||||
y_out = (y_out[0] + y_out[1]) // 2
|
||||
y_out = np.clip(y_out, np.iinfo(y.dtype).min,
|
||||
np.iinfo(y.dtype).max).astype(y.dtype)
|
||||
|
||||
elif y.dtype == 'int8':
|
||||
y_out = y.astype('int16')
|
||||
y_out = (y_out[0] + y_out[1]) // 2
|
||||
y_out = np.clip(y_out, np.iinfo(y.dtype).min,
|
||||
np.iinfo(y.dtype).max).astype(y.dtype)
|
||||
else:
|
||||
raise ParameterError(f'Unsupported dtype: {y.dtype}')
|
||||
return y_out
|
||||
|
||||
|
||||
def soundfile_load_(file: os.PathLike,
|
||||
offset: Optional[float]=None,
|
||||
dtype: str='int16',
|
||||
duration: Optional[int]=None) -> Tuple[np.ndarray, int]:
|
||||
"""Load audio using soundfile library. This function load audio file using libsndfile.
|
||||
|
||||
Args:
|
||||
file (os.PathLike): File of waveform.
|
||||
offset (Optional[float], optional): Offset to the start of waveform. Defaults to None.
|
||||
dtype (str, optional): Data type of waveform. Defaults to 'int16'.
|
||||
duration (Optional[int], optional): Duration of waveform to read. Defaults to None.
|
||||
|
||||
Returns:
|
||||
Tuple[np.ndarray, int]: Waveform in ndarray and its samplerate.
|
||||
"""
|
||||
with soundfile.SoundFile(file) as sf_desc:
|
||||
sr_native = sf_desc.samplerate
|
||||
if offset:
|
||||
sf_desc.seek(int(offset * sr_native))
|
||||
if duration is not None:
|
||||
frame_duration = int(duration * sr_native)
|
||||
else:
|
||||
frame_duration = -1
|
||||
y = sf_desc.read(frames=frame_duration, dtype=dtype, always_2d=False).T
|
||||
|
||||
return y, sf_desc.samplerate
|
||||
|
||||
|
||||
def normalize(y: np.ndarray, norm_type: str='linear',
|
||||
mul_factor: float=1.0) -> np.ndarray:
|
||||
"""Normalize an input audio with additional multiplier.
|
||||
|
||||
Args:
|
||||
y (np.ndarray): Input waveform array in 1D or 2D.
|
||||
norm_type (str, optional): Type of normalization. Defaults to 'linear'.
|
||||
mul_factor (float, optional): Scaling factor. Defaults to 1.0.
|
||||
|
||||
Returns:
|
||||
np.ndarray: `y` after normalization.
|
||||
"""
|
||||
|
||||
if norm_type == 'linear':
|
||||
amax = np.max(np.abs(y))
|
||||
factor = 1.0 / (amax + EPS)
|
||||
y = y * factor * mul_factor
|
||||
elif norm_type == 'gaussian':
|
||||
amean = np.mean(y)
|
||||
astd = np.std(y)
|
||||
astd = max(astd, EPS)
|
||||
y = mul_factor * (y - amean) / astd
|
||||
else:
|
||||
raise NotImplementedError(f'norm_type should be in {NORMALMIZE_TYPES}')
|
||||
|
||||
return y
|
||||
|
||||
|
||||
def soundfile_save(y: np.ndarray, sr: int, file: os.PathLike) -> None:
|
||||
"""Save audio file to disk. This function saves audio to disk using scipy.io.wavfile, with additional step to convert input waveform to int16.
|
||||
|
||||
Args:
|
||||
y (np.ndarray): Input waveform array in 1D or 2D.
|
||||
sr (int): Sample rate.
|
||||
file (os.PathLike): Path of audio file to save.
|
||||
"""
|
||||
if not file.endswith('.wav'):
|
||||
raise ParameterError(
|
||||
f'only .wav file supported, but dst file name is: {file}')
|
||||
|
||||
if sr <= 0:
|
||||
raise ParameterError(
|
||||
f'Sample rate should be larger than 0, received sr = {sr}')
|
||||
|
||||
if y.dtype not in ['int16', 'int8']:
|
||||
warnings.warn(
|
||||
f'input data type is {y.dtype}, will convert data to int16 format before saving'
|
||||
)
|
||||
y_out = depth_convert(y, 'int16')
|
||||
else:
|
||||
y_out = y
|
||||
|
||||
wavfile.write(file, sr, y_out)
|
||||
|
||||
|
||||
def soundfile_load(
|
||||
file: os.PathLike,
|
||||
sr: Optional[int]=None,
|
||||
mono: bool=True,
|
||||
merge_type: str='average', # ch0,ch1,random,average
|
||||
normal: bool=True,
|
||||
norm_type: str='linear',
|
||||
norm_mul_factor: float=1.0,
|
||||
offset: float=0.0,
|
||||
duration: Optional[int]=None,
|
||||
dtype: str='float32',
|
||||
resample_mode: str='kaiser_fast') -> Tuple[np.ndarray, int]:
|
||||
"""Load audio file from disk. This function loads audio from disk using using audio backend.
|
||||
|
||||
Args:
|
||||
file (os.PathLike): Path of audio file to load.
|
||||
sr (Optional[int], optional): Sample rate of loaded waveform. Defaults to None.
|
||||
mono (bool, optional): Return waveform with mono channel. Defaults to True.
|
||||
merge_type (str, optional): Merge type of multi-channels waveform. Defaults to 'average'.
|
||||
normal (bool, optional): Waveform normalization. Defaults to True.
|
||||
norm_type (str, optional): Type of normalization. Defaults to 'linear'.
|
||||
norm_mul_factor (float, optional): Scaling factor. Defaults to 1.0.
|
||||
offset (float, optional): Offset to the start of waveform. Defaults to 0.0.
|
||||
duration (Optional[int], optional): Duration of waveform to read. Defaults to None.
|
||||
dtype (str, optional): Data type of waveform. Defaults to 'float32'.
|
||||
resample_mode (str, optional): The resampling filter to use. Defaults to 'kaiser_fast'.
|
||||
|
||||
Returns:
|
||||
Tuple[np.ndarray, int]: Waveform in ndarray and its samplerate.
|
||||
"""
|
||||
|
||||
y, r = soundfile_load_(file, offset=offset, dtype=dtype, duration=duration)
|
||||
|
||||
if not ((y.ndim == 1 and len(y) > 0) or (y.ndim == 2 and len(y[0]) > 0)):
|
||||
raise ParameterError(f'audio file {file} looks empty')
|
||||
|
||||
if mono:
|
||||
y = to_mono(y, merge_type)
|
||||
|
||||
if sr is not None and sr != r:
|
||||
y = resample(y, r, sr, mode=resample_mode)
|
||||
r = sr
|
||||
|
||||
if normal:
|
||||
y = normalize(y, norm_type, norm_mul_factor)
|
||||
elif dtype in ['int8', 'int16']:
|
||||
# still need to do normalization, before depth conversion
|
||||
y = normalize(y, 'linear', 1.0)
|
||||
|
||||
y = depth_convert(y, dtype)
|
||||
return y, r
|
||||
|
||||
|
||||
#The code below is taken from: https://github.com/pytorch/audio/blob/main/torchaudio/backend/soundfile_backend.py, with some modifications.
|
||||
|
||||
|
||||
def _get_subtype_for_wav(dtype: paddle.dtype,
|
||||
encoding: str,
|
||||
bits_per_sample: int):
|
||||
if not encoding:
|
||||
if not bits_per_sample:
|
||||
subtype = {
|
||||
paddle.uint8: "PCM_U8",
|
||||
paddle.int16: "PCM_16",
|
||||
paddle.int32: "PCM_32",
|
||||
paddle.float32: "FLOAT",
|
||||
paddle.float64: "DOUBLE",
|
||||
}.get(dtype)
|
||||
if not subtype:
|
||||
raise ValueError(f"Unsupported dtype for wav: {dtype}")
|
||||
return subtype
|
||||
if bits_per_sample == 8:
|
||||
return "PCM_U8"
|
||||
return f"PCM_{bits_per_sample}"
|
||||
if encoding == "PCM_S":
|
||||
if not bits_per_sample:
|
||||
return "PCM_32"
|
||||
if bits_per_sample == 8:
|
||||
raise ValueError("wav does not support 8-bit signed PCM encoding.")
|
||||
return f"PCM_{bits_per_sample}"
|
||||
if encoding == "PCM_U":
|
||||
if bits_per_sample in (None, 8):
|
||||
return "PCM_U8"
|
||||
raise ValueError("wav only supports 8-bit unsigned PCM encoding.")
|
||||
if encoding == "PCM_F":
|
||||
if bits_per_sample in (None, 32):
|
||||
return "FLOAT"
|
||||
if bits_per_sample == 64:
|
||||
return "DOUBLE"
|
||||
raise ValueError("wav only supports 32/64-bit float PCM encoding.")
|
||||
if encoding == "ULAW":
|
||||
if bits_per_sample in (None, 8):
|
||||
return "ULAW"
|
||||
raise ValueError("wav only supports 8-bit mu-law encoding.")
|
||||
if encoding == "ALAW":
|
||||
if bits_per_sample in (None, 8):
|
||||
return "ALAW"
|
||||
raise ValueError("wav only supports 8-bit a-law encoding.")
|
||||
raise ValueError(f"wav does not support {encoding}.")
|
||||
|
||||
|
||||
def _get_subtype_for_sphere(encoding: str, bits_per_sample: int):
|
||||
if encoding in (None, "PCM_S"):
|
||||
return f"PCM_{bits_per_sample}" if bits_per_sample else "PCM_32"
|
||||
if encoding in ("PCM_U", "PCM_F"):
|
||||
raise ValueError(f"sph does not support {encoding} encoding.")
|
||||
if encoding == "ULAW":
|
||||
if bits_per_sample in (None, 8):
|
||||
return "ULAW"
|
||||
raise ValueError("sph only supports 8-bit for mu-law encoding.")
|
||||
if encoding == "ALAW":
|
||||
return "ALAW"
|
||||
raise ValueError(f"sph does not support {encoding}.")
|
||||
|
||||
|
||||
def _get_subtype(dtype: paddle.dtype,
|
||||
format: str,
|
||||
encoding: str,
|
||||
bits_per_sample: int):
|
||||
if format == "wav":
|
||||
return _get_subtype_for_wav(dtype, encoding, bits_per_sample)
|
||||
if format == "flac":
|
||||
if encoding:
|
||||
raise ValueError("flac does not support encoding.")
|
||||
if not bits_per_sample:
|
||||
return "PCM_16"
|
||||
if bits_per_sample > 24:
|
||||
raise ValueError("flac does not support bits_per_sample > 24.")
|
||||
return "PCM_S8" if bits_per_sample == 8 else f"PCM_{bits_per_sample}"
|
||||
if format in ("ogg", "vorbis"):
|
||||
if encoding or bits_per_sample:
|
||||
raise ValueError(
|
||||
"ogg/vorbis does not support encoding/bits_per_sample.")
|
||||
return "VORBIS"
|
||||
if format == "sph":
|
||||
return _get_subtype_for_sphere(encoding, bits_per_sample)
|
||||
if format in ("nis", "nist"):
|
||||
return "PCM_16"
|
||||
raise ValueError(f"Unsupported format: {format}")
|
||||
|
||||
|
||||
def save(
|
||||
filepath: str,
|
||||
src: paddle.Tensor,
|
||||
sample_rate: int,
|
||||
channels_first: bool=True,
|
||||
compression: Optional[float]=None,
|
||||
format: Optional[str]=None,
|
||||
encoding: Optional[str]=None,
|
||||
bits_per_sample: Optional[int]=None, ):
|
||||
"""Save audio data to file.
|
||||
|
||||
Note:
|
||||
The formats this function can handle depend on the soundfile installation.
|
||||
This function is tested on the following formats;
|
||||
|
||||
* WAV
|
||||
|
||||
* 32-bit floating-point
|
||||
* 32-bit signed integer
|
||||
* 16-bit signed integer
|
||||
* 8-bit unsigned integer
|
||||
|
||||
* FLAC
|
||||
* OGG/VORBIS
|
||||
* SPHERE
|
||||
|
||||
Note:
|
||||
``filepath`` argument is intentionally annotated as ``str`` only, even though it accepts
|
||||
``pathlib.Path`` object as well. This is for the consistency with ``"sox_io"`` backend,
|
||||
|
||||
Args:
|
||||
filepath (str or pathlib.Path): Path to audio file.
|
||||
src (paddle.Tensor): Audio data to save. must be 2D tensor.
|
||||
sample_rate (int): sampling rate
|
||||
channels_first (bool, optional): If ``True``, the given tensor is interpreted as `[channel, time]`,
|
||||
otherwise `[time, channel]`.
|
||||
compression (float of None, optional): Not used.
|
||||
It is here only for interface compatibility reason with "sox_io" backend.
|
||||
format (str or None, optional): Override the audio format.
|
||||
When ``filepath`` argument is path-like object, audio format is
|
||||
inferred from file extension. If the file extension is missing or
|
||||
different, you can specify the correct format with this argument.
|
||||
|
||||
When ``filepath`` argument is file-like object,
|
||||
this argument is required.
|
||||
|
||||
Valid values are ``"wav"``, ``"ogg"``, ``"vorbis"``,
|
||||
``"flac"`` and ``"sph"``.
|
||||
encoding (str or None, optional): Changes the encoding for supported formats.
|
||||
This argument is effective only for supported formats, such as
|
||||
``"wav"``, ``""flac"`` and ``"sph"``. Valid values are:
|
||||
|
||||
- ``"PCM_S"`` (signed integer Linear PCM)
|
||||
- ``"PCM_U"`` (unsigned integer Linear PCM)
|
||||
- ``"PCM_F"`` (floating point PCM)
|
||||
- ``"ULAW"`` (mu-law)
|
||||
- ``"ALAW"`` (a-law)
|
||||
|
||||
bits_per_sample (int or None, optional): Changes the bit depth for the
|
||||
supported formats.
|
||||
When ``format`` is one of ``"wav"``, ``"flac"`` or ``"sph"``,
|
||||
you can change the bit depth.
|
||||
Valid values are ``8``, ``16``, ``24``, ``32`` and ``64``.
|
||||
|
||||
Supported formats/encodings/bit depth/compression are:
|
||||
|
||||
``"wav"``
|
||||
- 32-bit floating-point PCM
|
||||
- 32-bit signed integer PCM
|
||||
- 24-bit signed integer PCM
|
||||
- 16-bit signed integer PCM
|
||||
- 8-bit unsigned integer PCM
|
||||
- 8-bit mu-law
|
||||
- 8-bit a-law
|
||||
|
||||
Note:
|
||||
Default encoding/bit depth is determined by the dtype of
|
||||
the input Tensor.
|
||||
|
||||
``"flac"``
|
||||
- 8-bit
|
||||
- 16-bit (default)
|
||||
- 24-bit
|
||||
|
||||
``"ogg"``, ``"vorbis"``
|
||||
- Doesn't accept changing configuration.
|
||||
|
||||
``"sph"``
|
||||
- 8-bit signed integer PCM
|
||||
- 16-bit signed integer PCM
|
||||
- 24-bit signed integer PCM
|
||||
- 32-bit signed integer PCM (default)
|
||||
- 8-bit mu-law
|
||||
- 8-bit a-law
|
||||
- 16-bit a-law
|
||||
- 24-bit a-law
|
||||
- 32-bit a-law
|
||||
|
||||
"""
|
||||
if src.ndim != 2:
|
||||
raise ValueError(f"Expected 2D Tensor, got {src.ndim}D.")
|
||||
if compression is not None:
|
||||
warnings.warn(
|
||||
'`save` function of "soundfile" backend does not support "compression" parameter. '
|
||||
"The argument is silently ignored.")
|
||||
if hasattr(filepath, "write"):
|
||||
if format is None:
|
||||
raise RuntimeError(
|
||||
"`format` is required when saving to file object.")
|
||||
ext = format.lower()
|
||||
else:
|
||||
ext = str(filepath).split(".")[-1].lower()
|
||||
|
||||
if bits_per_sample not in (None, 8, 16, 24, 32, 64):
|
||||
raise ValueError("Invalid bits_per_sample.")
|
||||
if bits_per_sample == 24:
|
||||
warnings.warn(
|
||||
"Saving audio with 24 bits per sample might warp samples near -1. "
|
||||
"Using 16 bits per sample might be able to avoid this.")
|
||||
subtype = _get_subtype(src.dtype, ext, encoding, bits_per_sample)
|
||||
|
||||
# sph is a extension used in TED-LIUM but soundfile does not recognize it as NIST format,
|
||||
# so we extend the extensions manually here
|
||||
if ext in ["nis", "nist", "sph"] and format is None:
|
||||
format = "NIST"
|
||||
|
||||
if channels_first:
|
||||
src = src.t()
|
||||
|
||||
soundfile.write(
|
||||
file=filepath,
|
||||
data=src,
|
||||
samplerate=sample_rate,
|
||||
subtype=subtype,
|
||||
format=format)
|
||||
|
||||
|
||||
_SUBTYPE2DTYPE = {
|
||||
"PCM_S8": "int8",
|
||||
"PCM_U8": "uint8",
|
||||
"PCM_16": "int16",
|
||||
"PCM_32": "int32",
|
||||
"FLOAT": "float32",
|
||||
"DOUBLE": "float64",
|
||||
}
|
||||
|
||||
|
||||
def load(
|
||||
filepath: str,
|
||||
frame_offset: int=0,
|
||||
num_frames: int=-1,
|
||||
normalize: bool=True,
|
||||
channels_first: bool=True,
|
||||
format: Optional[str]=None, ) -> Tuple[paddle.Tensor, int]:
|
||||
"""Load audio data from file.
|
||||
|
||||
Note:
|
||||
The formats this function can handle depend on the soundfile installation.
|
||||
This function is tested on the following formats;
|
||||
|
||||
* WAV
|
||||
|
||||
* 32-bit floating-point
|
||||
* 32-bit signed integer
|
||||
* 16-bit signed integer
|
||||
* 8-bit unsigned integer
|
||||
|
||||
* FLAC
|
||||
* OGG/VORBIS
|
||||
* SPHERE
|
||||
|
||||
By default (``normalize=True``, ``channels_first=True``), this function returns Tensor with
|
||||
``float32`` dtype and the shape of `[channel, time]`.
|
||||
The samples are normalized to fit in the range of ``[-1.0, 1.0]``.
|
||||
|
||||
When the input format is WAV with integer type, such as 32-bit signed integer, 16-bit
|
||||
signed integer and 8-bit unsigned integer (24-bit signed integer is not supported),
|
||||
by providing ``normalize=False``, this function can return integer Tensor, where the samples
|
||||
are expressed within the whole range of the corresponding dtype, that is, ``int32`` tensor
|
||||
for 32-bit signed PCM, ``int16`` for 16-bit signed PCM and ``uint8`` for 8-bit unsigned PCM.
|
||||
|
||||
``normalize`` parameter has no effect on 32-bit floating-point WAV and other formats, such as
|
||||
``flac`` and ``mp3``.
|
||||
For these formats, this function always returns ``float32`` Tensor with values normalized to
|
||||
``[-1.0, 1.0]``.
|
||||
|
||||
Note:
|
||||
``filepath`` argument is intentionally annotated as ``str`` only, even though it accepts
|
||||
``pathlib.Path`` object as well. This is for the consistency with ``"sox_io"`` backend.
|
||||
|
||||
Args:
|
||||
filepath (path-like object or file-like object):
|
||||
Source of audio data.
|
||||
frame_offset (int, optional):
|
||||
Number of frames to skip before start reading data.
|
||||
num_frames (int, optional):
|
||||
Maximum number of frames to read. ``-1`` reads all the remaining samples,
|
||||
starting from ``frame_offset``.
|
||||
This function may return the less number of frames if there is not enough
|
||||
frames in the given file.
|
||||
normalize (bool, optional):
|
||||
When ``True``, this function always return ``float32``, and sample values are
|
||||
normalized to ``[-1.0, 1.0]``.
|
||||
If input file is integer WAV, giving ``False`` will change the resulting Tensor type to
|
||||
integer type.
|
||||
This argument has no effect for formats other than integer WAV type.
|
||||
channels_first (bool, optional):
|
||||
When True, the returned Tensor has dimension `[channel, time]`.
|
||||
Otherwise, the returned Tensor's dimension is `[time, channel]`.
|
||||
format (str or None, optional):
|
||||
Not used. PySoundFile does not accept format hint.
|
||||
|
||||
Returns:
|
||||
(paddle.Tensor, int): Resulting Tensor and sample rate.
|
||||
If the input file has integer wav format and normalization is off, then it has
|
||||
integer type, else ``float32`` type. If ``channels_first=True``, it has
|
||||
`[channel, time]` else `[time, channel]`.
|
||||
"""
|
||||
with soundfile.SoundFile(filepath, "r") as file_:
|
||||
if file_.format != "WAV" or normalize:
|
||||
dtype = "float32"
|
||||
elif file_.subtype not in _SUBTYPE2DTYPE:
|
||||
raise ValueError(f"Unsupported subtype: {file_.subtype}")
|
||||
else:
|
||||
dtype = _SUBTYPE2DTYPE[file_.subtype]
|
||||
|
||||
frames = file_._prepare_read(frame_offset, None, num_frames)
|
||||
waveform = file_.read(frames, dtype, always_2d=True)
|
||||
sample_rate = file_.samplerate
|
||||
|
||||
waveform = paddle.to_tensor(waveform)
|
||||
if channels_first:
|
||||
waveform = paddle.transpose(waveform, perm=[1, 0])
|
||||
return waveform, sample_rate
|
||||
|
||||
|
||||
# Mapping from soundfile subtype to number of bits per sample.
|
||||
# This is mostly heuristical and the value is set to 0 when it is irrelevant
|
||||
# (lossy formats) or when it can't be inferred.
|
||||
# For ADPCM (and G72X) subtypes, it's hard to infer the bit depth because it's not part of the standard:
|
||||
# According to https://en.wikipedia.org/wiki/Adaptive_differential_pulse-code_modulation#In_telephony,
|
||||
# the default seems to be 8 bits but it can be compressed further to 4 bits.
|
||||
# The dict is inspired from
|
||||
# https://github.com/bastibe/python-soundfile/blob/744efb4b01abc72498a96b09115b42a4cabd85e4/soundfile.py#L66-L94
|
||||
_SUBTYPE_TO_BITS_PER_SAMPLE = {
|
||||
"PCM_S8": 8, # Signed 8 bit data
|
||||
"PCM_16": 16, # Signed 16 bit data
|
||||
"PCM_24": 24, # Signed 24 bit data
|
||||
"PCM_32": 32, # Signed 32 bit data
|
||||
"PCM_U8": 8, # Unsigned 8 bit data (WAV and RAW only)
|
||||
"FLOAT": 32, # 32 bit float data
|
||||
"DOUBLE": 64, # 64 bit float data
|
||||
"ULAW": 8, # U-Law encoded. See https://en.wikipedia.org/wiki/G.711#Types
|
||||
"ALAW": 8, # A-Law encoded. See https://en.wikipedia.org/wiki/G.711#Types
|
||||
"IMA_ADPCM": 0, # IMA ADPCM.
|
||||
"MS_ADPCM": 0, # Microsoft ADPCM.
|
||||
"GSM610":
|
||||
0, # GSM 6.10 encoding. (Wikipedia says 1.625 bit depth?? https://en.wikipedia.org/wiki/Full_Rate)
|
||||
"VOX_ADPCM": 0, # OKI / Dialogix ADPCM
|
||||
"G721_32": 0, # 32kbs G721 ADPCM encoding.
|
||||
"G723_24": 0, # 24kbs G723 ADPCM encoding.
|
||||
"G723_40": 0, # 40kbs G723 ADPCM encoding.
|
||||
"DWVW_12": 12, # 12 bit Delta Width Variable Word encoding.
|
||||
"DWVW_16": 16, # 16 bit Delta Width Variable Word encoding.
|
||||
"DWVW_24": 24, # 24 bit Delta Width Variable Word encoding.
|
||||
"DWVW_N": 0, # N bit Delta Width Variable Word encoding.
|
||||
"DPCM_8": 8, # 8 bit differential PCM (XI only)
|
||||
"DPCM_16": 16, # 16 bit differential PCM (XI only)
|
||||
"VORBIS": 0, # Xiph Vorbis encoding. (lossy)
|
||||
"ALAC_16": 16, # Apple Lossless Audio Codec (16 bit).
|
||||
"ALAC_20": 20, # Apple Lossless Audio Codec (20 bit).
|
||||
"ALAC_24": 24, # Apple Lossless Audio Codec (24 bit).
|
||||
"ALAC_32": 32, # Apple Lossless Audio Codec (32 bit).
|
||||
}
|
||||
|
||||
|
||||
def _get_bit_depth(subtype):
|
||||
if subtype not in _SUBTYPE_TO_BITS_PER_SAMPLE:
|
||||
warnings.warn(
|
||||
f"The {subtype} subtype is unknown to PaddleAudio. As a result, the bits_per_sample "
|
||||
"attribute will be set to 0. If you are seeing this warning, please "
|
||||
"report by opening an issue on github (after checking for existing/closed ones). "
|
||||
"You may otherwise ignore this warning.")
|
||||
return _SUBTYPE_TO_BITS_PER_SAMPLE.get(subtype, 0)
|
||||
|
||||
|
||||
_SUBTYPE_TO_ENCODING = {
|
||||
"PCM_S8": "PCM_S",
|
||||
"PCM_16": "PCM_S",
|
||||
"PCM_24": "PCM_S",
|
||||
"PCM_32": "PCM_S",
|
||||
"PCM_U8": "PCM_U",
|
||||
"FLOAT": "PCM_F",
|
||||
"DOUBLE": "PCM_F",
|
||||
"ULAW": "ULAW",
|
||||
"ALAW": "ALAW",
|
||||
"VORBIS": "VORBIS",
|
||||
}
|
||||
|
||||
|
||||
def _get_encoding(format: str, subtype: str):
|
||||
if format == "FLAC":
|
||||
return "FLAC"
|
||||
return _SUBTYPE_TO_ENCODING.get(subtype, "UNKNOWN")
|
||||
|
||||
|
||||
def info(filepath: str, format: Optional[str]=None) -> AudioInfo:
|
||||
"""Get signal information of an audio file.
|
||||
|
||||
Note:
|
||||
``filepath`` argument is intentionally annotated as ``str`` only, even though it accepts
|
||||
``pathlib.Path`` object as well. This is for the consistency with ``"sox_io"`` backend,
|
||||
|
||||
Args:
|
||||
filepath (path-like object or file-like object):
|
||||
Source of audio data.
|
||||
format (str or None, optional):
|
||||
Not used. PySoundFile does not accept format hint.
|
||||
|
||||
Returns:
|
||||
AudioInfo: meta data of the given audio.
|
||||
|
||||
"""
|
||||
sinfo = soundfile.info(filepath)
|
||||
return AudioInfo(
|
||||
sinfo.samplerate,
|
||||
sinfo.frames,
|
||||
sinfo.channels,
|
||||
bits_per_sample=_get_bit_depth(sinfo.subtype),
|
||||
encoding=_get_encoding(sinfo.format, sinfo.subtype), )
|
@ -0,0 +1,15 @@
|
||||
# Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
from . import kaldi
|
||||
from . import librosa
|
@ -0,0 +1,643 @@
|
||||
# Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
# Modified from torchaudio(https://github.com/pytorch/audio)
|
||||
import math
|
||||
from typing import Tuple
|
||||
|
||||
import paddle
|
||||
from paddle import Tensor
|
||||
|
||||
from ..functional import create_dct
|
||||
from ..functional.window import get_window
|
||||
|
||||
__all__ = [
|
||||
'spectrogram',
|
||||
'fbank',
|
||||
'mfcc',
|
||||
]
|
||||
|
||||
# window types
|
||||
HANNING = 'hann'
|
||||
HAMMING = 'hamming'
|
||||
POVEY = 'povey'
|
||||
RECTANGULAR = 'rect'
|
||||
BLACKMAN = 'blackman'
|
||||
|
||||
|
||||
def _get_epsilon(dtype):
|
||||
return paddle.to_tensor(1e-07, dtype=dtype)
|
||||
|
||||
|
||||
def _next_power_of_2(x: int) -> int:
|
||||
return 1 if x == 0 else 2**(x - 1).bit_length()
|
||||
|
||||
|
||||
def _get_strided(waveform: Tensor,
|
||||
window_size: int,
|
||||
window_shift: int,
|
||||
snip_edges: bool) -> Tensor:
|
||||
assert waveform.dim() == 1
|
||||
num_samples = waveform.shape[0]
|
||||
|
||||
if snip_edges:
|
||||
if num_samples < window_size:
|
||||
return paddle.empty((0, 0), dtype=waveform.dtype)
|
||||
else:
|
||||
m = 1 + (num_samples - window_size) // window_shift
|
||||
else:
|
||||
reversed_waveform = paddle.flip(waveform, [0])
|
||||
m = (num_samples + (window_shift // 2)) // window_shift
|
||||
pad = window_size // 2 - window_shift // 2
|
||||
pad_right = reversed_waveform
|
||||
if pad > 0:
|
||||
pad_left = reversed_waveform[-pad:]
|
||||
waveform = paddle.concat((pad_left, waveform, pad_right), axis=0)
|
||||
else:
|
||||
waveform = paddle.concat((waveform[-pad:], pad_right), axis=0)
|
||||
|
||||
return paddle.signal.frame(waveform, window_size, window_shift)[:, :m].T
|
||||
|
||||
|
||||
def _feature_window_function(
|
||||
window_type: str,
|
||||
window_size: int,
|
||||
blackman_coeff: float,
|
||||
dtype: int, ) -> Tensor:
|
||||
if window_type == "hann":
|
||||
return get_window('hann', window_size, fftbins=False, dtype=dtype)
|
||||
elif window_type == "hamming":
|
||||
return get_window('hamming', window_size, fftbins=False, dtype=dtype)
|
||||
elif window_type == "povey":
|
||||
return get_window(
|
||||
'hann', window_size, fftbins=False, dtype=dtype).pow(0.85)
|
||||
elif window_type == "rect":
|
||||
return paddle.ones([window_size], dtype=dtype)
|
||||
elif window_type == "blackman":
|
||||
a = 2 * math.pi / (window_size - 1)
|
||||
window_function = paddle.arange(window_size, dtype=dtype)
|
||||
return (blackman_coeff - 0.5 * paddle.cos(a * window_function) +
|
||||
(0.5 - blackman_coeff) * paddle.cos(2 * a * window_function)
|
||||
).astype(dtype)
|
||||
else:
|
||||
raise Exception('Invalid window type ' + window_type)
|
||||
|
||||
|
||||
def _get_log_energy(strided_input: Tensor, epsilon: Tensor,
|
||||
energy_floor: float) -> Tensor:
|
||||
log_energy = paddle.maximum(strided_input.pow(2).sum(1), epsilon).log()
|
||||
if energy_floor == 0.0:
|
||||
return log_energy
|
||||
return paddle.maximum(
|
||||
log_energy,
|
||||
paddle.to_tensor(math.log(energy_floor), dtype=strided_input.dtype))
|
||||
|
||||
|
||||
def _get_waveform_and_window_properties(
|
||||
waveform: Tensor,
|
||||
channel: int,
|
||||
sr: int,
|
||||
frame_shift: float,
|
||||
frame_length: float,
|
||||
round_to_power_of_two: bool,
|
||||
preemphasis_coefficient: float) -> Tuple[Tensor, int, int, int]:
|
||||
channel = max(channel, 0)
|
||||
assert channel < waveform.shape[0], (
|
||||
'Invalid channel {} for size {}'.format(channel, waveform.shape[0]))
|
||||
waveform = waveform[channel, :] # size (n)
|
||||
window_shift = int(
|
||||
sr * frame_shift *
|
||||
0.001) # pass frame_shift and frame_length in milliseconds
|
||||
window_size = int(sr * frame_length * 0.001)
|
||||
padded_window_size = _next_power_of_2(
|
||||
window_size) if round_to_power_of_two else window_size
|
||||
|
||||
assert 2 <= window_size <= len(waveform), (
|
||||
'choose a window size {} that is [2, {}]'.format(window_size,
|
||||
len(waveform)))
|
||||
assert 0 < window_shift, '`window_shift` must be greater than 0'
|
||||
assert padded_window_size % 2 == 0, 'the padded `window_size` must be divisible by two.' \
|
||||
' use `round_to_power_of_two` or change `frame_length`'
|
||||
assert 0. <= preemphasis_coefficient <= 1.0, '`preemphasis_coefficient` must be between [0,1]'
|
||||
assert sr > 0, '`sr` must be greater than zero'
|
||||
return waveform, window_shift, window_size, padded_window_size
|
||||
|
||||
|
||||
def _get_window(waveform: Tensor,
|
||||
padded_window_size: int,
|
||||
window_size: int,
|
||||
window_shift: int,
|
||||
window_type: str,
|
||||
blackman_coeff: float,
|
||||
snip_edges: bool,
|
||||
raw_energy: bool,
|
||||
energy_floor: float,
|
||||
dither: float,
|
||||
remove_dc_offset: bool,
|
||||
preemphasis_coefficient: float) -> Tuple[Tensor, Tensor]:
|
||||
dtype = waveform.dtype
|
||||
epsilon = _get_epsilon(dtype)
|
||||
|
||||
# (m, window_size)
|
||||
strided_input = _get_strided(waveform, window_size, window_shift,
|
||||
snip_edges)
|
||||
|
||||
if dither != 0.0:
|
||||
x = paddle.maximum(epsilon,
|
||||
paddle.rand(strided_input.shape, dtype=dtype))
|
||||
rand_gauss = paddle.sqrt(-2 * x.log()) * paddle.cos(2 * math.pi * x)
|
||||
strided_input = strided_input + rand_gauss * dither
|
||||
|
||||
if remove_dc_offset:
|
||||
row_means = paddle.mean(strided_input, axis=1).unsqueeze(1) # (m, 1)
|
||||
strided_input = strided_input - row_means
|
||||
|
||||
if raw_energy:
|
||||
signal_log_energy = _get_log_energy(strided_input, epsilon,
|
||||
energy_floor) # (m)
|
||||
|
||||
if preemphasis_coefficient != 0.0:
|
||||
offset_strided_input = paddle.nn.functional.pad(
|
||||
strided_input.unsqueeze(0), (1, 0),
|
||||
data_format='NCL',
|
||||
mode='replicate').squeeze(0) # (m, window_size + 1)
|
||||
strided_input = strided_input - preemphasis_coefficient * offset_strided_input[:, :
|
||||
-1]
|
||||
|
||||
window_function = _feature_window_function(
|
||||
window_type, window_size, blackman_coeff,
|
||||
dtype).unsqueeze(0) # (1, window_size)
|
||||
strided_input = strided_input * window_function # (m, window_size)
|
||||
|
||||
# (m, padded_window_size)
|
||||
if padded_window_size != window_size:
|
||||
padding_right = padded_window_size - window_size
|
||||
strided_input = paddle.nn.functional.pad(
|
||||
strided_input.unsqueeze(0), (0, padding_right),
|
||||
data_format='NCL',
|
||||
mode='constant',
|
||||
value=0).squeeze(0)
|
||||
|
||||
if not raw_energy:
|
||||
signal_log_energy = _get_log_energy(strided_input, epsilon,
|
||||
energy_floor) # size (m)
|
||||
|
||||
return strided_input, signal_log_energy
|
||||
|
||||
|
||||
def _subtract_column_mean(tensor: Tensor, subtract_mean: bool) -> Tensor:
|
||||
if subtract_mean:
|
||||
col_means = paddle.mean(tensor, axis=0).unsqueeze(0)
|
||||
tensor = tensor - col_means
|
||||
return tensor
|
||||
|
||||
|
||||
def spectrogram(waveform: Tensor,
|
||||
blackman_coeff: float=0.42,
|
||||
channel: int=-1,
|
||||
dither: float=0.0,
|
||||
energy_floor: float=1.0,
|
||||
frame_length: float=25.0,
|
||||
frame_shift: float=10.0,
|
||||
preemphasis_coefficient: float=0.97,
|
||||
raw_energy: bool=True,
|
||||
remove_dc_offset: bool=True,
|
||||
round_to_power_of_two: bool=True,
|
||||
sr: int=16000,
|
||||
snip_edges: bool=True,
|
||||
subtract_mean: bool=False,
|
||||
window_type: str="povey") -> Tensor:
|
||||
"""Compute and return a spectrogram from a waveform. The output is identical to Kaldi's.
|
||||
|
||||
Args:
|
||||
waveform (Tensor): A waveform tensor with shape `(C, T)`.
|
||||
blackman_coeff (float, optional): Coefficient for Blackman window.. Defaults to 0.42.
|
||||
channel (int, optional): Select the channel of waveform. Defaults to -1.
|
||||
dither (float, optional): Dithering constant . Defaults to 0.0.
|
||||
energy_floor (float, optional): Floor on energy of the output Spectrogram. Defaults to 1.0.
|
||||
frame_length (float, optional): Frame length in milliseconds. Defaults to 25.0.
|
||||
frame_shift (float, optional): Shift between adjacent frames in milliseconds. Defaults to 10.0.
|
||||
preemphasis_coefficient (float, optional): Preemphasis coefficient for input waveform. Defaults to 0.97.
|
||||
raw_energy (bool, optional): Whether to compute before preemphasis and windowing. Defaults to True.
|
||||
remove_dc_offset (bool, optional): Whether to subtract mean from waveform on frames. Defaults to True.
|
||||
round_to_power_of_two (bool, optional): If True, round window size to power of two by zero-padding input
|
||||
to FFT. Defaults to True.
|
||||
sr (int, optional): Sample rate of input waveform. Defaults to 16000.
|
||||
snip_edges (bool, optional): Drop samples in the end of waveform that cann't fit a signal frame when it
|
||||
is set True. Otherwise performs reflect padding to the end of waveform. Defaults to True.
|
||||
subtract_mean (bool, optional): Whether to subtract mean of feature files. Defaults to False.
|
||||
window_type (str, optional): Choose type of window for FFT computation. Defaults to "povey".
|
||||
|
||||
Returns:
|
||||
Tensor: A spectrogram tensor with shape `(m, padded_window_size // 2 + 1)` where m is the number of frames
|
||||
depends on frame_length and frame_shift.
|
||||
"""
|
||||
dtype = waveform.dtype
|
||||
epsilon = _get_epsilon(dtype)
|
||||
|
||||
waveform, window_shift, window_size, padded_window_size = _get_waveform_and_window_properties(
|
||||
waveform, channel, sr, frame_shift, frame_length, round_to_power_of_two,
|
||||
preemphasis_coefficient)
|
||||
|
||||
strided_input, signal_log_energy = _get_window(
|
||||
waveform, padded_window_size, window_size, window_shift, window_type,
|
||||
blackman_coeff, snip_edges, raw_energy, energy_floor, dither,
|
||||
remove_dc_offset, preemphasis_coefficient)
|
||||
|
||||
# (m, padded_window_size // 2 + 1, 2)
|
||||
fft = paddle.fft.rfft(strided_input)
|
||||
|
||||
power_spectrum = paddle.maximum(
|
||||
fft.abs().pow(2.), epsilon).log() # (m, padded_window_size // 2 + 1)
|
||||
power_spectrum[:, 0] = signal_log_energy
|
||||
|
||||
power_spectrum = _subtract_column_mean(power_spectrum, subtract_mean)
|
||||
return power_spectrum
|
||||
|
||||
|
||||
def _inverse_mel_scale_scalar(mel_freq: float) -> float:
|
||||
return 700.0 * (math.exp(mel_freq / 1127.0) - 1.0)
|
||||
|
||||
|
||||
def _inverse_mel_scale(mel_freq: Tensor) -> Tensor:
|
||||
return 700.0 * ((mel_freq / 1127.0).exp() - 1.0)
|
||||
|
||||
|
||||
def _mel_scale_scalar(freq: float) -> float:
|
||||
return 1127.0 * math.log(1.0 + freq / 700.0)
|
||||
|
||||
|
||||
def _mel_scale(freq: Tensor) -> Tensor:
|
||||
return 1127.0 * (1.0 + freq / 700.0).log()
|
||||
|
||||
|
||||
def _vtln_warp_freq(vtln_low_cutoff: float,
|
||||
vtln_high_cutoff: float,
|
||||
low_freq: float,
|
||||
high_freq: float,
|
||||
vtln_warp_factor: float,
|
||||
freq: Tensor) -> Tensor:
|
||||
assert vtln_low_cutoff > low_freq, 'be sure to set the vtln_low option higher than low_freq'
|
||||
assert vtln_high_cutoff < high_freq, 'be sure to set the vtln_high option lower than high_freq [or negative]'
|
||||
l = vtln_low_cutoff * max(1.0, vtln_warp_factor)
|
||||
h = vtln_high_cutoff * min(1.0, vtln_warp_factor)
|
||||
scale = 1.0 / vtln_warp_factor
|
||||
Fl = scale * l
|
||||
Fh = scale * h
|
||||
assert l > low_freq and h < high_freq
|
||||
scale_left = (Fl - low_freq) / (l - low_freq)
|
||||
scale_right = (high_freq - Fh) / (high_freq - h)
|
||||
res = paddle.empty_like(freq)
|
||||
|
||||
outside_low_high_freq = paddle.less_than(freq, paddle.to_tensor(low_freq)) \
|
||||
| paddle.greater_than(freq, paddle.to_tensor(high_freq))
|
||||
before_l = paddle.less_than(freq, paddle.to_tensor(l))
|
||||
before_h = paddle.less_than(freq, paddle.to_tensor(h))
|
||||
after_h = paddle.greater_equal(freq, paddle.to_tensor(h))
|
||||
|
||||
res[after_h] = high_freq + scale_right * (freq[after_h] - high_freq)
|
||||
res[before_h] = scale * freq[before_h]
|
||||
res[before_l] = low_freq + scale_left * (freq[before_l] - low_freq)
|
||||
res[outside_low_high_freq] = freq[outside_low_high_freq]
|
||||
|
||||
return res
|
||||
|
||||
|
||||
def _vtln_warp_mel_freq(vtln_low_cutoff: float,
|
||||
vtln_high_cutoff: float,
|
||||
low_freq,
|
||||
high_freq: float,
|
||||
vtln_warp_factor: float,
|
||||
mel_freq: Tensor) -> Tensor:
|
||||
return _mel_scale(
|
||||
_vtln_warp_freq(vtln_low_cutoff, vtln_high_cutoff, low_freq, high_freq,
|
||||
vtln_warp_factor, _inverse_mel_scale(mel_freq)))
|
||||
|
||||
|
||||
def _get_mel_banks(num_bins: int,
|
||||
window_length_padded: int,
|
||||
sample_freq: float,
|
||||
low_freq: float,
|
||||
high_freq: float,
|
||||
vtln_low: float,
|
||||
vtln_high: float,
|
||||
vtln_warp_factor: float) -> Tuple[Tensor, Tensor]:
|
||||
assert num_bins > 3, 'Must have at least 3 mel bins'
|
||||
assert window_length_padded % 2 == 0
|
||||
num_fft_bins = window_length_padded / 2
|
||||
nyquist = 0.5 * sample_freq
|
||||
|
||||
if high_freq <= 0.0:
|
||||
high_freq += nyquist
|
||||
|
||||
assert (0.0 <= low_freq < nyquist) and (0.0 < high_freq <= nyquist) and (low_freq < high_freq), \
|
||||
('Bad values in options: low-freq {} and high-freq {} vs. nyquist {}'.format(low_freq, high_freq, nyquist))
|
||||
|
||||
fft_bin_width = sample_freq / window_length_padded
|
||||
mel_low_freq = _mel_scale_scalar(low_freq)
|
||||
mel_high_freq = _mel_scale_scalar(high_freq)
|
||||
|
||||
mel_freq_delta = (mel_high_freq - mel_low_freq) / (num_bins + 1)
|
||||
|
||||
if vtln_high < 0.0:
|
||||
vtln_high += nyquist
|
||||
|
||||
assert vtln_warp_factor == 1.0 or ((low_freq < vtln_low < high_freq) and
|
||||
(0.0 < vtln_high < high_freq) and (vtln_low < vtln_high)), \
|
||||
('Bad values in options: vtln-low {} and vtln-high {}, versus '
|
||||
'low-freq {} and high-freq {}'.format(vtln_low, vtln_high, low_freq, high_freq))
|
||||
|
||||
bin = paddle.arange(num_bins, dtype=paddle.float32).unsqueeze(1)
|
||||
# left_mel = mel_low_freq + bin * mel_freq_delta # (num_bins, 1)
|
||||
# center_mel = mel_low_freq + (bin + 1.0) * mel_freq_delta # (num_bins, 1)
|
||||
# right_mel = mel_low_freq + (bin + 2.0) * mel_freq_delta # (num_bins, 1)
|
||||
left_mel = mel_low_freq + bin * mel_freq_delta # (num_bins, 1)
|
||||
center_mel = left_mel + mel_freq_delta
|
||||
right_mel = center_mel + mel_freq_delta
|
||||
|
||||
if vtln_warp_factor != 1.0:
|
||||
left_mel = _vtln_warp_mel_freq(vtln_low, vtln_high, low_freq, high_freq,
|
||||
vtln_warp_factor, left_mel)
|
||||
center_mel = _vtln_warp_mel_freq(vtln_low, vtln_high, low_freq,
|
||||
high_freq, vtln_warp_factor,
|
||||
center_mel)
|
||||
right_mel = _vtln_warp_mel_freq(vtln_low, vtln_high, low_freq,
|
||||
high_freq, vtln_warp_factor, right_mel)
|
||||
|
||||
center_freqs = _inverse_mel_scale(center_mel) # (num_bins)
|
||||
# (1, num_fft_bins)
|
||||
mel = _mel_scale(fft_bin_width * paddle.arange(
|
||||
num_fft_bins, dtype=paddle.float32)).unsqueeze(0)
|
||||
|
||||
# (num_bins, num_fft_bins)
|
||||
up_slope = (mel - left_mel) / (center_mel - left_mel)
|
||||
down_slope = (right_mel - mel) / (right_mel - center_mel)
|
||||
|
||||
if vtln_warp_factor == 1.0:
|
||||
bins = paddle.maximum(
|
||||
paddle.zeros([1]), paddle.minimum(up_slope, down_slope))
|
||||
else:
|
||||
bins = paddle.zeros_like(up_slope)
|
||||
up_idx = paddle.greater_than(mel, left_mel) & paddle.less_than(
|
||||
mel, center_mel)
|
||||
down_idx = paddle.greater_than(mel, center_mel) & paddle.less_than(
|
||||
mel, right_mel)
|
||||
bins[up_idx] = up_slope[up_idx]
|
||||
bins[down_idx] = down_slope[down_idx]
|
||||
|
||||
return bins, center_freqs
|
||||
|
||||
|
||||
def fbank(waveform: Tensor,
|
||||
blackman_coeff: float=0.42,
|
||||
channel: int=-1,
|
||||
dither: float=0.0,
|
||||
energy_floor: float=1.0,
|
||||
frame_length: float=25.0,
|
||||
frame_shift: float=10.0,
|
||||
high_freq: float=0.0,
|
||||
htk_compat: bool=False,
|
||||
low_freq: float=20.0,
|
||||
n_mels: int=23,
|
||||
preemphasis_coefficient: float=0.97,
|
||||
raw_energy: bool=True,
|
||||
remove_dc_offset: bool=True,
|
||||
round_to_power_of_two: bool=True,
|
||||
sr: int=16000,
|
||||
snip_edges: bool=True,
|
||||
subtract_mean: bool=False,
|
||||
use_energy: bool=False,
|
||||
use_log_fbank: bool=True,
|
||||
use_power: bool=True,
|
||||
vtln_high: float=-500.0,
|
||||
vtln_low: float=100.0,
|
||||
vtln_warp: float=1.0,
|
||||
window_type: str="povey") -> Tensor:
|
||||
"""Compute and return filter banks from a waveform. The output is identical to Kaldi's.
|
||||
|
||||
Args:
|
||||
waveform (Tensor): A waveform tensor with shape `(C, T)`. `C` is in the range [0,1].
|
||||
blackman_coeff (float, optional): Coefficient for Blackman window.. Defaults to 0.42.
|
||||
channel (int, optional): Select the channel of waveform. Defaults to -1.
|
||||
dither (float, optional): Dithering constant . Defaults to 0.0.
|
||||
energy_floor (float, optional): Floor on energy of the output Spectrogram. Defaults to 1.0.
|
||||
frame_length (float, optional): Frame length in milliseconds. Defaults to 25.0.
|
||||
frame_shift (float, optional): Shift between adjacent frames in milliseconds. Defaults to 10.0.
|
||||
high_freq (float, optional): The upper cut-off frequency. Defaults to 0.0.
|
||||
htk_compat (bool, optional): Put energy to the last when it is set True. Defaults to False.
|
||||
low_freq (float, optional): The lower cut-off frequency. Defaults to 20.0.
|
||||
n_mels (int, optional): Number of output mel bins. Defaults to 23.
|
||||
preemphasis_coefficient (float, optional): Preemphasis coefficient for input waveform. Defaults to 0.97.
|
||||
raw_energy (bool, optional): Whether to compute before preemphasis and windowing. Defaults to True.
|
||||
remove_dc_offset (bool, optional): Whether to subtract mean from waveform on frames. Defaults to True.
|
||||
round_to_power_of_two (bool, optional): If True, round window size to power of two by zero-padding input
|
||||
to FFT. Defaults to True.
|
||||
sr (int, optional): Sample rate of input waveform. Defaults to 16000.
|
||||
snip_edges (bool, optional): Drop samples in the end of waveform that cann't fit a signal frame when it
|
||||
is set True. Otherwise performs reflect padding to the end of waveform. Defaults to True.
|
||||
subtract_mean (bool, optional): Whether to subtract mean of feature files. Defaults to False.
|
||||
use_energy (bool, optional): Add an dimension with energy of spectrogram to the output. Defaults to False.
|
||||
use_log_fbank (bool, optional): Return log fbank when it is set True. Defaults to True.
|
||||
use_power (bool, optional): Whether to use power instead of magnitude. Defaults to True.
|
||||
vtln_high (float, optional): High inflection point in piecewise linear VTLN warping function. Defaults to -500.0.
|
||||
vtln_low (float, optional): Low inflection point in piecewise linear VTLN warping function. Defaults to 100.0.
|
||||
vtln_warp (float, optional): Vtln warp factor. Defaults to 1.0.
|
||||
window_type (str, optional): Choose type of window for FFT computation. Defaults to "povey".
|
||||
|
||||
Returns:
|
||||
Tensor: A filter banks tensor with shape `(m, n_mels)`.
|
||||
"""
|
||||
dtype = waveform.dtype
|
||||
|
||||
waveform, window_shift, window_size, padded_window_size = _get_waveform_and_window_properties(
|
||||
waveform, channel, sr, frame_shift, frame_length, round_to_power_of_two,
|
||||
preemphasis_coefficient)
|
||||
|
||||
strided_input, signal_log_energy = _get_window(
|
||||
waveform, padded_window_size, window_size, window_shift, window_type,
|
||||
blackman_coeff, snip_edges, raw_energy, energy_floor, dither,
|
||||
remove_dc_offset, preemphasis_coefficient)
|
||||
|
||||
# (m, padded_window_size // 2 + 1)
|
||||
spectrum = paddle.fft.rfft(strided_input).abs()
|
||||
if use_power:
|
||||
spectrum = spectrum.pow(2.)
|
||||
|
||||
# (n_mels, padded_window_size // 2)
|
||||
mel_energies, _ = _get_mel_banks(n_mels, padded_window_size, sr, low_freq,
|
||||
high_freq, vtln_low, vtln_high, vtln_warp)
|
||||
# mel_energies = mel_energies.astype(dtype)
|
||||
assert mel_energies.dtype == dtype
|
||||
|
||||
# (n_mels, padded_window_size // 2 + 1)
|
||||
mel_energies = paddle.nn.functional.pad(
|
||||
mel_energies.unsqueeze(0), (0, 1),
|
||||
data_format='NCL',
|
||||
mode='constant',
|
||||
value=0).squeeze(0)
|
||||
|
||||
# (m, n_mels)
|
||||
mel_energies = paddle.mm(spectrum, mel_energies.T)
|
||||
if use_log_fbank:
|
||||
mel_energies = paddle.maximum(mel_energies, _get_epsilon(dtype)).log()
|
||||
|
||||
if use_energy:
|
||||
signal_log_energy = signal_log_energy.unsqueeze(1)
|
||||
if htk_compat:
|
||||
mel_energies = paddle.concat(
|
||||
(mel_energies, signal_log_energy), axis=1)
|
||||
else:
|
||||
mel_energies = paddle.concat(
|
||||
(signal_log_energy, mel_energies), axis=1)
|
||||
|
||||
# (m, n_mels + 1)
|
||||
mel_energies = _subtract_column_mean(mel_energies, subtract_mean)
|
||||
return mel_energies
|
||||
|
||||
|
||||
def _get_dct_matrix(n_mfcc: int, n_mels: int) -> Tensor:
|
||||
dct_matrix = create_dct(n_mels, n_mels, 'ortho')
|
||||
dct_matrix[:, 0] = math.sqrt(1 / float(n_mels))
|
||||
dct_matrix = dct_matrix[:, :n_mfcc] # (n_mels, n_mfcc)
|
||||
return dct_matrix
|
||||
|
||||
|
||||
def _get_lifter_coeffs(n_mfcc: int, cepstral_lifter: float) -> Tensor:
|
||||
i = paddle.arange(n_mfcc)
|
||||
return 1.0 + 0.5 * cepstral_lifter * paddle.sin(math.pi * i /
|
||||
cepstral_lifter)
|
||||
|
||||
|
||||
def mfcc(waveform: Tensor,
|
||||
blackman_coeff: float=0.42,
|
||||
cepstral_lifter: float=22.0,
|
||||
channel: int=-1,
|
||||
dither: float=0.0,
|
||||
energy_floor: float=1.0,
|
||||
frame_length: float=25.0,
|
||||
frame_shift: float=10.0,
|
||||
high_freq: float=0.0,
|
||||
htk_compat: bool=False,
|
||||
low_freq: float=20.0,
|
||||
n_mfcc: int=13,
|
||||
n_mels: int=23,
|
||||
preemphasis_coefficient: float=0.97,
|
||||
raw_energy: bool=True,
|
||||
remove_dc_offset: bool=True,
|
||||
round_to_power_of_two: bool=True,
|
||||
sr: int=16000,
|
||||
snip_edges: bool=True,
|
||||
subtract_mean: bool=False,
|
||||
use_energy: bool=False,
|
||||
vtln_high: float=-500.0,
|
||||
vtln_low: float=100.0,
|
||||
vtln_warp: float=1.0,
|
||||
window_type: str="povey") -> Tensor:
|
||||
"""Compute and return mel frequency cepstral coefficients from a waveform. The output is
|
||||
identical to Kaldi's.
|
||||
|
||||
Args:
|
||||
waveform (Tensor): A waveform tensor with shape `(C, T)`.
|
||||
blackman_coeff (float, optional): Coefficient for Blackman window.. Defaults to 0.42.
|
||||
cepstral_lifter (float, optional): Scaling of output mfccs. Defaults to 22.0.
|
||||
channel (int, optional): Select the channel of waveform. Defaults to -1.
|
||||
dither (float, optional): Dithering constant . Defaults to 0.0.
|
||||
energy_floor (float, optional): Floor on energy of the output Spectrogram. Defaults to 1.0.
|
||||
frame_length (float, optional): Frame length in milliseconds. Defaults to 25.0.
|
||||
frame_shift (float, optional): Shift between adjacent frames in milliseconds. Defaults to 10.0.
|
||||
high_freq (float, optional): The upper cut-off frequency. Defaults to 0.0.
|
||||
htk_compat (bool, optional): Put energy to the last when it is set True. Defaults to False.
|
||||
low_freq (float, optional): The lower cut-off frequency. Defaults to 20.0.
|
||||
n_mfcc (int, optional): Number of cepstra in MFCC. Defaults to 13.
|
||||
n_mels (int, optional): Number of output mel bins. Defaults to 23.
|
||||
preemphasis_coefficient (float, optional): Preemphasis coefficient for input waveform. Defaults to 0.97.
|
||||
raw_energy (bool, optional): Whether to compute before preemphasis and windowing. Defaults to True.
|
||||
remove_dc_offset (bool, optional): Whether to subtract mean from waveform on frames. Defaults to True.
|
||||
round_to_power_of_two (bool, optional): If True, round window size to power of two by zero-padding input
|
||||
to FFT. Defaults to True.
|
||||
sr (int, optional): Sample rate of input waveform. Defaults to 16000.
|
||||
snip_edges (bool, optional): Drop samples in the end of waveform that cann't fit a signal frame when it
|
||||
is set True. Otherwise performs reflect padding to the end of waveform. Defaults to True.
|
||||
subtract_mean (bool, optional): Whether to subtract mean of feature files. Defaults to False.
|
||||
use_energy (bool, optional): Add an dimension with energy of spectrogram to the output. Defaults to False.
|
||||
vtln_high (float, optional): High inflection point in piecewise linear VTLN warping function. Defaults to -500.0.
|
||||
vtln_low (float, optional): Low inflection point in piecewise linear VTLN warping function. Defaults to 100.0.
|
||||
vtln_warp (float, optional): Vtln warp factor. Defaults to 1.0.
|
||||
window_type (str, optional): Choose type of window for FFT computation. Defaults to POVEY.
|
||||
|
||||
Returns:
|
||||
Tensor: A mel frequency cepstral coefficients tensor with shape `(m, n_mfcc)`.
|
||||
"""
|
||||
assert n_mfcc <= n_mels, 'n_mfcc cannot be larger than n_mels: %d vs %d' % (
|
||||
n_mfcc, n_mels)
|
||||
|
||||
dtype = waveform.dtype
|
||||
|
||||
# (m, n_mels + use_energy)
|
||||
feature = fbank(
|
||||
waveform=waveform,
|
||||
blackman_coeff=blackman_coeff,
|
||||
channel=channel,
|
||||
dither=dither,
|
||||
energy_floor=energy_floor,
|
||||
frame_length=frame_length,
|
||||
frame_shift=frame_shift,
|
||||
high_freq=high_freq,
|
||||
htk_compat=htk_compat,
|
||||
low_freq=low_freq,
|
||||
n_mels=n_mels,
|
||||
preemphasis_coefficient=preemphasis_coefficient,
|
||||
raw_energy=raw_energy,
|
||||
remove_dc_offset=remove_dc_offset,
|
||||
round_to_power_of_two=round_to_power_of_two,
|
||||
sr=sr,
|
||||
snip_edges=snip_edges,
|
||||
subtract_mean=False,
|
||||
use_energy=use_energy,
|
||||
use_log_fbank=True,
|
||||
use_power=True,
|
||||
vtln_high=vtln_high,
|
||||
vtln_low=vtln_low,
|
||||
vtln_warp=vtln_warp,
|
||||
window_type=window_type)
|
||||
|
||||
if use_energy:
|
||||
# (m)
|
||||
signal_log_energy = feature[:, n_mels if htk_compat else 0]
|
||||
mel_offset = int(not htk_compat)
|
||||
feature = feature[:, mel_offset:(n_mels + mel_offset)]
|
||||
|
||||
# (n_mels, n_mfcc)
|
||||
dct_matrix = _get_dct_matrix(n_mfcc, n_mels).astype(dtype=dtype)
|
||||
|
||||
# (m, n_mfcc)
|
||||
feature = feature.matmul(dct_matrix)
|
||||
|
||||
if cepstral_lifter != 0.0:
|
||||
# (1, n_mfcc)
|
||||
lifter_coeffs = _get_lifter_coeffs(n_mfcc, cepstral_lifter).unsqueeze(0)
|
||||
feature *= lifter_coeffs.astype(dtype=dtype)
|
||||
|
||||
if use_energy:
|
||||
feature[:, 0] = signal_log_energy
|
||||
|
||||
if htk_compat:
|
||||
energy = feature[:, 0].unsqueeze(1) # (m, 1)
|
||||
feature = feature[:, 1:] # (m, n_mfcc - 1)
|
||||
if not use_energy:
|
||||
energy *= math.sqrt(2)
|
||||
|
||||
feature = paddle.concat((feature, energy), axis=1)
|
||||
|
||||
feature = _subtract_column_mean(feature, subtract_mean)
|
||||
return feature
|
@ -0,0 +1,788 @@
|
||||
# Copyright (c) 2021 PaddlePaddle Authors. All Rights Reserved.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
# Modified from librosa(https://github.com/librosa/librosa)
|
||||
import warnings
|
||||
from typing import List
|
||||
from typing import Optional
|
||||
from typing import Union
|
||||
|
||||
import numpy as np
|
||||
import scipy
|
||||
from numpy.lib.stride_tricks import as_strided
|
||||
from scipy import signal
|
||||
|
||||
from ..utils import depth_convert
|
||||
from ..utils import ParameterError
|
||||
|
||||
__all__ = [
|
||||
# dsp
|
||||
'stft',
|
||||
'mfcc',
|
||||
'hz_to_mel',
|
||||
'mel_to_hz',
|
||||
'mel_frequencies',
|
||||
'power_to_db',
|
||||
'compute_fbank_matrix',
|
||||
'melspectrogram',
|
||||
'spectrogram',
|
||||
'mu_encode',
|
||||
'mu_decode',
|
||||
# augmentation
|
||||
'depth_augment',
|
||||
'spect_augment',
|
||||
'random_crop1d',
|
||||
'random_crop2d',
|
||||
'adaptive_spect_augment',
|
||||
]
|
||||
|
||||
|
||||
def _pad_center(data: np.ndarray, size: int, axis: int=-1,
|
||||
**kwargs) -> np.ndarray:
|
||||
"""Pad an array to a target length along a target axis.
|
||||
|
||||
This differs from `np.pad` by centering the data prior to padding,
|
||||
analogous to `str.center`
|
||||
"""
|
||||
|
||||
kwargs.setdefault("mode", "constant")
|
||||
n = data.shape[axis]
|
||||
lpad = int((size - n) // 2)
|
||||
lengths = [(0, 0)] * data.ndim
|
||||
lengths[axis] = (lpad, int(size - n - lpad))
|
||||
|
||||
if lpad < 0:
|
||||
raise ParameterError(("Target size ({size:d}) must be "
|
||||
"at least input size ({n:d})"))
|
||||
|
||||
return np.pad(data, lengths, **kwargs)
|
||||
|
||||
|
||||
def _split_frames(x: np.ndarray,
|
||||
frame_length: int,
|
||||
hop_length: int,
|
||||
axis: int=-1) -> np.ndarray:
|
||||
"""Slice a data array into (overlapping) frames.
|
||||
|
||||
This function is aligned with librosa.frame
|
||||
"""
|
||||
|
||||
if not isinstance(x, np.ndarray):
|
||||
raise ParameterError(
|
||||
f"Input must be of type numpy.ndarray, given type(x)={type(x)}")
|
||||
|
||||
if x.shape[axis] < frame_length:
|
||||
raise ParameterError(f"Input is too short (n={x.shape[axis]:d})"
|
||||
f" for frame_length={frame_length:d}")
|
||||
|
||||
if hop_length < 1:
|
||||
raise ParameterError(f"Invalid hop_length: {hop_length:d}")
|
||||
|
||||
if axis == -1 and not x.flags["F_CONTIGUOUS"]:
|
||||
warnings.warn(f"librosa.util.frame called with axis={axis} "
|
||||
"on a non-contiguous input. This will result in a copy.")
|
||||
x = np.asfortranarray(x)
|
||||
elif axis == 0 and not x.flags["C_CONTIGUOUS"]:
|
||||
warnings.warn(f"librosa.util.frame called with axis={axis} "
|
||||
"on a non-contiguous input. This will result in a copy.")
|
||||
x = np.ascontiguousarray(x)
|
||||
|
||||
n_frames = 1 + (x.shape[axis] - frame_length) // hop_length
|
||||
strides = np.asarray(x.strides)
|
||||
|
||||
new_stride = np.prod(strides[strides > 0] // x.itemsize) * x.itemsize
|
||||
|
||||
if axis == -1:
|
||||
shape = list(x.shape)[:-1] + [frame_length, n_frames]
|
||||
strides = list(strides) + [hop_length * new_stride]
|
||||
|
||||
elif axis == 0:
|
||||
shape = [n_frames, frame_length] + list(x.shape)[1:]
|
||||
strides = [hop_length * new_stride] + list(strides)
|
||||
|
||||
else:
|
||||
raise ParameterError(f"Frame axis={axis} must be either 0 or -1")
|
||||
|
||||
return as_strided(x, shape=shape, strides=strides)
|
||||
|
||||
|
||||
def _check_audio(y, mono=True) -> bool:
|
||||
"""Determine whether a variable contains valid audio data.
|
||||
|
||||
The audio y must be a np.ndarray, ether 1-channel or two channel
|
||||
"""
|
||||
if not isinstance(y, np.ndarray):
|
||||
raise ParameterError("Audio data must be of type numpy.ndarray")
|
||||
if y.ndim > 2:
|
||||
raise ParameterError(
|
||||
f"Invalid shape for audio ndim={y.ndim:d}, shape={y.shape}")
|
||||
|
||||
if mono and y.ndim == 2:
|
||||
raise ParameterError(
|
||||
f"Invalid shape for mono audio ndim={y.ndim:d}, shape={y.shape}")
|
||||
|
||||
if (mono and len(y) == 0) or (not mono and y.shape[1] < 0):
|
||||
raise ParameterError(f"Audio is empty ndim={y.ndim:d}, shape={y.shape}")
|
||||
|
||||
if not np.issubdtype(y.dtype, np.floating):
|
||||
raise ParameterError("Audio data must be floating-point")
|
||||
|
||||
if not np.isfinite(y).all():
|
||||
raise ParameterError("Audio buffer is not finite everywhere")
|
||||
|
||||
return True
|
||||
|
||||
|
||||
def hz_to_mel(frequencies: Union[float, List[float], np.ndarray],
|
||||
htk: bool=False) -> np.ndarray:
|
||||
"""Convert Hz to Mels.
|
||||
|
||||
Args:
|
||||
frequencies (Union[float, List[float], np.ndarray]): Frequencies in Hz.
|
||||
htk (bool, optional): Use htk scaling. Defaults to False.
|
||||
|
||||
Returns:
|
||||
np.ndarray: Frequency in mels.
|
||||
"""
|
||||
freq = np.asanyarray(frequencies)
|
||||
|
||||
if htk:
|
||||
return 2595.0 * np.log10(1.0 + freq / 700.0)
|
||||
|
||||
# Fill in the linear part
|
||||
f_min = 0.0
|
||||
f_sp = 200.0 / 3
|
||||
|
||||
mels = (freq - f_min) / f_sp
|
||||
|
||||
# Fill in the log-scale part
|
||||
|
||||
min_log_hz = 1000.0 # beginning of log region (Hz)
|
||||
min_log_mel = (min_log_hz - f_min) / f_sp # same (Mels)
|
||||
logstep = np.log(6.4) / 27.0 # step size for log region
|
||||
|
||||
if freq.ndim:
|
||||
# If we have array data, vectorize
|
||||
log_t = freq >= min_log_hz
|
||||
mels[log_t] = min_log_mel + \
|
||||
np.log(freq[log_t] / min_log_hz) / logstep
|
||||
elif freq >= min_log_hz:
|
||||
# If we have scalar data, heck directly
|
||||
mels = min_log_mel + np.log(freq / min_log_hz) / logstep
|
||||
|
||||
return mels
|
||||
|
||||
|
||||
def mel_to_hz(mels: Union[float, List[float], np.ndarray],
|
||||
htk: int=False) -> np.ndarray:
|
||||
"""Convert mel bin numbers to frequencies.
|
||||
|
||||
Args:
|
||||
mels (Union[float, List[float], np.ndarray]): Frequency in mels.
|
||||
htk (bool, optional): Use htk scaling. Defaults to False.
|
||||
|
||||
Returns:
|
||||
np.ndarray: Frequencies in Hz.
|
||||
"""
|
||||
mel_array = np.asanyarray(mels)
|
||||
|
||||
if htk:
|
||||
return 700.0 * (10.0**(mel_array / 2595.0) - 1.0)
|
||||
|
||||
# Fill in the linear scale
|
||||
f_min = 0.0
|
||||
f_sp = 200.0 / 3
|
||||
freqs = f_min + f_sp * mel_array
|
||||
|
||||
# And now the nonlinear scale
|
||||
min_log_hz = 1000.0 # beginning of log region (Hz)
|
||||
min_log_mel = (min_log_hz - f_min) / f_sp # same (Mels)
|
||||
logstep = np.log(6.4) / 27.0 # step size for log region
|
||||
|
||||
if mel_array.ndim:
|
||||
# If we have vector data, vectorize
|
||||
log_t = mel_array >= min_log_mel
|
||||
freqs[log_t] = min_log_hz * \
|
||||
np.exp(logstep * (mel_array[log_t] - min_log_mel))
|
||||
elif mel_array >= min_log_mel:
|
||||
# If we have scalar data, check directly
|
||||
freqs = min_log_hz * np.exp(logstep * (mel_array - min_log_mel))
|
||||
|
||||
return freqs
|
||||
|
||||
|
||||
def mel_frequencies(n_mels: int=128,
|
||||
fmin: float=0.0,
|
||||
fmax: float=11025.0,
|
||||
htk: bool=False) -> np.ndarray:
|
||||
"""Compute mel frequencies.
|
||||
|
||||
Args:
|
||||
n_mels (int, optional): Number of mel bins. Defaults to 128.
|
||||
fmin (float, optional): Minimum frequency in Hz. Defaults to 0.0.
|
||||
fmax (float, optional): Maximum frequency in Hz. Defaults to 11025.0.
|
||||
htk (bool, optional): Use htk scaling. Defaults to False.
|
||||
|
||||
Returns:
|
||||
np.ndarray: Vector of n_mels frequencies in Hz with shape `(n_mels,)`.
|
||||
"""
|
||||
# 'Center freqs' of mel bands - uniformly spaced between limits
|
||||
min_mel = hz_to_mel(fmin, htk=htk)
|
||||
max_mel = hz_to_mel(fmax, htk=htk)
|
||||
|
||||
mels = np.linspace(min_mel, max_mel, n_mels)
|
||||
|
||||
return mel_to_hz(mels, htk=htk)
|
||||
|
||||
|
||||
def fft_frequencies(sr: int, n_fft: int) -> np.ndarray:
|
||||
"""Compute fourier frequencies.
|
||||
|
||||
Args:
|
||||
sr (int): Sample rate.
|
||||
n_fft (int): FFT size.
|
||||
|
||||
Returns:
|
||||
np.ndarray: FFT frequencies in Hz with shape `(n_fft//2 + 1,)`.
|
||||
"""
|
||||
return np.linspace(0, float(sr) / 2, int(1 + n_fft // 2), endpoint=True)
|
||||
|
||||
|
||||
def compute_fbank_matrix(sr: int,
|
||||
n_fft: int,
|
||||
n_mels: int=128,
|
||||
fmin: float=0.0,
|
||||
fmax: Optional[float]=None,
|
||||
htk: bool=False,
|
||||
norm: str="slaney",
|
||||
dtype: type=np.float32) -> np.ndarray:
|
||||
"""Compute fbank matrix.
|
||||
|
||||
Args:
|
||||
sr (int): Sample rate.
|
||||
n_fft (int): FFT size.
|
||||
n_mels (int, optional): Number of mel bins. Defaults to 128.
|
||||
fmin (float, optional): Minimum frequency in Hz. Defaults to 0.0.
|
||||
fmax (Optional[float], optional): Maximum frequency in Hz. Defaults to None.
|
||||
htk (bool, optional): Use htk scaling. Defaults to False.
|
||||
norm (str, optional): Type of normalization. Defaults to "slaney".
|
||||
dtype (type, optional): Data type. Defaults to np.float32.
|
||||
|
||||
|
||||
Returns:
|
||||
np.ndarray: Mel transform matrix with shape `(n_mels, n_fft//2 + 1)`.
|
||||
"""
|
||||
if norm != "slaney":
|
||||
raise ParameterError('norm must set to slaney')
|
||||
|
||||
if fmax is None:
|
||||
fmax = float(sr) / 2
|
||||
|
||||
# Initialize the weights
|
||||
n_mels = int(n_mels)
|
||||
weights = np.zeros((n_mels, int(1 + n_fft // 2)), dtype=dtype)
|
||||
|
||||
# Center freqs of each FFT bin
|
||||
fftfreqs = fft_frequencies(sr=sr, n_fft=n_fft)
|
||||
|
||||
# 'Center freqs' of mel bands - uniformly spaced between limits
|
||||
mel_f = mel_frequencies(n_mels + 2, fmin=fmin, fmax=fmax, htk=htk)
|
||||
|
||||
fdiff = np.diff(mel_f)
|
||||
ramps = np.subtract.outer(mel_f, fftfreqs)
|
||||
|
||||
for i in range(n_mels):
|
||||
# lower and upper slopes for all bins
|
||||
lower = -ramps[i] / fdiff[i]
|
||||
upper = ramps[i + 2] / fdiff[i + 1]
|
||||
|
||||
# .. then intersect them with each other and zero
|
||||
weights[i] = np.maximum(0, np.minimum(lower, upper))
|
||||
|
||||
if norm == "slaney":
|
||||
# Slaney-style mel is scaled to be approx constant energy per channel
|
||||
enorm = 2.0 / (mel_f[2:n_mels + 2] - mel_f[:n_mels])
|
||||
weights *= enorm[:, np.newaxis]
|
||||
|
||||
# Only check weights if f_mel[0] is positive
|
||||
if not np.all((mel_f[:-2] == 0) | (weights.max(axis=1) > 0)):
|
||||
# This means we have an empty channel somewhere
|
||||
warnings.warn("Empty filters detected in mel frequency basis. "
|
||||
"Some channels will produce empty responses. "
|
||||
"Try increasing your sampling rate (and fmax) or "
|
||||
"reducing n_mels.")
|
||||
|
||||
return weights
|
||||
|
||||
|
||||
def stft(x: np.ndarray,
|
||||
n_fft: int=2048,
|
||||
hop_length: Optional[int]=None,
|
||||
win_length: Optional[int]=None,
|
||||
window: str="hann",
|
||||
center: bool=True,
|
||||
dtype: type=np.complex64,
|
||||
pad_mode: str="reflect") -> np.ndarray:
|
||||
"""Short-time Fourier transform (STFT).
|
||||
|
||||
Args:
|
||||
x (np.ndarray): Input waveform in one dimension.
|
||||
n_fft (int, optional): FFT size. Defaults to 2048.
|
||||
hop_length (Optional[int], optional): Number of steps to advance between adjacent windows. Defaults to None.
|
||||
win_length (Optional[int], optional): The size of window. Defaults to None.
|
||||
window (str, optional): A string of window specification. Defaults to "hann".
|
||||
center (bool, optional): Whether to pad `x` to make that the :math:`t \times hop\\_length` at the center of `t`-th frame. Defaults to True.
|
||||
dtype (type, optional): Data type of STFT results. Defaults to np.complex64.
|
||||
pad_mode (str, optional): Choose padding pattern when `center` is `True`. Defaults to "reflect".
|
||||
|
||||
Returns:
|
||||
np.ndarray: The complex STFT output with shape `(n_fft//2 + 1, num_frames)`.
|
||||
"""
|
||||
_check_audio(x)
|
||||
|
||||
# By default, use the entire frame
|
||||
if win_length is None:
|
||||
win_length = n_fft
|
||||
|
||||
# Set the default hop, if it's not already specified
|
||||
if hop_length is None:
|
||||
hop_length = int(win_length // 4)
|
||||
|
||||
fft_window = signal.get_window(window, win_length, fftbins=True)
|
||||
|
||||
# Pad the window out to n_fft size
|
||||
fft_window = _pad_center(fft_window, n_fft)
|
||||
|
||||
# Reshape so that the window can be broadcast
|
||||
fft_window = fft_window.reshape((-1, 1))
|
||||
|
||||
# Pad the time series so that frames are centered
|
||||
if center:
|
||||
if n_fft > x.shape[-1]:
|
||||
warnings.warn(
|
||||
f"n_fft={n_fft} is too small for input signal of length={x.shape[-1]}"
|
||||
)
|
||||
x = np.pad(x, int(n_fft // 2), mode=pad_mode)
|
||||
|
||||
elif n_fft > x.shape[-1]:
|
||||
raise ParameterError(
|
||||
f"n_fft={n_fft} is too small for input signal of length={x.shape[-1]}"
|
||||
)
|
||||
|
||||
# Window the time series.
|
||||
x_frames = _split_frames(x, frame_length=n_fft, hop_length=hop_length)
|
||||
# Pre-allocate the STFT matrix
|
||||
stft_matrix = np.empty(
|
||||
(int(1 + n_fft // 2), x_frames.shape[1]), dtype=dtype, order="F")
|
||||
fft = np.fft # use numpy fft as default
|
||||
# Constrain STFT block sizes to 256 KB
|
||||
MAX_MEM_BLOCK = 2**8 * 2**10
|
||||
# how many columns can we fit within MAX_MEM_BLOCK?
|
||||
n_columns = MAX_MEM_BLOCK // (stft_matrix.shape[0] * stft_matrix.itemsize)
|
||||
n_columns = max(n_columns, 1)
|
||||
|
||||
for bl_s in range(0, stft_matrix.shape[1], n_columns):
|
||||
bl_t = min(bl_s + n_columns, stft_matrix.shape[1])
|
||||
stft_matrix[:, bl_s:bl_t] = fft.rfft(
|
||||
fft_window * x_frames[:, bl_s:bl_t], axis=0)
|
||||
|
||||
return stft_matrix
|
||||
|
||||
|
||||
def power_to_db(spect: np.ndarray,
|
||||
ref: float=1.0,
|
||||
amin: float=1e-10,
|
||||
top_db: Optional[float]=80.0) -> np.ndarray:
|
||||
"""Convert a power spectrogram (amplitude squared) to decibel (dB) units. The function computes the scaling `10 * log10(x / ref)` in a numerically stable way.
|
||||
|
||||
Args:
|
||||
spect (np.ndarray): STFT power spectrogram of an input waveform.
|
||||
ref (float, optional): The reference value. If smaller than 1.0, the db level of the signal will be pulled up accordingly. Otherwise, the db level is pushed down. Defaults to 1.0.
|
||||
amin (float, optional): Minimum threshold. Defaults to 1e-10.
|
||||
top_db (Optional[float], optional): Threshold the output at `top_db` below the peak. Defaults to 80.0.
|
||||
|
||||
Returns:
|
||||
np.ndarray: Power spectrogram in db scale.
|
||||
"""
|
||||
spect = np.asarray(spect)
|
||||
|
||||
if amin <= 0:
|
||||
raise ParameterError("amin must be strictly positive")
|
||||
|
||||
if np.issubdtype(spect.dtype, np.complexfloating):
|
||||
warnings.warn(
|
||||
"power_to_db was called on complex input so phase "
|
||||
"information will be discarded. To suppress this warning, "
|
||||
"call power_to_db(np.abs(D)**2) instead.")
|
||||
magnitude = np.abs(spect)
|
||||
else:
|
||||
magnitude = spect
|
||||
|
||||
if callable(ref):
|
||||
# User supplied a function to calculate reference power
|
||||
ref_value = ref(magnitude)
|
||||
else:
|
||||
ref_value = np.abs(ref)
|
||||
|
||||
log_spec = 10.0 * np.log10(np.maximum(amin, magnitude))
|
||||
log_spec -= 10.0 * np.log10(np.maximum(amin, ref_value))
|
||||
|
||||
if top_db is not None:
|
||||
if top_db < 0:
|
||||
raise ParameterError("top_db must be non-negative")
|
||||
log_spec = np.maximum(log_spec, log_spec.max() - top_db)
|
||||
|
||||
return log_spec
|
||||
|
||||
|
||||
def mfcc(x: np.ndarray,
|
||||
sr: int=16000,
|
||||
spect: Optional[np.ndarray]=None,
|
||||
n_mfcc: int=20,
|
||||
dct_type: int=2,
|
||||
norm: str="ortho",
|
||||
lifter: int=0,
|
||||
**kwargs) -> np.ndarray:
|
||||
"""Mel-frequency cepstral coefficients (MFCCs)
|
||||
|
||||
Args:
|
||||
x (np.ndarray): Input waveform in one dimension.
|
||||
sr (int, optional): Sample rate. Defaults to 16000.
|
||||
spect (Optional[np.ndarray], optional): Input log-power Mel spectrogram. Defaults to None.
|
||||
n_mfcc (int, optional): Number of cepstra in MFCC. Defaults to 20.
|
||||
dct_type (int, optional): Discrete cosine transform (DCT) type. Defaults to 2.
|
||||
norm (str, optional): Type of normalization. Defaults to "ortho".
|
||||
lifter (int, optional): Cepstral filtering. Defaults to 0.
|
||||
|
||||
Returns:
|
||||
np.ndarray: Mel frequency cepstral coefficients array with shape `(n_mfcc, num_frames)`.
|
||||
"""
|
||||
if spect is None:
|
||||
spect = melspectrogram(x, sr=sr, **kwargs)
|
||||
|
||||
M = scipy.fftpack.dct(spect, axis=0, type=dct_type, norm=norm)[:n_mfcc]
|
||||
|
||||
if lifter > 0:
|
||||
factor = np.sin(np.pi * np.arange(1, 1 + n_mfcc, dtype=M.dtype) /
|
||||
lifter)
|
||||
return M * factor[:, np.newaxis]
|
||||
elif lifter == 0:
|
||||
return M
|
||||
else:
|
||||
raise ParameterError(
|
||||
f"MFCC lifter={lifter} must be a non-negative number")
|
||||
|
||||
|
||||
def melspectrogram(x: np.ndarray,
|
||||
sr: int=16000,
|
||||
window_size: int=512,
|
||||
hop_length: int=320,
|
||||
n_mels: int=64,
|
||||
fmin: float=50.0,
|
||||
fmax: Optional[float]=None,
|
||||
window: str='hann',
|
||||
center: bool=True,
|
||||
pad_mode: str='reflect',
|
||||
power: float=2.0,
|
||||
to_db: bool=True,
|
||||
ref: float=1.0,
|
||||
amin: float=1e-10,
|
||||
top_db: Optional[float]=None) -> np.ndarray:
|
||||
"""Compute mel-spectrogram.
|
||||
|
||||
Args:
|
||||
x (np.ndarray): Input waveform in one dimension.
|
||||
sr (int, optional): Sample rate. Defaults to 16000.
|
||||
window_size (int, optional): Size of FFT and window length. Defaults to 512.
|
||||
hop_length (int, optional): Number of steps to advance between adjacent windows. Defaults to 320.
|
||||
n_mels (int, optional): Number of mel bins. Defaults to 64.
|
||||
fmin (float, optional): Minimum frequency in Hz. Defaults to 50.0.
|
||||
fmax (Optional[float], optional): Maximum frequency in Hz. Defaults to None.
|
||||
window (str, optional): A string of window specification. Defaults to "hann".
|
||||
center (bool, optional): Whether to pad `x` to make that the :math:`t \times hop\\_length` at the center of `t`-th frame. Defaults to True.
|
||||
pad_mode (str, optional): Choose padding pattern when `center` is `True`. Defaults to "reflect".
|
||||
power (float, optional): Exponent for the magnitude melspectrogram. Defaults to 2.0.
|
||||
to_db (bool, optional): Enable db scale. Defaults to True.
|
||||
ref (float, optional): The reference value. If smaller than 1.0, the db level of the signal will be pulled up accordingly. Otherwise, the db level is pushed down. Defaults to 1.0.
|
||||
amin (float, optional): Minimum threshold. Defaults to 1e-10.
|
||||
top_db (Optional[float], optional): Threshold the output at `top_db` below the peak. Defaults to None.
|
||||
|
||||
Returns:
|
||||
np.ndarray: The mel-spectrogram in power scale or db scale with shape `(n_mels, num_frames)`.
|
||||
"""
|
||||
_check_audio(x, mono=True)
|
||||
if len(x) <= 0:
|
||||
raise ParameterError('The input waveform is empty')
|
||||
|
||||
if fmax is None:
|
||||
fmax = sr // 2
|
||||
if fmin < 0 or fmin >= fmax:
|
||||
raise ParameterError('fmin and fmax must statisfy 0<fmin<fmax')
|
||||
|
||||
s = stft(
|
||||
x,
|
||||
n_fft=window_size,
|
||||
hop_length=hop_length,
|
||||
win_length=window_size,
|
||||
window=window,
|
||||
center=center,
|
||||
pad_mode=pad_mode)
|
||||
|
||||
spect_power = np.abs(s)**power
|
||||
fb_matrix = compute_fbank_matrix(
|
||||
sr=sr, n_fft=window_size, n_mels=n_mels, fmin=fmin, fmax=fmax)
|
||||
mel_spect = np.matmul(fb_matrix, spect_power)
|
||||
if to_db:
|
||||
return power_to_db(mel_spect, ref=ref, amin=amin, top_db=top_db)
|
||||
else:
|
||||
return mel_spect
|
||||
|
||||
|
||||
def spectrogram(x: np.ndarray,
|
||||
sr: int=16000,
|
||||
window_size: int=512,
|
||||
hop_length: int=320,
|
||||
window: str='hann',
|
||||
center: bool=True,
|
||||
pad_mode: str='reflect',
|
||||
power: float=2.0) -> np.ndarray:
|
||||
"""Compute spectrogram.
|
||||
|
||||
Args:
|
||||
x (np.ndarray): Input waveform in one dimension.
|
||||
sr (int, optional): Sample rate. Defaults to 16000.
|
||||
window_size (int, optional): Size of FFT and window length. Defaults to 512.
|
||||
hop_length (int, optional): Number of steps to advance between adjacent windows. Defaults to 320.
|
||||
window (str, optional): A string of window specification. Defaults to "hann".
|
||||
center (bool, optional): Whether to pad `x` to make that the :math:`t \times hop\\_length` at the center of `t`-th frame. Defaults to True.
|
||||
pad_mode (str, optional): Choose padding pattern when `center` is `True`. Defaults to "reflect".
|
||||
power (float, optional): Exponent for the magnitude melspectrogram. Defaults to 2.0.
|
||||
|
||||
Returns:
|
||||
np.ndarray: The STFT spectrogram in power scale `(n_fft//2 + 1, num_frames)`.
|
||||
"""
|
||||
|
||||
s = stft(
|
||||
x,
|
||||
n_fft=window_size,
|
||||
hop_length=hop_length,
|
||||
win_length=window_size,
|
||||
window=window,
|
||||
center=center,
|
||||
pad_mode=pad_mode)
|
||||
|
||||
return np.abs(s)**power
|
||||
|
||||
|
||||
def mu_encode(x: np.ndarray, mu: int=255, quantized: bool=True) -> np.ndarray:
|
||||
"""Mu-law encoding. Encode waveform based on mu-law companding. When quantized is True, the result will be converted to integer in range `[0,mu-1]`. Otherwise, the resulting waveform is in range `[-1,1]`.
|
||||
|
||||
Args:
|
||||
x (np.ndarray): The input waveform to encode.
|
||||
mu (int, optional): The endoceding parameter. Defaults to 255.
|
||||
quantized (bool, optional): If `True`, quantize the encoded values into `1 + mu` distinct integer values. Defaults to True.
|
||||
|
||||
Returns:
|
||||
np.ndarray: The mu-law encoded waveform.
|
||||
"""
|
||||
mu = 255
|
||||
y = np.sign(x) * np.log1p(mu * np.abs(x)) / np.log1p(mu)
|
||||
if quantized:
|
||||
y = np.floor((y + 1) / 2 * mu + 0.5) # convert to [0 , mu-1]
|
||||
return y
|
||||
|
||||
|
||||
def mu_decode(y: np.ndarray, mu: int=255, quantized: bool=True) -> np.ndarray:
|
||||
"""Mu-law decoding. Compute the mu-law decoding given an input code. It assumes that the input `y` is in range `[0,mu-1]` when quantize is True and `[-1,1]` otherwise.
|
||||
|
||||
Args:
|
||||
y (np.ndarray): The encoded waveform.
|
||||
mu (int, optional): The endoceding parameter. Defaults to 255.
|
||||
quantized (bool, optional): If `True`, the input is assumed to be quantized to `1 + mu` distinct integer values. Defaults to True.
|
||||
|
||||
Returns:
|
||||
np.ndarray: The mu-law decoded waveform.
|
||||
"""
|
||||
if mu < 1:
|
||||
raise ParameterError('mu is typically set as 2**k-1, k=1, 2, 3,...')
|
||||
|
||||
mu = mu - 1
|
||||
if quantized: # undo the quantization
|
||||
y = y * 2 / mu - 1
|
||||
x = np.sign(y) / mu * ((1 + mu)**np.abs(y) - 1)
|
||||
return x
|
||||
|
||||
|
||||
def _randint(high: int) -> int:
|
||||
"""Generate one random integer in range [0 high)
|
||||
|
||||
This is a helper function for random data augmentation
|
||||
"""
|
||||
return int(np.random.randint(0, high=high))
|
||||
|
||||
|
||||
def depth_augment(y: np.ndarray,
|
||||
choices: List=['int8', 'int16'],
|
||||
probs: List[float]=[0.5, 0.5]) -> np.ndarray:
|
||||
""" Audio depth augmentation. Do audio depth augmentation to simulate the distortion brought by quantization.
|
||||
|
||||
Args:
|
||||
y (np.ndarray): Input waveform array in 1D or 2D.
|
||||
choices (List, optional): A list of data type to depth conversion. Defaults to ['int8', 'int16'].
|
||||
probs (List[float], optional): Probabilities to depth conversion. Defaults to [0.5, 0.5].
|
||||
|
||||
Returns:
|
||||
np.ndarray: The augmented waveform.
|
||||
"""
|
||||
assert len(probs) == len(
|
||||
choices
|
||||
), 'number of choices {} must be equal to size of probs {}'.format(
|
||||
len(choices), len(probs))
|
||||
depth = np.random.choice(choices, p=probs)
|
||||
src_depth = y.dtype
|
||||
y1 = depth_convert(y, depth)
|
||||
y2 = depth_convert(y1, src_depth)
|
||||
|
||||
return y2
|
||||
|
||||
|
||||
def adaptive_spect_augment(spect: np.ndarray,
|
||||
tempo_axis: int=0,
|
||||
level: float=0.1) -> np.ndarray:
|
||||
"""Do adaptive spectrogram augmentation. The level of the augmentation is govern by the parameter level, ranging from 0 to 1, with 0 represents no augmentation.
|
||||
|
||||
Args:
|
||||
spect (np.ndarray): Input spectrogram.
|
||||
tempo_axis (int, optional): Indicate the tempo axis. Defaults to 0.
|
||||
level (float, optional): The level factor of masking. Defaults to 0.1.
|
||||
|
||||
Returns:
|
||||
np.ndarray: The augmented spectrogram.
|
||||
"""
|
||||
assert spect.ndim == 2., 'only supports 2d tensor or numpy array'
|
||||
if tempo_axis == 0:
|
||||
nt, nf = spect.shape
|
||||
else:
|
||||
nf, nt = spect.shape
|
||||
|
||||
time_mask_width = int(nt * level * 0.5)
|
||||
freq_mask_width = int(nf * level * 0.5)
|
||||
|
||||
num_time_mask = int(10 * level)
|
||||
num_freq_mask = int(10 * level)
|
||||
|
||||
if tempo_axis == 0:
|
||||
for _ in range(num_time_mask):
|
||||
start = _randint(nt - time_mask_width)
|
||||
spect[start:start + time_mask_width, :] = 0
|
||||
for _ in range(num_freq_mask):
|
||||
start = _randint(nf - freq_mask_width)
|
||||
spect[:, start:start + freq_mask_width] = 0
|
||||
else:
|
||||
for _ in range(num_time_mask):
|
||||
start = _randint(nt - time_mask_width)
|
||||
spect[:, start:start + time_mask_width] = 0
|
||||
for _ in range(num_freq_mask):
|
||||
start = _randint(nf - freq_mask_width)
|
||||
spect[start:start + freq_mask_width, :] = 0
|
||||
|
||||
return spect
|
||||
|
||||
|
||||
def spect_augment(spect: np.ndarray,
|
||||
tempo_axis: int=0,
|
||||
max_time_mask: int=3,
|
||||
max_freq_mask: int=3,
|
||||
max_time_mask_width: int=30,
|
||||
max_freq_mask_width: int=20) -> np.ndarray:
|
||||
"""Do spectrogram augmentation in both time and freq axis.
|
||||
|
||||
Args:
|
||||
spect (np.ndarray): Input spectrogram.
|
||||
tempo_axis (int, optional): Indicate the tempo axis. Defaults to 0.
|
||||
max_time_mask (int, optional): Maximum number of time masking. Defaults to 3.
|
||||
max_freq_mask (int, optional): Maximum number of frequency masking. Defaults to 3.
|
||||
max_time_mask_width (int, optional): Maximum width of time masking. Defaults to 30.
|
||||
max_freq_mask_width (int, optional): Maximum width of frequency masking. Defaults to 20.
|
||||
|
||||
Returns:
|
||||
np.ndarray: The augmented spectrogram.
|
||||
"""
|
||||
assert spect.ndim == 2., 'only supports 2d tensor or numpy array'
|
||||
if tempo_axis == 0:
|
||||
nt, nf = spect.shape
|
||||
else:
|
||||
nf, nt = spect.shape
|
||||
|
||||
num_time_mask = _randint(max_time_mask)
|
||||
num_freq_mask = _randint(max_freq_mask)
|
||||
|
||||
time_mask_width = _randint(max_time_mask_width)
|
||||
freq_mask_width = _randint(max_freq_mask_width)
|
||||
|
||||
if tempo_axis == 0:
|
||||
for _ in range(num_time_mask):
|
||||
start = _randint(nt - time_mask_width)
|
||||
spect[start:start + time_mask_width, :] = 0
|
||||
for _ in range(num_freq_mask):
|
||||
start = _randint(nf - freq_mask_width)
|
||||
spect[:, start:start + freq_mask_width] = 0
|
||||
else:
|
||||
for _ in range(num_time_mask):
|
||||
start = _randint(nt - time_mask_width)
|
||||
spect[:, start:start + time_mask_width] = 0
|
||||
for _ in range(num_freq_mask):
|
||||
start = _randint(nf - freq_mask_width)
|
||||
spect[start:start + freq_mask_width, :] = 0
|
||||
|
||||
return spect
|
||||
|
||||
|
||||
def random_crop1d(y: np.ndarray, crop_len: int) -> np.ndarray:
|
||||
""" Random cropping on a input waveform.
|
||||
|
||||
Args:
|
||||
y (np.ndarray): Input waveform array in 1D.
|
||||
crop_len (int): Length of waveform to crop.
|
||||
|
||||
Returns:
|
||||
np.ndarray: The cropped waveform.
|
||||
"""
|
||||
if y.ndim != 1:
|
||||
'only accept 1d tensor or numpy array'
|
||||
n = len(y)
|
||||
idx = _randint(n - crop_len)
|
||||
return y[idx:idx + crop_len]
|
||||
|
||||
|
||||
def random_crop2d(s: np.ndarray, crop_len: int,
|
||||
tempo_axis: int=0) -> np.ndarray:
|
||||
""" Random cropping on a spectrogram.
|
||||
|
||||
Args:
|
||||
s (np.ndarray): Input spectrogram in 2D.
|
||||
crop_len (int): Length of spectrogram to crop.
|
||||
tempo_axis (int, optional): Indicate the tempo axis. Defaults to 0.
|
||||
|
||||
Returns:
|
||||
np.ndarray: The cropped spectrogram.
|
||||
"""
|
||||
if tempo_axis >= s.ndim:
|
||||
raise ParameterError('axis out of range')
|
||||
|
||||
n = s.shape[tempo_axis]
|
||||
idx = _randint(high=n - crop_len)
|
||||
sli = [slice(None) for i in range(s.ndim)]
|
||||
sli[tempo_axis] = slice(idx, idx + crop_len)
|
||||
out = s[tuple(sli)]
|
||||
return out
|
@ -0,0 +1,15 @@
|
||||
# Copyright (c) 2021 PaddlePaddle Authors. All Rights Reserved.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License"
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
from .esc50 import ESC50
|
||||
from .voxceleb import VoxCeleb
|
@ -0,0 +1,100 @@
|
||||
# Copyright (c) 2021 PaddlePaddle Authors. All Rights Reserved.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
from typing import List
|
||||
|
||||
import numpy as np
|
||||
import paddle
|
||||
|
||||
from ..backends.soundfile_backend import soundfile_load as load_audio
|
||||
from ..compliance.kaldi import fbank as kaldi_fbank
|
||||
from ..compliance.kaldi import mfcc as kaldi_mfcc
|
||||
from ..compliance.librosa import melspectrogram
|
||||
from ..compliance.librosa import mfcc
|
||||
|
||||
feat_funcs = {
|
||||
'raw': None,
|
||||
'melspectrogram': melspectrogram,
|
||||
'mfcc': mfcc,
|
||||
'kaldi_fbank': kaldi_fbank,
|
||||
'kaldi_mfcc': kaldi_mfcc,
|
||||
}
|
||||
|
||||
|
||||
class AudioClassificationDataset(paddle.io.Dataset):
|
||||
"""
|
||||
Base class of audio classification dataset.
|
||||
"""
|
||||
|
||||
def __init__(self,
|
||||
files: List[str],
|
||||
labels: List[int],
|
||||
feat_type: str='raw',
|
||||
sample_rate: int=None,
|
||||
**kwargs):
|
||||
"""
|
||||
Ags:
|
||||
files (:obj:`List[str]`): A list of absolute path of audio files.
|
||||
labels (:obj:`List[int]`): Labels of audio files.
|
||||
feat_type (:obj:`str`, `optional`, defaults to `raw`):
|
||||
It identifies the feature type that user wants to extract of an audio file.
|
||||
"""
|
||||
super(AudioClassificationDataset, self).__init__()
|
||||
|
||||
if feat_type not in feat_funcs.keys():
|
||||
raise RuntimeError(
|
||||
f"Unknown feat_type: {feat_type}, it must be one in {list(feat_funcs.keys())}"
|
||||
)
|
||||
|
||||
self.files = files
|
||||
self.labels = labels
|
||||
|
||||
self.feat_type = feat_type
|
||||
self.sample_rate = sample_rate
|
||||
self.feat_config = kwargs # Pass keyword arguments to customize feature config
|
||||
|
||||
def _get_data(self, input_file: str):
|
||||
raise NotImplementedError
|
||||
|
||||
def _convert_to_record(self, idx):
|
||||
file, label = self.files[idx], self.labels[idx]
|
||||
|
||||
if self.sample_rate is None:
|
||||
waveform, sample_rate = load_audio(file)
|
||||
else:
|
||||
waveform, sample_rate = load_audio(file, sr=self.sample_rate)
|
||||
|
||||
feat_func = feat_funcs[self.feat_type]
|
||||
|
||||
record = {}
|
||||
if self.feat_type in ['kaldi_fbank', 'kaldi_mfcc']:
|
||||
waveform = paddle.to_tensor(waveform).unsqueeze(0) # (C, T)
|
||||
record['feat'] = feat_func(
|
||||
waveform=waveform, sr=self.sample_rate, **self.feat_config)
|
||||
else:
|
||||
record['feat'] = feat_func(
|
||||
waveform, sample_rate,
|
||||
**self.feat_config) if feat_func else waveform
|
||||
record['label'] = label
|
||||
return record
|
||||
|
||||
def __getitem__(self, idx):
|
||||
record = self._convert_to_record(idx)
|
||||
if self.feat_type in ['kaldi_fbank', 'kaldi_mfcc']:
|
||||
return self.keys[idx], record['feat'], record['label']
|
||||
else:
|
||||
return np.array(record['feat']).transpose(), np.array(
|
||||
record['label'], dtype=np.int64)
|
||||
|
||||
def __len__(self):
|
||||
return len(self.files)
|
@ -0,0 +1,152 @@
|
||||
# Copyright (c) 2021 PaddlePaddle Authors. All Rights Reserved.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
import collections
|
||||
import os
|
||||
from typing import List
|
||||
from typing import Tuple
|
||||
|
||||
from ...utils.env import DATA_HOME
|
||||
from ..utils.download import download_and_decompress
|
||||
from .dataset import AudioClassificationDataset
|
||||
|
||||
__all__ = ['ESC50']
|
||||
|
||||
|
||||
class ESC50(AudioClassificationDataset):
|
||||
"""
|
||||
The ESC-50 dataset is a labeled collection of 2000 environmental audio recordings
|
||||
suitable for benchmarking methods of environmental sound classification. The dataset
|
||||
consists of 5-second-long recordings organized into 50 semantical classes (with
|
||||
40 examples per class)
|
||||
|
||||
Reference:
|
||||
ESC: Dataset for Environmental Sound Classification
|
||||
http://dx.doi.org/10.1145/2733373.2806390
|
||||
"""
|
||||
|
||||
archieves = [
|
||||
{
|
||||
'url':
|
||||
'https://paddleaudio.bj.bcebos.com/datasets/ESC-50-master.zip',
|
||||
'md5': '7771e4b9d86d0945acce719c7a59305a',
|
||||
},
|
||||
]
|
||||
label_list = [
|
||||
# Animals
|
||||
'Dog',
|
||||
'Rooster',
|
||||
'Pig',
|
||||
'Cow',
|
||||
'Frog',
|
||||
'Cat',
|
||||
'Hen',
|
||||
'Insects (flying)',
|
||||
'Sheep',
|
||||
'Crow',
|
||||
# Natural soundscapes & water sounds
|
||||
'Rain',
|
||||
'Sea waves',
|
||||
'Crackling fire',
|
||||
'Crickets',
|
||||
'Chirping birds',
|
||||
'Water drops',
|
||||
'Wind',
|
||||
'Pouring water',
|
||||
'Toilet flush',
|
||||
'Thunderstorm',
|
||||
# Human, non-speech sounds
|
||||
'Crying baby',
|
||||
'Sneezing',
|
||||
'Clapping',
|
||||
'Breathing',
|
||||
'Coughing',
|
||||
'Footsteps',
|
||||
'Laughing',
|
||||
'Brushing teeth',
|
||||
'Snoring',
|
||||
'Drinking, sipping',
|
||||
# Interior/domestic sounds
|
||||
'Door knock',
|
||||
'Mouse click',
|
||||
'Keyboard typing',
|
||||
'Door, wood creaks',
|
||||
'Can opening',
|
||||
'Washing machine',
|
||||
'Vacuum cleaner',
|
||||
'Clock alarm',
|
||||
'Clock tick',
|
||||
'Glass breaking',
|
||||
# Exterior/urban noises
|
||||
'Helicopter',
|
||||
'Chainsaw',
|
||||
'Siren',
|
||||
'Car horn',
|
||||
'Engine',
|
||||
'Train',
|
||||
'Church bells',
|
||||
'Airplane',
|
||||
'Fireworks',
|
||||
'Hand saw',
|
||||
]
|
||||
meta = os.path.join('ESC-50-master', 'meta', 'esc50.csv')
|
||||
meta_info = collections.namedtuple(
|
||||
'META_INFO',
|
||||
('filename', 'fold', 'target', 'category', 'esc10', 'src_file', 'take'))
|
||||
audio_path = os.path.join('ESC-50-master', 'audio')
|
||||
|
||||
def __init__(self,
|
||||
mode: str='train',
|
||||
split: int=1,
|
||||
feat_type: str='raw',
|
||||
**kwargs):
|
||||
"""
|
||||
Ags:
|
||||
mode (:obj:`str`, `optional`, defaults to `train`):
|
||||
It identifies the dataset mode (train or dev).
|
||||
split (:obj:`int`, `optional`, defaults to 1):
|
||||
It specify the fold of dev dataset.
|
||||
feat_type (:obj:`str`, `optional`, defaults to `raw`):
|
||||
It identifies the feature type that user wants to extract of an audio file.
|
||||
"""
|
||||
files, labels = self._get_data(mode, split)
|
||||
super(ESC50, self).__init__(
|
||||
files=files, labels=labels, feat_type=feat_type, **kwargs)
|
||||
|
||||
def _get_meta_info(self) -> List[collections.namedtuple]:
|
||||
ret = []
|
||||
with open(os.path.join(DATA_HOME, self.meta), 'r') as rf:
|
||||
for line in rf.readlines()[1:]:
|
||||
ret.append(self.meta_info(*line.strip().split(',')))
|
||||
return ret
|
||||
|
||||
def _get_data(self, mode: str, split: int) -> Tuple[List[str], List[int]]:
|
||||
if not os.path.isdir(os.path.join(DATA_HOME, self.audio_path)) or \
|
||||
not os.path.isfile(os.path.join(DATA_HOME, self.meta)):
|
||||
download_and_decompress(self.archieves, DATA_HOME)
|
||||
|
||||
meta_info = self._get_meta_info()
|
||||
|
||||
files = []
|
||||
labels = []
|
||||
for sample in meta_info:
|
||||
filename, fold, target, _, _, _, _ = sample
|
||||
if mode == 'train' and int(fold) != split:
|
||||
files.append(os.path.join(DATA_HOME, self.audio_path, filename))
|
||||
labels.append(int(target))
|
||||
|
||||
if mode != 'train' and int(fold) == split:
|
||||
files.append(os.path.join(DATA_HOME, self.audio_path, filename))
|
||||
labels.append(int(target))
|
||||
|
||||
return files, labels
|
@ -0,0 +1,356 @@
|
||||
# Copyright (c) 2021 PaddlePaddle Authors. All Rights Reserved.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
import collections
|
||||
import csv
|
||||
import glob
|
||||
import os
|
||||
import random
|
||||
from multiprocessing import cpu_count
|
||||
from typing import List
|
||||
|
||||
from paddle.io import Dataset
|
||||
from pathos.multiprocessing import Pool
|
||||
from tqdm import tqdm
|
||||
|
||||
from ...utils.env import DATA_HOME
|
||||
from ..backends.soundfile_backend import soundfile_load as load_audio
|
||||
from ..utils.download import decompress
|
||||
from ..utils.download import download_and_decompress
|
||||
from .dataset import feat_funcs
|
||||
|
||||
__all__ = ['VoxCeleb']
|
||||
|
||||
|
||||
class VoxCeleb(Dataset):
|
||||
source_url = 'https://thor.robots.ox.ac.uk/~vgg/data/voxceleb/vox1a/'
|
||||
archieves_audio_dev = [
|
||||
{
|
||||
'url': source_url + 'vox1_dev_wav_partaa',
|
||||
'md5': 'e395d020928bc15670b570a21695ed96',
|
||||
},
|
||||
{
|
||||
'url': source_url + 'vox1_dev_wav_partab',
|
||||
'md5': 'bbfaaccefab65d82b21903e81a8a8020',
|
||||
},
|
||||
{
|
||||
'url': source_url + 'vox1_dev_wav_partac',
|
||||
'md5': '017d579a2a96a077f40042ec33e51512',
|
||||
},
|
||||
{
|
||||
'url': source_url + 'vox1_dev_wav_partad',
|
||||
'md5': '7bb1e9f70fddc7a678fa998ea8b3ba19',
|
||||
},
|
||||
]
|
||||
archieves_audio_test = [
|
||||
{
|
||||
'url': source_url + 'vox1_test_wav.zip',
|
||||
'md5': '185fdc63c3c739954633d50379a3d102',
|
||||
},
|
||||
]
|
||||
archieves_meta = [
|
||||
{
|
||||
'url':
|
||||
'https://www.robots.ox.ac.uk/~vgg/data/voxceleb/meta/veri_test2.txt',
|
||||
'md5':
|
||||
'b73110731c9223c1461fe49cb48dddfc',
|
||||
},
|
||||
]
|
||||
|
||||
num_speakers = 1211 # 1211 vox1, 5994 vox2, 7205 vox1+2, test speakers: 41
|
||||
sample_rate = 16000
|
||||
meta_info = collections.namedtuple(
|
||||
'META_INFO', ('id', 'duration', 'wav', 'start', 'stop', 'spk_id'))
|
||||
base_path = os.path.join(DATA_HOME, 'vox1')
|
||||
wav_path = os.path.join(base_path, 'wav')
|
||||
meta_path = os.path.join(base_path, 'meta')
|
||||
veri_test_file = os.path.join(meta_path, 'veri_test2.txt')
|
||||
csv_path = os.path.join(base_path, 'csv')
|
||||
subsets = ['train', 'dev', 'enroll', 'test']
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
subset: str='train',
|
||||
feat_type: str='raw',
|
||||
random_chunk: bool=True,
|
||||
chunk_duration: float=3.0, # seconds
|
||||
split_ratio: float=0.9, # train split ratio
|
||||
seed: int=0,
|
||||
target_dir: str=None,
|
||||
vox2_base_path=None,
|
||||
**kwargs):
|
||||
"""VoxCeleb data prepare and get the specific dataset audio info
|
||||
|
||||
Args:
|
||||
subset (str, optional): dataset name, such as train, dev, enroll or test. Defaults to 'train'.
|
||||
feat_type (str, optional): feat type, such raw, melspectrogram(fbank) or mfcc . Defaults to 'raw'.
|
||||
random_chunk (bool, optional): random select a duration from audio. Defaults to True.
|
||||
chunk_duration (float, optional): chunk duration if random_chunk flag is set. Defaults to 3.0.
|
||||
target_dir (str, optional): data dir, audio info will be stored in this directory. Defaults to None.
|
||||
vox2_base_path (_type_, optional): vox2 directory. vox2 data must be converted from m4a to wav. Defaults to None.
|
||||
"""
|
||||
assert subset in self.subsets, \
|
||||
'Dataset subset must be one in {}, but got {}'.format(self.subsets, subset)
|
||||
|
||||
self.subset = subset
|
||||
self.spk_id2label = {}
|
||||
self.feat_type = feat_type
|
||||
self.feat_config = kwargs
|
||||
self.random_chunk = random_chunk
|
||||
self.chunk_duration = chunk_duration
|
||||
self.split_ratio = split_ratio
|
||||
self.target_dir = target_dir if target_dir else VoxCeleb.base_path
|
||||
self.vox2_base_path = vox2_base_path
|
||||
|
||||
# if we set the target dir, we will change the vox data info data from base path to target dir
|
||||
VoxCeleb.csv_path = os.path.join(
|
||||
target_dir, "voxceleb", 'csv') if target_dir else VoxCeleb.csv_path
|
||||
VoxCeleb.meta_path = os.path.join(
|
||||
target_dir, "voxceleb",
|
||||
'meta') if target_dir else VoxCeleb.meta_path
|
||||
VoxCeleb.veri_test_file = os.path.join(VoxCeleb.meta_path,
|
||||
'veri_test2.txt')
|
||||
# self._data = self._get_data()[:1000] # KP: Small dataset test.
|
||||
self._data = self._get_data()
|
||||
super(VoxCeleb, self).__init__()
|
||||
|
||||
# Set up a seed to reproduce training or predicting result.
|
||||
# random.seed(seed)
|
||||
|
||||
def _get_data(self):
|
||||
# Download audio files.
|
||||
# We need the users to decompress all vox1/dev/wav and vox1/test/wav/ to vox1/wav/ dir
|
||||
# so, we check the vox1/wav dir status
|
||||
print(f"wav base path: {self.wav_path}")
|
||||
if not os.path.isdir(self.wav_path):
|
||||
print("start to download the voxceleb1 dataset")
|
||||
download_and_decompress( # multi-zip parts concatenate to vox1_dev_wav.zip
|
||||
self.archieves_audio_dev,
|
||||
self.base_path,
|
||||
decompress=False)
|
||||
download_and_decompress( # download the vox1_test_wav.zip and unzip
|
||||
self.archieves_audio_test,
|
||||
self.base_path,
|
||||
decompress=True)
|
||||
|
||||
# Download all parts and concatenate the files into one zip file.
|
||||
dev_zipfile = os.path.join(self.base_path, 'vox1_dev_wav.zip')
|
||||
print(f'Concatenating all parts to: {dev_zipfile}')
|
||||
os.system(
|
||||
f'cat {os.path.join(self.base_path, "vox1_dev_wav_parta*")} > {dev_zipfile}'
|
||||
)
|
||||
|
||||
# Extract all audio files of dev and test set.
|
||||
decompress(dev_zipfile, self.base_path)
|
||||
|
||||
# Download meta files.
|
||||
if not os.path.isdir(self.meta_path):
|
||||
print("prepare the meta data")
|
||||
download_and_decompress(
|
||||
self.archieves_meta, self.meta_path, decompress=False)
|
||||
|
||||
# Data preparation.
|
||||
if not os.path.isdir(self.csv_path):
|
||||
os.makedirs(self.csv_path)
|
||||
self.prepare_data()
|
||||
|
||||
data = []
|
||||
print(
|
||||
f"read the {self.subset} from {os.path.join(self.csv_path, f'{self.subset}.csv')}"
|
||||
)
|
||||
with open(os.path.join(self.csv_path, f'{self.subset}.csv'), 'r') as rf:
|
||||
for line in rf.readlines()[1:]:
|
||||
audio_id, duration, wav, start, stop, spk_id = line.strip(
|
||||
).split(',')
|
||||
data.append(
|
||||
self.meta_info(audio_id,
|
||||
float(duration), wav,
|
||||
int(start), int(stop), spk_id))
|
||||
|
||||
with open(os.path.join(self.meta_path, 'spk_id2label.txt'), 'r') as f:
|
||||
for line in f.readlines():
|
||||
spk_id, label = line.strip().split(' ')
|
||||
self.spk_id2label[spk_id] = int(label)
|
||||
|
||||
return data
|
||||
|
||||
def _convert_to_record(self, idx: int):
|
||||
sample = self._data[idx]
|
||||
|
||||
record = {}
|
||||
# To show all fields in a namedtuple: `type(sample)._fields`
|
||||
for field in type(sample)._fields:
|
||||
record[field] = getattr(sample, field)
|
||||
|
||||
waveform, sr = load_audio(record['wav'])
|
||||
|
||||
# random select a chunk audio samples from the audio
|
||||
if self.random_chunk:
|
||||
num_wav_samples = waveform.shape[0]
|
||||
num_chunk_samples = int(self.chunk_duration * sr)
|
||||
start = random.randint(0, num_wav_samples - num_chunk_samples - 1)
|
||||
stop = start + num_chunk_samples
|
||||
else:
|
||||
start = record['start']
|
||||
stop = record['stop']
|
||||
|
||||
waveform = waveform[start:stop]
|
||||
|
||||
assert self.feat_type in feat_funcs.keys(), \
|
||||
f"Unknown feat_type: {self.feat_type}, it must be one in {list(feat_funcs.keys())}"
|
||||
feat_func = feat_funcs[self.feat_type]
|
||||
feat = feat_func(
|
||||
waveform, sr=sr, **self.feat_config) if feat_func else waveform
|
||||
|
||||
record.update({'feat': feat})
|
||||
if self.subset in ['train',
|
||||
'dev']: # Labels are available in train and dev.
|
||||
record.update({'label': self.spk_id2label[record['spk_id']]})
|
||||
|
||||
return record
|
||||
|
||||
@staticmethod
|
||||
def _get_chunks(seg_dur, audio_id, audio_duration):
|
||||
num_chunks = int(audio_duration / seg_dur) # all in milliseconds
|
||||
|
||||
chunk_lst = [
|
||||
audio_id + "_" + str(i * seg_dur) + "_" + str(i * seg_dur + seg_dur)
|
||||
for i in range(num_chunks)
|
||||
]
|
||||
return chunk_lst
|
||||
|
||||
def _get_audio_info(self, wav_file: str,
|
||||
split_chunks: bool) -> List[List[str]]:
|
||||
waveform, sr = load_audio(wav_file)
|
||||
spk_id, sess_id, utt_id = wav_file.split("/")[-3:]
|
||||
audio_id = '-'.join([spk_id, sess_id, utt_id.split(".")[0]])
|
||||
audio_duration = waveform.shape[0] / sr
|
||||
|
||||
ret = []
|
||||
if split_chunks: # Split into pieces of self.chunk_duration seconds.
|
||||
uniq_chunks_list = self._get_chunks(self.chunk_duration, audio_id,
|
||||
audio_duration)
|
||||
|
||||
for chunk in uniq_chunks_list:
|
||||
s, e = chunk.split("_")[-2:] # Timestamps of start and end
|
||||
start_sample = int(float(s) * sr)
|
||||
end_sample = int(float(e) * sr)
|
||||
# id, duration, wav, start, stop, spk_id
|
||||
ret.append([
|
||||
chunk, audio_duration, wav_file, start_sample, end_sample,
|
||||
spk_id
|
||||
])
|
||||
else: # Keep whole audio.
|
||||
ret.append([
|
||||
audio_id, audio_duration, wav_file, 0, waveform.shape[0], spk_id
|
||||
])
|
||||
return ret
|
||||
|
||||
def generate_csv(self,
|
||||
wav_files: List[str],
|
||||
output_file: str,
|
||||
split_chunks: bool=True):
|
||||
print(f'Generating csv: {output_file}')
|
||||
header = ["id", "duration", "wav", "start", "stop", "spk_id"]
|
||||
# Note: this may occurs c++ exception, but the program will execute fine
|
||||
# so we can ignore the exception
|
||||
with Pool(cpu_count()) as p:
|
||||
infos = list(
|
||||
tqdm(
|
||||
p.imap(lambda x: self._get_audio_info(x, split_chunks),
|
||||
wav_files),
|
||||
total=len(wav_files)))
|
||||
|
||||
csv_lines = []
|
||||
for info in infos:
|
||||
csv_lines.extend(info)
|
||||
|
||||
with open(output_file, mode="w") as csv_f:
|
||||
csv_writer = csv.writer(
|
||||
csv_f, delimiter=",", quotechar='"', quoting=csv.QUOTE_MINIMAL)
|
||||
csv_writer.writerow(header)
|
||||
for line in csv_lines:
|
||||
csv_writer.writerow(line)
|
||||
|
||||
def prepare_data(self):
|
||||
# Audio of speakers in veri_test_file should not be included in training set.
|
||||
print("start to prepare the data csv file")
|
||||
enroll_files = set()
|
||||
test_files = set()
|
||||
# get the enroll and test audio file path
|
||||
with open(self.veri_test_file, 'r') as f:
|
||||
for line in f.readlines():
|
||||
_, enrol_file, test_file = line.strip().split(' ')
|
||||
enroll_files.add(os.path.join(self.wav_path, enrol_file))
|
||||
test_files.add(os.path.join(self.wav_path, test_file))
|
||||
enroll_files = sorted(enroll_files)
|
||||
test_files = sorted(test_files)
|
||||
|
||||
# get the enroll and test speakers
|
||||
test_spks = set()
|
||||
for file in (enroll_files + test_files):
|
||||
spk = file.split('/wav/')[1].split('/')[0]
|
||||
test_spks.add(spk)
|
||||
|
||||
# get all the train and dev audios file path
|
||||
audio_files = []
|
||||
speakers = set()
|
||||
print("Getting file list...")
|
||||
for path in [self.wav_path, self.vox2_base_path]:
|
||||
# if vox2 directory is not set and vox2 is not a directory
|
||||
# we will not process this directory
|
||||
if not path or not os.path.exists(path):
|
||||
print(f"{path} is an invalid path, please check again, "
|
||||
"and we will ignore the vox2 base path")
|
||||
continue
|
||||
for file in glob.glob(
|
||||
os.path.join(path, "**", "*.wav"), recursive=True):
|
||||
spk = file.split('/wav/')[1].split('/')[0]
|
||||
if spk in test_spks:
|
||||
continue
|
||||
speakers.add(spk)
|
||||
audio_files.append(file)
|
||||
|
||||
print(
|
||||
f"start to generate the {os.path.join(self.meta_path, 'spk_id2label.txt')}"
|
||||
)
|
||||
# encode the train and dev speakers label to spk_id2label.txt
|
||||
with open(os.path.join(self.meta_path, 'spk_id2label.txt'), 'w') as f:
|
||||
for label, spk_id in enumerate(
|
||||
sorted(speakers)): # 1211 vox1, 5994 vox2, 7205 vox1+2
|
||||
f.write(f'{spk_id} {label}\n')
|
||||
|
||||
audio_files = sorted(audio_files)
|
||||
random.shuffle(audio_files)
|
||||
split_idx = int(self.split_ratio * len(audio_files))
|
||||
# split_ratio to train
|
||||
train_files, dev_files = audio_files[:split_idx], audio_files[
|
||||
split_idx:]
|
||||
|
||||
self.generate_csv(train_files, os.path.join(self.csv_path, 'train.csv'))
|
||||
self.generate_csv(dev_files, os.path.join(self.csv_path, 'dev.csv'))
|
||||
|
||||
self.generate_csv(
|
||||
enroll_files,
|
||||
os.path.join(self.csv_path, 'enroll.csv'),
|
||||
split_chunks=False)
|
||||
self.generate_csv(
|
||||
test_files,
|
||||
os.path.join(self.csv_path, 'test.csv'),
|
||||
split_chunks=False)
|
||||
|
||||
def __getitem__(self, idx):
|
||||
return self._convert_to_record(idx)
|
||||
|
||||
def __len__(self):
|
||||
return len(self._data)
|
@ -0,0 +1,20 @@
|
||||
# Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
from .functional import compute_fbank_matrix
|
||||
from .functional import create_dct
|
||||
from .functional import fft_frequencies
|
||||
from .functional import hz_to_mel
|
||||
from .functional import mel_frequencies
|
||||
from .functional import mel_to_hz
|
||||
from .functional import power_to_db
|
@ -0,0 +1,266 @@
|
||||
# Copyright (c) 2021 PaddlePaddle Authors. All Rights Reserved.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
# Modified from librosa(https://github.com/librosa/librosa)
|
||||
import math
|
||||
from typing import Optional
|
||||
from typing import Union
|
||||
|
||||
import paddle
|
||||
from paddle import Tensor
|
||||
|
||||
__all__ = [
|
||||
'hz_to_mel',
|
||||
'mel_to_hz',
|
||||
'mel_frequencies',
|
||||
'fft_frequencies',
|
||||
'compute_fbank_matrix',
|
||||
'power_to_db',
|
||||
'create_dct',
|
||||
]
|
||||
|
||||
|
||||
def hz_to_mel(freq: Union[Tensor, float],
|
||||
htk: bool=False) -> Union[Tensor, float]:
|
||||
"""Convert Hz to Mels.
|
||||
|
||||
Args:
|
||||
freq (Union[Tensor, float]): The input tensor with arbitrary shape.
|
||||
htk (bool, optional): Use htk scaling. Defaults to False.
|
||||
|
||||
Returns:
|
||||
Union[Tensor, float]: Frequency in mels.
|
||||
"""
|
||||
|
||||
if htk:
|
||||
if isinstance(freq, Tensor):
|
||||
return 2595.0 * paddle.log10(1.0 + freq / 700.0)
|
||||
else:
|
||||
return 2595.0 * math.log10(1.0 + freq / 700.0)
|
||||
|
||||
# Fill in the linear part
|
||||
f_min = 0.0
|
||||
f_sp = 200.0 / 3
|
||||
|
||||
mels = (freq - f_min) / f_sp
|
||||
|
||||
# Fill in the log-scale part
|
||||
|
||||
min_log_hz = 1000.0 # beginning of log region (Hz)
|
||||
min_log_mel = (min_log_hz - f_min) / f_sp # same (Mels)
|
||||
logstep = math.log(6.4) / 27.0 # step size for log region
|
||||
|
||||
if isinstance(freq, Tensor):
|
||||
target = min_log_mel + paddle.log(
|
||||
freq / min_log_hz + 1e-10) / logstep # prevent nan with 1e-10
|
||||
mask = (freq > min_log_hz).astype(freq.dtype)
|
||||
mels = target * mask + mels * (
|
||||
1 - mask) # will replace by masked_fill OP in future
|
||||
else:
|
||||
if freq >= min_log_hz:
|
||||
mels = min_log_mel + math.log(freq / min_log_hz + 1e-10) / logstep
|
||||
|
||||
return mels
|
||||
|
||||
|
||||
def mel_to_hz(mel: Union[float, Tensor],
|
||||
htk: bool=False) -> Union[float, Tensor]:
|
||||
"""Convert mel bin numbers to frequencies.
|
||||
|
||||
Args:
|
||||
mel (Union[float, Tensor]): The mel frequency represented as a tensor with arbitrary shape.
|
||||
htk (bool, optional): Use htk scaling. Defaults to False.
|
||||
|
||||
Returns:
|
||||
Union[float, Tensor]: Frequencies in Hz.
|
||||
"""
|
||||
if htk:
|
||||
return 700.0 * (10.0**(mel / 2595.0) - 1.0)
|
||||
|
||||
f_min = 0.0
|
||||
f_sp = 200.0 / 3
|
||||
freqs = f_min + f_sp * mel
|
||||
# And now the nonlinear scale
|
||||
min_log_hz = 1000.0 # beginning of log region (Hz)
|
||||
min_log_mel = (min_log_hz - f_min) / f_sp # same (Mels)
|
||||
logstep = math.log(6.4) / 27.0 # step size for log region
|
||||
if isinstance(mel, Tensor):
|
||||
target = min_log_hz * paddle.exp(logstep * (mel - min_log_mel))
|
||||
mask = (mel > min_log_mel).astype(mel.dtype)
|
||||
freqs = target * mask + freqs * (
|
||||
1 - mask) # will replace by masked_fill OP in future
|
||||
else:
|
||||
if mel >= min_log_mel:
|
||||
freqs = min_log_hz * math.exp(logstep * (mel - min_log_mel))
|
||||
|
||||
return freqs
|
||||
|
||||
|
||||
def mel_frequencies(n_mels: int=64,
|
||||
f_min: float=0.0,
|
||||
f_max: float=11025.0,
|
||||
htk: bool=False,
|
||||
dtype: str='float32') -> Tensor:
|
||||
"""Compute mel frequencies.
|
||||
|
||||
Args:
|
||||
n_mels (int, optional): Number of mel bins. Defaults to 64.
|
||||
f_min (float, optional): Minimum frequency in Hz. Defaults to 0.0.
|
||||
fmax (float, optional): Maximum frequency in Hz. Defaults to 11025.0.
|
||||
htk (bool, optional): Use htk scaling. Defaults to False.
|
||||
dtype (str, optional): The data type of the return frequencies. Defaults to 'float32'.
|
||||
|
||||
Returns:
|
||||
Tensor: Tensor of n_mels frequencies in Hz with shape `(n_mels,)`.
|
||||
"""
|
||||
# 'Center freqs' of mel bands - uniformly spaced between limits
|
||||
min_mel = hz_to_mel(f_min, htk=htk)
|
||||
max_mel = hz_to_mel(f_max, htk=htk)
|
||||
mels = paddle.linspace(min_mel, max_mel, n_mels, dtype=dtype)
|
||||
freqs = mel_to_hz(mels, htk=htk)
|
||||
return freqs
|
||||
|
||||
|
||||
def fft_frequencies(sr: int, n_fft: int, dtype: str='float32') -> Tensor:
|
||||
"""Compute fourier frequencies.
|
||||
|
||||
Args:
|
||||
sr (int): Sample rate.
|
||||
n_fft (int): Number of fft bins.
|
||||
dtype (str, optional): The data type of the return frequencies. Defaults to 'float32'.
|
||||
|
||||
Returns:
|
||||
Tensor: FFT frequencies in Hz with shape `(n_fft//2 + 1,)`.
|
||||
"""
|
||||
return paddle.linspace(0, float(sr) / 2, int(1 + n_fft // 2), dtype=dtype)
|
||||
|
||||
|
||||
def compute_fbank_matrix(sr: int,
|
||||
n_fft: int,
|
||||
n_mels: int=64,
|
||||
f_min: float=0.0,
|
||||
f_max: Optional[float]=None,
|
||||
htk: bool=False,
|
||||
norm: Union[str, float]='slaney',
|
||||
dtype: str='float32') -> Tensor:
|
||||
"""Compute fbank matrix.
|
||||
|
||||
Args:
|
||||
sr (int): Sample rate.
|
||||
n_fft (int): Number of fft bins.
|
||||
n_mels (int, optional): Number of mel bins. Defaults to 64.
|
||||
f_min (float, optional): Minimum frequency in Hz. Defaults to 0.0.
|
||||
f_max (Optional[float], optional): Maximum frequency in Hz. Defaults to None.
|
||||
htk (bool, optional): Use htk scaling. Defaults to False.
|
||||
norm (Union[str, float], optional): Type of normalization. Defaults to 'slaney'.
|
||||
dtype (str, optional): The data type of the return matrix. Defaults to 'float32'.
|
||||
|
||||
Returns:
|
||||
Tensor: Mel transform matrix with shape `(n_mels, n_fft//2 + 1)`.
|
||||
"""
|
||||
|
||||
if f_max is None:
|
||||
f_max = float(sr) / 2
|
||||
|
||||
# Initialize the weights
|
||||
weights = paddle.zeros((n_mels, int(1 + n_fft // 2)), dtype=dtype)
|
||||
|
||||
# Center freqs of each FFT bin
|
||||
fftfreqs = fft_frequencies(sr=sr, n_fft=n_fft, dtype=dtype)
|
||||
|
||||
# 'Center freqs' of mel bands - uniformly spaced between limits
|
||||
mel_f = mel_frequencies(
|
||||
n_mels + 2, f_min=f_min, f_max=f_max, htk=htk, dtype=dtype)
|
||||
|
||||
fdiff = mel_f[1:] - mel_f[:-1] #np.diff(mel_f)
|
||||
ramps = mel_f.unsqueeze(1) - fftfreqs.unsqueeze(0)
|
||||
#ramps = np.subtract.outer(mel_f, fftfreqs)
|
||||
|
||||
for i in range(n_mels):
|
||||
# lower and upper slopes for all bins
|
||||
lower = -ramps[i] / fdiff[i]
|
||||
upper = ramps[i + 2] / fdiff[i + 1]
|
||||
|
||||
# .. then intersect them with each other and zero
|
||||
weights[i] = paddle.maximum(
|
||||
paddle.zeros_like(lower), paddle.minimum(lower, upper))
|
||||
|
||||
# Slaney-style mel is scaled to be approx constant energy per channel
|
||||
if norm == 'slaney':
|
||||
enorm = 2.0 / (mel_f[2:n_mels + 2] - mel_f[:n_mels])
|
||||
weights *= enorm.unsqueeze(1)
|
||||
elif isinstance(norm, int) or isinstance(norm, float):
|
||||
weights = paddle.nn.functional.normalize(weights, p=norm, axis=-1)
|
||||
|
||||
return weights
|
||||
|
||||
|
||||
def power_to_db(spect: Tensor,
|
||||
ref_value: float=1.0,
|
||||
amin: float=1e-10,
|
||||
top_db: Optional[float]=None) -> Tensor:
|
||||
"""Convert a power spectrogram (amplitude squared) to decibel (dB) units. The function computes the scaling `10 * log10(x / ref)` in a numerically stable way.
|
||||
|
||||
Args:
|
||||
spect (Tensor): STFT power spectrogram.
|
||||
ref_value (float, optional): The reference value. If smaller than 1.0, the db level of the signal will be pulled up accordingly. Otherwise, the db level is pushed down. Defaults to 1.0.
|
||||
amin (float, optional): Minimum threshold. Defaults to 1e-10.
|
||||
top_db (Optional[float], optional): Threshold the output at `top_db` below the peak. Defaults to None.
|
||||
|
||||
Returns:
|
||||
Tensor: Power spectrogram in db scale.
|
||||
"""
|
||||
if amin <= 0:
|
||||
raise Exception("amin must be strictly positive")
|
||||
|
||||
if ref_value <= 0:
|
||||
raise Exception("ref_value must be strictly positive")
|
||||
|
||||
ones = paddle.ones_like(spect)
|
||||
log_spec = 10.0 * paddle.log10(paddle.maximum(ones * amin, spect))
|
||||
log_spec -= 10.0 * math.log10(max(ref_value, amin))
|
||||
|
||||
if top_db is not None:
|
||||
if top_db < 0:
|
||||
raise Exception("top_db must be non-negative")
|
||||
log_spec = paddle.maximum(log_spec, ones * (log_spec.max() - top_db))
|
||||
|
||||
return log_spec
|
||||
|
||||
|
||||
def create_dct(n_mfcc: int,
|
||||
n_mels: int,
|
||||
norm: Optional[str]='ortho',
|
||||
dtype: str='float32') -> Tensor:
|
||||
"""Create a discrete cosine transform(DCT) matrix.
|
||||
|
||||
Args:
|
||||
n_mfcc (int): Number of mel frequency cepstral coefficients.
|
||||
n_mels (int): Number of mel filterbanks.
|
||||
norm (Optional[str], optional): Normalization type. Defaults to 'ortho'.
|
||||
dtype (str, optional): The data type of the return matrix. Defaults to 'float32'.
|
||||
|
||||
Returns:
|
||||
Tensor: The DCT matrix with shape `(n_mels, n_mfcc)`.
|
||||
"""
|
||||
n = paddle.arange(n_mels, dtype=dtype)
|
||||
k = paddle.arange(n_mfcc, dtype=dtype).unsqueeze(1)
|
||||
dct = paddle.cos(math.pi / float(n_mels) * (n + 0.5) *
|
||||
k) # size (n_mfcc, n_mels)
|
||||
if norm is None:
|
||||
dct *= 2.0
|
||||
else:
|
||||
assert norm == "ortho"
|
||||
dct[0] *= 1.0 / math.sqrt(2.0)
|
||||
dct *= math.sqrt(2.0 / float(n_mels))
|
||||
return dct.T
|
@ -0,0 +1,373 @@
|
||||
# Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
import math
|
||||
from typing import List
|
||||
from typing import Tuple
|
||||
from typing import Union
|
||||
|
||||
import paddle
|
||||
from paddle import Tensor
|
||||
|
||||
|
||||
class WindowFunctionRegister(object):
|
||||
def __init__(self):
|
||||
self._functions_dict = dict()
|
||||
|
||||
def register(self):
|
||||
def add_subfunction(func):
|
||||
name = func.__name__
|
||||
self._functions_dict[name] = func
|
||||
return func
|
||||
|
||||
return add_subfunction
|
||||
|
||||
def get(self, name):
|
||||
return self._functions_dict[name]
|
||||
|
||||
|
||||
window_function_register = WindowFunctionRegister()
|
||||
|
||||
|
||||
@window_function_register.register()
|
||||
def _cat(x: List[Tensor], data_type: str) -> Tensor:
|
||||
l = [paddle.to_tensor(_, data_type) for _ in x]
|
||||
return paddle.concat(l)
|
||||
|
||||
|
||||
@window_function_register.register()
|
||||
def _acosh(x: Union[Tensor, float]) -> Tensor:
|
||||
if isinstance(x, float):
|
||||
return math.log(x + math.sqrt(x**2 - 1))
|
||||
return paddle.log(x + paddle.sqrt(paddle.square(x) - 1))
|
||||
|
||||
|
||||
@window_function_register.register()
|
||||
def _extend(M: int, sym: bool) -> bool:
|
||||
"""Extend window by 1 sample if needed for DFT-even symmetry."""
|
||||
if not sym:
|
||||
return M + 1, True
|
||||
else:
|
||||
return M, False
|
||||
|
||||
|
||||
@window_function_register.register()
|
||||
def _len_guards(M: int) -> bool:
|
||||
"""Handle small or incorrect window lengths."""
|
||||
if int(M) != M or M < 0:
|
||||
raise ValueError('Window length M must be a non-negative integer')
|
||||
|
||||
return M <= 1
|
||||
|
||||
|
||||
@window_function_register.register()
|
||||
def _truncate(w: Tensor, needed: bool) -> Tensor:
|
||||
"""Truncate window by 1 sample if needed for DFT-even symmetry."""
|
||||
if needed:
|
||||
return w[:-1]
|
||||
else:
|
||||
return w
|
||||
|
||||
|
||||
@window_function_register.register()
|
||||
def _general_gaussian(M: int, p, sig, sym: bool=True,
|
||||
dtype: str='float64') -> Tensor:
|
||||
"""Compute a window with a generalized Gaussian shape.
|
||||
This function is consistent with scipy.signal.windows.general_gaussian().
|
||||
"""
|
||||
if _len_guards(M):
|
||||
return paddle.ones((M, ), dtype=dtype)
|
||||
M, needs_trunc = _extend(M, sym)
|
||||
|
||||
n = paddle.arange(0, M, dtype=dtype) - (M - 1.0) / 2.0
|
||||
w = paddle.exp(-0.5 * paddle.abs(n / sig)**(2 * p))
|
||||
|
||||
return _truncate(w, needs_trunc)
|
||||
|
||||
|
||||
@window_function_register.register()
|
||||
def _general_cosine(M: int, a: float, sym: bool=True,
|
||||
dtype: str='float64') -> Tensor:
|
||||
"""Compute a generic weighted sum of cosine terms window.
|
||||
This function is consistent with scipy.signal.windows.general_cosine().
|
||||
"""
|
||||
if _len_guards(M):
|
||||
return paddle.ones((M, ), dtype=dtype)
|
||||
M, needs_trunc = _extend(M, sym)
|
||||
fac = paddle.linspace(-math.pi, math.pi, M, dtype=dtype)
|
||||
w = paddle.zeros((M, ), dtype=dtype)
|
||||
for k in range(len(a)):
|
||||
w += a[k] * paddle.cos(k * fac)
|
||||
return _truncate(w, needs_trunc)
|
||||
|
||||
|
||||
@window_function_register.register()
|
||||
def _general_hamming(M: int, alpha: float, sym: bool=True,
|
||||
dtype: str='float64') -> Tensor:
|
||||
"""Compute a generalized Hamming window.
|
||||
This function is consistent with scipy.signal.windows.general_hamming()
|
||||
"""
|
||||
return _general_cosine(M, [alpha, 1.0 - alpha], sym, dtype=dtype)
|
||||
|
||||
|
||||
@window_function_register.register()
|
||||
def _taylor(M: int,
|
||||
nbar=4,
|
||||
sll=30,
|
||||
norm=True,
|
||||
sym: bool=True,
|
||||
dtype: str='float64') -> Tensor:
|
||||
"""Compute a Taylor window.
|
||||
The Taylor window taper function approximates the Dolph-Chebyshev window's
|
||||
constant sidelobe level for a parameterized number of near-in sidelobes.
|
||||
"""
|
||||
if _len_guards(M):
|
||||
return paddle.ones((M, ), dtype=dtype)
|
||||
M, needs_trunc = _extend(M, sym)
|
||||
# Original text uses a negative sidelobe level parameter and then negates
|
||||
# it in the calculation of B. To keep consistent with other methods we
|
||||
# assume the sidelobe level parameter to be positive.
|
||||
B = 10**(sll / 20)
|
||||
A = _acosh(B) / math.pi
|
||||
s2 = nbar**2 / (A**2 + (nbar - 0.5)**2)
|
||||
ma = paddle.arange(1, nbar, dtype=dtype)
|
||||
|
||||
Fm = paddle.empty((nbar - 1, ), dtype=dtype)
|
||||
signs = paddle.empty_like(ma)
|
||||
signs[::2] = 1
|
||||
signs[1::2] = -1
|
||||
m2 = ma * ma
|
||||
for mi in range(len(ma)):
|
||||
numer = signs[mi] * paddle.prod(1 - m2[mi] / s2 / (A**2 + (ma - 0.5)**2
|
||||
))
|
||||
if mi == 0:
|
||||
denom = 2 * paddle.prod(1 - m2[mi] / m2[mi + 1:])
|
||||
elif mi == len(ma) - 1:
|
||||
denom = 2 * paddle.prod(1 - m2[mi] / m2[:mi])
|
||||
else:
|
||||
denom = (2 * paddle.prod(1 - m2[mi] / m2[:mi]) *
|
||||
paddle.prod(1 - m2[mi] / m2[mi + 1:]))
|
||||
|
||||
Fm[mi] = numer / denom
|
||||
|
||||
def W(n):
|
||||
return 1 + 2 * paddle.matmul(
|
||||
Fm.unsqueeze(0),
|
||||
paddle.cos(2 * math.pi * ma.unsqueeze(1) *
|
||||
(n - M / 2.0 + 0.5) / M), )
|
||||
|
||||
w = W(paddle.arange(0, M, dtype=dtype))
|
||||
|
||||
# normalize (Note that this is not described in the original text [1])
|
||||
if norm:
|
||||
scale = 1.0 / W((M - 1) / 2)
|
||||
w *= scale
|
||||
w = w.squeeze()
|
||||
return _truncate(w, needs_trunc)
|
||||
|
||||
|
||||
@window_function_register.register()
|
||||
def _hamming(M: int, sym: bool=True, dtype: str='float64') -> Tensor:
|
||||
"""Compute a Hamming window.
|
||||
The Hamming window is a taper formed by using a raised cosine with
|
||||
non-zero endpoints, optimized to minimize the nearest side lobe.
|
||||
"""
|
||||
return _general_hamming(M, 0.54, sym, dtype=dtype)
|
||||
|
||||
|
||||
@window_function_register.register()
|
||||
def _hann(M: int, sym: bool=True, dtype: str='float64') -> Tensor:
|
||||
"""Compute a Hann window.
|
||||
The Hann window is a taper formed by using a raised cosine or sine-squared
|
||||
with ends that touch zero.
|
||||
"""
|
||||
return _general_hamming(M, 0.5, sym, dtype=dtype)
|
||||
|
||||
|
||||
@window_function_register.register()
|
||||
def _tukey(M: int, alpha=0.5, sym: bool=True, dtype: str='float64') -> Tensor:
|
||||
"""Compute a Tukey window.
|
||||
The Tukey window is also known as a tapered cosine window.
|
||||
"""
|
||||
if _len_guards(M):
|
||||
return paddle.ones((M, ), dtype=dtype)
|
||||
|
||||
if alpha <= 0:
|
||||
return paddle.ones((M, ), dtype=dtype)
|
||||
elif alpha >= 1.0:
|
||||
return hann(M, sym=sym)
|
||||
|
||||
M, needs_trunc = _extend(M, sym)
|
||||
|
||||
n = paddle.arange(0, M, dtype=dtype)
|
||||
width = int(alpha * (M - 1) / 2.0)
|
||||
n1 = n[0:width + 1]
|
||||
n2 = n[width + 1:M - width - 1]
|
||||
n3 = n[M - width - 1:]
|
||||
|
||||
w1 = 0.5 * (1 + paddle.cos(math.pi * (-1 + 2.0 * n1 / alpha / (M - 1))))
|
||||
w2 = paddle.ones(n2.shape, dtype=dtype)
|
||||
w3 = 0.5 * (1 + paddle.cos(math.pi * (-2.0 / alpha + 1 + 2.0 * n3 / alpha /
|
||||
(M - 1))))
|
||||
w = paddle.concat([w1, w2, w3])
|
||||
|
||||
return _truncate(w, needs_trunc)
|
||||
|
||||
|
||||
@window_function_register.register()
|
||||
def _gaussian(M: int, std: float, sym: bool=True,
|
||||
dtype: str='float64') -> Tensor:
|
||||
"""Compute a Gaussian window.
|
||||
The Gaussian widows has a Gaussian shape defined by the standard deviation(std).
|
||||
"""
|
||||
if _len_guards(M):
|
||||
return paddle.ones((M, ), dtype=dtype)
|
||||
M, needs_trunc = _extend(M, sym)
|
||||
|
||||
n = paddle.arange(0, M, dtype=dtype) - (M - 1.0) / 2.0
|
||||
sig2 = 2 * std * std
|
||||
w = paddle.exp(-(n**2) / sig2)
|
||||
|
||||
return _truncate(w, needs_trunc)
|
||||
|
||||
|
||||
@window_function_register.register()
|
||||
def _exponential(M: int,
|
||||
center=None,
|
||||
tau=1.0,
|
||||
sym: bool=True,
|
||||
dtype: str='float64') -> Tensor:
|
||||
"""Compute an exponential (or Poisson) window."""
|
||||
if sym and center is not None:
|
||||
raise ValueError("If sym==True, center must be None.")
|
||||
if _len_guards(M):
|
||||
return paddle.ones((M, ), dtype=dtype)
|
||||
M, needs_trunc = _extend(M, sym)
|
||||
|
||||
if center is None:
|
||||
center = (M - 1) / 2
|
||||
|
||||
n = paddle.arange(0, M, dtype=dtype)
|
||||
w = paddle.exp(-paddle.abs(n - center) / tau)
|
||||
|
||||
return _truncate(w, needs_trunc)
|
||||
|
||||
|
||||
@window_function_register.register()
|
||||
def _triang(M: int, sym: bool=True, dtype: str='float64') -> Tensor:
|
||||
"""Compute a triangular window."""
|
||||
if _len_guards(M):
|
||||
return paddle.ones((M, ), dtype=dtype)
|
||||
M, needs_trunc = _extend(M, sym)
|
||||
|
||||
n = paddle.arange(1, (M + 1) // 2 + 1, dtype=dtype)
|
||||
if M % 2 == 0:
|
||||
w = (2 * n - 1.0) / M
|
||||
w = paddle.concat([w, w[::-1]])
|
||||
else:
|
||||
w = 2 * n / (M + 1.0)
|
||||
w = paddle.concat([w, w[-2::-1]])
|
||||
|
||||
return _truncate(w, needs_trunc)
|
||||
|
||||
|
||||
@window_function_register.register()
|
||||
def _bohman(M: int, sym: bool=True, dtype: str='float64') -> Tensor:
|
||||
"""Compute a Bohman window.
|
||||
The Bohman window is the autocorrelation of a cosine window.
|
||||
"""
|
||||
if _len_guards(M):
|
||||
return paddle.ones((M, ), dtype=dtype)
|
||||
M, needs_trunc = _extend(M, sym)
|
||||
|
||||
fac = paddle.abs(paddle.linspace(-1, 1, M, dtype=dtype)[1:-1])
|
||||
w = (1 - fac) * paddle.cos(math.pi * fac) + 1.0 / math.pi * paddle.sin(
|
||||
math.pi * fac)
|
||||
w = _cat([0, w, 0], dtype)
|
||||
|
||||
return _truncate(w, needs_trunc)
|
||||
|
||||
|
||||
@window_function_register.register()
|
||||
def _blackman(M: int, sym: bool=True, dtype: str='float64') -> Tensor:
|
||||
"""Compute a Blackman window.
|
||||
The Blackman window is a taper formed by using the first three terms of
|
||||
a summation of cosines. It was designed to have close to the minimal
|
||||
leakage possible. It is close to optimal, only slightly worse than a
|
||||
Kaiser window.
|
||||
"""
|
||||
return _general_cosine(M, [0.42, 0.50, 0.08], sym, dtype=dtype)
|
||||
|
||||
|
||||
@window_function_register.register()
|
||||
def _cosine(M: int, sym: bool=True, dtype: str='float64') -> Tensor:
|
||||
"""Compute a window with a simple cosine shape."""
|
||||
if _len_guards(M):
|
||||
return paddle.ones((M, ), dtype=dtype)
|
||||
M, needs_trunc = _extend(M, sym)
|
||||
w = paddle.sin(math.pi / M * (paddle.arange(0, M, dtype=dtype) + 0.5))
|
||||
|
||||
return _truncate(w, needs_trunc)
|
||||
|
||||
|
||||
def get_window(
|
||||
window: Union[str, Tuple[str, float]],
|
||||
win_length: int,
|
||||
fftbins: bool=True,
|
||||
dtype: str='float64', ) -> Tensor:
|
||||
"""Return a window of a given length and type.
|
||||
|
||||
Args:
|
||||
window (Union[str, Tuple[str, float]]): The window function applied to the signal before the Fourier transform. Supported window functions: 'hamming', 'hann', 'gaussian', 'general_gaussian', 'exponential', 'triang', 'bohman', 'blackman', 'cosine', 'tukey', 'taylor'.
|
||||
win_length (int): Number of samples.
|
||||
fftbins (bool, optional): If True, create a "periodic" window. Otherwise, create a "symmetric" window, for use in filter design. Defaults to True.
|
||||
dtype (str, optional): The data type of the return window. Defaults to 'float64'.
|
||||
|
||||
Returns:
|
||||
Tensor: The window represented as a tensor.
|
||||
|
||||
Examples:
|
||||
.. code-block:: python
|
||||
|
||||
import paddle
|
||||
|
||||
n_fft = 512
|
||||
cosine_window = paddle.audio.functional.get_window('cosine', n_fft)
|
||||
|
||||
std = 7
|
||||
gaussian_window = paddle.audio.functional.get_window(('gaussian',std), n_fft)
|
||||
"""
|
||||
sym = not fftbins
|
||||
|
||||
args = ()
|
||||
if isinstance(window, tuple):
|
||||
winstr = window[0]
|
||||
if len(window) > 1:
|
||||
args = window[1:]
|
||||
elif isinstance(window, str):
|
||||
if window in ['gaussian', 'exponential']:
|
||||
raise ValueError("The '" + window + "' window needs one or "
|
||||
"more parameters -- pass a tuple.")
|
||||
else:
|
||||
winstr = window
|
||||
else:
|
||||
raise ValueError("%s as window type is not supported." %
|
||||
str(type(window)))
|
||||
|
||||
try:
|
||||
winfunc = window_function_register.get('_' + winstr)
|
||||
except KeyError as e:
|
||||
raise ValueError("Unknown window type.") from e
|
||||
|
||||
params = (win_length, ) + args
|
||||
kwargs = {'sym': sym}
|
||||
return winfunc(*params, dtype=dtype, **kwargs)
|
@ -0,0 +1,677 @@
|
||||
# Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
import os
|
||||
import warnings
|
||||
from typing import Optional
|
||||
from typing import Tuple
|
||||
|
||||
import numpy as np
|
||||
import paddle
|
||||
import resampy
|
||||
import soundfile
|
||||
from scipy.io import wavfile
|
||||
|
||||
from ..utils import depth_convert
|
||||
from ..utils import ParameterError
|
||||
from .common import AudioInfo
|
||||
|
||||
__all__ = [
|
||||
'resample',
|
||||
'to_mono',
|
||||
'normalize',
|
||||
'save',
|
||||
'soundfile_save',
|
||||
'load',
|
||||
'soundfile_load',
|
||||
'info',
|
||||
]
|
||||
NORMALMIZE_TYPES = ['linear', 'gaussian']
|
||||
MERGE_TYPES = ['ch0', 'ch1', 'random', 'average']
|
||||
RESAMPLE_MODES = ['kaiser_best', 'kaiser_fast']
|
||||
EPS = 1e-8
|
||||
|
||||
|
||||
def resample(y: np.ndarray,
|
||||
src_sr: int,
|
||||
target_sr: int,
|
||||
mode: str='kaiser_fast') -> np.ndarray:
|
||||
"""Audio resampling.
|
||||
|
||||
Args:
|
||||
y (np.ndarray): Input waveform array in 1D or 2D.
|
||||
src_sr (int): Source sample rate.
|
||||
target_sr (int): Target sample rate.
|
||||
mode (str, optional): The resampling filter to use. Defaults to 'kaiser_fast'.
|
||||
|
||||
Returns:
|
||||
np.ndarray: `y` resampled to `target_sr`
|
||||
"""
|
||||
|
||||
if mode == 'kaiser_best':
|
||||
warnings.warn(
|
||||
f'Using resampy in kaiser_best to {src_sr}=>{target_sr}. This function is pretty slow, \
|
||||
we recommend the mode kaiser_fast in large scale audio training')
|
||||
|
||||
if not isinstance(y, np.ndarray):
|
||||
raise ParameterError(
|
||||
'Only support numpy np.ndarray, but received y in {type(y)}')
|
||||
|
||||
if mode not in RESAMPLE_MODES:
|
||||
raise ParameterError(f'resample mode must in {RESAMPLE_MODES}')
|
||||
|
||||
return resampy.resample(y, src_sr, target_sr, filter=mode)
|
||||
|
||||
|
||||
def to_mono(y: np.ndarray, merge_type: str='average') -> np.ndarray:
|
||||
"""Convert sterior audio to mono.
|
||||
|
||||
Args:
|
||||
y (np.ndarray): Input waveform array in 1D or 2D.
|
||||
merge_type (str, optional): Merge type to generate mono waveform. Defaults to 'average'.
|
||||
|
||||
Returns:
|
||||
np.ndarray: `y` with mono channel.
|
||||
"""
|
||||
|
||||
if merge_type not in MERGE_TYPES:
|
||||
raise ParameterError(
|
||||
f'Unsupported merge type {merge_type}, available types are {MERGE_TYPES}'
|
||||
)
|
||||
if y.ndim > 2:
|
||||
raise ParameterError(
|
||||
f'Unsupported audio array, y.ndim > 2, the shape is {y.shape}')
|
||||
if y.ndim == 1: # nothing to merge
|
||||
return y
|
||||
|
||||
if merge_type == 'ch0':
|
||||
return y[0]
|
||||
if merge_type == 'ch1':
|
||||
return y[1]
|
||||
if merge_type == 'random':
|
||||
return y[np.random.randint(0, 2)]
|
||||
|
||||
# need to do averaging according to dtype
|
||||
|
||||
if y.dtype == 'float32':
|
||||
y_out = (y[0] + y[1]) * 0.5
|
||||
elif y.dtype == 'int16':
|
||||
y_out = y.astype('int32')
|
||||
y_out = (y_out[0] + y_out[1]) // 2
|
||||
y_out = np.clip(y_out, np.iinfo(y.dtype).min,
|
||||
np.iinfo(y.dtype).max).astype(y.dtype)
|
||||
|
||||
elif y.dtype == 'int8':
|
||||
y_out = y.astype('int16')
|
||||
y_out = (y_out[0] + y_out[1]) // 2
|
||||
y_out = np.clip(y_out, np.iinfo(y.dtype).min,
|
||||
np.iinfo(y.dtype).max).astype(y.dtype)
|
||||
else:
|
||||
raise ParameterError(f'Unsupported dtype: {y.dtype}')
|
||||
return y_out
|
||||
|
||||
|
||||
def soundfile_load_(file: os.PathLike,
|
||||
offset: Optional[float]=None,
|
||||
dtype: str='int16',
|
||||
duration: Optional[int]=None) -> Tuple[np.ndarray, int]:
|
||||
"""Load audio using soundfile library. This function load audio file using libsndfile.
|
||||
|
||||
Args:
|
||||
file (os.PathLike): File of waveform.
|
||||
offset (Optional[float], optional): Offset to the start of waveform. Defaults to None.
|
||||
dtype (str, optional): Data type of waveform. Defaults to 'int16'.
|
||||
duration (Optional[int], optional): Duration of waveform to read. Defaults to None.
|
||||
|
||||
Returns:
|
||||
Tuple[np.ndarray, int]: Waveform in ndarray and its samplerate.
|
||||
"""
|
||||
with soundfile.SoundFile(file) as sf_desc:
|
||||
sr_native = sf_desc.samplerate
|
||||
if offset:
|
||||
sf_desc.seek(int(offset * sr_native))
|
||||
if duration is not None:
|
||||
frame_duration = int(duration * sr_native)
|
||||
else:
|
||||
frame_duration = -1
|
||||
y = sf_desc.read(frames=frame_duration, dtype=dtype, always_2d=False).T
|
||||
|
||||
return y, sf_desc.samplerate
|
||||
|
||||
|
||||
def normalize(y: np.ndarray, norm_type: str='linear',
|
||||
mul_factor: float=1.0) -> np.ndarray:
|
||||
"""Normalize an input audio with additional multiplier.
|
||||
|
||||
Args:
|
||||
y (np.ndarray): Input waveform array in 1D or 2D.
|
||||
norm_type (str, optional): Type of normalization. Defaults to 'linear'.
|
||||
mul_factor (float, optional): Scaling factor. Defaults to 1.0.
|
||||
|
||||
Returns:
|
||||
np.ndarray: `y` after normalization.
|
||||
"""
|
||||
|
||||
if norm_type == 'linear':
|
||||
amax = np.max(np.abs(y))
|
||||
factor = 1.0 / (amax + EPS)
|
||||
y = y * factor * mul_factor
|
||||
elif norm_type == 'gaussian':
|
||||
amean = np.mean(y)
|
||||
astd = np.std(y)
|
||||
astd = max(astd, EPS)
|
||||
y = mul_factor * (y - amean) / astd
|
||||
else:
|
||||
raise NotImplementedError(f'norm_type should be in {NORMALMIZE_TYPES}')
|
||||
|
||||
return y
|
||||
|
||||
|
||||
def soundfile_save(y: np.ndarray, sr: int, file: os.PathLike) -> None:
|
||||
"""Save audio file to disk. This function saves audio to disk using scipy.io.wavfile, with additional step to convert input waveform to int16.
|
||||
|
||||
Args:
|
||||
y (np.ndarray): Input waveform array in 1D or 2D.
|
||||
sr (int): Sample rate.
|
||||
file (os.PathLike): Path of audio file to save.
|
||||
"""
|
||||
if not file.endswith('.wav'):
|
||||
raise ParameterError(
|
||||
f'only .wav file supported, but dst file name is: {file}')
|
||||
|
||||
if sr <= 0:
|
||||
raise ParameterError(
|
||||
f'Sample rate should be larger than 0, received sr = {sr}')
|
||||
|
||||
if y.dtype not in ['int16', 'int8']:
|
||||
warnings.warn(
|
||||
f'input data type is {y.dtype}, will convert data to int16 format before saving'
|
||||
)
|
||||
y_out = depth_convert(y, 'int16')
|
||||
else:
|
||||
y_out = y
|
||||
|
||||
wavfile.write(file, sr, y_out)
|
||||
|
||||
|
||||
def soundfile_load(
|
||||
file: os.PathLike,
|
||||
sr: Optional[int]=None,
|
||||
mono: bool=True,
|
||||
merge_type: str='average', # ch0,ch1,random,average
|
||||
normal: bool=True,
|
||||
norm_type: str='linear',
|
||||
norm_mul_factor: float=1.0,
|
||||
offset: float=0.0,
|
||||
duration: Optional[int]=None,
|
||||
dtype: str='float32',
|
||||
resample_mode: str='kaiser_fast') -> Tuple[np.ndarray, int]:
|
||||
"""Load audio file from disk. This function loads audio from disk using using audio backend.
|
||||
|
||||
Args:
|
||||
file (os.PathLike): Path of audio file to load.
|
||||
sr (Optional[int], optional): Sample rate of loaded waveform. Defaults to None.
|
||||
mono (bool, optional): Return waveform with mono channel. Defaults to True.
|
||||
merge_type (str, optional): Merge type of multi-channels waveform. Defaults to 'average'.
|
||||
normal (bool, optional): Waveform normalization. Defaults to True.
|
||||
norm_type (str, optional): Type of normalization. Defaults to 'linear'.
|
||||
norm_mul_factor (float, optional): Scaling factor. Defaults to 1.0.
|
||||
offset (float, optional): Offset to the start of waveform. Defaults to 0.0.
|
||||
duration (Optional[int], optional): Duration of waveform to read. Defaults to None.
|
||||
dtype (str, optional): Data type of waveform. Defaults to 'float32'.
|
||||
resample_mode (str, optional): The resampling filter to use. Defaults to 'kaiser_fast'.
|
||||
|
||||
Returns:
|
||||
Tuple[np.ndarray, int]: Waveform in ndarray and its samplerate.
|
||||
"""
|
||||
|
||||
y, r = soundfile_load_(file, offset=offset, dtype=dtype, duration=duration)
|
||||
|
||||
if not ((y.ndim == 1 and len(y) > 0) or (y.ndim == 2 and len(y[0]) > 0)):
|
||||
raise ParameterError(f'audio file {file} looks empty')
|
||||
|
||||
if mono:
|
||||
y = to_mono(y, merge_type)
|
||||
|
||||
if sr is not None and sr != r:
|
||||
y = resample(y, r, sr, mode=resample_mode)
|
||||
r = sr
|
||||
|
||||
if normal:
|
||||
y = normalize(y, norm_type, norm_mul_factor)
|
||||
elif dtype in ['int8', 'int16']:
|
||||
# still need to do normalization, before depth conversion
|
||||
y = normalize(y, 'linear', 1.0)
|
||||
|
||||
y = depth_convert(y, dtype)
|
||||
return y, r
|
||||
|
||||
|
||||
#The code below is taken from: https://github.com/pytorch/audio/blob/main/torchaudio/backend/soundfile_backend.py, with some modifications.
|
||||
|
||||
|
||||
def _get_subtype_for_wav(dtype: paddle.dtype,
|
||||
encoding: str,
|
||||
bits_per_sample: int):
|
||||
if not encoding:
|
||||
if not bits_per_sample:
|
||||
subtype = {
|
||||
paddle.uint8: "PCM_U8",
|
||||
paddle.int16: "PCM_16",
|
||||
paddle.int32: "PCM_32",
|
||||
paddle.float32: "FLOAT",
|
||||
paddle.float64: "DOUBLE",
|
||||
}.get(dtype)
|
||||
if not subtype:
|
||||
raise ValueError(f"Unsupported dtype for wav: {dtype}")
|
||||
return subtype
|
||||
if bits_per_sample == 8:
|
||||
return "PCM_U8"
|
||||
return f"PCM_{bits_per_sample}"
|
||||
if encoding == "PCM_S":
|
||||
if not bits_per_sample:
|
||||
return "PCM_32"
|
||||
if bits_per_sample == 8:
|
||||
raise ValueError("wav does not support 8-bit signed PCM encoding.")
|
||||
return f"PCM_{bits_per_sample}"
|
||||
if encoding == "PCM_U":
|
||||
if bits_per_sample in (None, 8):
|
||||
return "PCM_U8"
|
||||
raise ValueError("wav only supports 8-bit unsigned PCM encoding.")
|
||||
if encoding == "PCM_F":
|
||||
if bits_per_sample in (None, 32):
|
||||
return "FLOAT"
|
||||
if bits_per_sample == 64:
|
||||
return "DOUBLE"
|
||||
raise ValueError("wav only supports 32/64-bit float PCM encoding.")
|
||||
if encoding == "ULAW":
|
||||
if bits_per_sample in (None, 8):
|
||||
return "ULAW"
|
||||
raise ValueError("wav only supports 8-bit mu-law encoding.")
|
||||
if encoding == "ALAW":
|
||||
if bits_per_sample in (None, 8):
|
||||
return "ALAW"
|
||||
raise ValueError("wav only supports 8-bit a-law encoding.")
|
||||
raise ValueError(f"wav does not support {encoding}.")
|
||||
|
||||
|
||||
def _get_subtype_for_sphere(encoding: str, bits_per_sample: int):
|
||||
if encoding in (None, "PCM_S"):
|
||||
return f"PCM_{bits_per_sample}" if bits_per_sample else "PCM_32"
|
||||
if encoding in ("PCM_U", "PCM_F"):
|
||||
raise ValueError(f"sph does not support {encoding} encoding.")
|
||||
if encoding == "ULAW":
|
||||
if bits_per_sample in (None, 8):
|
||||
return "ULAW"
|
||||
raise ValueError("sph only supports 8-bit for mu-law encoding.")
|
||||
if encoding == "ALAW":
|
||||
return "ALAW"
|
||||
raise ValueError(f"sph does not support {encoding}.")
|
||||
|
||||
|
||||
def _get_subtype(dtype: paddle.dtype,
|
||||
format: str,
|
||||
encoding: str,
|
||||
bits_per_sample: int):
|
||||
if format == "wav":
|
||||
return _get_subtype_for_wav(dtype, encoding, bits_per_sample)
|
||||
if format == "flac":
|
||||
if encoding:
|
||||
raise ValueError("flac does not support encoding.")
|
||||
if not bits_per_sample:
|
||||
return "PCM_16"
|
||||
if bits_per_sample > 24:
|
||||
raise ValueError("flac does not support bits_per_sample > 24.")
|
||||
return "PCM_S8" if bits_per_sample == 8 else f"PCM_{bits_per_sample}"
|
||||
if format in ("ogg", "vorbis"):
|
||||
if encoding or bits_per_sample:
|
||||
raise ValueError(
|
||||
"ogg/vorbis does not support encoding/bits_per_sample.")
|
||||
return "VORBIS"
|
||||
if format == "sph":
|
||||
return _get_subtype_for_sphere(encoding, bits_per_sample)
|
||||
if format in ("nis", "nist"):
|
||||
return "PCM_16"
|
||||
raise ValueError(f"Unsupported format: {format}")
|
||||
|
||||
|
||||
def save(
|
||||
filepath: str,
|
||||
src: paddle.Tensor,
|
||||
sample_rate: int,
|
||||
channels_first: bool=True,
|
||||
compression: Optional[float]=None,
|
||||
format: Optional[str]=None,
|
||||
encoding: Optional[str]=None,
|
||||
bits_per_sample: Optional[int]=None, ):
|
||||
"""Save audio data to file.
|
||||
|
||||
Note:
|
||||
The formats this function can handle depend on the soundfile installation.
|
||||
This function is tested on the following formats;
|
||||
|
||||
* WAV
|
||||
|
||||
* 32-bit floating-point
|
||||
* 32-bit signed integer
|
||||
* 16-bit signed integer
|
||||
* 8-bit unsigned integer
|
||||
|
||||
* FLAC
|
||||
* OGG/VORBIS
|
||||
* SPHERE
|
||||
|
||||
Note:
|
||||
``filepath`` argument is intentionally annotated as ``str`` only, even though it accepts
|
||||
``pathlib.Path`` object as well. This is for the consistency with ``"sox_io"`` backend,
|
||||
|
||||
Args:
|
||||
filepath (str or pathlib.Path): Path to audio file.
|
||||
src (paddle.Tensor): Audio data to save. must be 2D tensor.
|
||||
sample_rate (int): sampling rate
|
||||
channels_first (bool, optional): If ``True``, the given tensor is interpreted as `[channel, time]`,
|
||||
otherwise `[time, channel]`.
|
||||
compression (float of None, optional): Not used.
|
||||
It is here only for interface compatibility reason with "sox_io" backend.
|
||||
format (str or None, optional): Override the audio format.
|
||||
When ``filepath`` argument is path-like object, audio format is
|
||||
inferred from file extension. If the file extension is missing or
|
||||
different, you can specify the correct format with this argument.
|
||||
|
||||
When ``filepath`` argument is file-like object,
|
||||
this argument is required.
|
||||
|
||||
Valid values are ``"wav"``, ``"ogg"``, ``"vorbis"``,
|
||||
``"flac"`` and ``"sph"``.
|
||||
encoding (str or None, optional): Changes the encoding for supported formats.
|
||||
This argument is effective only for supported formats, such as
|
||||
``"wav"``, ``""flac"`` and ``"sph"``. Valid values are:
|
||||
|
||||
- ``"PCM_S"`` (signed integer Linear PCM)
|
||||
- ``"PCM_U"`` (unsigned integer Linear PCM)
|
||||
- ``"PCM_F"`` (floating point PCM)
|
||||
- ``"ULAW"`` (mu-law)
|
||||
- ``"ALAW"`` (a-law)
|
||||
|
||||
bits_per_sample (int or None, optional): Changes the bit depth for the
|
||||
supported formats.
|
||||
When ``format`` is one of ``"wav"``, ``"flac"`` or ``"sph"``,
|
||||
you can change the bit depth.
|
||||
Valid values are ``8``, ``16``, ``24``, ``32`` and ``64``.
|
||||
|
||||
Supported formats/encodings/bit depth/compression are:
|
||||
|
||||
``"wav"``
|
||||
- 32-bit floating-point PCM
|
||||
- 32-bit signed integer PCM
|
||||
- 24-bit signed integer PCM
|
||||
- 16-bit signed integer PCM
|
||||
- 8-bit unsigned integer PCM
|
||||
- 8-bit mu-law
|
||||
- 8-bit a-law
|
||||
|
||||
Note:
|
||||
Default encoding/bit depth is determined by the dtype of
|
||||
the input Tensor.
|
||||
|
||||
``"flac"``
|
||||
- 8-bit
|
||||
- 16-bit (default)
|
||||
- 24-bit
|
||||
|
||||
``"ogg"``, ``"vorbis"``
|
||||
- Doesn't accept changing configuration.
|
||||
|
||||
``"sph"``
|
||||
- 8-bit signed integer PCM
|
||||
- 16-bit signed integer PCM
|
||||
- 24-bit signed integer PCM
|
||||
- 32-bit signed integer PCM (default)
|
||||
- 8-bit mu-law
|
||||
- 8-bit a-law
|
||||
- 16-bit a-law
|
||||
- 24-bit a-law
|
||||
- 32-bit a-law
|
||||
|
||||
"""
|
||||
if src.ndim != 2:
|
||||
raise ValueError(f"Expected 2D Tensor, got {src.ndim}D.")
|
||||
if compression is not None:
|
||||
warnings.warn(
|
||||
'`save` function of "soundfile" backend does not support "compression" parameter. '
|
||||
"The argument is silently ignored.")
|
||||
if hasattr(filepath, "write"):
|
||||
if format is None:
|
||||
raise RuntimeError(
|
||||
"`format` is required when saving to file object.")
|
||||
ext = format.lower()
|
||||
else:
|
||||
ext = str(filepath).split(".")[-1].lower()
|
||||
|
||||
if bits_per_sample not in (None, 8, 16, 24, 32, 64):
|
||||
raise ValueError("Invalid bits_per_sample.")
|
||||
if bits_per_sample == 24:
|
||||
warnings.warn(
|
||||
"Saving audio with 24 bits per sample might warp samples near -1. "
|
||||
"Using 16 bits per sample might be able to avoid this.")
|
||||
subtype = _get_subtype(src.dtype, ext, encoding, bits_per_sample)
|
||||
|
||||
# sph is a extension used in TED-LIUM but soundfile does not recognize it as NIST format,
|
||||
# so we extend the extensions manually here
|
||||
if ext in ["nis", "nist", "sph"] and format is None:
|
||||
format = "NIST"
|
||||
|
||||
if channels_first:
|
||||
src = src.t()
|
||||
|
||||
soundfile.write(
|
||||
file=filepath,
|
||||
data=src,
|
||||
samplerate=sample_rate,
|
||||
subtype=subtype,
|
||||
format=format)
|
||||
|
||||
|
||||
_SUBTYPE2DTYPE = {
|
||||
"PCM_S8": "int8",
|
||||
"PCM_U8": "uint8",
|
||||
"PCM_16": "int16",
|
||||
"PCM_32": "int32",
|
||||
"FLOAT": "float32",
|
||||
"DOUBLE": "float64",
|
||||
}
|
||||
|
||||
|
||||
def load(
|
||||
filepath: str,
|
||||
frame_offset: int=0,
|
||||
num_frames: int=-1,
|
||||
normalize: bool=True,
|
||||
channels_first: bool=True,
|
||||
format: Optional[str]=None, ) -> Tuple[paddle.Tensor, int]:
|
||||
"""Load audio data from file.
|
||||
|
||||
Note:
|
||||
The formats this function can handle depend on the soundfile installation.
|
||||
This function is tested on the following formats;
|
||||
|
||||
* WAV
|
||||
|
||||
* 32-bit floating-point
|
||||
* 32-bit signed integer
|
||||
* 16-bit signed integer
|
||||
* 8-bit unsigned integer
|
||||
|
||||
* FLAC
|
||||
* OGG/VORBIS
|
||||
* SPHERE
|
||||
|
||||
By default (``normalize=True``, ``channels_first=True``), this function returns Tensor with
|
||||
``float32`` dtype and the shape of `[channel, time]`.
|
||||
The samples are normalized to fit in the range of ``[-1.0, 1.0]``.
|
||||
|
||||
When the input format is WAV with integer type, such as 32-bit signed integer, 16-bit
|
||||
signed integer and 8-bit unsigned integer (24-bit signed integer is not supported),
|
||||
by providing ``normalize=False``, this function can return integer Tensor, where the samples
|
||||
are expressed within the whole range of the corresponding dtype, that is, ``int32`` tensor
|
||||
for 32-bit signed PCM, ``int16`` for 16-bit signed PCM and ``uint8`` for 8-bit unsigned PCM.
|
||||
|
||||
``normalize`` parameter has no effect on 32-bit floating-point WAV and other formats, such as
|
||||
``flac`` and ``mp3``.
|
||||
For these formats, this function always returns ``float32`` Tensor with values normalized to
|
||||
``[-1.0, 1.0]``.
|
||||
|
||||
Note:
|
||||
``filepath`` argument is intentionally annotated as ``str`` only, even though it accepts
|
||||
``pathlib.Path`` object as well. This is for the consistency with ``"sox_io"`` backend.
|
||||
|
||||
Args:
|
||||
filepath (path-like object or file-like object):
|
||||
Source of audio data.
|
||||
frame_offset (int, optional):
|
||||
Number of frames to skip before start reading data.
|
||||
num_frames (int, optional):
|
||||
Maximum number of frames to read. ``-1`` reads all the remaining samples,
|
||||
starting from ``frame_offset``.
|
||||
This function may return the less number of frames if there is not enough
|
||||
frames in the given file.
|
||||
normalize (bool, optional):
|
||||
When ``True``, this function always return ``float32``, and sample values are
|
||||
normalized to ``[-1.0, 1.0]``.
|
||||
If input file is integer WAV, giving ``False`` will change the resulting Tensor type to
|
||||
integer type.
|
||||
This argument has no effect for formats other than integer WAV type.
|
||||
channels_first (bool, optional):
|
||||
When True, the returned Tensor has dimension `[channel, time]`.
|
||||
Otherwise, the returned Tensor's dimension is `[time, channel]`.
|
||||
format (str or None, optional):
|
||||
Not used. PySoundFile does not accept format hint.
|
||||
|
||||
Returns:
|
||||
(paddle.Tensor, int): Resulting Tensor and sample rate.
|
||||
If the input file has integer wav format and normalization is off, then it has
|
||||
integer type, else ``float32`` type. If ``channels_first=True``, it has
|
||||
`[channel, time]` else `[time, channel]`.
|
||||
"""
|
||||
with soundfile.SoundFile(filepath, "r") as file_:
|
||||
if file_.format != "WAV" or normalize:
|
||||
dtype = "float32"
|
||||
elif file_.subtype not in _SUBTYPE2DTYPE:
|
||||
raise ValueError(f"Unsupported subtype: {file_.subtype}")
|
||||
else:
|
||||
dtype = _SUBTYPE2DTYPE[file_.subtype]
|
||||
|
||||
frames = file_._prepare_read(frame_offset, None, num_frames)
|
||||
waveform = file_.read(frames, dtype, always_2d=True)
|
||||
sample_rate = file_.samplerate
|
||||
|
||||
waveform = paddle.to_tensor(waveform)
|
||||
if channels_first:
|
||||
waveform = paddle.transpose(waveform, perm=[1, 0])
|
||||
return waveform, sample_rate
|
||||
|
||||
|
||||
# Mapping from soundfile subtype to number of bits per sample.
|
||||
# This is mostly heuristical and the value is set to 0 when it is irrelevant
|
||||
# (lossy formats) or when it can't be inferred.
|
||||
# For ADPCM (and G72X) subtypes, it's hard to infer the bit depth because it's not part of the standard:
|
||||
# According to https://en.wikipedia.org/wiki/Adaptive_differential_pulse-code_modulation#In_telephony,
|
||||
# the default seems to be 8 bits but it can be compressed further to 4 bits.
|
||||
# The dict is inspired from
|
||||
# https://github.com/bastibe/python-soundfile/blob/744efb4b01abc72498a96b09115b42a4cabd85e4/soundfile.py#L66-L94
|
||||
_SUBTYPE_TO_BITS_PER_SAMPLE = {
|
||||
"PCM_S8": 8, # Signed 8 bit data
|
||||
"PCM_16": 16, # Signed 16 bit data
|
||||
"PCM_24": 24, # Signed 24 bit data
|
||||
"PCM_32": 32, # Signed 32 bit data
|
||||
"PCM_U8": 8, # Unsigned 8 bit data (WAV and RAW only)
|
||||
"FLOAT": 32, # 32 bit float data
|
||||
"DOUBLE": 64, # 64 bit float data
|
||||
"ULAW": 8, # U-Law encoded. See https://en.wikipedia.org/wiki/G.711#Types
|
||||
"ALAW": 8, # A-Law encoded. See https://en.wikipedia.org/wiki/G.711#Types
|
||||
"IMA_ADPCM": 0, # IMA ADPCM.
|
||||
"MS_ADPCM": 0, # Microsoft ADPCM.
|
||||
"GSM610":
|
||||
0, # GSM 6.10 encoding. (Wikipedia says 1.625 bit depth?? https://en.wikipedia.org/wiki/Full_Rate)
|
||||
"VOX_ADPCM": 0, # OKI / Dialogix ADPCM
|
||||
"G721_32": 0, # 32kbs G721 ADPCM encoding.
|
||||
"G723_24": 0, # 24kbs G723 ADPCM encoding.
|
||||
"G723_40": 0, # 40kbs G723 ADPCM encoding.
|
||||
"DWVW_12": 12, # 12 bit Delta Width Variable Word encoding.
|
||||
"DWVW_16": 16, # 16 bit Delta Width Variable Word encoding.
|
||||
"DWVW_24": 24, # 24 bit Delta Width Variable Word encoding.
|
||||
"DWVW_N": 0, # N bit Delta Width Variable Word encoding.
|
||||
"DPCM_8": 8, # 8 bit differential PCM (XI only)
|
||||
"DPCM_16": 16, # 16 bit differential PCM (XI only)
|
||||
"VORBIS": 0, # Xiph Vorbis encoding. (lossy)
|
||||
"ALAC_16": 16, # Apple Lossless Audio Codec (16 bit).
|
||||
"ALAC_20": 20, # Apple Lossless Audio Codec (20 bit).
|
||||
"ALAC_24": 24, # Apple Lossless Audio Codec (24 bit).
|
||||
"ALAC_32": 32, # Apple Lossless Audio Codec (32 bit).
|
||||
}
|
||||
|
||||
|
||||
def _get_bit_depth(subtype):
|
||||
if subtype not in _SUBTYPE_TO_BITS_PER_SAMPLE:
|
||||
warnings.warn(
|
||||
f"The {subtype} subtype is unknown to PaddleAudio. As a result, the bits_per_sample "
|
||||
"attribute will be set to 0. If you are seeing this warning, please "
|
||||
"report by opening an issue on github (after checking for existing/closed ones). "
|
||||
"You may otherwise ignore this warning.")
|
||||
return _SUBTYPE_TO_BITS_PER_SAMPLE.get(subtype, 0)
|
||||
|
||||
|
||||
_SUBTYPE_TO_ENCODING = {
|
||||
"PCM_S8": "PCM_S",
|
||||
"PCM_16": "PCM_S",
|
||||
"PCM_24": "PCM_S",
|
||||
"PCM_32": "PCM_S",
|
||||
"PCM_U8": "PCM_U",
|
||||
"FLOAT": "PCM_F",
|
||||
"DOUBLE": "PCM_F",
|
||||
"ULAW": "ULAW",
|
||||
"ALAW": "ALAW",
|
||||
"VORBIS": "VORBIS",
|
||||
}
|
||||
|
||||
|
||||
def _get_encoding(format: str, subtype: str):
|
||||
if format == "FLAC":
|
||||
return "FLAC"
|
||||
return _SUBTYPE_TO_ENCODING.get(subtype, "UNKNOWN")
|
||||
|
||||
|
||||
def info(filepath: str, format: Optional[str]=None) -> AudioInfo:
|
||||
"""Get signal information of an audio file.
|
||||
|
||||
Note:
|
||||
``filepath`` argument is intentionally annotated as ``str`` only, even though it accepts
|
||||
``pathlib.Path`` object as well. This is for the consistency with ``"sox_io"`` backend,
|
||||
|
||||
Args:
|
||||
filepath (path-like object or file-like object):
|
||||
Source of audio data.
|
||||
format (str or None, optional):
|
||||
Not used. PySoundFile does not accept format hint.
|
||||
|
||||
Returns:
|
||||
AudioInfo: meta data of the given audio.
|
||||
|
||||
"""
|
||||
sinfo = soundfile.info(filepath)
|
||||
return AudioInfo(
|
||||
sinfo.samplerate,
|
||||
sinfo.frames,
|
||||
sinfo.channels,
|
||||
bits_per_sample=_get_bit_depth(sinfo.subtype),
|
||||
encoding=_get_encoding(sinfo.format, sinfo.subtype), )
|
Loading…
Reference in new issue