fix train logitc

pull/578/head
Hui Zhang 5 years ago
parent b5bbfc5e24
commit 5ea181b7ab

@ -45,9 +45,10 @@ class DeepSpeech2Trainer(Trainer):
def __init__(self, config, args):
super().__init__(config, args)
def train_batch(self, batch_data):
start = time.time()
def train_batch(self, batch_data, msg):
self.model.train()
start = time.time()
loss = self.model(*batch_data)
loss.backward()
layer_tools.print_grads(self.model, print_func=None)
@ -59,10 +60,8 @@ class DeepSpeech2Trainer(Trainer):
losses_np = {
'train_loss': float(loss),
}
msg = "Train: Rank: {}, ".format(dist.get_rank())
msg += "epoch: {}, ".format(self.epoch)
msg += "step: {}, ".format(self.iteration)
msg += "time: {:>.3f}s, ".format(iteration_time)
msg += "batch size: {}, ".format(self.config.data.batch_size)
msg += ', '.join('{}: {:>.6f}'.format(k, v)
for k, v in losses_np.items())
self.logger.info(msg)
@ -71,6 +70,7 @@ class DeepSpeech2Trainer(Trainer):
for k, v in losses_np.items():
self.visualizer.add_scalar("train/{}".format(k), v,
self.iteration)
self.iteration += 1
@mp_tools.rank_zero_only
@paddle.no_grad()

@ -13,6 +13,8 @@
# limitations under the License.
"""Trainer for U2 model."""
import os
import cProfile
from paddle import distributed as dist
from deepspeech.utils.utility import print_arguments
@ -52,4 +54,7 @@ if __name__ == "__main__":
with open(args.dump_config, 'w') as f:
print(config, file=f)
main(config, args)
# Setting for profiling
pr = cProfile.Profile()
pr.runcall(main, config, args)
pr.dump_stats(os.path.join('.', 'train.profile'))

@ -82,40 +82,39 @@ class U2Trainer(Trainer):
start = time.time()
loss, attention_loss, ctc_loss = self.model(*batch_data)
# loss div by `batch_size * accum_grad`
loss /= train_conf.accum_grad
loss.backward()
layer_tools.print_grads(self.model, print_func=None)
if self.iteration % train_conf.accum_grad == 0:
losses_np = {
'train_loss': float(loss) * train_conf.accum_grad,
'train_att_loss': float(attention_loss),
'train_ctc_loss': float(ctc_loss),
}
if (self.iteration + 1) % train_conf.accum_grad == 0:
if dist.get_rank() == 0 and self.visualizer:
for k, v in losses_np.items():
self.visualizer.add_scalar("train/{}".format(k), v,
self.iteration)
self.optimizer.step()
self.optimizer.clear_grad()
self.lr_scheduler.step()
self.iteration += 1
iteration_time = time.time() - start
losses_np = {
'train_loss': float(loss),
'train_att_loss': float(attention_loss),
'train_ctc_loss': float(ctc_loss),
}
msg += "time: {:>.3f}s, ".format(iteration_time)
msg += "batch size: {}, ".format(self.config.data.batch_size)
msg += "accum: {}, ".format(train_conf.accum_grad)
msg += ', '.join('{}: {:>.6f}'.format(k, v)
for k, v in losses_np.items())
if self.iteration % train_conf.log_interval == 0:
if (self.iteration + 1) % train_conf.log_interval == 0:
msg += "time: {:>.3f}s, ".format(iteration_time)
msg += "batch size: {}, ".format(self.config.data.batch_size)
msg += "accum: {}, ".format(train_conf.accum_grad)
msg += ', '.join('{}: {:>.6f}'.format(k, v)
for k, v in losses_np.items())
self.logger.info(msg)
# display
if dist.get_rank() == 0 and self.visualizer:
for k, v in losses_np.items():
self.visualizer.add_scalar("train/{}".format(k), v,
self.iteration)
def train(self):
"""The training process.
It includes forward/backward/update and periodical validation and
saving.
"""
"""The training process control by step."""
# !!!IMPORTANT!!!
# Try to export the model by script, if fails, we should refine
# the code to satisfy the script export requirements
@ -124,10 +123,17 @@ class U2Trainer(Trainer):
# paddle.jit.save(script_model, script_model_path)
from_scratch = self.resume_or_scratch()
if from_scratch:
# save init model, i.e. 0 epoch
self.save(tag='init')
self.lr_scheduler.step(self.iteration)
if self.parallel:
self.train_loader.batch_sampler.set_epoch(self.epoch)
self.logger.info(
f"Train Total Examples: {len(self.train_loader.dataset)}")
while self.epoch <= self.config.training.n_epoch:
while self.epoch < self.config.training.n_epoch:
try:
data_start_time = time.time()
for batch in self.train_loader:
@ -135,9 +141,8 @@ class U2Trainer(Trainer):
msg = "Train: Rank: {}, ".format(dist.get_rank())
msg += "epoch: {}, ".format(self.epoch)
msg += "step: {}, ".format(self.iteration)
msg += "lr: {}, ".foramt(self.lr_scheduler())
msg += "lr: {:>.8f}, ".format(self.lr_scheduler())
msg += "dataloader time: {:>.3f}s, ".format(dataload_time)
self.iteration += 1
self.train_batch(batch, msg)
data_start_time = time.time()
except Exception as e:
@ -263,12 +268,12 @@ class U2Trainer(Trainer):
lr_scheduler = paddle.optimizer.lr.ExponentialDecay(
learning_rate=optim_conf.lr,
gamma=scheduler_conf.lr_decay,
verbose=True)
verbose=False)
elif scheduler_type == 'warmuplr':
lr_scheduler = WarmupLR(
learning_rate=optim_conf.lr,
warmup_steps=scheduler_conf.warmup_steps,
verbose=True)
verbose=False)
else:
raise ValueError(f"Not support scheduler: {scheduler_type}")

@ -127,7 +127,7 @@ class Trainer():
dist.init_parallel_env()
@mp_tools.rank_zero_only
def save(self, infos=None):
def save(self, tag=None, infos=None):
"""Save checkpoint (model parameters and optimizer states).
"""
if infos is None:
@ -136,8 +136,9 @@ class Trainer():
"epoch": self.epoch,
"lr": self.optimizer.get_lr(),
}
checkpoint.save_parameters(self.checkpoint_dir, self.iteration,
self.model, self.optimizer, infos)
checkpoint.save_parameters(self.checkpoint_dir, self.iteration
if tag is None else tag, self.model,
self.optimizer, infos)
def resume_or_scratch(self):
"""Resume from latest checkpoint at checkpoints in the output
@ -146,6 +147,7 @@ class Trainer():
If ``args.checkpoint_path`` is not None, load the checkpoint, else
resume training.
"""
scratch = None
infos = checkpoint.load_parameters(
self.model,
self.optimizer,
@ -155,44 +157,41 @@ class Trainer():
# restore from ckpt
self.iteration = infos["step"]
self.epoch = infos["epoch"]
self.lr_scheduler.step(self.iteration)
if self.parallel:
self.train_loader.batch_sampler.set_epoch(self.epoch)
return False
scratch = False
else:
# from scratch, epoch and iteration init with zero
# save init model, i.e. 0 epoch
self.save()
# self.epoch start from 1.
self.new_epoch()
return True
scratch = True
return scratch
def new_epoch(self):
"""Reset the train loader and increment ``epoch``.
"""Reset the train loader seed and increment `epoch`.
"""
self.epoch += 1
if self.parallel:
# batch sampler epoch start from 0
self.train_loader.batch_sampler.set_epoch(self.epoch)
self.epoch += 1
def train(self):
"""The training process.
"""
"""The training process control by epoch."""
from_scratch = self.resume_or_scratch()
if from_scratch:
# save init model, i.e. 0 epoch
self.save(tag='init')
self.lr_scheduler.step(self.iteration)
if self.parallel:
self.train_loader.batch_sampler.set_epoch(self.epoch)
self.logger.info(
f"Train Total Examples: {len(self.train_loader.dataset)}")
while self.epoch <= self.config.training.n_epoch:
while self.epoch < self.config.training.n_epoch:
try:
data_start_time = time.time()
for batch in self.train_loader:
dataload_time = time.time() - data_start_time
# iteration start from 1.
self.iteration += 1
msg = "Train: Rank: {}, ".format(dist.get_rank())
msg += "epoch: {}, ".format(self.epoch)
msg += "step: {}, ".format(self.iteration)
msg += "lr: {:>.8f}, ".format(self.lr_scheduler())
msg += "dataloader time: {:>.3f}s, ".format(dataload_time)
self.train_batch(batch, msg)
data_start_time = time.time()
@ -202,7 +201,6 @@ class Trainer():
self.valid()
self.save()
# lr control by epoch
self.lr_scheduler.step()
self.new_epoch()

@ -16,6 +16,7 @@ import os
import logging
import re
import json
from typing import Union
import paddle
from paddle import distributed as dist
@ -79,7 +80,7 @@ def load_parameters(model,
configs = {}
if checkpoint_path is not None:
iteration = int(os.path.basename(checkpoint_path).split(":")[-1])
tag = os.path.basename(checkpoint_path).split(":")[-1]
elif checkpoint_dir is not None:
iteration = _load_latest_checkpoint(checkpoint_dir)
if iteration == -1:
@ -113,14 +114,14 @@ def load_parameters(model,
@mp_tools.rank_zero_only
def save_parameters(checkpoint_dir: str,
iteration: int,
tag_or_iteration: Union[int, str],
model: paddle.nn.Layer,
optimizer: Optimizer=None,
infos: dict=None):
"""Checkpoint the latest trained model parameters.
Args:
checkpoint_dir (str): the directory where checkpoint is saved.
iteration (int): the latest iteration(step or epoch) number.
tag_or_iteration (int or str): the latest iteration(step or epoch) number.
model (Layer): model to be checkpointed.
optimizer (Optimizer, optional): optimizer to be checkpointed.
Defaults to None.
@ -128,7 +129,8 @@ def save_parameters(checkpoint_dir: str,
Returns:
None
"""
checkpoint_path = os.path.join(checkpoint_dir, "{}".format(iteration))
checkpoint_path = os.path.join(checkpoint_dir,
"{}".format(tag_or_iteration))
model_dict = model.state_dict()
params_path = checkpoint_path + ".pdparams"
@ -142,10 +144,10 @@ def save_parameters(checkpoint_dir: str,
logger.info("Saved optimzier state to {}".format(optimizer_path))
info_path = re.sub('.pdparams$', '.json', params_path)
if infos is None:
infos = {}
infos = {} if infos is None else infos
with open(info_path, 'w') as fout:
data = json.dumps(infos)
fout.write(data)
_save_checkpoint(checkpoint_dir, iteration)
if isinstance(tag_or_iteration, int):
_save_checkpoint(checkpoint_dir, tag_or_iteration)

@ -6,7 +6,6 @@ data:
vocab_filepath: data/vocab.txt
unit_type: 'char'
spm_model_prefix: ''
mean_std_filepath: ""
augmentation_config: conf/augmentation.json
batch_size: 64
min_input_len: 0.5

@ -12,7 +12,7 @@ data:
min_input_len: 0.5
max_input_len: 20.0
min_output_len: 0.0
max_output_len: 400
max_output_len: 400.0
min_output_input_ratio: 0.05
max_output_input_ratio: 10.0
raw_wav: True # use raw_wav or kaldi feature

Loading…
Cancel
Save