|
|
|
@ -663,21 +663,11 @@ class DiffSinger(nn.Layer):
|
|
|
|
|
hs = self._integrate_with_tone_embed(hs, tone_embs)
|
|
|
|
|
# forward duration predictor and variance predictors
|
|
|
|
|
d_masks = make_pad_mask(ilens)
|
|
|
|
|
# forward decoder
|
|
|
|
|
if olens is not None and not is_inference:
|
|
|
|
|
if self.reduction_factor > 1:
|
|
|
|
|
olens_in = paddle.to_tensor(
|
|
|
|
|
[olen // self.reduction_factor for olen in olens.numpy()])
|
|
|
|
|
else:
|
|
|
|
|
olens_in = olens
|
|
|
|
|
# (B, 1, T)
|
|
|
|
|
h_masks = self._source_mask(olens_in)
|
|
|
|
|
pitch_masks = h_masks.transpose((0, 2, 1))
|
|
|
|
|
if olens is not None:
|
|
|
|
|
pitch_masks = make_pad_mask(olens).unsqueeze(-1)
|
|
|
|
|
else:
|
|
|
|
|
h_masks = None
|
|
|
|
|
pitch_masks = h_masks
|
|
|
|
|
pitch_masks = None
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
if is_inference:
|
|
|
|
|
# (B, Tmax)
|
|
|
|
|
if ds is not None:
|
|
|
|
@ -727,10 +717,22 @@ class DiffSinger(nn.Layer):
|
|
|
|
|
(0, 2, 1))
|
|
|
|
|
hs = hs + e_embs + p_embs
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
# forward decoder
|
|
|
|
|
if olens is not None and not is_inference:
|
|
|
|
|
if self.reduction_factor > 1:
|
|
|
|
|
olens_in = paddle.to_tensor(
|
|
|
|
|
[olen // self.reduction_factor for olen in olens.numpy()])
|
|
|
|
|
else:
|
|
|
|
|
olens_in = olens
|
|
|
|
|
# (B, 1, T)
|
|
|
|
|
h_masks = self._source_mask(olens_in)
|
|
|
|
|
else:
|
|
|
|
|
h_masks = None
|
|
|
|
|
|
|
|
|
|
if return_after_enc:
|
|
|
|
|
return hs, h_masks
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
if self.decoder_type == 'cnndecoder':
|
|
|
|
|
# remove output masks for dygraph to static graph
|
|
|
|
|
zs = self.decoder(hs, h_masks)
|
|
|
|
@ -1135,14 +1137,13 @@ class DiffSingerLoss(nn.Layer):
|
|
|
|
|
duration_weights /= ds.shape[0]
|
|
|
|
|
|
|
|
|
|
# apply weight
|
|
|
|
|
|
|
|
|
|
l1_loss = l1_loss.multiply(out_weights)
|
|
|
|
|
l1_loss = l1_loss.masked_select(
|
|
|
|
|
out_masks.broadcast_to(l1_loss.shape)).sum()
|
|
|
|
|
duration_loss = (duration_loss.multiply(duration_weights)
|
|
|
|
|
.masked_select(duration_masks).sum())
|
|
|
|
|
pitch_masks = out_masks
|
|
|
|
|
pitch_weights = duration_weights.unsqueeze(-1)
|
|
|
|
|
pitch_weights = out_weights
|
|
|
|
|
pitch_loss = pitch_loss.multiply(pitch_weights)
|
|
|
|
|
pitch_loss = pitch_loss.masked_select(
|
|
|
|
|
pitch_masks.broadcast_to(pitch_loss.shape)).sum()
|
|
|
|
|