Merge pull request #1855 from yt605155624/add_vits
[TTS]add vits network scripts, test=ttspull/1948/merge
commit
5ee3cc0c31
@ -0,0 +1,183 @@
|
|||||||
|
# This configuration tested on 4 GPUs (V100) with 32GB GPU
|
||||||
|
# memory. It takes around 2 weeks to finish the training
|
||||||
|
# but 100k iters model should generate reasonable results.
|
||||||
|
###########################################################
|
||||||
|
# FEATURE EXTRACTION SETTING #
|
||||||
|
###########################################################
|
||||||
|
|
||||||
|
fs: 22050 # sr
|
||||||
|
n_fft: 1024 # FFT size (samples).
|
||||||
|
n_shift: 256 # Hop size (samples). 12.5ms
|
||||||
|
win_length: null # Window length (samples). 50ms
|
||||||
|
# If set to null, it will be the same as fft_size.
|
||||||
|
window: "hann" # Window function.
|
||||||
|
|
||||||
|
|
||||||
|
##########################################################
|
||||||
|
# TTS MODEL SETTING #
|
||||||
|
##########################################################
|
||||||
|
model:
|
||||||
|
# generator related
|
||||||
|
generator_type: vits_generator
|
||||||
|
generator_params:
|
||||||
|
hidden_channels: 192
|
||||||
|
spks: -1
|
||||||
|
global_channels: -1
|
||||||
|
segment_size: 32
|
||||||
|
text_encoder_attention_heads: 2
|
||||||
|
text_encoder_ffn_expand: 4
|
||||||
|
text_encoder_blocks: 6
|
||||||
|
text_encoder_positionwise_layer_type: "conv1d"
|
||||||
|
text_encoder_positionwise_conv_kernel_size: 3
|
||||||
|
text_encoder_positional_encoding_layer_type: "rel_pos"
|
||||||
|
text_encoder_self_attention_layer_type: "rel_selfattn"
|
||||||
|
text_encoder_activation_type: "swish"
|
||||||
|
text_encoder_normalize_before: True
|
||||||
|
text_encoder_dropout_rate: 0.1
|
||||||
|
text_encoder_positional_dropout_rate: 0.0
|
||||||
|
text_encoder_attention_dropout_rate: 0.1
|
||||||
|
use_macaron_style_in_text_encoder: True
|
||||||
|
use_conformer_conv_in_text_encoder: False
|
||||||
|
text_encoder_conformer_kernel_size: -1
|
||||||
|
decoder_kernel_size: 7
|
||||||
|
decoder_channels: 512
|
||||||
|
decoder_upsample_scales: [8, 8, 2, 2]
|
||||||
|
decoder_upsample_kernel_sizes: [16, 16, 4, 4]
|
||||||
|
decoder_resblock_kernel_sizes: [3, 7, 11]
|
||||||
|
decoder_resblock_dilations: [[1, 3, 5], [1, 3, 5], [1, 3, 5]]
|
||||||
|
use_weight_norm_in_decoder: True
|
||||||
|
posterior_encoder_kernel_size: 5
|
||||||
|
posterior_encoder_layers: 16
|
||||||
|
posterior_encoder_stacks: 1
|
||||||
|
posterior_encoder_base_dilation: 1
|
||||||
|
posterior_encoder_dropout_rate: 0.0
|
||||||
|
use_weight_norm_in_posterior_encoder: True
|
||||||
|
flow_flows: 4
|
||||||
|
flow_kernel_size: 5
|
||||||
|
flow_base_dilation: 1
|
||||||
|
flow_layers: 4
|
||||||
|
flow_dropout_rate: 0.0
|
||||||
|
use_weight_norm_in_flow: True
|
||||||
|
use_only_mean_in_flow: True
|
||||||
|
stochastic_duration_predictor_kernel_size: 3
|
||||||
|
stochastic_duration_predictor_dropout_rate: 0.5
|
||||||
|
stochastic_duration_predictor_flows: 4
|
||||||
|
stochastic_duration_predictor_dds_conv_layers: 3
|
||||||
|
# discriminator related
|
||||||
|
discriminator_type: hifigan_multi_scale_multi_period_discriminator
|
||||||
|
discriminator_params:
|
||||||
|
scales: 1
|
||||||
|
scale_downsample_pooling: "AvgPool1D"
|
||||||
|
scale_downsample_pooling_params:
|
||||||
|
kernel_size: 4
|
||||||
|
stride: 2
|
||||||
|
padding: 2
|
||||||
|
scale_discriminator_params:
|
||||||
|
in_channels: 1
|
||||||
|
out_channels: 1
|
||||||
|
kernel_sizes: [15, 41, 5, 3]
|
||||||
|
channels: 128
|
||||||
|
max_downsample_channels: 1024
|
||||||
|
max_groups: 16
|
||||||
|
bias: True
|
||||||
|
downsample_scales: [2, 2, 4, 4, 1]
|
||||||
|
nonlinear_activation: "leakyrelu"
|
||||||
|
nonlinear_activation_params:
|
||||||
|
negative_slope: 0.1
|
||||||
|
use_weight_norm: True
|
||||||
|
use_spectral_norm: False
|
||||||
|
follow_official_norm: False
|
||||||
|
periods: [2, 3, 5, 7, 11]
|
||||||
|
period_discriminator_params:
|
||||||
|
in_channels: 1
|
||||||
|
out_channels: 1
|
||||||
|
kernel_sizes: [5, 3]
|
||||||
|
channels: 32
|
||||||
|
downsample_scales: [3, 3, 3, 3, 1]
|
||||||
|
max_downsample_channels: 1024
|
||||||
|
bias: True
|
||||||
|
nonlinear_activation: "leakyrelu"
|
||||||
|
nonlinear_activation_params:
|
||||||
|
negative_slope: 0.1
|
||||||
|
use_weight_norm: True
|
||||||
|
use_spectral_norm: False
|
||||||
|
# others
|
||||||
|
sampling_rate: 22050 # needed in the inference for saving wav
|
||||||
|
cache_generator_outputs: True # whether to cache generator outputs in the training
|
||||||
|
|
||||||
|
###########################################################
|
||||||
|
# LOSS SETTING #
|
||||||
|
###########################################################
|
||||||
|
# loss function related
|
||||||
|
generator_adv_loss_params:
|
||||||
|
average_by_discriminators: False # whether to average loss value by #discriminators
|
||||||
|
loss_type: mse # loss type, "mse" or "hinge"
|
||||||
|
discriminator_adv_loss_params:
|
||||||
|
average_by_discriminators: False # whether to average loss value by #discriminators
|
||||||
|
loss_type: mse # loss type, "mse" or "hinge"
|
||||||
|
feat_match_loss_params:
|
||||||
|
average_by_discriminators: False # whether to average loss value by #discriminators
|
||||||
|
average_by_layers: False # whether to average loss value by #layers of each discriminator
|
||||||
|
include_final_outputs: True # whether to include final outputs for loss calculation
|
||||||
|
mel_loss_params:
|
||||||
|
fs: 22050 # must be the same as the training data
|
||||||
|
fft_size: 1024 # fft points
|
||||||
|
hop_size: 256 # hop size
|
||||||
|
win_length: null # window length
|
||||||
|
window: hann # window type
|
||||||
|
num_mels: 80 # number of Mel basis
|
||||||
|
fmin: 0 # minimum frequency for Mel basis
|
||||||
|
fmax: null # maximum frequency for Mel basis
|
||||||
|
log_base: null # null represent natural log
|
||||||
|
|
||||||
|
###########################################################
|
||||||
|
# ADVERSARIAL LOSS SETTING #
|
||||||
|
###########################################################
|
||||||
|
lambda_adv: 1.0 # loss scaling coefficient for adversarial loss
|
||||||
|
lambda_mel: 45.0 # loss scaling coefficient for Mel loss
|
||||||
|
lambda_feat_match: 2.0 # loss scaling coefficient for feat match loss
|
||||||
|
lambda_dur: 1.0 # loss scaling coefficient for duration loss
|
||||||
|
lambda_kl: 1.0 # loss scaling coefficient for KL divergence loss
|
||||||
|
# others
|
||||||
|
sampling_rate: 22050 # needed in the inference for saving wav
|
||||||
|
cache_generator_outputs: True # whether to cache generator outputs in the training
|
||||||
|
|
||||||
|
|
||||||
|
###########################################################
|
||||||
|
# DATA LOADER SETTING #
|
||||||
|
###########################################################
|
||||||
|
batch_size: 64 # Batch size.
|
||||||
|
num_workers: 4 # Number of workers in DataLoader.
|
||||||
|
|
||||||
|
##########################################################
|
||||||
|
# OPTIMIZER & SCHEDULER SETTING #
|
||||||
|
##########################################################
|
||||||
|
# optimizer setting for generator
|
||||||
|
generator_optimizer_params:
|
||||||
|
beta1: 0.8
|
||||||
|
beta2: 0.99
|
||||||
|
epsilon: 1.0e-9
|
||||||
|
weight_decay: 0.0
|
||||||
|
generator_scheduler: exponential_decay
|
||||||
|
generator_scheduler_params:
|
||||||
|
learning_rate: 2.0e-4
|
||||||
|
gamma: 0.999875
|
||||||
|
|
||||||
|
# optimizer setting for discriminator
|
||||||
|
discriminator_optimizer_params:
|
||||||
|
beta1: 0.8
|
||||||
|
beta2: 0.99
|
||||||
|
epsilon: 1.0e-9
|
||||||
|
weight_decay: 0.0
|
||||||
|
discriminator_scheduler: exponential_decay
|
||||||
|
discriminator_scheduler_params:
|
||||||
|
learning_rate: 2.0e-4
|
||||||
|
gamma: 0.999875
|
||||||
|
generator_first: False # whether to start updating generator first
|
||||||
|
|
||||||
|
##########################################################
|
||||||
|
# OTHER TRAINING SETTING #
|
||||||
|
##########################################################
|
||||||
|
max_epoch: 1000 # number of epochs
|
||||||
|
num_snapshots: 10 # max number of snapshots to keep while training
|
||||||
|
seed: 777 # random seed number
|
@ -0,0 +1,64 @@
|
|||||||
|
#!/bin/bash
|
||||||
|
|
||||||
|
stage=0
|
||||||
|
stop_stage=100
|
||||||
|
|
||||||
|
config_path=$1
|
||||||
|
|
||||||
|
if [ ${stage} -le 0 ] && [ ${stop_stage} -ge 0 ]; then
|
||||||
|
# get durations from MFA's result
|
||||||
|
echo "Generate durations.txt from MFA results ..."
|
||||||
|
python3 ${MAIN_ROOT}/utils/gen_duration_from_textgrid.py \
|
||||||
|
--inputdir=./baker_alignment_tone \
|
||||||
|
--output=durations.txt \
|
||||||
|
--config=${config_path}
|
||||||
|
fi
|
||||||
|
|
||||||
|
if [ ${stage} -le 1 ] && [ ${stop_stage} -ge 1 ]; then
|
||||||
|
# extract features
|
||||||
|
echo "Extract features ..."
|
||||||
|
python3 ${BIN_DIR}/preprocess.py \
|
||||||
|
--dataset=baker \
|
||||||
|
--rootdir=~/datasets/BZNSYP/ \
|
||||||
|
--dumpdir=dump \
|
||||||
|
--dur-file=durations.txt \
|
||||||
|
--config=${config_path} \
|
||||||
|
--num-cpu=20 \
|
||||||
|
--cut-sil=True
|
||||||
|
fi
|
||||||
|
|
||||||
|
if [ ${stage} -le 2 ] && [ ${stop_stage} -ge 2 ]; then
|
||||||
|
# get features' stats(mean and std)
|
||||||
|
echo "Get features' stats ..."
|
||||||
|
python3 ${MAIN_ROOT}/utils/compute_statistics.py \
|
||||||
|
--metadata=dump/train/raw/metadata.jsonl \
|
||||||
|
--field-name="feats"
|
||||||
|
fi
|
||||||
|
|
||||||
|
if [ ${stage} -le 3 ] && [ ${stop_stage} -ge 3 ]; then
|
||||||
|
# normalize and covert phone/speaker to id, dev and test should use train's stats
|
||||||
|
echo "Normalize ..."
|
||||||
|
python3 ${BIN_DIR}/normalize.py \
|
||||||
|
--metadata=dump/train/raw/metadata.jsonl \
|
||||||
|
--dumpdir=dump/train/norm \
|
||||||
|
--feats-stats=dump/train/feats_stats.npy \
|
||||||
|
--phones-dict=dump/phone_id_map.txt \
|
||||||
|
--speaker-dict=dump/speaker_id_map.txt \
|
||||||
|
--skip-wav-copy
|
||||||
|
|
||||||
|
python3 ${BIN_DIR}/normalize.py \
|
||||||
|
--metadata=dump/dev/raw/metadata.jsonl \
|
||||||
|
--dumpdir=dump/dev/norm \
|
||||||
|
--feats-stats=dump/train/feats_stats.npy \
|
||||||
|
--phones-dict=dump/phone_id_map.txt \
|
||||||
|
--speaker-dict=dump/speaker_id_map.txt \
|
||||||
|
--skip-wav-copy
|
||||||
|
|
||||||
|
python3 ${BIN_DIR}/normalize.py \
|
||||||
|
--metadata=dump/test/raw/metadata.jsonl \
|
||||||
|
--dumpdir=dump/test/norm \
|
||||||
|
--feats-stats=dump/train/feats_stats.npy \
|
||||||
|
--phones-dict=dump/phone_id_map.txt \
|
||||||
|
--speaker-dict=dump/speaker_id_map.txt \
|
||||||
|
--skip-wav-copy
|
||||||
|
fi
|
@ -0,0 +1,18 @@
|
|||||||
|
#!/bin/bash
|
||||||
|
|
||||||
|
config_path=$1
|
||||||
|
train_output_path=$2
|
||||||
|
ckpt_name=$3
|
||||||
|
stage=0
|
||||||
|
stop_stage=0
|
||||||
|
|
||||||
|
if [ ${stage} -le 0 ] && [ ${stop_stage} -ge 0 ]; then
|
||||||
|
FLAGS_allocator_strategy=naive_best_fit \
|
||||||
|
FLAGS_fraction_of_gpu_memory_to_use=0.01 \
|
||||||
|
python3 ${BIN_DIR}/synthesize.py \
|
||||||
|
--config=${config_path} \
|
||||||
|
--ckpt=${train_output_path}/checkpoints/${ckpt_name} \
|
||||||
|
--phones_dict=dump/phone_id_map.txt \
|
||||||
|
--test_metadata=dump/test/norm/metadata.jsonl \
|
||||||
|
--output_dir=${train_output_path}/test
|
||||||
|
fi
|
@ -0,0 +1,18 @@
|
|||||||
|
#!/bin/bash
|
||||||
|
|
||||||
|
config_path=$1
|
||||||
|
train_output_path=$2
|
||||||
|
ckpt_name=$3
|
||||||
|
stage=0
|
||||||
|
stop_stage=0
|
||||||
|
|
||||||
|
if [ ${stage} -le 0 ] && [ ${stop_stage} -ge 0 ]; then
|
||||||
|
FLAGS_allocator_strategy=naive_best_fit \
|
||||||
|
FLAGS_fraction_of_gpu_memory_to_use=0.01 \
|
||||||
|
python3 ${BIN_DIR}/synthesize_e2e.py \
|
||||||
|
--config=${config_path} \
|
||||||
|
--ckpt=${train_output_path}/checkpoints/${ckpt_name} \
|
||||||
|
--phones_dict=dump/phone_id_map.txt \
|
||||||
|
--output_dir=${train_output_path}/test_e2e \
|
||||||
|
--text=${BIN_DIR}/../sentences.txt
|
||||||
|
fi
|
@ -0,0 +1,12 @@
|
|||||||
|
#!/bin/bash
|
||||||
|
|
||||||
|
config_path=$1
|
||||||
|
train_output_path=$2
|
||||||
|
|
||||||
|
python3 ${BIN_DIR}/train.py \
|
||||||
|
--train-metadata=dump/train/norm/metadata.jsonl \
|
||||||
|
--dev-metadata=dump/dev/norm/metadata.jsonl \
|
||||||
|
--config=${config_path} \
|
||||||
|
--output-dir=${train_output_path} \
|
||||||
|
--ngpu=4 \
|
||||||
|
--phones-dict=dump/phone_id_map.txt
|
@ -0,0 +1,13 @@
|
|||||||
|
#!/bin/bash
|
||||||
|
export MAIN_ROOT=`realpath ${PWD}/../../../`
|
||||||
|
|
||||||
|
export PATH=${MAIN_ROOT}:${MAIN_ROOT}/utils:${PATH}
|
||||||
|
export LC_ALL=C
|
||||||
|
|
||||||
|
export PYTHONDONTWRITEBYTECODE=1
|
||||||
|
# Use UTF-8 in Python to avoid UnicodeDecodeError when LC_ALL=C
|
||||||
|
export PYTHONIOENCODING=UTF-8
|
||||||
|
export PYTHONPATH=${MAIN_ROOT}:${PYTHONPATH}
|
||||||
|
|
||||||
|
MODEL=vits
|
||||||
|
export BIN_DIR=${MAIN_ROOT}/paddlespeech/t2s/exps/${MODEL}
|
@ -0,0 +1,36 @@
|
|||||||
|
#!/bin/bash
|
||||||
|
|
||||||
|
set -e
|
||||||
|
source path.sh
|
||||||
|
|
||||||
|
gpus=0,1
|
||||||
|
stage=0
|
||||||
|
stop_stage=100
|
||||||
|
|
||||||
|
conf_path=conf/default.yaml
|
||||||
|
train_output_path=exp/default
|
||||||
|
ckpt_name=snapshot_iter_153.pdz
|
||||||
|
|
||||||
|
# with the following command, you can choose the stage range you want to run
|
||||||
|
# such as `./run.sh --stage 0 --stop-stage 0`
|
||||||
|
# this can not be mixed use with `$1`, `$2` ...
|
||||||
|
source ${MAIN_ROOT}/utils/parse_options.sh || exit 1
|
||||||
|
|
||||||
|
if [ ${stage} -le 0 ] && [ ${stop_stage} -ge 0 ]; then
|
||||||
|
# prepare data
|
||||||
|
./local/preprocess.sh ${conf_path} || exit -1
|
||||||
|
fi
|
||||||
|
|
||||||
|
if [ ${stage} -le 1 ] && [ ${stop_stage} -ge 1 ]; then
|
||||||
|
# train model, all `ckpt` under `train_output_path/checkpoints/` dir
|
||||||
|
CUDA_VISIBLE_DEVICES=${gpus} ./local/train.sh ${conf_path} ${train_output_path} || exit -1
|
||||||
|
fi
|
||||||
|
|
||||||
|
if [ ${stage} -le 2 ] && [ ${stop_stage} -ge 2 ]; then
|
||||||
|
CUDA_VISIBLE_DEVICES=${gpus} ./local/synthesize.sh ${conf_path} ${train_output_path} ${ckpt_name} || exit -1
|
||||||
|
fi
|
||||||
|
|
||||||
|
if [ ${stage} -le 3 ] && [ ${stop_stage} -ge 3 ]; then
|
||||||
|
# synthesize_e2e, vocoder is pwgan
|
||||||
|
CUDA_VISIBLE_DEVICES=${gpus} ./local/synthesize_e2e.sh ${conf_path} ${train_output_path} ${ckpt_name} || exit -1
|
||||||
|
fi
|
@ -0,0 +1,165 @@
|
|||||||
|
# Copyright (c) 2021 PaddlePaddle Authors. All Rights Reserved.
|
||||||
|
#
|
||||||
|
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||||
|
# you may not use this file except in compliance with the License.
|
||||||
|
# You may obtain a copy of the License at
|
||||||
|
#
|
||||||
|
# http://www.apache.org/licenses/LICENSE-2.0
|
||||||
|
#
|
||||||
|
# Unless required by applicable law or agreed to in writing, software
|
||||||
|
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||||
|
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||||
|
# See the License for the specific language governing permissions and
|
||||||
|
# limitations under the License.
|
||||||
|
"""Normalize feature files and dump them."""
|
||||||
|
import argparse
|
||||||
|
import logging
|
||||||
|
from operator import itemgetter
|
||||||
|
from pathlib import Path
|
||||||
|
|
||||||
|
import jsonlines
|
||||||
|
import numpy as np
|
||||||
|
from sklearn.preprocessing import StandardScaler
|
||||||
|
from tqdm import tqdm
|
||||||
|
|
||||||
|
from paddlespeech.t2s.datasets.data_table import DataTable
|
||||||
|
|
||||||
|
|
||||||
|
def main():
|
||||||
|
"""Run preprocessing process."""
|
||||||
|
parser = argparse.ArgumentParser(
|
||||||
|
description="Normalize dumped raw features (See detail in parallel_wavegan/bin/normalize.py)."
|
||||||
|
)
|
||||||
|
parser.add_argument(
|
||||||
|
"--metadata",
|
||||||
|
type=str,
|
||||||
|
required=True,
|
||||||
|
help="directory including feature files to be normalized. "
|
||||||
|
"you need to specify either *-scp or rootdir.")
|
||||||
|
|
||||||
|
parser.add_argument(
|
||||||
|
"--dumpdir",
|
||||||
|
type=str,
|
||||||
|
required=True,
|
||||||
|
help="directory to dump normalized feature files.")
|
||||||
|
parser.add_argument(
|
||||||
|
"--feats-stats",
|
||||||
|
type=str,
|
||||||
|
required=True,
|
||||||
|
help="speech statistics file.")
|
||||||
|
parser.add_argument(
|
||||||
|
"--skip-wav-copy",
|
||||||
|
default=False,
|
||||||
|
action="store_true",
|
||||||
|
help="whether to skip the copy of wav files.")
|
||||||
|
|
||||||
|
parser.add_argument(
|
||||||
|
"--phones-dict", type=str, default=None, help="phone vocabulary file.")
|
||||||
|
parser.add_argument(
|
||||||
|
"--speaker-dict", type=str, default=None, help="speaker id map file.")
|
||||||
|
parser.add_argument(
|
||||||
|
"--verbose",
|
||||||
|
type=int,
|
||||||
|
default=1,
|
||||||
|
help="logging level. higher is more logging. (default=1)")
|
||||||
|
args = parser.parse_args()
|
||||||
|
|
||||||
|
# set logger
|
||||||
|
if args.verbose > 1:
|
||||||
|
logging.basicConfig(
|
||||||
|
level=logging.DEBUG,
|
||||||
|
format="%(asctime)s (%(module)s:%(lineno)d) %(levelname)s: %(message)s"
|
||||||
|
)
|
||||||
|
elif args.verbose > 0:
|
||||||
|
logging.basicConfig(
|
||||||
|
level=logging.INFO,
|
||||||
|
format="%(asctime)s (%(module)s:%(lineno)d) %(levelname)s: %(message)s"
|
||||||
|
)
|
||||||
|
else:
|
||||||
|
logging.basicConfig(
|
||||||
|
level=logging.WARN,
|
||||||
|
format="%(asctime)s (%(module)s:%(lineno)d) %(levelname)s: %(message)s"
|
||||||
|
)
|
||||||
|
logging.warning('Skip DEBUG/INFO messages')
|
||||||
|
|
||||||
|
dumpdir = Path(args.dumpdir).expanduser()
|
||||||
|
# use absolute path
|
||||||
|
dumpdir = dumpdir.resolve()
|
||||||
|
dumpdir.mkdir(parents=True, exist_ok=True)
|
||||||
|
|
||||||
|
# get dataset
|
||||||
|
with jsonlines.open(args.metadata, 'r') as reader:
|
||||||
|
metadata = list(reader)
|
||||||
|
dataset = DataTable(
|
||||||
|
metadata,
|
||||||
|
converters={
|
||||||
|
"feats": np.load,
|
||||||
|
"wave": None if args.skip_wav_copy else np.load,
|
||||||
|
})
|
||||||
|
logging.info(f"The number of files = {len(dataset)}.")
|
||||||
|
|
||||||
|
# restore scaler
|
||||||
|
feats_scaler = StandardScaler()
|
||||||
|
feats_scaler.mean_ = np.load(args.feats_stats)[0]
|
||||||
|
feats_scaler.scale_ = np.load(args.feats_stats)[1]
|
||||||
|
feats_scaler.n_features_in_ = feats_scaler.mean_.shape[0]
|
||||||
|
|
||||||
|
vocab_phones = {}
|
||||||
|
with open(args.phones_dict, 'rt') as f:
|
||||||
|
phn_id = [line.strip().split() for line in f.readlines()]
|
||||||
|
for phn, id in phn_id:
|
||||||
|
vocab_phones[phn] = int(id)
|
||||||
|
|
||||||
|
vocab_speaker = {}
|
||||||
|
with open(args.speaker_dict, 'rt') as f:
|
||||||
|
spk_id = [line.strip().split() for line in f.readlines()]
|
||||||
|
for spk, id in spk_id:
|
||||||
|
vocab_speaker[spk] = int(id)
|
||||||
|
|
||||||
|
# process each file
|
||||||
|
output_metadata = []
|
||||||
|
|
||||||
|
for item in tqdm(dataset):
|
||||||
|
utt_id = item['utt_id']
|
||||||
|
feats = item['feats']
|
||||||
|
wave = item['wave']
|
||||||
|
|
||||||
|
# normalize
|
||||||
|
feats = feats_scaler.transform(feats)
|
||||||
|
feats_path = dumpdir / f"{utt_id}_feats.npy"
|
||||||
|
np.save(feats_path, feats.astype(np.float32), allow_pickle=False)
|
||||||
|
|
||||||
|
if not args.skip_wav_copy:
|
||||||
|
wav_path = dumpdir / f"{utt_id}_wave.npy"
|
||||||
|
np.save(wav_path, wave.astype(np.float32), allow_pickle=False)
|
||||||
|
else:
|
||||||
|
wav_path = wave
|
||||||
|
|
||||||
|
phone_ids = [vocab_phones[p] for p in item['phones']]
|
||||||
|
spk_id = vocab_speaker[item["speaker"]]
|
||||||
|
|
||||||
|
record = {
|
||||||
|
"utt_id": item['utt_id'],
|
||||||
|
"text": phone_ids,
|
||||||
|
"text_lengths": item['text_lengths'],
|
||||||
|
'feats': str(feats_path),
|
||||||
|
"feats_lengths": item['feats_lengths'],
|
||||||
|
"wave": str(wav_path),
|
||||||
|
"spk_id": spk_id,
|
||||||
|
}
|
||||||
|
|
||||||
|
# add spk_emb for voice cloning
|
||||||
|
if "spk_emb" in item:
|
||||||
|
record["spk_emb"] = str(item["spk_emb"])
|
||||||
|
|
||||||
|
output_metadata.append(record)
|
||||||
|
output_metadata.sort(key=itemgetter('utt_id'))
|
||||||
|
output_metadata_path = Path(args.dumpdir) / "metadata.jsonl"
|
||||||
|
with jsonlines.open(output_metadata_path, 'w') as writer:
|
||||||
|
for item in output_metadata:
|
||||||
|
writer.write(item)
|
||||||
|
logging.info(f"metadata dumped into {output_metadata_path}")
|
||||||
|
|
||||||
|
|
||||||
|
if __name__ == "__main__":
|
||||||
|
main()
|
@ -0,0 +1,348 @@
|
|||||||
|
# Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved.
|
||||||
|
#
|
||||||
|
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||||
|
# you may not use this file except in compliance with the License.
|
||||||
|
# You may obtain a copy of the License at
|
||||||
|
#
|
||||||
|
# http://www.apache.org/licenses/LICENSE-2.0
|
||||||
|
#
|
||||||
|
# Unless required by applicable law or agreed to in writing, software
|
||||||
|
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||||
|
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||||
|
# See the License for the specific language governing permissions and
|
||||||
|
# limitations under the License.
|
||||||
|
import argparse
|
||||||
|
import os
|
||||||
|
from concurrent.futures import ThreadPoolExecutor
|
||||||
|
from operator import itemgetter
|
||||||
|
from pathlib import Path
|
||||||
|
from typing import Any
|
||||||
|
from typing import Dict
|
||||||
|
from typing import List
|
||||||
|
|
||||||
|
import jsonlines
|
||||||
|
import librosa
|
||||||
|
import numpy as np
|
||||||
|
import tqdm
|
||||||
|
import yaml
|
||||||
|
from yacs.config import CfgNode
|
||||||
|
|
||||||
|
from paddlespeech.t2s.datasets.get_feats import LinearSpectrogram
|
||||||
|
from paddlespeech.t2s.datasets.preprocess_utils import compare_duration_and_mel_length
|
||||||
|
from paddlespeech.t2s.datasets.preprocess_utils import get_input_token
|
||||||
|
from paddlespeech.t2s.datasets.preprocess_utils import get_phn_dur
|
||||||
|
from paddlespeech.t2s.datasets.preprocess_utils import get_spk_id_map
|
||||||
|
from paddlespeech.t2s.datasets.preprocess_utils import merge_silence
|
||||||
|
from paddlespeech.t2s.utils import str2bool
|
||||||
|
|
||||||
|
|
||||||
|
def process_sentence(config: Dict[str, Any],
|
||||||
|
fp: Path,
|
||||||
|
sentences: Dict,
|
||||||
|
output_dir: Path,
|
||||||
|
spec_extractor=None,
|
||||||
|
cut_sil: bool=True,
|
||||||
|
spk_emb_dir: Path=None):
|
||||||
|
utt_id = fp.stem
|
||||||
|
# for vctk
|
||||||
|
if utt_id.endswith("_mic2"):
|
||||||
|
utt_id = utt_id[:-5]
|
||||||
|
record = None
|
||||||
|
if utt_id in sentences:
|
||||||
|
# reading, resampling may occur
|
||||||
|
wav, _ = librosa.load(str(fp), sr=config.fs)
|
||||||
|
if len(wav.shape) != 1:
|
||||||
|
return record
|
||||||
|
max_value = np.abs(wav).max()
|
||||||
|
if max_value > 1.0:
|
||||||
|
wav = wav / max_value
|
||||||
|
assert len(wav.shape) == 1, f"{utt_id} is not a mono-channel audio."
|
||||||
|
assert np.abs(wav).max(
|
||||||
|
) <= 1.0, f"{utt_id} is seems to be different that 16 bit PCM."
|
||||||
|
phones = sentences[utt_id][0]
|
||||||
|
durations = sentences[utt_id][1]
|
||||||
|
speaker = sentences[utt_id][2]
|
||||||
|
d_cumsum = np.pad(np.array(durations).cumsum(0), (1, 0), 'constant')
|
||||||
|
# little imprecise than use *.TextGrid directly
|
||||||
|
times = librosa.frames_to_time(
|
||||||
|
d_cumsum, sr=config.fs, hop_length=config.n_shift)
|
||||||
|
if cut_sil:
|
||||||
|
start = 0
|
||||||
|
end = d_cumsum[-1]
|
||||||
|
if phones[0] == "sil" and len(durations) > 1:
|
||||||
|
start = times[1]
|
||||||
|
durations = durations[1:]
|
||||||
|
phones = phones[1:]
|
||||||
|
if phones[-1] == 'sil' and len(durations) > 1:
|
||||||
|
end = times[-2]
|
||||||
|
durations = durations[:-1]
|
||||||
|
phones = phones[:-1]
|
||||||
|
sentences[utt_id][0] = phones
|
||||||
|
sentences[utt_id][1] = durations
|
||||||
|
start, end = librosa.time_to_samples([start, end], sr=config.fs)
|
||||||
|
wav = wav[start:end]
|
||||||
|
# extract mel feats
|
||||||
|
spec = spec_extractor.get_linear_spectrogram(wav)
|
||||||
|
# change duration according to mel_length
|
||||||
|
compare_duration_and_mel_length(sentences, utt_id, spec)
|
||||||
|
# utt_id may be popped in compare_duration_and_mel_length
|
||||||
|
if utt_id not in sentences:
|
||||||
|
return None
|
||||||
|
phones = sentences[utt_id][0]
|
||||||
|
durations = sentences[utt_id][1]
|
||||||
|
num_frames = spec.shape[0]
|
||||||
|
assert sum(durations) == num_frames
|
||||||
|
|
||||||
|
if wav.size < num_frames * config.n_shift:
|
||||||
|
wav = np.pad(
|
||||||
|
wav, (0, num_frames * config.n_shift - wav.size),
|
||||||
|
mode="reflect")
|
||||||
|
else:
|
||||||
|
wav = wav[:num_frames * config.n_shift]
|
||||||
|
num_samples = wav.shape[0]
|
||||||
|
|
||||||
|
spec_path = output_dir / (utt_id + "_feats.npy")
|
||||||
|
wav_path = output_dir / (utt_id + "_wave.npy")
|
||||||
|
# (num_samples, )
|
||||||
|
np.save(wav_path, wav)
|
||||||
|
# (num_frames, aux_channels)
|
||||||
|
np.save(spec_path, spec)
|
||||||
|
|
||||||
|
record = {
|
||||||
|
"utt_id": utt_id,
|
||||||
|
"phones": phones,
|
||||||
|
"text_lengths": len(phones),
|
||||||
|
"feats": str(spec_path),
|
||||||
|
"feats_lengths": num_frames,
|
||||||
|
"wave": str(wav_path),
|
||||||
|
"speaker": speaker
|
||||||
|
}
|
||||||
|
if spk_emb_dir:
|
||||||
|
if speaker in os.listdir(spk_emb_dir):
|
||||||
|
embed_name = utt_id + ".npy"
|
||||||
|
embed_path = spk_emb_dir / speaker / embed_name
|
||||||
|
if embed_path.is_file():
|
||||||
|
record["spk_emb"] = str(embed_path)
|
||||||
|
else:
|
||||||
|
return None
|
||||||
|
return record
|
||||||
|
|
||||||
|
|
||||||
|
def process_sentences(config,
|
||||||
|
fps: List[Path],
|
||||||
|
sentences: Dict,
|
||||||
|
output_dir: Path,
|
||||||
|
spec_extractor=None,
|
||||||
|
nprocs: int=1,
|
||||||
|
cut_sil: bool=True,
|
||||||
|
spk_emb_dir: Path=None):
|
||||||
|
if nprocs == 1:
|
||||||
|
results = []
|
||||||
|
for fp in tqdm.tqdm(fps, total=len(fps)):
|
||||||
|
record = process_sentence(
|
||||||
|
config=config,
|
||||||
|
fp=fp,
|
||||||
|
sentences=sentences,
|
||||||
|
output_dir=output_dir,
|
||||||
|
spec_extractor=spec_extractor,
|
||||||
|
cut_sil=cut_sil,
|
||||||
|
spk_emb_dir=spk_emb_dir)
|
||||||
|
if record:
|
||||||
|
results.append(record)
|
||||||
|
else:
|
||||||
|
with ThreadPoolExecutor(nprocs) as pool:
|
||||||
|
futures = []
|
||||||
|
with tqdm.tqdm(total=len(fps)) as progress:
|
||||||
|
for fp in fps:
|
||||||
|
future = pool.submit(process_sentence, config, fp,
|
||||||
|
sentences, output_dir, spec_extractor,
|
||||||
|
cut_sil, spk_emb_dir)
|
||||||
|
future.add_done_callback(lambda p: progress.update())
|
||||||
|
futures.append(future)
|
||||||
|
|
||||||
|
results = []
|
||||||
|
for ft in futures:
|
||||||
|
record = ft.result()
|
||||||
|
if record:
|
||||||
|
results.append(record)
|
||||||
|
|
||||||
|
results.sort(key=itemgetter("utt_id"))
|
||||||
|
with jsonlines.open(output_dir / "metadata.jsonl", 'w') as writer:
|
||||||
|
for item in results:
|
||||||
|
writer.write(item)
|
||||||
|
print("Done")
|
||||||
|
|
||||||
|
|
||||||
|
def main():
|
||||||
|
# parse config and args
|
||||||
|
parser = argparse.ArgumentParser(
|
||||||
|
description="Preprocess audio and then extract features.")
|
||||||
|
|
||||||
|
parser.add_argument(
|
||||||
|
"--dataset",
|
||||||
|
default="baker",
|
||||||
|
type=str,
|
||||||
|
help="name of dataset, should in {baker, aishell3, ljspeech, vctk} now")
|
||||||
|
|
||||||
|
parser.add_argument(
|
||||||
|
"--rootdir", default=None, type=str, help="directory to dataset.")
|
||||||
|
|
||||||
|
parser.add_argument(
|
||||||
|
"--dumpdir",
|
||||||
|
type=str,
|
||||||
|
required=True,
|
||||||
|
help="directory to dump feature files.")
|
||||||
|
parser.add_argument(
|
||||||
|
"--dur-file", default=None, type=str, help="path to durations.txt.")
|
||||||
|
|
||||||
|
parser.add_argument("--config", type=str, help="fastspeech2 config file.")
|
||||||
|
|
||||||
|
parser.add_argument(
|
||||||
|
"--verbose",
|
||||||
|
type=int,
|
||||||
|
default=1,
|
||||||
|
help="logging level. higher is more logging. (default=1)")
|
||||||
|
parser.add_argument(
|
||||||
|
"--num-cpu", type=int, default=1, help="number of process.")
|
||||||
|
|
||||||
|
parser.add_argument(
|
||||||
|
"--cut-sil",
|
||||||
|
type=str2bool,
|
||||||
|
default=True,
|
||||||
|
help="whether cut sil in the edge of audio")
|
||||||
|
|
||||||
|
parser.add_argument(
|
||||||
|
"--spk_emb_dir",
|
||||||
|
default=None,
|
||||||
|
type=str,
|
||||||
|
help="directory to speaker embedding files.")
|
||||||
|
args = parser.parse_args()
|
||||||
|
|
||||||
|
rootdir = Path(args.rootdir).expanduser()
|
||||||
|
dumpdir = Path(args.dumpdir).expanduser()
|
||||||
|
# use absolute path
|
||||||
|
dumpdir = dumpdir.resolve()
|
||||||
|
dumpdir.mkdir(parents=True, exist_ok=True)
|
||||||
|
dur_file = Path(args.dur_file).expanduser()
|
||||||
|
|
||||||
|
if args.spk_emb_dir:
|
||||||
|
spk_emb_dir = Path(args.spk_emb_dir).expanduser().resolve()
|
||||||
|
else:
|
||||||
|
spk_emb_dir = None
|
||||||
|
|
||||||
|
assert rootdir.is_dir()
|
||||||
|
assert dur_file.is_file()
|
||||||
|
|
||||||
|
with open(args.config, 'rt') as f:
|
||||||
|
config = CfgNode(yaml.safe_load(f))
|
||||||
|
|
||||||
|
if args.verbose > 1:
|
||||||
|
print(vars(args))
|
||||||
|
print(config)
|
||||||
|
|
||||||
|
sentences, speaker_set = get_phn_dur(dur_file)
|
||||||
|
|
||||||
|
merge_silence(sentences)
|
||||||
|
phone_id_map_path = dumpdir / "phone_id_map.txt"
|
||||||
|
speaker_id_map_path = dumpdir / "speaker_id_map.txt"
|
||||||
|
get_input_token(sentences, phone_id_map_path, args.dataset)
|
||||||
|
get_spk_id_map(speaker_set, speaker_id_map_path)
|
||||||
|
|
||||||
|
if args.dataset == "baker":
|
||||||
|
wav_files = sorted(list((rootdir / "Wave").rglob("*.wav")))
|
||||||
|
# split data into 3 sections
|
||||||
|
num_train = 9800
|
||||||
|
num_dev = 100
|
||||||
|
train_wav_files = wav_files[:num_train]
|
||||||
|
dev_wav_files = wav_files[num_train:num_train + num_dev]
|
||||||
|
test_wav_files = wav_files[num_train + num_dev:]
|
||||||
|
elif args.dataset == "aishell3":
|
||||||
|
sub_num_dev = 5
|
||||||
|
wav_dir = rootdir / "train" / "wav"
|
||||||
|
train_wav_files = []
|
||||||
|
dev_wav_files = []
|
||||||
|
test_wav_files = []
|
||||||
|
for speaker in os.listdir(wav_dir):
|
||||||
|
wav_files = sorted(list((wav_dir / speaker).rglob("*.wav")))
|
||||||
|
if len(wav_files) > 100:
|
||||||
|
train_wav_files += wav_files[:-sub_num_dev * 2]
|
||||||
|
dev_wav_files += wav_files[-sub_num_dev * 2:-sub_num_dev]
|
||||||
|
test_wav_files += wav_files[-sub_num_dev:]
|
||||||
|
else:
|
||||||
|
train_wav_files += wav_files
|
||||||
|
|
||||||
|
elif args.dataset == "ljspeech":
|
||||||
|
wav_files = sorted(list((rootdir / "wavs").rglob("*.wav")))
|
||||||
|
# split data into 3 sections
|
||||||
|
num_train = 12900
|
||||||
|
num_dev = 100
|
||||||
|
train_wav_files = wav_files[:num_train]
|
||||||
|
dev_wav_files = wav_files[num_train:num_train + num_dev]
|
||||||
|
test_wav_files = wav_files[num_train + num_dev:]
|
||||||
|
elif args.dataset == "vctk":
|
||||||
|
sub_num_dev = 5
|
||||||
|
wav_dir = rootdir / "wav48_silence_trimmed"
|
||||||
|
train_wav_files = []
|
||||||
|
dev_wav_files = []
|
||||||
|
test_wav_files = []
|
||||||
|
for speaker in os.listdir(wav_dir):
|
||||||
|
wav_files = sorted(list((wav_dir / speaker).rglob("*_mic2.flac")))
|
||||||
|
if len(wav_files) > 100:
|
||||||
|
train_wav_files += wav_files[:-sub_num_dev * 2]
|
||||||
|
dev_wav_files += wav_files[-sub_num_dev * 2:-sub_num_dev]
|
||||||
|
test_wav_files += wav_files[-sub_num_dev:]
|
||||||
|
else:
|
||||||
|
train_wav_files += wav_files
|
||||||
|
|
||||||
|
else:
|
||||||
|
print("dataset should in {baker, aishell3, ljspeech, vctk} now!")
|
||||||
|
|
||||||
|
train_dump_dir = dumpdir / "train" / "raw"
|
||||||
|
train_dump_dir.mkdir(parents=True, exist_ok=True)
|
||||||
|
dev_dump_dir = dumpdir / "dev" / "raw"
|
||||||
|
dev_dump_dir.mkdir(parents=True, exist_ok=True)
|
||||||
|
test_dump_dir = dumpdir / "test" / "raw"
|
||||||
|
test_dump_dir.mkdir(parents=True, exist_ok=True)
|
||||||
|
|
||||||
|
# Extractor
|
||||||
|
|
||||||
|
spec_extractor = LinearSpectrogram(
|
||||||
|
n_fft=config.n_fft,
|
||||||
|
hop_length=config.n_shift,
|
||||||
|
win_length=config.win_length,
|
||||||
|
window=config.window)
|
||||||
|
|
||||||
|
# process for the 3 sections
|
||||||
|
if train_wav_files:
|
||||||
|
process_sentences(
|
||||||
|
config=config,
|
||||||
|
fps=train_wav_files,
|
||||||
|
sentences=sentences,
|
||||||
|
output_dir=train_dump_dir,
|
||||||
|
spec_extractor=spec_extractor,
|
||||||
|
nprocs=args.num_cpu,
|
||||||
|
cut_sil=args.cut_sil,
|
||||||
|
spk_emb_dir=spk_emb_dir)
|
||||||
|
if dev_wav_files:
|
||||||
|
process_sentences(
|
||||||
|
config=config,
|
||||||
|
fps=dev_wav_files,
|
||||||
|
sentences=sentences,
|
||||||
|
output_dir=dev_dump_dir,
|
||||||
|
spec_extractor=spec_extractor,
|
||||||
|
cut_sil=args.cut_sil,
|
||||||
|
spk_emb_dir=spk_emb_dir)
|
||||||
|
if test_wav_files:
|
||||||
|
process_sentences(
|
||||||
|
config=config,
|
||||||
|
fps=test_wav_files,
|
||||||
|
sentences=sentences,
|
||||||
|
output_dir=test_dump_dir,
|
||||||
|
spec_extractor=spec_extractor,
|
||||||
|
nprocs=args.num_cpu,
|
||||||
|
cut_sil=args.cut_sil,
|
||||||
|
spk_emb_dir=spk_emb_dir)
|
||||||
|
|
||||||
|
|
||||||
|
if __name__ == "__main__":
|
||||||
|
main()
|
@ -0,0 +1,117 @@
|
|||||||
|
# Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved.
|
||||||
|
#
|
||||||
|
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||||
|
# you may not use this file except in compliance with the License.
|
||||||
|
# You may obtain a copy of the License at
|
||||||
|
#
|
||||||
|
# http://www.apache.org/licenses/LICENSE-2.0
|
||||||
|
#
|
||||||
|
# Unless required by applicable law or agreed to in writing, software
|
||||||
|
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||||
|
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||||
|
# See the License for the specific language governing permissions and
|
||||||
|
# limitations under the License.
|
||||||
|
import argparse
|
||||||
|
from pathlib import Path
|
||||||
|
|
||||||
|
import jsonlines
|
||||||
|
import paddle
|
||||||
|
import soundfile as sf
|
||||||
|
import yaml
|
||||||
|
from timer import timer
|
||||||
|
from yacs.config import CfgNode
|
||||||
|
|
||||||
|
from paddlespeech.t2s.datasets.data_table import DataTable
|
||||||
|
from paddlespeech.t2s.models.vits import VITS
|
||||||
|
|
||||||
|
|
||||||
|
def evaluate(args):
|
||||||
|
|
||||||
|
# construct dataset for evaluation
|
||||||
|
with jsonlines.open(args.test_metadata, 'r') as reader:
|
||||||
|
test_metadata = list(reader)
|
||||||
|
# Init body.
|
||||||
|
with open(args.config) as f:
|
||||||
|
config = CfgNode(yaml.safe_load(f))
|
||||||
|
|
||||||
|
print("========Args========")
|
||||||
|
print(yaml.safe_dump(vars(args)))
|
||||||
|
print("========Config========")
|
||||||
|
print(config)
|
||||||
|
|
||||||
|
fields = ["utt_id", "text"]
|
||||||
|
|
||||||
|
test_dataset = DataTable(data=test_metadata, fields=fields)
|
||||||
|
|
||||||
|
with open(args.phones_dict, "r") as f:
|
||||||
|
phn_id = [line.strip().split() for line in f.readlines()]
|
||||||
|
vocab_size = len(phn_id)
|
||||||
|
print("vocab_size:", vocab_size)
|
||||||
|
|
||||||
|
odim = config.n_fft // 2 + 1
|
||||||
|
|
||||||
|
vits = VITS(idim=vocab_size, odim=odim, **config["model"])
|
||||||
|
vits.set_state_dict(paddle.load(args.ckpt)["main_params"])
|
||||||
|
vits.eval()
|
||||||
|
|
||||||
|
output_dir = Path(args.output_dir)
|
||||||
|
output_dir.mkdir(parents=True, exist_ok=True)
|
||||||
|
|
||||||
|
N = 0
|
||||||
|
T = 0
|
||||||
|
|
||||||
|
for datum in test_dataset:
|
||||||
|
utt_id = datum["utt_id"]
|
||||||
|
phone_ids = paddle.to_tensor(datum["text"])
|
||||||
|
with timer() as t:
|
||||||
|
with paddle.no_grad():
|
||||||
|
out = vits.inference(text=phone_ids)
|
||||||
|
wav = out["wav"]
|
||||||
|
wav = wav.numpy()
|
||||||
|
N += wav.size
|
||||||
|
T += t.elapse
|
||||||
|
speed = wav.size / t.elapse
|
||||||
|
rtf = config.fs / speed
|
||||||
|
print(
|
||||||
|
f"{utt_id}, wave: {wav.size}, time: {t.elapse}s, Hz: {speed}, RTF: {rtf}."
|
||||||
|
)
|
||||||
|
sf.write(str(output_dir / (utt_id + ".wav")), wav, samplerate=config.fs)
|
||||||
|
print(f"{utt_id} done!")
|
||||||
|
print(f"generation speed: {N / T}Hz, RTF: {config.fs / (N / T) }")
|
||||||
|
|
||||||
|
|
||||||
|
def parse_args():
|
||||||
|
# parse args and config
|
||||||
|
parser = argparse.ArgumentParser(description="Synthesize with VITS")
|
||||||
|
# model
|
||||||
|
parser.add_argument(
|
||||||
|
'--config', type=str, default=None, help='Config of VITS.')
|
||||||
|
parser.add_argument(
|
||||||
|
'--ckpt', type=str, default=None, help='Checkpoint file of VITS.')
|
||||||
|
parser.add_argument(
|
||||||
|
"--phones_dict", type=str, default=None, help="phone vocabulary file.")
|
||||||
|
# other
|
||||||
|
parser.add_argument(
|
||||||
|
"--ngpu", type=int, default=1, help="if ngpu == 0, use cpu.")
|
||||||
|
parser.add_argument("--test_metadata", type=str, help="test metadata.")
|
||||||
|
parser.add_argument("--output_dir", type=str, help="output dir.")
|
||||||
|
|
||||||
|
args = parser.parse_args()
|
||||||
|
return args
|
||||||
|
|
||||||
|
|
||||||
|
def main():
|
||||||
|
args = parse_args()
|
||||||
|
|
||||||
|
if args.ngpu == 0:
|
||||||
|
paddle.set_device("cpu")
|
||||||
|
elif args.ngpu > 0:
|
||||||
|
paddle.set_device("gpu")
|
||||||
|
else:
|
||||||
|
print("ngpu should >= 0 !")
|
||||||
|
|
||||||
|
evaluate(args)
|
||||||
|
|
||||||
|
|
||||||
|
if __name__ == "__main__":
|
||||||
|
main()
|
@ -0,0 +1,146 @@
|
|||||||
|
# Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved.
|
||||||
|
#
|
||||||
|
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||||
|
# you may not use this file except in compliance with the License.
|
||||||
|
# You may obtain a copy of the License at
|
||||||
|
#
|
||||||
|
# http://www.apache.org/licenses/LICENSE-2.0
|
||||||
|
#
|
||||||
|
# Unless required by applicable law or agreed to in writing, software
|
||||||
|
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||||
|
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||||
|
# See the License for the specific language governing permissions and
|
||||||
|
# limitations under the License.
|
||||||
|
import argparse
|
||||||
|
from pathlib import Path
|
||||||
|
|
||||||
|
import paddle
|
||||||
|
import soundfile as sf
|
||||||
|
import yaml
|
||||||
|
from timer import timer
|
||||||
|
from yacs.config import CfgNode
|
||||||
|
|
||||||
|
from paddlespeech.t2s.exps.syn_utils import get_frontend
|
||||||
|
from paddlespeech.t2s.exps.syn_utils import get_sentences
|
||||||
|
from paddlespeech.t2s.models.vits import VITS
|
||||||
|
|
||||||
|
|
||||||
|
def evaluate(args):
|
||||||
|
|
||||||
|
# Init body.
|
||||||
|
with open(args.config) as f:
|
||||||
|
config = CfgNode(yaml.safe_load(f))
|
||||||
|
|
||||||
|
print("========Args========")
|
||||||
|
print(yaml.safe_dump(vars(args)))
|
||||||
|
print("========Config========")
|
||||||
|
print(config)
|
||||||
|
|
||||||
|
sentences = get_sentences(text_file=args.text, lang=args.lang)
|
||||||
|
|
||||||
|
# frontend
|
||||||
|
frontend = get_frontend(lang=args.lang, phones_dict=args.phones_dict)
|
||||||
|
|
||||||
|
with open(args.phones_dict, "r") as f:
|
||||||
|
phn_id = [line.strip().split() for line in f.readlines()]
|
||||||
|
vocab_size = len(phn_id)
|
||||||
|
print("vocab_size:", vocab_size)
|
||||||
|
|
||||||
|
odim = config.n_fft // 2 + 1
|
||||||
|
|
||||||
|
vits = VITS(idim=vocab_size, odim=odim, **config["model"])
|
||||||
|
vits.set_state_dict(paddle.load(args.ckpt)["main_params"])
|
||||||
|
vits.eval()
|
||||||
|
|
||||||
|
output_dir = Path(args.output_dir)
|
||||||
|
output_dir.mkdir(parents=True, exist_ok=True)
|
||||||
|
merge_sentences = False
|
||||||
|
|
||||||
|
N = 0
|
||||||
|
T = 0
|
||||||
|
for utt_id, sentence in sentences:
|
||||||
|
with timer() as t:
|
||||||
|
if args.lang == 'zh':
|
||||||
|
input_ids = frontend.get_input_ids(
|
||||||
|
sentence, merge_sentences=merge_sentences)
|
||||||
|
phone_ids = input_ids["phone_ids"]
|
||||||
|
elif args.lang == 'en':
|
||||||
|
input_ids = frontend.get_input_ids(
|
||||||
|
sentence, merge_sentences=merge_sentences)
|
||||||
|
phone_ids = input_ids["phone_ids"]
|
||||||
|
else:
|
||||||
|
print("lang should in {'zh', 'en'}!")
|
||||||
|
with paddle.no_grad():
|
||||||
|
flags = 0
|
||||||
|
for i in range(len(phone_ids)):
|
||||||
|
part_phone_ids = phone_ids[i]
|
||||||
|
out = vits.inference(text=part_phone_ids)
|
||||||
|
wav = out["wav"]
|
||||||
|
if flags == 0:
|
||||||
|
wav_all = wav
|
||||||
|
flags = 1
|
||||||
|
else:
|
||||||
|
wav_all = paddle.concat([wav_all, wav])
|
||||||
|
wav = wav_all.numpy()
|
||||||
|
N += wav.size
|
||||||
|
T += t.elapse
|
||||||
|
speed = wav.size / t.elapse
|
||||||
|
rtf = config.fs / speed
|
||||||
|
print(
|
||||||
|
f"{utt_id}, wave: {wav.shape}, time: {t.elapse}s, Hz: {speed}, RTF: {rtf}."
|
||||||
|
)
|
||||||
|
sf.write(str(output_dir / (utt_id + ".wav")), wav, samplerate=config.fs)
|
||||||
|
print(f"{utt_id} done!")
|
||||||
|
print(f"generation speed: {N / T}Hz, RTF: {config.fs / (N / T) }")
|
||||||
|
|
||||||
|
|
||||||
|
def parse_args():
|
||||||
|
# parse args and config
|
||||||
|
parser = argparse.ArgumentParser(description="Synthesize with VITS")
|
||||||
|
|
||||||
|
# model
|
||||||
|
parser.add_argument(
|
||||||
|
'--config', type=str, default=None, help='Config of VITS.')
|
||||||
|
parser.add_argument(
|
||||||
|
'--ckpt', type=str, default=None, help='Checkpoint file of VITS.')
|
||||||
|
parser.add_argument(
|
||||||
|
"--phones_dict", type=str, default=None, help="phone vocabulary file.")
|
||||||
|
# other
|
||||||
|
parser.add_argument(
|
||||||
|
'--lang',
|
||||||
|
type=str,
|
||||||
|
default='zh',
|
||||||
|
help='Choose model language. zh or en')
|
||||||
|
|
||||||
|
parser.add_argument(
|
||||||
|
"--inference_dir",
|
||||||
|
type=str,
|
||||||
|
default=None,
|
||||||
|
help="dir to save inference models")
|
||||||
|
parser.add_argument(
|
||||||
|
"--ngpu", type=int, default=1, help="if ngpu == 0, use cpu.")
|
||||||
|
parser.add_argument(
|
||||||
|
"--text",
|
||||||
|
type=str,
|
||||||
|
help="text to synthesize, a 'utt_id sentence' pair per line.")
|
||||||
|
parser.add_argument("--output_dir", type=str, help="output dir.")
|
||||||
|
|
||||||
|
args = parser.parse_args()
|
||||||
|
return args
|
||||||
|
|
||||||
|
|
||||||
|
def main():
|
||||||
|
args = parse_args()
|
||||||
|
|
||||||
|
if args.ngpu == 0:
|
||||||
|
paddle.set_device("cpu")
|
||||||
|
elif args.ngpu > 0:
|
||||||
|
paddle.set_device("gpu")
|
||||||
|
else:
|
||||||
|
print("ngpu should >= 0 !")
|
||||||
|
|
||||||
|
evaluate(args)
|
||||||
|
|
||||||
|
|
||||||
|
if __name__ == "__main__":
|
||||||
|
main()
|
@ -0,0 +1,261 @@
|
|||||||
|
# Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved.
|
||||||
|
#
|
||||||
|
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||||
|
# you may not use this file except in compliance with the License.
|
||||||
|
# You may obtain a copy of the License at
|
||||||
|
#
|
||||||
|
# http://www.apache.org/licenses/LICENSE-2.0
|
||||||
|
#
|
||||||
|
# Unless required by applicable law or agreed to in writing, software
|
||||||
|
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||||
|
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||||
|
# See the License for the specific language governing permissions and
|
||||||
|
# limitations under the License.
|
||||||
|
import argparse
|
||||||
|
import logging
|
||||||
|
import os
|
||||||
|
import shutil
|
||||||
|
from pathlib import Path
|
||||||
|
|
||||||
|
import jsonlines
|
||||||
|
import numpy as np
|
||||||
|
import paddle
|
||||||
|
import yaml
|
||||||
|
from paddle import DataParallel
|
||||||
|
from paddle import distributed as dist
|
||||||
|
from paddle.io import DataLoader
|
||||||
|
from paddle.io import DistributedBatchSampler
|
||||||
|
from paddle.optimizer import Adam
|
||||||
|
from yacs.config import CfgNode
|
||||||
|
|
||||||
|
from paddlespeech.t2s.datasets.am_batch_fn import vits_single_spk_batch_fn
|
||||||
|
from paddlespeech.t2s.datasets.data_table import DataTable
|
||||||
|
from paddlespeech.t2s.models.vits import VITS
|
||||||
|
from paddlespeech.t2s.models.vits import VITSEvaluator
|
||||||
|
from paddlespeech.t2s.models.vits import VITSUpdater
|
||||||
|
from paddlespeech.t2s.modules.losses import DiscriminatorAdversarialLoss
|
||||||
|
from paddlespeech.t2s.modules.losses import FeatureMatchLoss
|
||||||
|
from paddlespeech.t2s.modules.losses import GeneratorAdversarialLoss
|
||||||
|
from paddlespeech.t2s.modules.losses import KLDivergenceLoss
|
||||||
|
from paddlespeech.t2s.modules.losses import MelSpectrogramLoss
|
||||||
|
from paddlespeech.t2s.training.extensions.snapshot import Snapshot
|
||||||
|
from paddlespeech.t2s.training.extensions.visualizer import VisualDL
|
||||||
|
from paddlespeech.t2s.training.optimizer import scheduler_classes
|
||||||
|
from paddlespeech.t2s.training.seeding import seed_everything
|
||||||
|
from paddlespeech.t2s.training.trainer import Trainer
|
||||||
|
|
||||||
|
|
||||||
|
def train_sp(args, config):
|
||||||
|
# decides device type and whether to run in parallel
|
||||||
|
# setup running environment correctly
|
||||||
|
world_size = paddle.distributed.get_world_size()
|
||||||
|
if (not paddle.is_compiled_with_cuda()) or args.ngpu == 0:
|
||||||
|
paddle.set_device("cpu")
|
||||||
|
else:
|
||||||
|
paddle.set_device("gpu")
|
||||||
|
if world_size > 1:
|
||||||
|
paddle.distributed.init_parallel_env()
|
||||||
|
|
||||||
|
# set the random seed, it is a must for multiprocess training
|
||||||
|
seed_everything(config.seed)
|
||||||
|
|
||||||
|
print(
|
||||||
|
f"rank: {dist.get_rank()}, pid: {os.getpid()}, parent_pid: {os.getppid()}",
|
||||||
|
)
|
||||||
|
|
||||||
|
# dataloader has been too verbose
|
||||||
|
logging.getLogger("DataLoader").disabled = True
|
||||||
|
|
||||||
|
fields = ["text", "text_lengths", "feats", "feats_lengths", "wave"]
|
||||||
|
|
||||||
|
converters = {
|
||||||
|
"wave": np.load,
|
||||||
|
"feats": np.load,
|
||||||
|
}
|
||||||
|
|
||||||
|
# construct dataset for training and validation
|
||||||
|
with jsonlines.open(args.train_metadata, 'r') as reader:
|
||||||
|
train_metadata = list(reader)
|
||||||
|
train_dataset = DataTable(
|
||||||
|
data=train_metadata,
|
||||||
|
fields=fields,
|
||||||
|
converters=converters, )
|
||||||
|
with jsonlines.open(args.dev_metadata, 'r') as reader:
|
||||||
|
dev_metadata = list(reader)
|
||||||
|
dev_dataset = DataTable(
|
||||||
|
data=dev_metadata,
|
||||||
|
fields=fields,
|
||||||
|
converters=converters, )
|
||||||
|
|
||||||
|
# collate function and dataloader
|
||||||
|
train_sampler = DistributedBatchSampler(
|
||||||
|
train_dataset,
|
||||||
|
batch_size=config.batch_size,
|
||||||
|
shuffle=True,
|
||||||
|
drop_last=True)
|
||||||
|
dev_sampler = DistributedBatchSampler(
|
||||||
|
dev_dataset,
|
||||||
|
batch_size=config.batch_size,
|
||||||
|
shuffle=False,
|
||||||
|
drop_last=False)
|
||||||
|
print("samplers done!")
|
||||||
|
|
||||||
|
train_batch_fn = vits_single_spk_batch_fn
|
||||||
|
|
||||||
|
train_dataloader = DataLoader(
|
||||||
|
train_dataset,
|
||||||
|
batch_sampler=train_sampler,
|
||||||
|
collate_fn=train_batch_fn,
|
||||||
|
num_workers=config.num_workers)
|
||||||
|
|
||||||
|
dev_dataloader = DataLoader(
|
||||||
|
dev_dataset,
|
||||||
|
batch_sampler=dev_sampler,
|
||||||
|
collate_fn=train_batch_fn,
|
||||||
|
num_workers=config.num_workers)
|
||||||
|
print("dataloaders done!")
|
||||||
|
|
||||||
|
with open(args.phones_dict, "r") as f:
|
||||||
|
phn_id = [line.strip().split() for line in f.readlines()]
|
||||||
|
vocab_size = len(phn_id)
|
||||||
|
print("vocab_size:", vocab_size)
|
||||||
|
|
||||||
|
odim = config.n_fft // 2 + 1
|
||||||
|
model = VITS(idim=vocab_size, odim=odim, **config["model"])
|
||||||
|
gen_parameters = model.generator.parameters()
|
||||||
|
dis_parameters = model.discriminator.parameters()
|
||||||
|
if world_size > 1:
|
||||||
|
model = DataParallel(model)
|
||||||
|
gen_parameters = model._layers.generator.parameters()
|
||||||
|
dis_parameters = model._layers.discriminator.parameters()
|
||||||
|
|
||||||
|
print("model done!")
|
||||||
|
|
||||||
|
# loss
|
||||||
|
criterion_mel = MelSpectrogramLoss(
|
||||||
|
**config["mel_loss_params"], )
|
||||||
|
criterion_feat_match = FeatureMatchLoss(
|
||||||
|
**config["feat_match_loss_params"], )
|
||||||
|
criterion_gen_adv = GeneratorAdversarialLoss(
|
||||||
|
**config["generator_adv_loss_params"], )
|
||||||
|
criterion_dis_adv = DiscriminatorAdversarialLoss(
|
||||||
|
**config["discriminator_adv_loss_params"], )
|
||||||
|
criterion_kl = KLDivergenceLoss()
|
||||||
|
|
||||||
|
print("criterions done!")
|
||||||
|
|
||||||
|
lr_schedule_g = scheduler_classes[config["generator_scheduler"]](
|
||||||
|
**config["generator_scheduler_params"])
|
||||||
|
optimizer_g = Adam(
|
||||||
|
learning_rate=lr_schedule_g,
|
||||||
|
parameters=gen_parameters,
|
||||||
|
**config["generator_optimizer_params"])
|
||||||
|
|
||||||
|
lr_schedule_d = scheduler_classes[config["discriminator_scheduler"]](
|
||||||
|
**config["discriminator_scheduler_params"])
|
||||||
|
optimizer_d = Adam(
|
||||||
|
learning_rate=lr_schedule_d,
|
||||||
|
parameters=dis_parameters,
|
||||||
|
**config["discriminator_optimizer_params"])
|
||||||
|
|
||||||
|
print("optimizers done!")
|
||||||
|
|
||||||
|
output_dir = Path(args.output_dir)
|
||||||
|
output_dir.mkdir(parents=True, exist_ok=True)
|
||||||
|
if dist.get_rank() == 0:
|
||||||
|
config_name = args.config.split("/")[-1]
|
||||||
|
# copy conf to output_dir
|
||||||
|
shutil.copyfile(args.config, output_dir / config_name)
|
||||||
|
|
||||||
|
updater = VITSUpdater(
|
||||||
|
model=model,
|
||||||
|
optimizers={
|
||||||
|
"generator": optimizer_g,
|
||||||
|
"discriminator": optimizer_d,
|
||||||
|
},
|
||||||
|
criterions={
|
||||||
|
"mel": criterion_mel,
|
||||||
|
"feat_match": criterion_feat_match,
|
||||||
|
"gen_adv": criterion_gen_adv,
|
||||||
|
"dis_adv": criterion_dis_adv,
|
||||||
|
"kl": criterion_kl,
|
||||||
|
},
|
||||||
|
schedulers={
|
||||||
|
"generator": lr_schedule_g,
|
||||||
|
"discriminator": lr_schedule_d,
|
||||||
|
},
|
||||||
|
dataloader=train_dataloader,
|
||||||
|
lambda_adv=config.lambda_adv,
|
||||||
|
lambda_mel=config.lambda_mel,
|
||||||
|
lambda_kl=config.lambda_kl,
|
||||||
|
lambda_feat_match=config.lambda_feat_match,
|
||||||
|
lambda_dur=config.lambda_dur,
|
||||||
|
generator_first=config.generator_first,
|
||||||
|
output_dir=output_dir)
|
||||||
|
|
||||||
|
evaluator = VITSEvaluator(
|
||||||
|
model=model,
|
||||||
|
criterions={
|
||||||
|
"mel": criterion_mel,
|
||||||
|
"feat_match": criterion_feat_match,
|
||||||
|
"gen_adv": criterion_gen_adv,
|
||||||
|
"dis_adv": criterion_dis_adv,
|
||||||
|
"kl": criterion_kl,
|
||||||
|
},
|
||||||
|
dataloader=dev_dataloader,
|
||||||
|
lambda_adv=config.lambda_adv,
|
||||||
|
lambda_mel=config.lambda_mel,
|
||||||
|
lambda_kl=config.lambda_kl,
|
||||||
|
lambda_feat_match=config.lambda_feat_match,
|
||||||
|
lambda_dur=config.lambda_dur,
|
||||||
|
generator_first=config.generator_first,
|
||||||
|
output_dir=output_dir)
|
||||||
|
|
||||||
|
trainer = Trainer(updater, (config.max_epoch, 'epoch'), output_dir)
|
||||||
|
|
||||||
|
if dist.get_rank() == 0:
|
||||||
|
trainer.extend(evaluator, trigger=(1, "epoch"))
|
||||||
|
trainer.extend(VisualDL(output_dir), trigger=(1, "iteration"))
|
||||||
|
trainer.extend(
|
||||||
|
Snapshot(max_size=config.num_snapshots), trigger=(1, 'epoch'))
|
||||||
|
|
||||||
|
print("Trainer Done!")
|
||||||
|
trainer.run()
|
||||||
|
|
||||||
|
|
||||||
|
def main():
|
||||||
|
# parse args and config and redirect to train_sp
|
||||||
|
|
||||||
|
parser = argparse.ArgumentParser(description="Train a HiFiGAN model.")
|
||||||
|
parser.add_argument(
|
||||||
|
"--config", type=str, help="config file to overwrite default config.")
|
||||||
|
parser.add_argument("--train-metadata", type=str, help="training data.")
|
||||||
|
parser.add_argument("--dev-metadata", type=str, help="dev data.")
|
||||||
|
parser.add_argument("--output-dir", type=str, help="output dir.")
|
||||||
|
parser.add_argument(
|
||||||
|
"--ngpu", type=int, default=1, help="if ngpu == 0, use cpu.")
|
||||||
|
parser.add_argument(
|
||||||
|
"--phones-dict", type=str, default=None, help="phone vocabulary file.")
|
||||||
|
|
||||||
|
args = parser.parse_args()
|
||||||
|
|
||||||
|
with open(args.config, 'rt') as f:
|
||||||
|
config = CfgNode(yaml.safe_load(f))
|
||||||
|
|
||||||
|
print("========Args========")
|
||||||
|
print(yaml.safe_dump(vars(args)))
|
||||||
|
print("========Config========")
|
||||||
|
print(config)
|
||||||
|
print(
|
||||||
|
f"master see the word size: {dist.get_world_size()}, from pid: {os.getpid()}"
|
||||||
|
)
|
||||||
|
|
||||||
|
# dispatch
|
||||||
|
if args.ngpu > 1:
|
||||||
|
dist.spawn(train_sp, (args, config), nprocs=args.ngpu)
|
||||||
|
else:
|
||||||
|
train_sp(args, config)
|
||||||
|
|
||||||
|
|
||||||
|
if __name__ == "__main__":
|
||||||
|
main()
|
@ -0,0 +1,15 @@
|
|||||||
|
# Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved.
|
||||||
|
#
|
||||||
|
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||||
|
# you may not use this file except in compliance with the License.
|
||||||
|
# You may obtain a copy of the License at
|
||||||
|
#
|
||||||
|
# http://www.apache.org/licenses/LICENSE-2.0
|
||||||
|
#
|
||||||
|
# Unless required by applicable law or agreed to in writing, software
|
||||||
|
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||||
|
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||||
|
# See the License for the specific language governing permissions and
|
||||||
|
# limitations under the License.
|
||||||
|
from .vits import *
|
||||||
|
from .vits_updater import *
|
@ -0,0 +1,172 @@
|
|||||||
|
# Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved.
|
||||||
|
#
|
||||||
|
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||||
|
# you may not use this file except in compliance with the License.
|
||||||
|
# You may obtain a copy of the License at
|
||||||
|
#
|
||||||
|
# http://www.apache.org/licenses/LICENSE-2.0
|
||||||
|
#
|
||||||
|
# Unless required by applicable law or agreed to in writing, software
|
||||||
|
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||||
|
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||||
|
# See the License for the specific language governing permissions and
|
||||||
|
# limitations under the License.
|
||||||
|
"""Stochastic duration predictor modules in VITS.
|
||||||
|
|
||||||
|
This code is based on https://github.com/jaywalnut310/vits.
|
||||||
|
|
||||||
|
"""
|
||||||
|
import math
|
||||||
|
from typing import Optional
|
||||||
|
|
||||||
|
import paddle
|
||||||
|
import paddle.nn.functional as F
|
||||||
|
from paddle import nn
|
||||||
|
|
||||||
|
from paddlespeech.t2s.models.vits.flow import ConvFlow
|
||||||
|
from paddlespeech.t2s.models.vits.flow import DilatedDepthSeparableConv
|
||||||
|
from paddlespeech.t2s.models.vits.flow import ElementwiseAffineFlow
|
||||||
|
from paddlespeech.t2s.models.vits.flow import FlipFlow
|
||||||
|
from paddlespeech.t2s.models.vits.flow import LogFlow
|
||||||
|
|
||||||
|
|
||||||
|
class StochasticDurationPredictor(nn.Layer):
|
||||||
|
"""Stochastic duration predictor module.
|
||||||
|
This is a module of stochastic duration predictor described in `Conditional
|
||||||
|
Variational Autoencoder with Adversarial Learning for End-to-End Text-to-Speech`_.
|
||||||
|
.. _`Conditional Variational Autoencoder with Adversarial Learning for End-to-End
|
||||||
|
Text-to-Speech`: https://arxiv.org/abs/2106.06103
|
||||||
|
"""
|
||||||
|
|
||||||
|
def __init__(
|
||||||
|
self,
|
||||||
|
channels: int=192,
|
||||||
|
kernel_size: int=3,
|
||||||
|
dropout_rate: float=0.5,
|
||||||
|
flows: int=4,
|
||||||
|
dds_conv_layers: int=3,
|
||||||
|
global_channels: int=-1, ):
|
||||||
|
"""Initialize StochasticDurationPredictor module.
|
||||||
|
Args:
|
||||||
|
channels (int): Number of channels.
|
||||||
|
kernel_size (int): Kernel size.
|
||||||
|
dropout_rate (float): Dropout rate.
|
||||||
|
flows (int): Number of flows.
|
||||||
|
dds_conv_layers (int): Number of conv layers in DDS conv.
|
||||||
|
global_channels (int): Number of global conditioning channels.
|
||||||
|
"""
|
||||||
|
super().__init__()
|
||||||
|
|
||||||
|
self.pre = nn.Conv1D(channels, channels, 1)
|
||||||
|
self.dds = DilatedDepthSeparableConv(
|
||||||
|
channels,
|
||||||
|
kernel_size,
|
||||||
|
layers=dds_conv_layers,
|
||||||
|
dropout_rate=dropout_rate, )
|
||||||
|
self.proj = nn.Conv1D(channels, channels, 1)
|
||||||
|
|
||||||
|
self.log_flow = LogFlow()
|
||||||
|
self.flows = nn.LayerList()
|
||||||
|
self.flows.append(ElementwiseAffineFlow(2))
|
||||||
|
for i in range(flows):
|
||||||
|
self.flows.append(
|
||||||
|
ConvFlow(
|
||||||
|
2,
|
||||||
|
channels,
|
||||||
|
kernel_size,
|
||||||
|
layers=dds_conv_layers, ))
|
||||||
|
self.flows.append(FlipFlow())
|
||||||
|
|
||||||
|
self.post_pre = nn.Conv1D(1, channels, 1)
|
||||||
|
self.post_dds = DilatedDepthSeparableConv(
|
||||||
|
channels,
|
||||||
|
kernel_size,
|
||||||
|
layers=dds_conv_layers,
|
||||||
|
dropout_rate=dropout_rate, )
|
||||||
|
self.post_proj = nn.Conv1D(channels, channels, 1)
|
||||||
|
self.post_flows = nn.LayerList()
|
||||||
|
self.post_flows.append(ElementwiseAffineFlow(2))
|
||||||
|
for i in range(flows):
|
||||||
|
self.post_flows.append(
|
||||||
|
ConvFlow(
|
||||||
|
2,
|
||||||
|
channels,
|
||||||
|
kernel_size,
|
||||||
|
layers=dds_conv_layers, ))
|
||||||
|
self.post_flows.append(FlipFlow())
|
||||||
|
|
||||||
|
if global_channels > 0:
|
||||||
|
self.global_conv = nn.Conv1D(global_channels, channels, 1)
|
||||||
|
|
||||||
|
def forward(
|
||||||
|
self,
|
||||||
|
x: paddle.Tensor,
|
||||||
|
x_mask: paddle.Tensor,
|
||||||
|
w: Optional[paddle.Tensor]=None,
|
||||||
|
g: Optional[paddle.Tensor]=None,
|
||||||
|
inverse: bool=False,
|
||||||
|
noise_scale: float=1.0, ) -> paddle.Tensor:
|
||||||
|
"""Calculate forward propagation.
|
||||||
|
Args:
|
||||||
|
x (Tensor): Input tensor (B, channels, T_text).
|
||||||
|
x_mask (Tensor): Mask tensor (B, 1, T_text).
|
||||||
|
w (Optional[Tensor]): Duration tensor (B, 1, T_text).
|
||||||
|
g (Optional[Tensor]): Global conditioning tensor (B, channels, 1)
|
||||||
|
inverse (bool): Whether to inverse the flow.
|
||||||
|
noise_scale (float): Noise scale value.
|
||||||
|
Returns:
|
||||||
|
Tensor: If not inverse, negative log-likelihood (NLL) tensor (B,).
|
||||||
|
If inverse, log-duration tensor (B, 1, T_text).
|
||||||
|
"""
|
||||||
|
# stop gradient
|
||||||
|
# x = x.detach()
|
||||||
|
x = self.pre(x)
|
||||||
|
if g is not None:
|
||||||
|
# stop gradient
|
||||||
|
x = x + self.global_conv(g.detach())
|
||||||
|
x = self.dds(x, x_mask)
|
||||||
|
x = self.proj(x) * x_mask
|
||||||
|
|
||||||
|
if not inverse:
|
||||||
|
assert w is not None, "w must be provided."
|
||||||
|
h_w = self.post_pre(w)
|
||||||
|
h_w = self.post_dds(h_w, x_mask)
|
||||||
|
h_w = self.post_proj(h_w) * x_mask
|
||||||
|
e_q = (paddle.randn([paddle.shape(w)[0], 2, paddle.shape(w)[2]]) *
|
||||||
|
x_mask)
|
||||||
|
z_q = e_q
|
||||||
|
logdet_tot_q = 0.0
|
||||||
|
for i, flow in enumerate(self.post_flows):
|
||||||
|
z_q, logdet_q = flow(z_q, x_mask, g=(x + h_w))
|
||||||
|
logdet_tot_q += logdet_q
|
||||||
|
z_u, z1 = paddle.split(z_q, [1, 1], 1)
|
||||||
|
u = F.sigmoid(z_u) * x_mask
|
||||||
|
z0 = (w - u) * x_mask
|
||||||
|
logdet_tot_q += paddle.sum(
|
||||||
|
(F.log_sigmoid(z_u) + F.log_sigmoid(-z_u)) * x_mask, [1, 2])
|
||||||
|
logq = (paddle.sum(-0.5 *
|
||||||
|
(math.log(2 * math.pi) +
|
||||||
|
(e_q**2)) * x_mask, [1, 2]) - logdet_tot_q)
|
||||||
|
|
||||||
|
logdet_tot = 0
|
||||||
|
z0, logdet = self.log_flow(z0, x_mask)
|
||||||
|
logdet_tot += logdet
|
||||||
|
z = paddle.concat([z0, z1], 1)
|
||||||
|
for flow in self.flows:
|
||||||
|
z, logdet = flow(z, x_mask, g=x, inverse=inverse)
|
||||||
|
logdet_tot = logdet_tot + logdet
|
||||||
|
nll = (paddle.sum(0.5 * (math.log(2 * math.pi) +
|
||||||
|
(z**2)) * x_mask, [1, 2]) - logdet_tot)
|
||||||
|
# (B,)
|
||||||
|
return nll + logq
|
||||||
|
else:
|
||||||
|
flows = list(reversed(self.flows))
|
||||||
|
# remove a useless vflow
|
||||||
|
flows = flows[:-2] + [flows[-1]]
|
||||||
|
z = (paddle.randn([paddle.shape(x)[0], 2, paddle.shape(x)[2]]) *
|
||||||
|
noise_scale)
|
||||||
|
for flow in flows:
|
||||||
|
z = flow(z, x_mask, g=x, inverse=inverse)
|
||||||
|
z0, z1 = paddle.split(z, 2, axis=1)
|
||||||
|
logw = z0
|
||||||
|
return logw
|
@ -0,0 +1,313 @@
|
|||||||
|
# Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved.
|
||||||
|
#
|
||||||
|
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||||
|
# you may not use this file except in compliance with the License.
|
||||||
|
# You may obtain a copy of the License at
|
||||||
|
#
|
||||||
|
# http://www.apache.org/licenses/LICENSE-2.0
|
||||||
|
#
|
||||||
|
# Unless required by applicable law or agreed to in writing, software
|
||||||
|
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||||
|
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||||
|
# See the License for the specific language governing permissions and
|
||||||
|
# limitations under the License.
|
||||||
|
"""Basic Flow modules used in VITS.
|
||||||
|
|
||||||
|
This code is based on https://github.com/jaywalnut310/vits.
|
||||||
|
|
||||||
|
"""
|
||||||
|
import math
|
||||||
|
from typing import Optional
|
||||||
|
from typing import Tuple
|
||||||
|
from typing import Union
|
||||||
|
|
||||||
|
import paddle
|
||||||
|
from paddle import nn
|
||||||
|
|
||||||
|
from paddlespeech.t2s.models.vits.transform import piecewise_rational_quadratic_transform
|
||||||
|
|
||||||
|
|
||||||
|
class FlipFlow(nn.Layer):
|
||||||
|
"""Flip flow module."""
|
||||||
|
|
||||||
|
def forward(self, x: paddle.Tensor, *args, inverse: bool=False, **kwargs
|
||||||
|
) -> Union[paddle.Tensor, Tuple[paddle.Tensor, paddle.Tensor]]:
|
||||||
|
"""Calculate forward propagation.
|
||||||
|
Args:
|
||||||
|
x (Tensor): Input tensor (B, channels, T).
|
||||||
|
inverse (bool): Whether to inverse the flow.
|
||||||
|
Returns:
|
||||||
|
Tensor: Flipped tensor (B, channels, T).
|
||||||
|
Tensor: Log-determinant tensor for NLL (B,) if not inverse.
|
||||||
|
"""
|
||||||
|
x = paddle.flip(x, [1])
|
||||||
|
if not inverse:
|
||||||
|
logdet = paddle.zeros(paddle.shape(x)[0], dtype=x.dtype)
|
||||||
|
return x, logdet
|
||||||
|
else:
|
||||||
|
return x
|
||||||
|
|
||||||
|
|
||||||
|
class LogFlow(nn.Layer):
|
||||||
|
"""Log flow module."""
|
||||||
|
|
||||||
|
def forward(self,
|
||||||
|
x: paddle.Tensor,
|
||||||
|
x_mask: paddle.Tensor,
|
||||||
|
inverse: bool=False,
|
||||||
|
eps: float=1e-5,
|
||||||
|
**kwargs
|
||||||
|
) -> Union[paddle.Tensor, Tuple[paddle.Tensor, paddle.Tensor]]:
|
||||||
|
"""Calculate forward propagation.
|
||||||
|
Args:
|
||||||
|
x (Tensor): Input tensor (B, channels, T).
|
||||||
|
x_mask (Tensor): Mask tensor (B, 1, T).
|
||||||
|
inverse (bool): Whether to inverse the flow.
|
||||||
|
eps (float): Epsilon for log.
|
||||||
|
Returns:
|
||||||
|
Tensor: Output tensor (B, channels, T).
|
||||||
|
Tensor: Log-determinant tensor for NLL (B,) if not inverse.
|
||||||
|
"""
|
||||||
|
if not inverse:
|
||||||
|
y = paddle.log(paddle.clip(x, min=eps)) * x_mask
|
||||||
|
logdet = paddle.sum(-y, [1, 2])
|
||||||
|
return y, logdet
|
||||||
|
else:
|
||||||
|
x = paddle.exp(x) * x_mask
|
||||||
|
return x
|
||||||
|
|
||||||
|
|
||||||
|
class ElementwiseAffineFlow(nn.Layer):
|
||||||
|
"""Elementwise affine flow module."""
|
||||||
|
|
||||||
|
def __init__(self, channels: int):
|
||||||
|
"""Initialize ElementwiseAffineFlow module.
|
||||||
|
Args:
|
||||||
|
channels (int): Number of channels.
|
||||||
|
"""
|
||||||
|
super().__init__()
|
||||||
|
self.channels = channels
|
||||||
|
|
||||||
|
m = paddle.zeros([channels, 1])
|
||||||
|
self.m = paddle.create_parameter(
|
||||||
|
shape=m.shape,
|
||||||
|
dtype=str(m.numpy().dtype),
|
||||||
|
default_initializer=paddle.nn.initializer.Assign(m))
|
||||||
|
logs = paddle.zeros([channels, 1])
|
||||||
|
self.logs = paddle.create_parameter(
|
||||||
|
shape=logs.shape,
|
||||||
|
dtype=str(logs.numpy().dtype),
|
||||||
|
default_initializer=paddle.nn.initializer.Assign(logs))
|
||||||
|
|
||||||
|
def forward(self,
|
||||||
|
x: paddle.Tensor,
|
||||||
|
x_mask: paddle.Tensor,
|
||||||
|
inverse: bool=False,
|
||||||
|
**kwargs
|
||||||
|
) -> Union[paddle.Tensor, Tuple[paddle.Tensor, paddle.Tensor]]:
|
||||||
|
"""Calculate forward propagation.
|
||||||
|
Args:
|
||||||
|
x (Tensor): Input tensor (B, channels, T).
|
||||||
|
x_mask (Tensor): Mask tensor (B, 1, T).
|
||||||
|
inverse (bool): Whether to inverse the flow.
|
||||||
|
Returns:
|
||||||
|
Tensor: Output tensor (B, channels, T).
|
||||||
|
Tensor: Log-determinant tensor for NLL (B,) if not inverse.
|
||||||
|
"""
|
||||||
|
if not inverse:
|
||||||
|
y = self.m + paddle.exp(self.logs) * x
|
||||||
|
y = y * x_mask
|
||||||
|
logdet = paddle.sum(self.logs * x_mask, [1, 2])
|
||||||
|
return y, logdet
|
||||||
|
else:
|
||||||
|
x = (x - self.m) * paddle.exp(-self.logs) * x_mask
|
||||||
|
return x
|
||||||
|
|
||||||
|
|
||||||
|
class Transpose(nn.Layer):
|
||||||
|
"""Transpose module for paddle.nn.Sequential()."""
|
||||||
|
|
||||||
|
def __init__(self, dim1: int, dim2: int):
|
||||||
|
"""Initialize Transpose module."""
|
||||||
|
super().__init__()
|
||||||
|
self.dim1 = dim1
|
||||||
|
self.dim2 = dim2
|
||||||
|
|
||||||
|
def forward(self, x: paddle.Tensor) -> paddle.Tensor:
|
||||||
|
"""Transpose."""
|
||||||
|
len_dim = len(x.shape)
|
||||||
|
orig_perm = list(range(len_dim))
|
||||||
|
new_perm = orig_perm[:]
|
||||||
|
temp = new_perm[self.dim1]
|
||||||
|
new_perm[self.dim1] = new_perm[self.dim2]
|
||||||
|
new_perm[self.dim2] = temp
|
||||||
|
|
||||||
|
return paddle.transpose(x, new_perm)
|
||||||
|
|
||||||
|
|
||||||
|
class DilatedDepthSeparableConv(nn.Layer):
|
||||||
|
"""Dilated depth-separable conv module."""
|
||||||
|
|
||||||
|
def __init__(
|
||||||
|
self,
|
||||||
|
channels: int,
|
||||||
|
kernel_size: int,
|
||||||
|
layers: int,
|
||||||
|
dropout_rate: float=0.0,
|
||||||
|
eps: float=1e-5, ):
|
||||||
|
"""Initialize DilatedDepthSeparableConv module.
|
||||||
|
Args:
|
||||||
|
channels (int): Number of channels.
|
||||||
|
kernel_size (int): Kernel size.
|
||||||
|
layers (int): Number of layers.
|
||||||
|
dropout_rate (float): Dropout rate.
|
||||||
|
eps (float): Epsilon for layer norm.
|
||||||
|
"""
|
||||||
|
super().__init__()
|
||||||
|
|
||||||
|
self.convs = nn.LayerList()
|
||||||
|
for i in range(layers):
|
||||||
|
dilation = kernel_size**i
|
||||||
|
padding = (kernel_size * dilation - dilation) // 2
|
||||||
|
self.convs.append(
|
||||||
|
nn.Sequential(
|
||||||
|
nn.Conv1D(
|
||||||
|
channels,
|
||||||
|
channels,
|
||||||
|
kernel_size,
|
||||||
|
groups=channels,
|
||||||
|
dilation=dilation,
|
||||||
|
padding=padding, ),
|
||||||
|
Transpose(1, 2),
|
||||||
|
nn.LayerNorm(channels, epsilon=eps),
|
||||||
|
Transpose(1, 2),
|
||||||
|
nn.GELU(),
|
||||||
|
nn.Conv1D(
|
||||||
|
channels,
|
||||||
|
channels,
|
||||||
|
1, ),
|
||||||
|
Transpose(1, 2),
|
||||||
|
nn.LayerNorm(channels, epsilon=eps),
|
||||||
|
Transpose(1, 2),
|
||||||
|
nn.GELU(),
|
||||||
|
nn.Dropout(dropout_rate), ))
|
||||||
|
|
||||||
|
def forward(self,
|
||||||
|
x: paddle.Tensor,
|
||||||
|
x_mask: paddle.Tensor,
|
||||||
|
g: Optional[paddle.Tensor]=None) -> paddle.Tensor:
|
||||||
|
"""Calculate forward propagation.
|
||||||
|
Args:
|
||||||
|
x (Tensor): Input tensor (B, in_channels, T).
|
||||||
|
x_mask (Tensor): Mask tensor (B, 1, T).
|
||||||
|
g (Optional[Tensor]): Global conditioning tensor (B, global_channels, 1).
|
||||||
|
Returns:
|
||||||
|
Tensor: Output tensor (B, channels, T).
|
||||||
|
"""
|
||||||
|
if g is not None:
|
||||||
|
x = x + g
|
||||||
|
for f in self.convs:
|
||||||
|
y = f(x * x_mask)
|
||||||
|
x = x + y
|
||||||
|
return x * x_mask
|
||||||
|
|
||||||
|
|
||||||
|
class ConvFlow(nn.Layer):
|
||||||
|
"""Convolutional flow module."""
|
||||||
|
|
||||||
|
def __init__(
|
||||||
|
self,
|
||||||
|
in_channels: int,
|
||||||
|
hidden_channels: int,
|
||||||
|
kernel_size: int,
|
||||||
|
layers: int,
|
||||||
|
bins: int=10,
|
||||||
|
tail_bound: float=5.0, ):
|
||||||
|
"""Initialize ConvFlow module.
|
||||||
|
Args:
|
||||||
|
in_channels (int): Number of input channels.
|
||||||
|
hidden_channels (int): Number of hidden channels.
|
||||||
|
kernel_size (int): Kernel size.
|
||||||
|
layers (int): Number of layers.
|
||||||
|
bins (int): Number of bins.
|
||||||
|
tail_bound (float): Tail bound value.
|
||||||
|
"""
|
||||||
|
super().__init__()
|
||||||
|
self.half_channels = in_channels // 2
|
||||||
|
self.hidden_channels = hidden_channels
|
||||||
|
self.bins = bins
|
||||||
|
self.tail_bound = tail_bound
|
||||||
|
|
||||||
|
self.input_conv = nn.Conv1D(
|
||||||
|
self.half_channels,
|
||||||
|
hidden_channels,
|
||||||
|
1, )
|
||||||
|
self.dds_conv = DilatedDepthSeparableConv(
|
||||||
|
hidden_channels,
|
||||||
|
kernel_size,
|
||||||
|
layers,
|
||||||
|
dropout_rate=0.0, )
|
||||||
|
self.proj = nn.Conv1D(
|
||||||
|
hidden_channels,
|
||||||
|
self.half_channels * (bins * 3 - 1),
|
||||||
|
1, )
|
||||||
|
|
||||||
|
weight = paddle.zeros(paddle.shape(self.proj.weight))
|
||||||
|
|
||||||
|
self.proj.weight = paddle.create_parameter(
|
||||||
|
shape=weight.shape,
|
||||||
|
dtype=str(weight.numpy().dtype),
|
||||||
|
default_initializer=paddle.nn.initializer.Assign(weight))
|
||||||
|
|
||||||
|
bias = paddle.zeros(paddle.shape(self.proj.bias))
|
||||||
|
|
||||||
|
self.proj.bias = paddle.create_parameter(
|
||||||
|
shape=bias.shape,
|
||||||
|
dtype=str(bias.numpy().dtype),
|
||||||
|
default_initializer=paddle.nn.initializer.Assign(bias))
|
||||||
|
|
||||||
|
def forward(
|
||||||
|
self,
|
||||||
|
x: paddle.Tensor,
|
||||||
|
x_mask: paddle.Tensor,
|
||||||
|
g: Optional[paddle.Tensor]=None,
|
||||||
|
inverse: bool=False,
|
||||||
|
) -> Union[paddle.Tensor, Tuple[paddle.Tensor, paddle.Tensor]]:
|
||||||
|
"""Calculate forward propagation.
|
||||||
|
Args:
|
||||||
|
x (Tensor): Input tensor (B, channels, T).
|
||||||
|
x_mask (Tensor): Mask tensor (B, 1, T).
|
||||||
|
g (Optional[Tensor]): Global conditioning tensor (B, channels, 1).
|
||||||
|
inverse (bool): Whether to inverse the flow.
|
||||||
|
Returns:
|
||||||
|
Tensor: Output tensor (B, channels, T).
|
||||||
|
Tensor: Log-determinant tensor for NLL (B,) if not inverse.
|
||||||
|
"""
|
||||||
|
xa, xb = x.split(2, 1)
|
||||||
|
h = self.input_conv(xa)
|
||||||
|
h = self.dds_conv(h, x_mask, g=g)
|
||||||
|
# (B, half_channels * (bins * 3 - 1), T)
|
||||||
|
h = self.proj(h) * x_mask
|
||||||
|
|
||||||
|
b, c, t = xa.shape
|
||||||
|
# (B, half_channels, bins * 3 - 1, T) -> (B, half_channels, T, bins * 3 - 1)
|
||||||
|
h = h.reshape([b, c, -1, t]).transpose([0, 1, 3, 2])
|
||||||
|
|
||||||
|
denom = math.sqrt(self.hidden_channels)
|
||||||
|
unnorm_widths = h[..., :self.bins] / denom
|
||||||
|
unnorm_heights = h[..., self.bins:2 * self.bins] / denom
|
||||||
|
unnorm_derivatives = h[..., 2 * self.bins:]
|
||||||
|
xb, logdet_abs = piecewise_rational_quadratic_transform(
|
||||||
|
xb,
|
||||||
|
unnorm_widths,
|
||||||
|
unnorm_heights,
|
||||||
|
unnorm_derivatives,
|
||||||
|
inverse=inverse,
|
||||||
|
tails="linear",
|
||||||
|
tail_bound=self.tail_bound, )
|
||||||
|
x = paddle.concat([xa, xb], 1) * x_mask
|
||||||
|
logdet = paddle.sum(logdet_abs * x_mask, [1, 2])
|
||||||
|
if not inverse:
|
||||||
|
return x, logdet
|
||||||
|
else:
|
||||||
|
return x
|
@ -0,0 +1,550 @@
|
|||||||
|
# Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved.
|
||||||
|
#
|
||||||
|
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||||
|
# you may not use this file except in compliance with the License.
|
||||||
|
# You may obtain a copy of the License at
|
||||||
|
#
|
||||||
|
# http://www.apache.org/licenses/LICENSE-2.0
|
||||||
|
#
|
||||||
|
# Unless required by applicable law or agreed to in writing, software
|
||||||
|
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||||
|
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||||
|
# See the License for the specific language governing permissions and
|
||||||
|
# limitations under the License.
|
||||||
|
"""Generator module in VITS.
|
||||||
|
|
||||||
|
This code is based on https://github.com/jaywalnut310/vits.
|
||||||
|
|
||||||
|
"""
|
||||||
|
import math
|
||||||
|
from typing import List
|
||||||
|
from typing import Optional
|
||||||
|
from typing import Tuple
|
||||||
|
|
||||||
|
import numpy as np
|
||||||
|
import paddle
|
||||||
|
import paddle.nn.functional as F
|
||||||
|
from paddle import nn
|
||||||
|
|
||||||
|
from paddlespeech.t2s.models.hifigan import HiFiGANGenerator
|
||||||
|
from paddlespeech.t2s.models.vits.duration_predictor import StochasticDurationPredictor
|
||||||
|
from paddlespeech.t2s.models.vits.posterior_encoder import PosteriorEncoder
|
||||||
|
from paddlespeech.t2s.models.vits.residual_coupling import ResidualAffineCouplingBlock
|
||||||
|
from paddlespeech.t2s.models.vits.text_encoder import TextEncoder
|
||||||
|
from paddlespeech.t2s.modules.nets_utils import get_random_segments
|
||||||
|
from paddlespeech.t2s.modules.nets_utils import make_non_pad_mask
|
||||||
|
|
||||||
|
|
||||||
|
class VITSGenerator(nn.Layer):
|
||||||
|
"""Generator module in VITS.
|
||||||
|
This is a module of VITS described in `Conditional Variational Autoencoder
|
||||||
|
with Adversarial Learning for End-to-End Text-to-Speech`_.
|
||||||
|
As text encoder, we use conformer architecture instead of the relative positional
|
||||||
|
Transformer, which contains additional convolution layers.
|
||||||
|
.. _`Conditional Variational Autoencoder with Adversarial Learning for End-to-End
|
||||||
|
Text-to-Speech`: https://arxiv.org/abs/2006.04558
|
||||||
|
"""
|
||||||
|
|
||||||
|
def __init__(
|
||||||
|
self,
|
||||||
|
vocabs: int,
|
||||||
|
aux_channels: int=513,
|
||||||
|
hidden_channels: int=192,
|
||||||
|
spks: Optional[int]=None,
|
||||||
|
langs: Optional[int]=None,
|
||||||
|
spk_embed_dim: Optional[int]=None,
|
||||||
|
global_channels: int=-1,
|
||||||
|
segment_size: int=32,
|
||||||
|
text_encoder_attention_heads: int=2,
|
||||||
|
text_encoder_ffn_expand: int=4,
|
||||||
|
text_encoder_blocks: int=6,
|
||||||
|
text_encoder_positionwise_layer_type: str="conv1d",
|
||||||
|
text_encoder_positionwise_conv_kernel_size: int=1,
|
||||||
|
text_encoder_positional_encoding_layer_type: str="rel_pos",
|
||||||
|
text_encoder_self_attention_layer_type: str="rel_selfattn",
|
||||||
|
text_encoder_activation_type: str="swish",
|
||||||
|
text_encoder_normalize_before: bool=True,
|
||||||
|
text_encoder_dropout_rate: float=0.1,
|
||||||
|
text_encoder_positional_dropout_rate: float=0.0,
|
||||||
|
text_encoder_attention_dropout_rate: float=0.0,
|
||||||
|
text_encoder_conformer_kernel_size: int=7,
|
||||||
|
use_macaron_style_in_text_encoder: bool=True,
|
||||||
|
use_conformer_conv_in_text_encoder: bool=True,
|
||||||
|
decoder_kernel_size: int=7,
|
||||||
|
decoder_channels: int=512,
|
||||||
|
decoder_upsample_scales: List[int]=[8, 8, 2, 2],
|
||||||
|
decoder_upsample_kernel_sizes: List[int]=[16, 16, 4, 4],
|
||||||
|
decoder_resblock_kernel_sizes: List[int]=[3, 7, 11],
|
||||||
|
decoder_resblock_dilations: List[List[int]]=[[1, 3, 5], [1, 3, 5],
|
||||||
|
[1, 3, 5]],
|
||||||
|
use_weight_norm_in_decoder: bool=True,
|
||||||
|
posterior_encoder_kernel_size: int=5,
|
||||||
|
posterior_encoder_layers: int=16,
|
||||||
|
posterior_encoder_stacks: int=1,
|
||||||
|
posterior_encoder_base_dilation: int=1,
|
||||||
|
posterior_encoder_dropout_rate: float=0.0,
|
||||||
|
use_weight_norm_in_posterior_encoder: bool=True,
|
||||||
|
flow_flows: int=4,
|
||||||
|
flow_kernel_size: int=5,
|
||||||
|
flow_base_dilation: int=1,
|
||||||
|
flow_layers: int=4,
|
||||||
|
flow_dropout_rate: float=0.0,
|
||||||
|
use_weight_norm_in_flow: bool=True,
|
||||||
|
use_only_mean_in_flow: bool=True,
|
||||||
|
stochastic_duration_predictor_kernel_size: int=3,
|
||||||
|
stochastic_duration_predictor_dropout_rate: float=0.5,
|
||||||
|
stochastic_duration_predictor_flows: int=4,
|
||||||
|
stochastic_duration_predictor_dds_conv_layers: int=3, ):
|
||||||
|
"""Initialize VITS generator module.
|
||||||
|
Args:
|
||||||
|
vocabs (int): Input vocabulary size.
|
||||||
|
aux_channels (int): Number of acoustic feature channels.
|
||||||
|
hidden_channels (int): Number of hidden channels.
|
||||||
|
spks (Optional[int]): Number of speakers. If set to > 1, assume that the
|
||||||
|
sids will be provided as the input and use sid embedding layer.
|
||||||
|
langs (Optional[int]): Number of languages. If set to > 1, assume that the
|
||||||
|
lids will be provided as the input and use sid embedding layer.
|
||||||
|
spk_embed_dim (Optional[int]): Speaker embedding dimension. If set to > 0,
|
||||||
|
assume that spembs will be provided as the input.
|
||||||
|
global_channels (int): Number of global conditioning channels.
|
||||||
|
segment_size (int): Segment size for decoder.
|
||||||
|
text_encoder_attention_heads (int): Number of heads in conformer block
|
||||||
|
of text encoder.
|
||||||
|
text_encoder_ffn_expand (int): Expansion ratio of FFN in conformer block
|
||||||
|
of text encoder.
|
||||||
|
text_encoder_blocks (int): Number of conformer blocks in text encoder.
|
||||||
|
text_encoder_positionwise_layer_type (str): Position-wise layer type in
|
||||||
|
conformer block of text encoder.
|
||||||
|
text_encoder_positionwise_conv_kernel_size (int): Position-wise convolution
|
||||||
|
kernel size in conformer block of text encoder. Only used when the
|
||||||
|
above layer type is conv1d or conv1d-linear.
|
||||||
|
text_encoder_positional_encoding_layer_type (str): Positional encoding layer
|
||||||
|
type in conformer block of text encoder.
|
||||||
|
text_encoder_self_attention_layer_type (str): Self-attention layer type in
|
||||||
|
conformer block of text encoder.
|
||||||
|
text_encoder_activation_type (str): Activation function type in conformer
|
||||||
|
block of text encoder.
|
||||||
|
text_encoder_normalize_before (bool): Whether to apply layer norm before
|
||||||
|
self-attention in conformer block of text encoder.
|
||||||
|
text_encoder_dropout_rate (float): Dropout rate in conformer block of
|
||||||
|
text encoder.
|
||||||
|
text_encoder_positional_dropout_rate (float): Dropout rate for positional
|
||||||
|
encoding in conformer block of text encoder.
|
||||||
|
text_encoder_attention_dropout_rate (float): Dropout rate for attention in
|
||||||
|
conformer block of text encoder.
|
||||||
|
text_encoder_conformer_kernel_size (int): Conformer conv kernel size. It
|
||||||
|
will be used when only use_conformer_conv_in_text_encoder = True.
|
||||||
|
use_macaron_style_in_text_encoder (bool): Whether to use macaron style FFN
|
||||||
|
in conformer block of text encoder.
|
||||||
|
use_conformer_conv_in_text_encoder (bool): Whether to use covolution in
|
||||||
|
conformer block of text encoder.
|
||||||
|
decoder_kernel_size (int): Decoder kernel size.
|
||||||
|
decoder_channels (int): Number of decoder initial channels.
|
||||||
|
decoder_upsample_scales (List[int]): List of upsampling scales in decoder.
|
||||||
|
decoder_upsample_kernel_sizes (List[int]): List of kernel size for
|
||||||
|
upsampling layers in decoder.
|
||||||
|
decoder_resblock_kernel_sizes (List[int]): List of kernel size for resblocks
|
||||||
|
in decoder.
|
||||||
|
decoder_resblock_dilations (List[List[int]]): List of list of dilations for
|
||||||
|
resblocks in decoder.
|
||||||
|
use_weight_norm_in_decoder (bool): Whether to apply weight normalization in
|
||||||
|
decoder.
|
||||||
|
posterior_encoder_kernel_size (int): Posterior encoder kernel size.
|
||||||
|
posterior_encoder_layers (int): Number of layers of posterior encoder.
|
||||||
|
posterior_encoder_stacks (int): Number of stacks of posterior encoder.
|
||||||
|
posterior_encoder_base_dilation (int): Base dilation of posterior encoder.
|
||||||
|
posterior_encoder_dropout_rate (float): Dropout rate for posterior encoder.
|
||||||
|
use_weight_norm_in_posterior_encoder (bool): Whether to apply weight
|
||||||
|
normalization in posterior encoder.
|
||||||
|
flow_flows (int): Number of flows in flow.
|
||||||
|
flow_kernel_size (int): Kernel size in flow.
|
||||||
|
flow_base_dilation (int): Base dilation in flow.
|
||||||
|
flow_layers (int): Number of layers in flow.
|
||||||
|
flow_dropout_rate (float): Dropout rate in flow
|
||||||
|
use_weight_norm_in_flow (bool): Whether to apply weight normalization in
|
||||||
|
flow.
|
||||||
|
use_only_mean_in_flow (bool): Whether to use only mean in flow.
|
||||||
|
stochastic_duration_predictor_kernel_size (int): Kernel size in stochastic
|
||||||
|
duration predictor.
|
||||||
|
stochastic_duration_predictor_dropout_rate (float): Dropout rate in
|
||||||
|
stochastic duration predictor.
|
||||||
|
stochastic_duration_predictor_flows (int): Number of flows in stochastic
|
||||||
|
duration predictor.
|
||||||
|
stochastic_duration_predictor_dds_conv_layers (int): Number of DDS conv
|
||||||
|
layers in stochastic duration predictor.
|
||||||
|
"""
|
||||||
|
super().__init__()
|
||||||
|
self.segment_size = segment_size
|
||||||
|
self.text_encoder = TextEncoder(
|
||||||
|
vocabs=vocabs,
|
||||||
|
attention_dim=hidden_channels,
|
||||||
|
attention_heads=text_encoder_attention_heads,
|
||||||
|
linear_units=hidden_channels * text_encoder_ffn_expand,
|
||||||
|
blocks=text_encoder_blocks,
|
||||||
|
positionwise_layer_type=text_encoder_positionwise_layer_type,
|
||||||
|
positionwise_conv_kernel_size=text_encoder_positionwise_conv_kernel_size,
|
||||||
|
positional_encoding_layer_type=text_encoder_positional_encoding_layer_type,
|
||||||
|
self_attention_layer_type=text_encoder_self_attention_layer_type,
|
||||||
|
activation_type=text_encoder_activation_type,
|
||||||
|
normalize_before=text_encoder_normalize_before,
|
||||||
|
dropout_rate=text_encoder_dropout_rate,
|
||||||
|
positional_dropout_rate=text_encoder_positional_dropout_rate,
|
||||||
|
attention_dropout_rate=text_encoder_attention_dropout_rate,
|
||||||
|
conformer_kernel_size=text_encoder_conformer_kernel_size,
|
||||||
|
use_macaron_style=use_macaron_style_in_text_encoder,
|
||||||
|
use_conformer_conv=use_conformer_conv_in_text_encoder, )
|
||||||
|
self.decoder = HiFiGANGenerator(
|
||||||
|
in_channels=hidden_channels,
|
||||||
|
out_channels=1,
|
||||||
|
channels=decoder_channels,
|
||||||
|
global_channels=global_channels,
|
||||||
|
kernel_size=decoder_kernel_size,
|
||||||
|
upsample_scales=decoder_upsample_scales,
|
||||||
|
upsample_kernel_sizes=decoder_upsample_kernel_sizes,
|
||||||
|
resblock_kernel_sizes=decoder_resblock_kernel_sizes,
|
||||||
|
resblock_dilations=decoder_resblock_dilations,
|
||||||
|
use_weight_norm=use_weight_norm_in_decoder, )
|
||||||
|
self.posterior_encoder = PosteriorEncoder(
|
||||||
|
in_channels=aux_channels,
|
||||||
|
out_channels=hidden_channels,
|
||||||
|
hidden_channels=hidden_channels,
|
||||||
|
kernel_size=posterior_encoder_kernel_size,
|
||||||
|
layers=posterior_encoder_layers,
|
||||||
|
stacks=posterior_encoder_stacks,
|
||||||
|
base_dilation=posterior_encoder_base_dilation,
|
||||||
|
global_channels=global_channels,
|
||||||
|
dropout_rate=posterior_encoder_dropout_rate,
|
||||||
|
use_weight_norm=use_weight_norm_in_posterior_encoder, )
|
||||||
|
self.flow = ResidualAffineCouplingBlock(
|
||||||
|
in_channels=hidden_channels,
|
||||||
|
hidden_channels=hidden_channels,
|
||||||
|
flows=flow_flows,
|
||||||
|
kernel_size=flow_kernel_size,
|
||||||
|
base_dilation=flow_base_dilation,
|
||||||
|
layers=flow_layers,
|
||||||
|
global_channels=global_channels,
|
||||||
|
dropout_rate=flow_dropout_rate,
|
||||||
|
use_weight_norm=use_weight_norm_in_flow,
|
||||||
|
use_only_mean=use_only_mean_in_flow, )
|
||||||
|
# TODO: Add deterministic version as an option
|
||||||
|
self.duration_predictor = StochasticDurationPredictor(
|
||||||
|
channels=hidden_channels,
|
||||||
|
kernel_size=stochastic_duration_predictor_kernel_size,
|
||||||
|
dropout_rate=stochastic_duration_predictor_dropout_rate,
|
||||||
|
flows=stochastic_duration_predictor_flows,
|
||||||
|
dds_conv_layers=stochastic_duration_predictor_dds_conv_layers,
|
||||||
|
global_channels=global_channels, )
|
||||||
|
|
||||||
|
self.upsample_factor = int(np.prod(decoder_upsample_scales))
|
||||||
|
self.spks = None
|
||||||
|
if spks is not None and spks > 1:
|
||||||
|
assert global_channels > 0
|
||||||
|
self.spks = spks
|
||||||
|
self.global_emb = nn.Embedding(spks, global_channels)
|
||||||
|
self.spk_embed_dim = None
|
||||||
|
if spk_embed_dim is not None and spk_embed_dim > 0:
|
||||||
|
assert global_channels > 0
|
||||||
|
self.spk_embed_dim = spk_embed_dim
|
||||||
|
self.spemb_proj = nn.Linear(spk_embed_dim, global_channels)
|
||||||
|
self.langs = None
|
||||||
|
if langs is not None and langs > 1:
|
||||||
|
assert global_channels > 0
|
||||||
|
self.langs = langs
|
||||||
|
self.lang_emb = nn.Embedding(langs, global_channels)
|
||||||
|
|
||||||
|
# delayed import
|
||||||
|
from paddlespeech.t2s.models.vits.monotonic_align import maximum_path
|
||||||
|
|
||||||
|
self.maximum_path = maximum_path
|
||||||
|
|
||||||
|
def forward(
|
||||||
|
self,
|
||||||
|
text: paddle.Tensor,
|
||||||
|
text_lengths: paddle.Tensor,
|
||||||
|
feats: paddle.Tensor,
|
||||||
|
feats_lengths: paddle.Tensor,
|
||||||
|
sids: Optional[paddle.Tensor]=None,
|
||||||
|
spembs: Optional[paddle.Tensor]=None,
|
||||||
|
lids: Optional[paddle.Tensor]=None,
|
||||||
|
) -> Tuple[paddle.Tensor, paddle.Tensor, paddle.Tensor, paddle.Tensor,
|
||||||
|
paddle.Tensor, paddle.Tensor,
|
||||||
|
Tuple[paddle.Tensor, paddle.Tensor, paddle.Tensor, paddle.Tensor,
|
||||||
|
paddle.Tensor, paddle.Tensor, ], ]:
|
||||||
|
"""Calculate forward propagation.
|
||||||
|
Args:
|
||||||
|
text (Tensor): Text index tensor (B, T_text).
|
||||||
|
text_lengths (Tensor): Text length tensor (B,).
|
||||||
|
feats (Tensor): Feature tensor (B, aux_channels, T_feats).
|
||||||
|
feats_lengths (Tensor): Feature length tensor (B,).
|
||||||
|
sids (Optional[Tensor]): Speaker index tensor (B,) or (B, 1).
|
||||||
|
spembs (Optional[Tensor]): Speaker embedding tensor (B, spk_embed_dim).
|
||||||
|
lids (Optional[Tensor]): Language index tensor (B,) or (B, 1).
|
||||||
|
Returns:
|
||||||
|
Tensor: Waveform tensor (B, 1, segment_size * upsample_factor).
|
||||||
|
Tensor: Duration negative log-likelihood (NLL) tensor (B,).
|
||||||
|
Tensor: Monotonic attention weight tensor (B, 1, T_feats, T_text).
|
||||||
|
Tensor: Segments start index tensor (B,).
|
||||||
|
Tensor: Text mask tensor (B, 1, T_text).
|
||||||
|
Tensor: Feature mask tensor (B, 1, T_feats).
|
||||||
|
tuple[Tensor, Tensor, Tensor, Tensor, Tensor, Tensor]:
|
||||||
|
- Tensor: Posterior encoder hidden representation (B, H, T_feats).
|
||||||
|
- Tensor: Flow hidden representation (B, H, T_feats).
|
||||||
|
- Tensor: Expanded text encoder projected mean (B, H, T_feats).
|
||||||
|
- Tensor: Expanded text encoder projected scale (B, H, T_feats).
|
||||||
|
- Tensor: Posterior encoder projected mean (B, H, T_feats).
|
||||||
|
- Tensor: Posterior encoder projected scale (B, H, T_feats).
|
||||||
|
"""
|
||||||
|
# forward text encoder
|
||||||
|
x, m_p, logs_p, x_mask = self.text_encoder(text, text_lengths)
|
||||||
|
|
||||||
|
# calculate global conditioning
|
||||||
|
g = None
|
||||||
|
if self.spks is not None:
|
||||||
|
# speaker one-hot vector embedding: (B, global_channels, 1)
|
||||||
|
g = self.global_emb(paddle.reshape(sids, [-1])).unsqueeze(-1)
|
||||||
|
if self.spk_embed_dim is not None:
|
||||||
|
# pretreined speaker embedding, e.g., X-vector (B, global_channels, 1)
|
||||||
|
g_ = self.spemb_proj(F.normalize(spembs)).unsqueeze(-1)
|
||||||
|
if g is None:
|
||||||
|
g = g_
|
||||||
|
else:
|
||||||
|
g = g + g_
|
||||||
|
if self.langs is not None:
|
||||||
|
# language one-hot vector embedding: (B, global_channels, 1)
|
||||||
|
g_ = self.lang_emb(paddle.reshape(lids, [-1])).unsqueeze(-1)
|
||||||
|
if g is None:
|
||||||
|
g = g_
|
||||||
|
else:
|
||||||
|
g = g + g_
|
||||||
|
|
||||||
|
# forward posterior encoder
|
||||||
|
z, m_q, logs_q, y_mask = self.posterior_encoder(
|
||||||
|
feats, feats_lengths, g=g)
|
||||||
|
|
||||||
|
# forward flow
|
||||||
|
# (B, H, T_feats)
|
||||||
|
z_p = self.flow(z, y_mask, g=g)
|
||||||
|
|
||||||
|
# monotonic alignment search
|
||||||
|
with paddle.no_grad():
|
||||||
|
# negative cross-entropy
|
||||||
|
# (B, H, T_text)
|
||||||
|
s_p_sq_r = paddle.exp(-2 * logs_p)
|
||||||
|
# (B, 1, T_text)
|
||||||
|
neg_x_ent_1 = paddle.sum(
|
||||||
|
-0.5 * math.log(2 * math.pi) - logs_p,
|
||||||
|
[1],
|
||||||
|
keepdim=True, )
|
||||||
|
# (B, T_feats, H) x (B, H, T_text) = (B, T_feats, T_text)
|
||||||
|
neg_x_ent_2 = paddle.matmul(
|
||||||
|
-0.5 * (z_p**2).transpose([0, 2, 1]),
|
||||||
|
s_p_sq_r, )
|
||||||
|
# (B, T_feats, H) x (B, H, T_text) = (B, T_feats, T_text)
|
||||||
|
neg_x_ent_3 = paddle.matmul(
|
||||||
|
z_p.transpose([0, 2, 1]),
|
||||||
|
(m_p * s_p_sq_r), )
|
||||||
|
# (B, 1, T_text)
|
||||||
|
neg_x_ent_4 = paddle.sum(
|
||||||
|
-0.5 * (m_p**2) * s_p_sq_r,
|
||||||
|
[1],
|
||||||
|
keepdim=True, )
|
||||||
|
# (B, T_feats, T_text)
|
||||||
|
neg_x_ent = neg_x_ent_1 + neg_x_ent_2 + neg_x_ent_3 + neg_x_ent_4
|
||||||
|
# (B, 1, T_feats, T_text)
|
||||||
|
attn_mask = paddle.unsqueeze(x_mask, 2) * paddle.unsqueeze(y_mask,
|
||||||
|
-1)
|
||||||
|
# monotonic attention weight: (B, 1, T_feats, T_text)
|
||||||
|
attn = (self.maximum_path(
|
||||||
|
neg_x_ent,
|
||||||
|
attn_mask.squeeze(1), ).unsqueeze(1).detach())
|
||||||
|
|
||||||
|
# forward duration predictor
|
||||||
|
# (B, 1, T_text)
|
||||||
|
w = attn.sum(2)
|
||||||
|
dur_nll = self.duration_predictor(x, x_mask, w=w, g=g)
|
||||||
|
dur_nll = dur_nll / paddle.sum(x_mask)
|
||||||
|
|
||||||
|
# expand the length to match with the feature sequence
|
||||||
|
# (B, T_feats, T_text) x (B, T_text, H) -> (B, H, T_feats)
|
||||||
|
m_p = paddle.matmul(attn.squeeze(1),
|
||||||
|
m_p.transpose([0, 2, 1])).transpose([0, 2, 1])
|
||||||
|
# (B, T_feats, T_text) x (B, T_text, H) -> (B, H, T_feats)
|
||||||
|
logs_p = paddle.matmul(attn.squeeze(1),
|
||||||
|
logs_p.transpose([0, 2, 1])).transpose([0, 2, 1])
|
||||||
|
|
||||||
|
# get random segments
|
||||||
|
z_segments, z_start_idxs = get_random_segments(
|
||||||
|
z,
|
||||||
|
feats_lengths,
|
||||||
|
self.segment_size, )
|
||||||
|
|
||||||
|
# forward decoder with random segments
|
||||||
|
wav = self.decoder(z_segments, g=g)
|
||||||
|
|
||||||
|
return (wav, dur_nll, attn, z_start_idxs, x_mask, y_mask,
|
||||||
|
(z, z_p, m_p, logs_p, m_q, logs_q), )
|
||||||
|
|
||||||
|
def inference(
|
||||||
|
self,
|
||||||
|
text: paddle.Tensor,
|
||||||
|
text_lengths: paddle.Tensor,
|
||||||
|
feats: Optional[paddle.Tensor]=None,
|
||||||
|
feats_lengths: Optional[paddle.Tensor]=None,
|
||||||
|
sids: Optional[paddle.Tensor]=None,
|
||||||
|
spembs: Optional[paddle.Tensor]=None,
|
||||||
|
lids: Optional[paddle.Tensor]=None,
|
||||||
|
dur: Optional[paddle.Tensor]=None,
|
||||||
|
noise_scale: float=0.667,
|
||||||
|
noise_scale_dur: float=0.8,
|
||||||
|
alpha: float=1.0,
|
||||||
|
max_len: Optional[int]=None,
|
||||||
|
use_teacher_forcing: bool=False,
|
||||||
|
) -> Tuple[paddle.Tensor, paddle.Tensor, paddle.Tensor]:
|
||||||
|
"""Run inference.
|
||||||
|
Args:
|
||||||
|
text (Tensor): Input text index tensor (B, T_text,).
|
||||||
|
text_lengths (Tensor): Text length tensor (B,).
|
||||||
|
feats (Tensor): Feature tensor (B, aux_channels, T_feats,).
|
||||||
|
feats_lengths (Tensor): Feature length tensor (B,).
|
||||||
|
sids (Optional[Tensor]): Speaker index tensor (B,) or (B, 1).
|
||||||
|
spembs (Optional[Tensor]): Speaker embedding tensor (B, spk_embed_dim).
|
||||||
|
lids (Optional[Tensor]): Language index tensor (B,) or (B, 1).
|
||||||
|
dur (Optional[Tensor]): Ground-truth duration (B, T_text,). If provided,
|
||||||
|
skip the prediction of durations (i.e., teacher forcing).
|
||||||
|
noise_scale (float): Noise scale parameter for flow.
|
||||||
|
noise_scale_dur (float): Noise scale parameter for duration predictor.
|
||||||
|
alpha (float): Alpha parameter to control the speed of generated speech.
|
||||||
|
max_len (Optional[int]): Maximum length of acoustic feature sequence.
|
||||||
|
use_teacher_forcing (bool): Whether to use teacher forcing.
|
||||||
|
Returns:
|
||||||
|
Tensor: Generated waveform tensor (B, T_wav).
|
||||||
|
Tensor: Monotonic attention weight tensor (B, T_feats, T_text).
|
||||||
|
Tensor: Duration tensor (B, T_text).
|
||||||
|
"""
|
||||||
|
# encoder
|
||||||
|
x, m_p, logs_p, x_mask = self.text_encoder(text, text_lengths)
|
||||||
|
g = None
|
||||||
|
if self.spks is not None:
|
||||||
|
# (B, global_channels, 1)
|
||||||
|
g = self.global_emb(paddle.reshape(sids, [-1])).unsqueeze(-1)
|
||||||
|
if self.spk_embed_dim is not None:
|
||||||
|
# (B, global_channels, 1)
|
||||||
|
g_ = self.spemb_proj(F.normalize(spembs.unsqueeze(0))).unsqueeze(-1)
|
||||||
|
if g is None:
|
||||||
|
g = g_
|
||||||
|
else:
|
||||||
|
g = g + g_
|
||||||
|
if self.langs is not None:
|
||||||
|
# (B, global_channels, 1)
|
||||||
|
g_ = self.lang_emb(paddle.reshape(lids, [-1])).unsqueeze(-1)
|
||||||
|
if g is None:
|
||||||
|
g = g_
|
||||||
|
else:
|
||||||
|
g = g + g_
|
||||||
|
|
||||||
|
if use_teacher_forcing:
|
||||||
|
# forward posterior encoder
|
||||||
|
z, m_q, logs_q, y_mask = self.posterior_encoder(
|
||||||
|
feats, feats_lengths, g=g)
|
||||||
|
|
||||||
|
# forward flow
|
||||||
|
# (B, H, T_feats)
|
||||||
|
z_p = self.flow(z, y_mask, g=g)
|
||||||
|
|
||||||
|
# monotonic alignment search
|
||||||
|
# (B, H, T_text)
|
||||||
|
s_p_sq_r = paddle.exp(-2 * logs_p)
|
||||||
|
# (B, 1, T_text)
|
||||||
|
neg_x_ent_1 = paddle.sum(
|
||||||
|
-0.5 * math.log(2 * math.pi) - logs_p,
|
||||||
|
[1],
|
||||||
|
keepdim=True, )
|
||||||
|
# (B, T_feats, H) x (B, H, T_text) = (B, T_feats, T_text)
|
||||||
|
neg_x_ent_2 = paddle.matmul(
|
||||||
|
-0.5 * (z_p**2).transpose([0, 2, 1]),
|
||||||
|
s_p_sq_r, )
|
||||||
|
# (B, T_feats, H) x (B, H, T_text) = (B, T_feats, T_text)
|
||||||
|
neg_x_ent_3 = paddle.matmul(
|
||||||
|
z_p.transpose([0, 2, 1]),
|
||||||
|
(m_p * s_p_sq_r), )
|
||||||
|
# (B, 1, T_text)
|
||||||
|
neg_x_ent_4 = paddle.sum(
|
||||||
|
-0.5 * (m_p**2) * s_p_sq_r,
|
||||||
|
[1],
|
||||||
|
keepdim=True, )
|
||||||
|
# (B, T_feats, T_text)
|
||||||
|
neg_x_ent = neg_x_ent_1 + neg_x_ent_2 + neg_x_ent_3 + neg_x_ent_4
|
||||||
|
# (B, 1, T_feats, T_text)
|
||||||
|
attn_mask = paddle.unsqueeze(x_mask, 2) * paddle.unsqueeze(y_mask,
|
||||||
|
-1)
|
||||||
|
# monotonic attention weight: (B, 1, T_feats, T_text)
|
||||||
|
attn = self.maximum_path(
|
||||||
|
neg_x_ent,
|
||||||
|
attn_mask.squeeze(1), ).unsqueeze(1)
|
||||||
|
# (B, 1, T_text)
|
||||||
|
dur = attn.sum(2)
|
||||||
|
|
||||||
|
# forward decoder with random segments
|
||||||
|
wav = self.decoder(z * y_mask, g=g)
|
||||||
|
else:
|
||||||
|
# duration
|
||||||
|
if dur is None:
|
||||||
|
logw = self.duration_predictor(
|
||||||
|
x,
|
||||||
|
x_mask,
|
||||||
|
g=g,
|
||||||
|
inverse=True,
|
||||||
|
noise_scale=noise_scale_dur, )
|
||||||
|
w = paddle.exp(logw) * x_mask * alpha
|
||||||
|
dur = paddle.ceil(w)
|
||||||
|
y_lengths = paddle.cast(
|
||||||
|
paddle.clip(paddle.sum(dur, [1, 2]), min=1), dtype='int64')
|
||||||
|
y_mask = make_non_pad_mask(y_lengths).unsqueeze(1)
|
||||||
|
attn_mask = paddle.unsqueeze(x_mask, 2) * paddle.unsqueeze(y_mask,
|
||||||
|
-1)
|
||||||
|
attn = self._generate_path(dur, attn_mask)
|
||||||
|
|
||||||
|
# expand the length to match with the feature sequence
|
||||||
|
# (B, T_feats, T_text) x (B, T_text, H) -> (B, H, T_feats)
|
||||||
|
m_p = paddle.matmul(
|
||||||
|
attn.squeeze(1),
|
||||||
|
m_p.transpose([0, 2, 1]), ).transpose([0, 2, 1])
|
||||||
|
# (B, T_feats, T_text) x (B, T_text, H) -> (B, H, T_feats)
|
||||||
|
logs_p = paddle.matmul(
|
||||||
|
attn.squeeze(1),
|
||||||
|
logs_p.transpose([0, 2, 1]), ).transpose([0, 2, 1])
|
||||||
|
|
||||||
|
# decoder
|
||||||
|
z_p = m_p + paddle.randn(
|
||||||
|
paddle.shape(m_p)) * paddle.exp(logs_p) * noise_scale
|
||||||
|
z = self.flow(z_p, y_mask, g=g, inverse=True)
|
||||||
|
wav = self.decoder((z * y_mask)[:, :, :max_len], g=g)
|
||||||
|
|
||||||
|
return wav.squeeze(1), attn.squeeze(1), dur.squeeze(1)
|
||||||
|
|
||||||
|
def _generate_path(self, dur: paddle.Tensor,
|
||||||
|
mask: paddle.Tensor) -> paddle.Tensor:
|
||||||
|
"""Generate path a.k.a. monotonic attention.
|
||||||
|
Args:
|
||||||
|
dur (Tensor): Duration tensor (B, 1, T_text).
|
||||||
|
mask (Tensor): Attention mask tensor (B, 1, T_feats, T_text).
|
||||||
|
Returns:
|
||||||
|
Tensor: Path tensor (B, 1, T_feats, T_text).
|
||||||
|
"""
|
||||||
|
b, _, t_y, t_x = paddle.shape(mask)
|
||||||
|
cum_dur = paddle.cumsum(dur, -1)
|
||||||
|
cum_dur_flat = paddle.reshape(cum_dur, [b * t_x])
|
||||||
|
|
||||||
|
path = paddle.arange(t_y, dtype=dur.dtype)
|
||||||
|
path = path.unsqueeze(0) < cum_dur_flat.unsqueeze(1)
|
||||||
|
path = paddle.reshape(path, [b, t_x, t_y])
|
||||||
|
'''
|
||||||
|
path will be like (t_x = 3, t_y = 5):
|
||||||
|
[[[1., 1., 0., 0., 0.], [[[1., 1., 0., 0., 0.],
|
||||||
|
[1., 1., 1., 1., 0.], --> [0., 0., 1., 1., 0.],
|
||||||
|
[1., 1., 1., 1., 1.]]] [0., 0., 0., 0., 1.]]]
|
||||||
|
'''
|
||||||
|
|
||||||
|
path = paddle.cast(path, dtype='float32')
|
||||||
|
path = path - F.pad(path, [0, 0, 1, 0, 0, 0])[:, :-1]
|
||||||
|
return path.unsqueeze(1).transpose([0, 1, 3, 2]) * mask
|
@ -0,0 +1,94 @@
|
|||||||
|
# Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved.
|
||||||
|
#
|
||||||
|
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||||
|
# you may not use this file except in compliance with the License.
|
||||||
|
# You may obtain a copy of the License at
|
||||||
|
#
|
||||||
|
# http://www.apache.org/licenses/LICENSE-2.0
|
||||||
|
#
|
||||||
|
# Unless required by applicable law or agreed to in writing, software
|
||||||
|
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||||
|
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||||
|
# See the License for the specific language governing permissions and
|
||||||
|
# limitations under the License.
|
||||||
|
"""Maximum path calculation module.
|
||||||
|
|
||||||
|
This code is based on https://github.com/jaywalnut310/vits.
|
||||||
|
|
||||||
|
"""
|
||||||
|
import warnings
|
||||||
|
|
||||||
|
import numpy as np
|
||||||
|
import paddle
|
||||||
|
from numba import njit
|
||||||
|
from numba import prange
|
||||||
|
|
||||||
|
try:
|
||||||
|
from .core import maximum_path_c
|
||||||
|
|
||||||
|
is_cython_avalable = True
|
||||||
|
except ImportError:
|
||||||
|
is_cython_avalable = False
|
||||||
|
warnings.warn(
|
||||||
|
"Cython version is not available. Fallback to 'EXPERIMETAL' numba version. "
|
||||||
|
"If you want to use the cython version, please build it as follows: "
|
||||||
|
"`cd paddlespeech/t2s/models/vits/monotonic_align; python setup.py build_ext --inplace`"
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
|
def maximum_path(neg_x_ent: paddle.Tensor,
|
||||||
|
attn_mask: paddle.Tensor) -> paddle.Tensor:
|
||||||
|
"""Calculate maximum path.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
neg_x_ent (Tensor): Negative X entropy tensor (B, T_feats, T_text).
|
||||||
|
attn_mask (Tensor): Attention mask (B, T_feats, T_text).
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
Tensor: Maximum path tensor (B, T_feats, T_text).
|
||||||
|
|
||||||
|
"""
|
||||||
|
dtype = neg_x_ent.dtype
|
||||||
|
neg_x_ent = neg_x_ent.numpy().astype(np.float32)
|
||||||
|
path = np.zeros(neg_x_ent.shape, dtype=np.int32)
|
||||||
|
t_t_max = attn_mask.sum(1)[:, 0].cpu().numpy().astype(np.int32)
|
||||||
|
t_s_max = attn_mask.sum(2)[:, 0].cpu().numpy().astype(np.int32)
|
||||||
|
if is_cython_avalable:
|
||||||
|
maximum_path_c(path, neg_x_ent, t_t_max, t_s_max)
|
||||||
|
else:
|
||||||
|
maximum_path_numba(path, neg_x_ent, t_t_max, t_s_max)
|
||||||
|
|
||||||
|
return paddle.cast(paddle.to_tensor(path), dtype=dtype)
|
||||||
|
|
||||||
|
|
||||||
|
@njit
|
||||||
|
def maximum_path_each_numba(path, value, t_y, t_x, max_neg_val=-np.inf):
|
||||||
|
"""Calculate a single maximum path with numba."""
|
||||||
|
index = t_x - 1
|
||||||
|
for y in range(t_y):
|
||||||
|
for x in range(max(0, t_x + y - t_y), min(t_x, y + 1)):
|
||||||
|
if x == y:
|
||||||
|
v_cur = max_neg_val
|
||||||
|
else:
|
||||||
|
v_cur = value[y - 1, x]
|
||||||
|
if x == 0:
|
||||||
|
if y == 0:
|
||||||
|
v_prev = 0.0
|
||||||
|
else:
|
||||||
|
v_prev = max_neg_val
|
||||||
|
else:
|
||||||
|
v_prev = value[y - 1, x - 1]
|
||||||
|
value[y, x] += max(v_prev, v_cur)
|
||||||
|
|
||||||
|
for y in range(t_y - 1, -1, -1):
|
||||||
|
path[y, index] = 1
|
||||||
|
if index != 0 and (index == y or
|
||||||
|
value[y - 1, index] < value[y - 1, index - 1]):
|
||||||
|
index = index - 1
|
||||||
|
|
||||||
|
|
||||||
|
@njit(parallel=True)
|
||||||
|
def maximum_path_numba(paths, values, t_ys, t_xs):
|
||||||
|
"""Calculate batch maximum path with numba."""
|
||||||
|
for i in prange(paths.shape[0]):
|
||||||
|
maximum_path_each_numba(paths[i], values[i], t_ys[i], t_xs[i])
|
@ -0,0 +1,62 @@
|
|||||||
|
# Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved.
|
||||||
|
#
|
||||||
|
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||||
|
# you may not use this file except in compliance with the License.
|
||||||
|
# You may obtain a copy of the License at
|
||||||
|
#
|
||||||
|
# http://www.apache.org/licenses/LICENSE-2.0
|
||||||
|
#
|
||||||
|
# Unless required by applicable law or agreed to in writing, software
|
||||||
|
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||||
|
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||||
|
# See the License for the specific language governing permissions and
|
||||||
|
# limitations under the License.
|
||||||
|
"""Maximum path calculation module with cython optimization.
|
||||||
|
|
||||||
|
This code is copied from https://github.com/jaywalnut310/vits and modifed code format.
|
||||||
|
|
||||||
|
"""
|
||||||
|
|
||||||
|
cimport cython
|
||||||
|
|
||||||
|
from cython.parallel import prange
|
||||||
|
|
||||||
|
|
||||||
|
@cython.boundscheck(False)
|
||||||
|
@cython.wraparound(False)
|
||||||
|
cdef void maximum_path_each(int[:, ::1] path, float[:, ::1] value, int t_y, int t_x, float max_neg_val=-1e9) nogil:
|
||||||
|
cdef int x
|
||||||
|
cdef int y
|
||||||
|
cdef float v_prev
|
||||||
|
cdef float v_cur
|
||||||
|
cdef float tmp
|
||||||
|
cdef int index = t_x - 1
|
||||||
|
|
||||||
|
for y in range(t_y):
|
||||||
|
for x in range(max(0, t_x + y - t_y), min(t_x, y + 1)):
|
||||||
|
if x == y:
|
||||||
|
v_cur = max_neg_val
|
||||||
|
else:
|
||||||
|
v_cur = value[y - 1, x]
|
||||||
|
if x == 0:
|
||||||
|
if y == 0:
|
||||||
|
v_prev = 0.0
|
||||||
|
else:
|
||||||
|
v_prev = max_neg_val
|
||||||
|
else:
|
||||||
|
v_prev = value[y - 1, x - 1]
|
||||||
|
value[y, x] += max(v_prev, v_cur)
|
||||||
|
|
||||||
|
for y in range(t_y - 1, -1, -1):
|
||||||
|
path[y, index] = 1
|
||||||
|
if index != 0 and (index == y or value[y - 1, index] < value[y - 1, index - 1]):
|
||||||
|
index = index - 1
|
||||||
|
|
||||||
|
|
||||||
|
@cython.boundscheck(False)
|
||||||
|
@cython.wraparound(False)
|
||||||
|
cpdef void maximum_path_c(int[:, :, ::1] paths, float[:, :, ::1] values, int[::1] t_ys, int[::1] t_xs) nogil:
|
||||||
|
cdef int b = paths.shape[0]
|
||||||
|
cdef int i
|
||||||
|
for i in prange(b, nogil=True):
|
||||||
|
maximum_path_each(paths[i], values[i], t_ys[i], t_xs[i])
|
@ -0,0 +1,39 @@
|
|||||||
|
# Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved.
|
||||||
|
#
|
||||||
|
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||||
|
# you may not use this file except in compliance with the License.
|
||||||
|
# You may obtain a copy of the License at
|
||||||
|
#
|
||||||
|
# http://www.apache.org/licenses/LICENSE-2.0
|
||||||
|
#
|
||||||
|
# Unless required by applicable law or agreed to in writing, software
|
||||||
|
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||||
|
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||||
|
# See the License for the specific language governing permissions and
|
||||||
|
# limitations under the License.
|
||||||
|
"""Setup cython code."""
|
||||||
|
from Cython.Build import cythonize
|
||||||
|
from setuptools import Extension
|
||||||
|
from setuptools import setup
|
||||||
|
from setuptools.command.build_ext import build_ext as _build_ext
|
||||||
|
|
||||||
|
|
||||||
|
class build_ext(_build_ext):
|
||||||
|
"""Overwrite build_ext."""
|
||||||
|
|
||||||
|
def finalize_options(self):
|
||||||
|
"""Prevent numpy from thinking it is still in its setup process."""
|
||||||
|
_build_ext.finalize_options(self)
|
||||||
|
__builtins__.__NUMPY_SETUP__ = False
|
||||||
|
import numpy
|
||||||
|
|
||||||
|
self.include_dirs.append(numpy.get_include())
|
||||||
|
|
||||||
|
|
||||||
|
exts = [Extension(
|
||||||
|
name="core",
|
||||||
|
sources=["core.pyx"], )]
|
||||||
|
setup(
|
||||||
|
name="monotonic_align",
|
||||||
|
ext_modules=cythonize(exts, language_level=3),
|
||||||
|
cmdclass={"build_ext": build_ext}, )
|
@ -0,0 +1,120 @@
|
|||||||
|
# Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved.
|
||||||
|
#
|
||||||
|
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||||
|
# you may not use this file except in compliance with the License.
|
||||||
|
# You may obtain a copy of the License at
|
||||||
|
#
|
||||||
|
# http://www.apache.org/licenses/LICENSE-2.0
|
||||||
|
#
|
||||||
|
# Unless required by applicable law or agreed to in writing, software
|
||||||
|
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||||
|
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||||
|
# See the License for the specific language governing permissions and
|
||||||
|
# limitations under the License.
|
||||||
|
"""Text encoder module in VITS.
|
||||||
|
|
||||||
|
This code is based on https://github.com/jaywalnut310/vits.
|
||||||
|
|
||||||
|
"""
|
||||||
|
from typing import Optional
|
||||||
|
from typing import Tuple
|
||||||
|
|
||||||
|
import paddle
|
||||||
|
from paddle import nn
|
||||||
|
|
||||||
|
from paddlespeech.t2s.models.vits.wavenet.wavenet import WaveNet
|
||||||
|
from paddlespeech.t2s.modules.nets_utils import make_non_pad_mask
|
||||||
|
|
||||||
|
|
||||||
|
class PosteriorEncoder(nn.Layer):
|
||||||
|
"""Posterior encoder module in VITS.
|
||||||
|
|
||||||
|
This is a module of posterior encoder described in `Conditional Variational
|
||||||
|
Autoencoder with Adversarial Learning for End-to-End Text-to-Speech`_.
|
||||||
|
|
||||||
|
.. _`Conditional Variational Autoencoder with Adversarial Learning for End-to-End
|
||||||
|
Text-to-Speech`: https://arxiv.org/abs/2006.04558
|
||||||
|
"""
|
||||||
|
|
||||||
|
def __init__(
|
||||||
|
self,
|
||||||
|
in_channels: int=513,
|
||||||
|
out_channels: int=192,
|
||||||
|
hidden_channels: int=192,
|
||||||
|
kernel_size: int=5,
|
||||||
|
layers: int=16,
|
||||||
|
stacks: int=1,
|
||||||
|
base_dilation: int=1,
|
||||||
|
global_channels: int=-1,
|
||||||
|
dropout_rate: float=0.0,
|
||||||
|
bias: bool=True,
|
||||||
|
use_weight_norm: bool=True, ):
|
||||||
|
"""Initilialize PosteriorEncoder module.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
in_channels (int): Number of input channels.
|
||||||
|
out_channels (int): Number of output channels.
|
||||||
|
hidden_channels (int): Number of hidden channels.
|
||||||
|
kernel_size (int): Kernel size in WaveNet.
|
||||||
|
layers (int): Number of layers of WaveNet.
|
||||||
|
stacks (int): Number of repeat stacking of WaveNet.
|
||||||
|
base_dilation (int): Base dilation factor.
|
||||||
|
global_channels (int): Number of global conditioning channels.
|
||||||
|
dropout_rate (float): Dropout rate.
|
||||||
|
bias (bool): Whether to use bias parameters in conv.
|
||||||
|
use_weight_norm (bool): Whether to apply weight norm.
|
||||||
|
|
||||||
|
"""
|
||||||
|
super().__init__()
|
||||||
|
|
||||||
|
# define modules
|
||||||
|
self.input_conv = nn.Conv1D(in_channels, hidden_channels, 1)
|
||||||
|
self.encoder = WaveNet(
|
||||||
|
in_channels=-1,
|
||||||
|
out_channels=-1,
|
||||||
|
kernel_size=kernel_size,
|
||||||
|
layers=layers,
|
||||||
|
stacks=stacks,
|
||||||
|
base_dilation=base_dilation,
|
||||||
|
residual_channels=hidden_channels,
|
||||||
|
aux_channels=-1,
|
||||||
|
gate_channels=hidden_channels * 2,
|
||||||
|
skip_channels=hidden_channels,
|
||||||
|
global_channels=global_channels,
|
||||||
|
dropout_rate=dropout_rate,
|
||||||
|
bias=bias,
|
||||||
|
use_weight_norm=use_weight_norm,
|
||||||
|
use_first_conv=False,
|
||||||
|
use_last_conv=False,
|
||||||
|
scale_residual=False,
|
||||||
|
scale_skip_connect=True, )
|
||||||
|
self.proj = nn.Conv1D(hidden_channels, out_channels * 2, 1)
|
||||||
|
|
||||||
|
def forward(
|
||||||
|
self,
|
||||||
|
x: paddle.Tensor,
|
||||||
|
x_lengths: paddle.Tensor,
|
||||||
|
g: Optional[paddle.Tensor]=None
|
||||||
|
) -> Tuple[paddle.Tensor, paddle.Tensor, paddle.Tensor, paddle.Tensor]:
|
||||||
|
"""Calculate forward propagation.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
x (Tensor): Input tensor (B, in_channels, T_feats).
|
||||||
|
x_lengths (Tensor): Length tensor (B,).
|
||||||
|
g (Optional[Tensor]): Global conditioning tensor (B, global_channels, 1).
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
Tensor: Encoded hidden representation tensor (B, out_channels, T_feats).
|
||||||
|
Tensor: Projected mean tensor (B, out_channels, T_feats).
|
||||||
|
Tensor: Projected scale tensor (B, out_channels, T_feats).
|
||||||
|
Tensor: Mask tensor for input tensor (B, 1, T_feats).
|
||||||
|
|
||||||
|
"""
|
||||||
|
x_mask = make_non_pad_mask(x_lengths).unsqueeze(1)
|
||||||
|
x = self.input_conv(x) * x_mask
|
||||||
|
x = self.encoder(x, x_mask, g=g)
|
||||||
|
stats = self.proj(x) * x_mask
|
||||||
|
m, logs = paddle.split(stats, 2, axis=1)
|
||||||
|
z = (m + paddle.randn(paddle.shape(m)) * paddle.exp(logs)) * x_mask
|
||||||
|
|
||||||
|
return z, m, logs, x_mask
|
@ -0,0 +1,242 @@
|
|||||||
|
# Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved.
|
||||||
|
#
|
||||||
|
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||||
|
# you may not use this file except in compliance with the License.
|
||||||
|
# You may obtain a copy of the License at
|
||||||
|
#
|
||||||
|
# http://www.apache.org/licenses/LICENSE-2.0
|
||||||
|
#
|
||||||
|
# Unless required by applicable law or agreed to in writing, software
|
||||||
|
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||||
|
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||||
|
# See the License for the specific language governing permissions and
|
||||||
|
# limitations under the License.
|
||||||
|
"""Residual affine coupling modules in VITS.
|
||||||
|
|
||||||
|
This code is based on https://github.com/jaywalnut310/vits.
|
||||||
|
|
||||||
|
"""
|
||||||
|
from typing import Optional
|
||||||
|
from typing import Tuple
|
||||||
|
from typing import Union
|
||||||
|
|
||||||
|
import paddle
|
||||||
|
from paddle import nn
|
||||||
|
|
||||||
|
from paddlespeech.t2s.models.vits.flow import FlipFlow
|
||||||
|
from paddlespeech.t2s.models.vits.wavenet.wavenet import WaveNet
|
||||||
|
|
||||||
|
|
||||||
|
class ResidualAffineCouplingBlock(nn.Layer):
|
||||||
|
"""Residual affine coupling block module.
|
||||||
|
|
||||||
|
This is a module of residual affine coupling block, which used as "Flow" in
|
||||||
|
`Conditional Variational Autoencoder with Adversarial Learning for End-to-End
|
||||||
|
Text-to-Speech`_.
|
||||||
|
|
||||||
|
.. _`Conditional Variational Autoencoder with Adversarial Learning for End-to-End
|
||||||
|
Text-to-Speech`: https://arxiv.org/abs/2006.04558
|
||||||
|
|
||||||
|
"""
|
||||||
|
|
||||||
|
def __init__(
|
||||||
|
self,
|
||||||
|
in_channels: int=192,
|
||||||
|
hidden_channels: int=192,
|
||||||
|
flows: int=4,
|
||||||
|
kernel_size: int=5,
|
||||||
|
base_dilation: int=1,
|
||||||
|
layers: int=4,
|
||||||
|
global_channels: int=-1,
|
||||||
|
dropout_rate: float=0.0,
|
||||||
|
use_weight_norm: bool=True,
|
||||||
|
bias: bool=True,
|
||||||
|
use_only_mean: bool=True, ):
|
||||||
|
"""Initilize ResidualAffineCouplingBlock module.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
in_channels (int): Number of input channels.
|
||||||
|
hidden_channels (int): Number of hidden channels.
|
||||||
|
flows (int): Number of flows.
|
||||||
|
kernel_size (int): Kernel size for WaveNet.
|
||||||
|
base_dilation (int): Base dilation factor for WaveNet.
|
||||||
|
layers (int): Number of layers of WaveNet.
|
||||||
|
stacks (int): Number of stacks of WaveNet.
|
||||||
|
global_channels (int): Number of global channels.
|
||||||
|
dropout_rate (float): Dropout rate.
|
||||||
|
use_weight_norm (bool): Whether to use weight normalization in WaveNet.
|
||||||
|
bias (bool): Whether to use bias paramters in WaveNet.
|
||||||
|
use_only_mean (bool): Whether to estimate only mean.
|
||||||
|
|
||||||
|
"""
|
||||||
|
super().__init__()
|
||||||
|
|
||||||
|
self.flows = nn.LayerList()
|
||||||
|
for i in range(flows):
|
||||||
|
self.flows.append(
|
||||||
|
ResidualAffineCouplingLayer(
|
||||||
|
in_channels=in_channels,
|
||||||
|
hidden_channels=hidden_channels,
|
||||||
|
kernel_size=kernel_size,
|
||||||
|
base_dilation=base_dilation,
|
||||||
|
layers=layers,
|
||||||
|
stacks=1,
|
||||||
|
global_channels=global_channels,
|
||||||
|
dropout_rate=dropout_rate,
|
||||||
|
use_weight_norm=use_weight_norm,
|
||||||
|
bias=bias,
|
||||||
|
use_only_mean=use_only_mean, ))
|
||||||
|
self.flows.append(FlipFlow())
|
||||||
|
|
||||||
|
def forward(
|
||||||
|
self,
|
||||||
|
x: paddle.Tensor,
|
||||||
|
x_mask: paddle.Tensor,
|
||||||
|
g: Optional[paddle.Tensor]=None,
|
||||||
|
inverse: bool=False, ) -> paddle.Tensor:
|
||||||
|
"""Calculate forward propagation.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
x (Tensor): Input tensor (B, in_channels, T).
|
||||||
|
x_mask (Tensor): Length tensor (B, 1, T).
|
||||||
|
g (Optional[Tensor]): Global conditioning tensor (B, global_channels, 1).
|
||||||
|
inverse (bool): Whether to inverse the flow.
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
Tensor: Output tensor (B, in_channels, T).
|
||||||
|
|
||||||
|
"""
|
||||||
|
if not inverse:
|
||||||
|
for flow in self.flows:
|
||||||
|
x, _ = flow(x, x_mask, g=g, inverse=inverse)
|
||||||
|
else:
|
||||||
|
for flow in reversed(self.flows):
|
||||||
|
x = flow(x, x_mask, g=g, inverse=inverse)
|
||||||
|
return x
|
||||||
|
|
||||||
|
|
||||||
|
class ResidualAffineCouplingLayer(nn.Layer):
|
||||||
|
"""Residual affine coupling layer."""
|
||||||
|
|
||||||
|
def __init__(
|
||||||
|
self,
|
||||||
|
in_channels: int=192,
|
||||||
|
hidden_channels: int=192,
|
||||||
|
kernel_size: int=5,
|
||||||
|
base_dilation: int=1,
|
||||||
|
layers: int=5,
|
||||||
|
stacks: int=1,
|
||||||
|
global_channels: int=-1,
|
||||||
|
dropout_rate: float=0.0,
|
||||||
|
use_weight_norm: bool=True,
|
||||||
|
bias: bool=True,
|
||||||
|
use_only_mean: bool=True, ):
|
||||||
|
"""Initialzie ResidualAffineCouplingLayer module.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
in_channels (int): Number of input channels.
|
||||||
|
hidden_channels (int): Number of hidden channels.
|
||||||
|
kernel_size (int): Kernel size for WaveNet.
|
||||||
|
base_dilation (int): Base dilation factor for WaveNet.
|
||||||
|
layers (int): Number of layers of WaveNet.
|
||||||
|
stacks (int): Number of stacks of WaveNet.
|
||||||
|
global_channels (int): Number of global channels.
|
||||||
|
dropout_rate (float): Dropout rate.
|
||||||
|
use_weight_norm (bool): Whether to use weight normalization in WaveNet.
|
||||||
|
bias (bool): Whether to use bias paramters in WaveNet.
|
||||||
|
use_only_mean (bool): Whether to estimate only mean.
|
||||||
|
|
||||||
|
"""
|
||||||
|
assert in_channels % 2 == 0, "in_channels should be divisible by 2"
|
||||||
|
super().__init__()
|
||||||
|
self.half_channels = in_channels // 2
|
||||||
|
self.use_only_mean = use_only_mean
|
||||||
|
|
||||||
|
# define modules
|
||||||
|
self.input_conv = nn.Conv1D(
|
||||||
|
self.half_channels,
|
||||||
|
hidden_channels,
|
||||||
|
1, )
|
||||||
|
self.encoder = WaveNet(
|
||||||
|
in_channels=-1,
|
||||||
|
out_channels=-1,
|
||||||
|
kernel_size=kernel_size,
|
||||||
|
layers=layers,
|
||||||
|
stacks=stacks,
|
||||||
|
base_dilation=base_dilation,
|
||||||
|
residual_channels=hidden_channels,
|
||||||
|
aux_channels=-1,
|
||||||
|
gate_channels=hidden_channels * 2,
|
||||||
|
skip_channels=hidden_channels,
|
||||||
|
global_channels=global_channels,
|
||||||
|
dropout_rate=dropout_rate,
|
||||||
|
bias=bias,
|
||||||
|
use_weight_norm=use_weight_norm,
|
||||||
|
use_first_conv=False,
|
||||||
|
use_last_conv=False,
|
||||||
|
scale_residual=False,
|
||||||
|
scale_skip_connect=True, )
|
||||||
|
if use_only_mean:
|
||||||
|
self.proj = nn.Conv1D(
|
||||||
|
hidden_channels,
|
||||||
|
self.half_channels,
|
||||||
|
1, )
|
||||||
|
else:
|
||||||
|
self.proj = nn.Conv1D(
|
||||||
|
hidden_channels,
|
||||||
|
self.half_channels * 2,
|
||||||
|
1, )
|
||||||
|
|
||||||
|
weight = paddle.zeros(paddle.shape(self.proj.weight))
|
||||||
|
|
||||||
|
self.proj.weight = paddle.create_parameter(
|
||||||
|
shape=weight.shape,
|
||||||
|
dtype=str(weight.numpy().dtype),
|
||||||
|
default_initializer=paddle.nn.initializer.Assign(weight))
|
||||||
|
|
||||||
|
bias = paddle.zeros(paddle.shape(self.proj.bias))
|
||||||
|
|
||||||
|
self.proj.bias = paddle.create_parameter(
|
||||||
|
shape=bias.shape,
|
||||||
|
dtype=str(bias.numpy().dtype),
|
||||||
|
default_initializer=paddle.nn.initializer.Assign(bias))
|
||||||
|
|
||||||
|
def forward(
|
||||||
|
self,
|
||||||
|
x: paddle.Tensor,
|
||||||
|
x_mask: paddle.Tensor,
|
||||||
|
g: Optional[paddle.Tensor]=None,
|
||||||
|
inverse: bool=False,
|
||||||
|
) -> Union[paddle.Tensor, Tuple[paddle.Tensor, paddle.Tensor]]:
|
||||||
|
"""Calculate forward propagation.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
x (Tensor): Input tensor (B, in_channels, T).
|
||||||
|
x_lengths (Tensor): Length tensor (B,).
|
||||||
|
g (Optional[Tensor]): Global conditioning tensor (B, global_channels, 1).
|
||||||
|
inverse (bool): Whether to inverse the flow.
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
Tensor: Output tensor (B, in_channels, T).
|
||||||
|
Tensor: Log-determinant tensor for NLL (B,) if not inverse.
|
||||||
|
|
||||||
|
"""
|
||||||
|
xa, xb = paddle.split(x, 2, axis=1)
|
||||||
|
h = self.input_conv(xa) * x_mask
|
||||||
|
h = self.encoder(h, x_mask, g=g)
|
||||||
|
stats = self.proj(h) * x_mask
|
||||||
|
if not self.use_only_mean:
|
||||||
|
m, logs = paddle.split(stats, 2, axis=1)
|
||||||
|
else:
|
||||||
|
m = stats
|
||||||
|
logs = paddle.zeros(paddle.shape(m))
|
||||||
|
|
||||||
|
if not inverse:
|
||||||
|
xb = m + xb * paddle.exp(logs) * x_mask
|
||||||
|
x = paddle.concat([xa, xb], 1)
|
||||||
|
logdet = paddle.sum(logs, [1, 2])
|
||||||
|
return x, logdet
|
||||||
|
else:
|
||||||
|
xb = (xb - m) * paddle.exp(-logs) * x_mask
|
||||||
|
x = paddle.concat([xa, xb], 1)
|
||||||
|
return x
|
@ -0,0 +1,145 @@
|
|||||||
|
# Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved.
|
||||||
|
#
|
||||||
|
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||||
|
# you may not use this file except in compliance with the License.
|
||||||
|
# You may obtain a copy of the License at
|
||||||
|
#
|
||||||
|
# http://www.apache.org/licenses/LICENSE-2.0
|
||||||
|
#
|
||||||
|
# Unless required by applicable law or agreed to in writing, software
|
||||||
|
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||||
|
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||||
|
# See the License for the specific language governing permissions and
|
||||||
|
# limitations under the License.
|
||||||
|
"""Text encoder module in VITS.
|
||||||
|
|
||||||
|
This code is based on https://github.com/jaywalnut310/vits.
|
||||||
|
|
||||||
|
"""
|
||||||
|
import math
|
||||||
|
from typing import Tuple
|
||||||
|
|
||||||
|
import paddle
|
||||||
|
from paddle import nn
|
||||||
|
|
||||||
|
from paddlespeech.t2s.modules.nets_utils import make_non_pad_mask
|
||||||
|
from paddlespeech.t2s.modules.transformer.encoder import ConformerEncoder as Encoder
|
||||||
|
|
||||||
|
|
||||||
|
class TextEncoder(nn.Layer):
|
||||||
|
"""Text encoder module in VITS.
|
||||||
|
|
||||||
|
This is a module of text encoder described in `Conditional Variational Autoencoder
|
||||||
|
with Adversarial Learning for End-to-End Text-to-Speech`_.
|
||||||
|
|
||||||
|
Instead of the relative positional Transformer, we use conformer architecture as
|
||||||
|
the encoder module, which contains additional convolution layers.
|
||||||
|
|
||||||
|
.. _`Conditional Variational Autoencoder with Adversarial Learning for End-to-End
|
||||||
|
Text-to-Speech`: https://arxiv.org/abs/2006.04558
|
||||||
|
|
||||||
|
"""
|
||||||
|
|
||||||
|
def __init__(
|
||||||
|
self,
|
||||||
|
vocabs: int,
|
||||||
|
attention_dim: int=192,
|
||||||
|
attention_heads: int=2,
|
||||||
|
linear_units: int=768,
|
||||||
|
blocks: int=6,
|
||||||
|
positionwise_layer_type: str="conv1d",
|
||||||
|
positionwise_conv_kernel_size: int=3,
|
||||||
|
positional_encoding_layer_type: str="rel_pos",
|
||||||
|
self_attention_layer_type: str="rel_selfattn",
|
||||||
|
activation_type: str="swish",
|
||||||
|
normalize_before: bool=True,
|
||||||
|
use_macaron_style: bool=False,
|
||||||
|
use_conformer_conv: bool=False,
|
||||||
|
conformer_kernel_size: int=7,
|
||||||
|
dropout_rate: float=0.1,
|
||||||
|
positional_dropout_rate: float=0.0,
|
||||||
|
attention_dropout_rate: float=0.0, ):
|
||||||
|
"""Initialize TextEncoder module.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
vocabs (int): Vocabulary size.
|
||||||
|
attention_dim (int): Attention dimension.
|
||||||
|
attention_heads (int): Number of attention heads.
|
||||||
|
linear_units (int): Number of linear units of positionwise layers.
|
||||||
|
blocks (int): Number of encoder blocks.
|
||||||
|
positionwise_layer_type (str): Positionwise layer type.
|
||||||
|
positionwise_conv_kernel_size (int): Positionwise layer's kernel size.
|
||||||
|
positional_encoding_layer_type (str): Positional encoding layer type.
|
||||||
|
self_attention_layer_type (str): Self-attention layer type.
|
||||||
|
activation_type (str): Activation function type.
|
||||||
|
normalize_before (bool): Whether to apply LayerNorm before attention.
|
||||||
|
use_macaron_style (bool): Whether to use macaron style components.
|
||||||
|
use_conformer_conv (bool): Whether to use conformer conv layers.
|
||||||
|
conformer_kernel_size (int): Conformer's conv kernel size.
|
||||||
|
dropout_rate (float): Dropout rate.
|
||||||
|
positional_dropout_rate (float): Dropout rate for positional encoding.
|
||||||
|
attention_dropout_rate (float): Dropout rate for attention.
|
||||||
|
|
||||||
|
"""
|
||||||
|
super().__init__()
|
||||||
|
# store for forward
|
||||||
|
self.attention_dim = attention_dim
|
||||||
|
|
||||||
|
# define modules
|
||||||
|
self.emb = nn.Embedding(vocabs, attention_dim)
|
||||||
|
|
||||||
|
dist = paddle.distribution.Normal(loc=0.0, scale=attention_dim**-0.5)
|
||||||
|
w = dist.sample(self.emb.weight.shape)
|
||||||
|
self.emb.weight.set_value(w)
|
||||||
|
|
||||||
|
self.encoder = Encoder(
|
||||||
|
idim=-1,
|
||||||
|
input_layer=None,
|
||||||
|
attention_dim=attention_dim,
|
||||||
|
attention_heads=attention_heads,
|
||||||
|
linear_units=linear_units,
|
||||||
|
num_blocks=blocks,
|
||||||
|
dropout_rate=dropout_rate,
|
||||||
|
positional_dropout_rate=positional_dropout_rate,
|
||||||
|
attention_dropout_rate=attention_dropout_rate,
|
||||||
|
normalize_before=normalize_before,
|
||||||
|
positionwise_layer_type=positionwise_layer_type,
|
||||||
|
positionwise_conv_kernel_size=positionwise_conv_kernel_size,
|
||||||
|
macaron_style=use_macaron_style,
|
||||||
|
pos_enc_layer_type=positional_encoding_layer_type,
|
||||||
|
selfattention_layer_type=self_attention_layer_type,
|
||||||
|
activation_type=activation_type,
|
||||||
|
use_cnn_module=use_conformer_conv,
|
||||||
|
cnn_module_kernel=conformer_kernel_size, )
|
||||||
|
self.proj = nn.Conv1D(attention_dim, attention_dim * 2, 1)
|
||||||
|
|
||||||
|
def forward(
|
||||||
|
self,
|
||||||
|
x: paddle.Tensor,
|
||||||
|
x_lengths: paddle.Tensor,
|
||||||
|
) -> Tuple[paddle.Tensor, paddle.Tensor, paddle.Tensor, paddle.Tensor]:
|
||||||
|
"""Calculate forward propagation.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
x (Tensor): Input index tensor (B, T_text).
|
||||||
|
x_lengths (Tensor): Length tensor (B,).
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
Tensor: Encoded hidden representation (B, attention_dim, T_text).
|
||||||
|
Tensor: Projected mean tensor (B, attention_dim, T_text).
|
||||||
|
Tensor: Projected scale tensor (B, attention_dim, T_text).
|
||||||
|
Tensor: Mask tensor for input tensor (B, 1, T_text).
|
||||||
|
|
||||||
|
"""
|
||||||
|
x = self.emb(x) * math.sqrt(self.attention_dim)
|
||||||
|
x_mask = make_non_pad_mask(x_lengths).unsqueeze(1)
|
||||||
|
# encoder assume the channel last (B, T_text, attention_dim)
|
||||||
|
# but mask shape shoud be (B, 1, T_text)
|
||||||
|
x, _ = self.encoder(x, x_mask)
|
||||||
|
|
||||||
|
# convert the channel first (B, attention_dim, T_text)
|
||||||
|
x = paddle.transpose(x, [0, 2, 1])
|
||||||
|
stats = self.proj(x) * x_mask
|
||||||
|
m, logs = paddle.split(stats, 2, axis=1)
|
||||||
|
|
||||||
|
return x, m, logs, x_mask
|
@ -0,0 +1,238 @@
|
|||||||
|
# Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved.
|
||||||
|
#
|
||||||
|
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||||
|
# you may not use this file except in compliance with the License.
|
||||||
|
# You may obtain a copy of the License at
|
||||||
|
#
|
||||||
|
# http://www.apache.org/licenses/LICENSE-2.0
|
||||||
|
#
|
||||||
|
# Unless required by applicable law or agreed to in writing, software
|
||||||
|
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||||
|
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||||
|
# See the License for the specific language governing permissions and
|
||||||
|
# limitations under the License.
|
||||||
|
"""Flow-related transformation.
|
||||||
|
|
||||||
|
This code is based on https://github.com/bayesiains/nflows.
|
||||||
|
|
||||||
|
"""
|
||||||
|
import numpy as np
|
||||||
|
import paddle
|
||||||
|
from paddle.nn import functional as F
|
||||||
|
|
||||||
|
from paddlespeech.t2s.modules.nets_utils import paddle_gather
|
||||||
|
|
||||||
|
DEFAULT_MIN_BIN_WIDTH = 1e-3
|
||||||
|
DEFAULT_MIN_BIN_HEIGHT = 1e-3
|
||||||
|
DEFAULT_MIN_DERIVATIVE = 1e-3
|
||||||
|
|
||||||
|
|
||||||
|
def piecewise_rational_quadratic_transform(
|
||||||
|
inputs,
|
||||||
|
unnormalized_widths,
|
||||||
|
unnormalized_heights,
|
||||||
|
unnormalized_derivatives,
|
||||||
|
inverse=False,
|
||||||
|
tails=None,
|
||||||
|
tail_bound=1.0,
|
||||||
|
min_bin_width=DEFAULT_MIN_BIN_WIDTH,
|
||||||
|
min_bin_height=DEFAULT_MIN_BIN_HEIGHT,
|
||||||
|
min_derivative=DEFAULT_MIN_DERIVATIVE, ):
|
||||||
|
if tails is None:
|
||||||
|
spline_fn = rational_quadratic_spline
|
||||||
|
spline_kwargs = {}
|
||||||
|
else:
|
||||||
|
spline_fn = unconstrained_rational_quadratic_spline
|
||||||
|
spline_kwargs = {"tails": tails, "tail_bound": tail_bound}
|
||||||
|
|
||||||
|
outputs, logabsdet = spline_fn(
|
||||||
|
inputs=inputs,
|
||||||
|
unnormalized_widths=unnormalized_widths,
|
||||||
|
unnormalized_heights=unnormalized_heights,
|
||||||
|
unnormalized_derivatives=unnormalized_derivatives,
|
||||||
|
inverse=inverse,
|
||||||
|
min_bin_width=min_bin_width,
|
||||||
|
min_bin_height=min_bin_height,
|
||||||
|
min_derivative=min_derivative,
|
||||||
|
**spline_kwargs)
|
||||||
|
return outputs, logabsdet
|
||||||
|
|
||||||
|
|
||||||
|
def mask_preprocess(x, mask):
|
||||||
|
B, C, T, bins = paddle.shape(x)
|
||||||
|
new_x = paddle.zeros([mask.sum(), bins])
|
||||||
|
for i in range(bins):
|
||||||
|
new_x[:, i] = x[:, :, :, i][mask]
|
||||||
|
return new_x
|
||||||
|
|
||||||
|
|
||||||
|
def unconstrained_rational_quadratic_spline(
|
||||||
|
inputs,
|
||||||
|
unnormalized_widths,
|
||||||
|
unnormalized_heights,
|
||||||
|
unnormalized_derivatives,
|
||||||
|
inverse=False,
|
||||||
|
tails="linear",
|
||||||
|
tail_bound=1.0,
|
||||||
|
min_bin_width=DEFAULT_MIN_BIN_WIDTH,
|
||||||
|
min_bin_height=DEFAULT_MIN_BIN_HEIGHT,
|
||||||
|
min_derivative=DEFAULT_MIN_DERIVATIVE, ):
|
||||||
|
inside_interval_mask = (inputs >= -tail_bound) & (inputs <= tail_bound)
|
||||||
|
outside_interval_mask = ~inside_interval_mask
|
||||||
|
|
||||||
|
outputs = paddle.zeros(paddle.shape(inputs))
|
||||||
|
logabsdet = paddle.zeros(paddle.shape(inputs))
|
||||||
|
if tails == "linear":
|
||||||
|
unnormalized_derivatives = F.pad(
|
||||||
|
unnormalized_derivatives,
|
||||||
|
pad=[0] * (len(unnormalized_derivatives.shape) - 1) * 2 + [1, 1])
|
||||||
|
constant = np.log(np.exp(1 - min_derivative) - 1)
|
||||||
|
unnormalized_derivatives[..., 0] = constant
|
||||||
|
unnormalized_derivatives[..., -1] = constant
|
||||||
|
|
||||||
|
outputs[outside_interval_mask] = inputs[outside_interval_mask]
|
||||||
|
logabsdet[outside_interval_mask] = 0
|
||||||
|
else:
|
||||||
|
raise RuntimeError("{} tails are not implemented.".format(tails))
|
||||||
|
|
||||||
|
unnormalized_widths = mask_preprocess(unnormalized_widths,
|
||||||
|
inside_interval_mask)
|
||||||
|
unnormalized_heights = mask_preprocess(unnormalized_heights,
|
||||||
|
inside_interval_mask)
|
||||||
|
unnormalized_derivatives = mask_preprocess(unnormalized_derivatives,
|
||||||
|
inside_interval_mask)
|
||||||
|
|
||||||
|
(outputs[inside_interval_mask],
|
||||||
|
logabsdet[inside_interval_mask], ) = rational_quadratic_spline(
|
||||||
|
inputs=inputs[inside_interval_mask],
|
||||||
|
unnormalized_widths=unnormalized_widths,
|
||||||
|
unnormalized_heights=unnormalized_heights,
|
||||||
|
unnormalized_derivatives=unnormalized_derivatives,
|
||||||
|
inverse=inverse,
|
||||||
|
left=-tail_bound,
|
||||||
|
right=tail_bound,
|
||||||
|
bottom=-tail_bound,
|
||||||
|
top=tail_bound,
|
||||||
|
min_bin_width=min_bin_width,
|
||||||
|
min_bin_height=min_bin_height,
|
||||||
|
min_derivative=min_derivative, )
|
||||||
|
|
||||||
|
return outputs, logabsdet
|
||||||
|
|
||||||
|
|
||||||
|
def rational_quadratic_spline(
|
||||||
|
inputs,
|
||||||
|
unnormalized_widths,
|
||||||
|
unnormalized_heights,
|
||||||
|
unnormalized_derivatives,
|
||||||
|
inverse=False,
|
||||||
|
left=0.0,
|
||||||
|
right=1.0,
|
||||||
|
bottom=0.0,
|
||||||
|
top=1.0,
|
||||||
|
min_bin_width=DEFAULT_MIN_BIN_WIDTH,
|
||||||
|
min_bin_height=DEFAULT_MIN_BIN_HEIGHT,
|
||||||
|
min_derivative=DEFAULT_MIN_DERIVATIVE, ):
|
||||||
|
if paddle.min(inputs) < left or paddle.max(inputs) > right:
|
||||||
|
raise ValueError("Input to a transform is not within its domain")
|
||||||
|
|
||||||
|
num_bins = unnormalized_widths.shape[-1]
|
||||||
|
|
||||||
|
if min_bin_width * num_bins > 1.0:
|
||||||
|
raise ValueError("Minimal bin width too large for the number of bins")
|
||||||
|
if min_bin_height * num_bins > 1.0:
|
||||||
|
raise ValueError("Minimal bin height too large for the number of bins")
|
||||||
|
|
||||||
|
widths = F.softmax(unnormalized_widths, axis=-1)
|
||||||
|
widths = min_bin_width + (1 - min_bin_width * num_bins) * widths
|
||||||
|
cumwidths = paddle.cumsum(widths, axis=-1)
|
||||||
|
cumwidths = F.pad(
|
||||||
|
cumwidths,
|
||||||
|
pad=[0] * (len(cumwidths.shape) - 1) * 2 + [1, 0],
|
||||||
|
mode="constant",
|
||||||
|
value=0.0)
|
||||||
|
cumwidths = (right - left) * cumwidths + left
|
||||||
|
cumwidths[..., 0] = left
|
||||||
|
cumwidths[..., -1] = right
|
||||||
|
widths = cumwidths[..., 1:] - cumwidths[..., :-1]
|
||||||
|
|
||||||
|
derivatives = min_derivative + F.softplus(unnormalized_derivatives)
|
||||||
|
|
||||||
|
heights = F.softmax(unnormalized_heights, axis=-1)
|
||||||
|
heights = min_bin_height + (1 - min_bin_height * num_bins) * heights
|
||||||
|
cumheights = paddle.cumsum(heights, axis=-1)
|
||||||
|
cumheights = F.pad(
|
||||||
|
cumheights,
|
||||||
|
pad=[0] * (len(cumheights.shape) - 1) * 2 + [1, 0],
|
||||||
|
mode="constant",
|
||||||
|
value=0.0)
|
||||||
|
cumheights = (top - bottom) * cumheights + bottom
|
||||||
|
cumheights[..., 0] = bottom
|
||||||
|
cumheights[..., -1] = top
|
||||||
|
heights = cumheights[..., 1:] - cumheights[..., :-1]
|
||||||
|
|
||||||
|
if inverse:
|
||||||
|
bin_idx = _searchsorted(cumheights, inputs)[..., None]
|
||||||
|
else:
|
||||||
|
bin_idx = _searchsorted(cumwidths, inputs)[..., None]
|
||||||
|
input_cumwidths = paddle_gather(cumwidths, -1, bin_idx)[..., 0]
|
||||||
|
input_bin_widths = paddle_gather(widths, -1, bin_idx)[..., 0]
|
||||||
|
|
||||||
|
input_cumheights = paddle_gather(cumheights, -1, bin_idx)[..., 0]
|
||||||
|
delta = heights / widths
|
||||||
|
input_delta = paddle_gather(delta, -1, bin_idx)[..., 0]
|
||||||
|
|
||||||
|
input_derivatives = paddle_gather(derivatives, -1, bin_idx)[..., 0]
|
||||||
|
input_derivatives_plus_one = paddle_gather(derivatives[..., 1:], -1,
|
||||||
|
bin_idx)[..., 0]
|
||||||
|
|
||||||
|
input_heights = paddle_gather(heights, -1, bin_idx)[..., 0]
|
||||||
|
|
||||||
|
if inverse:
|
||||||
|
a = (inputs - input_cumheights) * (
|
||||||
|
input_derivatives + input_derivatives_plus_one - 2 * input_delta
|
||||||
|
) + input_heights * (input_delta - input_derivatives)
|
||||||
|
b = input_heights * input_derivatives - (inputs - input_cumheights) * (
|
||||||
|
input_derivatives + input_derivatives_plus_one - 2 * input_delta)
|
||||||
|
c = -input_delta * (inputs - input_cumheights)
|
||||||
|
|
||||||
|
discriminant = b.pow(2) - 4 * a * c
|
||||||
|
assert (discriminant >= 0).all()
|
||||||
|
|
||||||
|
root = (2 * c) / (-b - paddle.sqrt(discriminant))
|
||||||
|
outputs = root * input_bin_widths + input_cumwidths
|
||||||
|
|
||||||
|
theta_one_minus_theta = root * (1 - root)
|
||||||
|
denominator = input_delta + (
|
||||||
|
(input_derivatives + input_derivatives_plus_one - 2 * input_delta
|
||||||
|
) * theta_one_minus_theta)
|
||||||
|
derivative_numerator = input_delta.pow(2) * (
|
||||||
|
input_derivatives_plus_one * root.pow(2) + 2 * input_delta *
|
||||||
|
theta_one_minus_theta + input_derivatives * (1 - root).pow(2))
|
||||||
|
logabsdet = paddle.log(derivative_numerator) - 2 * paddle.log(
|
||||||
|
denominator)
|
||||||
|
|
||||||
|
return outputs, -logabsdet
|
||||||
|
else:
|
||||||
|
theta = (inputs - input_cumwidths) / input_bin_widths
|
||||||
|
theta_one_minus_theta = theta * (1 - theta)
|
||||||
|
|
||||||
|
numerator = input_heights * (input_delta * theta.pow(2) +
|
||||||
|
input_derivatives * theta_one_minus_theta)
|
||||||
|
denominator = input_delta + (
|
||||||
|
(input_derivatives + input_derivatives_plus_one - 2 * input_delta
|
||||||
|
) * theta_one_minus_theta)
|
||||||
|
outputs = input_cumheights + numerator / denominator
|
||||||
|
|
||||||
|
derivative_numerator = input_delta.pow(2) * (
|
||||||
|
input_derivatives_plus_one * theta.pow(2) + 2 * input_delta *
|
||||||
|
theta_one_minus_theta + input_derivatives * (1 - theta).pow(2))
|
||||||
|
logabsdet = paddle.log(derivative_numerator) - 2 * paddle.log(
|
||||||
|
denominator)
|
||||||
|
|
||||||
|
return outputs, logabsdet
|
||||||
|
|
||||||
|
|
||||||
|
def _searchsorted(bin_locations, inputs, eps=1e-6):
|
||||||
|
bin_locations[..., -1] += eps
|
||||||
|
return paddle.sum(inputs[..., None] >= bin_locations, axis=-1) - 1
|
@ -0,0 +1,412 @@
|
|||||||
|
# Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved.
|
||||||
|
#
|
||||||
|
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||||
|
# you may not use this file except in compliance with the License.
|
||||||
|
# You may obtain a copy of the License at
|
||||||
|
#
|
||||||
|
# http://www.apache.org/licenses/LICENSE-2.0
|
||||||
|
#
|
||||||
|
# Unless required by applicable law or agreed to in writing, software
|
||||||
|
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||||
|
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||||
|
# See the License for the specific language governing permissions and
|
||||||
|
# limitations under the License.
|
||||||
|
# Modified from espnet(https://github.com/espnet/espnet)
|
||||||
|
"""VITS module"""
|
||||||
|
from typing import Any
|
||||||
|
from typing import Dict
|
||||||
|
from typing import Optional
|
||||||
|
|
||||||
|
import paddle
|
||||||
|
from paddle import nn
|
||||||
|
from typeguard import check_argument_types
|
||||||
|
|
||||||
|
from paddlespeech.t2s.models.hifigan import HiFiGANMultiPeriodDiscriminator
|
||||||
|
from paddlespeech.t2s.models.hifigan import HiFiGANMultiScaleDiscriminator
|
||||||
|
from paddlespeech.t2s.models.hifigan import HiFiGANMultiScaleMultiPeriodDiscriminator
|
||||||
|
from paddlespeech.t2s.models.hifigan import HiFiGANPeriodDiscriminator
|
||||||
|
from paddlespeech.t2s.models.hifigan import HiFiGANScaleDiscriminator
|
||||||
|
from paddlespeech.t2s.models.vits.generator import VITSGenerator
|
||||||
|
from paddlespeech.t2s.modules.nets_utils import initialize
|
||||||
|
|
||||||
|
AVAILABLE_GENERATERS = {
|
||||||
|
"vits_generator": VITSGenerator,
|
||||||
|
}
|
||||||
|
AVAILABLE_DISCRIMINATORS = {
|
||||||
|
"hifigan_period_discriminator":
|
||||||
|
HiFiGANPeriodDiscriminator,
|
||||||
|
"hifigan_scale_discriminator":
|
||||||
|
HiFiGANScaleDiscriminator,
|
||||||
|
"hifigan_multi_period_discriminator":
|
||||||
|
HiFiGANMultiPeriodDiscriminator,
|
||||||
|
"hifigan_multi_scale_discriminator":
|
||||||
|
HiFiGANMultiScaleDiscriminator,
|
||||||
|
"hifigan_multi_scale_multi_period_discriminator":
|
||||||
|
HiFiGANMultiScaleMultiPeriodDiscriminator,
|
||||||
|
}
|
||||||
|
|
||||||
|
|
||||||
|
class VITS(nn.Layer):
|
||||||
|
"""VITS module (generator + discriminator).
|
||||||
|
This is a module of VITS described in `Conditional Variational Autoencoder
|
||||||
|
with Adversarial Learning for End-to-End Text-to-Speech`_.
|
||||||
|
.. _`Conditional Variational Autoencoder with Adversarial Learning for End-to-End
|
||||||
|
Text-to-Speech`: https://arxiv.org/abs/2006.04558
|
||||||
|
"""
|
||||||
|
|
||||||
|
def __init__(
|
||||||
|
self,
|
||||||
|
# generator related
|
||||||
|
idim: int,
|
||||||
|
odim: int,
|
||||||
|
sampling_rate: int=22050,
|
||||||
|
generator_type: str="vits_generator",
|
||||||
|
generator_params: Dict[str, Any]={
|
||||||
|
"hidden_channels": 192,
|
||||||
|
"spks": None,
|
||||||
|
"langs": None,
|
||||||
|
"spk_embed_dim": None,
|
||||||
|
"global_channels": -1,
|
||||||
|
"segment_size": 32,
|
||||||
|
"text_encoder_attention_heads": 2,
|
||||||
|
"text_encoder_ffn_expand": 4,
|
||||||
|
"text_encoder_blocks": 6,
|
||||||
|
"text_encoder_positionwise_layer_type": "conv1d",
|
||||||
|
"text_encoder_positionwise_conv_kernel_size": 1,
|
||||||
|
"text_encoder_positional_encoding_layer_type": "rel_pos",
|
||||||
|
"text_encoder_self_attention_layer_type": "rel_selfattn",
|
||||||
|
"text_encoder_activation_type": "swish",
|
||||||
|
"text_encoder_normalize_before": True,
|
||||||
|
"text_encoder_dropout_rate": 0.1,
|
||||||
|
"text_encoder_positional_dropout_rate": 0.0,
|
||||||
|
"text_encoder_attention_dropout_rate": 0.0,
|
||||||
|
"text_encoder_conformer_kernel_size": 7,
|
||||||
|
"use_macaron_style_in_text_encoder": True,
|
||||||
|
"use_conformer_conv_in_text_encoder": True,
|
||||||
|
"decoder_kernel_size": 7,
|
||||||
|
"decoder_channels": 512,
|
||||||
|
"decoder_upsample_scales": [8, 8, 2, 2],
|
||||||
|
"decoder_upsample_kernel_sizes": [16, 16, 4, 4],
|
||||||
|
"decoder_resblock_kernel_sizes": [3, 7, 11],
|
||||||
|
"decoder_resblock_dilations": [[1, 3, 5], [1, 3, 5], [1, 3, 5]],
|
||||||
|
"use_weight_norm_in_decoder": True,
|
||||||
|
"posterior_encoder_kernel_size": 5,
|
||||||
|
"posterior_encoder_layers": 16,
|
||||||
|
"posterior_encoder_stacks": 1,
|
||||||
|
"posterior_encoder_base_dilation": 1,
|
||||||
|
"posterior_encoder_dropout_rate": 0.0,
|
||||||
|
"use_weight_norm_in_posterior_encoder": True,
|
||||||
|
"flow_flows": 4,
|
||||||
|
"flow_kernel_size": 5,
|
||||||
|
"flow_base_dilation": 1,
|
||||||
|
"flow_layers": 4,
|
||||||
|
"flow_dropout_rate": 0.0,
|
||||||
|
"use_weight_norm_in_flow": True,
|
||||||
|
"use_only_mean_in_flow": True,
|
||||||
|
"stochastic_duration_predictor_kernel_size": 3,
|
||||||
|
"stochastic_duration_predictor_dropout_rate": 0.5,
|
||||||
|
"stochastic_duration_predictor_flows": 4,
|
||||||
|
"stochastic_duration_predictor_dds_conv_layers": 3,
|
||||||
|
},
|
||||||
|
# discriminator related
|
||||||
|
discriminator_type: str="hifigan_multi_scale_multi_period_discriminator",
|
||||||
|
discriminator_params: Dict[str, Any]={
|
||||||
|
"scales": 1,
|
||||||
|
"scale_downsample_pooling": "AvgPool1D",
|
||||||
|
"scale_downsample_pooling_params": {
|
||||||
|
"kernel_size": 4,
|
||||||
|
"stride": 2,
|
||||||
|
"padding": 2,
|
||||||
|
},
|
||||||
|
"scale_discriminator_params": {
|
||||||
|
"in_channels": 1,
|
||||||
|
"out_channels": 1,
|
||||||
|
"kernel_sizes": [15, 41, 5, 3],
|
||||||
|
"channels": 128,
|
||||||
|
"max_downsample_channels": 1024,
|
||||||
|
"max_groups": 16,
|
||||||
|
"bias": True,
|
||||||
|
"downsample_scales": [2, 2, 4, 4, 1],
|
||||||
|
"nonlinear_activation": "leakyrelu",
|
||||||
|
"nonlinear_activation_params": {
|
||||||
|
"negative_slope": 0.1
|
||||||
|
},
|
||||||
|
"use_weight_norm": True,
|
||||||
|
"use_spectral_norm": False,
|
||||||
|
},
|
||||||
|
"follow_official_norm": False,
|
||||||
|
"periods": [2, 3, 5, 7, 11],
|
||||||
|
"period_discriminator_params": {
|
||||||
|
"in_channels": 1,
|
||||||
|
"out_channels": 1,
|
||||||
|
"kernel_sizes": [5, 3],
|
||||||
|
"channels": 32,
|
||||||
|
"downsample_scales": [3, 3, 3, 3, 1],
|
||||||
|
"max_downsample_channels": 1024,
|
||||||
|
"bias": True,
|
||||||
|
"nonlinear_activation": "leakyrelu",
|
||||||
|
"nonlinear_activation_params": {
|
||||||
|
"negative_slope": 0.1
|
||||||
|
},
|
||||||
|
"use_weight_norm": True,
|
||||||
|
"use_spectral_norm": False,
|
||||||
|
},
|
||||||
|
},
|
||||||
|
cache_generator_outputs: bool=True,
|
||||||
|
init_type: str="xavier_uniform", ):
|
||||||
|
"""Initialize VITS module.
|
||||||
|
Args:
|
||||||
|
idim (int): Input vocabrary size.
|
||||||
|
odim (int): Acoustic feature dimension. The actual output channels will
|
||||||
|
be 1 since VITS is the end-to-end text-to-wave model but for the
|
||||||
|
compatibility odim is used to indicate the acoustic feature dimension.
|
||||||
|
sampling_rate (int): Sampling rate, not used for the training but it will
|
||||||
|
be referred in saving waveform during the inference.
|
||||||
|
generator_type (str): Generator type.
|
||||||
|
generator_params (Dict[str, Any]): Parameter dict for generator.
|
||||||
|
discriminator_type (str): Discriminator type.
|
||||||
|
discriminator_params (Dict[str, Any]): Parameter dict for discriminator.
|
||||||
|
cache_generator_outputs (bool): Whether to cache generator outputs.
|
||||||
|
"""
|
||||||
|
assert check_argument_types()
|
||||||
|
super().__init__()
|
||||||
|
|
||||||
|
# initialize parameters
|
||||||
|
initialize(self, init_type)
|
||||||
|
|
||||||
|
# define modules
|
||||||
|
generator_class = AVAILABLE_GENERATERS[generator_type]
|
||||||
|
if generator_type == "vits_generator":
|
||||||
|
# NOTE: Update parameters for the compatibility.
|
||||||
|
# The idim and odim is automatically decided from input data,
|
||||||
|
# where idim represents #vocabularies and odim represents
|
||||||
|
# the input acoustic feature dimension.
|
||||||
|
generator_params.update(vocabs=idim, aux_channels=odim)
|
||||||
|
self.generator = generator_class(
|
||||||
|
**generator_params, )
|
||||||
|
discriminator_class = AVAILABLE_DISCRIMINATORS[discriminator_type]
|
||||||
|
self.discriminator = discriminator_class(
|
||||||
|
**discriminator_params, )
|
||||||
|
|
||||||
|
nn.initializer.set_global_initializer(None)
|
||||||
|
|
||||||
|
# cache
|
||||||
|
self.cache_generator_outputs = cache_generator_outputs
|
||||||
|
self._cache = None
|
||||||
|
|
||||||
|
# store sampling rate for saving wav file
|
||||||
|
# (not used for the training)
|
||||||
|
self.fs = sampling_rate
|
||||||
|
|
||||||
|
# store parameters for test compatibility
|
||||||
|
self.spks = self.generator.spks
|
||||||
|
self.langs = self.generator.langs
|
||||||
|
self.spk_embed_dim = self.generator.spk_embed_dim
|
||||||
|
|
||||||
|
self.reuse_cache_gen = True
|
||||||
|
self.reuse_cache_dis = True
|
||||||
|
|
||||||
|
def forward(
|
||||||
|
self,
|
||||||
|
text: paddle.Tensor,
|
||||||
|
text_lengths: paddle.Tensor,
|
||||||
|
feats: paddle.Tensor,
|
||||||
|
feats_lengths: paddle.Tensor,
|
||||||
|
sids: Optional[paddle.Tensor]=None,
|
||||||
|
spembs: Optional[paddle.Tensor]=None,
|
||||||
|
lids: Optional[paddle.Tensor]=None,
|
||||||
|
forward_generator: bool=True, ) -> Dict[str, Any]:
|
||||||
|
"""Perform generator forward.
|
||||||
|
Args:
|
||||||
|
text (Tensor): Text index tensor (B, T_text).
|
||||||
|
text_lengths (Tensor): Text length tensor (B,).
|
||||||
|
feats (Tensor): Feature tensor (B, T_feats, aux_channels).
|
||||||
|
feats_lengths (Tensor): Feature length tensor (B,).
|
||||||
|
sids (Optional[Tensor]): Speaker index tensor (B,) or (B, 1).
|
||||||
|
spembs (Optional[Tensor]): Speaker embedding tensor (B, spk_embed_dim).
|
||||||
|
lids (Optional[Tensor]): Language index tensor (B,) or (B, 1).
|
||||||
|
forward_generator (bool): Whether to forward generator.
|
||||||
|
Returns:
|
||||||
|
Dict[str, Any]:
|
||||||
|
- loss (Tensor): Loss scalar tensor.
|
||||||
|
- stats (Dict[str, float]): Statistics to be monitored.
|
||||||
|
- weight (Tensor): Weight tensor to summarize losses.
|
||||||
|
- optim_idx (int): Optimizer index (0 for G and 1 for D).
|
||||||
|
"""
|
||||||
|
if forward_generator:
|
||||||
|
return self._forward_generator(
|
||||||
|
text=text,
|
||||||
|
text_lengths=text_lengths,
|
||||||
|
feats=feats,
|
||||||
|
feats_lengths=feats_lengths,
|
||||||
|
sids=sids,
|
||||||
|
spembs=spembs,
|
||||||
|
lids=lids, )
|
||||||
|
else:
|
||||||
|
return self._forward_discrminator(
|
||||||
|
text=text,
|
||||||
|
text_lengths=text_lengths,
|
||||||
|
feats=feats,
|
||||||
|
feats_lengths=feats_lengths,
|
||||||
|
sids=sids,
|
||||||
|
spembs=spembs,
|
||||||
|
lids=lids, )
|
||||||
|
|
||||||
|
def _forward_generator(
|
||||||
|
self,
|
||||||
|
text: paddle.Tensor,
|
||||||
|
text_lengths: paddle.Tensor,
|
||||||
|
feats: paddle.Tensor,
|
||||||
|
feats_lengths: paddle.Tensor,
|
||||||
|
sids: Optional[paddle.Tensor]=None,
|
||||||
|
spembs: Optional[paddle.Tensor]=None,
|
||||||
|
lids: Optional[paddle.Tensor]=None, ) -> Dict[str, Any]:
|
||||||
|
"""Perform generator forward.
|
||||||
|
Args:
|
||||||
|
text (Tensor): Text index tensor (B, T_text).
|
||||||
|
text_lengths (Tensor): Text length tensor (B,).
|
||||||
|
feats (Tensor): Feature tensor (B, T_feats, aux_channels).
|
||||||
|
feats_lengths (Tensor): Feature length tensor (B,).
|
||||||
|
sids (Optional[Tensor]): Speaker index tensor (B,) or (B, 1).
|
||||||
|
spembs (Optional[Tensor]): Speaker embedding tensor (B, spk_embed_dim).
|
||||||
|
lids (Optional[Tensor]): Language index tensor (B,) or (B, 1).
|
||||||
|
Returns:
|
||||||
|
|
||||||
|
"""
|
||||||
|
# setup
|
||||||
|
feats = feats.transpose([0, 2, 1])
|
||||||
|
|
||||||
|
# calculate generator outputs
|
||||||
|
self.reuse_cache_gen = True
|
||||||
|
if not self.cache_generator_outputs or self._cache is None:
|
||||||
|
self.reuse_cache_gen = False
|
||||||
|
outs = self.generator(
|
||||||
|
text=text,
|
||||||
|
text_lengths=text_lengths,
|
||||||
|
feats=feats,
|
||||||
|
feats_lengths=feats_lengths,
|
||||||
|
sids=sids,
|
||||||
|
spembs=spembs,
|
||||||
|
lids=lids, )
|
||||||
|
else:
|
||||||
|
outs = self._cache
|
||||||
|
|
||||||
|
# store cache
|
||||||
|
if self.training and self.cache_generator_outputs and not self.reuse_cache_gen:
|
||||||
|
self._cache = outs
|
||||||
|
|
||||||
|
return outs
|
||||||
|
|
||||||
|
def _forward_discrminator(
|
||||||
|
self,
|
||||||
|
text: paddle.Tensor,
|
||||||
|
text_lengths: paddle.Tensor,
|
||||||
|
feats: paddle.Tensor,
|
||||||
|
feats_lengths: paddle.Tensor,
|
||||||
|
sids: Optional[paddle.Tensor]=None,
|
||||||
|
spembs: Optional[paddle.Tensor]=None,
|
||||||
|
lids: Optional[paddle.Tensor]=None, ) -> Dict[str, Any]:
|
||||||
|
"""Perform discriminator forward.
|
||||||
|
Args:
|
||||||
|
text (Tensor): Text index tensor (B, T_text).
|
||||||
|
text_lengths (Tensor): Text length tensor (B,).
|
||||||
|
feats (Tensor): Feature tensor (B, T_feats, aux_channels).
|
||||||
|
feats_lengths (Tensor): Feature length tensor (B,).
|
||||||
|
sids (Optional[Tensor]): Speaker index tensor (B,) or (B, 1).
|
||||||
|
spembs (Optional[Tensor]): Speaker embedding tensor (B, spk_embed_dim).
|
||||||
|
lids (Optional[Tensor]): Language index tensor (B,) or (B, 1).
|
||||||
|
Returns:
|
||||||
|
|
||||||
|
"""
|
||||||
|
# setup
|
||||||
|
feats = feats.transpose([0, 2, 1])
|
||||||
|
|
||||||
|
# calculate generator outputs
|
||||||
|
self.reuse_cache_dis = True
|
||||||
|
if not self.cache_generator_outputs or self._cache is None:
|
||||||
|
self.reuse_cache_dis = False
|
||||||
|
outs = self.generator(
|
||||||
|
text=text,
|
||||||
|
text_lengths=text_lengths,
|
||||||
|
feats=feats,
|
||||||
|
feats_lengths=feats_lengths,
|
||||||
|
sids=sids,
|
||||||
|
spembs=spembs,
|
||||||
|
lids=lids, )
|
||||||
|
else:
|
||||||
|
outs = self._cache
|
||||||
|
|
||||||
|
# store cache
|
||||||
|
if self.cache_generator_outputs and not self.reuse_cache_dis:
|
||||||
|
self._cache = outs
|
||||||
|
|
||||||
|
return outs
|
||||||
|
|
||||||
|
def inference(
|
||||||
|
self,
|
||||||
|
text: paddle.Tensor,
|
||||||
|
feats: Optional[paddle.Tensor]=None,
|
||||||
|
sids: Optional[paddle.Tensor]=None,
|
||||||
|
spembs: Optional[paddle.Tensor]=None,
|
||||||
|
lids: Optional[paddle.Tensor]=None,
|
||||||
|
durations: Optional[paddle.Tensor]=None,
|
||||||
|
noise_scale: float=0.667,
|
||||||
|
noise_scale_dur: float=0.8,
|
||||||
|
alpha: float=1.0,
|
||||||
|
max_len: Optional[int]=None,
|
||||||
|
use_teacher_forcing: bool=False, ) -> Dict[str, paddle.Tensor]:
|
||||||
|
"""Run inference.
|
||||||
|
Args:
|
||||||
|
text (Tensor): Input text index tensor (T_text,).
|
||||||
|
feats (Tensor): Feature tensor (T_feats, aux_channels).
|
||||||
|
sids (Tensor): Speaker index tensor (1,).
|
||||||
|
spembs (Optional[Tensor]): Speaker embedding tensor (spk_embed_dim,).
|
||||||
|
lids (Tensor): Language index tensor (1,).
|
||||||
|
durations (Tensor): Ground-truth duration tensor (T_text,).
|
||||||
|
noise_scale (float): Noise scale value for flow.
|
||||||
|
noise_scale_dur (float): Noise scale value for duration predictor.
|
||||||
|
alpha (float): Alpha parameter to control the speed of generated speech.
|
||||||
|
max_len (Optional[int]): Maximum length.
|
||||||
|
use_teacher_forcing (bool): Whether to use teacher forcing.
|
||||||
|
Returns:
|
||||||
|
Dict[str, Tensor]:
|
||||||
|
* wav (Tensor): Generated waveform tensor (T_wav,).
|
||||||
|
* att_w (Tensor): Monotonic attention weight tensor (T_feats, T_text).
|
||||||
|
* duration (Tensor): Predicted duration tensor (T_text,).
|
||||||
|
"""
|
||||||
|
# setup
|
||||||
|
text = text[None]
|
||||||
|
text_lengths = paddle.to_tensor(paddle.shape(text)[1])
|
||||||
|
|
||||||
|
if durations is not None:
|
||||||
|
durations = paddle.reshape(durations, [1, 1, -1])
|
||||||
|
|
||||||
|
# inference
|
||||||
|
if use_teacher_forcing:
|
||||||
|
assert feats is not None
|
||||||
|
feats = feats[None].transpose([0, 2, 1])
|
||||||
|
feats_lengths = paddle.to_tensor([paddle.shape(feats)[2]])
|
||||||
|
wav, att_w, dur = self.generator.inference(
|
||||||
|
text=text,
|
||||||
|
text_lengths=text_lengths,
|
||||||
|
feats=feats,
|
||||||
|
feats_lengths=feats_lengths,
|
||||||
|
sids=sids,
|
||||||
|
spembs=spembs,
|
||||||
|
lids=lids,
|
||||||
|
max_len=max_len,
|
||||||
|
use_teacher_forcing=use_teacher_forcing, )
|
||||||
|
else:
|
||||||
|
wav, att_w, dur = self.generator.inference(
|
||||||
|
text=text,
|
||||||
|
text_lengths=text_lengths,
|
||||||
|
sids=sids,
|
||||||
|
spembs=spembs,
|
||||||
|
lids=lids,
|
||||||
|
dur=durations,
|
||||||
|
noise_scale=noise_scale,
|
||||||
|
noise_scale_dur=noise_scale_dur,
|
||||||
|
alpha=alpha,
|
||||||
|
max_len=max_len, )
|
||||||
|
return dict(
|
||||||
|
wav=paddle.reshape(wav, [-1]), att_w=att_w[0], duration=dur[0])
|
@ -0,0 +1,353 @@
|
|||||||
|
# Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved.
|
||||||
|
#
|
||||||
|
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||||
|
# you may not use this file except in compliance with the License.
|
||||||
|
# You may obtain a copy of the License at
|
||||||
|
#
|
||||||
|
# http://www.apache.org/licenses/LICENSE-2.0
|
||||||
|
#
|
||||||
|
# Unless required by applicable law or agreed to in writing, software
|
||||||
|
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||||
|
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||||
|
# See the License for the specific language governing permissions and
|
||||||
|
# limitations under the License.
|
||||||
|
import logging
|
||||||
|
from typing import Dict
|
||||||
|
|
||||||
|
import paddle
|
||||||
|
from paddle import distributed as dist
|
||||||
|
from paddle.io import DataLoader
|
||||||
|
from paddle.nn import Layer
|
||||||
|
from paddle.optimizer import Optimizer
|
||||||
|
from paddle.optimizer.lr import LRScheduler
|
||||||
|
|
||||||
|
from paddlespeech.t2s.modules.nets_utils import get_segments
|
||||||
|
from paddlespeech.t2s.training.extensions.evaluator import StandardEvaluator
|
||||||
|
from paddlespeech.t2s.training.reporter import report
|
||||||
|
from paddlespeech.t2s.training.updaters.standard_updater import StandardUpdater
|
||||||
|
from paddlespeech.t2s.training.updaters.standard_updater import UpdaterState
|
||||||
|
|
||||||
|
logging.basicConfig(
|
||||||
|
format='%(asctime)s [%(levelname)s] [%(filename)s:%(lineno)d] %(message)s',
|
||||||
|
datefmt='[%Y-%m-%d %H:%M:%S]')
|
||||||
|
logger = logging.getLogger(__name__)
|
||||||
|
logger.setLevel(logging.INFO)
|
||||||
|
|
||||||
|
|
||||||
|
class VITSUpdater(StandardUpdater):
|
||||||
|
def __init__(self,
|
||||||
|
model: Layer,
|
||||||
|
optimizers: Dict[str, Optimizer],
|
||||||
|
criterions: Dict[str, Layer],
|
||||||
|
schedulers: Dict[str, LRScheduler],
|
||||||
|
dataloader: DataLoader,
|
||||||
|
generator_train_start_steps: int=0,
|
||||||
|
discriminator_train_start_steps: int=100000,
|
||||||
|
lambda_adv: float=1.0,
|
||||||
|
lambda_mel: float=45.0,
|
||||||
|
lambda_feat_match: float=2.0,
|
||||||
|
lambda_dur: float=1.0,
|
||||||
|
lambda_kl: float=1.0,
|
||||||
|
generator_first: bool=False,
|
||||||
|
output_dir=None):
|
||||||
|
# it is designed to hold multiple models
|
||||||
|
# 因为输入的是单模型,但是没有用到父类的 init(), 所以需要重新写这部分
|
||||||
|
models = {"main": model}
|
||||||
|
self.models: Dict[str, Layer] = models
|
||||||
|
# self.model = model
|
||||||
|
|
||||||
|
self.model = model._layers if isinstance(model, paddle.DataParallel) else model
|
||||||
|
|
||||||
|
self.optimizers = optimizers
|
||||||
|
self.optimizer_g: Optimizer = optimizers['generator']
|
||||||
|
self.optimizer_d: Optimizer = optimizers['discriminator']
|
||||||
|
|
||||||
|
self.criterions = criterions
|
||||||
|
self.criterion_mel = criterions['mel']
|
||||||
|
self.criterion_feat_match = criterions['feat_match']
|
||||||
|
self.criterion_gen_adv = criterions["gen_adv"]
|
||||||
|
self.criterion_dis_adv = criterions["dis_adv"]
|
||||||
|
self.criterion_kl = criterions["kl"]
|
||||||
|
|
||||||
|
self.schedulers = schedulers
|
||||||
|
self.scheduler_g = schedulers['generator']
|
||||||
|
self.scheduler_d = schedulers['discriminator']
|
||||||
|
|
||||||
|
self.dataloader = dataloader
|
||||||
|
|
||||||
|
self.generator_train_start_steps = generator_train_start_steps
|
||||||
|
self.discriminator_train_start_steps = discriminator_train_start_steps
|
||||||
|
|
||||||
|
self.lambda_adv = lambda_adv
|
||||||
|
self.lambda_mel = lambda_mel
|
||||||
|
self.lambda_feat_match = lambda_feat_match
|
||||||
|
self.lambda_dur = lambda_dur
|
||||||
|
self.lambda_kl = lambda_kl
|
||||||
|
|
||||||
|
if generator_first:
|
||||||
|
self.turns = ["generator", "discriminator"]
|
||||||
|
else:
|
||||||
|
self.turns = ["discriminator", "generator"]
|
||||||
|
|
||||||
|
self.state = UpdaterState(iteration=0, epoch=0)
|
||||||
|
self.train_iterator = iter(self.dataloader)
|
||||||
|
|
||||||
|
log_file = output_dir / 'worker_{}.log'.format(dist.get_rank())
|
||||||
|
self.filehandler = logging.FileHandler(str(log_file))
|
||||||
|
logger.addHandler(self.filehandler)
|
||||||
|
self.logger = logger
|
||||||
|
self.msg = ""
|
||||||
|
|
||||||
|
def update_core(self, batch):
|
||||||
|
self.msg = "Rank: {}, ".format(dist.get_rank())
|
||||||
|
losses_dict = {}
|
||||||
|
|
||||||
|
for turn in self.turns:
|
||||||
|
speech = batch["speech"]
|
||||||
|
speech = speech.unsqueeze(1)
|
||||||
|
outs = self.model(
|
||||||
|
text=batch["text"],
|
||||||
|
text_lengths=batch["text_lengths"],
|
||||||
|
feats=batch["feats"],
|
||||||
|
feats_lengths=batch["feats_lengths"],
|
||||||
|
forward_generator=turn == "generator")
|
||||||
|
# Generator
|
||||||
|
if turn == "generator":
|
||||||
|
# parse outputs
|
||||||
|
speech_hat_, dur_nll, _, start_idxs, _, z_mask, outs_ = outs
|
||||||
|
_, z_p, m_p, logs_p, _, logs_q = outs_
|
||||||
|
speech_ = get_segments(
|
||||||
|
x=speech,
|
||||||
|
start_idxs=start_idxs *
|
||||||
|
self.model.generator.upsample_factor,
|
||||||
|
segment_size=self.model.generator.segment_size *
|
||||||
|
self.model.generator.upsample_factor, )
|
||||||
|
|
||||||
|
# calculate discriminator outputs
|
||||||
|
p_hat = self.model.discriminator(speech_hat_)
|
||||||
|
with paddle.no_grad():
|
||||||
|
# do not store discriminator gradient in generator turn
|
||||||
|
p = self.model.discriminator(speech_)
|
||||||
|
|
||||||
|
# calculate losses
|
||||||
|
mel_loss = self.criterion_mel(speech_hat_, speech_)
|
||||||
|
kl_loss = self.criterion_kl(z_p, logs_q, m_p, logs_p, z_mask)
|
||||||
|
dur_loss = paddle.sum(dur_nll)
|
||||||
|
adv_loss = self.criterion_gen_adv(p_hat)
|
||||||
|
feat_match_loss = self.criterion_feat_match(p_hat, p)
|
||||||
|
|
||||||
|
mel_loss = mel_loss * self.lambda_mel
|
||||||
|
kl_loss = kl_loss * self.lambda_kl
|
||||||
|
dur_loss = dur_loss * self.lambda_dur
|
||||||
|
adv_loss = adv_loss * self.lambda_adv
|
||||||
|
feat_match_loss = feat_match_loss * self.lambda_feat_match
|
||||||
|
gen_loss = mel_loss + kl_loss + dur_loss + adv_loss + feat_match_loss
|
||||||
|
|
||||||
|
report("train/generator_loss", float(gen_loss))
|
||||||
|
report("train/generator_mel_loss", float(mel_loss))
|
||||||
|
report("train/generator_kl_loss", float(kl_loss))
|
||||||
|
report("train/generator_dur_loss", float(dur_loss))
|
||||||
|
report("train/generator_adv_loss", float(adv_loss))
|
||||||
|
report("train/generator_feat_match_loss",
|
||||||
|
float(feat_match_loss))
|
||||||
|
|
||||||
|
losses_dict["generator_loss"] = float(gen_loss)
|
||||||
|
losses_dict["generator_mel_loss"] = float(mel_loss)
|
||||||
|
losses_dict["generator_kl_loss"] = float(kl_loss)
|
||||||
|
losses_dict["generator_dur_loss"] = float(dur_loss)
|
||||||
|
losses_dict["generator_adv_loss"] = float(adv_loss)
|
||||||
|
losses_dict["generator_feat_match_loss"] = float(
|
||||||
|
feat_match_loss)
|
||||||
|
|
||||||
|
self.optimizer_g.clear_grad()
|
||||||
|
gen_loss.backward()
|
||||||
|
|
||||||
|
self.optimizer_g.step()
|
||||||
|
self.scheduler_g.step()
|
||||||
|
|
||||||
|
# reset cache
|
||||||
|
if self.model.reuse_cache_gen or not self.model.training:
|
||||||
|
self.model._cache = None
|
||||||
|
|
||||||
|
# Disctiminator
|
||||||
|
elif turn == "discriminator":
|
||||||
|
# parse outputs
|
||||||
|
speech_hat_, _, _, start_idxs, *_ = outs
|
||||||
|
speech_ = get_segments(
|
||||||
|
x=speech,
|
||||||
|
start_idxs=start_idxs *
|
||||||
|
self.model.generator.upsample_factor,
|
||||||
|
segment_size=self.model.generator.segment_size *
|
||||||
|
self.model.generator.upsample_factor, )
|
||||||
|
|
||||||
|
# calculate discriminator outputs
|
||||||
|
p_hat = self.model.discriminator(speech_hat_.detach())
|
||||||
|
p = self.model.discriminator(speech_)
|
||||||
|
|
||||||
|
# calculate losses
|
||||||
|
real_loss, fake_loss = self.criterion_dis_adv(p_hat, p)
|
||||||
|
dis_loss = real_loss + fake_loss
|
||||||
|
|
||||||
|
report("train/real_loss", float(real_loss))
|
||||||
|
report("train/fake_loss", float(fake_loss))
|
||||||
|
report("train/discriminator_loss", float(dis_loss))
|
||||||
|
losses_dict["real_loss"] = float(real_loss)
|
||||||
|
losses_dict["fake_loss"] = float(fake_loss)
|
||||||
|
losses_dict["discriminator_loss"] = float(dis_loss)
|
||||||
|
|
||||||
|
self.optimizer_d.clear_grad()
|
||||||
|
dis_loss.backward()
|
||||||
|
|
||||||
|
self.optimizer_d.step()
|
||||||
|
self.scheduler_d.step()
|
||||||
|
|
||||||
|
# reset cache
|
||||||
|
if self.model.reuse_cache_dis or not self.model.training:
|
||||||
|
self.model._cache = None
|
||||||
|
|
||||||
|
self.msg += ', '.join('{}: {:>.6f}'.format(k, v)
|
||||||
|
for k, v in losses_dict.items())
|
||||||
|
|
||||||
|
|
||||||
|
class VITSEvaluator(StandardEvaluator):
|
||||||
|
def __init__(self,
|
||||||
|
model,
|
||||||
|
criterions: Dict[str, Layer],
|
||||||
|
dataloader: DataLoader,
|
||||||
|
lambda_adv: float=1.0,
|
||||||
|
lambda_mel: float=45.0,
|
||||||
|
lambda_feat_match: float=2.0,
|
||||||
|
lambda_dur: float=1.0,
|
||||||
|
lambda_kl: float=1.0,
|
||||||
|
generator_first: bool=False,
|
||||||
|
output_dir=None):
|
||||||
|
# 因为输入的是单模型,但是没有用到父类的 init(), 所以需要重新写这部分
|
||||||
|
models = {"main": model}
|
||||||
|
self.models: Dict[str, Layer] = models
|
||||||
|
# self.model = model
|
||||||
|
self.model = model._layers if isinstance(model, paddle.DataParallel) else model
|
||||||
|
|
||||||
|
self.criterions = criterions
|
||||||
|
self.criterion_mel = criterions['mel']
|
||||||
|
self.criterion_feat_match = criterions['feat_match']
|
||||||
|
self.criterion_gen_adv = criterions["gen_adv"]
|
||||||
|
self.criterion_dis_adv = criterions["dis_adv"]
|
||||||
|
self.criterion_kl = criterions["kl"]
|
||||||
|
|
||||||
|
self.dataloader = dataloader
|
||||||
|
|
||||||
|
self.lambda_adv = lambda_adv
|
||||||
|
self.lambda_mel = lambda_mel
|
||||||
|
self.lambda_feat_match = lambda_feat_match
|
||||||
|
self.lambda_dur = lambda_dur
|
||||||
|
self.lambda_kl = lambda_kl
|
||||||
|
|
||||||
|
if generator_first:
|
||||||
|
self.turns = ["generator", "discriminator"]
|
||||||
|
else:
|
||||||
|
self.turns = ["discriminator", "generator"]
|
||||||
|
|
||||||
|
log_file = output_dir / 'worker_{}.log'.format(dist.get_rank())
|
||||||
|
self.filehandler = logging.FileHandler(str(log_file))
|
||||||
|
logger.addHandler(self.filehandler)
|
||||||
|
self.logger = logger
|
||||||
|
self.msg = ""
|
||||||
|
|
||||||
|
def evaluate_core(self, batch):
|
||||||
|
# logging.debug("Evaluate: ")
|
||||||
|
self.msg = "Evaluate: "
|
||||||
|
losses_dict = {}
|
||||||
|
|
||||||
|
for turn in self.turns:
|
||||||
|
speech = batch["speech"]
|
||||||
|
speech = speech.unsqueeze(1)
|
||||||
|
outs = self.model(
|
||||||
|
text=batch["text"],
|
||||||
|
text_lengths=batch["text_lengths"],
|
||||||
|
feats=batch["feats"],
|
||||||
|
feats_lengths=batch["feats_lengths"],
|
||||||
|
forward_generator=turn == "generator")
|
||||||
|
# Generator
|
||||||
|
if turn == "generator":
|
||||||
|
# parse outputs
|
||||||
|
speech_hat_, dur_nll, _, start_idxs, _, z_mask, outs_ = outs
|
||||||
|
_, z_p, m_p, logs_p, _, logs_q = outs_
|
||||||
|
speech_ = get_segments(
|
||||||
|
x=speech,
|
||||||
|
start_idxs=start_idxs *
|
||||||
|
self.model.generator.upsample_factor,
|
||||||
|
segment_size=self.model.generator.segment_size *
|
||||||
|
self.model.generator.upsample_factor, )
|
||||||
|
|
||||||
|
# calculate discriminator outputs
|
||||||
|
p_hat = self.model.discriminator(speech_hat_)
|
||||||
|
with paddle.no_grad():
|
||||||
|
# do not store discriminator gradient in generator turn
|
||||||
|
p = self.model.discriminator(speech_)
|
||||||
|
|
||||||
|
# calculate losses
|
||||||
|
mel_loss = self.criterion_mel(speech_hat_, speech_)
|
||||||
|
kl_loss = self.criterion_kl(z_p, logs_q, m_p, logs_p, z_mask)
|
||||||
|
dur_loss = paddle.sum(dur_nll)
|
||||||
|
adv_loss = self.criterion_gen_adv(p_hat)
|
||||||
|
feat_match_loss = self.criterion_feat_match(p_hat, p)
|
||||||
|
|
||||||
|
mel_loss = mel_loss * self.lambda_mel
|
||||||
|
kl_loss = kl_loss * self.lambda_kl
|
||||||
|
dur_loss = dur_loss * self.lambda_dur
|
||||||
|
adv_loss = adv_loss * self.lambda_adv
|
||||||
|
feat_match_loss = feat_match_loss * self.lambda_feat_match
|
||||||
|
gen_loss = mel_loss + kl_loss + dur_loss + adv_loss + feat_match_loss
|
||||||
|
|
||||||
|
report("eval/generator_loss", float(gen_loss))
|
||||||
|
report("eval/generator_mel_loss", float(mel_loss))
|
||||||
|
report("eval/generator_kl_loss", float(kl_loss))
|
||||||
|
report("eval/generator_dur_loss", float(dur_loss))
|
||||||
|
report("eval/generator_adv_loss", float(adv_loss))
|
||||||
|
report("eval/generator_feat_match_loss", float(feat_match_loss))
|
||||||
|
|
||||||
|
losses_dict["generator_loss"] = float(gen_loss)
|
||||||
|
losses_dict["generator_mel_loss"] = float(mel_loss)
|
||||||
|
losses_dict["generator_kl_loss"] = float(kl_loss)
|
||||||
|
losses_dict["generator_dur_loss"] = float(dur_loss)
|
||||||
|
losses_dict["generator_adv_loss"] = float(adv_loss)
|
||||||
|
losses_dict["generator_feat_match_loss"] = float(
|
||||||
|
feat_match_loss)
|
||||||
|
|
||||||
|
# reset cache
|
||||||
|
if self.model.reuse_cache_gen or not self.model.training:
|
||||||
|
self.model._cache = None
|
||||||
|
|
||||||
|
# Disctiminator
|
||||||
|
elif turn == "discriminator":
|
||||||
|
# parse outputs
|
||||||
|
speech_hat_, _, _, start_idxs, *_ = outs
|
||||||
|
speech_ = get_segments(
|
||||||
|
x=speech,
|
||||||
|
start_idxs=start_idxs *
|
||||||
|
self.model.generator.upsample_factor,
|
||||||
|
segment_size=self.model.generator.segment_size *
|
||||||
|
self.model.generator.upsample_factor, )
|
||||||
|
|
||||||
|
# calculate discriminator outputs
|
||||||
|
p_hat = self.model.discriminator(speech_hat_.detach())
|
||||||
|
p = self.model.discriminator(speech_)
|
||||||
|
|
||||||
|
# calculate losses
|
||||||
|
real_loss, fake_loss = self.criterion_dis_adv(p_hat, p)
|
||||||
|
dis_loss = real_loss + fake_loss
|
||||||
|
|
||||||
|
report("eval/real_loss", float(real_loss))
|
||||||
|
report("eval/fake_loss", float(fake_loss))
|
||||||
|
report("eval/discriminator_loss", float(dis_loss))
|
||||||
|
losses_dict["real_loss"] = float(real_loss)
|
||||||
|
losses_dict["fake_loss"] = float(fake_loss)
|
||||||
|
losses_dict["discriminator_loss"] = float(dis_loss)
|
||||||
|
|
||||||
|
# reset cache
|
||||||
|
if self.model.reuse_cache_dis or not self.model.training:
|
||||||
|
self.model._cache = None
|
||||||
|
|
||||||
|
self.msg += ', '.join('{}: {:>.6f}'.format(k, v)
|
||||||
|
for k, v in losses_dict.items())
|
||||||
|
self.logger.info(self.msg)
|
@ -0,0 +1,13 @@
|
|||||||
|
# Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved.
|
||||||
|
#
|
||||||
|
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||||
|
# you may not use this file except in compliance with the License.
|
||||||
|
# You may obtain a copy of the License at
|
||||||
|
#
|
||||||
|
# http://www.apache.org/licenses/LICENSE-2.0
|
||||||
|
#
|
||||||
|
# Unless required by applicable law or agreed to in writing, software
|
||||||
|
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||||
|
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||||
|
# See the License for the specific language governing permissions and
|
||||||
|
# limitations under the License.
|
@ -0,0 +1,154 @@
|
|||||||
|
# Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved.
|
||||||
|
#
|
||||||
|
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||||
|
# you may not use this file except in compliance with the License.
|
||||||
|
# You may obtain a copy of the License at
|
||||||
|
#
|
||||||
|
# http://www.apache.org/licenses/LICENSE-2.0
|
||||||
|
#
|
||||||
|
# Unless required by applicable law or agreed to in writing, software
|
||||||
|
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||||
|
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||||
|
# See the License for the specific language governing permissions and
|
||||||
|
# limitations under the License.
|
||||||
|
# Modified from espnet(https://github.com/espnet/espnet)
|
||||||
|
import math
|
||||||
|
from typing import Optional
|
||||||
|
from typing import Tuple
|
||||||
|
|
||||||
|
import paddle
|
||||||
|
import paddle.nn.functional as F
|
||||||
|
from paddle import nn
|
||||||
|
|
||||||
|
|
||||||
|
class ResidualBlock(nn.Layer):
|
||||||
|
"""Residual block module in WaveNet."""
|
||||||
|
|
||||||
|
def __init__(
|
||||||
|
self,
|
||||||
|
kernel_size: int=3,
|
||||||
|
residual_channels: int=64,
|
||||||
|
gate_channels: int=128,
|
||||||
|
skip_channels: int=64,
|
||||||
|
aux_channels: int=80,
|
||||||
|
global_channels: int=-1,
|
||||||
|
dropout_rate: float=0.0,
|
||||||
|
dilation: int=1,
|
||||||
|
bias: bool=True,
|
||||||
|
scale_residual: bool=False, ):
|
||||||
|
"""Initialize ResidualBlock module.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
kernel_size (int): Kernel size of dilation convolution layer.
|
||||||
|
residual_channels (int): Number of channels for residual connection.
|
||||||
|
skip_channels (int): Number of channels for skip connection.
|
||||||
|
aux_channels (int): Number of local conditioning channels.
|
||||||
|
dropout (float): Dropout probability.
|
||||||
|
dilation (int): Dilation factor.
|
||||||
|
bias (bool): Whether to add bias parameter in convolution layers.
|
||||||
|
scale_residual (bool): Whether to scale the residual outputs.
|
||||||
|
|
||||||
|
"""
|
||||||
|
super().__init__()
|
||||||
|
self.dropout_rate = dropout_rate
|
||||||
|
self.residual_channels = residual_channels
|
||||||
|
self.skip_channels = skip_channels
|
||||||
|
self.scale_residual = scale_residual
|
||||||
|
|
||||||
|
# check
|
||||||
|
assert (
|
||||||
|
kernel_size - 1) % 2 == 0, "Not support even number kernel size."
|
||||||
|
assert gate_channels % 2 == 0
|
||||||
|
|
||||||
|
# dilation conv
|
||||||
|
padding = (kernel_size - 1) // 2 * dilation
|
||||||
|
self.conv = nn.Conv1D(
|
||||||
|
residual_channels,
|
||||||
|
gate_channels,
|
||||||
|
kernel_size,
|
||||||
|
padding=padding,
|
||||||
|
dilation=dilation,
|
||||||
|
bias_attr=bias, )
|
||||||
|
|
||||||
|
# local conditioning
|
||||||
|
if aux_channels > 0:
|
||||||
|
self.conv1x1_aux = nn.Conv1D(
|
||||||
|
aux_channels, gate_channels, kernel_size=1, bias_attr=False)
|
||||||
|
else:
|
||||||
|
self.conv1x1_aux = None
|
||||||
|
|
||||||
|
# global conditioning
|
||||||
|
if global_channels > 0:
|
||||||
|
self.conv1x1_glo = nn.Conv1D(
|
||||||
|
global_channels, gate_channels, kernel_size=1, bias_attr=False)
|
||||||
|
else:
|
||||||
|
self.conv1x1_glo = None
|
||||||
|
|
||||||
|
# conv output is split into two groups
|
||||||
|
gate_out_channels = gate_channels // 2
|
||||||
|
|
||||||
|
# NOTE: concat two convs into a single conv for the efficiency
|
||||||
|
# (integrate res 1x1 + skip 1x1 convs)
|
||||||
|
self.conv1x1_out = nn.Conv1D(
|
||||||
|
gate_out_channels,
|
||||||
|
residual_channels + skip_channels,
|
||||||
|
kernel_size=1,
|
||||||
|
bias_attr=bias)
|
||||||
|
|
||||||
|
def forward(
|
||||||
|
self,
|
||||||
|
x: paddle.Tensor,
|
||||||
|
x_mask: Optional[paddle.Tensor]=None,
|
||||||
|
c: Optional[paddle.Tensor]=None,
|
||||||
|
g: Optional[paddle.Tensor]=None,
|
||||||
|
) -> Tuple[paddle.Tensor, paddle.Tensor]:
|
||||||
|
"""Calculate forward propagation.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
x (Tensor): Input tensor (B, residual_channels, T).
|
||||||
|
x_mask Optional[paddle.Tensor]: Mask tensor (B, 1, T).
|
||||||
|
c (Optional[Tensor]): Local conditioning tensor (B, aux_channels, T).
|
||||||
|
g (Optional[Tensor]): Global conditioning tensor (B, global_channels, 1).
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
Tensor: Output tensor for residual connection (B, residual_channels, T).
|
||||||
|
Tensor: Output tensor for skip connection (B, skip_channels, T).
|
||||||
|
|
||||||
|
"""
|
||||||
|
residual = x
|
||||||
|
x = F.dropout(x, p=self.dropout_rate, training=self.training)
|
||||||
|
x = self.conv(x)
|
||||||
|
|
||||||
|
# split into two part for gated activation
|
||||||
|
splitdim = 1
|
||||||
|
xa, xb = paddle.split(x, 2, axis=splitdim)
|
||||||
|
|
||||||
|
# local conditioning
|
||||||
|
if c is not None:
|
||||||
|
c = self.conv1x1_aux(c)
|
||||||
|
ca, cb = paddle.split(c, 2, axis=splitdim)
|
||||||
|
xa, xb = xa + ca, xb + cb
|
||||||
|
|
||||||
|
# global conditioning
|
||||||
|
if g is not None:
|
||||||
|
g = self.conv1x1_glo(g)
|
||||||
|
ga, gb = paddle.split(g, 2, axis=splitdim)
|
||||||
|
xa, xb = xa + ga, xb + gb
|
||||||
|
|
||||||
|
x = paddle.tanh(xa) * F.sigmoid(xb)
|
||||||
|
|
||||||
|
# residual + skip 1x1 conv
|
||||||
|
x = self.conv1x1_out(x)
|
||||||
|
if x_mask is not None:
|
||||||
|
x = x * x_mask
|
||||||
|
|
||||||
|
# split integrated conv results
|
||||||
|
x, s = paddle.split(
|
||||||
|
x, [self.residual_channels, self.skip_channels], axis=1)
|
||||||
|
|
||||||
|
# for residual connection
|
||||||
|
x = x + residual
|
||||||
|
if self.scale_residual:
|
||||||
|
x = x * math.sqrt(0.5)
|
||||||
|
|
||||||
|
return x, s
|
@ -0,0 +1,175 @@
|
|||||||
|
# Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved.
|
||||||
|
#
|
||||||
|
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||||
|
# you may not use this file except in compliance with the License.
|
||||||
|
# You may obtain a copy of the License at
|
||||||
|
#
|
||||||
|
# http://www.apache.org/licenses/LICENSE-2.0
|
||||||
|
#
|
||||||
|
# Unless required by applicable law or agreed to in writing, software
|
||||||
|
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||||
|
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||||
|
# See the License for the specific language governing permissions and
|
||||||
|
# limitations under the License.
|
||||||
|
# Modified from espnet(https://github.com/espnet/espnet)
|
||||||
|
import math
|
||||||
|
from typing import Optional
|
||||||
|
|
||||||
|
import paddle
|
||||||
|
from paddle import nn
|
||||||
|
|
||||||
|
from paddlespeech.t2s.models.vits.wavenet.residual_block import ResidualBlock
|
||||||
|
|
||||||
|
|
||||||
|
class WaveNet(nn.Layer):
|
||||||
|
"""WaveNet with global conditioning."""
|
||||||
|
|
||||||
|
def __init__(
|
||||||
|
self,
|
||||||
|
in_channels: int=1,
|
||||||
|
out_channels: int=1,
|
||||||
|
kernel_size: int=3,
|
||||||
|
layers: int=30,
|
||||||
|
stacks: int=3,
|
||||||
|
base_dilation: int=2,
|
||||||
|
residual_channels: int=64,
|
||||||
|
aux_channels: int=-1,
|
||||||
|
gate_channels: int=128,
|
||||||
|
skip_channels: int=64,
|
||||||
|
global_channels: int=-1,
|
||||||
|
dropout_rate: float=0.0,
|
||||||
|
bias: bool=True,
|
||||||
|
use_weight_norm: bool=True,
|
||||||
|
use_first_conv: bool=False,
|
||||||
|
use_last_conv: bool=False,
|
||||||
|
scale_residual: bool=False,
|
||||||
|
scale_skip_connect: bool=False, ):
|
||||||
|
"""Initialize WaveNet module.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
in_channels (int): Number of input channels.
|
||||||
|
out_channels (int): Number of output channels.
|
||||||
|
kernel_size (int): Kernel size of dilated convolution.
|
||||||
|
layers (int): Number of residual block layers.
|
||||||
|
stacks (int): Number of stacks i.e., dilation cycles.
|
||||||
|
base_dilation (int): Base dilation factor.
|
||||||
|
residual_channels (int): Number of channels in residual conv.
|
||||||
|
gate_channels (int): Number of channels in gated conv.
|
||||||
|
skip_channels (int): Number of channels in skip conv.
|
||||||
|
aux_channels (int): Number of channels for local conditioning feature.
|
||||||
|
global_channels (int): Number of channels for global conditioning feature.
|
||||||
|
dropout_rate (float): Dropout rate. 0.0 means no dropout applied.
|
||||||
|
bias (bool): Whether to use bias parameter in conv layer.
|
||||||
|
use_weight_norm (bool): Whether to use weight norm. If set to true, it will
|
||||||
|
be applied to all of the conv layers.
|
||||||
|
use_first_conv (bool): Whether to use the first conv layers.
|
||||||
|
use_last_conv (bool): Whether to use the last conv layers.
|
||||||
|
scale_residual (bool): Whether to scale the residual outputs.
|
||||||
|
scale_skip_connect (bool): Whether to scale the skip connection outputs.
|
||||||
|
|
||||||
|
"""
|
||||||
|
super().__init__()
|
||||||
|
self.layers = layers
|
||||||
|
self.stacks = stacks
|
||||||
|
self.kernel_size = kernel_size
|
||||||
|
self.base_dilation = base_dilation
|
||||||
|
self.use_first_conv = use_first_conv
|
||||||
|
self.use_last_conv = use_last_conv
|
||||||
|
self.scale_skip_connect = scale_skip_connect
|
||||||
|
|
||||||
|
# check the number of layers and stacks
|
||||||
|
assert layers % stacks == 0
|
||||||
|
layers_per_stack = layers // stacks
|
||||||
|
|
||||||
|
# define first convolution
|
||||||
|
if self.use_first_conv:
|
||||||
|
self.first_conv = nn.Conv1D(
|
||||||
|
in_channels, residual_channels, kernel_size=1, bias_attr=True)
|
||||||
|
|
||||||
|
# define residual blocks
|
||||||
|
self.conv_layers = nn.LayerList()
|
||||||
|
for layer in range(layers):
|
||||||
|
dilation = base_dilation**(layer % layers_per_stack)
|
||||||
|
conv = ResidualBlock(
|
||||||
|
kernel_size=kernel_size,
|
||||||
|
residual_channels=residual_channels,
|
||||||
|
gate_channels=gate_channels,
|
||||||
|
skip_channels=skip_channels,
|
||||||
|
aux_channels=aux_channels,
|
||||||
|
global_channels=global_channels,
|
||||||
|
dilation=dilation,
|
||||||
|
dropout_rate=dropout_rate,
|
||||||
|
bias=bias,
|
||||||
|
scale_residual=scale_residual, )
|
||||||
|
self.conv_layers.append(conv)
|
||||||
|
|
||||||
|
# define output layers
|
||||||
|
if self.use_last_conv:
|
||||||
|
self.last_conv = nn.Sequential(
|
||||||
|
nn.ReLU(),
|
||||||
|
nn.Conv1D(
|
||||||
|
skip_channels, skip_channels, kernel_size=1,
|
||||||
|
bias_attr=True),
|
||||||
|
nn.ReLU(),
|
||||||
|
nn.Conv1D(
|
||||||
|
skip_channels, out_channels, kernel_size=1, bias_attr=True),
|
||||||
|
)
|
||||||
|
|
||||||
|
# apply weight norm
|
||||||
|
if use_weight_norm:
|
||||||
|
self.apply_weight_norm()
|
||||||
|
|
||||||
|
def forward(
|
||||||
|
self,
|
||||||
|
x: paddle.Tensor,
|
||||||
|
x_mask: Optional[paddle.Tensor]=None,
|
||||||
|
c: Optional[paddle.Tensor]=None,
|
||||||
|
g: Optional[paddle.Tensor]=None, ) -> paddle.Tensor:
|
||||||
|
"""Calculate forward propagation.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
x (Tensor): Input noise signal (B, 1, T) if use_first_conv else
|
||||||
|
(B, residual_channels, T).
|
||||||
|
x_mask (Optional[Tensor]): Mask tensor (B, 1, T).
|
||||||
|
c (Optional[Tensor]): Local conditioning features (B, aux_channels, T).
|
||||||
|
g (Optional[Tensor]): Global conditioning features (B, global_channels, 1).
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
Tensor: Output tensor (B, out_channels, T) if use_last_conv else
|
||||||
|
(B, residual_channels, T).
|
||||||
|
|
||||||
|
"""
|
||||||
|
# encode to hidden representation
|
||||||
|
if self.use_first_conv:
|
||||||
|
x = self.first_conv(x)
|
||||||
|
|
||||||
|
# residual block
|
||||||
|
skips = 0.0
|
||||||
|
for f in self.conv_layers:
|
||||||
|
x, h = f(x, x_mask=x_mask, c=c, g=g)
|
||||||
|
skips = skips + h
|
||||||
|
x = skips
|
||||||
|
if self.scale_skip_connect:
|
||||||
|
x = x * math.sqrt(1.0 / len(self.conv_layers))
|
||||||
|
|
||||||
|
# apply final layers
|
||||||
|
if self.use_last_conv:
|
||||||
|
x = self.last_conv(x)
|
||||||
|
|
||||||
|
return x
|
||||||
|
|
||||||
|
def apply_weight_norm(self):
|
||||||
|
def _apply_weight_norm(layer):
|
||||||
|
if isinstance(layer, (nn.Conv1D, nn.Conv2D)):
|
||||||
|
nn.utils.weight_norm(layer)
|
||||||
|
|
||||||
|
self.apply(_apply_weight_norm)
|
||||||
|
|
||||||
|
def remove_weight_norm(self):
|
||||||
|
def _remove_weight_norm(layer):
|
||||||
|
try:
|
||||||
|
nn.utils.remove_weight_norm(layer)
|
||||||
|
except ValueError:
|
||||||
|
pass
|
||||||
|
|
||||||
|
self.apply(_remove_weight_norm)
|
@ -0,0 +1,13 @@
|
|||||||
|
# Copyright (c) 2021 PaddlePaddle Authors. All Rights Reserved.
|
||||||
|
#
|
||||||
|
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||||
|
# you may not use this file except in compliance with the License.
|
||||||
|
# You may obtain a copy of the License at
|
||||||
|
#
|
||||||
|
# http://www.apache.org/licenses/LICENSE-2.0
|
||||||
|
#
|
||||||
|
# Unless required by applicable law or agreed to in writing, software
|
||||||
|
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||||
|
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||||
|
# See the License for the specific language governing permissions and
|
||||||
|
# limitations under the License.
|
@ -0,0 +1,38 @@
|
|||||||
|
# Copyright (c) 2021 PaddlePaddle Authors. All Rights Reserved.
|
||||||
|
#
|
||||||
|
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||||
|
# you may not use this file except in compliance with the License.
|
||||||
|
# You may obtain a copy of the License at
|
||||||
|
#
|
||||||
|
# http://www.apache.org/licenses/LICENSE-2.0
|
||||||
|
#
|
||||||
|
# Unless required by applicable law or agreed to in writing, software
|
||||||
|
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||||
|
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||||
|
# See the License for the specific language governing permissions and
|
||||||
|
# limitations under the License.
|
||||||
|
# Modified from espnet(https://github.com/espnet/espnet)
|
||||||
|
import importlib
|
||||||
|
|
||||||
|
__all__ = ["dynamic_import"]
|
||||||
|
|
||||||
|
|
||||||
|
def dynamic_import(import_path, alias=dict()):
|
||||||
|
"""dynamic import module and class
|
||||||
|
|
||||||
|
:param str import_path: syntax 'module_name:class_name'
|
||||||
|
e.g., 'paddlespeech.s2t.models.u2:U2Model'
|
||||||
|
:param dict alias: shortcut for registered class
|
||||||
|
:return: imported class
|
||||||
|
"""
|
||||||
|
if import_path not in alias and ":" not in import_path:
|
||||||
|
raise ValueError(
|
||||||
|
"import_path should be one of {} or "
|
||||||
|
'include ":", e.g. "paddlespeech.s2t.models.u2:U2Model" : '
|
||||||
|
"{}".format(set(alias), import_path))
|
||||||
|
if ":" not in import_path:
|
||||||
|
import_path = alias[import_path]
|
||||||
|
|
||||||
|
module_name, objname = import_path.split(":")
|
||||||
|
m = importlib.import_module(module_name)
|
||||||
|
return getattr(m, objname)
|
Loading…
Reference in new issue