align ouput of dygraph and static graph

pull/948/head
TianYuan 3 years ago
parent f652ba3a34
commit 79e7a4d44e

@ -215,6 +215,7 @@ python3 ${BIN_DIR}/synthesize_e2e.py \
--pwg-stat=pwg_baker_ckpt_0.4/pwg_stats.npy \ --pwg-stat=pwg_baker_ckpt_0.4/pwg_stats.npy \
--text=${BIN_DIR}/../sentences.txt \ --text=${BIN_DIR}/../sentences.txt \
--output-dir=exp/default/test_e2e \ --output-dir=exp/default/test_e2e \
--inference-dir=exp/default/inference \
--device="gpu" \ --device="gpu" \
--phones-dict=fastspeech2_nosil_baker_ckpt_0.4/phone_id_map.txt --phones-dict=fastspeech2_nosil_baker_ckpt_0.4/phone_id_map.txt
``` ```

@ -0,0 +1,9 @@
#!/bin/bash
train_output_path=$1
python3 ${BIN_DIR}/inference.py \
--inference-dir=${train_output_path}/inference \
--text=${BIN_DIR}/../sentences.txt \
--output-dir=${train_output_path}/pd_infer_out \
--phones-dict=dump/phone_id_map.txt

@ -15,5 +15,6 @@ python3 ${BIN_DIR}/synthesize_e2e.py \
--pwg-stat=pwg_baker_ckpt_0.4/pwg_stats.npy \ --pwg-stat=pwg_baker_ckpt_0.4/pwg_stats.npy \
--text=${BIN_DIR}/../sentences.txt \ --text=${BIN_DIR}/../sentences.txt \
--output-dir=${train_output_path}/test_e2e \ --output-dir=${train_output_path}/test_e2e \
--inference-dir=${train_output_path}/inference \
--device="gpu" \ --device="gpu" \
--phones-dict=dump/phone_id_map.txt --phones-dict=dump/phone_id_map.txt

@ -0,0 +1,132 @@
# Copyright (c) 2021 PaddlePaddle Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import argparse
import os
from pathlib import Path
import soundfile as sf
from paddle import inference
from parakeet.frontend.zh_frontend import Frontend
def main():
parser = argparse.ArgumentParser(
description="Paddle Infernce with speedyspeech & parallel wavegan.")
parser.add_argument(
"--inference-dir", type=str, help="dir to save inference models")
parser.add_argument(
"--text",
type=str,
help="text to synthesize, a 'utt_id sentence' pair per line")
parser.add_argument("--output-dir", type=str, help="output dir")
parser.add_argument(
"--enable-auto-log", action="store_true", help="use auto log")
parser.add_argument(
"--phones-dict",
type=str,
default="phones.txt",
help="phone vocabulary file.")
args, _ = parser.parse_known_args()
frontend = Frontend(phone_vocab_path=args.phones_dict)
print("frontend done!")
fastspeech2_config = inference.Config(
str(Path(args.inference_dir) / "fastspeech2.pdmodel"),
str(Path(args.inference_dir) / "fastspeech2.pdiparams"))
fastspeech2_config.enable_use_gpu(50, 0)
fastspeech2_config.enable_memory_optim()
fastspeech2_predictor = inference.create_predictor(fastspeech2_config)
pwg_config = inference.Config(
str(Path(args.inference_dir) / "pwg.pdmodel"),
str(Path(args.inference_dir) / "pwg.pdiparams"))
pwg_config.enable_use_gpu(100, 0)
pwg_config.enable_memory_optim()
pwg_predictor = inference.create_predictor(pwg_config)
if args.enable_auto_log:
import auto_log
os.makedirs("output", exist_ok=True)
pid = os.getpid()
logger = auto_log.AutoLogger(
model_name="fastspeech2",
model_precision='float32',
batch_size=1,
data_shape="dynamic",
save_path="./output/auto_log.log",
inference_config=fastspeech2_config,
pids=pid,
process_name=None,
gpu_ids=0,
time_keys=['preprocess_time', 'inference_time', 'postprocess_time'],
warmup=0)
output_dir = Path(args.output_dir)
output_dir.mkdir(parents=True, exist_ok=True)
sentences = []
with open(args.text, 'rt') as f:
for line in f:
utt_id, sentence = line.strip().split()
sentences.append((utt_id, sentence))
for utt_id, sentence in sentences:
if args.enable_auto_log:
logger.times.start()
input_ids = frontend.get_input_ids(sentence, merge_sentences=True)
phone_ids = input_ids["phone_ids"]
phones = phone_ids[0].numpy()
if args.enable_auto_log:
logger.times.stamp()
input_names = fastspeech2_predictor.get_input_names()
phones_handle = fastspeech2_predictor.get_input_handle(input_names[0])
phones_handle.reshape(phones.shape)
phones_handle.copy_from_cpu(phones)
fastspeech2_predictor.run()
output_names = fastspeech2_predictor.get_output_names()
output_handle = fastspeech2_predictor.get_output_handle(output_names[0])
output_data = output_handle.copy_to_cpu()
input_names = pwg_predictor.get_input_names()
mel_handle = pwg_predictor.get_input_handle(input_names[0])
mel_handle.reshape(output_data.shape)
mel_handle.copy_from_cpu(output_data)
pwg_predictor.run()
output_names = pwg_predictor.get_output_names()
output_handle = pwg_predictor.get_output_handle(output_names[0])
wav = output_data = output_handle.copy_to_cpu()
if args.enable_auto_log:
logger.times.stamp()
sf.write(output_dir / (utt_id + ".wav"), wav, samplerate=24000)
if args.enable_auto_log:
logger.times.end(stamp=True)
print(f"{utt_id} done!")
if args.enable_auto_log:
logger.report()
if __name__ == "__main__":
main()

@ -13,12 +13,15 @@
# limitations under the License. # limitations under the License.
import argparse import argparse
import logging import logging
import os
from pathlib import Path from pathlib import Path
import numpy as np import numpy as np
import paddle import paddle
import soundfile as sf import soundfile as sf
import yaml import yaml
from paddle import jit
from paddle.static import InputSpec
from yacs.config import CfgNode from yacs.config import CfgNode
from parakeet.frontend.zh_frontend import Frontend from parakeet.frontend.zh_frontend import Frontend
@ -74,7 +77,21 @@ def evaluate(args, fastspeech2_config, pwg_config):
pwg_normalizer = ZScore(mu, std) pwg_normalizer = ZScore(mu, std)
fastspeech2_inference = FastSpeech2Inference(fastspeech2_normalizer, model) fastspeech2_inference = FastSpeech2Inference(fastspeech2_normalizer, model)
fastspeech2_inference.eval()
fastspeech2_inference = jit.to_static(
fastspeech2_inference, input_spec=[InputSpec([-1], dtype=paddle.int64)])
paddle.jit.save(fastspeech2_inference,
os.path.join(args.inference_dir, "fastspeech2"))
fastspeech2_inference = paddle.jit.load(
os.path.join(args.inference_dir, "fastspeech2"))
pwg_inference = PWGInference(pwg_normalizer, vocoder) pwg_inference = PWGInference(pwg_normalizer, vocoder)
pwg_inference.eval()
pwg_inference = jit.to_static(
pwg_inference, input_spec=[
InputSpec([-1, 80], dtype=paddle.float32),
])
paddle.jit.save(pwg_inference, os.path.join(args.inference_dir, "pwg"))
pwg_inference = paddle.jit.load(os.path.join(args.inference_dir, "pwg"))
output_dir = Path(args.output_dir) output_dir = Path(args.output_dir)
output_dir.mkdir(parents=True, exist_ok=True) output_dir.mkdir(parents=True, exist_ok=True)
@ -135,6 +152,8 @@ def main():
type=str, type=str,
help="text to synthesize, a 'utt_id sentence' pair per line.") help="text to synthesize, a 'utt_id sentence' pair per line.")
parser.add_argument("--output-dir", type=str, help="output dir.") parser.add_argument("--output-dir", type=str, help="output dir.")
parser.add_argument(
"--inference-dir", type=str, help="dir to save inference models")
parser.add_argument( parser.add_argument(
"--device", type=str, default="gpu", help="device type to use.") "--device", type=str, default="gpu", help="device type to use.")
parser.add_argument("--verbose", type=int, default=1, help="verbose.") parser.add_argument("--verbose", type=int, default=1, help="verbose.")

@ -388,7 +388,6 @@ class FastSpeech2(nn.Layer):
spk_id=None, spk_id=None,
tone_id=None) -> Sequence[paddle.Tensor]: tone_id=None) -> Sequence[paddle.Tensor]:
# forward encoder # forward encoder
bs = xs.shape[0]
x_masks = self._source_mask(ilens) x_masks = self._source_mask(ilens)
# (B, Tmax, adim) # (B, Tmax, adim)
hs, _ = self.encoder(xs, x_masks) hs, _ = self.encoder(xs, x_masks)
@ -428,6 +427,7 @@ class FastSpeech2(nn.Layer):
e_embs = self.energy_embed(e_outs.transpose((0, 2, 1))).transpose( e_embs = self.energy_embed(e_outs.transpose((0, 2, 1))).transpose(
(0, 2, 1)) (0, 2, 1))
hs = hs + e_embs + p_embs hs = hs + e_embs + p_embs
# (B, Lmax, adim) # (B, Lmax, adim)
hs = self.length_regulator(hs, d_outs, alpha) hs = self.length_regulator(hs, d_outs, alpha)
else: else:
@ -438,6 +438,7 @@ class FastSpeech2(nn.Layer):
e_embs = self.energy_embed(es.transpose((0, 2, 1))).transpose( e_embs = self.energy_embed(es.transpose((0, 2, 1))).transpose(
(0, 2, 1)) (0, 2, 1))
hs = hs + e_embs + p_embs hs = hs + e_embs + p_embs
# (B, Lmax, adim) # (B, Lmax, adim)
hs = self.length_regulator(hs, ds) hs = self.length_regulator(hs, ds)
@ -455,7 +456,8 @@ class FastSpeech2(nn.Layer):
zs, _ = self.decoder(hs, h_masks) zs, _ = self.decoder(hs, h_masks)
# (B, Lmax, odim) # (B, Lmax, odim)
before_outs = self.feat_out(zs).reshape((bs, -1, self.odim)) before_outs = self.feat_out(zs).reshape(
(paddle.shape(zs)[0], -1, self.odim))
# postnet -> (B, Lmax//r * r, odim) # postnet -> (B, Lmax//r * r, odim)
if self.postnet is None: if self.postnet is None:
@ -463,6 +465,7 @@ class FastSpeech2(nn.Layer):
else: else:
after_outs = before_outs + self.postnet( after_outs = before_outs + self.postnet(
before_outs.transpose((0, 2, 1))).transpose((0, 2, 1)) before_outs.transpose((0, 2, 1))).transpose((0, 2, 1))
return before_outs, after_outs, d_outs, p_outs, e_outs return before_outs, after_outs, d_outs, p_outs, e_outs
def inference( def inference(

@ -48,10 +48,9 @@ class LengthRegulator(nn.Layer):
encodings: (B, T, C) encodings: (B, T, C)
durations: (B, T) durations: (B, T)
""" """
batch_size, t_enc = durations.shape batch_size, t_enc = paddle.shape(durations)
# durations = durations.numpy() slens = durations.sum(-1)
slens = paddle.sum(durations, -1) t_dec = slens.max()
t_dec = paddle.max(slens)
M = paddle.zeros([batch_size, t_dec, t_enc]) M = paddle.zeros([batch_size, t_dec, t_enc])
for i in range(batch_size): for i in range(batch_size):
k = 0 k = 0
@ -60,7 +59,6 @@ class LengthRegulator(nn.Layer):
if d >= 1: if d >= 1:
M[i, k:k + d, j] = 1 M[i, k:k + d, j] = 1
k += d k += d
M = paddle.to_tensor(M, dtype=encodings.dtype)
encodings = paddle.matmul(M, encodings) encodings = paddle.matmul(M, encodings)
return encodings return encodings

@ -37,7 +37,7 @@ class MultiHeadedAttention(nn.Layer):
def __init__(self, n_head, n_feat, dropout_rate): def __init__(self, n_head, n_feat, dropout_rate):
"""Construct an MultiHeadedAttention object.""" """Construct an MultiHeadedAttention object."""
super(MultiHeadedAttention, self).__init__() super(MultiHeadedAttention, self).__init__()
# assert n_feat % n_head == 0 assert n_feat % n_head == 0
# We assume d_v always equals d_k # We assume d_v always equals d_k
self.d_k = n_feat // n_head self.d_k = n_feat // n_head
self.h = n_head self.h = n_head
@ -108,7 +108,9 @@ class MultiHeadedAttention(nn.Layer):
if mask is not None: if mask is not None:
mask = mask.unsqueeze(1) mask = mask.unsqueeze(1)
mask = paddle.logical_not(mask) mask = paddle.logical_not(mask)
min_value = float(numpy.finfo("float32").min) # assume scores.dtype==paddle.float32, we only use "float32" here
dtype = str(scores.dtype).split(".")[-1]
min_value = numpy.finfo(dtype).min
scores = masked_fill(scores, mask, min_value) scores = masked_fill(scores, mask, min_value)
# (batch, head, time1, time2) # (batch, head, time1, time2)
self.attn = softmax(scores) self.attn = softmax(scores)

@ -31,9 +31,16 @@ class PositionalEncoding(nn.Layer):
Maximum input length. Maximum input length.
reverse : bool reverse : bool
Whether to reverse the input position. Whether to reverse the input position.
type : str
dtype of param
""" """
def __init__(self, d_model, dropout_rate, max_len=5000, reverse=False): def __init__(self,
d_model,
dropout_rate,
max_len=5000,
dtype="float32",
reverse=False):
"""Construct an PositionalEncoding object.""" """Construct an PositionalEncoding object."""
super(PositionalEncoding, self).__init__() super(PositionalEncoding, self).__init__()
self.d_model = d_model self.d_model = d_model
@ -41,21 +48,21 @@ class PositionalEncoding(nn.Layer):
self.xscale = math.sqrt(self.d_model) self.xscale = math.sqrt(self.d_model)
self.dropout = nn.Dropout(p=dropout_rate) self.dropout = nn.Dropout(p=dropout_rate)
self.pe = None self.pe = None
self.extend_pe(paddle.expand(paddle.to_tensor(0.0), (1, max_len))) self.dtype = dtype
self.extend_pe(paddle.expand(paddle.zeros([1]), (1, max_len)))
def extend_pe(self, x): def extend_pe(self, x):
"""Reset the positional encodings.""" """Reset the positional encodings."""
x_shape = paddle.shape(x)
pe = paddle.zeros([paddle.shape(x)[1], self.d_model]) pe = paddle.zeros([x_shape[1], self.d_model])
if self.reverse: if self.reverse:
position = paddle.arange( position = paddle.arange(
paddle.shape(x)[1] - 1, -1, -1.0, x_shape[1] - 1, -1, -1.0, dtype=self.dtype).unsqueeze(1)
dtype=paddle.float32).unsqueeze(1)
else: else:
position = paddle.arange( position = paddle.arange(
0, paddle.shape(x)[1], dtype=paddle.float32).unsqueeze(1) 0, x_shape[1], dtype=self.dtype).unsqueeze(1)
div_term = paddle.exp( div_term = paddle.exp(
paddle.arange(0, self.d_model, 2, dtype=paddle.float32) * paddle.arange(0, self.d_model, 2, dtype=self.dtype) *
-(math.log(10000.0) / self.d_model)) -(math.log(10000.0) / self.d_model))
pe[:, 0::2] = paddle.sin(position * div_term) pe[:, 0::2] = paddle.sin(position * div_term)
pe[:, 1::2] = paddle.cos(position * div_term) pe[:, 1::2] = paddle.cos(position * div_term)
@ -76,8 +83,8 @@ class PositionalEncoding(nn.Layer):
Encoded tensor (batch, time, `*`). Encoded tensor (batch, time, `*`).
""" """
self.extend_pe(x) self.extend_pe(x)
T = paddle.shape(x)[1]
x = x * self.xscale + self.pe[:, :paddle.shape(x)[1]] x = x * self.xscale + self.pe[:, :T]
return self.dropout(x) return self.dropout(x)
@ -94,21 +101,26 @@ class ScaledPositionalEncoding(PositionalEncoding):
Dropout rate. Dropout rate.
max_len : int max_len : int
Maximum input length. Maximum input length.
dtype : str
dtype of param
""" """
def __init__(self, d_model, dropout_rate, max_len=5000): def __init__(self, d_model, dropout_rate, max_len=5000, dtype="float32"):
"""Initialize class.""" """Initialize class."""
super().__init__( super().__init__(
d_model=d_model, dropout_rate=dropout_rate, max_len=max_len) d_model=d_model,
x = paddle.ones([1], dtype="float32") dropout_rate=dropout_rate,
max_len=max_len,
dtype=dtype)
x = paddle.ones([1], dtype=self.dtype)
self.alpha = paddle.create_parameter( self.alpha = paddle.create_parameter(
shape=x.shape, shape=x.shape,
dtype="float32", dtype=self.dtype,
default_initializer=paddle.nn.initializer.Assign(x)) default_initializer=paddle.nn.initializer.Assign(x))
def reset_parameters(self): def reset_parameters(self):
"""Reset parameters.""" """Reset parameters."""
self.alpha = paddle.to_tensor(1.0) self.alpha = paddle.ones([1])
def forward(self, x): def forward(self, x):
"""Add positional encoding. """Add positional encoding.
@ -123,5 +135,6 @@ class ScaledPositionalEncoding(PositionalEncoding):
Encoded tensor (batch, time, `*`). Encoded tensor (batch, time, `*`).
""" """
self.extend_pe(x) self.extend_pe(x)
x = x + self.alpha * self.pe[:, :paddle.shape(x)[1]] T = paddle.shape(x)[1]
x = x + self.alpha * self.pe[:, :T]
return self.dropout(x) return self.dropout(x)

@ -87,7 +87,7 @@ class EncoderLayer(nn.Layer):
if cache is None: if cache is None:
x_q = x x_q = x
else: else:
# assert cache.shape == (x.shape[0], x.shape[1] - 1, self.size) assert cache.shape == (x.shape[0], x.shape[1] - 1, self.size)
x_q = x[:, -1:, :] x_q = x[:, -1:, :]
residual = residual[:, -1:, :] residual = residual[:, -1:, :]
mask = None if mask is None else mask[:, -1:, :] mask = None if mask is None else mask[:, -1:, :]

@ -55,10 +55,12 @@ class LayerNorm(paddle.nn.LayerNorm):
orig_perm = list(range(len_dim)) orig_perm = list(range(len_dim))
new_perm = orig_perm[:] new_perm = orig_perm[:]
# Python style item change is not able when converting dygraph to static graph.
# new_perm[self.dim], new_perm[len_dim -1] = new_perm[len_dim -1], new_perm[self.dim]
# use C++ style item change here
temp = new_perm[self.dim] temp = new_perm[self.dim]
new_perm[self.dim] = new_perm[len_dim - 1] new_perm[self.dim] = new_perm[len_dim - 1]
new_perm[len_dim - 1] = temp new_perm[len_dim - 1] = temp
# new_perm[self.dim], new_perm[len_dim -1] = new_perm[len_dim -1], new_perm[self.dim]
return paddle.transpose( return paddle.transpose(
super(LayerNorm, self).forward(paddle.transpose(x, new_perm)), super(LayerNorm, self).forward(paddle.transpose(x, new_perm)),

@ -25,6 +25,7 @@ def is_broadcastable(shp1, shp2):
return True return True
# assume that len(shp1) == len(shp2)
def broadcast_shape(shp1, shp2): def broadcast_shape(shp1, shp2):
result = [] result = []
for a, b in zip(shp1[::-1], shp2[::-1]): for a, b in zip(shp1[::-1], shp2[::-1]):
@ -35,6 +36,7 @@ def broadcast_shape(shp1, shp2):
def masked_fill(xs: paddle.Tensor, def masked_fill(xs: paddle.Tensor,
mask: paddle.Tensor, mask: paddle.Tensor,
value: Union[float, int]): value: Union[float, int]):
# comment following line for converting dygraph to static graph.
# assert is_broadcastable(xs.shape, mask.shape) is True # assert is_broadcastable(xs.shape, mask.shape) is True
# bshape = paddle.broadcast_shape(xs.shape, mask.shape) # bshape = paddle.broadcast_shape(xs.shape, mask.shape)
bshape = broadcast_shape(xs.shape, mask.shape) bshape = broadcast_shape(xs.shape, mask.shape)

Loading…
Cancel
Save