add CNNDecoder, test=tts

pull/1634/head
TianYuan 3 years ago
parent b5315657ff
commit 0fc79f474d

@ -0,0 +1,107 @@
# use CNND
###########################################################
# FEATURE EXTRACTION SETTING #
###########################################################
fs: 24000 # sr
n_fft: 2048 # FFT size (samples).
n_shift: 300 # Hop size (samples). 12.5ms
win_length: 1200 # Window length (samples). 50ms
# If set to null, it will be the same as fft_size.
window: "hann" # Window function.
# Only used for feats_type != raw
fmin: 80 # Minimum frequency of Mel basis.
fmax: 7600 # Maximum frequency of Mel basis.
n_mels: 80 # The number of mel basis.
# Only used for the model using pitch features (e.g. FastSpeech2)
f0min: 80 # Minimum f0 for pitch extraction.
f0max: 400 # Maximum f0 for pitch extraction.
###########################################################
# DATA SETTING #
###########################################################
batch_size: 64
num_workers: 4
###########################################################
# MODEL SETTING #
###########################################################
model:
adim: 384 # attention dimension
aheads: 2 # number of attention heads
elayers: 4 # number of encoder layers
eunits: 1536 # number of encoder ff units
dlayers: 4 # number of decoder layers
dunits: 1536 # number of decoder ff units
positionwise_layer_type: conv1d # type of position-wise layer
positionwise_conv_kernel_size: 3 # kernel size of position wise conv layer
duration_predictor_layers: 2 # number of layers of duration predictor
duration_predictor_chans: 256 # number of channels of duration predictor
duration_predictor_kernel_size: 3 # filter size of duration predictor
postnet_layers: 5 # number of layers of postnset
postnet_filts: 5 # filter size of conv layers in postnet
postnet_chans: 256 # number of channels of conv layers in postnet
use_scaled_pos_enc: True # whether to use scaled positional encoding
encoder_normalize_before: True # whether to perform layer normalization before the input
decoder_normalize_before: True # whether to perform layer normalization before the input
reduction_factor: 1 # reduction factor
encoder_type: transformer # encoder type
decoder_type: cnndecoder # decoder type
init_type: xavier_uniform # initialization type
init_enc_alpha: 1.0 # initial value of alpha of encoder scaled position encoding
init_dec_alpha: 1.0 # initial value of alpha of decoder scaled position encoding
transformer_enc_dropout_rate: 0.2 # dropout rate for transformer encoder layer
transformer_enc_positional_dropout_rate: 0.2 # dropout rate for transformer encoder positional encoding
transformer_enc_attn_dropout_rate: 0.2 # dropout rate for transformer encoder attention layer
cnn_dec_dropout_rate: 0.2 # dropout rate for cnn decoder layer
cnn_postnet_dropout_rate: 0.2
cnn_postnet_resblock_kernel_sizes: [256, 256] # kernel sizes for residual block of cnn_postnet
cnn_postnet_kernel_size: 5 # kernel size of cnn_postnet
cnn_decoder_embedding_dim: 256
pitch_predictor_layers: 5 # number of conv layers in pitch predictor
pitch_predictor_chans: 256 # number of channels of conv layers in pitch predictor
pitch_predictor_kernel_size: 5 # kernel size of conv leyers in pitch predictor
pitch_predictor_dropout: 0.5 # dropout rate in pitch predictor
pitch_embed_kernel_size: 1 # kernel size of conv embedding layer for pitch
pitch_embed_dropout: 0.0 # dropout rate after conv embedding layer for pitch
stop_gradient_from_pitch_predictor: True # whether to stop the gradient from pitch predictor to encoder
energy_predictor_layers: 2 # number of conv layers in energy predictor
energy_predictor_chans: 256 # number of channels of conv layers in energy predictor
energy_predictor_kernel_size: 3 # kernel size of conv leyers in energy predictor
energy_predictor_dropout: 0.5 # dropout rate in energy predictor
energy_embed_kernel_size: 1 # kernel size of conv embedding layer for energy
energy_embed_dropout: 0.0 # dropout rate after conv embedding layer for energy
stop_gradient_from_energy_predictor: False # whether to stop the gradient from energy predictor to encoder
###########################################################
# UPDATER SETTING #
###########################################################
updater:
use_masking: True # whether to apply masking for padded part in loss calculation
###########################################################
# OPTIMIZER SETTING #
###########################################################
optimizer:
optim: adam # optimizer type
learning_rate: 0.001 # learning rate
###########################################################
# TRAINING SETTING #
###########################################################
max_epoch: 1000
num_snapshots: 5
###########################################################
# OTHER SETTING #
###########################################################
seed: 10086

@ -0,0 +1,92 @@
#!/bin/bash
config_path=$1
train_output_path=$2
ckpt_name=$3
stage=0
stop_stage=0
# pwgan
if [ ${stage} -le 0 ] && [ ${stop_stage} -ge 0 ]; then
FLAGS_allocator_strategy=naive_best_fit \
FLAGS_fraction_of_gpu_memory_to_use=0.01 \
python3 ${BIN_DIR}/../synthesize_streaming.py \
--am=fastspeech2_csmsc \
--am_config=${config_path} \
--am_ckpt=${train_output_path}/checkpoints/${ckpt_name} \
--am_stat=dump/train/speech_stats.npy \
--voc=pwgan_csmsc \
--voc_config=pwg_baker_ckpt_0.4/pwg_default.yaml \
--voc_ckpt=pwg_baker_ckpt_0.4/pwg_snapshot_iter_400000.pdz \
--voc_stat=pwg_baker_ckpt_0.4/pwg_stats.npy \
--lang=zh \
--text=${BIN_DIR}/../sentences.txt \
--output_dir=${train_output_path}/test_e2e \
--phones_dict=dump/phone_id_map.txt \
--inference_dir=${train_output_path}/inference
fi
# for more GAN Vocoders
# multi band melgan
if [ ${stage} -le 1 ] && [ ${stop_stage} -ge 1 ]; then
FLAGS_allocator_strategy=naive_best_fit \
FLAGS_fraction_of_gpu_memory_to_use=0.01 \
python3 ${BIN_DIR}/../synthesize_streaming.py \
--am=fastspeech2_csmsc \
--am_config=${config_path} \
--am_ckpt=${train_output_path}/checkpoints/${ckpt_name} \
--am_stat=dump/train/speech_stats.npy \
--voc=mb_melgan_csmsc \
--voc_config=mb_melgan_csmsc_ckpt_0.1.1/default.yaml \
--voc_ckpt=mb_melgan_csmsc_ckpt_0.1.1/snapshot_iter_1000000.pdz\
--voc_stat=mb_melgan_csmsc_ckpt_0.1.1/feats_stats.npy \
--lang=zh \
--text=${BIN_DIR}/../sentences.txt \
--output_dir=${train_output_path}/test_e2e \
--phones_dict=dump/phone_id_map.txt \
--inference_dir=${train_output_path}/inference
fi
# the pretrained models haven't release now
# style melgan
# style melgan's Dygraph to Static Graph is not ready now
if [ ${stage} -le 2 ] && [ ${stop_stage} -ge 2 ]; then
FLAGS_allocator_strategy=naive_best_fit \
FLAGS_fraction_of_gpu_memory_to_use=0.01 \
python3 ${BIN_DIR}/../synthesize_streaming.py \
--am=fastspeech2_csmsc \
--am_config=${config_path} \
--am_ckpt=${train_output_path}/checkpoints/${ckpt_name} \
--am_stat=dump/train/speech_stats.npy \
--voc=style_melgan_csmsc \
--voc_config=style_melgan_csmsc_ckpt_0.1.1/default.yaml \
--voc_ckpt=style_melgan_csmsc_ckpt_0.1.1/snapshot_iter_1500000.pdz \
--voc_stat=style_melgan_csmsc_ckpt_0.1.1/feats_stats.npy \
--lang=zh \
--text=${BIN_DIR}/../sentences.txt \
--output_dir=${train_output_path}/test_e2e \
--phones_dict=dump/phone_id_map.txt
# --inference_dir=${train_output_path}/inference
fi
# hifigan
if [ ${stage} -le 3 ] && [ ${stop_stage} -ge 3 ]; then
echo "in hifigan syn_e2e"
FLAGS_allocator_strategy=naive_best_fit \
FLAGS_fraction_of_gpu_memory_to_use=0.01 \
python3 ${BIN_DIR}/../synthesize_streaming.py \
--am=fastspeech2_csmsc \
--am_config=${config_path} \
--am_ckpt=${train_output_path}/checkpoints/${ckpt_name} \
--am_stat=dump/train/speech_stats.npy \
--voc=hifigan_csmsc \
--voc_config=hifigan_csmsc_ckpt_0.1.1/default.yaml \
--voc_ckpt=hifigan_csmsc_ckpt_0.1.1/snapshot_iter_2500000.pdz \
--voc_stat=hifigan_csmsc_ckpt_0.1.1/feats_stats.npy \
--lang=zh \
--text=${BIN_DIR}/../sentences.txt \
--output_dir=${train_output_path}/test_e2e \
--phones_dict=dump/phone_id_map.txt \
--inference_dir=${train_output_path}/inference
fi

@ -0,0 +1,48 @@
#!/bin/bash
set -e
source path.sh
gpus=0,1
stage=0
stop_stage=100
conf_path=conf/cnndecoder.yaml
train_output_path=exp/cnndecoder
ckpt_name=snapshot_iter_153.pdz
# with the following command, you can choose the stage range you want to run
# such as `./run.sh --stage 0 --stop-stage 0`
# this can not be mixed use with `$1`, `$2` ...
source ${MAIN_ROOT}/utils/parse_options.sh || exit 1
if [ ${stage} -le 0 ] && [ ${stop_stage} -ge 0 ]; then
# prepare data
./local/preprocess.sh ${conf_path} || exit -1
fi
if [ ${stage} -le 1 ] && [ ${stop_stage} -ge 1 ]; then
# train model, all `ckpt` under `train_output_path/checkpoints/` dir
CUDA_VISIBLE_DEVICES=${gpus} ./local/train.sh ${conf_path} ${train_output_path} || exit -1
fi
if [ ${stage} -le 2 ] && [ ${stop_stage} -ge 2 ]; then
# synthesize, vocoder is pwgan
CUDA_VISIBLE_DEVICES=${gpus} ./local/synthesize.sh ${conf_path} ${train_output_path} ${ckpt_name} || exit -1
fi
if [ ${stage} -le 3 ] && [ ${stop_stage} -ge 3 ]; then
# synthesize_e2e, vocoder is pwgan
CUDA_VISIBLE_DEVICES=${gpus} ./local/synthesize_e2e.sh ${conf_path} ${train_output_path} ${ckpt_name} || exit -1
fi
if [ ${stage} -le 4 ] && [ ${stop_stage} -ge 4 ]; then
# inference with static model
CUDA_VISIBLE_DEVICES=${gpus} ./local/inference.sh ${train_output_path} || exit -1
fi
if [ ${stage} -le 5 ] && [ ${stop_stage} -ge 5 ]; then
# synthesize_e2e, vocoder is pwgan
CUDA_VISIBLE_DEVICES=${gpus} ./local/synthesize_streaming.sh ${conf_path} ${train_output_path} ${ckpt_name} || exit -1
fi

@ -14,6 +14,7 @@
# Modified from espnet(https://github.com/espnet/espnet) # Modified from espnet(https://github.com/espnet/espnet)
"""Fastspeech2 related modules for paddle""" """Fastspeech2 related modules for paddle"""
from typing import Dict from typing import Dict
from typing import List
from typing import Sequence from typing import Sequence
from typing import Tuple from typing import Tuple
from typing import Union from typing import Union
@ -32,6 +33,8 @@ from paddlespeech.t2s.modules.predictor.duration_predictor import DurationPredic
from paddlespeech.t2s.modules.predictor.length_regulator import LengthRegulator from paddlespeech.t2s.modules.predictor.length_regulator import LengthRegulator
from paddlespeech.t2s.modules.predictor.variance_predictor import VariancePredictor from paddlespeech.t2s.modules.predictor.variance_predictor import VariancePredictor
from paddlespeech.t2s.modules.tacotron2.decoder import Postnet from paddlespeech.t2s.modules.tacotron2.decoder import Postnet
from paddlespeech.t2s.modules.transformer.encoder import CNNDecoder
from paddlespeech.t2s.modules.transformer.encoder import CNNPostnet
from paddlespeech.t2s.modules.transformer.encoder import ConformerEncoder from paddlespeech.t2s.modules.transformer.encoder import ConformerEncoder
from paddlespeech.t2s.modules.transformer.encoder import TransformerEncoder from paddlespeech.t2s.modules.transformer.encoder import TransformerEncoder
@ -97,6 +100,12 @@ class FastSpeech2(nn.Layer):
zero_triu: bool=False, zero_triu: bool=False,
conformer_enc_kernel_size: int=7, conformer_enc_kernel_size: int=7,
conformer_dec_kernel_size: int=31, conformer_dec_kernel_size: int=31,
# for CNN Decoder
cnn_dec_dropout_rate: float=0.2,
cnn_postnet_dropout_rate: float=0.2,
cnn_postnet_resblock_kernel_sizes: List[int]=[256, 256],
cnn_postnet_kernel_size: int=5,
cnn_decoder_embedding_dim: int=256,
# duration predictor # duration predictor
duration_predictor_layers: int=2, duration_predictor_layers: int=2,
duration_predictor_chans: int=384, duration_predictor_chans: int=384,
@ -392,6 +401,13 @@ class FastSpeech2(nn.Layer):
activation_type=conformer_activation_type, activation_type=conformer_activation_type,
use_cnn_module=use_cnn_in_conformer, use_cnn_module=use_cnn_in_conformer,
cnn_module_kernel=conformer_dec_kernel_size, ) cnn_module_kernel=conformer_dec_kernel_size, )
elif decoder_type == 'cnndecoder':
self.decoder = CNNDecoder(
emb_dim=adim,
odim=odim,
kernel_size=cnn_postnet_kernel_size,
dropout_rate=cnn_dec_dropout_rate,
resblock_kernel_sizes=cnn_postnet_resblock_kernel_sizes)
else: else:
raise ValueError(f"{decoder_type} is not supported.") raise ValueError(f"{decoder_type} is not supported.")
@ -399,6 +415,13 @@ class FastSpeech2(nn.Layer):
self.feat_out = nn.Linear(adim, odim * reduction_factor) self.feat_out = nn.Linear(adim, odim * reduction_factor)
# define postnet # define postnet
if decoder_type == 'cnndecoder':
self.postnet = CNNPostnet(
odim=odim,
kernel_size=cnn_postnet_kernel_size,
dropout_rate=cnn_postnet_dropout_rate,
resblock_kernel_sizes=cnn_postnet_resblock_kernel_sizes)
else:
self.postnet = (None if postnet_layers == 0 else Postnet( self.postnet = (None if postnet_layers == 0 else Postnet(
idim=idim, idim=idim,
odim=odim, odim=odim,
@ -562,6 +585,7 @@ class FastSpeech2(nn.Layer):
[olen // self.reduction_factor for olen in olens.numpy()]) [olen // self.reduction_factor for olen in olens.numpy()])
else: else:
olens_in = olens olens_in = olens
# (B, 1, T)
h_masks = self._source_mask(olens_in) h_masks = self._source_mask(olens_in)
else: else:
h_masks = None h_masks = None
@ -569,6 +593,9 @@ class FastSpeech2(nn.Layer):
zs, _ = self.decoder(hs, h_masks) zs, _ = self.decoder(hs, h_masks)
# (B, Lmax, odim) # (B, Lmax, odim)
if self.decoder_type == 'cnndecoder':
before_outs = zs
else:
before_outs = self.feat_out(zs).reshape( before_outs = self.feat_out(zs).reshape(
(paddle.shape(zs)[0], -1, self.odim)) (paddle.shape(zs)[0], -1, self.odim))

@ -515,3 +515,136 @@ class ConformerEncoder(BaseEncoder):
if self.intermediate_layers is not None: if self.intermediate_layers is not None:
return xs, masks, intermediate_outputs return xs, masks, intermediate_outputs
return xs, masks return xs, masks
class Conv1dResidualBlock(nn.Layer):
"""
Special module for simplified version of Encoder class.
"""
def __init__(self,
idim: int=256,
odim: int=256,
kernel_size: int=5,
dropout_rate: float=0.2):
super().__init__()
self.main_block = nn.Sequential(
nn.Conv1D(
idim, odim, kernel_size=kernel_size, padding=kernel_size // 2),
nn.ReLU(),
nn.BatchNorm1D(odim),
nn.Dropout(p=dropout_rate))
self.conv1d_residual = nn.Conv1D(idim, odim, kernel_size=1)
def forward(self, xs):
"""Encode input sequence.
Args:
xs (Tensor): Input tensor (#batch, idim, T).
Returns:
Tensor: Output tensor (#batch, odim, T).
"""
outputs = self.main_block(xs)
outputs = self.conv1d_residual(xs) + outputs
return outputs
class CNNDecoder(nn.Layer):
"""
Much simplified decoder than the original one with Prenet.
"""
def __init__(
self,
emb_dim: int=256,
odim: int=80,
kernel_size: int=5,
dropout_rate: float=0.2,
resblock_kernel_sizes: List[int]=[256, 256], ):
super().__init__()
input_shape = emb_dim
out_sizes = resblock_kernel_sizes
out_sizes.append(out_sizes[-1])
in_sizes = [input_shape] + out_sizes[:-1]
self.residual_blocks = nn.LayerList([
Conv1dResidualBlock(
idim=in_channels,
odim=out_channels,
kernel_size=kernel_size,
dropout_rate=dropout_rate, )
for in_channels, out_channels in zip(in_sizes, out_sizes)
])
self.conv1d = nn.Conv1D(
in_channels=out_sizes[-1], out_channels=odim, kernel_size=1)
def forward(self, xs, masks=None):
"""Encode input sequence.
Args:
xs (Tensor): Input tensor (#batch, time, idim).
masks (Tensor): Mask tensor (#batch, 1, time).
Returns:
Tensor: Output tensor (#batch, time, odim).
"""
# print("input.shape in CNNDecoder:",xs.shape)
# exchange the temporal dimension and the feature dimension
xs = xs.transpose([0, 2, 1])
if masks is not None:
xs = xs * masks
for layer in self.residual_blocks:
outputs = layer(xs)
if masks is not None:
# input_mask B * 1 * T
outputs = outputs * masks
xs = outputs
outputs = self.conv1d(outputs)
if masks is not None:
outputs = outputs * masks
outputs = outputs.transpose([0, 2, 1])
# print("outputs.shape in CNNDecoder:",outputs.shape)
return outputs, masks
class CNNPostnet(nn.Layer):
def __init__(
self,
odim: int=80,
kernel_size: int=5,
dropout_rate: float=0.2,
resblock_kernel_sizes: List[int]=[256, 256], ):
super().__init__()
out_sizes = resblock_kernel_sizes
in_sizes = [odim] + out_sizes[:-1]
self.residual_blocks = nn.LayerList([
Conv1dResidualBlock(
idim=in_channels,
odim=out_channels,
kernel_size=kernel_size,
dropout_rate=dropout_rate)
for in_channels, out_channels in zip(in_sizes, out_sizes)
])
self.conv1d = nn.Conv1D(
in_channels=out_sizes[-1], out_channels=odim, kernel_size=1)
def forward(self, xs, masks=None):
"""Encode input sequence.
Args:
xs (Tensor): Input tensor (#batch, odim, time).
masks (Tensor): Mask tensor (#batch, 1, time).
Returns:
Tensor: Output tensor (#batch, odim, time).
"""
# print("xs.shape in CNNPostnet:",xs.shape)
for layer in self.residual_blocks:
outputs = layer(xs)
if masks is not None:
# input_mask B * 1 * T
outputs = outputs * masks
xs = outputs
outputs = self.conv1d(outputs)
if masks is not None:
outputs = outputs * masks
# print("outputs.shape in CNNPostnet:",outputs.shape)
return outputs

Loading…
Cancel
Save