diff --git a/audio/paddleaudio/backends/common.py b/audio/paddleaudio/backends/common.py
index 9d3edf812..3065fe89f 100644
--- a/audio/paddleaudio/backends/common.py
+++ b/audio/paddleaudio/backends/common.py
@@ -1,4 +1,5 @@
-# Token form https://github.com/pytorch/audio/blob/main/torchaudio/backend/common.py with modification.
+# Token from https://github.com/pytorch/audio/blob/main/torchaudio/backend/common.py with modification.
+
class AudioInfo:
"""return of info function.
@@ -30,13 +31,12 @@ class AudioInfo:
"""
def __init__(
- self,
- sample_rate: int,
- num_frames: int,
- num_channels: int,
- bits_per_sample: int,
- encoding: str,
- ):
+ self,
+ sample_rate: int,
+ num_frames: int,
+ num_channels: int,
+ bits_per_sample: int,
+ encoding: str, ):
self.sample_rate = sample_rate
self.num_frames = num_frames
self.num_channels = num_channels
@@ -44,12 +44,10 @@ class AudioInfo:
self.encoding = encoding
def __str__(self):
- return (
- f"AudioMetaData("
- f"sample_rate={self.sample_rate}, "
- f"num_frames={self.num_frames}, "
- f"num_channels={self.num_channels}, "
- f"bits_per_sample={self.bits_per_sample}, "
- f"encoding={self.encoding}"
- f")"
- )
+ return (f"AudioMetaData("
+ f"sample_rate={self.sample_rate}, "
+ f"num_frames={self.num_frames}, "
+ f"num_channels={self.num_channels}, "
+ f"bits_per_sample={self.bits_per_sample}, "
+ f"encoding={self.encoding}"
+ f")")
diff --git a/docs/source/cls/custom_dataset.md b/docs/source/cls/custom_dataset.md
index 7482d5edf..26bd60b25 100644
--- a/docs/source/cls/custom_dataset.md
+++ b/docs/source/cls/custom_dataset.md
@@ -2,7 +2,7 @@
Following this tutorial you can customize your dataset for audio classification task by using `paddlespeech`.
-A base class of classification dataset is `paddlespeech.audio.dataset.AudioClassificationDataset`. To customize your dataset you should write a dataset class derived from `AudioClassificationDataset`.
+A base class of classification dataset is `paddlespeech.audio.datasets.dataset.AudioClassificationDataset`. To customize your dataset you should write a dataset class derived from `AudioClassificationDataset`.
Assuming you have some wave files that stored in your own directory. You should prepare a meta file with the information of filepaths and labels. For example the absolute path of it is `/PATH/TO/META_FILE.txt`:
```
@@ -14,7 +14,7 @@ Assuming you have some wave files that stored in your own directory. You should
Here is an example to build your custom dataset in `custom_dataset.py`:
```python
-from paddleaudio.datasets.dataset import AudioClassificationDataset
+from paddlespeech.audio.datasets.dataset import AudioClassificationDataset
class CustomDataset(AudioClassificationDataset):
meta_file = '/PATH/TO/META_FILE.txt'
@@ -48,7 +48,7 @@ class CustomDataset(AudioClassificationDataset):
Then you can build dataset and data loader from `CustomDataset`:
```python
import paddle
-from paddleaudio.features import LogMelSpectrogram
+from paddlespeech.audio.transform.spectrogram import LogMelSpectrogram
from custom_dataset import CustomDataset
diff --git a/docs/tutorial/cls/cls_tutorial.ipynb b/docs/tutorial/cls/cls_tutorial.ipynb
index 3cee64991..e37b086f7 100644
--- a/docs/tutorial/cls/cls_tutorial.ipynb
+++ b/docs/tutorial/cls/cls_tutorial.ipynb
@@ -52,8 +52,8 @@
"metadata": {},
"outputs": [],
"source": [
- "# 环境准备:安装paddlespeech和paddleaudio\n",
- "!pip install --upgrade pip && pip install paddlespeech paddleaudio -U"
+ "# 环境准备:安装paddlespeech\n",
+ "!pip install --upgrade pip && pip install paddlespeech -U"
]
},
{
@@ -100,7 +100,7 @@
"metadata": {},
"outputs": [],
"source": [
- "from paddleaudio import load\n",
+ "from paddlespeech.audio.backends import load\n",
"data, sr = load(file='./dog.wav', mono=True, dtype='float32') # 单通道,float32音频样本点\n",
"print('wav shape: {}'.format(data.shape))\n",
"print('sample rate: {}'.format(sr))\n",
@@ -191,7 +191,7 @@
"
图片来源:https://ww2.mathworks.cn/help/audio/ref/mfcc.html\n",
"\n",
"
\n",
- "下面例子采用 `paddleaudio.features.LogMelSpectrogram` 演示如何提取示例音频的 LogFBank:"
+ "下面例子采用 `paddlespeech.audio.transform.spectrogram.LogMelSpectrogram` 演示如何提取示例音频的 LogFBank:"
]
},
{
@@ -200,7 +200,7 @@
"metadata": {},
"outputs": [],
"source": [
- "from paddleaudio.features import LogMelSpectrogram\n",
+ "from paddlespeech.audio.transform.spectrogram import LogMelSpectrogram\n",
"\n",
"f_min=50.0\n",
"f_max=14000.0\n",
@@ -337,7 +337,7 @@
"metadata": {},
"outputs": [],
"source": [
- "from paddleaudio.datasets import ESC50\n",
+ "from paddlespeech.audio.datasets import ESC50\n",
"\n",
"train_ds = ESC50(mode='train', sample_rate=sr)\n",
"dev_ds = ESC50(mode='dev', sample_rate=sr)"
@@ -348,7 +348,7 @@
"metadata": {},
"source": [
"### 3.1.2 特征提取\n",
- "通过下列代码,用 `paddleaudio.features.LogMelSpectrogram` 初始化一个音频特征提取器,在训练过程中实时提取音频的 LogFBank 特征,其中主要的参数如下: "
+ "通过下列代码,用 `paddlespeech.audio.transform.spectrogram.LogMelSpectrogram` 初始化一个音频特征提取器,在训练过程中实时提取音频的 LogFBank 特征,其中主要的参数如下: "
]
},
{
@@ -481,7 +481,7 @@
"metadata": {},
"outputs": [],
"source": [
- "from paddleaudio.utils import logger\n",
+ "from paddlespeech.audio.utils import logger\n",
"\n",
"epochs = 20\n",
"steps_per_epoch = len(train_loader)\n",
diff --git a/examples/tess/cls0/local/train.py b/examples/tess/cls0/local/train.py
index f023a37b7..ad4926d76 100644
--- a/examples/tess/cls0/local/train.py
+++ b/examples/tess/cls0/local/train.py
@@ -16,9 +16,9 @@ import os
import paddle
import yaml
-from paddleaudio.utils import logger
-from paddleaudio.utils import Timer
+from paddlespeech.audio.utils import logger
+from paddlespeech.audio.utils.time import Timer
from paddlespeech.cls.models import SoundClassifier
from paddlespeech.utils.dynamic_import import dynamic_import
diff --git a/examples/voxceleb/sv0/local/data_prepare.py b/examples/voxceleb/sv0/local/data_prepare.py
index b4486b6f0..e5a5dff7b 100644
--- a/examples/voxceleb/sv0/local/data_prepare.py
+++ b/examples/voxceleb/sv0/local/data_prepare.py
@@ -14,9 +14,9 @@
import argparse
import paddle
-from paddleaudio.datasets.voxceleb import VoxCeleb
from yacs.config import CfgNode
+from paddlespeech.audio.datasets.voxceleb import VoxCeleb
from paddlespeech.s2t.utils.log import Log
from paddlespeech.vector.io.augment import build_augment_pipeline
from paddlespeech.vector.training.seeding import seed_everything
diff --git a/examples/voxceleb/sv0/local/make_rirs_noise_csv_dataset_from_json.py b/examples/voxceleb/sv0/local/make_rirs_noise_csv_dataset_from_json.py
index 11908fe63..b65fa35b4 100644
--- a/examples/voxceleb/sv0/local/make_rirs_noise_csv_dataset_from_json.py
+++ b/examples/voxceleb/sv0/local/make_rirs_noise_csv_dataset_from_json.py
@@ -21,9 +21,9 @@ import os
from typing import List
import tqdm
-from paddleaudio.backends import soundfile_load as load_audio
from yacs.config import CfgNode
+from paddlespeech.audio.backends import soundfile_load as load_audio
from paddlespeech.s2t.utils.log import Log
from paddlespeech.vector.utils.vector_utils import get_chunks
diff --git a/examples/voxceleb/sv0/local/make_vox_csv_dataset_from_json.py b/examples/voxceleb/sv0/local/make_vox_csv_dataset_from_json.py
index ebeb598a4..6ef2064a0 100644
--- a/examples/voxceleb/sv0/local/make_vox_csv_dataset_from_json.py
+++ b/examples/voxceleb/sv0/local/make_vox_csv_dataset_from_json.py
@@ -22,9 +22,9 @@ import os
import random
import tqdm
-from paddleaudio.backends import soundfile_load as load_audio
from yacs.config import CfgNode
+from paddlespeech.audio.backends import soundfile_load as load_audio
from paddlespeech.s2t.utils.log import Log
from paddlespeech.vector.utils.vector_utils import get_chunks
diff --git a/paddlespeech/audio/__init__.py b/paddlespeech/audio/__init__.py
index a7cf6caaf..f35bff869 100644
--- a/paddlespeech/audio/__init__.py
+++ b/paddlespeech/audio/__init__.py
@@ -11,6 +11,11 @@
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
+from . import backends
+from . import compliance
+from . import datasets
+from . import functional
from . import streamdata
from . import text
from . import transform
+from . import utils
diff --git a/paddlespeech/audio/backends/__init__.py b/paddlespeech/audio/backends/__init__.py
new file mode 100644
index 000000000..7e4ee6506
--- /dev/null
+++ b/paddlespeech/audio/backends/__init__.py
@@ -0,0 +1,20 @@
+# Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+from .soundfile_backend import depth_convert
+from .soundfile_backend import load
+from .soundfile_backend import normalize
+from .soundfile_backend import resample
+from .soundfile_backend import soundfile_load
+from .soundfile_backend import soundfile_save
+from .soundfile_backend import to_mono
diff --git a/paddlespeech/audio/backends/common.py b/paddlespeech/audio/backends/common.py
new file mode 100644
index 000000000..3065fe89f
--- /dev/null
+++ b/paddlespeech/audio/backends/common.py
@@ -0,0 +1,53 @@
+# Token from https://github.com/pytorch/audio/blob/main/torchaudio/backend/common.py with modification.
+
+
+class AudioInfo:
+ """return of info function.
+
+ This class is used by :ref:`"sox_io" backend` and
+ :ref:`"soundfile" backend with the new interface`.
+
+ :ivar int sample_rate: Sample rate
+ :ivar int num_frames: The number of frames
+ :ivar int num_channels: The number of channels
+ :ivar int bits_per_sample: The number of bits per sample. This is 0 for lossy formats,
+ or when it cannot be accurately inferred.
+ :ivar str encoding: Audio encoding
+ The values encoding can take are one of the following:
+
+ * ``PCM_S``: Signed integer linear PCM
+ * ``PCM_U``: Unsigned integer linear PCM
+ * ``PCM_F``: Floating point linear PCM
+ * ``FLAC``: Flac, Free Lossless Audio Codec
+ * ``ULAW``: Mu-law
+ * ``ALAW``: A-law
+ * ``MP3`` : MP3, MPEG-1 Audio Layer III
+ * ``VORBIS``: OGG Vorbis
+ * ``AMR_WB``: Adaptive Multi-Rate
+ * ``AMR_NB``: Adaptive Multi-Rate Wideband
+ * ``OPUS``: Opus
+ * ``HTK``: Single channel 16-bit PCM
+ * ``UNKNOWN`` : None of above
+ """
+
+ def __init__(
+ self,
+ sample_rate: int,
+ num_frames: int,
+ num_channels: int,
+ bits_per_sample: int,
+ encoding: str, ):
+ self.sample_rate = sample_rate
+ self.num_frames = num_frames
+ self.num_channels = num_channels
+ self.bits_per_sample = bits_per_sample
+ self.encoding = encoding
+
+ def __str__(self):
+ return (f"AudioMetaData("
+ f"sample_rate={self.sample_rate}, "
+ f"num_frames={self.num_frames}, "
+ f"num_channels={self.num_channels}, "
+ f"bits_per_sample={self.bits_per_sample}, "
+ f"encoding={self.encoding}"
+ f")")
diff --git a/paddlespeech/audio/backends/soundfile_backend.py b/paddlespeech/audio/backends/soundfile_backend.py
new file mode 100644
index 000000000..7611fd297
--- /dev/null
+++ b/paddlespeech/audio/backends/soundfile_backend.py
@@ -0,0 +1,677 @@
+# Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+import os
+import warnings
+from typing import Optional
+from typing import Tuple
+
+import numpy as np
+import paddle
+import resampy
+import soundfile
+from scipy.io import wavfile
+
+from ..utils import depth_convert
+from ..utils import ParameterError
+from .common import AudioInfo
+
+__all__ = [
+ 'resample',
+ 'to_mono',
+ 'normalize',
+ 'save',
+ 'soundfile_save',
+ 'load',
+ 'soundfile_load',
+ 'info',
+]
+NORMALMIZE_TYPES = ['linear', 'gaussian']
+MERGE_TYPES = ['ch0', 'ch1', 'random', 'average']
+RESAMPLE_MODES = ['kaiser_best', 'kaiser_fast']
+EPS = 1e-8
+
+
+def resample(y: np.ndarray,
+ src_sr: int,
+ target_sr: int,
+ mode: str='kaiser_fast') -> np.ndarray:
+ """Audio resampling.
+
+ Args:
+ y (np.ndarray): Input waveform array in 1D or 2D.
+ src_sr (int): Source sample rate.
+ target_sr (int): Target sample rate.
+ mode (str, optional): The resampling filter to use. Defaults to 'kaiser_fast'.
+
+ Returns:
+ np.ndarray: `y` resampled to `target_sr`
+ """
+
+ if mode == 'kaiser_best':
+ warnings.warn(
+ f'Using resampy in kaiser_best to {src_sr}=>{target_sr}. This function is pretty slow, \
+ we recommend the mode kaiser_fast in large scale audio training')
+
+ if not isinstance(y, np.ndarray):
+ raise ParameterError(
+ 'Only support numpy np.ndarray, but received y in {type(y)}')
+
+ if mode not in RESAMPLE_MODES:
+ raise ParameterError(f'resample mode must in {RESAMPLE_MODES}')
+
+ return resampy.resample(y, src_sr, target_sr, filter=mode)
+
+
+def to_mono(y: np.ndarray, merge_type: str='average') -> np.ndarray:
+ """Convert sterior audio to mono.
+
+ Args:
+ y (np.ndarray): Input waveform array in 1D or 2D.
+ merge_type (str, optional): Merge type to generate mono waveform. Defaults to 'average'.
+
+ Returns:
+ np.ndarray: `y` with mono channel.
+ """
+
+ if merge_type not in MERGE_TYPES:
+ raise ParameterError(
+ f'Unsupported merge type {merge_type}, available types are {MERGE_TYPES}'
+ )
+ if y.ndim > 2:
+ raise ParameterError(
+ f'Unsupported audio array, y.ndim > 2, the shape is {y.shape}')
+ if y.ndim == 1: # nothing to merge
+ return y
+
+ if merge_type == 'ch0':
+ return y[0]
+ if merge_type == 'ch1':
+ return y[1]
+ if merge_type == 'random':
+ return y[np.random.randint(0, 2)]
+
+ # need to do averaging according to dtype
+
+ if y.dtype == 'float32':
+ y_out = (y[0] + y[1]) * 0.5
+ elif y.dtype == 'int16':
+ y_out = y.astype('int32')
+ y_out = (y_out[0] + y_out[1]) // 2
+ y_out = np.clip(y_out, np.iinfo(y.dtype).min,
+ np.iinfo(y.dtype).max).astype(y.dtype)
+
+ elif y.dtype == 'int8':
+ y_out = y.astype('int16')
+ y_out = (y_out[0] + y_out[1]) // 2
+ y_out = np.clip(y_out, np.iinfo(y.dtype).min,
+ np.iinfo(y.dtype).max).astype(y.dtype)
+ else:
+ raise ParameterError(f'Unsupported dtype: {y.dtype}')
+ return y_out
+
+
+def soundfile_load_(file: os.PathLike,
+ offset: Optional[float]=None,
+ dtype: str='int16',
+ duration: Optional[int]=None) -> Tuple[np.ndarray, int]:
+ """Load audio using soundfile library. This function load audio file using libsndfile.
+
+ Args:
+ file (os.PathLike): File of waveform.
+ offset (Optional[float], optional): Offset to the start of waveform. Defaults to None.
+ dtype (str, optional): Data type of waveform. Defaults to 'int16'.
+ duration (Optional[int], optional): Duration of waveform to read. Defaults to None.
+
+ Returns:
+ Tuple[np.ndarray, int]: Waveform in ndarray and its samplerate.
+ """
+ with soundfile.SoundFile(file) as sf_desc:
+ sr_native = sf_desc.samplerate
+ if offset:
+ sf_desc.seek(int(offset * sr_native))
+ if duration is not None:
+ frame_duration = int(duration * sr_native)
+ else:
+ frame_duration = -1
+ y = sf_desc.read(frames=frame_duration, dtype=dtype, always_2d=False).T
+
+ return y, sf_desc.samplerate
+
+
+def normalize(y: np.ndarray, norm_type: str='linear',
+ mul_factor: float=1.0) -> np.ndarray:
+ """Normalize an input audio with additional multiplier.
+
+ Args:
+ y (np.ndarray): Input waveform array in 1D or 2D.
+ norm_type (str, optional): Type of normalization. Defaults to 'linear'.
+ mul_factor (float, optional): Scaling factor. Defaults to 1.0.
+
+ Returns:
+ np.ndarray: `y` after normalization.
+ """
+
+ if norm_type == 'linear':
+ amax = np.max(np.abs(y))
+ factor = 1.0 / (amax + EPS)
+ y = y * factor * mul_factor
+ elif norm_type == 'gaussian':
+ amean = np.mean(y)
+ astd = np.std(y)
+ astd = max(astd, EPS)
+ y = mul_factor * (y - amean) / astd
+ else:
+ raise NotImplementedError(f'norm_type should be in {NORMALMIZE_TYPES}')
+
+ return y
+
+
+def soundfile_save(y: np.ndarray, sr: int, file: os.PathLike) -> None:
+ """Save audio file to disk. This function saves audio to disk using scipy.io.wavfile, with additional step to convert input waveform to int16.
+
+ Args:
+ y (np.ndarray): Input waveform array in 1D or 2D.
+ sr (int): Sample rate.
+ file (os.PathLike): Path of audio file to save.
+ """
+ if not file.endswith('.wav'):
+ raise ParameterError(
+ f'only .wav file supported, but dst file name is: {file}')
+
+ if sr <= 0:
+ raise ParameterError(
+ f'Sample rate should be larger than 0, received sr = {sr}')
+
+ if y.dtype not in ['int16', 'int8']:
+ warnings.warn(
+ f'input data type is {y.dtype}, will convert data to int16 format before saving'
+ )
+ y_out = depth_convert(y, 'int16')
+ else:
+ y_out = y
+
+ wavfile.write(file, sr, y_out)
+
+
+def soundfile_load(
+ file: os.PathLike,
+ sr: Optional[int]=None,
+ mono: bool=True,
+ merge_type: str='average', # ch0,ch1,random,average
+ normal: bool=True,
+ norm_type: str='linear',
+ norm_mul_factor: float=1.0,
+ offset: float=0.0,
+ duration: Optional[int]=None,
+ dtype: str='float32',
+ resample_mode: str='kaiser_fast') -> Tuple[np.ndarray, int]:
+ """Load audio file from disk. This function loads audio from disk using using audio backend.
+
+ Args:
+ file (os.PathLike): Path of audio file to load.
+ sr (Optional[int], optional): Sample rate of loaded waveform. Defaults to None.
+ mono (bool, optional): Return waveform with mono channel. Defaults to True.
+ merge_type (str, optional): Merge type of multi-channels waveform. Defaults to 'average'.
+ normal (bool, optional): Waveform normalization. Defaults to True.
+ norm_type (str, optional): Type of normalization. Defaults to 'linear'.
+ norm_mul_factor (float, optional): Scaling factor. Defaults to 1.0.
+ offset (float, optional): Offset to the start of waveform. Defaults to 0.0.
+ duration (Optional[int], optional): Duration of waveform to read. Defaults to None.
+ dtype (str, optional): Data type of waveform. Defaults to 'float32'.
+ resample_mode (str, optional): The resampling filter to use. Defaults to 'kaiser_fast'.
+
+ Returns:
+ Tuple[np.ndarray, int]: Waveform in ndarray and its samplerate.
+ """
+
+ y, r = soundfile_load_(file, offset=offset, dtype=dtype, duration=duration)
+
+ if not ((y.ndim == 1 and len(y) > 0) or (y.ndim == 2 and len(y[0]) > 0)):
+ raise ParameterError(f'audio file {file} looks empty')
+
+ if mono:
+ y = to_mono(y, merge_type)
+
+ if sr is not None and sr != r:
+ y = resample(y, r, sr, mode=resample_mode)
+ r = sr
+
+ if normal:
+ y = normalize(y, norm_type, norm_mul_factor)
+ elif dtype in ['int8', 'int16']:
+ # still need to do normalization, before depth conversion
+ y = normalize(y, 'linear', 1.0)
+
+ y = depth_convert(y, dtype)
+ return y, r
+
+
+#The code below is taken from: https://github.com/pytorch/audio/blob/main/torchaudio/backend/soundfile_backend.py, with some modifications.
+
+
+def _get_subtype_for_wav(dtype: paddle.dtype,
+ encoding: str,
+ bits_per_sample: int):
+ if not encoding:
+ if not bits_per_sample:
+ subtype = {
+ paddle.uint8: "PCM_U8",
+ paddle.int16: "PCM_16",
+ paddle.int32: "PCM_32",
+ paddle.float32: "FLOAT",
+ paddle.float64: "DOUBLE",
+ }.get(dtype)
+ if not subtype:
+ raise ValueError(f"Unsupported dtype for wav: {dtype}")
+ return subtype
+ if bits_per_sample == 8:
+ return "PCM_U8"
+ return f"PCM_{bits_per_sample}"
+ if encoding == "PCM_S":
+ if not bits_per_sample:
+ return "PCM_32"
+ if bits_per_sample == 8:
+ raise ValueError("wav does not support 8-bit signed PCM encoding.")
+ return f"PCM_{bits_per_sample}"
+ if encoding == "PCM_U":
+ if bits_per_sample in (None, 8):
+ return "PCM_U8"
+ raise ValueError("wav only supports 8-bit unsigned PCM encoding.")
+ if encoding == "PCM_F":
+ if bits_per_sample in (None, 32):
+ return "FLOAT"
+ if bits_per_sample == 64:
+ return "DOUBLE"
+ raise ValueError("wav only supports 32/64-bit float PCM encoding.")
+ if encoding == "ULAW":
+ if bits_per_sample in (None, 8):
+ return "ULAW"
+ raise ValueError("wav only supports 8-bit mu-law encoding.")
+ if encoding == "ALAW":
+ if bits_per_sample in (None, 8):
+ return "ALAW"
+ raise ValueError("wav only supports 8-bit a-law encoding.")
+ raise ValueError(f"wav does not support {encoding}.")
+
+
+def _get_subtype_for_sphere(encoding: str, bits_per_sample: int):
+ if encoding in (None, "PCM_S"):
+ return f"PCM_{bits_per_sample}" if bits_per_sample else "PCM_32"
+ if encoding in ("PCM_U", "PCM_F"):
+ raise ValueError(f"sph does not support {encoding} encoding.")
+ if encoding == "ULAW":
+ if bits_per_sample in (None, 8):
+ return "ULAW"
+ raise ValueError("sph only supports 8-bit for mu-law encoding.")
+ if encoding == "ALAW":
+ return "ALAW"
+ raise ValueError(f"sph does not support {encoding}.")
+
+
+def _get_subtype(dtype: paddle.dtype,
+ format: str,
+ encoding: str,
+ bits_per_sample: int):
+ if format == "wav":
+ return _get_subtype_for_wav(dtype, encoding, bits_per_sample)
+ if format == "flac":
+ if encoding:
+ raise ValueError("flac does not support encoding.")
+ if not bits_per_sample:
+ return "PCM_16"
+ if bits_per_sample > 24:
+ raise ValueError("flac does not support bits_per_sample > 24.")
+ return "PCM_S8" if bits_per_sample == 8 else f"PCM_{bits_per_sample}"
+ if format in ("ogg", "vorbis"):
+ if encoding or bits_per_sample:
+ raise ValueError(
+ "ogg/vorbis does not support encoding/bits_per_sample.")
+ return "VORBIS"
+ if format == "sph":
+ return _get_subtype_for_sphere(encoding, bits_per_sample)
+ if format in ("nis", "nist"):
+ return "PCM_16"
+ raise ValueError(f"Unsupported format: {format}")
+
+
+def save(
+ filepath: str,
+ src: paddle.Tensor,
+ sample_rate: int,
+ channels_first: bool=True,
+ compression: Optional[float]=None,
+ format: Optional[str]=None,
+ encoding: Optional[str]=None,
+ bits_per_sample: Optional[int]=None, ):
+ """Save audio data to file.
+
+ Note:
+ The formats this function can handle depend on the soundfile installation.
+ This function is tested on the following formats;
+
+ * WAV
+
+ * 32-bit floating-point
+ * 32-bit signed integer
+ * 16-bit signed integer
+ * 8-bit unsigned integer
+
+ * FLAC
+ * OGG/VORBIS
+ * SPHERE
+
+ Note:
+ ``filepath`` argument is intentionally annotated as ``str`` only, even though it accepts
+ ``pathlib.Path`` object as well. This is for the consistency with ``"sox_io"`` backend,
+
+ Args:
+ filepath (str or pathlib.Path): Path to audio file.
+ src (paddle.Tensor): Audio data to save. must be 2D tensor.
+ sample_rate (int): sampling rate
+ channels_first (bool, optional): If ``True``, the given tensor is interpreted as `[channel, time]`,
+ otherwise `[time, channel]`.
+ compression (float of None, optional): Not used.
+ It is here only for interface compatibility reason with "sox_io" backend.
+ format (str or None, optional): Override the audio format.
+ When ``filepath`` argument is path-like object, audio format is
+ inferred from file extension. If the file extension is missing or
+ different, you can specify the correct format with this argument.
+
+ When ``filepath`` argument is file-like object,
+ this argument is required.
+
+ Valid values are ``"wav"``, ``"ogg"``, ``"vorbis"``,
+ ``"flac"`` and ``"sph"``.
+ encoding (str or None, optional): Changes the encoding for supported formats.
+ This argument is effective only for supported formats, such as
+ ``"wav"``, ``""flac"`` and ``"sph"``. Valid values are:
+
+ - ``"PCM_S"`` (signed integer Linear PCM)
+ - ``"PCM_U"`` (unsigned integer Linear PCM)
+ - ``"PCM_F"`` (floating point PCM)
+ - ``"ULAW"`` (mu-law)
+ - ``"ALAW"`` (a-law)
+
+ bits_per_sample (int or None, optional): Changes the bit depth for the
+ supported formats.
+ When ``format`` is one of ``"wav"``, ``"flac"`` or ``"sph"``,
+ you can change the bit depth.
+ Valid values are ``8``, ``16``, ``24``, ``32`` and ``64``.
+
+ Supported formats/encodings/bit depth/compression are:
+
+ ``"wav"``
+ - 32-bit floating-point PCM
+ - 32-bit signed integer PCM
+ - 24-bit signed integer PCM
+ - 16-bit signed integer PCM
+ - 8-bit unsigned integer PCM
+ - 8-bit mu-law
+ - 8-bit a-law
+
+ Note:
+ Default encoding/bit depth is determined by the dtype of
+ the input Tensor.
+
+ ``"flac"``
+ - 8-bit
+ - 16-bit (default)
+ - 24-bit
+
+ ``"ogg"``, ``"vorbis"``
+ - Doesn't accept changing configuration.
+
+ ``"sph"``
+ - 8-bit signed integer PCM
+ - 16-bit signed integer PCM
+ - 24-bit signed integer PCM
+ - 32-bit signed integer PCM (default)
+ - 8-bit mu-law
+ - 8-bit a-law
+ - 16-bit a-law
+ - 24-bit a-law
+ - 32-bit a-law
+
+ """
+ if src.ndim != 2:
+ raise ValueError(f"Expected 2D Tensor, got {src.ndim}D.")
+ if compression is not None:
+ warnings.warn(
+ '`save` function of "soundfile" backend does not support "compression" parameter. '
+ "The argument is silently ignored.")
+ if hasattr(filepath, "write"):
+ if format is None:
+ raise RuntimeError(
+ "`format` is required when saving to file object.")
+ ext = format.lower()
+ else:
+ ext = str(filepath).split(".")[-1].lower()
+
+ if bits_per_sample not in (None, 8, 16, 24, 32, 64):
+ raise ValueError("Invalid bits_per_sample.")
+ if bits_per_sample == 24:
+ warnings.warn(
+ "Saving audio with 24 bits per sample might warp samples near -1. "
+ "Using 16 bits per sample might be able to avoid this.")
+ subtype = _get_subtype(src.dtype, ext, encoding, bits_per_sample)
+
+ # sph is a extension used in TED-LIUM but soundfile does not recognize it as NIST format,
+ # so we extend the extensions manually here
+ if ext in ["nis", "nist", "sph"] and format is None:
+ format = "NIST"
+
+ if channels_first:
+ src = src.t()
+
+ soundfile.write(
+ file=filepath,
+ data=src,
+ samplerate=sample_rate,
+ subtype=subtype,
+ format=format)
+
+
+_SUBTYPE2DTYPE = {
+ "PCM_S8": "int8",
+ "PCM_U8": "uint8",
+ "PCM_16": "int16",
+ "PCM_32": "int32",
+ "FLOAT": "float32",
+ "DOUBLE": "float64",
+}
+
+
+def load(
+ filepath: str,
+ frame_offset: int=0,
+ num_frames: int=-1,
+ normalize: bool=True,
+ channels_first: bool=True,
+ format: Optional[str]=None, ) -> Tuple[paddle.Tensor, int]:
+ """Load audio data from file.
+
+ Note:
+ The formats this function can handle depend on the soundfile installation.
+ This function is tested on the following formats;
+
+ * WAV
+
+ * 32-bit floating-point
+ * 32-bit signed integer
+ * 16-bit signed integer
+ * 8-bit unsigned integer
+
+ * FLAC
+ * OGG/VORBIS
+ * SPHERE
+
+ By default (``normalize=True``, ``channels_first=True``), this function returns Tensor with
+ ``float32`` dtype and the shape of `[channel, time]`.
+ The samples are normalized to fit in the range of ``[-1.0, 1.0]``.
+
+ When the input format is WAV with integer type, such as 32-bit signed integer, 16-bit
+ signed integer and 8-bit unsigned integer (24-bit signed integer is not supported),
+ by providing ``normalize=False``, this function can return integer Tensor, where the samples
+ are expressed within the whole range of the corresponding dtype, that is, ``int32`` tensor
+ for 32-bit signed PCM, ``int16`` for 16-bit signed PCM and ``uint8`` for 8-bit unsigned PCM.
+
+ ``normalize`` parameter has no effect on 32-bit floating-point WAV and other formats, such as
+ ``flac`` and ``mp3``.
+ For these formats, this function always returns ``float32`` Tensor with values normalized to
+ ``[-1.0, 1.0]``.
+
+ Note:
+ ``filepath`` argument is intentionally annotated as ``str`` only, even though it accepts
+ ``pathlib.Path`` object as well. This is for the consistency with ``"sox_io"`` backend.
+
+ Args:
+ filepath (path-like object or file-like object):
+ Source of audio data.
+ frame_offset (int, optional):
+ Number of frames to skip before start reading data.
+ num_frames (int, optional):
+ Maximum number of frames to read. ``-1`` reads all the remaining samples,
+ starting from ``frame_offset``.
+ This function may return the less number of frames if there is not enough
+ frames in the given file.
+ normalize (bool, optional):
+ When ``True``, this function always return ``float32``, and sample values are
+ normalized to ``[-1.0, 1.0]``.
+ If input file is integer WAV, giving ``False`` will change the resulting Tensor type to
+ integer type.
+ This argument has no effect for formats other than integer WAV type.
+ channels_first (bool, optional):
+ When True, the returned Tensor has dimension `[channel, time]`.
+ Otherwise, the returned Tensor's dimension is `[time, channel]`.
+ format (str or None, optional):
+ Not used. PySoundFile does not accept format hint.
+
+ Returns:
+ (paddle.Tensor, int): Resulting Tensor and sample rate.
+ If the input file has integer wav format and normalization is off, then it has
+ integer type, else ``float32`` type. If ``channels_first=True``, it has
+ `[channel, time]` else `[time, channel]`.
+ """
+ with soundfile.SoundFile(filepath, "r") as file_:
+ if file_.format != "WAV" or normalize:
+ dtype = "float32"
+ elif file_.subtype not in _SUBTYPE2DTYPE:
+ raise ValueError(f"Unsupported subtype: {file_.subtype}")
+ else:
+ dtype = _SUBTYPE2DTYPE[file_.subtype]
+
+ frames = file_._prepare_read(frame_offset, None, num_frames)
+ waveform = file_.read(frames, dtype, always_2d=True)
+ sample_rate = file_.samplerate
+
+ waveform = paddle.to_tensor(waveform)
+ if channels_first:
+ waveform = paddle.transpose(waveform, perm=[1, 0])
+ return waveform, sample_rate
+
+
+# Mapping from soundfile subtype to number of bits per sample.
+# This is mostly heuristical and the value is set to 0 when it is irrelevant
+# (lossy formats) or when it can't be inferred.
+# For ADPCM (and G72X) subtypes, it's hard to infer the bit depth because it's not part of the standard:
+# According to https://en.wikipedia.org/wiki/Adaptive_differential_pulse-code_modulation#In_telephony,
+# the default seems to be 8 bits but it can be compressed further to 4 bits.
+# The dict is inspired from
+# https://github.com/bastibe/python-soundfile/blob/744efb4b01abc72498a96b09115b42a4cabd85e4/soundfile.py#L66-L94
+_SUBTYPE_TO_BITS_PER_SAMPLE = {
+ "PCM_S8": 8, # Signed 8 bit data
+ "PCM_16": 16, # Signed 16 bit data
+ "PCM_24": 24, # Signed 24 bit data
+ "PCM_32": 32, # Signed 32 bit data
+ "PCM_U8": 8, # Unsigned 8 bit data (WAV and RAW only)
+ "FLOAT": 32, # 32 bit float data
+ "DOUBLE": 64, # 64 bit float data
+ "ULAW": 8, # U-Law encoded. See https://en.wikipedia.org/wiki/G.711#Types
+ "ALAW": 8, # A-Law encoded. See https://en.wikipedia.org/wiki/G.711#Types
+ "IMA_ADPCM": 0, # IMA ADPCM.
+ "MS_ADPCM": 0, # Microsoft ADPCM.
+ "GSM610":
+ 0, # GSM 6.10 encoding. (Wikipedia says 1.625 bit depth?? https://en.wikipedia.org/wiki/Full_Rate)
+ "VOX_ADPCM": 0, # OKI / Dialogix ADPCM
+ "G721_32": 0, # 32kbs G721 ADPCM encoding.
+ "G723_24": 0, # 24kbs G723 ADPCM encoding.
+ "G723_40": 0, # 40kbs G723 ADPCM encoding.
+ "DWVW_12": 12, # 12 bit Delta Width Variable Word encoding.
+ "DWVW_16": 16, # 16 bit Delta Width Variable Word encoding.
+ "DWVW_24": 24, # 24 bit Delta Width Variable Word encoding.
+ "DWVW_N": 0, # N bit Delta Width Variable Word encoding.
+ "DPCM_8": 8, # 8 bit differential PCM (XI only)
+ "DPCM_16": 16, # 16 bit differential PCM (XI only)
+ "VORBIS": 0, # Xiph Vorbis encoding. (lossy)
+ "ALAC_16": 16, # Apple Lossless Audio Codec (16 bit).
+ "ALAC_20": 20, # Apple Lossless Audio Codec (20 bit).
+ "ALAC_24": 24, # Apple Lossless Audio Codec (24 bit).
+ "ALAC_32": 32, # Apple Lossless Audio Codec (32 bit).
+}
+
+
+def _get_bit_depth(subtype):
+ if subtype not in _SUBTYPE_TO_BITS_PER_SAMPLE:
+ warnings.warn(
+ f"The {subtype} subtype is unknown to PaddleAudio. As a result, the bits_per_sample "
+ "attribute will be set to 0. If you are seeing this warning, please "
+ "report by opening an issue on github (after checking for existing/closed ones). "
+ "You may otherwise ignore this warning.")
+ return _SUBTYPE_TO_BITS_PER_SAMPLE.get(subtype, 0)
+
+
+_SUBTYPE_TO_ENCODING = {
+ "PCM_S8": "PCM_S",
+ "PCM_16": "PCM_S",
+ "PCM_24": "PCM_S",
+ "PCM_32": "PCM_S",
+ "PCM_U8": "PCM_U",
+ "FLOAT": "PCM_F",
+ "DOUBLE": "PCM_F",
+ "ULAW": "ULAW",
+ "ALAW": "ALAW",
+ "VORBIS": "VORBIS",
+}
+
+
+def _get_encoding(format: str, subtype: str):
+ if format == "FLAC":
+ return "FLAC"
+ return _SUBTYPE_TO_ENCODING.get(subtype, "UNKNOWN")
+
+
+def info(filepath: str, format: Optional[str]=None) -> AudioInfo:
+ """Get signal information of an audio file.
+
+ Note:
+ ``filepath`` argument is intentionally annotated as ``str`` only, even though it accepts
+ ``pathlib.Path`` object as well. This is for the consistency with ``"sox_io"`` backend,
+
+ Args:
+ filepath (path-like object or file-like object):
+ Source of audio data.
+ format (str or None, optional):
+ Not used. PySoundFile does not accept format hint.
+
+ Returns:
+ AudioInfo: meta data of the given audio.
+
+ """
+ sinfo = soundfile.info(filepath)
+ return AudioInfo(
+ sinfo.samplerate,
+ sinfo.frames,
+ sinfo.channels,
+ bits_per_sample=_get_bit_depth(sinfo.subtype),
+ encoding=_get_encoding(sinfo.format, sinfo.subtype), )
diff --git a/paddlespeech/audio/compliance/__init__.py b/paddlespeech/audio/compliance/__init__.py
new file mode 100644
index 000000000..c08f9ab11
--- /dev/null
+++ b/paddlespeech/audio/compliance/__init__.py
@@ -0,0 +1,15 @@
+# Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+from . import kaldi
+from . import librosa
diff --git a/paddlespeech/audio/compliance/kaldi.py b/paddlespeech/audio/compliance/kaldi.py
new file mode 100644
index 000000000..a94ec4053
--- /dev/null
+++ b/paddlespeech/audio/compliance/kaldi.py
@@ -0,0 +1,643 @@
+# Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+# Modified from torchaudio(https://github.com/pytorch/audio)
+import math
+from typing import Tuple
+
+import paddle
+from paddle import Tensor
+
+from ..functional import create_dct
+from ..functional.window import get_window
+
+__all__ = [
+ 'spectrogram',
+ 'fbank',
+ 'mfcc',
+]
+
+# window types
+HANNING = 'hann'
+HAMMING = 'hamming'
+POVEY = 'povey'
+RECTANGULAR = 'rect'
+BLACKMAN = 'blackman'
+
+
+def _get_epsilon(dtype):
+ return paddle.to_tensor(1e-07, dtype=dtype)
+
+
+def _next_power_of_2(x: int) -> int:
+ return 1 if x == 0 else 2**(x - 1).bit_length()
+
+
+def _get_strided(waveform: Tensor,
+ window_size: int,
+ window_shift: int,
+ snip_edges: bool) -> Tensor:
+ assert waveform.dim() == 1
+ num_samples = waveform.shape[0]
+
+ if snip_edges:
+ if num_samples < window_size:
+ return paddle.empty((0, 0), dtype=waveform.dtype)
+ else:
+ m = 1 + (num_samples - window_size) // window_shift
+ else:
+ reversed_waveform = paddle.flip(waveform, [0])
+ m = (num_samples + (window_shift // 2)) // window_shift
+ pad = window_size // 2 - window_shift // 2
+ pad_right = reversed_waveform
+ if pad > 0:
+ pad_left = reversed_waveform[-pad:]
+ waveform = paddle.concat((pad_left, waveform, pad_right), axis=0)
+ else:
+ waveform = paddle.concat((waveform[-pad:], pad_right), axis=0)
+
+ return paddle.signal.frame(waveform, window_size, window_shift)[:, :m].T
+
+
+def _feature_window_function(
+ window_type: str,
+ window_size: int,
+ blackman_coeff: float,
+ dtype: int, ) -> Tensor:
+ if window_type == "hann":
+ return get_window('hann', window_size, fftbins=False, dtype=dtype)
+ elif window_type == "hamming":
+ return get_window('hamming', window_size, fftbins=False, dtype=dtype)
+ elif window_type == "povey":
+ return get_window(
+ 'hann', window_size, fftbins=False, dtype=dtype).pow(0.85)
+ elif window_type == "rect":
+ return paddle.ones([window_size], dtype=dtype)
+ elif window_type == "blackman":
+ a = 2 * math.pi / (window_size - 1)
+ window_function = paddle.arange(window_size, dtype=dtype)
+ return (blackman_coeff - 0.5 * paddle.cos(a * window_function) +
+ (0.5 - blackman_coeff) * paddle.cos(2 * a * window_function)
+ ).astype(dtype)
+ else:
+ raise Exception('Invalid window type ' + window_type)
+
+
+def _get_log_energy(strided_input: Tensor, epsilon: Tensor,
+ energy_floor: float) -> Tensor:
+ log_energy = paddle.maximum(strided_input.pow(2).sum(1), epsilon).log()
+ if energy_floor == 0.0:
+ return log_energy
+ return paddle.maximum(
+ log_energy,
+ paddle.to_tensor(math.log(energy_floor), dtype=strided_input.dtype))
+
+
+def _get_waveform_and_window_properties(
+ waveform: Tensor,
+ channel: int,
+ sr: int,
+ frame_shift: float,
+ frame_length: float,
+ round_to_power_of_two: bool,
+ preemphasis_coefficient: float) -> Tuple[Tensor, int, int, int]:
+ channel = max(channel, 0)
+ assert channel < waveform.shape[0], (
+ 'Invalid channel {} for size {}'.format(channel, waveform.shape[0]))
+ waveform = waveform[channel, :] # size (n)
+ window_shift = int(
+ sr * frame_shift *
+ 0.001) # pass frame_shift and frame_length in milliseconds
+ window_size = int(sr * frame_length * 0.001)
+ padded_window_size = _next_power_of_2(
+ window_size) if round_to_power_of_two else window_size
+
+ assert 2 <= window_size <= len(waveform), (
+ 'choose a window size {} that is [2, {}]'.format(window_size,
+ len(waveform)))
+ assert 0 < window_shift, '`window_shift` must be greater than 0'
+ assert padded_window_size % 2 == 0, 'the padded `window_size` must be divisible by two.' \
+ ' use `round_to_power_of_two` or change `frame_length`'
+ assert 0. <= preemphasis_coefficient <= 1.0, '`preemphasis_coefficient` must be between [0,1]'
+ assert sr > 0, '`sr` must be greater than zero'
+ return waveform, window_shift, window_size, padded_window_size
+
+
+def _get_window(waveform: Tensor,
+ padded_window_size: int,
+ window_size: int,
+ window_shift: int,
+ window_type: str,
+ blackman_coeff: float,
+ snip_edges: bool,
+ raw_energy: bool,
+ energy_floor: float,
+ dither: float,
+ remove_dc_offset: bool,
+ preemphasis_coefficient: float) -> Tuple[Tensor, Tensor]:
+ dtype = waveform.dtype
+ epsilon = _get_epsilon(dtype)
+
+ # (m, window_size)
+ strided_input = _get_strided(waveform, window_size, window_shift,
+ snip_edges)
+
+ if dither != 0.0:
+ x = paddle.maximum(epsilon,
+ paddle.rand(strided_input.shape, dtype=dtype))
+ rand_gauss = paddle.sqrt(-2 * x.log()) * paddle.cos(2 * math.pi * x)
+ strided_input = strided_input + rand_gauss * dither
+
+ if remove_dc_offset:
+ row_means = paddle.mean(strided_input, axis=1).unsqueeze(1) # (m, 1)
+ strided_input = strided_input - row_means
+
+ if raw_energy:
+ signal_log_energy = _get_log_energy(strided_input, epsilon,
+ energy_floor) # (m)
+
+ if preemphasis_coefficient != 0.0:
+ offset_strided_input = paddle.nn.functional.pad(
+ strided_input.unsqueeze(0), (1, 0),
+ data_format='NCL',
+ mode='replicate').squeeze(0) # (m, window_size + 1)
+ strided_input = strided_input - preemphasis_coefficient * offset_strided_input[:, :
+ -1]
+
+ window_function = _feature_window_function(
+ window_type, window_size, blackman_coeff,
+ dtype).unsqueeze(0) # (1, window_size)
+ strided_input = strided_input * window_function # (m, window_size)
+
+ # (m, padded_window_size)
+ if padded_window_size != window_size:
+ padding_right = padded_window_size - window_size
+ strided_input = paddle.nn.functional.pad(
+ strided_input.unsqueeze(0), (0, padding_right),
+ data_format='NCL',
+ mode='constant',
+ value=0).squeeze(0)
+
+ if not raw_energy:
+ signal_log_energy = _get_log_energy(strided_input, epsilon,
+ energy_floor) # size (m)
+
+ return strided_input, signal_log_energy
+
+
+def _subtract_column_mean(tensor: Tensor, subtract_mean: bool) -> Tensor:
+ if subtract_mean:
+ col_means = paddle.mean(tensor, axis=0).unsqueeze(0)
+ tensor = tensor - col_means
+ return tensor
+
+
+def spectrogram(waveform: Tensor,
+ blackman_coeff: float=0.42,
+ channel: int=-1,
+ dither: float=0.0,
+ energy_floor: float=1.0,
+ frame_length: float=25.0,
+ frame_shift: float=10.0,
+ preemphasis_coefficient: float=0.97,
+ raw_energy: bool=True,
+ remove_dc_offset: bool=True,
+ round_to_power_of_two: bool=True,
+ sr: int=16000,
+ snip_edges: bool=True,
+ subtract_mean: bool=False,
+ window_type: str="povey") -> Tensor:
+ """Compute and return a spectrogram from a waveform. The output is identical to Kaldi's.
+
+ Args:
+ waveform (Tensor): A waveform tensor with shape `(C, T)`.
+ blackman_coeff (float, optional): Coefficient for Blackman window.. Defaults to 0.42.
+ channel (int, optional): Select the channel of waveform. Defaults to -1.
+ dither (float, optional): Dithering constant . Defaults to 0.0.
+ energy_floor (float, optional): Floor on energy of the output Spectrogram. Defaults to 1.0.
+ frame_length (float, optional): Frame length in milliseconds. Defaults to 25.0.
+ frame_shift (float, optional): Shift between adjacent frames in milliseconds. Defaults to 10.0.
+ preemphasis_coefficient (float, optional): Preemphasis coefficient for input waveform. Defaults to 0.97.
+ raw_energy (bool, optional): Whether to compute before preemphasis and windowing. Defaults to True.
+ remove_dc_offset (bool, optional): Whether to subtract mean from waveform on frames. Defaults to True.
+ round_to_power_of_two (bool, optional): If True, round window size to power of two by zero-padding input
+ to FFT. Defaults to True.
+ sr (int, optional): Sample rate of input waveform. Defaults to 16000.
+ snip_edges (bool, optional): Drop samples in the end of waveform that cann't fit a signal frame when it
+ is set True. Otherwise performs reflect padding to the end of waveform. Defaults to True.
+ subtract_mean (bool, optional): Whether to subtract mean of feature files. Defaults to False.
+ window_type (str, optional): Choose type of window for FFT computation. Defaults to "povey".
+
+ Returns:
+ Tensor: A spectrogram tensor with shape `(m, padded_window_size // 2 + 1)` where m is the number of frames
+ depends on frame_length and frame_shift.
+ """
+ dtype = waveform.dtype
+ epsilon = _get_epsilon(dtype)
+
+ waveform, window_shift, window_size, padded_window_size = _get_waveform_and_window_properties(
+ waveform, channel, sr, frame_shift, frame_length, round_to_power_of_two,
+ preemphasis_coefficient)
+
+ strided_input, signal_log_energy = _get_window(
+ waveform, padded_window_size, window_size, window_shift, window_type,
+ blackman_coeff, snip_edges, raw_energy, energy_floor, dither,
+ remove_dc_offset, preemphasis_coefficient)
+
+ # (m, padded_window_size // 2 + 1, 2)
+ fft = paddle.fft.rfft(strided_input)
+
+ power_spectrum = paddle.maximum(
+ fft.abs().pow(2.), epsilon).log() # (m, padded_window_size // 2 + 1)
+ power_spectrum[:, 0] = signal_log_energy
+
+ power_spectrum = _subtract_column_mean(power_spectrum, subtract_mean)
+ return power_spectrum
+
+
+def _inverse_mel_scale_scalar(mel_freq: float) -> float:
+ return 700.0 * (math.exp(mel_freq / 1127.0) - 1.0)
+
+
+def _inverse_mel_scale(mel_freq: Tensor) -> Tensor:
+ return 700.0 * ((mel_freq / 1127.0).exp() - 1.0)
+
+
+def _mel_scale_scalar(freq: float) -> float:
+ return 1127.0 * math.log(1.0 + freq / 700.0)
+
+
+def _mel_scale(freq: Tensor) -> Tensor:
+ return 1127.0 * (1.0 + freq / 700.0).log()
+
+
+def _vtln_warp_freq(vtln_low_cutoff: float,
+ vtln_high_cutoff: float,
+ low_freq: float,
+ high_freq: float,
+ vtln_warp_factor: float,
+ freq: Tensor) -> Tensor:
+ assert vtln_low_cutoff > low_freq, 'be sure to set the vtln_low option higher than low_freq'
+ assert vtln_high_cutoff < high_freq, 'be sure to set the vtln_high option lower than high_freq [or negative]'
+ l = vtln_low_cutoff * max(1.0, vtln_warp_factor)
+ h = vtln_high_cutoff * min(1.0, vtln_warp_factor)
+ scale = 1.0 / vtln_warp_factor
+ Fl = scale * l
+ Fh = scale * h
+ assert l > low_freq and h < high_freq
+ scale_left = (Fl - low_freq) / (l - low_freq)
+ scale_right = (high_freq - Fh) / (high_freq - h)
+ res = paddle.empty_like(freq)
+
+ outside_low_high_freq = paddle.less_than(freq, paddle.to_tensor(low_freq)) \
+ | paddle.greater_than(freq, paddle.to_tensor(high_freq))
+ before_l = paddle.less_than(freq, paddle.to_tensor(l))
+ before_h = paddle.less_than(freq, paddle.to_tensor(h))
+ after_h = paddle.greater_equal(freq, paddle.to_tensor(h))
+
+ res[after_h] = high_freq + scale_right * (freq[after_h] - high_freq)
+ res[before_h] = scale * freq[before_h]
+ res[before_l] = low_freq + scale_left * (freq[before_l] - low_freq)
+ res[outside_low_high_freq] = freq[outside_low_high_freq]
+
+ return res
+
+
+def _vtln_warp_mel_freq(vtln_low_cutoff: float,
+ vtln_high_cutoff: float,
+ low_freq,
+ high_freq: float,
+ vtln_warp_factor: float,
+ mel_freq: Tensor) -> Tensor:
+ return _mel_scale(
+ _vtln_warp_freq(vtln_low_cutoff, vtln_high_cutoff, low_freq, high_freq,
+ vtln_warp_factor, _inverse_mel_scale(mel_freq)))
+
+
+def _get_mel_banks(num_bins: int,
+ window_length_padded: int,
+ sample_freq: float,
+ low_freq: float,
+ high_freq: float,
+ vtln_low: float,
+ vtln_high: float,
+ vtln_warp_factor: float) -> Tuple[Tensor, Tensor]:
+ assert num_bins > 3, 'Must have at least 3 mel bins'
+ assert window_length_padded % 2 == 0
+ num_fft_bins = window_length_padded / 2
+ nyquist = 0.5 * sample_freq
+
+ if high_freq <= 0.0:
+ high_freq += nyquist
+
+ assert (0.0 <= low_freq < nyquist) and (0.0 < high_freq <= nyquist) and (low_freq < high_freq), \
+ ('Bad values in options: low-freq {} and high-freq {} vs. nyquist {}'.format(low_freq, high_freq, nyquist))
+
+ fft_bin_width = sample_freq / window_length_padded
+ mel_low_freq = _mel_scale_scalar(low_freq)
+ mel_high_freq = _mel_scale_scalar(high_freq)
+
+ mel_freq_delta = (mel_high_freq - mel_low_freq) / (num_bins + 1)
+
+ if vtln_high < 0.0:
+ vtln_high += nyquist
+
+ assert vtln_warp_factor == 1.0 or ((low_freq < vtln_low < high_freq) and
+ (0.0 < vtln_high < high_freq) and (vtln_low < vtln_high)), \
+ ('Bad values in options: vtln-low {} and vtln-high {}, versus '
+ 'low-freq {} and high-freq {}'.format(vtln_low, vtln_high, low_freq, high_freq))
+
+ bin = paddle.arange(num_bins, dtype=paddle.float32).unsqueeze(1)
+ # left_mel = mel_low_freq + bin * mel_freq_delta # (num_bins, 1)
+ # center_mel = mel_low_freq + (bin + 1.0) * mel_freq_delta # (num_bins, 1)
+ # right_mel = mel_low_freq + (bin + 2.0) * mel_freq_delta # (num_bins, 1)
+ left_mel = mel_low_freq + bin * mel_freq_delta # (num_bins, 1)
+ center_mel = left_mel + mel_freq_delta
+ right_mel = center_mel + mel_freq_delta
+
+ if vtln_warp_factor != 1.0:
+ left_mel = _vtln_warp_mel_freq(vtln_low, vtln_high, low_freq, high_freq,
+ vtln_warp_factor, left_mel)
+ center_mel = _vtln_warp_mel_freq(vtln_low, vtln_high, low_freq,
+ high_freq, vtln_warp_factor,
+ center_mel)
+ right_mel = _vtln_warp_mel_freq(vtln_low, vtln_high, low_freq,
+ high_freq, vtln_warp_factor, right_mel)
+
+ center_freqs = _inverse_mel_scale(center_mel) # (num_bins)
+ # (1, num_fft_bins)
+ mel = _mel_scale(fft_bin_width * paddle.arange(
+ num_fft_bins, dtype=paddle.float32)).unsqueeze(0)
+
+ # (num_bins, num_fft_bins)
+ up_slope = (mel - left_mel) / (center_mel - left_mel)
+ down_slope = (right_mel - mel) / (right_mel - center_mel)
+
+ if vtln_warp_factor == 1.0:
+ bins = paddle.maximum(
+ paddle.zeros([1]), paddle.minimum(up_slope, down_slope))
+ else:
+ bins = paddle.zeros_like(up_slope)
+ up_idx = paddle.greater_than(mel, left_mel) & paddle.less_than(
+ mel, center_mel)
+ down_idx = paddle.greater_than(mel, center_mel) & paddle.less_than(
+ mel, right_mel)
+ bins[up_idx] = up_slope[up_idx]
+ bins[down_idx] = down_slope[down_idx]
+
+ return bins, center_freqs
+
+
+def fbank(waveform: Tensor,
+ blackman_coeff: float=0.42,
+ channel: int=-1,
+ dither: float=0.0,
+ energy_floor: float=1.0,
+ frame_length: float=25.0,
+ frame_shift: float=10.0,
+ high_freq: float=0.0,
+ htk_compat: bool=False,
+ low_freq: float=20.0,
+ n_mels: int=23,
+ preemphasis_coefficient: float=0.97,
+ raw_energy: bool=True,
+ remove_dc_offset: bool=True,
+ round_to_power_of_two: bool=True,
+ sr: int=16000,
+ snip_edges: bool=True,
+ subtract_mean: bool=False,
+ use_energy: bool=False,
+ use_log_fbank: bool=True,
+ use_power: bool=True,
+ vtln_high: float=-500.0,
+ vtln_low: float=100.0,
+ vtln_warp: float=1.0,
+ window_type: str="povey") -> Tensor:
+ """Compute and return filter banks from a waveform. The output is identical to Kaldi's.
+
+ Args:
+ waveform (Tensor): A waveform tensor with shape `(C, T)`. `C` is in the range [0,1].
+ blackman_coeff (float, optional): Coefficient for Blackman window.. Defaults to 0.42.
+ channel (int, optional): Select the channel of waveform. Defaults to -1.
+ dither (float, optional): Dithering constant . Defaults to 0.0.
+ energy_floor (float, optional): Floor on energy of the output Spectrogram. Defaults to 1.0.
+ frame_length (float, optional): Frame length in milliseconds. Defaults to 25.0.
+ frame_shift (float, optional): Shift between adjacent frames in milliseconds. Defaults to 10.0.
+ high_freq (float, optional): The upper cut-off frequency. Defaults to 0.0.
+ htk_compat (bool, optional): Put energy to the last when it is set True. Defaults to False.
+ low_freq (float, optional): The lower cut-off frequency. Defaults to 20.0.
+ n_mels (int, optional): Number of output mel bins. Defaults to 23.
+ preemphasis_coefficient (float, optional): Preemphasis coefficient for input waveform. Defaults to 0.97.
+ raw_energy (bool, optional): Whether to compute before preemphasis and windowing. Defaults to True.
+ remove_dc_offset (bool, optional): Whether to subtract mean from waveform on frames. Defaults to True.
+ round_to_power_of_two (bool, optional): If True, round window size to power of two by zero-padding input
+ to FFT. Defaults to True.
+ sr (int, optional): Sample rate of input waveform. Defaults to 16000.
+ snip_edges (bool, optional): Drop samples in the end of waveform that cann't fit a signal frame when it
+ is set True. Otherwise performs reflect padding to the end of waveform. Defaults to True.
+ subtract_mean (bool, optional): Whether to subtract mean of feature files. Defaults to False.
+ use_energy (bool, optional): Add an dimension with energy of spectrogram to the output. Defaults to False.
+ use_log_fbank (bool, optional): Return log fbank when it is set True. Defaults to True.
+ use_power (bool, optional): Whether to use power instead of magnitude. Defaults to True.
+ vtln_high (float, optional): High inflection point in piecewise linear VTLN warping function. Defaults to -500.0.
+ vtln_low (float, optional): Low inflection point in piecewise linear VTLN warping function. Defaults to 100.0.
+ vtln_warp (float, optional): Vtln warp factor. Defaults to 1.0.
+ window_type (str, optional): Choose type of window for FFT computation. Defaults to "povey".
+
+ Returns:
+ Tensor: A filter banks tensor with shape `(m, n_mels)`.
+ """
+ dtype = waveform.dtype
+
+ waveform, window_shift, window_size, padded_window_size = _get_waveform_and_window_properties(
+ waveform, channel, sr, frame_shift, frame_length, round_to_power_of_two,
+ preemphasis_coefficient)
+
+ strided_input, signal_log_energy = _get_window(
+ waveform, padded_window_size, window_size, window_shift, window_type,
+ blackman_coeff, snip_edges, raw_energy, energy_floor, dither,
+ remove_dc_offset, preemphasis_coefficient)
+
+ # (m, padded_window_size // 2 + 1)
+ spectrum = paddle.fft.rfft(strided_input).abs()
+ if use_power:
+ spectrum = spectrum.pow(2.)
+
+ # (n_mels, padded_window_size // 2)
+ mel_energies, _ = _get_mel_banks(n_mels, padded_window_size, sr, low_freq,
+ high_freq, vtln_low, vtln_high, vtln_warp)
+ # mel_energies = mel_energies.astype(dtype)
+ assert mel_energies.dtype == dtype
+
+ # (n_mels, padded_window_size // 2 + 1)
+ mel_energies = paddle.nn.functional.pad(
+ mel_energies.unsqueeze(0), (0, 1),
+ data_format='NCL',
+ mode='constant',
+ value=0).squeeze(0)
+
+ # (m, n_mels)
+ mel_energies = paddle.mm(spectrum, mel_energies.T)
+ if use_log_fbank:
+ mel_energies = paddle.maximum(mel_energies, _get_epsilon(dtype)).log()
+
+ if use_energy:
+ signal_log_energy = signal_log_energy.unsqueeze(1)
+ if htk_compat:
+ mel_energies = paddle.concat(
+ (mel_energies, signal_log_energy), axis=1)
+ else:
+ mel_energies = paddle.concat(
+ (signal_log_energy, mel_energies), axis=1)
+
+ # (m, n_mels + 1)
+ mel_energies = _subtract_column_mean(mel_energies, subtract_mean)
+ return mel_energies
+
+
+def _get_dct_matrix(n_mfcc: int, n_mels: int) -> Tensor:
+ dct_matrix = create_dct(n_mels, n_mels, 'ortho')
+ dct_matrix[:, 0] = math.sqrt(1 / float(n_mels))
+ dct_matrix = dct_matrix[:, :n_mfcc] # (n_mels, n_mfcc)
+ return dct_matrix
+
+
+def _get_lifter_coeffs(n_mfcc: int, cepstral_lifter: float) -> Tensor:
+ i = paddle.arange(n_mfcc)
+ return 1.0 + 0.5 * cepstral_lifter * paddle.sin(math.pi * i /
+ cepstral_lifter)
+
+
+def mfcc(waveform: Tensor,
+ blackman_coeff: float=0.42,
+ cepstral_lifter: float=22.0,
+ channel: int=-1,
+ dither: float=0.0,
+ energy_floor: float=1.0,
+ frame_length: float=25.0,
+ frame_shift: float=10.0,
+ high_freq: float=0.0,
+ htk_compat: bool=False,
+ low_freq: float=20.0,
+ n_mfcc: int=13,
+ n_mels: int=23,
+ preemphasis_coefficient: float=0.97,
+ raw_energy: bool=True,
+ remove_dc_offset: bool=True,
+ round_to_power_of_two: bool=True,
+ sr: int=16000,
+ snip_edges: bool=True,
+ subtract_mean: bool=False,
+ use_energy: bool=False,
+ vtln_high: float=-500.0,
+ vtln_low: float=100.0,
+ vtln_warp: float=1.0,
+ window_type: str="povey") -> Tensor:
+ """Compute and return mel frequency cepstral coefficients from a waveform. The output is
+ identical to Kaldi's.
+
+ Args:
+ waveform (Tensor): A waveform tensor with shape `(C, T)`.
+ blackman_coeff (float, optional): Coefficient for Blackman window.. Defaults to 0.42.
+ cepstral_lifter (float, optional): Scaling of output mfccs. Defaults to 22.0.
+ channel (int, optional): Select the channel of waveform. Defaults to -1.
+ dither (float, optional): Dithering constant . Defaults to 0.0.
+ energy_floor (float, optional): Floor on energy of the output Spectrogram. Defaults to 1.0.
+ frame_length (float, optional): Frame length in milliseconds. Defaults to 25.0.
+ frame_shift (float, optional): Shift between adjacent frames in milliseconds. Defaults to 10.0.
+ high_freq (float, optional): The upper cut-off frequency. Defaults to 0.0.
+ htk_compat (bool, optional): Put energy to the last when it is set True. Defaults to False.
+ low_freq (float, optional): The lower cut-off frequency. Defaults to 20.0.
+ n_mfcc (int, optional): Number of cepstra in MFCC. Defaults to 13.
+ n_mels (int, optional): Number of output mel bins. Defaults to 23.
+ preemphasis_coefficient (float, optional): Preemphasis coefficient for input waveform. Defaults to 0.97.
+ raw_energy (bool, optional): Whether to compute before preemphasis and windowing. Defaults to True.
+ remove_dc_offset (bool, optional): Whether to subtract mean from waveform on frames. Defaults to True.
+ round_to_power_of_two (bool, optional): If True, round window size to power of two by zero-padding input
+ to FFT. Defaults to True.
+ sr (int, optional): Sample rate of input waveform. Defaults to 16000.
+ snip_edges (bool, optional): Drop samples in the end of waveform that cann't fit a signal frame when it
+ is set True. Otherwise performs reflect padding to the end of waveform. Defaults to True.
+ subtract_mean (bool, optional): Whether to subtract mean of feature files. Defaults to False.
+ use_energy (bool, optional): Add an dimension with energy of spectrogram to the output. Defaults to False.
+ vtln_high (float, optional): High inflection point in piecewise linear VTLN warping function. Defaults to -500.0.
+ vtln_low (float, optional): Low inflection point in piecewise linear VTLN warping function. Defaults to 100.0.
+ vtln_warp (float, optional): Vtln warp factor. Defaults to 1.0.
+ window_type (str, optional): Choose type of window for FFT computation. Defaults to POVEY.
+
+ Returns:
+ Tensor: A mel frequency cepstral coefficients tensor with shape `(m, n_mfcc)`.
+ """
+ assert n_mfcc <= n_mels, 'n_mfcc cannot be larger than n_mels: %d vs %d' % (
+ n_mfcc, n_mels)
+
+ dtype = waveform.dtype
+
+ # (m, n_mels + use_energy)
+ feature = fbank(
+ waveform=waveform,
+ blackman_coeff=blackman_coeff,
+ channel=channel,
+ dither=dither,
+ energy_floor=energy_floor,
+ frame_length=frame_length,
+ frame_shift=frame_shift,
+ high_freq=high_freq,
+ htk_compat=htk_compat,
+ low_freq=low_freq,
+ n_mels=n_mels,
+ preemphasis_coefficient=preemphasis_coefficient,
+ raw_energy=raw_energy,
+ remove_dc_offset=remove_dc_offset,
+ round_to_power_of_two=round_to_power_of_two,
+ sr=sr,
+ snip_edges=snip_edges,
+ subtract_mean=False,
+ use_energy=use_energy,
+ use_log_fbank=True,
+ use_power=True,
+ vtln_high=vtln_high,
+ vtln_low=vtln_low,
+ vtln_warp=vtln_warp,
+ window_type=window_type)
+
+ if use_energy:
+ # (m)
+ signal_log_energy = feature[:, n_mels if htk_compat else 0]
+ mel_offset = int(not htk_compat)
+ feature = feature[:, mel_offset:(n_mels + mel_offset)]
+
+ # (n_mels, n_mfcc)
+ dct_matrix = _get_dct_matrix(n_mfcc, n_mels).astype(dtype=dtype)
+
+ # (m, n_mfcc)
+ feature = feature.matmul(dct_matrix)
+
+ if cepstral_lifter != 0.0:
+ # (1, n_mfcc)
+ lifter_coeffs = _get_lifter_coeffs(n_mfcc, cepstral_lifter).unsqueeze(0)
+ feature *= lifter_coeffs.astype(dtype=dtype)
+
+ if use_energy:
+ feature[:, 0] = signal_log_energy
+
+ if htk_compat:
+ energy = feature[:, 0].unsqueeze(1) # (m, 1)
+ feature = feature[:, 1:] # (m, n_mfcc - 1)
+ if not use_energy:
+ energy *= math.sqrt(2)
+
+ feature = paddle.concat((feature, energy), axis=1)
+
+ feature = _subtract_column_mean(feature, subtract_mean)
+ return feature
diff --git a/paddlespeech/audio/compliance/librosa.py b/paddlespeech/audio/compliance/librosa.py
new file mode 100644
index 000000000..c671d4fb8
--- /dev/null
+++ b/paddlespeech/audio/compliance/librosa.py
@@ -0,0 +1,788 @@
+# Copyright (c) 2021 PaddlePaddle Authors. All Rights Reserved.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+# Modified from librosa(https://github.com/librosa/librosa)
+import warnings
+from typing import List
+from typing import Optional
+from typing import Union
+
+import numpy as np
+import scipy
+from numpy.lib.stride_tricks import as_strided
+from scipy import signal
+
+from ..utils import depth_convert
+from ..utils import ParameterError
+
+__all__ = [
+ # dsp
+ 'stft',
+ 'mfcc',
+ 'hz_to_mel',
+ 'mel_to_hz',
+ 'mel_frequencies',
+ 'power_to_db',
+ 'compute_fbank_matrix',
+ 'melspectrogram',
+ 'spectrogram',
+ 'mu_encode',
+ 'mu_decode',
+ # augmentation
+ 'depth_augment',
+ 'spect_augment',
+ 'random_crop1d',
+ 'random_crop2d',
+ 'adaptive_spect_augment',
+]
+
+
+def _pad_center(data: np.ndarray, size: int, axis: int=-1,
+ **kwargs) -> np.ndarray:
+ """Pad an array to a target length along a target axis.
+
+ This differs from `np.pad` by centering the data prior to padding,
+ analogous to `str.center`
+ """
+
+ kwargs.setdefault("mode", "constant")
+ n = data.shape[axis]
+ lpad = int((size - n) // 2)
+ lengths = [(0, 0)] * data.ndim
+ lengths[axis] = (lpad, int(size - n - lpad))
+
+ if lpad < 0:
+ raise ParameterError(("Target size ({size:d}) must be "
+ "at least input size ({n:d})"))
+
+ return np.pad(data, lengths, **kwargs)
+
+
+def _split_frames(x: np.ndarray,
+ frame_length: int,
+ hop_length: int,
+ axis: int=-1) -> np.ndarray:
+ """Slice a data array into (overlapping) frames.
+
+ This function is aligned with librosa.frame
+ """
+
+ if not isinstance(x, np.ndarray):
+ raise ParameterError(
+ f"Input must be of type numpy.ndarray, given type(x)={type(x)}")
+
+ if x.shape[axis] < frame_length:
+ raise ParameterError(f"Input is too short (n={x.shape[axis]:d})"
+ f" for frame_length={frame_length:d}")
+
+ if hop_length < 1:
+ raise ParameterError(f"Invalid hop_length: {hop_length:d}")
+
+ if axis == -1 and not x.flags["F_CONTIGUOUS"]:
+ warnings.warn(f"librosa.util.frame called with axis={axis} "
+ "on a non-contiguous input. This will result in a copy.")
+ x = np.asfortranarray(x)
+ elif axis == 0 and not x.flags["C_CONTIGUOUS"]:
+ warnings.warn(f"librosa.util.frame called with axis={axis} "
+ "on a non-contiguous input. This will result in a copy.")
+ x = np.ascontiguousarray(x)
+
+ n_frames = 1 + (x.shape[axis] - frame_length) // hop_length
+ strides = np.asarray(x.strides)
+
+ new_stride = np.prod(strides[strides > 0] // x.itemsize) * x.itemsize
+
+ if axis == -1:
+ shape = list(x.shape)[:-1] + [frame_length, n_frames]
+ strides = list(strides) + [hop_length * new_stride]
+
+ elif axis == 0:
+ shape = [n_frames, frame_length] + list(x.shape)[1:]
+ strides = [hop_length * new_stride] + list(strides)
+
+ else:
+ raise ParameterError(f"Frame axis={axis} must be either 0 or -1")
+
+ return as_strided(x, shape=shape, strides=strides)
+
+
+def _check_audio(y, mono=True) -> bool:
+ """Determine whether a variable contains valid audio data.
+
+ The audio y must be a np.ndarray, ether 1-channel or two channel
+ """
+ if not isinstance(y, np.ndarray):
+ raise ParameterError("Audio data must be of type numpy.ndarray")
+ if y.ndim > 2:
+ raise ParameterError(
+ f"Invalid shape for audio ndim={y.ndim:d}, shape={y.shape}")
+
+ if mono and y.ndim == 2:
+ raise ParameterError(
+ f"Invalid shape for mono audio ndim={y.ndim:d}, shape={y.shape}")
+
+ if (mono and len(y) == 0) or (not mono and y.shape[1] < 0):
+ raise ParameterError(f"Audio is empty ndim={y.ndim:d}, shape={y.shape}")
+
+ if not np.issubdtype(y.dtype, np.floating):
+ raise ParameterError("Audio data must be floating-point")
+
+ if not np.isfinite(y).all():
+ raise ParameterError("Audio buffer is not finite everywhere")
+
+ return True
+
+
+def hz_to_mel(frequencies: Union[float, List[float], np.ndarray],
+ htk: bool=False) -> np.ndarray:
+ """Convert Hz to Mels.
+
+ Args:
+ frequencies (Union[float, List[float], np.ndarray]): Frequencies in Hz.
+ htk (bool, optional): Use htk scaling. Defaults to False.
+
+ Returns:
+ np.ndarray: Frequency in mels.
+ """
+ freq = np.asanyarray(frequencies)
+
+ if htk:
+ return 2595.0 * np.log10(1.0 + freq / 700.0)
+
+ # Fill in the linear part
+ f_min = 0.0
+ f_sp = 200.0 / 3
+
+ mels = (freq - f_min) / f_sp
+
+ # Fill in the log-scale part
+
+ min_log_hz = 1000.0 # beginning of log region (Hz)
+ min_log_mel = (min_log_hz - f_min) / f_sp # same (Mels)
+ logstep = np.log(6.4) / 27.0 # step size for log region
+
+ if freq.ndim:
+ # If we have array data, vectorize
+ log_t = freq >= min_log_hz
+ mels[log_t] = min_log_mel + \
+ np.log(freq[log_t] / min_log_hz) / logstep
+ elif freq >= min_log_hz:
+ # If we have scalar data, heck directly
+ mels = min_log_mel + np.log(freq / min_log_hz) / logstep
+
+ return mels
+
+
+def mel_to_hz(mels: Union[float, List[float], np.ndarray],
+ htk: int=False) -> np.ndarray:
+ """Convert mel bin numbers to frequencies.
+
+ Args:
+ mels (Union[float, List[float], np.ndarray]): Frequency in mels.
+ htk (bool, optional): Use htk scaling. Defaults to False.
+
+ Returns:
+ np.ndarray: Frequencies in Hz.
+ """
+ mel_array = np.asanyarray(mels)
+
+ if htk:
+ return 700.0 * (10.0**(mel_array / 2595.0) - 1.0)
+
+ # Fill in the linear scale
+ f_min = 0.0
+ f_sp = 200.0 / 3
+ freqs = f_min + f_sp * mel_array
+
+ # And now the nonlinear scale
+ min_log_hz = 1000.0 # beginning of log region (Hz)
+ min_log_mel = (min_log_hz - f_min) / f_sp # same (Mels)
+ logstep = np.log(6.4) / 27.0 # step size for log region
+
+ if mel_array.ndim:
+ # If we have vector data, vectorize
+ log_t = mel_array >= min_log_mel
+ freqs[log_t] = min_log_hz * \
+ np.exp(logstep * (mel_array[log_t] - min_log_mel))
+ elif mel_array >= min_log_mel:
+ # If we have scalar data, check directly
+ freqs = min_log_hz * np.exp(logstep * (mel_array - min_log_mel))
+
+ return freqs
+
+
+def mel_frequencies(n_mels: int=128,
+ fmin: float=0.0,
+ fmax: float=11025.0,
+ htk: bool=False) -> np.ndarray:
+ """Compute mel frequencies.
+
+ Args:
+ n_mels (int, optional): Number of mel bins. Defaults to 128.
+ fmin (float, optional): Minimum frequency in Hz. Defaults to 0.0.
+ fmax (float, optional): Maximum frequency in Hz. Defaults to 11025.0.
+ htk (bool, optional): Use htk scaling. Defaults to False.
+
+ Returns:
+ np.ndarray: Vector of n_mels frequencies in Hz with shape `(n_mels,)`.
+ """
+ # 'Center freqs' of mel bands - uniformly spaced between limits
+ min_mel = hz_to_mel(fmin, htk=htk)
+ max_mel = hz_to_mel(fmax, htk=htk)
+
+ mels = np.linspace(min_mel, max_mel, n_mels)
+
+ return mel_to_hz(mels, htk=htk)
+
+
+def fft_frequencies(sr: int, n_fft: int) -> np.ndarray:
+ """Compute fourier frequencies.
+
+ Args:
+ sr (int): Sample rate.
+ n_fft (int): FFT size.
+
+ Returns:
+ np.ndarray: FFT frequencies in Hz with shape `(n_fft//2 + 1,)`.
+ """
+ return np.linspace(0, float(sr) / 2, int(1 + n_fft // 2), endpoint=True)
+
+
+def compute_fbank_matrix(sr: int,
+ n_fft: int,
+ n_mels: int=128,
+ fmin: float=0.0,
+ fmax: Optional[float]=None,
+ htk: bool=False,
+ norm: str="slaney",
+ dtype: type=np.float32) -> np.ndarray:
+ """Compute fbank matrix.
+
+ Args:
+ sr (int): Sample rate.
+ n_fft (int): FFT size.
+ n_mels (int, optional): Number of mel bins. Defaults to 128.
+ fmin (float, optional): Minimum frequency in Hz. Defaults to 0.0.
+ fmax (Optional[float], optional): Maximum frequency in Hz. Defaults to None.
+ htk (bool, optional): Use htk scaling. Defaults to False.
+ norm (str, optional): Type of normalization. Defaults to "slaney".
+ dtype (type, optional): Data type. Defaults to np.float32.
+
+
+ Returns:
+ np.ndarray: Mel transform matrix with shape `(n_mels, n_fft//2 + 1)`.
+ """
+ if norm != "slaney":
+ raise ParameterError('norm must set to slaney')
+
+ if fmax is None:
+ fmax = float(sr) / 2
+
+ # Initialize the weights
+ n_mels = int(n_mels)
+ weights = np.zeros((n_mels, int(1 + n_fft // 2)), dtype=dtype)
+
+ # Center freqs of each FFT bin
+ fftfreqs = fft_frequencies(sr=sr, n_fft=n_fft)
+
+ # 'Center freqs' of mel bands - uniformly spaced between limits
+ mel_f = mel_frequencies(n_mels + 2, fmin=fmin, fmax=fmax, htk=htk)
+
+ fdiff = np.diff(mel_f)
+ ramps = np.subtract.outer(mel_f, fftfreqs)
+
+ for i in range(n_mels):
+ # lower and upper slopes for all bins
+ lower = -ramps[i] / fdiff[i]
+ upper = ramps[i + 2] / fdiff[i + 1]
+
+ # .. then intersect them with each other and zero
+ weights[i] = np.maximum(0, np.minimum(lower, upper))
+
+ if norm == "slaney":
+ # Slaney-style mel is scaled to be approx constant energy per channel
+ enorm = 2.0 / (mel_f[2:n_mels + 2] - mel_f[:n_mels])
+ weights *= enorm[:, np.newaxis]
+
+ # Only check weights if f_mel[0] is positive
+ if not np.all((mel_f[:-2] == 0) | (weights.max(axis=1) > 0)):
+ # This means we have an empty channel somewhere
+ warnings.warn("Empty filters detected in mel frequency basis. "
+ "Some channels will produce empty responses. "
+ "Try increasing your sampling rate (and fmax) or "
+ "reducing n_mels.")
+
+ return weights
+
+
+def stft(x: np.ndarray,
+ n_fft: int=2048,
+ hop_length: Optional[int]=None,
+ win_length: Optional[int]=None,
+ window: str="hann",
+ center: bool=True,
+ dtype: type=np.complex64,
+ pad_mode: str="reflect") -> np.ndarray:
+ """Short-time Fourier transform (STFT).
+
+ Args:
+ x (np.ndarray): Input waveform in one dimension.
+ n_fft (int, optional): FFT size. Defaults to 2048.
+ hop_length (Optional[int], optional): Number of steps to advance between adjacent windows. Defaults to None.
+ win_length (Optional[int], optional): The size of window. Defaults to None.
+ window (str, optional): A string of window specification. Defaults to "hann".
+ center (bool, optional): Whether to pad `x` to make that the :math:`t \times hop\\_length` at the center of `t`-th frame. Defaults to True.
+ dtype (type, optional): Data type of STFT results. Defaults to np.complex64.
+ pad_mode (str, optional): Choose padding pattern when `center` is `True`. Defaults to "reflect".
+
+ Returns:
+ np.ndarray: The complex STFT output with shape `(n_fft//2 + 1, num_frames)`.
+ """
+ _check_audio(x)
+
+ # By default, use the entire frame
+ if win_length is None:
+ win_length = n_fft
+
+ # Set the default hop, if it's not already specified
+ if hop_length is None:
+ hop_length = int(win_length // 4)
+
+ fft_window = signal.get_window(window, win_length, fftbins=True)
+
+ # Pad the window out to n_fft size
+ fft_window = _pad_center(fft_window, n_fft)
+
+ # Reshape so that the window can be broadcast
+ fft_window = fft_window.reshape((-1, 1))
+
+ # Pad the time series so that frames are centered
+ if center:
+ if n_fft > x.shape[-1]:
+ warnings.warn(
+ f"n_fft={n_fft} is too small for input signal of length={x.shape[-1]}"
+ )
+ x = np.pad(x, int(n_fft // 2), mode=pad_mode)
+
+ elif n_fft > x.shape[-1]:
+ raise ParameterError(
+ f"n_fft={n_fft} is too small for input signal of length={x.shape[-1]}"
+ )
+
+ # Window the time series.
+ x_frames = _split_frames(x, frame_length=n_fft, hop_length=hop_length)
+ # Pre-allocate the STFT matrix
+ stft_matrix = np.empty(
+ (int(1 + n_fft // 2), x_frames.shape[1]), dtype=dtype, order="F")
+ fft = np.fft # use numpy fft as default
+ # Constrain STFT block sizes to 256 KB
+ MAX_MEM_BLOCK = 2**8 * 2**10
+ # how many columns can we fit within MAX_MEM_BLOCK?
+ n_columns = MAX_MEM_BLOCK // (stft_matrix.shape[0] * stft_matrix.itemsize)
+ n_columns = max(n_columns, 1)
+
+ for bl_s in range(0, stft_matrix.shape[1], n_columns):
+ bl_t = min(bl_s + n_columns, stft_matrix.shape[1])
+ stft_matrix[:, bl_s:bl_t] = fft.rfft(
+ fft_window * x_frames[:, bl_s:bl_t], axis=0)
+
+ return stft_matrix
+
+
+def power_to_db(spect: np.ndarray,
+ ref: float=1.0,
+ amin: float=1e-10,
+ top_db: Optional[float]=80.0) -> np.ndarray:
+ """Convert a power spectrogram (amplitude squared) to decibel (dB) units. The function computes the scaling `10 * log10(x / ref)` in a numerically stable way.
+
+ Args:
+ spect (np.ndarray): STFT power spectrogram of an input waveform.
+ ref (float, optional): The reference value. If smaller than 1.0, the db level of the signal will be pulled up accordingly. Otherwise, the db level is pushed down. Defaults to 1.0.
+ amin (float, optional): Minimum threshold. Defaults to 1e-10.
+ top_db (Optional[float], optional): Threshold the output at `top_db` below the peak. Defaults to 80.0.
+
+ Returns:
+ np.ndarray: Power spectrogram in db scale.
+ """
+ spect = np.asarray(spect)
+
+ if amin <= 0:
+ raise ParameterError("amin must be strictly positive")
+
+ if np.issubdtype(spect.dtype, np.complexfloating):
+ warnings.warn(
+ "power_to_db was called on complex input so phase "
+ "information will be discarded. To suppress this warning, "
+ "call power_to_db(np.abs(D)**2) instead.")
+ magnitude = np.abs(spect)
+ else:
+ magnitude = spect
+
+ if callable(ref):
+ # User supplied a function to calculate reference power
+ ref_value = ref(magnitude)
+ else:
+ ref_value = np.abs(ref)
+
+ log_spec = 10.0 * np.log10(np.maximum(amin, magnitude))
+ log_spec -= 10.0 * np.log10(np.maximum(amin, ref_value))
+
+ if top_db is not None:
+ if top_db < 0:
+ raise ParameterError("top_db must be non-negative")
+ log_spec = np.maximum(log_spec, log_spec.max() - top_db)
+
+ return log_spec
+
+
+def mfcc(x: np.ndarray,
+ sr: int=16000,
+ spect: Optional[np.ndarray]=None,
+ n_mfcc: int=20,
+ dct_type: int=2,
+ norm: str="ortho",
+ lifter: int=0,
+ **kwargs) -> np.ndarray:
+ """Mel-frequency cepstral coefficients (MFCCs)
+
+ Args:
+ x (np.ndarray): Input waveform in one dimension.
+ sr (int, optional): Sample rate. Defaults to 16000.
+ spect (Optional[np.ndarray], optional): Input log-power Mel spectrogram. Defaults to None.
+ n_mfcc (int, optional): Number of cepstra in MFCC. Defaults to 20.
+ dct_type (int, optional): Discrete cosine transform (DCT) type. Defaults to 2.
+ norm (str, optional): Type of normalization. Defaults to "ortho".
+ lifter (int, optional): Cepstral filtering. Defaults to 0.
+
+ Returns:
+ np.ndarray: Mel frequency cepstral coefficients array with shape `(n_mfcc, num_frames)`.
+ """
+ if spect is None:
+ spect = melspectrogram(x, sr=sr, **kwargs)
+
+ M = scipy.fftpack.dct(spect, axis=0, type=dct_type, norm=norm)[:n_mfcc]
+
+ if lifter > 0:
+ factor = np.sin(np.pi * np.arange(1, 1 + n_mfcc, dtype=M.dtype) /
+ lifter)
+ return M * factor[:, np.newaxis]
+ elif lifter == 0:
+ return M
+ else:
+ raise ParameterError(
+ f"MFCC lifter={lifter} must be a non-negative number")
+
+
+def melspectrogram(x: np.ndarray,
+ sr: int=16000,
+ window_size: int=512,
+ hop_length: int=320,
+ n_mels: int=64,
+ fmin: float=50.0,
+ fmax: Optional[float]=None,
+ window: str='hann',
+ center: bool=True,
+ pad_mode: str='reflect',
+ power: float=2.0,
+ to_db: bool=True,
+ ref: float=1.0,
+ amin: float=1e-10,
+ top_db: Optional[float]=None) -> np.ndarray:
+ """Compute mel-spectrogram.
+
+ Args:
+ x (np.ndarray): Input waveform in one dimension.
+ sr (int, optional): Sample rate. Defaults to 16000.
+ window_size (int, optional): Size of FFT and window length. Defaults to 512.
+ hop_length (int, optional): Number of steps to advance between adjacent windows. Defaults to 320.
+ n_mels (int, optional): Number of mel bins. Defaults to 64.
+ fmin (float, optional): Minimum frequency in Hz. Defaults to 50.0.
+ fmax (Optional[float], optional): Maximum frequency in Hz. Defaults to None.
+ window (str, optional): A string of window specification. Defaults to "hann".
+ center (bool, optional): Whether to pad `x` to make that the :math:`t \times hop\\_length` at the center of `t`-th frame. Defaults to True.
+ pad_mode (str, optional): Choose padding pattern when `center` is `True`. Defaults to "reflect".
+ power (float, optional): Exponent for the magnitude melspectrogram. Defaults to 2.0.
+ to_db (bool, optional): Enable db scale. Defaults to True.
+ ref (float, optional): The reference value. If smaller than 1.0, the db level of the signal will be pulled up accordingly. Otherwise, the db level is pushed down. Defaults to 1.0.
+ amin (float, optional): Minimum threshold. Defaults to 1e-10.
+ top_db (Optional[float], optional): Threshold the output at `top_db` below the peak. Defaults to None.
+
+ Returns:
+ np.ndarray: The mel-spectrogram in power scale or db scale with shape `(n_mels, num_frames)`.
+ """
+ _check_audio(x, mono=True)
+ if len(x) <= 0:
+ raise ParameterError('The input waveform is empty')
+
+ if fmax is None:
+ fmax = sr // 2
+ if fmin < 0 or fmin >= fmax:
+ raise ParameterError('fmin and fmax must statisfy 0 np.ndarray:
+ """Compute spectrogram.
+
+ Args:
+ x (np.ndarray): Input waveform in one dimension.
+ sr (int, optional): Sample rate. Defaults to 16000.
+ window_size (int, optional): Size of FFT and window length. Defaults to 512.
+ hop_length (int, optional): Number of steps to advance between adjacent windows. Defaults to 320.
+ window (str, optional): A string of window specification. Defaults to "hann".
+ center (bool, optional): Whether to pad `x` to make that the :math:`t \times hop\\_length` at the center of `t`-th frame. Defaults to True.
+ pad_mode (str, optional): Choose padding pattern when `center` is `True`. Defaults to "reflect".
+ power (float, optional): Exponent for the magnitude melspectrogram. Defaults to 2.0.
+
+ Returns:
+ np.ndarray: The STFT spectrogram in power scale `(n_fft//2 + 1, num_frames)`.
+ """
+
+ s = stft(
+ x,
+ n_fft=window_size,
+ hop_length=hop_length,
+ win_length=window_size,
+ window=window,
+ center=center,
+ pad_mode=pad_mode)
+
+ return np.abs(s)**power
+
+
+def mu_encode(x: np.ndarray, mu: int=255, quantized: bool=True) -> np.ndarray:
+ """Mu-law encoding. Encode waveform based on mu-law companding. When quantized is True, the result will be converted to integer in range `[0,mu-1]`. Otherwise, the resulting waveform is in range `[-1,1]`.
+
+ Args:
+ x (np.ndarray): The input waveform to encode.
+ mu (int, optional): The endoceding parameter. Defaults to 255.
+ quantized (bool, optional): If `True`, quantize the encoded values into `1 + mu` distinct integer values. Defaults to True.
+
+ Returns:
+ np.ndarray: The mu-law encoded waveform.
+ """
+ mu = 255
+ y = np.sign(x) * np.log1p(mu * np.abs(x)) / np.log1p(mu)
+ if quantized:
+ y = np.floor((y + 1) / 2 * mu + 0.5) # convert to [0 , mu-1]
+ return y
+
+
+def mu_decode(y: np.ndarray, mu: int=255, quantized: bool=True) -> np.ndarray:
+ """Mu-law decoding. Compute the mu-law decoding given an input code. It assumes that the input `y` is in range `[0,mu-1]` when quantize is True and `[-1,1]` otherwise.
+
+ Args:
+ y (np.ndarray): The encoded waveform.
+ mu (int, optional): The endoceding parameter. Defaults to 255.
+ quantized (bool, optional): If `True`, the input is assumed to be quantized to `1 + mu` distinct integer values. Defaults to True.
+
+ Returns:
+ np.ndarray: The mu-law decoded waveform.
+ """
+ if mu < 1:
+ raise ParameterError('mu is typically set as 2**k-1, k=1, 2, 3,...')
+
+ mu = mu - 1
+ if quantized: # undo the quantization
+ y = y * 2 / mu - 1
+ x = np.sign(y) / mu * ((1 + mu)**np.abs(y) - 1)
+ return x
+
+
+def _randint(high: int) -> int:
+ """Generate one random integer in range [0 high)
+
+ This is a helper function for random data augmentation
+ """
+ return int(np.random.randint(0, high=high))
+
+
+def depth_augment(y: np.ndarray,
+ choices: List=['int8', 'int16'],
+ probs: List[float]=[0.5, 0.5]) -> np.ndarray:
+ """ Audio depth augmentation. Do audio depth augmentation to simulate the distortion brought by quantization.
+
+ Args:
+ y (np.ndarray): Input waveform array in 1D or 2D.
+ choices (List, optional): A list of data type to depth conversion. Defaults to ['int8', 'int16'].
+ probs (List[float], optional): Probabilities to depth conversion. Defaults to [0.5, 0.5].
+
+ Returns:
+ np.ndarray: The augmented waveform.
+ """
+ assert len(probs) == len(
+ choices
+ ), 'number of choices {} must be equal to size of probs {}'.format(
+ len(choices), len(probs))
+ depth = np.random.choice(choices, p=probs)
+ src_depth = y.dtype
+ y1 = depth_convert(y, depth)
+ y2 = depth_convert(y1, src_depth)
+
+ return y2
+
+
+def adaptive_spect_augment(spect: np.ndarray,
+ tempo_axis: int=0,
+ level: float=0.1) -> np.ndarray:
+ """Do adaptive spectrogram augmentation. The level of the augmentation is govern by the parameter level, ranging from 0 to 1, with 0 represents no augmentation.
+
+ Args:
+ spect (np.ndarray): Input spectrogram.
+ tempo_axis (int, optional): Indicate the tempo axis. Defaults to 0.
+ level (float, optional): The level factor of masking. Defaults to 0.1.
+
+ Returns:
+ np.ndarray: The augmented spectrogram.
+ """
+ assert spect.ndim == 2., 'only supports 2d tensor or numpy array'
+ if tempo_axis == 0:
+ nt, nf = spect.shape
+ else:
+ nf, nt = spect.shape
+
+ time_mask_width = int(nt * level * 0.5)
+ freq_mask_width = int(nf * level * 0.5)
+
+ num_time_mask = int(10 * level)
+ num_freq_mask = int(10 * level)
+
+ if tempo_axis == 0:
+ for _ in range(num_time_mask):
+ start = _randint(nt - time_mask_width)
+ spect[start:start + time_mask_width, :] = 0
+ for _ in range(num_freq_mask):
+ start = _randint(nf - freq_mask_width)
+ spect[:, start:start + freq_mask_width] = 0
+ else:
+ for _ in range(num_time_mask):
+ start = _randint(nt - time_mask_width)
+ spect[:, start:start + time_mask_width] = 0
+ for _ in range(num_freq_mask):
+ start = _randint(nf - freq_mask_width)
+ spect[start:start + freq_mask_width, :] = 0
+
+ return spect
+
+
+def spect_augment(spect: np.ndarray,
+ tempo_axis: int=0,
+ max_time_mask: int=3,
+ max_freq_mask: int=3,
+ max_time_mask_width: int=30,
+ max_freq_mask_width: int=20) -> np.ndarray:
+ """Do spectrogram augmentation in both time and freq axis.
+
+ Args:
+ spect (np.ndarray): Input spectrogram.
+ tempo_axis (int, optional): Indicate the tempo axis. Defaults to 0.
+ max_time_mask (int, optional): Maximum number of time masking. Defaults to 3.
+ max_freq_mask (int, optional): Maximum number of frequency masking. Defaults to 3.
+ max_time_mask_width (int, optional): Maximum width of time masking. Defaults to 30.
+ max_freq_mask_width (int, optional): Maximum width of frequency masking. Defaults to 20.
+
+ Returns:
+ np.ndarray: The augmented spectrogram.
+ """
+ assert spect.ndim == 2., 'only supports 2d tensor or numpy array'
+ if tempo_axis == 0:
+ nt, nf = spect.shape
+ else:
+ nf, nt = spect.shape
+
+ num_time_mask = _randint(max_time_mask)
+ num_freq_mask = _randint(max_freq_mask)
+
+ time_mask_width = _randint(max_time_mask_width)
+ freq_mask_width = _randint(max_freq_mask_width)
+
+ if tempo_axis == 0:
+ for _ in range(num_time_mask):
+ start = _randint(nt - time_mask_width)
+ spect[start:start + time_mask_width, :] = 0
+ for _ in range(num_freq_mask):
+ start = _randint(nf - freq_mask_width)
+ spect[:, start:start + freq_mask_width] = 0
+ else:
+ for _ in range(num_time_mask):
+ start = _randint(nt - time_mask_width)
+ spect[:, start:start + time_mask_width] = 0
+ for _ in range(num_freq_mask):
+ start = _randint(nf - freq_mask_width)
+ spect[start:start + freq_mask_width, :] = 0
+
+ return spect
+
+
+def random_crop1d(y: np.ndarray, crop_len: int) -> np.ndarray:
+ """ Random cropping on a input waveform.
+
+ Args:
+ y (np.ndarray): Input waveform array in 1D.
+ crop_len (int): Length of waveform to crop.
+
+ Returns:
+ np.ndarray: The cropped waveform.
+ """
+ if y.ndim != 1:
+ 'only accept 1d tensor or numpy array'
+ n = len(y)
+ idx = _randint(n - crop_len)
+ return y[idx:idx + crop_len]
+
+
+def random_crop2d(s: np.ndarray, crop_len: int,
+ tempo_axis: int=0) -> np.ndarray:
+ """ Random cropping on a spectrogram.
+
+ Args:
+ s (np.ndarray): Input spectrogram in 2D.
+ crop_len (int): Length of spectrogram to crop.
+ tempo_axis (int, optional): Indicate the tempo axis. Defaults to 0.
+
+ Returns:
+ np.ndarray: The cropped spectrogram.
+ """
+ if tempo_axis >= s.ndim:
+ raise ParameterError('axis out of range')
+
+ n = s.shape[tempo_axis]
+ idx = _randint(high=n - crop_len)
+ sli = [slice(None) for i in range(s.ndim)]
+ sli[tempo_axis] = slice(idx, idx + crop_len)
+ out = s[tuple(sli)]
+ return out
diff --git a/paddlespeech/audio/datasets/__init__.py b/paddlespeech/audio/datasets/__init__.py
new file mode 100644
index 000000000..8068fa9d3
--- /dev/null
+++ b/paddlespeech/audio/datasets/__init__.py
@@ -0,0 +1,15 @@
+# Copyright (c) 2021 PaddlePaddle Authors. All Rights Reserved.
+#
+# Licensed under the Apache License, Version 2.0 (the "License"
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+from .esc50 import ESC50
+from .voxceleb import VoxCeleb
diff --git a/paddlespeech/audio/datasets/esc50.py b/paddlespeech/audio/datasets/esc50.py
new file mode 100644
index 000000000..684a8b8f5
--- /dev/null
+++ b/paddlespeech/audio/datasets/esc50.py
@@ -0,0 +1,152 @@
+# Copyright (c) 2021 PaddlePaddle Authors. All Rights Reserved.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+import collections
+import os
+from typing import List
+from typing import Tuple
+
+from ...utils.env import DATA_HOME
+from ..utils.download import download_and_decompress
+from .dataset import AudioClassificationDataset
+
+__all__ = ['ESC50']
+
+
+class ESC50(AudioClassificationDataset):
+ """
+ The ESC-50 dataset is a labeled collection of 2000 environmental audio recordings
+ suitable for benchmarking methods of environmental sound classification. The dataset
+ consists of 5-second-long recordings organized into 50 semantical classes (with
+ 40 examples per class)
+
+ Reference:
+ ESC: Dataset for Environmental Sound Classification
+ http://dx.doi.org/10.1145/2733373.2806390
+ """
+
+ archieves = [
+ {
+ 'url':
+ 'https://paddleaudio.bj.bcebos.com/datasets/ESC-50-master.zip',
+ 'md5': '7771e4b9d86d0945acce719c7a59305a',
+ },
+ ]
+ label_list = [
+ # Animals
+ 'Dog',
+ 'Rooster',
+ 'Pig',
+ 'Cow',
+ 'Frog',
+ 'Cat',
+ 'Hen',
+ 'Insects (flying)',
+ 'Sheep',
+ 'Crow',
+ # Natural soundscapes & water sounds
+ 'Rain',
+ 'Sea waves',
+ 'Crackling fire',
+ 'Crickets',
+ 'Chirping birds',
+ 'Water drops',
+ 'Wind',
+ 'Pouring water',
+ 'Toilet flush',
+ 'Thunderstorm',
+ # Human, non-speech sounds
+ 'Crying baby',
+ 'Sneezing',
+ 'Clapping',
+ 'Breathing',
+ 'Coughing',
+ 'Footsteps',
+ 'Laughing',
+ 'Brushing teeth',
+ 'Snoring',
+ 'Drinking, sipping',
+ # Interior/domestic sounds
+ 'Door knock',
+ 'Mouse click',
+ 'Keyboard typing',
+ 'Door, wood creaks',
+ 'Can opening',
+ 'Washing machine',
+ 'Vacuum cleaner',
+ 'Clock alarm',
+ 'Clock tick',
+ 'Glass breaking',
+ # Exterior/urban noises
+ 'Helicopter',
+ 'Chainsaw',
+ 'Siren',
+ 'Car horn',
+ 'Engine',
+ 'Train',
+ 'Church bells',
+ 'Airplane',
+ 'Fireworks',
+ 'Hand saw',
+ ]
+ meta = os.path.join('ESC-50-master', 'meta', 'esc50.csv')
+ meta_info = collections.namedtuple(
+ 'META_INFO',
+ ('filename', 'fold', 'target', 'category', 'esc10', 'src_file', 'take'))
+ audio_path = os.path.join('ESC-50-master', 'audio')
+
+ def __init__(self,
+ mode: str='train',
+ split: int=1,
+ feat_type: str='raw',
+ **kwargs):
+ """
+ Ags:
+ mode (:obj:`str`, `optional`, defaults to `train`):
+ It identifies the dataset mode (train or dev).
+ split (:obj:`int`, `optional`, defaults to 1):
+ It specify the fold of dev dataset.
+ feat_type (:obj:`str`, `optional`, defaults to `raw`):
+ It identifies the feature type that user wants to extract of an audio file.
+ """
+ files, labels = self._get_data(mode, split)
+ super(ESC50, self).__init__(
+ files=files, labels=labels, feat_type=feat_type, **kwargs)
+
+ def _get_meta_info(self) -> List[collections.namedtuple]:
+ ret = []
+ with open(os.path.join(DATA_HOME, self.meta), 'r') as rf:
+ for line in rf.readlines()[1:]:
+ ret.append(self.meta_info(*line.strip().split(',')))
+ return ret
+
+ def _get_data(self, mode: str, split: int) -> Tuple[List[str], List[int]]:
+ if not os.path.isdir(os.path.join(DATA_HOME, self.audio_path)) or \
+ not os.path.isfile(os.path.join(DATA_HOME, self.meta)):
+ download_and_decompress(self.archieves, DATA_HOME)
+
+ meta_info = self._get_meta_info()
+
+ files = []
+ labels = []
+ for sample in meta_info:
+ filename, fold, target, _, _, _, _ = sample
+ if mode == 'train' and int(fold) != split:
+ files.append(os.path.join(DATA_HOME, self.audio_path, filename))
+ labels.append(int(target))
+
+ if mode != 'train' and int(fold) == split:
+ files.append(os.path.join(DATA_HOME, self.audio_path, filename))
+ labels.append(int(target))
+
+ return files, labels
diff --git a/paddlespeech/audio/datasets/voxceleb.py b/paddlespeech/audio/datasets/voxceleb.py
new file mode 100644
index 000000000..4daa6bf6f
--- /dev/null
+++ b/paddlespeech/audio/datasets/voxceleb.py
@@ -0,0 +1,356 @@
+# Copyright (c) 2021 PaddlePaddle Authors. All Rights Reserved.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+import collections
+import csv
+import glob
+import os
+import random
+from multiprocessing import cpu_count
+from typing import List
+
+from paddle.io import Dataset
+from pathos.multiprocessing import Pool
+from tqdm import tqdm
+
+from ...utils.env import DATA_HOME
+from ..backends.soundfile_backend import soundfile_load as load_audio
+from ..utils.download import decompress
+from ..utils.download import download_and_decompress
+from .dataset import feat_funcs
+
+__all__ = ['VoxCeleb']
+
+
+class VoxCeleb(Dataset):
+ source_url = 'https://thor.robots.ox.ac.uk/~vgg/data/voxceleb/vox1a/'
+ archieves_audio_dev = [
+ {
+ 'url': source_url + 'vox1_dev_wav_partaa',
+ 'md5': 'e395d020928bc15670b570a21695ed96',
+ },
+ {
+ 'url': source_url + 'vox1_dev_wav_partab',
+ 'md5': 'bbfaaccefab65d82b21903e81a8a8020',
+ },
+ {
+ 'url': source_url + 'vox1_dev_wav_partac',
+ 'md5': '017d579a2a96a077f40042ec33e51512',
+ },
+ {
+ 'url': source_url + 'vox1_dev_wav_partad',
+ 'md5': '7bb1e9f70fddc7a678fa998ea8b3ba19',
+ },
+ ]
+ archieves_audio_test = [
+ {
+ 'url': source_url + 'vox1_test_wav.zip',
+ 'md5': '185fdc63c3c739954633d50379a3d102',
+ },
+ ]
+ archieves_meta = [
+ {
+ 'url':
+ 'https://www.robots.ox.ac.uk/~vgg/data/voxceleb/meta/veri_test2.txt',
+ 'md5':
+ 'b73110731c9223c1461fe49cb48dddfc',
+ },
+ ]
+
+ num_speakers = 1211 # 1211 vox1, 5994 vox2, 7205 vox1+2, test speakers: 41
+ sample_rate = 16000
+ meta_info = collections.namedtuple(
+ 'META_INFO', ('id', 'duration', 'wav', 'start', 'stop', 'spk_id'))
+ base_path = os.path.join(DATA_HOME, 'vox1')
+ wav_path = os.path.join(base_path, 'wav')
+ meta_path = os.path.join(base_path, 'meta')
+ veri_test_file = os.path.join(meta_path, 'veri_test2.txt')
+ csv_path = os.path.join(base_path, 'csv')
+ subsets = ['train', 'dev', 'enroll', 'test']
+
+ def __init__(
+ self,
+ subset: str='train',
+ feat_type: str='raw',
+ random_chunk: bool=True,
+ chunk_duration: float=3.0, # seconds
+ split_ratio: float=0.9, # train split ratio
+ seed: int=0,
+ target_dir: str=None,
+ vox2_base_path=None,
+ **kwargs):
+ """VoxCeleb data prepare and get the specific dataset audio info
+
+ Args:
+ subset (str, optional): dataset name, such as train, dev, enroll or test. Defaults to 'train'.
+ feat_type (str, optional): feat type, such raw, melspectrogram(fbank) or mfcc . Defaults to 'raw'.
+ random_chunk (bool, optional): random select a duration from audio. Defaults to True.
+ chunk_duration (float, optional): chunk duration if random_chunk flag is set. Defaults to 3.0.
+ target_dir (str, optional): data dir, audio info will be stored in this directory. Defaults to None.
+ vox2_base_path (_type_, optional): vox2 directory. vox2 data must be converted from m4a to wav. Defaults to None.
+ """
+ assert subset in self.subsets, \
+ 'Dataset subset must be one in {}, but got {}'.format(self.subsets, subset)
+
+ self.subset = subset
+ self.spk_id2label = {}
+ self.feat_type = feat_type
+ self.feat_config = kwargs
+ self.random_chunk = random_chunk
+ self.chunk_duration = chunk_duration
+ self.split_ratio = split_ratio
+ self.target_dir = target_dir if target_dir else VoxCeleb.base_path
+ self.vox2_base_path = vox2_base_path
+
+ # if we set the target dir, we will change the vox data info data from base path to target dir
+ VoxCeleb.csv_path = os.path.join(
+ target_dir, "voxceleb", 'csv') if target_dir else VoxCeleb.csv_path
+ VoxCeleb.meta_path = os.path.join(
+ target_dir, "voxceleb",
+ 'meta') if target_dir else VoxCeleb.meta_path
+ VoxCeleb.veri_test_file = os.path.join(VoxCeleb.meta_path,
+ 'veri_test2.txt')
+ # self._data = self._get_data()[:1000] # KP: Small dataset test.
+ self._data = self._get_data()
+ super(VoxCeleb, self).__init__()
+
+ # Set up a seed to reproduce training or predicting result.
+ # random.seed(seed)
+
+ def _get_data(self):
+ # Download audio files.
+ # We need the users to decompress all vox1/dev/wav and vox1/test/wav/ to vox1/wav/ dir
+ # so, we check the vox1/wav dir status
+ print(f"wav base path: {self.wav_path}")
+ if not os.path.isdir(self.wav_path):
+ print("start to download the voxceleb1 dataset")
+ download_and_decompress( # multi-zip parts concatenate to vox1_dev_wav.zip
+ self.archieves_audio_dev,
+ self.base_path,
+ decompress=False)
+ download_and_decompress( # download the vox1_test_wav.zip and unzip
+ self.archieves_audio_test,
+ self.base_path,
+ decompress=True)
+
+ # Download all parts and concatenate the files into one zip file.
+ dev_zipfile = os.path.join(self.base_path, 'vox1_dev_wav.zip')
+ print(f'Concatenating all parts to: {dev_zipfile}')
+ os.system(
+ f'cat {os.path.join(self.base_path, "vox1_dev_wav_parta*")} > {dev_zipfile}'
+ )
+
+ # Extract all audio files of dev and test set.
+ decompress(dev_zipfile, self.base_path)
+
+ # Download meta files.
+ if not os.path.isdir(self.meta_path):
+ print("prepare the meta data")
+ download_and_decompress(
+ self.archieves_meta, self.meta_path, decompress=False)
+
+ # Data preparation.
+ if not os.path.isdir(self.csv_path):
+ os.makedirs(self.csv_path)
+ self.prepare_data()
+
+ data = []
+ print(
+ f"read the {self.subset} from {os.path.join(self.csv_path, f'{self.subset}.csv')}"
+ )
+ with open(os.path.join(self.csv_path, f'{self.subset}.csv'), 'r') as rf:
+ for line in rf.readlines()[1:]:
+ audio_id, duration, wav, start, stop, spk_id = line.strip(
+ ).split(',')
+ data.append(
+ self.meta_info(audio_id,
+ float(duration), wav,
+ int(start), int(stop), spk_id))
+
+ with open(os.path.join(self.meta_path, 'spk_id2label.txt'), 'r') as f:
+ for line in f.readlines():
+ spk_id, label = line.strip().split(' ')
+ self.spk_id2label[spk_id] = int(label)
+
+ return data
+
+ def _convert_to_record(self, idx: int):
+ sample = self._data[idx]
+
+ record = {}
+ # To show all fields in a namedtuple: `type(sample)._fields`
+ for field in type(sample)._fields:
+ record[field] = getattr(sample, field)
+
+ waveform, sr = load_audio(record['wav'])
+
+ # random select a chunk audio samples from the audio
+ if self.random_chunk:
+ num_wav_samples = waveform.shape[0]
+ num_chunk_samples = int(self.chunk_duration * sr)
+ start = random.randint(0, num_wav_samples - num_chunk_samples - 1)
+ stop = start + num_chunk_samples
+ else:
+ start = record['start']
+ stop = record['stop']
+
+ waveform = waveform[start:stop]
+
+ assert self.feat_type in feat_funcs.keys(), \
+ f"Unknown feat_type: {self.feat_type}, it must be one in {list(feat_funcs.keys())}"
+ feat_func = feat_funcs[self.feat_type]
+ feat = feat_func(
+ waveform, sr=sr, **self.feat_config) if feat_func else waveform
+
+ record.update({'feat': feat})
+ if self.subset in ['train',
+ 'dev']: # Labels are available in train and dev.
+ record.update({'label': self.spk_id2label[record['spk_id']]})
+
+ return record
+
+ @staticmethod
+ def _get_chunks(seg_dur, audio_id, audio_duration):
+ num_chunks = int(audio_duration / seg_dur) # all in milliseconds
+
+ chunk_lst = [
+ audio_id + "_" + str(i * seg_dur) + "_" + str(i * seg_dur + seg_dur)
+ for i in range(num_chunks)
+ ]
+ return chunk_lst
+
+ def _get_audio_info(self, wav_file: str,
+ split_chunks: bool) -> List[List[str]]:
+ waveform, sr = load_audio(wav_file)
+ spk_id, sess_id, utt_id = wav_file.split("/")[-3:]
+ audio_id = '-'.join([spk_id, sess_id, utt_id.split(".")[0]])
+ audio_duration = waveform.shape[0] / sr
+
+ ret = []
+ if split_chunks: # Split into pieces of self.chunk_duration seconds.
+ uniq_chunks_list = self._get_chunks(self.chunk_duration, audio_id,
+ audio_duration)
+
+ for chunk in uniq_chunks_list:
+ s, e = chunk.split("_")[-2:] # Timestamps of start and end
+ start_sample = int(float(s) * sr)
+ end_sample = int(float(e) * sr)
+ # id, duration, wav, start, stop, spk_id
+ ret.append([
+ chunk, audio_duration, wav_file, start_sample, end_sample,
+ spk_id
+ ])
+ else: # Keep whole audio.
+ ret.append([
+ audio_id, audio_duration, wav_file, 0, waveform.shape[0], spk_id
+ ])
+ return ret
+
+ def generate_csv(self,
+ wav_files: List[str],
+ output_file: str,
+ split_chunks: bool=True):
+ print(f'Generating csv: {output_file}')
+ header = ["id", "duration", "wav", "start", "stop", "spk_id"]
+ # Note: this may occurs c++ exception, but the program will execute fine
+ # so we can ignore the exception
+ with Pool(cpu_count()) as p:
+ infos = list(
+ tqdm(
+ p.imap(lambda x: self._get_audio_info(x, split_chunks),
+ wav_files),
+ total=len(wav_files)))
+
+ csv_lines = []
+ for info in infos:
+ csv_lines.extend(info)
+
+ with open(output_file, mode="w") as csv_f:
+ csv_writer = csv.writer(
+ csv_f, delimiter=",", quotechar='"', quoting=csv.QUOTE_MINIMAL)
+ csv_writer.writerow(header)
+ for line in csv_lines:
+ csv_writer.writerow(line)
+
+ def prepare_data(self):
+ # Audio of speakers in veri_test_file should not be included in training set.
+ print("start to prepare the data csv file")
+ enroll_files = set()
+ test_files = set()
+ # get the enroll and test audio file path
+ with open(self.veri_test_file, 'r') as f:
+ for line in f.readlines():
+ _, enrol_file, test_file = line.strip().split(' ')
+ enroll_files.add(os.path.join(self.wav_path, enrol_file))
+ test_files.add(os.path.join(self.wav_path, test_file))
+ enroll_files = sorted(enroll_files)
+ test_files = sorted(test_files)
+
+ # get the enroll and test speakers
+ test_spks = set()
+ for file in (enroll_files + test_files):
+ spk = file.split('/wav/')[1].split('/')[0]
+ test_spks.add(spk)
+
+ # get all the train and dev audios file path
+ audio_files = []
+ speakers = set()
+ print("Getting file list...")
+ for path in [self.wav_path, self.vox2_base_path]:
+ # if vox2 directory is not set and vox2 is not a directory
+ # we will not process this directory
+ if not path or not os.path.exists(path):
+ print(f"{path} is an invalid path, please check again, "
+ "and we will ignore the vox2 base path")
+ continue
+ for file in glob.glob(
+ os.path.join(path, "**", "*.wav"), recursive=True):
+ spk = file.split('/wav/')[1].split('/')[0]
+ if spk in test_spks:
+ continue
+ speakers.add(spk)
+ audio_files.append(file)
+
+ print(
+ f"start to generate the {os.path.join(self.meta_path, 'spk_id2label.txt')}"
+ )
+ # encode the train and dev speakers label to spk_id2label.txt
+ with open(os.path.join(self.meta_path, 'spk_id2label.txt'), 'w') as f:
+ for label, spk_id in enumerate(
+ sorted(speakers)): # 1211 vox1, 5994 vox2, 7205 vox1+2
+ f.write(f'{spk_id} {label}\n')
+
+ audio_files = sorted(audio_files)
+ random.shuffle(audio_files)
+ split_idx = int(self.split_ratio * len(audio_files))
+ # split_ratio to train
+ train_files, dev_files = audio_files[:split_idx], audio_files[
+ split_idx:]
+
+ self.generate_csv(train_files, os.path.join(self.csv_path, 'train.csv'))
+ self.generate_csv(dev_files, os.path.join(self.csv_path, 'dev.csv'))
+
+ self.generate_csv(
+ enroll_files,
+ os.path.join(self.csv_path, 'enroll.csv'),
+ split_chunks=False)
+ self.generate_csv(
+ test_files,
+ os.path.join(self.csv_path, 'test.csv'),
+ split_chunks=False)
+
+ def __getitem__(self, idx):
+ return self._convert_to_record(idx)
+
+ def __len__(self):
+ return len(self._data)
diff --git a/paddlespeech/audio/functional/__init__.py b/paddlespeech/audio/functional/__init__.py
new file mode 100644
index 000000000..c85232df1
--- /dev/null
+++ b/paddlespeech/audio/functional/__init__.py
@@ -0,0 +1,20 @@
+# Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+from .functional import compute_fbank_matrix
+from .functional import create_dct
+from .functional import fft_frequencies
+from .functional import hz_to_mel
+from .functional import mel_frequencies
+from .functional import mel_to_hz
+from .functional import power_to_db
diff --git a/paddlespeech/audio/functional/functional.py b/paddlespeech/audio/functional/functional.py
new file mode 100644
index 000000000..7c20f9013
--- /dev/null
+++ b/paddlespeech/audio/functional/functional.py
@@ -0,0 +1,266 @@
+# Copyright (c) 2021 PaddlePaddle Authors. All Rights Reserved.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+# Modified from librosa(https://github.com/librosa/librosa)
+import math
+from typing import Optional
+from typing import Union
+
+import paddle
+from paddle import Tensor
+
+__all__ = [
+ 'hz_to_mel',
+ 'mel_to_hz',
+ 'mel_frequencies',
+ 'fft_frequencies',
+ 'compute_fbank_matrix',
+ 'power_to_db',
+ 'create_dct',
+]
+
+
+def hz_to_mel(freq: Union[Tensor, float],
+ htk: bool=False) -> Union[Tensor, float]:
+ """Convert Hz to Mels.
+
+ Args:
+ freq (Union[Tensor, float]): The input tensor with arbitrary shape.
+ htk (bool, optional): Use htk scaling. Defaults to False.
+
+ Returns:
+ Union[Tensor, float]: Frequency in mels.
+ """
+
+ if htk:
+ if isinstance(freq, Tensor):
+ return 2595.0 * paddle.log10(1.0 + freq / 700.0)
+ else:
+ return 2595.0 * math.log10(1.0 + freq / 700.0)
+
+ # Fill in the linear part
+ f_min = 0.0
+ f_sp = 200.0 / 3
+
+ mels = (freq - f_min) / f_sp
+
+ # Fill in the log-scale part
+
+ min_log_hz = 1000.0 # beginning of log region (Hz)
+ min_log_mel = (min_log_hz - f_min) / f_sp # same (Mels)
+ logstep = math.log(6.4) / 27.0 # step size for log region
+
+ if isinstance(freq, Tensor):
+ target = min_log_mel + paddle.log(
+ freq / min_log_hz + 1e-10) / logstep # prevent nan with 1e-10
+ mask = (freq > min_log_hz).astype(freq.dtype)
+ mels = target * mask + mels * (
+ 1 - mask) # will replace by masked_fill OP in future
+ else:
+ if freq >= min_log_hz:
+ mels = min_log_mel + math.log(freq / min_log_hz + 1e-10) / logstep
+
+ return mels
+
+
+def mel_to_hz(mel: Union[float, Tensor],
+ htk: bool=False) -> Union[float, Tensor]:
+ """Convert mel bin numbers to frequencies.
+
+ Args:
+ mel (Union[float, Tensor]): The mel frequency represented as a tensor with arbitrary shape.
+ htk (bool, optional): Use htk scaling. Defaults to False.
+
+ Returns:
+ Union[float, Tensor]: Frequencies in Hz.
+ """
+ if htk:
+ return 700.0 * (10.0**(mel / 2595.0) - 1.0)
+
+ f_min = 0.0
+ f_sp = 200.0 / 3
+ freqs = f_min + f_sp * mel
+ # And now the nonlinear scale
+ min_log_hz = 1000.0 # beginning of log region (Hz)
+ min_log_mel = (min_log_hz - f_min) / f_sp # same (Mels)
+ logstep = math.log(6.4) / 27.0 # step size for log region
+ if isinstance(mel, Tensor):
+ target = min_log_hz * paddle.exp(logstep * (mel - min_log_mel))
+ mask = (mel > min_log_mel).astype(mel.dtype)
+ freqs = target * mask + freqs * (
+ 1 - mask) # will replace by masked_fill OP in future
+ else:
+ if mel >= min_log_mel:
+ freqs = min_log_hz * math.exp(logstep * (mel - min_log_mel))
+
+ return freqs
+
+
+def mel_frequencies(n_mels: int=64,
+ f_min: float=0.0,
+ f_max: float=11025.0,
+ htk: bool=False,
+ dtype: str='float32') -> Tensor:
+ """Compute mel frequencies.
+
+ Args:
+ n_mels (int, optional): Number of mel bins. Defaults to 64.
+ f_min (float, optional): Minimum frequency in Hz. Defaults to 0.0.
+ fmax (float, optional): Maximum frequency in Hz. Defaults to 11025.0.
+ htk (bool, optional): Use htk scaling. Defaults to False.
+ dtype (str, optional): The data type of the return frequencies. Defaults to 'float32'.
+
+ Returns:
+ Tensor: Tensor of n_mels frequencies in Hz with shape `(n_mels,)`.
+ """
+ # 'Center freqs' of mel bands - uniformly spaced between limits
+ min_mel = hz_to_mel(f_min, htk=htk)
+ max_mel = hz_to_mel(f_max, htk=htk)
+ mels = paddle.linspace(min_mel, max_mel, n_mels, dtype=dtype)
+ freqs = mel_to_hz(mels, htk=htk)
+ return freqs
+
+
+def fft_frequencies(sr: int, n_fft: int, dtype: str='float32') -> Tensor:
+ """Compute fourier frequencies.
+
+ Args:
+ sr (int): Sample rate.
+ n_fft (int): Number of fft bins.
+ dtype (str, optional): The data type of the return frequencies. Defaults to 'float32'.
+
+ Returns:
+ Tensor: FFT frequencies in Hz with shape `(n_fft//2 + 1,)`.
+ """
+ return paddle.linspace(0, float(sr) / 2, int(1 + n_fft // 2), dtype=dtype)
+
+
+def compute_fbank_matrix(sr: int,
+ n_fft: int,
+ n_mels: int=64,
+ f_min: float=0.0,
+ f_max: Optional[float]=None,
+ htk: bool=False,
+ norm: Union[str, float]='slaney',
+ dtype: str='float32') -> Tensor:
+ """Compute fbank matrix.
+
+ Args:
+ sr (int): Sample rate.
+ n_fft (int): Number of fft bins.
+ n_mels (int, optional): Number of mel bins. Defaults to 64.
+ f_min (float, optional): Minimum frequency in Hz. Defaults to 0.0.
+ f_max (Optional[float], optional): Maximum frequency in Hz. Defaults to None.
+ htk (bool, optional): Use htk scaling. Defaults to False.
+ norm (Union[str, float], optional): Type of normalization. Defaults to 'slaney'.
+ dtype (str, optional): The data type of the return matrix. Defaults to 'float32'.
+
+ Returns:
+ Tensor: Mel transform matrix with shape `(n_mels, n_fft//2 + 1)`.
+ """
+
+ if f_max is None:
+ f_max = float(sr) / 2
+
+ # Initialize the weights
+ weights = paddle.zeros((n_mels, int(1 + n_fft // 2)), dtype=dtype)
+
+ # Center freqs of each FFT bin
+ fftfreqs = fft_frequencies(sr=sr, n_fft=n_fft, dtype=dtype)
+
+ # 'Center freqs' of mel bands - uniformly spaced between limits
+ mel_f = mel_frequencies(
+ n_mels + 2, f_min=f_min, f_max=f_max, htk=htk, dtype=dtype)
+
+ fdiff = mel_f[1:] - mel_f[:-1] #np.diff(mel_f)
+ ramps = mel_f.unsqueeze(1) - fftfreqs.unsqueeze(0)
+ #ramps = np.subtract.outer(mel_f, fftfreqs)
+
+ for i in range(n_mels):
+ # lower and upper slopes for all bins
+ lower = -ramps[i] / fdiff[i]
+ upper = ramps[i + 2] / fdiff[i + 1]
+
+ # .. then intersect them with each other and zero
+ weights[i] = paddle.maximum(
+ paddle.zeros_like(lower), paddle.minimum(lower, upper))
+
+ # Slaney-style mel is scaled to be approx constant energy per channel
+ if norm == 'slaney':
+ enorm = 2.0 / (mel_f[2:n_mels + 2] - mel_f[:n_mels])
+ weights *= enorm.unsqueeze(1)
+ elif isinstance(norm, int) or isinstance(norm, float):
+ weights = paddle.nn.functional.normalize(weights, p=norm, axis=-1)
+
+ return weights
+
+
+def power_to_db(spect: Tensor,
+ ref_value: float=1.0,
+ amin: float=1e-10,
+ top_db: Optional[float]=None) -> Tensor:
+ """Convert a power spectrogram (amplitude squared) to decibel (dB) units. The function computes the scaling `10 * log10(x / ref)` in a numerically stable way.
+
+ Args:
+ spect (Tensor): STFT power spectrogram.
+ ref_value (float, optional): The reference value. If smaller than 1.0, the db level of the signal will be pulled up accordingly. Otherwise, the db level is pushed down. Defaults to 1.0.
+ amin (float, optional): Minimum threshold. Defaults to 1e-10.
+ top_db (Optional[float], optional): Threshold the output at `top_db` below the peak. Defaults to None.
+
+ Returns:
+ Tensor: Power spectrogram in db scale.
+ """
+ if amin <= 0:
+ raise Exception("amin must be strictly positive")
+
+ if ref_value <= 0:
+ raise Exception("ref_value must be strictly positive")
+
+ ones = paddle.ones_like(spect)
+ log_spec = 10.0 * paddle.log10(paddle.maximum(ones * amin, spect))
+ log_spec -= 10.0 * math.log10(max(ref_value, amin))
+
+ if top_db is not None:
+ if top_db < 0:
+ raise Exception("top_db must be non-negative")
+ log_spec = paddle.maximum(log_spec, ones * (log_spec.max() - top_db))
+
+ return log_spec
+
+
+def create_dct(n_mfcc: int,
+ n_mels: int,
+ norm: Optional[str]='ortho',
+ dtype: str='float32') -> Tensor:
+ """Create a discrete cosine transform(DCT) matrix.
+
+ Args:
+ n_mfcc (int): Number of mel frequency cepstral coefficients.
+ n_mels (int): Number of mel filterbanks.
+ norm (Optional[str], optional): Normalization type. Defaults to 'ortho'.
+ dtype (str, optional): The data type of the return matrix. Defaults to 'float32'.
+
+ Returns:
+ Tensor: The DCT matrix with shape `(n_mels, n_mfcc)`.
+ """
+ n = paddle.arange(n_mels, dtype=dtype)
+ k = paddle.arange(n_mfcc, dtype=dtype).unsqueeze(1)
+ dct = paddle.cos(math.pi / float(n_mels) * (n + 0.5) *
+ k) # size (n_mfcc, n_mels)
+ if norm is None:
+ dct *= 2.0
+ else:
+ assert norm == "ortho"
+ dct[0] *= 1.0 / math.sqrt(2.0)
+ dct *= math.sqrt(2.0 / float(n_mels))
+ return dct.T
diff --git a/paddlespeech/audio/functional/window.py b/paddlespeech/audio/functional/window.py
new file mode 100644
index 000000000..c518dbab3
--- /dev/null
+++ b/paddlespeech/audio/functional/window.py
@@ -0,0 +1,373 @@
+# Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+import math
+from typing import List
+from typing import Tuple
+from typing import Union
+
+import paddle
+from paddle import Tensor
+
+
+class WindowFunctionRegister(object):
+ def __init__(self):
+ self._functions_dict = dict()
+
+ def register(self):
+ def add_subfunction(func):
+ name = func.__name__
+ self._functions_dict[name] = func
+ return func
+
+ return add_subfunction
+
+ def get(self, name):
+ return self._functions_dict[name]
+
+
+window_function_register = WindowFunctionRegister()
+
+
+@window_function_register.register()
+def _cat(x: List[Tensor], data_type: str) -> Tensor:
+ l = [paddle.to_tensor(_, data_type) for _ in x]
+ return paddle.concat(l)
+
+
+@window_function_register.register()
+def _acosh(x: Union[Tensor, float]) -> Tensor:
+ if isinstance(x, float):
+ return math.log(x + math.sqrt(x**2 - 1))
+ return paddle.log(x + paddle.sqrt(paddle.square(x) - 1))
+
+
+@window_function_register.register()
+def _extend(M: int, sym: bool) -> bool:
+ """Extend window by 1 sample if needed for DFT-even symmetry."""
+ if not sym:
+ return M + 1, True
+ else:
+ return M, False
+
+
+@window_function_register.register()
+def _len_guards(M: int) -> bool:
+ """Handle small or incorrect window lengths."""
+ if int(M) != M or M < 0:
+ raise ValueError('Window length M must be a non-negative integer')
+
+ return M <= 1
+
+
+@window_function_register.register()
+def _truncate(w: Tensor, needed: bool) -> Tensor:
+ """Truncate window by 1 sample if needed for DFT-even symmetry."""
+ if needed:
+ return w[:-1]
+ else:
+ return w
+
+
+@window_function_register.register()
+def _general_gaussian(M: int, p, sig, sym: bool=True,
+ dtype: str='float64') -> Tensor:
+ """Compute a window with a generalized Gaussian shape.
+ This function is consistent with scipy.signal.windows.general_gaussian().
+ """
+ if _len_guards(M):
+ return paddle.ones((M, ), dtype=dtype)
+ M, needs_trunc = _extend(M, sym)
+
+ n = paddle.arange(0, M, dtype=dtype) - (M - 1.0) / 2.0
+ w = paddle.exp(-0.5 * paddle.abs(n / sig)**(2 * p))
+
+ return _truncate(w, needs_trunc)
+
+
+@window_function_register.register()
+def _general_cosine(M: int, a: float, sym: bool=True,
+ dtype: str='float64') -> Tensor:
+ """Compute a generic weighted sum of cosine terms window.
+ This function is consistent with scipy.signal.windows.general_cosine().
+ """
+ if _len_guards(M):
+ return paddle.ones((M, ), dtype=dtype)
+ M, needs_trunc = _extend(M, sym)
+ fac = paddle.linspace(-math.pi, math.pi, M, dtype=dtype)
+ w = paddle.zeros((M, ), dtype=dtype)
+ for k in range(len(a)):
+ w += a[k] * paddle.cos(k * fac)
+ return _truncate(w, needs_trunc)
+
+
+@window_function_register.register()
+def _general_hamming(M: int, alpha: float, sym: bool=True,
+ dtype: str='float64') -> Tensor:
+ """Compute a generalized Hamming window.
+ This function is consistent with scipy.signal.windows.general_hamming()
+ """
+ return _general_cosine(M, [alpha, 1.0 - alpha], sym, dtype=dtype)
+
+
+@window_function_register.register()
+def _taylor(M: int,
+ nbar=4,
+ sll=30,
+ norm=True,
+ sym: bool=True,
+ dtype: str='float64') -> Tensor:
+ """Compute a Taylor window.
+ The Taylor window taper function approximates the Dolph-Chebyshev window's
+ constant sidelobe level for a parameterized number of near-in sidelobes.
+ """
+ if _len_guards(M):
+ return paddle.ones((M, ), dtype=dtype)
+ M, needs_trunc = _extend(M, sym)
+ # Original text uses a negative sidelobe level parameter and then negates
+ # it in the calculation of B. To keep consistent with other methods we
+ # assume the sidelobe level parameter to be positive.
+ B = 10**(sll / 20)
+ A = _acosh(B) / math.pi
+ s2 = nbar**2 / (A**2 + (nbar - 0.5)**2)
+ ma = paddle.arange(1, nbar, dtype=dtype)
+
+ Fm = paddle.empty((nbar - 1, ), dtype=dtype)
+ signs = paddle.empty_like(ma)
+ signs[::2] = 1
+ signs[1::2] = -1
+ m2 = ma * ma
+ for mi in range(len(ma)):
+ numer = signs[mi] * paddle.prod(1 - m2[mi] / s2 / (A**2 + (ma - 0.5)**2
+ ))
+ if mi == 0:
+ denom = 2 * paddle.prod(1 - m2[mi] / m2[mi + 1:])
+ elif mi == len(ma) - 1:
+ denom = 2 * paddle.prod(1 - m2[mi] / m2[:mi])
+ else:
+ denom = (2 * paddle.prod(1 - m2[mi] / m2[:mi]) *
+ paddle.prod(1 - m2[mi] / m2[mi + 1:]))
+
+ Fm[mi] = numer / denom
+
+ def W(n):
+ return 1 + 2 * paddle.matmul(
+ Fm.unsqueeze(0),
+ paddle.cos(2 * math.pi * ma.unsqueeze(1) *
+ (n - M / 2.0 + 0.5) / M), )
+
+ w = W(paddle.arange(0, M, dtype=dtype))
+
+ # normalize (Note that this is not described in the original text [1])
+ if norm:
+ scale = 1.0 / W((M - 1) / 2)
+ w *= scale
+ w = w.squeeze()
+ return _truncate(w, needs_trunc)
+
+
+@window_function_register.register()
+def _hamming(M: int, sym: bool=True, dtype: str='float64') -> Tensor:
+ """Compute a Hamming window.
+ The Hamming window is a taper formed by using a raised cosine with
+ non-zero endpoints, optimized to minimize the nearest side lobe.
+ """
+ return _general_hamming(M, 0.54, sym, dtype=dtype)
+
+
+@window_function_register.register()
+def _hann(M: int, sym: bool=True, dtype: str='float64') -> Tensor:
+ """Compute a Hann window.
+ The Hann window is a taper formed by using a raised cosine or sine-squared
+ with ends that touch zero.
+ """
+ return _general_hamming(M, 0.5, sym, dtype=dtype)
+
+
+@window_function_register.register()
+def _tukey(M: int, alpha=0.5, sym: bool=True, dtype: str='float64') -> Tensor:
+ """Compute a Tukey window.
+ The Tukey window is also known as a tapered cosine window.
+ """
+ if _len_guards(M):
+ return paddle.ones((M, ), dtype=dtype)
+
+ if alpha <= 0:
+ return paddle.ones((M, ), dtype=dtype)
+ elif alpha >= 1.0:
+ return hann(M, sym=sym)
+
+ M, needs_trunc = _extend(M, sym)
+
+ n = paddle.arange(0, M, dtype=dtype)
+ width = int(alpha * (M - 1) / 2.0)
+ n1 = n[0:width + 1]
+ n2 = n[width + 1:M - width - 1]
+ n3 = n[M - width - 1:]
+
+ w1 = 0.5 * (1 + paddle.cos(math.pi * (-1 + 2.0 * n1 / alpha / (M - 1))))
+ w2 = paddle.ones(n2.shape, dtype=dtype)
+ w3 = 0.5 * (1 + paddle.cos(math.pi * (-2.0 / alpha + 1 + 2.0 * n3 / alpha /
+ (M - 1))))
+ w = paddle.concat([w1, w2, w3])
+
+ return _truncate(w, needs_trunc)
+
+
+@window_function_register.register()
+def _gaussian(M: int, std: float, sym: bool=True,
+ dtype: str='float64') -> Tensor:
+ """Compute a Gaussian window.
+ The Gaussian widows has a Gaussian shape defined by the standard deviation(std).
+ """
+ if _len_guards(M):
+ return paddle.ones((M, ), dtype=dtype)
+ M, needs_trunc = _extend(M, sym)
+
+ n = paddle.arange(0, M, dtype=dtype) - (M - 1.0) / 2.0
+ sig2 = 2 * std * std
+ w = paddle.exp(-(n**2) / sig2)
+
+ return _truncate(w, needs_trunc)
+
+
+@window_function_register.register()
+def _exponential(M: int,
+ center=None,
+ tau=1.0,
+ sym: bool=True,
+ dtype: str='float64') -> Tensor:
+ """Compute an exponential (or Poisson) window."""
+ if sym and center is not None:
+ raise ValueError("If sym==True, center must be None.")
+ if _len_guards(M):
+ return paddle.ones((M, ), dtype=dtype)
+ M, needs_trunc = _extend(M, sym)
+
+ if center is None:
+ center = (M - 1) / 2
+
+ n = paddle.arange(0, M, dtype=dtype)
+ w = paddle.exp(-paddle.abs(n - center) / tau)
+
+ return _truncate(w, needs_trunc)
+
+
+@window_function_register.register()
+def _triang(M: int, sym: bool=True, dtype: str='float64') -> Tensor:
+ """Compute a triangular window."""
+ if _len_guards(M):
+ return paddle.ones((M, ), dtype=dtype)
+ M, needs_trunc = _extend(M, sym)
+
+ n = paddle.arange(1, (M + 1) // 2 + 1, dtype=dtype)
+ if M % 2 == 0:
+ w = (2 * n - 1.0) / M
+ w = paddle.concat([w, w[::-1]])
+ else:
+ w = 2 * n / (M + 1.0)
+ w = paddle.concat([w, w[-2::-1]])
+
+ return _truncate(w, needs_trunc)
+
+
+@window_function_register.register()
+def _bohman(M: int, sym: bool=True, dtype: str='float64') -> Tensor:
+ """Compute a Bohman window.
+ The Bohman window is the autocorrelation of a cosine window.
+ """
+ if _len_guards(M):
+ return paddle.ones((M, ), dtype=dtype)
+ M, needs_trunc = _extend(M, sym)
+
+ fac = paddle.abs(paddle.linspace(-1, 1, M, dtype=dtype)[1:-1])
+ w = (1 - fac) * paddle.cos(math.pi * fac) + 1.0 / math.pi * paddle.sin(
+ math.pi * fac)
+ w = _cat([0, w, 0], dtype)
+
+ return _truncate(w, needs_trunc)
+
+
+@window_function_register.register()
+def _blackman(M: int, sym: bool=True, dtype: str='float64') -> Tensor:
+ """Compute a Blackman window.
+ The Blackman window is a taper formed by using the first three terms of
+ a summation of cosines. It was designed to have close to the minimal
+ leakage possible. It is close to optimal, only slightly worse than a
+ Kaiser window.
+ """
+ return _general_cosine(M, [0.42, 0.50, 0.08], sym, dtype=dtype)
+
+
+@window_function_register.register()
+def _cosine(M: int, sym: bool=True, dtype: str='float64') -> Tensor:
+ """Compute a window with a simple cosine shape."""
+ if _len_guards(M):
+ return paddle.ones((M, ), dtype=dtype)
+ M, needs_trunc = _extend(M, sym)
+ w = paddle.sin(math.pi / M * (paddle.arange(0, M, dtype=dtype) + 0.5))
+
+ return _truncate(w, needs_trunc)
+
+
+def get_window(
+ window: Union[str, Tuple[str, float]],
+ win_length: int,
+ fftbins: bool=True,
+ dtype: str='float64', ) -> Tensor:
+ """Return a window of a given length and type.
+
+ Args:
+ window (Union[str, Tuple[str, float]]): The window function applied to the signal before the Fourier transform. Supported window functions: 'hamming', 'hann', 'gaussian', 'general_gaussian', 'exponential', 'triang', 'bohman', 'blackman', 'cosine', 'tukey', 'taylor'.
+ win_length (int): Number of samples.
+ fftbins (bool, optional): If True, create a "periodic" window. Otherwise, create a "symmetric" window, for use in filter design. Defaults to True.
+ dtype (str, optional): The data type of the return window. Defaults to 'float64'.
+
+ Returns:
+ Tensor: The window represented as a tensor.
+
+ Examples:
+ .. code-block:: python
+
+ import paddle
+
+ n_fft = 512
+ cosine_window = paddle.audio.functional.get_window('cosine', n_fft)
+
+ std = 7
+ gaussian_window = paddle.audio.functional.get_window(('gaussian',std), n_fft)
+ """
+ sym = not fftbins
+
+ args = ()
+ if isinstance(window, tuple):
+ winstr = window[0]
+ if len(window) > 1:
+ args = window[1:]
+ elif isinstance(window, str):
+ if window in ['gaussian', 'exponential']:
+ raise ValueError("The '" + window + "' window needs one or "
+ "more parameters -- pass a tuple.")
+ else:
+ winstr = window
+ else:
+ raise ValueError("%s as window type is not supported." %
+ str(type(window)))
+
+ try:
+ winfunc = window_function_register.get('_' + winstr)
+ except KeyError as e:
+ raise ValueError("Unknown window type.") from e
+
+ params = (win_length, ) + args
+ kwargs = {'sym': sym}
+ return winfunc(*params, dtype=dtype, **kwargs)
diff --git a/paddlespeech/audio/streamdata/autodecode.py b/paddlespeech/audio/streamdata/autodecode.py
index 2e82226df..664509842 100644
--- a/paddlespeech/audio/streamdata/autodecode.py
+++ b/paddlespeech/audio/streamdata/autodecode.py
@@ -304,13 +304,11 @@ def paddle_audio(key, data):
if extension not in ["flac", "mp3", "sox", "wav", "m4a", "ogg", "wma"]:
return None
- import paddleaudio
-
with tempfile.TemporaryDirectory() as dirname:
fname = os.path.join(dirname, f"file.{extension}")
with open(fname, "wb") as stream:
stream.write(data)
- return paddleaudio.backends.soundfile_load(fname)
+ return paddlespeech.audio.backends.soundfile_load(fname)
################################################################
diff --git a/paddlespeech/audio/streamdata/filters.py b/paddlespeech/audio/streamdata/filters.py
index 110b4a304..9a00c2dc6 100644
--- a/paddlespeech/audio/streamdata/filters.py
+++ b/paddlespeech/audio/streamdata/filters.py
@@ -22,8 +22,6 @@ from fnmatch import fnmatch
from functools import reduce
import paddle
-from paddleaudio import backends
-from paddleaudio.compliance import kaldi
from . import autodecode
from . import utils
@@ -33,6 +31,8 @@ from ..transform.spec_augment import time_mask
from ..transform.spec_augment import time_warp
from ..utils.tensor_utils import pad_sequence
from .utils import PipelineStage
+from paddlespeech.audio import backends
+from paddlespeech.audio.compliance import kaldi
class FilterFunction(object):
diff --git a/paddlespeech/audio/streamdata/soundfile.py b/paddlespeech/audio/streamdata/soundfile.py
new file mode 100644
index 000000000..7611fd297
--- /dev/null
+++ b/paddlespeech/audio/streamdata/soundfile.py
@@ -0,0 +1,677 @@
+# Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+import os
+import warnings
+from typing import Optional
+from typing import Tuple
+
+import numpy as np
+import paddle
+import resampy
+import soundfile
+from scipy.io import wavfile
+
+from ..utils import depth_convert
+from ..utils import ParameterError
+from .common import AudioInfo
+
+__all__ = [
+ 'resample',
+ 'to_mono',
+ 'normalize',
+ 'save',
+ 'soundfile_save',
+ 'load',
+ 'soundfile_load',
+ 'info',
+]
+NORMALMIZE_TYPES = ['linear', 'gaussian']
+MERGE_TYPES = ['ch0', 'ch1', 'random', 'average']
+RESAMPLE_MODES = ['kaiser_best', 'kaiser_fast']
+EPS = 1e-8
+
+
+def resample(y: np.ndarray,
+ src_sr: int,
+ target_sr: int,
+ mode: str='kaiser_fast') -> np.ndarray:
+ """Audio resampling.
+
+ Args:
+ y (np.ndarray): Input waveform array in 1D or 2D.
+ src_sr (int): Source sample rate.
+ target_sr (int): Target sample rate.
+ mode (str, optional): The resampling filter to use. Defaults to 'kaiser_fast'.
+
+ Returns:
+ np.ndarray: `y` resampled to `target_sr`
+ """
+
+ if mode == 'kaiser_best':
+ warnings.warn(
+ f'Using resampy in kaiser_best to {src_sr}=>{target_sr}. This function is pretty slow, \
+ we recommend the mode kaiser_fast in large scale audio training')
+
+ if not isinstance(y, np.ndarray):
+ raise ParameterError(
+ 'Only support numpy np.ndarray, but received y in {type(y)}')
+
+ if mode not in RESAMPLE_MODES:
+ raise ParameterError(f'resample mode must in {RESAMPLE_MODES}')
+
+ return resampy.resample(y, src_sr, target_sr, filter=mode)
+
+
+def to_mono(y: np.ndarray, merge_type: str='average') -> np.ndarray:
+ """Convert sterior audio to mono.
+
+ Args:
+ y (np.ndarray): Input waveform array in 1D or 2D.
+ merge_type (str, optional): Merge type to generate mono waveform. Defaults to 'average'.
+
+ Returns:
+ np.ndarray: `y` with mono channel.
+ """
+
+ if merge_type not in MERGE_TYPES:
+ raise ParameterError(
+ f'Unsupported merge type {merge_type}, available types are {MERGE_TYPES}'
+ )
+ if y.ndim > 2:
+ raise ParameterError(
+ f'Unsupported audio array, y.ndim > 2, the shape is {y.shape}')
+ if y.ndim == 1: # nothing to merge
+ return y
+
+ if merge_type == 'ch0':
+ return y[0]
+ if merge_type == 'ch1':
+ return y[1]
+ if merge_type == 'random':
+ return y[np.random.randint(0, 2)]
+
+ # need to do averaging according to dtype
+
+ if y.dtype == 'float32':
+ y_out = (y[0] + y[1]) * 0.5
+ elif y.dtype == 'int16':
+ y_out = y.astype('int32')
+ y_out = (y_out[0] + y_out[1]) // 2
+ y_out = np.clip(y_out, np.iinfo(y.dtype).min,
+ np.iinfo(y.dtype).max).astype(y.dtype)
+
+ elif y.dtype == 'int8':
+ y_out = y.astype('int16')
+ y_out = (y_out[0] + y_out[1]) // 2
+ y_out = np.clip(y_out, np.iinfo(y.dtype).min,
+ np.iinfo(y.dtype).max).astype(y.dtype)
+ else:
+ raise ParameterError(f'Unsupported dtype: {y.dtype}')
+ return y_out
+
+
+def soundfile_load_(file: os.PathLike,
+ offset: Optional[float]=None,
+ dtype: str='int16',
+ duration: Optional[int]=None) -> Tuple[np.ndarray, int]:
+ """Load audio using soundfile library. This function load audio file using libsndfile.
+
+ Args:
+ file (os.PathLike): File of waveform.
+ offset (Optional[float], optional): Offset to the start of waveform. Defaults to None.
+ dtype (str, optional): Data type of waveform. Defaults to 'int16'.
+ duration (Optional[int], optional): Duration of waveform to read. Defaults to None.
+
+ Returns:
+ Tuple[np.ndarray, int]: Waveform in ndarray and its samplerate.
+ """
+ with soundfile.SoundFile(file) as sf_desc:
+ sr_native = sf_desc.samplerate
+ if offset:
+ sf_desc.seek(int(offset * sr_native))
+ if duration is not None:
+ frame_duration = int(duration * sr_native)
+ else:
+ frame_duration = -1
+ y = sf_desc.read(frames=frame_duration, dtype=dtype, always_2d=False).T
+
+ return y, sf_desc.samplerate
+
+
+def normalize(y: np.ndarray, norm_type: str='linear',
+ mul_factor: float=1.0) -> np.ndarray:
+ """Normalize an input audio with additional multiplier.
+
+ Args:
+ y (np.ndarray): Input waveform array in 1D or 2D.
+ norm_type (str, optional): Type of normalization. Defaults to 'linear'.
+ mul_factor (float, optional): Scaling factor. Defaults to 1.0.
+
+ Returns:
+ np.ndarray: `y` after normalization.
+ """
+
+ if norm_type == 'linear':
+ amax = np.max(np.abs(y))
+ factor = 1.0 / (amax + EPS)
+ y = y * factor * mul_factor
+ elif norm_type == 'gaussian':
+ amean = np.mean(y)
+ astd = np.std(y)
+ astd = max(astd, EPS)
+ y = mul_factor * (y - amean) / astd
+ else:
+ raise NotImplementedError(f'norm_type should be in {NORMALMIZE_TYPES}')
+
+ return y
+
+
+def soundfile_save(y: np.ndarray, sr: int, file: os.PathLike) -> None:
+ """Save audio file to disk. This function saves audio to disk using scipy.io.wavfile, with additional step to convert input waveform to int16.
+
+ Args:
+ y (np.ndarray): Input waveform array in 1D or 2D.
+ sr (int): Sample rate.
+ file (os.PathLike): Path of audio file to save.
+ """
+ if not file.endswith('.wav'):
+ raise ParameterError(
+ f'only .wav file supported, but dst file name is: {file}')
+
+ if sr <= 0:
+ raise ParameterError(
+ f'Sample rate should be larger than 0, received sr = {sr}')
+
+ if y.dtype not in ['int16', 'int8']:
+ warnings.warn(
+ f'input data type is {y.dtype}, will convert data to int16 format before saving'
+ )
+ y_out = depth_convert(y, 'int16')
+ else:
+ y_out = y
+
+ wavfile.write(file, sr, y_out)
+
+
+def soundfile_load(
+ file: os.PathLike,
+ sr: Optional[int]=None,
+ mono: bool=True,
+ merge_type: str='average', # ch0,ch1,random,average
+ normal: bool=True,
+ norm_type: str='linear',
+ norm_mul_factor: float=1.0,
+ offset: float=0.0,
+ duration: Optional[int]=None,
+ dtype: str='float32',
+ resample_mode: str='kaiser_fast') -> Tuple[np.ndarray, int]:
+ """Load audio file from disk. This function loads audio from disk using using audio backend.
+
+ Args:
+ file (os.PathLike): Path of audio file to load.
+ sr (Optional[int], optional): Sample rate of loaded waveform. Defaults to None.
+ mono (bool, optional): Return waveform with mono channel. Defaults to True.
+ merge_type (str, optional): Merge type of multi-channels waveform. Defaults to 'average'.
+ normal (bool, optional): Waveform normalization. Defaults to True.
+ norm_type (str, optional): Type of normalization. Defaults to 'linear'.
+ norm_mul_factor (float, optional): Scaling factor. Defaults to 1.0.
+ offset (float, optional): Offset to the start of waveform. Defaults to 0.0.
+ duration (Optional[int], optional): Duration of waveform to read. Defaults to None.
+ dtype (str, optional): Data type of waveform. Defaults to 'float32'.
+ resample_mode (str, optional): The resampling filter to use. Defaults to 'kaiser_fast'.
+
+ Returns:
+ Tuple[np.ndarray, int]: Waveform in ndarray and its samplerate.
+ """
+
+ y, r = soundfile_load_(file, offset=offset, dtype=dtype, duration=duration)
+
+ if not ((y.ndim == 1 and len(y) > 0) or (y.ndim == 2 and len(y[0]) > 0)):
+ raise ParameterError(f'audio file {file} looks empty')
+
+ if mono:
+ y = to_mono(y, merge_type)
+
+ if sr is not None and sr != r:
+ y = resample(y, r, sr, mode=resample_mode)
+ r = sr
+
+ if normal:
+ y = normalize(y, norm_type, norm_mul_factor)
+ elif dtype in ['int8', 'int16']:
+ # still need to do normalization, before depth conversion
+ y = normalize(y, 'linear', 1.0)
+
+ y = depth_convert(y, dtype)
+ return y, r
+
+
+#The code below is taken from: https://github.com/pytorch/audio/blob/main/torchaudio/backend/soundfile_backend.py, with some modifications.
+
+
+def _get_subtype_for_wav(dtype: paddle.dtype,
+ encoding: str,
+ bits_per_sample: int):
+ if not encoding:
+ if not bits_per_sample:
+ subtype = {
+ paddle.uint8: "PCM_U8",
+ paddle.int16: "PCM_16",
+ paddle.int32: "PCM_32",
+ paddle.float32: "FLOAT",
+ paddle.float64: "DOUBLE",
+ }.get(dtype)
+ if not subtype:
+ raise ValueError(f"Unsupported dtype for wav: {dtype}")
+ return subtype
+ if bits_per_sample == 8:
+ return "PCM_U8"
+ return f"PCM_{bits_per_sample}"
+ if encoding == "PCM_S":
+ if not bits_per_sample:
+ return "PCM_32"
+ if bits_per_sample == 8:
+ raise ValueError("wav does not support 8-bit signed PCM encoding.")
+ return f"PCM_{bits_per_sample}"
+ if encoding == "PCM_U":
+ if bits_per_sample in (None, 8):
+ return "PCM_U8"
+ raise ValueError("wav only supports 8-bit unsigned PCM encoding.")
+ if encoding == "PCM_F":
+ if bits_per_sample in (None, 32):
+ return "FLOAT"
+ if bits_per_sample == 64:
+ return "DOUBLE"
+ raise ValueError("wav only supports 32/64-bit float PCM encoding.")
+ if encoding == "ULAW":
+ if bits_per_sample in (None, 8):
+ return "ULAW"
+ raise ValueError("wav only supports 8-bit mu-law encoding.")
+ if encoding == "ALAW":
+ if bits_per_sample in (None, 8):
+ return "ALAW"
+ raise ValueError("wav only supports 8-bit a-law encoding.")
+ raise ValueError(f"wav does not support {encoding}.")
+
+
+def _get_subtype_for_sphere(encoding: str, bits_per_sample: int):
+ if encoding in (None, "PCM_S"):
+ return f"PCM_{bits_per_sample}" if bits_per_sample else "PCM_32"
+ if encoding in ("PCM_U", "PCM_F"):
+ raise ValueError(f"sph does not support {encoding} encoding.")
+ if encoding == "ULAW":
+ if bits_per_sample in (None, 8):
+ return "ULAW"
+ raise ValueError("sph only supports 8-bit for mu-law encoding.")
+ if encoding == "ALAW":
+ return "ALAW"
+ raise ValueError(f"sph does not support {encoding}.")
+
+
+def _get_subtype(dtype: paddle.dtype,
+ format: str,
+ encoding: str,
+ bits_per_sample: int):
+ if format == "wav":
+ return _get_subtype_for_wav(dtype, encoding, bits_per_sample)
+ if format == "flac":
+ if encoding:
+ raise ValueError("flac does not support encoding.")
+ if not bits_per_sample:
+ return "PCM_16"
+ if bits_per_sample > 24:
+ raise ValueError("flac does not support bits_per_sample > 24.")
+ return "PCM_S8" if bits_per_sample == 8 else f"PCM_{bits_per_sample}"
+ if format in ("ogg", "vorbis"):
+ if encoding or bits_per_sample:
+ raise ValueError(
+ "ogg/vorbis does not support encoding/bits_per_sample.")
+ return "VORBIS"
+ if format == "sph":
+ return _get_subtype_for_sphere(encoding, bits_per_sample)
+ if format in ("nis", "nist"):
+ return "PCM_16"
+ raise ValueError(f"Unsupported format: {format}")
+
+
+def save(
+ filepath: str,
+ src: paddle.Tensor,
+ sample_rate: int,
+ channels_first: bool=True,
+ compression: Optional[float]=None,
+ format: Optional[str]=None,
+ encoding: Optional[str]=None,
+ bits_per_sample: Optional[int]=None, ):
+ """Save audio data to file.
+
+ Note:
+ The formats this function can handle depend on the soundfile installation.
+ This function is tested on the following formats;
+
+ * WAV
+
+ * 32-bit floating-point
+ * 32-bit signed integer
+ * 16-bit signed integer
+ * 8-bit unsigned integer
+
+ * FLAC
+ * OGG/VORBIS
+ * SPHERE
+
+ Note:
+ ``filepath`` argument is intentionally annotated as ``str`` only, even though it accepts
+ ``pathlib.Path`` object as well. This is for the consistency with ``"sox_io"`` backend,
+
+ Args:
+ filepath (str or pathlib.Path): Path to audio file.
+ src (paddle.Tensor): Audio data to save. must be 2D tensor.
+ sample_rate (int): sampling rate
+ channels_first (bool, optional): If ``True``, the given tensor is interpreted as `[channel, time]`,
+ otherwise `[time, channel]`.
+ compression (float of None, optional): Not used.
+ It is here only for interface compatibility reason with "sox_io" backend.
+ format (str or None, optional): Override the audio format.
+ When ``filepath`` argument is path-like object, audio format is
+ inferred from file extension. If the file extension is missing or
+ different, you can specify the correct format with this argument.
+
+ When ``filepath`` argument is file-like object,
+ this argument is required.
+
+ Valid values are ``"wav"``, ``"ogg"``, ``"vorbis"``,
+ ``"flac"`` and ``"sph"``.
+ encoding (str or None, optional): Changes the encoding for supported formats.
+ This argument is effective only for supported formats, such as
+ ``"wav"``, ``""flac"`` and ``"sph"``. Valid values are:
+
+ - ``"PCM_S"`` (signed integer Linear PCM)
+ - ``"PCM_U"`` (unsigned integer Linear PCM)
+ - ``"PCM_F"`` (floating point PCM)
+ - ``"ULAW"`` (mu-law)
+ - ``"ALAW"`` (a-law)
+
+ bits_per_sample (int or None, optional): Changes the bit depth for the
+ supported formats.
+ When ``format`` is one of ``"wav"``, ``"flac"`` or ``"sph"``,
+ you can change the bit depth.
+ Valid values are ``8``, ``16``, ``24``, ``32`` and ``64``.
+
+ Supported formats/encodings/bit depth/compression are:
+
+ ``"wav"``
+ - 32-bit floating-point PCM
+ - 32-bit signed integer PCM
+ - 24-bit signed integer PCM
+ - 16-bit signed integer PCM
+ - 8-bit unsigned integer PCM
+ - 8-bit mu-law
+ - 8-bit a-law
+
+ Note:
+ Default encoding/bit depth is determined by the dtype of
+ the input Tensor.
+
+ ``"flac"``
+ - 8-bit
+ - 16-bit (default)
+ - 24-bit
+
+ ``"ogg"``, ``"vorbis"``
+ - Doesn't accept changing configuration.
+
+ ``"sph"``
+ - 8-bit signed integer PCM
+ - 16-bit signed integer PCM
+ - 24-bit signed integer PCM
+ - 32-bit signed integer PCM (default)
+ - 8-bit mu-law
+ - 8-bit a-law
+ - 16-bit a-law
+ - 24-bit a-law
+ - 32-bit a-law
+
+ """
+ if src.ndim != 2:
+ raise ValueError(f"Expected 2D Tensor, got {src.ndim}D.")
+ if compression is not None:
+ warnings.warn(
+ '`save` function of "soundfile" backend does not support "compression" parameter. '
+ "The argument is silently ignored.")
+ if hasattr(filepath, "write"):
+ if format is None:
+ raise RuntimeError(
+ "`format` is required when saving to file object.")
+ ext = format.lower()
+ else:
+ ext = str(filepath).split(".")[-1].lower()
+
+ if bits_per_sample not in (None, 8, 16, 24, 32, 64):
+ raise ValueError("Invalid bits_per_sample.")
+ if bits_per_sample == 24:
+ warnings.warn(
+ "Saving audio with 24 bits per sample might warp samples near -1. "
+ "Using 16 bits per sample might be able to avoid this.")
+ subtype = _get_subtype(src.dtype, ext, encoding, bits_per_sample)
+
+ # sph is a extension used in TED-LIUM but soundfile does not recognize it as NIST format,
+ # so we extend the extensions manually here
+ if ext in ["nis", "nist", "sph"] and format is None:
+ format = "NIST"
+
+ if channels_first:
+ src = src.t()
+
+ soundfile.write(
+ file=filepath,
+ data=src,
+ samplerate=sample_rate,
+ subtype=subtype,
+ format=format)
+
+
+_SUBTYPE2DTYPE = {
+ "PCM_S8": "int8",
+ "PCM_U8": "uint8",
+ "PCM_16": "int16",
+ "PCM_32": "int32",
+ "FLOAT": "float32",
+ "DOUBLE": "float64",
+}
+
+
+def load(
+ filepath: str,
+ frame_offset: int=0,
+ num_frames: int=-1,
+ normalize: bool=True,
+ channels_first: bool=True,
+ format: Optional[str]=None, ) -> Tuple[paddle.Tensor, int]:
+ """Load audio data from file.
+
+ Note:
+ The formats this function can handle depend on the soundfile installation.
+ This function is tested on the following formats;
+
+ * WAV
+
+ * 32-bit floating-point
+ * 32-bit signed integer
+ * 16-bit signed integer
+ * 8-bit unsigned integer
+
+ * FLAC
+ * OGG/VORBIS
+ * SPHERE
+
+ By default (``normalize=True``, ``channels_first=True``), this function returns Tensor with
+ ``float32`` dtype and the shape of `[channel, time]`.
+ The samples are normalized to fit in the range of ``[-1.0, 1.0]``.
+
+ When the input format is WAV with integer type, such as 32-bit signed integer, 16-bit
+ signed integer and 8-bit unsigned integer (24-bit signed integer is not supported),
+ by providing ``normalize=False``, this function can return integer Tensor, where the samples
+ are expressed within the whole range of the corresponding dtype, that is, ``int32`` tensor
+ for 32-bit signed PCM, ``int16`` for 16-bit signed PCM and ``uint8`` for 8-bit unsigned PCM.
+
+ ``normalize`` parameter has no effect on 32-bit floating-point WAV and other formats, such as
+ ``flac`` and ``mp3``.
+ For these formats, this function always returns ``float32`` Tensor with values normalized to
+ ``[-1.0, 1.0]``.
+
+ Note:
+ ``filepath`` argument is intentionally annotated as ``str`` only, even though it accepts
+ ``pathlib.Path`` object as well. This is for the consistency with ``"sox_io"`` backend.
+
+ Args:
+ filepath (path-like object or file-like object):
+ Source of audio data.
+ frame_offset (int, optional):
+ Number of frames to skip before start reading data.
+ num_frames (int, optional):
+ Maximum number of frames to read. ``-1`` reads all the remaining samples,
+ starting from ``frame_offset``.
+ This function may return the less number of frames if there is not enough
+ frames in the given file.
+ normalize (bool, optional):
+ When ``True``, this function always return ``float32``, and sample values are
+ normalized to ``[-1.0, 1.0]``.
+ If input file is integer WAV, giving ``False`` will change the resulting Tensor type to
+ integer type.
+ This argument has no effect for formats other than integer WAV type.
+ channels_first (bool, optional):
+ When True, the returned Tensor has dimension `[channel, time]`.
+ Otherwise, the returned Tensor's dimension is `[time, channel]`.
+ format (str or None, optional):
+ Not used. PySoundFile does not accept format hint.
+
+ Returns:
+ (paddle.Tensor, int): Resulting Tensor and sample rate.
+ If the input file has integer wav format and normalization is off, then it has
+ integer type, else ``float32`` type. If ``channels_first=True``, it has
+ `[channel, time]` else `[time, channel]`.
+ """
+ with soundfile.SoundFile(filepath, "r") as file_:
+ if file_.format != "WAV" or normalize:
+ dtype = "float32"
+ elif file_.subtype not in _SUBTYPE2DTYPE:
+ raise ValueError(f"Unsupported subtype: {file_.subtype}")
+ else:
+ dtype = _SUBTYPE2DTYPE[file_.subtype]
+
+ frames = file_._prepare_read(frame_offset, None, num_frames)
+ waveform = file_.read(frames, dtype, always_2d=True)
+ sample_rate = file_.samplerate
+
+ waveform = paddle.to_tensor(waveform)
+ if channels_first:
+ waveform = paddle.transpose(waveform, perm=[1, 0])
+ return waveform, sample_rate
+
+
+# Mapping from soundfile subtype to number of bits per sample.
+# This is mostly heuristical and the value is set to 0 when it is irrelevant
+# (lossy formats) or when it can't be inferred.
+# For ADPCM (and G72X) subtypes, it's hard to infer the bit depth because it's not part of the standard:
+# According to https://en.wikipedia.org/wiki/Adaptive_differential_pulse-code_modulation#In_telephony,
+# the default seems to be 8 bits but it can be compressed further to 4 bits.
+# The dict is inspired from
+# https://github.com/bastibe/python-soundfile/blob/744efb4b01abc72498a96b09115b42a4cabd85e4/soundfile.py#L66-L94
+_SUBTYPE_TO_BITS_PER_SAMPLE = {
+ "PCM_S8": 8, # Signed 8 bit data
+ "PCM_16": 16, # Signed 16 bit data
+ "PCM_24": 24, # Signed 24 bit data
+ "PCM_32": 32, # Signed 32 bit data
+ "PCM_U8": 8, # Unsigned 8 bit data (WAV and RAW only)
+ "FLOAT": 32, # 32 bit float data
+ "DOUBLE": 64, # 64 bit float data
+ "ULAW": 8, # U-Law encoded. See https://en.wikipedia.org/wiki/G.711#Types
+ "ALAW": 8, # A-Law encoded. See https://en.wikipedia.org/wiki/G.711#Types
+ "IMA_ADPCM": 0, # IMA ADPCM.
+ "MS_ADPCM": 0, # Microsoft ADPCM.
+ "GSM610":
+ 0, # GSM 6.10 encoding. (Wikipedia says 1.625 bit depth?? https://en.wikipedia.org/wiki/Full_Rate)
+ "VOX_ADPCM": 0, # OKI / Dialogix ADPCM
+ "G721_32": 0, # 32kbs G721 ADPCM encoding.
+ "G723_24": 0, # 24kbs G723 ADPCM encoding.
+ "G723_40": 0, # 40kbs G723 ADPCM encoding.
+ "DWVW_12": 12, # 12 bit Delta Width Variable Word encoding.
+ "DWVW_16": 16, # 16 bit Delta Width Variable Word encoding.
+ "DWVW_24": 24, # 24 bit Delta Width Variable Word encoding.
+ "DWVW_N": 0, # N bit Delta Width Variable Word encoding.
+ "DPCM_8": 8, # 8 bit differential PCM (XI only)
+ "DPCM_16": 16, # 16 bit differential PCM (XI only)
+ "VORBIS": 0, # Xiph Vorbis encoding. (lossy)
+ "ALAC_16": 16, # Apple Lossless Audio Codec (16 bit).
+ "ALAC_20": 20, # Apple Lossless Audio Codec (20 bit).
+ "ALAC_24": 24, # Apple Lossless Audio Codec (24 bit).
+ "ALAC_32": 32, # Apple Lossless Audio Codec (32 bit).
+}
+
+
+def _get_bit_depth(subtype):
+ if subtype not in _SUBTYPE_TO_BITS_PER_SAMPLE:
+ warnings.warn(
+ f"The {subtype} subtype is unknown to PaddleAudio. As a result, the bits_per_sample "
+ "attribute will be set to 0. If you are seeing this warning, please "
+ "report by opening an issue on github (after checking for existing/closed ones). "
+ "You may otherwise ignore this warning.")
+ return _SUBTYPE_TO_BITS_PER_SAMPLE.get(subtype, 0)
+
+
+_SUBTYPE_TO_ENCODING = {
+ "PCM_S8": "PCM_S",
+ "PCM_16": "PCM_S",
+ "PCM_24": "PCM_S",
+ "PCM_32": "PCM_S",
+ "PCM_U8": "PCM_U",
+ "FLOAT": "PCM_F",
+ "DOUBLE": "PCM_F",
+ "ULAW": "ULAW",
+ "ALAW": "ALAW",
+ "VORBIS": "VORBIS",
+}
+
+
+def _get_encoding(format: str, subtype: str):
+ if format == "FLAC":
+ return "FLAC"
+ return _SUBTYPE_TO_ENCODING.get(subtype, "UNKNOWN")
+
+
+def info(filepath: str, format: Optional[str]=None) -> AudioInfo:
+ """Get signal information of an audio file.
+
+ Note:
+ ``filepath`` argument is intentionally annotated as ``str`` only, even though it accepts
+ ``pathlib.Path`` object as well. This is for the consistency with ``"sox_io"`` backend,
+
+ Args:
+ filepath (path-like object or file-like object):
+ Source of audio data.
+ format (str or None, optional):
+ Not used. PySoundFile does not accept format hint.
+
+ Returns:
+ AudioInfo: meta data of the given audio.
+
+ """
+ sinfo = soundfile.info(filepath)
+ return AudioInfo(
+ sinfo.samplerate,
+ sinfo.frames,
+ sinfo.channels,
+ bits_per_sample=_get_bit_depth(sinfo.subtype),
+ encoding=_get_encoding(sinfo.format, sinfo.subtype), )
diff --git a/paddlespeech/audio/streamdata/tariterators.py b/paddlespeech/audio/streamdata/tariterators.py
index 3adf4892a..8429e6f77 100644
--- a/paddlespeech/audio/streamdata/tariterators.py
+++ b/paddlespeech/audio/streamdata/tariterators.py
@@ -20,9 +20,9 @@ trace = False
meta_prefix = "__"
meta_suffix = "__"
-import paddleaudio
import paddle
import numpy as np
+from paddlespeech.audio.backends import soundfile_load
AUDIO_FORMAT_SETS = set(['flac', 'mp3', 'm4a', 'ogg', 'opus', 'wav', 'wma'])
@@ -111,7 +111,7 @@ def tar_file_iterator(fileobj,
assert pos > 0
prefix, postfix = name[:pos], name[pos + 1:]
if postfix == 'wav':
- waveform, sample_rate = paddleaudio.backends.soundfile_load(
+ waveform, sample_rate = soundfile_load(
stream.extractfile(tarinfo), normal=False)
result = dict(
fname=prefix, wav=waveform, sample_rate=sample_rate)
@@ -163,7 +163,7 @@ def tar_file_and_group_iterator(fileobj,
if postfix == 'txt':
example['txt'] = file_obj.read().decode('utf8').strip()
elif postfix in AUDIO_FORMAT_SETS:
- waveform, sample_rate = paddleaudio.backends.soundfile_load(
+ waveform, sample_rate = soundfile_load(
file_obj, normal=False)
waveform = paddle.to_tensor(
np.expand_dims(np.array(waveform), 0),
diff --git a/paddlespeech/audio/transform/spectrogram.py b/paddlespeech/audio/transform/spectrogram.py
index f2dab3169..a4da86ec7 100644
--- a/paddlespeech/audio/transform/spectrogram.py
+++ b/paddlespeech/audio/transform/spectrogram.py
@@ -15,9 +15,10 @@
import librosa
import numpy as np
import paddle
-from paddleaudio.compliance import kaldi
from python_speech_features import logfbank
+from paddlespeech.audio.compliance import kaldi
+
def stft(x,
n_fft,
diff --git a/paddlespeech/cli/cls/infer.py b/paddlespeech/cli/cls/infer.py
index 5e2168e3d..fa49f7bdb 100644
--- a/paddlespeech/cli/cls/infer.py
+++ b/paddlespeech/cli/cls/infer.py
@@ -22,11 +22,11 @@ import numpy as np
import paddle
import yaml
from paddle.audio.features import LogMelSpectrogram
-from paddleaudio.backends import soundfile_load as load
from ..executor import BaseExecutor
from ..log import logger
from ..utils import stats_wrapper
+from paddlespeech.audio.backends import soundfile_load as load
__all__ = ['CLSExecutor']
diff --git a/paddlespeech/cli/kws/infer.py b/paddlespeech/cli/kws/infer.py
index ce2f3f461..6dee4cc84 100644
--- a/paddlespeech/cli/kws/infer.py
+++ b/paddlespeech/cli/kws/infer.py
@@ -20,12 +20,12 @@ from typing import Union
import paddle
import yaml
-from paddleaudio.backends import soundfile_load as load_audio
-from paddleaudio.compliance.kaldi import fbank as kaldi_fbank
from ..executor import BaseExecutor
from ..log import logger
from ..utils import stats_wrapper
+from paddlespeech.audio.backends import soundfile_load as load_audio
+from paddlespeech.audio.compliance.kaldi import fbank as kaldi_fbank
__all__ = ['KWSExecutor']
diff --git a/paddlespeech/cli/vector/infer.py b/paddlespeech/cli/vector/infer.py
index 57a781656..c4ae11c75 100644
--- a/paddlespeech/cli/vector/infer.py
+++ b/paddlespeech/cli/vector/infer.py
@@ -22,13 +22,13 @@ from typing import Union
import paddle
import soundfile
-from paddleaudio.backends import soundfile_load as load_audio
-from paddleaudio.compliance.librosa import melspectrogram
from yacs.config import CfgNode
from ..executor import BaseExecutor
from ..log import logger
from ..utils import stats_wrapper
+from paddlespeech.audio.backends import soundfile_load as load_audio
+from paddlespeech.audio.compliance.librosa import melspectrogram
from paddlespeech.vector.io.batch import feature_normalize
from paddlespeech.vector.modules.sid_model import SpeakerIdetification
diff --git a/paddlespeech/cls/exps/panns/deploy/predict.py b/paddlespeech/cls/exps/panns/deploy/predict.py
index a6b735335..3085a8482 100644
--- a/paddlespeech/cls/exps/panns/deploy/predict.py
+++ b/paddlespeech/cls/exps/panns/deploy/predict.py
@@ -19,10 +19,10 @@ import paddle
from paddle import inference
from paddle.audio.datasets import ESC50
from paddle.audio.features import LogMelSpectrogram
-from paddleaudio.backends import soundfile_load as load_audio
from scipy.special import softmax
import paddlespeech.utils
+from paddlespeech.audio.backends import soundfile_load as load_audio
# yapf: disable
parser = argparse.ArgumentParser()
diff --git a/paddlespeech/cls/exps/panns/export_model.py b/paddlespeech/cls/exps/panns/export_model.py
index e860b54aa..5163dbacf 100644
--- a/paddlespeech/cls/exps/panns/export_model.py
+++ b/paddlespeech/cls/exps/panns/export_model.py
@@ -15,8 +15,8 @@ import argparse
import os
import paddle
-from paddleaudio.datasets import ESC50
+from paddlespeech.audio.datasets import ESC50
from paddlespeech.cls.models import cnn14
from paddlespeech.cls.models import SoundClassifier
diff --git a/paddlespeech/cls/exps/panns/predict.py b/paddlespeech/cls/exps/panns/predict.py
index 4681e4dc9..6b0eb9f68 100644
--- a/paddlespeech/cls/exps/panns/predict.py
+++ b/paddlespeech/cls/exps/panns/predict.py
@@ -18,12 +18,11 @@ import paddle
import paddle.nn.functional as F
import yaml
from paddle.audio.features import LogMelSpectrogram
-from paddleaudio.backends import soundfile_load as load_audio
-from paddleaudio.utils import logger
+from paddlespeech.audio.backends import soundfile_load as load_audio
+from paddlespeech.audio.utils import logger
from paddlespeech.cls.models import SoundClassifier
from paddlespeech.utils.dynamic_import import dynamic_import
-#from paddleaudio.features import LogMelSpectrogram
# yapf: disable
parser = argparse.ArgumentParser(__doc__)
diff --git a/paddlespeech/cls/exps/panns/train.py b/paddlespeech/cls/exps/panns/train.py
index b768919be..5e5e0809d 100644
--- a/paddlespeech/cls/exps/panns/train.py
+++ b/paddlespeech/cls/exps/panns/train.py
@@ -17,9 +17,9 @@ import os
import paddle
import yaml
from paddle.audio.features import LogMelSpectrogram
-from paddleaudio.utils import logger
-from paddleaudio.utils import Timer
+from paddlespeech.audio.utils import logger
+from paddlespeech.audio.utils import Timer
from paddlespeech.cls.models import SoundClassifier
from paddlespeech.utils.dynamic_import import dynamic_import
diff --git a/paddlespeech/cls/models/panns/panns.py b/paddlespeech/cls/models/panns/panns.py
index 6f9af9b52..37deae80c 100644
--- a/paddlespeech/cls/models/panns/panns.py
+++ b/paddlespeech/cls/models/panns/panns.py
@@ -15,8 +15,8 @@ import os
import paddle.nn as nn
import paddle.nn.functional as F
-from paddleaudio.utils.download import load_state_dict_from_url
+from paddlespeech.audio.utils.download import load_state_dict_from_url
from paddlespeech.utils.env import MODEL_HOME
__all__ = ['CNN14', 'CNN10', 'CNN6', 'cnn14', 'cnn10', 'cnn6']
diff --git a/paddlespeech/kws/exps/mdtc/train.py b/paddlespeech/kws/exps/mdtc/train.py
index bb727d36a..d5bb5e020 100644
--- a/paddlespeech/kws/exps/mdtc/train.py
+++ b/paddlespeech/kws/exps/mdtc/train.py
@@ -14,10 +14,10 @@
import os
import paddle
-from paddleaudio.utils import logger
-from paddleaudio.utils import Timer
from yacs.config import CfgNode
+from paddlespeech.audio.utils import logger
+from paddlespeech.audio.utils import Timer
from paddlespeech.kws.exps.mdtc.collate import collate_features
from paddlespeech.kws.models.loss import max_pooling_loss
from paddlespeech.kws.models.mdtc import KWSModel
diff --git a/paddlespeech/s2t/frontend/featurizer/audio_featurizer.py b/paddlespeech/s2t/frontend/featurizer/audio_featurizer.py
index 22329d5e0..ac5720fd5 100644
--- a/paddlespeech/s2t/frontend/featurizer/audio_featurizer.py
+++ b/paddlespeech/s2t/frontend/featurizer/audio_featurizer.py
@@ -14,10 +14,11 @@
"""Contains the audio featurizer class."""
import numpy as np
import paddle
-import paddleaudio.compliance.kaldi as kaldi
from python_speech_features import delta
from python_speech_features import mfcc
+import paddlespeech.audio.compliance.kaldi as kaldi
+
class AudioFeaturizer():
"""Audio featurizer, for extracting features from audio contents of
diff --git a/paddlespeech/s2t/modules/fbank.py b/paddlespeech/s2t/modules/fbank.py
index 30671c274..8d76a4727 100644
--- a/paddlespeech/s2t/modules/fbank.py
+++ b/paddlespeech/s2t/modules/fbank.py
@@ -1,7 +1,7 @@
import paddle
from paddle import nn
-from paddleaudio.compliance import kaldi
+from paddlespeech.audio.compliance import kaldi
from paddlespeech.s2t.utils.log import Log
logger = Log(__name__).getlog()
diff --git a/paddlespeech/server/engine/vector/python/vector_engine.py b/paddlespeech/server/engine/vector/python/vector_engine.py
index 7d86f3df7..f02a942fb 100644
--- a/paddlespeech/server/engine/vector/python/vector_engine.py
+++ b/paddlespeech/server/engine/vector/python/vector_engine.py
@@ -16,9 +16,9 @@ from collections import OrderedDict
import numpy as np
import paddle
-from paddleaudio.backends import soundfile_load as load_audio
-from paddleaudio.compliance.librosa import melspectrogram
+from paddlespeech.audio.backends import soundfile_load as load_audio
+from paddlespeech.audio.compliance.librosa import melspectrogram
from paddlespeech.cli.log import logger
from paddlespeech.cli.vector.infer import VectorExecutor
from paddlespeech.server.engine.base_engine import BaseEngine
diff --git a/paddlespeech/server/util.py b/paddlespeech/server/util.py
index 6aa6fd589..47871922b 100644
--- a/paddlespeech/server/util.py
+++ b/paddlespeech/server/util.py
@@ -24,13 +24,13 @@ from typing import Any
from typing import Dict
import paddle
-import paddleaudio
import requests
import yaml
from paddle.framework import load
from .entry import client_commands
from .entry import server_commands
+from paddlespeech.audio.backends import soundfile_load
from paddlespeech.cli import download
try:
from .. import __version__
@@ -289,7 +289,7 @@ def _note_one_stat(cls_name, params={}):
if 'audio_file' in params:
try:
- _, sr = paddleaudio.backends.soundfile_load(params['audio_file'])
+ _, sr = soundfile_load(params['audio_file'])
except Exception:
sr = -1
diff --git a/paddlespeech/t2s/models/starganv2_vc/AuxiliaryASR/layers.py b/paddlespeech/t2s/models/starganv2_vc/AuxiliaryASR/layers.py
index 5901c805a..b29d0863e 100644
--- a/paddlespeech/t2s/models/starganv2_vc/AuxiliaryASR/layers.py
+++ b/paddlespeech/t2s/models/starganv2_vc/AuxiliaryASR/layers.py
@@ -13,9 +13,9 @@
# limitations under the License.
import paddle
import paddle.nn.functional as F
-import paddleaudio.functional as audio_F
from paddle import nn
+from paddlespeech.audio.functional import create_dct
from paddlespeech.utils.initialize import _calculate_gain
from paddlespeech.utils.initialize import xavier_uniform_
@@ -243,7 +243,7 @@ class MFCC(nn.Layer):
self.n_mfcc = n_mfcc
self.n_mels = n_mels
self.norm = 'ortho'
- dct_mat = audio_F.create_dct(self.n_mfcc, self.n_mels, self.norm)
+ dct_mat = create_dct(self.n_mfcc, self.n_mels, self.norm)
self.register_buffer('dct_mat', dct_mat)
def forward(self, mel_specgram: paddle.Tensor):
diff --git a/paddlespeech/vector/exps/ecapa_tdnn/extract_emb.py b/paddlespeech/vector/exps/ecapa_tdnn/extract_emb.py
index 821b1deed..a2a19cb66 100644
--- a/paddlespeech/vector/exps/ecapa_tdnn/extract_emb.py
+++ b/paddlespeech/vector/exps/ecapa_tdnn/extract_emb.py
@@ -16,10 +16,10 @@ import os
import time
import paddle
-from paddleaudio.backends import soundfile_load as load_audio
-from paddleaudio.compliance.librosa import melspectrogram
from yacs.config import CfgNode
+from paddlespeech.audio.backends import soundfile_load as load_audio
+from paddlespeech.audio.compliance.librosa import melspectrogram
from paddlespeech.s2t.utils.log import Log
from paddlespeech.vector.io.batch import feature_normalize
from paddlespeech.vector.models.ecapa_tdnn import EcapaTdnn
diff --git a/paddlespeech/vector/exps/ecapa_tdnn/test.py b/paddlespeech/vector/exps/ecapa_tdnn/test.py
index f15dbf9b7..167b82422 100644
--- a/paddlespeech/vector/exps/ecapa_tdnn/test.py
+++ b/paddlespeech/vector/exps/ecapa_tdnn/test.py
@@ -18,7 +18,7 @@ import numpy as np
import paddle
from paddle.io import BatchSampler
from paddle.io import DataLoader
-from paddleaudio.metric import compute_eer
+from sklearn.metrics import roc_curve
from tqdm import tqdm
from yacs.config import CfgNode
@@ -129,6 +129,23 @@ def compute_verification_scores(id2embedding, train_cohort, config):
return scores, labels
+def compute_eer(labels: np.ndarray, scores: np.ndarray) -> List[float]:
+ """Compute EER and return score threshold.
+
+ Args:
+ labels (np.ndarray): the trial label, shape: [N], one-dimension, N refer to the samples num
+ scores (np.ndarray): the trial scores, shape: [N], one-dimension, N refer to the samples num
+
+ Returns:
+ List[float]: eer and the specific threshold
+ """
+ fpr, tpr, threshold = roc_curve(y_true=labels, y_score=scores)
+ fnr = 1 - tpr
+ eer_threshold = threshold[np.nanargmin(np.absolute((fnr - fpr)))]
+ eer = fpr[np.nanargmin(np.absolute((fnr - fpr)))]
+ return eer, eer_threshold
+
+
def main(args, config):
"""The main process for test the speaker verification model
diff --git a/paddlespeech/vector/exps/ecapa_tdnn/train.py b/paddlespeech/vector/exps/ecapa_tdnn/train.py
index 2dc7a7164..3966a900d 100644
--- a/paddlespeech/vector/exps/ecapa_tdnn/train.py
+++ b/paddlespeech/vector/exps/ecapa_tdnn/train.py
@@ -20,9 +20,9 @@ import paddle
from paddle.io import BatchSampler
from paddle.io import DataLoader
from paddle.io import DistributedBatchSampler
-from paddleaudio.compliance.librosa import melspectrogram
from yacs.config import CfgNode
+from paddlespeech.audio.compliance.librosa import melspectrogram
from paddlespeech.s2t.utils.log import Log
from paddlespeech.vector.io.augment import build_augment_pipeline
from paddlespeech.vector.io.augment import waveform_augment
diff --git a/paddlespeech/vector/io/dataset.py b/paddlespeech/vector/io/dataset.py
index dff8ad9fd..ae5c83637 100644
--- a/paddlespeech/vector/io/dataset.py
+++ b/paddlespeech/vector/io/dataset.py
@@ -15,9 +15,9 @@ from dataclasses import dataclass
from dataclasses import fields
from paddle.io import Dataset
-from paddleaudio.backends import soundfile_load as load_audio
-from paddleaudio.compliance.librosa import melspectrogram
+from paddlespeech.audio.backends import soundfile_load as load_audio
+from paddlespeech.audio.compliance.librosa import melspectrogram
from paddlespeech.s2t.utils.log import Log
logger = Log(__name__).getlog()
diff --git a/paddlespeech/vector/io/dataset_from_json.py b/paddlespeech/vector/io/dataset_from_json.py
index 852f39a94..1d1a4ad9c 100644
--- a/paddlespeech/vector/io/dataset_from_json.py
+++ b/paddlespeech/vector/io/dataset_from_json.py
@@ -16,9 +16,10 @@ from dataclasses import dataclass
from dataclasses import fields
from paddle.io import Dataset
-from paddleaudio.backends import soundfile_load as load_audio
-from paddleaudio.compliance.librosa import melspectrogram
-from paddleaudio.compliance.librosa import mfcc
+
+from paddlespeech.audio.backends import soundfile_load as load_audio
+from paddlespeech.audio.compliance.librosa import melspectrogram
+from paddlespeech.audio.compliance.librosa import mfcc
@dataclass
diff --git a/setup.py b/setup.py
index 8c2a4c1b7..49d9188a6 100644
--- a/setup.py
+++ b/setup.py
@@ -99,7 +99,6 @@ base = [
determine_opencc_version(), # opencc or opencc==1.1.6
"opencc-python-reimplemented",
"pandas",
- "paddleaudio>=1.1.0",
"paddlenlp>=2.4.8",
"paddleslim>=2.3.4",
"ppdiffusers>=0.9.0",
@@ -122,6 +121,7 @@ base = [
"webrtcvad",
"yacs>=0.1.8",
"zhon",
+ "sklearn",
]
server = ["pattern_singleton", "websockets"]