Merge pull request from yt605155624/fs2_tostatic

fix fastspeech2 to static
pull/955/head
Hui Zhang 3 years ago committed by GitHub
commit 04d84a87ae
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23

@ -19,7 +19,7 @@ Run the command below to
4. synthesize wavs.
- synthesize waveform from `metadata.jsonl`.
- synthesize waveform from text file.
6. inference using static model.
5. inference using static model.
```bash
./run.sh
```

@ -19,6 +19,7 @@ Run the command below to
4. synthesize wavs.
- synthesize waveform from `metadata.jsonl`.
- synthesize waveform from text file.
5. inference using static model.
```bash
./run.sh
```
@ -189,6 +190,13 @@ optional arguments:
5. `--output-dir` is the directory to save synthesized audio files.
6. `--device is` the type of device to run synthesis, 'cpu' and 'gpu' are supported. 'gpu' is recommended for faster synthesis.
### Inference
After Synthesize, we will get static models of fastspeech2 and pwgan in `${train_output_path}/inference`.
`./local/inference.sh` calls `${BIN_DIR}/inference.py`, which provides a paddle static model inference example for fastspeech2 + pwgan synthesize.
```bash
CUDA_VISIBLE_DEVICES=${gpus} ./local/inference.sh ${train_output_path}
```
## Pretrained Model
Pretrained FastSpeech2 model with no silence in the edge of audios. [fastspeech2_nosil_baker_ckpt_0.4.zip](https://paddlespeech.bj.bcebos.com/Parakeet/fastspeech2_nosil_baker_ckpt_0.4.zip)
@ -215,6 +223,7 @@ python3 ${BIN_DIR}/synthesize_e2e.py \
--pwg-stat=pwg_baker_ckpt_0.4/pwg_stats.npy \
--text=${BIN_DIR}/../sentences.txt \
--output-dir=exp/default/test_e2e \
--inference-dir=exp/default/inference \
--device="gpu" \
--phones-dict=fastspeech2_nosil_baker_ckpt_0.4/phone_id_map.txt
```

@ -0,0 +1,9 @@
#!/bin/bash
train_output_path=$1
python3 ${BIN_DIR}/inference.py \
--inference-dir=${train_output_path}/inference \
--text=${BIN_DIR}/../sentences.txt \
--output-dir=${train_output_path}/pd_infer_out \
--phones-dict=dump/phone_id_map.txt

@ -15,5 +15,6 @@ python3 ${BIN_DIR}/synthesize_e2e.py \
--pwg-stat=pwg_baker_ckpt_0.4/pwg_stats.npy \
--text=${BIN_DIR}/../sentences.txt \
--output-dir=${train_output_path}/test_e2e \
--inference-dir=${train_output_path}/inference \
--device="gpu" \
--phones-dict=dump/phone_id_map.txt

@ -35,3 +35,8 @@ if [ ${stage} -le 3 ] && [ ${stop_stage} -ge 3 ]; then
# synthesize_e2e, vocoder is pwgan
CUDA_VISIBLE_DEVICES=${gpus} ./local/synthesize_e2e.sh ${conf_path} ${train_output_path} ${ckpt_name} || exit -1
fi
if [ ${stage} -le 4 ] && [ ${stop_stage} -ge 4 ]; then
# inference with static model
CUDA_VISIBLE_DEVICES=${gpus} ./local/inference.sh ${train_output_path} || exit -1
fi

@ -0,0 +1,133 @@
# Copyright (c) 2021 PaddlePaddle Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import argparse
import os
from pathlib import Path
import soundfile as sf
from paddle import inference
from parakeet.frontend.zh_frontend import Frontend
def main():
parser = argparse.ArgumentParser(
description="Paddle Infernce with speedyspeech & parallel wavegan.")
parser.add_argument(
"--inference-dir", type=str, help="dir to save inference models")
parser.add_argument(
"--text",
type=str,
help="text to synthesize, a 'utt_id sentence' pair per line")
parser.add_argument("--output-dir", type=str, help="output dir")
parser.add_argument(
"--enable-auto-log", action="store_true", help="use auto log")
parser.add_argument(
"--phones-dict",
type=str,
default="phones.txt",
help="phone vocabulary file.")
args, _ = parser.parse_known_args()
frontend = Frontend(phone_vocab_path=args.phones_dict)
print("frontend done!")
fastspeech2_config = inference.Config(
str(Path(args.inference_dir) / "fastspeech2.pdmodel"),
str(Path(args.inference_dir) / "fastspeech2.pdiparams"))
fastspeech2_config.enable_use_gpu(50, 0)
# This line must be commented, if not, it will OOM
# fastspeech2_config.enable_memory_optim()
fastspeech2_predictor = inference.create_predictor(fastspeech2_config)
pwg_config = inference.Config(
str(Path(args.inference_dir) / "pwg.pdmodel"),
str(Path(args.inference_dir) / "pwg.pdiparams"))
pwg_config.enable_use_gpu(100, 0)
pwg_config.enable_memory_optim()
pwg_predictor = inference.create_predictor(pwg_config)
if args.enable_auto_log:
import auto_log
os.makedirs("output", exist_ok=True)
pid = os.getpid()
logger = auto_log.AutoLogger(
model_name="fastspeech2",
model_precision='float32',
batch_size=1,
data_shape="dynamic",
save_path="./output/auto_log.log",
inference_config=fastspeech2_config,
pids=pid,
process_name=None,
gpu_ids=0,
time_keys=['preprocess_time', 'inference_time', 'postprocess_time'],
warmup=0)
output_dir = Path(args.output_dir)
output_dir.mkdir(parents=True, exist_ok=True)
sentences = []
with open(args.text, 'rt') as f:
for line in f:
utt_id, sentence = line.strip().split()
sentences.append((utt_id, sentence))
for utt_id, sentence in sentences:
if args.enable_auto_log:
logger.times.start()
input_ids = frontend.get_input_ids(sentence, merge_sentences=True)
phone_ids = input_ids["phone_ids"]
phones = phone_ids[0].numpy()
if args.enable_auto_log:
logger.times.stamp()
input_names = fastspeech2_predictor.get_input_names()
phones_handle = fastspeech2_predictor.get_input_handle(input_names[0])
phones_handle.reshape(phones.shape)
phones_handle.copy_from_cpu(phones)
fastspeech2_predictor.run()
output_names = fastspeech2_predictor.get_output_names()
output_handle = fastspeech2_predictor.get_output_handle(output_names[0])
output_data = output_handle.copy_to_cpu()
input_names = pwg_predictor.get_input_names()
mel_handle = pwg_predictor.get_input_handle(input_names[0])
mel_handle.reshape(output_data.shape)
mel_handle.copy_from_cpu(output_data)
pwg_predictor.run()
output_names = pwg_predictor.get_output_names()
output_handle = pwg_predictor.get_output_handle(output_names[0])
wav = output_data = output_handle.copy_to_cpu()
if args.enable_auto_log:
logger.times.stamp()
sf.write(output_dir / (utt_id + ".wav"), wav, samplerate=24000)
if args.enable_auto_log:
logger.times.end(stamp=True)
print(f"{utt_id} done!")
if args.enable_auto_log:
logger.report()
if __name__ == "__main__":
main()

@ -13,12 +13,15 @@
# limitations under the License.
import argparse
import logging
import os
from pathlib import Path
import numpy as np
import paddle
import soundfile as sf
import yaml
from paddle import jit
from paddle.static import InputSpec
from yacs.config import CfgNode
from parakeet.frontend.zh_frontend import Frontend
@ -74,7 +77,21 @@ def evaluate(args, fastspeech2_config, pwg_config):
pwg_normalizer = ZScore(mu, std)
fastspeech2_inference = FastSpeech2Inference(fastspeech2_normalizer, model)
fastspeech2_inference.eval()
fastspeech2_inference = jit.to_static(
fastspeech2_inference, input_spec=[InputSpec([-1], dtype=paddle.int64)])
paddle.jit.save(fastspeech2_inference,
os.path.join(args.inference_dir, "fastspeech2"))
fastspeech2_inference = paddle.jit.load(
os.path.join(args.inference_dir, "fastspeech2"))
pwg_inference = PWGInference(pwg_normalizer, vocoder)
pwg_inference.eval()
pwg_inference = jit.to_static(
pwg_inference, input_spec=[
InputSpec([-1, 80], dtype=paddle.float32),
])
paddle.jit.save(pwg_inference, os.path.join(args.inference_dir, "pwg"))
pwg_inference = paddle.jit.load(os.path.join(args.inference_dir, "pwg"))
output_dir = Path(args.output_dir)
output_dir.mkdir(parents=True, exist_ok=True)
@ -135,6 +152,8 @@ def main():
type=str,
help="text to synthesize, a 'utt_id sentence' pair per line.")
parser.add_argument("--output-dir", type=str, help="output dir.")
parser.add_argument(
"--inference-dir", type=str, help="dir to save inference models")
parser.add_argument(
"--device", type=str, default="gpu", help="device type to use.")
parser.add_argument("--verbose", type=int, default=1, help="verbose.")

@ -341,6 +341,7 @@ class FastSpeech2(nn.Layer):
Tensor
speech_lengths, modified if reduction_factor > 1
"""
# input of embedding must be int64
xs = paddle.cast(text, 'int64')
ilens = paddle.cast(text_lengths, 'int64')
@ -388,7 +389,6 @@ class FastSpeech2(nn.Layer):
tone_id=None) -> Sequence[paddle.Tensor]:
# forward encoder
x_masks = self._source_mask(ilens)
# (B, Tmax, adim)
hs, _ = self.encoder(xs, x_masks)
@ -405,7 +405,6 @@ class FastSpeech2(nn.Layer):
if tone_id is not None:
tone_embs = self.tone_embedding_table(tone_id)
hs = self._integrate_with_tone_embed(hs, tone_embs)
# forward duration predictor and variance predictors
d_masks = make_pad_mask(ilens)
@ -428,6 +427,7 @@ class FastSpeech2(nn.Layer):
e_embs = self.energy_embed(e_outs.transpose((0, 2, 1))).transpose(
(0, 2, 1))
hs = hs + e_embs + p_embs
# (B, Lmax, adim)
hs = self.length_regulator(hs, d_outs, alpha)
else:
@ -438,6 +438,7 @@ class FastSpeech2(nn.Layer):
e_embs = self.energy_embed(es.transpose((0, 2, 1))).transpose(
(0, 2, 1))
hs = hs + e_embs + p_embs
# (B, Lmax, adim)
hs = self.length_regulator(hs, ds)
@ -452,9 +453,11 @@ class FastSpeech2(nn.Layer):
else:
h_masks = None
# (B, Lmax, adim)
zs, _ = self.decoder(hs, h_masks)
# (B, Lmax, odim)
before_outs = self.feat_out(zs).reshape((zs.shape[0], -1, self.odim))
before_outs = self.feat_out(zs).reshape(
(paddle.shape(zs)[0], -1, self.odim))
# postnet -> (B, Lmax//r * r, odim)
if self.postnet is None:
@ -517,8 +520,8 @@ class FastSpeech2(nn.Layer):
d = paddle.cast(durations, 'int64')
p, e = pitch, energy
# setup batch axis
ilens = paddle.to_tensor(
[x.shape[0]], dtype=paddle.int64, place=x.place)
ilens = paddle.shape(x)[0]
xs, ys = x.unsqueeze(0), None
if y is not None:

@ -12,7 +12,6 @@
# See the License for the specific language governing permissions and
# limitations under the License.
"""Length regulator related modules."""
import numpy as np
import paddle
from paddle import nn
@ -49,11 +48,10 @@ class LengthRegulator(nn.Layer):
encodings: (B, T, C)
durations: (B, T)
"""
batch_size, t_enc = durations.shape
durations = durations.numpy()
slens = np.sum(durations, -1)
t_dec = np.max(slens)
M = np.zeros([batch_size, t_dec, t_enc])
batch_size, t_enc = paddle.shape(durations)
slens = durations.sum(-1)
t_dec = slens.max()
M = paddle.zeros([batch_size, t_dec, t_enc])
for i in range(batch_size):
k = 0
for j in range(t_enc):
@ -61,7 +59,6 @@ class LengthRegulator(nn.Layer):
if d >= 1:
M[i, k:k + d, j] = 1
k += d
M = paddle.to_tensor(M, dtype=encodings.dtype)
encodings = paddle.matmul(M, encodings)
return encodings
@ -82,6 +79,7 @@ class LengthRegulator(nn.Layer):
Tensor
replicated input tensor based on durations (B, T*, D).
"""
if alpha != 1.0:
assert alpha > 0
ds = paddle.round(ds.cast(dtype=paddle.float32) * alpha)

@ -106,13 +106,11 @@ class MultiHeadedAttention(nn.Layer):
n_batch = value.shape[0]
softmax = paddle.nn.Softmax(axis=-1)
if mask is not None:
mask = mask.unsqueeze(1)
mask = paddle.logical_not(mask)
min_value = float(
numpy.finfo(
paddle.to_tensor(0, dtype=scores.dtype).numpy().dtype).min)
# assume scores.dtype==paddle.float32, we only use "float32" here
dtype = str(scores.dtype).split(".")[-1]
min_value = numpy.finfo(dtype).min
scores = masked_fill(scores, mask, min_value)
# (batch, head, time1, time2)
self.attn = softmax(scores)

@ -31,9 +31,16 @@ class PositionalEncoding(nn.Layer):
Maximum input length.
reverse : bool
Whether to reverse the input position.
type : str
dtype of param
"""
def __init__(self, d_model, dropout_rate, max_len=5000, reverse=False):
def __init__(self,
d_model,
dropout_rate,
max_len=5000,
dtype="float32",
reverse=False):
"""Construct an PositionalEncoding object."""
super(PositionalEncoding, self).__init__()
self.d_model = d_model
@ -41,20 +48,21 @@ class PositionalEncoding(nn.Layer):
self.xscale = math.sqrt(self.d_model)
self.dropout = nn.Dropout(p=dropout_rate)
self.pe = None
self.extend_pe(paddle.expand(paddle.to_tensor(0.0), (1, max_len)))
self.dtype = dtype
self.extend_pe(paddle.expand(paddle.zeros([1]), (1, max_len)))
def extend_pe(self, x):
"""Reset the positional encodings."""
pe = paddle.zeros([x.shape[1], self.d_model])
x_shape = paddle.shape(x)
pe = paddle.zeros([x_shape[1], self.d_model])
if self.reverse:
position = paddle.arange(
x.shape[1] - 1, -1, -1.0, dtype=paddle.float32).unsqueeze(1)
x_shape[1] - 1, -1, -1.0, dtype=self.dtype).unsqueeze(1)
else:
position = paddle.arange(
0, x.shape[1], dtype=paddle.float32).unsqueeze(1)
0, x_shape[1], dtype=self.dtype).unsqueeze(1)
div_term = paddle.exp(
paddle.arange(0, self.d_model, 2, dtype=paddle.float32) *
paddle.arange(0, self.d_model, 2, dtype=self.dtype) *
-(math.log(10000.0) / self.d_model))
pe[:, 0::2] = paddle.sin(position * div_term)
pe[:, 1::2] = paddle.cos(position * div_term)
@ -75,7 +83,8 @@ class PositionalEncoding(nn.Layer):
Encoded tensor (batch, time, `*`).
"""
self.extend_pe(x)
x = x * self.xscale + self.pe[:, :x.shape[1]]
T = paddle.shape(x)[1]
x = x * self.xscale + self.pe[:, :T]
return self.dropout(x)
@ -92,21 +101,26 @@ class ScaledPositionalEncoding(PositionalEncoding):
Dropout rate.
max_len : int
Maximum input length.
dtype : str
dtype of param
"""
def __init__(self, d_model, dropout_rate, max_len=5000):
def __init__(self, d_model, dropout_rate, max_len=5000, dtype="float32"):
"""Initialize class."""
super().__init__(
d_model=d_model, dropout_rate=dropout_rate, max_len=max_len)
x = paddle.ones([1], dtype="float32")
d_model=d_model,
dropout_rate=dropout_rate,
max_len=max_len,
dtype=dtype)
x = paddle.ones([1], dtype=self.dtype)
self.alpha = paddle.create_parameter(
shape=x.shape,
dtype=str(x.numpy().dtype),
dtype=self.dtype,
default_initializer=paddle.nn.initializer.Assign(x))
def reset_parameters(self):
"""Reset parameters."""
self.alpha = paddle.to_tensor(1.0)
self.alpha = paddle.ones([1])
def forward(self, x):
"""Add positional encoding.
@ -115,12 +129,12 @@ class ScaledPositionalEncoding(PositionalEncoding):
----------
x : paddle.Tensor
Input tensor (batch, time, `*`).
Returns
----------
paddle.Tensor
Encoded tensor (batch, time, `*`).
"""
self.extend_pe(x)
x = x + self.alpha * self.pe[:, :x.shape[1]]
T = paddle.shape(x)[1]
x = x + self.alpha * self.pe[:, :T]
return self.dropout(x)

@ -185,6 +185,7 @@ class Encoder(nn.Layer):
paddle.Tensor
Mask tensor (#batch, time).
"""
xs = self.embed(xs)
xs, masks = self.encoders(xs, masks)
if self.normalize_before:

@ -44,6 +44,7 @@ class LayerNorm(paddle.nn.LayerNorm):
paddle.Tensor
Normalized tensor.
"""
if self.dim == -1:
return super(LayerNorm, self).forward(x)
else:
@ -54,9 +55,12 @@ class LayerNorm(paddle.nn.LayerNorm):
orig_perm = list(range(len_dim))
new_perm = orig_perm[:]
new_perm[self.dim], new_perm[len_dim -
1] = new_perm[len_dim -
1], new_perm[self.dim]
# Python style item change is not able when converting dygraph to static graph.
# new_perm[self.dim], new_perm[len_dim -1] = new_perm[len_dim -1], new_perm[self.dim]
# use C++ style item change here
temp = new_perm[self.dim]
new_perm[self.dim] = new_perm[len_dim - 1]
new_perm[len_dim - 1] = temp
return paddle.transpose(
super(LayerNorm, self).forward(paddle.transpose(x, new_perm)),

@ -25,12 +25,24 @@ def is_broadcastable(shp1, shp2):
return True
# assume that len(shp1) == len(shp2)
def broadcast_shape(shp1, shp2):
result = []
for a, b in zip(shp1[::-1], shp2[::-1]):
result.append(max(a, b))
return result[::-1]
def masked_fill(xs: paddle.Tensor,
mask: paddle.Tensor,
value: Union[float, int]):
assert is_broadcastable(xs.shape, mask.shape) is True
bshape = paddle.broadcast_shape(xs.shape, mask.shape)
# comment following line for converting dygraph to static graph.
# assert is_broadcastable(xs.shape, mask.shape) is True
# bshape = paddle.broadcast_shape(xs.shape, mask.shape)
bshape = broadcast_shape(xs.shape, mask.shape)
mask.stop_gradient = True
mask = mask.broadcast_to(bshape)
trues = paddle.ones_like(xs) * value
mask = mask.cast(dtype=paddle.bool)
xs = paddle.where(mask, trues, xs)

@ -56,7 +56,7 @@ def make_pad_mask(lengths, length_dim=-1):
Parameters
----------
lengths : LongTensor or List
lengths : LongTensor
Batch of lengths (B,).
Returns
@ -77,17 +77,11 @@ def make_pad_mask(lengths, length_dim=-1):
if length_dim == 0:
raise ValueError("length_dim cannot be 0: {}".format(length_dim))
if not isinstance(lengths, list):
lengths = lengths.tolist()
bs = int(len(lengths))
maxlen = int(max(lengths))
bs = paddle.shape(lengths)[0]
maxlen = lengths.max()
seq_range = paddle.arange(0, maxlen, dtype=paddle.int64)
seq_range_expand = seq_range.unsqueeze(0).expand([bs, maxlen])
seq_length_expand = paddle.to_tensor(
lengths, dtype=seq_range_expand.dtype).unsqueeze(-1)
seq_length_expand = lengths.unsqueeze(-1)
mask = seq_range_expand >= seq_length_expand
return mask

Loading…
Cancel
Save