fix for mix_lang

pull/2117/head
TianYuan 2 years ago
parent 5503c8bd6b
commit 72fa8176ca

@ -154,6 +154,7 @@ def train_sp(args, config):
dataloader=train_dataloader, dataloader=train_dataloader,
text_masking=config["model"]["text_masking"], text_masking=config["model"]["text_masking"],
odim=odim, odim=odim,
vocab_size=vocab_size,
output_dir=output_dir) output_dir=output_dir)
trainer = Trainer(updater, (config.max_epoch, 'epoch'), output_dir) trainer = Trainer(updater, (config.max_epoch, 'epoch'), output_dir)
@ -163,6 +164,7 @@ def train_sp(args, config):
dataloader=dev_dataloader, dataloader=dev_dataloader,
text_masking=config["model"]["text_masking"], text_masking=config["model"]["text_masking"],
odim=odim, odim=odim,
vocab_size=vocab_size,
output_dir=output_dir, ) output_dir=output_dir, )
if dist.get_rank() == 0: if dist.get_rank() == 0:

@ -40,11 +40,13 @@ class ErnieSATUpdater(StandardUpdater):
init_state=None, init_state=None,
text_masking: bool=False, text_masking: bool=False,
odim: int=80, odim: int=80,
vocab_size: int=100,
output_dir: Path=None): output_dir: Path=None):
super().__init__(model, optimizer, dataloader, init_state=None) super().__init__(model, optimizer, dataloader, init_state=None)
self.scheduler = scheduler self.scheduler = scheduler
self.criterion = MLMLoss(text_masking=text_masking, odim=odim) self.criterion = MLMLoss(
text_masking=text_masking, odim=odim, vocab_size=vocab_size)
log_file = output_dir / 'worker_{}.log'.format(dist.get_rank()) log_file = output_dir / 'worker_{}.log'.format(dist.get_rank())
self.filehandler = logging.FileHandler(str(log_file)) self.filehandler = logging.FileHandler(str(log_file))
@ -104,6 +106,7 @@ class ErnieSATEvaluator(StandardEvaluator):
dataloader: DataLoader, dataloader: DataLoader,
text_masking: bool=False, text_masking: bool=False,
odim: int=80, odim: int=80,
vocab_size: int=100,
output_dir: Path=None): output_dir: Path=None):
super().__init__(model, dataloader) super().__init__(model, dataloader)
@ -113,7 +116,8 @@ class ErnieSATEvaluator(StandardEvaluator):
self.logger = logger self.logger = logger
self.msg = "" self.msg = ""
self.criterion = MLMLoss(text_masking=text_masking, odim=odim) self.criterion = MLMLoss(
text_masking=text_masking, odim=odim, vocab_size=vocab_size)
def evaluate_core(self, batch): def evaluate_core(self, batch):
self.msg = "Evaluate: " self.msg = "Evaluate: "

@ -1013,6 +1013,7 @@ class KLDivergenceLoss(nn.Layer):
class MLMLoss(nn.Layer): class MLMLoss(nn.Layer):
def __init__(self, def __init__(self,
odim: int, odim: int,
vocab_size: int=0,
lsm_weight: float=0.1, lsm_weight: float=0.1,
ignore_id: int=-1, ignore_id: int=-1,
text_masking: bool=False): text_masking: bool=False):
@ -1025,6 +1026,7 @@ class MLMLoss(nn.Layer):
self.l1_loss_func = nn.L1Loss(reduction='none') self.l1_loss_func = nn.L1Loss(reduction='none')
self.text_masking = text_masking self.text_masking = text_masking
self.odim = odim self.odim = odim
self.vocab_size = vocab_size
def forward( def forward(
self, self,
@ -1059,10 +1061,12 @@ class MLMLoss(nn.Layer):
assert text is not None assert text is not None
assert text_outs is not None assert text_outs is not None
assert text_masked_pos is not None assert text_masked_pos is not None
text_mlm_loss = paddle.sum((self.text_mlm_loss( text_outs = paddle.reshape(text_outs, [-1, self.vocab_size])
paddle.reshape(text_outs, (-1, self.vocab_size)), text = paddle.reshape(text, [-1])
paddle.reshape(text, (-1))) * paddle.reshape( text_mlm_loss = self.text_mlm_loss(text_outs, text)
text_masked_pos, text_masked_pos_reshape = paddle.reshape(text_masked_pos, [-1])
(-1)))) / paddle.sum((text_masked_pos) + 1e-10) text_mlm_loss = paddle.sum(
text_mlm_loss *
text_masked_pos_reshape) / paddle.sum((text_masked_pos) + 1e-10)
return mlm_loss, text_mlm_loss return mlm_loss, text_mlm_loss

@ -464,14 +464,15 @@ def phones_text_masking(xs_pad: paddle.Tensor,
set(range(length)) - set(masked_phn_idxs[0].tolist())) set(range(length)) - set(masked_phn_idxs[0].tolist()))
np.random.shuffle(unmasked_phn_idxs) np.random.shuffle(unmasked_phn_idxs)
masked_text_idxs = unmasked_phn_idxs[:text_mask_num_lower] masked_text_idxs = unmasked_phn_idxs[:text_mask_num_lower]
text_masked_pos[idx][masked_text_idxs] = 1 text_masked_pos[idx, masked_text_idxs] = 1
masked_start = align_start[idx][masked_phn_idxs].tolist() masked_start = align_start[idx][masked_phn_idxs].tolist()
masked_end = align_end[idx][masked_phn_idxs].tolist() masked_end = align_end[idx][masked_phn_idxs].tolist()
for s, e in zip(masked_start, masked_end): for s, e in zip(masked_start, masked_end):
masked_pos[idx, s:e] = 1 masked_pos[idx, s:e] = 1
non_eos_mask = paddle.reshape(src_mask, paddle.shape(xs_pad)[:2]) non_eos_mask = paddle.reshape(src_mask, shape=paddle.shape(xs_pad)[:2])
masked_pos = masked_pos * non_eos_mask masked_pos = masked_pos * non_eos_mask
non_eos_text_mask = paddle.reshape(text_mask, paddle.shape(xs_pad)[:2]) non_eos_text_mask = paddle.reshape(
text_mask, shape=paddle.shape(text_pad)[:2])
text_masked_pos = text_masked_pos * non_eos_text_mask text_masked_pos = text_masked_pos * non_eos_text_mask
masked_pos = paddle.cast(masked_pos, 'bool') masked_pos = paddle.cast(masked_pos, 'bool')
text_masked_pos = paddle.cast(text_masked_pos, 'bool') text_masked_pos = paddle.cast(text_masked_pos, 'bool')

Loading…
Cancel
Save