add CNNDecoder, test=tts

pull/1634/head
TianYuan 2 years ago
parent b5315657ff
commit 0fc79f474d

@ -0,0 +1,107 @@
# use CNND
###########################################################
# FEATURE EXTRACTION SETTING #
###########################################################
fs: 24000 # sr
n_fft: 2048 # FFT size (samples).
n_shift: 300 # Hop size (samples). 12.5ms
win_length: 1200 # Window length (samples). 50ms
# If set to null, it will be the same as fft_size.
window: "hann" # Window function.
# Only used for feats_type != raw
fmin: 80 # Minimum frequency of Mel basis.
fmax: 7600 # Maximum frequency of Mel basis.
n_mels: 80 # The number of mel basis.
# Only used for the model using pitch features (e.g. FastSpeech2)
f0min: 80 # Minimum f0 for pitch extraction.
f0max: 400 # Maximum f0 for pitch extraction.
###########################################################
# DATA SETTING #
###########################################################
batch_size: 64
num_workers: 4
###########################################################
# MODEL SETTING #
###########################################################
model:
adim: 384 # attention dimension
aheads: 2 # number of attention heads
elayers: 4 # number of encoder layers
eunits: 1536 # number of encoder ff units
dlayers: 4 # number of decoder layers
dunits: 1536 # number of decoder ff units
positionwise_layer_type: conv1d # type of position-wise layer
positionwise_conv_kernel_size: 3 # kernel size of position wise conv layer
duration_predictor_layers: 2 # number of layers of duration predictor
duration_predictor_chans: 256 # number of channels of duration predictor
duration_predictor_kernel_size: 3 # filter size of duration predictor
postnet_layers: 5 # number of layers of postnset
postnet_filts: 5 # filter size of conv layers in postnet
postnet_chans: 256 # number of channels of conv layers in postnet
use_scaled_pos_enc: True # whether to use scaled positional encoding
encoder_normalize_before: True # whether to perform layer normalization before the input
decoder_normalize_before: True # whether to perform layer normalization before the input
reduction_factor: 1 # reduction factor
encoder_type: transformer # encoder type
decoder_type: cnndecoder # decoder type
init_type: xavier_uniform # initialization type
init_enc_alpha: 1.0 # initial value of alpha of encoder scaled position encoding
init_dec_alpha: 1.0 # initial value of alpha of decoder scaled position encoding
transformer_enc_dropout_rate: 0.2 # dropout rate for transformer encoder layer
transformer_enc_positional_dropout_rate: 0.2 # dropout rate for transformer encoder positional encoding
transformer_enc_attn_dropout_rate: 0.2 # dropout rate for transformer encoder attention layer
cnn_dec_dropout_rate: 0.2 # dropout rate for cnn decoder layer
cnn_postnet_dropout_rate: 0.2
cnn_postnet_resblock_kernel_sizes: [256, 256] # kernel sizes for residual block of cnn_postnet
cnn_postnet_kernel_size: 5 # kernel size of cnn_postnet
cnn_decoder_embedding_dim: 256
pitch_predictor_layers: 5 # number of conv layers in pitch predictor
pitch_predictor_chans: 256 # number of channels of conv layers in pitch predictor
pitch_predictor_kernel_size: 5 # kernel size of conv leyers in pitch predictor
pitch_predictor_dropout: 0.5 # dropout rate in pitch predictor
pitch_embed_kernel_size: 1 # kernel size of conv embedding layer for pitch
pitch_embed_dropout: 0.0 # dropout rate after conv embedding layer for pitch
stop_gradient_from_pitch_predictor: True # whether to stop the gradient from pitch predictor to encoder
energy_predictor_layers: 2 # number of conv layers in energy predictor
energy_predictor_chans: 256 # number of channels of conv layers in energy predictor
energy_predictor_kernel_size: 3 # kernel size of conv leyers in energy predictor
energy_predictor_dropout: 0.5 # dropout rate in energy predictor
energy_embed_kernel_size: 1 # kernel size of conv embedding layer for energy
energy_embed_dropout: 0.0 # dropout rate after conv embedding layer for energy
stop_gradient_from_energy_predictor: False # whether to stop the gradient from energy predictor to encoder
###########################################################
# UPDATER SETTING #
###########################################################
updater:
use_masking: True # whether to apply masking for padded part in loss calculation
###########################################################
# OPTIMIZER SETTING #
###########################################################
optimizer:
optim: adam # optimizer type
learning_rate: 0.001 # learning rate
###########################################################
# TRAINING SETTING #
###########################################################
max_epoch: 1000
num_snapshots: 5
###########################################################
# OTHER SETTING #
###########################################################
seed: 10086

@ -0,0 +1,92 @@
#!/bin/bash
config_path=$1
train_output_path=$2
ckpt_name=$3
stage=0
stop_stage=0
# pwgan
if [ ${stage} -le 0 ] && [ ${stop_stage} -ge 0 ]; then
FLAGS_allocator_strategy=naive_best_fit \
FLAGS_fraction_of_gpu_memory_to_use=0.01 \
python3 ${BIN_DIR}/../synthesize_streaming.py \
--am=fastspeech2_csmsc \
--am_config=${config_path} \
--am_ckpt=${train_output_path}/checkpoints/${ckpt_name} \
--am_stat=dump/train/speech_stats.npy \
--voc=pwgan_csmsc \
--voc_config=pwg_baker_ckpt_0.4/pwg_default.yaml \
--voc_ckpt=pwg_baker_ckpt_0.4/pwg_snapshot_iter_400000.pdz \
--voc_stat=pwg_baker_ckpt_0.4/pwg_stats.npy \
--lang=zh \
--text=${BIN_DIR}/../sentences.txt \
--output_dir=${train_output_path}/test_e2e \
--phones_dict=dump/phone_id_map.txt \
--inference_dir=${train_output_path}/inference
fi
# for more GAN Vocoders
# multi band melgan
if [ ${stage} -le 1 ] && [ ${stop_stage} -ge 1 ]; then
FLAGS_allocator_strategy=naive_best_fit \
FLAGS_fraction_of_gpu_memory_to_use=0.01 \
python3 ${BIN_DIR}/../synthesize_streaming.py \
--am=fastspeech2_csmsc \
--am_config=${config_path} \
--am_ckpt=${train_output_path}/checkpoints/${ckpt_name} \
--am_stat=dump/train/speech_stats.npy \
--voc=mb_melgan_csmsc \
--voc_config=mb_melgan_csmsc_ckpt_0.1.1/default.yaml \
--voc_ckpt=mb_melgan_csmsc_ckpt_0.1.1/snapshot_iter_1000000.pdz\
--voc_stat=mb_melgan_csmsc_ckpt_0.1.1/feats_stats.npy \
--lang=zh \
--text=${BIN_DIR}/../sentences.txt \
--output_dir=${train_output_path}/test_e2e \
--phones_dict=dump/phone_id_map.txt \
--inference_dir=${train_output_path}/inference
fi
# the pretrained models haven't release now
# style melgan
# style melgan's Dygraph to Static Graph is not ready now
if [ ${stage} -le 2 ] && [ ${stop_stage} -ge 2 ]; then
FLAGS_allocator_strategy=naive_best_fit \
FLAGS_fraction_of_gpu_memory_to_use=0.01 \
python3 ${BIN_DIR}/../synthesize_streaming.py \
--am=fastspeech2_csmsc \
--am_config=${config_path} \
--am_ckpt=${train_output_path}/checkpoints/${ckpt_name} \
--am_stat=dump/train/speech_stats.npy \
--voc=style_melgan_csmsc \
--voc_config=style_melgan_csmsc_ckpt_0.1.1/default.yaml \
--voc_ckpt=style_melgan_csmsc_ckpt_0.1.1/snapshot_iter_1500000.pdz \
--voc_stat=style_melgan_csmsc_ckpt_0.1.1/feats_stats.npy \
--lang=zh \
--text=${BIN_DIR}/../sentences.txt \
--output_dir=${train_output_path}/test_e2e \
--phones_dict=dump/phone_id_map.txt
# --inference_dir=${train_output_path}/inference
fi
# hifigan
if [ ${stage} -le 3 ] && [ ${stop_stage} -ge 3 ]; then
echo "in hifigan syn_e2e"
FLAGS_allocator_strategy=naive_best_fit \
FLAGS_fraction_of_gpu_memory_to_use=0.01 \
python3 ${BIN_DIR}/../synthesize_streaming.py \
--am=fastspeech2_csmsc \
--am_config=${config_path} \
--am_ckpt=${train_output_path}/checkpoints/${ckpt_name} \
--am_stat=dump/train/speech_stats.npy \
--voc=hifigan_csmsc \
--voc_config=hifigan_csmsc_ckpt_0.1.1/default.yaml \
--voc_ckpt=hifigan_csmsc_ckpt_0.1.1/snapshot_iter_2500000.pdz \
--voc_stat=hifigan_csmsc_ckpt_0.1.1/feats_stats.npy \
--lang=zh \
--text=${BIN_DIR}/../sentences.txt \
--output_dir=${train_output_path}/test_e2e \
--phones_dict=dump/phone_id_map.txt \
--inference_dir=${train_output_path}/inference
fi

@ -0,0 +1,48 @@
#!/bin/bash
set -e
source path.sh
gpus=0,1
stage=0
stop_stage=100
conf_path=conf/cnndecoder.yaml
train_output_path=exp/cnndecoder
ckpt_name=snapshot_iter_153.pdz
# with the following command, you can choose the stage range you want to run
# such as `./run.sh --stage 0 --stop-stage 0`
# this can not be mixed use with `$1`, `$2` ...
source ${MAIN_ROOT}/utils/parse_options.sh || exit 1
if [ ${stage} -le 0 ] && [ ${stop_stage} -ge 0 ]; then
# prepare data
./local/preprocess.sh ${conf_path} || exit -1
fi
if [ ${stage} -le 1 ] && [ ${stop_stage} -ge 1 ]; then
# train model, all `ckpt` under `train_output_path/checkpoints/` dir
CUDA_VISIBLE_DEVICES=${gpus} ./local/train.sh ${conf_path} ${train_output_path} || exit -1
fi
if [ ${stage} -le 2 ] && [ ${stop_stage} -ge 2 ]; then
# synthesize, vocoder is pwgan
CUDA_VISIBLE_DEVICES=${gpus} ./local/synthesize.sh ${conf_path} ${train_output_path} ${ckpt_name} || exit -1
fi
if [ ${stage} -le 3 ] && [ ${stop_stage} -ge 3 ]; then
# synthesize_e2e, vocoder is pwgan
CUDA_VISIBLE_DEVICES=${gpus} ./local/synthesize_e2e.sh ${conf_path} ${train_output_path} ${ckpt_name} || exit -1
fi
if [ ${stage} -le 4 ] && [ ${stop_stage} -ge 4 ]; then
# inference with static model
CUDA_VISIBLE_DEVICES=${gpus} ./local/inference.sh ${train_output_path} || exit -1
fi
if [ ${stage} -le 5 ] && [ ${stop_stage} -ge 5 ]; then
# synthesize_e2e, vocoder is pwgan
CUDA_VISIBLE_DEVICES=${gpus} ./local/synthesize_streaming.sh ${conf_path} ${train_output_path} ${ckpt_name} || exit -1
fi

@ -14,6 +14,7 @@
# Modified from espnet(https://github.com/espnet/espnet)
"""Fastspeech2 related modules for paddle"""
from typing import Dict
from typing import List
from typing import Sequence
from typing import Tuple
from typing import Union
@ -32,6 +33,8 @@ from paddlespeech.t2s.modules.predictor.duration_predictor import DurationPredic
from paddlespeech.t2s.modules.predictor.length_regulator import LengthRegulator
from paddlespeech.t2s.modules.predictor.variance_predictor import VariancePredictor
from paddlespeech.t2s.modules.tacotron2.decoder import Postnet
from paddlespeech.t2s.modules.transformer.encoder import CNNDecoder
from paddlespeech.t2s.modules.transformer.encoder import CNNPostnet
from paddlespeech.t2s.modules.transformer.encoder import ConformerEncoder
from paddlespeech.t2s.modules.transformer.encoder import TransformerEncoder
@ -97,6 +100,12 @@ class FastSpeech2(nn.Layer):
zero_triu: bool=False,
conformer_enc_kernel_size: int=7,
conformer_dec_kernel_size: int=31,
# for CNN Decoder
cnn_dec_dropout_rate: float=0.2,
cnn_postnet_dropout_rate: float=0.2,
cnn_postnet_resblock_kernel_sizes: List[int]=[256, 256],
cnn_postnet_kernel_size: int=5,
cnn_decoder_embedding_dim: int=256,
# duration predictor
duration_predictor_layers: int=2,
duration_predictor_chans: int=384,
@ -392,6 +401,13 @@ class FastSpeech2(nn.Layer):
activation_type=conformer_activation_type,
use_cnn_module=use_cnn_in_conformer,
cnn_module_kernel=conformer_dec_kernel_size, )
elif decoder_type == 'cnndecoder':
self.decoder = CNNDecoder(
emb_dim=adim,
odim=odim,
kernel_size=cnn_postnet_kernel_size,
dropout_rate=cnn_dec_dropout_rate,
resblock_kernel_sizes=cnn_postnet_resblock_kernel_sizes)
else:
raise ValueError(f"{decoder_type} is not supported.")
@ -399,14 +415,21 @@ class FastSpeech2(nn.Layer):
self.feat_out = nn.Linear(adim, odim * reduction_factor)
# define postnet
self.postnet = (None if postnet_layers == 0 else Postnet(
idim=idim,
odim=odim,
n_layers=postnet_layers,
n_chans=postnet_chans,
n_filts=postnet_filts,
use_batch_norm=use_batch_norm,
dropout_rate=postnet_dropout_rate, ))
if decoder_type == 'cnndecoder':
self.postnet = CNNPostnet(
odim=odim,
kernel_size=cnn_postnet_kernel_size,
dropout_rate=cnn_postnet_dropout_rate,
resblock_kernel_sizes=cnn_postnet_resblock_kernel_sizes)
else:
self.postnet = (None if postnet_layers == 0 else Postnet(
idim=idim,
odim=odim,
n_layers=postnet_layers,
n_chans=postnet_chans,
n_filts=postnet_filts,
use_batch_norm=use_batch_norm,
dropout_rate=postnet_dropout_rate, ))
nn.initializer.set_global_initializer(None)
@ -562,6 +585,7 @@ class FastSpeech2(nn.Layer):
[olen // self.reduction_factor for olen in olens.numpy()])
else:
olens_in = olens
# (B, 1, T)
h_masks = self._source_mask(olens_in)
else:
h_masks = None
@ -569,8 +593,11 @@ class FastSpeech2(nn.Layer):
zs, _ = self.decoder(hs, h_masks)
# (B, Lmax, odim)
before_outs = self.feat_out(zs).reshape(
(paddle.shape(zs)[0], -1, self.odim))
if self.decoder_type == 'cnndecoder':
before_outs = zs
else:
before_outs = self.feat_out(zs).reshape(
(paddle.shape(zs)[0], -1, self.odim))
# postnet -> (B, Lmax//r * r, odim)
if self.postnet is None:

@ -515,3 +515,136 @@ class ConformerEncoder(BaseEncoder):
if self.intermediate_layers is not None:
return xs, masks, intermediate_outputs
return xs, masks
class Conv1dResidualBlock(nn.Layer):
"""
Special module for simplified version of Encoder class.
"""
def __init__(self,
idim: int=256,
odim: int=256,
kernel_size: int=5,
dropout_rate: float=0.2):
super().__init__()
self.main_block = nn.Sequential(
nn.Conv1D(
idim, odim, kernel_size=kernel_size, padding=kernel_size // 2),
nn.ReLU(),
nn.BatchNorm1D(odim),
nn.Dropout(p=dropout_rate))
self.conv1d_residual = nn.Conv1D(idim, odim, kernel_size=1)
def forward(self, xs):
"""Encode input sequence.
Args:
xs (Tensor): Input tensor (#batch, idim, T).
Returns:
Tensor: Output tensor (#batch, odim, T).
"""
outputs = self.main_block(xs)
outputs = self.conv1d_residual(xs) + outputs
return outputs
class CNNDecoder(nn.Layer):
"""
Much simplified decoder than the original one with Prenet.
"""
def __init__(
self,
emb_dim: int=256,
odim: int=80,
kernel_size: int=5,
dropout_rate: float=0.2,
resblock_kernel_sizes: List[int]=[256, 256], ):
super().__init__()
input_shape = emb_dim
out_sizes = resblock_kernel_sizes
out_sizes.append(out_sizes[-1])
in_sizes = [input_shape] + out_sizes[:-1]
self.residual_blocks = nn.LayerList([
Conv1dResidualBlock(
idim=in_channels,
odim=out_channels,
kernel_size=kernel_size,
dropout_rate=dropout_rate, )
for in_channels, out_channels in zip(in_sizes, out_sizes)
])
self.conv1d = nn.Conv1D(
in_channels=out_sizes[-1], out_channels=odim, kernel_size=1)
def forward(self, xs, masks=None):
"""Encode input sequence.
Args:
xs (Tensor): Input tensor (#batch, time, idim).
masks (Tensor): Mask tensor (#batch, 1, time).
Returns:
Tensor: Output tensor (#batch, time, odim).
"""
# print("input.shape in CNNDecoder:",xs.shape)
# exchange the temporal dimension and the feature dimension
xs = xs.transpose([0, 2, 1])
if masks is not None:
xs = xs * masks
for layer in self.residual_blocks:
outputs = layer(xs)
if masks is not None:
# input_mask B * 1 * T
outputs = outputs * masks
xs = outputs
outputs = self.conv1d(outputs)
if masks is not None:
outputs = outputs * masks
outputs = outputs.transpose([0, 2, 1])
# print("outputs.shape in CNNDecoder:",outputs.shape)
return outputs, masks
class CNNPostnet(nn.Layer):
def __init__(
self,
odim: int=80,
kernel_size: int=5,
dropout_rate: float=0.2,
resblock_kernel_sizes: List[int]=[256, 256], ):
super().__init__()
out_sizes = resblock_kernel_sizes
in_sizes = [odim] + out_sizes[:-1]
self.residual_blocks = nn.LayerList([
Conv1dResidualBlock(
idim=in_channels,
odim=out_channels,
kernel_size=kernel_size,
dropout_rate=dropout_rate)
for in_channels, out_channels in zip(in_sizes, out_sizes)
])
self.conv1d = nn.Conv1D(
in_channels=out_sizes[-1], out_channels=odim, kernel_size=1)
def forward(self, xs, masks=None):
"""Encode input sequence.
Args:
xs (Tensor): Input tensor (#batch, odim, time).
masks (Tensor): Mask tensor (#batch, 1, time).
Returns:
Tensor: Output tensor (#batch, odim, time).
"""
# print("xs.shape in CNNPostnet:",xs.shape)
for layer in self.residual_blocks:
outputs = layer(xs)
if masks is not None:
# input_mask B * 1 * T
outputs = outputs * masks
xs = outputs
outputs = self.conv1d(outputs)
if masks is not None:
outputs = outputs * masks
# print("outputs.shape in CNNPostnet:",outputs.shape)
return outputs

Loading…
Cancel
Save