|
|
@ -40,11 +40,13 @@ class ErnieSATUpdater(StandardUpdater):
|
|
|
|
init_state=None,
|
|
|
|
init_state=None,
|
|
|
|
text_masking: bool=False,
|
|
|
|
text_masking: bool=False,
|
|
|
|
odim: int=80,
|
|
|
|
odim: int=80,
|
|
|
|
|
|
|
|
vocab_size: int=100,
|
|
|
|
output_dir: Path=None):
|
|
|
|
output_dir: Path=None):
|
|
|
|
super().__init__(model, optimizer, dataloader, init_state=None)
|
|
|
|
super().__init__(model, optimizer, dataloader, init_state=None)
|
|
|
|
self.scheduler = scheduler
|
|
|
|
self.scheduler = scheduler
|
|
|
|
|
|
|
|
|
|
|
|
self.criterion = MLMLoss(text_masking=text_masking, odim=odim)
|
|
|
|
self.criterion = MLMLoss(
|
|
|
|
|
|
|
|
text_masking=text_masking, odim=odim, vocab_size=vocab_size)
|
|
|
|
|
|
|
|
|
|
|
|
log_file = output_dir / 'worker_{}.log'.format(dist.get_rank())
|
|
|
|
log_file = output_dir / 'worker_{}.log'.format(dist.get_rank())
|
|
|
|
self.filehandler = logging.FileHandler(str(log_file))
|
|
|
|
self.filehandler = logging.FileHandler(str(log_file))
|
|
|
@ -104,6 +106,7 @@ class ErnieSATEvaluator(StandardEvaluator):
|
|
|
|
dataloader: DataLoader,
|
|
|
|
dataloader: DataLoader,
|
|
|
|
text_masking: bool=False,
|
|
|
|
text_masking: bool=False,
|
|
|
|
odim: int=80,
|
|
|
|
odim: int=80,
|
|
|
|
|
|
|
|
vocab_size: int=100,
|
|
|
|
output_dir: Path=None):
|
|
|
|
output_dir: Path=None):
|
|
|
|
super().__init__(model, dataloader)
|
|
|
|
super().__init__(model, dataloader)
|
|
|
|
|
|
|
|
|
|
|
@ -113,7 +116,8 @@ class ErnieSATEvaluator(StandardEvaluator):
|
|
|
|
self.logger = logger
|
|
|
|
self.logger = logger
|
|
|
|
self.msg = ""
|
|
|
|
self.msg = ""
|
|
|
|
|
|
|
|
|
|
|
|
self.criterion = MLMLoss(text_masking=text_masking, odim=odim)
|
|
|
|
self.criterion = MLMLoss(
|
|
|
|
|
|
|
|
text_masking=text_masking, odim=odim, vocab_size=vocab_size)
|
|
|
|
|
|
|
|
|
|
|
|
def evaluate_core(self, batch):
|
|
|
|
def evaluate_core(self, batch):
|
|
|
|
self.msg = "Evaluate: "
|
|
|
|
self.msg = "Evaluate: "
|
|
|
|