fix style_syn, replace DeepSpeech with PaddleSpeech in readme

pull/992/head
TianYuan 3 years ago
parent f26db2e762
commit 30d09b411d

@ -10,7 +10,7 @@ stop_stage=100
# this can not be mixed use with `$1`, `$2` ... # this can not be mixed use with `$1`, `$2` ...
source ${MAIN_ROOT}/utils/parse_options.sh || exit 1 source ${MAIN_ROOT}/utils/parse_options.sh || exit 1
mkdir download mkdir -p download
if [ ${stage} -le 0 ] && [ ${stop_stage} -ge 0 ]; then if [ ${stage} -le 0 ] && [ ${stop_stage} -ge 0 ]; then
# install PaddleGAN # install PaddleGAN

@ -10,7 +10,7 @@ stop_stage=100
# this can not be mixed use with `$1`, `$2` ... # this can not be mixed use with `$1`, `$2` ...
source ${MAIN_ROOT}/utils/parse_options.sh || exit 1 source ${MAIN_ROOT}/utils/parse_options.sh || exit 1
mkdir download mkdir -p download
if [ ${stage} -le 0 ] && [ ${stop_stage} -ge 0 ]; then if [ ${stage} -le 0 ] && [ ${stop_stage} -ge 0 ]; then
# install PaddleOCR # install PaddleOCR

@ -10,7 +10,7 @@ stop_stage=100
# this can not be mixed use with `$1`, `$2` ... # this can not be mixed use with `$1`, `$2` ...
source ${MAIN_ROOT}/utils/parse_options.sh || exit 1 source ${MAIN_ROOT}/utils/parse_options.sh || exit 1
mkdir download mkdir -p download
if [ ${stage} -le 0 ] && [ ${stop_stage} -ge 0 ]; then if [ ${stage} -le 0 ] && [ ${stop_stage} -ge 0 ]; then
# download pretrained tts models and unzip # download pretrained tts models and unzip

@ -13,6 +13,7 @@
# limitations under the License. # limitations under the License.
import argparse import argparse
from pathlib import Path from pathlib import Path
from typing import Union
import numpy as np import numpy as np
import paddle import paddle
@ -31,12 +32,12 @@ from paddlespeech.t2s.modules.normalizer import ZScore
class StyleFastSpeech2Inference(FastSpeech2Inference): class StyleFastSpeech2Inference(FastSpeech2Inference):
def __init__(self, normalizer, model, pitch_stats_path, energy_stats_path): def __init__(self, normalizer, model, pitch_stats_path, energy_stats_path):
super().__init__(normalizer, model) super().__init__(normalizer, model)
self.pitch_mean, self.pitch_std = np.load(pitch_stats_path) pitch_mean, pitch_std = np.load(pitch_stats_path)
self.pitch_mean = paddle.to_tensor(self.pitch_mean) self.pitch_mean = paddle.to_tensor(pitch_mean)
self.pitch_std = paddle.to_tensor(self.pitch_std) self.pitch_std = paddle.to_tensor(pitch_std)
self.energy_mean, self.energy_std = np.load(energy_stats_path) energy_mean, energy_std = np.load(energy_stats_path)
self.energy_mean = paddle.to_tensor(self.energy_mean) self.energy_mean = paddle.to_tensor(energy_mean)
self.energy_std = paddle.to_tensor(self.energy_std) self.energy_std = paddle.to_tensor(energy_std)
def denorm(self, data, mean, std): def denorm(self, data, mean, std):
return data * std + mean return data * std + mean
@ -45,11 +46,17 @@ class StyleFastSpeech2Inference(FastSpeech2Inference):
return (data - mean) / std return (data - mean) / std
def forward(self, def forward(self,
text, text: paddle.Tensor,
durations=None, durations: Union[paddle.Tensor, np.ndarray]=None,
pitch=None, durations_scale: Union[int, float]=None,
energy=None, durations_bias: Union[int, float]=None,
robot=False): pitch: Union[paddle.Tensor, np.ndarray]=None,
pitch_scale: Union[int, float]=None,
pitch_bias: Union[int, float]=None,
energy: Union[paddle.Tensor, np.ndarray]=None,
energy_scale: Union[int, float]=None,
energy_bias: Union[int, float]=None,
robot: bool=False):
""" """
Parameters Parameters
---------- ----------
@ -57,15 +64,22 @@ class StyleFastSpeech2Inference(FastSpeech2Inference):
Input sequence of characters (T,). Input sequence of characters (T,).
speech : Tensor, optional speech : Tensor, optional
Feature sequence to extract style (N, idim). Feature sequence to extract style (N, idim).
durations : Tensor, optional (int64) durations : paddle.Tensor/np.ndarray, optional (int64)
Groundtruth of duration (T,) or Groundtruth of duration (T,), this will overwrite the set of durations_scale and durations_bias
float/int (represents ratio) durations_scale: int/float, optional
pitch : Tensor, optional durations_bias: int/float, optional
Groundtruth of token-averaged pitch (T, 1) or pitch : paddle.Tensor/np.ndarray, optional
float/int (represents ratio) Groundtruth of token-averaged pitch (T, 1), this will overwrite the set of pitch_scale and pitch_bias
energy : Tensor, optional pitch_scale: int/float, optional
Groundtruth of token-averaged energy (T, 1) or In denormed HZ domain.
float (represents ratio) pitch_bias: int/float, optional
In denormed HZ domain.
energy : paddle.Tensor/np.ndarray, optional
Groundtruth of token-averaged energy (T, 1), this will overwrite the set of energy_scale and energy_bias
energy_scale: int/float, optional
In denormed domain.
energy_bias: int/float, optional
In denormed domain.
robot : bool, optional robot : bool, optional
Weather output robot style Weather output robot style
Returns Returns
@ -75,12 +89,16 @@ class StyleFastSpeech2Inference(FastSpeech2Inference):
""" """
normalized_mel, d_outs, p_outs, e_outs = self.acoustic_model.inference( normalized_mel, d_outs, p_outs, e_outs = self.acoustic_model.inference(
text, durations=None, pitch=None, energy=None) text, durations=None, pitch=None, energy=None)
# priority: groundtruth > scale/bias > previous output
# set duration # set durations
if isinstance(durations, float): if isinstance(durations, np.ndarray):
durations = durations * d_outs durations = paddle.to_tensor(durations)
elif isinstance(durations, paddle.Tensor): elif isinstance(durations, paddle.Tensor):
durations = durations durations = durations
elif durations_scale or durations_bias:
durations_scale = durations_scale if durations_scale is not None else 1
durations_bias = durations_bias if durations_bias is not None else 0
durations = durations_scale * d_outs + durations_bias
else: else:
durations = d_outs durations = d_outs
@ -88,24 +106,32 @@ class StyleFastSpeech2Inference(FastSpeech2Inference):
# set normed pitch to zeros have the same effect with set denormd ones to mean # set normed pitch to zeros have the same effect with set denormd ones to mean
pitch = paddle.zeros(p_outs.shape) pitch = paddle.zeros(p_outs.shape)
# set pitch, can overwrite robot set # set pitch, can overwrite robot set
if isinstance(pitch, (int, float)): if isinstance(pitch, np.ndarray):
pitch = paddle.to_tensor(pitch)
elif isinstance(pitch, paddle.Tensor):
pitch = pitch
elif pitch_scale or pitch_bias:
pitch_scale = pitch_scale if pitch_scale is not None else 1
pitch_bias = pitch_bias if pitch_bias is not None else 0
p_Hz = paddle.exp( p_Hz = paddle.exp(
self.denorm(p_outs, self.pitch_mean, self.pitch_std)) self.denorm(p_outs, self.pitch_mean, self.pitch_std))
p_HZ = pitch * p_Hz p_HZ = pitch_scale * p_Hz + pitch_bias
pitch = self.norm(paddle.log(p_HZ), self.pitch_mean, self.pitch_std) pitch = self.norm(paddle.log(p_HZ), self.pitch_mean, self.pitch_std)
elif isinstance(pitch, paddle.Tensor):
pitch = pitch
else: else:
pitch = p_outs pitch = p_outs
# set energy # set energy
if isinstance(energy, (int, float)): if isinstance(energy, np.ndarray):
e_dnorm = self.denorm(e_outs, self.energy_mean, self.energy_std) energy = paddle.to_tensor(energy)
e_dnorm = energy * e_dnorm
energy = self.norm(e_dnorm, self.energy_mean, self.energy_std)
elif isinstance(energy, paddle.Tensor): elif isinstance(energy, paddle.Tensor):
energy = energy energy = energy
elif energy_scale or energy_bias:
energy_scale = energy_scale if energy_scale is not None else 1
energy_bias = energy_bias if energy_bias is not None else 0
e_dnorm = self.denorm(e_outs, self.energy_mean, self.energy_std)
e_dnorm = energy_scale * e_dnorm + energy_bias
energy = self.norm(e_dnorm, self.energy_mean, self.energy_std)
else: else:
energy = e_outs energy = e_outs
@ -173,23 +199,29 @@ def evaluate(args, fastspeech2_config, pwg_config):
output_dir = Path(args.output_dir) output_dir = Path(args.output_dir)
output_dir.mkdir(parents=True, exist_ok=True) output_dir.mkdir(parents=True, exist_ok=True)
styles = ["normal", "robot", "1.2xspeed", "0.8xspeed", "child_voice"] styles = ["normal", "robot", "1.2xspeed", "0.8xspeed", "child_voice"]
for style in styles: for style in styles:
robot = False robot = False
durations = None durations = None
durations_scale = None
durations_bias = None
pitch = None pitch = None
pitch_scale = None
pitch_bias = None
energy = None energy = None
energy_scale = None
energy_bias = None
if style == "robot": if style == "robot":
# all tones in phones be `1` # all tones in phones be `1`
# all pitch should be the same, we use mean here # all pitch should be the same, we use mean here
robot = True robot = True
if style == "1.2xspeed": if style == "1.2xspeed":
durations = 1 / 1.2 durations_scale = 1 / 1.2
if style == "0.8xspeed": if style == "0.8xspeed":
durations = 1 / 0.8 durations_scale = 1 / 0.8
if style == "child_voice": if style == "child_voice":
pitch = 1.3 pitch_scale = 1.3
sub_output_dir = output_dir / style sub_output_dir = output_dir / style
sub_output_dir.mkdir(parents=True, exist_ok=True) sub_output_dir.mkdir(parents=True, exist_ok=True)
for utt_id, sentence in sentences: for utt_id, sentence in sentences:
@ -201,8 +233,14 @@ def evaluate(args, fastspeech2_config, pwg_config):
mel = fastspeech2_inference( mel = fastspeech2_inference(
phone_ids, phone_ids,
durations=durations, durations=durations,
durations_scale=durations_scale,
durations_bias=durations_bias,
pitch=pitch, pitch=pitch,
pitch_scale=pitch_scale,
pitch_bias=pitch_bias,
energy=energy, energy=energy,
energy_scale=energy_scale,
energy_bias=energy_bias,
robot=robot) robot=robot)
wav = pwg_inference(mel) wav = pwg_inference(mel)

@ -13,7 +13,7 @@ In addition, the training process and the testing process are also introduced.
The arcitecture of the model is shown in Fig.1. The arcitecture of the model is shown in Fig.1.
<p align="center"> <p align="center">
<img src="https://raw.githubusercontent.com/PaddlePaddle/DeepSpeech/develop/docs/images/ds2onlineModel.png" width=800> <img src="https://raw.githubusercontent.com/PaddlePaddle/PaddleSpeech/develop/docs/images/ds2onlineModel.png" width=800>
<br/>Fig.1 The Arcitecture of deepspeech2 online model <br/>Fig.1 The Arcitecture of deepspeech2 online model
</p> </p>
@ -160,7 +160,7 @@ The deepspeech2 offline model is similarity to the deepspeech2 online model. The
The arcitecture of the model is shown in Fig.2. The arcitecture of the model is shown in Fig.2.
<p align="center"> <p align="center">
<img src="https://raw.githubusercontent.com/PaddlePaddle/DeepSpeech/develop/docs/images/ds2offlineModel.png" width=800> <img src="https://raw.githubusercontent.com/PaddlePaddle/PaddleSpeech/develop/docs/images/ds2offlineModel.png" width=800>
<br/>Fig.2 The Arcitecture of deepspeech2 offline model <br/>Fig.2 The Arcitecture of deepspeech2 offline model
</p> </p>

@ -33,8 +33,8 @@ Model Type | Dataset| Example Link | Pretrained Models|Static Models|Siize(stati
:-------------:| :------------:| :-----: | :-----:| :-----:| :-----: :-------------:| :------------:| :-----: | :-----:| :-----:| :-----:
Tacotron2|LJSpeech|[tacotron2-vctk](https://github.com/PaddlePaddle/PaddleSpeech/tree/develop/examples/ljspeech/tts0)|[tacotron2_ljspeech_ckpt_0.3.zip](https://paddlespeech.bj.bcebos.com/Parakeet/tacotron2_ljspeech_ckpt_0.3.zip)||| Tacotron2|LJSpeech|[tacotron2-vctk](https://github.com/PaddlePaddle/PaddleSpeech/tree/develop/examples/ljspeech/tts0)|[tacotron2_ljspeech_ckpt_0.3.zip](https://paddlespeech.bj.bcebos.com/Parakeet/tacotron2_ljspeech_ckpt_0.3.zip)|||
TransformerTTS| LJSpeech| [transformer-ljspeech](https://github.com/PaddlePaddle/PaddleSpeech/tree/develop/examples/ljspeech/tts1)|[transformer_tts_ljspeech_ckpt_0.4.zip](https://paddlespeech.bj.bcebos.com/Parakeet/transformer_tts_ljspeech_ckpt_0.4.zip)||| TransformerTTS| LJSpeech| [transformer-ljspeech](https://github.com/PaddlePaddle/PaddleSpeech/tree/develop/examples/ljspeech/tts1)|[transformer_tts_ljspeech_ckpt_0.4.zip](https://paddlespeech.bj.bcebos.com/Parakeet/transformer_tts_ljspeech_ckpt_0.4.zip)|||
SpeedySpeech| CSMSC | [speedyspeech-csmsc](https://github.com/PaddlePaddle/PaddleSpeech/tree/develop/examples/csmsc/tts2) |[speedyspeech_nosil_baker_ckpt_0.5.zip](https://paddlespeech.bj.bcebos.com/Parakeet/speedyspeech_nosil_baker_ckpt_0.5.zip)|[speedyspeech_nosil_baker_static_0.5.zip](https://paddlespeech.bj.bcebos.com/Parakeet/speedyspeech_nosil_baker_static_0.5.zip)|12M| SpeedySpeech| CSMSC | [speedyspeech-csmsc](https://github.com/PaddlePaddle/PaddleSpeech/tree/develop/examples/csmsc/tts2) |[speedyspeech_nosil_baker_ckpt_0.5.zip](https://paddlespeech.bj.bcebos.com/Parakeet/speedyspeech_nosil_baker_ckpt_0.5.zip)|[speedyspeech_nosil_baker_static_0.5.zip](https://paddlespeech.bj.bcebos.com/Parakeet/speedyspeech_nosil_baker_static_0.5.zip)|12MB|
FastSpeech2| CSMSC |[fastspeech2-csmsc](https://github.com/PaddlePaddle/PaddleSpeech/tree/develop/examples/csmsc/tts3)|[fastspeech2_nosil_baker_ckpt_0.4.zip](https://paddlespeech.bj.bcebos.com/Parakeet/fastspeech2_nosil_baker_ckpt_0.4.zip)|[fastspeech2_nosil_baker_static_0.4.zip](https://paddlespeech.bj.bcebos.com/Parakeet/fastspeech2_nosil_baker_static_0.4.zip)|157M| FastSpeech2| CSMSC |[fastspeech2-csmsc](https://github.com/PaddlePaddle/PaddleSpeech/tree/develop/examples/csmsc/tts3)|[fastspeech2_nosil_baker_ckpt_0.4.zip](https://paddlespeech.bj.bcebos.com/Parakeet/fastspeech2_nosil_baker_ckpt_0.4.zip)|[fastspeech2_nosil_baker_static_0.4.zip](https://paddlespeech.bj.bcebos.com/Parakeet/fastspeech2_nosil_baker_static_0.4.zip)|157MB|
FastSpeech2| AISHELL-3 |[fastspeech2-aishell3](https://github.com/PaddlePaddle/PaddleSpeech/tree/develop/examples/aishell3/tts3)|[fastspeech2_nosil_aishell3_ckpt_0.4.zip](https://paddlespeech.bj.bcebos.com/Parakeet/fastspeech2_nosil_aishell3_ckpt_0.4.zip)||| FastSpeech2| AISHELL-3 |[fastspeech2-aishell3](https://github.com/PaddlePaddle/PaddleSpeech/tree/develop/examples/aishell3/tts3)|[fastspeech2_nosil_aishell3_ckpt_0.4.zip](https://paddlespeech.bj.bcebos.com/Parakeet/fastspeech2_nosil_aishell3_ckpt_0.4.zip)|||
FastSpeech2| LJSpeech |[fastspeech2-ljspeech](https://github.com/PaddlePaddle/PaddleSpeech/tree/develop/examples/ljspeech/tts3)|[fastspeech2_nosil_ljspeech_ckpt_0.5.zip](https://paddlespeech.bj.bcebos.com/Parakeet/fastspeech2_nosil_ljspeech_ckpt_0.5.zip)||| FastSpeech2| LJSpeech |[fastspeech2-ljspeech](https://github.com/PaddlePaddle/PaddleSpeech/tree/develop/examples/ljspeech/tts3)|[fastspeech2_nosil_ljspeech_ckpt_0.5.zip](https://paddlespeech.bj.bcebos.com/Parakeet/fastspeech2_nosil_ljspeech_ckpt_0.5.zip)|||
FastSpeech2| VCTK |[fastspeech2-csmsc](https://github.com/PaddlePaddle/PaddleSpeech/tree/develop/examples/vctk/tts3)|[fastspeech2_nosil_vctk_ckpt_0.5.zip](https://paddlespeech.bj.bcebos.com/Parakeet/fastspeech2_nosil_vctk_ckpt_0.5.zip)||| FastSpeech2| VCTK |[fastspeech2-csmsc](https://github.com/PaddlePaddle/PaddleSpeech/tree/develop/examples/vctk/tts3)|[fastspeech2_nosil_vctk_ckpt_0.5.zip](https://paddlespeech.bj.bcebos.com/Parakeet/fastspeech2_nosil_vctk_ckpt_0.5.zip)|||
@ -44,11 +44,11 @@ FastSpeech2| VCTK |[fastspeech2-csmsc](https://github.com/PaddlePaddle/PaddleSpe
Model Type | Dataset| Example Link | Pretrained Models| Static Models|Size(static) Model Type | Dataset| Example Link | Pretrained Models| Static Models|Size(static)
:-------------:| :------------:| :-----: | :-----:| :-----:| :-----: :-------------:| :------------:| :-----: | :-----:| :-----:| :-----:
WaveFlow| LJSpeech |[waveflow-ljspeech](https://github.com/PaddlePaddle/PaddleSpeech/tree/develop/examples/ljspeech/voc0)|[waveflow_ljspeech_ckpt_0.3.zip](https://paddlespeech.bj.bcebos.com/Parakeet/waveflow_ljspeech_ckpt_0.3.zip)||| WaveFlow| LJSpeech |[waveflow-ljspeech](https://github.com/PaddlePaddle/PaddleSpeech/tree/develop/examples/ljspeech/voc0)|[waveflow_ljspeech_ckpt_0.3.zip](https://paddlespeech.bj.bcebos.com/Parakeet/waveflow_ljspeech_ckpt_0.3.zip)|||
Parallel WaveGAN| CSMSC |[PWGAN-csmsc](https://github.com/PaddlePaddle/PaddleSpeech/tree/develop/examples/csmsc/voc1)|[pwg_baker_ckpt_0.4.zip](https://paddlespeech.bj.bcebos.com/Parakeet/pwg_baker_ckpt_0.4.zip)|[pwg_baker_static_0.4.zip](https://paddlespeech.bj.bcebos.com/Parakeet/pwg_baker_static_0.4.zip)|5.1M| Parallel WaveGAN| CSMSC |[PWGAN-csmsc](https://github.com/PaddlePaddle/PaddleSpeech/tree/develop/examples/csmsc/voc1)|[pwg_baker_ckpt_0.4.zip](https://paddlespeech.bj.bcebos.com/Parakeet/pwg_baker_ckpt_0.4.zip)|[pwg_baker_static_0.4.zip](https://paddlespeech.bj.bcebos.com/Parakeet/pwg_baker_static_0.4.zip)|5.1MB|
Parallel WaveGAN| LJSpeech |[PWGAN-ljspeech](https://github.com/PaddlePaddle/PaddleSpeech/tree/develop/examples/ljspeech/voc1)|[pwg_ljspeech_ckpt_0.5.zip](https://paddlespeech.bj.bcebos.com/Parakeet/pwg_ljspeech_ckpt_0.5.zip)||| Parallel WaveGAN| LJSpeech |[PWGAN-ljspeech](https://github.com/PaddlePaddle/PaddleSpeech/tree/develop/examples/ljspeech/voc1)|[pwg_ljspeech_ckpt_0.5.zip](https://paddlespeech.bj.bcebos.com/Parakeet/pwg_ljspeech_ckpt_0.5.zip)|||
Parallel WaveGAN|AISHELL-3 |[PWGAN-aishell3](https://github.com/PaddlePaddle/PaddleSpeech/tree/develop/examples/aishell3/voc1)|[pwg_aishell3_ckpt_0.5.zip](https://paddlespeech.bj.bcebos.com/Parakeet/pwg_aishell3_ckpt_0.5.zip)||| Parallel WaveGAN|AISHELL-3 |[PWGAN-aishell3](https://github.com/PaddlePaddle/PaddleSpeech/tree/develop/examples/aishell3/voc1)|[pwg_aishell3_ckpt_0.5.zip](https://paddlespeech.bj.bcebos.com/Parakeet/pwg_aishell3_ckpt_0.5.zip)|||
Parallel WaveGAN| VCTK |[PWGAN-vctk](https://github.com/PaddlePaddle/PaddleSpeech/tree/develop/examples/vctk/voc1)|[pwg_vctk_ckpt_0.5.zip](https://paddlespeech.bj.bcebos.com/Parakeet/pwg_vctk_ckpt_0.5.zip)||| Parallel WaveGAN| VCTK |[PWGAN-vctk](https://github.com/PaddlePaddle/PaddleSpeech/tree/develop/examples/vctk/voc1)|[pwg_vctk_ckpt_0.5.zip](https://paddlespeech.bj.bcebos.com/Parakeet/pwg_vctk_ckpt_0.5.zip)|||
|Multi Band MelGAN |CSMSC|[MB MelGAN-csmsc](https://github.com/PaddlePaddle/PaddleSpeech/tree/develop/examples/csmsc/voc3) | [mb_melgan_baker_ckpt_0.5.zip](https://paddlespeech.bj.bcebos.com/Parakeet/mb_melgan_baker_ckpt_0.5.zip)|[mb_melgan_baker_static_0.5.zip](https://paddlespeech.bj.bcebos.com/Parakeet/mb_melgan_baker_static_0.5.zip) |8.2M| |Multi Band MelGAN |CSMSC|[MB MelGAN-csmsc](https://github.com/PaddlePaddle/PaddleSpeech/tree/develop/examples/csmsc/voc3) | [mb_melgan_baker_ckpt_0.5.zip](https://paddlespeech.bj.bcebos.com/Parakeet/mb_melgan_baker_ckpt_0.5.zip)|[mb_melgan_baker_static_0.5.zip](https://paddlespeech.bj.bcebos.com/Parakeet/mb_melgan_baker_static_0.5.zip) |8.2MB|
### Voice Cloning ### Voice Cloning
Model Type | Dataset| Example Link | Pretrained Models Model Type | Dataset| Example Link | Pretrained Models

@ -642,8 +642,7 @@ Audio samples generated by a TTS system. Text is first transformed into spectrog
Multi-Speaker TTS Multi-Speaker TTS
------------------- -------------------
PaddleSpeech also support Multi-Speaker TTS, we provide the audio demos generated by FastSpeech2 + ParallelWaveGAN, we use AISHELL-3 Multi-Speaker TTS dataset. PaddleSpeech also support Multi-Speaker TTS, we provide the audio demos generated by FastSpeech2 + ParallelWaveGAN, we use AISHELL-3 Multi-Speaker TTS dataset. Each line is a different person.
.. raw:: html .. raw:: html
@ -651,10 +650,370 @@ PaddleSpeech also support Multi-Speaker TTS, we provide the audio demos generate
<div class="table"> <div class="table">
<table border="2" cellspacing="1" cellpadding="1"> <table border="2" cellspacing="1" cellpadding="1">
<tr> <tr>
<th align="center"> Text </th> <th align="center"> Target Timbre </th>
<th align="center"> Origin </th>
<th align="center"> Generated </th> <th align="center"> Generated </th>
</tr> </tr>
<tr>
<td>
<audio controls="controls">
<source
src="https://paddlespeech.bj.bcebos.com/Parakeet/docs/demos/fs2_aishell3_demos/target/0.wav"
type="audio/wav">
Your browser does not support the <code>audio</code> element.
</audio>
</td>
<td>
<audio controls="controls">
<source
src="https://paddlespeech.bj.bcebos.com/Parakeet/docs/demos/fs2_aishell3_demos/generated/0_002.wav"
type="audio/wav">
Your browser does not support the <code>audio</code> element.
</audio>
</td>
</tr>
<tr>
<td>
<audio controls="controls">
<source
src="https://paddlespeech.bj.bcebos.com/Parakeet/docs/demos/fs2_aishell3_demos/target/1.wav"
type="audio/wav">
Your browser does not support the <code>audio</code> element.
</audio>
</td>
<td>
<audio controls="controls">
<source
src="https://paddlespeech.bj.bcebos.com/Parakeet/docs/demos/fs2_aishell3_demos/generated/1_002.wav"
type="audio/wav">
Your browser does not support the <code>audio</code> element.
</audio>
</td>
</tr>
<tr>
<td>
<audio controls="controls">
<source
src="https://paddlespeech.bj.bcebos.com/Parakeet/docs/demos/fs2_aishell3_demos/target/2.wav"
type="audio/wav">
Your browser does not support the <code>audio</code> element.
</audio>
</td>
<td>
<audio controls="controls">
<source
src="https://paddlespeech.bj.bcebos.com/Parakeet/docs/demos/fs2_aishell3_demos/generated/2_002.wav"
type="audio/wav">
Your browser does not support the <code>audio</code> element.
</audio>
</td>
</tr>
<tr>
<td>
<audio controls="controls">
<source
src="https://paddlespeech.bj.bcebos.com/Parakeet/docs/demos/fs2_aishell3_demos/target/3.wav"
type="audio/wav">
Your browser does not support the <code>audio</code> element.
</audio>
</td>
<td>
<audio controls="controls">
<source
src="https://paddlespeech.bj.bcebos.com/Parakeet/docs/demos/fs2_aishell3_demos/generated/3_002.wav"
type="audio/wav">
Your browser does not support the <code>audio</code> element.
</audio>
</td>
</tr>
<tr>
<td>
<audio controls="controls">
<source
src="https://paddlespeech.bj.bcebos.com/Parakeet/docs/demos/fs2_aishell3_demos/target/4.wav"
type="audio/wav">
Your browser does not support the <code>audio</code> element.
</audio>
</td>
<td>
<audio controls="controls">
<source
src="https://paddlespeech.bj.bcebos.com/Parakeet/docs/demos/fs2_aishell3_demos/generated/4_002.wav"
type="audio/wav">
Your browser does not support the <code>audio</code> element.
</audio>
</td>
</tr>
<tr>
<td>
<audio controls="controls">
<source
src="https://paddlespeech.bj.bcebos.com/Parakeet/docs/demos/fs2_aishell3_demos/target/5.wav"
type="audio/wav">
Your browser does not support the <code>audio</code> element.
</audio>
</td>
<td>
<audio controls="controls">
<source
src="https://paddlespeech.bj.bcebos.com/Parakeet/docs/demos/fs2_aishell3_demos/generated/5_002.wav"
type="audio/wav">
Your browser does not support the <code>audio</code> element.
</audio>
</td>
</tr>
<tr>
<td>
<audio controls="controls">
<source
src="https://paddlespeech.bj.bcebos.com/Parakeet/docs/demos/fs2_aishell3_demos/target/6.wav"
type="audio/wav">
Your browser does not support the <code>audio</code> element.
</audio>
</td>
<td>
<audio controls="controls">
<source
src="https://paddlespeech.bj.bcebos.com/Parakeet/docs/demos/fs2_aishell3_demos/generated/6_002.wav"
type="audio/wav">
Your browser does not support the <code>audio</code> element.
</audio>
</td>
</tr>
<tr>
<td>
<audio controls="controls">
<source
src="https://paddlespeech.bj.bcebos.com/Parakeet/docs/demos/fs2_aishell3_demos/target/7.wav"
type="audio/wav">
Your browser does not support the <code>audio</code> element.
</audio>
</td>
<td>
<audio controls="controls">
<source
src="https://paddlespeech.bj.bcebos.com/Parakeet/docs/demos/fs2_aishell3_demos/generated/7_002.wav"
type="audio/wav">
Your browser does not support the <code>audio</code> element.
</audio>
</td>
</tr>
<tr>
<td>
<audio controls="controls">
<source
src="https://paddlespeech.bj.bcebos.com/Parakeet/docs/demos/fs2_aishell3_demos/target/8.wav"
type="audio/wav">
Your browser does not support the <code>audio</code> element.
</audio>
</td>
<td>
<audio controls="controls">
<source
src="https://paddlespeech.bj.bcebos.com/Parakeet/docs/demos/fs2_aishell3_demos/generated/8_002.wav"
type="audio/wav">
Your browser does not support the <code>audio</code> element.
</audio>
</td>
</tr>
<tr>
<td>
<audio controls="controls">
<source
src="https://paddlespeech.bj.bcebos.com/Parakeet/docs/demos/fs2_aishell3_demos/target/9.wav"
type="audio/wav">
Your browser does not support the <code>audio</code> element.
</audio>
</td>
<td>
<audio controls="controls">
<source
src="https://paddlespeech.bj.bcebos.com/Parakeet/docs/demos/fs2_aishell3_demos/generated/9_002.wav"
type="audio/wav">
Your browser does not support the <code>audio</code> element.
</audio>
</td>
</tr>
<tr>
<td>
<audio controls="controls">
<source
src="https://paddlespeech.bj.bcebos.com/Parakeet/docs/demos/fs2_aishell3_demos/target/10.wav"
type="audio/wav">
Your browser does not support the <code>audio</code> element.
</audio>
</td>
<td>
<audio controls="controls">
<source
src="https://paddlespeech.bj.bcebos.com/Parakeet/docs/demos/fs2_aishell3_demos/generated/10_002.wav"
type="audio/wav">
Your browser does not support the <code>audio</code> element.
</audio>
</td>
</tr>
<tr>
<td>
<audio controls="controls">
<source
src="https://paddlespeech.bj.bcebos.com/Parakeet/docs/demos/fs2_aishell3_demos/target/11.wav"
type="audio/wav">
Your browser does not support the <code>audio</code> element.
</audio>
</td>
<td>
<audio controls="controls">
<source
src="https://paddlespeech.bj.bcebos.com/Parakeet/docs/demos/fs2_aishell3_demos/generated/11_002.wav"
type="audio/wav">
Your browser does not support the <code>audio</code> element.
</audio>
</td>
</tr>
<tr>
<td>
<audio controls="controls">
<source
src="https://paddlespeech.bj.bcebos.com/Parakeet/docs/demos/fs2_aishell3_demos/target/12.wav"
type="audio/wav">
Your browser does not support the <code>audio</code> element.
</audio>
</td>
<td>
<audio controls="controls">
<source
src="https://paddlespeech.bj.bcebos.com/Parakeet/docs/demos/fs2_aishell3_demos/generated/12_002.wav"
type="audio/wav">
Your browser does not support the <code>audio</code> element.
</audio>
</td>
</tr>
<tr>
<td>
<audio controls="controls">
<source
src="https://paddlespeech.bj.bcebos.com/Parakeet/docs/demos/fs2_aishell3_demos/target/13.wav"
type="audio/wav">
Your browser does not support the <code>audio</code> element.
</audio>
</td>
<td>
<audio controls="controls">
<source
src="https://paddlespeech.bj.bcebos.com/Parakeet/docs/demos/fs2_aishell3_demos/generated/13_002.wav"
type="audio/wav">
Your browser does not support the <code>audio</code> element.
</audio>
</td>
</tr>
<tr>
<td>
<audio controls="controls">
<source
src="https://paddlespeech.bj.bcebos.com/Parakeet/docs/demos/fs2_aishell3_demos/target/14.wav"
type="audio/wav">
Your browser does not support the <code>audio</code> element.
</audio>
</td>
<td>
<audio controls="controls">
<source
src="https://paddlespeech.bj.bcebos.com/Parakeet/docs/demos/fs2_aishell3_demos/generated/14_002.wav"
type="audio/wav">
Your browser does not support the <code>audio</code> element.
</audio>
</td>
</tr>
<tr>
<td>
<audio controls="controls">
<source
src="https://paddlespeech.bj.bcebos.com/Parakeet/docs/demos/fs2_aishell3_demos/target/15.wav"
type="audio/wav">
Your browser does not support the <code>audio</code> element.
</audio>
</td>
<td>
<audio controls="controls">
<source
src="https://paddlespeech.bj.bcebos.com/Parakeet/docs/demos/fs2_aishell3_demos/generated/15_002.wav"
type="audio/wav">
Your browser does not support the <code>audio</code> element.
</audio>
</td>
</tr>
<tr>
<td>
<audio controls="controls">
<source
src="https://paddlespeech.bj.bcebos.com/Parakeet/docs/demos/fs2_aishell3_demos/target/16.wav"
type="audio/wav">
Your browser does not support the <code>audio</code> element.
</audio>
</td>
<td>
<audio controls="controls">
<source
src="https://paddlespeech.bj.bcebos.com/Parakeet/docs/demos/fs2_aishell3_demos/generated/16_002.wav"
type="audio/wav">
Your browser does not support the <code>audio</code> element.
</audio>
</td>
</tr>
<tr>
<td>
<audio controls="controls">
<source
src="https://paddlespeech.bj.bcebos.com/Parakeet/docs/demos/fs2_aishell3_demos/target/17.wav"
type="audio/wav">
Your browser does not support the <code>audio</code> element.
</audio>
</td>
<td>
<audio controls="controls">
<source
src="https://paddlespeech.bj.bcebos.com/Parakeet/docs/demos/fs2_aishell3_demos/generated/17_002.wav"
type="audio/wav">
Your browser does not support the <code>audio</code> element.
</audio>
</td>
</tr>
<tr>
<td>
<audio controls="controls">
<source
src="https://paddlespeech.bj.bcebos.com/Parakeet/docs/demos/fs2_aishell3_demos/target/18.wav"
type="audio/wav">
Your browser does not support the <code>audio</code> element.
</audio>
</td>
<td>
<audio controls="controls">
<source
src="https://paddlespeech.bj.bcebos.com/Parakeet/docs/demos/fs2_aishell3_demos/generated/18_002.wav"
type="audio/wav">
Your browser does not support the <code>audio</code> element.
</audio>
</td>
</tr>
<tr>
<td>
<audio controls="controls">
<source
src="https://paddlespeech.bj.bcebos.com/Parakeet/docs/demos/fs2_aishell3_demos/target/19.wav"
type="audio/wav">
Your browser does not support the <code>audio</code> element.
</audio>
</td>
<td>
<audio controls="controls">
<source
src="https://paddlespeech.bj.bcebos.com/Parakeet/docs/demos/fs2_aishell3_demos/generated/19_002.wav"
type="audio/wav">
Your browser does not support the <code>audio</code> element.
</audio>
</td>
</tr>
<table> <table>
<div> <div>
<br> <br>

@ -6,4 +6,4 @@ Model | Generator Loss |Discriminator Loss
Parallel Wave GAN| adversial loss <br> Feature Matching | Multi-Scale Discriminator | Parallel Wave GAN| adversial loss <br> Feature Matching | Multi-Scale Discriminator |
Mel GAN |adversial loss <br> Multi-resolution STFT loss | adversial loss| Mel GAN |adversial loss <br> Multi-resolution STFT loss | adversial loss|
Multi-Band Mel GAN | adversial loss <br> full band Multi-resolution STFT loss <br> sub band Multi-resolution STFT loss |Multi-Scale Discriminator| Multi-Band Mel GAN | adversial loss <br> full band Multi-resolution STFT loss <br> sub band Multi-resolution STFT loss |Multi-Scale Discriminator|
HiFi GAN |adversial loss <br> Feature Matching <br> Mel-Spectrogram Loss | Multi-Scale Discriminator <br> Multi-Period Discriminato | HiFi GAN |adversial loss <br> Feature Matching <br> Mel-Spectrogram Loss | Multi-Scale Discriminator <br> Multi-Period Discriminator|

@ -27,14 +27,14 @@ At present, there are two mainstream acoustic model structures.
- Acoustic decoder (N Frames - > N Frames). - Acoustic decoder (N Frames - > N Frames).
<div align="left"> <div align="left">
<img src="https://raw.githubusercontent.com/PaddlePaddle/DeepSpeech/develop/docs/images/frame_level_am.png" width=500 /> <br> <img src="https://raw.githubusercontent.com/PaddlePaddle/PaddleSpeech/develop/docs/images/frame_level_am.png" width=500 /> <br>
</div> </div>
- Sequence to sequence acoustic model: - Sequence to sequence acoustic model:
- M Tokens - > N Frames. - M Tokens - > N Frames.
<div align="left"> <div align="left">
<img src="https://raw.githubusercontent.com/PaddlePaddle/DeepSpeech/develop/docs/images/seq2seq_am.png" width=500 /> <br> <img src="https://raw.githubusercontent.com/PaddlePaddle/PaddleSpeech/develop/docs/images/seq2seq_am.png" width=500 /> <br>
</div> </div>
### Tacotron2 ### Tacotron2
@ -54,7 +54,7 @@ At present, there are two mainstream acoustic model structures.
- CBHG postprocess. - CBHG postprocess.
- Vocoder: Griffin-Lim. - Vocoder: Griffin-Lim.
<div align="left"> <div align="left">
<img src="https://raw.githubusercontent.com/PaddlePaddle/DeepSpeech/develop/docs/images/tacotron.png" width=700 /> <br> <img src="https://raw.githubusercontent.com/PaddlePaddle/PaddleSpeech/develop/docs/images/tacotron.png" width=700 /> <br>
</div> </div>
**Advantage of Tacotron:** **Advantage of Tacotron:**
@ -89,10 +89,10 @@ At present, there are two mainstream acoustic model structures.
- The alignment matrix of previous time is considered at the step `t` of decoder. - The alignment matrix of previous time is considered at the step `t` of decoder.
<div align="left"> <div align="left">
<img src="https://raw.githubusercontent.com/PaddlePaddle/DeepSpeech/develop/docs/images/tacotron2.png" width=500 /> <br> <img src="https://raw.githubusercontent.com/PaddlePaddle/PaddleSpeech/develop/docs/images/tacotron2.png" width=500 /> <br>
</div> </div>
You can find PaddleSpeech TTS's tacotron2 with LJSpeech dataset example at [examples/ljspeech/tts0](https://github.com/PaddlePaddle/DeepSpeech/tree/develop/examples/ljspeech/tts0). You can find PaddleSpeech TTS's tacotron2 with LJSpeech dataset example at [examples/ljspeech/tts0](https://github.com/PaddlePaddle/PaddleSpeech/tree/develop/examples/ljspeech/tts0).
### TransformerTTS ### TransformerTTS
**Disadvantages of the Tacotrons:** **Disadvantages of the Tacotrons:**
@ -118,7 +118,7 @@ Transformer TTS is a combination of Tacotron2 and Transformer.
- Positional Encoding. - Positional Encoding.
<div align="left"> <div align="left">
<img src="https://raw.githubusercontent.com/PaddlePaddle/DeepSpeech/develop/docs/images/transformer.png" width=500 /> <br> <img src="https://raw.githubusercontent.com/PaddlePaddle/PaddleSpeech/develop/docs/images/transformer.png" width=500 /> <br>
</div> </div>
#### Transformer TTS #### Transformer TTS
@ -138,7 +138,7 @@ Transformer TTS is a seq2seq acoustic model based on Transformer and Tacotron2.
- Uniform scale position encoding may have a negative impact on input or output sequences. - Uniform scale position encoding may have a negative impact on input or output sequences.
<div align="left"> <div align="left">
<img src="https://raw.githubusercontent.com/PaddlePaddle/DeepSpeech/develop/docs/images/transformer_tts.png" width=500 /> <br> <img src="https://raw.githubusercontent.com/PaddlePaddle/PaddleSpeech/develop/docs/images/transformer_tts.png" width=500 /> <br>
</div> </div>
**Disadvantages of Transformer TTS:** **Disadvantages of Transformer TTS:**
@ -146,7 +146,7 @@ Transformer TTS is a seq2seq acoustic model based on Transformer and Tacotron2.
- The ability to perceive local information is weak, and local information is more related to pronunciation. - The ability to perceive local information is weak, and local information is more related to pronunciation.
- Stability is worse than Tacotron2. - Stability is worse than Tacotron2.
You can find PaddleSpeech TTS's Transformer TTS with LJSpeech dataset example at [examples/ljspeech/tts1](https://github.com/PaddlePaddle/DeepSpeech/tree/develop/examples/ljspeech/tts1). You can find PaddleSpeech TTS's Transformer TTS with LJSpeech dataset example at [examples/ljspeech/tts1](https://github.com/PaddlePaddle/PaddleSpeech/tree/develop/examples/ljspeech/tts1).
### FastSpeech2 ### FastSpeech2
@ -184,14 +184,14 @@ Instead of using the encoder-attention-decoder based architecture as adopted by
• Can be generated in parallel (decoding time is less affected by sequence length) • Can be generated in parallel (decoding time is less affected by sequence length)
<div align="left"> <div align="left">
<img src="https://raw.githubusercontent.com/PaddlePaddle/DeepSpeech/develop/docs/images/fastspeech.png" width=800 /> <br> <img src="https://raw.githubusercontent.com/PaddlePaddle/PaddleSpeech/develop/docs/images/fastspeech.png" width=800 /> <br>
</div> </div>
#### FastPitch #### FastPitch
[FastPitch](https://arxiv.org/abs/2006.06873) follows FastSpeech. A single pitch value is predicted for every temporal location, which improves the overall quality of synthesized speech. [FastPitch](https://arxiv.org/abs/2006.06873) follows FastSpeech. A single pitch value is predicted for every temporal location, which improves the overall quality of synthesized speech.
<div align="left"> <div align="left">
<img src="https://raw.githubusercontent.com/PaddlePaddle/DeepSpeech/develop/docs/images/fastpitch.png" width=500 /> <br> <img src="https://raw.githubusercontent.com/PaddlePaddle/PaddleSpeech/develop/docs/images/fastpitch.png" width=500 /> <br>
</div> </div>
#### FastSpeech2 #### FastSpeech2
@ -209,10 +209,10 @@ Instead of using the encoder-attention-decoder based architecture as adopted by
FastSpeech2 is similar to FastPitch but introduces more variation information of speech. FastSpeech2 is similar to FastPitch but introduces more variation information of speech.
<div align="left"> <div align="left">
<img src="https://raw.githubusercontent.com/PaddlePaddle/DeepSpeech/develop/docs/images/fastspeech2.png" width=800 /> <br> <img src="https://raw.githubusercontent.com/PaddlePaddle/PaddleSpeech/develop/docs/images/fastspeech2.png" width=800 /> <br>
</div> </div>
You can find PaddleSpeech TTS's FastSpeech2/FastPitch with CSMSC dataset example at [examples/csmsc/tts3](https://github.com/PaddlePaddle/DeepSpeech/tree/develop/examples/csmsc/tts3), We use token-averaged pitch and energy values introduced in FastPitch rather than frame level ones in FastSpeech2. You can find PaddleSpeech TTS's FastSpeech2/FastPitch with CSMSC dataset example at [examples/csmsc/tts3](https://github.com/PaddlePaddle/PaddleSpeech/tree/develop/examples/csmsc/tts3), We use token-averaged pitch and energy values introduced in FastPitch rather than frame level ones in FastSpeech2.
### SpeedySpeech ### SpeedySpeech
[SpeedySpeech](https://arxiv.org/abs/2008.03802) simplify the teacher-student architecture of FastSpeech and provide a fast and stable training procedure. [SpeedySpeech](https://arxiv.org/abs/2008.03802) simplify the teacher-student architecture of FastSpeech and provide a fast and stable training procedure.
@ -223,10 +223,10 @@ You can find PaddleSpeech TTS's FastSpeech2/FastPitch with CSMSC dataset example
- Describe a simple data augmentation technique that can be used early in the training to make the teacher network robust to sequential error propagation. - Describe a simple data augmentation technique that can be used early in the training to make the teacher network robust to sequential error propagation.
<div align="left"> <div align="left">
<img src="https://raw.githubusercontent.com/PaddlePaddle/DeepSpeech/develop/docs/images/speedyspeech.png" width=500 /> <br> <img src="https://raw.githubusercontent.com/PaddlePaddle/PaddleSpeech/develop/docs/images/speedyspeech.png" width=500 /> <br>
</div> </div>
You can find PaddleSpeech TTS's SpeedySpeech with CSMSC dataset example at [examples/csmsc/tts2](https://github.com/PaddlePaddle/DeepSpeech/tree/develop/examples/csmsc/tts2). You can find PaddleSpeech TTS's SpeedySpeech with CSMSC dataset example at [examples/csmsc/tts2](https://github.com/PaddlePaddle/PaddleSpeech/tree/develop/examples/csmsc/tts2).
## Vocoders ## Vocoders
In speech synthesis, the main task of the vocoder is to convert the spectral parameters predicted by the acoustic model into the final speech waveform. In speech synthesis, the main task of the vocoder is to convert the spectral parameters predicted by the acoustic model into the final speech waveform.
@ -276,7 +276,7 @@ Here, we introduce a Flow-based vocoder WaveFlow and a GAN-based vocoder Paralle
- It is a small-footprint flow-based model for raw audio. It has only 5.9M parameters, which is 15x smalller than WaveGlow (87.9M). - It is a small-footprint flow-based model for raw audio. It has only 5.9M parameters, which is 15x smalller than WaveGlow (87.9M).
- It is directly trained with maximum likelihood without probability density distillation and auxiliary losses as used in [Parallel WaveNet](https://arxiv.org/abs/1711.10433) and [ClariNet](https://openreview.net/pdf?id=HklY120cYm), which simplifies the training pipeline and reduces the cost of development. - It is directly trained with maximum likelihood without probability density distillation and auxiliary losses as used in [Parallel WaveNet](https://arxiv.org/abs/1711.10433) and [ClariNet](https://openreview.net/pdf?id=HklY120cYm), which simplifies the training pipeline and reduces the cost of development.
You can find PaddleSpeech TTS's WaveFlow with LJSpeech dataset example at [examples/ljspeech/voc0](https://github.com/PaddlePaddle/DeepSpeech/tree/develop/examples/ljspeech/voc0). You can find PaddleSpeech TTS's WaveFlow with LJSpeech dataset example at [examples/ljspeech/voc0](https://github.com/PaddlePaddle/PaddleSpeech/tree/develop/examples/ljspeech/voc0).
### Parallel WaveGAN ### Parallel WaveGAN
[Parallel WaveGAN](https://arxiv.org/abs/1910.11480) trains a non-autoregressive WaveNet variant as a generator in a GAN based training method. [Parallel WaveGAN](https://arxiv.org/abs/1910.11480) trains a non-autoregressive WaveNet variant as a generator in a GAN based training method.
@ -289,7 +289,7 @@ You can find PaddleSpeech TTS's WaveFlow with LJSpeech dataset example at [examp
- Multi-resolution STFT loss. - Multi-resolution STFT loss.
<div align="left"> <div align="left">
<img src="https://raw.githubusercontent.com/PaddlePaddle/DeepSpeech/develop/docs/images/pwg.png" width=600 /> <br> <img src="https://raw.githubusercontent.com/PaddlePaddle/PaddleSpeech/develop/docs/images/pwg.png" width=600 /> <br>
</div> </div>
You can find PaddleSpeech TTS's Parallel WaveGAN with CSMSC example at [examples/csmsc/voc1](https://github.com/PaddlePaddle/DeepSpeech/tree/develop/examples/csmsc/voc1). You can find PaddleSpeech TTS's Parallel WaveGAN with CSMSC example at [examples/csmsc/voc1](https://github.com/PaddlePaddle/PaddleSpeech/tree/develop/examples/csmsc/voc1).

@ -18,7 +18,7 @@ The models in PaddleSpeech TTS have the following mapping relationship:
## Quick Start ## Quick Start
Let's take a FastSpeech2 + Parallel WaveGAN with CSMSC dataset for instance. (./examples/csmsc/)(https://github.com/PaddlePaddle/DeepSpeech/tree/develop/examples/csmsc) Let's take a FastSpeech2 + Parallel WaveGAN with CSMSC dataset for instance. (./examples/csmsc/)(https://github.com/PaddlePaddle/PaddleSpeech/tree/develop/examples/csmsc)
### Train Parallel WaveGAN with CSMSC ### Train Parallel WaveGAN with CSMSC
- Go to directory - Go to directory

@ -1,5 +1,5 @@
# Chinese Rule Based Text Frontend # Chinese Rule Based Text Frontend
A TTS system mainly includes three modules: `Text Frontend`, `Acoustic model` and `Vocoder`. We provide a complete Chinese text frontend module in PaddleSpeech TTS, see exapmle in [examples/other/text_frontend/](https://github.com/PaddlePaddle/DeepSpeech/tree/develop/examples/other/text_frontend). A TTS system mainly includes three modules: `Text Frontend`, `Acoustic model` and `Vocoder`. We provide a complete Chinese text frontend module in PaddleSpeech TTS, see exapmle in [examples/other/text_frontend/](https://github.com/PaddlePaddle/PaddleSpeech/tree/develop/examples/other/text_frontend).
A text frontend module mainly includes: A text frontend module mainly includes:
- Text Segmentation - Text Segmentation

@ -174,7 +174,7 @@
"source": [ "source": [
"# 实践\n", "# 实践\n",
"<br></br>\n", "<br></br>\n",
"<font size=4>环境安装请参考: https://github.com/PaddlePaddle/DeepSpeech/blob/develop/docs/source/install.md</font>\n", "<font size=4>环境安装请参考: https://github.com/PaddlePaddle/PaddleSpeech/blob/develop/docs/source/install.md</font>\n",
"\n", "\n",
"<br></br>\n", "<br></br>\n",
"\n", "\n",
@ -414,8 +414,8 @@
"<br></br>\n", "<br></br>\n",
"<font size=4>相关 examples:\n", "<font size=4>相关 examples:\n",
" \n", " \n",
"https://github.com/PaddlePaddle/DeepSpeech/tree/develop/examples/other/tn\n", "https://github.com/PaddlePaddle/PaddleSpeech/tree/develop/examples/other/tn\n",
"https://github.com/PaddlePaddle/DeepSpeech/tree/develop/examples/other/g2p</font>\n", "https://github.com/PaddlePaddle/PaddleSpeech/tree/develop/examples/other/g2p</font>\n",
"\n", "\n",
"<br></br>\n", "<br></br>\n",
"<font size=4>(未来计划推出基于深度学习的文本前端模块)</font>" "<font size=4>(未来计划推出基于深度学习的文本前端模块)</font>"
@ -624,7 +624,7 @@
"Parallel Wave GAN| adversial loss <br> Feature Matching | Multi-Scale Discriminator |\n", "Parallel Wave GAN| adversial loss <br> Feature Matching | Multi-Scale Discriminator |\n",
"Mel GAN |adversial loss <br> Multi-resolution STFT loss | adversial loss|\n", "Mel GAN |adversial loss <br> Multi-resolution STFT loss | adversial loss|\n",
"Multi-Band Mel GAN | adversial loss <br> full band Multi-resolution STFT loss <br> sub band Multi-resolution STFT loss |Multi-Scale Discriminator|\n", "Multi-Band Mel GAN | adversial loss <br> full band Multi-resolution STFT loss <br> sub band Multi-resolution STFT loss |Multi-Scale Discriminator|\n",
"HiFi GAN |adversial loss <br> Feature Matching <br> Mel-Spectrogram Loss | Multi-Scale Discriminator <br> Multi-Period Discriminato |\n" "HiFi GAN |adversial loss <br> Feature Matching <br> Mel-Spectrogram Loss | Multi-Scale Discriminator <br> Multi-Period Discriminator|\n"
] ]
}, },
{ {
@ -800,7 +800,7 @@
"<br></br>\n", "<br></br>\n",
"## 基于 CSMCS 数据集训练 FastSpeech2 模型\n", "## 基于 CSMCS 数据集训练 FastSpeech2 模型\n",
"```bash\n", "```bash\n",
"git clone https://github.com/PaddlePaddle/DeepSpeech.git\n", "git clone https://github.com/PaddlePaddle/PaddleSpeech.git\n",
"cd examples/csmsc/tts\n", "cd examples/csmsc/tts\n",
"```\n", "```\n",
"<font size=3>根据 README.md, 下载 CSMCS 数据集和其对应的强制对齐文件, 并放置在对应的位置<font>\n", "<font size=3>根据 README.md, 下载 CSMCS 数据集和其对应的强制对齐文件, 并放置在对应的位置<font>\n",
@ -849,7 +849,7 @@
"<br></br>\n", "<br></br>\n",
"## 基于 CSMCS 数据集训练 Parallel WaveGAN 模型\n", "## 基于 CSMCS 数据集训练 Parallel WaveGAN 模型\n",
"```bash\n", "```bash\n",
"git clone https://github.com/PaddlePaddle/DeepSpeech.git\n", "git clone https://github.com/PaddlePaddle/PaddleSpeech.git\n",
"cd examples/csmsc/voc1\n", "cd examples/csmsc/voc1\n",
"```\n", "```\n",
"<font size=3>根据 README.md, 下载 CSMCS 数据集和其对应的强制对齐文件, 并放置在对应的位置<font>\n", "<font size=3>根据 README.md, 下载 CSMCS 数据集和其对应的强制对齐文件, 并放置在对应的位置<font>\n",
@ -912,7 +912,7 @@
"metadata": {}, "metadata": {},
"source": [ "source": [
"# 关注 PaddleSpeech\n", "# 关注 PaddleSpeech\n",
"<font size=3>https://github.com/PaddlePaddle/DeepSpeech/<font>" "<font size=3>https://github.com/PaddlePaddle/PaddleSpeech/<font>"
] ]
} }
], ],

@ -1,8 +1,8 @@
# Tacotron2 + AISHELL-3 Voice Cloning # Tacotron2 + AISHELL-3 Voice Cloning
This example contains code used to train a [Tacotron2 ](https://arxiv.org/abs/1712.05884) model with [AISHELL-3](http://www.aishelltech.com/aishell_3). The trained model can be used in Voice Cloning Task, We refer to the model structure of [Transfer Learning from Speaker Verification to Multispeaker Text-To-Speech Synthesis](https://arxiv.org/pdf/1806.04558.pdf) . The general steps are as follows: This example contains code used to train a [Tacotron2 ](https://arxiv.org/abs/1712.05884) model with [AISHELL-3](http://www.aishelltech.com/aishell_3). The trained model can be used in Voice Cloning Task, We refer to the model structure of [Transfer Learning from Speaker Verification to Multispeaker Text-To-Speech Synthesis](https://arxiv.org/pdf/1806.04558.pdf) . The general steps are as follows:
1. Speaker Encoder: We use a Speaker Verification to train a speaker encoder. Datasets used in this task are different from those used in Tacotron2, because the transcriptions are not needed, we use more datasets, refer to [ge2e](https://github.com/PaddlePaddle/DeepSpeech/tree/develop/examples/other/ge2e). 1. Speaker Encoder: We use a Speaker Verification to train a speaker encoder. Datasets used in this task are different from those used in Tacotron2, because the transcriptions are not needed, we use more datasets, refer to [ge2e](https://github.com/PaddlePaddle/PaddleSpeech/tree/develop/examples/other/ge2e).
2. Synthesizer: Then, we use the trained speaker encoder to generate utterance embedding for each sentence in AISHELL-3. This embedding is a extra input of Tacotron2 which will be concated with encoder outputs. 2. Synthesizer: Then, we use the trained speaker encoder to generate utterance embedding for each sentence in AISHELL-3. This embedding is a extra input of Tacotron2 which will be concated with encoder outputs.
3. Vocoder: We use WaveFlow as the neural Vocoder, refer to [waveflow](https://github.com/PaddlePaddle/DeepSpeech/tree/develop/examples/ljspeech/voc0). 3. Vocoder: We use WaveFlow as the neural Vocoder, refer to [waveflow](https://github.com/PaddlePaddle/PaddleSpeech/tree/develop/examples/ljspeech/voc0).
## Get Started ## Get Started
Assume the path to the dataset is `~/datasets/data_aishell3`. Assume the path to the dataset is `~/datasets/data_aishell3`.
@ -39,9 +39,9 @@ There are silence in the edge of AISHELL-3's wavs, and the audio amplitude is ve
We use Montreal Force Aligner 1.0. The label in aishell3 include pinyinso the lexicon we provided to MFA is pinyin rather than Chinese characters. And the prosody marks(`$` and `%`) need to be removed. You shoud preprocess the dataset into the format which MFA needs, the texts have the same name with wavs and have the suffix `.lab`. We use Montreal Force Aligner 1.0. The label in aishell3 include pinyinso the lexicon we provided to MFA is pinyin rather than Chinese characters. And the prosody marks(`$` and `%`) need to be removed. You shoud preprocess the dataset into the format which MFA needs, the texts have the same name with wavs and have the suffix `.lab`.
We use [lexicon.txt](https://github.com/PaddlePaddle/DeepSpeech/blob/develop/paddlespeech/t2s/exps/voice_cloning/tacotron2_ge2e/lexicon.txt) as the lexicon. We use [lexicon.txt](https://github.com/PaddlePaddle/PaddleSpeech/blob/develop/paddlespeech/t2s/exps/voice_cloning/tacotron2_ge2e/lexicon.txt) as the lexicon.
You can download the alignment results from here [alignment_aishell3.tar.gz](https://paddlespeech.bj.bcebos.com/Parakeet/alignment_aishell3.tar.gz), or train your own MFA model reference to [use_mfa example](https://github.com/PaddlePaddle/DeepSpeech/tree/develop/examples/other/use_mfa) (use MFA1.x now) of our repo. You can download the alignment results from here [alignment_aishell3.tar.gz](https://paddlespeech.bj.bcebos.com/Parakeet/alignment_aishell3.tar.gz), or train your own MFA model reference to [use_mfa example](https://github.com/PaddlePaddle/PaddleSpeech/tree/develop/examples/other/use_mfa) (use MFA1.x now) of our repo.
```bash ```bash
if [ ${stage} -le 1 ] && [ ${stop_stage} -ge 1 ]; then if [ ${stage} -le 1 ] && [ ${stop_stage} -ge 1 ]; then

@ -7,7 +7,7 @@ Download CSMSC from it's [Official Website](https://test.data-baker.com/data/ind
### Get MFA result of CSMSC and Extract it ### Get MFA result of CSMSC and Extract it
We use [MFA](https://github.com/MontrealCorpusTools/Montreal-Forced-Aligner) to get durations for SPEEDYSPEECH. We use [MFA](https://github.com/MontrealCorpusTools/Montreal-Forced-Aligner) to get durations for SPEEDYSPEECH.
You can download from here [baker_alignment_tone.tar.gz](https://paddlespeech.bj.bcebos.com/MFA/BZNSYP/with_tone/baker_alignment_tone.tar.gz), or train your own MFA model reference to [use_mfa example](https://github.com/PaddlePaddle/DeepSpeech/tree/develop/examples/other/use_mfa) of our repo. You can download from here [baker_alignment_tone.tar.gz](https://paddlespeech.bj.bcebos.com/MFA/BZNSYP/with_tone/baker_alignment_tone.tar.gz), or train your own MFA model reference to [use_mfa example](https://github.com/PaddlePaddle/PaddleSpeech/tree/develop/examples/other/use_mfa) of our repo.
## Get Started ## Get Started
Assume the path to the dataset is `~/datasets/BZNSYP`. Assume the path to the dataset is `~/datasets/BZNSYP`.
@ -89,7 +89,7 @@ optional arguments:
6. `--tones-dict` is the path of the tone vocabulary file. 6. `--tones-dict` is the path of the tone vocabulary file.
### Synthesize ### Synthesize
We use [parallel wavegan](https://github.com/PaddlePaddle/DeepSpeech/tree/develop/examples/csmsc/voc1) as the neural vocoder. We use [parallel wavegan](https://github.com/PaddlePaddle/PaddleSpeech/tree/develop/examples/csmsc/voc1) as the neural vocoder.
Download pretrained parallel wavegan model from [pwg_baker_ckpt_0.4.zip](https://paddlespeech.bj.bcebos.com/Parakeet/pwg_baker_ckpt_0.4.zip) and unzip it. Download pretrained parallel wavegan model from [pwg_baker_ckpt_0.4.zip](https://paddlespeech.bj.bcebos.com/Parakeet/pwg_baker_ckpt_0.4.zip) and unzip it.
```bash ```bash
unzip pwg_baker_ckpt_0.4.zip unzip pwg_baker_ckpt_0.4.zip

@ -7,7 +7,7 @@ Download CSMSC from it's [Official Website](https://test.data-baker.com/data/ind
### Get MFA result of CSMSC and Extract it ### Get MFA result of CSMSC and Extract it
We use [MFA](https://github.com/MontrealCorpusTools/Montreal-Forced-Aligner) to get durations for fastspeech2. We use [MFA](https://github.com/MontrealCorpusTools/Montreal-Forced-Aligner) to get durations for fastspeech2.
You can download from here [baker_alignment_tone.tar.gz](https://paddlespeech.bj.bcebos.com/MFA/BZNSYP/with_tone/baker_alignment_tone.tar.gz), or train your own MFA model reference to [use_mfa example](https://github.com/PaddlePaddle/DeepSpeech/tree/develop/examples/other/use_mfa) of our repo. You can download from here [baker_alignment_tone.tar.gz](https://paddlespeech.bj.bcebos.com/MFA/BZNSYP/with_tone/baker_alignment_tone.tar.gz), or train your own MFA model reference to [use_mfa example](https://github.com/PaddlePaddle/PaddleSpeech/tree/develop/examples/other/use_mfa) of our repo.
## Get Started ## Get Started
Assume the path to the dataset is `~/datasets/BZNSYP`. Assume the path to the dataset is `~/datasets/BZNSYP`.
@ -87,7 +87,7 @@ optional arguments:
5. `--phones-dict` is the path of the phone vocabulary file. 5. `--phones-dict` is the path of the phone vocabulary file.
### Synthesize ### Synthesize
We use [parallel wavegan](https://github.com/PaddlePaddle/DeepSpeech/tree/develop/examples/csmsc/voc1) as the neural vocoder. We use [parallel wavegan](https://github.com/PaddlePaddle/PaddleSpeech/tree/develop/examples/csmsc/voc1) as the neural vocoder.
Download pretrained parallel wavegan model from [pwg_baker_ckpt_0.4.zip](https://paddlespeech.bj.bcebos.com/Parakeet/pwg_baker_ckpt_0.4.zip) and unzip it. Download pretrained parallel wavegan model from [pwg_baker_ckpt_0.4.zip](https://paddlespeech.bj.bcebos.com/Parakeet/pwg_baker_ckpt_0.4.zip) and unzip it.
```bash ```bash
unzip pwg_baker_ckpt_0.4.zip unzip pwg_baker_ckpt_0.4.zip

@ -6,7 +6,7 @@ Download CSMSC from the [official website](https://www.data-baker.com/data/index
### Get MFA results for silence trim ### Get MFA results for silence trim
We use [MFA](https://github.com/MontrealCorpusTools/Montreal-Forced-Aligner) results to cut silence in the edge of audio. We use [MFA](https://github.com/MontrealCorpusTools/Montreal-Forced-Aligner) results to cut silence in the edge of audio.
You can download from here [baker_alignment_tone.tar.gz](https://paddlespeech.bj.bcebos.com/MFA/BZNSYP/with_tone/baker_alignment_tone.tar.gz), or train your own MFA model reference to [use_mfa example](https://github.com/PaddlePaddle/DeepSpeech/tree/develop/examples/other/use_mfa) of our repo. You can download from here [baker_alignment_tone.tar.gz](https://paddlespeech.bj.bcebos.com/MFA/BZNSYP/with_tone/baker_alignment_tone.tar.gz), or train your own MFA model reference to [use_mfa example](https://github.com/PaddlePaddle/PaddleSpeech/tree/develop/examples/other/use_mfa) of our repo.
## Get Started ## Get Started
Assume the path to the dataset is `~/datasets/BZNSYP`. Assume the path to the dataset is `~/datasets/BZNSYP`.

@ -75,7 +75,7 @@ optional arguments:
config, passing in KEY VALUE pairs config, passing in KEY VALUE pairs
-v, --verbose print msg -v, --verbose print msg
``` ```
**Ps.** You can use [waveflow](https://github.com/PaddlePaddle/DeepSpeech/tree/develop/examples/ljspeech/voc0) as the neural vocoder to synthesize mels to wavs. (Please refer to `synthesize.sh` in our LJSpeech waveflow example) **Ps.** You can use [waveflow](https://github.com/PaddlePaddle/PaddleSpeech/tree/develop/examples/ljspeech/voc0) as the neural vocoder to synthesize mels to wavs. (Please refer to `synthesize.sh` in our LJSpeech waveflow example)
## Pretrained Models ## Pretrained Models
Pretrained Models can be downloaded from links below. We provide 2 models with different configurations. Pretrained Models can be downloaded from links below. We provide 2 models with different configurations.

@ -78,7 +78,7 @@ optional arguments:
5. `--phones-dict` is the path of the phone vocabulary file. 5. `--phones-dict` is the path of the phone vocabulary file.
## Synthesize ## Synthesize
We use [waveflow](https://github.com/PaddlePaddle/DeepSpeech/tree/develop/examples/ljspeech/voc0) as the neural vocoder. We use [waveflow](https://github.com/PaddlePaddle/PaddleSpeech/tree/develop/examples/ljspeech/voc0) as the neural vocoder.
Download Pretrained WaveFlow Model with residual channel equals 128 from [waveflow_ljspeech_ckpt_0.3.zip](https://paddlespeech.bj.bcebos.com/Parakeet/waveflow_ljspeech_ckpt_0.3.zip) and unzip it. Download Pretrained WaveFlow Model with residual channel equals 128 from [waveflow_ljspeech_ckpt_0.3.zip](https://paddlespeech.bj.bcebos.com/Parakeet/waveflow_ljspeech_ckpt_0.3.zip) and unzip it.
```bash ```bash
unzip waveflow_ljspeech_ckpt_0.3.zip unzip waveflow_ljspeech_ckpt_0.3.zip

@ -7,7 +7,7 @@ Download LJSpeech-1.1 from the [official website](https://keithito.com/LJ-Speech
### Get MFA result of LJSpeech-1.1 and Extract it ### Get MFA result of LJSpeech-1.1 and Extract it
We use [MFA](https://github.com/MontrealCorpusTools/Montreal-Forced-Aligner) to get durations for fastspeech2. We use [MFA](https://github.com/MontrealCorpusTools/Montreal-Forced-Aligner) to get durations for fastspeech2.
You can download from here [ljspeech_alignment.tar.gz](https://paddlespeech.bj.bcebos.com/MFA/LJSpeech-1.1/ljspeech_alignment.tar.gz), or train your own MFA model reference to [use_mfa example](https://github.com/PaddlePaddle/DeepSpeech/tree/develop/examples/other/use_mfa) of our repo. You can download from here [ljspeech_alignment.tar.gz](https://paddlespeech.bj.bcebos.com/MFA/LJSpeech-1.1/ljspeech_alignment.tar.gz), or train your own MFA model reference to [use_mfa example](https://github.com/PaddlePaddle/PaddleSpeech/tree/develop/examples/other/use_mfa) of our repo.
## Get Started ## Get Started
Assume the path to the dataset is `~/datasets/LJSpeech-1.1`. Assume the path to the dataset is `~/datasets/LJSpeech-1.1`.
@ -86,7 +86,7 @@ optional arguments:
5. `--phones-dict` is the path of the phone vocabulary file. 5. `--phones-dict` is the path of the phone vocabulary file.
### Synthesize ### Synthesize
We use [parallel wavegan](https://github.com/PaddlePaddle/DeepSpeech/tree/develop/examples/ljspeech/voc1) as the neural vocoder. We use [parallel wavegan](https://github.com/PaddlePaddle/PaddleSpeech/tree/develop/examples/ljspeech/voc1) as the neural vocoder.
Download pretrained parallel wavegan model from [pwg_ljspeech_ckpt_0.5.zip](https://paddlespeech.bj.bcebos.com/Parakeet/pwg_ljspeech_ckpt_0.5.zip) and unzip it. Download pretrained parallel wavegan model from [pwg_ljspeech_ckpt_0.5.zip](https://paddlespeech.bj.bcebos.com/Parakeet/pwg_ljspeech_ckpt_0.5.zip) and unzip it.
```bash ```bash
unzip pwg_ljspeech_ckpt_0.5.zip unzip pwg_ljspeech_ckpt_0.5.zip

@ -5,7 +5,7 @@ This example contains code used to train a [parallel wavegan](http://arxiv.org/a
Download LJSpeech-1.1 from the [official website](https://keithito.com/LJ-Speech-Dataset/). Download LJSpeech-1.1 from the [official website](https://keithito.com/LJ-Speech-Dataset/).
### Get MFA results for silence trim ### Get MFA results for silence trim
We use [MFA](https://github.com/MontrealCorpusTools/Montreal-Forced-Aligner) results to cut silence in the edge of audio. We use [MFA](https://github.com/MontrealCorpusTools/Montreal-Forced-Aligner) results to cut silence in the edge of audio.
You can download from here [ljspeech_alignment.tar.gz](https://paddlespeech.bj.bcebos.com/MFA/LJSpeech-1.1/ljspeech_alignment.tar.gz), or train your own MFA model reference to [use_mfa example](https://github.com/PaddlePaddle/DeepSpeech/tree/develop/examples/other/use_mfa) of our repo. You can download from here [ljspeech_alignment.tar.gz](https://paddlespeech.bj.bcebos.com/MFA/LJSpeech-1.1/ljspeech_alignment.tar.gz), or train your own MFA model reference to [use_mfa example](https://github.com/PaddlePaddle/PaddleSpeech/tree/develop/examples/other/use_mfa) of our repo.
## Get Started ## Get Started
Assume the path to the dataset is `~/datasets/LJSpeech-1.1`. Assume the path to the dataset is `~/datasets/LJSpeech-1.1`.

@ -1,5 +1,5 @@
# Speaker Encoder # Speaker Encoder
This experiment trains a speaker encoder with speaker verification as its task. It is done as a part of the experiment of transfer learning from speaker verification to multispeaker text-to-speech synthesis, which can be found at [examples/aishell3/vc0](https://github.com/PaddlePaddle/DeepSpeech/tree/develop/examples/aishell3/vc0). The trained speaker encoder is used to extract utterance embeddings from utterances. This experiment trains a speaker encoder with speaker verification as its task. It is done as a part of the experiment of transfer learning from speaker verification to multispeaker text-to-speech synthesis, which can be found at [examples/aishell3/vc0](https://github.com/PaddlePaddle/PaddleSpeech/tree/develop/examples/aishell3/vc0). The trained speaker encoder is used to extract utterance embeddings from utterances.
## Model ## Model
The model used in this experiment is the speaker encoder with text independent speaker verification task in [GENERALIZED END-TO-END LOSS FOR SPEAKER VERIFICATION](https://arxiv.org/pdf/1710.10467.pdf). GE2E-softmax loss is used. The model used in this experiment is the speaker encoder with text independent speaker verification task in [GENERALIZED END-TO-END LOSS FOR SPEAKER VERIFICATION](https://arxiv.org/pdf/1710.10467.pdf). GE2E-softmax loss is used.

@ -7,8 +7,8 @@ Download VCTK-0.92 from the [official website](https://datashare.ed.ac.uk/handle
### Get MFA result of VCTK and Extract it ### Get MFA result of VCTK and Extract it
We use [MFA](https://github.com/MontrealCorpusTools/Montreal-Forced-Aligner) to get durations for fastspeech2. We use [MFA](https://github.com/MontrealCorpusTools/Montreal-Forced-Aligner) to get durations for fastspeech2.
You can download from here [vctk_alignment.tar.gz](https://paddlespeech.bj.bcebos.com/MFA/VCTK-Corpus-0.92/vctk_alignment.tar.gz), or train your own MFA model reference to [use_mfa example](https://github.com/PaddlePaddle/DeepSpeech/tree/develop/examples/other/use_mfa) of our repo. You can download from here [vctk_alignment.tar.gz](https://paddlespeech.bj.bcebos.com/MFA/VCTK-Corpus-0.92/vctk_alignment.tar.gz), or train your own MFA model reference to [use_mfa example](https://github.com/PaddlePaddle/PaddleSpeech/tree/develop/examples/other/use_mfa) of our repo.
ps: we remove three speakers in VCTK-0.92 (see [reorganize_vctk.py](https://github.com/PaddlePaddle/DeepSpeech/blob/develop/examples/other/use_mfa/local/reorganize_vctk.py)): ps: we remove three speakers in VCTK-0.92 (see [reorganize_vctk.py](https://github.com/PaddlePaddle/PaddleSpeech/blob/develop/examples/other/use_mfa/local/reorganize_vctk.py)):
1. `p315`, because no txt for it. 1. `p315`, because no txt for it.
2. `p280` and `p362`, because no *_mic2.flac (which is better than *_mic1.flac) for them. 2. `p280` and `p362`, because no *_mic2.flac (which is better than *_mic1.flac) for them.
@ -88,7 +88,7 @@ optional arguments:
4. `--phones-dict` is the path of the phone vocabulary file. 4. `--phones-dict` is the path of the phone vocabulary file.
### Synthesize ### Synthesize
We use [parallel wavegan](https://github.com/PaddlePaddle/DeepSpeech/tree/develop/examples/vctk/voc1) as the neural vocoder. We use [parallel wavegan](https://github.com/PaddlePaddle/PaddleSpeech/tree/develop/examples/vctk/voc1) as the neural vocoder.
Download pretrained parallel wavegan model from [pwg_vctk_ckpt_0.5.zip](https://paddlespeech.bj.bcebos.com/Parakeet/pwg_vctk_ckpt_0.5.zip)and unzip it. Download pretrained parallel wavegan model from [pwg_vctk_ckpt_0.5.zip](https://paddlespeech.bj.bcebos.com/Parakeet/pwg_vctk_ckpt_0.5.zip)and unzip it.
```bash ```bash

@ -7,8 +7,8 @@ Download VCTK-0.92 from the [official website](https://datashare.ed.ac.uk/handl
### Get MFA results for silence trim ### Get MFA results for silence trim
We use [MFA](https://github.com/MontrealCorpusTools/Montreal-Forced-Aligner) results to cut silence in the edge of audio. We use [MFA](https://github.com/MontrealCorpusTools/Montreal-Forced-Aligner) results to cut silence in the edge of audio.
You can download from here [vctk_alignment.tar.gz](https://paddlespeech.bj.bcebos.com/MFA/VCTK-Corpus-0.92/vctk_alignment.tar.gz), or train your own MFA model reference to [use_mfa example](https://github.com/PaddlePaddle/DeepSpeech/tree/develop/examples/other/use_mfa) of our repo. You can download from here [vctk_alignment.tar.gz](https://paddlespeech.bj.bcebos.com/MFA/VCTK-Corpus-0.92/vctk_alignment.tar.gz), or train your own MFA model reference to [use_mfa example](https://github.com/PaddlePaddle/PaddleSpeech/tree/develop/examples/other/use_mfa) of our repo.
ps: we remove three speakers in VCTK-0.92 (see [reorganize_vctk.py](https://github.com/PaddlePaddle/DeepSpeech/blob/develop/examples/other/use_mfa/local/reorganize_vctk.py)): ps: we remove three speakers in VCTK-0.92 (see [reorganize_vctk.py](https://github.com/PaddlePaddle/PaddleSpeech/blob/develop/examples/other/use_mfa/local/reorganize_vctk.py)):
1. `p315`, because no txt for it. 1. `p315`, because no txt for it.
2. `p280` and `p362`, because no *_mic2.flac (which is better than *_mic1.flac) for them. 2. `p280` and `p362`, because no *_mic2.flac (which is better than *_mic1.flac) for them.

@ -220,8 +220,8 @@ class Frontend():
sentence: str, sentence: str,
merge_sentences: bool=True, merge_sentences: bool=True,
with_erhua: bool=True, with_erhua: bool=True,
print_info: bool=False, robot: bool=False,
robot: bool=False) -> List[List[str]]: print_info: bool=False) -> List[List[str]]:
sentences = self.text_normalizer.normalize(sentence) sentences = self.text_normalizer.normalize(sentence)
phonemes = self._g2p( phonemes = self._g2p(
sentences, merge_sentences=merge_sentences, with_erhua=with_erhua) sentences, merge_sentences=merge_sentences, with_erhua=with_erhua)
@ -251,8 +251,8 @@ class Frontend():
sentence: str, sentence: str,
merge_sentences: bool=True, merge_sentences: bool=True,
get_tone_ids: bool=False, get_tone_ids: bool=False,
print_info: bool=False, robot: bool=False,
robot: bool=False) -> Dict[str, List[paddle.Tensor]]: print_info: bool=False) -> Dict[str, List[paddle.Tensor]]:
phonemes = self.get_phonemes( phonemes = self.get_phonemes(
sentence, sentence,
merge_sentences=merge_sentences, merge_sentences=merge_sentences,

@ -420,9 +420,10 @@ class FastSpeech2(nn.Layer):
if is_inference: if is_inference:
# (B, Tmax) # (B, Tmax)
d_outs = self.duration_predictor.inference(hs, d_masks)
if ds is not None: if ds is not None:
d_outs = ds d_outs = ds
else:
d_outs = self.duration_predictor.inference(hs, d_masks)
if ps is not None: if ps is not None:
p_outs = ps p_outs = ps
if es is not None: if es is not None:

Loading…
Cancel
Save