You can not select more than 25 topics
Topics must start with a letter or number, can include dashes ('-') and can be up to 35 characters long.
90 lines
2.2 KiB
90 lines
2.2 KiB
2 years ago
|
import itertools
|
||
|
from unittest import skipIf
|
||
|
|
||
|
from paddleaudio._internal.module_utils import is_module_available
|
||
|
from parameterized import parameterized
|
||
|
|
||
|
|
||
|
def name_func(func, _, params):
|
||
|
return f'{func.__name__}_{"_".join(str(arg) for arg in params.args)}'
|
||
|
|
||
|
|
||
|
def dtype2subtype(dtype):
|
||
|
return {
|
||
|
"float64": "DOUBLE",
|
||
|
"float32": "FLOAT",
|
||
|
"int32": "PCM_32",
|
||
|
"int16": "PCM_16",
|
||
|
"uint8": "PCM_U8",
|
||
|
"int8": "PCM_S8",
|
||
|
}[dtype]
|
||
|
|
||
|
|
||
|
def skipIfFormatNotSupported(fmt):
|
||
|
fmts = []
|
||
|
if is_module_available("soundfile"):
|
||
|
import soundfile
|
||
|
|
||
|
fmts = soundfile.available_formats()
|
||
|
return skipIf(fmt not in fmts, f'"{fmt}" is not supported by soundfile')
|
||
|
return skipIf(True, '"soundfile" not available.')
|
||
|
|
||
|
|
||
|
def parameterize(*params):
|
||
|
return parameterized.expand(
|
||
|
list(itertools.product(*params)), name_func=name_func)
|
||
|
|
||
|
|
||
|
def fetch_wav_subtype(dtype, encoding, bits_per_sample):
|
||
|
subtype = {
|
||
|
(None, None): dtype2subtype(dtype),
|
||
|
(None, 8): "PCM_U8",
|
||
|
("PCM_U", None): "PCM_U8",
|
||
|
("PCM_U", 8): "PCM_U8",
|
||
|
("PCM_S", None): "PCM_32",
|
||
|
("PCM_S", 16): "PCM_16",
|
||
|
("PCM_S", 32): "PCM_32",
|
||
|
("PCM_F", None): "FLOAT",
|
||
|
("PCM_F", 32): "FLOAT",
|
||
|
("PCM_F", 64): "DOUBLE",
|
||
|
("ULAW", None): "ULAW",
|
||
|
("ULAW", 8): "ULAW",
|
||
|
("ALAW", None): "ALAW",
|
||
|
("ALAW", 8): "ALAW",
|
||
|
}.get((encoding, bits_per_sample))
|
||
|
if subtype:
|
||
|
return subtype
|
||
|
raise ValueError(f"wav does not support ({encoding}, {bits_per_sample}).")
|
||
|
|
||
|
def get_encoding(ext, dtype):
|
||
|
exts = {
|
||
|
"mp3",
|
||
|
"flac",
|
||
|
"vorbis",
|
||
|
}
|
||
|
encodings = {
|
||
|
"float32": "PCM_F",
|
||
|
"int32": "PCM_S",
|
||
|
"int16": "PCM_S",
|
||
|
"uint8": "PCM_U",
|
||
|
}
|
||
|
return ext.upper() if ext in exts else encodings[dtype]
|
||
|
|
||
|
|
||
|
def get_bit_depth(dtype):
|
||
|
bit_depths = {
|
||
|
"float32": 32,
|
||
|
"int32": 32,
|
||
|
"int16": 16,
|
||
|
"uint8": 8,
|
||
|
}
|
||
|
return bit_depths[dtype]
|
||
|
|
||
|
def get_bits_per_sample(ext, dtype):
|
||
|
bits_per_samples = {
|
||
|
"flac": 24,
|
||
|
"mp3": 0,
|
||
|
"vorbis": 0,
|
||
|
}
|
||
|
return bits_per_samples.get(ext, get_bit_depth(dtype))
|