diff --git a/paddlespeech/audiotools/core/util.py b/paddlespeech/audiotools/core/util.py index 676d57704..a5891a470 100644 --- a/paddlespeech/audiotools/core/util.py +++ b/paddlespeech/audiotools/core/util.py @@ -13,12 +13,9 @@ import typing from contextlib import contextmanager from dataclasses import dataclass from pathlib import Path -from typing import Any from typing import Callable from typing import Dict -from typing import Iterable from typing import List -from typing import NamedTuple from typing import Optional from typing import Tuple from typing import Type @@ -33,7 +30,6 @@ from flatten_dict import flatten from flatten_dict import unflatten from paddlespeech.utils import satisfy_paddle_version -from paddlespeech.vector.training.seeding import seed_everything __all__ = [ "exp_compat", diff --git a/paddlespeech/audiotools/ml/accelerator.py b/paddlespeech/audiotools/ml/accelerator.py index 74cb3331b..add632ebb 100644 --- a/paddlespeech/audiotools/ml/accelerator.py +++ b/paddlespeech/audiotools/ml/accelerator.py @@ -42,6 +42,23 @@ class ResumableSequentialSampler(SequenceSampler): self.start_idx = 0 # set the index back to 0 so for the next epoch +class DummyScaler: + def __init__(self): + pass + + def step(self, optimizer): + optimizer.step() + + def scale(self, loss): + return loss + + def unscale_(self, optimizer): + return optimizer + + def update(self): + pass + + class Accelerator: """This class is used to prepare models and dataloaders for usage with DDP or DP. Use the functions prepare_model, prepare_dataloader to @@ -78,22 +95,6 @@ class Accelerator: self.local_rank = 0 if trainer_id is None else int(trainer_id) self.amp = amp - class DummyScaler: - def __init__(self): - pass - - def step(self, optimizer): - optimizer.step() - - def scale(self, loss): - return loss - - def unscale_(self, optimizer): - return optimizer - - def update(self): - pass - self.scaler = paddle.amp.GradScaler() if self.amp else DummyScaler() def __enter__(self): diff --git a/paddlespeech/audiotools/ml/basemodel.py b/paddlespeech/audiotools/ml/basemodel.py index 2d5683266..cd4ca5213 100644 --- a/paddlespeech/audiotools/ml/basemodel.py +++ b/paddlespeech/audiotools/ml/basemodel.py @@ -1,3 +1,4 @@ +# TODO(DrRyanHuang): rm this file # MIT License, Copyright (c) 2023-Present, Descript. # Copyright (c) 2025 PaddlePaddle Authors. All Rights Reserved. # diff --git a/paddlespeech/codec/exps/dac/train.py b/paddlespeech/codec/exps/dac/train.py new file mode 100644 index 000000000..b6a61e281 --- /dev/null +++ b/paddlespeech/codec/exps/dac/train.py @@ -0,0 +1,445 @@ +import os +import sys +import warnings +from dataclasses import dataclass +from pathlib import Path + +import argbind +import paddle +import paddle.nn as nn +from visualdl import LogWriter + +import paddlespeech +import paddlespeech.t2s.modules.losses as _losses +from paddlespeech.audiotools.core import AudioSignal +from paddlespeech.audiotools.core import util +from paddlespeech.audiotools.data import transforms +from paddlespeech.audiotools.data.datasets import AudioDataset +from paddlespeech.audiotools.data.datasets import AudioLoader +from paddlespeech.audiotools.data.datasets import ConcatDataset +from paddlespeech.audiotools.ml import Accelerator +from paddlespeech.audiotools.ml.decorators import timer +from paddlespeech.audiotools.ml.decorators import Tracker +from paddlespeech.audiotools.ml.decorators import when +from paddlespeech.codec.models.dac_.model import DAC +from paddlespeech.codec.models.dac_.model import Discriminator +from paddlespeech.t2s.training.seeding import seed_everything + +warnings.filterwarnings("ignore", category=UserWarning) + +# Optimizers +AdamW = argbind.bind(paddle.optimizer.AdamW, "generator", "discriminator") +# Accelerator = argbind.bind(ml.Accelerator, without_prefix=True) + + +@argbind.bind("generator", "discriminator") +def ExponentialLR(optimizer, gamma: float=1.0): + return paddle.optimizer.lr.ExponentialDecay(optimizer, gamma) + + +# Models +# DAC = argbind.bind(dac.model.DAC) +# Discriminator = argbind.bind(dac.model.Discriminator) +DAC = argbind.bind(DAC) +Discriminator = argbind.bind(Discriminator) + +# Data +AudioDataset = argbind.bind(AudioDataset, "train", "val") +AudioLoader = argbind.bind(AudioLoader, "train", "val") + + +# Loss +# filter_fn = lambda fn: hasattr(fn, "forward") and "Loss" in fn.__name__ +def filter_fn(fn): + return hasattr(fn, "forward") and "Loss" in fn.__name__ + + +losses = argbind.bind_module(_losses, filter_fn=filter_fn) + + +def get_infinite_loader(dataloader): + while True: + for batch in dataloader: + yield batch + + +# Transforms +# filter_fn = lambda fn: hasattr(fn, "transform") and fn.__qualname__ not in [ +# "BaseTransform", +# "Compose", +# "Choose", ] +def filter_fn(fn): + return hasattr(fn, "transform") and fn.__qualname__ not in [ + "BaseTransform", + "Compose", + "Choose", + ] + + +tfm = argbind.bind_module(transforms, "train", "val", filter_fn=filter_fn) + + +# to_tfm = lambda l: [getattr(tfm, x)() for x in l] +def to_tfm(l): + return [getattr(tfm, x)() for x in l] + + +@argbind.bind("train", "val") +def build_transform( + augment_prob: float=1.0, + preprocess: list=["Identity"], + augment: list=["Identity"], + postprocess: list=["Identity"], ): + preprocess = transforms.Compose(*to_tfm(preprocess), name="preprocess") + augment = transforms.Compose( + *to_tfm(augment), name="augment", prob=augment_prob) + postprocess = transforms.Compose(*to_tfm(postprocess), name="postprocess") + transform = transforms.Compose(preprocess, augment, postprocess) + return transform + + +@argbind.bind("train", "val", "test") +def build_dataset( + sample_rate: int, + folders: dict=None, ): + # Give one loader per key/value of dictionary, where + # value is a list of folders. Create a dataset for each one. + # Concatenate the datasets with ConcatDataset, which + # cycles through them. + datasets = [] + for _, v in folders.items(): + loader = AudioLoader(sources=v) + transform = build_transform() + dataset = AudioDataset(loader, sample_rate, transform=transform) + datasets.append(dataset) + + dataset = ConcatDataset(datasets) + dataset.transform = transform + return dataset + + +@dataclass +class State: + generator: DAC + optimizer_g: AdamW + scheduler_g: ExponentialLR + + discriminator: Discriminator + optimizer_d: AdamW + scheduler_d: ExponentialLR + + stft_loss: losses.MultiScaleSTFTLoss + mel_loss: losses.MelSpectrogramLoss + gan_loss: losses.GANLoss + waveform_loss: nn.L1Loss + + train_data: AudioDataset + val_data: AudioDataset + + tracker: Tracker + + +# @argbind.bind(without_prefix=True) +def load( + args, + accel: Accelerator, + tracker: Tracker, + save_path: str, + resume: bool=False, + tag: str="latest", + load_weights: bool=False, ): + generator, g_extra = None, {} + discriminator, d_extra = None, {} + + if resume: + kwargs = { + "folder": f"{save_path}/{tag}", + "map_location": "cpu", + "package": not load_weights, + } + tracker.print( + f"Resuming from {str(Path('.').absolute())}/{kwargs['folder']}") + if (Path(kwargs["folder"]) / "dac").exists(): + generator, g_extra = DAC.load_from_folder(**kwargs) + if (Path(kwargs["folder"]) / "discriminator").exists(): + discriminator, d_extra = Discriminator.load_from_folder(**kwargs) + + generator = DAC() if generator is None else generator + discriminator = Discriminator() if discriminator is None else discriminator + + tracker.print(generator) + tracker.print(discriminator) + + generator = accel.prepare_model(generator) + discriminator = accel.prepare_model(discriminator) + + with argbind.scope(args, "generator"): + optimizer_g = AdamW(generator.parameters(), use_zero=accel.use_ddp) + scheduler_g = ExponentialLR(optimizer_g) + with argbind.scope(args, "discriminator"): + optimizer_d = AdamW(discriminator.parameters(), use_zero=accel.use_ddp) + scheduler_d = ExponentialLR(optimizer_d) + + if "optimizer.pth" in g_extra: + optimizer_g.load_state_dict(g_extra["optimizer.pth"]) + if "scheduler.pth" in g_extra: + scheduler_g.load_state_dict(g_extra["scheduler.pth"]) + if "tracker.pth" in g_extra: + tracker.load_state_dict(g_extra["tracker.pth"]) + + if "optimizer.pth" in d_extra: + optimizer_d.load_state_dict(d_extra["optimizer.pth"]) + if "scheduler.pth" in d_extra: + scheduler_d.load_state_dict(d_extra["scheduler.pth"]) + + sample_rate = accel.unwrap(generator).sample_rate + with argbind.scope(args, "train"): + train_data = build_dataset(sample_rate) + with argbind.scope(args, "val"): + val_data = build_dataset(sample_rate) + + waveform_loss = nn.L1Loss() + stft_loss = losses.MultiScaleSTFTLoss() + mel_loss = losses.MelSpectrogramLoss() + gan_loss = losses.GANLoss(discriminator) + + return State( + generator=generator, + optimizer_g=optimizer_g, + scheduler_g=scheduler_g, + discriminator=discriminator, + optimizer_d=optimizer_d, + scheduler_d=scheduler_d, + waveform_loss=waveform_loss, + stft_loss=stft_loss, + mel_loss=mel_loss, + gan_loss=gan_loss, + tracker=tracker, + train_data=train_data, + val_data=val_data, ) + + +@timer() +@paddle.no_grad() +def val_loop(batch, state, accel): + state.generator.eval() + batch = util.prepare_batch(batch, accel.device) + signal = state.val_data.transform(batch["signal"].clone(), + **batch["transform_args"]) + + out = state.generator(signal.audio_data, signal.sample_rate) + recons = AudioSignal(out["audio"], signal.sample_rate) + + return { + "loss": state.mel_loss(recons, signal), + "mel/loss": state.mel_loss(recons, signal), + "stft/loss": state.stft_loss(recons, signal), + "waveform/loss": state.waveform_loss(recons, signal), + } + + +@timer() +def train_loop(state, batch, accel, lambdas): + state.generator.train() + state.discriminator.train() + output = {} + + batch = util.prepare_batch(batch, accel.device) + with paddle.no_grad(): + signal = state.train_data.transform(batch["signal"].clone(), + **batch["transform_args"]) + + with accel.autocast(): + out = state.generator(signal.audio_data, signal.sample_rate) + recons = AudioSignal(out["audio"], signal.sample_rate) + commitment_loss = out["vq/commitment_loss"] + codebook_loss = out["vq/codebook_loss"] + + with accel.autocast(): + output["adv/disc_loss"] = state.gan_loss.discriminator_loss(recons, + signal) + + state.optimizer_d.zero_grad() + accel.backward(output["adv/disc_loss"]) + accel.scaler.unscale_(state.optimizer_d) + output["other/grad_norm_d"] = paddle.nn.utils.clip_grad_norm_( + state.discriminator.parameters(), 10.0) + accel.step(state.optimizer_d) + state.scheduler_d.step() + + with accel.autocast(): + output["stft/loss"] = state.stft_loss(recons, signal) + output["mel/loss"] = state.mel_loss(recons, signal) + output["waveform/loss"] = state.waveform_loss(recons, signal) + (output["adv/gen_loss"], + output["adv/feat_loss"], ) = state.gan_loss.generator_loss(recons, + signal) + output["vq/commitment_loss"] = commitment_loss + output["vq/codebook_loss"] = codebook_loss + output["loss"] = sum( + [v * output[k] for k, v in lambdas.items() if k in output]) + + state.optimizer_g.zero_grad() + accel.backward(output["loss"]) + accel.scaler.unscale_(state.optimizer_g) + output["other/grad_norm"] = paddle.nn.utils.clip_grad_norm_( + state.generator.parameters(), 1e3) + accel.step(state.optimizer_g) + state.scheduler_g.step() + accel.update() + + output["other/learning_rate"] = state.optimizer_g.param_groups[0]["lr"] + output["other/batch_size"] = signal.batch_size * accel.world_size + + return {k: v for k, v in sorted(output.items())} + + +def checkpoint(state, save_iters, save_path): + metadata = {"logs": state.tracker.history} + + tags = ["latest"] + state.tracker.print(f"Saving to {str(Path('.').absolute())}") + if state.tracker.is_best("val", "mel/loss"): + state.tracker.print("Best generator so far") + tags.append("best") + if state.tracker.step in save_iters: + tags.append(f"{state.tracker.step // 1000}k") + + for tag in tags: + generator_extra = { + "optimizer.pth": state.optimizer_g.state_dict(), + "scheduler.pth": state.scheduler_g.state_dict(), + "tracker.pth": state.tracker.state_dict(), + "metadata.pth": metadata, + } + accel.unwrap(state.generator).metadata = metadata + accel.unwrap(state.generator).save_to_folder(f"{save_path}/{tag}", + generator_extra) + discriminator_extra = { + "optimizer.pth": state.optimizer_d.state_dict(), + "scheduler.pth": state.scheduler_d.state_dict(), + } + accel.unwrap(state.discriminator).save_to_folder(f"{save_path}/{tag}", + discriminator_extra) + + +@paddle.no_grad() +def save_samples(state, val_idx, writer): + state.tracker.print("Saving audio samples to TensorBoard") + state.generator.eval() + + samples = [state.val_data[idx] for idx in val_idx] + batch = state.val_data.collate(samples) + batch = util.prepare_batch(batch, accel.device) + signal = state.train_data.transform(batch["signal"].clone(), + **batch["transform_args"]) + + out = state.generator(signal.audio_data, signal.sample_rate) + recons = AudioSignal(out["audio"], signal.sample_rate) + + audio_dict = {"recons": recons} + if state.tracker.step == 0: + audio_dict["signal"] = signal + + for k, v in audio_dict.items(): + for nb in range(v.batch_size): + v[nb].cpu().write_audio_to_tb(f"{k}/sample_{nb}.wav", writer, + state.tracker.step) + + +def validate(state, val_dataloader, accel): + for batch in val_dataloader: + output = val_loop(batch, state, accel) + # Consolidate state dicts if using ZeroRedundancyOptimizer + if hasattr(state.optimizer_g, "consolidate_state_dict"): + state.optimizer_g.consolidate_state_dict() + state.optimizer_d.consolidate_state_dict() + return output + + +# @argbind.bind(without_prefix=True) +def train( + args, + accel: Accelerator, + seed: int=2025, + save_path: str="ckpt", + num_iters: int=250000, + save_iters: list=[10000, 50000, 100000, 200000], + sample_freq: int=10000, + valid_freq: int=1000, + batch_size: int=12, + val_batch_size: int=10, + num_workers: int=8, + val_idx: list=[0, 1, 2, 3, 4, 5, 6, 7], + lambdas: dict={ + "mel/loss": 100.0, + "adv/feat_loss": 2.0, + "adv/gen_loss": 1.0, + "vq/commitment_loss": 0.25, + "vq/codebook_loss": 1.0, + }, ): + seed_everything(seed) + Path(save_path).mkdir(exist_ok=True, parents=True) + writer = LogWriter( + log_dir=f"{save_path}/logs") if accel.local_rank == 0 else None + tracker = Tracker( + writer=writer, log_file=f"{save_path}/log.txt", rank=accel.local_rank) + + state = load(args, accel, tracker, save_path) + train_dataloader = accel.prepare_dataloader( + state.train_data, + start_idx=state.tracker.step * batch_size, + num_workers=num_workers, + batch_size=batch_size, + collate_fn=state.train_data.collate, ) + train_dataloader = get_infinite_loader(train_dataloader) + val_dataloader = accel.prepare_dataloader( + state.val_data, + start_idx=0, + num_workers=num_workers, + batch_size=val_batch_size, + collate_fn=state.val_data.collate, + persistent_workers=True if num_workers > 0 else False, ) + + # Wrap the functions so that they neatly track in TensorBoard + progress bars + # and only run when specific conditions are met. + global train_loop, val_loop, validate, save_samples, checkpoint + train_loop = tracker.log( + "train", "value", history=False)(tracker.track( + "train", num_iters, completed=state.tracker.step)(train_loop)) + val_loop = tracker.track("val", len(val_dataloader))(val_loop) + validate = tracker.log("val", "mean")(validate) + + # These functions run only on the 0-rank process + save_samples = when(lambda: accel.local_rank == 0)(save_samples) + checkpoint = when(lambda: accel.local_rank == 0)(checkpoint) + + with tracker.live: + for tracker.step, batch in enumerate( + train_dataloader, start=tracker.step): + train_loop(state, batch, accel, lambdas) + + last_iter = (tracker.step == num_iters - 1 + if num_iters is not None else False) + if tracker.step % sample_freq == 0 or last_iter: + save_samples(state, val_idx, writer) + + if tracker.step % valid_freq == 0 or last_iter: + validate(state, val_dataloader, accel) + checkpoint(state, save_iters, save_path) + # Reset validation progress bar, print summary since last validation. + tracker.done("val", f"Iteration {tracker.step}") + + if last_iter: + break + + +if __name__ == "__main__": + # args = argbind.parse_args() + # args["args.debug"] = int(os.getenv("LOCAL_RANK", 0)) == 0 + # with argbind.scope(args): + with Accelerator() as accel: + if accel.local_rank != 0: + sys.tracebacklimit = 0 + # train(args, accel) + train(None, accel) diff --git a/paddlespeech/codec/models/dac_/model/__init__.py b/paddlespeech/codec/models/dac_/model/__init__.py new file mode 100644 index 000000000..02a75b7ad --- /dev/null +++ b/paddlespeech/codec/models/dac_/model/__init__.py @@ -0,0 +1,4 @@ +from .base import CodecMixin +from .base import DACFile +from .dac import DAC +from .discriminator import Discriminator diff --git a/paddlespeech/codec/models/dac_/model/base.py b/paddlespeech/codec/models/dac_/model/base.py new file mode 100644 index 000000000..87617f6af --- /dev/null +++ b/paddlespeech/codec/models/dac_/model/base.py @@ -0,0 +1,298 @@ +import math +from dataclasses import dataclass +from pathlib import Path +from typing import Union + +import numpy as np +import paddle +import tqdm +from paddle import nn + +from paddlespeech.audiotools import AudioSignal + +SUPPORTED_VERSIONS = ["1.0.0"] + + +@dataclass +class DACFile: + codes: paddle.Tensor + + # Metadata + chunk_length: int + original_length: int + input_db: paddle.Tensor + channels: int + sample_rate: int + padding: bool + dac_version: str + + def save(self, path): + artifacts = { + "codes": self.codes.numpy().astype(np.uint16), + "metadata": { + "input_db": self.input_db.numpy().astype(np.float32), + "original_length": self.original_length, + "sample_rate": self.sample_rate, + "chunk_length": self.chunk_length, + "channels": self.channels, + "padding": self.padding, + "dac_version": SUPPORTED_VERSIONS[-1], + }, + } + path = Path(path).with_suffix(".dac") + with open(path, "wb") as f: + np.save(f, artifacts) + return path + + @classmethod + def load(cls, path): + artifacts = np.load(path, allow_pickle=True)[()] + codes = paddle.to_tensor(artifacts["codes"].astype(int)) + if artifacts["metadata"].get("dac_version", + None) not in SUPPORTED_VERSIONS: + raise RuntimeError( + f"Given file {path} can't be loaded with this version of descript-audio-codec." + ) + return cls(codes=codes, **artifacts["metadata"]) + + +class CodecMixin: + @property + def padding(self): + if not hasattr(self, "_padding"): + self._padding = True + return self._padding + + @padding.setter + def padding(self, value): + assert isinstance(value, bool) + + layers = [ + l for l in self.sublayers() + if isinstance(l, (nn.Conv1D, nn.Conv1DTranspose)) + ] + + for layer in layers: + if value: + if hasattr(layer, "original_padding"): + layer._padding = layer.original_padding + else: + if isinstance(layer._padding, int): + # TODO: drryanhuang, fix this condition + layer._padding = [layer._padding] + layer.original_padding = layer._padding + layer._padding = tuple(0 for _ in range(len(layer._padding))) + + self._padding = value + + def get_delay(self): + # Any number works here, delay is invariant to input length + l_out = self.get_output_length(0) + L = l_out + + layers = [] + for layer in self.sublayers(): + if isinstance(layer, (nn.Conv1D, nn.Conv1DTranspose)): + layers.append(layer) + + for layer in reversed(layers): + d = layer._dilation[0] + k = layer._kernel_size[0] + s = layer._stride[0] + + if isinstance(layer, nn.Conv1DTranspose): + L = ((L - d * (k - 1) - 1) / s) + 1 + elif isinstance(layer, nn.Conv1D): + L = (L - 1) * s + d * (k - 1) + 1 + + L = math.ceil(L) + + l_in = L + + return (l_in - l_out) // 2 + + def get_output_length(self, input_length): + L = input_length + # Calculate output length + for layer in self.sublayers(): + if isinstance(layer, (nn.Conv1D, nn.Conv1DTranspose)): + d = layer._dilation[0] + k = layer._kernel_size[0] + s = layer._stride[0] + + if isinstance(layer, nn.Conv1D): + L = ((L - d * (k - 1) - 1) / s) + 1 + elif isinstance(layer, nn.Conv1DTranspose): + L = (L - 1) * s + d * (k - 1) + 1 + + L = math.floor(L) + return L + + @paddle.no_grad() + def compress( + self, + audio_path_or_signal: Union[str, Path, AudioSignal], + win_duration: float=1.0, + verbose: bool=False, + normalize_db: float=-16, + n_quantizers: int=None, ) -> DACFile: + """Processes an audio signal from a file or AudioSignal object into + discrete codes. This function processes the signal in short windows, + using constant GPU memory. + + Parameters + ---------- + audio_path_or_signal : Union[str, Path, AudioSignal] + audio signal to reconstruct + win_duration : float, optional + window duration in seconds, by default 5.0 + verbose : bool, optional + by default False + normalize_db : float, optional + normalize db, by default -16 + + Returns + ------- + DACFile + Object containing compressed codes and metadata + required for decompression + """ + audio_signal = audio_path_or_signal + if isinstance(audio_signal, (str, Path)): + audio_signal = AudioSignal.load_from_file_with_ffmpeg( + str(audio_signal)) + + self.eval() + original_padding = self.padding + original_device = audio_signal.device + + audio_signal = audio_signal.clone() + original_sr = audio_signal.sample_rate + + resample_fn = audio_signal.resample + loudness_fn = audio_signal.loudness + + # If audio is > 10 minutes long, use the ffmpeg versions + if audio_signal.signal_duration >= 10 * 60 * 60: + resample_fn = audio_signal.ffmpeg_resample + loudness_fn = audio_signal.ffmpeg_loudness + + original_length = audio_signal.signal_length + resample_fn(self.sample_rate) + input_db = loudness_fn() + + if normalize_db is not None: + audio_signal.normalize(normalize_db) + audio_signal.ensure_max_of_audio() + + nb, nac, nt = audio_signal.audio_data.shape + audio_signal.audio_data = audio_signal.audio_data.reshape( + [nb * nac, 1, nt]) + win_duration = (audio_signal.signal_duration + if win_duration is None else win_duration) + + if audio_signal.signal_duration <= win_duration: + # Unchunked compression (used if signal length < win duration) + self.padding = True + n_samples = nt + hop = nt + else: + # Chunked inference + self.padding = False + # Zero-pad signal on either side by the delay + audio_signal.zero_pad(self.delay, self.delay) + n_samples = int(win_duration * self.sample_rate) + # Round n_samples to nearest hop length multiple + n_samples = int( + math.ceil(n_samples / self.hop_length) * self.hop_length) + hop = self.get_output_length(n_samples) + + codes = [] + range_fn = range if not verbose else tqdm.trange + + for i in range_fn(0, nt, hop): + x = audio_signal[..., i:i + n_samples] + x = x.zero_pad(0, max(0, n_samples - x.shape[-1])) + + audio_data = x.audio_data + audio_data = self.preprocess(audio_data, self.sample_rate) + _, c, _, _, _ = self.encode(audio_data, n_quantizers) + codes.append(c) + chunk_length = c.shape[-1] + + codes = paddle.concat(codes, axis=-1) + + dac_file = DACFile( + codes=codes, + chunk_length=chunk_length, + original_length=original_length, + input_db=input_db, + channels=nac, + sample_rate=original_sr, + padding=self.padding, + dac_version=SUPPORTED_VERSIONS[-1], ) + + if n_quantizers is not None: + codes = codes[:, :n_quantizers, :] + + self.padding = original_padding + return dac_file + + @paddle.no_grad() + def decompress( + self, + obj: Union[str, Path, DACFile], + verbose: bool=False, ) -> AudioSignal: + """Reconstruct audio from a given .dac file + + Parameters + ---------- + obj : Union[str, Path, DACFile] + .dac file location or corresponding DACFile object. + verbose : bool, optional + Prints progress if True, by default False + + Returns + ------- + AudioSignal + Object with the reconstructed audio + """ + self.eval() + if isinstance(obj, (str, Path)): + obj = DACFile.load(obj) + + original_padding = self.padding + self.padding = obj.padding + + range_fn = range if not verbose else tqdm.trange + codes = obj.codes + # original_device = codes.device + chunk_length = obj.chunk_length + recons = [] + + for i in range_fn(0, codes.shape[-1], chunk_length): + c = codes[..., i:i + chunk_length] + z = self.quantizer.from_codes(c)[0] + r = self.decode(z) + recons.append(r) + + recons = paddle.concat(recons, axis=-1) + recons = AudioSignal(recons, self.sample_rate) + + resample_fn = recons.resample + loudness_fn = recons.loudness + + # If audio is > 10 minutes long, use the ffmpeg versions + if recons.signal_duration >= 10 * 60 * 60: + resample_fn = recons.ffmpeg_resample + loudness_fn = recons.ffmpeg_loudness + + recons.normalize(obj.input_db) + resample_fn(obj.sample_rate) + recons = recons[..., :obj.original_length] + loudness_fn() + recons.audio_data = recons.audio_data.reshape( + [-1, obj.channels, obj.original_length]) + + self.padding = original_padding + return recons diff --git a/paddlespeech/codec/models/dac_/model/dac.py b/paddlespeech/codec/models/dac_/model/dac.py new file mode 100644 index 000000000..3d182d152 --- /dev/null +++ b/paddlespeech/codec/models/dac_/model/dac.py @@ -0,0 +1,354 @@ +from functools import partial + +import numpy as np +import paddle +import paddle.nn.functional as F +from paddle import nn + +from paddlespeech.audiotools import AudioSignal +from paddlespeech.codec.models.dac_.model.base import CodecMixin +from paddlespeech.codec.models.dac_.nn.layers import Snake1d +from paddlespeech.codec.models.dac_.nn.layers import WNConv1d +from paddlespeech.codec.models.dac_.nn.layers import WNConvTranspose1d +from paddlespeech.codec.models.dac_.nn.quantize import ResidualVectorQuantize + + +def init_weights(m): + if isinstance(m, nn.Conv1D): + nn.initializer.TruncatedNormal(std=0.02)(m.weight) + nn.initializer.Constant(0)(m.bias) + + +class ResidualUnit(nn.Layer): + def __init__(self, dim: int=16, dilation: int=1): + super(ResidualUnit, self).__init__() + pad = ((7 - 1) * dilation) // 2 + self.block = nn.Sequential( + Snake1d(dim), + WNConv1d(dim, dim, kernel_size=7, dilation=dilation, padding=pad), + Snake1d(dim), + WNConv1d(dim, dim, kernel_size=1), ) + + def forward(self, x): + y = self.block(x) + pad = (x.shape[-1] - y.shape[-1]) // 2 + if pad > 0: + x = x[..., pad:-pad] + return x + y + + +class EncoderBlock(nn.Layer): + def __init__(self, dim: int=16, stride: int=1): + super(EncoderBlock, self).__init__() + self.block = nn.Sequential( + ResidualUnit(dim // 2, dilation=1), + ResidualUnit(dim // 2, dilation=3), + ResidualUnit(dim // 2, dilation=9), + Snake1d(dim // 2), + WNConv1d( + dim // 2, + dim, + kernel_size=2 * stride, + stride=stride, + padding=int(np.ceil(stride / 2).astype(int)), ), ) + + def forward(self, x): + return self.block(x) + + +class Encoder(nn.Layer): + def __init__( + self, + d_model: int=64, + strides: list=[2, 4, 8, 8], + d_latent: int=64, ): + super(Encoder, self).__init__() + # Create first convolution + self.block = [WNConv1d(1, d_model, kernel_size=7, padding=3)] + + # Create EncoderBlocks that double channels as they downsample by `stride` + for stride in strides: + d_model *= 2 + self.block += [EncoderBlock(d_model, stride=stride)] + + # Create last convolution + self.block += [ + Snake1d(d_model), + WNConv1d(d_model, d_latent, kernel_size=3, padding=1), + ] + + # Wrap black into nn.Sequential + self.block = nn.Sequential(*self.block) + self.enc_dim = d_model + + def forward(self, x): + return self.block(x) + + +class DecoderBlock(nn.Layer): + def __init__(self, input_dim: int=16, output_dim: int=8, stride: int=1): + super(DecoderBlock, self).__init__() + self.block = nn.Sequential( + Snake1d(input_dim), + WNConvTranspose1d( + input_dim, + output_dim, + kernel_size=2 * stride, + stride=stride, + padding=int(np.ceil(stride / 2).astype(int)), ), + ResidualUnit(output_dim, dilation=1), + ResidualUnit(output_dim, dilation=3), + ResidualUnit(output_dim, dilation=9), ) + + def forward(self, x): + return self.block(x) + + +class Decoder(nn.Layer): + def __init__( + self, + input_channel, + channels, + rates, + d_out: int=1, ): + super(Decoder, self).__init__() + + # Add first conv layer + layers = [WNConv1d(input_channel, channels, kernel_size=7, padding=3)] + + # Add upsampling + MRF blocks + for i, stride in enumerate(rates): + input_dim = channels // 2**i + output_dim = channels // 2**(i + 1) + layers += [DecoderBlock(input_dim, output_dim, stride)] + + # Add final conv layer + layers += [ + Snake1d(output_dim), + WNConv1d(output_dim, d_out, kernel_size=7, padding=3), + nn.Tanh(), + ] + + self.model = nn.Sequential(*layers) + + def forward(self, x): + return self.model(x) + + +class DAC(nn.Layer, CodecMixin): + def __init__( + self, + encoder_dim: int=64, + encoder_rates: list=[2, 4, 8, 8], + latent_dim: int=None, + decoder_dim: int=1536, + decoder_rates: list=[8, 8, 4, 2], + n_codebooks: int=9, + codebook_size: int=1024, + codebook_dim: int=8, + quantizer_dropout: bool=False, + sample_rate: int=44100, ): + super(DAC, self).__init__() + + self.encoder_dim = encoder_dim + self.encoder_rates = encoder_rates + self.decoder_dim = decoder_dim + self.decoder_rates = decoder_rates + self.sample_rate = sample_rate + + if latent_dim is None: + latent_dim = encoder_dim * (2**len(encoder_rates)) + + self.latent_dim = latent_dim + + self.hop_length = np.prod(encoder_rates) + self.encoder = Encoder(encoder_dim, encoder_rates, latent_dim) + + self.n_codebooks = n_codebooks + self.codebook_size = codebook_size + self.codebook_dim = codebook_dim + self.quantizer = ResidualVectorQuantize( + input_dim=latent_dim, + n_codebooks=n_codebooks, + codebook_size=codebook_size, + codebook_dim=codebook_dim, + quantizer_dropout=quantizer_dropout, ) + + self.decoder = Decoder( + latent_dim, + decoder_dim, + decoder_rates, ) + self.sample_rate = sample_rate + self.apply(init_weights) + + self.delay = self.get_delay() + + def preprocess(self, audio_data, sample_rate): + if sample_rate is None: + sample_rate = self.sample_rate + assert sample_rate == self.sample_rate + + length = audio_data.shape[-1] + right_pad = np.ceil(length / self.hop_length) * self.hop_length - length + audio_data = F.pad( + audio_data, [0, int(right_pad)], + mode='constant', + value=0, + data_format="NCL") + + return audio_data + + def encode( + self, + audio_data: paddle.Tensor, + n_quantizers: int=None, ): + """Encode given audio data and return quantized latent codes + + Parameters + ---------- + audio_data : Tensor[B x 1 x T] + Audio data to encode + n_quantizers : int, optional + Number of quantizers to use, by default None + If None, all quantizers are used. + + Returns + ------- + dict + A dictionary with the following keys: + "z" : Tensor[B x D x T] + Quantized continuous representation of input + "codes" : Tensor[B x N x T] + Codebook indices for each codebook + (quantized discrete representation of input) + "latents" : Tensor[B x N*D x T] + Projected latents (continuous representation of input before quantization) + "vq/commitment_loss" : Tensor[1] + Commitment loss to train encoder to predict vectors closer to codebook + entries + "vq/codebook_loss" : Tensor[1] + Codebook loss to update the codebook + "length" : int + Number of samples in input audio + """ + z = self.encoder(audio_data) + z, codes, latents, commitment_loss, codebook_loss = self.quantizer( + z, n_quantizers) + return z, codes, latents, commitment_loss, codebook_loss + + def decode(self, z: paddle.Tensor): + """Decode given latent codes and return audio data + + Parameters + ---------- + z : Tensor[B x D x T] + Quantized continuous representation of input + length : int, optional + Number of samples in output audio, by default None + + Returns + ------- + dict + A dictionary with the following keys: + "audio" : Tensor[B x 1 x length] + Decoded audio data. + """ + return self.decoder(z) + + def forward( + self, + audio_data: paddle.Tensor, + sample_rate: int=None, + n_quantizers: int=None, ): + """Model forward pass + + Parameters + ---------- + audio_data : Tensor[B x 1 x T] + Audio data to encode + sample_rate : int, optional + Sample rate of audio data in Hz, by default None + If None, defaults to `self.sample_rate` + n_quantizers : int, optional + Number of quantizers to use, by default None. + If None, all quantizers are used. + + Returns + ------- + dict + A dictionary with the following keys: + "z" : Tensor[B x D x T] + Quantized continuous representation of input + "codes" : Tensor[B x N x T] + Codebook indices for each codebook + (quantized discrete representation of input) + "latents" : Tensor[B x N*D x T] + Projected latents (continuous representation of input before quantization) + "vq/commitment_loss" : Tensor[1] + Commitment loss to train encoder to predict vectors closer to codebook + entries + "vq/codebook_loss" : Tensor[1] + Codebook loss to update the codebook + "length" : int + Number of samples in input audio + "audio" : Tensor[B x 1 x length] + Decoded audio data. + """ + length = audio_data.shape[-1] + audio_data = self.preprocess(audio_data, sample_rate) + z, codes, latents, commitment_loss, codebook_loss = self.encode( + audio_data, n_quantizers) + + x = self.decode(z) + return { + "audio": x[..., :length], + "z": z, + "codes": codes, + "latents": latents, + "vq/commitment_loss": commitment_loss, + "vq/codebook_loss": codebook_loss, + } + + +if __name__ == "__main__": + + model = DAC() + + def fn(o, p): + return o + f" {p / 1e6:<.3f}M params." + + for n, m in model.named_sublayers(): + # print(type(m)) + o = m.extra_repr() + p = sum([np.prod(p.shape) for p in m.parameters()]) + setattr(m, "extra_repr", partial(fn, o=o, p=p)) + print(model) + print("Total # of params: ", + sum([np.prod(p.shape) for p in model.parameters()])) + + length = 88200 * 2 + x = paddle.randn([1, 1, length]) + x.stop_gradient = False + + # Make a forward pass + out = model(x)["audio"] + print("Input shape:", x.shape) + print("Output shape:", out.shape) + + # Create gradient variable + grad = paddle.zeros_like(out) + grad[:, :, grad.shape[-1] // 2] = 1 + + # Make a backward pass + out.backward(grad) + + # Check non-zero values + gradmap = x.grad.squeeze(0) + gradmap = (gradmap != 0).sum(0) # sum across features + rf = (gradmap != 0).sum() + + print(f"Receptive field: {rf.item()}" + ) # TODO: drryanhuang, fix this question, why is RF zero? + + x = AudioSignal(paddle.randn([1, 1, 44100 * 60]), 44100) + model.decompress(model.compress(x, verbose=True), verbose=True) diff --git a/paddlespeech/codec/models/dac_/model/discriminator.py b/paddlespeech/codec/models/dac_/model/discriminator.py new file mode 100644 index 000000000..9e4668e0e --- /dev/null +++ b/paddlespeech/codec/models/dac_/model/discriminator.py @@ -0,0 +1,235 @@ +import paddle +import paddle.nn as nn +import paddle.nn.functional as F + +from paddlespeech.audiotools import AudioSignal +from paddlespeech.audiotools import ml +from paddlespeech.audiotools import STFTParams + + +def WNConv1d(*args, **kwargs): + act = kwargs.pop("act", True) + model = nn.Conv1D(*args, **kwargs) + conv = nn.utils.weight_norm(model) + if not act: + return conv + return nn.Sequential(conv, nn.LeakyReLU(0.1)) + + +def WNConv2d(*args, **kwargs): + act = kwargs.pop("act", True) + model = nn.Conv2D(*args, **kwargs) + conv = nn.utils.weight_norm(model) + if not act: + return conv + return nn.Sequential(conv, nn.LeakyReLU(0.1)) + + +class MPD(nn.Layer): + def __init__(self, period): + super(MPD, self).__init__() + self.period = period + self.convs = nn.LayerList([ + WNConv2d(1, 32, (5, 1), (3, 1), padding=(2, 0)), + WNConv2d(32, 128, (5, 1), (3, 1), padding=(2, 0)), + WNConv2d(128, 512, (5, 1), (3, 1), padding=(2, 0)), + WNConv2d(512, 1024, (5, 1), (3, 1), padding=(2, 0)), + WNConv2d(1024, 1024, (5, 1), 1, padding=(2, 0)), + ]) + self.conv_post = WNConv2d( + 1024, 1, kernel_size=(3, 1), padding=(1, 0), act=False) + + def pad_to_period(self, x): + t = x.shape[-1] + x = F.pad( + x, (0, self.period - t % self.period), + mode="reflect", + data_format="NCL") + return x + + def forward(self, x): + fmap = [] + + x = self.pad_to_period(x) + # x = rearrange(x, "b c (l p) -> b c l p", p=self.period) + b, c, lp = x.shape + l = lp // self.period + p = self.period + x = x.reshape([b, c, l, p]) + + for layer in self.convs: + x = layer(x) + fmap.append(x) + + x = self.conv_post(x) + fmap.append(x) + + return fmap + + +class MSD(nn.Layer): + def __init__(self, rate: int=1, sample_rate: int=44100): + super(MSD, self).__init__() + self.convs = nn.LayerList([ + WNConv1d(1, 16, 15, 1, padding=7), + WNConv1d(16, 64, 41, 4, groups=4, padding=20), + WNConv1d(64, 256, 41, 4, groups=16, padding=20), + WNConv1d(256, 1024, 41, 4, groups=64, padding=20), + WNConv1d(1024, 1024, 41, 4, groups=256, padding=20), + WNConv1d(1024, 1024, 5, 1, padding=2), + ]) + self.conv_post = WNConv1d(1024, 1, 3, 1, padding=1, act=False) + self.sample_rate = sample_rate + self.rate = rate + + def forward(self, x): + x = AudioSignal(x, self.sample_rate) + x.resample(self.sample_rate // self.rate) + x = x.audio_data + + fmap = [] + + for l in self.convs: + x = l(x) + fmap.append(x) + x = self.conv_post(x) + fmap.append(x) + + return fmap + + +BANDS = [(0.0, 0.1), (0.1, 0.25), (0.25, 0.5), (0.5, 0.75), (0.75, 1.0)] + + +class MRD(nn.Layer): + def __init__( + self, + window_length: int, + hop_factor: float=0.25, + sample_rate: int=44100, + bands: list=BANDS, ): + """Complex multi-band spectrogram discriminator. + Parameters + ---------- + window_length : int + Window length of STFT. + hop_factor : float, optional + Hop factor of the STFT, defaults to ``0.25 * window_length``. + sample_rate : int, optional + Sampling rate of audio in Hz, by default 44100 + bands : list, optional + Bands to run discriminator over. + """ + super(MRD, self).__init__() + + self.window_length = window_length + self.hop_factor = hop_factor + self.sample_rate = sample_rate + self.stft_params = STFTParams( + window_length=window_length, + hop_length=int(window_length * hop_factor), + match_stride=True, ) + + n_fft = window_length // 2 + 1 + bands = [(int(b[0] * n_fft), int(b[1] * n_fft)) for b in bands] + self.bands = bands + + ch = 32 + + def convs(): + return nn.LayerList([ + WNConv2d(2, ch, (3, 9), (1, 1), padding=(1, 4)), + WNConv2d(ch, ch, (3, 9), (1, 2), padding=(1, 4)), + WNConv2d(ch, ch, (3, 9), (1, 2), padding=(1, 4)), + WNConv2d(ch, ch, (3, 9), (1, 2), padding=(1, 4)), + WNConv2d(ch, ch, (3, 3), (1, 1), padding=(1, 1)), + ]) + + self.band_convs = nn.LayerList( + [convs() for _ in range(len(self.bands))]) + self.conv_post = WNConv2d( + ch, 1, (3, 3), (1, 1), padding=(1, 1), act=False) + + def spectrogram(self, x): + x = AudioSignal(x, self.sample_rate, stft_params=self.stft_params) + x = paddle.as_real(x.stft()) + # x = rearrange(x, "b 1 f t c -> (b 1) c t f") + x = x.transpose([0, 1, 4, 3, 2]).flatten(stop_axis=1) + + # Split into bands + x_bands = [x[..., b[0]:b[1]] for b in self.bands] + return x_bands + + def forward(self, x): + x_bands = self.spectrogram(x) + fmap = [] + + x = [] + for band, stack in zip(x_bands, self.band_convs): + for layer in stack: + band = layer(band) + fmap.append(band) + x.append(band) + + x = paddle.concat(x, axis=-1) + x = self.conv_post(x) + fmap.append(x) + + return fmap + + +class Discriminator(ml.BaseModel): + def __init__( + self, + rates: list=[], + periods: list=[2, 3, 5, 7, 11], + fft_sizes: list=[2048, 1024, 512], + sample_rate: int=44100, + bands: list=BANDS, ): + """Discriminator that combines multiple discriminators. + + Parameters + ---------- + rates : list, optional + sampling rates (in Hz) to run MSD at, by default [] + If empty, MSD is not used. + periods : list, optional + periods (of samples) to run MPD at, by default [2, 3, 5, 7, 11] + fft_sizes : list, optional + Window sizes of the FFT to run MRD at, by default [2048, 1024, 512] + sample_rate : int, optional + Sampling rate of audio in Hz, by default 44100 + bands : list, optional + Bands to run MRD at, by default `BANDS` + """ + super(Discriminator, self).__init__() + discs = [] + discs += [MPD(p) for p in periods] + discs += [MSD(r, sample_rate=sample_rate) for r in rates] + discs += [ + MRD(f, sample_rate=sample_rate, bands=bands) for f in fft_sizes + ] + self.discriminators = nn.LayerList(discs) + + def preprocess(self, y): + # Remove DC offset + y = y - y.mean(axis=-1, keepdim=True) + # Peak normalize the volume of input audio + y = 0.8 * y / (paddle.abs(y).max(axis=-1, keepdim=True)[0] + 1e-9) + return y + + def forward(self, x): + x = self.preprocess(x) + fmaps = [d(x) for d in self.discriminators] + return fmaps + + +if __name__ == "__main__": + disc = Discriminator() + x = paddle.zeros([1, 1, 44100]) + results = disc(x) + for i, result in enumerate(results): + print(f"disc{i}") + for i, r in enumerate(result): + print(r.shape, r.mean(), r.min(), r.max()) + print() diff --git a/paddlespeech/codec/models/dac_/nn/__init__.py b/paddlespeech/codec/models/dac_/nn/__init__.py new file mode 100644 index 000000000..a04e28ca1 --- /dev/null +++ b/paddlespeech/codec/models/dac_/nn/__init__.py @@ -0,0 +1,3 @@ +from . import layers +from . import quantize +# from . import loss diff --git a/paddlespeech/codec/models/dac_/nn/layers.py b/paddlespeech/codec/models/dac_/nn/layers.py new file mode 100644 index 000000000..3e7d9c5ca --- /dev/null +++ b/paddlespeech/codec/models/dac_/nn/layers.py @@ -0,0 +1,31 @@ +import paddle +import paddle.nn as nn +from paddle.nn.utils import weight_norm + + +def WNConv1d(*args, **kwargs): + return weight_norm(nn.Conv1D(*args, **kwargs), name='weight', dim=1) + + +def WNConvTranspose1d(*args, **kwargs): + return weight_norm( + nn.Conv1DTranspose(*args, **kwargs), name='weight', dim=1) + + +def snake(x, alpha): + shape = x.shape + x = x.reshape([shape[0], shape[1], -1]) + x = x + (alpha + 1e-9).reciprocal() * paddle.sin(alpha * x).pow(2) + x = x.reshape(shape) + return x + + +class Snake1d(nn.Layer): + def __init__(self, channels): + super(Snake1d, self).__init__() + self.alpha = self.create_parameter( + shape=[1, channels, 1], + default_initializer=nn.initializer.Constant(1.0)) + + def forward(self, x): + return snake(x, self.alpha) diff --git a/paddlespeech/codec/models/dac_/nn/quantize.py b/paddlespeech/codec/models/dac_/nn/quantize.py new file mode 100644 index 000000000..27acb38cf --- /dev/null +++ b/paddlespeech/codec/models/dac_/nn/quantize.py @@ -0,0 +1,161 @@ +from typing import Union + +import numpy as np +import paddle +import paddle.nn as nn +import paddle.nn.functional as F +from paddle.nn.utils import weight_norm + +from paddlespeech.codec.models.dac_.nn.layers import WNConv1d + + +class VectorQuantize(nn.Layer): + def __init__(self, input_dim, codebook_size, codebook_dim): + super(VectorQuantize, self).__init__() + self.codebook_size = codebook_size + self.codebook_dim = codebook_dim + + self.in_proj = WNConv1d(input_dim, codebook_dim, kernel_size=1) + self.out_proj = WNConv1d(codebook_dim, input_dim, kernel_size=1) + self.codebook = nn.Embedding(codebook_size, codebook_dim) + + def forward(self, z): + z_e = self.in_proj(z) # z_e : (B x D x T) + z_q, indices = self.decode_latents(z_e) + + commitment_loss = F.mse_loss( + z_e, z_q.detach(), reduction='none').mean(axis=[1, 2]) + codebook_loss = F.mse_loss( + z_q, z_e.detach(), reduction='none').mean(axis=[1, 2]) + + z_q = z_e + (z_q - z_e).detach( + ) # noop in forward pass, straight-through gradient estimator in backward pass + z_q = self.out_proj(z_q) + + return z_q, commitment_loss, codebook_loss, indices, z_e + + def embed_code(self, embed_id): + return F.embedding(embed_id, self.codebook.weight) + + def decode_code(self, embed_id): + return self.embed_code(embed_id).transpose([0, 2, 1]) + + def decode_latents(self, latents): + # encodings = Rearrange('b d t -> (b t) d')(latents) + encodings = latents.transpose([0, 2, 1]).flatten(stop_axis=1) + codebook = self.codebook.weight # codebook: (N x D) + + # L2 normalize encodings and codebook (ViT-VQGAN) + encodings = F.normalize(encodings, axis=1) + codebook = F.normalize(codebook, axis=1) + + # Compute euclidean distance with codebook + dist = (encodings.pow(2).sum(axis=1, keepdim=True) - 2 * encodings + @ codebook.T + codebook.pow(2).sum(axis=1, keepdim=True).T) + # indices = rearrange((-dist).max(1)[1], "(b t) -> b t", b=latents.size(0)) + indices = (-dist).argmax(axis=1).reshape([latents.shape[0], -1]) + z_q = self.decode_code(indices) + return z_q, indices + + +class ResidualVectorQuantize(nn.Layer): + def __init__( + self, + input_dim: int=512, + n_codebooks: int=9, + codebook_size: int=1024, + codebook_dim: Union[int, list]=8, + quantizer_dropout: float=0.0, ): + super().__init__() + if isinstance(codebook_dim, int): + codebook_dim = [codebook_dim for _ in range(n_codebooks)] + + self.n_codebooks = n_codebooks + self.codebook_dim = codebook_dim + self.codebook_size = codebook_size + + self.quantizers = nn.LayerList([ + VectorQuantize(input_dim, codebook_size, codebook_dim[i]) + for i in range(n_codebooks) + ]) + self.quantizer_dropout = quantizer_dropout + + def forward(self, z, n_quantizers: int=None): + z_q = 0 + residual = z + commitment_loss = 0 + codebook_loss = 0 + + codebook_indices = [] + latents = [] + + if n_quantizers is None: + n_quantizers = self.n_codebooks + if self.training: + n_quantizers = paddle.ones((z.shape[0], )) * self.n_codebooks + 1 + dropout = paddle.randint(1, self.n_codebooks + 1, (z.shape[0], )) + n_dropout = int(z.shape[0] * self.quantizer_dropout) + if dropout[:n_dropout].size: + n_quantizers[:n_dropout] = dropout[:n_dropout] + + for i, quantizer in enumerate(self.quantizers): + if not self.training and i >= n_quantizers: + break + + z_q_i, commitment_loss_i, codebook_loss_i, indices_i, z_e_i = quantizer( + residual) + + mask = (paddle.full( + (z.shape[0], ), fill_value=i) < n_quantizers).astype("float32") + z_q = z_q + z_q_i * mask[:, None, None] + residual = residual - z_q_i + + commitment_loss += (commitment_loss_i * mask).mean() + codebook_loss += (codebook_loss_i * mask).mean() + + codebook_indices.append(indices_i) + latents.append(z_e_i) + + codes = paddle.stack(codebook_indices, axis=1) + latents = paddle.concat(latents, axis=1) + + return z_q, codes, latents, commitment_loss, codebook_loss + + def from_codes(self, codes: paddle.Tensor): + z_q = 0 + z_p = [] + n_codebooks = codes.shape[1] + for i in range(n_codebooks): + z_p_i = self.quantizers[i].decode_code(codes[:, i, :]) + z_p.append(z_p_i) + + z_q_i = self.quantizers[i].out_proj(z_p_i) + z_q = z_q + z_q_i + return z_q, paddle.concat(z_p, axis=1), codes + + def from_latents(self, latents: paddle.Tensor): + z_q = 0 + z_p = [] + codes = [] + dims = np.cumsum([0] + [q.codebook_dim for q in self.quantizers]) + + n_codebooks = np.where(dims <= latents.shape[1])[0].max( + axis=0, keepdims=True)[0] + for i in range(n_codebooks): + j, k = dims[i], dims[i + 1] + z_p_i, codes_i = self.quantizers[i].decode_latents( + latents[:, j:k, :]) + z_p.append(z_p_i) + codes.append(codes_i) + + z_q_i = self.quantizers[i].out_proj(z_p_i) + z_q = z_q + z_q_i + + return z_q, paddle.concat(z_p, axis=1), paddle.stack(codes, axis=1) + + +if __name__ == "__main__": + rvq = ResidualVectorQuantize(quantizer_dropout=True) + x = paddle.randn((16, 512, 80)) + y = rvq(x) + print(y[2].shape) diff --git a/paddlespeech/codec/models/dac_/utils/decode.py b/paddlespeech/codec/models/dac_/utils/decode.py new file mode 100644 index 000000000..df8ca499d --- /dev/null +++ b/paddlespeech/codec/models/dac_/utils/decode.py @@ -0,0 +1,90 @@ +import warnings +from pathlib import Path + +import argbind +import numpy as np +import paddle +from tqdm import tqdm + +from dac import DACFile +from dac.utils import load_model + +warnings.filterwarnings("ignore", category=UserWarning) + + +@argbind.bind(group="decode", positional=True, without_prefix=True) +@paddle.no_grad() +def decode( + input: str, + output: str="", + weights_path: str="", + model_tag: str="latest", + model_bitrate: str="8kbps", + device: str="cuda", + model_type: str="44khz", + verbose: bool=False, ): + """Decode audio from codes. + + Parameters + ---------- + input : str + Path to input directory or file + output : str, optional + Path to output directory, by default "". + If `input` is a directory, the directory sub-tree relative to `input` is re-created in `output`. + weights_path : str, optional + Path to weights file, by default "". If not specified, the weights file will be downloaded from the internet using the + model_tag and model_type. + model_tag : str, optional + Tag of the model to use, by default "latest". Ignored if `weights_path` is specified. + model_bitrate: str + Bitrate of the model. Must be one of "8kbps", or "16kbps". Defaults to "8kbps". + device : str, optional + Device to use, by default "cuda". If "cpu", the model will be loaded on the CPU. + model_type : str, optional + The type of model to use. Must be one of "44khz", "24khz", or "16khz". Defaults to "44khz". Ignored if `weights_path` is specified. + """ + generator = load_model( + model_type=model_type, + model_bitrate=model_bitrate, + tag=model_tag, + load_path=weights_path, ) + generator.eval() + + # Find all .dac files in input directory + _input = Path(input) + input_files = list(_input.glob("**/*.dac")) + + # If input is a .dac file, add it to the list + if _input.suffix == ".dac": + input_files.append(_input) + + # Create output directory + output = Path(output) + output.mkdir(parents=True, exist_ok=True) + + for i in tqdm(range(len(input_files)), desc="Decoding files"): + # Load file + artifact = DACFile.load(input_files[i]) + + # Reconstruct audio from codes + recons = generator.decompress(artifact, verbose=verbose) + + # Compute output path + relative_path = input_files[i].relative_to(input) + output_dir = output / relative_path.parent + if not relative_path.name: + output_dir = output + relative_path = input_files[i] + output_name = relative_path.with_suffix(".wav").name + output_path = output_dir / output_name + output_path.parent.mkdir(parents=True, exist_ok=True) + + # Write to file + recons.write(output_path) + + +if __name__ == "__main__": + args = argbind.parse_args() + with argbind.scope(args): + decode() diff --git a/paddlespeech/codec/models/dac_/utils/encode.py b/paddlespeech/codec/models/dac_/utils/encode.py new file mode 100644 index 000000000..5b70f0010 --- /dev/null +++ b/paddlespeech/codec/models/dac_/utils/encode.py @@ -0,0 +1,89 @@ +import math +import warnings +from pathlib import Path + +import argbind +import paddle +from tqdm import tqdm + +from dac.utils import load_model +from paddlespeech.audiotools import AudioSignal +from paddlespeech.audiotools.core import util + +warnings.filterwarnings("ignore", category=UserWarning) + + +@argbind.bind(group="encode", positional=True, without_prefix=True) +@paddle.no_grad() +def encode( + input: str, + output: str="", + weights_path: str="", + model_tag: str="latest", + model_bitrate: str="8kbps", + n_quantizers: int=None, + model_type: str="44khz", + win_duration: float=5.0, + verbose: bool=False, ): + """Encode audio files in input path to .dac format. + + Parameters + ---------- + input : str + Path to input audio file or directory + output : str, optional + Path to output directory, by default "". If `input` is a directory, the directory sub-tree relative to `input` is re-created in `output`. + weights_path : str, optional + Path to weights file, by default "". If not specified, the weights file will be downloaded from the internet using the + model_tag and model_type. + model_tag : str, optional + Tag of the model to use, by default "latest". Ignored if `weights_path` is specified. + model_bitrate: str + Bitrate of the model. Must be one of "8kbps", or "16kbps". Defaults to "8kbps". + n_quantizers : int, optional + Number of quantizers to use, by default None. If not specified, all the quantizers will be used and the model will compress at maximum bitrate. + device : str, optional + Device to use, by default "cuda" + model_type : str, optional + The type of model to use. Must be one of "44khz", "24khz", or "16khz". Defaults to "44khz". Ignored if `weights_path` is specified. + """ + generator = load_model( + model_type=model_type, + model_bitrate=model_bitrate, + tag=model_tag, + load_path=weights_path, ) + generator.eval() + kwargs = {"n_quantizers": n_quantizers} + + # Find all audio files in input path + input = Path(input) + audio_files = util.find_audio(input) + + output = Path(output) + output.mkdir(parents=True, exist_ok=True) + + for i in tqdm(range(len(audio_files)), desc="Encoding files"): + # Load file + signal = AudioSignal(audio_files[i]) + + # Encode audio to .dac format + artifact = generator.compress( + signal, win_duration, verbose=verbose, **kwargs) + + # Compute output path + relative_path = audio_files[i].relative_to(input) + output_dir = output / relative_path.parent + if not relative_path.name: + output_dir = output + relative_path = audio_files[i] + output_name = relative_path.with_suffix(".dac").name + output_path = output_dir / output_name + output_path.parent.mkdir(parents=True, exist_ok=True) + + artifact.save(output_path) + + +if __name__ == "__main__": + args = argbind.parse_args() + with argbind.scope(args): + encode()