diff --git a/examples/other/tts_finetune/tts3/README.md b/examples/other/tts_finetune/tts3/README.md index 1ad30328..192ee7ff 100644 --- a/examples/other/tts_finetune/tts3/README.md +++ b/examples/other/tts_finetune/tts3/README.md @@ -75,6 +75,15 @@ When "Prepare" done. The structure of the current directory is listed below. ``` +### Set finetune.yaml +`finetune.yaml` contains some configurations for fine-tuning. You can try various options to fine better result. +Arguments: + - `batch_size`: finetune batch size. Default: -1, means 64 which same to pretrained model + - `learning_rate`: learning rate. Default: 0.0001 + - `num_snapshots`: number of save models. Default: -1, means 5 which same to pretrained model + - `frozen_layers`: frozen layers. must be a list. If you don't want to frozen any layer, set []. + + ## Get Started Run the command below to diff --git a/examples/other/tts_finetune/tts3/finetune.py b/examples/other/tts_finetune/tts3/finetune.py index 0f060b44..207e2dbc 100644 --- a/examples/other/tts_finetune/tts3/finetune.py +++ b/examples/other/tts_finetune/tts3/finetune.py @@ -14,6 +14,7 @@ import argparse import os from pathlib import Path +from typing import List from typing import Union import yaml @@ -21,10 +22,10 @@ from local.check_oov import get_check_result from local.extract import extract_feature from local.label_process import get_single_label from local.prepare_env import generate_finetune_env +from local.train import train_sp from paddle import distributed as dist from yacs.config import CfgNode -from paddlespeech.t2s.exps.fastspeech2.train import train_sp from utils.gen_duration_from_textgrid import gen_duration_from_textgrid DICT_EN = 'tools/aligner/cmudict-0.7b' @@ -38,15 +39,24 @@ os.environ['PATH'] = MFA_PATH + '/:' + os.environ['PATH'] class TrainArgs(): - def __init__(self, ngpu, config_file, dump_dir: Path, output_dir: Path): + def __init__(self, + ngpu, + config_file, + dump_dir: Path, + output_dir: Path, + frozen_layers: List[str]): + # config: fastspeech2 config file. self.config = str(config_file) self.train_metadata = str(dump_dir / "train/norm/metadata.jsonl") self.dev_metadata = str(dump_dir / "dev/norm/metadata.jsonl") + # model output dir. self.output_dir = str(output_dir) self.ngpu = ngpu self.phones_dict = str(dump_dir / "phone_id_map.txt") self.speaker_dict = str(dump_dir / "speaker_id_map.txt") self.voice_cloning = False + # frozen layers + self.frozen_layers = frozen_layers def get_mfa_result( @@ -122,12 +132,11 @@ if __name__ == '__main__': "--ngpu", type=int, default=2, help="if ngpu=0, use cpu.") parser.add_argument("--epoch", type=int, default=100, help="finetune epoch") - parser.add_argument( - "--batch_size", - type=int, - default=-1, - help="batch size, default -1 means same as pretrained model") + "--finetune_config", + type=str, + default="./finetune.yaml", + help="Path to finetune config file") args = parser.parse_args() @@ -147,8 +156,14 @@ if __name__ == '__main__': with open(config_file) as f: config = CfgNode(yaml.safe_load(f)) config.max_epoch = config.max_epoch + args.epoch - if args.batch_size > 0: - config.batch_size = args.batch_size + + with open(args.finetune_config) as f2: + finetune_config = CfgNode(yaml.safe_load(f2)) + config.batch_size = finetune_config.batch_size if finetune_config.batch_size > 0 else config.batch_size + config.optimizer.learning_rate = finetune_config.learning_rate if finetune_config.learning_rate > 0 else config.optimizer.learning_rate + config.num_snapshots = finetune_config.num_snapshots if finetune_config.num_snapshots > 0 else config.num_snapshots + frozen_layers = finetune_config.frozen_layers + assert type(frozen_layers) == list, "frozen_layers should be set a list." if args.lang == 'en': lexicon_file = DICT_EN @@ -158,6 +173,13 @@ if __name__ == '__main__': mfa_phone_file = MFA_PHONE_ZH else: print('please input right lang!!') + + print(f"finetune max_epoch: {config.max_epoch}") + print(f"finetune batch_size: {config.batch_size}") + print(f"finetune learning_rate: {config.optimizer.learning_rate}") + print(f"finetune num_snapshots: {config.num_snapshots}") + print(f"finetune frozen_layers: {frozen_layers}") + am_phone_file = pretrained_model_dir / "phone_id_map.txt" label_file = input_dir / "labels.txt" @@ -181,7 +203,8 @@ if __name__ == '__main__': generate_finetune_env(output_dir, pretrained_model_dir) # create a new args for training - train_args = TrainArgs(args.ngpu, config_file, dump_dir, output_dir) + train_args = TrainArgs(args.ngpu, config_file, dump_dir, output_dir, + frozen_layers) # finetune models # dispatch diff --git a/examples/other/tts_finetune/tts3/finetune.yaml b/examples/other/tts_finetune/tts3/finetune.yaml new file mode 100644 index 00000000..374a69f3 --- /dev/null +++ b/examples/other/tts_finetune/tts3/finetune.yaml @@ -0,0 +1,12 @@ +########################################################### +# PARAS SETTING # +########################################################### +# Set to -1 to indicate that the parameter is the same as the pretrained model configuration + +batch_size: -1 +learning_rate: 0.0001 # learning rate +num_snapshots: -1 + +# frozen_layers should be a list +# if you don't need to freeze, set frozen_layers to [] +frozen_layers: ["encoder", "duration_predictor"] diff --git a/examples/other/tts_finetune/tts3/local/extract.py b/examples/other/tts_finetune/tts3/local/extract.py index edd92420..630b58ce 100644 --- a/examples/other/tts_finetune/tts3/local/extract.py +++ b/examples/other/tts_finetune/tts3/local/extract.py @@ -12,7 +12,6 @@ # See the License for the specific language governing permissions and # limitations under the License. import logging -import math import os from operator import itemgetter from pathlib import Path @@ -211,9 +210,9 @@ def extract_feature(duration_file: str, mel_extractor, pitch_extractor, energy_extractor = get_extractor(config) wav_files = sorted(list((input_dir).rglob("*.wav"))) - # split data into 3 sections, train: 80%, dev: 10%, test: 10% - num_train = math.ceil(len(wav_files) * 0.8) - num_dev = math.ceil(len(wav_files) * 0.1) + # split data into 3 sections, train: len(wav_files) - 2, dev: 1, test: 1 + num_train = len(wav_files) - 2 + num_dev = 1 print(num_train, num_dev) train_wav_files = wav_files[:num_train] diff --git a/examples/other/tts_finetune/tts3/local/train.py b/examples/other/tts_finetune/tts3/local/train.py new file mode 100644 index 00000000..d065ae59 --- /dev/null +++ b/examples/other/tts_finetune/tts3/local/train.py @@ -0,0 +1,178 @@ +# Copyright (c) 2021 PaddlePaddle Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +import logging +import os +import shutil +from pathlib import Path +from typing import List + +import jsonlines +import numpy as np +import paddle +from paddle import DataParallel +from paddle import distributed as dist +from paddle.io import DataLoader +from paddle.io import DistributedBatchSampler + +from paddlespeech.t2s.datasets.am_batch_fn import fastspeech2_multi_spk_batch_fn +from paddlespeech.t2s.datasets.am_batch_fn import fastspeech2_single_spk_batch_fn +from paddlespeech.t2s.datasets.data_table import DataTable +from paddlespeech.t2s.models.fastspeech2 import FastSpeech2 +from paddlespeech.t2s.models.fastspeech2 import FastSpeech2Evaluator +from paddlespeech.t2s.models.fastspeech2 import FastSpeech2Updater +from paddlespeech.t2s.training.extensions.snapshot import Snapshot +from paddlespeech.t2s.training.extensions.visualizer import VisualDL +from paddlespeech.t2s.training.optimizer import build_optimizers +from paddlespeech.t2s.training.seeding import seed_everything +from paddlespeech.t2s.training.trainer import Trainer + + +def freeze_layer(model, layers: List[str]): + """freeze layers + + Args: + layers (List[str]): frozen layers + """ + for layer in layers: + for param in eval("model." + layer + ".parameters()"): + param.trainable = False + + +def train_sp(args, config): + # decides device type and whether to run in parallel + # setup running environment correctly + if (not paddle.is_compiled_with_cuda()) or args.ngpu == 0: + paddle.set_device("cpu") + else: + paddle.set_device("gpu") + world_size = paddle.distributed.get_world_size() + if world_size > 1: + paddle.distributed.init_parallel_env() + + # set the random seed, it is a must for multiprocess training + seed_everything(config.seed) + + print( + f"rank: {dist.get_rank()}, pid: {os.getpid()}, parent_pid: {os.getppid()}", + ) + fields = [ + "text", "text_lengths", "speech", "speech_lengths", "durations", + "pitch", "energy" + ] + converters = {"speech": np.load, "pitch": np.load, "energy": np.load} + spk_num = None + if args.speaker_dict is not None: + print("multiple speaker fastspeech2!") + collate_fn = fastspeech2_multi_spk_batch_fn + with open(args.speaker_dict, 'rt') as f: + spk_id = [line.strip().split() for line in f.readlines()] + spk_num = len(spk_id) + fields += ["spk_id"] + elif args.voice_cloning: + print("Training voice cloning!") + collate_fn = fastspeech2_multi_spk_batch_fn + fields += ["spk_emb"] + converters["spk_emb"] = np.load + else: + print("single speaker fastspeech2!") + collate_fn = fastspeech2_single_spk_batch_fn + print("spk_num:", spk_num) + + # dataloader has been too verbose + logging.getLogger("DataLoader").disabled = True + + # construct dataset for training and validation + with jsonlines.open(args.train_metadata, 'r') as reader: + train_metadata = list(reader) + train_dataset = DataTable( + data=train_metadata, + fields=fields, + converters=converters, ) + with jsonlines.open(args.dev_metadata, 'r') as reader: + dev_metadata = list(reader) + dev_dataset = DataTable( + data=dev_metadata, + fields=fields, + converters=converters, ) + + # collate function and dataloader + + train_sampler = DistributedBatchSampler( + train_dataset, + batch_size=config.batch_size, + shuffle=True, + drop_last=True) + + print("samplers done!") + + train_dataloader = DataLoader( + train_dataset, + batch_sampler=train_sampler, + collate_fn=collate_fn, + num_workers=config.num_workers) + + dev_dataloader = DataLoader( + dev_dataset, + shuffle=False, + drop_last=False, + batch_size=config.batch_size, + collate_fn=collate_fn, + num_workers=config.num_workers) + print("dataloaders done!") + + with open(args.phones_dict, "r") as f: + phn_id = [line.strip().split() for line in f.readlines()] + vocab_size = len(phn_id) + print("vocab_size:", vocab_size) + + odim = config.n_mels + model = FastSpeech2( + idim=vocab_size, odim=odim, spk_num=spk_num, **config["model"]) + + # freeze layer + if args.frozen_layers != []: + freeze_layer(model, args.frozen_layers) + + if world_size > 1: + model = DataParallel(model) + print("model done!") + + optimizer = build_optimizers(model, **config["optimizer"]) + print("optimizer done!") + + output_dir = Path(args.output_dir) + output_dir.mkdir(parents=True, exist_ok=True) + if dist.get_rank() == 0: + config_name = args.config.split("/")[-1] + # copy conf to output_dir + shutil.copyfile(args.config, output_dir / config_name) + + updater = FastSpeech2Updater( + model=model, + optimizer=optimizer, + dataloader=train_dataloader, + output_dir=output_dir, + **config["updater"]) + + trainer = Trainer(updater, (config.max_epoch, 'epoch'), output_dir) + + evaluator = FastSpeech2Evaluator( + model, dev_dataloader, output_dir=output_dir, **config["updater"]) + + if dist.get_rank() == 0: + trainer.extend(evaluator, trigger=(1, "epoch")) + trainer.extend(VisualDL(output_dir), trigger=(1, "iteration")) + trainer.extend( + Snapshot(max_size=config.num_snapshots), trigger=(1, 'epoch')) + trainer.run() diff --git a/examples/other/tts_finetune/tts3/run.sh b/examples/other/tts_finetune/tts3/run.sh index 9bb7ec6f..9c877e64 100755 --- a/examples/other/tts_finetune/tts3/run.sh +++ b/examples/other/tts_finetune/tts3/run.sh @@ -10,11 +10,12 @@ mfa_dir=./mfa_result dump_dir=./dump output_dir=./exp/default lang=zh -ngpu=2 +ngpu=1 +finetune_config=./finetune.yaml -ckpt=snapshot_iter_96600 +ckpt=snapshot_iter_96699 -gpus=0,1 +gpus=1 CUDA_VISIBLE_DEVICES=${gpus} stage=0 stop_stage=100 @@ -35,7 +36,8 @@ if [ ${stage} -le 0 ] && [ ${stop_stage} -ge 0 ]; then --output_dir=${output_dir} \ --lang=${lang} \ --ngpu=${ngpu} \ - --epoch=100 + --epoch=100 \ + --finetune_config=${finetune_config} fi @@ -54,7 +56,7 @@ if [ ${stage} -le 1 ] && [ ${stop_stage} -ge 1 ]; then --voc_stat=pretrained_models/hifigan_aishell3_ckpt_0.2.0/feats_stats.npy \ --lang=zh \ --text=${BIN_DIR}/../sentences.txt \ - --output_dir=./test_e2e \ + --output_dir=./test_e2e/ \ --phones_dict=${dump_dir}/phone_id_map.txt \ --speaker_dict=${dump_dir}/speaker_id_map.txt \ --spk_id=0