parent
59d82c0c65
commit
d264118416
@ -1,34 +1,177 @@
|
||||
# Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
import io
|
||||
import os
|
||||
import unittest
|
||||
|
||||
import numpy as np
|
||||
import paddle
|
||||
|
||||
from parameterized import parameterized
|
||||
from paddlespeech.audio.backends import sox_io_backend
|
||||
|
||||
class TestInfo(unittest.TestCase):
|
||||
|
||||
def test_wav(self, dtype, sample_rate, num_channels, sample_size):
|
||||
"""check wav file correctly """
|
||||
path = 'testdata/test.wav'
|
||||
info = sox_io_backend.get_info_file(path)
|
||||
assert info.sample_rate == sample_rate
|
||||
assert info.num_frames == sample_size # duration*sample_rate
|
||||
assert info.num_channels == num_channels
|
||||
assert info.bits_per_sample == get_bit_depth(dtype)
|
||||
assert info.encoding == get_encoding('wav', dtype)
|
||||
|
||||
from tests.unit.common_utils import (
|
||||
get_wav_data,
|
||||
load_wav,
|
||||
save_wav,
|
||||
nested_params,
|
||||
TempDirMixin,
|
||||
sox_utils
|
||||
)
|
||||
|
||||
#code is from:https://github.com/pytorch/audio/blob/main/torchaudio/test/torchaudio_unittest/backend/sox_io/save_test.py
|
||||
|
||||
def _get_sox_encoding(encoding):
|
||||
encodings = {
|
||||
"PCM_F": "floating-point",
|
||||
"PCM_S": "signed-integer",
|
||||
"PCM_U": "unsigned-integer",
|
||||
"ULAW": "u-law",
|
||||
"ALAW": "a-law",
|
||||
}
|
||||
return encodings.get(encoding)
|
||||
|
||||
class TestSaveBase(TempDirMixin):
|
||||
def assert_save_consistency(
|
||||
self,
|
||||
format: str,
|
||||
*,
|
||||
compression: float = None,
|
||||
encoding: str = None,
|
||||
bits_per_sample: int = None,
|
||||
sample_rate: float = 8000,
|
||||
num_channels: int = 2,
|
||||
num_frames: float = 3 * 8000,
|
||||
src_dtype: str = "int32",
|
||||
test_mode: str = "path",
|
||||
):
|
||||
"""`save` function produces file that is comparable with `sox` command
|
||||
|
||||
To compare that the file produced by `save` function agains the file produced by
|
||||
the equivalent `sox` command, we need to load both files.
|
||||
But there are many formats that cannot be opened with common Python modules (like
|
||||
SciPy).
|
||||
So we use `sox` command to prepare the original data and convert the saved files
|
||||
into a format that SciPy can read (PCM wav).
|
||||
The following diagram illustrates this process. The difference is 2.1. and 3.1.
|
||||
|
||||
This assumes that
|
||||
- loading data with SciPy preserves the data well.
|
||||
- converting the resulting files into WAV format with `sox` preserve the data well.
|
||||
|
||||
x
|
||||
| 1. Generate source wav file with SciPy
|
||||
|
|
||||
v
|
||||
-------------- wav ----------------
|
||||
| |
|
||||
| 2.1. load with scipy | 3.1. Convert to the target
|
||||
| then save it into the target | format depth with sox
|
||||
| format with torchaudio |
|
||||
v v
|
||||
target format target format
|
||||
| |
|
||||
| 2.2. Convert to wav with sox | 3.2. Convert to wav with sox
|
||||
| |
|
||||
v v
|
||||
wav wav
|
||||
| |
|
||||
| 2.3. load with scipy | 3.3. load with scipy
|
||||
| |
|
||||
v v
|
||||
tensor -------> compare <--------- tensor
|
||||
|
||||
"""
|
||||
cmp_encoding = "floating-point"
|
||||
cmp_bit_depth = 32
|
||||
|
||||
src_path = self.get_temp_path("1.source.wav")
|
||||
tgt_path = self.get_temp_path(f"2.1.torchaudio.{format}")
|
||||
tst_path = self.get_temp_path("2.2.result.wav")
|
||||
sox_path = self.get_temp_path(f"3.1.sox.{format}")
|
||||
ref_path = self.get_temp_path("3.2.ref.wav")
|
||||
|
||||
# 1. Generate original wav
|
||||
data = get_wav_data(src_dtype, num_channels, normalize=False, num_frames=num_frames)
|
||||
save_wav(src_path, data, sample_rate)
|
||||
|
||||
# 2.1. Convert the original wav to target format with torchaudio
|
||||
data = load_wav(src_path, normalize=False)[0]
|
||||
if test_mode == "path":
|
||||
sox_io_backend.save(
|
||||
tgt_path, data, sample_rate, compression=compression, encoding=encoding, bits_per_sample=bits_per_sample
|
||||
)
|
||||
elif test_mode == "fileobj":
|
||||
with open(tgt_path, "bw") as file_:
|
||||
sox_io_backend.save(
|
||||
file_,
|
||||
data,
|
||||
sample_rate,
|
||||
format=format,
|
||||
compression=compression,
|
||||
encoding=encoding,
|
||||
bits_per_sample=bits_per_sample,
|
||||
)
|
||||
elif test_mode == "bytesio":
|
||||
file_ = io.BytesIO()
|
||||
sox_io_backend.save(
|
||||
file_,
|
||||
data,
|
||||
sample_rate,
|
||||
format=format,
|
||||
compression=compression,
|
||||
encoding=encoding,
|
||||
bits_per_sample=bits_per_sample,
|
||||
)
|
||||
file_.seek(0)
|
||||
with open(tgt_path, "bw") as f:
|
||||
f.write(file_.read())
|
||||
else:
|
||||
raise ValueError(f"Unexpected test mode: {test_mode}")
|
||||
# 2.2. Convert the target format to wav with sox
|
||||
sox_utils.convert_audio_file(tgt_path, tst_path, encoding=cmp_encoding, bit_depth=cmp_bit_depth)
|
||||
# 2.3. Load with SciPy
|
||||
found = load_wav(tst_path, normalize=False)[0]
|
||||
|
||||
# 3.1. Convert the original wav to target format with sox
|
||||
sox_encoding = _get_sox_encoding(encoding)
|
||||
sox_utils.convert_audio_file(
|
||||
src_path, sox_path, compression=compression, encoding=sox_encoding, bit_depth=bits_per_sample
|
||||
)
|
||||
# 3.2. Convert the target format to wav with sox
|
||||
sox_utils.convert_audio_file(sox_path, ref_path, encoding=cmp_encoding, bit_depth=cmp_bit_depth)
|
||||
# 3.3. Load with SciPy
|
||||
expected = load_wav(ref_path, normalize=False)[0]
|
||||
|
||||
np.testing.assert_array_almost_equal(found, expected)
|
||||
|
||||
class TestSave(TestSaveBase, unittest.TestCase):
|
||||
@nested_params(
|
||||
["path",],
|
||||
[
|
||||
("PCM_U", 8),
|
||||
("PCM_S", 16),
|
||||
("PCM_S", 32),
|
||||
("PCM_F", 32),
|
||||
("PCM_F", 64),
|
||||
("ULAW", 8),
|
||||
("ALAW", 8),
|
||||
],
|
||||
)
|
||||
def test_save_wav(self, test_mode, enc_params):
|
||||
encoding, bits_per_sample = enc_params
|
||||
self.assert_save_consistency("wav", encoding=encoding, bits_per_sample=bits_per_sample, test_mode=test_mode)
|
||||
|
||||
@nested_params(
|
||||
["path", ],
|
||||
[
|
||||
("float32",),
|
||||
("int32",),
|
||||
("int16",),
|
||||
("uint8",),
|
||||
],
|
||||
)
|
||||
def test_save_wav_dtype(self, test_mode, params):
|
||||
(dtype,) = params
|
||||
self.assert_save_consistency("wav", src_dtype=dtype, test_mode=test_mode)
|
||||
|
||||
|
||||
if __name__ == '__main__':
|
||||
unittest.main()
|
@ -1,8 +1,14 @@
|
||||
from .wav_utils import get_wav_data, load_wav, save_wav, normalize_wav
|
||||
from .parameterized_utils import load_params, nested_params
|
||||
from .case_utils import (
|
||||
TempDirMixin
|
||||
)
|
||||
|
||||
__all__ = [
|
||||
"get_wav_data",
|
||||
"load_wav",
|
||||
"save_wav",
|
||||
"normalize_wav"
|
||||
"normalize_wav",
|
||||
"load_params",
|
||||
"nested_params",
|
||||
]
|
||||
|
@ -0,0 +1,56 @@
|
||||
import functools
|
||||
import os.path
|
||||
import shutil
|
||||
import subprocess
|
||||
import sys
|
||||
import tempfile
|
||||
import time
|
||||
import unittest
|
||||
|
||||
import paddle
|
||||
from paddlespeech.audio._internal.module_utils import (
|
||||
is_kaldi_available,
|
||||
is_module_available,
|
||||
is_sox_available,
|
||||
)
|
||||
|
||||
class TempDirMixin:
|
||||
"""Mixin to provide easy access to temp dir"""
|
||||
|
||||
temp_dir_ = None
|
||||
|
||||
@classmethod
|
||||
def get_base_temp_dir(cls):
|
||||
# If TORCHAUDIO_TEST_TEMP_DIR is set, use it instead of temporary directory.
|
||||
# this is handy for debugging.
|
||||
key = "TORCHAUDIO_TEST_TEMP_DIR"
|
||||
if key in os.environ:
|
||||
return os.environ[key]
|
||||
if cls.temp_dir_ is None:
|
||||
cls.temp_dir_ = tempfile.TemporaryDirectory()
|
||||
return cls.temp_dir_.name
|
||||
|
||||
@classmethod
|
||||
def tearDownClass(cls):
|
||||
if cls.temp_dir_ is not None:
|
||||
try:
|
||||
cls.temp_dir_.cleanup()
|
||||
cls.temp_dir_ = None
|
||||
except PermissionError:
|
||||
# On Windows there is a know issue with `shutil.rmtree`,
|
||||
# which fails intermittenly.
|
||||
#
|
||||
# https://github.com/python/cpython/issues/74168
|
||||
#
|
||||
# We observed this on CircleCI, where Windows job raises
|
||||
# PermissionError.
|
||||
#
|
||||
# Following the above thread, we ignore it.
|
||||
pass
|
||||
super().tearDownClass()
|
||||
|
||||
def get_temp_path(self, *paths):
|
||||
temp_dir = os.path.join(self.get_base_temp_dir(), self.id())
|
||||
path = os.path.join(temp_dir, *paths)
|
||||
os.makedirs(os.path.dirname(path), exist_ok=True)
|
||||
return path
|
@ -0,0 +1,50 @@
|
||||
import json
|
||||
from itertools import product
|
||||
|
||||
from parameterized import param, parameterized
|
||||
|
||||
def get_asset_path(*paths):
|
||||
"""Return full path of a test asset"""
|
||||
return os.path.join(_TEST_DIR_PATH, "assets", *paths)
|
||||
|
||||
def load_params(*paths):
|
||||
with open(get_asset_path(*paths), "r") as file:
|
||||
return [param(json.loads(line)) for line in file]
|
||||
|
||||
def _name_func(func, _, params):
|
||||
strs = []
|
||||
for arg in params.args:
|
||||
if isinstance(arg, tuple):
|
||||
strs.append("_".join(str(a) for a in arg))
|
||||
else:
|
||||
strs.append(str(arg))
|
||||
# sanitize the test name
|
||||
name = "_".join(strs)
|
||||
return parameterized.to_safe_name(f"{func.__name__}_{name}")
|
||||
|
||||
|
||||
def nested_params(*params_set, name_func=_name_func):
|
||||
"""Generate the cartesian product of the given list of parameters.
|
||||
|
||||
Args:
|
||||
params_set (list of parameters): Parameters. When using ``parameterized.param`` class,
|
||||
all the parameters have to be specified with the class, only using kwargs.
|
||||
"""
|
||||
flatten = [p for params in params_set for p in params]
|
||||
|
||||
# Parameters to be nested are given as list of plain objects
|
||||
if all(not isinstance(p, param) for p in flatten):
|
||||
args = list(product(*params_set))
|
||||
return parameterized.expand(args, name_func=_name_func)
|
||||
|
||||
# Parameters to be nested are given as list of `parameterized.param`
|
||||
if not all(isinstance(p, param) for p in flatten):
|
||||
raise TypeError("When using ``parameterized.param``, " "all the parameters have to be of the ``param`` type.")
|
||||
if any(p.args for p in flatten):
|
||||
raise ValueError(
|
||||
"When using ``parameterized.param``, " "all the parameters have to be provided as keyword argument."
|
||||
)
|
||||
args = [param()]
|
||||
for params in params_set:
|
||||
args = [param(**x.kwargs, **y.kwargs) for x in args for y in params]
|
||||
return parameterized.expand(args)
|
@ -0,0 +1,116 @@
|
||||
import subprocess
|
||||
import sys
|
||||
import warnings
|
||||
|
||||
|
||||
def get_encoding(dtype):
|
||||
encodings = {
|
||||
"float32": "floating-point",
|
||||
"int32": "signed-integer",
|
||||
"int16": "signed-integer",
|
||||
"uint8": "unsigned-integer",
|
||||
}
|
||||
return encodings[dtype]
|
||||
|
||||
|
||||
def get_bit_depth(dtype):
|
||||
bit_depths = {
|
||||
"float32": 32,
|
||||
"int32": 32,
|
||||
"int16": 16,
|
||||
"uint8": 8,
|
||||
}
|
||||
return bit_depths[dtype]
|
||||
|
||||
|
||||
def gen_audio_file(
|
||||
path,
|
||||
sample_rate,
|
||||
num_channels,
|
||||
*,
|
||||
encoding=None,
|
||||
bit_depth=None,
|
||||
compression=None,
|
||||
attenuation=None,
|
||||
duration=1,
|
||||
comment_file=None,
|
||||
):
|
||||
"""Generate synthetic audio file with `sox` command."""
|
||||
if path.endswith(".wav"):
|
||||
warnings.warn("Use get_wav_data and save_wav to generate wav file for accurate result.")
|
||||
command = [
|
||||
"sox",
|
||||
"-V3", # verbose
|
||||
"--no-dither", # disable automatic dithering
|
||||
"-R",
|
||||
# -R is supposed to be repeatable, though the implementation looks suspicious
|
||||
# and not setting the seed to a fixed value.
|
||||
# https://fossies.org/dox/sox-14.4.2/sox_8c_source.html
|
||||
# search "sox_globals.repeatable"
|
||||
]
|
||||
if bit_depth is not None:
|
||||
command += ["--bits", str(bit_depth)]
|
||||
command += [
|
||||
"--rate",
|
||||
str(sample_rate),
|
||||
"--null", # no input
|
||||
"--channels",
|
||||
str(num_channels),
|
||||
]
|
||||
if compression is not None:
|
||||
command += ["--compression", str(compression)]
|
||||
if bit_depth is not None:
|
||||
command += ["--bits", str(bit_depth)]
|
||||
if encoding is not None:
|
||||
command += ["--encoding", str(encoding)]
|
||||
if comment_file is not None:
|
||||
command += ["--comment-file", str(comment_file)]
|
||||
command += [
|
||||
str(path),
|
||||
"synth",
|
||||
str(duration), # synthesizes for the given duration [sec]
|
||||
"sawtooth",
|
||||
"1",
|
||||
# saw tooth covers the both ends of value range, which is a good property for test.
|
||||
# similar to linspace(-1., 1.)
|
||||
# this introduces bigger boundary effect than sine when converted to mp3
|
||||
]
|
||||
if attenuation is not None:
|
||||
command += ["vol", f"-{attenuation}dB"]
|
||||
print(" ".join(command), file=sys.stderr)
|
||||
subprocess.run(command, check=True)
|
||||
|
||||
|
||||
def convert_audio_file(src_path, dst_path, *, encoding=None, bit_depth=None, compression=None):
|
||||
"""Convert audio file with `sox` command."""
|
||||
command = ["sox", "-V3", "--no-dither", "-R", str(src_path)]
|
||||
if encoding is not None:
|
||||
command += ["--encoding", str(encoding)]
|
||||
if bit_depth is not None:
|
||||
command += ["--bits", str(bit_depth)]
|
||||
if compression is not None:
|
||||
command += ["--compression", str(compression)]
|
||||
command += [dst_path]
|
||||
print(" ".join(command), file=sys.stderr)
|
||||
subprocess.run(command, check=True)
|
||||
|
||||
|
||||
def _flattern(effects):
|
||||
if not effects:
|
||||
return effects
|
||||
if isinstance(effects[0], str):
|
||||
return effects
|
||||
return [item for sublist in effects for item in sublist]
|
||||
|
||||
|
||||
def run_sox_effect(input_file, output_file, effect, *, output_sample_rate=None, output_bitdepth=None):
|
||||
"""Run sox effects"""
|
||||
effect = _flattern(effect)
|
||||
command = ["sox", "-V", "--no-dither", input_file]
|
||||
if output_bitdepth:
|
||||
command += ["--bits", str(output_bitdepth)]
|
||||
command += [output_file] + effect
|
||||
if output_sample_rate:
|
||||
command += ["rate", str(output_sample_rate)]
|
||||
print(" ".join(command))
|
||||
subprocess.run(command, check=True)
|
Loading…
Reference in new issue