fix wavernn dygraph to static , test=tts

pull/1379/head
TianYuan 3 years ago
parent 2071774d81
commit 001afee644

@ -49,3 +49,14 @@ if [ ${stage} -le 3 ] && [ ${stop_stage} -ge 3 ]; then
--output_dir=${train_output_path}/pd_infer_out \ --output_dir=${train_output_path}/pd_infer_out \
--phones_dict=dump/phone_id_map.txt --phones_dict=dump/phone_id_map.txt
fi fi
# wavernn
if [ ${stage} -le 4 ] && [ ${stop_stage} -ge 4 ]; then
python3 ${BIN_DIR}/../inference.py \
--inference_dir=${train_output_path}/inference \
--am=fastspeech2_csmsc \
--voc=wavernn_csmsc \
--text=${BIN_DIR}/../sentences.txt \
--output_dir=${train_output_path}/pd_infer_out \
--phones_dict=dump/phone_id_map.txt
fi

@ -108,5 +108,6 @@ if [ ${stage} -le 4 ] && [ ${stop_stage} -ge 4 ]; then
--lang=zh \ --lang=zh \
--text=${BIN_DIR}/../sentences.txt \ --text=${BIN_DIR}/../sentences.txt \
--output_dir=${train_output_path}/test_e2e \ --output_dir=${train_output_path}/test_e2e \
--phones_dict=dump/phone_id_map.txt --phones_dict=dump/phone_id_map.txt \
--inference_dir=${train_output_path}/inference
fi fi

@ -54,7 +54,7 @@ def main():
default='pwgan_csmsc', default='pwgan_csmsc',
choices=[ choices=[
'pwgan_csmsc', 'mb_melgan_csmsc', 'hifigan_csmsc', 'pwgan_aishell3', 'pwgan_csmsc', 'mb_melgan_csmsc', 'hifigan_csmsc', 'pwgan_aishell3',
'pwgan_vctk' 'pwgan_vctk', 'wavernn_csmsc'
], ],
help='Choose vocoder type of tts task.') help='Choose vocoder type of tts task.')
# other # other

@ -76,6 +76,7 @@ class MelResNet(nn.Layer):
Tensor Tensor
Output tensor (B, res_out_dims, T). Output tensor (B, res_out_dims, T).
''' '''
x = self.conv_in(x) x = self.conv_in(x)
x = self.batch_norm(x) x = self.batch_norm(x)
x = F.relu(x) x = F.relu(x)
@ -230,6 +231,7 @@ class WaveRNN(nn.Layer):
self.rnn1 = nn.GRU(rnn_dims, rnn_dims) self.rnn1 = nn.GRU(rnn_dims, rnn_dims)
self.rnn2 = nn.GRU(rnn_dims + self.aux_dims, rnn_dims) self.rnn2 = nn.GRU(rnn_dims + self.aux_dims, rnn_dims)
self._to_flatten += [self.rnn1, self.rnn2] self._to_flatten += [self.rnn1, self.rnn2]
self.fc1 = nn.Linear(rnn_dims + self.aux_dims, fc_dims) self.fc1 = nn.Linear(rnn_dims + self.aux_dims, fc_dims)
@ -326,17 +328,17 @@ class WaveRNN(nn.Layer):
output = [] output = []
start = time.time() start = time.time()
rnn1 = self.get_gru_cell(self.rnn1)
rnn2 = self.get_gru_cell(self.rnn2)
# pseudo batch # pseudo batch
# (T, C_aux) -> (1, C_aux, T) # (T, C_aux) -> (1, C_aux, T)
c = paddle.transpose(c, [1, 0]).unsqueeze(0) c = paddle.transpose(c, [1, 0]).unsqueeze(0)
T = paddle.shape(c)[-1]
wave_len = (paddle.shape(c)[-1] - 1) * self.hop_length wave_len = (T - 1) * self.hop_length
# TODO remove two transpose op by modifying function pad_tensor # TODO remove two transpose op by modifying function pad_tensor
c = self.pad_tensor( c = self.pad_tensor(
c.transpose([0, 2, 1]), pad=self.aux_context_window, c.transpose([0, 2, 1]), pad=self.aux_context_window,
side='both').transpose([0, 2, 1]) side='both').transpose([0, 2, 1])
c, aux = self.upsample(c) c, aux = self.upsample(c)
if batched: if batched:
@ -344,7 +346,13 @@ class WaveRNN(nn.Layer):
c = self.fold_with_overlap(c, target, overlap) c = self.fold_with_overlap(c, target, overlap)
aux = self.fold_with_overlap(aux, target, overlap) aux = self.fold_with_overlap(aux, target, overlap)
b_size, seq_len, _ = paddle.shape(c) # for dygraph to static graph, if use seq_len of `b_size, seq_len, _ = paddle.shape(c)` in for
# will not get TensorArray
# see https://www.paddlepaddle.org.cn/documentation/docs/zh/guides/04_dygraph_to_static/case_analysis_cn.html#list-lodtensorarray
# b_size, seq_len, _ = paddle.shape(c)
b_size = paddle.shape(c)[0]
seq_len = paddle.shape(c)[1]
h1 = paddle.zeros([b_size, self.rnn_dims]) h1 = paddle.zeros([b_size, self.rnn_dims])
h2 = paddle.zeros([b_size, self.rnn_dims]) h2 = paddle.zeros([b_size, self.rnn_dims])
x = paddle.zeros([b_size, 1]) x = paddle.zeros([b_size, 1])
@ -354,14 +362,20 @@ class WaveRNN(nn.Layer):
for i in range(seq_len): for i in range(seq_len):
m_t = c[:, i, :] m_t = c[:, i, :]
# for dygraph to static graph
a1_t, a2_t, a3_t, a4_t = (a[:, i, :] for a in aux_split) # a1_t, a2_t, a3_t, a4_t = (a[:, i, :] for a in aux_split)
a1_t = aux_split[0][:, i, :]
a2_t = aux_split[1][:, i, :]
a3_t = aux_split[2][:, i, :]
a4_t = aux_split[3][:, i, :]
x = paddle.concat([x, m_t, a1_t], axis=1) x = paddle.concat([x, m_t, a1_t], axis=1)
x = self.I(x) x = self.I(x)
h1, _ = rnn1(x, h1) # use GRUCell here
h1, _ = self.rnn1[0].cell(x, h1)
x = x + h1 x = x + h1
inp = paddle.concat([x, a2_t], axis=1) inp = paddle.concat([x, a2_t], axis=1)
h2, _ = rnn2(inp, h2) # use GRUCell here
h2, _ = self.rnn2[0].cell(inp, h2)
x = x + h2 x = x + h2
x = paddle.concat([x, a3_t], axis=1) x = paddle.concat([x, a3_t], axis=1)
@ -413,15 +427,6 @@ class WaveRNN(nn.Layer):
# 增加 C_out 维度 # 增加 C_out 维度
return output.unsqueeze(-1) return output.unsqueeze(-1)
def get_gru_cell(self, gru):
gru_cell = nn.GRUCell(gru.input_size, gru.hidden_size)
gru_cell.weight_hh = gru.weight_hh_l0
gru_cell.weight_ih = gru.weight_ih_l0
gru_cell.bias_hh = gru.bias_hh_l0
gru_cell.bias_ih = gru.bias_ih_l0
return gru_cell
def _flatten_parameters(self): def _flatten_parameters(self):
[m.flatten_parameters() for m in self._to_flatten] [m.flatten_parameters() for m in self._to_flatten]
@ -438,7 +443,9 @@ class WaveRNN(nn.Layer):
---------- ----------
Tensor Tensor
''' '''
b, t, c = paddle.shape(x) b, t, _ = paddle.shape(x)
# for dygraph to static graph
c = x.shape[-1]
total = t + 2 * pad if side == 'both' else t + pad total = t + 2 * pad if side == 'both' else t + pad
padded = paddle.zeros([b, total, c]) padded = paddle.zeros([b, total, c])
if side == 'before' or side == 'both': if side == 'before' or side == 'both':
@ -516,7 +523,7 @@ class WaveRNN(nn.Layer):
y : Tensor y : Tensor
Batched sequences of audio samples Batched sequences of audio samples
shape=(num_folds, target + 2 * overlap) shape=(num_folds, target + 2 * overlap)
dtype=paddle.float64 dtype=paddle.float32
overlap : int overlap : int
Timesteps for both xfade and rnn warmup Timesteps for both xfade and rnn warmup
@ -525,7 +532,7 @@ class WaveRNN(nn.Layer):
Tensor Tensor
audio samples in a 1d array audio samples in a 1d array
shape=(total_len) shape=(total_len)
dtype=paddle.float64 dtype=paddle.float32
Details Details
---------- ----------
@ -545,19 +552,19 @@ class WaveRNN(nn.Layer):
''' '''
# num_folds = (total_len - overlap) // (target + overlap) # num_folds = (total_len - overlap) // (target + overlap)
num_folds, length = y.shape num_folds, length = paddle.shape(y)
target = length - 2 * overlap target = length - 2 * overlap
total_len = num_folds * (target + overlap) + overlap total_len = num_folds * (target + overlap) + overlap
# Need some silence for the run warmup # Need some silence for the run warmup
slience_len = overlap // 2 slience_len = overlap // 2
fade_len = overlap - slience_len fade_len = overlap - slience_len
slience = paddle.zeros([slience_len], dtype=paddle.float64) slience = paddle.zeros([slience_len], dtype=paddle.float32)
linear = paddle.ones([fade_len], dtype=paddle.float64) linear = paddle.ones([fade_len], dtype=paddle.float32)
# Equal power crossfade # Equal power crossfade
# fade_in increase from 0 to 1, fade_out reduces from 1 to 0 # fade_in increase from 0 to 1, fade_out reduces from 1 to 0
t = paddle.linspace(-1, 1, fade_len, dtype=paddle.float64) t = paddle.linspace(-1, 1, fade_len, dtype=paddle.float32)
fade_in = paddle.sqrt(0.5 * (1 + t)) fade_in = paddle.sqrt(0.5 * (1 + t))
fade_out = paddle.sqrt(0.5 * (1 - t)) fade_out = paddle.sqrt(0.5 * (1 - t))
# Concat the silence to the fades # Concat the silence to the fades
@ -568,7 +575,7 @@ class WaveRNN(nn.Layer):
y[:, :overlap] *= fade_in y[:, :overlap] *= fade_in
y[:, -overlap:] *= fade_out y[:, -overlap:] *= fade_out
unfolded = paddle.zeros([total_len], dtype=paddle.float64) unfolded = paddle.zeros([total_len], dtype=paddle.float32)
# Loop to add up all the samples # Loop to add up all the samples
for i in range(num_folds): for i in range(num_folds):
@ -606,11 +613,13 @@ class WaveRNNInference(nn.Layer):
mu_law: bool=True, mu_law: bool=True,
gen_display: bool=False): gen_display: bool=False):
normalized_mel = self.normalizer(logmel) normalized_mel = self.normalizer(logmel)
wav = self.wavernn.generate( wav = self.wavernn.generate(
normalized_mel, normalized_mel, )
batched=batched, # batched=batched,
target=target, # target=target,
overlap=overlap, # overlap=overlap,
mu_law=mu_law, # mu_law=mu_law,
gen_display=gen_display) # gen_display=gen_display)
return wav return wav

Loading…
Cancel
Save