commit
97db74ca60
@ -0,0 +1,91 @@
|
||||
# This configuration is for Paddle to train Tacotron 2. Compared to the
|
||||
# original paper, this configuration additionally use the guided attention
|
||||
# loss to accelerate the learning of the diagonal attention. It requires
|
||||
# only a single GPU with 12 GB memory and it takes ~1 days to finish the
|
||||
# training on Titan V.
|
||||
|
||||
###########################################################
|
||||
# FEATURE EXTRACTION SETTING #
|
||||
###########################################################
|
||||
|
||||
fs: 24000 # sr
|
||||
n_fft: 2048 # FFT size (samples).
|
||||
n_shift: 300 # Hop size (samples). 12.5ms
|
||||
win_length: 1200 # Window length (samples). 50ms
|
||||
# If set to null, it will be the same as fft_size.
|
||||
window: "hann" # Window function.
|
||||
|
||||
# Only used for feats_type != raw
|
||||
|
||||
fmin: 80 # Minimum frequency of Mel basis.
|
||||
fmax: 7600 # Maximum frequency of Mel basis.
|
||||
n_mels: 80 # The number of mel basis.
|
||||
|
||||
###########################################################
|
||||
# DATA SETTING #
|
||||
###########################################################
|
||||
batch_size: 64
|
||||
num_workers: 2
|
||||
|
||||
###########################################################
|
||||
# MODEL SETTING #
|
||||
###########################################################
|
||||
model: # keyword arguments for the selected model
|
||||
embed_dim: 512 # char or phn embedding dimension
|
||||
elayers: 1 # number of blstm layers in encoder
|
||||
eunits: 512 # number of blstm units
|
||||
econv_layers: 3 # number of convolutional layers in encoder
|
||||
econv_chans: 512 # number of channels in convolutional layer
|
||||
econv_filts: 5 # filter size of convolutional layer
|
||||
atype: location # attention function type
|
||||
adim: 512 # attention dimension
|
||||
aconv_chans: 32 # number of channels in convolutional layer of attention
|
||||
aconv_filts: 15 # filter size of convolutional layer of attention
|
||||
cumulate_att_w: True # whether to cumulate attention weight
|
||||
dlayers: 2 # number of lstm layers in decoder
|
||||
dunits: 1024 # number of lstm units in decoder
|
||||
prenet_layers: 2 # number of layers in prenet
|
||||
prenet_units: 256 # number of units in prenet
|
||||
postnet_layers: 5 # number of layers in postnet
|
||||
postnet_chans: 512 # number of channels in postnet
|
||||
postnet_filts: 5 # filter size of postnet layer
|
||||
output_activation: null # activation function for the final output
|
||||
use_batch_norm: True # whether to use batch normalization in encoder
|
||||
use_concate: True # whether to concatenate encoder embedding with decoder outputs
|
||||
use_residual: False # whether to use residual connection in encoder
|
||||
dropout_rate: 0.5 # dropout rate
|
||||
zoneout_rate: 0.1 # zoneout rate
|
||||
reduction_factor: 1 # reduction factor
|
||||
spk_embed_dim: null # speaker embedding dimension
|
||||
|
||||
|
||||
###########################################################
|
||||
# UPDATER SETTING #
|
||||
###########################################################
|
||||
updater:
|
||||
use_masking: True # whether to apply masking for padded part in loss calculation
|
||||
bce_pos_weight: 5.0 # weight of positive sample in binary cross entropy calculation
|
||||
use_guided_attn_loss: True # whether to use guided attention loss
|
||||
guided_attn_loss_sigma: 0.4 # sigma of guided attention loss
|
||||
guided_attn_loss_lambda: 1.0 # strength of guided attention loss
|
||||
|
||||
|
||||
##########################################################
|
||||
# OPTIMIZER SETTING #
|
||||
##########################################################
|
||||
optimizer:
|
||||
optim: adam # optimizer type
|
||||
learning_rate: 1.0e-03 # learning rate
|
||||
epsilon: 1.0e-06 # epsilon
|
||||
weight_decay: 0.0 # weight decay coefficient
|
||||
|
||||
###########################################################
|
||||
# TRAINING SETTING #
|
||||
###########################################################
|
||||
max_epoch: 200
|
||||
num_snapshots: 5
|
||||
|
||||
###########################################################
|
||||
# OTHER SETTING #
|
||||
###########################################################
|
||||
seed: 42
|
@ -0,0 +1,62 @@
|
||||
#!/bin/bash
|
||||
|
||||
stage=0
|
||||
stop_stage=100
|
||||
|
||||
config_path=$1
|
||||
|
||||
if [ ${stage} -le 0 ] && [ ${stop_stage} -ge 0 ]; then
|
||||
# get durations from MFA's result
|
||||
echo "Generate durations.txt from MFA results ..."
|
||||
python3 ${MAIN_ROOT}/utils/gen_duration_from_textgrid.py \
|
||||
--inputdir=./baker_alignment_tone \
|
||||
--output=durations.txt \
|
||||
--config=${config_path}
|
||||
fi
|
||||
|
||||
if [ ${stage} -le 1 ] && [ ${stop_stage} -ge 1 ]; then
|
||||
# extract features
|
||||
echo "Extract features ..."
|
||||
python3 ${BIN_DIR}/preprocess.py \
|
||||
--dataset=baker \
|
||||
--rootdir=~/datasets/BZNSYP/ \
|
||||
--dumpdir=dump \
|
||||
--dur-file=durations.txt \
|
||||
--config=${config_path} \
|
||||
--num-cpu=20 \
|
||||
--cut-sil=True
|
||||
fi
|
||||
|
||||
if [ ${stage} -le 2 ] && [ ${stop_stage} -ge 2 ]; then
|
||||
# get features' stats(mean and std)
|
||||
echo "Get features' stats ..."
|
||||
python3 ${MAIN_ROOT}/utils/compute_statistics.py \
|
||||
--metadata=dump/train/raw/metadata.jsonl \
|
||||
--field-name="speech"
|
||||
|
||||
fi
|
||||
|
||||
if [ ${stage} -le 3 ] && [ ${stop_stage} -ge 3 ]; then
|
||||
# normalize and covert phone to id, dev and test should use train's stats
|
||||
echo "Normalize ..."
|
||||
python3 ${BIN_DIR}/normalize.py \
|
||||
--metadata=dump/train/raw/metadata.jsonl \
|
||||
--dumpdir=dump/train/norm \
|
||||
--speech-stats=dump/train/speech_stats.npy \
|
||||
--phones-dict=dump/phone_id_map.txt \
|
||||
--speaker-dict=dump/speaker_id_map.txt
|
||||
|
||||
python3 ${BIN_DIR}/normalize.py \
|
||||
--metadata=dump/dev/raw/metadata.jsonl \
|
||||
--dumpdir=dump/dev/norm \
|
||||
--speech-stats=dump/train/speech_stats.npy \
|
||||
--phones-dict=dump/phone_id_map.txt \
|
||||
--speaker-dict=dump/speaker_id_map.txt
|
||||
|
||||
python3 ${BIN_DIR}/normalize.py \
|
||||
--metadata=dump/test/raw/metadata.jsonl \
|
||||
--dumpdir=dump/test/norm \
|
||||
--speech-stats=dump/train/speech_stats.npy \
|
||||
--phones-dict=dump/phone_id_map.txt \
|
||||
--speaker-dict=dump/speaker_id_map.txt
|
||||
fi
|
@ -0,0 +1,20 @@
|
||||
#!/bin/bash
|
||||
|
||||
config_path=$1
|
||||
train_output_path=$2
|
||||
ckpt_name=$3
|
||||
|
||||
FLAGS_allocator_strategy=naive_best_fit \
|
||||
FLAGS_fraction_of_gpu_memory_to_use=0.01 \
|
||||
python3 ${BIN_DIR}/../synthesize.py \
|
||||
--am=tacotron2_csmsc \
|
||||
--am_config=${config_path} \
|
||||
--am_ckpt=${train_output_path}/checkpoints/${ckpt_name} \
|
||||
--am_stat=dump/train/speech_stats.npy \
|
||||
--voc=pwgan_csmsc \
|
||||
--voc_config=pwg_baker_ckpt_0.4/pwg_default.yaml \
|
||||
--voc_ckpt=pwg_baker_ckpt_0.4/pwg_snapshot_iter_400000.pdz \
|
||||
--voc_stat=pwg_baker_ckpt_0.4/pwg_stats.npy \
|
||||
--test_metadata=dump/test/norm/metadata.jsonl \
|
||||
--output_dir=${train_output_path}/test \
|
||||
--phones_dict=dump/phone_id_map.txt
|
@ -0,0 +1,91 @@
|
||||
#!/bin/bash
|
||||
|
||||
config_path=$1
|
||||
train_output_path=$2
|
||||
ckpt_name=$3
|
||||
|
||||
stage=0
|
||||
stop_stage=0
|
||||
|
||||
if [ ${stage} -le 0 ] && [ ${stop_stage} -ge 0 ]; then
|
||||
FLAGS_allocator_strategy=naive_best_fit \
|
||||
FLAGS_fraction_of_gpu_memory_to_use=0.01 \
|
||||
python3 ${BIN_DIR}/../synthesize_e2e.py \
|
||||
--am=tacotron2_csmsc \
|
||||
--am_config=${config_path} \
|
||||
--am_ckpt=${train_output_path}/checkpoints/${ckpt_name} \
|
||||
--am_stat=dump/train/speech_stats.npy \
|
||||
--voc=pwgan_csmsc \
|
||||
--voc_config=pwg_baker_ckpt_0.4/pwg_default.yaml \
|
||||
--voc_ckpt=pwg_baker_ckpt_0.4/pwg_snapshot_iter_400000.pdz \
|
||||
--voc_stat=pwg_baker_ckpt_0.4/pwg_stats.npy \
|
||||
--lang=zh \
|
||||
--text=${BIN_DIR}/../sentences.txt \
|
||||
--output_dir=${train_output_path}/test_e2e \
|
||||
--inference_dir=${train_output_path}/inference \
|
||||
--phones_dict=dump/phone_id_map.txt
|
||||
fi
|
||||
|
||||
# for more GAN Vocoders
|
||||
# multi band melgan
|
||||
if [ ${stage} -le 1 ] && [ ${stop_stage} -ge 1 ]; then
|
||||
FLAGS_allocator_strategy=naive_best_fit \
|
||||
FLAGS_fraction_of_gpu_memory_to_use=0.01 \
|
||||
python3 ${BIN_DIR}/../synthesize_e2e.py \
|
||||
--am=fastspeech2_csmsc \
|
||||
--am_config=${config_path} \
|
||||
--am_ckpt=${train_output_path}/checkpoints/${ckpt_name} \
|
||||
--am_stat=dump/train/speech_stats.npy \
|
||||
--voc=mb_melgan_csmsc \
|
||||
--voc_config=mb_melgan_baker_finetune_ckpt_0.5/finetune.yaml \
|
||||
--voc_ckpt=mb_melgan_baker_finetune_ckpt_0.5/snapshot_iter_2000000.pdz\
|
||||
--voc_stat=mb_melgan_baker_finetune_ckpt_0.5/feats_stats.npy \
|
||||
--lang=zh \
|
||||
--text=${BIN_DIR}/../sentences.txt \
|
||||
--output_dir=${train_output_path}/test_e2e \
|
||||
--inference_dir=${train_output_path}/inference \
|
||||
--phones_dict=dump/phone_id_map.txt
|
||||
fi
|
||||
|
||||
# the pretrained models haven't release now
|
||||
# style melgan
|
||||
# style melgan's Dygraph to Static Graph is not ready now
|
||||
if [ ${stage} -le 2 ] && [ ${stop_stage} -ge 2 ]; then
|
||||
FLAGS_allocator_strategy=naive_best_fit \
|
||||
FLAGS_fraction_of_gpu_memory_to_use=0.01 \
|
||||
python3 ${BIN_DIR}/../synthesize_e2e.py \
|
||||
--am=fastspeech2_csmsc \
|
||||
--am_config=${config_path} \
|
||||
--am_ckpt=${train_output_path}/checkpoints/${ckpt_name} \
|
||||
--am_stat=dump/train/speech_stats.npy \
|
||||
--voc=style_melgan_csmsc \
|
||||
--voc_config=style_melgan_csmsc_ckpt_0.1.1/default.yaml \
|
||||
--voc_ckpt=style_melgan_csmsc_ckpt_0.1.1/snapshot_iter_1500000.pdz \
|
||||
--voc_stat=style_melgan_csmsc_ckpt_0.1.1/feats_stats.npy \
|
||||
--lang=zh \
|
||||
--text=${BIN_DIR}/../sentences.txt \
|
||||
--output_dir=${train_output_path}/test_e2e \
|
||||
--phones_dict=dump/phone_id_map.txt
|
||||
# --inference_dir=${train_output_path}/inference
|
||||
fi
|
||||
|
||||
# hifigan
|
||||
if [ ${stage} -le 3 ] && [ ${stop_stage} -ge 3 ]; then
|
||||
echo "in hifigan syn_e2e"
|
||||
FLAGS_allocator_strategy=naive_best_fit \
|
||||
FLAGS_fraction_of_gpu_memory_to_use=0.01 \
|
||||
python3 ${BIN_DIR}/../synthesize_e2e.py \
|
||||
--am=fastspeech2_csmsc \
|
||||
--am_config=${config_path} \
|
||||
--am_ckpt=${train_output_path}/checkpoints/${ckpt_name} \
|
||||
--am_stat=dump/train/speech_stats.npy \
|
||||
--voc=hifigan_csmsc \
|
||||
--voc_config=hifigan_csmsc_ckpt_0.1.1/default.yaml \
|
||||
--voc_ckpt=hifigan_csmsc_ckpt_0.1.1/snapshot_iter_2500000.pdz \
|
||||
--voc_stat=hifigan_csmsc_ckpt_0.1.1/feats_stats.npy \
|
||||
--lang=zh \
|
||||
--text=${BIN_DIR}/../sentences.txt \
|
||||
--output_dir=${train_output_path}/test_e2e \
|
||||
--inference_dir=${train_output_path}/inference \
|
||||
--phones_dict=dump/phone_id_map.txt
|
||||
fi
|
@ -0,0 +1,12 @@
|
||||
#!/bin/bash
|
||||
|
||||
config_path=$1
|
||||
train_output_path=$2
|
||||
|
||||
python3 ${BIN_DIR}/train.py \
|
||||
--train-metadata=dump/train/norm/metadata.jsonl \
|
||||
--dev-metadata=dump/dev/norm/metadata.jsonl \
|
||||
--config=${config_path} \
|
||||
--output-dir=${train_output_path} \
|
||||
--ngpu=1 \
|
||||
--phones-dict=dump/phone_id_map.txt
|
@ -0,0 +1,13 @@
|
||||
#!/bin/bash
|
||||
export MAIN_ROOT=`realpath ${PWD}/../../../`
|
||||
|
||||
export PATH=${MAIN_ROOT}:${MAIN_ROOT}/utils:${PATH}
|
||||
export LC_ALL=C
|
||||
|
||||
export PYTHONDONTWRITEBYTECODE=1
|
||||
# Use UTF-8 in Python to avoid UnicodeDecodeError when LC_ALL=C
|
||||
export PYTHONIOENCODING=UTF-8
|
||||
export PYTHONPATH=${MAIN_ROOT}:${PYTHONPATH}
|
||||
|
||||
MODEL=new_tacotron2
|
||||
export BIN_DIR=${MAIN_ROOT}/paddlespeech/t2s/exps/${MODEL}
|
@ -0,0 +1,37 @@
|
||||
#!/bin/bash
|
||||
|
||||
set -e
|
||||
source path.sh
|
||||
|
||||
gpus=0,1
|
||||
stage=0
|
||||
stop_stage=100
|
||||
|
||||
conf_path=conf/default.yaml
|
||||
train_output_path=exp/default
|
||||
ckpt_name=snapshot_iter_153.pdz
|
||||
|
||||
# with the following command, you can choose the stage range you want to run
|
||||
# such as `./run.sh --stage 0 --stop-stage 0`
|
||||
# this can not be mixed use with `$1`, `$2` ...
|
||||
source ${MAIN_ROOT}/utils/parse_options.sh || exit 1
|
||||
|
||||
if [ ${stage} -le 0 ] && [ ${stop_stage} -ge 0 ]; then
|
||||
# prepare data
|
||||
./local/preprocess.sh ${conf_path} || exit -1
|
||||
fi
|
||||
|
||||
if [ ${stage} -le 1 ] && [ ${stop_stage} -ge 1 ]; then
|
||||
# train model, all `ckpt` under `train_output_path/checkpoints/` dir
|
||||
CUDA_VISIBLE_DEVICES=${gpus} ./local/train.sh ${conf_path} ${train_output_path} || exit -1
|
||||
fi
|
||||
|
||||
if [ ${stage} -le 2 ] && [ ${stop_stage} -ge 2 ]; then
|
||||
# synthesize, vocoder is pwgan
|
||||
CUDA_VISIBLE_DEVICES=${gpus} ./local/synthesize.sh ${conf_path} ${train_output_path} ${ckpt_name} || exit -1
|
||||
fi
|
||||
|
||||
if [ ${stage} -le 3 ] && [ ${stop_stage} -ge 3 ]; then
|
||||
# synthesize_e2e, vocoder is pwgan
|
||||
CUDA_VISIBLE_DEVICES=${gpus} ./local/synthesize_e2e.sh ${conf_path} ${train_output_path} ${ckpt_name} || exit -1
|
||||
fi
|
@ -0,0 +1,13 @@
|
||||
# Copyright (c) 2020 PaddlePaddle Authors. All Rights Reserved.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
@ -0,0 +1 @@
|
||||
../transformer_tts/normalize.py
|
@ -0,0 +1,328 @@
|
||||
# Copyright (c) 2021 PaddlePaddle Authors. All Rights Reserved.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
import argparse
|
||||
import os
|
||||
from concurrent.futures import ThreadPoolExecutor
|
||||
from operator import itemgetter
|
||||
from pathlib import Path
|
||||
from typing import Any
|
||||
from typing import Dict
|
||||
from typing import List
|
||||
|
||||
import jsonlines
|
||||
import librosa
|
||||
import numpy as np
|
||||
import tqdm
|
||||
import yaml
|
||||
from yacs.config import CfgNode
|
||||
|
||||
from paddlespeech.t2s.data.get_feats import LogMelFBank
|
||||
from paddlespeech.t2s.datasets.preprocess_utils import compare_duration_and_mel_length
|
||||
from paddlespeech.t2s.datasets.preprocess_utils import get_input_token
|
||||
from paddlespeech.t2s.datasets.preprocess_utils import get_phn_dur
|
||||
from paddlespeech.t2s.datasets.preprocess_utils import get_spk_id_map
|
||||
from paddlespeech.t2s.datasets.preprocess_utils import merge_silence
|
||||
|
||||
|
||||
def process_sentence(config: Dict[str, Any],
|
||||
fp: Path,
|
||||
sentences: Dict,
|
||||
output_dir: Path,
|
||||
mel_extractor=None,
|
||||
cut_sil: bool=True,
|
||||
spk_emb_dir: Path=None):
|
||||
utt_id = fp.stem
|
||||
# for vctk
|
||||
if utt_id.endswith("_mic2"):
|
||||
utt_id = utt_id[:-5]
|
||||
record = None
|
||||
if utt_id in sentences:
|
||||
# reading, resampling may occur
|
||||
wav, _ = librosa.load(str(fp), sr=config.fs)
|
||||
if len(wav.shape) != 1 or np.abs(wav).max() > 1.0:
|
||||
return record
|
||||
assert len(wav.shape) == 1, f"{utt_id} is not a mono-channel audio."
|
||||
assert np.abs(wav).max(
|
||||
) <= 1.0, f"{utt_id} is seems to be different that 16 bit PCM."
|
||||
phones = sentences[utt_id][0]
|
||||
durations = sentences[utt_id][1]
|
||||
speaker = sentences[utt_id][2]
|
||||
d_cumsum = np.pad(np.array(durations).cumsum(0), (1, 0), 'constant')
|
||||
# little imprecise than use *.TextGrid directly
|
||||
times = librosa.frames_to_time(
|
||||
d_cumsum, sr=config.fs, hop_length=config.n_shift)
|
||||
if cut_sil:
|
||||
start = 0
|
||||
end = d_cumsum[-1]
|
||||
if phones[0] == "sil" and len(durations) > 1:
|
||||
start = times[1]
|
||||
durations = durations[1:]
|
||||
phones = phones[1:]
|
||||
if phones[-1] == 'sil' and len(durations) > 1:
|
||||
end = times[-2]
|
||||
durations = durations[:-1]
|
||||
phones = phones[:-1]
|
||||
sentences[utt_id][0] = phones
|
||||
sentences[utt_id][1] = durations
|
||||
start, end = librosa.time_to_samples([start, end], sr=config.fs)
|
||||
wav = wav[start:end]
|
||||
# extract mel feats
|
||||
logmel = mel_extractor.get_log_mel_fbank(wav)
|
||||
# change duration according to mel_length
|
||||
compare_duration_and_mel_length(sentences, utt_id, logmel)
|
||||
phones = sentences[utt_id][0]
|
||||
durations = sentences[utt_id][1]
|
||||
num_frames = logmel.shape[0]
|
||||
assert sum(durations) == num_frames
|
||||
mel_dir = output_dir / "data_speech"
|
||||
mel_dir.mkdir(parents=True, exist_ok=True)
|
||||
mel_path = mel_dir / (utt_id + "_speech.npy")
|
||||
np.save(mel_path, logmel)
|
||||
record = {
|
||||
"utt_id": utt_id,
|
||||
"phones": phones,
|
||||
"text_lengths": len(phones),
|
||||
"speech_lengths": num_frames,
|
||||
"speech": str(mel_path),
|
||||
"speaker": speaker
|
||||
}
|
||||
if spk_emb_dir:
|
||||
if speaker in os.listdir(spk_emb_dir):
|
||||
embed_name = utt_id + ".npy"
|
||||
embed_path = spk_emb_dir / speaker / embed_name
|
||||
if embed_path.is_file():
|
||||
record["spk_emb"] = str(embed_path)
|
||||
else:
|
||||
return None
|
||||
return record
|
||||
|
||||
|
||||
def process_sentences(config,
|
||||
fps: List[Path],
|
||||
sentences: Dict,
|
||||
output_dir: Path,
|
||||
mel_extractor=None,
|
||||
nprocs: int=1,
|
||||
cut_sil: bool=True,
|
||||
spk_emb_dir: Path=None):
|
||||
if nprocs == 1:
|
||||
results = []
|
||||
for fp in fps:
|
||||
record = process_sentence(config, fp, sentences, output_dir,
|
||||
mel_extractor, cut_sil, spk_emb_dir)
|
||||
if record:
|
||||
results.append(record)
|
||||
else:
|
||||
with ThreadPoolExecutor(nprocs) as pool:
|
||||
futures = []
|
||||
with tqdm.tqdm(total=len(fps)) as progress:
|
||||
for fp in fps:
|
||||
future = pool.submit(process_sentence, config, fp,
|
||||
sentences, output_dir, mel_extractor,
|
||||
cut_sil, spk_emb_dir)
|
||||
future.add_done_callback(lambda p: progress.update())
|
||||
futures.append(future)
|
||||
|
||||
results = []
|
||||
for ft in futures:
|
||||
record = ft.result()
|
||||
if record:
|
||||
results.append(record)
|
||||
|
||||
results.sort(key=itemgetter("utt_id"))
|
||||
with jsonlines.open(output_dir / "metadata.jsonl", 'w') as writer:
|
||||
for item in results:
|
||||
writer.write(item)
|
||||
print("Done")
|
||||
|
||||
|
||||
def main():
|
||||
# parse config and args
|
||||
parser = argparse.ArgumentParser(
|
||||
description="Preprocess audio and then extract features.")
|
||||
|
||||
parser.add_argument(
|
||||
"--dataset",
|
||||
default="baker",
|
||||
type=str,
|
||||
help="name of dataset, should in {baker, aishell3, ljspeech, vctk} now")
|
||||
|
||||
parser.add_argument(
|
||||
"--rootdir", default=None, type=str, help="directory to dataset.")
|
||||
|
||||
parser.add_argument(
|
||||
"--dumpdir",
|
||||
type=str,
|
||||
required=True,
|
||||
help="directory to dump feature files.")
|
||||
parser.add_argument(
|
||||
"--dur-file", default=None, type=str, help="path to durations.txt.")
|
||||
|
||||
parser.add_argument("--config", type=str, help="fastspeech2 config file.")
|
||||
|
||||
parser.add_argument(
|
||||
"--verbose",
|
||||
type=int,
|
||||
default=1,
|
||||
help="logging level. higher is more logging. (default=1)")
|
||||
parser.add_argument(
|
||||
"--num-cpu", type=int, default=1, help="number of process.")
|
||||
|
||||
def str2bool(str):
|
||||
return True if str.lower() == 'true' else False
|
||||
|
||||
parser.add_argument(
|
||||
"--cut-sil",
|
||||
type=str2bool,
|
||||
default=True,
|
||||
help="whether cut sil in the edge of audio")
|
||||
|
||||
parser.add_argument(
|
||||
"--spk_emb_dir",
|
||||
default=None,
|
||||
type=str,
|
||||
help="directory to speaker embedding files.")
|
||||
args = parser.parse_args()
|
||||
|
||||
rootdir = Path(args.rootdir).expanduser()
|
||||
dumpdir = Path(args.dumpdir).expanduser()
|
||||
# use absolute path
|
||||
dumpdir = dumpdir.resolve()
|
||||
dumpdir.mkdir(parents=True, exist_ok=True)
|
||||
dur_file = Path(args.dur_file).expanduser()
|
||||
|
||||
if args.spk_emb_dir:
|
||||
spk_emb_dir = Path(args.spk_emb_dir).expanduser().resolve()
|
||||
else:
|
||||
spk_emb_dir = None
|
||||
|
||||
assert rootdir.is_dir()
|
||||
assert dur_file.is_file()
|
||||
|
||||
with open(args.config, 'rt') as f:
|
||||
config = CfgNode(yaml.safe_load(f))
|
||||
|
||||
if args.verbose > 1:
|
||||
print(vars(args))
|
||||
print(config)
|
||||
|
||||
sentences, speaker_set = get_phn_dur(dur_file)
|
||||
|
||||
merge_silence(sentences)
|
||||
phone_id_map_path = dumpdir / "phone_id_map.txt"
|
||||
speaker_id_map_path = dumpdir / "speaker_id_map.txt"
|
||||
get_input_token(sentences, phone_id_map_path, args.dataset)
|
||||
get_spk_id_map(speaker_set, speaker_id_map_path)
|
||||
|
||||
if args.dataset == "baker":
|
||||
wav_files = sorted(list((rootdir / "Wave").rglob("*.wav")))
|
||||
# split data into 3 sections
|
||||
num_train = 9800
|
||||
num_dev = 100
|
||||
train_wav_files = wav_files[:num_train]
|
||||
dev_wav_files = wav_files[num_train:num_train + num_dev]
|
||||
test_wav_files = wav_files[num_train + num_dev:]
|
||||
elif args.dataset == "aishell3":
|
||||
sub_num_dev = 5
|
||||
wav_dir = rootdir / "train" / "wav"
|
||||
train_wav_files = []
|
||||
dev_wav_files = []
|
||||
test_wav_files = []
|
||||
for speaker in os.listdir(wav_dir):
|
||||
wav_files = sorted(list((wav_dir / speaker).rglob("*.wav")))
|
||||
if len(wav_files) > 100:
|
||||
train_wav_files += wav_files[:-sub_num_dev * 2]
|
||||
dev_wav_files += wav_files[-sub_num_dev * 2:-sub_num_dev]
|
||||
test_wav_files += wav_files[-sub_num_dev:]
|
||||
else:
|
||||
train_wav_files += wav_files
|
||||
|
||||
elif args.dataset == "ljspeech":
|
||||
wav_files = sorted(list((rootdir / "wavs").rglob("*.wav")))
|
||||
# split data into 3 sections
|
||||
num_train = 12900
|
||||
num_dev = 100
|
||||
train_wav_files = wav_files[:num_train]
|
||||
dev_wav_files = wav_files[num_train:num_train + num_dev]
|
||||
test_wav_files = wav_files[num_train + num_dev:]
|
||||
elif args.dataset == "vctk":
|
||||
sub_num_dev = 5
|
||||
wav_dir = rootdir / "wav48_silence_trimmed"
|
||||
train_wav_files = []
|
||||
dev_wav_files = []
|
||||
test_wav_files = []
|
||||
for speaker in os.listdir(wav_dir):
|
||||
wav_files = sorted(list((wav_dir / speaker).rglob("*_mic2.flac")))
|
||||
if len(wav_files) > 100:
|
||||
train_wav_files += wav_files[:-sub_num_dev * 2]
|
||||
dev_wav_files += wav_files[-sub_num_dev * 2:-sub_num_dev]
|
||||
test_wav_files += wav_files[-sub_num_dev:]
|
||||
else:
|
||||
train_wav_files += wav_files
|
||||
|
||||
else:
|
||||
print("dataset should in {baker, aishell3, ljspeech, vctk} now!")
|
||||
|
||||
train_dump_dir = dumpdir / "train" / "raw"
|
||||
train_dump_dir.mkdir(parents=True, exist_ok=True)
|
||||
dev_dump_dir = dumpdir / "dev" / "raw"
|
||||
dev_dump_dir.mkdir(parents=True, exist_ok=True)
|
||||
test_dump_dir = dumpdir / "test" / "raw"
|
||||
test_dump_dir.mkdir(parents=True, exist_ok=True)
|
||||
|
||||
# Extractor
|
||||
mel_extractor = LogMelFBank(
|
||||
sr=config.fs,
|
||||
n_fft=config.n_fft,
|
||||
hop_length=config.n_shift,
|
||||
win_length=config.win_length,
|
||||
window=config.window,
|
||||
n_mels=config.n_mels,
|
||||
fmin=config.fmin,
|
||||
fmax=config.fmax)
|
||||
|
||||
# process for the 3 sections
|
||||
if train_wav_files:
|
||||
process_sentences(
|
||||
config,
|
||||
train_wav_files,
|
||||
sentences,
|
||||
train_dump_dir,
|
||||
mel_extractor,
|
||||
nprocs=args.num_cpu,
|
||||
cut_sil=args.cut_sil,
|
||||
spk_emb_dir=spk_emb_dir)
|
||||
if dev_wav_files:
|
||||
process_sentences(
|
||||
config,
|
||||
dev_wav_files,
|
||||
sentences,
|
||||
dev_dump_dir,
|
||||
mel_extractor,
|
||||
cut_sil=args.cut_sil,
|
||||
spk_emb_dir=spk_emb_dir)
|
||||
if test_wav_files:
|
||||
process_sentences(
|
||||
config,
|
||||
test_wav_files,
|
||||
sentences,
|
||||
test_dump_dir,
|
||||
mel_extractor,
|
||||
nprocs=args.num_cpu,
|
||||
cut_sil=args.cut_sil,
|
||||
spk_emb_dir=spk_emb_dir)
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
main()
|
@ -0,0 +1,190 @@
|
||||
# Copyright (c) 2021 PaddlePaddle Authors. All Rights Reserved.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
import argparse
|
||||
import logging
|
||||
import os
|
||||
import shutil
|
||||
from pathlib import Path
|
||||
|
||||
import jsonlines
|
||||
import numpy as np
|
||||
import paddle
|
||||
import yaml
|
||||
from paddle import DataParallel
|
||||
from paddle import distributed as dist
|
||||
from paddle.io import DataLoader
|
||||
from paddle.io import DistributedBatchSampler
|
||||
from yacs.config import CfgNode
|
||||
|
||||
from paddlespeech.t2s.datasets.am_batch_fn import tacotron2_single_spk_batch_fn
|
||||
from paddlespeech.t2s.datasets.data_table import DataTable
|
||||
from paddlespeech.t2s.models.new_tacotron2 import Tacotron2
|
||||
from paddlespeech.t2s.models.new_tacotron2 import Tacotron2Evaluator
|
||||
from paddlespeech.t2s.models.new_tacotron2 import Tacotron2Updater
|
||||
from paddlespeech.t2s.training.extensions.snapshot import Snapshot
|
||||
from paddlespeech.t2s.training.extensions.visualizer import VisualDL
|
||||
from paddlespeech.t2s.training.optimizer import build_optimizers
|
||||
from paddlespeech.t2s.training.seeding import seed_everything
|
||||
from paddlespeech.t2s.training.trainer import Trainer
|
||||
|
||||
|
||||
def train_sp(args, config):
|
||||
# decides device type and whether to run in parallel
|
||||
# setup running environment correctly
|
||||
if (not paddle.is_compiled_with_cuda()) or args.ngpu == 0:
|
||||
paddle.set_device("cpu")
|
||||
else:
|
||||
paddle.set_device("gpu")
|
||||
world_size = paddle.distributed.get_world_size()
|
||||
if world_size > 1:
|
||||
paddle.distributed.init_parallel_env()
|
||||
|
||||
# set the random seed, it is a must for multiprocess training
|
||||
seed_everything(config.seed)
|
||||
|
||||
print(
|
||||
f"rank: {dist.get_rank()}, pid: {os.getpid()}, parent_pid: {os.getppid()}",
|
||||
)
|
||||
|
||||
# dataloader has been too verbose
|
||||
logging.getLogger("DataLoader").disabled = True
|
||||
|
||||
# construct dataset for training and validation
|
||||
with jsonlines.open(args.train_metadata, 'r') as reader:
|
||||
train_metadata = list(reader)
|
||||
train_dataset = DataTable(
|
||||
data=train_metadata,
|
||||
fields=[
|
||||
"text",
|
||||
"text_lengths",
|
||||
"speech",
|
||||
"speech_lengths",
|
||||
],
|
||||
converters={
|
||||
"speech": np.load,
|
||||
}, )
|
||||
with jsonlines.open(args.dev_metadata, 'r') as reader:
|
||||
dev_metadata = list(reader)
|
||||
dev_dataset = DataTable(
|
||||
data=dev_metadata,
|
||||
fields=[
|
||||
"text",
|
||||
"text_lengths",
|
||||
"speech",
|
||||
"speech_lengths",
|
||||
],
|
||||
converters={
|
||||
"speech": np.load,
|
||||
}, )
|
||||
|
||||
# collate function and dataloader
|
||||
train_sampler = DistributedBatchSampler(
|
||||
train_dataset,
|
||||
batch_size=config.batch_size,
|
||||
shuffle=True,
|
||||
drop_last=True)
|
||||
|
||||
print("samplers done!")
|
||||
|
||||
train_dataloader = DataLoader(
|
||||
train_dataset,
|
||||
batch_sampler=train_sampler,
|
||||
collate_fn=tacotron2_single_spk_batch_fn,
|
||||
num_workers=config.num_workers)
|
||||
|
||||
dev_dataloader = DataLoader(
|
||||
dev_dataset,
|
||||
shuffle=False,
|
||||
drop_last=False,
|
||||
batch_size=config.batch_size,
|
||||
collate_fn=tacotron2_single_spk_batch_fn,
|
||||
num_workers=config.num_workers)
|
||||
print("dataloaders done!")
|
||||
|
||||
with open(args.phones_dict, "r") as f:
|
||||
phn_id = [line.strip().split() for line in f.readlines()]
|
||||
vocab_size = len(phn_id)
|
||||
print("vocab_size:", vocab_size)
|
||||
|
||||
odim = config.n_mels
|
||||
model = Tacotron2(idim=vocab_size, odim=odim, **config["model"])
|
||||
if world_size > 1:
|
||||
model = DataParallel(model)
|
||||
print("model done!")
|
||||
|
||||
optimizer = build_optimizers(model, **config["optimizer"])
|
||||
print("optimizer done!")
|
||||
|
||||
output_dir = Path(args.output_dir)
|
||||
output_dir.mkdir(parents=True, exist_ok=True)
|
||||
if dist.get_rank() == 0:
|
||||
config_name = args.config.split("/")[-1]
|
||||
# copy conf to output_dir
|
||||
shutil.copyfile(args.config, output_dir / config_name)
|
||||
|
||||
updater = Tacotron2Updater(
|
||||
model=model,
|
||||
optimizer=optimizer,
|
||||
dataloader=train_dataloader,
|
||||
output_dir=output_dir,
|
||||
**config["updater"])
|
||||
|
||||
trainer = Trainer(updater, (config.max_epoch, 'epoch'), output_dir)
|
||||
|
||||
evaluator = Tacotron2Evaluator(
|
||||
model, dev_dataloader, output_dir=output_dir, **config["updater"])
|
||||
|
||||
if dist.get_rank() == 0:
|
||||
trainer.extend(evaluator, trigger=(1, "epoch"))
|
||||
trainer.extend(VisualDL(output_dir), trigger=(1, "iteration"))
|
||||
trainer.extend(
|
||||
Snapshot(max_size=config.num_snapshots), trigger=(1, 'epoch'))
|
||||
# print(trainer.extensions)
|
||||
trainer.run()
|
||||
|
||||
|
||||
def main():
|
||||
# parse args and config and redirect to train_sp
|
||||
parser = argparse.ArgumentParser(description="Train a Tacotron2 model.")
|
||||
parser.add_argument("--config", type=str, help="tacotron2 config file.")
|
||||
parser.add_argument("--train-metadata", type=str, help="training data.")
|
||||
parser.add_argument("--dev-metadata", type=str, help="dev data.")
|
||||
parser.add_argument("--output-dir", type=str, help="output dir.")
|
||||
parser.add_argument(
|
||||
"--ngpu", type=int, default=1, help="if ngpu == 0, use cpu.")
|
||||
parser.add_argument(
|
||||
"--phones-dict", type=str, default=None, help="phone vocabulary file.")
|
||||
|
||||
args = parser.parse_args()
|
||||
|
||||
with open(args.config) as f:
|
||||
config = CfgNode(yaml.safe_load(f))
|
||||
|
||||
print("========Args========")
|
||||
print(yaml.safe_dump(vars(args)))
|
||||
print("========Config========")
|
||||
print(config)
|
||||
print(
|
||||
f"master see the word size: {dist.get_world_size()}, from pid: {os.getpid()}"
|
||||
)
|
||||
|
||||
# dispatch
|
||||
if args.ngpu > 1:
|
||||
dist.spawn(train_sp, (args, config), nprocs=args.ngpu)
|
||||
else:
|
||||
train_sp(args, config)
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
main()
|
@ -0,0 +1,15 @@
|
||||
# Copyright (c) 2020 PaddlePaddle Authors. All Rights Reserved.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
from .tacotron2 import *
|
||||
from .tacotron2_updater import *
|
@ -0,0 +1,500 @@
|
||||
# Copyright (c) 2020 PaddlePaddle Authors. All Rights Reserved.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
"""Tacotron 2 related modules for paddle"""
|
||||
import logging
|
||||
from typing import Dict
|
||||
from typing import Optional
|
||||
from typing import Tuple
|
||||
|
||||
import paddle
|
||||
import paddle.nn.functional as F
|
||||
from paddle import nn
|
||||
from typeguard import check_argument_types
|
||||
|
||||
from paddlespeech.t2s.modules.nets_utils import initialize
|
||||
from paddlespeech.t2s.modules.nets_utils import make_pad_mask
|
||||
from paddlespeech.t2s.modules.tacotron2.attentions import AttForward
|
||||
from paddlespeech.t2s.modules.tacotron2.attentions import AttForwardTA
|
||||
from paddlespeech.t2s.modules.tacotron2.attentions import AttLoc
|
||||
from paddlespeech.t2s.modules.tacotron2.decoder import Decoder
|
||||
from paddlespeech.t2s.modules.tacotron2.encoder import Encoder
|
||||
|
||||
|
||||
class Tacotron2(nn.Layer):
|
||||
"""Tacotron2 module for end-to-end text-to-speech.
|
||||
|
||||
This is a module of Spectrogram prediction network in Tacotron2 described
|
||||
in `Natural TTS Synthesis by Conditioning WaveNet on Mel Spectrogram Predictions`_,
|
||||
which converts the sequence of characters into the sequence of Mel-filterbanks.
|
||||
|
||||
.. _`Natural TTS Synthesis by Conditioning WaveNet on Mel Spectrogram Predictions`:
|
||||
https://arxiv.org/abs/1712.05884
|
||||
|
||||
"""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
# network structure related
|
||||
idim: int,
|
||||
odim: int,
|
||||
embed_dim: int=512,
|
||||
elayers: int=1,
|
||||
eunits: int=512,
|
||||
econv_layers: int=3,
|
||||
econv_chans: int=512,
|
||||
econv_filts: int=5,
|
||||
atype: str="location",
|
||||
adim: int=512,
|
||||
aconv_chans: int=32,
|
||||
aconv_filts: int=15,
|
||||
cumulate_att_w: bool=True,
|
||||
dlayers: int=2,
|
||||
dunits: int=1024,
|
||||
prenet_layers: int=2,
|
||||
prenet_units: int=256,
|
||||
postnet_layers: int=5,
|
||||
postnet_chans: int=512,
|
||||
postnet_filts: int=5,
|
||||
output_activation: str=None,
|
||||
use_batch_norm: bool=True,
|
||||
use_concate: bool=True,
|
||||
use_residual: bool=False,
|
||||
reduction_factor: int=1,
|
||||
# extra embedding related
|
||||
spk_num: Optional[int]=None,
|
||||
lang_num: Optional[int]=None,
|
||||
spk_embed_dim: Optional[int]=None,
|
||||
spk_embed_integration_type: str="concat",
|
||||
dropout_rate: float=0.5,
|
||||
zoneout_rate: float=0.1,
|
||||
# training related
|
||||
init_type: str="xavier_uniform", ):
|
||||
"""Initialize Tacotron2 module.
|
||||
Parameters
|
||||
----------
|
||||
idim : int
|
||||
Dimension of the inputs.
|
||||
odim : int
|
||||
Dimension of the outputs.
|
||||
embed_dim : int
|
||||
Dimension of the token embedding.
|
||||
elayers : int
|
||||
Number of encoder blstm layers.
|
||||
eunits : int
|
||||
Number of encoder blstm units.
|
||||
econv_layers : int
|
||||
Number of encoder conv layers.
|
||||
econv_filts : int
|
||||
Number of encoder conv filter size.
|
||||
econv_chans : int
|
||||
Number of encoder conv filter channels.
|
||||
dlayers : int
|
||||
Number of decoder lstm layers.
|
||||
dunits : int
|
||||
Number of decoder lstm units.
|
||||
prenet_layers : int
|
||||
Number of prenet layers.
|
||||
prenet_units : int
|
||||
Number of prenet units.
|
||||
postnet_layers : int
|
||||
Number of postnet layers.
|
||||
postnet_filts : int
|
||||
Number of postnet filter size.
|
||||
postnet_chans : int
|
||||
Number of postnet filter channels.
|
||||
output_activation : str
|
||||
Name of activation function for outputs.
|
||||
adim : int
|
||||
Number of dimension of mlp in attention.
|
||||
aconv_chans : int
|
||||
Number of attention conv filter channels.
|
||||
aconv_filts : int
|
||||
Number of attention conv filter size.
|
||||
cumulate_att_w : bool
|
||||
Whether to cumulate previous attention weight.
|
||||
use_batch_norm : bool
|
||||
Whether to use batch normalization.
|
||||
use_concate : bool
|
||||
Whether to concat enc outputs w/ dec lstm outputs.
|
||||
reduction_factor : int
|
||||
Reduction factor.
|
||||
spk_num : Optional[int]
|
||||
Number of speakers. If set to > 1, assume that the
|
||||
sids will be provided as the input and use sid embedding layer.
|
||||
lang_num : Optional[int]
|
||||
Number of languages. If set to > 1, assume that the
|
||||
lids will be provided as the input and use sid embedding layer.
|
||||
spk_embed_dim : Optional[int]
|
||||
Speaker embedding dimension. If set to > 0,
|
||||
assume that spk_emb will be provided as the input.
|
||||
spk_embed_integration_type : str
|
||||
How to integrate speaker embedding.
|
||||
dropout_rate : float
|
||||
Dropout rate.
|
||||
zoneout_rate : float
|
||||
Zoneout rate.
|
||||
"""
|
||||
assert check_argument_types()
|
||||
super().__init__()
|
||||
|
||||
# store hyperparameters
|
||||
self.idim = idim
|
||||
self.odim = odim
|
||||
self.eos = idim - 1
|
||||
self.cumulate_att_w = cumulate_att_w
|
||||
self.reduction_factor = reduction_factor
|
||||
|
||||
# define activation function for the final output
|
||||
if output_activation is None:
|
||||
self.output_activation_fn = None
|
||||
elif hasattr(F, output_activation):
|
||||
self.output_activation_fn = getattr(F, output_activation)
|
||||
else:
|
||||
raise ValueError(f"there is no such an activation function. "
|
||||
f"({output_activation})")
|
||||
|
||||
# set padding idx
|
||||
padding_idx = 0
|
||||
self.padding_idx = padding_idx
|
||||
|
||||
# initialize parameters
|
||||
initialize(self, init_type)
|
||||
|
||||
# define network modules
|
||||
self.enc = Encoder(
|
||||
idim=idim,
|
||||
embed_dim=embed_dim,
|
||||
elayers=elayers,
|
||||
eunits=eunits,
|
||||
econv_layers=econv_layers,
|
||||
econv_chans=econv_chans,
|
||||
econv_filts=econv_filts,
|
||||
use_batch_norm=use_batch_norm,
|
||||
use_residual=use_residual,
|
||||
dropout_rate=dropout_rate,
|
||||
padding_idx=padding_idx, )
|
||||
|
||||
self.spk_num = None
|
||||
if spk_num is not None and spk_num > 1:
|
||||
self.spk_num = spk_num
|
||||
self.sid_emb = nn.Embedding(spk_num, eunits)
|
||||
self.lang_num = None
|
||||
if lang_num is not None and lang_num > 1:
|
||||
self.lang_num = lang_num
|
||||
self.lid_emb = nn.Embedding(lang_num, eunits)
|
||||
|
||||
self.spk_embed_dim = None
|
||||
if spk_embed_dim is not None and spk_embed_dim > 0:
|
||||
self.spk_embed_dim = spk_embed_dim
|
||||
self.spk_embed_integration_type = spk_embed_integration_type
|
||||
if self.spk_embed_dim is None:
|
||||
dec_idim = eunits
|
||||
elif self.spk_embed_integration_type == "concat":
|
||||
dec_idim = eunits + spk_embed_dim
|
||||
elif self.spk_embed_integration_type == "add":
|
||||
dec_idim = eunits
|
||||
self.projection = nn.Linear(self.spk_embed_dim, eunits)
|
||||
else:
|
||||
raise ValueError(f"{spk_embed_integration_type} is not supported.")
|
||||
|
||||
if atype == "location":
|
||||
att = AttLoc(dec_idim, dunits, adim, aconv_chans, aconv_filts)
|
||||
elif atype == "forward":
|
||||
att = AttForward(dec_idim, dunits, adim, aconv_chans, aconv_filts)
|
||||
if self.cumulate_att_w:
|
||||
logging.warning("cumulation of attention weights is disabled "
|
||||
"in forward attention.")
|
||||
self.cumulate_att_w = False
|
||||
elif atype == "forward_ta":
|
||||
att = AttForwardTA(dec_idim, dunits, adim, aconv_chans, aconv_filts,
|
||||
odim)
|
||||
if self.cumulate_att_w:
|
||||
logging.warning("cumulation of attention weights is disabled "
|
||||
"in forward attention.")
|
||||
self.cumulate_att_w = False
|
||||
else:
|
||||
raise NotImplementedError("Support only location or forward")
|
||||
self.dec = Decoder(
|
||||
idim=dec_idim,
|
||||
odim=odim,
|
||||
att=att,
|
||||
dlayers=dlayers,
|
||||
dunits=dunits,
|
||||
prenet_layers=prenet_layers,
|
||||
prenet_units=prenet_units,
|
||||
postnet_layers=postnet_layers,
|
||||
postnet_chans=postnet_chans,
|
||||
postnet_filts=postnet_filts,
|
||||
output_activation_fn=self.output_activation_fn,
|
||||
cumulate_att_w=self.cumulate_att_w,
|
||||
use_batch_norm=use_batch_norm,
|
||||
use_concate=use_concate,
|
||||
dropout_rate=dropout_rate,
|
||||
zoneout_rate=zoneout_rate,
|
||||
reduction_factor=reduction_factor, )
|
||||
|
||||
nn.initializer.set_global_initializer(None)
|
||||
|
||||
def forward(
|
||||
self,
|
||||
text: paddle.Tensor,
|
||||
text_lengths: paddle.Tensor,
|
||||
speech: paddle.Tensor,
|
||||
speech_lengths: paddle.Tensor,
|
||||
spk_emb: Optional[paddle.Tensor]=None,
|
||||
spk_id: Optional[paddle.Tensor]=None,
|
||||
lang_id: Optional[paddle.Tensor]=None
|
||||
) -> Tuple[paddle.Tensor, Dict[str, paddle.Tensor], paddle.Tensor]:
|
||||
"""Calculate forward propagation.
|
||||
|
||||
Parameters
|
||||
----------
|
||||
text : Tensor(int64)
|
||||
Batch of padded character ids (B, T_text).
|
||||
text_lengths : Tensor(int64)
|
||||
Batch of lengths of each input batch (B,).
|
||||
speech : Tensor
|
||||
Batch of padded target features (B, T_feats, odim).
|
||||
speech_lengths : Tensor(int64)
|
||||
Batch of the lengths of each target (B,).
|
||||
spk_emb : Optional[Tensor]
|
||||
Batch of speaker embeddings (B, spk_embed_dim).
|
||||
spk_id : Optional[Tensor]
|
||||
Batch of speaker IDs (B, 1).
|
||||
lang_id : Optional[Tensor]
|
||||
Batch of language IDs (B, 1).
|
||||
|
||||
Returns
|
||||
----------
|
||||
Tensor
|
||||
Loss scalar value.
|
||||
Dict
|
||||
Statistics to be monitored.
|
||||
Tensor
|
||||
Weight value if not joint training else model outputs.
|
||||
|
||||
"""
|
||||
text = text[:, :text_lengths.max()]
|
||||
speech = speech[:, :speech_lengths.max()]
|
||||
|
||||
batch_size = paddle.shape(text)[0]
|
||||
|
||||
# Add eos at the last of sequence
|
||||
xs = F.pad(text, [0, 0, 0, 1], "constant", self.padding_idx)
|
||||
for i, l in enumerate(text_lengths):
|
||||
xs[i, l] = self.eos
|
||||
ilens = text_lengths + 1
|
||||
|
||||
ys = speech
|
||||
olens = speech_lengths
|
||||
|
||||
# make labels for stop prediction
|
||||
stop_labels = make_pad_mask(olens - 1)
|
||||
# bool 类型无法切片
|
||||
stop_labels = paddle.cast(stop_labels, dtype='float32')
|
||||
stop_labels = F.pad(stop_labels, [0, 0, 0, 1], "constant", 1.0)
|
||||
|
||||
# calculate tacotron2 outputs
|
||||
after_outs, before_outs, logits, att_ws = self._forward(
|
||||
xs=xs,
|
||||
ilens=ilens,
|
||||
ys=ys,
|
||||
olens=olens,
|
||||
spk_emb=spk_emb,
|
||||
spk_id=spk_id,
|
||||
lang_id=lang_id, )
|
||||
|
||||
# modify mod part of groundtruth
|
||||
if self.reduction_factor > 1:
|
||||
assert olens.ge(self.reduction_factor).all(
|
||||
), "Output length must be greater than or equal to reduction factor."
|
||||
olens = olens - olens % self.reduction_factor
|
||||
max_out = max(olens)
|
||||
ys = ys[:, :max_out]
|
||||
stop_labels = stop_labels[:, :max_out]
|
||||
stop_labels = paddle.scatter(stop_labels, 1,
|
||||
(olens - 1).unsqueeze(1), 1.0)
|
||||
olens_in = olens // self.reduction_factor
|
||||
else:
|
||||
olens_in = olens
|
||||
return after_outs, before_outs, logits, ys, stop_labels, olens, att_ws, olens_in
|
||||
|
||||
def _forward(
|
||||
self,
|
||||
xs: paddle.Tensor,
|
||||
ilens: paddle.Tensor,
|
||||
ys: paddle.Tensor,
|
||||
olens: paddle.Tensor,
|
||||
spk_emb: paddle.Tensor,
|
||||
spk_id: paddle.Tensor,
|
||||
lang_id: paddle.Tensor,
|
||||
) -> Tuple[paddle.Tensor, paddle.Tensor, paddle.Tensor]:
|
||||
|
||||
hs, hlens = self.enc(xs, ilens)
|
||||
if self.spk_num is not None:
|
||||
sid_embs = self.sid_emb(spk_id.reshape([-1]))
|
||||
hs = hs + sid_embs.unsqueeze(1)
|
||||
if self.lang_num is not None:
|
||||
lid_embs = self.lid_emb(lang_id.reshape([-1]))
|
||||
hs = hs + lid_embs.unsqueeze(1)
|
||||
if self.spk_embed_dim is not None:
|
||||
hs = self._integrate_with_spk_embed(hs, spk_emb)
|
||||
|
||||
return self.dec(hs, hlens, ys)
|
||||
|
||||
def inference(
|
||||
self,
|
||||
text: paddle.Tensor,
|
||||
speech: Optional[paddle.Tensor]=None,
|
||||
spk_emb: Optional[paddle.Tensor]=None,
|
||||
spk_id: Optional[paddle.Tensor]=None,
|
||||
lang_id: Optional[paddle.Tensor]=None,
|
||||
threshold: float=0.5,
|
||||
minlenratio: float=0.0,
|
||||
maxlenratio: float=10.0,
|
||||
use_att_constraint: bool=False,
|
||||
backward_window: int=1,
|
||||
forward_window: int=3,
|
||||
use_teacher_forcing: bool=False, ) -> Dict[str, paddle.Tensor]:
|
||||
"""Generate the sequence of features given the sequences of characters.
|
||||
|
||||
Parameters
|
||||
----------
|
||||
text Tensor(int64)
|
||||
Input sequence of characters (T_text,).
|
||||
speech : Optional[Tensor]
|
||||
Feature sequence to extract style (N, idim).
|
||||
spk_emb : ptional[Tensor]
|
||||
Speaker embedding (spk_embed_dim,).
|
||||
spk_id : Optional[Tensor]
|
||||
Speaker ID (1,).
|
||||
lang_id : Optional[Tensor]
|
||||
Language ID (1,).
|
||||
threshold : float
|
||||
Threshold in inference.
|
||||
minlenratio : float
|
||||
Minimum length ratio in inference.
|
||||
maxlenratio : float
|
||||
Maximum length ratio in inference.
|
||||
use_att_constraint : bool
|
||||
Whether to apply attention constraint.
|
||||
backward_window : int
|
||||
Backward window in attention constraint.
|
||||
forward_window : int
|
||||
Forward window in attention constraint.
|
||||
use_teacher_forcing : bool
|
||||
Whether to use teacher forcing.
|
||||
|
||||
Return
|
||||
----------
|
||||
Dict[str, Tensor]
|
||||
Output dict including the following items:
|
||||
* feat_gen (Tensor): Output sequence of features (T_feats, odim).
|
||||
* prob (Tensor): Output sequence of stop probabilities (T_feats,).
|
||||
* att_w (Tensor): Attention weights (T_feats, T).
|
||||
|
||||
"""
|
||||
x = text
|
||||
y = speech
|
||||
|
||||
# add eos at the last of sequence
|
||||
x = F.pad(x, [0, 1], "constant", self.eos)
|
||||
|
||||
# inference with teacher forcing
|
||||
if use_teacher_forcing:
|
||||
assert speech is not None, "speech must be provided with teacher forcing."
|
||||
|
||||
xs, ys = x.unsqueeze(0), y.unsqueeze(0)
|
||||
spk_emb = None if spk_emb is None else spk_emb.unsqueeze(0)
|
||||
ilens = paddle.shape(xs)[1]
|
||||
olens = paddle.shape(ys)[1]
|
||||
outs, _, _, att_ws = self._forward(
|
||||
xs=xs,
|
||||
ilens=ilens,
|
||||
ys=ys,
|
||||
olens=olens,
|
||||
spk_emb=spk_emb,
|
||||
spk_id=spk_id,
|
||||
lang_id=lang_id, )
|
||||
|
||||
return dict(feat_gen=outs[0], att_w=att_ws[0])
|
||||
|
||||
# inference
|
||||
h = self.enc.inference(x)
|
||||
if self.spk_num is not None:
|
||||
sid_emb = self.sid_emb(spk_id.reshape([-1]))
|
||||
h = h + sid_emb
|
||||
if self.lang_num is not None:
|
||||
lid_emb = self.lid_emb(lang_id.reshape([-1]))
|
||||
h = h + lid_emb
|
||||
if self.spk_embed_dim is not None:
|
||||
hs, spk_emb = h.unsqueeze(0), spk_emb.unsqueeze(0)
|
||||
h = self._integrate_with_spk_embed(hs, spk_emb)[0]
|
||||
out, prob, att_w = self.dec.inference(
|
||||
h,
|
||||
threshold=threshold,
|
||||
minlenratio=minlenratio,
|
||||
maxlenratio=maxlenratio,
|
||||
use_att_constraint=use_att_constraint,
|
||||
backward_window=backward_window,
|
||||
forward_window=forward_window, )
|
||||
|
||||
return dict(feat_gen=out, prob=prob, att_w=att_w)
|
||||
|
||||
def _integrate_with_spk_embed(self,
|
||||
hs: paddle.Tensor,
|
||||
spk_emb: paddle.Tensor) -> paddle.Tensor:
|
||||
"""Integrate speaker embedding with hidden states.
|
||||
|
||||
Parameters
|
||||
----------
|
||||
hs : Tensor
|
||||
Batch of hidden state sequences (B, Tmax, eunits).
|
||||
spk_emb : Tensor
|
||||
Batch of speaker embeddings (B, spk_embed_dim).
|
||||
|
||||
Returns
|
||||
----------
|
||||
Tensor
|
||||
Batch of integrated hidden state sequences (B, Tmax, eunits) if
|
||||
integration_type is "add" else (B, Tmax, eunits + spk_embed_dim).
|
||||
|
||||
"""
|
||||
if self.spk_embed_integration_type == "add":
|
||||
# apply projection and then add to hidden states
|
||||
spk_emb = self.projection(F.normalize(spk_emb))
|
||||
hs = hs + spk_emb.unsqueeze(1)
|
||||
elif self.spk_embed_integration_type == "concat":
|
||||
# concat hidden states with spk embeds
|
||||
spk_emb = F.normalize(spk_emb).unsqueeze(1).expand(
|
||||
-1, paddle.shape(hs)[1], -1)
|
||||
hs = paddle.concat([hs, spk_emb], axis=-1)
|
||||
else:
|
||||
raise NotImplementedError("support only add or concat.")
|
||||
|
||||
return hs
|
||||
|
||||
|
||||
class Tacotron2Inference(nn.Layer):
|
||||
def __init__(self, normalizer, model):
|
||||
super().__init__()
|
||||
self.normalizer = normalizer
|
||||
self.acoustic_model = model
|
||||
|
||||
def forward(self, text, spk_id=None, spk_emb=None):
|
||||
out = self.acoustic_model.inference(
|
||||
text, spk_id=spk_id, spk_emb=spk_emb)
|
||||
normalized_mel = out["feat_gen"]
|
||||
logmel = self.normalizer.inverse(normalized_mel)
|
||||
return logmel
|
@ -0,0 +1,219 @@
|
||||
# Copyright (c) 2021 PaddlePaddle Authors. All Rights Reserved.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
import logging
|
||||
from pathlib import Path
|
||||
|
||||
from paddle import distributed as dist
|
||||
from paddle.io import DataLoader
|
||||
from paddle.nn import Layer
|
||||
from paddle.optimizer import Optimizer
|
||||
|
||||
from paddlespeech.t2s.modules.losses import GuidedAttentionLoss
|
||||
from paddlespeech.t2s.modules.losses import Tacotron2Loss
|
||||
from paddlespeech.t2s.training.extensions.evaluator import StandardEvaluator
|
||||
from paddlespeech.t2s.training.reporter import report
|
||||
from paddlespeech.t2s.training.updaters.standard_updater import StandardUpdater
|
||||
logging.basicConfig(
|
||||
format='%(asctime)s [%(levelname)s] [%(filename)s:%(lineno)d] %(message)s',
|
||||
datefmt='[%Y-%m-%d %H:%M:%S]')
|
||||
logger = logging.getLogger(__name__)
|
||||
logger.setLevel(logging.INFO)
|
||||
|
||||
|
||||
class Tacotron2Updater(StandardUpdater):
|
||||
def __init__(self,
|
||||
model: Layer,
|
||||
optimizer: Optimizer,
|
||||
dataloader: DataLoader,
|
||||
init_state=None,
|
||||
use_masking: bool=True,
|
||||
use_weighted_masking: bool=False,
|
||||
bce_pos_weight: float=5.0,
|
||||
loss_type: str="L1+L2",
|
||||
use_guided_attn_loss: bool=True,
|
||||
guided_attn_loss_sigma: float=0.4,
|
||||
guided_attn_loss_lambda: float=1.0,
|
||||
output_dir: Path=None):
|
||||
super().__init__(model, optimizer, dataloader, init_state=None)
|
||||
|
||||
self.loss_type = loss_type
|
||||
self.use_guided_attn_loss = use_guided_attn_loss
|
||||
|
||||
self.taco2_loss = Tacotron2Loss(
|
||||
use_masking=use_masking,
|
||||
use_weighted_masking=use_weighted_masking,
|
||||
bce_pos_weight=bce_pos_weight, )
|
||||
if self.use_guided_attn_loss:
|
||||
self.attn_loss = GuidedAttentionLoss(
|
||||
sigma=guided_attn_loss_sigma,
|
||||
alpha=guided_attn_loss_lambda, )
|
||||
|
||||
log_file = output_dir / 'worker_{}.log'.format(dist.get_rank())
|
||||
self.filehandler = logging.FileHandler(str(log_file))
|
||||
logger.addHandler(self.filehandler)
|
||||
self.logger = logger
|
||||
self.msg = ""
|
||||
|
||||
def update_core(self, batch):
|
||||
self.msg = "Rank: {}, ".format(dist.get_rank())
|
||||
losses_dict = {}
|
||||
# spk_id!=None in multiple spk fastspeech2
|
||||
spk_id = batch["spk_id"] if "spk_id" in batch else None
|
||||
spk_emb = batch["spk_emb"] if "spk_emb" in batch else None
|
||||
if spk_emb is not None:
|
||||
spk_id = None
|
||||
|
||||
after_outs, before_outs, logits, ys, stop_labels, olens, att_ws, olens_in = self.model(
|
||||
text=batch["text"],
|
||||
text_lengths=batch["text_lengths"],
|
||||
speech=batch["speech"],
|
||||
speech_lengths=batch["speech_lengths"],
|
||||
spk_id=spk_id,
|
||||
spk_emb=spk_emb)
|
||||
|
||||
# calculate taco2 loss
|
||||
l1_loss, mse_loss, bce_loss = self.taco2_loss(
|
||||
after_outs=after_outs,
|
||||
before_outs=before_outs,
|
||||
logits=logits,
|
||||
ys=ys,
|
||||
stop_labels=stop_labels,
|
||||
olens=olens)
|
||||
|
||||
if self.loss_type == "L1+L2":
|
||||
loss = l1_loss + mse_loss + bce_loss
|
||||
elif self.loss_type == "L1":
|
||||
loss = l1_loss + bce_loss
|
||||
elif self.loss_type == "L2":
|
||||
loss = mse_loss + bce_loss
|
||||
else:
|
||||
raise ValueError(f"unknown --loss-type {self.loss_type}")
|
||||
|
||||
# calculate attention loss
|
||||
if self.use_guided_attn_loss:
|
||||
# NOTE: length of output for auto-regressive
|
||||
# input will be changed when r > 1
|
||||
attn_loss = self.attn_loss(
|
||||
att_ws=att_ws, ilens=batch["text_lengths"] + 1, olens=olens_in)
|
||||
loss = loss + attn_loss
|
||||
|
||||
optimizer = self.optimizer
|
||||
optimizer.clear_grad()
|
||||
loss.backward()
|
||||
optimizer.step()
|
||||
|
||||
report("train/l1_loss", float(l1_loss))
|
||||
report("train/mse_loss", float(mse_loss))
|
||||
report("train/bce_loss", float(bce_loss))
|
||||
report("train/attn_loss", float(attn_loss))
|
||||
report("train/loss", float(loss))
|
||||
|
||||
losses_dict["l1_loss"] = float(l1_loss)
|
||||
losses_dict["mse_loss"] = float(mse_loss)
|
||||
losses_dict["bce_loss"] = float(bce_loss)
|
||||
losses_dict["attn_loss"] = float(attn_loss)
|
||||
losses_dict["loss"] = float(loss)
|
||||
self.msg += ', '.join('{}: {:>.6f}'.format(k, v)
|
||||
for k, v in losses_dict.items())
|
||||
|
||||
|
||||
class Tacotron2Evaluator(StandardEvaluator):
|
||||
def __init__(self,
|
||||
model: Layer,
|
||||
dataloader: DataLoader,
|
||||
use_masking: bool=True,
|
||||
use_weighted_masking: bool=False,
|
||||
bce_pos_weight: float=5.0,
|
||||
loss_type: str="L1+L2",
|
||||
use_guided_attn_loss: bool=True,
|
||||
guided_attn_loss_sigma: float=0.4,
|
||||
guided_attn_loss_lambda: float=1.0,
|
||||
output_dir=None):
|
||||
super().__init__(model, dataloader)
|
||||
|
||||
self.loss_type = loss_type
|
||||
self.use_guided_attn_loss = use_guided_attn_loss
|
||||
|
||||
self.taco2_loss = Tacotron2Loss(
|
||||
use_masking=use_masking,
|
||||
use_weighted_masking=use_weighted_masking,
|
||||
bce_pos_weight=bce_pos_weight, )
|
||||
if self.use_guided_attn_loss:
|
||||
self.attn_loss = GuidedAttentionLoss(
|
||||
sigma=guided_attn_loss_sigma,
|
||||
alpha=guided_attn_loss_lambda, )
|
||||
|
||||
log_file = output_dir / 'worker_{}.log'.format(dist.get_rank())
|
||||
self.filehandler = logging.FileHandler(str(log_file))
|
||||
logger.addHandler(self.filehandler)
|
||||
self.logger = logger
|
||||
self.msg = ""
|
||||
|
||||
def evaluate_core(self, batch):
|
||||
self.msg = "Evaluate: "
|
||||
losses_dict = {}
|
||||
# spk_id!=None in multiple spk fastspeech2
|
||||
spk_id = batch["spk_id"] if "spk_id" in batch else None
|
||||
spk_emb = batch["spk_emb"] if "spk_emb" in batch else None
|
||||
if spk_emb is not None:
|
||||
spk_id = None
|
||||
|
||||
after_outs, before_outs, logits, ys, stop_labels, olens, att_ws, olens_in = self.model(
|
||||
text=batch["text"],
|
||||
text_lengths=batch["text_lengths"],
|
||||
speech=batch["speech"],
|
||||
speech_lengths=batch["speech_lengths"],
|
||||
spk_id=spk_id,
|
||||
spk_emb=spk_emb)
|
||||
|
||||
# calculate taco2 loss
|
||||
l1_loss, mse_loss, bce_loss = self.taco2_loss(
|
||||
after_outs=after_outs,
|
||||
before_outs=before_outs,
|
||||
logits=logits,
|
||||
ys=ys,
|
||||
stop_labels=stop_labels,
|
||||
olens=olens)
|
||||
|
||||
if self.loss_type == "L1+L2":
|
||||
loss = l1_loss + mse_loss + bce_loss
|
||||
elif self.loss_type == "L1":
|
||||
loss = l1_loss + bce_loss
|
||||
elif self.loss_type == "L2":
|
||||
loss = mse_loss + bce_loss
|
||||
else:
|
||||
raise ValueError(f"unknown --loss-type {self.loss_type}")
|
||||
|
||||
# calculate attention loss
|
||||
if self.use_guided_attn_loss:
|
||||
# NOTE: length of output for auto-regressive
|
||||
# input will be changed when r > 1
|
||||
attn_loss = self.attn_loss(
|
||||
att_ws=att_ws, ilens=batch["text_lengths"] + 1, olens=olens_in)
|
||||
loss = loss + attn_loss
|
||||
|
||||
report("eval/l1_loss", float(l1_loss))
|
||||
report("eval/mse_loss", float(mse_loss))
|
||||
report("eval/bce_loss", float(bce_loss))
|
||||
report("eval/attn_loss", float(attn_loss))
|
||||
report("eval/loss", float(loss))
|
||||
|
||||
losses_dict["l1_loss"] = float(l1_loss)
|
||||
losses_dict["mse_loss"] = float(mse_loss)
|
||||
losses_dict["bce_loss"] = float(bce_loss)
|
||||
losses_dict["attn_loss"] = float(attn_loss)
|
||||
losses_dict["loss"] = float(loss)
|
||||
self.msg += ', '.join('{}: {:>.6f}'.format(k, v)
|
||||
for k, v in losses_dict.items())
|
||||
self.logger.info(self.msg)
|
@ -0,0 +1,519 @@
|
||||
# Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
"""Attention modules for RNN."""
|
||||
import paddle
|
||||
import paddle.nn.functional as F
|
||||
from paddle import nn
|
||||
|
||||
from paddlespeech.t2s.modules.masked_fill import masked_fill
|
||||
from paddlespeech.t2s.modules.nets_utils import make_pad_mask
|
||||
|
||||
|
||||
def _apply_attention_constraint(e,
|
||||
last_attended_idx,
|
||||
backward_window=1,
|
||||
forward_window=3):
|
||||
"""Apply monotonic attention constraint.
|
||||
|
||||
This function apply the monotonic attention constraint
|
||||
introduced in `Deep Voice 3: Scaling
|
||||
Text-to-Speech with Convolutional Sequence Learning`_.
|
||||
|
||||
Parameters
|
||||
----------
|
||||
e : Tensor
|
||||
Attention energy before applying softmax (1, T).
|
||||
last_attended_idx : int
|
||||
The index of the inputs of the last attended [0, T].
|
||||
backward_window : int, optional
|
||||
Backward window size in attention constraint.
|
||||
forward_window : int, optional
|
||||
Forward window size in attetion constraint.
|
||||
|
||||
Returns
|
||||
----------
|
||||
Tensor
|
||||
Monotonic constrained attention energy (1, T).
|
||||
|
||||
.. _`Deep Voice 3: Scaling Text-to-Speech with Convolutional Sequence Learning`:
|
||||
https://arxiv.org/abs/1710.07654
|
||||
|
||||
"""
|
||||
if paddle.shape(e)[0] != 1:
|
||||
raise NotImplementedError(
|
||||
"Batch attention constraining is not yet supported.")
|
||||
backward_idx = last_attended_idx - backward_window
|
||||
forward_idx = last_attended_idx + forward_window
|
||||
if backward_idx > 0:
|
||||
e[:, :backward_idx] = -float("inf")
|
||||
if forward_idx < paddle.shape(e)[1]:
|
||||
e[:, forward_idx:] = -float("inf")
|
||||
return e
|
||||
|
||||
|
||||
class AttLoc(nn.Layer):
|
||||
"""location-aware attention module.
|
||||
|
||||
Reference: Attention-Based Models for Speech Recognition
|
||||
(https://arxiv.org/pdf/1506.07503.pdf)
|
||||
Parameters
|
||||
----------
|
||||
eprojs : int
|
||||
projection-units of encoder
|
||||
dunits : int
|
||||
units of decoder
|
||||
att_dim : int
|
||||
att_dim: attention dimension
|
||||
aconv_chans : int
|
||||
channels of attention convolution
|
||||
aconv_filts : int
|
||||
filter size of attention convolution
|
||||
han_mode : bool
|
||||
flag to swith on mode of hierarchical attention and not store pre_compute_enc_h
|
||||
"""
|
||||
|
||||
def __init__(self,
|
||||
eprojs,
|
||||
dunits,
|
||||
att_dim,
|
||||
aconv_chans,
|
||||
aconv_filts,
|
||||
han_mode=False):
|
||||
super().__init__()
|
||||
self.mlp_enc = nn.Linear(eprojs, att_dim)
|
||||
self.mlp_dec = nn.Linear(dunits, att_dim, bias_attr=False)
|
||||
self.mlp_att = nn.Linear(aconv_chans, att_dim, bias_attr=False)
|
||||
self.loc_conv = nn.Conv2D(
|
||||
1,
|
||||
aconv_chans,
|
||||
(1, 2 * aconv_filts + 1),
|
||||
padding=(0, aconv_filts),
|
||||
bias_attr=False, )
|
||||
self.gvec = nn.Linear(att_dim, 1)
|
||||
|
||||
self.dunits = dunits
|
||||
self.eprojs = eprojs
|
||||
self.att_dim = att_dim
|
||||
self.h_length = None
|
||||
self.enc_h = None
|
||||
self.pre_compute_enc_h = None
|
||||
self.mask = None
|
||||
self.han_mode = han_mode
|
||||
|
||||
def reset(self):
|
||||
"""reset states"""
|
||||
self.h_length = None
|
||||
self.enc_h = None
|
||||
self.pre_compute_enc_h = None
|
||||
self.mask = None
|
||||
|
||||
def forward(
|
||||
self,
|
||||
enc_hs_pad,
|
||||
enc_hs_len,
|
||||
dec_z,
|
||||
att_prev,
|
||||
scaling=2.0,
|
||||
last_attended_idx=None,
|
||||
backward_window=1,
|
||||
forward_window=3, ):
|
||||
"""Calculate AttLoc forward propagation.
|
||||
Parameters
|
||||
----------
|
||||
enc_hs_pad : paddle.Tensor
|
||||
padded encoder hidden state (B, T_max, D_enc)
|
||||
enc_hs_len : paddle.Tensor
|
||||
padded encoder hidden state length (B)
|
||||
dec_z : paddle.Tensor dec_z
|
||||
decoder hidden state (B, D_dec)
|
||||
att_prev : paddle.Tensor
|
||||
previous attention weight (B, T_max)
|
||||
scaling : float
|
||||
scaling parameter before applying softmax
|
||||
forward_window : paddle.Tensor
|
||||
forward window size when constraining attention
|
||||
last_attended_idx : int
|
||||
index of the inputs of the last attended
|
||||
backward_window : int
|
||||
backward window size in attention constraint
|
||||
forward_window : int
|
||||
forward window size in attetion constraint
|
||||
|
||||
Returns
|
||||
----------
|
||||
paddle.Tensor
|
||||
attention weighted encoder state (B, D_enc)
|
||||
paddle.Tensor
|
||||
previous attention weights (B, T_max)
|
||||
"""
|
||||
batch = len(enc_hs_pad)
|
||||
# pre-compute all h outside the decoder loop
|
||||
if self.pre_compute_enc_h is None or self.han_mode:
|
||||
# (utt, frame, hdim)
|
||||
self.enc_h = enc_hs_pad
|
||||
self.h_length = paddle.shape(self.enc_h)[1]
|
||||
# (utt, frame, att_dim)
|
||||
self.pre_compute_enc_h = self.mlp_enc(self.enc_h)
|
||||
|
||||
if dec_z is None:
|
||||
dec_z = paddle.zeros([batch, self.dunits])
|
||||
else:
|
||||
dec_z = dec_z.reshape([batch, self.dunits])
|
||||
|
||||
# initialize attention weight with uniform dist.
|
||||
if att_prev is None:
|
||||
# if no bias, 0 0-pad goes 0
|
||||
|
||||
att_prev = 1.0 - make_pad_mask(enc_hs_len)
|
||||
att_prev = att_prev / enc_hs_len.unsqueeze(-1)
|
||||
|
||||
# att_prev: (utt, frame) -> (utt, 1, 1, frame)
|
||||
# -> (utt, att_conv_chans, 1, frame)
|
||||
|
||||
att_conv = self.loc_conv(att_prev.reshape([batch, 1, 1, self.h_length]))
|
||||
# att_conv: (utt, att_conv_chans, 1, frame) -> (utt, frame, att_conv_chans)
|
||||
att_conv = att_conv.squeeze(2).transpose([0, 2, 1])
|
||||
# att_conv: (utt, frame, att_conv_chans) -> (utt, frame, att_dim)
|
||||
att_conv = self.mlp_att(att_conv)
|
||||
|
||||
# dec_z_tiled: (utt, frame, att_dim)
|
||||
dec_z_tiled = self.mlp_dec(dec_z).reshape([batch, 1, self.att_dim])
|
||||
|
||||
# dot with gvec
|
||||
# (utt, frame, att_dim) -> (utt, frame)
|
||||
e = self.gvec(
|
||||
paddle.tanh(att_conv + self.pre_compute_enc_h +
|
||||
dec_z_tiled)).squeeze(2)
|
||||
|
||||
# NOTE: consider zero padding when compute w.
|
||||
if self.mask is None:
|
||||
self.mask = make_pad_mask(enc_hs_len)
|
||||
e = masked_fill(e, self.mask, -float("inf"))
|
||||
# apply monotonic attention constraint (mainly for TTS)
|
||||
if last_attended_idx is not None:
|
||||
e = _apply_attention_constraint(e, last_attended_idx,
|
||||
backward_window, forward_window)
|
||||
|
||||
w = F.softmax(scaling * e, axis=1)
|
||||
|
||||
# weighted sum over frames
|
||||
# utt x hdim
|
||||
c = paddle.sum(
|
||||
self.enc_h * w.reshape([batch, self.h_length, 1]), axis=1)
|
||||
|
||||
return c, w
|
||||
|
||||
|
||||
class AttForward(nn.Layer):
|
||||
"""Forward attention module.
|
||||
Reference
|
||||
----------
|
||||
Forward attention in sequence-to-sequence acoustic modeling for speech synthesis
|
||||
(https://arxiv.org/pdf/1807.06736.pdf)
|
||||
|
||||
Parameters
|
||||
----------
|
||||
eprojs : int
|
||||
projection-units of encoder
|
||||
dunits : int
|
||||
units of decoder
|
||||
att_dim : int
|
||||
attention dimension
|
||||
aconv_chans : int
|
||||
channels of attention convolution
|
||||
aconv_filts : int
|
||||
filter size of attention convolution
|
||||
"""
|
||||
|
||||
def __init__(self, eprojs, dunits, att_dim, aconv_chans, aconv_filts):
|
||||
super().__init__()
|
||||
self.mlp_enc = nn.Linear(eprojs, att_dim)
|
||||
self.mlp_dec = nn.Linear(dunits, att_dim, bias_attr=False)
|
||||
self.mlp_att = nn.Linear(aconv_chans, att_dim, bias_attr=False)
|
||||
self.loc_conv = nn.Conv2D(
|
||||
1,
|
||||
aconv_chans,
|
||||
(1, 2 * aconv_filts + 1),
|
||||
padding=(0, aconv_filts),
|
||||
bias_attr=False, )
|
||||
self.gvec = nn.Linear(att_dim, 1)
|
||||
self.dunits = dunits
|
||||
self.eprojs = eprojs
|
||||
self.att_dim = att_dim
|
||||
self.h_length = None
|
||||
self.enc_h = None
|
||||
self.pre_compute_enc_h = None
|
||||
self.mask = None
|
||||
|
||||
def reset(self):
|
||||
"""reset states"""
|
||||
self.h_length = None
|
||||
self.enc_h = None
|
||||
self.pre_compute_enc_h = None
|
||||
self.mask = None
|
||||
|
||||
def forward(
|
||||
self,
|
||||
enc_hs_pad,
|
||||
enc_hs_len,
|
||||
dec_z,
|
||||
att_prev,
|
||||
scaling=1.0,
|
||||
last_attended_idx=None,
|
||||
backward_window=1,
|
||||
forward_window=3, ):
|
||||
"""Calculate AttForward forward propagation.
|
||||
Parameters
|
||||
----------
|
||||
enc_hs_pad : paddle.Tensor
|
||||
padded encoder hidden state (B, T_max, D_enc)
|
||||
enc_hs_len : list
|
||||
padded encoder hidden state length (B,)
|
||||
dec_z : paddle.Tensor
|
||||
decoder hidden state (B, D_dec)
|
||||
att_prev : paddle.Tensor
|
||||
attention weights of previous step (B, T_max)
|
||||
scaling : float
|
||||
scaling parameter before applying softmax
|
||||
last_attended_idx : int
|
||||
index of the inputs of the last attended
|
||||
backward_window : int
|
||||
backward window size in attention constraint
|
||||
forward_window : int
|
||||
forward window size in attetion constraint
|
||||
Returns
|
||||
----------
|
||||
paddle.Tensor
|
||||
attention weighted encoder state (B, D_enc)
|
||||
paddle.Tensor
|
||||
previous attention weights (B, T_max)
|
||||
"""
|
||||
batch = len(enc_hs_pad)
|
||||
# pre-compute all h outside the decoder loop
|
||||
if self.pre_compute_enc_h is None:
|
||||
self.enc_h = enc_hs_pad # utt x frame x hdim
|
||||
self.h_length = paddle.shape(self.enc_h)[1]
|
||||
# utt x frame x att_dim
|
||||
self.pre_compute_enc_h = self.mlp_enc(self.enc_h)
|
||||
|
||||
if dec_z is None:
|
||||
dec_z = paddle.zeros([batch, self.dunits])
|
||||
else:
|
||||
dec_z = dec_z.reshape([batch, self.dunits])
|
||||
|
||||
if att_prev is None:
|
||||
# initial attention will be [1, 0, 0, ...]
|
||||
att_prev = paddle.zeros([*paddle.shape(enc_hs_pad)[:2]])
|
||||
att_prev[:, 0] = 1.0
|
||||
|
||||
# att_prev: utt x frame -> utt x 1 x 1 x frame
|
||||
# -> utt x att_conv_chans x 1 x frame
|
||||
att_conv = self.loc_conv(att_prev.reshape([batch, 1, 1, self.h_length]))
|
||||
# att_conv: utt x att_conv_chans x 1 x frame -> utt x frame x att_conv_chans
|
||||
att_conv = att_conv.squeeze(2).transpose([0, 2, 1])
|
||||
# att_conv: utt x frame x att_conv_chans -> utt x frame x att_dim
|
||||
att_conv = self.mlp_att(att_conv)
|
||||
|
||||
# dec_z_tiled: utt x frame x att_dim
|
||||
dec_z_tiled = self.mlp_dec(dec_z).unsqueeze(1)
|
||||
|
||||
# dot with gvec
|
||||
# utt x frame x att_dim -> utt x frame
|
||||
e = self.gvec(
|
||||
paddle.tanh(self.pre_compute_enc_h + dec_z_tiled +
|
||||
att_conv)).squeeze(2)
|
||||
|
||||
# NOTE: consider zero padding when compute w.
|
||||
if self.mask is None:
|
||||
self.mask = make_pad_mask(enc_hs_len)
|
||||
e = masked_fill(e, self.mask, -float("inf"))
|
||||
|
||||
# apply monotonic attention constraint (mainly for TTS)
|
||||
if last_attended_idx is not None:
|
||||
e = _apply_attention_constraint(e, last_attended_idx,
|
||||
backward_window, forward_window)
|
||||
|
||||
w = F.softmax(scaling * e, axis=1)
|
||||
|
||||
# forward attention
|
||||
att_prev_shift = F.pad(att_prev, (0, 0, 1, 0))[:, :-1]
|
||||
|
||||
w = (att_prev + att_prev_shift) * w
|
||||
# NOTE: clip is needed to avoid nan gradient
|
||||
w = F.normalize(paddle.clip(w, 1e-6), p=1, axis=1)
|
||||
|
||||
# weighted sum over flames
|
||||
# utt x hdim
|
||||
# NOTE use bmm instead of sum(*)
|
||||
c = paddle.sum(self.enc_h * w.unsqueeze(-1), axis=1)
|
||||
|
||||
return c, w
|
||||
|
||||
|
||||
class AttForwardTA(nn.Layer):
|
||||
"""Forward attention with transition agent module.
|
||||
Reference
|
||||
----------
|
||||
Forward attention in sequence-to-sequence acoustic modeling for speech synthesis
|
||||
(https://arxiv.org/pdf/1807.06736.pdf)
|
||||
Parameters
|
||||
----------
|
||||
eunits : int
|
||||
units of encoder
|
||||
dunits : int
|
||||
units of decoder
|
||||
att_dim : int
|
||||
attention dimension
|
||||
aconv_chans : int
|
||||
channels of attention convolution
|
||||
aconv_filts : int
|
||||
filter size of attention convolution
|
||||
odim : int
|
||||
output dimension
|
||||
"""
|
||||
|
||||
def __init__(self, eunits, dunits, att_dim, aconv_chans, aconv_filts, odim):
|
||||
super().__init__()
|
||||
self.mlp_enc = nn.Linear(eunits, att_dim)
|
||||
self.mlp_dec = nn.Linear(dunits, att_dim, bias_attr=False)
|
||||
self.mlp_ta = nn.Linear(eunits + dunits + odim, 1)
|
||||
self.mlp_att = nn.Linear(aconv_chans, att_dim, bias_attr=False)
|
||||
self.loc_conv = nn.Conv2D(
|
||||
1,
|
||||
aconv_chans,
|
||||
(1, 2 * aconv_filts + 1),
|
||||
padding=(0, aconv_filts),
|
||||
bias_attr=False, )
|
||||
self.gvec = nn.Linear(att_dim, 1)
|
||||
self.dunits = dunits
|
||||
self.eunits = eunits
|
||||
self.att_dim = att_dim
|
||||
self.h_length = None
|
||||
self.enc_h = None
|
||||
self.pre_compute_enc_h = None
|
||||
self.mask = None
|
||||
self.trans_agent_prob = 0.5
|
||||
|
||||
def reset(self):
|
||||
self.h_length = None
|
||||
self.enc_h = None
|
||||
self.pre_compute_enc_h = None
|
||||
self.mask = None
|
||||
self.trans_agent_prob = 0.5
|
||||
|
||||
def forward(
|
||||
self,
|
||||
enc_hs_pad,
|
||||
enc_hs_len,
|
||||
dec_z,
|
||||
att_prev,
|
||||
out_prev,
|
||||
scaling=1.0,
|
||||
last_attended_idx=None,
|
||||
backward_window=1,
|
||||
forward_window=3, ):
|
||||
"""Calculate AttForwardTA forward propagation.
|
||||
Parameters
|
||||
----------
|
||||
enc_hs_pad : paddle.Tensor
|
||||
padded encoder hidden state (B, Tmax, eunits)
|
||||
enc_hs_len : list paddle.Tensor
|
||||
padded encoder hidden state length (B,)
|
||||
dec_z : paddle.Tensor
|
||||
decoder hidden state (B, dunits)
|
||||
att_prev : paddle.Tensor
|
||||
attention weights of previous step (B, T_max)
|
||||
out_prev : paddle.Tensor
|
||||
decoder outputs of previous step (B, odim)
|
||||
scaling : float
|
||||
scaling parameter before applying softmax
|
||||
last_attended_idx : int
|
||||
index of the inputs of the last attended
|
||||
backward_window : int
|
||||
backward window size in attention constraint
|
||||
forward_window : int
|
||||
forward window size in attetion constraint
|
||||
Returns
|
||||
----------
|
||||
paddle.Tensor
|
||||
attention weighted encoder state (B, dunits)
|
||||
paddle.Tensor
|
||||
previous attention weights (B, Tmax)
|
||||
"""
|
||||
batch = len(enc_hs_pad)
|
||||
# pre-compute all h outside the decoder loop
|
||||
if self.pre_compute_enc_h is None:
|
||||
self.enc_h = enc_hs_pad # utt x frame x hdim
|
||||
self.h_length = paddle.shape(self.enc_h)[1]
|
||||
# utt x frame x att_dim
|
||||
self.pre_compute_enc_h = self.mlp_enc(self.enc_h)
|
||||
|
||||
if dec_z is None:
|
||||
dec_z = paddle.zeros([batch, self.dunits])
|
||||
else:
|
||||
dec_z = dec_z.reshape([batch, self.dunits])
|
||||
|
||||
if att_prev is None:
|
||||
# initial attention will be [1, 0, 0, ...]
|
||||
att_prev = paddle.zeros([*paddle.shape(enc_hs_pad)[:2]])
|
||||
att_prev[:, 0] = 1.0
|
||||
|
||||
# att_prev: utt x frame -> utt x 1 x 1 x frame
|
||||
# -> utt x att_conv_chans x 1 x frame
|
||||
att_conv = self.loc_conv(att_prev.reshape([batch, 1, 1, self.h_length]))
|
||||
# att_conv: utt x att_conv_chans x 1 x frame -> utt x frame x att_conv_chans
|
||||
att_conv = att_conv.squeeze(2).transpose([0, 2, 1])
|
||||
# att_conv: utt x frame x att_conv_chans -> utt x frame x att_dim
|
||||
att_conv = self.mlp_att(att_conv)
|
||||
|
||||
# dec_z_tiled: utt x frame x att_dim
|
||||
dec_z_tiled = self.mlp_dec(dec_z).reshape([batch, 1, self.att_dim])
|
||||
|
||||
# dot with gvec
|
||||
# utt x frame x att_dim -> utt x frame
|
||||
e = self.gvec(
|
||||
paddle.tanh(att_conv + self.pre_compute_enc_h +
|
||||
dec_z_tiled)).squeeze(2)
|
||||
|
||||
# NOTE consider zero padding when compute w.
|
||||
if self.mask is None:
|
||||
self.mask = make_pad_mask(enc_hs_len)
|
||||
e = masked_fill(e, self.mask, -float("inf"))
|
||||
|
||||
# apply monotonic attention constraint (mainly for TTS)
|
||||
if last_attended_idx is not None:
|
||||
e = _apply_attention_constraint(e, last_attended_idx,
|
||||
backward_window, forward_window)
|
||||
|
||||
w = F.softmax(scaling * e, axis=1)
|
||||
|
||||
# forward attention
|
||||
# att_prev_shift = F.pad(att_prev.unsqueeze(0), (1, 0), data_format='NCL').squeeze(0)[:, :-1]
|
||||
att_prev_shift = F.pad(att_prev, (0, 0, 1, 0))[:, :-1]
|
||||
w = (self.trans_agent_prob * att_prev +
|
||||
(1 - self.trans_agent_prob) * att_prev_shift) * w
|
||||
# NOTE: clip is needed to avoid nan gradient
|
||||
w = F.normalize(paddle.clip(w, 1e-6), p=1, axis=1)
|
||||
|
||||
# weighted sum over flames
|
||||
# utt x hdim
|
||||
# NOTE use bmm instead of sum(*)
|
||||
c = paddle.sum(
|
||||
self.enc_h * w.reshape([batch, self.h_length, 1]), axis=1)
|
||||
|
||||
# update transition agent prob
|
||||
self.trans_agent_prob = F.sigmoid(
|
||||
self.mlp_ta(paddle.concat([c, out_prev, dec_z], axis=1)))
|
||||
|
||||
return c, w
|
Loading…
Reference in new issue