pull/578/head
Hui Zhang 4 years ago
parent 20f1976899
commit 16b8b9821c

@ -73,28 +73,37 @@ class DeepSpeech2Trainer(Trainer):
logger.info(f"Valid Total Examples: {len(self.valid_loader.dataset)}") logger.info(f"Valid Total Examples: {len(self.valid_loader.dataset)}")
self.model.eval() self.model.eval()
valid_losses = defaultdict(list) valid_losses = defaultdict(list)
num_seen_utts = 1
total_loss = 0.0
for i, batch in enumerate(self.valid_loader): for i, batch in enumerate(self.valid_loader):
loss = self.model(*batch) loss = self.model(*batch)
if paddle.isfinite(loss):
num_utts = batch[0].shape[0]
num_seen_utts += num_utts
total_loss += float(loss) * num_utts
valid_losses['val_loss'].append(float(loss)) valid_losses['val_loss'].append(float(loss))
# write visual log if (i + 1) % self.config.training.log_interval == 0:
valid_losses = {k: np.mean(v) for k, v in valid_losses.items()} valid_losses['val_history_loss'] = total_loss / num_seen_utts
# logging # write visual log
msg = f"Valid: Rank: {dist.get_rank()}, " valid_losses = {k: np.mean(v) for k, v in valid_losses.items()}
msg += "epoch: {}, ".format(self.epoch)
msg += "step: {}, ".format(self.iteration)
msg += ', '.join('{}: {:>.6f}'.format(k, v)
for k, v in valid_losses.items())
logger.info(msg)
if self.visualizer: # logging
for k, v in valid_losses.items(): msg = f"Valid: Rank: {dist.get_rank()}, "
self.visualizer.add_scalar("valid/{}".format(k), v, msg += "epoch: {}, ".format(self.epoch)
self.iteration) msg += "step: {}, ".format(self.iteration)
msg += "batch : {}/{}, ".format(i + 1, len(self.valid_loader))
msg += ', '.join('{}: {:>.6f}'.format(k, v)
for k, v in valid_losses.items())
logger.info(msg)
if self.visualizer:
for k, v in valid_losses.items():
self.visualizer.add_scalar("valid/{}".format(k), v,
self.iteration)
return valid_losses return total_loss, num_seen_utts
def setup_model(self): def setup_model(self):
config = self.config config = self.config

@ -81,11 +81,11 @@ class U2Trainer(Trainer):
loss.backward() loss.backward()
layer_tools.print_grads(self.model, print_func=None) layer_tools.print_grads(self.model, print_func=None)
losses_np = { losses_np = {'loss': float(loss) * train_conf.accum_grad}
'train_loss': float(loss) * train_conf.accum_grad, if attention_loss:
'train_att_loss': float(attention_loss), losses_np['att_loss'] = float(attention_loss)
'train_ctc_loss': float(ctc_loss), if ctc_loss:
} losses_np['ctc_loss'] = float(ctc_loss)
if (batch_index + 1) % train_conf.accum_grad == 0: if (batch_index + 1) % train_conf.accum_grad == 0:
if dist.get_rank() == 0 and self.visualizer: if dist.get_rank() == 0 and self.visualizer:
@ -135,6 +135,8 @@ class U2Trainer(Trainer):
msg = "Train: Rank: {}, ".format(dist.get_rank()) msg = "Train: Rank: {}, ".format(dist.get_rank())
msg += "epoch: {}, ".format(self.epoch) msg += "epoch: {}, ".format(self.epoch)
msg += "step: {}, ".format(self.iteration) msg += "step: {}, ".format(self.iteration)
msg += "batch : {}/{}, ".format(batch_index + 1,
len(self.train_loader))
msg += "lr: {:>.8f}, ".format(self.lr_scheduler()) msg += "lr: {:>.8f}, ".format(self.lr_scheduler())
msg += "dataloader time: {:>.3f}s, ".format(dataload_time) msg += "dataloader time: {:>.3f}s, ".format(dataload_time)
self.train_batch(batch_index, batch, msg) self.train_batch(batch_index, batch, msg)
@ -143,8 +145,9 @@ class U2Trainer(Trainer):
logger.error(e) logger.error(e)
raise e raise e
valid_losses = self.valid() total_loss, num_seen_utts = self.valid()
self.save(tag=self.epoch, infos=valid_losses) self.save(
tag=self.epoch, infos={'val_loss': total_loss / num_seen_utts})
self.new_epoch() self.new_epoch()
@mp_tools.rank_zero_only @mp_tools.rank_zero_only
@ -153,29 +156,42 @@ class U2Trainer(Trainer):
self.model.eval() self.model.eval()
logger.info(f"Valid Total Examples: {len(self.valid_loader.dataset)}") logger.info(f"Valid Total Examples: {len(self.valid_loader.dataset)}")
valid_losses = defaultdict(list) valid_losses = defaultdict(list)
num_seen_utts = 1
total_loss = 0.0
for i, batch in enumerate(self.valid_loader): for i, batch in enumerate(self.valid_loader):
total_loss, attention_loss, ctc_loss = self.model(*batch) loss, attention_loss, ctc_loss = self.model(*batch)
if paddle.isfinite(loss):
valid_losses['val_loss'].append(float(total_loss)) num_utts = batch[0].shape[0]
valid_losses['val_att_loss'].append(float(attention_loss)) num_seen_utts += num_utts
valid_losses['val_ctc_loss'].append(float(ctc_loss)) total_loss += float(loss) * num_utts
valid_losses = {'val_loss': float(loss)}
# write visual log if attention_loss:
valid_losses = {k: np.mean(v) for k, v in valid_losses.items()} valid_losses['val_att_loss'] = float(attention_loss)
if ctc_loss:
# logging valid_losses['val_ctc_loss'] = float(ctc_loss)
msg = f"Valid: Rank: {dist.get_rank()}, "
msg += "epoch: {}, ".format(self.epoch) if (i + 1) % self.config.training.log_interval == 0:
msg += "step: {}, ".format(self.iteration) valid_losses['val_history_loss'] = total_loss / num_seen_utts
msg += ', '.join('{}: {:>.6f}'.format(k, v)
for k, v in valid_losses.items()) # write visual log
logger.info(msg) valid_losses = {k: np.mean(v) for k, v in valid_losses.items()}
if self.visualizer: # logging
valid_losses_v = valid_losses.copy() msg = f"Valid: Rank: {dist.get_rank()}, "
valid_losses_v.update({"lr": self.lr_scheduler()}) msg += "epoch: {}, ".format(self.epoch)
self.visualizer.add_scalars('epoch', valid_losses_v, self.epoch) msg += "step: {}, ".format(self.iteration)
return valid_losses msg += "batch : {}/{}, ".format(i + 1, len(self.valid_loader))
msg += ', '.join('{}: {:>.6f}'.format(k, v)
for k, v in valid_losses.items())
logger.info(msg)
if self.visualizer:
valid_losses_v = valid_losses.copy()
valid_losses_v.update({"lr": self.lr_scheduler()})
self.visualizer.add_scalars('epoch', valid_losses_v,
self.epoch)
return total_loss, num_seen_utts
def setup_dataloader(self): def setup_dataloader(self):
config = self.config.clone() config = self.config.clone()

@ -56,6 +56,7 @@ class ClipGradByGlobalNormWithLog(paddle.nn.ClipGradByGlobalNorm):
global_norm_var = layers.sqrt(global_norm_var) global_norm_var = layers.sqrt(global_norm_var)
# debug log # debug log
logger.debug(f"Grad Global Norm: {float(global_norm_var)}!!!!") logger.debug(f"Grad Global Norm: {float(global_norm_var)}!!!!")
max_global_norm = layers.fill_constant( max_global_norm = layers.fill_constant(
shape=[1], dtype=global_norm_var.dtype, value=self.clip_norm) shape=[1], dtype=global_norm_var.dtype, value=self.clip_norm)
clip_var = layers.elementwise_div( clip_var = layers.elementwise_div(

@ -53,5 +53,14 @@ class WarmupLR(LRScheduler):
return self.base_lr * self.warmup_steps**0.5 * min( return self.base_lr * self.warmup_steps**0.5 * min(
step_num**-0.5, step_num * self.warmup_steps**-1.5) step_num**-0.5, step_num * self.warmup_steps**-1.5)
def set_step(self, step: int): def set_step(self, step: int=None):
self.step(step) '''
It will update the learning rate in optimizer according to current ``epoch`` .
The new learning rate will take effect on next ``optimizer.step`` .
Args:
step (int, None): specify current epoch. Default: None. Auto-increment from last_epoch=-1.
Returns:
None
'''
self.step(epoch=step)

@ -157,6 +157,8 @@ class Trainer():
self.epoch = infos["epoch"] self.epoch = infos["epoch"]
scratch = False scratch = False
else: else:
self.iteration = 0
self.epoch = 0
scratch = True scratch = True
return scratch return scratch
@ -189,6 +191,8 @@ class Trainer():
msg = "Train: Rank: {}, ".format(dist.get_rank()) msg = "Train: Rank: {}, ".format(dist.get_rank())
msg += "epoch: {}, ".format(self.epoch) msg += "epoch: {}, ".format(self.epoch)
msg += "step: {}, ".format(self.iteration) msg += "step: {}, ".format(self.iteration)
msg += "batch : {}/{}, ".format(batch_index + 1,
len(self.train_loader))
msg += "lr: {:>.8f}, ".format(self.lr_scheduler()) msg += "lr: {:>.8f}, ".format(self.lr_scheduler())
msg += "dataloader time: {:>.3f}s, ".format(dataload_time) msg += "dataloader time: {:>.3f}s, ".format(dataload_time)
self.train_batch(batch_index, batch, msg) self.train_batch(batch_index, batch, msg)
@ -197,8 +201,8 @@ class Trainer():
logger.error(e) logger.error(e)
raise e raise e
valid_losses = self.valid() total_loss, num_seen_utts = self.valid()
self.save(infos=valid_losses) self.save(infos={'val_loss': total_loss / num_seen_utts})
self.lr_scheduler.step() self.lr_scheduler.step()
self.new_epoch() self.new_epoch()

Loading…
Cancel
Save