|
|
@ -23,8 +23,10 @@ import paddle
|
|
|
|
import yaml
|
|
|
|
import yaml
|
|
|
|
from paddle import DataParallel
|
|
|
|
from paddle import DataParallel
|
|
|
|
from paddle import distributed as dist
|
|
|
|
from paddle import distributed as dist
|
|
|
|
|
|
|
|
from paddle import nn
|
|
|
|
from paddle.io import DataLoader
|
|
|
|
from paddle.io import DataLoader
|
|
|
|
from paddle.io import DistributedBatchSampler
|
|
|
|
from paddle.io import DistributedBatchSampler
|
|
|
|
|
|
|
|
from paddle.optimizer import Adam
|
|
|
|
from yacs.config import CfgNode
|
|
|
|
from yacs.config import CfgNode
|
|
|
|
|
|
|
|
|
|
|
|
from paddlespeech.t2s.datasets.am_batch_fn import build_erniesat_collate_fn
|
|
|
|
from paddlespeech.t2s.datasets.am_batch_fn import build_erniesat_collate_fn
|
|
|
@ -34,7 +36,6 @@ from paddlespeech.t2s.models.ernie_sat import ErnieSATEvaluator
|
|
|
|
from paddlespeech.t2s.models.ernie_sat import ErnieSATUpdater
|
|
|
|
from paddlespeech.t2s.models.ernie_sat import ErnieSATUpdater
|
|
|
|
from paddlespeech.t2s.training.extensions.snapshot import Snapshot
|
|
|
|
from paddlespeech.t2s.training.extensions.snapshot import Snapshot
|
|
|
|
from paddlespeech.t2s.training.extensions.visualizer import VisualDL
|
|
|
|
from paddlespeech.t2s.training.extensions.visualizer import VisualDL
|
|
|
|
from paddlespeech.t2s.training.optimizer import build_optimizers
|
|
|
|
|
|
|
|
from paddlespeech.t2s.training.seeding import seed_everything
|
|
|
|
from paddlespeech.t2s.training.seeding import seed_everything
|
|
|
|
from paddlespeech.t2s.training.trainer import Trainer
|
|
|
|
from paddlespeech.t2s.training.trainer import Trainer
|
|
|
|
|
|
|
|
|
|
|
@ -118,12 +119,27 @@ def train_sp(args, config):
|
|
|
|
|
|
|
|
|
|
|
|
odim = config.n_mels
|
|
|
|
odim = config.n_mels
|
|
|
|
model = ErnieSAT(idim=vocab_size, odim=odim, **config["model"])
|
|
|
|
model = ErnieSAT(idim=vocab_size, odim=odim, **config["model"])
|
|
|
|
|
|
|
|
# model_path = "/home/yuantian01/PaddleSpeech_ERNIE_SAT/PaddleSpeech/examples/ernie_sat/pretrained_model/paddle_checkpoint_en/model.pdparams"
|
|
|
|
|
|
|
|
# state_dict = paddle.load(model_path)
|
|
|
|
|
|
|
|
# new_state_dict = {}
|
|
|
|
|
|
|
|
# for key, value in state_dict.items():
|
|
|
|
|
|
|
|
# new_key = "model." + key
|
|
|
|
|
|
|
|
# new_state_dict[new_key] = value
|
|
|
|
|
|
|
|
# model.set_state_dict(new_state_dict)
|
|
|
|
|
|
|
|
|
|
|
|
if world_size > 1:
|
|
|
|
if world_size > 1:
|
|
|
|
model = DataParallel(model)
|
|
|
|
model = DataParallel(model)
|
|
|
|
print("model done!")
|
|
|
|
print("model done!")
|
|
|
|
|
|
|
|
|
|
|
|
optimizer = build_optimizers(model, **config["optimizer"])
|
|
|
|
scheduler = paddle.optimizer.lr.NoamDecay(
|
|
|
|
|
|
|
|
d_model=config["scheduler_params"]["d_model"],
|
|
|
|
|
|
|
|
warmup_steps=config["scheduler_params"]["warmup_steps"])
|
|
|
|
|
|
|
|
grad_clip = nn.ClipGradByGlobalNorm(config["grad_clip"])
|
|
|
|
|
|
|
|
optimizer = Adam(
|
|
|
|
|
|
|
|
learning_rate=scheduler,
|
|
|
|
|
|
|
|
grad_clip=grad_clip,
|
|
|
|
|
|
|
|
parameters=model.parameters())
|
|
|
|
|
|
|
|
|
|
|
|
print("optimizer done!")
|
|
|
|
print("optimizer done!")
|
|
|
|
|
|
|
|
|
|
|
|
output_dir = Path(args.output_dir)
|
|
|
|
output_dir = Path(args.output_dir)
|
|
|
@ -136,6 +152,7 @@ def train_sp(args, config):
|
|
|
|
updater = ErnieSATUpdater(
|
|
|
|
updater = ErnieSATUpdater(
|
|
|
|
model=model,
|
|
|
|
model=model,
|
|
|
|
optimizer=optimizer,
|
|
|
|
optimizer=optimizer,
|
|
|
|
|
|
|
|
scheduler=scheduler,
|
|
|
|
dataloader=train_dataloader,
|
|
|
|
dataloader=train_dataloader,
|
|
|
|
text_masking=config["model"]["text_masking"],
|
|
|
|
text_masking=config["model"]["text_masking"],
|
|
|
|
odim=odim,
|
|
|
|
odim=odim,
|
|
|
|