From 6fb281ca8a02ecd1eaa1a2d7eb30a5657416f08f Mon Sep 17 00:00:00 2001 From: liangym Date: Mon, 16 Jan 2023 08:20:16 +0000 Subject: [PATCH] fix pitch_mask --- .../t2s/models/diffsinger/diffsinger.py | 33 ++++++++++--------- 1 file changed, 17 insertions(+), 16 deletions(-) diff --git a/paddlespeech/t2s/models/diffsinger/diffsinger.py b/paddlespeech/t2s/models/diffsinger/diffsinger.py index 8abf13b00..24c5d4ee8 100644 --- a/paddlespeech/t2s/models/diffsinger/diffsinger.py +++ b/paddlespeech/t2s/models/diffsinger/diffsinger.py @@ -663,21 +663,11 @@ class DiffSinger(nn.Layer): hs = self._integrate_with_tone_embed(hs, tone_embs) # forward duration predictor and variance predictors d_masks = make_pad_mask(ilens) - # forward decoder - if olens is not None and not is_inference: - if self.reduction_factor > 1: - olens_in = paddle.to_tensor( - [olen // self.reduction_factor for olen in olens.numpy()]) - else: - olens_in = olens - # (B, 1, T) - h_masks = self._source_mask(olens_in) - pitch_masks = h_masks.transpose((0, 2, 1)) + if olens is not None: + pitch_masks = make_pad_mask(olens).unsqueeze(-1) else: - h_masks = None - pitch_masks = h_masks + pitch_masks = None - if is_inference: # (B, Tmax) if ds is not None: @@ -727,10 +717,22 @@ class DiffSinger(nn.Layer): (0, 2, 1)) hs = hs + e_embs + p_embs + + # forward decoder + if olens is not None and not is_inference: + if self.reduction_factor > 1: + olens_in = paddle.to_tensor( + [olen // self.reduction_factor for olen in olens.numpy()]) + else: + olens_in = olens + # (B, 1, T) + h_masks = self._source_mask(olens_in) + else: + h_masks = None + if return_after_enc: return hs, h_masks - if self.decoder_type == 'cnndecoder': # remove output masks for dygraph to static graph zs = self.decoder(hs, h_masks) @@ -1135,14 +1137,13 @@ class DiffSingerLoss(nn.Layer): duration_weights /= ds.shape[0] # apply weight - l1_loss = l1_loss.multiply(out_weights) l1_loss = l1_loss.masked_select( out_masks.broadcast_to(l1_loss.shape)).sum() duration_loss = (duration_loss.multiply(duration_weights) .masked_select(duration_masks).sum()) pitch_masks = out_masks - pitch_weights = duration_weights.unsqueeze(-1) + pitch_weights = out_weights pitch_loss = pitch_loss.multiply(pitch_weights) pitch_loss = pitch_loss.masked_select( pitch_masks.broadcast_to(pitch_loss.shape)).sum()