From 72fa8176ca21589c596543d6bc4dfe46d3f5085a Mon Sep 17 00:00:00 2001 From: TianYuan Date: Wed, 13 Jul 2022 11:54:44 +0000 Subject: [PATCH] fix for mix_lang --- paddlespeech/t2s/exps/ernie_sat/train.py | 2 ++ .../t2s/models/ernie_sat/ernie_sat_updater.py | 8 ++++++-- paddlespeech/t2s/modules/losses.py | 14 +++++++++----- paddlespeech/t2s/modules/nets_utils.py | 7 ++++--- 4 files changed, 21 insertions(+), 10 deletions(-) diff --git a/paddlespeech/t2s/exps/ernie_sat/train.py b/paddlespeech/t2s/exps/ernie_sat/train.py index 977b8fc5..020b0d0f 100644 --- a/paddlespeech/t2s/exps/ernie_sat/train.py +++ b/paddlespeech/t2s/exps/ernie_sat/train.py @@ -154,6 +154,7 @@ def train_sp(args, config): dataloader=train_dataloader, text_masking=config["model"]["text_masking"], odim=odim, + vocab_size=vocab_size, output_dir=output_dir) trainer = Trainer(updater, (config.max_epoch, 'epoch'), output_dir) @@ -163,6 +164,7 @@ def train_sp(args, config): dataloader=dev_dataloader, text_masking=config["model"]["text_masking"], odim=odim, + vocab_size=vocab_size, output_dir=output_dir, ) if dist.get_rank() == 0: diff --git a/paddlespeech/t2s/models/ernie_sat/ernie_sat_updater.py b/paddlespeech/t2s/models/ernie_sat/ernie_sat_updater.py index 17cfaae9..219341c8 100644 --- a/paddlespeech/t2s/models/ernie_sat/ernie_sat_updater.py +++ b/paddlespeech/t2s/models/ernie_sat/ernie_sat_updater.py @@ -40,11 +40,13 @@ class ErnieSATUpdater(StandardUpdater): init_state=None, text_masking: bool=False, odim: int=80, + vocab_size: int=100, output_dir: Path=None): super().__init__(model, optimizer, dataloader, init_state=None) self.scheduler = scheduler - self.criterion = MLMLoss(text_masking=text_masking, odim=odim) + self.criterion = MLMLoss( + text_masking=text_masking, odim=odim, vocab_size=vocab_size) log_file = output_dir / 'worker_{}.log'.format(dist.get_rank()) self.filehandler = logging.FileHandler(str(log_file)) @@ -104,6 +106,7 @@ class ErnieSATEvaluator(StandardEvaluator): dataloader: DataLoader, text_masking: bool=False, odim: int=80, + vocab_size: int=100, output_dir: Path=None): super().__init__(model, dataloader) @@ -113,7 +116,8 @@ class ErnieSATEvaluator(StandardEvaluator): self.logger = logger self.msg = "" - self.criterion = MLMLoss(text_masking=text_masking, odim=odim) + self.criterion = MLMLoss( + text_masking=text_masking, odim=odim, vocab_size=vocab_size) def evaluate_core(self, batch): self.msg = "Evaluate: " diff --git a/paddlespeech/t2s/modules/losses.py b/paddlespeech/t2s/modules/losses.py index 95f2ff86..b3cf45aa 100644 --- a/paddlespeech/t2s/modules/losses.py +++ b/paddlespeech/t2s/modules/losses.py @@ -1013,6 +1013,7 @@ class KLDivergenceLoss(nn.Layer): class MLMLoss(nn.Layer): def __init__(self, odim: int, + vocab_size: int=0, lsm_weight: float=0.1, ignore_id: int=-1, text_masking: bool=False): @@ -1025,6 +1026,7 @@ class MLMLoss(nn.Layer): self.l1_loss_func = nn.L1Loss(reduction='none') self.text_masking = text_masking self.odim = odim + self.vocab_size = vocab_size def forward( self, @@ -1059,10 +1061,12 @@ class MLMLoss(nn.Layer): assert text is not None assert text_outs is not None assert text_masked_pos is not None - text_mlm_loss = paddle.sum((self.text_mlm_loss( - paddle.reshape(text_outs, (-1, self.vocab_size)), - paddle.reshape(text, (-1))) * paddle.reshape( - text_masked_pos, - (-1)))) / paddle.sum((text_masked_pos) + 1e-10) + text_outs = paddle.reshape(text_outs, [-1, self.vocab_size]) + text = paddle.reshape(text, [-1]) + text_mlm_loss = self.text_mlm_loss(text_outs, text) + text_masked_pos_reshape = paddle.reshape(text_masked_pos, [-1]) + text_mlm_loss = paddle.sum( + text_mlm_loss * + text_masked_pos_reshape) / paddle.sum((text_masked_pos) + 1e-10) return mlm_loss, text_mlm_loss diff --git a/paddlespeech/t2s/modules/nets_utils.py b/paddlespeech/t2s/modules/nets_utils.py index 608a4742..1490ae83 100644 --- a/paddlespeech/t2s/modules/nets_utils.py +++ b/paddlespeech/t2s/modules/nets_utils.py @@ -464,14 +464,15 @@ def phones_text_masking(xs_pad: paddle.Tensor, set(range(length)) - set(masked_phn_idxs[0].tolist())) np.random.shuffle(unmasked_phn_idxs) masked_text_idxs = unmasked_phn_idxs[:text_mask_num_lower] - text_masked_pos[idx][masked_text_idxs] = 1 + text_masked_pos[idx, masked_text_idxs] = 1 masked_start = align_start[idx][masked_phn_idxs].tolist() masked_end = align_end[idx][masked_phn_idxs].tolist() for s, e in zip(masked_start, masked_end): masked_pos[idx, s:e] = 1 - non_eos_mask = paddle.reshape(src_mask, paddle.shape(xs_pad)[:2]) + non_eos_mask = paddle.reshape(src_mask, shape=paddle.shape(xs_pad)[:2]) masked_pos = masked_pos * non_eos_mask - non_eos_text_mask = paddle.reshape(text_mask, paddle.shape(xs_pad)[:2]) + non_eos_text_mask = paddle.reshape( + text_mask, shape=paddle.shape(text_pad)[:2]) text_masked_pos = text_masked_pos * non_eos_text_mask masked_pos = paddle.cast(masked_pos, 'bool') text_masked_pos = paddle.cast(text_masked_pos, 'bool')