diff --git a/examples/csmsc/tts3/conf/cnndecoder.yaml b/examples/csmsc/tts3/conf/cnndecoder.yaml new file mode 100644 index 00000000..8b46fea4 --- /dev/null +++ b/examples/csmsc/tts3/conf/cnndecoder.yaml @@ -0,0 +1,107 @@ +# use CNND +########################################################### +# FEATURE EXTRACTION SETTING # +########################################################### + +fs: 24000 # sr +n_fft: 2048 # FFT size (samples). +n_shift: 300 # Hop size (samples). 12.5ms +win_length: 1200 # Window length (samples). 50ms + # If set to null, it will be the same as fft_size. +window: "hann" # Window function. + +# Only used for feats_type != raw + +fmin: 80 # Minimum frequency of Mel basis. +fmax: 7600 # Maximum frequency of Mel basis. +n_mels: 80 # The number of mel basis. + +# Only used for the model using pitch features (e.g. FastSpeech2) +f0min: 80 # Minimum f0 for pitch extraction. +f0max: 400 # Maximum f0 for pitch extraction. + + +########################################################### +# DATA SETTING # +########################################################### +batch_size: 64 +num_workers: 4 + + +########################################################### +# MODEL SETTING # +########################################################### +model: + adim: 384 # attention dimension + aheads: 2 # number of attention heads + elayers: 4 # number of encoder layers + eunits: 1536 # number of encoder ff units + dlayers: 4 # number of decoder layers + dunits: 1536 # number of decoder ff units + positionwise_layer_type: conv1d # type of position-wise layer + positionwise_conv_kernel_size: 3 # kernel size of position wise conv layer + duration_predictor_layers: 2 # number of layers of duration predictor + duration_predictor_chans: 256 # number of channels of duration predictor + duration_predictor_kernel_size: 3 # filter size of duration predictor + postnet_layers: 5 # number of layers of postnset + postnet_filts: 5 # filter size of conv layers in postnet + postnet_chans: 256 # number of channels of conv layers in postnet + use_scaled_pos_enc: True # whether to use scaled positional encoding + encoder_normalize_before: True # whether to perform layer normalization before the input + decoder_normalize_before: True # whether to perform layer normalization before the input + reduction_factor: 1 # reduction factor + encoder_type: transformer # encoder type + decoder_type: cnndecoder # decoder type + init_type: xavier_uniform # initialization type + init_enc_alpha: 1.0 # initial value of alpha of encoder scaled position encoding + init_dec_alpha: 1.0 # initial value of alpha of decoder scaled position encoding + transformer_enc_dropout_rate: 0.2 # dropout rate for transformer encoder layer + transformer_enc_positional_dropout_rate: 0.2 # dropout rate for transformer encoder positional encoding + transformer_enc_attn_dropout_rate: 0.2 # dropout rate for transformer encoder attention layer + cnn_dec_dropout_rate: 0.2 # dropout rate for cnn decoder layer + cnn_postnet_dropout_rate: 0.2 + cnn_postnet_resblock_kernel_sizes: [256, 256] # kernel sizes for residual block of cnn_postnet + cnn_postnet_kernel_size: 5 # kernel size of cnn_postnet + cnn_decoder_embedding_dim: 256 + pitch_predictor_layers: 5 # number of conv layers in pitch predictor + pitch_predictor_chans: 256 # number of channels of conv layers in pitch predictor + pitch_predictor_kernel_size: 5 # kernel size of conv leyers in pitch predictor + pitch_predictor_dropout: 0.5 # dropout rate in pitch predictor + pitch_embed_kernel_size: 1 # kernel size of conv embedding layer for pitch + pitch_embed_dropout: 0.0 # dropout rate after conv embedding layer for pitch + stop_gradient_from_pitch_predictor: True # whether to stop the gradient from pitch predictor to encoder + energy_predictor_layers: 2 # number of conv layers in energy predictor + energy_predictor_chans: 256 # number of channels of conv layers in energy predictor + energy_predictor_kernel_size: 3 # kernel size of conv leyers in energy predictor + energy_predictor_dropout: 0.5 # dropout rate in energy predictor + energy_embed_kernel_size: 1 # kernel size of conv embedding layer for energy + energy_embed_dropout: 0.0 # dropout rate after conv embedding layer for energy + stop_gradient_from_energy_predictor: False # whether to stop the gradient from energy predictor to encoder + + + +########################################################### +# UPDATER SETTING # +########################################################### +updater: + use_masking: True # whether to apply masking for padded part in loss calculation + + +########################################################### +# OPTIMIZER SETTING # +########################################################### +optimizer: + optim: adam # optimizer type + learning_rate: 0.001 # learning rate + +########################################################### +# TRAINING SETTING # +########################################################### +max_epoch: 1000 +num_snapshots: 5 + + +########################################################### +# OTHER SETTING # +########################################################### +seed: 10086 diff --git a/examples/csmsc/tts3/local/synthesize_streaming.sh b/examples/csmsc/tts3/local/synthesize_streaming.sh new file mode 100755 index 00000000..69bb22df --- /dev/null +++ b/examples/csmsc/tts3/local/synthesize_streaming.sh @@ -0,0 +1,92 @@ +#!/bin/bash + +config_path=$1 +train_output_path=$2 +ckpt_name=$3 + +stage=0 +stop_stage=0 + +# pwgan +if [ ${stage} -le 0 ] && [ ${stop_stage} -ge 0 ]; then + FLAGS_allocator_strategy=naive_best_fit \ + FLAGS_fraction_of_gpu_memory_to_use=0.01 \ + python3 ${BIN_DIR}/../synthesize_streaming.py \ + --am=fastspeech2_csmsc \ + --am_config=${config_path} \ + --am_ckpt=${train_output_path}/checkpoints/${ckpt_name} \ + --am_stat=dump/train/speech_stats.npy \ + --voc=pwgan_csmsc \ + --voc_config=pwg_baker_ckpt_0.4/pwg_default.yaml \ + --voc_ckpt=pwg_baker_ckpt_0.4/pwg_snapshot_iter_400000.pdz \ + --voc_stat=pwg_baker_ckpt_0.4/pwg_stats.npy \ + --lang=zh \ + --text=${BIN_DIR}/../sentences.txt \ + --output_dir=${train_output_path}/test_e2e \ + --phones_dict=dump/phone_id_map.txt \ + --inference_dir=${train_output_path}/inference +fi + +# for more GAN Vocoders +# multi band melgan +if [ ${stage} -le 1 ] && [ ${stop_stage} -ge 1 ]; then + FLAGS_allocator_strategy=naive_best_fit \ + FLAGS_fraction_of_gpu_memory_to_use=0.01 \ + python3 ${BIN_DIR}/../synthesize_streaming.py \ + --am=fastspeech2_csmsc \ + --am_config=${config_path} \ + --am_ckpt=${train_output_path}/checkpoints/${ckpt_name} \ + --am_stat=dump/train/speech_stats.npy \ + --voc=mb_melgan_csmsc \ + --voc_config=mb_melgan_csmsc_ckpt_0.1.1/default.yaml \ + --voc_ckpt=mb_melgan_csmsc_ckpt_0.1.1/snapshot_iter_1000000.pdz\ + --voc_stat=mb_melgan_csmsc_ckpt_0.1.1/feats_stats.npy \ + --lang=zh \ + --text=${BIN_DIR}/../sentences.txt \ + --output_dir=${train_output_path}/test_e2e \ + --phones_dict=dump/phone_id_map.txt \ + --inference_dir=${train_output_path}/inference +fi + +# the pretrained models haven't release now +# style melgan +# style melgan's Dygraph to Static Graph is not ready now +if [ ${stage} -le 2 ] && [ ${stop_stage} -ge 2 ]; then + FLAGS_allocator_strategy=naive_best_fit \ + FLAGS_fraction_of_gpu_memory_to_use=0.01 \ + python3 ${BIN_DIR}/../synthesize_streaming.py \ + --am=fastspeech2_csmsc \ + --am_config=${config_path} \ + --am_ckpt=${train_output_path}/checkpoints/${ckpt_name} \ + --am_stat=dump/train/speech_stats.npy \ + --voc=style_melgan_csmsc \ + --voc_config=style_melgan_csmsc_ckpt_0.1.1/default.yaml \ + --voc_ckpt=style_melgan_csmsc_ckpt_0.1.1/snapshot_iter_1500000.pdz \ + --voc_stat=style_melgan_csmsc_ckpt_0.1.1/feats_stats.npy \ + --lang=zh \ + --text=${BIN_DIR}/../sentences.txt \ + --output_dir=${train_output_path}/test_e2e \ + --phones_dict=dump/phone_id_map.txt + # --inference_dir=${train_output_path}/inference +fi + +# hifigan +if [ ${stage} -le 3 ] && [ ${stop_stage} -ge 3 ]; then + echo "in hifigan syn_e2e" + FLAGS_allocator_strategy=naive_best_fit \ + FLAGS_fraction_of_gpu_memory_to_use=0.01 \ + python3 ${BIN_DIR}/../synthesize_streaming.py \ + --am=fastspeech2_csmsc \ + --am_config=${config_path} \ + --am_ckpt=${train_output_path}/checkpoints/${ckpt_name} \ + --am_stat=dump/train/speech_stats.npy \ + --voc=hifigan_csmsc \ + --voc_config=hifigan_csmsc_ckpt_0.1.1/default.yaml \ + --voc_ckpt=hifigan_csmsc_ckpt_0.1.1/snapshot_iter_2500000.pdz \ + --voc_stat=hifigan_csmsc_ckpt_0.1.1/feats_stats.npy \ + --lang=zh \ + --text=${BIN_DIR}/../sentences.txt \ + --output_dir=${train_output_path}/test_e2e \ + --phones_dict=dump/phone_id_map.txt \ + --inference_dir=${train_output_path}/inference +fi diff --git a/examples/csmsc/tts3/run_cnndecoder.sh b/examples/csmsc/tts3/run_cnndecoder.sh new file mode 100755 index 00000000..5cccef01 --- /dev/null +++ b/examples/csmsc/tts3/run_cnndecoder.sh @@ -0,0 +1,48 @@ +#!/bin/bash + +set -e +source path.sh + +gpus=0,1 +stage=0 +stop_stage=100 + +conf_path=conf/cnndecoder.yaml +train_output_path=exp/cnndecoder +ckpt_name=snapshot_iter_153.pdz + +# with the following command, you can choose the stage range you want to run +# such as `./run.sh --stage 0 --stop-stage 0` +# this can not be mixed use with `$1`, `$2` ... +source ${MAIN_ROOT}/utils/parse_options.sh || exit 1 + +if [ ${stage} -le 0 ] && [ ${stop_stage} -ge 0 ]; then + # prepare data + ./local/preprocess.sh ${conf_path} || exit -1 +fi + +if [ ${stage} -le 1 ] && [ ${stop_stage} -ge 1 ]; then + # train model, all `ckpt` under `train_output_path/checkpoints/` dir + CUDA_VISIBLE_DEVICES=${gpus} ./local/train.sh ${conf_path} ${train_output_path} || exit -1 +fi + +if [ ${stage} -le 2 ] && [ ${stop_stage} -ge 2 ]; then + # synthesize, vocoder is pwgan + CUDA_VISIBLE_DEVICES=${gpus} ./local/synthesize.sh ${conf_path} ${train_output_path} ${ckpt_name} || exit -1 +fi + +if [ ${stage} -le 3 ] && [ ${stop_stage} -ge 3 ]; then + # synthesize_e2e, vocoder is pwgan + CUDA_VISIBLE_DEVICES=${gpus} ./local/synthesize_e2e.sh ${conf_path} ${train_output_path} ${ckpt_name} || exit -1 +fi + +if [ ${stage} -le 4 ] && [ ${stop_stage} -ge 4 ]; then + # inference with static model + CUDA_VISIBLE_DEVICES=${gpus} ./local/inference.sh ${train_output_path} || exit -1 +fi + +if [ ${stage} -le 5 ] && [ ${stop_stage} -ge 5 ]; then + # synthesize_e2e, vocoder is pwgan + CUDA_VISIBLE_DEVICES=${gpus} ./local/synthesize_streaming.sh ${conf_path} ${train_output_path} ${ckpt_name} || exit -1 +fi + diff --git a/paddlespeech/t2s/models/fastspeech2/fastspeech2.py b/paddlespeech/t2s/models/fastspeech2/fastspeech2.py index 73f5498e..1c805051 100644 --- a/paddlespeech/t2s/models/fastspeech2/fastspeech2.py +++ b/paddlespeech/t2s/models/fastspeech2/fastspeech2.py @@ -14,6 +14,7 @@ # Modified from espnet(https://github.com/espnet/espnet) """Fastspeech2 related modules for paddle""" from typing import Dict +from typing import List from typing import Sequence from typing import Tuple from typing import Union @@ -32,6 +33,8 @@ from paddlespeech.t2s.modules.predictor.duration_predictor import DurationPredic from paddlespeech.t2s.modules.predictor.length_regulator import LengthRegulator from paddlespeech.t2s.modules.predictor.variance_predictor import VariancePredictor from paddlespeech.t2s.modules.tacotron2.decoder import Postnet +from paddlespeech.t2s.modules.transformer.encoder import CNNDecoder +from paddlespeech.t2s.modules.transformer.encoder import CNNPostnet from paddlespeech.t2s.modules.transformer.encoder import ConformerEncoder from paddlespeech.t2s.modules.transformer.encoder import TransformerEncoder @@ -97,6 +100,12 @@ class FastSpeech2(nn.Layer): zero_triu: bool=False, conformer_enc_kernel_size: int=7, conformer_dec_kernel_size: int=31, + # for CNN Decoder + cnn_dec_dropout_rate: float=0.2, + cnn_postnet_dropout_rate: float=0.2, + cnn_postnet_resblock_kernel_sizes: List[int]=[256, 256], + cnn_postnet_kernel_size: int=5, + cnn_decoder_embedding_dim: int=256, # duration predictor duration_predictor_layers: int=2, duration_predictor_chans: int=384, @@ -392,6 +401,13 @@ class FastSpeech2(nn.Layer): activation_type=conformer_activation_type, use_cnn_module=use_cnn_in_conformer, cnn_module_kernel=conformer_dec_kernel_size, ) + elif decoder_type == 'cnndecoder': + self.decoder = CNNDecoder( + emb_dim=adim, + odim=odim, + kernel_size=cnn_postnet_kernel_size, + dropout_rate=cnn_dec_dropout_rate, + resblock_kernel_sizes=cnn_postnet_resblock_kernel_sizes) else: raise ValueError(f"{decoder_type} is not supported.") @@ -399,14 +415,21 @@ class FastSpeech2(nn.Layer): self.feat_out = nn.Linear(adim, odim * reduction_factor) # define postnet - self.postnet = (None if postnet_layers == 0 else Postnet( - idim=idim, - odim=odim, - n_layers=postnet_layers, - n_chans=postnet_chans, - n_filts=postnet_filts, - use_batch_norm=use_batch_norm, - dropout_rate=postnet_dropout_rate, )) + if decoder_type == 'cnndecoder': + self.postnet = CNNPostnet( + odim=odim, + kernel_size=cnn_postnet_kernel_size, + dropout_rate=cnn_postnet_dropout_rate, + resblock_kernel_sizes=cnn_postnet_resblock_kernel_sizes) + else: + self.postnet = (None if postnet_layers == 0 else Postnet( + idim=idim, + odim=odim, + n_layers=postnet_layers, + n_chans=postnet_chans, + n_filts=postnet_filts, + use_batch_norm=use_batch_norm, + dropout_rate=postnet_dropout_rate, )) nn.initializer.set_global_initializer(None) @@ -562,6 +585,7 @@ class FastSpeech2(nn.Layer): [olen // self.reduction_factor for olen in olens.numpy()]) else: olens_in = olens + # (B, 1, T) h_masks = self._source_mask(olens_in) else: h_masks = None @@ -569,8 +593,11 @@ class FastSpeech2(nn.Layer): zs, _ = self.decoder(hs, h_masks) # (B, Lmax, odim) - before_outs = self.feat_out(zs).reshape( - (paddle.shape(zs)[0], -1, self.odim)) + if self.decoder_type == 'cnndecoder': + before_outs = zs + else: + before_outs = self.feat_out(zs).reshape( + (paddle.shape(zs)[0], -1, self.odim)) # postnet -> (B, Lmax//r * r, odim) if self.postnet is None: diff --git a/paddlespeech/t2s/modules/transformer/encoder.py b/paddlespeech/t2s/modules/transformer/encoder.py index 2b3ee788..25a11ff6 100644 --- a/paddlespeech/t2s/modules/transformer/encoder.py +++ b/paddlespeech/t2s/modules/transformer/encoder.py @@ -515,3 +515,136 @@ class ConformerEncoder(BaseEncoder): if self.intermediate_layers is not None: return xs, masks, intermediate_outputs return xs, masks + + +class Conv1dResidualBlock(nn.Layer): + """ + Special module for simplified version of Encoder class. + """ + + def __init__(self, + idim: int=256, + odim: int=256, + kernel_size: int=5, + dropout_rate: float=0.2): + super().__init__() + self.main_block = nn.Sequential( + nn.Conv1D( + idim, odim, kernel_size=kernel_size, padding=kernel_size // 2), + nn.ReLU(), + nn.BatchNorm1D(odim), + nn.Dropout(p=dropout_rate)) + self.conv1d_residual = nn.Conv1D(idim, odim, kernel_size=1) + + def forward(self, xs): + """Encode input sequence. + Args: + xs (Tensor): Input tensor (#batch, idim, T). + Returns: + Tensor: Output tensor (#batch, odim, T). + """ + outputs = self.main_block(xs) + outputs = self.conv1d_residual(xs) + outputs + return outputs + + +class CNNDecoder(nn.Layer): + """ + Much simplified decoder than the original one with Prenet. + """ + + def __init__( + self, + emb_dim: int=256, + odim: int=80, + kernel_size: int=5, + dropout_rate: float=0.2, + resblock_kernel_sizes: List[int]=[256, 256], ): + + super().__init__() + + input_shape = emb_dim + out_sizes = resblock_kernel_sizes + out_sizes.append(out_sizes[-1]) + + in_sizes = [input_shape] + out_sizes[:-1] + self.residual_blocks = nn.LayerList([ + Conv1dResidualBlock( + idim=in_channels, + odim=out_channels, + kernel_size=kernel_size, + dropout_rate=dropout_rate, ) + for in_channels, out_channels in zip(in_sizes, out_sizes) + ]) + self.conv1d = nn.Conv1D( + in_channels=out_sizes[-1], out_channels=odim, kernel_size=1) + + def forward(self, xs, masks=None): + """Encode input sequence. + Args: + xs (Tensor): Input tensor (#batch, time, idim). + masks (Tensor): Mask tensor (#batch, 1, time). + Returns: + Tensor: Output tensor (#batch, time, odim). + """ + # print("input.shape in CNNDecoder:",xs.shape) + # exchange the temporal dimension and the feature dimension + xs = xs.transpose([0, 2, 1]) + if masks is not None: + xs = xs * masks + + for layer in self.residual_blocks: + outputs = layer(xs) + if masks is not None: + # input_mask B * 1 * T + outputs = outputs * masks + xs = outputs + outputs = self.conv1d(outputs) + if masks is not None: + outputs = outputs * masks + outputs = outputs.transpose([0, 2, 1]) + # print("outputs.shape in CNNDecoder:",outputs.shape) + return outputs, masks + + +class CNNPostnet(nn.Layer): + def __init__( + self, + odim: int=80, + kernel_size: int=5, + dropout_rate: float=0.2, + resblock_kernel_sizes: List[int]=[256, 256], ): + super().__init__() + out_sizes = resblock_kernel_sizes + in_sizes = [odim] + out_sizes[:-1] + self.residual_blocks = nn.LayerList([ + Conv1dResidualBlock( + idim=in_channels, + odim=out_channels, + kernel_size=kernel_size, + dropout_rate=dropout_rate) + for in_channels, out_channels in zip(in_sizes, out_sizes) + ]) + self.conv1d = nn.Conv1D( + in_channels=out_sizes[-1], out_channels=odim, kernel_size=1) + + def forward(self, xs, masks=None): + """Encode input sequence. + Args: + xs (Tensor): Input tensor (#batch, odim, time). + masks (Tensor): Mask tensor (#batch, 1, time). + Returns: + Tensor: Output tensor (#batch, odim, time). + """ + # print("xs.shape in CNNPostnet:",xs.shape) + for layer in self.residual_blocks: + outputs = layer(xs) + if masks is not None: + # input_mask B * 1 * T + outputs = outputs * masks + xs = outputs + outputs = self.conv1d(outputs) + if masks is not None: + outputs = outputs * masks + # print("outputs.shape in CNNPostnet:",outputs.shape) + return outputs