[Hackathon 7th] fix aishell3/vctk vc0/ernie (#3928)

* fix aishell3-vc0

* fix aishell3-vc0

* Apply suggestions from code review

* Apply suggestions from code review
pull/3946/head
张春乔 3 months ago committed by GitHub
parent 5069111e6d
commit 9d16002c23
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194

@ -1115,7 +1115,8 @@ class MLMLoss(nn.Layer):
paddle.reshape(xs_pad, (-1, self.odim))), paddle.reshape(xs_pad, (-1, self.odim))),
axis=-1) axis=-1)
mlm_loss = paddle.sum((loss * paddle.reshape( mlm_loss = paddle.sum((loss * paddle.reshape(
mlm_loss_pos, [-1]))) / paddle.sum((mlm_loss_pos) + 1e-10) mlm_loss_pos,
[-1]).astype(loss.dtype))) / paddle.sum((mlm_loss_pos) + 1e-10)
text_mlm_loss = None text_mlm_loss = None

@ -466,7 +466,7 @@ def phones_masking(xs_pad: paddle.Tensor,
for s, e in zip(masked_start, masked_end): for s, e in zip(masked_start, masked_end):
masked_pos[idx, s:e] = 1 masked_pos[idx, s:e] = 1
non_eos_mask = paddle.reshape(src_mask, paddle.shape(xs_pad)[:2]) non_eos_mask = paddle.reshape(src_mask, paddle.shape(xs_pad)[:2])
masked_pos = masked_pos * non_eos_mask masked_pos = masked_pos * non_eos_mask.astype(masked_pos.dtype)
masked_pos = paddle.cast(masked_pos, 'bool') masked_pos = paddle.cast(masked_pos, 'bool')
return masked_pos return masked_pos
@ -550,10 +550,11 @@ def phones_text_masking(xs_pad: paddle.Tensor,
for s, e in zip(masked_start, masked_end): for s, e in zip(masked_start, masked_end):
masked_pos[idx, s:e] = 1 masked_pos[idx, s:e] = 1
non_eos_mask = paddle.reshape(src_mask, shape=paddle.shape(xs_pad)[:2]) non_eos_mask = paddle.reshape(src_mask, shape=paddle.shape(xs_pad)[:2])
masked_pos = masked_pos * non_eos_mask masked_pos = masked_pos * non_eos_mask.astype(masked_pos.dtype)
non_eos_text_mask = paddle.reshape( non_eos_text_mask = paddle.reshape(
text_mask, shape=paddle.shape(text_pad)[:2]) text_mask, shape=paddle.shape(text_pad)[:2])
text_masked_pos = text_masked_pos * non_eos_text_mask text_masked_pos = text_masked_pos * non_eos_text_mask.astype(
text_masked_pos.dtype)
masked_pos = paddle.cast(masked_pos, 'bool') masked_pos = paddle.cast(masked_pos, 'bool')
text_masked_pos = paddle.cast(text_masked_pos, 'bool') text_masked_pos = paddle.cast(text_masked_pos, 'bool')

@ -171,7 +171,8 @@ class AttLoc(nn.Layer):
if paddle.sum(att_prev) == 0: if paddle.sum(att_prev) == 0:
# if no bias, 0 0-pad goes 0 # if no bias, 0 0-pad goes 0
att_prev = 1.0 - make_pad_mask(enc_hs_len) att_prev = 1.0 - make_pad_mask(enc_hs_len)
att_prev = att_prev / enc_hs_len.unsqueeze(-1) att_prev = att_prev / enc_hs_len.unsqueeze(-1).astype(
att_prev.dtype)
# att_prev: (utt, frame) -> (utt, 1, 1, frame) # att_prev: (utt, frame) -> (utt, 1, 1, frame)
# -> (utt, att_conv_chans, 1, frame) # -> (utt, att_conv_chans, 1, frame)

@ -162,6 +162,8 @@ class Encoder(nn.Layer):
return xs.transpose([0, 2, 1]) return xs.transpose([0, 2, 1])
if not isinstance(ilens, paddle.Tensor): if not isinstance(ilens, paddle.Tensor):
ilens = paddle.to_tensor(ilens) ilens = paddle.to_tensor(ilens)
if ilens.ndim == 0:
ilens = ilens.unsqueeze(0)
xs = xs.transpose([0, 2, 1]) xs = xs.transpose([0, 2, 1])
# for dygraph to static graph # for dygraph to static graph
# self.blstm.flatten_parameters() # self.blstm.flatten_parameters()

Loading…
Cancel
Save