Merge b248cf9b69
into 8247eba840
commit
61512af956
@ -0,0 +1,445 @@
|
||||
import os
|
||||
import sys
|
||||
import warnings
|
||||
from dataclasses import dataclass
|
||||
from pathlib import Path
|
||||
|
||||
import argbind
|
||||
import paddle
|
||||
import paddle.nn as nn
|
||||
from visualdl import LogWriter
|
||||
|
||||
import paddlespeech
|
||||
import paddlespeech.t2s.modules.losses as _losses
|
||||
from paddlespeech.audiotools.core import AudioSignal
|
||||
from paddlespeech.audiotools.core import util
|
||||
from paddlespeech.audiotools.data import transforms
|
||||
from paddlespeech.audiotools.data.datasets import AudioDataset
|
||||
from paddlespeech.audiotools.data.datasets import AudioLoader
|
||||
from paddlespeech.audiotools.data.datasets import ConcatDataset
|
||||
from paddlespeech.audiotools.ml import Accelerator
|
||||
from paddlespeech.audiotools.ml.decorators import timer
|
||||
from paddlespeech.audiotools.ml.decorators import Tracker
|
||||
from paddlespeech.audiotools.ml.decorators import when
|
||||
from paddlespeech.codec.models.dac_.model import DAC
|
||||
from paddlespeech.codec.models.dac_.model import Discriminator
|
||||
from paddlespeech.t2s.training.seeding import seed_everything
|
||||
|
||||
warnings.filterwarnings("ignore", category=UserWarning)
|
||||
|
||||
# Optimizers
|
||||
AdamW = argbind.bind(paddle.optimizer.AdamW, "generator", "discriminator")
|
||||
# Accelerator = argbind.bind(ml.Accelerator, without_prefix=True)
|
||||
|
||||
|
||||
@argbind.bind("generator", "discriminator")
|
||||
def ExponentialLR(optimizer, gamma: float=1.0):
|
||||
return paddle.optimizer.lr.ExponentialDecay(optimizer, gamma)
|
||||
|
||||
|
||||
# Models
|
||||
# DAC = argbind.bind(dac.model.DAC)
|
||||
# Discriminator = argbind.bind(dac.model.Discriminator)
|
||||
DAC = argbind.bind(DAC)
|
||||
Discriminator = argbind.bind(Discriminator)
|
||||
|
||||
# Data
|
||||
AudioDataset = argbind.bind(AudioDataset, "train", "val")
|
||||
AudioLoader = argbind.bind(AudioLoader, "train", "val")
|
||||
|
||||
|
||||
# Loss
|
||||
# filter_fn = lambda fn: hasattr(fn, "forward") and "Loss" in fn.__name__
|
||||
def filter_fn(fn):
|
||||
return hasattr(fn, "forward") and "Loss" in fn.__name__
|
||||
|
||||
|
||||
losses = argbind.bind_module(_losses, filter_fn=filter_fn)
|
||||
|
||||
|
||||
def get_infinite_loader(dataloader):
|
||||
while True:
|
||||
for batch in dataloader:
|
||||
yield batch
|
||||
|
||||
|
||||
# Transforms
|
||||
# filter_fn = lambda fn: hasattr(fn, "transform") and fn.__qualname__ not in [
|
||||
# "BaseTransform",
|
||||
# "Compose",
|
||||
# "Choose", ]
|
||||
def filter_fn(fn):
|
||||
return hasattr(fn, "transform") and fn.__qualname__ not in [
|
||||
"BaseTransform",
|
||||
"Compose",
|
||||
"Choose",
|
||||
]
|
||||
|
||||
|
||||
tfm = argbind.bind_module(transforms, "train", "val", filter_fn=filter_fn)
|
||||
|
||||
|
||||
# to_tfm = lambda l: [getattr(tfm, x)() for x in l]
|
||||
def to_tfm(l):
|
||||
return [getattr(tfm, x)() for x in l]
|
||||
|
||||
|
||||
@argbind.bind("train", "val")
|
||||
def build_transform(
|
||||
augment_prob: float=1.0,
|
||||
preprocess: list=["Identity"],
|
||||
augment: list=["Identity"],
|
||||
postprocess: list=["Identity"], ):
|
||||
preprocess = transforms.Compose(*to_tfm(preprocess), name="preprocess")
|
||||
augment = transforms.Compose(
|
||||
*to_tfm(augment), name="augment", prob=augment_prob)
|
||||
postprocess = transforms.Compose(*to_tfm(postprocess), name="postprocess")
|
||||
transform = transforms.Compose(preprocess, augment, postprocess)
|
||||
return transform
|
||||
|
||||
|
||||
@argbind.bind("train", "val", "test")
|
||||
def build_dataset(
|
||||
sample_rate: int,
|
||||
folders: dict=None, ):
|
||||
# Give one loader per key/value of dictionary, where
|
||||
# value is a list of folders. Create a dataset for each one.
|
||||
# Concatenate the datasets with ConcatDataset, which
|
||||
# cycles through them.
|
||||
datasets = []
|
||||
for _, v in folders.items():
|
||||
loader = AudioLoader(sources=v)
|
||||
transform = build_transform()
|
||||
dataset = AudioDataset(loader, sample_rate, transform=transform)
|
||||
datasets.append(dataset)
|
||||
|
||||
dataset = ConcatDataset(datasets)
|
||||
dataset.transform = transform
|
||||
return dataset
|
||||
|
||||
|
||||
@dataclass
|
||||
class State:
|
||||
generator: DAC
|
||||
optimizer_g: AdamW
|
||||
scheduler_g: ExponentialLR
|
||||
|
||||
discriminator: Discriminator
|
||||
optimizer_d: AdamW
|
||||
scheduler_d: ExponentialLR
|
||||
|
||||
stft_loss: losses.MultiScaleSTFTLoss
|
||||
mel_loss: losses.MelSpectrogramLoss
|
||||
gan_loss: losses.GANLoss
|
||||
waveform_loss: nn.L1Loss
|
||||
|
||||
train_data: AudioDataset
|
||||
val_data: AudioDataset
|
||||
|
||||
tracker: Tracker
|
||||
|
||||
|
||||
# @argbind.bind(without_prefix=True)
|
||||
def load(
|
||||
args,
|
||||
accel: Accelerator,
|
||||
tracker: Tracker,
|
||||
save_path: str,
|
||||
resume: bool=False,
|
||||
tag: str="latest",
|
||||
load_weights: bool=False, ):
|
||||
generator, g_extra = None, {}
|
||||
discriminator, d_extra = None, {}
|
||||
|
||||
if resume:
|
||||
kwargs = {
|
||||
"folder": f"{save_path}/{tag}",
|
||||
"map_location": "cpu",
|
||||
"package": not load_weights,
|
||||
}
|
||||
tracker.print(
|
||||
f"Resuming from {str(Path('.').absolute())}/{kwargs['folder']}")
|
||||
if (Path(kwargs["folder"]) / "dac").exists():
|
||||
generator, g_extra = DAC.load_from_folder(**kwargs)
|
||||
if (Path(kwargs["folder"]) / "discriminator").exists():
|
||||
discriminator, d_extra = Discriminator.load_from_folder(**kwargs)
|
||||
|
||||
generator = DAC() if generator is None else generator
|
||||
discriminator = Discriminator() if discriminator is None else discriminator
|
||||
|
||||
tracker.print(generator)
|
||||
tracker.print(discriminator)
|
||||
|
||||
generator = accel.prepare_model(generator)
|
||||
discriminator = accel.prepare_model(discriminator)
|
||||
|
||||
with argbind.scope(args, "generator"):
|
||||
optimizer_g = AdamW(generator.parameters(), use_zero=accel.use_ddp)
|
||||
scheduler_g = ExponentialLR(optimizer_g)
|
||||
with argbind.scope(args, "discriminator"):
|
||||
optimizer_d = AdamW(discriminator.parameters(), use_zero=accel.use_ddp)
|
||||
scheduler_d = ExponentialLR(optimizer_d)
|
||||
|
||||
if "optimizer.pth" in g_extra:
|
||||
optimizer_g.load_state_dict(g_extra["optimizer.pth"])
|
||||
if "scheduler.pth" in g_extra:
|
||||
scheduler_g.load_state_dict(g_extra["scheduler.pth"])
|
||||
if "tracker.pth" in g_extra:
|
||||
tracker.load_state_dict(g_extra["tracker.pth"])
|
||||
|
||||
if "optimizer.pth" in d_extra:
|
||||
optimizer_d.load_state_dict(d_extra["optimizer.pth"])
|
||||
if "scheduler.pth" in d_extra:
|
||||
scheduler_d.load_state_dict(d_extra["scheduler.pth"])
|
||||
|
||||
sample_rate = accel.unwrap(generator).sample_rate
|
||||
with argbind.scope(args, "train"):
|
||||
train_data = build_dataset(sample_rate)
|
||||
with argbind.scope(args, "val"):
|
||||
val_data = build_dataset(sample_rate)
|
||||
|
||||
waveform_loss = nn.L1Loss()
|
||||
stft_loss = losses.MultiScaleSTFTLoss()
|
||||
mel_loss = losses.MelSpectrogramLoss()
|
||||
gan_loss = losses.GANLoss(discriminator)
|
||||
|
||||
return State(
|
||||
generator=generator,
|
||||
optimizer_g=optimizer_g,
|
||||
scheduler_g=scheduler_g,
|
||||
discriminator=discriminator,
|
||||
optimizer_d=optimizer_d,
|
||||
scheduler_d=scheduler_d,
|
||||
waveform_loss=waveform_loss,
|
||||
stft_loss=stft_loss,
|
||||
mel_loss=mel_loss,
|
||||
gan_loss=gan_loss,
|
||||
tracker=tracker,
|
||||
train_data=train_data,
|
||||
val_data=val_data, )
|
||||
|
||||
|
||||
@timer()
|
||||
@paddle.no_grad()
|
||||
def val_loop(batch, state, accel):
|
||||
state.generator.eval()
|
||||
batch = util.prepare_batch(batch, accel.device)
|
||||
signal = state.val_data.transform(batch["signal"].clone(),
|
||||
**batch["transform_args"])
|
||||
|
||||
out = state.generator(signal.audio_data, signal.sample_rate)
|
||||
recons = AudioSignal(out["audio"], signal.sample_rate)
|
||||
|
||||
return {
|
||||
"loss": state.mel_loss(recons, signal),
|
||||
"mel/loss": state.mel_loss(recons, signal),
|
||||
"stft/loss": state.stft_loss(recons, signal),
|
||||
"waveform/loss": state.waveform_loss(recons, signal),
|
||||
}
|
||||
|
||||
|
||||
@timer()
|
||||
def train_loop(state, batch, accel, lambdas):
|
||||
state.generator.train()
|
||||
state.discriminator.train()
|
||||
output = {}
|
||||
|
||||
batch = util.prepare_batch(batch, accel.device)
|
||||
with paddle.no_grad():
|
||||
signal = state.train_data.transform(batch["signal"].clone(),
|
||||
**batch["transform_args"])
|
||||
|
||||
with accel.autocast():
|
||||
out = state.generator(signal.audio_data, signal.sample_rate)
|
||||
recons = AudioSignal(out["audio"], signal.sample_rate)
|
||||
commitment_loss = out["vq/commitment_loss"]
|
||||
codebook_loss = out["vq/codebook_loss"]
|
||||
|
||||
with accel.autocast():
|
||||
output["adv/disc_loss"] = state.gan_loss.discriminator_loss(recons,
|
||||
signal)
|
||||
|
||||
state.optimizer_d.zero_grad()
|
||||
accel.backward(output["adv/disc_loss"])
|
||||
accel.scaler.unscale_(state.optimizer_d)
|
||||
output["other/grad_norm_d"] = paddle.nn.utils.clip_grad_norm_(
|
||||
state.discriminator.parameters(), 10.0)
|
||||
accel.step(state.optimizer_d)
|
||||
state.scheduler_d.step()
|
||||
|
||||
with accel.autocast():
|
||||
output["stft/loss"] = state.stft_loss(recons, signal)
|
||||
output["mel/loss"] = state.mel_loss(recons, signal)
|
||||
output["waveform/loss"] = state.waveform_loss(recons, signal)
|
||||
(output["adv/gen_loss"],
|
||||
output["adv/feat_loss"], ) = state.gan_loss.generator_loss(recons,
|
||||
signal)
|
||||
output["vq/commitment_loss"] = commitment_loss
|
||||
output["vq/codebook_loss"] = codebook_loss
|
||||
output["loss"] = sum(
|
||||
[v * output[k] for k, v in lambdas.items() if k in output])
|
||||
|
||||
state.optimizer_g.zero_grad()
|
||||
accel.backward(output["loss"])
|
||||
accel.scaler.unscale_(state.optimizer_g)
|
||||
output["other/grad_norm"] = paddle.nn.utils.clip_grad_norm_(
|
||||
state.generator.parameters(), 1e3)
|
||||
accel.step(state.optimizer_g)
|
||||
state.scheduler_g.step()
|
||||
accel.update()
|
||||
|
||||
output["other/learning_rate"] = state.optimizer_g.param_groups[0]["lr"]
|
||||
output["other/batch_size"] = signal.batch_size * accel.world_size
|
||||
|
||||
return {k: v for k, v in sorted(output.items())}
|
||||
|
||||
|
||||
def checkpoint(state, save_iters, save_path):
|
||||
metadata = {"logs": state.tracker.history}
|
||||
|
||||
tags = ["latest"]
|
||||
state.tracker.print(f"Saving to {str(Path('.').absolute())}")
|
||||
if state.tracker.is_best("val", "mel/loss"):
|
||||
state.tracker.print("Best generator so far")
|
||||
tags.append("best")
|
||||
if state.tracker.step in save_iters:
|
||||
tags.append(f"{state.tracker.step // 1000}k")
|
||||
|
||||
for tag in tags:
|
||||
generator_extra = {
|
||||
"optimizer.pth": state.optimizer_g.state_dict(),
|
||||
"scheduler.pth": state.scheduler_g.state_dict(),
|
||||
"tracker.pth": state.tracker.state_dict(),
|
||||
"metadata.pth": metadata,
|
||||
}
|
||||
accel.unwrap(state.generator).metadata = metadata
|
||||
accel.unwrap(state.generator).save_to_folder(f"{save_path}/{tag}",
|
||||
generator_extra)
|
||||
discriminator_extra = {
|
||||
"optimizer.pth": state.optimizer_d.state_dict(),
|
||||
"scheduler.pth": state.scheduler_d.state_dict(),
|
||||
}
|
||||
accel.unwrap(state.discriminator).save_to_folder(f"{save_path}/{tag}",
|
||||
discriminator_extra)
|
||||
|
||||
|
||||
@paddle.no_grad()
|
||||
def save_samples(state, val_idx, writer):
|
||||
state.tracker.print("Saving audio samples to TensorBoard")
|
||||
state.generator.eval()
|
||||
|
||||
samples = [state.val_data[idx] for idx in val_idx]
|
||||
batch = state.val_data.collate(samples)
|
||||
batch = util.prepare_batch(batch, accel.device)
|
||||
signal = state.train_data.transform(batch["signal"].clone(),
|
||||
**batch["transform_args"])
|
||||
|
||||
out = state.generator(signal.audio_data, signal.sample_rate)
|
||||
recons = AudioSignal(out["audio"], signal.sample_rate)
|
||||
|
||||
audio_dict = {"recons": recons}
|
||||
if state.tracker.step == 0:
|
||||
audio_dict["signal"] = signal
|
||||
|
||||
for k, v in audio_dict.items():
|
||||
for nb in range(v.batch_size):
|
||||
v[nb].cpu().write_audio_to_tb(f"{k}/sample_{nb}.wav", writer,
|
||||
state.tracker.step)
|
||||
|
||||
|
||||
def validate(state, val_dataloader, accel):
|
||||
for batch in val_dataloader:
|
||||
output = val_loop(batch, state, accel)
|
||||
# Consolidate state dicts if using ZeroRedundancyOptimizer
|
||||
if hasattr(state.optimizer_g, "consolidate_state_dict"):
|
||||
state.optimizer_g.consolidate_state_dict()
|
||||
state.optimizer_d.consolidate_state_dict()
|
||||
return output
|
||||
|
||||
|
||||
# @argbind.bind(without_prefix=True)
|
||||
def train(
|
||||
args,
|
||||
accel: Accelerator,
|
||||
seed: int=2025,
|
||||
save_path: str="ckpt",
|
||||
num_iters: int=250000,
|
||||
save_iters: list=[10000, 50000, 100000, 200000],
|
||||
sample_freq: int=10000,
|
||||
valid_freq: int=1000,
|
||||
batch_size: int=12,
|
||||
val_batch_size: int=10,
|
||||
num_workers: int=8,
|
||||
val_idx: list=[0, 1, 2, 3, 4, 5, 6, 7],
|
||||
lambdas: dict={
|
||||
"mel/loss": 100.0,
|
||||
"adv/feat_loss": 2.0,
|
||||
"adv/gen_loss": 1.0,
|
||||
"vq/commitment_loss": 0.25,
|
||||
"vq/codebook_loss": 1.0,
|
||||
}, ):
|
||||
seed_everything(seed)
|
||||
Path(save_path).mkdir(exist_ok=True, parents=True)
|
||||
writer = LogWriter(
|
||||
log_dir=f"{save_path}/logs") if accel.local_rank == 0 else None
|
||||
tracker = Tracker(
|
||||
writer=writer, log_file=f"{save_path}/log.txt", rank=accel.local_rank)
|
||||
|
||||
state = load(args, accel, tracker, save_path)
|
||||
train_dataloader = accel.prepare_dataloader(
|
||||
state.train_data,
|
||||
start_idx=state.tracker.step * batch_size,
|
||||
num_workers=num_workers,
|
||||
batch_size=batch_size,
|
||||
collate_fn=state.train_data.collate, )
|
||||
train_dataloader = get_infinite_loader(train_dataloader)
|
||||
val_dataloader = accel.prepare_dataloader(
|
||||
state.val_data,
|
||||
start_idx=0,
|
||||
num_workers=num_workers,
|
||||
batch_size=val_batch_size,
|
||||
collate_fn=state.val_data.collate,
|
||||
persistent_workers=True if num_workers > 0 else False, )
|
||||
|
||||
# Wrap the functions so that they neatly track in TensorBoard + progress bars
|
||||
# and only run when specific conditions are met.
|
||||
global train_loop, val_loop, validate, save_samples, checkpoint
|
||||
train_loop = tracker.log(
|
||||
"train", "value", history=False)(tracker.track(
|
||||
"train", num_iters, completed=state.tracker.step)(train_loop))
|
||||
val_loop = tracker.track("val", len(val_dataloader))(val_loop)
|
||||
validate = tracker.log("val", "mean")(validate)
|
||||
|
||||
# These functions run only on the 0-rank process
|
||||
save_samples = when(lambda: accel.local_rank == 0)(save_samples)
|
||||
checkpoint = when(lambda: accel.local_rank == 0)(checkpoint)
|
||||
|
||||
with tracker.live:
|
||||
for tracker.step, batch in enumerate(
|
||||
train_dataloader, start=tracker.step):
|
||||
train_loop(state, batch, accel, lambdas)
|
||||
|
||||
last_iter = (tracker.step == num_iters - 1
|
||||
if num_iters is not None else False)
|
||||
if tracker.step % sample_freq == 0 or last_iter:
|
||||
save_samples(state, val_idx, writer)
|
||||
|
||||
if tracker.step % valid_freq == 0 or last_iter:
|
||||
validate(state, val_dataloader, accel)
|
||||
checkpoint(state, save_iters, save_path)
|
||||
# Reset validation progress bar, print summary since last validation.
|
||||
tracker.done("val", f"Iteration {tracker.step}")
|
||||
|
||||
if last_iter:
|
||||
break
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
# args = argbind.parse_args()
|
||||
# args["args.debug"] = int(os.getenv("LOCAL_RANK", 0)) == 0
|
||||
# with argbind.scope(args):
|
||||
with Accelerator() as accel:
|
||||
if accel.local_rank != 0:
|
||||
sys.tracebacklimit = 0
|
||||
# train(args, accel)
|
||||
train(None, accel)
|
@ -0,0 +1,4 @@
|
||||
from .base import CodecMixin
|
||||
from .base import DACFile
|
||||
from .dac import DAC
|
||||
from .discriminator import Discriminator
|
@ -0,0 +1,298 @@
|
||||
import math
|
||||
from dataclasses import dataclass
|
||||
from pathlib import Path
|
||||
from typing import Union
|
||||
|
||||
import numpy as np
|
||||
import paddle
|
||||
import tqdm
|
||||
from paddle import nn
|
||||
|
||||
from paddlespeech.audiotools import AudioSignal
|
||||
|
||||
SUPPORTED_VERSIONS = ["1.0.0"]
|
||||
|
||||
|
||||
@dataclass
|
||||
class DACFile:
|
||||
codes: paddle.Tensor
|
||||
|
||||
# Metadata
|
||||
chunk_length: int
|
||||
original_length: int
|
||||
input_db: paddle.Tensor
|
||||
channels: int
|
||||
sample_rate: int
|
||||
padding: bool
|
||||
dac_version: str
|
||||
|
||||
def save(self, path):
|
||||
artifacts = {
|
||||
"codes": self.codes.numpy().astype(np.uint16),
|
||||
"metadata": {
|
||||
"input_db": self.input_db.numpy().astype(np.float32),
|
||||
"original_length": self.original_length,
|
||||
"sample_rate": self.sample_rate,
|
||||
"chunk_length": self.chunk_length,
|
||||
"channels": self.channels,
|
||||
"padding": self.padding,
|
||||
"dac_version": SUPPORTED_VERSIONS[-1],
|
||||
},
|
||||
}
|
||||
path = Path(path).with_suffix(".dac")
|
||||
with open(path, "wb") as f:
|
||||
np.save(f, artifacts)
|
||||
return path
|
||||
|
||||
@classmethod
|
||||
def load(cls, path):
|
||||
artifacts = np.load(path, allow_pickle=True)[()]
|
||||
codes = paddle.to_tensor(artifacts["codes"].astype(int))
|
||||
if artifacts["metadata"].get("dac_version",
|
||||
None) not in SUPPORTED_VERSIONS:
|
||||
raise RuntimeError(
|
||||
f"Given file {path} can't be loaded with this version of descript-audio-codec."
|
||||
)
|
||||
return cls(codes=codes, **artifacts["metadata"])
|
||||
|
||||
|
||||
class CodecMixin:
|
||||
@property
|
||||
def padding(self):
|
||||
if not hasattr(self, "_padding"):
|
||||
self._padding = True
|
||||
return self._padding
|
||||
|
||||
@padding.setter
|
||||
def padding(self, value):
|
||||
assert isinstance(value, bool)
|
||||
|
||||
layers = [
|
||||
l for l in self.sublayers()
|
||||
if isinstance(l, (nn.Conv1D, nn.Conv1DTranspose))
|
||||
]
|
||||
|
||||
for layer in layers:
|
||||
if value:
|
||||
if hasattr(layer, "original_padding"):
|
||||
layer._padding = layer.original_padding
|
||||
else:
|
||||
if isinstance(layer._padding, int):
|
||||
# TODO: drryanhuang, fix this condition
|
||||
layer._padding = [layer._padding]
|
||||
layer.original_padding = layer._padding
|
||||
layer._padding = tuple(0 for _ in range(len(layer._padding)))
|
||||
|
||||
self._padding = value
|
||||
|
||||
def get_delay(self):
|
||||
# Any number works here, delay is invariant to input length
|
||||
l_out = self.get_output_length(0)
|
||||
L = l_out
|
||||
|
||||
layers = []
|
||||
for layer in self.sublayers():
|
||||
if isinstance(layer, (nn.Conv1D, nn.Conv1DTranspose)):
|
||||
layers.append(layer)
|
||||
|
||||
for layer in reversed(layers):
|
||||
d = layer._dilation[0]
|
||||
k = layer._kernel_size[0]
|
||||
s = layer._stride[0]
|
||||
|
||||
if isinstance(layer, nn.Conv1DTranspose):
|
||||
L = ((L - d * (k - 1) - 1) / s) + 1
|
||||
elif isinstance(layer, nn.Conv1D):
|
||||
L = (L - 1) * s + d * (k - 1) + 1
|
||||
|
||||
L = math.ceil(L)
|
||||
|
||||
l_in = L
|
||||
|
||||
return (l_in - l_out) // 2
|
||||
|
||||
def get_output_length(self, input_length):
|
||||
L = input_length
|
||||
# Calculate output length
|
||||
for layer in self.sublayers():
|
||||
if isinstance(layer, (nn.Conv1D, nn.Conv1DTranspose)):
|
||||
d = layer._dilation[0]
|
||||
k = layer._kernel_size[0]
|
||||
s = layer._stride[0]
|
||||
|
||||
if isinstance(layer, nn.Conv1D):
|
||||
L = ((L - d * (k - 1) - 1) / s) + 1
|
||||
elif isinstance(layer, nn.Conv1DTranspose):
|
||||
L = (L - 1) * s + d * (k - 1) + 1
|
||||
|
||||
L = math.floor(L)
|
||||
return L
|
||||
|
||||
@paddle.no_grad()
|
||||
def compress(
|
||||
self,
|
||||
audio_path_or_signal: Union[str, Path, AudioSignal],
|
||||
win_duration: float=1.0,
|
||||
verbose: bool=False,
|
||||
normalize_db: float=-16,
|
||||
n_quantizers: int=None, ) -> DACFile:
|
||||
"""Processes an audio signal from a file or AudioSignal object into
|
||||
discrete codes. This function processes the signal in short windows,
|
||||
using constant GPU memory.
|
||||
|
||||
Parameters
|
||||
----------
|
||||
audio_path_or_signal : Union[str, Path, AudioSignal]
|
||||
audio signal to reconstruct
|
||||
win_duration : float, optional
|
||||
window duration in seconds, by default 5.0
|
||||
verbose : bool, optional
|
||||
by default False
|
||||
normalize_db : float, optional
|
||||
normalize db, by default -16
|
||||
|
||||
Returns
|
||||
-------
|
||||
DACFile
|
||||
Object containing compressed codes and metadata
|
||||
required for decompression
|
||||
"""
|
||||
audio_signal = audio_path_or_signal
|
||||
if isinstance(audio_signal, (str, Path)):
|
||||
audio_signal = AudioSignal.load_from_file_with_ffmpeg(
|
||||
str(audio_signal))
|
||||
|
||||
self.eval()
|
||||
original_padding = self.padding
|
||||
original_device = audio_signal.device
|
||||
|
||||
audio_signal = audio_signal.clone()
|
||||
original_sr = audio_signal.sample_rate
|
||||
|
||||
resample_fn = audio_signal.resample
|
||||
loudness_fn = audio_signal.loudness
|
||||
|
||||
# If audio is > 10 minutes long, use the ffmpeg versions
|
||||
if audio_signal.signal_duration >= 10 * 60 * 60:
|
||||
resample_fn = audio_signal.ffmpeg_resample
|
||||
loudness_fn = audio_signal.ffmpeg_loudness
|
||||
|
||||
original_length = audio_signal.signal_length
|
||||
resample_fn(self.sample_rate)
|
||||
input_db = loudness_fn()
|
||||
|
||||
if normalize_db is not None:
|
||||
audio_signal.normalize(normalize_db)
|
||||
audio_signal.ensure_max_of_audio()
|
||||
|
||||
nb, nac, nt = audio_signal.audio_data.shape
|
||||
audio_signal.audio_data = audio_signal.audio_data.reshape(
|
||||
[nb * nac, 1, nt])
|
||||
win_duration = (audio_signal.signal_duration
|
||||
if win_duration is None else win_duration)
|
||||
|
||||
if audio_signal.signal_duration <= win_duration:
|
||||
# Unchunked compression (used if signal length < win duration)
|
||||
self.padding = True
|
||||
n_samples = nt
|
||||
hop = nt
|
||||
else:
|
||||
# Chunked inference
|
||||
self.padding = False
|
||||
# Zero-pad signal on either side by the delay
|
||||
audio_signal.zero_pad(self.delay, self.delay)
|
||||
n_samples = int(win_duration * self.sample_rate)
|
||||
# Round n_samples to nearest hop length multiple
|
||||
n_samples = int(
|
||||
math.ceil(n_samples / self.hop_length) * self.hop_length)
|
||||
hop = self.get_output_length(n_samples)
|
||||
|
||||
codes = []
|
||||
range_fn = range if not verbose else tqdm.trange
|
||||
|
||||
for i in range_fn(0, nt, hop):
|
||||
x = audio_signal[..., i:i + n_samples]
|
||||
x = x.zero_pad(0, max(0, n_samples - x.shape[-1]))
|
||||
|
||||
audio_data = x.audio_data
|
||||
audio_data = self.preprocess(audio_data, self.sample_rate)
|
||||
_, c, _, _, _ = self.encode(audio_data, n_quantizers)
|
||||
codes.append(c)
|
||||
chunk_length = c.shape[-1]
|
||||
|
||||
codes = paddle.concat(codes, axis=-1)
|
||||
|
||||
dac_file = DACFile(
|
||||
codes=codes,
|
||||
chunk_length=chunk_length,
|
||||
original_length=original_length,
|
||||
input_db=input_db,
|
||||
channels=nac,
|
||||
sample_rate=original_sr,
|
||||
padding=self.padding,
|
||||
dac_version=SUPPORTED_VERSIONS[-1], )
|
||||
|
||||
if n_quantizers is not None:
|
||||
codes = codes[:, :n_quantizers, :]
|
||||
|
||||
self.padding = original_padding
|
||||
return dac_file
|
||||
|
||||
@paddle.no_grad()
|
||||
def decompress(
|
||||
self,
|
||||
obj: Union[str, Path, DACFile],
|
||||
verbose: bool=False, ) -> AudioSignal:
|
||||
"""Reconstruct audio from a given .dac file
|
||||
|
||||
Parameters
|
||||
----------
|
||||
obj : Union[str, Path, DACFile]
|
||||
.dac file location or corresponding DACFile object.
|
||||
verbose : bool, optional
|
||||
Prints progress if True, by default False
|
||||
|
||||
Returns
|
||||
-------
|
||||
AudioSignal
|
||||
Object with the reconstructed audio
|
||||
"""
|
||||
self.eval()
|
||||
if isinstance(obj, (str, Path)):
|
||||
obj = DACFile.load(obj)
|
||||
|
||||
original_padding = self.padding
|
||||
self.padding = obj.padding
|
||||
|
||||
range_fn = range if not verbose else tqdm.trange
|
||||
codes = obj.codes
|
||||
# original_device = codes.device
|
||||
chunk_length = obj.chunk_length
|
||||
recons = []
|
||||
|
||||
for i in range_fn(0, codes.shape[-1], chunk_length):
|
||||
c = codes[..., i:i + chunk_length]
|
||||
z = self.quantizer.from_codes(c)[0]
|
||||
r = self.decode(z)
|
||||
recons.append(r)
|
||||
|
||||
recons = paddle.concat(recons, axis=-1)
|
||||
recons = AudioSignal(recons, self.sample_rate)
|
||||
|
||||
resample_fn = recons.resample
|
||||
loudness_fn = recons.loudness
|
||||
|
||||
# If audio is > 10 minutes long, use the ffmpeg versions
|
||||
if recons.signal_duration >= 10 * 60 * 60:
|
||||
resample_fn = recons.ffmpeg_resample
|
||||
loudness_fn = recons.ffmpeg_loudness
|
||||
|
||||
recons.normalize(obj.input_db)
|
||||
resample_fn(obj.sample_rate)
|
||||
recons = recons[..., :obj.original_length]
|
||||
loudness_fn()
|
||||
recons.audio_data = recons.audio_data.reshape(
|
||||
[-1, obj.channels, obj.original_length])
|
||||
|
||||
self.padding = original_padding
|
||||
return recons
|
@ -0,0 +1,354 @@
|
||||
from functools import partial
|
||||
|
||||
import numpy as np
|
||||
import paddle
|
||||
import paddle.nn.functional as F
|
||||
from paddle import nn
|
||||
|
||||
from paddlespeech.audiotools import AudioSignal
|
||||
from paddlespeech.codec.models.dac_.model.base import CodecMixin
|
||||
from paddlespeech.codec.models.dac_.nn.layers import Snake1d
|
||||
from paddlespeech.codec.models.dac_.nn.layers import WNConv1d
|
||||
from paddlespeech.codec.models.dac_.nn.layers import WNConvTranspose1d
|
||||
from paddlespeech.codec.models.dac_.nn.quantize import ResidualVectorQuantize
|
||||
|
||||
|
||||
def init_weights(m):
|
||||
if isinstance(m, nn.Conv1D):
|
||||
nn.initializer.TruncatedNormal(std=0.02)(m.weight)
|
||||
nn.initializer.Constant(0)(m.bias)
|
||||
|
||||
|
||||
class ResidualUnit(nn.Layer):
|
||||
def __init__(self, dim: int=16, dilation: int=1):
|
||||
super(ResidualUnit, self).__init__()
|
||||
pad = ((7 - 1) * dilation) // 2
|
||||
self.block = nn.Sequential(
|
||||
Snake1d(dim),
|
||||
WNConv1d(dim, dim, kernel_size=7, dilation=dilation, padding=pad),
|
||||
Snake1d(dim),
|
||||
WNConv1d(dim, dim, kernel_size=1), )
|
||||
|
||||
def forward(self, x):
|
||||
y = self.block(x)
|
||||
pad = (x.shape[-1] - y.shape[-1]) // 2
|
||||
if pad > 0:
|
||||
x = x[..., pad:-pad]
|
||||
return x + y
|
||||
|
||||
|
||||
class EncoderBlock(nn.Layer):
|
||||
def __init__(self, dim: int=16, stride: int=1):
|
||||
super(EncoderBlock, self).__init__()
|
||||
self.block = nn.Sequential(
|
||||
ResidualUnit(dim // 2, dilation=1),
|
||||
ResidualUnit(dim // 2, dilation=3),
|
||||
ResidualUnit(dim // 2, dilation=9),
|
||||
Snake1d(dim // 2),
|
||||
WNConv1d(
|
||||
dim // 2,
|
||||
dim,
|
||||
kernel_size=2 * stride,
|
||||
stride=stride,
|
||||
padding=int(np.ceil(stride / 2).astype(int)), ), )
|
||||
|
||||
def forward(self, x):
|
||||
return self.block(x)
|
||||
|
||||
|
||||
class Encoder(nn.Layer):
|
||||
def __init__(
|
||||
self,
|
||||
d_model: int=64,
|
||||
strides: list=[2, 4, 8, 8],
|
||||
d_latent: int=64, ):
|
||||
super(Encoder, self).__init__()
|
||||
# Create first convolution
|
||||
self.block = [WNConv1d(1, d_model, kernel_size=7, padding=3)]
|
||||
|
||||
# Create EncoderBlocks that double channels as they downsample by `stride`
|
||||
for stride in strides:
|
||||
d_model *= 2
|
||||
self.block += [EncoderBlock(d_model, stride=stride)]
|
||||
|
||||
# Create last convolution
|
||||
self.block += [
|
||||
Snake1d(d_model),
|
||||
WNConv1d(d_model, d_latent, kernel_size=3, padding=1),
|
||||
]
|
||||
|
||||
# Wrap black into nn.Sequential
|
||||
self.block = nn.Sequential(*self.block)
|
||||
self.enc_dim = d_model
|
||||
|
||||
def forward(self, x):
|
||||
return self.block(x)
|
||||
|
||||
|
||||
class DecoderBlock(nn.Layer):
|
||||
def __init__(self, input_dim: int=16, output_dim: int=8, stride: int=1):
|
||||
super(DecoderBlock, self).__init__()
|
||||
self.block = nn.Sequential(
|
||||
Snake1d(input_dim),
|
||||
WNConvTranspose1d(
|
||||
input_dim,
|
||||
output_dim,
|
||||
kernel_size=2 * stride,
|
||||
stride=stride,
|
||||
padding=int(np.ceil(stride / 2).astype(int)), ),
|
||||
ResidualUnit(output_dim, dilation=1),
|
||||
ResidualUnit(output_dim, dilation=3),
|
||||
ResidualUnit(output_dim, dilation=9), )
|
||||
|
||||
def forward(self, x):
|
||||
return self.block(x)
|
||||
|
||||
|
||||
class Decoder(nn.Layer):
|
||||
def __init__(
|
||||
self,
|
||||
input_channel,
|
||||
channels,
|
||||
rates,
|
||||
d_out: int=1, ):
|
||||
super(Decoder, self).__init__()
|
||||
|
||||
# Add first conv layer
|
||||
layers = [WNConv1d(input_channel, channels, kernel_size=7, padding=3)]
|
||||
|
||||
# Add upsampling + MRF blocks
|
||||
for i, stride in enumerate(rates):
|
||||
input_dim = channels // 2**i
|
||||
output_dim = channels // 2**(i + 1)
|
||||
layers += [DecoderBlock(input_dim, output_dim, stride)]
|
||||
|
||||
# Add final conv layer
|
||||
layers += [
|
||||
Snake1d(output_dim),
|
||||
WNConv1d(output_dim, d_out, kernel_size=7, padding=3),
|
||||
nn.Tanh(),
|
||||
]
|
||||
|
||||
self.model = nn.Sequential(*layers)
|
||||
|
||||
def forward(self, x):
|
||||
return self.model(x)
|
||||
|
||||
|
||||
class DAC(nn.Layer, CodecMixin):
|
||||
def __init__(
|
||||
self,
|
||||
encoder_dim: int=64,
|
||||
encoder_rates: list=[2, 4, 8, 8],
|
||||
latent_dim: int=None,
|
||||
decoder_dim: int=1536,
|
||||
decoder_rates: list=[8, 8, 4, 2],
|
||||
n_codebooks: int=9,
|
||||
codebook_size: int=1024,
|
||||
codebook_dim: int=8,
|
||||
quantizer_dropout: bool=False,
|
||||
sample_rate: int=44100, ):
|
||||
super(DAC, self).__init__()
|
||||
|
||||
self.encoder_dim = encoder_dim
|
||||
self.encoder_rates = encoder_rates
|
||||
self.decoder_dim = decoder_dim
|
||||
self.decoder_rates = decoder_rates
|
||||
self.sample_rate = sample_rate
|
||||
|
||||
if latent_dim is None:
|
||||
latent_dim = encoder_dim * (2**len(encoder_rates))
|
||||
|
||||
self.latent_dim = latent_dim
|
||||
|
||||
self.hop_length = np.prod(encoder_rates)
|
||||
self.encoder = Encoder(encoder_dim, encoder_rates, latent_dim)
|
||||
|
||||
self.n_codebooks = n_codebooks
|
||||
self.codebook_size = codebook_size
|
||||
self.codebook_dim = codebook_dim
|
||||
self.quantizer = ResidualVectorQuantize(
|
||||
input_dim=latent_dim,
|
||||
n_codebooks=n_codebooks,
|
||||
codebook_size=codebook_size,
|
||||
codebook_dim=codebook_dim,
|
||||
quantizer_dropout=quantizer_dropout, )
|
||||
|
||||
self.decoder = Decoder(
|
||||
latent_dim,
|
||||
decoder_dim,
|
||||
decoder_rates, )
|
||||
self.sample_rate = sample_rate
|
||||
self.apply(init_weights)
|
||||
|
||||
self.delay = self.get_delay()
|
||||
|
||||
def preprocess(self, audio_data, sample_rate):
|
||||
if sample_rate is None:
|
||||
sample_rate = self.sample_rate
|
||||
assert sample_rate == self.sample_rate
|
||||
|
||||
length = audio_data.shape[-1]
|
||||
right_pad = np.ceil(length / self.hop_length) * self.hop_length - length
|
||||
audio_data = F.pad(
|
||||
audio_data, [0, int(right_pad)],
|
||||
mode='constant',
|
||||
value=0,
|
||||
data_format="NCL")
|
||||
|
||||
return audio_data
|
||||
|
||||
def encode(
|
||||
self,
|
||||
audio_data: paddle.Tensor,
|
||||
n_quantizers: int=None, ):
|
||||
"""Encode given audio data and return quantized latent codes
|
||||
|
||||
Parameters
|
||||
----------
|
||||
audio_data : Tensor[B x 1 x T]
|
||||
Audio data to encode
|
||||
n_quantizers : int, optional
|
||||
Number of quantizers to use, by default None
|
||||
If None, all quantizers are used.
|
||||
|
||||
Returns
|
||||
-------
|
||||
dict
|
||||
A dictionary with the following keys:
|
||||
"z" : Tensor[B x D x T]
|
||||
Quantized continuous representation of input
|
||||
"codes" : Tensor[B x N x T]
|
||||
Codebook indices for each codebook
|
||||
(quantized discrete representation of input)
|
||||
"latents" : Tensor[B x N*D x T]
|
||||
Projected latents (continuous representation of input before quantization)
|
||||
"vq/commitment_loss" : Tensor[1]
|
||||
Commitment loss to train encoder to predict vectors closer to codebook
|
||||
entries
|
||||
"vq/codebook_loss" : Tensor[1]
|
||||
Codebook loss to update the codebook
|
||||
"length" : int
|
||||
Number of samples in input audio
|
||||
"""
|
||||
z = self.encoder(audio_data)
|
||||
z, codes, latents, commitment_loss, codebook_loss = self.quantizer(
|
||||
z, n_quantizers)
|
||||
return z, codes, latents, commitment_loss, codebook_loss
|
||||
|
||||
def decode(self, z: paddle.Tensor):
|
||||
"""Decode given latent codes and return audio data
|
||||
|
||||
Parameters
|
||||
----------
|
||||
z : Tensor[B x D x T]
|
||||
Quantized continuous representation of input
|
||||
length : int, optional
|
||||
Number of samples in output audio, by default None
|
||||
|
||||
Returns
|
||||
-------
|
||||
dict
|
||||
A dictionary with the following keys:
|
||||
"audio" : Tensor[B x 1 x length]
|
||||
Decoded audio data.
|
||||
"""
|
||||
return self.decoder(z)
|
||||
|
||||
def forward(
|
||||
self,
|
||||
audio_data: paddle.Tensor,
|
||||
sample_rate: int=None,
|
||||
n_quantizers: int=None, ):
|
||||
"""Model forward pass
|
||||
|
||||
Parameters
|
||||
----------
|
||||
audio_data : Tensor[B x 1 x T]
|
||||
Audio data to encode
|
||||
sample_rate : int, optional
|
||||
Sample rate of audio data in Hz, by default None
|
||||
If None, defaults to `self.sample_rate`
|
||||
n_quantizers : int, optional
|
||||
Number of quantizers to use, by default None.
|
||||
If None, all quantizers are used.
|
||||
|
||||
Returns
|
||||
-------
|
||||
dict
|
||||
A dictionary with the following keys:
|
||||
"z" : Tensor[B x D x T]
|
||||
Quantized continuous representation of input
|
||||
"codes" : Tensor[B x N x T]
|
||||
Codebook indices for each codebook
|
||||
(quantized discrete representation of input)
|
||||
"latents" : Tensor[B x N*D x T]
|
||||
Projected latents (continuous representation of input before quantization)
|
||||
"vq/commitment_loss" : Tensor[1]
|
||||
Commitment loss to train encoder to predict vectors closer to codebook
|
||||
entries
|
||||
"vq/codebook_loss" : Tensor[1]
|
||||
Codebook loss to update the codebook
|
||||
"length" : int
|
||||
Number of samples in input audio
|
||||
"audio" : Tensor[B x 1 x length]
|
||||
Decoded audio data.
|
||||
"""
|
||||
length = audio_data.shape[-1]
|
||||
audio_data = self.preprocess(audio_data, sample_rate)
|
||||
z, codes, latents, commitment_loss, codebook_loss = self.encode(
|
||||
audio_data, n_quantizers)
|
||||
|
||||
x = self.decode(z)
|
||||
return {
|
||||
"audio": x[..., :length],
|
||||
"z": z,
|
||||
"codes": codes,
|
||||
"latents": latents,
|
||||
"vq/commitment_loss": commitment_loss,
|
||||
"vq/codebook_loss": codebook_loss,
|
||||
}
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
|
||||
model = DAC()
|
||||
|
||||
def fn(o, p):
|
||||
return o + f" {p / 1e6:<.3f}M params."
|
||||
|
||||
for n, m in model.named_sublayers():
|
||||
# print(type(m))
|
||||
o = m.extra_repr()
|
||||
p = sum([np.prod(p.shape) for p in m.parameters()])
|
||||
setattr(m, "extra_repr", partial(fn, o=o, p=p))
|
||||
print(model)
|
||||
print("Total # of params: ",
|
||||
sum([np.prod(p.shape) for p in model.parameters()]))
|
||||
|
||||
length = 88200 * 2
|
||||
x = paddle.randn([1, 1, length])
|
||||
x.stop_gradient = False
|
||||
|
||||
# Make a forward pass
|
||||
out = model(x)["audio"]
|
||||
print("Input shape:", x.shape)
|
||||
print("Output shape:", out.shape)
|
||||
|
||||
# Create gradient variable
|
||||
grad = paddle.zeros_like(out)
|
||||
grad[:, :, grad.shape[-1] // 2] = 1
|
||||
|
||||
# Make a backward pass
|
||||
out.backward(grad)
|
||||
|
||||
# Check non-zero values
|
||||
gradmap = x.grad.squeeze(0)
|
||||
gradmap = (gradmap != 0).sum(0) # sum across features
|
||||
rf = (gradmap != 0).sum()
|
||||
|
||||
print(f"Receptive field: {rf.item()}"
|
||||
) # TODO: drryanhuang, fix this question, why is RF zero?
|
||||
|
||||
x = AudioSignal(paddle.randn([1, 1, 44100 * 60]), 44100)
|
||||
model.decompress(model.compress(x, verbose=True), verbose=True)
|
@ -0,0 +1,235 @@
|
||||
import paddle
|
||||
import paddle.nn as nn
|
||||
import paddle.nn.functional as F
|
||||
|
||||
from paddlespeech.audiotools import AudioSignal
|
||||
from paddlespeech.audiotools import ml
|
||||
from paddlespeech.audiotools import STFTParams
|
||||
|
||||
|
||||
def WNConv1d(*args, **kwargs):
|
||||
act = kwargs.pop("act", True)
|
||||
model = nn.Conv1D(*args, **kwargs)
|
||||
conv = nn.utils.weight_norm(model)
|
||||
if not act:
|
||||
return conv
|
||||
return nn.Sequential(conv, nn.LeakyReLU(0.1))
|
||||
|
||||
|
||||
def WNConv2d(*args, **kwargs):
|
||||
act = kwargs.pop("act", True)
|
||||
model = nn.Conv2D(*args, **kwargs)
|
||||
conv = nn.utils.weight_norm(model)
|
||||
if not act:
|
||||
return conv
|
||||
return nn.Sequential(conv, nn.LeakyReLU(0.1))
|
||||
|
||||
|
||||
class MPD(nn.Layer):
|
||||
def __init__(self, period):
|
||||
super(MPD, self).__init__()
|
||||
self.period = period
|
||||
self.convs = nn.LayerList([
|
||||
WNConv2d(1, 32, (5, 1), (3, 1), padding=(2, 0)),
|
||||
WNConv2d(32, 128, (5, 1), (3, 1), padding=(2, 0)),
|
||||
WNConv2d(128, 512, (5, 1), (3, 1), padding=(2, 0)),
|
||||
WNConv2d(512, 1024, (5, 1), (3, 1), padding=(2, 0)),
|
||||
WNConv2d(1024, 1024, (5, 1), 1, padding=(2, 0)),
|
||||
])
|
||||
self.conv_post = WNConv2d(
|
||||
1024, 1, kernel_size=(3, 1), padding=(1, 0), act=False)
|
||||
|
||||
def pad_to_period(self, x):
|
||||
t = x.shape[-1]
|
||||
x = F.pad(
|
||||
x, (0, self.period - t % self.period),
|
||||
mode="reflect",
|
||||
data_format="NCL")
|
||||
return x
|
||||
|
||||
def forward(self, x):
|
||||
fmap = []
|
||||
|
||||
x = self.pad_to_period(x)
|
||||
# x = rearrange(x, "b c (l p) -> b c l p", p=self.period)
|
||||
b, c, lp = x.shape
|
||||
l = lp // self.period
|
||||
p = self.period
|
||||
x = x.reshape([b, c, l, p])
|
||||
|
||||
for layer in self.convs:
|
||||
x = layer(x)
|
||||
fmap.append(x)
|
||||
|
||||
x = self.conv_post(x)
|
||||
fmap.append(x)
|
||||
|
||||
return fmap
|
||||
|
||||
|
||||
class MSD(nn.Layer):
|
||||
def __init__(self, rate: int=1, sample_rate: int=44100):
|
||||
super(MSD, self).__init__()
|
||||
self.convs = nn.LayerList([
|
||||
WNConv1d(1, 16, 15, 1, padding=7),
|
||||
WNConv1d(16, 64, 41, 4, groups=4, padding=20),
|
||||
WNConv1d(64, 256, 41, 4, groups=16, padding=20),
|
||||
WNConv1d(256, 1024, 41, 4, groups=64, padding=20),
|
||||
WNConv1d(1024, 1024, 41, 4, groups=256, padding=20),
|
||||
WNConv1d(1024, 1024, 5, 1, padding=2),
|
||||
])
|
||||
self.conv_post = WNConv1d(1024, 1, 3, 1, padding=1, act=False)
|
||||
self.sample_rate = sample_rate
|
||||
self.rate = rate
|
||||
|
||||
def forward(self, x):
|
||||
x = AudioSignal(x, self.sample_rate)
|
||||
x.resample(self.sample_rate // self.rate)
|
||||
x = x.audio_data
|
||||
|
||||
fmap = []
|
||||
|
||||
for l in self.convs:
|
||||
x = l(x)
|
||||
fmap.append(x)
|
||||
x = self.conv_post(x)
|
||||
fmap.append(x)
|
||||
|
||||
return fmap
|
||||
|
||||
|
||||
BANDS = [(0.0, 0.1), (0.1, 0.25), (0.25, 0.5), (0.5, 0.75), (0.75, 1.0)]
|
||||
|
||||
|
||||
class MRD(nn.Layer):
|
||||
def __init__(
|
||||
self,
|
||||
window_length: int,
|
||||
hop_factor: float=0.25,
|
||||
sample_rate: int=44100,
|
||||
bands: list=BANDS, ):
|
||||
"""Complex multi-band spectrogram discriminator.
|
||||
Parameters
|
||||
----------
|
||||
window_length : int
|
||||
Window length of STFT.
|
||||
hop_factor : float, optional
|
||||
Hop factor of the STFT, defaults to ``0.25 * window_length``.
|
||||
sample_rate : int, optional
|
||||
Sampling rate of audio in Hz, by default 44100
|
||||
bands : list, optional
|
||||
Bands to run discriminator over.
|
||||
"""
|
||||
super(MRD, self).__init__()
|
||||
|
||||
self.window_length = window_length
|
||||
self.hop_factor = hop_factor
|
||||
self.sample_rate = sample_rate
|
||||
self.stft_params = STFTParams(
|
||||
window_length=window_length,
|
||||
hop_length=int(window_length * hop_factor),
|
||||
match_stride=True, )
|
||||
|
||||
n_fft = window_length // 2 + 1
|
||||
bands = [(int(b[0] * n_fft), int(b[1] * n_fft)) for b in bands]
|
||||
self.bands = bands
|
||||
|
||||
ch = 32
|
||||
|
||||
def convs():
|
||||
return nn.LayerList([
|
||||
WNConv2d(2, ch, (3, 9), (1, 1), padding=(1, 4)),
|
||||
WNConv2d(ch, ch, (3, 9), (1, 2), padding=(1, 4)),
|
||||
WNConv2d(ch, ch, (3, 9), (1, 2), padding=(1, 4)),
|
||||
WNConv2d(ch, ch, (3, 9), (1, 2), padding=(1, 4)),
|
||||
WNConv2d(ch, ch, (3, 3), (1, 1), padding=(1, 1)),
|
||||
])
|
||||
|
||||
self.band_convs = nn.LayerList(
|
||||
[convs() for _ in range(len(self.bands))])
|
||||
self.conv_post = WNConv2d(
|
||||
ch, 1, (3, 3), (1, 1), padding=(1, 1), act=False)
|
||||
|
||||
def spectrogram(self, x):
|
||||
x = AudioSignal(x, self.sample_rate, stft_params=self.stft_params)
|
||||
x = paddle.as_real(x.stft())
|
||||
# x = rearrange(x, "b 1 f t c -> (b 1) c t f")
|
||||
x = x.transpose([0, 1, 4, 3, 2]).flatten(stop_axis=1)
|
||||
|
||||
# Split into bands
|
||||
x_bands = [x[..., b[0]:b[1]] for b in self.bands]
|
||||
return x_bands
|
||||
|
||||
def forward(self, x):
|
||||
x_bands = self.spectrogram(x)
|
||||
fmap = []
|
||||
|
||||
x = []
|
||||
for band, stack in zip(x_bands, self.band_convs):
|
||||
for layer in stack:
|
||||
band = layer(band)
|
||||
fmap.append(band)
|
||||
x.append(band)
|
||||
|
||||
x = paddle.concat(x, axis=-1)
|
||||
x = self.conv_post(x)
|
||||
fmap.append(x)
|
||||
|
||||
return fmap
|
||||
|
||||
|
||||
class Discriminator(ml.BaseModel):
|
||||
def __init__(
|
||||
self,
|
||||
rates: list=[],
|
||||
periods: list=[2, 3, 5, 7, 11],
|
||||
fft_sizes: list=[2048, 1024, 512],
|
||||
sample_rate: int=44100,
|
||||
bands: list=BANDS, ):
|
||||
"""Discriminator that combines multiple discriminators.
|
||||
|
||||
Parameters
|
||||
----------
|
||||
rates : list, optional
|
||||
sampling rates (in Hz) to run MSD at, by default []
|
||||
If empty, MSD is not used.
|
||||
periods : list, optional
|
||||
periods (of samples) to run MPD at, by default [2, 3, 5, 7, 11]
|
||||
fft_sizes : list, optional
|
||||
Window sizes of the FFT to run MRD at, by default [2048, 1024, 512]
|
||||
sample_rate : int, optional
|
||||
Sampling rate of audio in Hz, by default 44100
|
||||
bands : list, optional
|
||||
Bands to run MRD at, by default `BANDS`
|
||||
"""
|
||||
super(Discriminator, self).__init__()
|
||||
discs = []
|
||||
discs += [MPD(p) for p in periods]
|
||||
discs += [MSD(r, sample_rate=sample_rate) for r in rates]
|
||||
discs += [
|
||||
MRD(f, sample_rate=sample_rate, bands=bands) for f in fft_sizes
|
||||
]
|
||||
self.discriminators = nn.LayerList(discs)
|
||||
|
||||
def preprocess(self, y):
|
||||
# Remove DC offset
|
||||
y = y - y.mean(axis=-1, keepdim=True)
|
||||
# Peak normalize the volume of input audio
|
||||
y = 0.8 * y / (paddle.abs(y).max(axis=-1, keepdim=True)[0] + 1e-9)
|
||||
return y
|
||||
|
||||
def forward(self, x):
|
||||
x = self.preprocess(x)
|
||||
fmaps = [d(x) for d in self.discriminators]
|
||||
return fmaps
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
disc = Discriminator()
|
||||
x = paddle.zeros([1, 1, 44100])
|
||||
results = disc(x)
|
||||
for i, result in enumerate(results):
|
||||
print(f"disc{i}")
|
||||
for i, r in enumerate(result):
|
||||
print(r.shape, r.mean(), r.min(), r.max())
|
||||
print()
|
@ -0,0 +1,3 @@
|
||||
from . import layers
|
||||
from . import quantize
|
||||
# from . import loss
|
@ -0,0 +1,31 @@
|
||||
import paddle
|
||||
import paddle.nn as nn
|
||||
from paddle.nn.utils import weight_norm
|
||||
|
||||
|
||||
def WNConv1d(*args, **kwargs):
|
||||
return weight_norm(nn.Conv1D(*args, **kwargs), name='weight', dim=1)
|
||||
|
||||
|
||||
def WNConvTranspose1d(*args, **kwargs):
|
||||
return weight_norm(
|
||||
nn.Conv1DTranspose(*args, **kwargs), name='weight', dim=1)
|
||||
|
||||
|
||||
def snake(x, alpha):
|
||||
shape = x.shape
|
||||
x = x.reshape([shape[0], shape[1], -1])
|
||||
x = x + (alpha + 1e-9).reciprocal() * paddle.sin(alpha * x).pow(2)
|
||||
x = x.reshape(shape)
|
||||
return x
|
||||
|
||||
|
||||
class Snake1d(nn.Layer):
|
||||
def __init__(self, channels):
|
||||
super(Snake1d, self).__init__()
|
||||
self.alpha = self.create_parameter(
|
||||
shape=[1, channels, 1],
|
||||
default_initializer=nn.initializer.Constant(1.0))
|
||||
|
||||
def forward(self, x):
|
||||
return snake(x, self.alpha)
|
@ -0,0 +1,161 @@
|
||||
from typing import Union
|
||||
|
||||
import numpy as np
|
||||
import paddle
|
||||
import paddle.nn as nn
|
||||
import paddle.nn.functional as F
|
||||
from paddle.nn.utils import weight_norm
|
||||
|
||||
from paddlespeech.codec.models.dac_.nn.layers import WNConv1d
|
||||
|
||||
|
||||
class VectorQuantize(nn.Layer):
|
||||
def __init__(self, input_dim, codebook_size, codebook_dim):
|
||||
super(VectorQuantize, self).__init__()
|
||||
self.codebook_size = codebook_size
|
||||
self.codebook_dim = codebook_dim
|
||||
|
||||
self.in_proj = WNConv1d(input_dim, codebook_dim, kernel_size=1)
|
||||
self.out_proj = WNConv1d(codebook_dim, input_dim, kernel_size=1)
|
||||
self.codebook = nn.Embedding(codebook_size, codebook_dim)
|
||||
|
||||
def forward(self, z):
|
||||
z_e = self.in_proj(z) # z_e : (B x D x T)
|
||||
z_q, indices = self.decode_latents(z_e)
|
||||
|
||||
commitment_loss = F.mse_loss(
|
||||
z_e, z_q.detach(), reduction='none').mean(axis=[1, 2])
|
||||
codebook_loss = F.mse_loss(
|
||||
z_q, z_e.detach(), reduction='none').mean(axis=[1, 2])
|
||||
|
||||
z_q = z_e + (z_q - z_e).detach(
|
||||
) # noop in forward pass, straight-through gradient estimator in backward pass
|
||||
z_q = self.out_proj(z_q)
|
||||
|
||||
return z_q, commitment_loss, codebook_loss, indices, z_e
|
||||
|
||||
def embed_code(self, embed_id):
|
||||
return F.embedding(embed_id, self.codebook.weight)
|
||||
|
||||
def decode_code(self, embed_id):
|
||||
return self.embed_code(embed_id).transpose([0, 2, 1])
|
||||
|
||||
def decode_latents(self, latents):
|
||||
# encodings = Rearrange('b d t -> (b t) d')(latents)
|
||||
encodings = latents.transpose([0, 2, 1]).flatten(stop_axis=1)
|
||||
codebook = self.codebook.weight # codebook: (N x D)
|
||||
|
||||
# L2 normalize encodings and codebook (ViT-VQGAN)
|
||||
encodings = F.normalize(encodings, axis=1)
|
||||
codebook = F.normalize(codebook, axis=1)
|
||||
|
||||
# Compute euclidean distance with codebook
|
||||
dist = (encodings.pow(2).sum(axis=1, keepdim=True) - 2 * encodings
|
||||
@ codebook.T + codebook.pow(2).sum(axis=1, keepdim=True).T)
|
||||
# indices = rearrange((-dist).max(1)[1], "(b t) -> b t", b=latents.size(0))
|
||||
indices = (-dist).argmax(axis=1).reshape([latents.shape[0], -1])
|
||||
z_q = self.decode_code(indices)
|
||||
return z_q, indices
|
||||
|
||||
|
||||
class ResidualVectorQuantize(nn.Layer):
|
||||
def __init__(
|
||||
self,
|
||||
input_dim: int=512,
|
||||
n_codebooks: int=9,
|
||||
codebook_size: int=1024,
|
||||
codebook_dim: Union[int, list]=8,
|
||||
quantizer_dropout: float=0.0, ):
|
||||
super().__init__()
|
||||
if isinstance(codebook_dim, int):
|
||||
codebook_dim = [codebook_dim for _ in range(n_codebooks)]
|
||||
|
||||
self.n_codebooks = n_codebooks
|
||||
self.codebook_dim = codebook_dim
|
||||
self.codebook_size = codebook_size
|
||||
|
||||
self.quantizers = nn.LayerList([
|
||||
VectorQuantize(input_dim, codebook_size, codebook_dim[i])
|
||||
for i in range(n_codebooks)
|
||||
])
|
||||
self.quantizer_dropout = quantizer_dropout
|
||||
|
||||
def forward(self, z, n_quantizers: int=None):
|
||||
z_q = 0
|
||||
residual = z
|
||||
commitment_loss = 0
|
||||
codebook_loss = 0
|
||||
|
||||
codebook_indices = []
|
||||
latents = []
|
||||
|
||||
if n_quantizers is None:
|
||||
n_quantizers = self.n_codebooks
|
||||
if self.training:
|
||||
n_quantizers = paddle.ones((z.shape[0], )) * self.n_codebooks + 1
|
||||
dropout = paddle.randint(1, self.n_codebooks + 1, (z.shape[0], ))
|
||||
n_dropout = int(z.shape[0] * self.quantizer_dropout)
|
||||
if dropout[:n_dropout].size:
|
||||
n_quantizers[:n_dropout] = dropout[:n_dropout]
|
||||
|
||||
for i, quantizer in enumerate(self.quantizers):
|
||||
if not self.training and i >= n_quantizers:
|
||||
break
|
||||
|
||||
z_q_i, commitment_loss_i, codebook_loss_i, indices_i, z_e_i = quantizer(
|
||||
residual)
|
||||
|
||||
mask = (paddle.full(
|
||||
(z.shape[0], ), fill_value=i) < n_quantizers).astype("float32")
|
||||
z_q = z_q + z_q_i * mask[:, None, None]
|
||||
residual = residual - z_q_i
|
||||
|
||||
commitment_loss += (commitment_loss_i * mask).mean()
|
||||
codebook_loss += (codebook_loss_i * mask).mean()
|
||||
|
||||
codebook_indices.append(indices_i)
|
||||
latents.append(z_e_i)
|
||||
|
||||
codes = paddle.stack(codebook_indices, axis=1)
|
||||
latents = paddle.concat(latents, axis=1)
|
||||
|
||||
return z_q, codes, latents, commitment_loss, codebook_loss
|
||||
|
||||
def from_codes(self, codes: paddle.Tensor):
|
||||
z_q = 0
|
||||
z_p = []
|
||||
n_codebooks = codes.shape[1]
|
||||
for i in range(n_codebooks):
|
||||
z_p_i = self.quantizers[i].decode_code(codes[:, i, :])
|
||||
z_p.append(z_p_i)
|
||||
|
||||
z_q_i = self.quantizers[i].out_proj(z_p_i)
|
||||
z_q = z_q + z_q_i
|
||||
return z_q, paddle.concat(z_p, axis=1), codes
|
||||
|
||||
def from_latents(self, latents: paddle.Tensor):
|
||||
z_q = 0
|
||||
z_p = []
|
||||
codes = []
|
||||
dims = np.cumsum([0] + [q.codebook_dim for q in self.quantizers])
|
||||
|
||||
n_codebooks = np.where(dims <= latents.shape[1])[0].max(
|
||||
axis=0, keepdims=True)[0]
|
||||
for i in range(n_codebooks):
|
||||
j, k = dims[i], dims[i + 1]
|
||||
z_p_i, codes_i = self.quantizers[i].decode_latents(
|
||||
latents[:, j:k, :])
|
||||
z_p.append(z_p_i)
|
||||
codes.append(codes_i)
|
||||
|
||||
z_q_i = self.quantizers[i].out_proj(z_p_i)
|
||||
z_q = z_q + z_q_i
|
||||
|
||||
return z_q, paddle.concat(z_p, axis=1), paddle.stack(codes, axis=1)
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
rvq = ResidualVectorQuantize(quantizer_dropout=True)
|
||||
x = paddle.randn((16, 512, 80))
|
||||
y = rvq(x)
|
||||
print(y[2].shape)
|
@ -0,0 +1,90 @@
|
||||
import warnings
|
||||
from pathlib import Path
|
||||
|
||||
import argbind
|
||||
import numpy as np
|
||||
import paddle
|
||||
from tqdm import tqdm
|
||||
|
||||
from dac import DACFile
|
||||
from dac.utils import load_model
|
||||
|
||||
warnings.filterwarnings("ignore", category=UserWarning)
|
||||
|
||||
|
||||
@argbind.bind(group="decode", positional=True, without_prefix=True)
|
||||
@paddle.no_grad()
|
||||
def decode(
|
||||
input: str,
|
||||
output: str="",
|
||||
weights_path: str="",
|
||||
model_tag: str="latest",
|
||||
model_bitrate: str="8kbps",
|
||||
device: str="cuda",
|
||||
model_type: str="44khz",
|
||||
verbose: bool=False, ):
|
||||
"""Decode audio from codes.
|
||||
|
||||
Parameters
|
||||
----------
|
||||
input : str
|
||||
Path to input directory or file
|
||||
output : str, optional
|
||||
Path to output directory, by default "".
|
||||
If `input` is a directory, the directory sub-tree relative to `input` is re-created in `output`.
|
||||
weights_path : str, optional
|
||||
Path to weights file, by default "". If not specified, the weights file will be downloaded from the internet using the
|
||||
model_tag and model_type.
|
||||
model_tag : str, optional
|
||||
Tag of the model to use, by default "latest". Ignored if `weights_path` is specified.
|
||||
model_bitrate: str
|
||||
Bitrate of the model. Must be one of "8kbps", or "16kbps". Defaults to "8kbps".
|
||||
device : str, optional
|
||||
Device to use, by default "cuda". If "cpu", the model will be loaded on the CPU.
|
||||
model_type : str, optional
|
||||
The type of model to use. Must be one of "44khz", "24khz", or "16khz". Defaults to "44khz". Ignored if `weights_path` is specified.
|
||||
"""
|
||||
generator = load_model(
|
||||
model_type=model_type,
|
||||
model_bitrate=model_bitrate,
|
||||
tag=model_tag,
|
||||
load_path=weights_path, )
|
||||
generator.eval()
|
||||
|
||||
# Find all .dac files in input directory
|
||||
_input = Path(input)
|
||||
input_files = list(_input.glob("**/*.dac"))
|
||||
|
||||
# If input is a .dac file, add it to the list
|
||||
if _input.suffix == ".dac":
|
||||
input_files.append(_input)
|
||||
|
||||
# Create output directory
|
||||
output = Path(output)
|
||||
output.mkdir(parents=True, exist_ok=True)
|
||||
|
||||
for i in tqdm(range(len(input_files)), desc="Decoding files"):
|
||||
# Load file
|
||||
artifact = DACFile.load(input_files[i])
|
||||
|
||||
# Reconstruct audio from codes
|
||||
recons = generator.decompress(artifact, verbose=verbose)
|
||||
|
||||
# Compute output path
|
||||
relative_path = input_files[i].relative_to(input)
|
||||
output_dir = output / relative_path.parent
|
||||
if not relative_path.name:
|
||||
output_dir = output
|
||||
relative_path = input_files[i]
|
||||
output_name = relative_path.with_suffix(".wav").name
|
||||
output_path = output_dir / output_name
|
||||
output_path.parent.mkdir(parents=True, exist_ok=True)
|
||||
|
||||
# Write to file
|
||||
recons.write(output_path)
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
args = argbind.parse_args()
|
||||
with argbind.scope(args):
|
||||
decode()
|
@ -0,0 +1,89 @@
|
||||
import math
|
||||
import warnings
|
||||
from pathlib import Path
|
||||
|
||||
import argbind
|
||||
import paddle
|
||||
from tqdm import tqdm
|
||||
|
||||
from dac.utils import load_model
|
||||
from paddlespeech.audiotools import AudioSignal
|
||||
from paddlespeech.audiotools.core import util
|
||||
|
||||
warnings.filterwarnings("ignore", category=UserWarning)
|
||||
|
||||
|
||||
@argbind.bind(group="encode", positional=True, without_prefix=True)
|
||||
@paddle.no_grad()
|
||||
def encode(
|
||||
input: str,
|
||||
output: str="",
|
||||
weights_path: str="",
|
||||
model_tag: str="latest",
|
||||
model_bitrate: str="8kbps",
|
||||
n_quantizers: int=None,
|
||||
model_type: str="44khz",
|
||||
win_duration: float=5.0,
|
||||
verbose: bool=False, ):
|
||||
"""Encode audio files in input path to .dac format.
|
||||
|
||||
Parameters
|
||||
----------
|
||||
input : str
|
||||
Path to input audio file or directory
|
||||
output : str, optional
|
||||
Path to output directory, by default "". If `input` is a directory, the directory sub-tree relative to `input` is re-created in `output`.
|
||||
weights_path : str, optional
|
||||
Path to weights file, by default "". If not specified, the weights file will be downloaded from the internet using the
|
||||
model_tag and model_type.
|
||||
model_tag : str, optional
|
||||
Tag of the model to use, by default "latest". Ignored if `weights_path` is specified.
|
||||
model_bitrate: str
|
||||
Bitrate of the model. Must be one of "8kbps", or "16kbps". Defaults to "8kbps".
|
||||
n_quantizers : int, optional
|
||||
Number of quantizers to use, by default None. If not specified, all the quantizers will be used and the model will compress at maximum bitrate.
|
||||
device : str, optional
|
||||
Device to use, by default "cuda"
|
||||
model_type : str, optional
|
||||
The type of model to use. Must be one of "44khz", "24khz", or "16khz". Defaults to "44khz". Ignored if `weights_path` is specified.
|
||||
"""
|
||||
generator = load_model(
|
||||
model_type=model_type,
|
||||
model_bitrate=model_bitrate,
|
||||
tag=model_tag,
|
||||
load_path=weights_path, )
|
||||
generator.eval()
|
||||
kwargs = {"n_quantizers": n_quantizers}
|
||||
|
||||
# Find all audio files in input path
|
||||
input = Path(input)
|
||||
audio_files = util.find_audio(input)
|
||||
|
||||
output = Path(output)
|
||||
output.mkdir(parents=True, exist_ok=True)
|
||||
|
||||
for i in tqdm(range(len(audio_files)), desc="Encoding files"):
|
||||
# Load file
|
||||
signal = AudioSignal(audio_files[i])
|
||||
|
||||
# Encode audio to .dac format
|
||||
artifact = generator.compress(
|
||||
signal, win_duration, verbose=verbose, **kwargs)
|
||||
|
||||
# Compute output path
|
||||
relative_path = audio_files[i].relative_to(input)
|
||||
output_dir = output / relative_path.parent
|
||||
if not relative_path.name:
|
||||
output_dir = output
|
||||
relative_path = audio_files[i]
|
||||
output_name = relative_path.with_suffix(".dac").name
|
||||
output_path = output_dir / output_name
|
||||
output_path.parent.mkdir(parents=True, exist_ok=True)
|
||||
|
||||
artifact.save(output_path)
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
args = argbind.parse_args()
|
||||
with argbind.scope(args):
|
||||
encode()
|
Loading…
Reference in new issue