From 1b0c034134005adbe2f3754dc8b301ca044d6613 Mon Sep 17 00:00:00 2001 From: TianYuan Date: Sat, 29 Jan 2022 03:32:08 +0000 Subject: [PATCH] update wavernn, test=tts --- examples/csmsc/voc6/conf/default.yaml | 11 ++-- paddlespeech/t2s/audio/__init__.py | 1 + paddlespeech/t2s/audio/codec.py | 51 +++++++++++++++++++ paddlespeech/t2s/datasets/vocoder_batch_fn.py | 35 +++---------- paddlespeech/t2s/models/wavernn/wavernn.py | 14 ++--- 5 files changed, 72 insertions(+), 40 deletions(-) create mode 100644 paddlespeech/t2s/audio/codec.py diff --git a/examples/csmsc/voc6/conf/default.yaml b/examples/csmsc/voc6/conf/default.yaml index 2c838fb9..e7696cf4 100644 --- a/examples/csmsc/voc6/conf/default.yaml +++ b/examples/csmsc/voc6/conf/default.yaml @@ -12,7 +12,6 @@ n_mels: 80 # Number of mel basis. fmin: 80 # Minimum freq in mel basis calculation. (Hz) fmax: 7600 # Maximum frequency in mel basis calculation. (Hz) mu_law: True # Recommended to suppress noise if using raw bitsexit() -peak_norm: True ########################################################### @@ -22,13 +21,14 @@ model: rnn_dims: 512 # Hidden dims of RNN Layers. fc_dims: 512 bits: 9 # Bit depth of signal - aux_context_window: 2 + aux_context_window: 2 # Context window size for auxiliary feature. + # If set to 2, previous 2 and future 2 frames will be considered. aux_channels: 80 # Number of channels for auxiliary feature conv. # Must be the same as num_mels. upsample_scales: [4, 5, 3, 5] # Upsampling scales. Prodcut of these must be the same as hop size, same with pwgan here - compute_dims: 128 - res_out_dims: 128 - res_blocks: 10 + compute_dims: 128 # Dims of Conv1D in MelResNet. + res_out_dims: 128 # Dims of output in MelResNet. + res_blocks: 10 # Number of residual blocks. mode: RAW # either 'raw'(softmax on raw bits) or 'mold' (sample from mixture of logistics) inference: gen_batched: True # whether to genenate sample in batch mode @@ -42,7 +42,6 @@ inference: batch_size: 64 # Batch size. batch_max_steps: 4500 # Length of each audio in batch. Make sure dividable by hop_size. num_workers: 2 # Number of workers in DataLoader. -valid_size: 50 ########################################################### # OPTIMIZER SETTING # diff --git a/paddlespeech/t2s/audio/__init__.py b/paddlespeech/t2s/audio/__init__.py index 7747b794..0deefc8b 100644 --- a/paddlespeech/t2s/audio/__init__.py +++ b/paddlespeech/t2s/audio/__init__.py @@ -12,5 +12,6 @@ # See the License for the specific language governing permissions and # limitations under the License. from .audio import AudioProcessor +from .codec import * from .spec_normalizer import LogMagnitude from .spec_normalizer import NormalizerBase diff --git a/paddlespeech/t2s/audio/codec.py b/paddlespeech/t2s/audio/codec.py new file mode 100644 index 00000000..2a759ce4 --- /dev/null +++ b/paddlespeech/t2s/audio/codec.py @@ -0,0 +1,51 @@ +# Copyright (c) 2020 PaddlePaddle Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +import math + +import numpy as np +import paddle + + +# x: [0: 2**bit-1], return: [-1, 1] +def label_2_float(x, bits): + return 2 * x / (2**bits - 1.) - 1. + + +#x: [-1, 1], return: [0, 2**bits-1] +def float_2_label(x, bits): + assert abs(x).max() <= 1.0 + x = (x + 1.) * (2**bits - 1) / 2 + return x.clip(0, 2**bits - 1) + + +# y: [-1, 1], mu: 2**bits, return: [0, 2**bits-1] +# see https://en.wikipedia.org/wiki/%CE%9C-law_algorithm +# be careful the input `mu` here, which is +1 than that of the link above +def encode_mu_law(x, mu): + mu = mu - 1 + fx = np.sign(x) * np.log(1 + mu * np.abs(x)) / np.log(1 + mu) + return np.floor((fx + 1) / 2 * mu + 0.5) + + +# from_labels = True: +# y: [0: 2**bit-1], mu: 2**bits, return: [-1,1] +# from_labels = False: +# y: [-1, 1], return: [-1, 1] +def decode_mu_law(y, mu, from_labels=True): + # TODO: get rid of log2 - makes no sense + if from_labels: + y = label_2_float(y, math.log2(mu)) + mu = mu - 1 + x = paddle.sign(y) / mu * ((1 + mu)**paddle.abs(y) - 1) + return x diff --git a/paddlespeech/t2s/datasets/vocoder_batch_fn.py b/paddlespeech/t2s/datasets/vocoder_batch_fn.py index b1d22db9..d969a1d3 100644 --- a/paddlespeech/t2s/datasets/vocoder_batch_fn.py +++ b/paddlespeech/t2s/datasets/vocoder_batch_fn.py @@ -11,35 +11,12 @@ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. -import math - import numpy as np import paddle - -def label_2_float(x, bits): - return 2 * x / (2**bits - 1.) - 1. - - -def float_2_label(x, bits): - assert abs(x).max() <= 1.0 - x = (x + 1.) * (2**bits - 1) / 2 - return x.clip(0, 2**bits - 1) - - -def encode_mu_law(x, mu): - mu = mu - 1 - fx = np.sign(x) * np.log(1 + mu * np.abs(x)) / np.log(1 + mu) - return np.floor((fx + 1) / 2 * mu + 0.5) - - -def decode_mu_law(y, mu, from_labels=True): - # TODO: get rid of log2 - makes no sense - if from_labels: - y = label_2_float(y, math.log2(mu)) - mu = mu - 1 - x = paddle.sign(y) / mu * ((1 + mu)**paddle.abs(y) - 1) - return x +from paddlespeech.t2s.audio.codec import encode_mu_law +from paddlespeech.t2s.audio.codec import float_2_label +from paddlespeech.t2s.audio.codec import label_2_float class Clip(object): @@ -195,10 +172,12 @@ class WaveRNNClip(Clip): Returns ---------- Tensor - Auxiliary feature batch (B, C, T'), where - T = (T' - 2 * aux_context_window) * hop_size. + Input signal batch (B, 1, T). Tensor Target signal batch (B, 1, T). + Tensor + Auxiliary feature batch (B, C, T'), where + T = (T' - 2 * aux_context_window) * hop_size. """ # check length diff --git a/paddlespeech/t2s/models/wavernn/wavernn.py b/paddlespeech/t2s/models/wavernn/wavernn.py index f30879ed..fcf39a48 100644 --- a/paddlespeech/t2s/models/wavernn/wavernn.py +++ b/paddlespeech/t2s/models/wavernn/wavernn.py @@ -20,7 +20,7 @@ import paddle from paddle import nn from paddle.nn import functional as F -from paddlespeech.t2s.datasets.vocoder_batch_fn import decode_mu_law +from paddlespeech.t2s.audio.codec import decode_mu_law from paddlespeech.t2s.modules.losses import sample_from_discretized_mix_logistic from paddlespeech.t2s.modules.nets_utils import initialize from paddlespeech.t2s.modules.upsample import Stretch2D @@ -28,7 +28,7 @@ from paddlespeech.t2s.modules.upsample import Stretch2D class ResBlock(nn.Layer): def __init__(self, dims): - super(ResBlock, self).__init__() + super().__init__() self.conv1 = nn.Conv1D(dims, dims, kernel_size=1, bias_attr=False) self.conv2 = nn.Conv1D(dims, dims, kernel_size=1, bias_attr=False) self.batch_norm1 = nn.BatchNorm1D(dims) @@ -205,7 +205,7 @@ class WaveRNN(nn.Layer): if self.mode == 'RAW': self.n_classes = 2**bits elif self.mode == 'MOL': - self.n_classes = 30 + self.n_classes = 10 * 3 else: RuntimeError('Unknown model mode value - ', self.mode) @@ -333,7 +333,7 @@ class WaveRNN(nn.Layer): # (T, C_aux) -> (1, C_aux, T) c = paddle.transpose(c, [1, 0]).unsqueeze(0) T = paddle.shape(c)[-1] - wave_len = (T - 1) * self.hop_length + wave_len = T * self.hop_length # TODO remove two transpose op by modifying function pad_tensor c = self.pad_tensor( c.transpose([0, 2, 1]), pad=self.aux_context_window, @@ -396,6 +396,8 @@ class WaveRNN(nn.Layer): posterior = F.softmax(logits, axis=1) distrib = paddle.distribution.Categorical(posterior) # corresponding operate [np.floor((fx + 1) / 2 * mu + 0.5)] in enocde_mu_law + # distrib.sample([1])[0].cast('float32'): [0, 2**bits-1] + # sample: [-1, 1] sample = 2 * distrib.sample([1])[0].cast('float32') / ( self.n_classes - 1.) - 1. output.append(sample) @@ -418,9 +420,9 @@ class WaveRNN(nn.Layer): output = output[0] # Fade-out at the end to avoid signal cutting out suddenly - fade_out = paddle.linspace(1, 0, 20 * self.hop_length) + fade_out = paddle.linspace(1, 0, 10 * self.hop_length) output = output[:wave_len] - output[-20 * self.hop_length:] *= fade_out + output[-10 * self.hop_length:] *= fade_out self.train()