diff --git a/examples/csmsc/tts3/README.md b/examples/csmsc/tts3/README.md index 735ef6d16..219bec794 100644 --- a/examples/csmsc/tts3/README.md +++ b/examples/csmsc/tts3/README.md @@ -215,6 +215,7 @@ python3 ${BIN_DIR}/synthesize_e2e.py \ --pwg-stat=pwg_baker_ckpt_0.4/pwg_stats.npy \ --text=${BIN_DIR}/../sentences.txt \ --output-dir=exp/default/test_e2e \ + --inference-dir=exp/default/inference \ --device="gpu" \ --phones-dict=fastspeech2_nosil_baker_ckpt_0.4/phone_id_map.txt ``` diff --git a/examples/csmsc/tts3/inference.sh b/examples/csmsc/tts3/inference.sh new file mode 100755 index 000000000..cab72547c --- /dev/null +++ b/examples/csmsc/tts3/inference.sh @@ -0,0 +1,9 @@ +#!/bin/bash + +train_output_path=$1 + +python3 ${BIN_DIR}/inference.py \ + --inference-dir=${train_output_path}/inference \ + --text=${BIN_DIR}/../sentences.txt \ + --output-dir=${train_output_path}/pd_infer_out \ + --phones-dict=dump/phone_id_map.txt diff --git a/examples/csmsc/tts3/local/synthesize_e2e.sh b/examples/csmsc/tts3/local/synthesize_e2e.sh index 8c9755dd0..b65427431 100755 --- a/examples/csmsc/tts3/local/synthesize_e2e.sh +++ b/examples/csmsc/tts3/local/synthesize_e2e.sh @@ -15,5 +15,6 @@ python3 ${BIN_DIR}/synthesize_e2e.py \ --pwg-stat=pwg_baker_ckpt_0.4/pwg_stats.npy \ --text=${BIN_DIR}/../sentences.txt \ --output-dir=${train_output_path}/test_e2e \ + --inference-dir=${train_output_path}/inference \ --device="gpu" \ --phones-dict=dump/phone_id_map.txt diff --git a/parakeet/exps/fastspeech2/inference.py b/parakeet/exps/fastspeech2/inference.py new file mode 100644 index 000000000..9926541cb --- /dev/null +++ b/parakeet/exps/fastspeech2/inference.py @@ -0,0 +1,132 @@ +# Copyright (c) 2021 PaddlePaddle Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +import argparse +import os +from pathlib import Path + +import soundfile as sf +from paddle import inference + +from parakeet.frontend.zh_frontend import Frontend + + +def main(): + parser = argparse.ArgumentParser( + description="Paddle Infernce with speedyspeech & parallel wavegan.") + parser.add_argument( + "--inference-dir", type=str, help="dir to save inference models") + parser.add_argument( + "--text", + type=str, + help="text to synthesize, a 'utt_id sentence' pair per line") + parser.add_argument("--output-dir", type=str, help="output dir") + parser.add_argument( + "--enable-auto-log", action="store_true", help="use auto log") + parser.add_argument( + "--phones-dict", + type=str, + default="phones.txt", + help="phone vocabulary file.") + + args, _ = parser.parse_known_args() + + frontend = Frontend(phone_vocab_path=args.phones_dict) + print("frontend done!") + + fastspeech2_config = inference.Config( + str(Path(args.inference_dir) / "fastspeech2.pdmodel"), + str(Path(args.inference_dir) / "fastspeech2.pdiparams")) + fastspeech2_config.enable_use_gpu(50, 0) + fastspeech2_config.enable_memory_optim() + fastspeech2_predictor = inference.create_predictor(fastspeech2_config) + + pwg_config = inference.Config( + str(Path(args.inference_dir) / "pwg.pdmodel"), + str(Path(args.inference_dir) / "pwg.pdiparams")) + pwg_config.enable_use_gpu(100, 0) + pwg_config.enable_memory_optim() + pwg_predictor = inference.create_predictor(pwg_config) + + if args.enable_auto_log: + import auto_log + os.makedirs("output", exist_ok=True) + pid = os.getpid() + logger = auto_log.AutoLogger( + model_name="fastspeech2", + model_precision='float32', + batch_size=1, + data_shape="dynamic", + save_path="./output/auto_log.log", + inference_config=fastspeech2_config, + pids=pid, + process_name=None, + gpu_ids=0, + time_keys=['preprocess_time', 'inference_time', 'postprocess_time'], + warmup=0) + + output_dir = Path(args.output_dir) + output_dir.mkdir(parents=True, exist_ok=True) + sentences = [] + + with open(args.text, 'rt') as f: + for line in f: + utt_id, sentence = line.strip().split() + sentences.append((utt_id, sentence)) + + for utt_id, sentence in sentences: + if args.enable_auto_log: + logger.times.start() + input_ids = frontend.get_input_ids(sentence, merge_sentences=True) + phone_ids = input_ids["phone_ids"] + phones = phone_ids[0].numpy() + + if args.enable_auto_log: + logger.times.stamp() + + input_names = fastspeech2_predictor.get_input_names() + phones_handle = fastspeech2_predictor.get_input_handle(input_names[0]) + + phones_handle.reshape(phones.shape) + phones_handle.copy_from_cpu(phones) + + fastspeech2_predictor.run() + output_names = fastspeech2_predictor.get_output_names() + output_handle = fastspeech2_predictor.get_output_handle(output_names[0]) + output_data = output_handle.copy_to_cpu() + + input_names = pwg_predictor.get_input_names() + mel_handle = pwg_predictor.get_input_handle(input_names[0]) + mel_handle.reshape(output_data.shape) + mel_handle.copy_from_cpu(output_data) + + pwg_predictor.run() + output_names = pwg_predictor.get_output_names() + output_handle = pwg_predictor.get_output_handle(output_names[0]) + wav = output_data = output_handle.copy_to_cpu() + + if args.enable_auto_log: + logger.times.stamp() + + sf.write(output_dir / (utt_id + ".wav"), wav, samplerate=24000) + + if args.enable_auto_log: + logger.times.end(stamp=True) + print(f"{utt_id} done!") + + if args.enable_auto_log: + logger.report() + + +if __name__ == "__main__": + main() diff --git a/parakeet/exps/fastspeech2/synthesize_e2e.py b/parakeet/exps/fastspeech2/synthesize_e2e.py index dd1b57c8a..9c036e9fc 100644 --- a/parakeet/exps/fastspeech2/synthesize_e2e.py +++ b/parakeet/exps/fastspeech2/synthesize_e2e.py @@ -13,12 +13,15 @@ # limitations under the License. import argparse import logging +import os from pathlib import Path import numpy as np import paddle import soundfile as sf import yaml +from paddle import jit +from paddle.static import InputSpec from yacs.config import CfgNode from parakeet.frontend.zh_frontend import Frontend @@ -74,7 +77,21 @@ def evaluate(args, fastspeech2_config, pwg_config): pwg_normalizer = ZScore(mu, std) fastspeech2_inference = FastSpeech2Inference(fastspeech2_normalizer, model) + fastspeech2_inference.eval() + fastspeech2_inference = jit.to_static( + fastspeech2_inference, input_spec=[InputSpec([-1], dtype=paddle.int64)]) + paddle.jit.save(fastspeech2_inference, + os.path.join(args.inference_dir, "fastspeech2")) + fastspeech2_inference = paddle.jit.load( + os.path.join(args.inference_dir, "fastspeech2")) pwg_inference = PWGInference(pwg_normalizer, vocoder) + pwg_inference.eval() + pwg_inference = jit.to_static( + pwg_inference, input_spec=[ + InputSpec([-1, 80], dtype=paddle.float32), + ]) + paddle.jit.save(pwg_inference, os.path.join(args.inference_dir, "pwg")) + pwg_inference = paddle.jit.load(os.path.join(args.inference_dir, "pwg")) output_dir = Path(args.output_dir) output_dir.mkdir(parents=True, exist_ok=True) @@ -135,6 +152,8 @@ def main(): type=str, help="text to synthesize, a 'utt_id sentence' pair per line.") parser.add_argument("--output-dir", type=str, help="output dir.") + parser.add_argument( + "--inference-dir", type=str, help="dir to save inference models") parser.add_argument( "--device", type=str, default="gpu", help="device type to use.") parser.add_argument("--verbose", type=int, default=1, help="verbose.") diff --git a/parakeet/models/fastspeech2/fastspeech2.py b/parakeet/models/fastspeech2/fastspeech2.py index 21c2d2c3f..0dbbb7bd9 100644 --- a/parakeet/models/fastspeech2/fastspeech2.py +++ b/parakeet/models/fastspeech2/fastspeech2.py @@ -388,7 +388,6 @@ class FastSpeech2(nn.Layer): spk_id=None, tone_id=None) -> Sequence[paddle.Tensor]: # forward encoder - bs = xs.shape[0] x_masks = self._source_mask(ilens) # (B, Tmax, adim) hs, _ = self.encoder(xs, x_masks) @@ -428,6 +427,7 @@ class FastSpeech2(nn.Layer): e_embs = self.energy_embed(e_outs.transpose((0, 2, 1))).transpose( (0, 2, 1)) hs = hs + e_embs + p_embs + # (B, Lmax, adim) hs = self.length_regulator(hs, d_outs, alpha) else: @@ -438,6 +438,7 @@ class FastSpeech2(nn.Layer): e_embs = self.energy_embed(es.transpose((0, 2, 1))).transpose( (0, 2, 1)) hs = hs + e_embs + p_embs + # (B, Lmax, adim) hs = self.length_regulator(hs, ds) @@ -455,7 +456,8 @@ class FastSpeech2(nn.Layer): zs, _ = self.decoder(hs, h_masks) # (B, Lmax, odim) - before_outs = self.feat_out(zs).reshape((bs, -1, self.odim)) + before_outs = self.feat_out(zs).reshape( + (paddle.shape(zs)[0], -1, self.odim)) # postnet -> (B, Lmax//r * r, odim) if self.postnet is None: @@ -463,6 +465,7 @@ class FastSpeech2(nn.Layer): else: after_outs = before_outs + self.postnet( before_outs.transpose((0, 2, 1))).transpose((0, 2, 1)) + return before_outs, after_outs, d_outs, p_outs, e_outs def inference( diff --git a/parakeet/modules/fastspeech2_predictor/length_regulator.py b/parakeet/modules/fastspeech2_predictor/length_regulator.py index e413812d2..a4d508add 100644 --- a/parakeet/modules/fastspeech2_predictor/length_regulator.py +++ b/parakeet/modules/fastspeech2_predictor/length_regulator.py @@ -48,10 +48,9 @@ class LengthRegulator(nn.Layer): encodings: (B, T, C) durations: (B, T) """ - batch_size, t_enc = durations.shape - # durations = durations.numpy() - slens = paddle.sum(durations, -1) - t_dec = paddle.max(slens) + batch_size, t_enc = paddle.shape(durations) + slens = durations.sum(-1) + t_dec = slens.max() M = paddle.zeros([batch_size, t_dec, t_enc]) for i in range(batch_size): k = 0 @@ -60,7 +59,6 @@ class LengthRegulator(nn.Layer): if d >= 1: M[i, k:k + d, j] = 1 k += d - M = paddle.to_tensor(M, dtype=encodings.dtype) encodings = paddle.matmul(M, encodings) return encodings diff --git a/parakeet/modules/fastspeech2_transformer/attention.py b/parakeet/modules/fastspeech2_transformer/attention.py index 8cef0023c..0bac47426 100644 --- a/parakeet/modules/fastspeech2_transformer/attention.py +++ b/parakeet/modules/fastspeech2_transformer/attention.py @@ -37,7 +37,7 @@ class MultiHeadedAttention(nn.Layer): def __init__(self, n_head, n_feat, dropout_rate): """Construct an MultiHeadedAttention object.""" super(MultiHeadedAttention, self).__init__() - # assert n_feat % n_head == 0 + assert n_feat % n_head == 0 # We assume d_v always equals d_k self.d_k = n_feat // n_head self.h = n_head @@ -108,7 +108,9 @@ class MultiHeadedAttention(nn.Layer): if mask is not None: mask = mask.unsqueeze(1) mask = paddle.logical_not(mask) - min_value = float(numpy.finfo("float32").min) + # assume scores.dtype==paddle.float32, we only use "float32" here + dtype = str(scores.dtype).split(".")[-1] + min_value = numpy.finfo(dtype).min scores = masked_fill(scores, mask, min_value) # (batch, head, time1, time2) self.attn = softmax(scores) diff --git a/parakeet/modules/fastspeech2_transformer/embedding.py b/parakeet/modules/fastspeech2_transformer/embedding.py index 888a209a5..1dfd6dfdc 100644 --- a/parakeet/modules/fastspeech2_transformer/embedding.py +++ b/parakeet/modules/fastspeech2_transformer/embedding.py @@ -31,9 +31,16 @@ class PositionalEncoding(nn.Layer): Maximum input length. reverse : bool Whether to reverse the input position. + type : str + dtype of param """ - def __init__(self, d_model, dropout_rate, max_len=5000, reverse=False): + def __init__(self, + d_model, + dropout_rate, + max_len=5000, + dtype="float32", + reverse=False): """Construct an PositionalEncoding object.""" super(PositionalEncoding, self).__init__() self.d_model = d_model @@ -41,21 +48,21 @@ class PositionalEncoding(nn.Layer): self.xscale = math.sqrt(self.d_model) self.dropout = nn.Dropout(p=dropout_rate) self.pe = None - self.extend_pe(paddle.expand(paddle.to_tensor(0.0), (1, max_len))) + self.dtype = dtype + self.extend_pe(paddle.expand(paddle.zeros([1]), (1, max_len))) def extend_pe(self, x): """Reset the positional encodings.""" - - pe = paddle.zeros([paddle.shape(x)[1], self.d_model]) + x_shape = paddle.shape(x) + pe = paddle.zeros([x_shape[1], self.d_model]) if self.reverse: position = paddle.arange( - paddle.shape(x)[1] - 1, -1, -1.0, - dtype=paddle.float32).unsqueeze(1) + x_shape[1] - 1, -1, -1.0, dtype=self.dtype).unsqueeze(1) else: position = paddle.arange( - 0, paddle.shape(x)[1], dtype=paddle.float32).unsqueeze(1) + 0, x_shape[1], dtype=self.dtype).unsqueeze(1) div_term = paddle.exp( - paddle.arange(0, self.d_model, 2, dtype=paddle.float32) * + paddle.arange(0, self.d_model, 2, dtype=self.dtype) * -(math.log(10000.0) / self.d_model)) pe[:, 0::2] = paddle.sin(position * div_term) pe[:, 1::2] = paddle.cos(position * div_term) @@ -76,8 +83,8 @@ class PositionalEncoding(nn.Layer): Encoded tensor (batch, time, `*`). """ self.extend_pe(x) - - x = x * self.xscale + self.pe[:, :paddle.shape(x)[1]] + T = paddle.shape(x)[1] + x = x * self.xscale + self.pe[:, :T] return self.dropout(x) @@ -94,21 +101,26 @@ class ScaledPositionalEncoding(PositionalEncoding): Dropout rate. max_len : int Maximum input length. + dtype : str + dtype of param """ - def __init__(self, d_model, dropout_rate, max_len=5000): + def __init__(self, d_model, dropout_rate, max_len=5000, dtype="float32"): """Initialize class.""" super().__init__( - d_model=d_model, dropout_rate=dropout_rate, max_len=max_len) - x = paddle.ones([1], dtype="float32") + d_model=d_model, + dropout_rate=dropout_rate, + max_len=max_len, + dtype=dtype) + x = paddle.ones([1], dtype=self.dtype) self.alpha = paddle.create_parameter( shape=x.shape, - dtype="float32", + dtype=self.dtype, default_initializer=paddle.nn.initializer.Assign(x)) def reset_parameters(self): """Reset parameters.""" - self.alpha = paddle.to_tensor(1.0) + self.alpha = paddle.ones([1]) def forward(self, x): """Add positional encoding. @@ -123,5 +135,6 @@ class ScaledPositionalEncoding(PositionalEncoding): Encoded tensor (batch, time, `*`). """ self.extend_pe(x) - x = x + self.alpha * self.pe[:, :paddle.shape(x)[1]] + T = paddle.shape(x)[1] + x = x + self.alpha * self.pe[:, :T] return self.dropout(x) diff --git a/parakeet/modules/fastspeech2_transformer/encoder_layer.py b/parakeet/modules/fastspeech2_transformer/encoder_layer.py index 298e13f88..d8f89d677 100644 --- a/parakeet/modules/fastspeech2_transformer/encoder_layer.py +++ b/parakeet/modules/fastspeech2_transformer/encoder_layer.py @@ -87,7 +87,7 @@ class EncoderLayer(nn.Layer): if cache is None: x_q = x else: - # assert cache.shape == (x.shape[0], x.shape[1] - 1, self.size) + assert cache.shape == (x.shape[0], x.shape[1] - 1, self.size) x_q = x[:, -1:, :] residual = residual[:, -1:, :] mask = None if mask is None else mask[:, -1:, :] diff --git a/parakeet/modules/layer_norm.py b/parakeet/modules/layer_norm.py index f91b49ae6..a1c775fc8 100644 --- a/parakeet/modules/layer_norm.py +++ b/parakeet/modules/layer_norm.py @@ -55,10 +55,12 @@ class LayerNorm(paddle.nn.LayerNorm): orig_perm = list(range(len_dim)) new_perm = orig_perm[:] + # Python style item change is not able when converting dygraph to static graph. + # new_perm[self.dim], new_perm[len_dim -1] = new_perm[len_dim -1], new_perm[self.dim] + # use C++ style item change here temp = new_perm[self.dim] new_perm[self.dim] = new_perm[len_dim - 1] new_perm[len_dim - 1] = temp - # new_perm[self.dim], new_perm[len_dim -1] = new_perm[len_dim -1], new_perm[self.dim] return paddle.transpose( super(LayerNorm, self).forward(paddle.transpose(x, new_perm)), diff --git a/parakeet/modules/masked_fill.py b/parakeet/modules/masked_fill.py index e42a3cc0d..b32222547 100644 --- a/parakeet/modules/masked_fill.py +++ b/parakeet/modules/masked_fill.py @@ -25,6 +25,7 @@ def is_broadcastable(shp1, shp2): return True +# assume that len(shp1) == len(shp2) def broadcast_shape(shp1, shp2): result = [] for a, b in zip(shp1[::-1], shp2[::-1]): @@ -35,6 +36,7 @@ def broadcast_shape(shp1, shp2): def masked_fill(xs: paddle.Tensor, mask: paddle.Tensor, value: Union[float, int]): + # comment following line for converting dygraph to static graph. # assert is_broadcastable(xs.shape, mask.shape) is True # bshape = paddle.broadcast_shape(xs.shape, mask.shape) bshape = broadcast_shape(xs.shape, mask.shape)