diff --git a/examples/csmsc/voc6/conf/default.yaml b/examples/csmsc/voc6/conf/default.yaml new file mode 100644 index 000000000..2c838fb99 --- /dev/null +++ b/examples/csmsc/voc6/conf/default.yaml @@ -0,0 +1,68 @@ + +########################################################### +# FEATURE EXTRACTION SETTING # +########################################################### +fs: 24000 # Sampling rate. +n_fft: 2048 # FFT size (samples). +n_shift: 300 # Hop size (samples). 12.5ms +win_length: 1200 # Window length (samples). 50ms + # If set to null, it will be the same as fft_size. +window: "hann" # Window function. +n_mels: 80 # Number of mel basis. +fmin: 80 # Minimum freq in mel basis calculation. (Hz) +fmax: 7600 # Maximum frequency in mel basis calculation. (Hz) +mu_law: True # Recommended to suppress noise if using raw bitsexit() +peak_norm: True + + +########################################################### +# MODEL SETTING # +########################################################### +model: + rnn_dims: 512 # Hidden dims of RNN Layers. + fc_dims: 512 + bits: 9 # Bit depth of signal + aux_context_window: 2 + aux_channels: 80 # Number of channels for auxiliary feature conv. + # Must be the same as num_mels. + upsample_scales: [4, 5, 3, 5] # Upsampling scales. Prodcut of these must be the same as hop size, same with pwgan here + compute_dims: 128 + res_out_dims: 128 + res_blocks: 10 + mode: RAW # either 'raw'(softmax on raw bits) or 'mold' (sample from mixture of logistics) +inference: + gen_batched: True # whether to genenate sample in batch mode + target: 12000 # target number of samples to be generated in each batch entry + overlap: 600 # number of samples for crossfading between batches + + +########################################################### +# DATA LOADER SETTING # +########################################################### +batch_size: 64 # Batch size. +batch_max_steps: 4500 # Length of each audio in batch. Make sure dividable by hop_size. +num_workers: 2 # Number of workers in DataLoader. +valid_size: 50 + +########################################################### +# OPTIMIZER SETTING # +########################################################### +grad_clip: 4.0 +learning_rate: 1.0e-4 + + +########################################################### +# INTERVAL SETTING # +########################################################### + +train_max_steps: 400000 # Number of training steps. +save_interval_steps: 5000 # Interval steps to save checkpoint. +eval_interval_steps: 1000 # Interval steps to evaluate the network. +gen_eval_samples_interval_steps: 5000 # the iteration interval of generating valid samples +generate_num: 5 # number of samples to generate at each checkpoint + +########################################################### +# OTHER SETTING # +########################################################### +num_snapshots: 10 # max number of snapshots to keep while training +seed: 42 # random seed for paddle, random, and np.random diff --git a/examples/csmsc/voc6/local/preprocess.sh b/examples/csmsc/voc6/local/preprocess.sh new file mode 100755 index 000000000..064aea557 --- /dev/null +++ b/examples/csmsc/voc6/local/preprocess.sh @@ -0,0 +1,15 @@ +#!/bin/bash + +stage=0 +stop_stage=100 + +config_path=$1 + +if [ ${stage} -le 0 ] && [ ${stop_stage} -ge 0 ]; then + python3 ${BIN_DIR}/preprocess.py \ + --input=~/datasets/BZNSYP/ \ + --output=dump \ + --dataset=csmsc \ + --config=${config_path} \ + --num-cpu=20 +fi diff --git a/examples/csmsc/voc6/local/synthesize.sh b/examples/csmsc/voc6/local/synthesize.sh new file mode 100755 index 000000000..876c8444e --- /dev/null +++ b/examples/csmsc/voc6/local/synthesize.sh @@ -0,0 +1,14 @@ +#!/bin/bash + +config_path=$1 +train_output_path=$2 +ckpt_name=$3 +test_input=$4 + +FLAGS_allocator_strategy=naive_best_fit \ +FLAGS_fraction_of_gpu_memory_to_use=0.01 \ +python3 ${BIN_DIR}/synthesize.py \ + --config=${config_path} \ + --checkpoint=${train_output_path}/checkpoints/${ckpt_name} \ + --input=${test_input} \ + --output-dir=${train_output_path}/test diff --git a/examples/csmsc/voc6/local/train.sh b/examples/csmsc/voc6/local/train.sh new file mode 100755 index 000000000..900450cdd --- /dev/null +++ b/examples/csmsc/voc6/local/train.sh @@ -0,0 +1,9 @@ +#!/bin/bash + +config_path=$1 +train_output_path=$2 +python ${BIN_DIR}/train.py \ + --config=${config_path} \ + --data=dump/ \ + --output-dir=${train_output_path} \ + --ngpu=1 diff --git a/examples/csmsc/voc6/path.sh b/examples/csmsc/voc6/path.sh new file mode 100755 index 000000000..b0c98584d --- /dev/null +++ b/examples/csmsc/voc6/path.sh @@ -0,0 +1,13 @@ +#!/bin/bash +export MAIN_ROOT=`realpath ${PWD}/../../../` + +export PATH=${MAIN_ROOT}:${MAIN_ROOT}/utils:${PATH} +export LC_ALL=C + +export PYTHONDONTWRITEBYTECODE=1 +# Use UTF-8 in Python to avoid UnicodeDecodeError when LC_ALL=C +export PYTHONIOENCODING=UTF-8 +export PYTHONPATH=${MAIN_ROOT}:${PYTHONPATH} + +MODEL=wavernn +export BIN_DIR=${MAIN_ROOT}/paddlespeech/t2s/exps/${MODEL} \ No newline at end of file diff --git a/examples/csmsc/voc6/run.sh b/examples/csmsc/voc6/run.sh new file mode 100755 index 000000000..bd32e3d2e --- /dev/null +++ b/examples/csmsc/voc6/run.sh @@ -0,0 +1,33 @@ +#!/bin/bash + +set -e +source path.sh + +gpus=0,1 +stage=0 +stop_stage=100 + +conf_path=conf/default.yaml +train_output_path=exp/default +test_input=dump/mel_test +ckpt_name=snapshot_iter_100000.pdz + +source ${MAIN_ROOT}/utils/parse_options.sh || exit 1 + +if [ ${stage} -le 0 ] && [ ${stop_stage} -ge 0 ]; then + # prepare data + ./local/preprocess.sh ${conf_path} || exit -1 +fi + +if [ ${stage} -le 1 ] && [ ${stop_stage} -ge 1 ]; then + # prepare data + CUDA_VISIBLE_DEVICES=${gpus} ./local/train.sh ${conf_path} ${train_output_path} || exit -1 +fi + +if [ ${stage} -le 2 ] && [ ${stop_stage} -ge 2 ]; then + # copy some test mels from dump + mkdir -p ${test_input} + cp -r dump/mel/00995*.npy ${test_input} + # synthesize + CUDA_VISIBLE_DEVICES=${gpus} ./local/synthesize.sh ${conf_path} ${train_output_path} ${ckpt_name} ${test_input}|| exit -1 +fi diff --git a/paddlespeech/t2s/datasets/__init__.py b/paddlespeech/t2s/datasets/__init__.py index fc64a82f2..acaf808a4 100644 --- a/paddlespeech/t2s/datasets/__init__.py +++ b/paddlespeech/t2s/datasets/__init__.py @@ -12,4 +12,5 @@ # See the License for the specific language governing permissions and # limitations under the License. from .common import * +from .csmsc import * from .ljspeech import * diff --git a/paddlespeech/t2s/datasets/csmsc.py b/paddlespeech/t2s/datasets/csmsc.py new file mode 100644 index 000000000..9928a73af --- /dev/null +++ b/paddlespeech/t2s/datasets/csmsc.py @@ -0,0 +1,56 @@ +# Copyright (c) 2020 PaddlePaddle Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +import os +from pathlib import Path + +from paddle.io import Dataset + +__all__ = ["CSMSCMetaData"] + + +class CSMSCMetaData(Dataset): + def __init__(self, root): + """ + :param root: the path of baker dataset + """ + self.root = os.path.abspath(root) + records = [] + index = 1 + self.meta_info = ["file_path", "text", "pinyin"] + + metadata_path = os.path.join(root, "ProsodyLabeling/000001-010000.txt") + wav_dirs = os.path.join(self.root, "Wave") + with open(metadata_path, 'r', encoding='utf-8') as f: + while True: + line1 = f.readline().strip() + if not line1: + break + line2 = f.readline().strip() + strs = line1.split() + wav_fname = line1.split()[0].strip() + '.wav' + wav_filepath = os.path.join(wav_dirs, wav_fname) + text = strs[1].strip() + pinyin = line2 + records.append([wav_filepath, text, pinyin]) + + self.records = records + + def __getitem__(self, i): + return self.records[i] + + def __len__(self): + return len(self.records) + + def get_meta_info(self): + return self.meta_info diff --git a/paddlespeech/t2s/datasets/vocoder_batch_fn.py b/paddlespeech/t2s/datasets/vocoder_batch_fn.py index 2e4f740fb..496bf902a 100644 --- a/paddlespeech/t2s/datasets/vocoder_batch_fn.py +++ b/paddlespeech/t2s/datasets/vocoder_batch_fn.py @@ -11,8 +11,133 @@ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. +import math +from pathlib import Path + import numpy as np import paddle +from paddle.io import Dataset + + +def label_2_float(x, bits): + return 2 * x / (2**bits - 1.) - 1. + + +def float_2_label(x, bits): + assert abs(x).max() <= 1.0 + x = (x + 1.) * (2**bits - 1) / 2 + return x.clip(0, 2**bits - 1) + + +def encode_mu_law(x, mu): + mu = mu - 1 + fx = np.sign(x) * np.log(1 + mu * np.abs(x)) / np.log(1 + mu) + return np.floor((fx + 1) / 2 * mu + 0.5) + + +def decode_mu_law(y, mu, from_labels=True): + # TODO: get rid of log2 - makes no sense + if from_labels: + y = label_2_float(y, math.log2(mu)) + mu = mu - 1 + x = paddle.sign(y) / mu * ((1 + mu)**paddle.abs(y) - 1) + return x + + +class WaveRNNDataset(Dataset): + """A simple dataset adaptor for the processed ljspeech dataset.""" + + def __init__(self, root): + self.root = Path(root).expanduser() + + records = [] + + with open(self.root / "metadata.csv", 'r') as rf: + + for line in rf: + name = line.split("\t")[0] + mel_path = str(self.root / "mel" / (str(name) + ".npy")) + wav_path = str(self.root / "wav" / (str(name) + ".npy")) + records.append((mel_path, wav_path)) + + self.records = records + + def __getitem__(self, i): + mel_name, wav_name = self.records[i] + mel = np.load(mel_name) + wav = np.load(wav_name) + return mel, wav + + def __len__(self): + return len(self.records) + + +class WaveRNNClip(object): + def __init__(self, + mode: str='RAW', + batch_max_steps: int=4500, + hop_size: int=300, + aux_context_window: int=2, + bits: int=9): + self.mode = mode + self.mel_win = batch_max_steps // hop_size + 2 * aux_context_window + self.batch_max_steps = batch_max_steps + self.hop_size = hop_size + self.aux_context_window = aux_context_window + if self.mode == 'MOL': + self.bits = 16 + else: + self.bits = bits + + def __call__(self, batch): + # batch: [mel, quant] + # voc_pad = 2 this will pad the input so that the resnet can 'see' wider than input length + # max_offsets = n_frames - 2 - (mel_win + 2 * hp.voc_pad) = n_frames - 15 + max_offsets = [ + x[0].shape[-1] - 2 - (self.mel_win + 2 * self.aux_context_window) + for x in batch + ] + # the slice point of mel selecting randomly + mel_offsets = [np.random.randint(0, offset) for offset in max_offsets] + # the slice point of wav selecting randomly, which is behind 2(=pad) frames + sig_offsets = [(offset + self.aux_context_window) * self.hop_size + for offset in mel_offsets] + # mels.sape[1] = voc_seq_len // hop_length + 2 * voc_pad + mels = [ + x[0][:, mel_offsets[i]:mel_offsets[i] + self.mel_win] + for i, x in enumerate(batch) + ] + # label.shape[1] = voc_seq_len + 1 + labels = [ + x[1][sig_offsets[i]:sig_offsets[i] + self.batch_max_steps + 1] + for i, x in enumerate(batch) + ] + + mels = np.stack(mels).astype(np.float32) + labels = np.stack(labels).astype(np.int64) + + mels = paddle.to_tensor(mels) + labels = paddle.to_tensor(labels, dtype='int64') + + # x is input, y is label + x = labels[:, :self.batch_max_steps] + y = labels[:, 1:] + ''' + mode = RAW: + mu_law = True: + quant: bits = 9 0, 1, 2, ..., 509, 510, 511 int + mu_law = False + quant bits = 9 [0, 511] float + mode = MOL: + quant: bits = 16 [0. 65536] float + ''' + # x should be normalizes in.[0, 1] in RAW mode + x = label_2_float(paddle.cast(x, dtype='float32'), self.bits) + # y should be normalizes in.[0, 1] in MOL mode + if self.mode == 'MOL': + y = label_2_float(paddle.cast(y, dtype='float32'), self.bits) + + return x, y, mels class Clip(object): diff --git a/paddlespeech/t2s/exps/wavernn/__init__.py b/paddlespeech/t2s/exps/wavernn/__init__.py new file mode 100644 index 000000000..abf198b97 --- /dev/null +++ b/paddlespeech/t2s/exps/wavernn/__init__.py @@ -0,0 +1,13 @@ +# Copyright (c) 2020 PaddlePaddle Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. diff --git a/paddlespeech/t2s/exps/wavernn/preprocess.py b/paddlespeech/t2s/exps/wavernn/preprocess.py new file mode 100644 index 000000000..a26c6702a --- /dev/null +++ b/paddlespeech/t2s/exps/wavernn/preprocess.py @@ -0,0 +1,157 @@ +# Copyright (c) 2020 PaddlePaddle Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +import argparse +import os +from multiprocessing import cpu_count +from multiprocessing import Pool +from pathlib import Path + +import librosa +import numpy as np +import pandas as pd +import tqdm +import yaml +from yacs.config import CfgNode + +from paddlespeech.t2s.data.get_feats import LogMelFBank +from paddlespeech.t2s.datasets import CSMSCMetaData +from paddlespeech.t2s.datasets import LJSpeechMetaData +from paddlespeech.t2s.datasets.vocoder_batch_fn import encode_mu_law +from paddlespeech.t2s.datasets.vocoder_batch_fn import float_2_label + + +class Transform(object): + def __init__(self, output_dir: Path, config): + self.fs = config.fs + self.peak_norm = config.peak_norm + self.bits = config.model.bits + self.mode = config.model.mode + self.mu_law = config.mu_law + + self.wav_dir = output_dir / "wav" + self.mel_dir = output_dir / "mel" + self.wav_dir.mkdir(exist_ok=True) + self.mel_dir.mkdir(exist_ok=True) + + self.mel_extractor = LogMelFBank( + sr=config.fs, + n_fft=config.n_fft, + hop_length=config.n_shift, + win_length=config.win_length, + window=config.window, + n_mels=config.n_mels, + fmin=config.fmin, + fmax=config.fmax) + + if self.mode != 'RAW' and self.mode != 'MOL': + raise RuntimeError('Unknown mode value - ', self.mode) + + def __call__(self, example): + wav_path, _, _ = example + + base_name = os.path.splitext(os.path.basename(wav_path))[0] + # print("self.sample_rate:",self.sample_rate) + wav, _ = librosa.load(wav_path, sr=self.fs) + peak = np.abs(wav).max() + if self.peak_norm or peak > 1.0: + wav /= peak + + mel = self.mel_extractor.get_log_mel_fbank(wav).T + if self.mode == 'RAW': + if self.mu_law: + quant = encode_mu_law(wav, mu=2**self.bits) + else: + quant = float_2_label(wav, bits=self.bits) + elif self.mode == 'MOL': + quant = float_2_label(wav, bits=16) + + mel = mel.astype(np.float32) + audio = quant.astype(np.int64) + + np.save(str(self.wav_dir / base_name), audio) + np.save(str(self.mel_dir / base_name), mel) + + return base_name, mel.shape[-1], audio.shape[-1] + + +def create_dataset(config, + input_dir, + output_dir, + nprocs: int=1, + dataset_type: str="ljspeech"): + input_dir = Path(input_dir).expanduser() + ''' + LJSpeechMetaData.records: [filename, normalized text, speaker name(ljspeech)] + CSMSCMetaData.records: [filename, normalized text, pinyin] + ''' + if dataset_type == 'ljspeech': + dataset = LJSpeechMetaData(input_dir) + else: + dataset = CSMSCMetaData(input_dir) + output_dir = Path(output_dir).expanduser() + output_dir.mkdir(exist_ok=True) + + transform = Transform(output_dir, config) + + file_names = [] + + pool = Pool(processes=nprocs) + + for info in tqdm.tqdm(pool.imap(transform, dataset), total=len(dataset)): + base_name, mel_len, audio_len = info + file_names.append((base_name, mel_len, audio_len)) + + meta_data = pd.DataFrame.from_records(file_names) + meta_data.to_csv( + str(output_dir / "metadata.csv"), sep="\t", index=None, header=None) + print("saved meta data in to {}".format( + os.path.join(output_dir, "metadata.csv"))) + + print("Done!") + + +if __name__ == "__main__": + parser = argparse.ArgumentParser(description="create dataset") + parser.add_argument( + "--config", type=str, help="config file to overwrite default config.") + + parser.add_argument( + "--input", type=str, help="path of the ljspeech dataset") + parser.add_argument( + "--output", type=str, help="path to save output dataset") + parser.add_argument( + "--num-cpu", + type=int, + default=cpu_count() // 2, + help="number of process.") + parser.add_argument( + "--dataset", + type=str, + default="ljspeech", + help="The dataset to preprocess, ljspeech or csmsc") + + args = parser.parse_args() + + with open(args.config, 'rt') as f: + config = CfgNode(yaml.safe_load(f)) + + if args.dataset != "ljspeech" and args.dataset != "csmsc": + raise RuntimeError('Unknown dataset - ', args.dataset) + + create_dataset( + config, + input_dir=args.input, + output_dir=args.output, + nprocs=args.num_cpu, + dataset_type=args.dataset) diff --git a/paddlespeech/t2s/exps/wavernn/synthesize.py b/paddlespeech/t2s/exps/wavernn/synthesize.py new file mode 100644 index 000000000..e08c52b60 --- /dev/null +++ b/paddlespeech/t2s/exps/wavernn/synthesize.py @@ -0,0 +1,89 @@ +# Copyright (c) 2021 PaddlePaddle Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +import argparse +import os +from pathlib import Path + +import numpy as np +import paddle +import soundfile as sf +import yaml +from paddle import distributed as dist +from yacs.config import CfgNode + +from paddlespeech.t2s.models.wavernn import WaveRNN + + +def main(): + parser = argparse.ArgumentParser(description="Synthesize with WaveRNN.") + + parser.add_argument("--config", type=str, help="GANVocoder config file.") + parser.add_argument("--checkpoint", type=str, help="snapshot to load.") + parser.add_argument( + "--input", + type=str, + help="path of directory containing mel spectrogram (in .npy format)") + parser.add_argument("--output-dir", type=str, help="output dir.") + parser.add_argument( + "--ngpu", type=int, default=1, help="if ngpu == 0, use cpu.") + + args = parser.parse_args() + + with open(args.config) as f: + config = CfgNode(yaml.safe_load(f)) + + print("========Args========") + print(yaml.safe_dump(vars(args))) + print("========Config========") + print(config) + print( + f"master see the word size: {dist.get_world_size()}, from pid: {os.getpid()}" + ) + + if args.ngpu == 0: + paddle.set_device("cpu") + elif args.ngpu > 0: + paddle.set_device("gpu") + else: + print("ngpu should >= 0 !") + + model = WaveRNN( + hop_length=config.n_shift, sample_rate=config.fs, **config["model"]) + state_dict = paddle.load(args.checkpoint) + model.set_state_dict(state_dict["main_params"]) + + model.eval() + + mel_dir = Path(args.input).expanduser() + output_dir = Path(args.output_dir).expanduser() + output_dir.mkdir(parents=True, exist_ok=True) + for file_path in sorted(mel_dir.iterdir()): + mel = np.load(str(file_path)) + mel = paddle.to_tensor(mel) + mel = mel.transpose([1, 0]) + # input shape is (T', C_aux) + audio = model.generate( + c=mel, + batched=config.inference.gen_batched, + target=config.inference.target, + overlap=config.inference.overlap, + mu_law=config.mu_law, + gen_display=True) + audio_path = output_dir / (os.path.splitext(file_path.name)[0] + ".wav") + sf.write(audio_path, audio.numpy(), samplerate=config.fs) + print("[synthesize] {} -> {}".format(file_path, audio_path)) + + +if __name__ == "__main__": + main() diff --git a/paddlespeech/t2s/exps/wavernn/train.py b/paddlespeech/t2s/exps/wavernn/train.py new file mode 100644 index 000000000..d7bfc49bf --- /dev/null +++ b/paddlespeech/t2s/exps/wavernn/train.py @@ -0,0 +1,192 @@ +# Copyright (c) 2021 PaddlePaddle Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +import argparse +import os +import shutil +from pathlib import Path + +import paddle +import yaml +from paddle import DataParallel +from paddle import distributed as dist +from paddle.io import DataLoader +from paddle.io import DistributedBatchSampler +from paddle.optimizer import Adam +from yacs.config import CfgNode + +from paddlespeech.t2s.data import dataset +from paddlespeech.t2s.datasets.vocoder_batch_fn import WaveRNNClip +from paddlespeech.t2s.datasets.vocoder_batch_fn import WaveRNNDataset +from paddlespeech.t2s.models.wavernn import WaveRNN +from paddlespeech.t2s.models.wavernn import WaveRNNEvaluator +from paddlespeech.t2s.models.wavernn import WaveRNNUpdater +from paddlespeech.t2s.modules.losses import discretized_mix_logistic_loss +from paddlespeech.t2s.training.extensions.snapshot import Snapshot +from paddlespeech.t2s.training.extensions.visualizer import VisualDL +from paddlespeech.t2s.training.seeding import seed_everything +from paddlespeech.t2s.training.trainer import Trainer + + +def train_sp(args, config): + # decides device type and whether to run in parallel + # setup running environment correctly + world_size = paddle.distributed.get_world_size() + if (not paddle.is_compiled_with_cuda()) or args.ngpu == 0: + paddle.set_device("cpu") + else: + paddle.set_device("gpu") + if world_size > 1: + paddle.distributed.init_parallel_env() + + # set the random seed, it is a must for multiprocess training + seed_everything(config.seed) + + print( + f"rank: {dist.get_rank()}, pid: {os.getpid()}, parent_pid: {os.getppid()}", + ) + + wavernn_dataset = WaveRNNDataset(args.data) + + train_dataset, dev_dataset = dataset.split( + wavernn_dataset, len(wavernn_dataset) - config.valid_size) + + batch_fn = WaveRNNClip( + mode=config.model.mode, + aux_context_window=config.model.aux_context_window, + hop_size=config.n_shift, + batch_max_steps=config.batch_max_steps, + bits=config.model.bits) + + # collate function and dataloader + train_sampler = DistributedBatchSampler( + train_dataset, + batch_size=config.batch_size, + shuffle=True, + drop_last=True) + dev_sampler = DistributedBatchSampler( + dev_dataset, + batch_size=config.batch_size, + shuffle=False, + drop_last=False) + print("samplers done!") + + train_dataloader = DataLoader( + train_dataset, + batch_sampler=train_sampler, + collate_fn=batch_fn, + num_workers=config.num_workers) + + dev_dataloader = DataLoader( + dev_dataset, + collate_fn=batch_fn, + batch_sampler=dev_sampler, + num_workers=config.num_workers) + valid_generate_loader = DataLoader(dev_dataset, batch_size=1) + print("dataloaders done!") + + model = WaveRNN( + hop_length=config.n_shift, sample_rate=config.fs, **config["model"]) + if world_size > 1: + model = DataParallel(model) + print("model done!") + + if config.model.mode == 'RAW': + criterion = paddle.nn.CrossEntropyLoss(axis=1) + elif config.model.mode == 'MOL': + criterion = discretized_mix_logistic_loss + else: + criterion = None + RuntimeError('Unknown model mode value - ', config.model.mode) + print("criterions done!") + clip = paddle.nn.ClipGradByGlobalNorm(config.grad_clip) + optimizer = Adam( + parameters=model.parameters(), + learning_rate=config.learning_rate, + grad_clip=clip) + + print("optimizer done!") + + output_dir = Path(args.output_dir) + output_dir.mkdir(parents=True, exist_ok=True) + if dist.get_rank() == 0: + config_name = args.config.split("/")[-1] + # copy conf to output_dir + shutil.copyfile(args.config, output_dir / config_name) + + updater = WaveRNNUpdater( + model=model, + optimizer=optimizer, + criterion=criterion, + dataloader=train_dataloader, + output_dir=output_dir, + mode=config.model.mode) + + evaluator = WaveRNNEvaluator( + model=model, + dataloader=dev_dataloader, + criterion=criterion, + output_dir=output_dir, + valid_generate_loader=valid_generate_loader, + config=config) + + trainer = Trainer( + updater, + stop_trigger=(config.train_max_steps, "iteration"), + out=output_dir) + + if dist.get_rank() == 0: + trainer.extend( + evaluator, trigger=(config.eval_interval_steps, 'iteration')) + trainer.extend(VisualDL(output_dir), trigger=(1, 'iteration')) + trainer.extend( + Snapshot(max_size=config.num_snapshots), + trigger=(config.save_interval_steps, 'iteration')) + + print("Trainer Done!") + trainer.run() + + +def main(): + # parse args and config and redirect to train_sp + + parser = argparse.ArgumentParser(description="Train a WaveRNN model.") + parser.add_argument( + "--config", type=str, help="config file to overwrite default config.") + parser.add_argument("--data", type=str, help="input") + parser.add_argument("--output-dir", type=str, help="output dir.") + parser.add_argument( + "--ngpu", type=int, default=1, help="if ngpu == 0, use cpu.") + + args = parser.parse_args() + + with open(args.config, 'rt') as f: + config = CfgNode(yaml.safe_load(f)) + + print("========Args========") + print(yaml.safe_dump(vars(args))) + print("========Config========") + print(config) + print( + f"master see the word size: {dist.get_world_size()}, from pid: {os.getpid()}" + ) + + # dispatch + if args.ngpu > 1: + dist.spawn(train_sp, (args, config), nprocs=args.ngpu) + else: + train_sp(args, config) + + +if __name__ == "__main__": + main() diff --git a/paddlespeech/t2s/models/__init__.py b/paddlespeech/t2s/models/__init__.py index 65227374e..97f8695c2 100644 --- a/paddlespeech/t2s/models/__init__.py +++ b/paddlespeech/t2s/models/__init__.py @@ -20,3 +20,4 @@ from .speedyspeech import * from .tacotron2 import * from .transformer_tts import * from .waveflow import * +from .wavernn import * diff --git a/paddlespeech/t2s/models/wavernn/__init__.py b/paddlespeech/t2s/models/wavernn/__init__.py new file mode 100644 index 000000000..80ffd0688 --- /dev/null +++ b/paddlespeech/t2s/models/wavernn/__init__.py @@ -0,0 +1,15 @@ +# Copyright (c) 2020 PaddlePaddle Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +from .wavernn import * +from .wavernn_updater import * diff --git a/paddlespeech/t2s/models/wavernn/wavernn.py b/paddlespeech/t2s/models/wavernn/wavernn.py new file mode 100644 index 000000000..5d1cbd39d --- /dev/null +++ b/paddlespeech/t2s/models/wavernn/wavernn.py @@ -0,0 +1,592 @@ +# Copyright (c) 2020 PaddlePaddle Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +import sys +import time +from typing import List + +import numpy as np +import paddle +from paddle import nn +from paddle.nn import functional as F + +from paddlespeech.t2s.datasets.vocoder_batch_fn import decode_mu_law +from paddlespeech.t2s.modules.losses import sample_from_discretized_mix_logistic +from paddlespeech.t2s.modules.nets_utils import initialize +from paddlespeech.t2s.modules.upsample import Stretch2D + + +class ResBlock(nn.Layer): + def __init__(self, dims): + super(ResBlock, self).__init__() + self.conv1 = nn.Conv1D(dims, dims, kernel_size=1, bias_attr=False) + self.conv2 = nn.Conv1D(dims, dims, kernel_size=1, bias_attr=False) + self.batch_norm1 = nn.BatchNorm1D(dims) + self.batch_norm2 = nn.BatchNorm1D(dims) + + def forward(self, x): + ''' + conv -> bn -> relu -> conv -> bn + residual connection + ''' + residual = x + x = self.conv1(x) + x = self.batch_norm1(x) + x = F.relu(x) + x = self.conv2(x) + x = self.batch_norm2(x) + return x + residual + + +class MelResNet(nn.Layer): + def __init__(self, + res_blocks: int=10, + compute_dims: int=128, + res_out_dims: int=128, + aux_channels: int=80, + aux_context_window: int=0): + super().__init__() + k_size = aux_context_window * 2 + 1 + # pay attention here, the dim reduces aux_context_window * 2 + self.conv_in = nn.Conv1D( + aux_channels, compute_dims, kernel_size=k_size, bias_attr=False) + self.batch_norm = nn.BatchNorm1D(compute_dims) + self.layers = nn.LayerList() + for _ in range(res_blocks): + self.layers.append(ResBlock(compute_dims)) + self.conv_out = nn.Conv1D(compute_dims, res_out_dims, kernel_size=1) + + def forward(self, x): + ''' + Parameters + ---------- + x : Tensor + Input tensor (B, in_dims, T). + Returns + ---------- + Tensor + Output tensor (B, res_out_dims, T). + ''' + x = self.conv_in(x) + x = self.batch_norm(x) + x = F.relu(x) + for f in self.layers: + x = f(x) + x = self.conv_out(x) + return x + + +class UpsampleNetwork(nn.Layer): + def __init__(self, + aux_channels: int=80, + upsample_scales: List[int]=[4, 5, 3, 5], + compute_dims: int=128, + res_blocks: int=10, + res_out_dims: int=128, + aux_context_window: int=2): + super().__init__() + # total_scale is the total Up sampling multiple + total_scale = np.prod(upsample_scales) + # TODO pad*total_scale is numpy.int64 + self.indent = int(aux_context_window * total_scale) + self.resnet = MelResNet( + res_blocks=res_blocks, + aux_channels=aux_channels, + compute_dims=compute_dims, + res_out_dims=res_out_dims, + aux_context_window=aux_context_window) + self.resnet_stretch = Stretch2D(total_scale, 1) + self.up_layers = nn.LayerList() + for scale in upsample_scales: + k_size = (1, scale * 2 + 1) + padding = (0, scale) + stretch = Stretch2D(scale, 1) + + conv = nn.Conv2D( + 1, 1, kernel_size=k_size, padding=padding, bias_attr=False) + weight_ = paddle.full_like(conv.weight, 1. / k_size[1]) + conv.weight.set_value(weight_) + self.up_layers.append(stretch) + self.up_layers.append(conv) + + def forward(self, m): + ''' + Parameters + ---------- + c : Tensor + Input tensor (B, C_aux, T). + Returns + ---------- + Tensor + Output tensor (B, (T - 2 * pad) * prob(upsample_scales), C_aux). + Tensor + Output tensor (B, (T - 2 * pad) * prob(upsample_scales), res_out_dims). + ''' + # aux: [B, C_aux, T] + # -> [B, res_out_dims, T - 2 * aux_context_window] + # -> [B, 1, res_out_dims, T - 2 * aux_context_window] + aux = self.resnet(m).unsqueeze(1) + # aux: [B, 1, res_out_dims, T - 2 * aux_context_window] + # -> [B, 1, res_out_dims, (T - 2 * pad) * prob(upsample_scales)] + aux = self.resnet_stretch(aux) + # aux: [B, 1, res_out_dims, T * prob(upsample_scales)] + # -> [B, res_out_dims, T * prob(upsample_scales)] + aux = aux.squeeze(1) + # m: [B, C_aux, T] -> [B, 1, C_aux, T] + m = m.unsqueeze(1) + for f in self.up_layers: + m = f(m) + # m: [B, 1, C_aux, T*prob(upsample_scales)] + # -> [B, C_aux, T * prob(upsample_scales)] + # -> [B, C_aux, (T - 2 * pad) * prob(upsample_scales)] + m = m.squeeze(1)[:, :, self.indent:-self.indent] + # m: [B, (T - 2 * pad) * prob(upsample_scales), C_aux] + # aux: [B, (T - 2 * pad) * prob(upsample_scales), res_out_dims] + return m.transpose([0, 2, 1]), aux.transpose([0, 2, 1]) + + +class WaveRNN(nn.Layer): + def __init__( + self, + rnn_dims: int=512, + fc_dims: int=512, + bits: int=9, + aux_context_window: int=2, + upsample_scales: List[int]=[4, 5, 3, 5], + aux_channels: int=80, + compute_dims: int=128, + res_out_dims: int=128, + res_blocks: int=10, + hop_length: int=300, + sample_rate: int=24000, + mode='RAW', + init_type: str="xavier_uniform", ): + ''' + Parameters + ---------- + rnn_dims : int, optional + Hidden dims of RNN Layers. + fc_dims : int, optional + Dims of FC Layers. + bits : int, optional + bit depth of signal. + aux_context_window : int, optional + The context window size of the first convolution applied to the + auxiliary input, by default 2 + upsample_scales : List[int], optional + Upsample scales of the upsample network. + aux_channels : int, optional + Auxiliary channel of the residual blocks. + compute_dims : int, optional + Dims of Conv1D in MelResNet. + res_out_dims : int, optional + Dims of output in MelResNet. + res_blocks : int, optional + Number of residual blocks. + mode : str, optional + Output mode of the WaveRNN vocoder. `MOL` for Mixture of Logistic Distribution, + and `RAW` for quantized bits as the model's output. + init_type : str + How to initialize parameters. + ''' + super().__init__() + self.mode = mode + self.aux_context_window = aux_context_window + if self.mode == 'RAW': + self.n_classes = 2**bits + elif self.mode == 'MOL': + self.n_classes = 30 + else: + RuntimeError('Unknown model mode value - ', self.mode) + + # List of rnns to call 'flatten_parameters()' on + self._to_flatten = [] + + self.rnn_dims = rnn_dims + self.aux_dims = res_out_dims // 4 + self.hop_length = hop_length + self.sample_rate = sample_rate + + # initialize parameters + initialize(self, init_type) + + self.upsample = UpsampleNetwork( + aux_channels=aux_channels, + upsample_scales=upsample_scales, + compute_dims=compute_dims, + res_blocks=res_blocks, + res_out_dims=res_out_dims, + aux_context_window=aux_context_window) + self.I = nn.Linear(aux_channels + self.aux_dims + 1, rnn_dims) + + self.rnn1 = nn.GRU(rnn_dims, rnn_dims) + self.rnn2 = nn.GRU(rnn_dims + self.aux_dims, rnn_dims) + self._to_flatten += [self.rnn1, self.rnn2] + + self.fc1 = nn.Linear(rnn_dims + self.aux_dims, fc_dims) + self.fc2 = nn.Linear(fc_dims + self.aux_dims, fc_dims) + self.fc3 = nn.Linear(fc_dims, self.n_classes) + + # Avoid fragmentation of RNN parameters and associated warning + self._flatten_parameters() + + nn.initializer.set_global_initializer(None) + + def forward(self, x, c): + ''' + Parameters + ---------- + x : Tensor + wav sequence, [B, T] + c : Tensor + mel spectrogram [B, C_aux, T'] + + T = (T' - 2 * aux_context_window ) * hop_length + Returns + ---------- + Tensor + [B, T, n_classes] + ''' + # Although we `_flatten_parameters()` on init, when using DataParallel + # the model gets replicated, making it no longer guaranteed that the + # weights are contiguous in GPU memory. Hence, we must call it again + self._flatten_parameters() + + bsize = paddle.shape(x)[0] + h1 = paddle.zeros([1, bsize, self.rnn_dims]) + h2 = paddle.zeros([1, bsize, self.rnn_dims]) + # c: [B, T, C_aux] + # aux: [B, T, res_out_dims] + c, aux = self.upsample(c) + + aux_idx = [self.aux_dims * i for i in range(5)] + a1 = aux[:, :, aux_idx[0]:aux_idx[1]] + a2 = aux[:, :, aux_idx[1]:aux_idx[2]] + a3 = aux[:, :, aux_idx[2]:aux_idx[3]] + a4 = aux[:, :, aux_idx[3]:aux_idx[4]] + + x = paddle.concat([x.unsqueeze(-1), c, a1], axis=2) + x = self.I(x) + res = x + x, _ = self.rnn1(x, h1) + + x = x + res + res = x + x = paddle.concat([x, a2], axis=2) + x, _ = self.rnn2(x, h2) + + x = x + res + x = paddle.concat([x, a3], axis=2) + x = F.relu(self.fc1(x)) + + x = paddle.concat([x, a4], axis=2) + x = F.relu(self.fc2(x)) + + return self.fc3(x) + + @paddle.no_grad() + def generate(self, + c, + batched: bool=True, + target: int=12000, + overlap: int=600, + mu_law: bool=True, + gen_display: bool=False): + """ + Parameters + ---------- + c : Tensor + input mels, (T', C_aux) + batched : bool + generate in batch or not + target : int + target number of samples to be generated in each batch entry + overlap : int + number of samples for crossfading between batches + mu_law : bool + use mu law or not + Returns + ---------- + wav sequence + Output (T' * prod(upsample_scales), out_channels, C_out). + """ + + self.eval() + + mu_law = mu_law if self.mode == 'RAW' else False + + output = [] + start = time.time() + rnn1 = self.get_gru_cell(self.rnn1) + rnn2 = self.get_gru_cell(self.rnn2) + # pseudo batch + # (T, C_aux) -> (1, C_aux, T) + c = paddle.transpose(c, [1, 0]).unsqueeze(0) + + wave_len = (paddle.shape(c)[-1] - 1) * self.hop_length + # TODO remove two transpose op by modifying function pad_tensor + c = self.pad_tensor( + c.transpose([0, 2, 1]), pad=self.aux_context_window, + side='both').transpose([0, 2, 1]) + c, aux = self.upsample(c) + + if batched: + # (num_folds, target + 2 * overlap, features) + c = self.fold_with_overlap(c, target, overlap) + aux = self.fold_with_overlap(aux, target, overlap) + + b_size, seq_len, _ = paddle.shape(c) + h1 = paddle.zeros([b_size, self.rnn_dims]) + h2 = paddle.zeros([b_size, self.rnn_dims]) + x = paddle.zeros([b_size, 1]) + + d = self.aux_dims + aux_split = [aux[:, :, d * i:d * (i + 1)] for i in range(4)] + + for i in range(seq_len): + m_t = c[:, i, :] + + a1_t, a2_t, a3_t, a4_t = (a[:, i, :] for a in aux_split) + x = paddle.concat([x, m_t, a1_t], axis=1) + x = self.I(x) + h1, _ = rnn1(x, h1) + x = x + h1 + inp = paddle.concat([x, a2_t], axis=1) + h2, _ = rnn2(inp, h2) + + x = x + h2 + x = paddle.concat([x, a3_t], axis=1) + x = F.relu(self.fc1(x)) + + x = paddle.concat([x, a4_t], axis=1) + x = F.relu(self.fc2(x)) + + logits = self.fc3(x) + + if self.mode == 'MOL': + sample = sample_from_discretized_mix_logistic( + logits.unsqueeze(0).transpose([0, 2, 1])) + output.append(sample.reshape([-1])) + x = sample.transpose([1, 0, 2]) + + elif self.mode == 'RAW': + posterior = F.softmax(logits, axis=1) + distrib = paddle.distribution.Categorical(posterior) + # corresponding operate [np.floor((fx + 1) / 2 * mu + 0.5)] in enocde_mu_law + sample = 2 * distrib.sample([1])[0].cast('float32') / ( + self.n_classes - 1.) - 1. + output.append(sample) + x = sample.unsqueeze(-1) + else: + raise RuntimeError('Unknown model mode value - ', self.mode) + + if gen_display: + if i % 1000 == 0: + self.gen_display(i, int(seq_len), int(b_size), start) + + output = paddle.stack(output).transpose([1, 0]) + + if mu_law: + output = decode_mu_law(output, self.n_classes, False) + + if batched: + output = self.xfade_and_unfold(output, target, overlap) + else: + output = output[0] + + # Fade-out at the end to avoid signal cutting out suddenly + fade_out = paddle.linspace(1, 0, 20 * self.hop_length) + output = output[:wave_len] + output[-20 * self.hop_length:] *= fade_out + + self.train() + + # 增加 C_out 维度 + return output.unsqueeze(-1) + + def get_gru_cell(self, gru): + gru_cell = nn.GRUCell(gru.input_size, gru.hidden_size) + gru_cell.weight_hh = gru.weight_hh_l0 + gru_cell.weight_ih = gru.weight_ih_l0 + gru_cell.bias_hh = gru.bias_hh_l0 + gru_cell.bias_ih = gru.bias_ih_l0 + + return gru_cell + + def _flatten_parameters(self): + [m.flatten_parameters() for m in self._to_flatten] + + def pad_tensor(self, x, pad, side='both'): + ''' + Parameters + ---------- + x : Tensor + mel, [1, n_frames, 80] + pad : int + side : str + 'both', 'before' or 'after' + Returns + ---------- + Tensor + ''' + b, t, c = paddle.shape(x) + total = t + 2 * pad if side == 'both' else t + pad + padded = paddle.zeros([b, total, c]) + if side == 'before' or side == 'both': + padded[:, pad:pad + t, :] = x + elif side == 'after': + padded[:, :t, :] = x + return padded + + def fold_with_overlap(self, x, target, overlap): + ''' + Fold the tensor with overlap for quick batched inference. + Overlap will be used for crossfading in xfade_and_unfold() + + Parameters + ---------- + x : Tensor + Upsampled conditioning features. mels or aux + shape=(1, T, features) + mels: [1, T, 80] + aux: [1, T, 128] + target : int + Target timesteps for each index of batch + overlap : int + Timesteps for both xfade and rnn warmup + overlap = hop_length * 2 + + Returns + ---------- + Tensor + shape=(num_folds, target + 2 * overlap, features) + num_flods = (time_seq - overlap) // (target + overlap) + mel: [num_folds, target + 2 * overlap, 80] + aux: [num_folds, target + 2 * overlap, 128] + + Details + ---------- + x = [[h1, h2, ... hn]] + + Where each h is a vector of conditioning features + + Eg: target=2, overlap=1 with x.size(1)=10 + + folded = [[h1, h2, h3, h4], + [h4, h5, h6, h7], + [h7, h8, h9, h10]] + ''' + + _, total_len, features = paddle.shape(x) + + # Calculate variables needed + num_folds = (total_len - overlap) // (target + overlap) + extended_len = num_folds * (overlap + target) + overlap + remaining = total_len - extended_len + + # Pad if some time steps poking out + if remaining != 0: + num_folds += 1 + padding = target + 2 * overlap - remaining + x = self.pad_tensor(x, padding, side='after') + + folded = paddle.zeros([num_folds, target + 2 * overlap, features]) + + # Get the values for the folded tensor + for i in range(num_folds): + start = i * (target + overlap) + end = start + target + 2 * overlap + folded[i] = x[0][start:end, :] + return folded + + def xfade_and_unfold(self, y, target: int=12000, overlap: int=600): + ''' Applies a crossfade and unfolds into a 1d array. + + Parameters + ---------- + y : Tensor + Batched sequences of audio samples + shape=(num_folds, target + 2 * overlap) + dtype=paddle.float64 + overlap : int + Timesteps for both xfade and rnn warmup + + Returns + ---------- + Tensor + audio samples in a 1d array + shape=(total_len) + dtype=paddle.float64 + + Details + ---------- + y = [[seq1], + [seq2], + [seq3]] + + Apply a gain envelope at both ends of the sequences + + y = [[seq1_in, seq1_target, seq1_out], + [seq2_in, seq2_target, seq2_out], + [seq3_in, seq3_target, seq3_out]] + + Stagger and add up the groups of samples: + + [seq1_in, seq1_target, (seq1_out + seq2_in), seq2_target, ...] + + ''' + # num_folds = (total_len - overlap) // (target + overlap) + num_folds, length = y.shape + target = length - 2 * overlap + total_len = num_folds * (target + overlap) + overlap + + # Need some silence for the run warmup + slience_len = overlap // 2 + fade_len = overlap - slience_len + slience = paddle.zeros([slience_len], dtype=paddle.float64) + linear = paddle.ones([fade_len], dtype=paddle.float64) + + # Equal power crossfade + # fade_in increase from 0 to 1, fade_out reduces from 1 to 0 + t = paddle.linspace(-1, 1, fade_len, dtype=paddle.float64) + fade_in = paddle.sqrt(0.5 * (1 + t)) + fade_out = paddle.sqrt(0.5 * (1 - t)) + # Concat the silence to the fades + fade_out = paddle.concat([linear, fade_out]) + fade_in = paddle.concat([slience, fade_in]) + + # Apply the gain to the overlap samples + y[:, :overlap] *= fade_in + y[:, -overlap:] *= fade_out + + unfolded = paddle.zeros([total_len], dtype=paddle.float64) + + # Loop to add up all the samples + for i in range(num_folds): + start = i * (target + overlap) + end = start + target + 2 * overlap + unfolded[start:end] += y[i] + + return unfolded + + def gen_display(self, i, seq_len, b_size, start): + gen_rate = (i + 1) / (time.time() - start) * b_size / 1000 + pbar = self.progbar(i, seq_len) + msg = f'| {pbar} {i*b_size}/{seq_len*b_size} | Batch Size: {b_size} | Gen Rate: {gen_rate:.1f}kHz | ' + sys.stdout.write(f"\r{msg}") + + def progbar(self, i, n, size=16): + done = int(i * size) // n + bar = '' + for i in range(size): + bar += '█' if i <= done else '░' + return bar diff --git a/paddlespeech/t2s/models/wavernn/wavernn_updater.py b/paddlespeech/t2s/models/wavernn/wavernn_updater.py new file mode 100644 index 000000000..e6064e4cb --- /dev/null +++ b/paddlespeech/t2s/models/wavernn/wavernn_updater.py @@ -0,0 +1,203 @@ +# Copyright (c) 2021 PaddlePaddle Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +import logging +from pathlib import Path + +import paddle +import soundfile as sf +from paddle import distributed as dist +from paddle.io import DataLoader +from paddle.nn import Layer +from paddle.optimizer import Optimizer + +from paddlespeech.t2s.datasets.vocoder_batch_fn import decode_mu_law +from paddlespeech.t2s.datasets.vocoder_batch_fn import label_2_float +from paddlespeech.t2s.training.extensions.evaluator import StandardEvaluator +from paddlespeech.t2s.training.reporter import report +from paddlespeech.t2s.training.updaters.standard_updater import StandardUpdater +logging.basicConfig( + format='%(asctime)s [%(levelname)s] [%(filename)s:%(lineno)d] %(message)s', + datefmt='[%Y-%m-%d %H:%M:%S]') +logger = logging.getLogger(__name__) +logger.setLevel(logging.INFO) + + +def calculate_grad_norm(parameters, norm_type: str=2): + ''' + calculate grad norm of mdoel's parameters + parameters: + model's parameters + norm_type: str + Returns + ------------ + Tensor + grad_norm + ''' + + grad_list = [ + paddle.to_tensor(p.grad) for p in parameters if p.grad is not None + ] + norm_list = paddle.stack( + [paddle.norm(grad, norm_type) for grad in grad_list]) + total_norm = paddle.norm(norm_list) + return total_norm + + +# for save name in gen_valid_samples() +ITERATION = 0 + + +class WaveRNNUpdater(StandardUpdater): + def __init__(self, + model: Layer, + optimizer: Optimizer, + criterion: Layer, + dataloader: DataLoader, + init_state=None, + output_dir: Path=None, + mode='RAW'): + super().__init__(model, optimizer, dataloader, init_state=None) + + self.criterion = criterion + # self.scheduler = scheduler + + log_file = output_dir / 'worker_{}.log'.format(dist.get_rank()) + self.filehandler = logging.FileHandler(str(log_file)) + logger.addHandler(self.filehandler) + self.logger = logger + self.msg = "" + self.mode = mode + + def update_core(self, batch): + + self.msg = "Rank: {}, ".format(dist.get_rank()) + losses_dict = {} + # parse batch + self.model.train() + self.optimizer.clear_grad() + + wav, y, mel = batch + + y_hat = self.model(wav, mel) + if self.mode == 'RAW': + y_hat = y_hat.transpose([0, 2, 1]).unsqueeze(-1) + elif self.mode == 'MOL': + y_hat = paddle.cast(y, dtype='float32') + + y = y.unsqueeze(-1) + loss = self.criterion(y_hat, y) + loss.backward() + grad_norm = float( + calculate_grad_norm(self.model.parameters(), norm_type=2)) + + self.optimizer.step() + + report("train/loss", float(loss)) + report("train/grad_norm", float(grad_norm)) + + losses_dict["loss"] = float(loss) + losses_dict["grad_norm"] = float(grad_norm) + self.msg += ', '.join('{}: {:>.6f}'.format(k, v) + for k, v in losses_dict.items()) + global ITERATION + ITERATION = self.state.iteration + 1 + + +class WaveRNNEvaluator(StandardEvaluator): + def __init__(self, + model: Layer, + criterion: Layer, + dataloader: Optimizer, + output_dir: Path=None, + valid_generate_loader=None, + config=None): + super().__init__(model, dataloader) + + log_file = output_dir / 'worker_{}.log'.format(dist.get_rank()) + self.filehandler = logging.FileHandler(str(log_file)) + logger.addHandler(self.filehandler) + self.logger = logger + self.msg = "" + + self.criterion = criterion + self.valid_generate_loader = valid_generate_loader + self.config = config + self.mode = config.model.mode + + self.valid_samples_dir = output_dir / "valid_samples" + self.valid_samples_dir.mkdir(parents=True, exist_ok=True) + + def evaluate_core(self, batch): + self.msg = "Evaluate: " + losses_dict = {} + # parse batch + wav, y, mel = batch + y_hat = self.model(wav, mel) + + if self.mode == 'RAW': + y_hat = y_hat.transpose([0, 2, 1]).unsqueeze(-1) + elif self.mode == 'MOL': + y_hat = paddle.cast(y, dtype='float32') + + y = y.unsqueeze(-1) + loss = self.criterion(y_hat, y) + report("eval/loss", float(loss)) + + losses_dict["loss"] = float(loss) + + self.iteration = ITERATION + if self.iteration % self.config.gen_eval_samples_interval_steps == 0: + self.gen_valid_samples() + + self.msg += ', '.join('{}: {:>.6f}'.format(k, v) + for k, v in losses_dict.items()) + self.logger.info(self.msg) + + def gen_valid_samples(self): + + for i, (mel, wav) in enumerate(self.valid_generate_loader): + if i >= self.config.generate_num: + print("before break") + break + print( + '\n| Generating: {}/{}'.format(i + 1, self.config.generate_num)) + wav = wav[0] + if self.mode == 'MOL': + bits = 16 + else: + bits = self.config.model.bits + if self.config.mu_law and self.mode != 'MOL': + wav = decode_mu_law(wav, 2**bits, from_labels=True) + else: + wav = label_2_float(wav, bits) + origin_save_path = self.valid_samples_dir / '{}_steps_{}_target.wav'.format( + self.iteration, i) + sf.write(origin_save_path, wav.numpy(), samplerate=self.config.fs) + + if self.config.inference.gen_batched: + batch_str = 'gen_batched_target{}_overlap{}'.format( + self.config.inference.target, self.config.inference.overlap) + else: + batch_str = 'gen_not_batched' + gen_save_path = str(self.valid_samples_dir / + '{}_steps_{}_{}.wav'.format(self.iteration, i, + batch_str)) + # (1, C_aux, T) -> (T, C_aux) + mel = mel.squeeze(0).transpose([1, 0]) + gen_sample = self.model.generate( + mel, self.config.inference.gen_batched, + self.config.inference.target, self.config.inference.overlap, + self.config.mu_law) + sf.write( + gen_save_path, gen_sample.numpy(), samplerate=self.config.fs) diff --git a/paddlespeech/t2s/modules/losses.py b/paddlespeech/t2s/modules/losses.py index 3cc7a93cb..618f444a1 100644 --- a/paddlespeech/t2s/modules/losses.py +++ b/paddlespeech/t2s/modules/losses.py @@ -14,6 +14,7 @@ import math import librosa +import numpy as np import paddle from paddle import nn from paddle.fluid.layers import sequence_mask @@ -23,6 +24,145 @@ from scipy import signal from paddlespeech.t2s.modules.nets_utils import make_non_pad_mask +# Losses for WaveRNN +def log_sum_exp(x): + """ numerically stable log_sum_exp implementation that prevents overflow """ + # TF ordering + axis = len(x.shape) - 1 + m = paddle.max(x, axis=axis) + m2 = paddle.max(x, axis=axis, keepdim=True) + return m + paddle.log(paddle.sum(paddle.exp(x - m2), axis=axis)) + + +# It is adapted from https://github.com/r9y9/wavenet_vocoder/blob/master/wavenet_vocoder/mixture.py +def discretized_mix_logistic_loss(y_hat, + y, + num_classes=65536, + log_scale_min=None, + reduce=True): + if log_scale_min is None: + log_scale_min = float(np.log(1e-14)) + y_hat = y_hat.transpose([0, 2, 1]) + assert y_hat.dim() == 3 + assert y_hat.shape[1] % 3 == 0 + nr_mix = y_hat.shape[1] // 3 + + # (B x T x C) + y_hat = y_hat.transpose([0, 2, 1]) + + # unpack parameters. (B, T, num_mixtures) x 3 + logit_probs = y_hat[:, :, :nr_mix] + means = y_hat[:, :, nr_mix:2 * nr_mix] + log_scales = paddle.clip( + y_hat[:, :, 2 * nr_mix:3 * nr_mix], min=log_scale_min) + + # B x T x 1 -> B x T x num_mixtures + y = y.expand_as(means) + centered_y = paddle.cast(y, dtype=paddle.get_default_dtype()) - means + inv_stdv = paddle.exp(-log_scales) + plus_in = inv_stdv * (centered_y + 1. / (num_classes - 1)) + cdf_plus = F.sigmoid(plus_in) + min_in = inv_stdv * (centered_y - 1. / (num_classes - 1)) + cdf_min = F.sigmoid(min_in) + + # log probability for edge case of 0 (before scaling) + # equivalent: torch.log(F.sigmoid(plus_in)) + # softplus: log(1+ e^{-x}) + log_cdf_plus = plus_in - F.softplus(plus_in) + + # log probability for edge case of 255 (before scaling) + # equivalent: (1 - F.sigmoid(min_in)).log() + log_one_minus_cdf_min = -F.softplus(min_in) + + # probability for all other cases + cdf_delta = cdf_plus - cdf_min + + mid_in = inv_stdv * centered_y + # log probability in the center of the bin, to be used in extreme cases + # (not actually used in our code) + log_pdf_mid = mid_in - log_scales - 2. * F.softplus(mid_in) + + # TODO: cdf_delta <= 1e-5 actually can happen. How can we choose the value + # for num_classes=65536 case? 1e-7? not sure.. + inner_inner_cond = cdf_delta > 1e-5 + + inner_inner_cond = paddle.cast( + inner_inner_cond, dtype=paddle.get_default_dtype()) + + # inner_inner_out = inner_inner_cond * \ + # paddle.log(paddle.clip(cdf_delta, min=1e-12)) + \ + # (1. - inner_inner_cond) * (log_pdf_mid - np.log((num_classes - 1) / 2)) + + inner_inner_out = inner_inner_cond * paddle.log( + paddle.clip(cdf_delta, min=1e-12)) + (1. - inner_inner_cond) * ( + log_pdf_mid - np.log((num_classes - 1) / 2)) + + inner_cond = y > 0.999 + + inner_cond = paddle.cast(inner_cond, dtype=paddle.get_default_dtype()) + + inner_out = inner_cond * log_one_minus_cdf_min + (1. - inner_cond + ) * inner_inner_out + cond = y < -0.999 + cond = paddle.cast(cond, dtype=paddle.get_default_dtype()) + + log_probs = cond * log_cdf_plus + (1. - cond) * inner_out + log_probs = log_probs + F.log_softmax(logit_probs, -1) + + if reduce: + return -paddle.mean(log_sum_exp(log_probs)) + else: + return -log_sum_exp(log_probs).unsqueeze(-1) + + +def sample_from_discretized_mix_logistic(y, log_scale_min=None): + """ + Sample from discretized mixture of logistic distributions + Parameters + ---------- + y : Tensor + (B, C, T) + log_scale_min : float + Log scale minimum value + Returns + ---------- + Tensor + sample in range of [-1, 1]. + """ + if log_scale_min is None: + log_scale_min = float(np.log(1e-14)) + + assert y.shape[1] % 3 == 0 + nr_mix = y.shape[1] // 3 + + # (B, T, C) + y = y.transpose([0, 2, 1]) + logit_probs = y[:, :, :nr_mix] + + # sample mixture indicator from softmax + temp = paddle.uniform( + logit_probs.shape, dtype=logit_probs.dtype, min=1e-5, max=1.0 - 1e-5) + temp = logit_probs - paddle.log(-paddle.log(temp)) + argmax = paddle.argmax(temp, axis=-1) + + # (B, T) -> (B, T, nr_mix) + one_hot = F.one_hot(argmax, nr_mix) + one_hot = paddle.cast(one_hot, dtype=paddle.get_default_dtype()) + + # select logistic parameters + means = paddle.sum(y[:, :, nr_mix:2 * nr_mix] * one_hot, axis=-1) + log_scales = paddle.clip( + paddle.sum(y[:, :, 2 * nr_mix:3 * nr_mix] * one_hot, axis=-1), + min=log_scale_min) + # sample from logistic & clip to interval + # we don't actually round to the nearest 8bit value when sampling + u = paddle.uniform(means.shape, min=1e-5, max=1.0 - 1e-5) + x = means + paddle.exp(log_scales) * (paddle.log(u) - paddle.log(1. - u)) + x = paddle.clip(x, min=-1., max=-1.) + + return x + + # Loss for new Tacotron2 class GuidedAttentionLoss(nn.Layer): """Guided attention loss function module.