From c463b35fa86fab5af1fb432ff13d07c96cbf2afd Mon Sep 17 00:00:00 2001 From: liangym Date: Mon, 16 Jan 2023 02:41:26 +0000 Subject: [PATCH] diffsinger opencpop fft train, test=tts --- examples/opencpop/svs1/local/preprocess.sh | 66 + examples/opencpop/svs1/local/synthesize.sh | 100 ++ examples/opencpop/svs1/local/train.sh | 12 + examples/opencpop/svs1/run.sh | 37 + paddlespeech/t2s/datasets/am_batch_fn.py | 123 ++ paddlespeech/t2s/datasets/get_feats.py | 5 + paddlespeech/t2s/datasets/preprocess_utils.py | 57 + paddlespeech/t2s/exps/diffsinger/__init__.py | 13 + paddlespeech/t2s/exps/diffsinger/normalize.py | 165 +++ .../t2s/exps/diffsinger/preprocess.py | 359 +++++ paddlespeech/t2s/exps/diffsinger/train.py | 222 ++++ paddlespeech/t2s/exps/syn_utils.py | 9 + paddlespeech/t2s/exps/synthesize.py | 11 +- .../t2s/models/diffsinger/__init__.py | 15 + .../t2s/models/diffsinger/diffsinger.py | 1153 +++++++++++++++++ .../models/diffsinger/diffsinger_updater.py | 248 ++++ .../t2s/modules/transformer/encoder.py | 10 +- 17 files changed, 2602 insertions(+), 3 deletions(-) create mode 100755 examples/opencpop/svs1/local/preprocess.sh create mode 100755 examples/opencpop/svs1/local/synthesize.sh create mode 100755 examples/opencpop/svs1/local/train.sh create mode 100755 examples/opencpop/svs1/run.sh create mode 100644 paddlespeech/t2s/exps/diffsinger/__init__.py create mode 100644 paddlespeech/t2s/exps/diffsinger/normalize.py create mode 100644 paddlespeech/t2s/exps/diffsinger/preprocess.py create mode 100644 paddlespeech/t2s/exps/diffsinger/train.py create mode 100644 paddlespeech/t2s/models/diffsinger/__init__.py create mode 100644 paddlespeech/t2s/models/diffsinger/diffsinger.py create mode 100644 paddlespeech/t2s/models/diffsinger/diffsinger_updater.py diff --git a/examples/opencpop/svs1/local/preprocess.sh b/examples/opencpop/svs1/local/preprocess.sh new file mode 100755 index 000000000..78803b83f --- /dev/null +++ b/examples/opencpop/svs1/local/preprocess.sh @@ -0,0 +1,66 @@ +#!/bin/bash + +stage=1 +stop_stage=100 + +config_path=$1 + +if [ ${stage} -le 1 ] && [ ${stop_stage} -ge 1 ]; then + # extract features + echo "Extract features ..." + python3 ${BIN_DIR}/preprocess.py \ + --dataset=opencpop \ + --rootdir=~/datasets/SVS/Opencpop/segments \ + --dumpdir=dump \ + --label-file=~/datasets/SVS/Opencpop/segments/transcriptions.txt \ + --config=${config_path} \ + --num-cpu=20 \ + --cut-sil=True +fi + +if [ ${stage} -le 2 ] && [ ${stop_stage} -ge 2 ]; then + # get features' stats(mean and std) + echo "Get features' stats ..." + python3 ${MAIN_ROOT}/utils/compute_statistics.py \ + --metadata=dump/train/raw/metadata.jsonl \ + --field-name="speech" + + python3 ${MAIN_ROOT}/utils/compute_statistics.py \ + --metadata=dump/train/raw/metadata.jsonl \ + --field-name="pitch" + + python3 ${MAIN_ROOT}/utils/compute_statistics.py \ + --metadata=dump/train/raw/metadata.jsonl \ + --field-name="energy" +fi + +if [ ${stage} -le 3 ] && [ ${stop_stage} -ge 3 ]; then + # normalize and covert phone/speaker to id, dev and test should use train's stats + echo "Normalize ..." + python3 ${BIN_DIR}/normalize.py \ + --metadata=dump/train/raw/metadata.jsonl \ + --dumpdir=dump/train/norm \ + --speech-stats=dump/train/speech_stats.npy \ + --pitch-stats=dump/train/pitch_stats.npy \ + --energy-stats=dump/train/energy_stats.npy \ + --phones-dict=dump/phone_id_map.txt \ + --speaker-dict=dump/speaker_id_map.txt + + python3 ${BIN_DIR}/normalize.py \ + --metadata=dump/dev/raw/metadata.jsonl \ + --dumpdir=dump/dev/norm \ + --speech-stats=dump/train/speech_stats.npy \ + --pitch-stats=dump/train/pitch_stats.npy \ + --energy-stats=dump/train/energy_stats.npy \ + --phones-dict=dump/phone_id_map.txt \ + --speaker-dict=dump/speaker_id_map.txt + + python3 ${BIN_DIR}/normalize.py \ + --metadata=dump/test/raw/metadata.jsonl \ + --dumpdir=dump/test/norm \ + --speech-stats=dump/train/speech_stats.npy \ + --pitch-stats=dump/train/pitch_stats.npy \ + --energy-stats=dump/train/energy_stats.npy \ + --phones-dict=dump/phone_id_map.txt \ + --speaker-dict=dump/speaker_id_map.txt +fi diff --git a/examples/opencpop/svs1/local/synthesize.sh b/examples/opencpop/svs1/local/synthesize.sh new file mode 100755 index 000000000..238437b4d --- /dev/null +++ b/examples/opencpop/svs1/local/synthesize.sh @@ -0,0 +1,100 @@ +#!/bin/bash + +config_path=$1 +train_output_path=$2 +ckpt_name=$3 +stage=0 +stop_stage=0 + +# pwgan +if [ ${stage} -le 0 ] && [ ${stop_stage} -ge 0 ]; then + FLAGS_allocator_strategy=naive_best_fit \ + FLAGS_fraction_of_gpu_memory_to_use=0.01 \ + python3 ${BIN_DIR}/../synthesize.py \ + --am=diffsinger_opencpop \ + --am_config=${config_path} \ + --am_ckpt=${train_output_path}/checkpoints/${ckpt_name} \ + --am_stat=dump/train/speech_stats.npy \ + --voc=pwgan_opencpop \ + --voc_config=pwgan_opencpop/default.yaml \ + --voc_ckpt=pwgan_opencpop/snapshot_iter_100000.pdz \ + --voc_stat=pwgan_opencpop/feats_stats.npy \ + --test_metadata=dump/test/norm/metadata.jsonl \ + --output_dir=${train_output_path}/test \ + --phones_dict=dump/phone_id_map.txt +fi + +# for more GAN Vocoders +# multi band melgan +if [ ${stage} -le 1 ] && [ ${stop_stage} -ge 1 ]; then + FLAGS_allocator_strategy=naive_best_fit \ + FLAGS_fraction_of_gpu_memory_to_use=0.01 \ + python3 ${BIN_DIR}/../synthesize.py \ + --am=fastspeech2_csmsc \ + --am_config=${config_path} \ + --am_ckpt=${train_output_path}/checkpoints/${ckpt_name} \ + --am_stat=dump/train/speech_stats.npy \ + --voc=mb_melgan_csmsc \ + --voc_config=mb_melgan_csmsc_ckpt_0.1.1/default.yaml \ + --voc_ckpt=mb_melgan_csmsc_ckpt_0.1.1/snapshot_iter_1000000.pdz\ + --voc_stat=mb_melgan_csmsc_ckpt_0.1.1/feats_stats.npy \ + --test_metadata=dump/test/norm/metadata.jsonl \ + --output_dir=${train_output_path}/test \ + --phones_dict=dump/phone_id_map.txt +fi + +# style melgan +if [ ${stage} -le 2 ] && [ ${stop_stage} -ge 2 ]; then + FLAGS_allocator_strategy=naive_best_fit \ + FLAGS_fraction_of_gpu_memory_to_use=0.01 \ + python3 ${BIN_DIR}/../synthesize.py \ + --am=fastspeech2_csmsc \ + --am_config=${config_path} \ + --am_ckpt=${train_output_path}/checkpoints/${ckpt_name} \ + --am_stat=dump/train/speech_stats.npy \ + --voc=style_melgan_csmsc \ + --voc_config=style_melgan_csmsc_ckpt_0.1.1/default.yaml \ + --voc_ckpt=style_melgan_csmsc_ckpt_0.1.1/snapshot_iter_1500000.pdz \ + --voc_stat=style_melgan_csmsc_ckpt_0.1.1/feats_stats.npy \ + --test_metadata=dump/test/norm/metadata.jsonl \ + --output_dir=${train_output_path}/test \ + --phones_dict=dump/phone_id_map.txt +fi + +# hifigan +if [ ${stage} -le 3 ] && [ ${stop_stage} -ge 3 ]; then + echo "in hifigan syn" + FLAGS_allocator_strategy=naive_best_fit \ + FLAGS_fraction_of_gpu_memory_to_use=0.01 \ + python3 ${BIN_DIR}/../synthesize.py \ + --am=fastspeech2_csmsc \ + --am_config=${config_path} \ + --am_ckpt=${train_output_path}/checkpoints/${ckpt_name} \ + --am_stat=dump/train/speech_stats.npy \ + --voc=hifigan_csmsc \ + --voc_config=hifigan_csmsc_ckpt_0.1.1/default.yaml \ + --voc_ckpt=hifigan_csmsc_ckpt_0.1.1/snapshot_iter_2500000.pdz \ + --voc_stat=hifigan_csmsc_ckpt_0.1.1/feats_stats.npy \ + --test_metadata=dump/test/norm/metadata.jsonl \ + --output_dir=${train_output_path}/test \ + --phones_dict=dump/phone_id_map.txt +fi + +# wavernn +if [ ${stage} -le 4 ] && [ ${stop_stage} -ge 4 ]; then + echo "in wavernn syn" + FLAGS_allocator_strategy=naive_best_fit \ + FLAGS_fraction_of_gpu_memory_to_use=0.01 \ + python3 ${BIN_DIR}/../synthesize.py \ + --am=fastspeech2_csmsc \ + --am_config=${config_path} \ + --am_ckpt=${train_output_path}/checkpoints/${ckpt_name} \ + --am_stat=dump/train/speech_stats.npy \ + --voc=wavernn_csmsc \ + --voc_config=wavernn_csmsc_ckpt_0.2.0/default.yaml \ + --voc_ckpt=wavernn_csmsc_ckpt_0.2.0/snapshot_iter_400000.pdz \ + --voc_stat=wavernn_csmsc_ckpt_0.2.0/feats_stats.npy \ + --test_metadata=dump/test/norm/metadata.jsonl \ + --output_dir=${train_output_path}/test \ + --phones_dict=dump/phone_id_map.txt +fi diff --git a/examples/opencpop/svs1/local/train.sh b/examples/opencpop/svs1/local/train.sh new file mode 100755 index 000000000..5e255fb8d --- /dev/null +++ b/examples/opencpop/svs1/local/train.sh @@ -0,0 +1,12 @@ +#!/bin/bash + +config_path=$1 +train_output_path=$2 + +python3 ${BIN_DIR}/train.py \ + --train-metadata=dump/train/norm/metadata.jsonl \ + --dev-metadata=dump/dev/norm/metadata.jsonl \ + --config=${config_path} \ + --output-dir=${train_output_path} \ + --ngpu=2 \ + --phones-dict=dump/phone_id_map.txt diff --git a/examples/opencpop/svs1/run.sh b/examples/opencpop/svs1/run.sh new file mode 100755 index 000000000..44d8efc66 --- /dev/null +++ b/examples/opencpop/svs1/run.sh @@ -0,0 +1,37 @@ +#!/bin/bash + +set -e +source path.sh + +gpus=4,5 +stage=0 +stop_stage=100 + +conf_path=conf/default.yaml +train_output_path=exp/default +ckpt_name=snapshot_iter_153.pdz + +# with the following command, you can choose the stage range you want to run +# such as `./run.sh --stage 0 --stop-stage 0` +# this can not be mixed use with `$1`, `$2` ... +source ${MAIN_ROOT}/utils/parse_options.sh || exit 1 + +if [ ${stage} -le 0 ] && [ ${stop_stage} -ge 0 ]; then + # prepare data + ./local/preprocess.sh ${conf_path} || exit -1 +fi + +if [ ${stage} -le 1 ] && [ ${stop_stage} -ge 1 ]; then + # train model, all `ckpt` under `train_output_path/checkpoints/` dir + CUDA_VISIBLE_DEVICES=${gpus} ./local/train.sh ${conf_path} ${train_output_path} || exit -1 +fi + +if [ ${stage} -le 2 ] && [ ${stop_stage} -ge 2 ]; then + # synthesize, vocoder is pwgan by default + CUDA_VISIBLE_DEVICES=${gpus} ./local/synthesize.sh ${conf_path} ${train_output_path} ${ckpt_name} || exit -1 +fi + +if [ ${stage} -le 3 ] && [ ${stop_stage} -ge 3 ]; then + # synthesize_e2e, vocoder is pwgan by default + CUDA_VISIBLE_DEVICES=${gpus} ./local/synthesize_e2e.sh ${conf_path} ${train_output_path} ${ckpt_name} || exit -1 +fi diff --git a/paddlespeech/t2s/datasets/am_batch_fn.py b/paddlespeech/t2s/datasets/am_batch_fn.py index c95d908dc..9ae791b48 100644 --- a/paddlespeech/t2s/datasets/am_batch_fn.py +++ b/paddlespeech/t2s/datasets/am_batch_fn.py @@ -414,6 +414,129 @@ def fastspeech2_multi_spk_batch_fn(examples): return batch +def diffsinger_single_spk_batch_fn(examples): + # fields = ["text", "note", "note_dur", "is_slur", "text_lengths", "speech", "speech_lengths", "durations", "pitch", "energy"] + text = [np.array(item["text"], dtype=np.int64) for item in examples] + note = [np.array(item["note"], dtype=np.int64) for item in examples] + note_dur = [np.array(item["note_dur"], dtype=np.float32) for item in examples] + is_slur = [np.array(item["is_slur"], dtype=np.int64) for item in examples] + speech = [np.array(item["speech"], dtype=np.float32) for item in examples] + pitch = [np.array(item["pitch"], dtype=np.float32) for item in examples] + energy = [np.array(item["energy"], dtype=np.float32) for item in examples] + durations = [ + np.array(item["durations"], dtype=np.int64) for item in examples + ] + + text_lengths = [ + np.array(item["text_lengths"], dtype=np.int64) for item in examples + ] + speech_lengths = [ + np.array(item["speech_lengths"], dtype=np.int64) for item in examples + ] + + text = batch_sequences(text) + note = batch_sequences(note) + note_dur = batch_sequences(note_dur) + is_slur = batch_sequences(is_slur) + pitch = batch_sequences(pitch) + speech = batch_sequences(speech) + durations = batch_sequences(durations) + energy = batch_sequences(energy) + + # convert each batch to paddle.Tensor + text = paddle.to_tensor(text) + note = paddle.to_tensor(note) + note_dur = paddle.to_tensor(note_dur) + is_slur = paddle.to_tensor(is_slur) + pitch = paddle.to_tensor(pitch) + speech = paddle.to_tensor(speech) + durations = paddle.to_tensor(durations) + energy = paddle.to_tensor(energy) + text_lengths = paddle.to_tensor(text_lengths) + speech_lengths = paddle.to_tensor(speech_lengths) + + batch = { + "text": text, + "note": note, + "note_dur": note_dur, + "is_slur": is_slur, + "text_lengths": text_lengths, + "durations": durations, + "speech": speech, + "speech_lengths": speech_lengths, + "pitch": pitch, + "energy": energy + } + return batch + + +def diffsinger_multi_spk_batch_fn(examples): + # fields = ["text", "note", "note_dur", "is_slur", "text_lengths", "speech", "speech_lengths", "durations", "pitch", "energy", "spk_id"/"spk_emb"] + text = [np.array(item["text"], dtype=np.int64) for item in examples] + note = [np.array(item["note"], dtype=np.int64) for item in examples] + note_dur = [np.array(item["note_dur"], dtype=np.float32) for item in examples] + is_slur = [np.array(item["is_slur"], dtype=np.int64) for item in examples] + speech = [np.array(item["speech"], dtype=np.float32) for item in examples] + pitch = [np.array(item["pitch"], dtype=np.float32) for item in examples] + energy = [np.array(item["energy"], dtype=np.float32) for item in examples] + durations = [ + np.array(item["durations"], dtype=np.int64) for item in examples + ] + text_lengths = [ + np.array(item["text_lengths"], dtype=np.int64) for item in examples + ] + speech_lengths = [ + np.array(item["speech_lengths"], dtype=np.int64) for item in examples + ] + + text = batch_sequences(text) + note = batch_sequences(note) + note_dur = batch_sequences(note_dur) + is_slur = batch_sequences(is_slur) + pitch = batch_sequences(pitch) + speech = batch_sequences(speech) + durations = batch_sequences(durations) + energy = batch_sequences(energy) + + # convert each batch to paddle.Tensor + text = paddle.to_tensor(text) + note = paddle.to_tensor(note) + note_dur = paddle.to_tensor(note_dur) + is_slur = paddle.to_tensor(is_slur) + pitch = paddle.to_tensor(pitch) + speech = paddle.to_tensor(speech) + durations = paddle.to_tensor(durations) + energy = paddle.to_tensor(energy) + text_lengths = paddle.to_tensor(text_lengths) + speech_lengths = paddle.to_tensor(speech_lengths) + + batch = { + "text": text, + "note": note, + "note_dur": note_dur, + "is_slur": is_slur, + "text_lengths": text_lengths, + "durations": durations, + "speech": speech, + "speech_lengths": speech_lengths, + "pitch": pitch, + "energy": energy + } + # spk_emb has a higher priority than spk_id + if "spk_emb" in examples[0]: + spk_emb = [ + np.array(item["spk_emb"], dtype=np.float32) for item in examples + ] + spk_emb = batch_sequences(spk_emb) + spk_emb = paddle.to_tensor(spk_emb) + batch["spk_emb"] = spk_emb + elif "spk_id" in examples[0]: + spk_id = [np.array(item["spk_id"], dtype=np.int64) for item in examples] + spk_id = paddle.to_tensor(spk_id) + batch["spk_id"] = spk_id + return batch + + def transformer_single_spk_batch_fn(examples): # fields = ["text", "text_lengths", "speech", "speech_lengths"] text = [np.array(item["text"], dtype=np.int64) for item in examples] diff --git a/paddlespeech/t2s/datasets/get_feats.py b/paddlespeech/t2s/datasets/get_feats.py index 21458f152..bad1d43ce 100644 --- a/paddlespeech/t2s/datasets/get_feats.py +++ b/paddlespeech/t2s/datasets/get_feats.py @@ -16,6 +16,7 @@ import librosa import numpy as np import pyworld from scipy.interpolate import interp1d +from typing import List class LogMelFBank(): @@ -166,6 +167,8 @@ class Pitch(): f0 = self._calculate_f0(wav, use_continuous_f0, use_log_f0) if use_token_averaged_f0 and duration is not None: f0 = self._average_by_duration(f0, duration) + else: + f0 = np.expand_dims(np.array(f0), 0).T return f0 @@ -224,6 +227,8 @@ class Energy(): energy = self._calculate_energy(wav) if use_token_averaged_energy and duration is not None: energy = self._average_by_duration(energy, duration) + else: + energy = np.expand_dims(np.array(energy), 0).T return energy diff --git a/paddlespeech/t2s/datasets/preprocess_utils.py b/paddlespeech/t2s/datasets/preprocess_utils.py index 445b69bda..075a80e13 100644 --- a/paddlespeech/t2s/datasets/preprocess_utils.py +++ b/paddlespeech/t2s/datasets/preprocess_utils.py @@ -12,6 +12,8 @@ # See the License for the specific language governing permissions and # limitations under the License. import re +import librosa +import numpy as np # speaker|utt_id|phn dur phn dur ... @@ -40,6 +42,58 @@ def get_phn_dur(file_name): f.close() return sentence, speaker_set +def note2midi(notes): + midis = [] + for note in notes: + if note == 'rest': + midi = 0 + else: + midi = librosa.note_to_midi(note.split("/")[0]) + midis.append(midi) + + return midis + +def time2frame(times, sample_rate: int=24000, n_shift: int=128,): + end = 0.0 + ends = [] + for t in times: + end += t + ends.append(end) + frame_pos = librosa.time_to_frames(ends, sr=sample_rate, hop_length=n_shift) + durations = np.diff(frame_pos, prepend=0) + return durations + +def get_sentences_svs(file_name, dataset: str='opencpop', sample_rate: int=24000, n_shift: int=128,): + ''' + read label file + Args: + file_name (str or Path): path of gen_duration_from_textgrid.py's result + dataset (str): dataset name + Returns: + Dict: sentence: {'utt': ([char], [int])} + ''' + f = open(file_name, 'r') + sentence = {} + speaker_set = set() + if dataset == 'opencpop': + speaker_set.add("opencpop") + for line in f: + line_list = line.strip().split('|') + utt = line_list[0] + text = line_list[1] + ph = line_list[2].split() + midi = note2midi(line_list[3].split()) + midi_dur = line_list[4].split() + ph_dur = time2frame([float(t) for t in line_list[5].split()]) + is_slur = line_list[6].split() + assert len(ph) == len(midi) == len(midi_dur) == len(is_slur) + sentence[utt] = (ph, [int(i) for i in ph_dur], [int(i) for i in midi], [float(i) for i in midi_dur], [int(i) for i in is_slur], text, "opencpop") + else: + print("dataset should in {opencpop} now!") + + f.close() + return sentence, speaker_set + def merge_silence(sentence): ''' @@ -88,6 +142,9 @@ def get_input_token(sentence, output_path, dataset="baker"): phn_token = ["", ""] + phn_token if dataset in {"baker", "aishell3"}: phn_token += [",", "。", "?", "!"] + # svs dataset + elif dataset in {"opencpop"}: + pass else: phn_token += [",", ".", "?", "!"] phn_token += [""] diff --git a/paddlespeech/t2s/exps/diffsinger/__init__.py b/paddlespeech/t2s/exps/diffsinger/__init__.py new file mode 100644 index 000000000..abf198b97 --- /dev/null +++ b/paddlespeech/t2s/exps/diffsinger/__init__.py @@ -0,0 +1,13 @@ +# Copyright (c) 2020 PaddlePaddle Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. diff --git a/paddlespeech/t2s/exps/diffsinger/normalize.py b/paddlespeech/t2s/exps/diffsinger/normalize.py new file mode 100644 index 000000000..ec80a229e --- /dev/null +++ b/paddlespeech/t2s/exps/diffsinger/normalize.py @@ -0,0 +1,165 @@ +# Copyright (c) 2021 PaddlePaddle Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +"""Normalize feature files and dump them.""" +import argparse +import logging +from operator import itemgetter +from pathlib import Path + +import jsonlines +import numpy as np +from sklearn.preprocessing import StandardScaler +from tqdm import tqdm + +from paddlespeech.t2s.datasets.data_table import DataTable + + +def main(): + """Run preprocessing process.""" + parser = argparse.ArgumentParser( + description="Normalize dumped raw features (See detail in parallel_wavegan/bin/normalize.py)." + ) + parser.add_argument( + "--metadata", + type=str, + required=True, + help="directory including feature files to be normalized. " + "you need to specify either *-scp or rootdir.") + + parser.add_argument( + "--dumpdir", + type=str, + required=True, + help="directory to dump normalized feature files.") + parser.add_argument( + "--speech-stats", + type=str, + required=True, + help="speech statistics file.") + parser.add_argument( + "--pitch-stats", type=str, required=True, help="pitch statistics file.") + parser.add_argument( + "--energy-stats", + type=str, + required=True, + help="energy statistics file.") + parser.add_argument( + "--phones-dict", type=str, default=None, help="phone vocabulary file.") + parser.add_argument( + "--speaker-dict", type=str, default=None, help="speaker id map file.") + + args = parser.parse_args() + + dumpdir = Path(args.dumpdir).expanduser() + # use absolute path + dumpdir = dumpdir.resolve() + dumpdir.mkdir(parents=True, exist_ok=True) + + # get dataset + with jsonlines.open(args.metadata, 'r') as reader: + metadata = list(reader) + dataset = DataTable( + metadata, + converters={ + "speech": np.load, + "pitch": np.load, + "energy": np.load, + }) + logging.info(f"The number of files = {len(dataset)}.") + + # restore scaler + speech_scaler = StandardScaler() + speech_scaler.mean_ = np.load(args.speech_stats)[0] + speech_scaler.scale_ = np.load(args.speech_stats)[1] + speech_scaler.n_features_in_ = speech_scaler.mean_.shape[0] + + pitch_scaler = StandardScaler() + pitch_scaler.mean_ = np.load(args.pitch_stats)[0] + pitch_scaler.scale_ = np.load(args.pitch_stats)[1] + pitch_scaler.n_features_in_ = pitch_scaler.mean_.shape[0] + + energy_scaler = StandardScaler() + energy_scaler.mean_ = np.load(args.energy_stats)[0] + energy_scaler.scale_ = np.load(args.energy_stats)[1] + energy_scaler.n_features_in_ = energy_scaler.mean_.shape[0] + + vocab_phones = {} + with open(args.phones_dict, 'rt') as f: + phn_id = [line.strip().split() for line in f.readlines()] + for phn, id in phn_id: + vocab_phones[phn] = int(id) + + vocab_speaker = {} + with open(args.speaker_dict, 'rt') as f: + spk_id = [line.strip().split() for line in f.readlines()] + for spk, id in spk_id: + vocab_speaker[spk] = int(id) + + # process each file + output_metadata = [] + + for item in tqdm(dataset): + utt_id = item['utt_id'] + speech = item['speech'] + pitch = item['pitch'] + energy = item['energy'] + # normalize + speech = speech_scaler.transform(speech) + speech_dir = dumpdir / "data_speech" + speech_dir.mkdir(parents=True, exist_ok=True) + speech_path = speech_dir / f"{utt_id}_speech.npy" + np.save(speech_path, speech.astype(np.float32), allow_pickle=False) + + pitch = pitch_scaler.transform(pitch) + pitch_dir = dumpdir / "data_pitch" + pitch_dir.mkdir(parents=True, exist_ok=True) + pitch_path = pitch_dir / f"{utt_id}_pitch.npy" + np.save(pitch_path, pitch.astype(np.float32), allow_pickle=False) + + energy = energy_scaler.transform(energy) + energy_dir = dumpdir / "data_energy" + energy_dir.mkdir(parents=True, exist_ok=True) + energy_path = energy_dir / f"{utt_id}_energy.npy" + np.save(energy_path, energy.astype(np.float32), allow_pickle=False) + phone_ids = [vocab_phones[p] for p in item['phones']] + spk_id = vocab_speaker[item["speaker"]] + record = { + "utt_id": item['utt_id'], + "spk_id": spk_id, + "text": phone_ids, + "text_lengths": item['text_lengths'], + "speech_lengths": item['speech_lengths'], + "durations": item['durations'], + "speech": str(speech_path), + "pitch": str(pitch_path), + "energy": str(energy_path), + "note": item['note'], + "note_dur": item['note_dur'], + "is_slur": item['is_slur'], + } + # add spk_emb for voice cloning + if "spk_emb" in item: + record["spk_emb"] = str(item["spk_emb"]) + + output_metadata.append(record) + output_metadata.sort(key=itemgetter('utt_id')) + output_metadata_path = Path(args.dumpdir) / "metadata.jsonl" + with jsonlines.open(output_metadata_path, 'w') as writer: + for item in output_metadata: + writer.write(item) + logging.info(f"metadata dumped into {output_metadata_path}") + + +if __name__ == "__main__": + main() diff --git a/paddlespeech/t2s/exps/diffsinger/preprocess.py b/paddlespeech/t2s/exps/diffsinger/preprocess.py new file mode 100644 index 000000000..e89a2f31e --- /dev/null +++ b/paddlespeech/t2s/exps/diffsinger/preprocess.py @@ -0,0 +1,359 @@ +# Copyright (c) 2021 PaddlePaddle Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +import argparse +import os +from concurrent.futures import ThreadPoolExecutor +from operator import itemgetter +from pathlib import Path +from typing import Any +from typing import Dict +from typing import List + +import jsonlines +import librosa +import numpy as np +import tqdm +import yaml +from yacs.config import CfgNode + +from paddlespeech.t2s.datasets.get_feats import Energy +from paddlespeech.t2s.datasets.get_feats import LogMelFBank +from paddlespeech.t2s.datasets.get_feats import Pitch +from paddlespeech.t2s.datasets.preprocess_utils import compare_duration_and_mel_length +from paddlespeech.t2s.datasets.preprocess_utils import get_input_token +from paddlespeech.t2s.datasets.preprocess_utils import get_sentences_svs +from paddlespeech.t2s.datasets.preprocess_utils import get_spk_id_map +from paddlespeech.t2s.datasets.preprocess_utils import merge_silence +from paddlespeech.t2s.utils import str2bool + +ALL_SHENGMU = ['zh', 'ch', 'sh', 'b', 'p', 'm', 'f', 'd', 't', 'n', 'l', 'g', 'k', 'h', 'j', + 'q', 'x', 'r', 'z', 'c', 's', 'y', 'w'] +ALL_YUNMU = ['a', 'ai', 'an', 'ang', 'ao', 'e', 'ei', 'en', 'eng', 'er', 'i', 'ia', 'ian', + 'iang', 'iao', 'ie', 'in', 'ing', 'iong', 'iu', 'ng', 'o', 'ong', 'ou', + 'u', 'ua', 'uai', 'uan', 'uang', 'ui', 'un', 'uo', 'v', 'van', 've', 'vn'] + +def process_sentence(config: Dict[str, Any], + fp: Path, + sentences: Dict, + output_dir: Path, + mel_extractor=None, + pitch_extractor=None, + energy_extractor=None, + cut_sil: bool=True, + spk_emb_dir: Path=None,): + utt_id = fp.stem + record = None + if utt_id in sentences: + # reading, resampling may occur + wav, _ = librosa.load(str(fp), sr=config.fs) + if len(wav.shape) != 1: + return record + max_value = np.abs(wav).max() + if max_value > 1.0: + wav = wav / max_value + assert len(wav.shape) == 1, f"{utt_id} is not a mono-channel audio." + assert np.abs(wav).max( + ) <= 1.0, f"{utt_id} is seems to be different that 16 bit PCM." + phones = sentences[utt_id][0] + durations = sentences[utt_id][1] + note = sentences[utt_id][2] + note_dur = sentences[utt_id][3] + is_slur = sentences[utt_id][4] + speaker = sentences[utt_id][-1] + + # extract mel feats + logmel = mel_extractor.get_log_mel_fbank(wav) + # change duration according to mel_length + compare_duration_and_mel_length(sentences, utt_id, logmel) + # utt_id may be popped in compare_duration_and_mel_length + if utt_id not in sentences: + return None + phones = sentences[utt_id][0] + durations = sentences[utt_id][1] + num_frames = logmel.shape[0] + word_boundary = [1 if x in ALL_YUNMU + ['AP', 'SP'] else 0 for x in phones] + # print(sum(durations), num_frames) + assert sum(durations) == num_frames, "the sum of durations doesn't equal to the num of mel frames. " + speech_dir = output_dir / "data_speech" + speech_dir.mkdir(parents=True, exist_ok=True) + speech_path = speech_dir / (utt_id + "_speech.npy") + np.save(speech_path, logmel) + # extract pitch and energy + pitch = pitch_extractor.get_pitch(wav) + assert pitch.shape[0] == num_frames + pitch_dir = output_dir / "data_pitch" + pitch_dir.mkdir(parents=True, exist_ok=True) + pitch_path = pitch_dir / (utt_id + "_pitch.npy") + np.save(pitch_path, pitch) + energy = energy_extractor.get_energy(wav) + assert energy.shape[0] == num_frames + energy_dir = output_dir / "data_energy" + energy_dir.mkdir(parents=True, exist_ok=True) + energy_path = energy_dir / (utt_id + "_energy.npy") + np.save(energy_path, energy) + + record = { + "utt_id": utt_id, + "phones": phones, + "text_lengths": len(phones), + "speech_lengths": num_frames, + "durations": durations, + "speech": str(speech_path), + "pitch": str(pitch_path), + "energy": str(energy_path), + "speaker": speaker, + "note": note, + "note_dur": note_dur, + "is_slur": is_slur, + } + if spk_emb_dir: + if speaker in os.listdir(spk_emb_dir): + embed_name = utt_id + ".npy" + embed_path = spk_emb_dir / speaker / embed_name + if embed_path.is_file(): + record["spk_emb"] = str(embed_path) + else: + return None + return record + + +def process_sentences(config, + fps: List[Path], + sentences: Dict, + output_dir: Path, + mel_extractor=None, + pitch_extractor=None, + energy_extractor=None, + nprocs: int=1, + cut_sil: bool=True, + spk_emb_dir: Path=None, + write_metadata_method: str='w',): + if nprocs == 1: + results = [] + for fp in tqdm.tqdm(fps, total=len(fps)): + record = process_sentence( + config=config, + fp=fp, + sentences=sentences, + output_dir=output_dir, + mel_extractor=mel_extractor, + pitch_extractor=pitch_extractor, + energy_extractor=energy_extractor, + cut_sil=cut_sil, + spk_emb_dir=spk_emb_dir,) + if record: + results.append(record) + else: + with ThreadPoolExecutor(nprocs) as pool: + futures = [] + with tqdm.tqdm(total=len(fps)) as progress: + for fp in fps: + future = pool.submit(process_sentence, config, fp, + sentences, output_dir, mel_extractor, + pitch_extractor, energy_extractor, + cut_sil, spk_emb_dir,) + future.add_done_callback(lambda p: progress.update()) + futures.append(future) + + results = [] + for ft in futures: + record = ft.result() + if record: + results.append(record) + + results.sort(key=itemgetter("utt_id")) + with jsonlines.open(output_dir / "metadata.jsonl", + write_metadata_method) as writer: + for item in results: + writer.write(item) + print("Done") + + +def main(): + # parse config and args + parser = argparse.ArgumentParser( + description="Preprocess audio and then extract features.") + + parser.add_argument( + "--dataset", + default="opencpop", + type=str, + help="name of dataset, should in {opencpop} now") + + parser.add_argument( + "--rootdir", default=None, type=str, help="directory to dataset.") + + parser.add_argument( + "--dumpdir", + type=str, + required=True, + help="directory to dump feature files.") + + parser.add_argument( + "--label-file", default=None, type=str, help="path to label file.") + + parser.add_argument("--config", type=str, help="diffsinger config file.") + + parser.add_argument( + "--num-cpu", type=int, default=1, help="number of process.") + + parser.add_argument( + "--cut-sil", + type=str2bool, + default=True, + help="whether cut sil in the edge of audio") + + parser.add_argument( + "--spk_emb_dir", + default=None, + type=str, + help="directory to speaker embedding files.") + + parser.add_argument( + "--write_metadata_method", + default="w", + type=str, + choices=["w", "a"], + help="How the metadata.jsonl file is written.") + args = parser.parse_args() + + rootdir = Path(args.rootdir).expanduser() + dumpdir = Path(args.dumpdir).expanduser() + # use absolute path + dumpdir = dumpdir.resolve() + dumpdir.mkdir(parents=True, exist_ok=True) + label_file = Path(args.label_file).expanduser() + + + if args.spk_emb_dir: + spk_emb_dir = Path(args.spk_emb_dir).expanduser().resolve() + else: + spk_emb_dir = None + + assert rootdir.is_dir() + assert label_file.is_file() + + with open(args.config, 'rt') as f: + config = CfgNode(yaml.safe_load(f)) + + sentences, speaker_set = get_sentences_svs(label_file, dataset=args.dataset, sample_rate=config.fs, n_shift=config.n_shift,) + + # merge_silence(sentences) + phone_id_map_path = dumpdir / "phone_id_map.txt" + speaker_id_map_path = dumpdir / "speaker_id_map.txt" + get_input_token(sentences, phone_id_map_path, args.dataset) + get_spk_id_map(speaker_set, speaker_id_map_path) + + if args.dataset == "opencpop": + wavdir = rootdir / "wavs" + # split data into 3 sections + train_file = rootdir / "train.txt" + train_wav_files = [] + with open(train_file, "r") as f_train: + for line in f_train.readlines(): + utt = line.split("|")[0] + wav_name = utt + ".wav" + wav_path = wavdir / wav_name + train_wav_files.append(wav_path) + + test_file = rootdir / "test.txt" + dev_wav_files = [] + test_wav_files = [] + num_dev = 106 + count = 0 + with open(test_file, "r") as f_test: + for line in f_test.readlines(): + count += 1 + utt = line.split("|")[0] + wav_name = utt + ".wav" + wav_path = wavdir / wav_name + if count > num_dev: + test_wav_files.append(wav_path) + else: + dev_wav_files.append(wav_path) + + else: + print("dataset should in {opencpop} now!") + + train_dump_dir = dumpdir / "train" / "raw" + train_dump_dir.mkdir(parents=True, exist_ok=True) + dev_dump_dir = dumpdir / "dev" / "raw" + dev_dump_dir.mkdir(parents=True, exist_ok=True) + test_dump_dir = dumpdir / "test" / "raw" + test_dump_dir.mkdir(parents=True, exist_ok=True) + + # Extractor + mel_extractor = LogMelFBank( + sr=config.fs, + n_fft=config.n_fft, + hop_length=config.n_shift, + win_length=config.win_length, + window=config.window, + n_mels=config.n_mels, + fmin=config.fmin, + fmax=config.fmax) + pitch_extractor = Pitch( + sr=config.fs, + hop_length=config.n_shift, + f0min=config.f0min, + f0max=config.f0max) + energy_extractor = Energy( + n_fft=config.n_fft, + hop_length=config.n_shift, + win_length=config.win_length, + window=config.window) + + # process for the 3 sections + if train_wav_files: + process_sentences( + config=config, + fps=train_wav_files, + sentences=sentences, + output_dir=train_dump_dir, + mel_extractor=mel_extractor, + pitch_extractor=pitch_extractor, + energy_extractor=energy_extractor, + nprocs=args.num_cpu, + cut_sil=args.cut_sil, + spk_emb_dir=spk_emb_dir, + write_metadata_method=args.write_metadata_method) + if dev_wav_files: + process_sentences( + config=config, + fps=dev_wav_files, + sentences=sentences, + output_dir=dev_dump_dir, + mel_extractor=mel_extractor, + pitch_extractor=pitch_extractor, + energy_extractor=energy_extractor, + cut_sil=args.cut_sil, + spk_emb_dir=spk_emb_dir, + write_metadata_method=args.write_metadata_method) + if test_wav_files: + process_sentences( + config=config, + fps=test_wav_files, + sentences=sentences, + output_dir=test_dump_dir, + mel_extractor=mel_extractor, + pitch_extractor=pitch_extractor, + energy_extractor=energy_extractor, + nprocs=args.num_cpu, + cut_sil=args.cut_sil, + spk_emb_dir=spk_emb_dir, + write_metadata_method=args.write_metadata_method) + + +if __name__ == "__main__": + main() diff --git a/paddlespeech/t2s/exps/diffsinger/train.py b/paddlespeech/t2s/exps/diffsinger/train.py new file mode 100644 index 000000000..ce612d037 --- /dev/null +++ b/paddlespeech/t2s/exps/diffsinger/train.py @@ -0,0 +1,222 @@ +# Copyright (c) 2021 PaddlePaddle Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +import argparse +import logging +import os +import shutil +from pathlib import Path + +import jsonlines +import numpy as np +import paddle +import yaml +from paddle import DataParallel +from paddle import distributed as dist +from paddle.io import DataLoader +from paddle.io import DistributedBatchSampler +from yacs.config import CfgNode + +from paddlespeech.t2s.datasets.am_batch_fn import diffsinger_multi_spk_batch_fn +from paddlespeech.t2s.datasets.am_batch_fn import diffsinger_single_spk_batch_fn +from paddlespeech.t2s.datasets.data_table import DataTable +from paddlespeech.t2s.models.diffsinger import DiffSinger +from paddlespeech.t2s.models.diffsinger import DiffSingerEvaluator +from paddlespeech.t2s.models.diffsinger import DiffSingerUpdater +from paddlespeech.t2s.training.extensions.snapshot import Snapshot +from paddlespeech.t2s.training.extensions.visualizer import VisualDL +from paddlespeech.t2s.training.optimizer import build_optimizers +from paddlespeech.t2s.training.seeding import seed_everything +from paddlespeech.t2s.training.trainer import Trainer +from paddlespeech.t2s.utils import str2bool + + +def train_sp(args, config): + # decides device type and whether to run in parallel + # setup running environment correctly + if (not paddle.is_compiled_with_cuda()) or args.ngpu == 0: + paddle.set_device("cpu") + else: + paddle.set_device("gpu") + world_size = paddle.distributed.get_world_size() + if world_size > 1: + paddle.distributed.init_parallel_env() + + # set the random seed, it is a must for multiprocess training + seed_everything(config.seed) + + print( + f"rank: {dist.get_rank()}, pid: {os.getpid()}, parent_pid: {os.getppid()}", + ) + fields = [ + "text", "text_lengths", "speech", "speech_lengths", "durations", "pitch", "energy", + "note", "note_dur", "is_slur"] + converters = {"speech": np.load, "pitch": np.load, "energy": np.load} + spk_num = None + if args.speaker_dict is not None: + print("multiple speaker diffsinger!") + collate_fn = diffsinger_multi_spk_batch_fn + with open(args.speaker_dict, 'rt') as f: + spk_id = [line.strip().split() for line in f.readlines()] + spk_num = len(spk_id) + fields += ["spk_id"] + elif args.voice_cloning: + print("Training voice cloning!") + collate_fn = diffsinger_multi_spk_batch_fn + fields += ["spk_emb"] + converters["spk_emb"] = np.load + else: + collate_fn = diffsinger_single_spk_batch_fn + print("single speaker diffsinger!") + + print("spk_num:", spk_num) + + # dataloader has been too verbose + logging.getLogger("DataLoader").disabled = True + + # construct dataset for training and validation + with jsonlines.open(args.train_metadata, 'r') as reader: + train_metadata = list(reader) + train_dataset = DataTable( + data=train_metadata, + fields=fields, + converters=converters, ) + with jsonlines.open(args.dev_metadata, 'r') as reader: + dev_metadata = list(reader) + dev_dataset = DataTable( + data=dev_metadata, + fields=fields, + converters=converters, ) + + # collate function and dataloader + + train_sampler = DistributedBatchSampler( + train_dataset, + batch_size=config.batch_size, + shuffle=True, + drop_last=True) + + print("samplers done!") + + train_dataloader = DataLoader( + train_dataset, + batch_sampler=train_sampler, + collate_fn=collate_fn, + num_workers=config.num_workers) + + dev_dataloader = DataLoader( + dev_dataset, + shuffle=False, + drop_last=False, + batch_size=config.batch_size, + collate_fn=collate_fn, + num_workers=config.num_workers) + print("dataloaders done!") + + with open(args.phones_dict, "r") as f: + phn_id = [line.strip().split() for line in f.readlines()] + vocab_size = len(phn_id) + print("vocab_size:", vocab_size) + + odim = config.n_mels + model = DiffSinger( + idim=vocab_size, odim=odim, spk_num=spk_num, **config["model"]) + if world_size > 1: + model = DataParallel(model) + print("model done!") + + optimizer = build_optimizers(model, **config["optimizer"]) + print("optimizer done!") + + output_dir = Path(args.output_dir) + output_dir.mkdir(parents=True, exist_ok=True) + if dist.get_rank() == 0: + config_name = args.config.split("/")[-1] + # copy conf to output_dir + shutil.copyfile(args.config, output_dir / config_name) + + if "enable_speaker_classifier" in config.model: + enable_spk_cls = config.model.enable_speaker_classifier + else: + enable_spk_cls = False + + updater = DiffSingerUpdater( + model=model, + optimizer=optimizer, + dataloader=train_dataloader, + output_dir=output_dir, + enable_spk_cls=enable_spk_cls, + **config["updater"], ) + + trainer = Trainer(updater, (config.max_epoch, 'epoch'), output_dir) + + evaluator = DiffSingerEvaluator( + model, + dev_dataloader, + output_dir=output_dir, + enable_spk_cls=enable_spk_cls, + **config["updater"], ) + + if dist.get_rank() == 0: + trainer.extend(evaluator, trigger=(1, "epoch")) + trainer.extend(VisualDL(output_dir), trigger=(1, "iteration")) + trainer.extend( + Snapshot(max_size=config.num_snapshots), trigger=(1, 'epoch')) + trainer.run() + + +def main(): + # parse args and config and redirect to train_sp + parser = argparse.ArgumentParser(description="Train a DiffSinger model.") + parser.add_argument("--config", type=str, help="diffsinger config file.") + parser.add_argument("--train-metadata", type=str, help="training data.") + parser.add_argument("--dev-metadata", type=str, help="dev data.") + parser.add_argument("--output-dir", type=str, help="output dir.") + parser.add_argument( + "--ngpu", type=int, default=1, help="if ngpu=0, use cpu.") + parser.add_argument( + "--phones-dict", type=str, default=None, help="phone vocabulary file.") + parser.add_argument( + "--speaker-dict", + type=str, + default=None, + help="speaker id map file for multiple speaker model.") + + parser.add_argument( + "--voice-cloning", + type=str2bool, + default=False, + help="whether training voice cloning model.") + + args = parser.parse_args() + + with open(args.config) as f: + config = CfgNode(yaml.safe_load(f)) + + print("========Args========") + print(yaml.safe_dump(vars(args))) + print("========Config========") + print(config) + print( + f"master see the word size: {dist.get_world_size()}, from pid: {os.getpid()}" + ) + + # dispatch + if args.ngpu > 1: + dist.spawn(train_sp, (args, config), nprocs=args.ngpu) + else: + train_sp(args, config) + + +if __name__ == "__main__": + main() diff --git a/paddlespeech/t2s/exps/syn_utils.py b/paddlespeech/t2s/exps/syn_utils.py index 6b693440c..a8508019d 100644 --- a/paddlespeech/t2s/exps/syn_utils.py +++ b/paddlespeech/t2s/exps/syn_utils.py @@ -55,6 +55,10 @@ model_alias = { "paddlespeech.t2s.models.tacotron2:Tacotron2", "tacotron2_inference": "paddlespeech.t2s.models.tacotron2:Tacotron2Inference", + "diffsinger": + "paddlespeech.t2s.models.diffsinger:DiffSinger", + "diffsinger_inference": + "paddlespeech.t2s.models.diffsinger:DiffSingerInference", # voc "pwgan": "paddlespeech.t2s.models.parallel_wavegan:PWGGenerator", @@ -141,6 +145,8 @@ def get_test_dataset(test_metadata: List[Dict[str, Any]], fields += ["spk_emb"] else: print("single speaker fastspeech2!") + elif am_name == 'diffsinger': + fields = ["utt_id", "text", "note", "note_dur", "is_slur"] elif am_name == 'speedyspeech': fields = ["utt_id", "phones", "tones"] elif am_name == 'tacotron2': @@ -347,6 +353,9 @@ def get_am_inference(am: str='fastspeech2_csmsc', if am_name == 'fastspeech2': am = am_class( idim=vocab_size, odim=odim, spk_num=spk_num, **am_config["model"]) + if am_name == 'diffsinger': + am = am_class( + idim=vocab_size, odim=odim, spk_num=spk_num, **am_config["model"]) elif am_name == 'speedyspeech': am = am_class( vocab_size=vocab_size, diff --git a/paddlespeech/t2s/exps/synthesize.py b/paddlespeech/t2s/exps/synthesize.py index a8e18150e..9e060b42d 100644 --- a/paddlespeech/t2s/exps/synthesize.py +++ b/paddlespeech/t2s/exps/synthesize.py @@ -107,6 +107,12 @@ def evaluate(args): if args.voice_cloning and "spk_emb" in datum: spk_emb = paddle.to_tensor(np.load(datum["spk_emb"])) mel = am_inference(phone_ids, spk_emb=spk_emb) + elif am_name == 'diffsinger': + phone_ids = paddle.to_tensor(datum["text"]) + note = paddle.to_tensor(datum["note"]) + note_dur = paddle.to_tensor(datum["note_dur"]) + is_slur = paddle.to_tensor(datum["is_slur"]) + mel = am_inference(phone_ids, note=note, note_dur=note_dur, is_slur=is_slur) # vocoder wav = voc_inference(mel) @@ -136,7 +142,8 @@ def parse_args(): choices=[ 'speedyspeech_csmsc', 'fastspeech2_csmsc', 'fastspeech2_ljspeech', 'fastspeech2_aishell3', 'fastspeech2_vctk', 'tacotron2_csmsc', - 'tacotron2_ljspeech', 'tacotron2_aishell3', 'fastspeech2_mix' + 'tacotron2_ljspeech', 'tacotron2_aishell3', 'fastspeech2_mix', + "diffsinger_opencpop" ], help='Choose acoustic model type of tts task.') parser.add_argument( @@ -172,7 +179,7 @@ def parse_args(): 'pwgan_csmsc', 'pwgan_ljspeech', 'pwgan_aishell3', 'pwgan_vctk', 'mb_melgan_csmsc', 'wavernn_csmsc', 'hifigan_csmsc', 'hifigan_ljspeech', 'hifigan_aishell3', 'hifigan_vctk', - 'style_melgan_csmsc' + 'style_melgan_csmsc', "pwgan_opencpop", ], help='Choose vocoder type of tts task.') parser.add_argument( diff --git a/paddlespeech/t2s/models/diffsinger/__init__.py b/paddlespeech/t2s/models/diffsinger/__init__.py new file mode 100644 index 000000000..d07a45711 --- /dev/null +++ b/paddlespeech/t2s/models/diffsinger/__init__.py @@ -0,0 +1,15 @@ +# Copyright (c) 2020 PaddlePaddle Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +from .diffsinger import * +from .diffsinger_updater import * diff --git a/paddlespeech/t2s/models/diffsinger/diffsinger.py b/paddlespeech/t2s/models/diffsinger/diffsinger.py new file mode 100644 index 000000000..8abf13b00 --- /dev/null +++ b/paddlespeech/t2s/models/diffsinger/diffsinger.py @@ -0,0 +1,1153 @@ +# Copyright (c) 2021 PaddlePaddle Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# Modified from espnet(https://github.com/espnet/espnet) +"""DiffSinger related modules for paddle""" +from typing import Dict +from typing import List +from typing import Sequence +from typing import Tuple +from typing import Union + +import numpy as np +import paddle +import paddle.nn.functional as F +from paddle import nn +from typeguard import check_argument_types + +from paddlespeech.t2s.modules.adversarial_loss.gradient_reversal import GradientReversalLayer +from paddlespeech.t2s.modules.adversarial_loss.speaker_classifier import SpeakerClassifier +from paddlespeech.t2s.modules.nets_utils import initialize +from paddlespeech.t2s.modules.nets_utils import make_non_pad_mask +from paddlespeech.t2s.modules.nets_utils import make_pad_mask +from paddlespeech.t2s.modules.predictor.duration_predictor import DurationPredictor +from paddlespeech.t2s.modules.predictor.duration_predictor import DurationPredictorLoss +from paddlespeech.t2s.modules.predictor.length_regulator import LengthRegulator +from paddlespeech.t2s.modules.predictor.variance_predictor import VariancePredictor +from paddlespeech.t2s.modules.tacotron2.decoder import Postnet +from paddlespeech.t2s.modules.transformer.encoder import CNNDecoder +from paddlespeech.t2s.modules.transformer.encoder import CNNPostnet +from paddlespeech.t2s.modules.transformer.encoder import ConformerEncoder +from paddlespeech.t2s.modules.transformer.encoder import TransformerEncoder + + +class DiffSinger(nn.Layer): + """DiffSinger module. + + This is a module of DiffSinger described in `DiffSinger: Singing Voice Synthesis via Shallow Diffusion Mechanism`._ + .. _`DiffSinger: Singing Voice Synthesis via Shallow Diffusion Mechanism`: + https://arxiv.org/pdf/2105.02446.pdf + + Args: + + Returns: + + """ + + def __init__( + self, + # network structure related + idim: int, + odim: int, + adim: int=384, + aheads: int=4, + elayers: int=6, + eunits: int=1536, + dlayers: int=6, + dunits: int=1536, + postnet_layers: int=5, + postnet_chans: int=512, + postnet_filts: int=5, + postnet_dropout_rate: float=0.5, + positionwise_layer_type: str="conv1d", + positionwise_conv_kernel_size: int=1, + use_scaled_pos_enc: bool=True, + use_batch_norm: bool=True, + encoder_normalize_before: bool=True, + decoder_normalize_before: bool=True, + encoder_concat_after: bool=False, + decoder_concat_after: bool=False, + reduction_factor: int=1, + encoder_type: str="transformer", + decoder_type: str="transformer", + # for transformer + transformer_enc_dropout_rate: float=0.1, + transformer_enc_positional_dropout_rate: float=0.1, + transformer_enc_attn_dropout_rate: float=0.1, + transformer_dec_dropout_rate: float=0.1, + transformer_dec_positional_dropout_rate: float=0.1, + transformer_dec_attn_dropout_rate: float=0.1, + # for conformer + conformer_pos_enc_layer_type: str="rel_pos", + conformer_self_attn_layer_type: str="rel_selfattn", + conformer_activation_type: str="swish", + use_macaron_style_in_conformer: bool=True, + use_cnn_in_conformer: bool=True, + zero_triu: bool=False, + conformer_enc_kernel_size: int=7, + conformer_dec_kernel_size: int=31, + # for CNN Decoder + cnn_dec_dropout_rate: float=0.2, + cnn_postnet_dropout_rate: float=0.2, + cnn_postnet_resblock_kernel_sizes: List[int]=[256, 256], + cnn_postnet_kernel_size: int=5, + cnn_decoder_embedding_dim: int=256, + # duration predictor + duration_predictor_layers: int=2, + duration_predictor_chans: int=384, + duration_predictor_kernel_size: int=3, + duration_predictor_dropout_rate: float=0.1, + # energy predictor + energy_predictor_layers: int=2, + energy_predictor_chans: int=384, + energy_predictor_kernel_size: int=3, + energy_predictor_dropout: float=0.5, + energy_embed_kernel_size: int=9, + energy_embed_dropout: float=0.5, + stop_gradient_from_energy_predictor: bool=False, + # pitch predictor + pitch_predictor_layers: int=2, + pitch_predictor_chans: int=384, + pitch_predictor_kernel_size: int=3, + pitch_predictor_dropout: float=0.5, + pitch_embed_kernel_size: int=9, + pitch_embed_dropout: float=0.5, + stop_gradient_from_pitch_predictor: bool=False, + # spk emb + spk_num: int=None, + spk_embed_dim: int=None, + spk_embed_integration_type: str="add", + # tone emb + tone_num: int=None, + tone_embed_dim: int=None, + tone_embed_integration_type: str="add", + # note emb + note_num: int=300, + # note_embed_dim: int=384, + note_embed_integration_type: str="add", + # is_slur emb + is_slur_num: int=2, + # is_slur_embed_dim: int=384, + is_slur_embed_integration_type: str="add", + # training related + init_type: str="xavier_uniform", + init_enc_alpha: float=1.0, + init_dec_alpha: float=1.0, + # speaker classifier + enable_speaker_classifier: bool=False, + hidden_sc_dim: int=256, ): + """Initialize DiffSinger module. + Args: + idim (int): + Dimension of the inputs. + odim (int): + Dimension of the outputs. + adim (int): + Attention dimension. + aheads (int): + Number of attention heads. + elayers (int): + Number of encoder layers. + eunits (int): + Number of encoder hidden units. + dlayers (int): + Number of decoder layers. + dunits (int): + Number of decoder hidden units. + postnet_layers (int): + Number of postnet layers. + postnet_chans (int): + Number of postnet channels. + postnet_filts (int): + Kernel size of postnet. + postnet_dropout_rate (float): + Dropout rate in postnet. + use_scaled_pos_enc (bool): + Whether to use trainable scaled pos encoding. + use_batch_norm (bool): + Whether to use batch normalization in encoder prenet. + encoder_normalize_before (bool): + Whether to apply layernorm layer before encoder block. + decoder_normalize_before (bool): + Whether to apply layernorm layer before decoder block. + encoder_concat_after (bool): + Whether to concatenate attention layer's input and output in encoder. + decoder_concat_after (bool): + Whether to concatenate attention layer's input and output in decoder. + reduction_factor (int): + Reduction factor. + encoder_type (str): + Encoder type ("transformer" or "conformer"). + decoder_type (str): + Decoder type ("transformer" or "conformer"). + transformer_enc_dropout_rate (float): + Dropout rate in encoder except attention and positional encoding. + transformer_enc_positional_dropout_rate (float): + Dropout rate after encoder positional encoding. + transformer_enc_attn_dropout_rate (float): + Dropout rate in encoder self-attention module. + transformer_dec_dropout_rate (float): + Dropout rate in decoder except attention & positional encoding. + transformer_dec_positional_dropout_rate (float): + Dropout rate after decoder positional encoding. + transformer_dec_attn_dropout_rate (float): + Dropout rate in decoder self-attention module. + conformer_pos_enc_layer_type (str): + Pos encoding layer type in conformer. + conformer_self_attn_layer_type (str): + Self-attention layer type in conformer + conformer_activation_type (str): + Activation function type in conformer. + use_macaron_style_in_conformer (bool): + Whether to use macaron style FFN. + use_cnn_in_conformer (bool): + Whether to use CNN in conformer. + zero_triu (bool): + Whether to use zero triu in relative self-attention module. + conformer_enc_kernel_size (int): + Kernel size of encoder conformer. + conformer_dec_kernel_size (int): + Kernel size of decoder conformer. + duration_predictor_layers (int): + Number of duration predictor layers. + duration_predictor_chans (int): + Number of duration predictor channels. + duration_predictor_kernel_size (int): + Kernel size of duration predictor. + duration_predictor_dropout_rate (float): + Dropout rate in duration predictor. + pitch_predictor_layers (int): + Number of pitch predictor layers. + pitch_predictor_chans (int): + Number of pitch predictor channels. + pitch_predictor_kernel_size (int): + Kernel size of pitch predictor. + pitch_predictor_dropout_rate (float): + Dropout rate in pitch predictor. + pitch_embed_kernel_size (float): + Kernel size of pitch embedding. + pitch_embed_dropout_rate (float): + Dropout rate for pitch embedding. + stop_gradient_from_pitch_predictor (bool): + Whether to stop gradient from pitch predictor to encoder. + energy_predictor_layers (int): + Number of energy predictor layers. + energy_predictor_chans (int): + Number of energy predictor channels. + energy_predictor_kernel_size (int): + Kernel size of energy predictor. + energy_predictor_dropout_rate (float): + Dropout rate in energy predictor. + energy_embed_kernel_size (float): + Kernel size of energy embedding. + energy_embed_dropout_rate (float): + Dropout rate for energy embedding. + stop_gradient_from_energy_predictor(bool): + Whether to stop gradient from energy predictor to encoder. + spk_num (Optional[int]): + Number of speakers. If not None, assume that the spk_embed_dim is not None, + spk_ids will be provided as the input and use spk_embedding_table. + spk_embed_dim (Optional[int]): + Speaker embedding dimension. If not None, + assume that spk_emb will be provided as the input or spk_num is not None. + spk_embed_integration_type (str): + How to integrate speaker embedding. + tone_num (Optional[int]): + Number of tones. If not None, assume that the + tone_ids will be provided as the input and use tone_embedding_table. + tone_embed_dim (Optional[int]): + Tone embedding dimension. If not None, assume that tone_num is not None. + tone_embed_integration_type (str): + How to integrate tone embedding. + init_type (str): + How to initialize transformer parameters. + init_enc_alpha (float): + Initial value of alpha in scaled pos encoding of the encoder. + init_dec_alpha (float): + Initial value of alpha in scaled pos encoding of the decoder. + enable_speaker_classifier (bool): + Whether to use speaker classifier module + hidden_sc_dim (int): + The hidden layer dim of speaker classifier + + """ + assert check_argument_types() + super().__init__() + + # store hyperparameters + self.odim = odim + self.reduction_factor = reduction_factor + self.encoder_type = encoder_type + self.decoder_type = decoder_type + self.stop_gradient_from_pitch_predictor = stop_gradient_from_pitch_predictor + self.stop_gradient_from_energy_predictor = stop_gradient_from_energy_predictor + self.use_scaled_pos_enc = use_scaled_pos_enc + self.hidden_sc_dim = hidden_sc_dim + self.spk_num = spk_num + self.enable_speaker_classifier = enable_speaker_classifier + + self.spk_embed_dim = spk_embed_dim + if self.spk_embed_dim is not None: + self.spk_embed_integration_type = spk_embed_integration_type + + self.tone_embed_dim = tone_embed_dim + if self.tone_embed_dim is not None: + self.tone_embed_integration_type = tone_embed_integration_type + + self.note_embed_dim = adim + if self.note_embed_dim is not None: + self.note_embed_integration_type = note_embed_integration_type + self.note_dur_layer = nn.Linear(1, self.note_embed_dim) + + self.is_slur_embed_dim = adim + if self.is_slur_embed_dim is not None: + self.is_slur_embed_integration_type = is_slur_embed_integration_type + + # use idx 0 as padding idx + self.padding_idx = 0 + + # initialize parameters + initialize(self, init_type) + + if spk_num and self.spk_embed_dim: + self.spk_embedding_table = nn.Embedding( + num_embeddings=spk_num, + embedding_dim=self.spk_embed_dim, + padding_idx=self.padding_idx) + + if self.tone_embed_dim is not None: + self.tone_embedding_table = nn.Embedding( + num_embeddings=tone_num, + embedding_dim=self.tone_embed_dim, + padding_idx=self.padding_idx) + + if note_num and self.note_embed_dim: + self.note_embedding_table = nn.Embedding( + num_embeddings=note_num, + embedding_dim=self.note_embed_dim, + padding_idx=self.padding_idx) + + if is_slur_num and self.is_slur_embed_dim: + self.is_slur_embedding_table = nn.Embedding( + num_embeddings=is_slur_num, + embedding_dim=self.is_slur_embed_dim, + padding_idx=self.padding_idx) + + # get positional encoding layer type + transformer_pos_enc_layer_type = "scaled_abs_pos" if self.use_scaled_pos_enc else "abs_pos" + + # define encoder + encoder_input_layer = nn.Embedding( + num_embeddings=idim, + embedding_dim=adim, + padding_idx=self.padding_idx) + + if encoder_type == "transformer": + self.encoder = TransformerEncoder( + idim=idim, + attention_dim=adim, + attention_heads=aheads, + linear_units=eunits, + num_blocks=elayers, + input_layer=encoder_input_layer, + dropout_rate=transformer_enc_dropout_rate, + positional_dropout_rate=transformer_enc_positional_dropout_rate, + attention_dropout_rate=transformer_enc_attn_dropout_rate, + pos_enc_layer_type=transformer_pos_enc_layer_type, + normalize_before=encoder_normalize_before, + concat_after=encoder_concat_after, + positionwise_layer_type=positionwise_layer_type, + positionwise_conv_kernel_size=positionwise_conv_kernel_size, ) + elif encoder_type == "conformer": + self.encoder = ConformerEncoder( + idim=idim, + attention_dim=adim, + attention_heads=aheads, + linear_units=eunits, + num_blocks=elayers, + input_layer=encoder_input_layer, + dropout_rate=transformer_enc_dropout_rate, + positional_dropout_rate=transformer_enc_positional_dropout_rate, + attention_dropout_rate=transformer_enc_attn_dropout_rate, + normalize_before=encoder_normalize_before, + concat_after=encoder_concat_after, + positionwise_layer_type=positionwise_layer_type, + positionwise_conv_kernel_size=positionwise_conv_kernel_size, + macaron_style=use_macaron_style_in_conformer, + pos_enc_layer_type=conformer_pos_enc_layer_type, + selfattention_layer_type=conformer_self_attn_layer_type, + activation_type=conformer_activation_type, + use_cnn_module=use_cnn_in_conformer, + cnn_module_kernel=conformer_enc_kernel_size, + zero_triu=zero_triu, ) + else: + raise ValueError(f"{encoder_type} is not supported.") + + # define additional projection for speaker embedding + if self.spk_embed_dim is not None: + if self.spk_embed_integration_type == "add": + self.spk_projection = nn.Linear(self.spk_embed_dim, adim) + else: + self.spk_projection = nn.Linear(adim + self.spk_embed_dim, adim) + + # define additional projection for tone embedding + if self.tone_embed_dim is not None: + if self.tone_embed_integration_type == "add": + self.tone_projection = nn.Linear(self.tone_embed_dim, adim) + else: + self.tone_projection = nn.Linear(adim + self.tone_embed_dim, + adim) + + if self.spk_num and self.enable_speaker_classifier: + # set lambda = 1 + self.grad_reverse = GradientReversalLayer(1) + self.speaker_classifier = SpeakerClassifier( + idim=adim, hidden_sc_dim=self.hidden_sc_dim, spk_num=spk_num) + + # define duration predictor + self.duration_predictor = DurationPredictor( + idim=adim, + n_layers=duration_predictor_layers, + n_chans=duration_predictor_chans, + kernel_size=duration_predictor_kernel_size, + dropout_rate=duration_predictor_dropout_rate, ) + + # define pitch predictor + self.pitch_predictor = VariancePredictor( + idim=adim, + n_layers=pitch_predictor_layers, + n_chans=pitch_predictor_chans, + kernel_size=pitch_predictor_kernel_size, + dropout_rate=pitch_predictor_dropout, ) + # We use continuous pitch + FastPitch style avg + self.pitch_embed = nn.Sequential( + nn.Conv1D( + in_channels=1, + out_channels=adim, + kernel_size=pitch_embed_kernel_size, + padding=(pitch_embed_kernel_size - 1) // 2, ), + nn.Dropout(pitch_embed_dropout), ) + + # define energy predictor + self.energy_predictor = VariancePredictor( + idim=adim, + n_layers=energy_predictor_layers, + n_chans=energy_predictor_chans, + kernel_size=energy_predictor_kernel_size, + dropout_rate=energy_predictor_dropout, ) + # We use continuous enegy + FastPitch style avg + self.energy_embed = nn.Sequential( + nn.Conv1D( + in_channels=1, + out_channels=adim, + kernel_size=energy_embed_kernel_size, + padding=(energy_embed_kernel_size - 1) // 2, ), + nn.Dropout(energy_embed_dropout), ) + + # define length regulator + self.length_regulator = LengthRegulator() + + # define decoder + # NOTE: we use encoder as decoder + # because fastspeech's decoder is the same as encoder + if decoder_type == "transformer": + self.decoder = TransformerEncoder( + idim=0, + attention_dim=adim, + attention_heads=aheads, + linear_units=dunits, + num_blocks=dlayers, + # in decoder, don't need layer before pos_enc_class (we use embedding here in encoder) + input_layer=None, + dropout_rate=transformer_dec_dropout_rate, + positional_dropout_rate=transformer_dec_positional_dropout_rate, + attention_dropout_rate=transformer_dec_attn_dropout_rate, + pos_enc_layer_type=transformer_pos_enc_layer_type, + normalize_before=decoder_normalize_before, + concat_after=decoder_concat_after, + positionwise_layer_type=positionwise_layer_type, + positionwise_conv_kernel_size=positionwise_conv_kernel_size, ) + elif decoder_type == "conformer": + self.decoder = ConformerEncoder( + idim=0, + attention_dim=adim, + attention_heads=aheads, + linear_units=dunits, + num_blocks=dlayers, + input_layer=None, + dropout_rate=transformer_dec_dropout_rate, + positional_dropout_rate=transformer_dec_positional_dropout_rate, + attention_dropout_rate=transformer_dec_attn_dropout_rate, + normalize_before=decoder_normalize_before, + concat_after=decoder_concat_after, + positionwise_layer_type=positionwise_layer_type, + positionwise_conv_kernel_size=positionwise_conv_kernel_size, + macaron_style=use_macaron_style_in_conformer, + pos_enc_layer_type=conformer_pos_enc_layer_type, + selfattention_layer_type=conformer_self_attn_layer_type, + activation_type=conformer_activation_type, + use_cnn_module=use_cnn_in_conformer, + cnn_module_kernel=conformer_dec_kernel_size, ) + elif decoder_type == 'cnndecoder': + self.decoder = CNNDecoder( + emb_dim=adim, + odim=odim, + kernel_size=cnn_postnet_kernel_size, + dropout_rate=cnn_dec_dropout_rate, + resblock_kernel_sizes=cnn_postnet_resblock_kernel_sizes) + else: + raise ValueError(f"{decoder_type} is not supported.") + + # define final projection + self.feat_out = nn.Linear(adim, odim * reduction_factor) + + # define postnet + if decoder_type == 'cnndecoder': + self.postnet = CNNPostnet( + odim=odim, + kernel_size=cnn_postnet_kernel_size, + dropout_rate=cnn_postnet_dropout_rate, + resblock_kernel_sizes=cnn_postnet_resblock_kernel_sizes) + else: + self.postnet = (None if postnet_layers == 0 else Postnet( + idim=idim, + odim=odim, + n_layers=postnet_layers, + n_chans=postnet_chans, + n_filts=postnet_filts, + use_batch_norm=use_batch_norm, + dropout_rate=postnet_dropout_rate, )) + + nn.initializer.set_global_initializer(None) + + self._reset_parameters( + init_enc_alpha=init_enc_alpha, + init_dec_alpha=init_dec_alpha, ) + + def forward( + self, + text: paddle.Tensor, + note: paddle.Tensor, + note_dur: paddle.Tensor, + is_slur: paddle.Tensor, + text_lengths: paddle.Tensor, + speech: paddle.Tensor, + speech_lengths: paddle.Tensor, + durations: paddle.Tensor, + pitch: paddle.Tensor, + energy: paddle.Tensor, + tone_id: paddle.Tensor=None, + spk_emb: paddle.Tensor=None, + spk_id: paddle.Tensor=None, + ) -> Tuple[paddle.Tensor, Dict[str, paddle.Tensor], paddle.Tensor]: + """Calculate forward propagation. + + Args: + text(Tensor(int64)): + Batch of padded token ids (B, Tmax). + text_lengths(Tensor(int64)): + Batch of lengths of each input (B,). + speech(Tensor): + Batch of padded target features (B, Lmax, odim). + speech_lengths(Tensor(int64)): + Batch of the lengths of each target (B,). + durations(Tensor(int64)): + Batch of padded durations (B, Tmax). + pitch(Tensor): + Batch of padded token-averaged pitch (B, Tmax, 1). + energy(Tensor): + Batch of padded token-averaged energy (B, Tmax, 1). + tone_id(Tensor, optional(int64)): + Batch of padded tone ids (B, Tmax). + spk_emb(Tensor, optional): + Batch of speaker embeddings (B, spk_embed_dim). + spk_id(Tnesor, optional(int64)): + Batch of speaker ids (B,) + + Returns: + + """ + + # input of embedding must be int64 + xs = paddle.cast(text, 'int64') + note = paddle.cast(note, 'int64') + note_dur = paddle.cast(note_dur, 'float32') + is_slur = paddle.cast(is_slur, 'int64') + ilens = paddle.cast(text_lengths, 'int64') + ds = paddle.cast(durations, 'int64') + olens = paddle.cast(speech_lengths, 'int64') + ys = speech + ps = pitch + es = energy + if spk_id is not None: + spk_id = paddle.cast(spk_id, 'int64') + if tone_id is not None: + tone_id = paddle.cast(tone_id, 'int64') + # forward propagation + before_outs, after_outs, d_outs, p_outs, e_outs, spk_logits = self._forward( + xs, + note, + note_dur, + is_slur, + ilens, + olens, + ds, + ps, + es, + is_inference=False, + spk_emb=spk_emb, + spk_id=spk_id, + tone_id=tone_id) + # modify mod part of groundtruth + if self.reduction_factor > 1: + olens = olens - olens % self.reduction_factor + max_olen = max(olens) + ys = ys[:, :max_olen] + + return before_outs, after_outs, d_outs, p_outs, e_outs, ys, olens, spk_logits + + def _forward(self, + xs: paddle.Tensor, + note: paddle.Tensor, + note_dur: paddle.Tensor, + is_slur: paddle.Tensor, + ilens: paddle.Tensor, + olens: paddle.Tensor=None, + ds: paddle.Tensor=None, + ps: paddle.Tensor=None, + es: paddle.Tensor=None, + is_inference: bool=False, + return_after_enc=False, + alpha: float=1.0, + spk_emb=None, + spk_id=None, + tone_id=None) -> Sequence[paddle.Tensor]: + # forward encoder + x_masks = self._source_mask(ilens) + note_emb = self.note_embedding_table(note) + note_dur_emb = self.note_dur_layer(paddle.unsqueeze(note_dur, axis=-1)) + is_slur_emb = self.is_slur_embedding_table(is_slur) + + # (B, Tmax, adim) + hs, _ = self.encoder(xs, x_masks, note_emb, note_dur_emb, is_slur_emb,) + + if self.spk_num and self.enable_speaker_classifier and not is_inference: + hs_for_spk_cls = self.grad_reverse(hs) + spk_logits = self.speaker_classifier(hs_for_spk_cls, ilens) + else: + spk_logits = None + + # integrate speaker embedding + if self.spk_embed_dim is not None: + # spk_emb has a higher priority than spk_id + if spk_emb is not None: + hs = self._integrate_with_spk_embed(hs, spk_emb) + elif spk_id is not None: + spk_emb = self.spk_embedding_table(spk_id) + hs = self._integrate_with_spk_embed(hs, spk_emb) + + # integrate tone embedding + if self.tone_embed_dim is not None: + if tone_id is not None: + tone_embs = self.tone_embedding_table(tone_id) + hs = self._integrate_with_tone_embed(hs, tone_embs) + # forward duration predictor and variance predictors + d_masks = make_pad_mask(ilens) + # forward decoder + if olens is not None and not is_inference: + if self.reduction_factor > 1: + olens_in = paddle.to_tensor( + [olen // self.reduction_factor for olen in olens.numpy()]) + else: + olens_in = olens + # (B, 1, T) + h_masks = self._source_mask(olens_in) + pitch_masks = h_masks.transpose((0, 2, 1)) + else: + h_masks = None + pitch_masks = h_masks + + + if is_inference: + # (B, Tmax) + if ds is not None: + d_outs = ds + else: + d_outs = self.duration_predictor.inference(hs, d_masks) + # (B, Lmax, adim) + hs = self.length_regulator(hs, d_outs, alpha, is_inference=True) + + if ps is not None: + p_outs = ps + else: + if self.stop_gradient_from_pitch_predictor: + p_outs = self.pitch_predictor(hs.detach(), pitch_masks) + else: + p_outs = self.pitch_predictor(hs, pitch_masks) + + if es is not None: + e_outs = es + else: + if self.stop_gradient_from_energy_predictor: + e_outs = self.energy_predictor(hs.detach(), pitch_masks) + else: + e_outs = self.energy_predictor(hs, pitch_masks) + + p_embs = self.pitch_embed(p_outs.transpose((0, 2, 1))).transpose( + (0, 2, 1)) + e_embs = self.energy_embed(e_outs.transpose((0, 2, 1))).transpose( + (0, 2, 1)) + hs = hs + e_embs + p_embs + + else: + d_outs = self.duration_predictor(hs, d_masks) + # (B, Lmax, adim) + hs = self.length_regulator(hs, ds, is_inference=False) + if self.stop_gradient_from_pitch_predictor: + p_outs = self.pitch_predictor(hs.detach(), pitch_masks) + else: + p_outs = self.pitch_predictor(hs, pitch_masks) + if self.stop_gradient_from_energy_predictor: + e_outs = self.energy_predictor(hs.detach(), pitch_masks) + else: + e_outs = self.energy_predictor(hs, pitch_masks) + p_embs = self.pitch_embed(ps.transpose((0, 2, 1))).transpose( + (0, 2, 1)) + e_embs = self.energy_embed(es.transpose((0, 2, 1))).transpose( + (0, 2, 1)) + hs = hs + e_embs + p_embs + + if return_after_enc: + return hs, h_masks + + + if self.decoder_type == 'cnndecoder': + # remove output masks for dygraph to static graph + zs = self.decoder(hs, h_masks) + before_outs = zs + else: + # (B, Lmax, adim) + zs, _ = self.decoder(hs, h_masks) + # (B, Lmax, odim) + before_outs = self.feat_out(zs).reshape( + (paddle.shape(zs)[0], -1, self.odim)) + + # postnet -> (B, Lmax//r * r, odim) + if self.postnet is None: + after_outs = before_outs + else: + after_outs = before_outs + self.postnet( + before_outs.transpose((0, 2, 1))).transpose((0, 2, 1)) + + return before_outs, after_outs, d_outs, p_outs, e_outs, spk_logits + + def encoder_infer( + self, + text: paddle.Tensor, + note: paddle.Tensor, + note_dur: paddle.Tensor, + is_slur: paddle.Tensor, + alpha: float=1.0, + spk_emb=None, + spk_id=None, + tone_id=None, + ) -> Tuple[paddle.Tensor, paddle.Tensor, paddle.Tensor]: + # input of embedding must be int64 + x = paddle.cast(text, 'int64') + note = paddle.cast(note, 'int64') + note_dur = paddle.cast(note_dur, 'float32') + is_slur = paddle.cast(is_slur, 'int64') + # setup batch axis + ilens = paddle.shape(x)[0] + + xs = x.unsqueeze(0) + note = note.unsqueeze(0) + note_dur = note_dur.unsqueeze(0) + is_slur = is_slur.unsqueeze(0) + + if spk_emb is not None: + spk_emb = spk_emb.unsqueeze(0) + + if tone_id is not None: + tone_id = tone_id.unsqueeze(0) + + # (1, L, odim) + # use *_ to avoid bug in dygraph to static graph + hs, *_ = self._forward( + xs, + note, + note_dur, + is_slur, + ilens, + is_inference=True, + return_after_enc=True, + alpha=alpha, + spk_emb=spk_emb, + spk_id=spk_id, + tone_id=tone_id) + return hs + + def inference( + self, + text: paddle.Tensor, + note: paddle.Tensor, + note_dur: paddle.Tensor, + is_slur: paddle.Tensor, + durations: paddle.Tensor=None, + pitch: paddle.Tensor=None, + energy: paddle.Tensor=None, + alpha: float=1.0, + use_teacher_forcing: bool=False, + spk_emb=None, + spk_id=None, + tone_id=None, + ) -> Tuple[paddle.Tensor, paddle.Tensor, paddle.Tensor]: + """Generate the sequence of features given the sequences of characters. + + Args: + text(Tensor(int64)): + Input sequence of characters (T,). + durations(Tensor, optional (int64)): + Groundtruth of duration (T,). + pitch(Tensor, optional): + Groundtruth of token-averaged pitch (T, 1). + energy(Tensor, optional): + Groundtruth of token-averaged energy (T, 1). + alpha(float, optional): + Alpha to control the speed. + use_teacher_forcing(bool, optional): + Whether to use teacher forcing. + If true, groundtruth of duration, pitch and energy will be used. + spk_emb(Tensor, optional, optional): + peaker embedding vector (spk_embed_dim,). (Default value = None) + spk_id(Tensor, optional(int64), optional): + spk ids (1,). (Default value = None) + tone_id(Tensor, optional(int64), optional): + tone ids (T,). (Default value = None) + + Returns: + + """ + # input of embedding must be int64 + x = paddle.cast(text, 'int64') + note = paddle.cast(note, 'int64') + note_dur = paddle.cast(note_dur, 'float32') + is_slur = paddle.cast(is_slur, 'int64') + d, p, e = durations, pitch, energy + # setup batch axis + ilens = paddle.shape(x)[0] + + xs = x.unsqueeze(0) + note = note.unsqueeze(0) + note_dur = note_dur.unsqueeze(0) + is_slur = is_slur.unsqueeze(0) + + if spk_emb is not None: + spk_emb = spk_emb.unsqueeze(0) + + if tone_id is not None: + tone_id = tone_id.unsqueeze(0) + + if use_teacher_forcing: + # use groundtruth of duration, pitch, and energy + ds = d.unsqueeze(0) if d is not None else None + ps = p.unsqueeze(0) if p is not None else None + es = e.unsqueeze(0) if e is not None else None + + # (1, L, odim) + _, outs, d_outs, p_outs, e_outs, _ = self._forward( + xs, + note, + note_dur, + is_slur, + ilens, + ds=ds, + ps=ps, + es=es, + spk_emb=spk_emb, + spk_id=spk_id, + tone_id=tone_id, + is_inference=True) + else: + # (1, L, odim) + _, outs, d_outs, p_outs, e_outs, _ = self._forward( + xs, + note, + note_dur, + is_slur, + ilens, + is_inference=True, + alpha=alpha, + spk_emb=spk_emb, + spk_id=spk_id, + tone_id=tone_id) + + return outs[0], d_outs[0], p_outs[0], e_outs[0] + + def _integrate_with_spk_embed(self, hs, spk_emb): + """Integrate speaker embedding with hidden states. + + Args: + hs(Tensor): + Batch of hidden state sequences (B, Tmax, adim). + spk_emb(Tensor): + Batch of speaker embeddings (B, spk_embed_dim). + + Returns: + + + """ + if self.spk_embed_integration_type == "add": + # apply projection and then add to hidden states + spk_emb = self.spk_projection(F.normalize(spk_emb)) + hs = hs + spk_emb.unsqueeze(1) + elif self.spk_embed_integration_type == "concat": + # concat hidden states with spk embeds and then apply projection + spk_emb = F.normalize(spk_emb).unsqueeze(1).expand( + shape=[-1, paddle.shape(hs)[1], -1]) + hs = self.spk_projection(paddle.concat([hs, spk_emb], axis=-1)) + else: + raise NotImplementedError("support only add or concat.") + + return hs + + def _integrate_with_tone_embed(self, hs, tone_embs): + """Integrate speaker embedding with hidden states. + + Args: + hs(Tensor): + Batch of hidden state sequences (B, Tmax, adim). + tone_embs(Tensor): + Batch of speaker embeddings (B, Tmax, tone_embed_dim). + + Returns: + + """ + if self.tone_embed_integration_type == "add": + # apply projection and then add to hidden states + tone_embs = self.tone_projection(F.normalize(tone_embs)) + hs = hs + tone_embs + + elif self.tone_embed_integration_type == "concat": + # concat hidden states with tone embeds and then apply projection + tone_embs = F.normalize(tone_embs).expand( + shape=[-1, hs.shape[1], -1]) + hs = self.tone_projection(paddle.concat([hs, tone_embs], axis=-1)) + else: + raise NotImplementedError("support only add or concat.") + return hs + + def _source_mask(self, ilens: paddle.Tensor) -> paddle.Tensor: + """Make masks for self-attention. + + Args: + ilens(Tensor): + Batch of lengths (B,). + + Returns: + Tensor: + Mask tensor for self-attention. dtype=paddle.bool + + Examples: + >>> ilens = [5, 3] + >>> self._source_mask(ilens) + tensor([[[1, 1, 1, 1, 1], + [1, 1, 1, 0, 0]]]) bool + """ + x_masks = make_non_pad_mask(ilens) + return x_masks.unsqueeze(-2) + + def _reset_parameters(self, init_enc_alpha: float, init_dec_alpha: float): + + # initialize alpha in scaled positional encoding + if self.encoder_type == "transformer" and self.use_scaled_pos_enc: + init_enc_alpha = paddle.to_tensor(init_enc_alpha) + self.encoder.embed[-1].alpha = paddle.create_parameter( + shape=init_enc_alpha.shape, + dtype=str(init_enc_alpha.numpy().dtype), + default_initializer=paddle.nn.initializer.Assign( + init_enc_alpha)) + if self.decoder_type == "transformer" and self.use_scaled_pos_enc: + init_dec_alpha = paddle.to_tensor(init_dec_alpha) + self.decoder.embed[-1].alpha = paddle.create_parameter( + shape=init_dec_alpha.shape, + dtype=str(init_dec_alpha.numpy().dtype), + default_initializer=paddle.nn.initializer.Assign( + init_dec_alpha)) + + +class DiffSingerInference(nn.Layer): + def __init__(self, normalizer, model): + super().__init__() + self.normalizer = normalizer + self.acoustic_model = model + + def forward(self, text, note, note_dur, is_slur, spk_id=None, spk_emb=None): + normalized_mel, d_outs, p_outs, e_outs = self.acoustic_model.inference( + text, note=note, note_dur=note_dur, is_slur=is_slur, spk_id=spk_id, spk_emb=spk_emb) + logmel = self.normalizer.inverse(normalized_mel) + return logmel + + +class DiffSingerLoss(nn.Layer): + """Loss function module for DiffSinger.""" + + def __init__(self, use_masking: bool=True, + use_weighted_masking: bool=False): + """Initialize feed-forward Transformer loss module. + Args: + use_masking (bool): + Whether to apply masking for padded part in loss calculation. + use_weighted_masking (bool): + Whether to weighted masking in loss calculation. + """ + assert check_argument_types() + super().__init__() + + assert (use_masking != use_weighted_masking) or not use_masking + self.use_masking = use_masking + self.use_weighted_masking = use_weighted_masking + + # define criterions + reduction = "none" if self.use_weighted_masking else "mean" + self.l1_criterion = nn.L1Loss(reduction=reduction) + self.mse_criterion = nn.MSELoss(reduction=reduction) + self.duration_criterion = DurationPredictorLoss(reduction=reduction) + self.ce_criterion = nn.CrossEntropyLoss() + + def forward( + self, + after_outs: paddle.Tensor, + before_outs: paddle.Tensor, + d_outs: paddle.Tensor, + p_outs: paddle.Tensor, + e_outs: paddle.Tensor, + ys: paddle.Tensor, + ds: paddle.Tensor, + ps: paddle.Tensor, + es: paddle.Tensor, + ilens: paddle.Tensor, + olens: paddle.Tensor, + spk_logits: paddle.Tensor=None, + spk_ids: paddle.Tensor=None, + ) -> Tuple[paddle.Tensor, paddle.Tensor, paddle.Tensor, paddle.Tensor, + paddle.Tensor, ]: + """Calculate forward propagation. + + Args: + after_outs(Tensor): + Batch of outputs after postnets (B, Lmax, odim). + before_outs(Tensor): + Batch of outputs before postnets (B, Lmax, odim). + d_outs(Tensor): + Batch of outputs of duration predictor (B, Tmax). + p_outs(Tensor): + Batch of outputs of pitch predictor (B, Tmax, 1). + e_outs(Tensor): + Batch of outputs of energy predictor (B, Tmax, 1). + ys(Tensor): + Batch of target features (B, Lmax, odim). + ds(Tensor): + Batch of durations (B, Tmax). + ps(Tensor): + Batch of target token-averaged pitch (B, Tmax, 1). + es(Tensor): + Batch of target token-averaged energy (B, Tmax, 1). + ilens(Tensor): + Batch of the lengths of each input (B,). + olens(Tensor): + Batch of the lengths of each target (B,). + spk_logits(Option[Tensor]): + Batch of outputs after speaker classifier (B, Lmax, num_spk) + spk_ids(Option[Tensor]): + Batch of target spk_id (B,) + + + Returns: + + + """ + speaker_loss = 0.0 + + # apply mask to remove padded part + if self.use_masking: + out_masks = make_non_pad_mask(olens).unsqueeze(-1) + before_outs = before_outs.masked_select( + out_masks.broadcast_to(before_outs.shape)) + if after_outs is not None: + after_outs = after_outs.masked_select( + out_masks.broadcast_to(after_outs.shape)) + ys = ys.masked_select(out_masks.broadcast_to(ys.shape)) + duration_masks = make_non_pad_mask(ilens) + d_outs = d_outs.masked_select( + duration_masks.broadcast_to(d_outs.shape)) + ds = ds.masked_select(duration_masks.broadcast_to(ds.shape)) + pitch_masks = out_masks + p_outs = p_outs.masked_select( + pitch_masks.broadcast_to(p_outs.shape)) + e_outs = e_outs.masked_select( + pitch_masks.broadcast_to(e_outs.shape)) + ps = ps.masked_select(pitch_masks.broadcast_to(ps.shape)) + es = es.masked_select(pitch_masks.broadcast_to(es.shape)) + + if spk_logits is not None and spk_ids is not None: + batch_size = spk_ids.shape[0] + spk_ids = paddle.repeat_interleave(spk_ids, spk_logits.shape[1], + None) + spk_logits = paddle.reshape(spk_logits, + [-1, spk_logits.shape[-1]]) + mask_index = spk_logits.abs().sum(axis=1) != 0 + spk_ids = spk_ids[mask_index] + spk_logits = spk_logits[mask_index] + + # calculate loss + l1_loss = self.l1_criterion(before_outs, ys) + if after_outs is not None: + l1_loss += self.l1_criterion(after_outs, ys) + duration_loss = self.duration_criterion(d_outs, ds) + pitch_loss = self.mse_criterion(p_outs, ps) + energy_loss = self.mse_criterion(e_outs, es) + + if spk_logits is not None and spk_ids is not None: + speaker_loss = self.ce_criterion(spk_logits, spk_ids) / batch_size + + # make weighted mask and apply it + if self.use_weighted_masking: + out_masks = make_non_pad_mask(olens).unsqueeze(-1) + out_weights = out_masks.cast(dtype=paddle.float32) / out_masks.cast( + dtype=paddle.float32).sum( + axis=1, keepdim=True) + out_weights /= ys.shape[0] * ys.shape[2] + duration_masks = make_non_pad_mask(ilens) + duration_weights = (duration_masks.cast(dtype=paddle.float32) / + duration_masks.cast(dtype=paddle.float32).sum( + axis=1, keepdim=True)) + duration_weights /= ds.shape[0] + + # apply weight + + l1_loss = l1_loss.multiply(out_weights) + l1_loss = l1_loss.masked_select( + out_masks.broadcast_to(l1_loss.shape)).sum() + duration_loss = (duration_loss.multiply(duration_weights) + .masked_select(duration_masks).sum()) + pitch_masks = out_masks + pitch_weights = duration_weights.unsqueeze(-1) + pitch_loss = pitch_loss.multiply(pitch_weights) + pitch_loss = pitch_loss.masked_select( + pitch_masks.broadcast_to(pitch_loss.shape)).sum() + energy_loss = energy_loss.multiply(pitch_weights) + energy_loss = energy_loss.masked_select( + pitch_masks.broadcast_to(energy_loss.shape)).sum() + + return l1_loss, duration_loss, pitch_loss, energy_loss, speaker_loss diff --git a/paddlespeech/t2s/models/diffsinger/diffsinger_updater.py b/paddlespeech/t2s/models/diffsinger/diffsinger_updater.py new file mode 100644 index 000000000..1f5b15c7b --- /dev/null +++ b/paddlespeech/t2s/models/diffsinger/diffsinger_updater.py @@ -0,0 +1,248 @@ +# Copyright (c) 2021 PaddlePaddle Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +import logging +from pathlib import Path + +from paddle import DataParallel +from paddle import distributed as dist +from paddle.io import DataLoader +from paddle.nn import Layer +from paddle.optimizer import Optimizer + +from paddlespeech.t2s.models.diffsinger import DiffSingerLoss +from paddlespeech.t2s.training.extensions.evaluator import StandardEvaluator +from paddlespeech.t2s.training.reporter import report +from paddlespeech.t2s.training.updaters.standard_updater import StandardUpdater + +logging.basicConfig( + format='%(asctime)s [%(levelname)s] [%(filename)s:%(lineno)d] %(message)s', + datefmt='[%Y-%m-%d %H:%M:%S]') +logger = logging.getLogger(__name__) +logger.setLevel(logging.INFO) + + +class DiffSingerUpdater(StandardUpdater): + def __init__( + self, + model: Layer, + optimizer: Optimizer, + dataloader: DataLoader, + init_state=None, + use_masking: bool=False, + spk_loss_scale: float=0.02, + use_weighted_masking: bool=False, + output_dir: Path=None, + enable_spk_cls: bool=False, ): + super().__init__(model, optimizer, dataloader, init_state=None) + + self.criterion = DiffSingerLoss( + use_masking=use_masking, + use_weighted_masking=use_weighted_masking, ) + + log_file = output_dir / 'worker_{}.log'.format(dist.get_rank()) + self.filehandler = logging.FileHandler(str(log_file)) + logger.addHandler(self.filehandler) + self.logger = logger + self.msg = "" + self.spk_loss_scale = spk_loss_scale + self.enable_spk_cls = enable_spk_cls + + def update_core(self, batch): + self.msg = "Rank: {}, ".format(dist.get_rank()) + losses_dict = {} + # spk_id!=None in multiple spk diffsinger + spk_id = batch["spk_id"] if "spk_id" in batch else None + spk_emb = batch["spk_emb"] if "spk_emb" in batch else None + # No explicit speaker identifier labels are used during voice cloning training. + if spk_emb is not None: + spk_id = None + + if type( + self.model + ) == DataParallel and self.model._layers.spk_num and self.model._layers.enable_speaker_classifier: + with self.model.no_sync(): + before_outs, after_outs, d_outs, p_outs, e_outs, ys, olens, spk_logits = self.model( + text=batch["text"], + note=batch["note"], + note_dur=batch["note_dur"], + is_slur=batch["is_slur"], + text_lengths=batch["text_lengths"], + speech=batch["speech"], + speech_lengths=batch["speech_lengths"], + durations=batch["durations"], + pitch=batch["pitch"], + energy=batch["energy"], + spk_id=spk_id, + spk_emb=spk_emb) + else: + before_outs, after_outs, d_outs, p_outs, e_outs, ys, olens, spk_logits = self.model( + text=batch["text"], + note=batch["note"], + note_dur=batch["note_dur"], + is_slur=batch["is_slur"], + text_lengths=batch["text_lengths"], + speech=batch["speech"], + speech_lengths=batch["speech_lengths"], + durations=batch["durations"], + pitch=batch["pitch"], + energy=batch["energy"], + spk_id=spk_id, + spk_emb=spk_emb) + + l1_loss, duration_loss, pitch_loss, energy_loss, speaker_loss = self.criterion( + after_outs=after_outs, + before_outs=before_outs, + d_outs=d_outs, + p_outs=p_outs, + e_outs=e_outs, + ys=ys, + ds=batch["durations"], + ps=batch["pitch"], + es=batch["energy"], + ilens=batch["text_lengths"], + olens=olens, + spk_logits=spk_logits, + spk_ids=spk_id, ) + + scaled_speaker_loss = self.spk_loss_scale * speaker_loss + loss = l1_loss + duration_loss + pitch_loss + energy_loss + scaled_speaker_loss + + optimizer = self.optimizer + optimizer.clear_grad() + loss.backward() + optimizer.step() + + report("train/loss", float(loss)) + report("train/l1_loss", float(l1_loss)) + report("train/duration_loss", float(duration_loss)) + report("train/pitch_loss", float(pitch_loss)) + report("train/energy_loss", float(energy_loss)) + if self.enable_spk_cls: + report("train/speaker_loss", float(speaker_loss)) + report("train/scaled_speaker_loss", float(scaled_speaker_loss)) + + losses_dict["l1_loss"] = float(l1_loss) + losses_dict["duration_loss"] = float(duration_loss) + losses_dict["pitch_loss"] = float(pitch_loss) + losses_dict["energy_loss"] = float(energy_loss) + losses_dict["energy_loss"] = float(energy_loss) + if self.enable_spk_cls: + losses_dict["speaker_loss"] = float(speaker_loss) + losses_dict["scaled_speaker_loss"] = float(scaled_speaker_loss) + losses_dict["loss"] = float(loss) + self.msg += ', '.join('{}: {:>.6f}'.format(k, v) + for k, v in losses_dict.items()) + + +class DiffSingerEvaluator(StandardEvaluator): + def __init__(self, + model: Layer, + dataloader: DataLoader, + use_masking: bool=False, + use_weighted_masking: bool=False, + spk_loss_scale: float=0.02, + output_dir: Path=None, + enable_spk_cls: bool=False): + super().__init__(model, dataloader) + + log_file = output_dir / 'worker_{}.log'.format(dist.get_rank()) + self.filehandler = logging.FileHandler(str(log_file)) + logger.addHandler(self.filehandler) + self.logger = logger + self.msg = "" + self.spk_loss_scale = spk_loss_scale + self.enable_spk_cls = enable_spk_cls + + self.criterion = DiffSingerLoss( + use_masking=use_masking, use_weighted_masking=use_weighted_masking) + + def evaluate_core(self, batch): + self.msg = "Evaluate: " + losses_dict = {} + # spk_id!=None in multiple spk diffsinger + spk_id = batch["spk_id"] if "spk_id" in batch else None + spk_emb = batch["spk_emb"] if "spk_emb" in batch else None + if spk_emb is not None: + spk_id = None + + if type( + self.model + ) == DataParallel and self.model._layers.spk_num and self.model._layers.enable_speaker_classifier: + with self.model.no_sync(): + before_outs, after_outs, d_outs, p_outs, e_outs, ys, olens, spk_logits = self.model( + text=batch["text"], + note=batch["note"], + note_dur=batch["note_dur"], + is_slur=batch["is_slur"], + text_lengths=batch["text_lengths"], + speech=batch["speech"], + speech_lengths=batch["speech_lengths"], + durations=batch["durations"], + pitch=batch["pitch"], + energy=batch["energy"], + spk_id=spk_id, + spk_emb=spk_emb) + else: + before_outs, after_outs, d_outs, p_outs, e_outs, ys, olens, spk_logits = self.model( + text=batch["text"], + note=batch["note"], + note_dur=batch["note_dur"], + is_slur=batch["is_slur"], + text_lengths=batch["text_lengths"], + speech=batch["speech"], + speech_lengths=batch["speech_lengths"], + durations=batch["durations"], + pitch=batch["pitch"], + energy=batch["energy"], + spk_id=spk_id, + spk_emb=spk_emb) + + l1_loss, duration_loss, pitch_loss, energy_loss, speaker_loss = self.criterion( + after_outs=after_outs, + before_outs=before_outs, + d_outs=d_outs, + p_outs=p_outs, + e_outs=e_outs, + ys=ys, + ds=batch["durations"], + ps=batch["pitch"], + es=batch["energy"], + ilens=batch["text_lengths"], + olens=olens, + spk_logits=spk_logits, + spk_ids=spk_id, ) + + scaled_speaker_loss = self.spk_loss_scale * speaker_loss + loss = l1_loss + duration_loss + pitch_loss + energy_loss + scaled_speaker_loss + + report("eval/loss", float(loss)) + report("eval/l1_loss", float(l1_loss)) + report("eval/duration_loss", float(duration_loss)) + report("eval/pitch_loss", float(pitch_loss)) + report("eval/energy_loss", float(energy_loss)) + if self.enable_spk_cls: + report("train/speaker_loss", float(speaker_loss)) + report("train/scaled_speaker_loss", float(scaled_speaker_loss)) + + losses_dict["l1_loss"] = float(l1_loss) + losses_dict["duration_loss"] = float(duration_loss) + losses_dict["pitch_loss"] = float(pitch_loss) + losses_dict["energy_loss"] = float(energy_loss) + if self.enable_spk_cls: + losses_dict["speaker_loss"] = float(speaker_loss) + losses_dict["scaled_speaker_loss"] = float(scaled_speaker_loss) + losses_dict["loss"] = float(loss) + self.msg += ', '.join('{}: {:>.6f}'.format(k, v) + for k, v in losses_dict.items()) + self.logger.info(self.msg) diff --git a/paddlespeech/t2s/modules/transformer/encoder.py b/paddlespeech/t2s/modules/transformer/encoder.py index f2aed5892..91848de76 100644 --- a/paddlespeech/t2s/modules/transformer/encoder.py +++ b/paddlespeech/t2s/modules/transformer/encoder.py @@ -390,12 +390,18 @@ class TransformerEncoder(BaseEncoder): padding_idx=padding_idx, encoder_type="transformer") - def forward(self, xs, masks): + def forward(self, xs, masks, note_emb=None, note_dur_emb=None, is_slur_emb=None, scale=16): """Encoder input sequence. Args: xs(Tensor): Input tensor (#batch, time, idim). + note_emb(Tensor): + Input tensor (#batch, time, attention_dim). + note_dur_emb(Tensor): + Input tensor (#batch, time, attention_dim). + is_slur_emb(Tensor): + Input tensor (#batch, time, attention_dim). masks(Tensor): Mask tensor (#batch, 1, time). @@ -406,6 +412,8 @@ class TransformerEncoder(BaseEncoder): Mask tensor (#batch, 1, time). """ xs = self.embed(xs) + if note_emb is not None: + xs = scale * xs + note_emb + note_dur_emb + is_slur_emb xs, masks = self.encoders(xs, masks) if self.normalize_before: xs = self.after_norm(xs)