pull/3973/head
drryanhuang 7 months ago
parent 490dca9e80
commit 595a88a1e3

@ -12,8 +12,8 @@
# See the License for the specific language governing permissions and # See the License for the specific language governing permissions and
# limitations under the License. # limitations under the License.
from . import util from . import util
from ._julius import fft_conv1d # from ._julius import fft_conv1d
from ._julius import FFTConv1D # from ._julius import FFTConv1D
from ._julius import highpass_filter from ._julius import highpass_filter
from ._julius import highpass_filters from ._julius import highpass_filters
from ._julius import lowpass_filter from ._julius import lowpass_filter

@ -20,8 +20,7 @@ import paddle
import paddle.nn as nn import paddle.nn as nn
import paddle.nn.functional as F import paddle.nn.functional as F
from paddlespeech.t2s.modules import fft_conv1d
from paddlespeech.t2s.modules import FFTConv1D
from paddlespeech.utils import satisfy_paddle_version from paddlespeech.utils import satisfy_paddle_version
__all__ = [ __all__ = [
@ -312,6 +311,7 @@ class LowPassFilters(nn.Layer):
mode="replicate", mode="replicate",
data_format="NCL") data_format="NCL")
if self.fft: if self.fft:
from paddlespeech.t2s.modules.fftconv1d import fft_conv1d
out = fft_conv1d(_input, self.filters, stride=self.stride) out = fft_conv1d(_input, self.filters, stride=self.stride)
else: else:
out = F.conv1d(_input, self.filters, stride=self.stride) out = F.conv1d(_input, self.filters, stride=self.stride)

@ -13,12 +13,9 @@ import typing
from contextlib import contextmanager from contextlib import contextmanager
from dataclasses import dataclass from dataclasses import dataclass
from pathlib import Path from pathlib import Path
from typing import Any
from typing import Callable from typing import Callable
from typing import Dict from typing import Dict
from typing import Iterable
from typing import List from typing import List
from typing import NamedTuple
from typing import Optional from typing import Optional
from typing import Tuple from typing import Tuple
from typing import Type from typing import Type
@ -34,7 +31,6 @@ from flatten_dict import unflatten
from .audio_signal import AudioSignal from .audio_signal import AudioSignal
from paddlespeech.utils import satisfy_paddle_version from paddlespeech.utils import satisfy_paddle_version
from paddlespeech.vector.training.seeding import seed_everything
__all__ = [ __all__ = [
"exp_compat", "exp_compat",

@ -42,6 +42,23 @@ class ResumableSequentialSampler(SequenceSampler):
self.start_idx = 0 # set the index back to 0 so for the next epoch self.start_idx = 0 # set the index back to 0 so for the next epoch
class DummyScaler:
def __init__(self):
pass
def step(self, optimizer):
optimizer.step()
def scale(self, loss):
return loss
def unscale_(self, optimizer):
return optimizer
def update(self):
pass
class Accelerator: class Accelerator:
"""This class is used to prepare models and dataloaders for """This class is used to prepare models and dataloaders for
usage with DDP or DP. Use the functions prepare_model, prepare_dataloader to usage with DDP or DP. Use the functions prepare_model, prepare_dataloader to
@ -78,22 +95,6 @@ class Accelerator:
self.local_rank = 0 if trainer_id is None else int(trainer_id) self.local_rank = 0 if trainer_id is None else int(trainer_id)
self.amp = amp self.amp = amp
class DummyScaler:
def __init__(self):
pass
def step(self, optimizer):
optimizer.step()
def scale(self, loss):
return loss
def unscale_(self, optimizer):
return optimizer
def update(self):
pass
self.scaler = paddle.amp.GradScaler() if self.amp else DummyScaler() self.scaler = paddle.amp.GradScaler() if self.amp else DummyScaler()
def __enter__(self): def __enter__(self):

@ -1,3 +1,4 @@
# TODO(DrRyanHuang): rm this file
# MIT License, Copyright (c) 2023-Present, Descript. # MIT License, Copyright (c) 2023-Present, Descript.
# Copyright (c) 2025 PaddlePaddle Authors. All Rights Reserved. # Copyright (c) 2025 PaddlePaddle Authors. All Rights Reserved.
# #

@ -10,25 +10,26 @@ import paddle.nn as nn
from visualdl import LogWriter from visualdl import LogWriter
import paddlespeech import paddlespeech
import paddlespeech.t2s.modules.losses as losses import paddlespeech.t2s.modules.losses as _losses
from paddlespeech.audiotools import ml
from paddlespeech.audiotools.core import AudioSignal from paddlespeech.audiotools.core import AudioSignal
from paddlespeech.audiotools.core import util from paddlespeech.audiotools.core import util
from paddlespeech.audiotools.data import transforms from paddlespeech.audiotools.data import transforms
from paddlespeech.audiotools.data.datasets import AudioDataset from paddlespeech.audiotools.data.datasets import AudioDataset
from paddlespeech.audiotools.data.datasets import AudioLoader from paddlespeech.audiotools.data.datasets import AudioLoader
from paddlespeech.audiotools.data.datasets import ConcatDataset from paddlespeech.audiotools.data.datasets import ConcatDataset
from paddlespeech.audiotools.ml import Accelerator
from paddlespeech.audiotools.ml.decorators import timer from paddlespeech.audiotools.ml.decorators import timer
from paddlespeech.audiotools.ml.decorators import Tracker from paddlespeech.audiotools.ml.decorators import Tracker
from paddlespeech.audiotools.ml.decorators import when from paddlespeech.audiotools.ml.decorators import when
from paddlespeech.codec.models.dac_.model import DAC from paddlespeech.codec.models.dac_.model import DAC
from paddlespeech.codec.models.dac_.model import Discriminator from paddlespeech.codec.models.dac_.model import Discriminator
from paddlespeech.t2s.training.seeding import seed_everything
warnings.filterwarnings("ignore", category=UserWarning) warnings.filterwarnings("ignore", category=UserWarning)
# Optimizers # Optimizers
AdamW = argbind.bind(paddle.optimizer.AdamW, "generator", "discriminator") AdamW = argbind.bind(paddle.optimizer.AdamW, "generator", "discriminator")
Accelerator = argbind.bind(ml.Accelerator, without_prefix=True) # Accelerator = argbind.bind(ml.Accelerator, without_prefix=True)
@argbind.bind("generator", "discriminator") @argbind.bind("generator", "discriminator")
@ -46,17 +47,14 @@ Discriminator = argbind.bind(Discriminator)
AudioDataset = argbind.bind(AudioDataset, "train", "val") AudioDataset = argbind.bind(AudioDataset, "train", "val")
AudioLoader = argbind.bind(AudioLoader, "train", "val") AudioLoader = argbind.bind(AudioLoader, "train", "val")
# Transforms
filter_fn = lambda fn: hasattr(fn, "transform") and fn.__qualname__ not in [
"BaseTransform",
"Compose",
"Choose", ]
tfm = argbind.bind_module(transforms, "train", "val", filter_fn=filter_fn)
# Loss # Loss
filter_fn = lambda fn: hasattr(fn, "forward") and "Loss" in fn.__name__ # filter_fn = lambda fn: hasattr(fn, "forward") and "Loss" in fn.__name__
losses = argbind.bind_module( def filter_fn(fn):
paddlespeech.t2s.modules.losses, filter_fn=filter_fn) return hasattr(fn, "forward") and "Loss" in fn.__name__
losses = argbind.bind_module(_losses, filter_fn=filter_fn)
def get_infinite_loader(dataloader): def get_infinite_loader(dataloader):
@ -65,13 +63,33 @@ def get_infinite_loader(dataloader):
yield batch yield batch
# Transforms
# filter_fn = lambda fn: hasattr(fn, "transform") and fn.__qualname__ not in [
# "BaseTransform",
# "Compose",
# "Choose", ]
def filter_fn(fn):
return hasattr(fn, "transform") and fn.__qualname__ not in [
"BaseTransform",
"Compose",
"Choose",
]
tfm = argbind.bind_module(transforms, "train", "val", filter_fn=filter_fn)
# to_tfm = lambda l: [getattr(tfm, x)() for x in l]
def to_tfm(l):
return [getattr(tfm, x)() for x in l]
@argbind.bind("train", "val") @argbind.bind("train", "val")
def build_transform( def build_transform(
augment_prob: float=1.0, augment_prob: float=1.0,
preprocess: list=["Identity"], preprocess: list=["Identity"],
augment: list=["Identity"], augment: list=["Identity"],
postprocess: list=["Identity"], ): postprocess: list=["Identity"], ):
to_tfm = lambda l: [getattr(tfm, x)() for x in l]
preprocess = transforms.Compose(*to_tfm(preprocess), name="preprocess") preprocess = transforms.Compose(*to_tfm(preprocess), name="preprocess")
augment = transforms.Compose( augment = transforms.Compose(
*to_tfm(augment), name="augment", prob=augment_prob) *to_tfm(augment), name="augment", prob=augment_prob)
@ -121,10 +139,10 @@ class State:
tracker: Tracker tracker: Tracker
@argbind.bind(without_prefix=True) # @argbind.bind(without_prefix=True)
def load( def load(
args, args,
accel: ml.Accelerator, accel: Accelerator,
tracker: Tracker, tracker: Tracker,
save_path: str, save_path: str,
resume: bool=False, resume: bool=False,
@ -282,7 +300,7 @@ def checkpoint(state, save_iters, save_path):
tags = ["latest"] tags = ["latest"]
state.tracker.print(f"Saving to {str(Path('.').absolute())}") state.tracker.print(f"Saving to {str(Path('.').absolute())}")
if state.tracker.is_best("val", "mel/loss"): if state.tracker.is_best("val", "mel/loss"):
state.tracker.print(f"Best generator so far") state.tracker.print("Best generator so far")
tags.append("best") tags.append("best")
if state.tracker.step in save_iters: if state.tracker.step in save_iters:
tags.append(f"{state.tracker.step // 1000}k") tags.append(f"{state.tracker.step // 1000}k")
@ -339,11 +357,11 @@ def validate(state, val_dataloader, accel):
return output return output
@argbind.bind(without_prefix=True) # @argbind.bind(without_prefix=True)
def train( def train(
args, args,
accel: ml.Accelerator, accel: Accelerator,
seed: int=0, seed: int=2025,
save_path: str="ckpt", save_path: str="ckpt",
num_iters: int=250000, num_iters: int=250000,
save_iters: list=[10000, 50000, 100000, 200000], save_iters: list=[10000, 50000, 100000, 200000],
@ -360,7 +378,7 @@ def train(
"vq/commitment_loss": 0.25, "vq/commitment_loss": 0.25,
"vq/codebook_loss": 1.0, "vq/codebook_loss": 1.0,
}, ): }, ):
util.seed(seed) seed_everything(seed)
Path(save_path).mkdir(exist_ok=True, parents=True) Path(save_path).mkdir(exist_ok=True, parents=True)
writer = LogWriter( writer = LogWriter(
log_dir=f"{save_path}/logs") if accel.local_rank == 0 else None log_dir=f"{save_path}/logs") if accel.local_rank == 0 else None
@ -401,7 +419,8 @@ def train(
train_dataloader, start=tracker.step): train_dataloader, start=tracker.step):
train_loop(state, batch, accel, lambdas) train_loop(state, batch, accel, lambdas)
last_iter = tracker.step == num_iters - 1 if num_iters is not None else False last_iter = (tracker.step == num_iters - 1
if num_iters is not None else False)
if tracker.step % sample_freq == 0 or last_iter: if tracker.step % sample_freq == 0 or last_iter:
save_samples(state, val_idx, writer) save_samples(state, val_idx, writer)
@ -416,10 +435,11 @@ def train(
if __name__ == "__main__": if __name__ == "__main__":
args = argbind.parse_args() # args = argbind.parse_args()
args["args.debug"] = int(os.getenv("LOCAL_RANK", 0)) == 0 # args["args.debug"] = int(os.getenv("LOCAL_RANK", 0)) == 0
with argbind.scope(args): # with argbind.scope(args):
with Accelerator() as accel: with Accelerator() as accel:
if accel.local_rank != 0: if accel.local_rank != 0:
sys.tracebacklimit = 0 sys.tracebacklimit = 0
train(args, accel) # train(args, accel)
train(None, accel)

@ -1,7 +1,4 @@
import math
from functools import partial from functools import partial
from typing import List
from typing import Union
import numpy as np import numpy as np
import paddle import paddle
@ -9,7 +6,6 @@ import paddle.nn.functional as F
from paddle import nn from paddle import nn
from paddlespeech.audiotools import AudioSignal from paddlespeech.audiotools import AudioSignal
from paddlespeech.audiotools.ml import BaseModel
from paddlespeech.codec.models.dac_.model.base import CodecMixin from paddlespeech.codec.models.dac_.model.base import CodecMixin
from paddlespeech.codec.models.dac_.nn.layers import Snake1d from paddlespeech.codec.models.dac_.nn.layers import Snake1d
from paddlespeech.codec.models.dac_.nn.layers import WNConv1d from paddlespeech.codec.models.dac_.nn.layers import WNConv1d
@ -195,7 +191,7 @@ class DAC(nn.Layer, CodecMixin):
length = audio_data.shape[-1] length = audio_data.shape[-1]
right_pad = np.ceil(length / self.hop_length) * self.hop_length - length right_pad = np.ceil(length / self.hop_length) * self.hop_length - length
audio_data = F.pad( audio_data = F.pad(
audio_data, [0, right_pad], audio_data, [0, int(right_pad)],
mode='constant', mode='constant',
value=0, value=0,
data_format="NCL") data_format="NCL")

@ -4,10 +4,9 @@ import numpy as np
import paddle import paddle
import paddle.nn as nn import paddle.nn as nn
import paddle.nn.functional as F import paddle.nn.functional as F
from paddle.nn.utils import weight_norm
from .layers import WNConv1d from paddlespeech.codec.models.dac_.nn.layers import WNConv1d
# from dac.nn.layers import WNConv1d
class VectorQuantize(nn.Layer): class VectorQuantize(nn.Layer):
@ -106,7 +105,8 @@ class ResidualVectorQuantize(nn.Layer):
z_q_i, commitment_loss_i, codebook_loss_i, indices_i, z_e_i = quantizer( z_q_i, commitment_loss_i, codebook_loss_i, indices_i, z_e_i = quantizer(
residual) residual)
mask = paddle.full((z.shape[0], ), fill_value=i) < n_quantizers mask = (paddle.full(
(z.shape[0], ), fill_value=i) < n_quantizers).astype("float32")
z_q = z_q + z_q_i * mask[:, None, None] z_q = z_q + z_q_i * mask[:, None, None]
residual = residual - z_q_i residual = residual - z_q_i

@ -3,17 +3,17 @@ from pathlib import Path
import argbind import argbind
import numpy as np import numpy as np
import torch import paddle
from tqdm import tqdm
from dac import DACFile from dac import DACFile
from dac.utils import load_model from dac.utils import load_model
from tqdm import tqdm
warnings.filterwarnings("ignore", category=UserWarning) warnings.filterwarnings("ignore", category=UserWarning)
@argbind.bind(group="decode", positional=True, without_prefix=True) @argbind.bind(group="decode", positional=True, without_prefix=True)
@torch.inference_mode() @paddle.no_grad()
@torch.no_grad()
def decode( def decode(
input: str, input: str,
output: str="", output: str="",
@ -49,7 +49,6 @@ def decode(
model_bitrate=model_bitrate, model_bitrate=model_bitrate,
tag=model_tag, tag=model_tag,
load_path=weights_path, ) load_path=weights_path, )
generator.to(device)
generator.eval() generator.eval()
# Find all .dac files in input directory # Find all .dac files in input directory
@ -64,7 +63,7 @@ def decode(
output = Path(output) output = Path(output)
output.mkdir(parents=True, exist_ok=True) output.mkdir(parents=True, exist_ok=True)
for i in tqdm(range(len(input_files)), desc=f"Decoding files"): for i in tqdm(range(len(input_files)), desc="Decoding files"):
# Load file # Load file
artifact = DACFile.load(input_files[i]) artifact = DACFile.load(input_files[i])

@ -3,18 +3,18 @@ import warnings
from pathlib import Path from pathlib import Path
import argbind import argbind
import torch import paddle
from audiotools import AudioSignal
from audiotools.core import util
from dac.utils import load_model
from tqdm import tqdm from tqdm import tqdm
from dac.utils import load_model
from paddlespeech.audiotools import AudioSignal
from paddlespeech.audiotools.core import util
warnings.filterwarnings("ignore", category=UserWarning) warnings.filterwarnings("ignore", category=UserWarning)
@argbind.bind(group="encode", positional=True, without_prefix=True) @argbind.bind(group="encode", positional=True, without_prefix=True)
@torch.inference_mode() @paddle.no_grad()
@torch.no_grad()
def encode( def encode(
input: str, input: str,
output: str="", output: str="",
@ -22,7 +22,6 @@ def encode(
model_tag: str="latest", model_tag: str="latest",
model_bitrate: str="8kbps", model_bitrate: str="8kbps",
n_quantizers: int=None, n_quantizers: int=None,
device: str="cuda",
model_type: str="44khz", model_type: str="44khz",
win_duration: float=5.0, win_duration: float=5.0,
verbose: bool=False, ): verbose: bool=False, ):
@ -53,7 +52,6 @@ def encode(
model_bitrate=model_bitrate, model_bitrate=model_bitrate,
tag=model_tag, tag=model_tag,
load_path=weights_path, ) load_path=weights_path, )
generator.to(device)
generator.eval() generator.eval()
kwargs = {"n_quantizers": n_quantizers} kwargs = {"n_quantizers": n_quantizers}

Loading…
Cancel
Save