fix diffsinger, test=tts

pull/2834/head
liangym 3 years ago
parent ef7d15dc02
commit c91dc02931

@ -24,14 +24,19 @@ f0max: 750 # Maximum f0 for pitch extraction.
# DATA SETTING # # DATA SETTING #
########################################################### ###########################################################
batch_size: 32 batch_size: 32
num_workers: 4 num_workers: 1
########################################################### ###########################################################
# MODEL SETTING # # MODEL SETTING #
########################################################### ###########################################################
model:
# music score related
note_num: 300
is_slur_num: 2
# fastspeech2 module # fastspeech2 module
fs2_model: fastspeech2_params:
adim: 256 # attention dimension adim: 256 # attention dimension
aheads: 2 # number of attention heads aheads: 2 # number of attention heads
elayers: 4 # number of encoder layers elayers: 4 # number of encoder layers
@ -73,15 +78,14 @@ fs2_model:
energy_embed_kernel_size: 1 # kernel size of conv embedding layer for energy energy_embed_kernel_size: 1 # kernel size of conv embedding layer for energy
energy_embed_dropout: 0.0 # dropout rate after conv embedding layer for energy energy_embed_dropout: 0.0 # dropout rate after conv embedding layer for energy
stop_gradient_from_energy_predictor: False # whether to stop the gradient from energy predictor to encoder stop_gradient_from_energy_predictor: False # whether to stop the gradient from energy predictor to encoder
note_num: 300
is_slur_num: 2
denoiser_model: # denoiser module
denoiser_params:
in_channels: 80 in_channels: 80
out_channels: 80 out_channels: 80
kernel_size: 3 kernel_size: 3
layers: 20 layers: 20
stacks: 4 stacks: 5
residual_channels: 256 residual_channels: 256
gate_channels: 512 gate_channels: 512
skip_channels: 256 skip_channels: 256
@ -89,9 +93,10 @@ denoiser_model:
dropout: 0.1 dropout: 0.1
bias: True bias: True
use_weight_norm: False use_weight_norm: False
init_type: kaiming_uniform init_type: "kaiming_normal"
diffusion: # diffusion module
diffusion_params:
num_train_timesteps: 100 num_train_timesteps: 100
beta_start: 0.0001 beta_start: 0.0001
beta_end: 0.06 beta_end: 0.06
@ -112,7 +117,6 @@ ds_updater:
########################################################### ###########################################################
# OPTIMIZER SETTING # # OPTIMIZER SETTING #
########################################################### ###########################################################
# gpu_num=2 config
# fastspeech2 optimizer # fastspeech2 optimizer
fs2_optimizer: fs2_optimizer:
optim: adam # optimizer type optim: adam # optimizer type
@ -134,10 +138,10 @@ ds_grad_norm: 1
########################################################### ###########################################################
# INTERVAL SETTING # # INTERVAL SETTING #
########################################################### ###########################################################
ds_train_start_steps: 80000 # Number of steps to start to train diffusion module. ds_train_start_steps: 160000 # Number of steps to start to train diffusion module.
train_max_steps: 160000 # Number of training steps. train_max_steps: 320000 # Number of training steps.
save_interval_steps: 1000 # Interval steps to save checkpoint. save_interval_steps: 1000 # Interval steps to save checkpoint.
eval_interval_steps: 250 # Interval steps to evaluate the network. eval_interval_steps: 1000 # Interval steps to evaluate the network.
num_snapshots: 5 num_snapshots: 5

@ -8,5 +8,5 @@ python3 ${BIN_DIR}/train.py \
--dev-metadata=dump/dev/norm/metadata.jsonl \ --dev-metadata=dump/dev/norm/metadata.jsonl \
--config=${config_path} \ --config=${config_path} \
--output-dir=${train_output_path} \ --output-dir=${train_output_path} \
--ngpu=2 \ --ngpu=1 \
--phones-dict=dump/phone_id_map.txt --phones-dict=dump/phone_id_map.txt

@ -3,13 +3,13 @@
set -e set -e
source path.sh source path.sh
gpus=4,5 gpus=0
stage=0 stage=0
stop_stage=100 stop_stage=100
conf_path=conf/default.yaml conf_path=conf/default.yaml
train_output_path=exp/default train_output_path=exp/default
ckpt_name=snapshot_iter_153.pdz ckpt_name=snapshot_iter_320000.pdz
# with the following command, you can choose the stage range you want to run # with the following command, you can choose the stage range you want to run
# such as `./run.sh --stage 0 --stop-stage 0` # such as `./run.sh --stage 0 --stop-stage 0`
@ -30,8 +30,3 @@ if [ ${stage} -le 2 ] && [ ${stop_stage} -ge 2 ]; then
# synthesize, vocoder is pwgan by default # synthesize, vocoder is pwgan by default
CUDA_VISIBLE_DEVICES=${gpus} ./local/synthesize.sh ${conf_path} ${train_output_path} ${ckpt_name} || exit -1 CUDA_VISIBLE_DEVICES=${gpus} ./local/synthesize.sh ${conf_path} ${train_output_path} ${ckpt_name} || exit -1
fi fi
if [ ${stage} -le 3 ] && [ ${stop_stage} -ge 3 ]; then
# synthesize_e2e, vocoder is pwgan by default
CUDA_VISIBLE_DEVICES=${gpus} ./local/synthesize_e2e.sh ${conf_path} ${train_output_path} ${ckpt_name} || exit -1
fi

@ -12,6 +12,8 @@
# See the License for the specific language governing permissions and # See the License for the specific language governing permissions and
# limitations under the License. # limitations under the License.
import re import re
from typing import List
import librosa import librosa
import numpy as np import numpy as np
@ -42,7 +44,16 @@ def get_phn_dur(file_name):
f.close() f.close()
return sentence, speaker_set return sentence, speaker_set
def note2midi(notes):
def note2midi(notes: List[str]) -> List[str]:
"""Covert note string to note id, for example: ["C1"] -> [24]
Args:
notes (List[str]): the list of note string
Returns:
List[str]: the list of note id
"""
midis = [] midis = []
for note in notes: for note in notes:
if note == 'rest': if note == 'rest':
@ -53,7 +64,21 @@ def note2midi(notes):
return midis return midis
def time2frame(times, sample_rate: int=24000, n_shift: int=128,):
def time2frame(
times: List[float],
sample_rate: int=24000,
n_shift: int=128, ) -> List[int]:
"""Convert the phoneme duration of time(s) into frames
Args:
times (List[float]): phoneme duration of time(s)
sample_rate (int, optional): sample rate. Defaults to 24000.
n_shift (int, optional): frame shift. Defaults to 128.
Returns:
List[int]: phoneme duration of frame
"""
end = 0.0 end = 0.0
ends = [] ends = []
for t in times: for t in times:
@ -63,14 +88,20 @@ def time2frame(times, sample_rate: int=24000, n_shift: int=128,):
durations = np.diff(frame_pos, prepend=0) durations = np.diff(frame_pos, prepend=0)
return durations return durations
def get_sentences_svs(file_name, dataset: str='opencpop', sample_rate: int=24000, n_shift: int=128,):
def get_sentences_svs(
file_name,
dataset: str='opencpop',
sample_rate: int=24000,
n_shift: int=128, ):
''' '''
read label file read label file
Args: Args:
file_name (str or Path): path of gen_duration_from_textgrid.py's result file_name (str or Path): path of gen_duration_from_textgrid.py's result
dataset (str): dataset name dataset (str): dataset name
Returns: Returns:
Dict: sentence: {'utt': ([char], [int])} Dict: the information of sentence, include [phone id (int)], [the frame of phone (int)], [note id (int)], [note duration (float)], [is slur (int)], text(str), speaker name (str)
tunple: speaker name
''' '''
f = open(file_name, 'r') f = open(file_name, 'r')
sentence = {} sentence = {}
@ -87,7 +118,10 @@ def get_sentences_svs(file_name, dataset: str='opencpop', sample_rate: int=24000
ph_dur = time2frame([float(t) for t in line_list[5].split()]) ph_dur = time2frame([float(t) for t in line_list[5].split()])
is_slur = line_list[6].split() is_slur = line_list[6].split()
assert len(ph) == len(midi) == len(midi_dur) == len(is_slur) assert len(ph) == len(midi) == len(midi_dur) == len(is_slur)
sentence[utt] = (ph, [int(i) for i in ph_dur], [int(i) for i in midi], [float(i) for i in midi_dur], [int(i) for i in is_slur], text, "opencpop") sentence[utt] = (ph, [int(i) for i in ph_dur],
[int(i) for i in midi],
[float(i) for i in midi_dur],
[int(i) for i in is_slur], text, "opencpop")
else: else:
print("dataset should in {opencpop} now!") print("dataset should in {opencpop} now!")

@ -37,13 +37,20 @@ from paddlespeech.t2s.datasets.preprocess_utils import get_spk_id_map
from paddlespeech.t2s.datasets.preprocess_utils import merge_silence from paddlespeech.t2s.datasets.preprocess_utils import merge_silence
from paddlespeech.t2s.utils import str2bool from paddlespeech.t2s.utils import str2bool
ALL_SHENGMU = ['zh', 'ch', 'sh', 'b', 'p', 'm', 'f', 'd', 't', 'n', 'l', 'g', 'k', 'h', 'j', ALL_INITIALS = [
'q', 'x', 'r', 'z', 'c', 's', 'y', 'w'] 'zh', 'ch', 'sh', 'b', 'p', 'm', 'f', 'd', 't', 'n', 'l', 'g', 'k', 'h',
ALL_YUNMU = ['a', 'ai', 'an', 'ang', 'ao', 'e', 'ei', 'en', 'eng', 'er', 'i', 'ia', 'ian', 'j', 'q', 'x', 'r', 'z', 'c', 's', 'y', 'w'
'iang', 'iao', 'ie', 'in', 'ing', 'iong', 'iu', 'ng', 'o', 'ong', 'ou', ]
'u', 'ua', 'uai', 'uan', 'uang', 'ui', 'un', 'uo', 'v', 'van', 've', 'vn'] ALL_FINALS = [
'a', 'ai', 'an', 'ang', 'ao', 'e', 'ei', 'en', 'eng', 'er', 'i', 'ia',
def process_sentence(config: Dict[str, Any], 'ian', 'iang', 'iao', 'ie', 'in', 'ing', 'iong', 'iu', 'ng', 'o', 'ong',
'ou', 'u', 'ua', 'uai', 'uan', 'uang', 'ui', 'un', 'uo', 'v', 'van', 've',
'vn'
]
def process_sentence(
config: Dict[str, Any],
fp: Path, fp: Path,
sentences: Dict, sentences: Dict,
output_dir: Path, output_dir: Path,
@ -82,9 +89,13 @@ def process_sentence(config: Dict[str, Any],
phones = sentences[utt_id][0] phones = sentences[utt_id][0]
durations = sentences[utt_id][1] durations = sentences[utt_id][1]
num_frames = logmel.shape[0] num_frames = logmel.shape[0]
word_boundary = [1 if x in ALL_YUNMU + ['AP', 'SP'] else 0 for x in phones] word_boundary = [
1 if x in ALL_FINALS + ['AP', 'SP'] else 0 for x in phones
]
# print(sum(durations), num_frames) # print(sum(durations), num_frames)
assert sum(durations) == num_frames, "the sum of durations doesn't equal to the num of mel frames. " assert sum(
durations
) == num_frames, "the sum of durations doesn't equal to the num of mel frames. "
speech_dir = output_dir / "data_speech" speech_dir = output_dir / "data_speech"
speech_dir.mkdir(parents=True, exist_ok=True) speech_dir.mkdir(parents=True, exist_ok=True)
speech_path = speech_dir / (utt_id + "_speech.npy") speech_path = speech_dir / (utt_id + "_speech.npy")
@ -128,7 +139,8 @@ def process_sentence(config: Dict[str, Any],
return record return record
def process_sentences(config, def process_sentences(
config,
fps: List[Path], fps: List[Path],
sentences: Dict, sentences: Dict,
output_dir: Path, output_dir: Path,
@ -159,10 +171,17 @@ def process_sentences(config,
futures = [] futures = []
with tqdm.tqdm(total=len(fps)) as progress: with tqdm.tqdm(total=len(fps)) as progress:
for fp in fps: for fp in fps:
future = pool.submit(process_sentence, config, fp, future = pool.submit(
sentences, output_dir, mel_extractor, process_sentence,
pitch_extractor, energy_extractor, config,
cut_sil, spk_emb_dir,) fp,
sentences,
output_dir,
mel_extractor,
pitch_extractor,
energy_extractor,
cut_sil,
spk_emb_dir, )
future.add_done_callback(lambda p: progress.update()) future.add_done_callback(lambda p: progress.update())
futures.append(future) futures.append(future)
@ -235,7 +254,6 @@ def main():
dumpdir.mkdir(parents=True, exist_ok=True) dumpdir.mkdir(parents=True, exist_ok=True)
label_file = Path(args.label_file).expanduser() label_file = Path(args.label_file).expanduser()
if args.spk_emb_dir: if args.spk_emb_dir:
spk_emb_dir = Path(args.spk_emb_dir).expanduser().resolve() spk_emb_dir = Path(args.spk_emb_dir).expanduser().resolve()
else: else:
@ -247,7 +265,11 @@ def main():
with open(args.config, 'rt') as f: with open(args.config, 'rt') as f:
config = CfgNode(yaml.safe_load(f)) config = CfgNode(yaml.safe_load(f))
sentences, speaker_set = get_sentences_svs(label_file, dataset=args.dataset, sample_rate=config.fs, n_shift=config.n_shift,) sentences, speaker_set = get_sentences_svs(
label_file,
dataset=args.dataset,
sample_rate=config.fs,
n_shift=config.n_shift, )
# merge_silence(sentences) # merge_silence(sentences)
phone_id_map_path = dumpdir / "phone_id_map.txt" phone_id_map_path = dumpdir / "phone_id_map.txt"

@ -37,7 +37,7 @@ from paddlespeech.t2s.models.diffsinger import DiffSinger
from paddlespeech.t2s.models.diffsinger import DiffSingerEvaluator from paddlespeech.t2s.models.diffsinger import DiffSingerEvaluator
from paddlespeech.t2s.models.diffsinger import DiffSingerUpdater from paddlespeech.t2s.models.diffsinger import DiffSingerUpdater
from paddlespeech.t2s.models.diffsinger import DiffusionLoss from paddlespeech.t2s.models.diffsinger import DiffusionLoss
from paddlespeech.t2s.models.diffsinger import FastSpeech2MIDILoss from paddlespeech.t2s.models.diffsinger.fastspeech2midi import FastSpeech2MIDILoss
from paddlespeech.t2s.training.extensions.snapshot import Snapshot from paddlespeech.t2s.training.extensions.snapshot import Snapshot
from paddlespeech.t2s.training.extensions.visualizer import VisualDL from paddlespeech.t2s.training.extensions.visualizer import VisualDL
from paddlespeech.t2s.training.optimizer import build_optimizers from paddlespeech.t2s.training.optimizer import build_optimizers
@ -45,6 +45,9 @@ from paddlespeech.t2s.training.seeding import seed_everything
from paddlespeech.t2s.training.trainer import Trainer from paddlespeech.t2s.training.trainer import Trainer
from paddlespeech.t2s.utils import str2bool from paddlespeech.t2s.utils import str2bool
# from paddlespeech.t2s.models.fastspeech2 import FastSpeech2Loss
def train_sp(args, config): def train_sp(args, config):
# decides device type and whether to run in parallel # decides device type and whether to run in parallel
# setup running environment correctly # setup running environment correctly
@ -75,11 +78,6 @@ def train_sp(args, config):
spk_id = [line.strip().split() for line in f.readlines()] spk_id = [line.strip().split() for line in f.readlines()]
spk_num = len(spk_id) spk_num = len(spk_id)
fields += ["spk_id"] fields += ["spk_id"]
elif args.voice_cloning:
print("Training voice cloning!")
collate_fn = diffsinger_multi_spk_batch_fn
fields += ["spk_emb"]
converters["spk_emb"] = np.load
else: else:
collate_fn = diffsinger_single_spk_batch_fn collate_fn = diffsinger_single_spk_batch_fn
print("single speaker diffsinger!") print("single speaker diffsinger!")
@ -133,30 +131,28 @@ def train_sp(args, config):
print("vocab_size:", vocab_size) print("vocab_size:", vocab_size)
odim = config.n_mels odim = config.n_mels
config["fs2_model"]["idim"] = vocab_size config["model"]["fastspeech2_params"]["spk_num"] = spk_num
config["fs2_model"]["odim"] = odim model = DiffSinger(idim=vocab_size, odim=odim, **config["model"])
config["fs2_model"]["spk_num"] = spk_num model_fs2 = model.fs2
model_ds = model.diffusion
model = DiffSinger(
fs2_config=config["fs2_model"],
denoiser_config=config["denoiser_model"],
diffusion_config=config["diffusion"])
if world_size > 1: if world_size > 1:
model = DataParallel(model) model = DataParallel(model)
model_fs2 = model._layers.fs2
model_ds = model._layers.diffusion
print("models done!") print("models done!")
# criterion_fs2 = FastSpeech2Loss(**config["fs2_updater"])
criterion_fs2 = FastSpeech2MIDILoss(**config["fs2_updater"]) criterion_fs2 = FastSpeech2MIDILoss(**config["fs2_updater"])
criterion_ds = DiffusionLoss(**config["ds_updater"]) criterion_ds = DiffusionLoss(**config["ds_updater"])
print("criterions done!") print("criterions done!")
optimizer_fs2 = build_optimizers(model._layers.fs2, optimizer_fs2 = build_optimizers(model_fs2, **config["fs2_optimizer"])
**config["fs2_optimizer"])
lr_schedule_ds = StepDecay(**config["ds_scheduler_params"]) lr_schedule_ds = StepDecay(**config["ds_scheduler_params"])
gradient_clip_ds = nn.ClipGradByGlobalNorm(config["ds_grad_norm"]) gradient_clip_ds = nn.ClipGradByGlobalNorm(config["ds_grad_norm"])
optimizer_ds = AdamW( optimizer_ds = AdamW(
learning_rate=lr_schedule_ds, learning_rate=lr_schedule_ds,
grad_clip=gradient_clip_ds, grad_clip=gradient_clip_ds,
parameters=model._layers.diffusion.parameters(), parameters=model_ds.parameters(),
**config["ds_optimizer_params"]) **config["ds_optimizer_params"])
# optimizer_ds = build_optimizers(ds, **config["ds_optimizer"]) # optimizer_ds = build_optimizers(ds, **config["ds_optimizer"])
print("optimizer done!") print("optimizer done!")
@ -189,7 +185,8 @@ def train_sp(args, config):
"ds": criterion_ds, "ds": criterion_ds,
}, },
dataloader=dev_dataloader, dataloader=dev_dataloader,
output_dir=output_dir) output_dir=output_dir,)
trainer = Trainer( trainer = Trainer(
updater, updater,
stop_trigger=(config.train_max_steps, "iteration"), stop_trigger=(config.train_max_steps, "iteration"),
@ -224,12 +221,6 @@ def main():
default=None, default=None,
help="speaker id map file for multiple speaker model.") help="speaker id map file for multiple speaker model.")
parser.add_argument(
"--voice-cloning",
type=str2bool,
default=False,
help="whether training voice cloning model.")
args = parser.parse_args() args = parser.parse_args()
with open(args.config) as f: with open(args.config) as f:

@ -23,7 +23,6 @@ from typing import Optional
import numpy as np import numpy as np
import onnxruntime as ort import onnxruntime as ort
import paddle import paddle
import yaml
from paddle import inference from paddle import inference
from paddle import jit from paddle import jit
from paddle.io import DataLoader from paddle.io import DataLoader
@ -358,13 +357,8 @@ def get_am_inference(am: str='fastspeech2_csmsc',
am = am_class( am = am_class(
idim=vocab_size, odim=odim, spk_num=spk_num, **am_config["model"]) idim=vocab_size, odim=odim, spk_num=spk_num, **am_config["model"])
elif am_name == 'diffsinger': elif am_name == 'diffsinger':
am_config["fs2_model"]["idim"] = vocab_size am_config["model"]["fastspeech2_params"]["spk_num"] = spk_num
am_config["fs2_model"]["odim"] = am_config.n_mels am = am_class(idim=vocab_size, odim=odim, **am_config["model"])
am_config["fs2_model"]["spk_num"] = spk_num
am = am_class(
fs2_config=am_config["fs2_model"],
denoiser_config=am_config["denoiser_model"],
diffusion_config=am_config["diffusion"])
elif am_name == 'speedyspeech': elif am_name == 'speedyspeech':
am = am_class( am = am_class(
vocab_size=vocab_size, vocab_size=vocab_size,

File diff suppressed because it is too large Load Diff

@ -44,8 +44,9 @@ class DiffSingerUpdater(StandardUpdater):
fs2_train_start_steps: int=0, fs2_train_start_steps: int=0,
ds_train_start_steps: int=160000, ds_train_start_steps: int=160000,
output_dir: Path=None, ): output_dir: Path=None, ):
super().__init__(model, optimizers, dataloader, init_state=None) super().__init__(model, optimizers, dataloader, init_state=None)
self.model = model._layers if isinstance(model,
paddle.DataParallel) else model
self.optimizers = optimizers self.optimizers = optimizers
self.optimizer_fs2: Optimizer = optimizers['fs2'] self.optimizer_fs2: Optimizer = optimizers['fs2']
@ -79,7 +80,7 @@ class DiffSingerUpdater(StandardUpdater):
if spk_emb is not None: if spk_emb is not None:
spk_id = None spk_id = None
# fastspeech2 # only train fastspeech2 module firstly
if self.state.iteration > self.fs2_train_start_steps and self.state.iteration < self.ds_train_start_steps: if self.state.iteration > self.fs2_train_start_steps and self.state.iteration < self.ds_train_start_steps:
before_outs, after_outs, d_outs, p_outs, e_outs, ys, olens, spk_logits = self.model( before_outs, after_outs, d_outs, p_outs, e_outs, ys, olens, spk_logits = self.model(
text=batch["text"], text=batch["text"],
@ -133,8 +134,9 @@ class DiffSingerUpdater(StandardUpdater):
self.msg += ', '.join('{}: {:>.6f}'.format(k, v) self.msg += ', '.join('{}: {:>.6f}'.format(k, v)
for k, v in losses_dict.items()) for k, v in losses_dict.items())
# Then only train diffusion module, freeze fastspeech2 parameters.
if self.state.iteration > self.ds_train_start_steps: if self.state.iteration > self.ds_train_start_steps:
for param in self.model._layers.fs2.parameters(): for param in self.model.fs2.parameters():
param.trainable = False param.trainable = False
mel, mel_masks = self.model( mel, mel_masks = self.model(
@ -183,12 +185,12 @@ class DiffSingerEvaluator(StandardEvaluator):
dataloader: DataLoader, dataloader: DataLoader,
output_dir: Path=None, ): output_dir: Path=None, ):
super().__init__(model, dataloader) super().__init__(model, dataloader)
self.model = model self.model = model._layers if isinstance(model,
paddle.DataParallel) else model
self.criterions = criterions self.criterions = criterions
self.criterion_fs2 = criterions['fs2'] self.criterion_fs2 = criterions['fs2']
self.criterion_ds = criterions['ds'] self.criterion_ds = criterions['ds']
self.dataloader = dataloader self.dataloader = dataloader
log_file = output_dir / 'worker_{}.log'.format(dist.get_rank()) log_file = output_dir / 'worker_{}.log'.format(dist.get_rank())
@ -206,6 +208,7 @@ class DiffSingerEvaluator(StandardEvaluator):
if spk_emb is not None: if spk_emb is not None:
spk_id = None spk_id = None
# Here show diffsinger eval
mel, mel_masks = self.model( mel, mel_masks = self.model(
text=batch["text"], text=batch["text"],
note=batch["note"], note=batch["note"],
@ -227,11 +230,10 @@ class DiffSingerEvaluator(StandardEvaluator):
ref_mels=batch["speech"], ref_mels=batch["speech"],
out_mels=mel, out_mels=mel,
mel_masks=mel_masks, ) mel_masks=mel_masks, )
loss_ds = l1_loss_ds loss_ds = l1_loss_ds
report("train/loss_ds", float(loss_ds)) report("eval/loss_ds", float(loss_ds))
report("train/l1_loss_ds", float(l1_loss_ds)) report("eval/l1_loss_ds", float(l1_loss_ds))
losses_dict["l1_loss_ds"] = float(l1_loss_ds) losses_dict["l1_loss_ds"] = float(l1_loss_ds)
losses_dict["loss_ds"] = float(loss_ds) losses_dict["loss_ds"] = float(loss_ds)
self.msg += ', '.join('{}: {:>.6f}'.format(k, v) self.msg += ', '.join('{}: {:>.6f}'.format(k, v)

@ -0,0 +1,625 @@
# Copyright (c) 2021 PaddlePaddle Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# Modified from espnet(https://github.com/espnet/espnet)
from typing import Any
from typing import Dict
from typing import Sequence
from typing import Tuple
import paddle
from paddle import nn
from typeguard import check_argument_types
from paddlespeech.t2s.models.fastspeech2 import FastSpeech2
from paddlespeech.t2s.modules.nets_utils import make_non_pad_mask
from paddlespeech.t2s.modules.nets_utils import make_pad_mask
from paddlespeech.t2s.modules.predictor.duration_predictor import DurationPredictorLoss
class FastSpeech2MIDI(FastSpeech2):
"""The Fastspeech2 module of DiffSinger.
"""
def __init__(
self,
# fastspeech2 network structure related
idim: int,
odim: int,
fastspeech2_config: Dict[str, Any],
# note emb
note_num: int=300,
# is_slur emb
is_slur_num: int=2, ):
"""Initialize FastSpeech2 module for svs.
Args:
fastspeech2_config (Dict):
The config of FastSpeech2 module on DiffSinger model
note_num (Optional[int]):
Number of note. If not None, assume that the
note_ids will be provided as the input and use note_embedding_table.
is_slur_num (Optional[int]):
Number of note. If not None, assume that the
is_slur_ids will be provided as the input
"""
assert check_argument_types()
super().__init__(idim=idim, odim=odim, **fastspeech2_config)
self.note_embed_dim = self.is_slur_embed_dim = fastspeech2_config[
"adim"]
if note_num is not None:
self.note_embedding_table = nn.Embedding(
num_embeddings=note_num,
embedding_dim=self.note_embed_dim,
padding_idx=self.padding_idx)
self.note_dur_layer = nn.Linear(1, self.note_embed_dim)
if is_slur_num is not None:
self.is_slur_embedding_table = nn.Embedding(
num_embeddings=is_slur_num,
embedding_dim=self.is_slur_embed_dim,
padding_idx=self.padding_idx)
def forward(
self,
text: paddle.Tensor,
note: paddle.Tensor,
note_dur: paddle.Tensor,
is_slur: paddle.Tensor,
text_lengths: paddle.Tensor,
speech: paddle.Tensor,
speech_lengths: paddle.Tensor,
durations: paddle.Tensor,
pitch: paddle.Tensor,
energy: paddle.Tensor,
spk_emb: paddle.Tensor=None,
spk_id: paddle.Tensor=None,
) -> Tuple[paddle.Tensor, Dict[str, paddle.Tensor], paddle.Tensor]:
"""Calculate forward propagation.
Args:
text(Tensor(int64)):
Batch of padded token (phone) ids (B, Tmax).
note(Tensor(int64)):
Batch of padded note (element in music score) ids (B, Tmax).
note_dur(Tensor(float32)):
Batch of padded note durations in seconds (element in music score) (B, Tmax).
is_slur(Tensor(int64)):
Batch of padded slur (element in music score) ids (B, Tmax).
text_lengths(Tensor(int64)):
Batch of phone lengths of each input (B,).
speech(Tensor[float32]):
Batch of padded target features (e.g. mel) (B, Lmax, odim).
speech_lengths(Tensor(int64)):
Batch of the lengths of each target features (B,).
durations(Tensor(int64)):
Batch of padded token durations in frame (B, Tmax).
pitch(Tensor[float32]):
Batch of padded frame-averaged pitch (B, Lmax, 1).
energy(Tensor[float32]):
Batch of padded frame-averaged energy (B, Lmax, 1).
spk_emb(Tensor[float32], optional):
Batch of speaker embeddings (B, spk_embed_dim).
spk_id(Tnesor[int64], optional(int64)):
Batch of speaker ids (B,)
Returns:
"""
xs = paddle.cast(text, 'int64')
note = paddle.cast(note, 'int64')
note_dur = paddle.cast(note_dur, 'float32')
is_slur = paddle.cast(is_slur, 'int64')
ilens = paddle.cast(text_lengths, 'int64')
olens = paddle.cast(speech_lengths, 'int64')
ds = paddle.cast(durations, 'int64')
ps = pitch
es = energy
ys = speech
olens = speech_lengths
if spk_id is not None:
spk_id = paddle.cast(spk_id, 'int64')
# forward propagation
before_outs, after_outs, d_outs, p_outs, e_outs, spk_logits = self._forward(
xs,
note,
note_dur,
is_slur,
ilens,
olens,
ds,
ps,
es,
is_inference=False,
spk_emb=spk_emb,
spk_id=spk_id, )
# modify mod part of groundtruth
if self.reduction_factor > 1:
olens = olens - olens % self.reduction_factor
max_olen = max(olens)
ys = ys[:, :max_olen]
return before_outs, after_outs, d_outs, p_outs, e_outs, ys, olens, spk_logits
def _forward(
self,
xs: paddle.Tensor,
note: paddle.Tensor,
note_dur: paddle.Tensor,
is_slur: paddle.Tensor,
ilens: paddle.Tensor,
olens: paddle.Tensor=None,
ds: paddle.Tensor=None,
ps: paddle.Tensor=None,
es: paddle.Tensor=None,
is_inference: bool=False,
is_train_diffusion: bool=False,
return_after_enc=False,
alpha: float=1.0,
spk_emb=None,
spk_id=None, ) -> Sequence[paddle.Tensor]:
# forward encoder
x_masks = self._source_mask(ilens)
note_emb = self.note_embedding_table(note)
note_dur_emb = self.note_dur_layer(paddle.unsqueeze(note_dur, axis=-1))
is_slur_emb = self.is_slur_embedding_table(is_slur)
# (B, Tmax, adim)
hs, _ = self.encoder(
xs,
x_masks,
note_emb,
note_dur_emb,
is_slur_emb, )
if self.spk_num and self.enable_speaker_classifier and not is_inference:
hs_for_spk_cls = self.grad_reverse(hs)
spk_logits = self.speaker_classifier(hs_for_spk_cls, ilens)
else:
spk_logits = None
# integrate speaker embedding
if self.spk_embed_dim is not None:
# spk_emb has a higher priority than spk_id
if spk_emb is not None:
hs = self._integrate_with_spk_embed(hs, spk_emb)
elif spk_id is not None:
spk_emb = self.spk_embedding_table(spk_id)
hs = self._integrate_with_spk_embed(hs, spk_emb)
# forward duration predictor and variance predictors
d_masks = make_pad_mask(ilens)
if olens is not None:
pitch_masks = make_pad_mask(olens).unsqueeze(-1)
else:
pitch_masks = None
# inference for decoder input for duffusion
if is_train_diffusion:
hs = self.length_regulator(hs, ds, is_inference=False)
p_outs = self.pitch_predictor(hs.detach(), pitch_masks)
e_outs = self.energy_predictor(hs.detach(), pitch_masks)
p_embs = self.pitch_embed(p_outs.transpose((0, 2, 1))).transpose(
(0, 2, 1))
e_embs = self.energy_embed(e_outs.transpose((0, 2, 1))).transpose(
(0, 2, 1))
hs = hs + e_embs + p_embs
elif is_inference:
# (B, Tmax)
if ds is not None:
d_outs = ds
else:
d_outs = self.duration_predictor.inference(hs, d_masks)
# (B, Lmax, adim)
hs = self.length_regulator(hs, d_outs, alpha, is_inference=True)
if ps is not None:
p_outs = ps
else:
if self.stop_gradient_from_pitch_predictor:
p_outs = self.pitch_predictor(hs.detach(), pitch_masks)
else:
p_outs = self.pitch_predictor(hs, pitch_masks)
if es is not None:
e_outs = es
else:
if self.stop_gradient_from_energy_predictor:
e_outs = self.energy_predictor(hs.detach(), pitch_masks)
else:
e_outs = self.energy_predictor(hs, pitch_masks)
p_embs = self.pitch_embed(p_outs.transpose((0, 2, 1))).transpose(
(0, 2, 1))
e_embs = self.energy_embed(e_outs.transpose((0, 2, 1))).transpose(
(0, 2, 1))
hs = hs + e_embs + p_embs
# training
else:
d_outs = self.duration_predictor(hs, d_masks)
# (B, Lmax, adim)
hs = self.length_regulator(hs, ds, is_inference=False)
if self.stop_gradient_from_pitch_predictor:
p_outs = self.pitch_predictor(hs.detach(), pitch_masks)
else:
p_outs = self.pitch_predictor(hs, pitch_masks)
if self.stop_gradient_from_energy_predictor:
e_outs = self.energy_predictor(hs.detach(), pitch_masks)
else:
e_outs = self.energy_predictor(hs, pitch_masks)
p_embs = self.pitch_embed(ps.transpose((0, 2, 1))).transpose(
(0, 2, 1))
e_embs = self.energy_embed(es.transpose((0, 2, 1))).transpose(
(0, 2, 1))
hs = hs + e_embs + p_embs
# forward decoder
if olens is not None and not is_inference:
if self.reduction_factor > 1:
olens_in = paddle.to_tensor(
[olen // self.reduction_factor for olen in olens.numpy()])
else:
olens_in = olens
# (B, 1, T)
h_masks = self._source_mask(olens_in)
else:
h_masks = None
if return_after_enc:
return hs, h_masks
if self.decoder_type == 'cnndecoder':
# remove output masks for dygraph to static graph
zs = self.decoder(hs, h_masks)
before_outs = zs
else:
# (B, Lmax, adim)
zs, _ = self.decoder(hs, h_masks)
# (B, Lmax, odim)
before_outs = self.feat_out(zs).reshape(
(paddle.shape(zs)[0], -1, self.odim))
# postnet -> (B, Lmax//r * r, odim)
if self.postnet is None:
after_outs = before_outs
else:
after_outs = before_outs + self.postnet(
before_outs.transpose((0, 2, 1))).transpose((0, 2, 1))
return before_outs, after_outs, d_outs, p_outs, e_outs, spk_logits
def encoder_infer(
self,
text: paddle.Tensor,
note: paddle.Tensor,
note_dur: paddle.Tensor,
is_slur: paddle.Tensor,
alpha: float=1.0,
spk_emb=None,
spk_id=None,
) -> Tuple[paddle.Tensor, paddle.Tensor, paddle.Tensor]:
xs = paddle.cast(text, 'int64').unsqueeze(0)
note = paddle.cast(note, 'int64').unsqueeze(0)
note_dur = paddle.cast(note_dur, 'float32').unsqueeze(0)
is_slur = paddle.cast(is_slur, 'int64').unsqueeze(0)
# setup batch axis
ilens = paddle.shape(xs)[1]
if spk_emb is not None:
spk_emb = spk_emb.unsqueeze(0)
# (1, L, odim)
# use *_ to avoid bug in dygraph to static graph
hs, _ = self._forward(
xs,
note,
note_dur,
is_slur,
ilens,
is_inference=True,
return_after_enc=True,
alpha=alpha,
spk_emb=spk_emb,
spk_id=spk_id, )
return hs
# get encoder output for diffusion training
def encoder_infer_batch(
self,
text: paddle.Tensor,
note: paddle.Tensor,
note_dur: paddle.Tensor,
is_slur: paddle.Tensor,
text_lengths: paddle.Tensor,
speech_lengths: paddle.Tensor,
ds: paddle.Tensor=None,
ps: paddle.Tensor=None,
es: paddle.Tensor=None,
alpha: float=1.0,
spk_emb=None,
spk_id=None, ) -> Tuple[paddle.Tensor, paddle.Tensor]:
xs = paddle.cast(text, 'int64')
note = paddle.cast(note, 'int64')
note_dur = paddle.cast(note_dur, 'float32')
is_slur = paddle.cast(is_slur, 'int64')
ilens = paddle.cast(text_lengths, 'int64')
olens = paddle.cast(speech_lengths, 'int64')
if spk_emb is not None:
spk_emb = spk_emb.unsqueeze(0)
# (1, L, odim)
# use *_ to avoid bug in dygraph to static graph
hs, h_masks = self._forward(
xs,
note,
note_dur,
is_slur,
ilens,
olens,
ds,
ps,
es,
return_after_enc=True,
is_train_diffusion=True,
alpha=alpha,
spk_emb=spk_emb,
spk_id=spk_id, )
return hs, h_masks
def inference(
self,
text: paddle.Tensor,
note: paddle.Tensor,
note_dur: paddle.Tensor,
is_slur: paddle.Tensor,
durations: paddle.Tensor=None,
pitch: paddle.Tensor=None,
energy: paddle.Tensor=None,
alpha: float=1.0,
use_teacher_forcing: bool=False,
spk_emb=None,
spk_id=None,
) -> Tuple[paddle.Tensor, paddle.Tensor, paddle.Tensor]:
"""Generate the sequence of features given the sequences of characters.
Args:
text(Tensor(int64)):
Input sequence of characters (T,).
note(Tensor(int64)):
Input note (element in music score) ids (T,).
note_dur(Tensor(float32)):
Input note durations in seconds (element in music score) (T,).
is_slur(Tensor(int64)):
Input slur (element in music score) ids (T,).
durations(Tensor, optional (int64)):
Groundtruth of duration (T,).
pitch(Tensor, optional):
Groundtruth of token-averaged pitch (T, 1).
energy(Tensor, optional):
Groundtruth of token-averaged energy (T, 1).
alpha(float, optional):
Alpha to control the speed.
use_teacher_forcing(bool, optional):
Whether to use teacher forcing.
If true, groundtruth of duration, pitch and energy will be used.
spk_emb(Tensor, optional, optional):
peaker embedding vector (spk_embed_dim,). (Default value = None)
spk_id(Tensor, optional(int64), optional):
spk ids (1,). (Default value = None)
Returns:
"""
xs = paddle.cast(text, 'int64').unsqueeze(0)
note = paddle.cast(note, 'int64').unsqueeze(0)
note_dur = paddle.cast(note_dur, 'float32').unsqueeze(0)
is_slur = paddle.cast(is_slur, 'int64').unsqueeze(0)
d, p, e = durations, pitch, energy
# setup batch axis
ilens = paddle.shape(xs)[1]
if spk_emb is not None:
spk_emb = spk_emb.unsqueeze(0)
if use_teacher_forcing:
# use groundtruth of duration, pitch, and energy
ds = d.unsqueeze(0) if d is not None else None
ps = p.unsqueeze(0) if p is not None else None
es = e.unsqueeze(0) if e is not None else None
# (1, L, odim)
_, outs, d_outs, p_outs, e_outs, _ = self._forward(
xs,
note,
note_dur,
is_slur,
ilens,
ds=ds,
ps=ps,
es=es,
spk_emb=spk_emb,
spk_id=spk_id,
is_inference=True)
else:
# (1, L, odim)
_, outs, d_outs, p_outs, e_outs, _ = self._forward(
xs,
note,
note_dur,
is_slur,
ilens,
is_inference=True,
alpha=alpha,
spk_emb=spk_emb,
spk_id=spk_id, )
return outs[0], d_outs[0], p_outs[0], e_outs[0]
class FastSpeech2MIDILoss(nn.Layer):
"""Loss function module for DiffSinger."""
def __init__(self, use_masking: bool=True,
use_weighted_masking: bool=False):
"""Initialize feed-forward Transformer loss module.
Args:
use_masking (bool):
Whether to apply masking for padded part in loss calculation.
use_weighted_masking (bool):
Whether to weighted masking in loss calculation.
"""
assert check_argument_types()
super().__init__()
assert (use_masking != use_weighted_masking) or not use_masking
self.use_masking = use_masking
self.use_weighted_masking = use_weighted_masking
# define criterions
reduction = "none" if self.use_weighted_masking else "mean"
self.l1_criterion = nn.L1Loss(reduction=reduction)
self.mse_criterion = nn.MSELoss(reduction=reduction)
self.duration_criterion = DurationPredictorLoss(reduction=reduction)
self.ce_criterion = nn.CrossEntropyLoss()
def forward(
self,
after_outs: paddle.Tensor,
before_outs: paddle.Tensor,
d_outs: paddle.Tensor,
p_outs: paddle.Tensor,
e_outs: paddle.Tensor,
ys: paddle.Tensor,
ds: paddle.Tensor,
ps: paddle.Tensor,
es: paddle.Tensor,
ilens: paddle.Tensor,
olens: paddle.Tensor,
spk_logits: paddle.Tensor=None,
spk_ids: paddle.Tensor=None,
) -> Tuple[paddle.Tensor, paddle.Tensor, paddle.Tensor, paddle.Tensor,
paddle.Tensor, ]:
"""Calculate forward propagation.
Args:
after_outs(Tensor):
Batch of outputs after postnets (B, Lmax, odim).
before_outs(Tensor):
Batch of outputs before postnets (B, Lmax, odim).
d_outs(Tensor):
Batch of outputs of duration predictor (B, Tmax).
p_outs(Tensor):
Batch of outputs of pitch predictor (B, Lmax, 1).
e_outs(Tensor):
Batch of outputs of energy predictor (B, Lmax, 1).
ys(Tensor):
Batch of target features (B, Lmax, odim).
ds(Tensor):
Batch of durations (B, Tmax).
ps(Tensor):
Batch of target frame-averaged pitch (B, Lmax, 1).
es(Tensor):
Batch of target frame-averaged energy (B, Lmax, 1).
ilens(Tensor):
Batch of the lengths of each input (B,).
olens(Tensor):
Batch of the lengths of each target (B,).
spk_logits(Option[Tensor]):
Batch of outputs after speaker classifier (B, Lmax, num_spk)
spk_ids(Option[Tensor]):
Batch of target spk_id (B,)
Returns:
"""
speaker_loss = 0.0
# apply mask to remove padded part
if self.use_masking:
out_masks = make_non_pad_mask(olens).unsqueeze(-1)
before_outs = before_outs.masked_select(
out_masks.broadcast_to(before_outs.shape))
if after_outs is not None:
after_outs = after_outs.masked_select(
out_masks.broadcast_to(after_outs.shape))
ys = ys.masked_select(out_masks.broadcast_to(ys.shape))
duration_masks = make_non_pad_mask(ilens)
d_outs = d_outs.masked_select(
duration_masks.broadcast_to(d_outs.shape))
ds = ds.masked_select(duration_masks.broadcast_to(ds.shape))
pitch_masks = out_masks
p_outs = p_outs.masked_select(
pitch_masks.broadcast_to(p_outs.shape))
e_outs = e_outs.masked_select(
pitch_masks.broadcast_to(e_outs.shape))
ps = ps.masked_select(pitch_masks.broadcast_to(ps.shape))
es = es.masked_select(pitch_masks.broadcast_to(es.shape))
if spk_logits is not None and spk_ids is not None:
batch_size = spk_ids.shape[0]
spk_ids = paddle.repeat_interleave(spk_ids, spk_logits.shape[1],
None)
spk_logits = paddle.reshape(spk_logits,
[-1, spk_logits.shape[-1]])
mask_index = spk_logits.abs().sum(axis=1) != 0
spk_ids = spk_ids[mask_index]
spk_logits = spk_logits[mask_index]
# calculate loss
l1_loss = self.l1_criterion(before_outs, ys)
if after_outs is not None:
l1_loss += self.l1_criterion(after_outs, ys)
duration_loss = self.duration_criterion(d_outs, ds)
pitch_loss = self.mse_criterion(p_outs, ps)
energy_loss = self.mse_criterion(e_outs, es)
if spk_logits is not None and spk_ids is not None:
speaker_loss = self.ce_criterion(spk_logits, spk_ids) / batch_size
# make weighted mask and apply it
if self.use_weighted_masking:
out_masks = make_non_pad_mask(olens).unsqueeze(-1)
out_weights = out_masks.cast(dtype=paddle.float32) / out_masks.cast(
dtype=paddle.float32).sum(
axis=1, keepdim=True)
out_weights /= ys.shape[0] * ys.shape[2]
duration_masks = make_non_pad_mask(ilens)
duration_weights = (duration_masks.cast(dtype=paddle.float32) /
duration_masks.cast(dtype=paddle.float32).sum(
axis=1, keepdim=True))
duration_weights /= ds.shape[0]
# apply weight
l1_loss = l1_loss.multiply(out_weights)
l1_loss = l1_loss.masked_select(
out_masks.broadcast_to(l1_loss.shape)).sum()
duration_loss = (duration_loss.multiply(duration_weights)
.masked_select(duration_masks).sum())
pitch_masks = out_masks
pitch_weights = out_weights
pitch_loss = pitch_loss.multiply(pitch_weights)
pitch_loss = pitch_loss.masked_select(
pitch_masks.broadcast_to(pitch_loss.shape)).sum()
energy_loss = energy_loss.multiply(pitch_weights)
energy_loss = energy_loss.masked_select(
pitch_masks.broadcast_to(energy_loss.shape)).sum()
return l1_loss, duration_loss, pitch_loss, energy_loss, speaker_loss

@ -15,6 +15,7 @@
from typing import List from typing import List
from typing import Union from typing import Union
import paddle
from paddle import nn from paddle import nn
from paddlespeech.t2s.modules.activation import get_activation from paddlespeech.t2s.modules.activation import get_activation
@ -390,20 +391,26 @@ class TransformerEncoder(BaseEncoder):
padding_idx=padding_idx, padding_idx=padding_idx,
encoder_type="transformer") encoder_type="transformer")
def forward(self, xs, masks, note_emb=None, note_dur_emb=None, is_slur_emb=None, scale=16): def forward(self,
xs: paddle.Tensor,
masks: paddle.Tensor,
note_emb: paddle.Tensor=None,
note_dur_emb: paddle.Tensor=None,
is_slur_emb: paddle.Tensor=None,
scale: int=16):
"""Encoder input sequence. """Encoder input sequence.
Args: Args:
xs(Tensor): xs(Tensor):
Input tensor (#batch, time, idim). Input tensor (#batch, time, idim).
masks(Tensor):
Mask tensor (#batch, 1, time).
note_emb(Tensor): note_emb(Tensor):
Input tensor (#batch, time, attention_dim). Input tensor (#batch, time, attention_dim).
note_dur_emb(Tensor): note_dur_emb(Tensor):
Input tensor (#batch, time, attention_dim). Input tensor (#batch, time, attention_dim).
is_slur_emb(Tensor): is_slur_emb(Tensor):
Input tensor (#batch, time, attention_dim). Input tensor (#batch, time, attention_dim).
masks(Tensor):
Mask tensor (#batch, 1, time).
Returns: Returns:
Tensor: Tensor:

Loading…
Cancel
Save