diffsinger opencpop fft train, test=tts

pull/2834/head
liangym 3 years ago
parent 82378e5519
commit c463b35fa8

@ -0,0 +1,66 @@
#!/bin/bash
stage=1
stop_stage=100
config_path=$1
if [ ${stage} -le 1 ] && [ ${stop_stage} -ge 1 ]; then
# extract features
echo "Extract features ..."
python3 ${BIN_DIR}/preprocess.py \
--dataset=opencpop \
--rootdir=~/datasets/SVS/Opencpop/segments \
--dumpdir=dump \
--label-file=~/datasets/SVS/Opencpop/segments/transcriptions.txt \
--config=${config_path} \
--num-cpu=20 \
--cut-sil=True
fi
if [ ${stage} -le 2 ] && [ ${stop_stage} -ge 2 ]; then
# get features' stats(mean and std)
echo "Get features' stats ..."
python3 ${MAIN_ROOT}/utils/compute_statistics.py \
--metadata=dump/train/raw/metadata.jsonl \
--field-name="speech"
python3 ${MAIN_ROOT}/utils/compute_statistics.py \
--metadata=dump/train/raw/metadata.jsonl \
--field-name="pitch"
python3 ${MAIN_ROOT}/utils/compute_statistics.py \
--metadata=dump/train/raw/metadata.jsonl \
--field-name="energy"
fi
if [ ${stage} -le 3 ] && [ ${stop_stage} -ge 3 ]; then
# normalize and covert phone/speaker to id, dev and test should use train's stats
echo "Normalize ..."
python3 ${BIN_DIR}/normalize.py \
--metadata=dump/train/raw/metadata.jsonl \
--dumpdir=dump/train/norm \
--speech-stats=dump/train/speech_stats.npy \
--pitch-stats=dump/train/pitch_stats.npy \
--energy-stats=dump/train/energy_stats.npy \
--phones-dict=dump/phone_id_map.txt \
--speaker-dict=dump/speaker_id_map.txt
python3 ${BIN_DIR}/normalize.py \
--metadata=dump/dev/raw/metadata.jsonl \
--dumpdir=dump/dev/norm \
--speech-stats=dump/train/speech_stats.npy \
--pitch-stats=dump/train/pitch_stats.npy \
--energy-stats=dump/train/energy_stats.npy \
--phones-dict=dump/phone_id_map.txt \
--speaker-dict=dump/speaker_id_map.txt
python3 ${BIN_DIR}/normalize.py \
--metadata=dump/test/raw/metadata.jsonl \
--dumpdir=dump/test/norm \
--speech-stats=dump/train/speech_stats.npy \
--pitch-stats=dump/train/pitch_stats.npy \
--energy-stats=dump/train/energy_stats.npy \
--phones-dict=dump/phone_id_map.txt \
--speaker-dict=dump/speaker_id_map.txt
fi

@ -0,0 +1,100 @@
#!/bin/bash
config_path=$1
train_output_path=$2
ckpt_name=$3
stage=0
stop_stage=0
# pwgan
if [ ${stage} -le 0 ] && [ ${stop_stage} -ge 0 ]; then
FLAGS_allocator_strategy=naive_best_fit \
FLAGS_fraction_of_gpu_memory_to_use=0.01 \
python3 ${BIN_DIR}/../synthesize.py \
--am=diffsinger_opencpop \
--am_config=${config_path} \
--am_ckpt=${train_output_path}/checkpoints/${ckpt_name} \
--am_stat=dump/train/speech_stats.npy \
--voc=pwgan_opencpop \
--voc_config=pwgan_opencpop/default.yaml \
--voc_ckpt=pwgan_opencpop/snapshot_iter_100000.pdz \
--voc_stat=pwgan_opencpop/feats_stats.npy \
--test_metadata=dump/test/norm/metadata.jsonl \
--output_dir=${train_output_path}/test \
--phones_dict=dump/phone_id_map.txt
fi
# for more GAN Vocoders
# multi band melgan
if [ ${stage} -le 1 ] && [ ${stop_stage} -ge 1 ]; then
FLAGS_allocator_strategy=naive_best_fit \
FLAGS_fraction_of_gpu_memory_to_use=0.01 \
python3 ${BIN_DIR}/../synthesize.py \
--am=fastspeech2_csmsc \
--am_config=${config_path} \
--am_ckpt=${train_output_path}/checkpoints/${ckpt_name} \
--am_stat=dump/train/speech_stats.npy \
--voc=mb_melgan_csmsc \
--voc_config=mb_melgan_csmsc_ckpt_0.1.1/default.yaml \
--voc_ckpt=mb_melgan_csmsc_ckpt_0.1.1/snapshot_iter_1000000.pdz\
--voc_stat=mb_melgan_csmsc_ckpt_0.1.1/feats_stats.npy \
--test_metadata=dump/test/norm/metadata.jsonl \
--output_dir=${train_output_path}/test \
--phones_dict=dump/phone_id_map.txt
fi
# style melgan
if [ ${stage} -le 2 ] && [ ${stop_stage} -ge 2 ]; then
FLAGS_allocator_strategy=naive_best_fit \
FLAGS_fraction_of_gpu_memory_to_use=0.01 \
python3 ${BIN_DIR}/../synthesize.py \
--am=fastspeech2_csmsc \
--am_config=${config_path} \
--am_ckpt=${train_output_path}/checkpoints/${ckpt_name} \
--am_stat=dump/train/speech_stats.npy \
--voc=style_melgan_csmsc \
--voc_config=style_melgan_csmsc_ckpt_0.1.1/default.yaml \
--voc_ckpt=style_melgan_csmsc_ckpt_0.1.1/snapshot_iter_1500000.pdz \
--voc_stat=style_melgan_csmsc_ckpt_0.1.1/feats_stats.npy \
--test_metadata=dump/test/norm/metadata.jsonl \
--output_dir=${train_output_path}/test \
--phones_dict=dump/phone_id_map.txt
fi
# hifigan
if [ ${stage} -le 3 ] && [ ${stop_stage} -ge 3 ]; then
echo "in hifigan syn"
FLAGS_allocator_strategy=naive_best_fit \
FLAGS_fraction_of_gpu_memory_to_use=0.01 \
python3 ${BIN_DIR}/../synthesize.py \
--am=fastspeech2_csmsc \
--am_config=${config_path} \
--am_ckpt=${train_output_path}/checkpoints/${ckpt_name} \
--am_stat=dump/train/speech_stats.npy \
--voc=hifigan_csmsc \
--voc_config=hifigan_csmsc_ckpt_0.1.1/default.yaml \
--voc_ckpt=hifigan_csmsc_ckpt_0.1.1/snapshot_iter_2500000.pdz \
--voc_stat=hifigan_csmsc_ckpt_0.1.1/feats_stats.npy \
--test_metadata=dump/test/norm/metadata.jsonl \
--output_dir=${train_output_path}/test \
--phones_dict=dump/phone_id_map.txt
fi
# wavernn
if [ ${stage} -le 4 ] && [ ${stop_stage} -ge 4 ]; then
echo "in wavernn syn"
FLAGS_allocator_strategy=naive_best_fit \
FLAGS_fraction_of_gpu_memory_to_use=0.01 \
python3 ${BIN_DIR}/../synthesize.py \
--am=fastspeech2_csmsc \
--am_config=${config_path} \
--am_ckpt=${train_output_path}/checkpoints/${ckpt_name} \
--am_stat=dump/train/speech_stats.npy \
--voc=wavernn_csmsc \
--voc_config=wavernn_csmsc_ckpt_0.2.0/default.yaml \
--voc_ckpt=wavernn_csmsc_ckpt_0.2.0/snapshot_iter_400000.pdz \
--voc_stat=wavernn_csmsc_ckpt_0.2.0/feats_stats.npy \
--test_metadata=dump/test/norm/metadata.jsonl \
--output_dir=${train_output_path}/test \
--phones_dict=dump/phone_id_map.txt
fi

@ -0,0 +1,12 @@
#!/bin/bash
config_path=$1
train_output_path=$2
python3 ${BIN_DIR}/train.py \
--train-metadata=dump/train/norm/metadata.jsonl \
--dev-metadata=dump/dev/norm/metadata.jsonl \
--config=${config_path} \
--output-dir=${train_output_path} \
--ngpu=2 \
--phones-dict=dump/phone_id_map.txt

@ -0,0 +1,37 @@
#!/bin/bash
set -e
source path.sh
gpus=4,5
stage=0
stop_stage=100
conf_path=conf/default.yaml
train_output_path=exp/default
ckpt_name=snapshot_iter_153.pdz
# with the following command, you can choose the stage range you want to run
# such as `./run.sh --stage 0 --stop-stage 0`
# this can not be mixed use with `$1`, `$2` ...
source ${MAIN_ROOT}/utils/parse_options.sh || exit 1
if [ ${stage} -le 0 ] && [ ${stop_stage} -ge 0 ]; then
# prepare data
./local/preprocess.sh ${conf_path} || exit -1
fi
if [ ${stage} -le 1 ] && [ ${stop_stage} -ge 1 ]; then
# train model, all `ckpt` under `train_output_path/checkpoints/` dir
CUDA_VISIBLE_DEVICES=${gpus} ./local/train.sh ${conf_path} ${train_output_path} || exit -1
fi
if [ ${stage} -le 2 ] && [ ${stop_stage} -ge 2 ]; then
# synthesize, vocoder is pwgan by default
CUDA_VISIBLE_DEVICES=${gpus} ./local/synthesize.sh ${conf_path} ${train_output_path} ${ckpt_name} || exit -1
fi
if [ ${stage} -le 3 ] && [ ${stop_stage} -ge 3 ]; then
# synthesize_e2e, vocoder is pwgan by default
CUDA_VISIBLE_DEVICES=${gpus} ./local/synthesize_e2e.sh ${conf_path} ${train_output_path} ${ckpt_name} || exit -1
fi

@ -414,6 +414,129 @@ def fastspeech2_multi_spk_batch_fn(examples):
return batch return batch
def diffsinger_single_spk_batch_fn(examples):
# fields = ["text", "note", "note_dur", "is_slur", "text_lengths", "speech", "speech_lengths", "durations", "pitch", "energy"]
text = [np.array(item["text"], dtype=np.int64) for item in examples]
note = [np.array(item["note"], dtype=np.int64) for item in examples]
note_dur = [np.array(item["note_dur"], dtype=np.float32) for item in examples]
is_slur = [np.array(item["is_slur"], dtype=np.int64) for item in examples]
speech = [np.array(item["speech"], dtype=np.float32) for item in examples]
pitch = [np.array(item["pitch"], dtype=np.float32) for item in examples]
energy = [np.array(item["energy"], dtype=np.float32) for item in examples]
durations = [
np.array(item["durations"], dtype=np.int64) for item in examples
]
text_lengths = [
np.array(item["text_lengths"], dtype=np.int64) for item in examples
]
speech_lengths = [
np.array(item["speech_lengths"], dtype=np.int64) for item in examples
]
text = batch_sequences(text)
note = batch_sequences(note)
note_dur = batch_sequences(note_dur)
is_slur = batch_sequences(is_slur)
pitch = batch_sequences(pitch)
speech = batch_sequences(speech)
durations = batch_sequences(durations)
energy = batch_sequences(energy)
# convert each batch to paddle.Tensor
text = paddle.to_tensor(text)
note = paddle.to_tensor(note)
note_dur = paddle.to_tensor(note_dur)
is_slur = paddle.to_tensor(is_slur)
pitch = paddle.to_tensor(pitch)
speech = paddle.to_tensor(speech)
durations = paddle.to_tensor(durations)
energy = paddle.to_tensor(energy)
text_lengths = paddle.to_tensor(text_lengths)
speech_lengths = paddle.to_tensor(speech_lengths)
batch = {
"text": text,
"note": note,
"note_dur": note_dur,
"is_slur": is_slur,
"text_lengths": text_lengths,
"durations": durations,
"speech": speech,
"speech_lengths": speech_lengths,
"pitch": pitch,
"energy": energy
}
return batch
def diffsinger_multi_spk_batch_fn(examples):
# fields = ["text", "note", "note_dur", "is_slur", "text_lengths", "speech", "speech_lengths", "durations", "pitch", "energy", "spk_id"/"spk_emb"]
text = [np.array(item["text"], dtype=np.int64) for item in examples]
note = [np.array(item["note"], dtype=np.int64) for item in examples]
note_dur = [np.array(item["note_dur"], dtype=np.float32) for item in examples]
is_slur = [np.array(item["is_slur"], dtype=np.int64) for item in examples]
speech = [np.array(item["speech"], dtype=np.float32) for item in examples]
pitch = [np.array(item["pitch"], dtype=np.float32) for item in examples]
energy = [np.array(item["energy"], dtype=np.float32) for item in examples]
durations = [
np.array(item["durations"], dtype=np.int64) for item in examples
]
text_lengths = [
np.array(item["text_lengths"], dtype=np.int64) for item in examples
]
speech_lengths = [
np.array(item["speech_lengths"], dtype=np.int64) for item in examples
]
text = batch_sequences(text)
note = batch_sequences(note)
note_dur = batch_sequences(note_dur)
is_slur = batch_sequences(is_slur)
pitch = batch_sequences(pitch)
speech = batch_sequences(speech)
durations = batch_sequences(durations)
energy = batch_sequences(energy)
# convert each batch to paddle.Tensor
text = paddle.to_tensor(text)
note = paddle.to_tensor(note)
note_dur = paddle.to_tensor(note_dur)
is_slur = paddle.to_tensor(is_slur)
pitch = paddle.to_tensor(pitch)
speech = paddle.to_tensor(speech)
durations = paddle.to_tensor(durations)
energy = paddle.to_tensor(energy)
text_lengths = paddle.to_tensor(text_lengths)
speech_lengths = paddle.to_tensor(speech_lengths)
batch = {
"text": text,
"note": note,
"note_dur": note_dur,
"is_slur": is_slur,
"text_lengths": text_lengths,
"durations": durations,
"speech": speech,
"speech_lengths": speech_lengths,
"pitch": pitch,
"energy": energy
}
# spk_emb has a higher priority than spk_id
if "spk_emb" in examples[0]:
spk_emb = [
np.array(item["spk_emb"], dtype=np.float32) for item in examples
]
spk_emb = batch_sequences(spk_emb)
spk_emb = paddle.to_tensor(spk_emb)
batch["spk_emb"] = spk_emb
elif "spk_id" in examples[0]:
spk_id = [np.array(item["spk_id"], dtype=np.int64) for item in examples]
spk_id = paddle.to_tensor(spk_id)
batch["spk_id"] = spk_id
return batch
def transformer_single_spk_batch_fn(examples): def transformer_single_spk_batch_fn(examples):
# fields = ["text", "text_lengths", "speech", "speech_lengths"] # fields = ["text", "text_lengths", "speech", "speech_lengths"]
text = [np.array(item["text"], dtype=np.int64) for item in examples] text = [np.array(item["text"], dtype=np.int64) for item in examples]

@ -16,6 +16,7 @@ import librosa
import numpy as np import numpy as np
import pyworld import pyworld
from scipy.interpolate import interp1d from scipy.interpolate import interp1d
from typing import List
class LogMelFBank(): class LogMelFBank():
@ -166,6 +167,8 @@ class Pitch():
f0 = self._calculate_f0(wav, use_continuous_f0, use_log_f0) f0 = self._calculate_f0(wav, use_continuous_f0, use_log_f0)
if use_token_averaged_f0 and duration is not None: if use_token_averaged_f0 and duration is not None:
f0 = self._average_by_duration(f0, duration) f0 = self._average_by_duration(f0, duration)
else:
f0 = np.expand_dims(np.array(f0), 0).T
return f0 return f0
@ -224,6 +227,8 @@ class Energy():
energy = self._calculate_energy(wav) energy = self._calculate_energy(wav)
if use_token_averaged_energy and duration is not None: if use_token_averaged_energy and duration is not None:
energy = self._average_by_duration(energy, duration) energy = self._average_by_duration(energy, duration)
else:
energy = np.expand_dims(np.array(energy), 0).T
return energy return energy

@ -12,6 +12,8 @@
# See the License for the specific language governing permissions and # See the License for the specific language governing permissions and
# limitations under the License. # limitations under the License.
import re import re
import librosa
import numpy as np
# speaker|utt_id|phn dur phn dur ... # speaker|utt_id|phn dur phn dur ...
@ -40,6 +42,58 @@ def get_phn_dur(file_name):
f.close() f.close()
return sentence, speaker_set return sentence, speaker_set
def note2midi(notes):
midis = []
for note in notes:
if note == 'rest':
midi = 0
else:
midi = librosa.note_to_midi(note.split("/")[0])
midis.append(midi)
return midis
def time2frame(times, sample_rate: int=24000, n_shift: int=128,):
end = 0.0
ends = []
for t in times:
end += t
ends.append(end)
frame_pos = librosa.time_to_frames(ends, sr=sample_rate, hop_length=n_shift)
durations = np.diff(frame_pos, prepend=0)
return durations
def get_sentences_svs(file_name, dataset: str='opencpop', sample_rate: int=24000, n_shift: int=128,):
'''
read label file
Args:
file_name (str or Path): path of gen_duration_from_textgrid.py's result
dataset (str): dataset name
Returns:
Dict: sentence: {'utt': ([char], [int])}
'''
f = open(file_name, 'r')
sentence = {}
speaker_set = set()
if dataset == 'opencpop':
speaker_set.add("opencpop")
for line in f:
line_list = line.strip().split('|')
utt = line_list[0]
text = line_list[1]
ph = line_list[2].split()
midi = note2midi(line_list[3].split())
midi_dur = line_list[4].split()
ph_dur = time2frame([float(t) for t in line_list[5].split()])
is_slur = line_list[6].split()
assert len(ph) == len(midi) == len(midi_dur) == len(is_slur)
sentence[utt] = (ph, [int(i) for i in ph_dur], [int(i) for i in midi], [float(i) for i in midi_dur], [int(i) for i in is_slur], text, "opencpop")
else:
print("dataset should in {opencpop} now!")
f.close()
return sentence, speaker_set
def merge_silence(sentence): def merge_silence(sentence):
''' '''
@ -88,6 +142,9 @@ def get_input_token(sentence, output_path, dataset="baker"):
phn_token = ["<pad>", "<unk>"] + phn_token phn_token = ["<pad>", "<unk>"] + phn_token
if dataset in {"baker", "aishell3"}: if dataset in {"baker", "aishell3"}:
phn_token += ["", "", "", ""] phn_token += ["", "", "", ""]
# svs dataset
elif dataset in {"opencpop"}:
pass
else: else:
phn_token += [",", ".", "?", "!"] phn_token += [",", ".", "?", "!"]
phn_token += ["<eos>"] phn_token += ["<eos>"]

@ -0,0 +1,13 @@
# Copyright (c) 2020 PaddlePaddle Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.

@ -0,0 +1,165 @@
# Copyright (c) 2021 PaddlePaddle Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""Normalize feature files and dump them."""
import argparse
import logging
from operator import itemgetter
from pathlib import Path
import jsonlines
import numpy as np
from sklearn.preprocessing import StandardScaler
from tqdm import tqdm
from paddlespeech.t2s.datasets.data_table import DataTable
def main():
"""Run preprocessing process."""
parser = argparse.ArgumentParser(
description="Normalize dumped raw features (See detail in parallel_wavegan/bin/normalize.py)."
)
parser.add_argument(
"--metadata",
type=str,
required=True,
help="directory including feature files to be normalized. "
"you need to specify either *-scp or rootdir.")
parser.add_argument(
"--dumpdir",
type=str,
required=True,
help="directory to dump normalized feature files.")
parser.add_argument(
"--speech-stats",
type=str,
required=True,
help="speech statistics file.")
parser.add_argument(
"--pitch-stats", type=str, required=True, help="pitch statistics file.")
parser.add_argument(
"--energy-stats",
type=str,
required=True,
help="energy statistics file.")
parser.add_argument(
"--phones-dict", type=str, default=None, help="phone vocabulary file.")
parser.add_argument(
"--speaker-dict", type=str, default=None, help="speaker id map file.")
args = parser.parse_args()
dumpdir = Path(args.dumpdir).expanduser()
# use absolute path
dumpdir = dumpdir.resolve()
dumpdir.mkdir(parents=True, exist_ok=True)
# get dataset
with jsonlines.open(args.metadata, 'r') as reader:
metadata = list(reader)
dataset = DataTable(
metadata,
converters={
"speech": np.load,
"pitch": np.load,
"energy": np.load,
})
logging.info(f"The number of files = {len(dataset)}.")
# restore scaler
speech_scaler = StandardScaler()
speech_scaler.mean_ = np.load(args.speech_stats)[0]
speech_scaler.scale_ = np.load(args.speech_stats)[1]
speech_scaler.n_features_in_ = speech_scaler.mean_.shape[0]
pitch_scaler = StandardScaler()
pitch_scaler.mean_ = np.load(args.pitch_stats)[0]
pitch_scaler.scale_ = np.load(args.pitch_stats)[1]
pitch_scaler.n_features_in_ = pitch_scaler.mean_.shape[0]
energy_scaler = StandardScaler()
energy_scaler.mean_ = np.load(args.energy_stats)[0]
energy_scaler.scale_ = np.load(args.energy_stats)[1]
energy_scaler.n_features_in_ = energy_scaler.mean_.shape[0]
vocab_phones = {}
with open(args.phones_dict, 'rt') as f:
phn_id = [line.strip().split() for line in f.readlines()]
for phn, id in phn_id:
vocab_phones[phn] = int(id)
vocab_speaker = {}
with open(args.speaker_dict, 'rt') as f:
spk_id = [line.strip().split() for line in f.readlines()]
for spk, id in spk_id:
vocab_speaker[spk] = int(id)
# process each file
output_metadata = []
for item in tqdm(dataset):
utt_id = item['utt_id']
speech = item['speech']
pitch = item['pitch']
energy = item['energy']
# normalize
speech = speech_scaler.transform(speech)
speech_dir = dumpdir / "data_speech"
speech_dir.mkdir(parents=True, exist_ok=True)
speech_path = speech_dir / f"{utt_id}_speech.npy"
np.save(speech_path, speech.astype(np.float32), allow_pickle=False)
pitch = pitch_scaler.transform(pitch)
pitch_dir = dumpdir / "data_pitch"
pitch_dir.mkdir(parents=True, exist_ok=True)
pitch_path = pitch_dir / f"{utt_id}_pitch.npy"
np.save(pitch_path, pitch.astype(np.float32), allow_pickle=False)
energy = energy_scaler.transform(energy)
energy_dir = dumpdir / "data_energy"
energy_dir.mkdir(parents=True, exist_ok=True)
energy_path = energy_dir / f"{utt_id}_energy.npy"
np.save(energy_path, energy.astype(np.float32), allow_pickle=False)
phone_ids = [vocab_phones[p] for p in item['phones']]
spk_id = vocab_speaker[item["speaker"]]
record = {
"utt_id": item['utt_id'],
"spk_id": spk_id,
"text": phone_ids,
"text_lengths": item['text_lengths'],
"speech_lengths": item['speech_lengths'],
"durations": item['durations'],
"speech": str(speech_path),
"pitch": str(pitch_path),
"energy": str(energy_path),
"note": item['note'],
"note_dur": item['note_dur'],
"is_slur": item['is_slur'],
}
# add spk_emb for voice cloning
if "spk_emb" in item:
record["spk_emb"] = str(item["spk_emb"])
output_metadata.append(record)
output_metadata.sort(key=itemgetter('utt_id'))
output_metadata_path = Path(args.dumpdir) / "metadata.jsonl"
with jsonlines.open(output_metadata_path, 'w') as writer:
for item in output_metadata:
writer.write(item)
logging.info(f"metadata dumped into {output_metadata_path}")
if __name__ == "__main__":
main()

@ -0,0 +1,359 @@
# Copyright (c) 2021 PaddlePaddle Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import argparse
import os
from concurrent.futures import ThreadPoolExecutor
from operator import itemgetter
from pathlib import Path
from typing import Any
from typing import Dict
from typing import List
import jsonlines
import librosa
import numpy as np
import tqdm
import yaml
from yacs.config import CfgNode
from paddlespeech.t2s.datasets.get_feats import Energy
from paddlespeech.t2s.datasets.get_feats import LogMelFBank
from paddlespeech.t2s.datasets.get_feats import Pitch
from paddlespeech.t2s.datasets.preprocess_utils import compare_duration_and_mel_length
from paddlespeech.t2s.datasets.preprocess_utils import get_input_token
from paddlespeech.t2s.datasets.preprocess_utils import get_sentences_svs
from paddlespeech.t2s.datasets.preprocess_utils import get_spk_id_map
from paddlespeech.t2s.datasets.preprocess_utils import merge_silence
from paddlespeech.t2s.utils import str2bool
ALL_SHENGMU = ['zh', 'ch', 'sh', 'b', 'p', 'm', 'f', 'd', 't', 'n', 'l', 'g', 'k', 'h', 'j',
'q', 'x', 'r', 'z', 'c', 's', 'y', 'w']
ALL_YUNMU = ['a', 'ai', 'an', 'ang', 'ao', 'e', 'ei', 'en', 'eng', 'er', 'i', 'ia', 'ian',
'iang', 'iao', 'ie', 'in', 'ing', 'iong', 'iu', 'ng', 'o', 'ong', 'ou',
'u', 'ua', 'uai', 'uan', 'uang', 'ui', 'un', 'uo', 'v', 'van', 've', 'vn']
def process_sentence(config: Dict[str, Any],
fp: Path,
sentences: Dict,
output_dir: Path,
mel_extractor=None,
pitch_extractor=None,
energy_extractor=None,
cut_sil: bool=True,
spk_emb_dir: Path=None,):
utt_id = fp.stem
record = None
if utt_id in sentences:
# reading, resampling may occur
wav, _ = librosa.load(str(fp), sr=config.fs)
if len(wav.shape) != 1:
return record
max_value = np.abs(wav).max()
if max_value > 1.0:
wav = wav / max_value
assert len(wav.shape) == 1, f"{utt_id} is not a mono-channel audio."
assert np.abs(wav).max(
) <= 1.0, f"{utt_id} is seems to be different that 16 bit PCM."
phones = sentences[utt_id][0]
durations = sentences[utt_id][1]
note = sentences[utt_id][2]
note_dur = sentences[utt_id][3]
is_slur = sentences[utt_id][4]
speaker = sentences[utt_id][-1]
# extract mel feats
logmel = mel_extractor.get_log_mel_fbank(wav)
# change duration according to mel_length
compare_duration_and_mel_length(sentences, utt_id, logmel)
# utt_id may be popped in compare_duration_and_mel_length
if utt_id not in sentences:
return None
phones = sentences[utt_id][0]
durations = sentences[utt_id][1]
num_frames = logmel.shape[0]
word_boundary = [1 if x in ALL_YUNMU + ['AP', 'SP'] else 0 for x in phones]
# print(sum(durations), num_frames)
assert sum(durations) == num_frames, "the sum of durations doesn't equal to the num of mel frames. "
speech_dir = output_dir / "data_speech"
speech_dir.mkdir(parents=True, exist_ok=True)
speech_path = speech_dir / (utt_id + "_speech.npy")
np.save(speech_path, logmel)
# extract pitch and energy
pitch = pitch_extractor.get_pitch(wav)
assert pitch.shape[0] == num_frames
pitch_dir = output_dir / "data_pitch"
pitch_dir.mkdir(parents=True, exist_ok=True)
pitch_path = pitch_dir / (utt_id + "_pitch.npy")
np.save(pitch_path, pitch)
energy = energy_extractor.get_energy(wav)
assert energy.shape[0] == num_frames
energy_dir = output_dir / "data_energy"
energy_dir.mkdir(parents=True, exist_ok=True)
energy_path = energy_dir / (utt_id + "_energy.npy")
np.save(energy_path, energy)
record = {
"utt_id": utt_id,
"phones": phones,
"text_lengths": len(phones),
"speech_lengths": num_frames,
"durations": durations,
"speech": str(speech_path),
"pitch": str(pitch_path),
"energy": str(energy_path),
"speaker": speaker,
"note": note,
"note_dur": note_dur,
"is_slur": is_slur,
}
if spk_emb_dir:
if speaker in os.listdir(spk_emb_dir):
embed_name = utt_id + ".npy"
embed_path = spk_emb_dir / speaker / embed_name
if embed_path.is_file():
record["spk_emb"] = str(embed_path)
else:
return None
return record
def process_sentences(config,
fps: List[Path],
sentences: Dict,
output_dir: Path,
mel_extractor=None,
pitch_extractor=None,
energy_extractor=None,
nprocs: int=1,
cut_sil: bool=True,
spk_emb_dir: Path=None,
write_metadata_method: str='w',):
if nprocs == 1:
results = []
for fp in tqdm.tqdm(fps, total=len(fps)):
record = process_sentence(
config=config,
fp=fp,
sentences=sentences,
output_dir=output_dir,
mel_extractor=mel_extractor,
pitch_extractor=pitch_extractor,
energy_extractor=energy_extractor,
cut_sil=cut_sil,
spk_emb_dir=spk_emb_dir,)
if record:
results.append(record)
else:
with ThreadPoolExecutor(nprocs) as pool:
futures = []
with tqdm.tqdm(total=len(fps)) as progress:
for fp in fps:
future = pool.submit(process_sentence, config, fp,
sentences, output_dir, mel_extractor,
pitch_extractor, energy_extractor,
cut_sil, spk_emb_dir,)
future.add_done_callback(lambda p: progress.update())
futures.append(future)
results = []
for ft in futures:
record = ft.result()
if record:
results.append(record)
results.sort(key=itemgetter("utt_id"))
with jsonlines.open(output_dir / "metadata.jsonl",
write_metadata_method) as writer:
for item in results:
writer.write(item)
print("Done")
def main():
# parse config and args
parser = argparse.ArgumentParser(
description="Preprocess audio and then extract features.")
parser.add_argument(
"--dataset",
default="opencpop",
type=str,
help="name of dataset, should in {opencpop} now")
parser.add_argument(
"--rootdir", default=None, type=str, help="directory to dataset.")
parser.add_argument(
"--dumpdir",
type=str,
required=True,
help="directory to dump feature files.")
parser.add_argument(
"--label-file", default=None, type=str, help="path to label file.")
parser.add_argument("--config", type=str, help="diffsinger config file.")
parser.add_argument(
"--num-cpu", type=int, default=1, help="number of process.")
parser.add_argument(
"--cut-sil",
type=str2bool,
default=True,
help="whether cut sil in the edge of audio")
parser.add_argument(
"--spk_emb_dir",
default=None,
type=str,
help="directory to speaker embedding files.")
parser.add_argument(
"--write_metadata_method",
default="w",
type=str,
choices=["w", "a"],
help="How the metadata.jsonl file is written.")
args = parser.parse_args()
rootdir = Path(args.rootdir).expanduser()
dumpdir = Path(args.dumpdir).expanduser()
# use absolute path
dumpdir = dumpdir.resolve()
dumpdir.mkdir(parents=True, exist_ok=True)
label_file = Path(args.label_file).expanduser()
if args.spk_emb_dir:
spk_emb_dir = Path(args.spk_emb_dir).expanduser().resolve()
else:
spk_emb_dir = None
assert rootdir.is_dir()
assert label_file.is_file()
with open(args.config, 'rt') as f:
config = CfgNode(yaml.safe_load(f))
sentences, speaker_set = get_sentences_svs(label_file, dataset=args.dataset, sample_rate=config.fs, n_shift=config.n_shift,)
# merge_silence(sentences)
phone_id_map_path = dumpdir / "phone_id_map.txt"
speaker_id_map_path = dumpdir / "speaker_id_map.txt"
get_input_token(sentences, phone_id_map_path, args.dataset)
get_spk_id_map(speaker_set, speaker_id_map_path)
if args.dataset == "opencpop":
wavdir = rootdir / "wavs"
# split data into 3 sections
train_file = rootdir / "train.txt"
train_wav_files = []
with open(train_file, "r") as f_train:
for line in f_train.readlines():
utt = line.split("|")[0]
wav_name = utt + ".wav"
wav_path = wavdir / wav_name
train_wav_files.append(wav_path)
test_file = rootdir / "test.txt"
dev_wav_files = []
test_wav_files = []
num_dev = 106
count = 0
with open(test_file, "r") as f_test:
for line in f_test.readlines():
count += 1
utt = line.split("|")[0]
wav_name = utt + ".wav"
wav_path = wavdir / wav_name
if count > num_dev:
test_wav_files.append(wav_path)
else:
dev_wav_files.append(wav_path)
else:
print("dataset should in {opencpop} now!")
train_dump_dir = dumpdir / "train" / "raw"
train_dump_dir.mkdir(parents=True, exist_ok=True)
dev_dump_dir = dumpdir / "dev" / "raw"
dev_dump_dir.mkdir(parents=True, exist_ok=True)
test_dump_dir = dumpdir / "test" / "raw"
test_dump_dir.mkdir(parents=True, exist_ok=True)
# Extractor
mel_extractor = LogMelFBank(
sr=config.fs,
n_fft=config.n_fft,
hop_length=config.n_shift,
win_length=config.win_length,
window=config.window,
n_mels=config.n_mels,
fmin=config.fmin,
fmax=config.fmax)
pitch_extractor = Pitch(
sr=config.fs,
hop_length=config.n_shift,
f0min=config.f0min,
f0max=config.f0max)
energy_extractor = Energy(
n_fft=config.n_fft,
hop_length=config.n_shift,
win_length=config.win_length,
window=config.window)
# process for the 3 sections
if train_wav_files:
process_sentences(
config=config,
fps=train_wav_files,
sentences=sentences,
output_dir=train_dump_dir,
mel_extractor=mel_extractor,
pitch_extractor=pitch_extractor,
energy_extractor=energy_extractor,
nprocs=args.num_cpu,
cut_sil=args.cut_sil,
spk_emb_dir=spk_emb_dir,
write_metadata_method=args.write_metadata_method)
if dev_wav_files:
process_sentences(
config=config,
fps=dev_wav_files,
sentences=sentences,
output_dir=dev_dump_dir,
mel_extractor=mel_extractor,
pitch_extractor=pitch_extractor,
energy_extractor=energy_extractor,
cut_sil=args.cut_sil,
spk_emb_dir=spk_emb_dir,
write_metadata_method=args.write_metadata_method)
if test_wav_files:
process_sentences(
config=config,
fps=test_wav_files,
sentences=sentences,
output_dir=test_dump_dir,
mel_extractor=mel_extractor,
pitch_extractor=pitch_extractor,
energy_extractor=energy_extractor,
nprocs=args.num_cpu,
cut_sil=args.cut_sil,
spk_emb_dir=spk_emb_dir,
write_metadata_method=args.write_metadata_method)
if __name__ == "__main__":
main()

@ -0,0 +1,222 @@
# Copyright (c) 2021 PaddlePaddle Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import argparse
import logging
import os
import shutil
from pathlib import Path
import jsonlines
import numpy as np
import paddle
import yaml
from paddle import DataParallel
from paddle import distributed as dist
from paddle.io import DataLoader
from paddle.io import DistributedBatchSampler
from yacs.config import CfgNode
from paddlespeech.t2s.datasets.am_batch_fn import diffsinger_multi_spk_batch_fn
from paddlespeech.t2s.datasets.am_batch_fn import diffsinger_single_spk_batch_fn
from paddlespeech.t2s.datasets.data_table import DataTable
from paddlespeech.t2s.models.diffsinger import DiffSinger
from paddlespeech.t2s.models.diffsinger import DiffSingerEvaluator
from paddlespeech.t2s.models.diffsinger import DiffSingerUpdater
from paddlespeech.t2s.training.extensions.snapshot import Snapshot
from paddlespeech.t2s.training.extensions.visualizer import VisualDL
from paddlespeech.t2s.training.optimizer import build_optimizers
from paddlespeech.t2s.training.seeding import seed_everything
from paddlespeech.t2s.training.trainer import Trainer
from paddlespeech.t2s.utils import str2bool
def train_sp(args, config):
# decides device type and whether to run in parallel
# setup running environment correctly
if (not paddle.is_compiled_with_cuda()) or args.ngpu == 0:
paddle.set_device("cpu")
else:
paddle.set_device("gpu")
world_size = paddle.distributed.get_world_size()
if world_size > 1:
paddle.distributed.init_parallel_env()
# set the random seed, it is a must for multiprocess training
seed_everything(config.seed)
print(
f"rank: {dist.get_rank()}, pid: {os.getpid()}, parent_pid: {os.getppid()}",
)
fields = [
"text", "text_lengths", "speech", "speech_lengths", "durations", "pitch", "energy",
"note", "note_dur", "is_slur"]
converters = {"speech": np.load, "pitch": np.load, "energy": np.load}
spk_num = None
if args.speaker_dict is not None:
print("multiple speaker diffsinger!")
collate_fn = diffsinger_multi_spk_batch_fn
with open(args.speaker_dict, 'rt') as f:
spk_id = [line.strip().split() for line in f.readlines()]
spk_num = len(spk_id)
fields += ["spk_id"]
elif args.voice_cloning:
print("Training voice cloning!")
collate_fn = diffsinger_multi_spk_batch_fn
fields += ["spk_emb"]
converters["spk_emb"] = np.load
else:
collate_fn = diffsinger_single_spk_batch_fn
print("single speaker diffsinger!")
print("spk_num:", spk_num)
# dataloader has been too verbose
logging.getLogger("DataLoader").disabled = True
# construct dataset for training and validation
with jsonlines.open(args.train_metadata, 'r') as reader:
train_metadata = list(reader)
train_dataset = DataTable(
data=train_metadata,
fields=fields,
converters=converters, )
with jsonlines.open(args.dev_metadata, 'r') as reader:
dev_metadata = list(reader)
dev_dataset = DataTable(
data=dev_metadata,
fields=fields,
converters=converters, )
# collate function and dataloader
train_sampler = DistributedBatchSampler(
train_dataset,
batch_size=config.batch_size,
shuffle=True,
drop_last=True)
print("samplers done!")
train_dataloader = DataLoader(
train_dataset,
batch_sampler=train_sampler,
collate_fn=collate_fn,
num_workers=config.num_workers)
dev_dataloader = DataLoader(
dev_dataset,
shuffle=False,
drop_last=False,
batch_size=config.batch_size,
collate_fn=collate_fn,
num_workers=config.num_workers)
print("dataloaders done!")
with open(args.phones_dict, "r") as f:
phn_id = [line.strip().split() for line in f.readlines()]
vocab_size = len(phn_id)
print("vocab_size:", vocab_size)
odim = config.n_mels
model = DiffSinger(
idim=vocab_size, odim=odim, spk_num=spk_num, **config["model"])
if world_size > 1:
model = DataParallel(model)
print("model done!")
optimizer = build_optimizers(model, **config["optimizer"])
print("optimizer done!")
output_dir = Path(args.output_dir)
output_dir.mkdir(parents=True, exist_ok=True)
if dist.get_rank() == 0:
config_name = args.config.split("/")[-1]
# copy conf to output_dir
shutil.copyfile(args.config, output_dir / config_name)
if "enable_speaker_classifier" in config.model:
enable_spk_cls = config.model.enable_speaker_classifier
else:
enable_spk_cls = False
updater = DiffSingerUpdater(
model=model,
optimizer=optimizer,
dataloader=train_dataloader,
output_dir=output_dir,
enable_spk_cls=enable_spk_cls,
**config["updater"], )
trainer = Trainer(updater, (config.max_epoch, 'epoch'), output_dir)
evaluator = DiffSingerEvaluator(
model,
dev_dataloader,
output_dir=output_dir,
enable_spk_cls=enable_spk_cls,
**config["updater"], )
if dist.get_rank() == 0:
trainer.extend(evaluator, trigger=(1, "epoch"))
trainer.extend(VisualDL(output_dir), trigger=(1, "iteration"))
trainer.extend(
Snapshot(max_size=config.num_snapshots), trigger=(1, 'epoch'))
trainer.run()
def main():
# parse args and config and redirect to train_sp
parser = argparse.ArgumentParser(description="Train a DiffSinger model.")
parser.add_argument("--config", type=str, help="diffsinger config file.")
parser.add_argument("--train-metadata", type=str, help="training data.")
parser.add_argument("--dev-metadata", type=str, help="dev data.")
parser.add_argument("--output-dir", type=str, help="output dir.")
parser.add_argument(
"--ngpu", type=int, default=1, help="if ngpu=0, use cpu.")
parser.add_argument(
"--phones-dict", type=str, default=None, help="phone vocabulary file.")
parser.add_argument(
"--speaker-dict",
type=str,
default=None,
help="speaker id map file for multiple speaker model.")
parser.add_argument(
"--voice-cloning",
type=str2bool,
default=False,
help="whether training voice cloning model.")
args = parser.parse_args()
with open(args.config) as f:
config = CfgNode(yaml.safe_load(f))
print("========Args========")
print(yaml.safe_dump(vars(args)))
print("========Config========")
print(config)
print(
f"master see the word size: {dist.get_world_size()}, from pid: {os.getpid()}"
)
# dispatch
if args.ngpu > 1:
dist.spawn(train_sp, (args, config), nprocs=args.ngpu)
else:
train_sp(args, config)
if __name__ == "__main__":
main()

@ -55,6 +55,10 @@ model_alias = {
"paddlespeech.t2s.models.tacotron2:Tacotron2", "paddlespeech.t2s.models.tacotron2:Tacotron2",
"tacotron2_inference": "tacotron2_inference":
"paddlespeech.t2s.models.tacotron2:Tacotron2Inference", "paddlespeech.t2s.models.tacotron2:Tacotron2Inference",
"diffsinger":
"paddlespeech.t2s.models.diffsinger:DiffSinger",
"diffsinger_inference":
"paddlespeech.t2s.models.diffsinger:DiffSingerInference",
# voc # voc
"pwgan": "pwgan":
"paddlespeech.t2s.models.parallel_wavegan:PWGGenerator", "paddlespeech.t2s.models.parallel_wavegan:PWGGenerator",
@ -141,6 +145,8 @@ def get_test_dataset(test_metadata: List[Dict[str, Any]],
fields += ["spk_emb"] fields += ["spk_emb"]
else: else:
print("single speaker fastspeech2!") print("single speaker fastspeech2!")
elif am_name == 'diffsinger':
fields = ["utt_id", "text", "note", "note_dur", "is_slur"]
elif am_name == 'speedyspeech': elif am_name == 'speedyspeech':
fields = ["utt_id", "phones", "tones"] fields = ["utt_id", "phones", "tones"]
elif am_name == 'tacotron2': elif am_name == 'tacotron2':
@ -347,6 +353,9 @@ def get_am_inference(am: str='fastspeech2_csmsc',
if am_name == 'fastspeech2': if am_name == 'fastspeech2':
am = am_class( am = am_class(
idim=vocab_size, odim=odim, spk_num=spk_num, **am_config["model"]) idim=vocab_size, odim=odim, spk_num=spk_num, **am_config["model"])
if am_name == 'diffsinger':
am = am_class(
idim=vocab_size, odim=odim, spk_num=spk_num, **am_config["model"])
elif am_name == 'speedyspeech': elif am_name == 'speedyspeech':
am = am_class( am = am_class(
vocab_size=vocab_size, vocab_size=vocab_size,

@ -107,6 +107,12 @@ def evaluate(args):
if args.voice_cloning and "spk_emb" in datum: if args.voice_cloning and "spk_emb" in datum:
spk_emb = paddle.to_tensor(np.load(datum["spk_emb"])) spk_emb = paddle.to_tensor(np.load(datum["spk_emb"]))
mel = am_inference(phone_ids, spk_emb=spk_emb) mel = am_inference(phone_ids, spk_emb=spk_emb)
elif am_name == 'diffsinger':
phone_ids = paddle.to_tensor(datum["text"])
note = paddle.to_tensor(datum["note"])
note_dur = paddle.to_tensor(datum["note_dur"])
is_slur = paddle.to_tensor(datum["is_slur"])
mel = am_inference(phone_ids, note=note, note_dur=note_dur, is_slur=is_slur)
# vocoder # vocoder
wav = voc_inference(mel) wav = voc_inference(mel)
@ -136,7 +142,8 @@ def parse_args():
choices=[ choices=[
'speedyspeech_csmsc', 'fastspeech2_csmsc', 'fastspeech2_ljspeech', 'speedyspeech_csmsc', 'fastspeech2_csmsc', 'fastspeech2_ljspeech',
'fastspeech2_aishell3', 'fastspeech2_vctk', 'tacotron2_csmsc', 'fastspeech2_aishell3', 'fastspeech2_vctk', 'tacotron2_csmsc',
'tacotron2_ljspeech', 'tacotron2_aishell3', 'fastspeech2_mix' 'tacotron2_ljspeech', 'tacotron2_aishell3', 'fastspeech2_mix',
"diffsinger_opencpop"
], ],
help='Choose acoustic model type of tts task.') help='Choose acoustic model type of tts task.')
parser.add_argument( parser.add_argument(
@ -172,7 +179,7 @@ def parse_args():
'pwgan_csmsc', 'pwgan_ljspeech', 'pwgan_aishell3', 'pwgan_vctk', 'pwgan_csmsc', 'pwgan_ljspeech', 'pwgan_aishell3', 'pwgan_vctk',
'mb_melgan_csmsc', 'wavernn_csmsc', 'hifigan_csmsc', 'mb_melgan_csmsc', 'wavernn_csmsc', 'hifigan_csmsc',
'hifigan_ljspeech', 'hifigan_aishell3', 'hifigan_vctk', 'hifigan_ljspeech', 'hifigan_aishell3', 'hifigan_vctk',
'style_melgan_csmsc' 'style_melgan_csmsc', "pwgan_opencpop",
], ],
help='Choose vocoder type of tts task.') help='Choose vocoder type of tts task.')
parser.add_argument( parser.add_argument(

@ -0,0 +1,15 @@
# Copyright (c) 2020 PaddlePaddle Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
from .diffsinger import *
from .diffsinger_updater import *

File diff suppressed because it is too large Load Diff

@ -0,0 +1,248 @@
# Copyright (c) 2021 PaddlePaddle Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import logging
from pathlib import Path
from paddle import DataParallel
from paddle import distributed as dist
from paddle.io import DataLoader
from paddle.nn import Layer
from paddle.optimizer import Optimizer
from paddlespeech.t2s.models.diffsinger import DiffSingerLoss
from paddlespeech.t2s.training.extensions.evaluator import StandardEvaluator
from paddlespeech.t2s.training.reporter import report
from paddlespeech.t2s.training.updaters.standard_updater import StandardUpdater
logging.basicConfig(
format='%(asctime)s [%(levelname)s] [%(filename)s:%(lineno)d] %(message)s',
datefmt='[%Y-%m-%d %H:%M:%S]')
logger = logging.getLogger(__name__)
logger.setLevel(logging.INFO)
class DiffSingerUpdater(StandardUpdater):
def __init__(
self,
model: Layer,
optimizer: Optimizer,
dataloader: DataLoader,
init_state=None,
use_masking: bool=False,
spk_loss_scale: float=0.02,
use_weighted_masking: bool=False,
output_dir: Path=None,
enable_spk_cls: bool=False, ):
super().__init__(model, optimizer, dataloader, init_state=None)
self.criterion = DiffSingerLoss(
use_masking=use_masking,
use_weighted_masking=use_weighted_masking, )
log_file = output_dir / 'worker_{}.log'.format(dist.get_rank())
self.filehandler = logging.FileHandler(str(log_file))
logger.addHandler(self.filehandler)
self.logger = logger
self.msg = ""
self.spk_loss_scale = spk_loss_scale
self.enable_spk_cls = enable_spk_cls
def update_core(self, batch):
self.msg = "Rank: {}, ".format(dist.get_rank())
losses_dict = {}
# spk_id!=None in multiple spk diffsinger
spk_id = batch["spk_id"] if "spk_id" in batch else None
spk_emb = batch["spk_emb"] if "spk_emb" in batch else None
# No explicit speaker identifier labels are used during voice cloning training.
if spk_emb is not None:
spk_id = None
if type(
self.model
) == DataParallel and self.model._layers.spk_num and self.model._layers.enable_speaker_classifier:
with self.model.no_sync():
before_outs, after_outs, d_outs, p_outs, e_outs, ys, olens, spk_logits = self.model(
text=batch["text"],
note=batch["note"],
note_dur=batch["note_dur"],
is_slur=batch["is_slur"],
text_lengths=batch["text_lengths"],
speech=batch["speech"],
speech_lengths=batch["speech_lengths"],
durations=batch["durations"],
pitch=batch["pitch"],
energy=batch["energy"],
spk_id=spk_id,
spk_emb=spk_emb)
else:
before_outs, after_outs, d_outs, p_outs, e_outs, ys, olens, spk_logits = self.model(
text=batch["text"],
note=batch["note"],
note_dur=batch["note_dur"],
is_slur=batch["is_slur"],
text_lengths=batch["text_lengths"],
speech=batch["speech"],
speech_lengths=batch["speech_lengths"],
durations=batch["durations"],
pitch=batch["pitch"],
energy=batch["energy"],
spk_id=spk_id,
spk_emb=spk_emb)
l1_loss, duration_loss, pitch_loss, energy_loss, speaker_loss = self.criterion(
after_outs=after_outs,
before_outs=before_outs,
d_outs=d_outs,
p_outs=p_outs,
e_outs=e_outs,
ys=ys,
ds=batch["durations"],
ps=batch["pitch"],
es=batch["energy"],
ilens=batch["text_lengths"],
olens=olens,
spk_logits=spk_logits,
spk_ids=spk_id, )
scaled_speaker_loss = self.spk_loss_scale * speaker_loss
loss = l1_loss + duration_loss + pitch_loss + energy_loss + scaled_speaker_loss
optimizer = self.optimizer
optimizer.clear_grad()
loss.backward()
optimizer.step()
report("train/loss", float(loss))
report("train/l1_loss", float(l1_loss))
report("train/duration_loss", float(duration_loss))
report("train/pitch_loss", float(pitch_loss))
report("train/energy_loss", float(energy_loss))
if self.enable_spk_cls:
report("train/speaker_loss", float(speaker_loss))
report("train/scaled_speaker_loss", float(scaled_speaker_loss))
losses_dict["l1_loss"] = float(l1_loss)
losses_dict["duration_loss"] = float(duration_loss)
losses_dict["pitch_loss"] = float(pitch_loss)
losses_dict["energy_loss"] = float(energy_loss)
losses_dict["energy_loss"] = float(energy_loss)
if self.enable_spk_cls:
losses_dict["speaker_loss"] = float(speaker_loss)
losses_dict["scaled_speaker_loss"] = float(scaled_speaker_loss)
losses_dict["loss"] = float(loss)
self.msg += ', '.join('{}: {:>.6f}'.format(k, v)
for k, v in losses_dict.items())
class DiffSingerEvaluator(StandardEvaluator):
def __init__(self,
model: Layer,
dataloader: DataLoader,
use_masking: bool=False,
use_weighted_masking: bool=False,
spk_loss_scale: float=0.02,
output_dir: Path=None,
enable_spk_cls: bool=False):
super().__init__(model, dataloader)
log_file = output_dir / 'worker_{}.log'.format(dist.get_rank())
self.filehandler = logging.FileHandler(str(log_file))
logger.addHandler(self.filehandler)
self.logger = logger
self.msg = ""
self.spk_loss_scale = spk_loss_scale
self.enable_spk_cls = enable_spk_cls
self.criterion = DiffSingerLoss(
use_masking=use_masking, use_weighted_masking=use_weighted_masking)
def evaluate_core(self, batch):
self.msg = "Evaluate: "
losses_dict = {}
# spk_id!=None in multiple spk diffsinger
spk_id = batch["spk_id"] if "spk_id" in batch else None
spk_emb = batch["spk_emb"] if "spk_emb" in batch else None
if spk_emb is not None:
spk_id = None
if type(
self.model
) == DataParallel and self.model._layers.spk_num and self.model._layers.enable_speaker_classifier:
with self.model.no_sync():
before_outs, after_outs, d_outs, p_outs, e_outs, ys, olens, spk_logits = self.model(
text=batch["text"],
note=batch["note"],
note_dur=batch["note_dur"],
is_slur=batch["is_slur"],
text_lengths=batch["text_lengths"],
speech=batch["speech"],
speech_lengths=batch["speech_lengths"],
durations=batch["durations"],
pitch=batch["pitch"],
energy=batch["energy"],
spk_id=spk_id,
spk_emb=spk_emb)
else:
before_outs, after_outs, d_outs, p_outs, e_outs, ys, olens, spk_logits = self.model(
text=batch["text"],
note=batch["note"],
note_dur=batch["note_dur"],
is_slur=batch["is_slur"],
text_lengths=batch["text_lengths"],
speech=batch["speech"],
speech_lengths=batch["speech_lengths"],
durations=batch["durations"],
pitch=batch["pitch"],
energy=batch["energy"],
spk_id=spk_id,
spk_emb=spk_emb)
l1_loss, duration_loss, pitch_loss, energy_loss, speaker_loss = self.criterion(
after_outs=after_outs,
before_outs=before_outs,
d_outs=d_outs,
p_outs=p_outs,
e_outs=e_outs,
ys=ys,
ds=batch["durations"],
ps=batch["pitch"],
es=batch["energy"],
ilens=batch["text_lengths"],
olens=olens,
spk_logits=spk_logits,
spk_ids=spk_id, )
scaled_speaker_loss = self.spk_loss_scale * speaker_loss
loss = l1_loss + duration_loss + pitch_loss + energy_loss + scaled_speaker_loss
report("eval/loss", float(loss))
report("eval/l1_loss", float(l1_loss))
report("eval/duration_loss", float(duration_loss))
report("eval/pitch_loss", float(pitch_loss))
report("eval/energy_loss", float(energy_loss))
if self.enable_spk_cls:
report("train/speaker_loss", float(speaker_loss))
report("train/scaled_speaker_loss", float(scaled_speaker_loss))
losses_dict["l1_loss"] = float(l1_loss)
losses_dict["duration_loss"] = float(duration_loss)
losses_dict["pitch_loss"] = float(pitch_loss)
losses_dict["energy_loss"] = float(energy_loss)
if self.enable_spk_cls:
losses_dict["speaker_loss"] = float(speaker_loss)
losses_dict["scaled_speaker_loss"] = float(scaled_speaker_loss)
losses_dict["loss"] = float(loss)
self.msg += ', '.join('{}: {:>.6f}'.format(k, v)
for k, v in losses_dict.items())
self.logger.info(self.msg)

@ -390,12 +390,18 @@ class TransformerEncoder(BaseEncoder):
padding_idx=padding_idx, padding_idx=padding_idx,
encoder_type="transformer") encoder_type="transformer")
def forward(self, xs, masks): def forward(self, xs, masks, note_emb=None, note_dur_emb=None, is_slur_emb=None, scale=16):
"""Encoder input sequence. """Encoder input sequence.
Args: Args:
xs(Tensor): xs(Tensor):
Input tensor (#batch, time, idim). Input tensor (#batch, time, idim).
note_emb(Tensor):
Input tensor (#batch, time, attention_dim).
note_dur_emb(Tensor):
Input tensor (#batch, time, attention_dim).
is_slur_emb(Tensor):
Input tensor (#batch, time, attention_dim).
masks(Tensor): masks(Tensor):
Mask tensor (#batch, 1, time). Mask tensor (#batch, 1, time).
@ -406,6 +412,8 @@ class TransformerEncoder(BaseEncoder):
Mask tensor (#batch, 1, time). Mask tensor (#batch, 1, time).
""" """
xs = self.embed(xs) xs = self.embed(xs)
if note_emb is not None:
xs = scale * xs + note_emb + note_dur_emb + is_slur_emb
xs, masks = self.encoders(xs, masks) xs, masks = self.encoders(xs, masks)
if self.normalize_before: if self.normalize_before:
xs = self.after_norm(xs) xs = self.after_norm(xs)

Loading…
Cancel
Save