|
|
@ -10,25 +10,26 @@ import paddle.nn as nn
|
|
|
|
from visualdl import LogWriter
|
|
|
|
from visualdl import LogWriter
|
|
|
|
|
|
|
|
|
|
|
|
import paddlespeech
|
|
|
|
import paddlespeech
|
|
|
|
import paddlespeech.t2s.modules.losses as losses
|
|
|
|
import paddlespeech.t2s.modules.losses as _losses
|
|
|
|
from paddlespeech.audiotools import ml
|
|
|
|
|
|
|
|
from paddlespeech.audiotools.core import AudioSignal
|
|
|
|
from paddlespeech.audiotools.core import AudioSignal
|
|
|
|
from paddlespeech.audiotools.core import util
|
|
|
|
from paddlespeech.audiotools.core import util
|
|
|
|
from paddlespeech.audiotools.data import transforms
|
|
|
|
from paddlespeech.audiotools.data import transforms
|
|
|
|
from paddlespeech.audiotools.data.datasets import AudioDataset
|
|
|
|
from paddlespeech.audiotools.data.datasets import AudioDataset
|
|
|
|
from paddlespeech.audiotools.data.datasets import AudioLoader
|
|
|
|
from paddlespeech.audiotools.data.datasets import AudioLoader
|
|
|
|
from paddlespeech.audiotools.data.datasets import ConcatDataset
|
|
|
|
from paddlespeech.audiotools.data.datasets import ConcatDataset
|
|
|
|
|
|
|
|
from paddlespeech.audiotools.ml import Accelerator
|
|
|
|
from paddlespeech.audiotools.ml.decorators import timer
|
|
|
|
from paddlespeech.audiotools.ml.decorators import timer
|
|
|
|
from paddlespeech.audiotools.ml.decorators import Tracker
|
|
|
|
from paddlespeech.audiotools.ml.decorators import Tracker
|
|
|
|
from paddlespeech.audiotools.ml.decorators import when
|
|
|
|
from paddlespeech.audiotools.ml.decorators import when
|
|
|
|
from paddlespeech.codec.models.dac_.model import DAC
|
|
|
|
from paddlespeech.codec.models.dac_.model import DAC
|
|
|
|
from paddlespeech.codec.models.dac_.model import Discriminator
|
|
|
|
from paddlespeech.codec.models.dac_.model import Discriminator
|
|
|
|
|
|
|
|
from paddlespeech.t2s.training.seeding import seed_everything
|
|
|
|
|
|
|
|
|
|
|
|
warnings.filterwarnings("ignore", category=UserWarning)
|
|
|
|
warnings.filterwarnings("ignore", category=UserWarning)
|
|
|
|
|
|
|
|
|
|
|
|
# Optimizers
|
|
|
|
# Optimizers
|
|
|
|
AdamW = argbind.bind(paddle.optimizer.AdamW, "generator", "discriminator")
|
|
|
|
AdamW = argbind.bind(paddle.optimizer.AdamW, "generator", "discriminator")
|
|
|
|
Accelerator = argbind.bind(ml.Accelerator, without_prefix=True)
|
|
|
|
# Accelerator = argbind.bind(ml.Accelerator, without_prefix=True)
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
@argbind.bind("generator", "discriminator")
|
|
|
|
@argbind.bind("generator", "discriminator")
|
|
|
@ -46,17 +47,14 @@ Discriminator = argbind.bind(Discriminator)
|
|
|
|
AudioDataset = argbind.bind(AudioDataset, "train", "val")
|
|
|
|
AudioDataset = argbind.bind(AudioDataset, "train", "val")
|
|
|
|
AudioLoader = argbind.bind(AudioLoader, "train", "val")
|
|
|
|
AudioLoader = argbind.bind(AudioLoader, "train", "val")
|
|
|
|
|
|
|
|
|
|
|
|
# Transforms
|
|
|
|
|
|
|
|
filter_fn = lambda fn: hasattr(fn, "transform") and fn.__qualname__ not in [
|
|
|
|
|
|
|
|
"BaseTransform",
|
|
|
|
|
|
|
|
"Compose",
|
|
|
|
|
|
|
|
"Choose", ]
|
|
|
|
|
|
|
|
tfm = argbind.bind_module(transforms, "train", "val", filter_fn=filter_fn)
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
# Loss
|
|
|
|
# Loss
|
|
|
|
filter_fn = lambda fn: hasattr(fn, "forward") and "Loss" in fn.__name__
|
|
|
|
# filter_fn = lambda fn: hasattr(fn, "forward") and "Loss" in fn.__name__
|
|
|
|
losses = argbind.bind_module(
|
|
|
|
def filter_fn(fn):
|
|
|
|
paddlespeech.t2s.modules.losses, filter_fn=filter_fn)
|
|
|
|
return hasattr(fn, "forward") and "Loss" in fn.__name__
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
losses = argbind.bind_module(_losses, filter_fn=filter_fn)
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
def get_infinite_loader(dataloader):
|
|
|
|
def get_infinite_loader(dataloader):
|
|
|
@ -65,13 +63,33 @@ def get_infinite_loader(dataloader):
|
|
|
|
yield batch
|
|
|
|
yield batch
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
# Transforms
|
|
|
|
|
|
|
|
# filter_fn = lambda fn: hasattr(fn, "transform") and fn.__qualname__ not in [
|
|
|
|
|
|
|
|
# "BaseTransform",
|
|
|
|
|
|
|
|
# "Compose",
|
|
|
|
|
|
|
|
# "Choose", ]
|
|
|
|
|
|
|
|
def filter_fn(fn):
|
|
|
|
|
|
|
|
return hasattr(fn, "transform") and fn.__qualname__ not in [
|
|
|
|
|
|
|
|
"BaseTransform",
|
|
|
|
|
|
|
|
"Compose",
|
|
|
|
|
|
|
|
"Choose",
|
|
|
|
|
|
|
|
]
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
tfm = argbind.bind_module(transforms, "train", "val", filter_fn=filter_fn)
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
# to_tfm = lambda l: [getattr(tfm, x)() for x in l]
|
|
|
|
|
|
|
|
def to_tfm(l):
|
|
|
|
|
|
|
|
return [getattr(tfm, x)() for x in l]
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
@argbind.bind("train", "val")
|
|
|
|
@argbind.bind("train", "val")
|
|
|
|
def build_transform(
|
|
|
|
def build_transform(
|
|
|
|
augment_prob: float=1.0,
|
|
|
|
augment_prob: float=1.0,
|
|
|
|
preprocess: list=["Identity"],
|
|
|
|
preprocess: list=["Identity"],
|
|
|
|
augment: list=["Identity"],
|
|
|
|
augment: list=["Identity"],
|
|
|
|
postprocess: list=["Identity"], ):
|
|
|
|
postprocess: list=["Identity"], ):
|
|
|
|
to_tfm = lambda l: [getattr(tfm, x)() for x in l]
|
|
|
|
|
|
|
|
preprocess = transforms.Compose(*to_tfm(preprocess), name="preprocess")
|
|
|
|
preprocess = transforms.Compose(*to_tfm(preprocess), name="preprocess")
|
|
|
|
augment = transforms.Compose(
|
|
|
|
augment = transforms.Compose(
|
|
|
|
*to_tfm(augment), name="augment", prob=augment_prob)
|
|
|
|
*to_tfm(augment), name="augment", prob=augment_prob)
|
|
|
@ -121,10 +139,10 @@ class State:
|
|
|
|
tracker: Tracker
|
|
|
|
tracker: Tracker
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
@argbind.bind(without_prefix=True)
|
|
|
|
# @argbind.bind(without_prefix=True)
|
|
|
|
def load(
|
|
|
|
def load(
|
|
|
|
args,
|
|
|
|
args,
|
|
|
|
accel: ml.Accelerator,
|
|
|
|
accel: Accelerator,
|
|
|
|
tracker: Tracker,
|
|
|
|
tracker: Tracker,
|
|
|
|
save_path: str,
|
|
|
|
save_path: str,
|
|
|
|
resume: bool=False,
|
|
|
|
resume: bool=False,
|
|
|
@ -282,7 +300,7 @@ def checkpoint(state, save_iters, save_path):
|
|
|
|
tags = ["latest"]
|
|
|
|
tags = ["latest"]
|
|
|
|
state.tracker.print(f"Saving to {str(Path('.').absolute())}")
|
|
|
|
state.tracker.print(f"Saving to {str(Path('.').absolute())}")
|
|
|
|
if state.tracker.is_best("val", "mel/loss"):
|
|
|
|
if state.tracker.is_best("val", "mel/loss"):
|
|
|
|
state.tracker.print(f"Best generator so far")
|
|
|
|
state.tracker.print("Best generator so far")
|
|
|
|
tags.append("best")
|
|
|
|
tags.append("best")
|
|
|
|
if state.tracker.step in save_iters:
|
|
|
|
if state.tracker.step in save_iters:
|
|
|
|
tags.append(f"{state.tracker.step // 1000}k")
|
|
|
|
tags.append(f"{state.tracker.step // 1000}k")
|
|
|
@ -339,11 +357,11 @@ def validate(state, val_dataloader, accel):
|
|
|
|
return output
|
|
|
|
return output
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
@argbind.bind(without_prefix=True)
|
|
|
|
# @argbind.bind(without_prefix=True)
|
|
|
|
def train(
|
|
|
|
def train(
|
|
|
|
args,
|
|
|
|
args,
|
|
|
|
accel: ml.Accelerator,
|
|
|
|
accel: Accelerator,
|
|
|
|
seed: int=0,
|
|
|
|
seed: int=2025,
|
|
|
|
save_path: str="ckpt",
|
|
|
|
save_path: str="ckpt",
|
|
|
|
num_iters: int=250000,
|
|
|
|
num_iters: int=250000,
|
|
|
|
save_iters: list=[10000, 50000, 100000, 200000],
|
|
|
|
save_iters: list=[10000, 50000, 100000, 200000],
|
|
|
@ -360,7 +378,7 @@ def train(
|
|
|
|
"vq/commitment_loss": 0.25,
|
|
|
|
"vq/commitment_loss": 0.25,
|
|
|
|
"vq/codebook_loss": 1.0,
|
|
|
|
"vq/codebook_loss": 1.0,
|
|
|
|
}, ):
|
|
|
|
}, ):
|
|
|
|
util.seed(seed)
|
|
|
|
seed_everything(seed)
|
|
|
|
Path(save_path).mkdir(exist_ok=True, parents=True)
|
|
|
|
Path(save_path).mkdir(exist_ok=True, parents=True)
|
|
|
|
writer = LogWriter(
|
|
|
|
writer = LogWriter(
|
|
|
|
log_dir=f"{save_path}/logs") if accel.local_rank == 0 else None
|
|
|
|
log_dir=f"{save_path}/logs") if accel.local_rank == 0 else None
|
|
|
@ -401,7 +419,8 @@ def train(
|
|
|
|
train_dataloader, start=tracker.step):
|
|
|
|
train_dataloader, start=tracker.step):
|
|
|
|
train_loop(state, batch, accel, lambdas)
|
|
|
|
train_loop(state, batch, accel, lambdas)
|
|
|
|
|
|
|
|
|
|
|
|
last_iter = tracker.step == num_iters - 1 if num_iters is not None else False
|
|
|
|
last_iter = (tracker.step == num_iters - 1
|
|
|
|
|
|
|
|
if num_iters is not None else False)
|
|
|
|
if tracker.step % sample_freq == 0 or last_iter:
|
|
|
|
if tracker.step % sample_freq == 0 or last_iter:
|
|
|
|
save_samples(state, val_idx, writer)
|
|
|
|
save_samples(state, val_idx, writer)
|
|
|
|
|
|
|
|
|
|
|
@ -416,10 +435,11 @@ def train(
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
if __name__ == "__main__":
|
|
|
|
if __name__ == "__main__":
|
|
|
|
args = argbind.parse_args()
|
|
|
|
# args = argbind.parse_args()
|
|
|
|
args["args.debug"] = int(os.getenv("LOCAL_RANK", 0)) == 0
|
|
|
|
# args["args.debug"] = int(os.getenv("LOCAL_RANK", 0)) == 0
|
|
|
|
with argbind.scope(args):
|
|
|
|
# with argbind.scope(args):
|
|
|
|
with Accelerator() as accel:
|
|
|
|
with Accelerator() as accel:
|
|
|
|
if accel.local_rank != 0:
|
|
|
|
if accel.local_rank != 0:
|
|
|
|
sys.tracebacklimit = 0
|
|
|
|
sys.tracebacklimit = 0
|
|
|
|
train(args, accel)
|
|
|
|
# train(args, accel)
|
|
|
|
|
|
|
|
train(None, accel)
|
|
|
|