fix mbmelgan static

pull/958/head
TianYuan 3 years ago
parent 980944dab1
commit ba978fca98

@ -0,0 +1,185 @@
# Copyright (c) 2021 PaddlePaddle Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import argparse
import logging
import os
from pathlib import Path
import numpy as np
import paddle
import soundfile as sf
import yaml
from paddle import jit
from paddle.static import InputSpec
from yacs.config import CfgNode
from parakeet.frontend.zh_frontend import Frontend
from parakeet.models.fastspeech2 import FastSpeech2
from parakeet.models.fastspeech2 import FastSpeech2Inference
from parakeet.models.melgan import MelGANGenerator
from parakeet.models.melgan import MelGANInference
from parakeet.modules.normalizer import ZScore
def evaluate(args, fastspeech2_config, melgan_config):
# dataloader has been too verbose
logging.getLogger("DataLoader").disabled = True
# construct dataset for evaluation
sentences = []
with open(args.text, 'rt') as f:
for line in f:
utt_id, sentence = line.strip().split()
sentences.append((utt_id, sentence))
with open(args.phones_dict, "r") as f:
phn_id = [line.strip().split() for line in f.readlines()]
vocab_size = len(phn_id)
print("vocab_size:", vocab_size)
odim = fastspeech2_config.n_mels
model = FastSpeech2(
idim=vocab_size, odim=odim, **fastspeech2_config["model"])
model.set_state_dict(
paddle.load(args.fastspeech2_checkpoint)["main_params"])
model.eval()
vocoder = MelGANGenerator(**melgan_config["generator_params"])
vocoder.set_state_dict(
paddle.load(args.melgan_checkpoint)["generator_params"])
vocoder.remove_weight_norm()
vocoder.eval()
print("model done!")
frontend = Frontend(phone_vocab_path=args.phones_dict)
print("frontend done!")
stat = np.load(args.fastspeech2_stat)
mu, std = stat
mu = paddle.to_tensor(mu)
std = paddle.to_tensor(std)
fastspeech2_normalizer = ZScore(mu, std)
stat = np.load(args.melgan_stat)
mu, std = stat
mu = paddle.to_tensor(mu)
std = paddle.to_tensor(std)
pwg_normalizer = ZScore(mu, std)
fastspeech2_inference = FastSpeech2Inference(fastspeech2_normalizer, model)
fastspeech2_inference.eval()
fastspeech2_inference = jit.to_static(
fastspeech2_inference, input_spec=[InputSpec([-1], dtype=paddle.int64)])
paddle.jit.save(fastspeech2_inference,
os.path.join(args.inference_dir, "fastspeech2"))
fastspeech2_inference = paddle.jit.load(
os.path.join(args.inference_dir, "fastspeech2"))
mb_melgan_inference = MelGANInference(pwg_normalizer, vocoder)
mb_melgan_inference.eval()
mb_melgan_inference = jit.to_static(
mb_melgan_inference,
input_spec=[
InputSpec([-1, 80], dtype=paddle.float32),
])
paddle.jit.save(mb_melgan_inference,
os.path.join(args.inference_dir, "mb_melgan"))
mb_melgan_inference = paddle.jit.load(
os.path.join(args.inference_dir, "mb_melgan"))
output_dir = Path(args.output_dir)
output_dir.mkdir(parents=True, exist_ok=True)
for utt_id, sentence in sentences:
input_ids = frontend.get_input_ids(sentence, merge_sentences=True)
phone_ids = input_ids["phone_ids"]
flags = 0
for part_phone_ids in phone_ids:
with paddle.no_grad():
mel = fastspeech2_inference(part_phone_ids)
temp_wav = mb_melgan_inference(mel)
if flags == 0:
wav = temp_wav
flags = 1
else:
wav = paddle.concat([wav, temp_wav])
sf.write(
str(output_dir / (utt_id + ".wav")),
wav.numpy(),
samplerate=fastspeech2_config.fs)
print(f"{utt_id} done!")
def main():
# parse args and config and redirect to train_sp
parser = argparse.ArgumentParser(
description="Synthesize with fastspeech2 & parallel wavegan.")
parser.add_argument(
"--fastspeech2-config", type=str, help="fastspeech2 config file.")
parser.add_argument(
"--fastspeech2-checkpoint",
type=str,
help="fastspeech2 checkpoint to load.")
parser.add_argument(
"--fastspeech2-stat",
type=str,
help="mean and standard deviation used to normalize spectrogram when training fastspeech2."
)
parser.add_argument(
"--melgan-config", type=str, help="parallel wavegan config file.")
parser.add_argument(
"--melgan-checkpoint",
type=str,
help="parallel wavegan generator parameters to load.")
parser.add_argument(
"--melgan-stat",
type=str,
help="mean and standard deviation used to normalize spectrogram when training parallel wavegan."
)
parser.add_argument(
"--phones-dict",
type=str,
default="phone_id_map.txt",
help="phone vocabulary file.")
parser.add_argument(
"--text",
type=str,
help="text to synthesize, a 'utt_id sentence' pair per line.")
parser.add_argument("--output-dir", type=str, help="output dir.")
parser.add_argument(
"--inference-dir", type=str, help="dir to save inference models")
parser.add_argument(
"--device", type=str, default="gpu", help="device type to use.")
parser.add_argument("--verbose", type=int, default=1, help="verbose.")
args = parser.parse_args()
paddle.set_device(args.device)
with open(args.fastspeech2_config) as f:
fastspeech2_config = CfgNode(yaml.safe_load(f))
with open(args.melgan_config) as f:
melgan_config = CfgNode(yaml.safe_load(f))
print("========Args========")
print(yaml.safe_dump(vars(args)))
print("========Config========")
print(fastspeech2_config)
print(melgan_config)
evaluate(args, fastspeech2_config, melgan_config)
if __name__ == "__main__":
main()

@ -263,15 +263,13 @@ class MelGANGenerator(nn.Layer):
Tensor
Output tensor (out_channels*T ** prod(upsample_scales), 1).
"""
if not isinstance(c, paddle.Tensor):
c = paddle.to_tensor(c, dtype="float32")
# pseudo batch
c = c.transpose([1, 0]).unsqueeze(0)
# (B, out_channels, T ** prod(upsample_scales)
out = self.melgan(c)
if self.pqmf is not None:
# (B, 1, out_channels * T ** prod(upsample_scales)
out = self.pqmf.synthesis(out)
out = self.pqmf(out)
out = out.squeeze(0).transpose([1, 0])
return out
@ -551,3 +549,15 @@ class MelGANMultiScaleDiscriminator(nn.Layer):
m.weight.set_value(w)
self.apply(_reset_parameters)
class MelGANInference(nn.Layer):
def __init__(self, normalizer, melgan_generator):
super().__init__()
self.normalizer = normalizer
self.melgan_generator = melgan_generator
def forward(self, logmel):
normalized_mel = self.normalizer(logmel)
wav = self.melgan_generator.inference(normalized_mel)
return wav

@ -0,0 +1,146 @@
# Copyright (c) 2021 PaddlePaddle Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import collections.abc as collections_abc
import paddle
_i0A = [
-4.41534164647933937950E-18, 3.33079451882223809783E-17,
-2.43127984654795469359E-16, 1.71539128555513303061E-15,
-1.16853328779934516808E-14, 7.67618549860493561688E-14,
-4.85644678311192946090E-13, 2.95505266312963983461E-12,
-1.72682629144155570723E-11, 9.67580903537323691224E-11,
-5.18979560163526290666E-10, 2.65982372468238665035E-9,
-1.30002500998624804212E-8, 6.04699502254191894932E-8,
-2.67079385394061173391E-7, 1.11738753912010371815E-6,
-4.41673835845875056359E-6, 1.64484480707288970893E-5,
-5.75419501008210370398E-5, 1.88502885095841655729E-4,
-5.76375574538582365885E-4, 1.63947561694133579842E-3,
-4.32430999505057594430E-3, 1.05464603945949983183E-2,
-2.37374148058994688156E-2, 4.93052842396707084878E-2,
-9.49010970480476444210E-2, 1.71620901522208775349E-1,
-3.04682672343198398683E-1, 6.76795274409476084995E-1
]
_i0B = [
-7.23318048787475395456E-18, -4.83050448594418207126E-18,
4.46562142029675999901E-17, 3.46122286769746109310E-17,
-2.82762398051658348494E-16, -3.42548561967721913462E-16,
1.77256013305652638360E-15, 3.81168066935262242075E-15,
-9.55484669882830764870E-15, -4.15056934728722208663E-14,
1.54008621752140982691E-14, 3.85277838274214270114E-13,
7.18012445138366623367E-13, -1.79417853150680611778E-12,
-1.32158118404477131188E-11, -3.14991652796324136454E-11,
1.18891471078464383424E-11, 4.94060238822496958910E-10,
3.39623202570838634515E-9, 2.26666899049817806459E-8,
2.04891858946906374183E-7, 2.89137052083475648297E-6,
6.88975834691682398426E-5, 3.36911647825569408990E-3,
8.04490411014108831608E-1
]
def piecewise(x, condlist, funclist, *args, **kw):
n2 = len(funclist)
# n = len(condlist)
n = 1
if n == n2 - 1: # compute the "otherwise" condition.
condelse = ~paddle.any(condlist, axis=0, keepdim=True)
condlist = paddle.concat([condlist, condelse], axis=0)
n += 1
elif n != n2:
raise ValueError(
"with {} condition(s), either {} or {} functions are expected"
.format(n, n, n + 1))
y = paddle.zeros(paddle.shape(x), x.dtype)
for k in range(n):
item = funclist[k]
if not isinstance(item, collections_abc.Callable):
y[condlist[k]] = item
else:
temp = condlist[k]
if paddle.shape(x) == paddle.ones([1]):
vals = x
y = item(vals, *args, **kw)
else:
vals = x[temp]
y[temp] = item(vals, *args, **kw)
return y
def _chbevl(x, vals):
b0 = vals[0]
b1 = 0.0
for i in range(1, len(vals)):
b2 = b1
b1 = b0
b0 = x * b1 - b2 + vals[i]
return 0.5 * (b0 - b2)
def _i0_1(x):
out = paddle.exp(x) * _chbevl(x / 2.0 - 2, _i0A)
return paddle.cast(out, dtype="float32")
def _i0_2(x):
out = paddle.exp(x) * _chbevl(32.0 / x - 2.0, _i0B) / paddle.sqrt(x)
return paddle.cast(out, dtype="float32")
def _i0_dispatcher(x):
return (x, )
def i0(x):
x = paddle.abs(x)
condlist = x <= paddle.full([1], 8.0)
condlist = condlist.unsqueeze(0)
return piecewise(x, condlist, [_i0_1, _i0_2])
def _len_guards(M):
"""Handle small or incorrect window lengths"""
if int(M) != M or M < 0:
raise ValueError('Window length M must be a non-negative integer')
return M <= 1
def _extend(M, sym):
"""Extend window by 1 sample if needed for DFT-even symmetry"""
if not sym:
return M + 1, True
else:
return M, False
def _truncate(w, needed):
"""Truncate window by 1 sample if needed for DFT-even symmetry"""
if needed:
return w[:-1]
else:
return w
def kaiser(M, beta, sym=True):
if _len_guards(M):
return paddle.ones(M)
M, needs_trunc = _extend(M, sym)
n = paddle.arange(0, M)
alpha = (M - 1) / 2.0
a = i0(beta * paddle.sqrt(1 - ((n - alpha) / alpha)**2.0))
b = i0(paddle.full([1], beta))
w = a / b
return _truncate(w, needs_trunc)

@ -15,7 +15,8 @@
import numpy as np
import paddle
import paddle.nn.functional as F
from scipy.signal import kaiser
from parakeet.modules.kaiser import kaiser
def design_prototype_filter(taps=62, cutoff_ratio=0.142, beta=9.0):
@ -44,15 +45,12 @@ def design_prototype_filter(taps=62, cutoff_ratio=0.142, beta=9.0):
# make initial filter
omega_c = np.pi * cutoff_ratio
with np.errstate(invalid="ignore"):
h_i = np.sin(omega_c * (np.arange(taps + 1) - 0.5 * taps)) / (
np.pi * (np.arange(taps + 1) - 0.5 * taps))
h_i[taps //
2] = np.cos(0) * cutoff_ratio # fix nan due to indeterminate form
h_i = paddle.sin(omega_c * (paddle.arange(taps + 1) - 0.5 * taps)) / (
np.pi * (paddle.arange(taps + 1) - 0.5 * taps))
h_i[taps // 2] = 1 * cutoff_ratio # fix nan due to indeterminate form
# apply kaiser window
w = kaiser(taps + 1, beta)
h = h_i * w
return h
@ -78,26 +76,25 @@ class PQMF(paddle.nn.Layer):
beta : float
Beta coefficient for kaiser window.
"""
super(PQMF, self).__init__()
# build analysis & synthesis filter coefficients
super().__init__()
h_proto = design_prototype_filter(taps, cutoff_ratio, beta)
h_analysis = np.zeros((subbands, len(h_proto)))
h_synthesis = np.zeros((subbands, len(h_proto)))
h_proto_len = paddle.shape(h_proto)[0]
h_analysis = paddle.zeros((subbands, h_proto_len))
h_synthesis = paddle.zeros((subbands, h_proto_len))
for k in range(subbands):
h_analysis[k] = (
2 * h_proto * np.cos((2 * k + 1) * (np.pi / (2 * subbands)) * (
np.arange(taps + 1) - (taps / 2)) + (-1)**k * np.pi / 4))
2 * h_proto *
paddle.cos((2 * k + 1) * (np.pi / (2 * subbands)) * (
paddle.arange(taps + 1) - (taps / 2)) + (-1)**k * np.pi / 4)
)
h_synthesis[k] = (
2 * h_proto * np.cos((2 * k + 1) * (np.pi / (2 * subbands)) * (
np.arange(taps + 1) - (taps / 2)) - (-1)**k * np.pi / 4))
# convert to tensor
self.analysis_filter = paddle.to_tensor(
h_analysis, dtype="float32").unsqueeze(1)
self.synthesis_filter = paddle.to_tensor(
h_synthesis, dtype="float32").unsqueeze(0)
2 * h_proto *
paddle.cos((2 * k + 1) * (np.pi / (2 * subbands)) * (
paddle.arange(taps + 1) - (taps / 2)) - (-1)**k * np.pi / 4)
)
self.analysis_filter = h_analysis.unsqueeze(1)
self.synthesis_filter = h_synthesis.unsqueeze(0)
# filter for downsampling & upsampling
updown_filter = paddle.zeros(
(subbands, subbands, subbands), dtype="float32")
@ -105,7 +102,6 @@ class PQMF(paddle.nn.Layer):
updown_filter[k, k, 0] = 1.0
self.updown_filter = updown_filter
self.subbands = subbands
# keep padding info
self.pad_fn = paddle.nn.Pad1D(taps // 2, mode='constant', value=0.0)
@ -134,7 +130,11 @@ class PQMF(paddle.nn.Layer):
Tensor
Output tensor (B, 1, T).
"""
x = F.conv1d_transpose(
x, self.updown_filter * self.subbands, stride=self.subbands)
return F.conv1d(self.pad_fn(x), self.synthesis_filter)
# when converting dygraph to static graph, can not use self.pqmf.synthesis directly
def forward(self, x):
return self.synthesis(x)

Loading…
Cancel
Save