import itertools from unittest import skipIf from paddleaudio._internal.module_utils import is_module_available from parameterized import parameterized def name_func(func, _, params): return f'{func.__name__}_{"_".join(str(arg) for arg in params.args)}' def dtype2subtype(dtype): return { "float64": "DOUBLE", "float32": "FLOAT", "int32": "PCM_32", "int16": "PCM_16", "uint8": "PCM_U8", "int8": "PCM_S8", }[dtype] def skipIfFormatNotSupported(fmt): fmts = [] if is_module_available("soundfile"): import soundfile fmts = soundfile.available_formats() return skipIf(fmt not in fmts, f'"{fmt}" is not supported by soundfile') return skipIf(True, '"soundfile" not available.') def parameterize(*params): return parameterized.expand( list(itertools.product(*params)), name_func=name_func) def fetch_wav_subtype(dtype, encoding, bits_per_sample): subtype = { (None, None): dtype2subtype(dtype), (None, 8): "PCM_U8", ("PCM_U", None): "PCM_U8", ("PCM_U", 8): "PCM_U8", ("PCM_S", None): "PCM_32", ("PCM_S", 16): "PCM_16", ("PCM_S", 32): "PCM_32", ("PCM_F", None): "FLOAT", ("PCM_F", 32): "FLOAT", ("PCM_F", 64): "DOUBLE", ("ULAW", None): "ULAW", ("ULAW", 8): "ULAW", ("ALAW", None): "ALAW", ("ALAW", 8): "ALAW", }.get((encoding, bits_per_sample)) if subtype: return subtype raise ValueError(f"wav does not support ({encoding}, {bits_per_sample}).") def get_encoding(ext, dtype): exts = { "mp3", "flac", "vorbis", } encodings = { "float32": "PCM_F", "int32": "PCM_S", "int16": "PCM_S", "uint8": "PCM_U", } return ext.upper() if ext in exts else encodings[dtype] def get_bit_depth(dtype): bit_depths = { "float32": 32, "int32": 32, "int16": 16, "uint8": 8, } return bit_depths[dtype] def get_bits_per_sample(ext, dtype): bits_per_samples = { "flac": 24, "mp3": 0, "vorbis": 0, } return bits_per_samples.get(ext, get_bit_depth(dtype))