diff --git a/examples/csmsc/tts0/local/inference.sh b/examples/csmsc/tts0/local/inference.sh new file mode 100755 index 000000000..e417d748e --- /dev/null +++ b/examples/csmsc/tts0/local/inference.sh @@ -0,0 +1,51 @@ +#!/bin/bash + +train_output_path=$1 + +stage=0 +stop_stage=0 + +if [ ${stage} -le 0 ] && [ ${stop_stage} -ge 0 ]; then + python3 ${BIN_DIR}/../inference.py \ + --inference_dir=${train_output_path}/inference \ + --am=tacotron2_csmsc \ + --voc=pwgan_csmsc \ + --text=${BIN_DIR}/../sentences.txt \ + --output_dir=${train_output_path}/pd_infer_out \ + --phones_dict=dump/phone_id_map.txt +fi + +# for more GAN Vocoders +# multi band melgan +if [ ${stage} -le 1 ] && [ ${stop_stage} -ge 1 ]; then + python3 ${BIN_DIR}/../inference.py \ + --inference_dir=${train_output_path}/inference \ + --am=tacotron2_csmsc \ + --voc=mb_melgan_csmsc \ + --text=${BIN_DIR}/../sentences.txt \ + --output_dir=${train_output_path}/pd_infer_out \ + --phones_dict=dump/phone_id_map.txt +fi + +# style melgan +# style melgan's Dygraph to Static Graph is not ready now +if [ ${stage} -le 2 ] && [ ${stop_stage} -ge 2 ]; then + python3 ${BIN_DIR}/../inference.py \ + --inference_dir=${train_output_path}/inference \ + --am=tacotron2_csmsc \ + --voc=style_melgan_csmsc \ + --text=${BIN_DIR}/../sentences.txt \ + --output_dir=${train_output_path}/pd_infer_out \ + --phones_dict=dump/phone_id_map.txt +fi + +# hifigan +if [ ${stage} -le 3 ] && [ ${stop_stage} -ge 3 ]; then + python3 ${BIN_DIR}/../inference.py \ + --inference_dir=${train_output_path}/inference \ + --am=tacotron2_csmsc \ + --voc=hifigan_csmsc \ + --text=${BIN_DIR}/../sentences.txt \ + --output_dir=${train_output_path}/pd_infer_out \ + --phones_dict=dump/phone_id_map.txt +fi \ No newline at end of file diff --git a/examples/csmsc/tts0/local/synthesize_e2e.sh b/examples/csmsc/tts0/local/synthesize_e2e.sh index fe5d11d44..c957df876 100755 --- a/examples/csmsc/tts0/local/synthesize_e2e.sh +++ b/examples/csmsc/tts0/local/synthesize_e2e.sh @@ -22,8 +22,9 @@ if [ ${stage} -le 0 ] && [ ${stop_stage} -ge 0 ]; then --lang=zh \ --text=${BIN_DIR}/../sentences.txt \ --output_dir=${train_output_path}/test_e2e \ - --inference_dir=${train_output_path}/inference \ - --phones_dict=dump/phone_id_map.txt + --phones_dict=dump/phone_id_map.txt \ + --inference_dir=${train_output_path}/inference + fi # for more GAN Vocoders diff --git a/paddlespeech/t2s/exps/inference.py b/paddlespeech/t2s/exps/inference.py index 37afd0abc..c3510beaa 100644 --- a/paddlespeech/t2s/exps/inference.py +++ b/paddlespeech/t2s/exps/inference.py @@ -33,7 +33,7 @@ def main(): default='fastspeech2_csmsc', choices=[ 'speedyspeech_csmsc', 'fastspeech2_csmsc', 'fastspeech2_aishell3', - 'fastspeech2_vctk' + 'fastspeech2_vctk', 'tacotron2_csmsc' ], help='Choose acoustic model type of tts task.') parser.add_argument( diff --git a/paddlespeech/t2s/exps/synthesize_e2e.py b/paddlespeech/t2s/exps/synthesize_e2e.py index 8ebfcfe7f..8fca935a1 100644 --- a/paddlespeech/t2s/exps/synthesize_e2e.py +++ b/paddlespeech/t2s/exps/synthesize_e2e.py @@ -178,10 +178,7 @@ def evaluate(args): am_inference = jit.to_static( am_inference, input_spec=[InputSpec([-1], dtype=paddle.int64)]) - paddle.jit.save(am_inference, - os.path.join(args.inference_dir, args.am)) - am_inference = paddle.jit.load( - os.path.join(args.inference_dir, args.am)) + elif am_name == 'speedyspeech': if am_dataset in {"aishell3", "vctk"} and args.speaker_dict: am_inference = jit.to_static( @@ -200,10 +197,13 @@ def evaluate(args): InputSpec([-1], dtype=paddle.int64) ]) - paddle.jit.save(am_inference, - os.path.join(args.inference_dir, args.am)) - am_inference = paddle.jit.load( - os.path.join(args.inference_dir, args.am)) + elif am_name == 'tacotron2': + am_inference = jit.to_static( + am_inference, input_spec=[InputSpec([-1], dtype=paddle.int64)]) + + paddle.jit.save(am_inference, os.path.join(args.inference_dir, args.am)) + am_inference = paddle.jit.load( + os.path.join(args.inference_dir, args.am)) # vocoder voc_inference = jit.to_static( diff --git a/paddlespeech/t2s/models/new_tacotron2/tacotron2.py b/paddlespeech/t2s/models/new_tacotron2/tacotron2.py index 6a6d10735..bd4129fb7 100644 --- a/paddlespeech/t2s/models/new_tacotron2/tacotron2.py +++ b/paddlespeech/t2s/models/new_tacotron2/tacotron2.py @@ -432,6 +432,7 @@ class Tacotron2(nn.Layer): # inference h = self.enc.inference(x) + if self.spk_num is not None: sid_emb = self.sid_emb(spk_id.reshape([-1])) h = h + sid_emb diff --git a/paddlespeech/t2s/modules/tacotron2/attentions.py b/paddlespeech/t2s/modules/tacotron2/attentions.py index 710e326d6..af7a94f30 100644 --- a/paddlespeech/t2s/modules/tacotron2/attentions.py +++ b/paddlespeech/t2s/modules/tacotron2/attentions.py @@ -157,7 +157,7 @@ class AttLoc(nn.Layer): paddle.Tensor previous attention weights (B, T_max) """ - batch = len(enc_hs_pad) + batch = paddle.shape(enc_hs_pad)[0] # pre-compute all h outside the decoder loop if self.pre_compute_enc_h is None or self.han_mode: # (utt, frame, hdim) @@ -172,33 +172,30 @@ class AttLoc(nn.Layer): dec_z = dec_z.reshape([batch, self.dunits]) # initialize attention weight with uniform dist. - if att_prev is None: + if paddle.sum(att_prev) == 0: # if no bias, 0 0-pad goes 0 - att_prev = 1.0 - make_pad_mask(enc_hs_len) att_prev = att_prev / enc_hs_len.unsqueeze(-1) # att_prev: (utt, frame) -> (utt, 1, 1, frame) # -> (utt, att_conv_chans, 1, frame) - att_conv = self.loc_conv(att_prev.reshape([batch, 1, 1, self.h_length])) # att_conv: (utt, att_conv_chans, 1, frame) -> (utt, frame, att_conv_chans) att_conv = att_conv.squeeze(2).transpose([0, 2, 1]) # att_conv: (utt, frame, att_conv_chans) -> (utt, frame, att_dim) att_conv = self.mlp_att(att_conv) - - # dec_z_tiled: (utt, frame, att_dim) + # dec_z_tiled: (utt, frame, att_dim) dec_z_tiled = self.mlp_dec(dec_z).reshape([batch, 1, self.att_dim]) # dot with gvec # (utt, frame, att_dim) -> (utt, frame) - e = self.gvec( - paddle.tanh(att_conv + self.pre_compute_enc_h + - dec_z_tiled)).squeeze(2) + e = paddle.tanh(att_conv + self.pre_compute_enc_h + dec_z_tiled) + e = self.gvec(e).squeeze(2) # NOTE: consider zero padding when compute w. if self.mask is None: self.mask = make_pad_mask(enc_hs_len) + e = masked_fill(e, self.mask, -float("inf")) # apply monotonic attention constraint (mainly for TTS) if last_attended_idx is not None: @@ -211,7 +208,6 @@ class AttLoc(nn.Layer): # utt x hdim c = paddle.sum( self.enc_h * w.reshape([batch, self.h_length, 1]), axis=1) - return c, w diff --git a/paddlespeech/t2s/modules/tacotron2/decoder.py b/paddlespeech/t2s/modules/tacotron2/decoder.py index fc15adfda..3622fd7a2 100644 --- a/paddlespeech/t2s/modules/tacotron2/decoder.py +++ b/paddlespeech/t2s/modules/tacotron2/decoder.py @@ -15,7 +15,6 @@ """Tacotron2 decoder related modules.""" import paddle import paddle.nn.functional as F -import six from paddle import nn from paddlespeech.t2s.modules.tacotron2.attentions import AttForwardTA @@ -59,7 +58,7 @@ class Prenet(nn.Layer): super().__init__() self.dropout_rate = dropout_rate self.prenet = nn.LayerList() - for layer in six.moves.range(n_layers): + for layer in range(n_layers): n_inputs = idim if layer == 0 else n_units self.prenet.append( nn.Sequential(nn.Linear(n_inputs, n_units), nn.ReLU())) @@ -78,7 +77,7 @@ class Prenet(nn.Layer): Batch of output tensors (B, ..., odim). """ - for i in six.moves.range(len(self.prenet)): + for i in range(len(self.prenet)): # F.dropout 引入了随机, tacotron2 的 dropout 是不能去掉的 x = F.dropout(self.prenet[i](x)) return x @@ -129,7 +128,7 @@ class Postnet(nn.Layer): """ super().__init__() self.postnet = nn.LayerList() - for layer in six.moves.range(n_layers - 1): + for layer in range(n_layers - 1): ichans = odim if layer == 0 else n_chans ochans = odim if layer == n_layers - 1 else n_chans if use_batch_norm: @@ -196,7 +195,7 @@ class Postnet(nn.Layer): Batch of padded output tensor. (B, odim, Tmax). """ - for i in six.moves.range(len(self.postnet)): + for i in range(len(self.postnet)): xs = self.postnet[i](xs) return xs @@ -360,7 +359,7 @@ class Decoder(nn.Layer): # define lstm network prenet_units = prenet_units if prenet_layers != 0 else odim self.lstm = nn.LayerList() - for layer in six.moves.range(dlayers): + for layer in range(dlayers): iunits = idim + prenet_units if layer == 0 else dunits lstm = nn.LSTMCell(iunits, dunits) if zoneout_rate > 0.0: @@ -437,47 +436,50 @@ class Decoder(nn.Layer): # initialize hidden states of decoder c_list = [self._zero_state(hs)] z_list = [self._zero_state(hs)] - for _ in six.moves.range(1, len(self.lstm)): - c_list += [self._zero_state(hs)] - z_list += [self._zero_state(hs)] + for _ in range(1, len(self.lstm)): + c_list.append(self._zero_state(hs)) + z_list.append(self._zero_state(hs)) prev_out = paddle.zeros([paddle.shape(hs)[0], self.odim]) # initialize attention - prev_att_w = None + prev_att_ws = [] + prev_att_w = paddle.zeros(paddle.shape(hlens)) + prev_att_ws.append(prev_att_w) self.att.reset() # loop for an output sequence outs, logits, att_ws = [], [], [] for y in ys.transpose([1, 0, 2]): if self.use_att_extra_inputs: - att_c, att_w = self.att(hs, hlens, z_list[0], prev_att_w, + att_c, att_w = self.att(hs, hlens, z_list[0], prev_att_ws[-1], prev_out) else: - att_c, att_w = self.att(hs, hlens, z_list[0], prev_att_w) + att_c, att_w = self.att(hs, hlens, z_list[0], prev_att_ws[-1]) prenet_out = self.prenet( prev_out) if self.prenet is not None else prev_out xs = paddle.concat([att_c, prenet_out], axis=1) # we only use the second output of LSTMCell in paddle _, next_hidden = self.lstm[0](xs, (z_list[0], c_list[0])) z_list[0], c_list[0] = next_hidden - for i in six.moves.range(1, len(self.lstm)): + for i in range(1, len(self.lstm)): # we only use the second output of LSTMCell in paddle _, next_hidden = self.lstm[i](z_list[i - 1], (z_list[i], c_list[i])) z_list[i], c_list[i] = next_hidden zcs = (paddle.concat([z_list[-1], att_c], axis=1) if self.use_concate else z_list[-1]) - outs += [ - self.feat_out(zcs).reshape([paddle.shape(hs)[0], self.odim, -1]) - ] - logits += [self.prob_out(zcs)] - att_ws += [att_w] + outs.append( + self.feat_out(zcs).reshape([paddle.shape(hs)[0], self.odim, -1 + ])) + logits.append(self.prob_out(zcs)) + att_ws.append(att_w) # teacher forcing prev_out = y - if self.cumulate_att_w and prev_att_w is not None: + if self.cumulate_att_w and paddle.sum(prev_att_w) != 0: prev_att_w = prev_att_w + att_w # Note: error when use += else: prev_att_w = att_w + prev_att_ws.append(prev_att_w) # (B, Lmax) logits = paddle.concat(logits, axis=1) # (B, odim, Lmax) @@ -552,6 +554,7 @@ class Decoder(nn.Layer): .. _`Deep Voice 3`: https://arxiv.org/abs/1710.07654 """ # setup + assert len(paddle.shape(h)) == 2 hs = h.unsqueeze(0) ilens = paddle.shape(h)[0] @@ -561,13 +564,16 @@ class Decoder(nn.Layer): # initialize hidden states of decoder c_list = [self._zero_state(hs)] z_list = [self._zero_state(hs)] - for _ in six.moves.range(1, len(self.lstm)): - c_list += [self._zero_state(hs)] - z_list += [self._zero_state(hs)] + for _ in range(1, len(self.lstm)): + c_list.append(self._zero_state(hs)) + z_list.append(self._zero_state(hs)) prev_out = paddle.zeros([1, self.odim]) # initialize attention - prev_att_w = None + prev_att_ws = [] + prev_att_w = paddle.zeros([ilens]) + prev_att_ws.append(prev_att_w) + self.att.reset() # setup for attention constraint @@ -579,6 +585,7 @@ class Decoder(nn.Layer): # loop for an output sequence idx = 0 outs, att_ws, probs = [], [], [] + prob = paddle.zeros([1]) while True: # updated index idx += self.reduction_factor @@ -589,7 +596,7 @@ class Decoder(nn.Layer): hs, ilens, z_list[0], - prev_att_w, + prev_att_ws[-1], prev_out, last_attended_idx=last_attended_idx, backward_window=backward_window, @@ -599,19 +606,20 @@ class Decoder(nn.Layer): hs, ilens, z_list[0], - prev_att_w, + prev_att_ws[-1], last_attended_idx=last_attended_idx, backward_window=backward_window, forward_window=forward_window, ) - att_ws += [att_w] + att_ws.append(att_w) prenet_out = self.prenet( prev_out) if self.prenet is not None else prev_out xs = paddle.concat([att_c, prenet_out], axis=1) # we only use the second output of LSTMCell in paddle _, next_hidden = self.lstm[0](xs, (z_list[0], c_list[0])) + z_list[0], c_list[0] = next_hidden - for i in six.moves.range(1, len(self.lstm)): + for i in range(1, len(self.lstm)): # we only use the second output of LSTMCell in paddle _, next_hidden = self.lstm[i](z_list[i - 1], (z_list[i], c_list[i])) @@ -619,38 +627,38 @@ class Decoder(nn.Layer): zcs = (paddle.concat([z_list[-1], att_c], axis=1) if self.use_concate else z_list[-1]) # [(1, odim, r), ...] - outs += [self.feat_out(zcs).reshape([1, self.odim, -1])] + outs.append(self.feat_out(zcs).reshape([1, self.odim, -1])) + + prob = F.sigmoid(self.prob_out(zcs))[0] + probs.append(prob) - # [(r), ...] - probs += [F.sigmoid(self.prob_out(zcs))[0]] if self.output_activation_fn is not None: prev_out = self.output_activation_fn( outs[-1][:, :, -1]) # (1, odim) else: prev_out = outs[-1][:, :, -1] # (1, odim) - if self.cumulate_att_w and prev_att_w is not None: + if self.cumulate_att_w and paddle.sum(prev_att_w) != 0: prev_att_w = prev_att_w + att_w # Note: error when use += else: prev_att_w = att_w + prev_att_ws.append(prev_att_w) if use_att_constraint: last_attended_idx = int(att_w.argmax()) - # check whether to finish generation - if sum(paddle.cast(probs[-1] >= threshold, - 'int64')) > 0 or idx >= maxlen: + if prob >= threshold or idx >= maxlen: # check mininum length if idx < minlen: continue - # (1, odim, L) - outs = paddle.concat(outs, axis=2) - if self.postnet is not None: - # (1, odim, L) - outs = outs + self.postnet(outs) - # (L, odim) - outs = outs.transpose([0, 2, 1]).squeeze(0) - probs = paddle.concat(probs, axis=0) - att_ws = paddle.concat(att_ws, axis=0) break + # (1, odim, L) + outs = paddle.concat(outs, axis=2) + if self.postnet is not None: + # (1, odim, L) + outs = outs + self.postnet(outs) + # (L, odim) + outs = outs.transpose([0, 2, 1]).squeeze(0) + probs = paddle.concat(probs, axis=0) + att_ws = paddle.concat(att_ws, axis=0) if self.output_activation_fn is not None: outs = self.output_activation_fn(outs) @@ -685,9 +693,9 @@ class Decoder(nn.Layer): # initialize hidden states of decoder c_list = [self._zero_state(hs)] z_list = [self._zero_state(hs)] - for _ in six.moves.range(1, len(self.lstm)): - c_list += [self._zero_state(hs)] - z_list += [self._zero_state(hs)] + for _ in range(1, len(self.lstm)): + c_list.append(self._zero_state(hs)) + z_list.append(self._zero_state(hs)) prev_out = paddle.zeros([paddle.shape(hs)[0], self.odim]) # initialize attention @@ -702,14 +710,14 @@ class Decoder(nn.Layer): prev_out) else: att_c, att_w = self.att(hs, hlens, z_list[0], prev_att_w) - att_ws += [att_w] + att_ws.append(att_w) prenet_out = self.prenet( prev_out) if self.prenet is not None else prev_out xs = paddle.concat([att_c, prenet_out], axis=1) # we only use the second output of LSTMCell in paddle _, next_hidden = self.lstm[0](xs, (z_list[0], c_list[0])) z_list[0], c_list[0] = next_hidden - for i in six.moves.range(1, len(self.lstm)): + for i in range(1, len(self.lstm)): z_list[i], c_list[i] = self.lstm[i](z_list[i - 1], (z_list[i], c_list[i])) # teacher forcing diff --git a/paddlespeech/t2s/modules/tacotron2/encoder.py b/paddlespeech/t2s/modules/tacotron2/encoder.py index b2ed30d1f..80c213a1a 100644 --- a/paddlespeech/t2s/modules/tacotron2/encoder.py +++ b/paddlespeech/t2s/modules/tacotron2/encoder.py @@ -14,7 +14,6 @@ # Modified from espnet(https://github.com/espnet/espnet) """Tacotron2 encoder related modules.""" import paddle -import six from paddle import nn @@ -88,7 +87,7 @@ class Encoder(nn.Layer): if econv_layers > 0: self.convs = nn.LayerList() - for layer in six.moves.range(econv_layers): + for layer in range(econv_layers): ichans = (embed_dim if layer == 0 and input_layer == "embed" else econv_chans) if use_batch_norm: @@ -130,6 +129,7 @@ class Encoder(nn.Layer): direction='bidirectional', bias_ih_attr=True, bias_hh_attr=True) + self.blstm.flatten_parameters() else: self.blstm = None @@ -157,7 +157,7 @@ class Encoder(nn.Layer): """ xs = self.embed(xs).transpose([0, 2, 1]) if self.convs is not None: - for i in six.moves.range(len(self.convs)): + for i in range(len(self.convs)): if self.use_residual: xs += self.convs[i](xs) else: @@ -167,7 +167,8 @@ class Encoder(nn.Layer): if not isinstance(ilens, paddle.Tensor): ilens = paddle.to_tensor(ilens) xs = xs.transpose([0, 2, 1]) - self.blstm.flatten_parameters() + # for dygraph to static graph + # self.blstm.flatten_parameters() # (B, Tmax, C) # see https://www.paddlepaddle.org.cn/documentation/docs/zh/faq/train_cn.html#paddletorch-nn-utils-rnn-pack-padded-sequencetorch-nn-utils-rnn-pad-packed-sequenceapi xs, _ = self.blstm(xs, sequence_length=ilens) @@ -191,6 +192,6 @@ class Encoder(nn.Layer): """ xs = x.unsqueeze(0) - ilens = paddle.to_tensor([x.shape[0]]) + ilens = paddle.shape(x)[0] return self.forward(xs, ilens)[0][0]