add test_load.py

pull/2195/head
YangZhou 3 years ago
parent 63b4494700
commit 59d82c0c65

@ -103,7 +103,6 @@ def _load_lib(lib: str) -> bool:
If a dependency is missing, then users have to install it.
"""
path = _get_lib_path(lib)
warnings.warn("lib path is :" + str(path))
if not path.exists():
warnings.warn("lib path is not exists:" + str(path))
return False

@ -8,7 +8,7 @@ from paddle import Tensor
from .common import AudioMetaData
from paddlespeech.audio._internal import module_utils as _mod_utils
from paddlespeech.aduio import _paddleaudio as paddleaudio
from paddlespeech.audio import _paddleaudio as paddleaudio
#https://github.com/pytorch/audio/blob/main/torchaudio/backend/sox_io_backend.py

@ -178,36 +178,68 @@ py::array convert_to_tensor(
const py::dtype dtype,
const bool normalize,
const bool channels_first) {
// todo refector later(SGoat)
py::array t;
uint64_t dummy = 0;
SOX_SAMPLE_LOCALS;
int32_t num_rows = num_samples / num_channels;
if (normalize || dtype.char_() == 'f') {
t = py::array(dtype, {num_samples / num_channels, num_channels});
t = py::array(dtype, {num_rows, num_channels});
auto ptr = (float*)t.mutable_data(0, 0);
for (int32_t i = 0; i < num_samples; ++i) {
ptr[i] = SOX_SAMPLE_TO_FLOAT_32BIT(buffer[i], dummy);
}
if (channels_first) {
py::array t2 = py::array(dtype, {num_channels, num_rows});
for (int32_t row_idx = 0; row_idx < num_channels; ++row_idx) {
for (int32_t col_idx = 0; col_idx < num_rows; ++col_idx)
*(float*)t2.mutable_data(row_idx, col_idx) = *(float*)t.data(col_idx, row_idx);
}
return t2;
}
} else if (dtype.char_() == 'i') {
//t = torch::from_blob(
// buffer, {num_samples / num_channels, num_channels}, torch::kInt32)
// .clone();
t = py::array(dtype, {num_samples / num_channels, num_channels});
t = py::array(dtype, {num_rows, num_channels});
auto ptr = (int*)t.mutable_data(0, 0);
for (int32_t i = 0; i < num_samples; ++i) {
ptr[i] = buffer[i];
}
if (channels_first) {
py::array t2 = py::array(dtype, {num_channels, num_rows});
for (int32_t row_idx = 0; row_idx < num_channels; ++row_idx) {
for (int32_t col_idx = 0; col_idx < num_rows; ++col_idx)
*(int*)t2.mutable_data(row_idx, col_idx) = *(int*)t.data(col_idx, row_idx);
}
return t2;
}
} else if (dtype.char_() == 'h') { // int16
t = py::array(dtype, {num_samples / num_channels, num_channels});
t = py::array(dtype, {num_rows, num_channels});
auto ptr = (int16_t*)t.mutable_data(0, 0);
for (int32_t i = 0; i < num_samples; ++i) {
ptr[i] = SOX_SAMPLE_TO_SIGNED_16BIT(buffer[i], dummy);
}
if (channels_first) {
py::array t2 = py::array(dtype, {num_channels, num_rows});
for (int32_t row_idx = 0; row_idx < num_channels; ++row_idx) {
for (int32_t col_idx = 0; col_idx < num_rows; ++col_idx)
*(int16_t*)t2.mutable_data(row_idx, col_idx) = *(int16_t*)t.data(col_idx, row_idx);
}
return t2;
}
} else if (dtype.char_() == 'b') {
//t = torch::empty({num_samples / num_channels, num_channels}, torch::kUInt8);
t = py::array(dtype, {num_rows, num_channels});
auto ptr = (uint8_t*)t.mutable_data(0,0);
for (int32_t i = 0; i < num_samples; ++i) {
ptr[i] = SOX_SAMPLE_TO_UNSIGNED_8BIT(buffer[i], dummy);
}
if (channels_first) {
py::array t2 = py::array(dtype, {num_channels, num_rows});
for (int32_t row_idx = 0; row_idx < num_channels; ++row_idx) {
for (int32_t col_idx = 0; col_idx < num_rows; ++col_idx)
*(uint8_t*)t2.mutable_data(row_idx, col_idx) = *(uint8_t*)t.data(col_idx, row_idx);
}
return t2;
}
} else {
throw std::runtime_error("Unsupported dtype.");
}

@ -0,0 +1,32 @@
def get_encoding(ext, dtype):
exts = {
"mp3",
"flac",
"vorbis",
}
encodings = {
"float32": "PCM_F",
"int32": "PCM_S",
"int16": "PCM_S",
"uint8": "PCM_U",
}
return ext.upper() if ext in exts else encodings[dtype]
def get_bit_depth(dtype):
bit_depths = {
"float32": 32,
"int32": 32,
"int16": 16,
"uint8": 8,
}
return bit_depths[dtype]
def get_bits_per_sample(ext, dtype):
bits_per_samples = {
"flac": 24,
"mp3": 0,
"vorbis": 0,
}
return bits_per_samples.get(ext, get_bit_depth(dtype))

@ -0,0 +1,34 @@
# Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import unittest
import numpy as np
import paddle
from paddlespeech.audio.backends import sox_io_backend
class TestInfo(unittest.TestCase):
def test_wav(self, dtype, sample_rate, num_channels, sample_size):
"""check wav file correctly """
path = 'testdata/test.wav'
info = sox_io_backend.get_info_file(path)
assert info.sample_rate == sample_rate
assert info.num_frames == sample_size # duration*sample_rate
assert info.num_channels == num_channels
assert info.bits_per_sample == get_bit_depth(dtype)
assert info.encoding == get_encoding('wav', dtype)
if __name__ == '__main__':
unittest.main()

@ -0,0 +1,47 @@
import unittest
import itertools
from parameterized import parameterized
import numpy as np
from paddlespeech.audio._internal import module_utils as _mod_utils
from paddlespeech.audio.backends import sox_io_backend
from tests.unit.common_utils import (
get_wav_data,
load_wav,
save_wav,
)
#code is from:https://github.com/pytorch/audio/blob/main/torchaudio/test/torchaudio_unittest/backend/sox_io/load_test.py
class TestLoad(unittest.TestCase):
def assert_wav(self, dtype, sample_rate, num_channels, normalize, duration):
"""`sox_io_backend.load` can load wav format correctly.
Wav data loaded with sox_io backend should match those with scipy
"""
path = 'testdata/reference.wav'
data = get_wav_data(dtype, num_channels, normalize=normalize, num_frames=duration * sample_rate)
save_wav(path, data, sample_rate)
expected = load_wav(path, normalize=normalize)[0]
data, sr = sox_io_backend.load(path, normalize=normalize)
assert sr == sample_rate
np.testing.assert_array_almost_equal(data, expected, decimal=4)
@parameterized.expand(
list(
itertools.product(
["float64", "float32", "int32",],
[8000, 16000],
[1, 2],
[False, True],
)
),
)
def test_wav(self, dtype, sample_rate, num_channels, normalize):
"""`sox_io_backend.load` can load wav format correctly."""
self.assert_wav(dtype, sample_rate, num_channels, normalize, duration=1)
if __name__ == '__main__':
unittest.main()

@ -0,0 +1,34 @@
# Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import unittest
import numpy as np
import paddle
from paddlespeech.audio.backends import sox_io_backend
class TestInfo(unittest.TestCase):
def test_wav(self, dtype, sample_rate, num_channels, sample_size):
"""check wav file correctly """
path = 'testdata/test.wav'
info = sox_io_backend.get_info_file(path)
assert info.sample_rate == sample_rate
assert info.num_frames == sample_size # duration*sample_rate
assert info.num_channels == num_channels
assert info.bits_per_sample == get_bit_depth(dtype)
assert info.encoding == get_encoding('wav', dtype)
if __name__ == '__main__':
unittest.main()

@ -0,0 +1,8 @@
from .wav_utils import get_wav_data, load_wav, save_wav, normalize_wav
__all__ = [
"get_wav_data",
"load_wav",
"save_wav",
"normalize_wav"
]

@ -0,0 +1,92 @@
from typing import Optional
import scipy.io.wavfile
import paddle
def normalize_wav(tensor: paddle.Tensor) -> paddle.Tensor:
if tensor.dtype == paddle.float32:
pass
elif tensor.dtype == paddle.int32:
tensor = paddle.cast(tensor, paddle.float32)
tensor[tensor > 0] /= 2147483647.0
tensor[tensor < 0] /= 2147483648.0
elif tensor.dtype == paddle.int16:
tensor = paddle.cast(tensor, paddle.float32)
tensor[tensor > 0] /= 32767.0
tensor[tensor < 0] /= 32768.0
elif tensor.dtype == paddle.uint8:
tensor = paddle.cast(tensor, paddle.float32) - 128
tensor[tensor > 0] /= 127.0
tensor[tensor < 0] /= 128.0
return tensor
def get_wav_data(
dtype: str,
num_channels: int,
*,
num_frames: Optional[int] = None,
normalize: bool = True,
channels_first: bool = True,
):
"""Generate linear signal of the given dtype and num_channels
Data range is
[-1.0, 1.0] for float32,
[-2147483648, 2147483647] for int32
[-32768, 32767] for int16
[0, 255] for uint8
num_frames allow to change the linear interpolation parameter.
Default values are 256 for uint8, else 1 << 16.
1 << 16 as default is so that int16 value range is completely covered.
"""
dtype_ = getattr(paddle, dtype)
if num_frames is None:
if dtype == "uint8":
num_frames = 256
else:
num_frames = 1 << 16
# paddle linspace not support uint8, int8, int16
#if dtype == "uint8":
# base = paddle.linspace(0, 255, num_frames, dtype=dtype_)
#elif dtype == "int8":
# base = paddle.linspace(-128, 127, num_frames, dtype=dtype_)
if dtype == "float32":
base = paddle.linspace(-1.0, 1.0, num_frames, dtype=dtype_)
elif dtype == "float64":
base = paddle.linspace(-1.0, 1.0, num_frames, dtype=dtype_)
elif dtype == "int32":
base = paddle.linspace(-2147483648, 2147483647, num_frames, dtype=dtype_)
#elif dtype == "int16":
# base = paddle.linspace(-32768, 32767, num_frames, dtype=dtype_)
else:
raise NotImplementedError(f"Unsupported dtype {dtype}")
data = base.tile([num_channels, 1])
if not channels_first:
data = data.transpose([1, 0])
if normalize:
data = normalize_wav(data)
return data
def load_wav(path: str, normalize=True, channels_first=True) -> paddle.Tensor:
"""Load wav file without paddleaudio"""
sample_rate, data = scipy.io.wavfile.read(path)
data = paddle.to_tensor(data.copy())
if data.ndim == 1:
data = data.unsqueeze(1)
if normalize:
data = normalize_wav(data)
if channels_first:
data = data.transpose([1, 0])
return data, sample_rate
def save_wav(path, data, sample_rate, channels_first=True):
"""Save wav file without paddleaudio"""
if channels_first:
data = data.transpose([1, 0])
scipy.io.wavfile.write(path, sample_rate, data.numpy())
Loading…
Cancel
Save