Merge pull request #1634 from yt605155624/cnn_decoder

[TTS]Cnn decoder
r0.2
TianYuan 2 years ago committed by GitHub
commit 5bff096715
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23

@ -226,6 +226,7 @@ CUDA_VISIBLE_DEVICES=${gpus} ./local/inference.sh ${train_output_path}
Pretrained FastSpeech2 model with no silence in the edge of audios:
- [fastspeech2_nosil_baker_ckpt_0.4.zip](https://paddlespeech.bj.bcebos.com/Parakeet/released_models/fastspeech2/fastspeech2_nosil_baker_ckpt_0.4.zip)
- [fastspeech2_conformer_baker_ckpt_0.5.zip](https://paddlespeech.bj.bcebos.com/Parakeet/released_models/fastspeech2/fastspeech2_conformer_baker_ckpt_0.5.zip)
- [fastspeech2_cnndecoder_csmsc_ckpt_1.0.0.zip](https://paddlespeech.bj.bcebos.com/Parakeet/released_models/fastspeech2/fastspeech2_cnndecoder_csmsc_ckpt_1.0.0.zip)
The static model can be downloaded here:
- [fastspeech2_nosil_baker_static_0.4.zip](https://paddlespeech.bj.bcebos.com/Parakeet/released_models/fastspeech2/fastspeech2_nosil_baker_static_0.4.zip)

@ -0,0 +1,107 @@
# use CNND
###########################################################
# FEATURE EXTRACTION SETTING #
###########################################################
fs: 24000 # sr
n_fft: 2048 # FFT size (samples).
n_shift: 300 # Hop size (samples). 12.5ms
win_length: 1200 # Window length (samples). 50ms
# If set to null, it will be the same as fft_size.
window: "hann" # Window function.
# Only used for feats_type != raw
fmin: 80 # Minimum frequency of Mel basis.
fmax: 7600 # Maximum frequency of Mel basis.
n_mels: 80 # The number of mel basis.
# Only used for the model using pitch features (e.g. FastSpeech2)
f0min: 80 # Minimum f0 for pitch extraction.
f0max: 400 # Maximum f0 for pitch extraction.
###########################################################
# DATA SETTING #
###########################################################
batch_size: 64
num_workers: 4
###########################################################
# MODEL SETTING #
###########################################################
model:
adim: 384 # attention dimension
aheads: 2 # number of attention heads
elayers: 4 # number of encoder layers
eunits: 1536 # number of encoder ff units
dlayers: 4 # number of decoder layers
dunits: 1536 # number of decoder ff units
positionwise_layer_type: conv1d # type of position-wise layer
positionwise_conv_kernel_size: 3 # kernel size of position wise conv layer
duration_predictor_layers: 2 # number of layers of duration predictor
duration_predictor_chans: 256 # number of channels of duration predictor
duration_predictor_kernel_size: 3 # filter size of duration predictor
postnet_layers: 5 # number of layers of postnset
postnet_filts: 5 # filter size of conv layers in postnet
postnet_chans: 256 # number of channels of conv layers in postnet
use_scaled_pos_enc: True # whether to use scaled positional encoding
encoder_normalize_before: True # whether to perform layer normalization before the input
decoder_normalize_before: True # whether to perform layer normalization before the input
reduction_factor: 1 # reduction factor
encoder_type: transformer # encoder type
decoder_type: cnndecoder # decoder type
init_type: xavier_uniform # initialization type
init_enc_alpha: 1.0 # initial value of alpha of encoder scaled position encoding
init_dec_alpha: 1.0 # initial value of alpha of decoder scaled position encoding
transformer_enc_dropout_rate: 0.2 # dropout rate for transformer encoder layer
transformer_enc_positional_dropout_rate: 0.2 # dropout rate for transformer encoder positional encoding
transformer_enc_attn_dropout_rate: 0.2 # dropout rate for transformer encoder attention layer
cnn_dec_dropout_rate: 0.2 # dropout rate for cnn decoder layer
cnn_postnet_dropout_rate: 0.2
cnn_postnet_resblock_kernel_sizes: [256, 256] # kernel sizes for residual block of cnn_postnet
cnn_postnet_kernel_size: 5 # kernel size of cnn_postnet
cnn_decoder_embedding_dim: 256
pitch_predictor_layers: 5 # number of conv layers in pitch predictor
pitch_predictor_chans: 256 # number of channels of conv layers in pitch predictor
pitch_predictor_kernel_size: 5 # kernel size of conv leyers in pitch predictor
pitch_predictor_dropout: 0.5 # dropout rate in pitch predictor
pitch_embed_kernel_size: 1 # kernel size of conv embedding layer for pitch
pitch_embed_dropout: 0.0 # dropout rate after conv embedding layer for pitch
stop_gradient_from_pitch_predictor: True # whether to stop the gradient from pitch predictor to encoder
energy_predictor_layers: 2 # number of conv layers in energy predictor
energy_predictor_chans: 256 # number of channels of conv layers in energy predictor
energy_predictor_kernel_size: 3 # kernel size of conv leyers in energy predictor
energy_predictor_dropout: 0.5 # dropout rate in energy predictor
energy_embed_kernel_size: 1 # kernel size of conv embedding layer for energy
energy_embed_dropout: 0.0 # dropout rate after conv embedding layer for energy
stop_gradient_from_energy_predictor: False # whether to stop the gradient from energy predictor to encoder
###########################################################
# UPDATER SETTING #
###########################################################
updater:
use_masking: True # whether to apply masking for padded part in loss calculation
###########################################################
# OPTIMIZER SETTING #
###########################################################
optimizer:
optim: adam # optimizer type
learning_rate: 0.001 # learning rate
###########################################################
# TRAINING SETTING #
###########################################################
max_epoch: 1000
num_snapshots: 5
###########################################################
# OTHER SETTING #
###########################################################
seed: 10086

@ -0,0 +1,92 @@
#!/bin/bash
config_path=$1
train_output_path=$2
ckpt_name=$3
stage=0
stop_stage=0
# pwgan
if [ ${stage} -le 0 ] && [ ${stop_stage} -ge 0 ]; then
FLAGS_allocator_strategy=naive_best_fit \
FLAGS_fraction_of_gpu_memory_to_use=0.01 \
python3 ${BIN_DIR}/../synthesize_streaming.py \
--am=fastspeech2_csmsc \
--am_config=${config_path} \
--am_ckpt=${train_output_path}/checkpoints/${ckpt_name} \
--am_stat=dump/train/speech_stats.npy \
--voc=pwgan_csmsc \
--voc_config=pwg_baker_ckpt_0.4/pwg_default.yaml \
--voc_ckpt=pwg_baker_ckpt_0.4/pwg_snapshot_iter_400000.pdz \
--voc_stat=pwg_baker_ckpt_0.4/pwg_stats.npy \
--lang=zh \
--text=${BIN_DIR}/../sentences.txt \
--output_dir=${train_output_path}/test_e2e_streaming \
--phones_dict=dump/phone_id_map.txt \
--am_streaming=True
fi
# for more GAN Vocoders
# multi band melgan
if [ ${stage} -le 1 ] && [ ${stop_stage} -ge 1 ]; then
FLAGS_allocator_strategy=naive_best_fit \
FLAGS_fraction_of_gpu_memory_to_use=0.01 \
python3 ${BIN_DIR}/../synthesize_streaming.py \
--am=fastspeech2_csmsc \
--am_config=${config_path} \
--am_ckpt=${train_output_path}/checkpoints/${ckpt_name} \
--am_stat=dump/train/speech_stats.npy \
--voc=mb_melgan_csmsc \
--voc_config=mb_melgan_csmsc_ckpt_0.1.1/default.yaml \
--voc_ckpt=mb_melgan_csmsc_ckpt_0.1.1/snapshot_iter_1000000.pdz\
--voc_stat=mb_melgan_csmsc_ckpt_0.1.1/feats_stats.npy \
--lang=zh \
--text=${BIN_DIR}/../sentences.txt \
--output_dir=${train_output_path}/test_e2e_streaming \
--phones_dict=dump/phone_id_map.txt \
--am_streaming=True
fi
# the pretrained models haven't release now
# style melgan
# style melgan's Dygraph to Static Graph is not ready now
if [ ${stage} -le 2 ] && [ ${stop_stage} -ge 2 ]; then
FLAGS_allocator_strategy=naive_best_fit \
FLAGS_fraction_of_gpu_memory_to_use=0.01 \
python3 ${BIN_DIR}/../synthesize_streaming.py \
--am=fastspeech2_csmsc \
--am_config=${config_path} \
--am_ckpt=${train_output_path}/checkpoints/${ckpt_name} \
--am_stat=dump/train/speech_stats.npy \
--voc=style_melgan_csmsc \
--voc_config=style_melgan_csmsc_ckpt_0.1.1/default.yaml \
--voc_ckpt=style_melgan_csmsc_ckpt_0.1.1/snapshot_iter_1500000.pdz \
--voc_stat=style_melgan_csmsc_ckpt_0.1.1/feats_stats.npy \
--lang=zh \
--text=${BIN_DIR}/../sentences.txt \
--output_dir=${train_output_path}/test_e2e_streaming \
--phones_dict=dump/phone_id_map.txt \
--am_streaming=True
fi
# hifigan
if [ ${stage} -le 3 ] && [ ${stop_stage} -ge 3 ]; then
echo "in hifigan syn_e2e"
FLAGS_allocator_strategy=naive_best_fit \
FLAGS_fraction_of_gpu_memory_to_use=0.01 \
python3 ${BIN_DIR}/../synthesize_streaming.py \
--am=fastspeech2_csmsc \
--am_config=${config_path} \
--am_ckpt=${train_output_path}/checkpoints/${ckpt_name} \
--am_stat=dump/train/speech_stats.npy \
--voc=hifigan_csmsc \
--voc_config=hifigan_csmsc_ckpt_0.1.1/default.yaml \
--voc_ckpt=hifigan_csmsc_ckpt_0.1.1/snapshot_iter_2500000.pdz \
--voc_stat=hifigan_csmsc_ckpt_0.1.1/feats_stats.npy \
--lang=zh \
--text=${BIN_DIR}/../sentences.txt \
--output_dir=${train_output_path}/test_e2e_streaming \
--phones_dict=dump/phone_id_map.txt \
--am_streaming=True
fi

@ -0,0 +1,48 @@
#!/bin/bash
set -e
source path.sh
gpus=0,1
stage=0
stop_stage=100
conf_path=conf/cnndecoder.yaml
train_output_path=exp/cnndecoder
ckpt_name=snapshot_iter_153.pdz
# with the following command, you can choose the stage range you want to run
# such as `./run.sh --stage 0 --stop-stage 0`
# this can not be mixed use with `$1`, `$2` ...
source ${MAIN_ROOT}/utils/parse_options.sh || exit 1
if [ ${stage} -le 0 ] && [ ${stop_stage} -ge 0 ]; then
# prepare data
./local/preprocess.sh ${conf_path} || exit -1
fi
if [ ${stage} -le 1 ] && [ ${stop_stage} -ge 1 ]; then
# train model, all `ckpt` under `train_output_path/checkpoints/` dir
CUDA_VISIBLE_DEVICES=${gpus} ./local/train.sh ${conf_path} ${train_output_path} || exit -1
fi
if [ ${stage} -le 2 ] && [ ${stop_stage} -ge 2 ]; then
# synthesize, vocoder is pwgan
CUDA_VISIBLE_DEVICES=${gpus} ./local/synthesize.sh ${conf_path} ${train_output_path} ${ckpt_name} || exit -1
fi
if [ ${stage} -le 3 ] && [ ${stop_stage} -ge 3 ]; then
# synthesize_e2e, vocoder is pwgan
CUDA_VISIBLE_DEVICES=${gpus} ./local/synthesize_e2e.sh ${conf_path} ${train_output_path} ${ckpt_name} || exit -1
fi
if [ ${stage} -le 4 ] && [ ${stop_stage} -ge 4 ]; then
# inference with static model
CUDA_VISIBLE_DEVICES=${gpus} ./local/inference.sh ${train_output_path} || exit -1
fi
if [ ${stage} -le 5 ] && [ ${stop_stage} -ge 5 ]; then
# synthesize_e2e, vocoder is pwgan
CUDA_VISIBLE_DEVICES=${gpus} ./local/synthesize_streaming.sh ${conf_path} ${train_output_path} ${ckpt_name} || exit -1
fi

@ -0,0 +1,274 @@
# Copyright (c) 2021 PaddlePaddle Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import argparse
import math
from pathlib import Path
import numpy as np
import paddle
import soundfile as sf
import yaml
from timer import timer
from yacs.config import CfgNode
from paddlespeech.s2t.utils.dynamic_import import dynamic_import
from paddlespeech.t2s.exps.syn_utils import get_frontend
from paddlespeech.t2s.exps.syn_utils import get_sentences
from paddlespeech.t2s.exps.syn_utils import get_voc_inference
from paddlespeech.t2s.exps.syn_utils import model_alias
from paddlespeech.t2s.utils import str2bool
def denorm(data, mean, std):
return data * std + mean
def get_chunks(data, chunk_size, pad_size):
data_len = data.shape[1]
chunks = []
n = math.ceil(data_len / chunk_size)
for i in range(n):
start = max(0, i * chunk_size - pad_size)
end = min((i + 1) * chunk_size + pad_size, data_len)
chunks.append(data[:, start:end, :])
return chunks
def evaluate(args):
# Init body.
with open(args.am_config) as f:
am_config = CfgNode(yaml.safe_load(f))
with open(args.voc_config) as f:
voc_config = CfgNode(yaml.safe_load(f))
print("========Args========")
print(yaml.safe_dump(vars(args)))
print("========Config========")
print(am_config)
print(voc_config)
sentences = get_sentences(args)
# frontend
frontend = get_frontend(args)
with open(args.phones_dict, "r") as f:
phn_id = [line.strip().split() for line in f.readlines()]
vocab_size = len(phn_id)
print("vocab_size:", vocab_size)
# acoustic model, only support fastspeech2 here now!
# am_inference, am_name, am_dataset = get_am_inference(args, am_config)
# model: {model_name}_{dataset}
am_name = args.am[:args.am.rindex('_')]
am_dataset = args.am[args.am.rindex('_') + 1:]
odim = am_config.n_mels
am_class = dynamic_import(am_name, model_alias)
am = am_class(idim=vocab_size, odim=odim, **am_config["model"])
am.set_state_dict(paddle.load(args.am_ckpt)["main_params"])
am.eval()
am_mu, am_std = np.load(args.am_stat)
am_mu = paddle.to_tensor(am_mu)
am_std = paddle.to_tensor(am_std)
# vocoder
voc_inference = get_voc_inference(args, voc_config)
output_dir = Path(args.output_dir)
output_dir.mkdir(parents=True, exist_ok=True)
merge_sentences = True
N = 0
T = 0
chunk_size = args.chunk_size
pad_size = args.pad_size
for utt_id, sentence in sentences:
with timer() as t:
get_tone_ids = False
if args.lang == 'zh':
input_ids = frontend.get_input_ids(
sentence,
merge_sentences=merge_sentences,
get_tone_ids=get_tone_ids)
phone_ids = input_ids["phone_ids"]
else:
print("lang should in be 'zh' here!")
# merge_sentences=True here, so we only use the first item of phone_ids
phone_ids = phone_ids[0]
with paddle.no_grad():
# acoustic model
orig_hs, h_masks = am.encoder_infer(phone_ids)
if args.am_streaming:
hss = get_chunks(orig_hs, chunk_size, pad_size)
chunk_num = len(hss)
mel_list = []
for i, hs in enumerate(hss):
before_outs, _ = am.decoder(hs)
after_outs = before_outs + am.postnet(
before_outs.transpose((0, 2, 1))).transpose(
(0, 2, 1))
normalized_mel = after_outs[0]
sub_mel = denorm(normalized_mel, am_mu, am_std)
# clip output part of pad
if i == 0:
sub_mel = sub_mel[:-pad_size]
elif i == chunk_num - 1:
# 最后一块的右侧一定没有 pad 够
sub_mel = sub_mel[pad_size:]
else:
# 倒数几块的右侧也可能没有 pad 够
sub_mel = sub_mel[pad_size:(chunk_size + pad_size) -
sub_mel.shape[0]]
mel_list.append(sub_mel)
mel = paddle.concat(mel_list, axis=0)
else:
before_outs, _ = am.decoder(orig_hs)
after_outs = before_outs + am.postnet(
before_outs.transpose((0, 2, 1))).transpose((0, 2, 1))
normalized_mel = after_outs[0]
mel = denorm(normalized_mel, am_mu, am_std)
# vocoder
wav = voc_inference(mel)
wav = wav.numpy()
N += wav.size
T += t.elapse
speed = wav.size / t.elapse
rtf = am_config.fs / speed
print(
f"{utt_id}, mel: {mel.shape}, wave: {wav.shape}, time: {t.elapse}s, Hz: {speed}, RTF: {rtf}."
)
sf.write(
str(output_dir / (utt_id + ".wav")), wav, samplerate=am_config.fs)
print(f"{utt_id} done!")
print(f"generation speed: {N / T}Hz, RTF: {am_config.fs / (N / T) }")
def parse_args():
# parse args and config and redirect to train_sp
parser = argparse.ArgumentParser(
description="Synthesize with acoustic model & vocoder")
# acoustic model
parser.add_argument(
'--am',
type=str,
default='fastspeech2_csmsc',
choices=['fastspeech2_csmsc'],
help='Choose acoustic model type of tts task.')
parser.add_argument(
'--am_config',
type=str,
default=None,
help='Config of acoustic model. Use deault config when it is None.')
parser.add_argument(
'--am_ckpt',
type=str,
default=None,
help='Checkpoint file of acoustic model.')
parser.add_argument(
"--am_stat",
type=str,
default=None,
help="mean and standard deviation used to normalize spectrogram when training acoustic model."
)
parser.add_argument(
"--phones_dict", type=str, default=None, help="phone vocabulary file.")
parser.add_argument(
"--tones_dict", type=str, default=None, help="tone vocabulary file.")
# vocoder
parser.add_argument(
'--voc',
type=str,
default='pwgan_csmsc',
choices=[
'pwgan_csmsc',
'pwgan_ljspeech',
'pwgan_aishell3',
'pwgan_vctk',
'mb_melgan_csmsc',
'style_melgan_csmsc',
'hifigan_csmsc',
'hifigan_ljspeech',
'hifigan_aishell3',
'hifigan_vctk',
'wavernn_csmsc',
],
help='Choose vocoder type of tts task.')
parser.add_argument(
'--voc_config',
type=str,
default=None,
help='Config of voc. Use deault config when it is None.')
parser.add_argument(
'--voc_ckpt', type=str, default=None, help='Checkpoint file of voc.')
parser.add_argument(
"--voc_stat",
type=str,
default=None,
help="mean and standard deviation used to normalize spectrogram when training voc."
)
# other
parser.add_argument(
'--lang',
type=str,
default='zh',
help='Choose model language. zh or en')
parser.add_argument(
"--ngpu", type=int, default=1, help="if ngpu == 0, use cpu.")
parser.add_argument(
"--text",
type=str,
help="text to synthesize, a 'utt_id sentence' pair per line.")
parser.add_argument(
"--am_streaming",
type=str2bool,
default=False,
help="whether use streaming acoustic model")
parser.add_argument(
"--chunk_size", type=int, default=42, help="chunk size of am streaming")
parser.add_argument(
"--pad_size", type=int, default=12, help="pad size of am streaming")
parser.add_argument("--output_dir", type=str, help="output dir.")
args = parser.parse_args()
return args
def main():
args = parse_args()
if args.ngpu == 0:
paddle.set_device("cpu")
elif args.ngpu > 0:
paddle.set_device("gpu")
else:
print("ngpu should >= 0 !")
evaluate(args)
if __name__ == "__main__":
main()

@ -14,6 +14,7 @@
# Modified from espnet(https://github.com/espnet/espnet)
"""Fastspeech2 related modules for paddle"""
from typing import Dict
from typing import List
from typing import Sequence
from typing import Tuple
from typing import Union
@ -32,6 +33,8 @@ from paddlespeech.t2s.modules.predictor.duration_predictor import DurationPredic
from paddlespeech.t2s.modules.predictor.length_regulator import LengthRegulator
from paddlespeech.t2s.modules.predictor.variance_predictor import VariancePredictor
from paddlespeech.t2s.modules.tacotron2.decoder import Postnet
from paddlespeech.t2s.modules.transformer.encoder import CNNDecoder
from paddlespeech.t2s.modules.transformer.encoder import CNNPostnet
from paddlespeech.t2s.modules.transformer.encoder import ConformerEncoder
from paddlespeech.t2s.modules.transformer.encoder import TransformerEncoder
@ -97,6 +100,12 @@ class FastSpeech2(nn.Layer):
zero_triu: bool=False,
conformer_enc_kernel_size: int=7,
conformer_dec_kernel_size: int=31,
# for CNN Decoder
cnn_dec_dropout_rate: float=0.2,
cnn_postnet_dropout_rate: float=0.2,
cnn_postnet_resblock_kernel_sizes: List[int]=[256, 256],
cnn_postnet_kernel_size: int=5,
cnn_decoder_embedding_dim: int=256,
# duration predictor
duration_predictor_layers: int=2,
duration_predictor_chans: int=384,
@ -392,6 +401,13 @@ class FastSpeech2(nn.Layer):
activation_type=conformer_activation_type,
use_cnn_module=use_cnn_in_conformer,
cnn_module_kernel=conformer_dec_kernel_size, )
elif decoder_type == 'cnndecoder':
self.decoder = CNNDecoder(
emb_dim=adim,
odim=odim,
kernel_size=cnn_postnet_kernel_size,
dropout_rate=cnn_dec_dropout_rate,
resblock_kernel_sizes=cnn_postnet_resblock_kernel_sizes)
else:
raise ValueError(f"{decoder_type} is not supported.")
@ -399,14 +415,21 @@ class FastSpeech2(nn.Layer):
self.feat_out = nn.Linear(adim, odim * reduction_factor)
# define postnet
self.postnet = (None if postnet_layers == 0 else Postnet(
idim=idim,
odim=odim,
n_layers=postnet_layers,
n_chans=postnet_chans,
n_filts=postnet_filts,
use_batch_norm=use_batch_norm,
dropout_rate=postnet_dropout_rate, ))
if decoder_type == 'cnndecoder':
self.postnet = CNNPostnet(
odim=odim,
kernel_size=cnn_postnet_kernel_size,
dropout_rate=cnn_postnet_dropout_rate,
resblock_kernel_sizes=cnn_postnet_resblock_kernel_sizes)
else:
self.postnet = (None if postnet_layers == 0 else Postnet(
idim=idim,
odim=odim,
n_layers=postnet_layers,
n_chans=postnet_chans,
n_filts=postnet_filts,
use_batch_norm=use_batch_norm,
dropout_rate=postnet_dropout_rate, ))
nn.initializer.set_global_initializer(None)
@ -486,6 +509,7 @@ class FastSpeech2(nn.Layer):
ps: paddle.Tensor=None,
es: paddle.Tensor=None,
is_inference: bool=False,
return_after_enc=False,
alpha: float=1.0,
spk_emb=None,
spk_id=None,
@ -562,15 +586,21 @@ class FastSpeech2(nn.Layer):
[olen // self.reduction_factor for olen in olens.numpy()])
else:
olens_in = olens
# (B, 1, T)
h_masks = self._source_mask(olens_in)
else:
h_masks = None
# (B, Lmax, adim)
if return_after_enc:
return hs, h_masks
# (B, Lmax, adim)
zs, _ = self.decoder(hs, h_masks)
# (B, Lmax, odim)
before_outs = self.feat_out(zs).reshape(
(paddle.shape(zs)[0], -1, self.odim))
if self.decoder_type == 'cnndecoder':
before_outs = zs
else:
before_outs = self.feat_out(zs).reshape(
(paddle.shape(zs)[0], -1, self.odim))
# postnet -> (B, Lmax//r * r, odim)
if self.postnet is None:
@ -581,10 +611,42 @@ class FastSpeech2(nn.Layer):
return before_outs, after_outs, d_outs, p_outs, e_outs
def encoder_infer(
self,
text: paddle.Tensor,
alpha: float=1.0,
spk_emb=None,
spk_id=None,
tone_id=None,
) -> Tuple[paddle.Tensor, paddle.Tensor, paddle.Tensor]:
# input of embedding must be int64
x = paddle.cast(text, 'int64')
# setup batch axis
ilens = paddle.shape(x)[0]
xs = x.unsqueeze(0)
if spk_emb is not None:
spk_emb = spk_emb.unsqueeze(0)
if tone_id is not None:
tone_id = tone_id.unsqueeze(0)
# (1, L, odim)
hs, h_masks = self._forward(
xs,
ilens,
is_inference=True,
return_after_enc=True,
alpha=alpha,
spk_emb=spk_emb,
spk_id=spk_id,
tone_id=tone_id)
return hs, h_masks
def inference(
self,
text: paddle.Tensor,
speech: paddle.Tensor=None,
durations: paddle.Tensor=None,
pitch: paddle.Tensor=None,
energy: paddle.Tensor=None,
@ -598,7 +660,6 @@ class FastSpeech2(nn.Layer):
Args:
text(Tensor(int64)): Input sequence of characters (T,).
speech(Tensor, optional): Feature sequence to extract style (N, idim).
durations(Tensor, optional (int64)): Groundtruth of duration (T,).
pitch(Tensor, optional): Groundtruth of token-averaged pitch (T, 1).
energy(Tensor, optional): Groundtruth of token-averaged energy (T, 1).
@ -615,15 +676,11 @@ class FastSpeech2(nn.Layer):
"""
# input of embedding must be int64
x = paddle.cast(text, 'int64')
y = speech
d, p, e = durations, pitch, energy
# setup batch axis
ilens = paddle.shape(x)[0]
xs, ys = x.unsqueeze(0), None
if y is not None:
ys = y.unsqueeze(0)
xs = x.unsqueeze(0)
if spk_emb is not None:
spk_emb = spk_emb.unsqueeze(0)
@ -641,7 +698,6 @@ class FastSpeech2(nn.Layer):
_, outs, d_outs, p_outs, e_outs = self._forward(
xs,
ilens,
ys,
ds=ds,
ps=ps,
es=es,
@ -654,7 +710,6 @@ class FastSpeech2(nn.Layer):
_, outs, d_outs, p_outs, e_outs = self._forward(
xs,
ilens,
ys,
is_inference=True,
alpha=alpha,
spk_emb=spk_emb,
@ -802,7 +857,6 @@ class StyleFastSpeech2Inference(FastSpeech2Inference):
Args:
text(Tensor(int64)): Input sequence of characters (T,).
speech(Tensor, optional): Feature sequence to extract style (N, idim).
durations(paddle.Tensor/np.ndarray, optional (int64)): Groundtruth of duration (T,), this will overwrite the set of durations_scale and durations_bias
durations_scale(int/float, optional):
durations_bias(int/float, optional):

@ -515,3 +515,132 @@ class ConformerEncoder(BaseEncoder):
if self.intermediate_layers is not None:
return xs, masks, intermediate_outputs
return xs, masks
class Conv1dResidualBlock(nn.Layer):
"""
Special module for simplified version of Encoder class.
"""
def __init__(self,
idim: int=256,
odim: int=256,
kernel_size: int=5,
dropout_rate: float=0.2):
super().__init__()
self.main_block = nn.Sequential(
nn.Conv1D(
idim, odim, kernel_size=kernel_size, padding=kernel_size // 2),
nn.ReLU(),
nn.BatchNorm1D(odim),
nn.Dropout(p=dropout_rate))
self.conv1d_residual = nn.Conv1D(idim, odim, kernel_size=1)
def forward(self, xs):
"""Encode input sequence.
Args:
xs (Tensor): Input tensor (#batch, idim, T).
Returns:
Tensor: Output tensor (#batch, odim, T).
"""
outputs = self.main_block(xs)
outputs = self.conv1d_residual(xs) + outputs
return outputs
class CNNDecoder(nn.Layer):
"""
Much simplified decoder than the original one with Prenet.
"""
def __init__(
self,
emb_dim: int=256,
odim: int=80,
kernel_size: int=5,
dropout_rate: float=0.2,
resblock_kernel_sizes: List[int]=[256, 256], ):
super().__init__()
input_shape = emb_dim
out_sizes = resblock_kernel_sizes
out_sizes.append(out_sizes[-1])
in_sizes = [input_shape] + out_sizes[:-1]
self.residual_blocks = nn.LayerList([
Conv1dResidualBlock(
idim=in_channels,
odim=out_channels,
kernel_size=kernel_size,
dropout_rate=dropout_rate, )
for in_channels, out_channels in zip(in_sizes, out_sizes)
])
self.conv1d = nn.Conv1D(
in_channels=out_sizes[-1], out_channels=odim, kernel_size=1)
def forward(self, xs, masks=None):
"""Encode input sequence.
Args:
xs (Tensor): Input tensor (#batch, time, idim).
masks (Tensor): Mask tensor (#batch, 1, time).
Returns:
Tensor: Output tensor (#batch, time, odim).
"""
# exchange the temporal dimension and the feature dimension
xs = xs.transpose([0, 2, 1])
if masks is not None:
xs = xs * masks
for layer in self.residual_blocks:
outputs = layer(xs)
if masks is not None:
# input_mask B * 1 * T
outputs = outputs * masks
xs = outputs
outputs = self.conv1d(outputs)
if masks is not None:
outputs = outputs * masks
outputs = outputs.transpose([0, 2, 1])
return outputs, masks
class CNNPostnet(nn.Layer):
def __init__(
self,
odim: int=80,
kernel_size: int=5,
dropout_rate: float=0.2,
resblock_kernel_sizes: List[int]=[256, 256], ):
super().__init__()
out_sizes = resblock_kernel_sizes
in_sizes = [odim] + out_sizes[:-1]
self.residual_blocks = nn.LayerList([
Conv1dResidualBlock(
idim=in_channels,
odim=out_channels,
kernel_size=kernel_size,
dropout_rate=dropout_rate)
for in_channels, out_channels in zip(in_sizes, out_sizes)
])
self.conv1d = nn.Conv1D(
in_channels=out_sizes[-1], out_channels=odim, kernel_size=1)
def forward(self, xs, masks=None):
"""Encode input sequence.
Args:
xs (Tensor): Input tensor (#batch, odim, time).
masks (Tensor): Mask tensor (#batch, 1, time).
Returns:
Tensor: Output tensor (#batch, odim, time).
"""
for layer in self.residual_blocks:
outputs = layer(xs)
if masks is not None:
# input_mask B * 1 * T
outputs = outputs * masks
xs = outputs
outputs = self.conv1d(outputs)
if masks is not None:
outputs = outputs * masks
return outputs

Loading…
Cancel
Save