diff --git a/examples/csmsc/tts1/README.md b/examples/csmsc/tts1/README.md new file mode 100644 index 000000000..0725eda2d --- /dev/null +++ b/examples/csmsc/tts1/README.md @@ -0,0 +1,198 @@ +# TransformerTTS with CSMSC +## Dataset +### Download and Extract +Download CSMSC from it's [Official Website](https://test.data-baker.com/data/index/TNtts/) and extract it to `~/datasets`. Then the dataset is in the directory `~/datasets/BZNSYP`. +## Get Started +Assume the path to the dataset is `~/datasets/BZNSYP` and extract it to `~/datasets`. Then the dataset is in the directory `~/datasets/BZNSYP`. + ++ +Run the command below to +1. **source path**. +2. preprocess the dataset. +3. train the model. +4. synthesize wavs. + - synthesize waveform from `metadata.jsonl`. + - synthesize waveform from text file. +```bash +./run.sh +``` +You can choose a range of stages you want to run, or set `stage` equal to `stop-stage` to use only one stage, for example, running the following command will only preprocess the dataset. +```bash +./run.sh --stage 0 --stop-stage 0 +``` +### Data Preprocessing +```bash +./local/preprocess.sh ${conf_path} +``` +When it is done. A `dump` folder is created in the current directory. The structure of the dump folder is listed below. +```text +dump +├── dev +│ ├── norm +│ └── raw +├── phone_id_map.txt +├── speaker_id_map.txt +├── test +│ ├── norm +│ └── raw +└── train + ├── norm + ├── raw + └── speech_stats.npy +``` +The dataset is split into 3 parts, namely `train`, `dev`, and` test`, each of which contains a `norm` and `raw` subfolder. The raw folder contains the speech feature of each utterance, while the norm folder contains normalized ones. The statistics used to normalize features are computed from the training set, which is located in `dump/train/speech_stats.npy`. + +Also, there is a `metadata.jsonl` in each subfolder. It is a table-like file that contains phones, text_lengths, speech_lengths, the path of speech features, speaker, and id of each utterance. + +### Model Training +`./local/train.sh` calls `${BIN_DIR}/train.py`. +```bash +CUDA_VISIBLE_DEVICES=${gpus} ./local/train.sh ${conf_path} ${train_output_path} +``` +Here's the complete help message. +```text +usage: train.py [-h] [--config CONFIG] [--train-metadata TRAIN_METADATA] + [--dev-metadata DEV_METADATA] [--output-dir OUTPUT_DIR] + [--ngpu NGPU] [--phones-dict PHONES_DICT] + +Train a TransformerTTS model with LJSpeech TTS dataset. + +optional arguments: + -h, --help show this help message and exit + --config CONFIG TransformerTTS config file. + --train-metadata TRAIN_METADATA + training data. + --dev-metadata DEV_METADATA + dev data. + --output-dir OUTPUT_DIR + output dir. + --ngpu NGPU if ngpu == 0, use cpu. + --phones-dict PHONES_DICT + phone vocabulary file. +``` +1. `--config` is a config file in yaml format to overwrite the default config, which can be found at `conf/default.yaml`. +2. `--train-metadata` and `--dev-metadata` should be the metadata file in the normalized subfolder of `train` and `dev` in the `dump` folder. +3. `--output-dir` is the directory to save the results of the experiment. Checkpoints are saved in `checkpoints/` inside this directory. +4. `--ngpu` is the number of gpus to use, if ngpu == 0, use cpu. +5. `--phones-dict` is the path of the phone vocabulary file. + +## Synthesizing +We use [waveflow](https://github.com/PaddlePaddle/PaddleSpeech/tree/develop/examples/ljspeech/voc0) as the neural vocoder. +Download Pretrained WaveFlow Model with residual channel equals 128 from [waveflow_ljspeech_ckpt_0.3.zip](https://paddlespeech.bj.bcebos.com/Parakeet/released_models/waveflow/waveflow_ljspeech_ckpt_0.3.zip) and unzip it. +```bash +unzip waveflow_ljspeech_ckpt_0.3.zip +``` +WaveFlow checkpoint contains files listed below. +```text +waveflow_ljspeech_ckpt_0.3 +├── config.yaml # default config used to train waveflow +└── step-2000000.pdparams # model parameters of waveflow +``` +`./local/synthesize.sh` calls `${BIN_DIR}/synthesize.py`, which can synthesize waveform from `metadata.jsonl`. +```bash +CUDA_VISIBLE_DEVICES=${gpus} ./local/synthesize.sh ${conf_path} ${train_output_path} ${ckpt_name} +``` +```text +usage: synthesize.py [-h] [--transformer-tts-config TRANSFORMER_TTS_CONFIG] + [--transformer-tts-checkpoint TRANSFORMER_TTS_CHECKPOINT] + [--transformer-tts-stat TRANSFORMER_TTS_STAT] + [--waveflow-config WAVEFLOW_CONFIG] + [--waveflow-checkpoint WAVEFLOW_CHECKPOINT] + [--phones-dict PHONES_DICT] + [--test-metadata TEST_METADATA] [--output-dir OUTPUT_DIR] + [--ngpu NGPU] + +Synthesize with transformer tts & waveflow. + +optional arguments: + -h, --help show this help message and exit + --transformer-tts-config TRANSFORMER_TTS_CONFIG + transformer tts config file. + --transformer-tts-checkpoint TRANSFORMER_TTS_CHECKPOINT + transformer tts checkpoint to load. + --transformer-tts-stat TRANSFORMER_TTS_STAT + mean and standard deviation used to normalize + spectrogram when training transformer tts. + --waveflow-config WAVEFLOW_CONFIG + waveflow config file. + --waveflow-checkpoint WAVEFLOW_CHECKPOINT + waveflow checkpoint to load. + --phones-dict PHONES_DICT + phone vocabulary file. + --test-metadata TEST_METADATA + test metadata. + --output-dir OUTPUT_DIR + output dir. + --ngpu NGPU if ngpu == 0, use cpu. +``` +`./local/synthesize_e2e.sh` calls `${BIN_DIR}/synthesize_e2e.py`, which can synthesize waveform from text file. +```bash +CUDA_VISIBLE_DEVICES=${gpus} ./local/synthesize_e2e.sh ${conf_path} ${train_output_path} ${ckpt_name} +``` +```text +usage: synthesize_e2e.py [-h] + [--transformer-tts-config TRANSFORMER_TTS_CONFIG] + [--transformer-tts-checkpoint TRANSFORMER_TTS_CHECKPOINT] + [--transformer-tts-stat TRANSFORMER_TTS_STAT] + [--waveflow-config WAVEFLOW_CONFIG] + [--waveflow-checkpoint WAVEFLOW_CHECKPOINT] + [--phones-dict PHONES_DICT] [--text TEXT] + [--output-dir OUTPUT_DIR] [--ngpu NGPU] + +Synthesize with transformer tts & waveflow. + +optional arguments: + -h, --help show this help message and exit + --transformer-tts-config TRANSFORMER_TTS_CONFIG + transformer tts config file. + --transformer-tts-checkpoint TRANSFORMER_TTS_CHECKPOINT + transformer tts checkpoint to load. + --transformer-tts-stat TRANSFORMER_TTS_STAT + mean and standard deviation used to normalize + spectrogram when training transformer tts. + --waveflow-config WAVEFLOW_CONFIG + waveflow config file. + --waveflow-checkpoint WAVEFLOW_CHECKPOINT + waveflow checkpoint to load. + --phones-dict PHONES_DICT + phone vocabulary file. + --text TEXT text to synthesize, a 'utt_id sentence' pair per line. + --output-dir OUTPUT_DIR + output dir. + --ngpu NGPU if ngpu == 0, use cpu. +``` +1. `--transformer-tts-config`, `--transformer-tts-checkpoint`, `--transformer-tts-stat` and `--phones-dict` are arguments for transformer_tts, which correspond to the 4 files in the transformer_tts pretrained model. +2. `--waveflow-config`, `--waveflow-checkpoint` are arguments for waveflow, which correspond to the 2 files in the waveflow pretrained model. +3. `--test-metadata` should be the metadata file in the normalized subfolder of `test` in the `dump` folder. +4. `--text` is the text file, which contains sentences to synthesize. +5. `--output-dir` is the directory to save synthesized audio files. +6. `--ngpu` is the number of gpus to use, if ngpu == 0, use cpu. + +## Pretrained Model +Pretrained Model can be downloaded here: +- [transformer_tts_csmsc_ckpt.zip](https://pan.baidu.com/s/1jan_ZXCGKI7DHvS2jxWIEw?pwd=9i0t) + +TransformerTTS checkpoint contains files listed below. +```text +transformer_tts_csmsc_ckpt +├── default.yaml # default config used to train transformer_tts +├── phone_id_map.txt # phone vocabulary file when training transformer_tts +├── snapshot_iter_1118250.pdz # model parameters and optimizer states +└── speech_stats.npy # statistics used to normalize spectrogram when training transformer_tts +``` +You can use the following scripts to synthesize for `${BIN_DIR}/../sentences.txt` using pretrained transformer_tts and waveflow models. +```bash +source path.sh + +FLAGS_allocator_strategy=naive_best_fit \ +FLAGS_fraction_of_gpu_memory_to_use=0.01 \ +python3 ${BIN_DIR}/synthesize_e2e.py \ + --transformer-tts-config=transformer_tts_csmsc_ckpt/default.yaml \ + --transformer-tts-checkpoint=transformer_tts_csmsc_ckpt/snapshot_iter_1118250.pdz \ + --transformer-tts-stat=transformer_tts_csmsc_ckpt/speech_stats.npy \ + --waveflow-config=waveflow_ljspeech_ckpt_0.3/config.yaml \ + --waveflow-checkpoint=waveflow_ljspeech_ckpt_0.3/step-2000000.pdparams \ + --text=${BIN_DIR}/../sentences.txt \ + --output-dir=exp/default/test_e2e \ + --phones-dict=transformer_tts_csmsc_ckpt/phone_id_map.txt +``` diff --git a/examples/csmsc/tts1/conf/default.yaml b/examples/csmsc/tts1/conf/default.yaml new file mode 100644 index 000000000..456b6a1e3 --- /dev/null +++ b/examples/csmsc/tts1/conf/default.yaml @@ -0,0 +1,92 @@ + +fs : 22050 # Hz, sample rate +n_fft : 1024 # FFT size (samples). +win_length : 1024 # Window length (samples). 46.4ms +n_shift : 256 # Hop size (samples). 11.6ms +fmin : 0 # Hz, min frequency when converting to mel +fmax : 8000 # Hz, max frequency when converting to mel +n_mels : 80 # mel bands +window: "hann" # Window function. + +########################################################### +# DATA SETTING # +########################################################### +batch_size: 16 +num_workers: 2 + +########################################################## +# TTS MODEL SETTING # +########################################################## +tts: transformertts # model architecture +model: # keyword arguments for the selected model + embed_dim: 0 # embedding dimension in encoder prenet + eprenet_conv_layers: 0 # number of conv layers in encoder prenet + # if set to 0, no encoder prenet will be used + eprenet_conv_filts: 0 # filter size of conv layers in encoder prenet + eprenet_conv_chans: 0 # number of channels of conv layers in encoder prenet + dprenet_layers: 2 # number of layers in decoder prenet + dprenet_units: 256 # number of units in decoder prenet + adim: 512 # attention dimension + aheads: 8 # number of attention heads + elayers: 6 # number of encoder layers + eunits: 1024 # number of encoder ff units + dlayers: 6 # number of decoder layers + dunits: 1024 # number of decoder ff units + positionwise_layer_type: conv1d # type of position-wise layer + positionwise_conv_kernel_size: 1 # kernel size of position wise conv layer + postnet_layers: 5 # number of layers of postnset + postnet_filts: 5 # filter size of conv layers in postnet + postnet_chans: 256 # number of channels of conv layers in postnet + use_scaled_pos_enc: True # whether to use scaled positional encoding + encoder_normalize_before: True # whether to perform layer normalization before the input + decoder_normalize_before: True # whether to perform layer normalization before the input + reduction_factor: 1 # reduction factor + init_type: xavier_uniform # initialization type + init_enc_alpha: 1.0 # initial value of alpha of encoder scaled position encoding + init_dec_alpha: 1.0 # initial value of alpha of decoder scaled position encoding + eprenet_dropout_rate: 0.0 # dropout rate for encoder prenet + dprenet_dropout_rate: 0.5 # dropout rate for decoder prenet + postnet_dropout_rate: 0.5 # dropout rate for postnet + transformer_enc_dropout_rate: 0.1 # dropout rate for transformer encoder layer + transformer_enc_positional_dropout_rate: 0.1 # dropout rate for transformer encoder positional encoding + transformer_enc_attn_dropout_rate: 0.1 # dropout rate for transformer encoder attention layer + transformer_dec_dropout_rate: 0.1 # dropout rate for transformer decoder layer + transformer_dec_positional_dropout_rate: 0.1 # dropout rate for transformer decoder positional encoding + transformer_dec_attn_dropout_rate: 0.1 # dropout rate for transformer decoder attention layer + transformer_enc_dec_attn_dropout_rate: 0.1 # dropout rate for transformer encoder-decoder attention layer + num_heads_applied_guided_attn: 2 # number of heads to apply guided attention loss + num_layers_applied_guided_attn: 2 # number of layers to apply guided attention loss + + + +########################################################### +# UPDATER SETTING # +########################################################### +updater: + use_masking: True # whether to apply masking for padded part in loss calculation + loss_type: L1 + use_guided_attn_loss: True # whether to use guided attention loss + guided_attn_loss_sigma: 0.4 # sigma in guided attention loss + guided_attn_loss_lambda: 10.0 # lambda in guided attention loss + modules_applied_guided_attn: ["encoder-decoder"] # modules to apply guided attention loss + bce_pos_weight: 5.0 # weight of positive sample in binary cross entropy calculation + + +########################################################## +# OPTIMIZER & SCHEDULER SETTING # +########################################################## +optimizer: + optim: adam # optimizer type + learning_rate: 0.001 # learning rate + +########################################################### +# TRAINING SETTING # +########################################################### +max_epoch: 500 +num_snapshots: 5 + + +########################################################### +# OTHER SETTING # +########################################################### +seed: 10086 \ No newline at end of file diff --git a/examples/csmsc/tts1/local/preprocess.sh b/examples/csmsc/tts1/local/preprocess.sh new file mode 100644 index 000000000..e1acc8e83 --- /dev/null +++ b/examples/csmsc/tts1/local/preprocess.sh @@ -0,0 +1,50 @@ +#!/bin/bash + +stage=1 +stop_stage=100 + +config_path=$1 + +if [ ${stage} -le 1 ] && [ ${stop_stage} -ge 1 ]; then + # extract features + echo "Extract features ..." + python3 ${BIN_DIR}/preprocess.py \ + --dataset=ljspeech \ + --rootdir=~/datasets/LJSpeech-1.1/ \ + --dumpdir=dump \ + --config-path=conf/default.yaml \ + --num-cpu=8 +fi + +if [ ${stage} -le 2 ] && [ ${stop_stage} -ge 2 ]; then + # get features' stats(mean and std) + echo "Get features' stats ..." + python3 ${MAIN_ROOT}/utils/compute_statistics.py \ + --metadata=dump/train/raw/metadata.jsonl \ + --field-name="speech" +fi + +if [ ${stage} -le 3 ] && [ ${stop_stage} -ge 3 ]; then + # normalize and covert phone to id, dev and test should use train's stats + echo "Normalize ..." + python3 ${BIN_DIR}/normalize.py \ + --metadata=dump/train/raw/metadata.jsonl \ + --dumpdir=dump/train/norm \ + --speech-stats=dump/train/speech_stats.npy \ + --phones-dict=dump/phone_id_map.txt \ + --speaker-dict=dump/speaker_id_map.txt + + python3 ${BIN_DIR}/normalize.py \ + --metadata=dump/dev/raw/metadata.jsonl \ + --dumpdir=dump/dev/norm \ + --speech-stats=dump/train/speech_stats.npy \ + --phones-dict=dump/phone_id_map.txt \ + --speaker-dict=dump/speaker_id_map.txt + + python3 ${BIN_DIR}/normalize.py \ + --metadata=dump/test/raw/metadata.jsonl \ + --dumpdir=dump/test/norm \ + --speech-stats=dump/train/speech_stats.npy \ + --phones-dict=dump/phone_id_map.txt \ + --speaker-dict=dump/speaker_id_map.txt +fi diff --git a/examples/csmsc/tts1/local/synthesize.sh b/examples/csmsc/tts1/local/synthesize.sh new file mode 100644 index 000000000..9d1c47b39 --- /dev/null +++ b/examples/csmsc/tts1/local/synthesize.sh @@ -0,0 +1,17 @@ +#!/bin/bash + +config_path=$1 +train_output_path=$2 +ckpt_name=$3 + +FLAGS_allocator_strategy=naive_best_fit \ +FLAGS_fraction_of_gpu_memory_to_use=0.01 \ +python3 ${BIN_DIR}/synthesize.py \ + --transformer-tts-config=${config_path} \ + --transformer-tts-checkpoint=${train_output_path}/checkpoints/${ckpt_name} \ + --transformer-tts-stat=dump/train/speech_stats.npy \ + --waveflow-config=waveflow_ljspeech_ckpt_0.3/config.yaml \ + --waveflow-checkpoint=waveflow_ljspeech_ckpt_0.3/step-2000000.pdparams \ + --test-metadata=dump/test/norm/metadata.jsonl \ + --output-dir=${train_output_path}/test \ + --phones-dict=dump/phone_id_map.txt diff --git a/examples/csmsc/tts1/local/synthesize_e2e.sh b/examples/csmsc/tts1/local/synthesize_e2e.sh new file mode 100644 index 000000000..25a862f90 --- /dev/null +++ b/examples/csmsc/tts1/local/synthesize_e2e.sh @@ -0,0 +1,17 @@ +#!/bin/bash + +config_path=$1 +train_output_path=$2 +ckpt_name=$3 + +FLAGS_allocator_strategy=naive_best_fit \ +FLAGS_fraction_of_gpu_memory_to_use=0.01 \ +python3 ${BIN_DIR}/synthesize_e2e.py \ + --transformer-tts-config=${config_path} \ + --transformer-tts-checkpoint=${train_output_path}/checkpoints/${ckpt_name} \ + --transformer-tts-stat=dump/train/speech_stats.npy \ + --waveflow-config=waveflow_ljspeech_ckpt_0.3/config.yaml \ + --waveflow-checkpoint=waveflow_ljspeech_ckpt_0.3/step-2000000.pdparams \ + --text=${BIN_DIR}/../sentences_en.txt \ + --output-dir=${train_output_path}/test_e2e \ + --phones-dict=dump/phone_id_map.txt diff --git a/examples/csmsc/tts1/local/train.sh b/examples/csmsc/tts1/local/train.sh new file mode 100644 index 000000000..5e255fb8d --- /dev/null +++ b/examples/csmsc/tts1/local/train.sh @@ -0,0 +1,12 @@ +#!/bin/bash + +config_path=$1 +train_output_path=$2 + +python3 ${BIN_DIR}/train.py \ + --train-metadata=dump/train/norm/metadata.jsonl \ + --dev-metadata=dump/dev/norm/metadata.jsonl \ + --config=${config_path} \ + --output-dir=${train_output_path} \ + --ngpu=2 \ + --phones-dict=dump/phone_id_map.txt diff --git a/examples/csmsc/tts1/path.sh b/examples/csmsc/tts1/path.sh new file mode 100644 index 000000000..32eecd857 --- /dev/null +++ b/examples/csmsc/tts1/path.sh @@ -0,0 +1,13 @@ +#!/bin/bash +export MAIN_ROOT=`realpath ${PWD}/../../../` + +export PATH=${MAIN_ROOT}:${MAIN_ROOT}/utils:${PATH} +export LC_ALL=C + +export PYTHONDONTWRITEBYTECODE=1 +# Use UTF-8 in Python to avoid UnicodeDecodeError when LC_ALL=C +export PYTHONIOENCODING=UTF-8 +export PYTHONPATH=${MAIN_ROOT}:${PYTHONPATH} + +MODEL=transformer_tts +export BIN_DIR=${MAIN_ROOT}/paddlespeech/t2s/exps/${MODEL} diff --git a/examples/csmsc/tts1/run.sh b/examples/csmsc/tts1/run.sh new file mode 100644 index 000000000..48c4c9151 --- /dev/null +++ b/examples/csmsc/tts1/run.sh @@ -0,0 +1,37 @@ +#!/bin/bash + +set -e +source path.sh + +gpus=0,1 +stage=0 +stop_stage=100 + +conf_path=conf/default.yaml +train_output_path=exp/default +ckpt_name=snapshot_iter_403.pdz + +# with the following command, you can choose the stage range you want to run +# such as `./run.sh --stage 0 --stop-stage 0` +# this can not be mixed use with `$1`, `$2` ... +source ${MAIN_ROOT}/utils/parse_options.sh || exit 1 + +if [ ${stage} -le 0 ] && [ ${stop_stage} -ge 0 ]; then + # prepare data + ./local/preprocess.sh ${conf_path} || exit -1 +fi + +if [ ${stage} -le 1 ] && [ ${stop_stage} -ge 1 ]; then + # train model, all `ckpt` under `train_output_path/checkpoints/` dir + CUDA_VISIBLE_DEVICES=${gpus} ./local/train.sh ${conf_path} ${train_output_path} || exit -1 +fi + +if [ ${stage} -le 2 ] && [ ${stop_stage} -ge 2 ]; then + # synthesize, vocoder is pwgan + CUDA_VISIBLE_DEVICES=${gpus} ./local/synthesize.sh ${conf_path} ${train_output_path} ${ckpt_name} || exit -1 +fi + +if [ ${stage} -le 3 ] && [ ${stop_stage} -ge 3 ]; then + # synthesize_e2e, vocoder is pwgan + CUDA_VISIBLE_DEVICES=${gpus} ./local/synthesize_e2e.sh ${conf_path} ${train_output_path} ${ckpt_name} || exit -1 +fi diff --git a/examples/synthesize_e2e1.py b/examples/synthesize_e2e1.py deleted file mode 100644 index 0c01199c3..000000000 --- a/examples/synthesize_e2e1.py +++ /dev/null @@ -1,170 +0,0 @@ -# Copyright (c) 2021 PaddlePaddle Authors. All Rights Reserved. -# -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. -import argparse -import logging -from pathlib import Path - -import numpy as np -import paddle -import soundfile as sf -import yaml -from yacs.config import CfgNode - -from paddlespeech.t2s.frontend import English,Chinese -from paddlespeech.t2s.models.transformer_tts import TransformerTTS -from paddlespeech.t2s.models.transformer_tts import TransformerTTSInference -from paddlespeech.t2s.models.waveflow import ConditionalWaveFlow -from paddlespeech.t2s.modules.normalizer import ZScore -from paddlespeech.t2s.utils import layer_tools -from paddlespeech.t2s.exps.syn_utils import get_voc_inference2 -from paddlespeech.t2s.exps.syn_utils import get_am_inference - -def evaluate(args, acoustic_model_config, vocoder_config): - # dataloader has been too verbose - logging.getLogger("DataLoader").disabled = True - - # construct dataset for evaluation - sentences = [] - with open(args.text, 'rt') as f: - for line in f: - line_list = line.strip().split() - utt_id = line_list[0] - sentence = " ".join(line_list[1:]) - sentences.append((utt_id, sentence)) - - with open(args.phones_dict, "r") as f: - phn_id = [line.strip().split() for line in f.readlines()] - - vocab_size = len(phn_id) - - phone_id_map = {} - for phn, id in phn_id: - phone_id_map[phn] = int(id) - print("vocab_size:", vocab_size) - odim = acoustic_model_config.n_mels - model = TransformerTTS(idim=vocab_size, odim=odim, **acoustic_model_config["model"]) - model.set_state_dict(paddle.load(args.transformer_tts_checkpoint)["main_params"]) - model.eval() - - vocoder_checkpoint_path = args.waveflow_checkpoint[:-9] if args.waveflow_checkpoint.endswith(".pdparams") else args.waveflow_checkpoint - vocoder = ConditionalWaveFlow.from_pretrained(vocoder_config,vocoder_checkpoint_path) - layer_tools.recursively_remove_weight_norm(vocoder) - vocoder.eval() - print("model done!") - - # vocoder = get_voc_inference2( - # voc=args.voc, - # voc_config=vocoder_config, - # voc_ckpt=args.voc_ckpt, - # voc_stat=args.voc_stat) - - frontend = Chinese() - print("frontend done!") - - stat = np.load(args.transformer_tts_stat) - mu, std = stat - mu = paddle.to_tensor(mu) - std = paddle.to_tensor(std) - transformer_tts_normalizer = ZScore(mu, std) - transformer_tts_inference = TransformerTTSInference(transformer_tts_normalizer, model) - - output_dir = Path(args.output_dir) - output_dir.mkdir(parents=True, exist_ok=True) - - for utt_id, sentence in sentences: - phones = frontend.phoneticize(sentence) - print(sentence) - print(phones) - # remove start_symbol and end_symbol - phones = phones[1:-1] - phones = [phn for phn in phones if not phn.isspace()] - phones = [phn if phn in phone_id_map else "," for phn in phones] - phone_ids = [phone_id_map[phn] for phn in phones] - print('1',phone_ids) - with paddle.no_grad(): - tensor_phone_ids=paddle.to_tensor(phone_ids) - mel = transformer_tts_inference(tensor_phone_ids) - # mel shape is (T, feats) and waveflow's input shape is (batch, feats, T) - mel = mel.unsqueeze(0).transpose([0, 2, 1]) - # wavflow's output shape is (B, T) - wav = vocoder.infer(mel)[0] - #wav = vocoder(mel)[0] - sf.write(str(output_dir / (utt_id + ".wav")),wav.numpy(),samplerate=acoustic_model_config.fs) - #sf.write(str(output_dir / (utt_id + ".wav")), wav.numpy(), samplerate=24000) - print(f"{utt_id} done!") - - -def main(): - parser = argparse.ArgumentParser( - description="Synthesize with transformer tts & waveflow.") - parser.add_argument( - "--transformer-tts-config", - default="./out_put/default.yaml", - type=str, - help="transformer tts config file.") - parser.add_argument( - "--transformer-tts-checkpoint", - default="./out_put/checkpoints/snapshot_iter_1113750.pdz", - type=str, - help="transformer tts checkpoint to load.") - parser.add_argument( - "--transformer-tts-stat", - default="./dump/speech_stats.npy", - type=str, - help="mean and standard deviation used to normalize spectrogram when training transformer tts." - ) - - - parser.add_argument( - "--waveflow-config", default="./waveflow_ljspeech_ckpt_0.3/config.yaml", type=str, help="waveflow config file.") - # not normalize when training waveflow - parser.add_argument( - "--waveflow-checkpoint", default="./waveflow_ljspeech_ckpt_0.3/step-2000000.pdparams", type=str, - help="waveflow checkpoint to load.") - - parser.add_argument( - "--phones-dict", type=str, default="./dump/phone_id_map.txt", help="phone vocabulary file.") - parser.add_argument( - "--text", - default="./sentences.txt", - type=str, - help="text to synthesize, a 'utt_id sentence' pair per line.") - parser.add_argument("--output-dir", default="./output222", type=str, help="output dir.") - parser.add_argument("--ngpu", type=int, default=0, help="if ngpu == 0, use cpu.") - - args = parser.parse_args() - - if args.ngpu == 0: - paddle.set_device("cpu") - elif args.ngpu > 0: - paddle.set_device("gpu") - else: - print("ngpu should >= 0 !") - - with open(args.transformer_tts_config) as f: - transformer_tts_config = CfgNode(yaml.safe_load(f)) - with open(args.waveflow_config) as f: - waveflow_config = CfgNode(yaml.safe_load(f)) - - print("========Args========") - print(yaml.safe_dump(vars(args))) - print("========Config========") - print(transformer_tts_config) - print(waveflow_config) - - evaluate(args, transformer_tts_config, waveflow_config) - - -if __name__ == "__main__": - main() diff --git a/paddlespeech/t2s/models/transformer_tts/transformer_tts.py b/paddlespeech/t2s/models/transformer_tts/transformer_tts.py index 355fceb16..4e008798b 100644 --- a/paddlespeech/t2s/models/transformer_tts/transformer_tts.py +++ b/paddlespeech/t2s/models/transformer_tts/transformer_tts.py @@ -509,7 +509,7 @@ class TransformerTTS(nn.Layer): spk_emb: paddle.Tensor=None, threshold: float=0.5, minlenratio: float=0.0, - maxlenratio: float=10.0, + maxlenratio: float=100.0, use_teacher_forcing: bool=False, ) -> Tuple[paddle.Tensor, paddle.Tensor, paddle.Tensor]: """Generate the sequence of features given the sequences of characters.