modify info test

pull/2195/head
YangZhou 3 years ago
parent 7abfad804e
commit a55592f7c2

@ -88,7 +88,7 @@ def save(filepath: str,
)
@_mod_utils.requires_sox()
def info(filepath: str, format: Optional[str] = "") -> None:
def info(filepath: str, format: Optional[str] = None,) -> AudioMetaData:
if hasattr(filepath, "read"):
sinfo = paddleaudio.get_info_fileobj(filepath, format)
if sinfo is not None:

@ -13,13 +13,14 @@ using namespace paddleaudio::sox_utils;
namespace paddleaudio {
namespace sox_io {
auto get_info_file(const std::string &path, const std::string &format)
auto get_info_file(const std::string &path,
const tl::optional<std::string> &format)
-> std::tuple<int64_t, int64_t, int64_t, int64_t, std::string> {
SoxFormat sf(
sox_open_read(path.data(),
/*signal=*/nullptr,
/*encoding=*/nullptr,
/*filetype=*/format.empty() ? nullptr : format.data()));
/*filetype=*/format.has_value() ? format.value().c_str() : nullptr));
validate_input_file(sf, path);
@ -61,7 +62,8 @@ std::vector<std::vector<std::string>> get_effects(
return effects;
}
auto get_info_fileobj(py::object fileobj, const std::string &format)
auto get_info_fileobj(py::object fileobj,
const tl::optional<std::string> &format)
-> std::tuple<int64_t, int64_t, int64_t, int64_t, std::string> {
const auto capacity = [&]() {
const auto bufsiz = get_buffer_size();
@ -80,7 +82,7 @@ auto get_info_fileobj(py::object fileobj, const std::string &format)
buf_size,
/*signal=*/nullptr,
/*encoding=*/nullptr,
/*filetype=*/format.empty() ? nullptr : format.data()));
/*filetype=*/format.has_value() ? format.value().c_str() : nullptr));
// In case of streamed data, length can be 0
validate_input_memfile(sf);

@ -10,10 +10,12 @@ namespace py = pybind11;
namespace paddleaudio {
namespace sox_io {
auto get_info_file(const std::string &path, const std::string &format)
auto get_info_file(const std::string &path,
const tl::optional<std::string> &format)
-> std::tuple<int64_t, int64_t, int64_t, int64_t, std::string>;
auto get_info_fileobj(py::object fileobj, const std::string &format)
auto get_info_fileobj(py::object fileobj,
const tl::optional<std::string> &format)
-> std::tuple<int64_t, int64_t, int64_t, int64_t, std::string>;
tl::optional<std::tuple<py::array, int64_t>> load_audio_fileobj(

@ -1,34 +1,290 @@
# Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import unittest
import itertools
import tarfile
from contextlib import contextmanager
import numpy as np
import paddle
import os
import io
from parameterized import parameterized
from paddlespeech.audio.backends import sox_io_backend
class TestInfo(unittest.TestCase):
from tests.unit.common_utils import (
get_wav_data,
load_wav,
save_wav,
TempDirMixin,
sox_utils,
data_utils
)
def test_wav(self, dtype, sample_rate, num_channels, sample_size):
"""check wav file correctly """
path = 'testdata/test.wav'
info = sox_io_backend.get_info_file(path)
from common import get_encoding, get_bits_per_sample
#code is from:https://github.com/pytorch/audio/blob/main/torchaudio/test/torchaudio_unittest/backend/sox_io/info_test.py
class TestInfo(TempDirMixin, unittest.TestCase):
@parameterized.expand(
list(
itertools.product(
["float32", "int32",],
[8000, 16000],
[1, 2],
)
),
)
def test_wav(self, dtype, sample_rate, num_channels):
"""`sox_io_backend.info` can check wav file correctly"""
duration = 1
path = self.get_temp_path("data.wav")
data = get_wav_data(dtype, num_channels, normalize=False, num_frames=duration * sample_rate)
save_wav(path, data, sample_rate)
info = sox_io_backend.info(path)
assert info.sample_rate == sample_rate
assert info.num_frames == sample_rate * duration
assert info.num_channels == num_channels
assert info.bits_per_sample == sox_utils.get_bit_depth(dtype)
assert info.encoding == get_encoding("wav", dtype)
@parameterized.expand(
list(
itertools.product(
["float32", "int32"],
[8000, 16000],
[4, 8, 16, 32],
)
),
)
def test_wav_multiple_channels(self, dtype, sample_rate, num_channels):
"""`sox_io_backend.info` can check wav file with channels more than 2 correctly"""
duration = 1
path = self.get_temp_path("data.wav")
data = get_wav_data(dtype, num_channels, normalize=False, num_frames=duration * sample_rate)
save_wav(path, data, sample_rate)
info = sox_io_backend.info(path)
assert info.sample_rate == sample_rate
assert info.num_frames == sample_rate * duration
assert info.num_channels == num_channels
assert info.bits_per_sample == sox_utils.get_bit_depth(dtype)
def test_ulaw(self):
"""`sox_io_backend.info` can check ulaw file correctly"""
duration = 1
num_channels = 1
sample_rate = 8000
path = self.get_temp_path("data.wav")
sox_utils.gen_audio_file(
path, sample_rate=sample_rate, num_channels=num_channels, bit_depth=8, encoding="u-law", duration=duration
)
info = sox_io_backend.info(path)
assert info.sample_rate == sample_rate
assert info.num_frames == sample_rate * duration
assert info.num_channels == num_channels
assert info.bits_per_sample == 8
assert info.encoding == "ULAW"
def test_alaw(self):
"""`sox_io_backend.info` can check alaw file correctly"""
duration = 1
num_channels = 1
sample_rate = 8000
path = self.get_temp_path("data.wav")
sox_utils.gen_audio_file(
path, sample_rate=sample_rate, num_channels=num_channels, bit_depth=8, encoding="a-law", duration=duration
)
info = sox_io_backend.info(path)
assert info.sample_rate == sample_rate
assert info.num_frames == sample_size # duration*sample_rate
assert info.num_frames == sample_rate * duration
assert info.num_channels == num_channels
assert info.bits_per_sample == get_bit_depth(dtype)
assert info.encoding == get_encoding('wav', dtype)
assert info.bits_per_sample == 8
assert info.encoding == "ALAW"
#class TestInfoOpus(unittest.TestCase):
#@parameterized.expand(
#list(
#itertools.product(
#["96k"],
#[1, 2],
#[0, 5, 10],
#)
#),
#)
#def test_opus(self, bitrate, num_channels, compression_level):
#"""`sox_io_backend.info` can check opus file correcty"""
#path = data_utils.get_asset_path("io", f"{bitrate}_{compression_level}_{num_channels}ch.opus")
#info = sox_io_backend.info(path)
#assert info.sample_rate == 48000
#assert info.num_frames == 32768
#assert info.num_channels == num_channels
#assert info.bits_per_sample == 0 # bit_per_sample is irrelevant for compressed formats
#assert info.encoding == "OPUS"
class FileObjTestBase(TempDirMixin):
def _gen_file(self, ext, dtype, sample_rate, num_channels, num_frames, *, comments=None):
path = self.get_temp_path(f"test.{ext}")
bit_depth = sox_utils.get_bit_depth(dtype)
duration = num_frames / sample_rate
comment_file = self._gen_comment_file(comments) if comments else None
sox_utils.gen_audio_file(
path,
sample_rate,
num_channels=num_channels,
encoding=sox_utils.get_encoding(dtype),
bit_depth=bit_depth,
duration=duration,
comment_file=comment_file,
)
return path
def _gen_comment_file(self, comments):
comment_path = self.get_temp_path("comment.txt")
with open(comment_path, "w") as file_:
file_.writelines(comments)
return comment_path
class Unseekable:
def __init__(self, fileobj):
self.fileobj = fileobj
def read(self, n):
return self.fileobj.read(n)
class TestFileObject(FileObjTestBase, unittest.TestCase):
def _query_fileobj(self, ext, dtype, sample_rate, num_channels, num_frames, *, comments=None):
path = self._gen_file(ext, dtype, sample_rate, num_channels, num_frames, comments=comments)
format_ = ext if ext in ["mp3"] else None
with open(path, "rb") as fileobj:
return sox_io_backend.info(fileobj, format_)
def _query_bytesio(self, ext, dtype, sample_rate, num_channels, num_frames):
path = self._gen_file(ext, dtype, sample_rate, num_channels, num_frames)
format_ = ext if ext in ["mp3"] else None
with open(path, "rb") as file_:
fileobj = io.BytesIO(file_.read())
return sox_io_backend.info(fileobj, format_)
def _query_tarfile(self, ext, dtype, sample_rate, num_channels, num_frames):
audio_path = self._gen_file(ext, dtype, sample_rate, num_channels, num_frames)
audio_file = os.path.basename(audio_path)
archive_path = self.get_temp_path("archive.tar.gz")
with tarfile.TarFile(archive_path, "w") as tarobj:
tarobj.add(audio_path, arcname=audio_file)
format_ = ext if ext in ["mp3"] else None
with tarfile.TarFile(archive_path, "r") as tarobj:
fileobj = tarobj.extractfile(audio_file)
return sox_io_backend.info(fileobj, format_)
@contextmanager
def _set_buffer_size(self, buffer_size):
try:
original_buffer_size = get_buffer_size()
set_buffer_size(buffer_size)
yield
finally:
set_buffer_size(original_buffer_size)
@parameterized.expand(
[
("wav", "float32"),
("wav", "int32"),
("wav", "int16"),
("wav", "uint8"),
]
)
def test_fileobj(self, ext, dtype):
"""Querying audio via file object works"""
sample_rate = 16000
num_frames = 3 * sample_rate
num_channels = 2
sinfo = self._query_fileobj(ext, dtype, sample_rate, num_channels, num_frames)
bits_per_sample = get_bits_per_sample(ext, dtype)
num_frames = 0 if ext in ["mp3", "vorbis"] else num_frames
assert sinfo.sample_rate == sample_rate
assert sinfo.num_channels == num_channels
assert sinfo.num_frames == num_frames
assert sinfo.bits_per_sample == bits_per_sample
assert sinfo.encoding == get_encoding(ext, dtype)
@parameterized.expand(
[
("wav", "float32"),
("wav", "int32"),
("wav", "int16"),
("wav", "uint8"),
]
)
def test_bytesio(self, ext, dtype):
"""Querying audio via ByteIO object works for small data"""
sample_rate = 16000
num_frames = 3 * sample_rate
num_channels = 2
sinfo = self._query_bytesio(ext, dtype, sample_rate, num_channels, num_frames)
bits_per_sample = get_bits_per_sample(ext, dtype)
num_frames = 0 if ext in ["mp3", "vorbis"] else num_frames
assert sinfo.sample_rate == sample_rate
assert sinfo.num_channels == num_channels
assert sinfo.num_frames == num_frames
assert sinfo.bits_per_sample == bits_per_sample
assert sinfo.encoding == get_encoding(ext, dtype)
@parameterized.expand(
[
("wav", "float32"),
("wav", "int32"),
("wav", "int16"),
("wav", "uint8"),
]
)
def test_bytesio_tiny(self, ext, dtype):
"""Querying audio via ByteIO object works for small data"""
sample_rate = 8000
num_frames = 4
num_channels = 2
sinfo = self._query_bytesio(ext, dtype, sample_rate, num_channels, num_frames)
bits_per_sample = get_bits_per_sample(ext, dtype)
num_frames = 0 if ext in ["mp3", "vorbis"] else num_frames
assert sinfo.sample_rate == sample_rate
assert sinfo.num_channels == num_channels
assert sinfo.num_frames == num_frames
assert sinfo.bits_per_sample == bits_per_sample
assert sinfo.encoding == get_encoding(ext, dtype)
@parameterized.expand(
[
("wav", "float32"),
("wav", "int32"),
("wav", "int16"),
("wav", "uint8"),
("flac", "float32"),
("vorbis", "float32"),
("amb", "int16"),
]
)
def test_tarfile(self, ext, dtype):
"""Querying compressed audio via file-like object works"""
sample_rate = 16000
num_frames = 3.0 * sample_rate
num_channels = 2
sinfo = self._query_tarfile(ext, dtype, sample_rate, num_channels, num_frames)
bits_per_sample = get_bits_per_sample(ext, dtype)
num_frames = 0 if ext in ["vorbis"] else num_frames
assert sinfo.sample_rate == sample_rate
assert sinfo.num_channels == num_channels
assert sinfo.num_frames == num_frames
assert sinfo.bits_per_sample == bits_per_sample
assert sinfo.encoding == get_encoding(ext, dtype)
if __name__ == '__main__':
unittest.main()
Loading…
Cancel
Save