From d250ab0f95647e2129254e2a66e7ec2d7a5ed6b9 Mon Sep 17 00:00:00 2001 From: drryanhuang Date: Fri, 22 Nov 2024 08:28:12 +0000 Subject: [PATCH] add util && add quality --- audio/audiotools/basemodel.py | 51 ++++--- audio/audiotools/data/__init__.py | 3 + audio/audiotools/data/preprocess.py | 83 ++++++++++++ audio/audiotools/decorators.py | 55 ++++---- audio/audiotools/quality.py | 69 ++++++++++ audio/audiotools/util.py | 203 ++-------------------------- 6 files changed, 218 insertions(+), 246 deletions(-) create mode 100644 audio/audiotools/data/__init__.py create mode 100644 audio/audiotools/data/preprocess.py create mode 100644 audio/audiotools/quality.py diff --git a/audio/audiotools/basemodel.py b/audio/audiotools/basemodel.py index 2b9f916f9..633b50430 100644 --- a/audio/audiotools/basemodel.py +++ b/audio/audiotools/basemodel.py @@ -49,14 +49,13 @@ class BaseModel(nn.Layer): INTERN = [] def save( - self, - path: str, - metadata: dict = None, - package: bool = False, - intern: list = [], - extern: list = [], - mock: list = [], - ): + self, + path: str, + metadata: dict=None, + package: bool=False, + intern: list=[], + extern: list=[], + mock: list=[], ): """Saves the model, either as a package, or just as weights, alongside some specified metadata. @@ -123,13 +122,12 @@ class BaseModel(nn.Layer): @classmethod def load( - cls, - location: str, - *args, - package_name: str = None, - strict: bool = False, - **kwargs, - ): + cls, + location: str, + *args, + package_name: str=None, + strict: bool=False, + **kwargs, ): """Load model from a path. Tries first to load as a package, and if that fails, tries to load as weights. The arguments to the class are specified inside the model weights file. @@ -178,11 +176,10 @@ class BaseModel(nn.Layer): raise NotImplementedError("Currently Paddle does not support packaging") def save_to_folder( - self, - folder: typing.Union[str, Path], - extra_data: dict = None, - package: bool = False, - ): + self, + folder: typing.Union[str, Path], + extra_data: dict=None, + package: bool=False, ): """Dumps a model into a folder, as both a package and as weights, as well as anything specified in ``extra_data``. ``extra_data`` is a dictionary of other @@ -229,12 +226,11 @@ class BaseModel(nn.Layer): @classmethod def load_from_folder( - cls, - folder: typing.Union[str, Path], - package: bool = False, - strict: bool = False, - **kwargs, - ): + cls, + folder: typing.Union[str, Path], + package: bool=False, + strict: bool=False, + **kwargs, ): """Loads the model from a folder generated by :py:func:`audiotools.ml.layers.base.BaseModel.save_to_folder`. Like that function, this one looks for a subfolder that has @@ -265,8 +261,7 @@ class BaseModel(nn.Layer): extra_data = {} excluded = ["package.pth", "weights.pth"] files = [ - x - for x in folder.glob("*") + x for x in folder.glob("*") if x.is_file() and x.name not in excluded ] for f in files: diff --git a/audio/audiotools/data/__init__.py b/audio/audiotools/data/__init__.py new file mode 100644 index 000000000..5265384bd --- /dev/null +++ b/audio/audiotools/data/__init__.py @@ -0,0 +1,3 @@ +# from . import datasets +from . import preprocess +# from . import transforms diff --git a/audio/audiotools/data/preprocess.py b/audio/audiotools/data/preprocess.py new file mode 100644 index 000000000..47de4352b --- /dev/null +++ b/audio/audiotools/data/preprocess.py @@ -0,0 +1,83 @@ +import csv +import os +from pathlib import Path + +from audio_signal import AudioSignal +from tqdm import tqdm +# from ..core import AudioSignal + + +def create_csv(audio_files: list, + output_csv: Path, + loudness: bool=False, + data_path: str=None): + """Converts a folder of audio files to a CSV file. If ``loudness = True``, + the output of this function will create a CSV file that looks something + like: + + .. csv-table:: + :header: path,loudness + + daps/produced/f1_script1_produced.wav,-16.299999237060547 + daps/produced/f1_script2_produced.wav,-16.600000381469727 + daps/produced/f1_script3_produced.wav,-17.299999237060547 + daps/produced/f1_script4_produced.wav,-16.100000381469727 + daps/produced/f1_script5_produced.wav,-16.700000762939453 + daps/produced/f3_script1_produced.wav,-16.5 + + .. note:: + The paths above are written relative to the ``data_path`` argument + which defaults to the environment variable ``PATH_TO_DATA`` if + it isn't passed to this function, and defaults to the empty string + if that environment variable is not set. + + You can produce a CSV file from a directory of audio files via: + + >>> import audiotools + >>> directory = ... + >>> audio_files = audiotools.util.find_audio(directory) + >>> output_path = "train.csv" + >>> audiotools.data.preprocess.create_csv( + >>> audio_files, output_csv, loudness=True + >>> ) + + Note that you can create empty rows in the CSV file by passing an empty + string or None in the ``audio_files`` list. This is useful if you want to + sync multiple CSV files in a multitrack setting. The loudness of these + empty rows will be set to -inf. + + Parameters + ---------- + audio_files : list + List of audio files. + output_csv : Path + Output CSV, with each row containing the relative path of every file + to ``data_path``, if specified (defaults to None). + loudness : bool + Compute loudness of entire file and store alongside path. + """ + + info = [] + pbar = tqdm(audio_files) + for af in pbar: + af = Path(af) + pbar.set_description(f"Processing {af.name}") + _info = {} + if af.name == "": + _info["path"] = "" + if loudness: + _info["loudness"] = -float("inf") + else: + _info["path"] = af.relative_to( + data_path) if data_path is not None else af + if loudness: + _info["loudness"] = AudioSignal(af).ffmpeg_loudness().item() + + info.append(_info) + + with open(output_csv, "w") as f: + writer = csv.DictWriter(f, fieldnames=list(info[0].keys())) + writer.writeheader() + + for item in info: + writer.writerow(item) diff --git a/audio/audiotools/decorators.py b/audio/audiotools/decorators.py index 27982758a..7086313d4 100644 --- a/audio/audiotools/decorators.py +++ b/audio/audiotools/decorators.py @@ -88,7 +88,7 @@ def when(condition): return decorator -def timer(prefix: str = "time"): +def timer(prefix: str="time"): """✅Adds execution time to the output dictionary of the decorated function. The function decorated by this must output a dictionary. The key added will follow the form "[prefix]/[name_of_function]" @@ -161,13 +161,12 @@ class Tracker: """ def __init__( - self, - writer: LogWriter = None, - log_file: str = None, - rank: int = 0, - console_width: int = 100, - step: int = 0, - ): + self, + writer: LogWriter=None, + log_file: str=None, + rank: int=0, + console_width: int=100, + step: int=0, ): """ Initializes the Tracker object. @@ -199,14 +198,12 @@ class Tracker: BarColumn(), TimeElapsedColumn(), "/", - TimeRemainingColumn(), - ) + TimeRemainingColumn(), ) self.consoles = [Console(width=console_width)] self.live = Live(console=self.consoles[0], refresh_per_second=10) if log_file is not None: self.consoles.append( - Console(width=console_width, file=open(log_file, "a")) - ) + Console(width=console_width, file=open(log_file, "a"))) def print(self, msg): """ @@ -259,10 +256,7 @@ class Tracker: group, padding=(0, 5), title="[b]Progress", - border_style="blue", - ), - ) - ) + border_style="blue", ), )) def done(self, label: str, title: str): """ @@ -286,13 +280,12 @@ class Tracker: self.print(group) def track( - self, - label: str, - length: int, - completed: int = 0, - op: dist.ReduceOp = dist.ReduceOp.AVG, - ddp_active: bool = "LOCAL_RANK" in os.environ, - ): + self, + label: str, + length: int, + completed: int=0, + op: dist.ReduceOp=dist.ReduceOp.AVG, + ddp_active: bool="LOCAL_RANK" in os.environ, ): """ A decorator for tracking the progress and metrics of a function. @@ -310,10 +303,13 @@ class Tracker: Whether the DistributedDataParallel is active, by default "LOCAL_RANK" in os.environ. """ self.tasks[label] = { - "pbar": self.pbar.add_task( - f"[white]Iteration ({label})", total=length, completed=completed - ), - "table": Table(), + "pbar": + self.pbar.add_task( + f"[white]Iteration ({label})", + total=length, + completed=completed), + "table": + Table(), } self.metrics[label] = { "value": defaultdict(), @@ -356,7 +352,7 @@ class Tracker: return decorator - def log(self, label: str, value_type: str = "value", history: bool = True): + def log(self, label: str, value_type: str="value", history: bool=True): """ A decorator for logging the metrics of a function. @@ -385,8 +381,7 @@ class Tracker: v = v() if isinstance(v, Mean) else v if self.writer is not None: self.writer.add_scalar( - tag=f"{k}/{label}", value=v, step=self.step - ) + tag=f"{k}/{label}", value=v, step=self.step) if label in self.history: self.history[label][k].append(v) diff --git a/audio/audiotools/quality.py b/audio/audiotools/quality.py new file mode 100644 index 000000000..eec3014bc --- /dev/null +++ b/audio/audiotools/quality.py @@ -0,0 +1,69 @@ +import os + +import numpy as np +import paddle +from audio_signal import AudioSignal + + +def visqol( + estimates: AudioSignal, + references: AudioSignal, + mode: str="audio", ): # pragma: no cover + """ViSQOL score. + + Parameters + ---------- + estimates : AudioSignal + Degraded AudioSignal + references : AudioSignal + Reference AudioSignal + mode : str, optional + 'audio' or 'speech', by default 'audio' + + Returns + ------- + Tensor[float] + ViSQOL score (MOS-LQO) + """ + try: + from pyvisqol import visqol_lib_py + from pyvisqol.pb2 import visqol_config_pb2 + from pyvisqol.pb2 import similarity_result_pb2 + except ImportError: + from visqol import visqol_lib_py + from visqol.pb2 import visqol_config_pb2 + from visqol.pb2 import similarity_result_pb2 + + config = visqol_config_pb2.VisqolConfig() + if mode == "audio": + target_sr = 48000 + config.options.use_speech_scoring = False + svr_model_path = "libsvm_nu_svr_model.txt" + elif mode == "speech": + target_sr = 16000 + config.options.use_speech_scoring = True + svr_model_path = "lattice_tcditugenmeetpackhref_ls2_nl60_lr12_bs2048_learn.005_ep2400_train1_7_raw.tflite" + else: + raise ValueError(f"Unrecognized mode: {mode}") + config.audio.sample_rate = target_sr + config.options.svr_model_path = os.path.join( + os.path.dirname(visqol_lib_py.__file__), "model", svr_model_path) + + api = visqol_lib_py.VisqolApi() + api.Create(config) + + estimates = estimates.clone().to_mono().resample(target_sr) + references = references.clone().to_mono().resample(target_sr) + + visqols = [] + for i in range(estimates.batch_size): + _visqol = api.Measure( + references.audio_data[i, 0].detach().cpu().numpy().astype(float), + estimates.audio_data[i, 0].detach().cpu().numpy().astype(float), ) + visqols.append(_visqol.moslqo) + return paddle.to_tensor(np.array(visqols)) + + +if __name__ == "__main__": + signal = AudioSignal(paddle.randn([44100]), 44100) + print(visqol(signal, signal)) diff --git a/audio/audiotools/util.py b/audio/audiotools/util.py index 9f984aec0..8f8fa1da1 100644 --- a/audio/audiotools/util.py +++ b/audio/audiotools/util.py @@ -12,12 +12,16 @@ from typing import Dict from typing import List from typing import Optional +import librosa import numpy as np import paddle import soundfile +from audio_signal import AudioSignal from flatten_dict import flatten from flatten_dict import unflatten +from ..data.preprocess import create_csv + @dataclass class Info: @@ -89,35 +93,6 @@ def _get_value(other): return other -def hz_to_bin(hz: paddle.Tensor, n_fft: int, sample_rate: int): - """Closest frequency bin given a frequency, number - of bins, and a sampling rate. - - Parameters - ---------- - hz : paddle.Tensor - Tensor of frequencies in Hz. - n_fft : int - Number of FFT bins. - sample_rate : int - Sample rate of audio. - - Returns - ------- - paddle.Tensor - Closest bins to the data. - """ - shape = hz.shape - hz = hz.flatten() - freqs = paddle.linspace(0, sample_rate / 2, 2 + n_fft // 2) - hz[hz > sample_rate / 2] = sample_rate / 2 - - closest = (hz[None, :] - freqs[:, None]).abs() - closest_bins = closest.min(dim=0).indices - - return closest_bins.reshape(*shape) - - def random_state(seed: typing.Union[int, np.random.RandomState]): """✅ Turn seed into a np.random.RandomState instance. @@ -151,37 +126,25 @@ def random_state(seed: typing.Union[int, np.random.RandomState]): " instance" % seed) -def seed(random_seed, set_cudnn=False): - """ +def seed(random_seed): + """✅ Seeds all random states with the same random seed for reproducibility. Seeds ``numpy``, ``random`` and ``paddle`` random generators. - For full reproducibility, two further options must be set - according to the paddle documentation: - https://pypaddle.org/docs/stable/notes/randomness.html - To do this, ``set_cudnn`` must be True. It defaults to - False, since setting it to True results in a performance - hit. Args: random_seed (int): integer corresponding to random seed to use. - set_cudnn (bool): Whether or not to set cudnn into determinstic - mode and off of benchmark mode. Defaults to False. """ - paddle.manual_seed(random_seed) + paddle.seed(random_seed) np.random.seed(random_seed) random.seed(random_seed) - if set_cudnn: - paddle.backends.cudnn.deterministic = True - paddle.backends.cudnn.benchmark = False - @contextmanager def _close_temp_files(tmpfiles: list): - """Utility function for creating a context and closing all temporary files + """✅Utility function for creating a context and closing all temporary files once the context is exited. For correct functionality, all temporary file handles created inside the context must be appended to the ```tmpfiles``` list. @@ -214,7 +177,7 @@ AUDIO_EXTENSIONS = [".wav", ".flac", ".mp3", ".mp4"] def find_audio(folder: str, ext: List[str]=AUDIO_EXTENSIONS): - """Finds all audio files in a directory recursively. + """✅Finds all audio files in a directory recursively. Returns a list. Parameters @@ -247,7 +210,7 @@ def read_sources( remove_empty: bool=True, relative_path: str="", ext: List[str]=AUDIO_EXTENSIONS, ): - """Reads audio sources that can either be folders + """✅Reads audio sources that can either be folders full of audio files, or CSV files that contain paths to audio files. CSV files that adhere to the expected format can be generated by @@ -292,7 +255,7 @@ def read_sources( def choose_from_list_of_lists(state: np.random.RandomState, list_of_lists: list, p: float=None): - """Choose a single item from a list of lists. + """✅Choose a single item from a list of lists. Parameters ---------- @@ -335,7 +298,7 @@ def chdir(newdir: typing.Union[Path, str]): def prepare_batch(batch: typing.Union[dict, list, paddle.Tensor], device: str="cpu"): - """Moves items in a batch (typically generated by a DataLoader as a list + """✅Moves items in a batch (typically generated by a DataLoader as a list or a dict) to the specified device. This works even if dictionaries are nested. @@ -352,6 +315,7 @@ def prepare_batch(batch: typing.Union[dict, list, paddle.Tensor], typing.Union[dict, list, paddle.Tensor] Batch with all values moved to the specified device. """ + device = device.replace("cuda", "gpu") if isinstance(batch, dict): batch = flatten(batch) for key, val in batch.items(): @@ -372,7 +336,7 @@ def prepare_batch(batch: typing.Union[dict, list, paddle.Tensor], def sample_from_dist(dist_tuple: tuple, state: np.random.RandomState=None): - """Samples from a distribution defined by a tuple. The first + """✅Samples from a distribution defined by a tuple. The first item in the tuple is the distribution type, and the rest of the items are arguments to that distribution. The distribution function is gotten from the ``np.random.RandomState`` object. @@ -414,64 +378,6 @@ def sample_from_dist(dist_tuple: tuple, state: np.random.RandomState=None): return dist_fn(*dist_tuple[1:]) -def collate(list_of_dicts: list, n_splits: int=None): - """Collates a list of dictionaries (e.g. as returned by a - dataloader) into a dictionary with batched values. This routine - uses the default paddle collate function for everything - except AudioSignal objects, which are handled by the - :py:func:`audiotools.core.audio_signal.AudioSignal.batch` - function. - - This function takes n_splits to enable splitting a batch - into multiple sub-batches for the purposes of gradient accumulation, - etc. - - Parameters - ---------- - list_of_dicts : list - List of dictionaries to be collated. - n_splits : int - Number of splits to make when creating the batches (split into - sub-batches). Useful for things like gradient accumulation. - - Returns - ------- - dict - Dictionary containing batched data. - """ - - from . import AudioSignal - - batches = [] - list_len = len(list_of_dicts) - - return_list = False if n_splits is None else True - n_splits = 1 if n_splits is None else n_splits - n_items = int(math.ceil(list_len / n_splits)) - - for i in range(0, list_len, n_items): - # Flatten the dictionaries to avoid recursion. - list_of_dicts_ = [flatten(d) for d in list_of_dicts[i:i + n_items]] - dict_of_lists = { - k: [dic[k] for dic in list_of_dicts_] - for k in list_of_dicts_[0] - } - - batch = {} - for k, v in dict_of_lists.items(): - if isinstance(v, list): - if all(isinstance(s, AudioSignal) for s in v): - batch[k] = AudioSignal.batch(v, pad_signals=True) - else: - # Borrow the default collate fn from paddle. - batch[k] = paddle.utils.data._utils.collate.default_collate( - v) - batches.append(unflatten(batch)) - - batches = batches[0] if not return_list else batches - return batches - - BASE_SIZE = 864 DEFAULT_FIG_SIZE = (9, 3) @@ -483,7 +389,7 @@ def format_figure( format_axes: bool=True, format: bool=True, font_color: str="white", ): - """Prettifies the spectrogram and waveform plots. A title + """✅Prettifies the spectrogram and waveform plots. A title can be inset into the top right corner, and the axes can be inset into the figure, allowing the data to take up the entire image. Used in @@ -578,82 +484,3 @@ def format_figure( va="top", color="white", ) t.set_bbox(dict(facecolor="black", alpha=0.5, edgecolor="black")) - - -def generate_chord_dataset( - max_voices: int=8, - sample_rate: int=44100, - num_items: int=5, - duration: float=1.0, - min_note: str="C2", - max_note: str="C6", - output_dir: Path="chords", ): - """ - Generates a toy multitrack dataset of chords, synthesized from sine waves. - - - Parameters - ---------- - max_voices : int, optional - Maximum number of voices in a chord, by default 8 - sample_rate : int, optional - Sample rate of audio, by default 44100 - num_items : int, optional - Number of items to generate, by default 5 - duration : float, optional - Duration of each item, by default 1.0 - min_note : str, optional - Minimum note in the dataset, by default "C2" - max_note : str, optional - Maximum note in the dataset, by default "C6" - output_dir : Path, optional - Directory to save the dataset, by default "chords" - - """ - import librosa - from . import AudioSignal - from ..data.preprocess import create_csv - - min_midi = librosa.note_to_midi(min_note) - max_midi = librosa.note_to_midi(max_note) - - tracks = [] - for idx in range(num_items): - track = {} - # figure out how many voices to put in this track - num_voices = random.randint(1, max_voices) - for voice_idx in range(num_voices): - # choose some random params - midinote = random.randint(min_midi, max_midi) - dur = random.uniform(0.85 * duration, duration) - - sig = AudioSignal.wave( - frequency=librosa.midi_to_hz(midinote), - duration=dur, - sample_rate=sample_rate, - shape="sine", ) - track[f"voice_{voice_idx}"] = sig - tracks.append(track) - - # save the tracks to disk - output_dir = Path(output_dir) - output_dir.mkdir(exist_ok=True) - for idx, track in enumerate(tracks): - track_dir = output_dir / f"track_{idx}" - track_dir.mkdir(exist_ok=True) - for voice_name, sig in track.items(): - sig.write(track_dir / f"{voice_name}.wav") - - all_voices = list(set([k for track in tracks for k in track.keys()])) - voice_lists = {voice: [] for voice in all_voices} - for track in tracks: - for voice_name in all_voices: - if voice_name in track: - voice_lists[voice_name].append(track[voice_name].path_to_file) - else: - voice_lists[voice_name].append("") - - for voice_name, paths in voice_lists.items(): - create_csv(paths, output_dir / f"{voice_name}.csv", loudness=True) - - return output_dir