fix train logitc

pull/578/head
Hui Zhang 5 years ago
parent b5bbfc5e24
commit 5ea181b7ab

@ -45,9 +45,10 @@ class DeepSpeech2Trainer(Trainer):
def __init__(self, config, args): def __init__(self, config, args):
super().__init__(config, args) super().__init__(config, args)
def train_batch(self, batch_data): def train_batch(self, batch_data, msg):
start = time.time()
self.model.train() self.model.train()
start = time.time()
loss = self.model(*batch_data) loss = self.model(*batch_data)
loss.backward() loss.backward()
layer_tools.print_grads(self.model, print_func=None) layer_tools.print_grads(self.model, print_func=None)
@ -59,10 +60,8 @@ class DeepSpeech2Trainer(Trainer):
losses_np = { losses_np = {
'train_loss': float(loss), 'train_loss': float(loss),
} }
msg = "Train: Rank: {}, ".format(dist.get_rank())
msg += "epoch: {}, ".format(self.epoch)
msg += "step: {}, ".format(self.iteration)
msg += "time: {:>.3f}s, ".format(iteration_time) msg += "time: {:>.3f}s, ".format(iteration_time)
msg += "batch size: {}, ".format(self.config.data.batch_size)
msg += ', '.join('{}: {:>.6f}'.format(k, v) msg += ', '.join('{}: {:>.6f}'.format(k, v)
for k, v in losses_np.items()) for k, v in losses_np.items())
self.logger.info(msg) self.logger.info(msg)
@ -71,6 +70,7 @@ class DeepSpeech2Trainer(Trainer):
for k, v in losses_np.items(): for k, v in losses_np.items():
self.visualizer.add_scalar("train/{}".format(k), v, self.visualizer.add_scalar("train/{}".format(k), v,
self.iteration) self.iteration)
self.iteration += 1
@mp_tools.rank_zero_only @mp_tools.rank_zero_only
@paddle.no_grad() @paddle.no_grad()

@ -13,6 +13,8 @@
# limitations under the License. # limitations under the License.
"""Trainer for U2 model.""" """Trainer for U2 model."""
import os
import cProfile
from paddle import distributed as dist from paddle import distributed as dist
from deepspeech.utils.utility import print_arguments from deepspeech.utils.utility import print_arguments
@ -52,4 +54,7 @@ if __name__ == "__main__":
with open(args.dump_config, 'w') as f: with open(args.dump_config, 'w') as f:
print(config, file=f) print(config, file=f)
main(config, args) # Setting for profiling
pr = cProfile.Profile()
pr.runcall(main, config, args)
pr.dump_stats(os.path.join('.', 'train.profile'))

@ -82,40 +82,39 @@ class U2Trainer(Trainer):
start = time.time() start = time.time()
loss, attention_loss, ctc_loss = self.model(*batch_data) loss, attention_loss, ctc_loss = self.model(*batch_data)
# loss div by `batch_size * accum_grad`
loss /= train_conf.accum_grad
loss.backward() loss.backward()
layer_tools.print_grads(self.model, print_func=None) layer_tools.print_grads(self.model, print_func=None)
if self.iteration % train_conf.accum_grad == 0: losses_np = {
'train_loss': float(loss) * train_conf.accum_grad,
'train_att_loss': float(attention_loss),
'train_ctc_loss': float(ctc_loss),
}
if (self.iteration + 1) % train_conf.accum_grad == 0:
if dist.get_rank() == 0 and self.visualizer:
for k, v in losses_np.items():
self.visualizer.add_scalar("train/{}".format(k), v,
self.iteration)
self.optimizer.step() self.optimizer.step()
self.optimizer.clear_grad() self.optimizer.clear_grad()
self.lr_scheduler.step() self.lr_scheduler.step()
self.iteration += 1
iteration_time = time.time() - start iteration_time = time.time() - start
losses_np = { if (self.iteration + 1) % train_conf.log_interval == 0:
'train_loss': float(loss),
'train_att_loss': float(attention_loss),
'train_ctc_loss': float(ctc_loss),
}
msg += "time: {:>.3f}s, ".format(iteration_time) msg += "time: {:>.3f}s, ".format(iteration_time)
msg += "batch size: {}, ".format(self.config.data.batch_size) msg += "batch size: {}, ".format(self.config.data.batch_size)
msg += "accum: {}, ".format(train_conf.accum_grad) msg += "accum: {}, ".format(train_conf.accum_grad)
msg += ', '.join('{}: {:>.6f}'.format(k, v) msg += ', '.join('{}: {:>.6f}'.format(k, v)
for k, v in losses_np.items()) for k, v in losses_np.items())
if self.iteration % train_conf.log_interval == 0:
self.logger.info(msg) self.logger.info(msg)
# display
if dist.get_rank() == 0 and self.visualizer:
for k, v in losses_np.items():
self.visualizer.add_scalar("train/{}".format(k), v,
self.iteration)
def train(self): def train(self):
"""The training process. """The training process control by step."""
It includes forward/backward/update and periodical validation and
saving.
"""
# !!!IMPORTANT!!! # !!!IMPORTANT!!!
# Try to export the model by script, if fails, we should refine # Try to export the model by script, if fails, we should refine
# the code to satisfy the script export requirements # the code to satisfy the script export requirements
@ -124,10 +123,17 @@ class U2Trainer(Trainer):
# paddle.jit.save(script_model, script_model_path) # paddle.jit.save(script_model, script_model_path)
from_scratch = self.resume_or_scratch() from_scratch = self.resume_or_scratch()
if from_scratch:
# save init model, i.e. 0 epoch
self.save(tag='init')
self.lr_scheduler.step(self.iteration)
if self.parallel:
self.train_loader.batch_sampler.set_epoch(self.epoch)
self.logger.info( self.logger.info(
f"Train Total Examples: {len(self.train_loader.dataset)}") f"Train Total Examples: {len(self.train_loader.dataset)}")
while self.epoch < self.config.training.n_epoch:
while self.epoch <= self.config.training.n_epoch:
try: try:
data_start_time = time.time() data_start_time = time.time()
for batch in self.train_loader: for batch in self.train_loader:
@ -135,9 +141,8 @@ class U2Trainer(Trainer):
msg = "Train: Rank: {}, ".format(dist.get_rank()) msg = "Train: Rank: {}, ".format(dist.get_rank())
msg += "epoch: {}, ".format(self.epoch) msg += "epoch: {}, ".format(self.epoch)
msg += "step: {}, ".format(self.iteration) msg += "step: {}, ".format(self.iteration)
msg += "lr: {}, ".foramt(self.lr_scheduler()) msg += "lr: {:>.8f}, ".format(self.lr_scheduler())
msg += "dataloader time: {:>.3f}s, ".format(dataload_time) msg += "dataloader time: {:>.3f}s, ".format(dataload_time)
self.iteration += 1
self.train_batch(batch, msg) self.train_batch(batch, msg)
data_start_time = time.time() data_start_time = time.time()
except Exception as e: except Exception as e:
@ -263,12 +268,12 @@ class U2Trainer(Trainer):
lr_scheduler = paddle.optimizer.lr.ExponentialDecay( lr_scheduler = paddle.optimizer.lr.ExponentialDecay(
learning_rate=optim_conf.lr, learning_rate=optim_conf.lr,
gamma=scheduler_conf.lr_decay, gamma=scheduler_conf.lr_decay,
verbose=True) verbose=False)
elif scheduler_type == 'warmuplr': elif scheduler_type == 'warmuplr':
lr_scheduler = WarmupLR( lr_scheduler = WarmupLR(
learning_rate=optim_conf.lr, learning_rate=optim_conf.lr,
warmup_steps=scheduler_conf.warmup_steps, warmup_steps=scheduler_conf.warmup_steps,
verbose=True) verbose=False)
else: else:
raise ValueError(f"Not support scheduler: {scheduler_type}") raise ValueError(f"Not support scheduler: {scheduler_type}")

@ -127,7 +127,7 @@ class Trainer():
dist.init_parallel_env() dist.init_parallel_env()
@mp_tools.rank_zero_only @mp_tools.rank_zero_only
def save(self, infos=None): def save(self, tag=None, infos=None):
"""Save checkpoint (model parameters and optimizer states). """Save checkpoint (model parameters and optimizer states).
""" """
if infos is None: if infos is None:
@ -136,8 +136,9 @@ class Trainer():
"epoch": self.epoch, "epoch": self.epoch,
"lr": self.optimizer.get_lr(), "lr": self.optimizer.get_lr(),
} }
checkpoint.save_parameters(self.checkpoint_dir, self.iteration, checkpoint.save_parameters(self.checkpoint_dir, self.iteration
self.model, self.optimizer, infos) if tag is None else tag, self.model,
self.optimizer, infos)
def resume_or_scratch(self): def resume_or_scratch(self):
"""Resume from latest checkpoint at checkpoints in the output """Resume from latest checkpoint at checkpoints in the output
@ -146,6 +147,7 @@ class Trainer():
If ``args.checkpoint_path`` is not None, load the checkpoint, else If ``args.checkpoint_path`` is not None, load the checkpoint, else
resume training. resume training.
""" """
scratch = None
infos = checkpoint.load_parameters( infos = checkpoint.load_parameters(
self.model, self.model,
self.optimizer, self.optimizer,
@ -155,44 +157,41 @@ class Trainer():
# restore from ckpt # restore from ckpt
self.iteration = infos["step"] self.iteration = infos["step"]
self.epoch = infos["epoch"] self.epoch = infos["epoch"]
self.lr_scheduler.step(self.iteration) scratch = False
if self.parallel:
self.train_loader.batch_sampler.set_epoch(self.epoch)
return False
else: else:
# from scratch, epoch and iteration init with zero scratch = True
# save init model, i.e. 0 epoch
self.save() return scratch
# self.epoch start from 1.
self.new_epoch()
return True
def new_epoch(self): def new_epoch(self):
"""Reset the train loader and increment ``epoch``. """Reset the train loader seed and increment `epoch`.
""" """
self.epoch += 1
if self.parallel: if self.parallel:
# batch sampler epoch start from 0
self.train_loader.batch_sampler.set_epoch(self.epoch) self.train_loader.batch_sampler.set_epoch(self.epoch)
self.epoch += 1
def train(self): def train(self):
"""The training process. """The training process control by epoch."""
"""
from_scratch = self.resume_or_scratch() from_scratch = self.resume_or_scratch()
if from_scratch:
# save init model, i.e. 0 epoch
self.save(tag='init')
self.lr_scheduler.step(self.iteration)
if self.parallel:
self.train_loader.batch_sampler.set_epoch(self.epoch)
self.logger.info( self.logger.info(
f"Train Total Examples: {len(self.train_loader.dataset)}") f"Train Total Examples: {len(self.train_loader.dataset)}")
while self.epoch <= self.config.training.n_epoch: while self.epoch < self.config.training.n_epoch:
try: try:
data_start_time = time.time() data_start_time = time.time()
for batch in self.train_loader: for batch in self.train_loader:
dataload_time = time.time() - data_start_time dataload_time = time.time() - data_start_time
# iteration start from 1.
self.iteration += 1
msg = "Train: Rank: {}, ".format(dist.get_rank()) msg = "Train: Rank: {}, ".format(dist.get_rank())
msg += "epoch: {}, ".format(self.epoch) msg += "epoch: {}, ".format(self.epoch)
msg += "step: {}, ".format(self.iteration) msg += "step: {}, ".format(self.iteration)
msg += "lr: {:>.8f}, ".format(self.lr_scheduler())
msg += "dataloader time: {:>.3f}s, ".format(dataload_time) msg += "dataloader time: {:>.3f}s, ".format(dataload_time)
self.train_batch(batch, msg) self.train_batch(batch, msg)
data_start_time = time.time() data_start_time = time.time()
@ -202,7 +201,6 @@ class Trainer():
self.valid() self.valid()
self.save() self.save()
# lr control by epoch
self.lr_scheduler.step() self.lr_scheduler.step()
self.new_epoch() self.new_epoch()

@ -16,6 +16,7 @@ import os
import logging import logging
import re import re
import json import json
from typing import Union
import paddle import paddle
from paddle import distributed as dist from paddle import distributed as dist
@ -79,7 +80,7 @@ def load_parameters(model,
configs = {} configs = {}
if checkpoint_path is not None: if checkpoint_path is not None:
iteration = int(os.path.basename(checkpoint_path).split(":")[-1]) tag = os.path.basename(checkpoint_path).split(":")[-1]
elif checkpoint_dir is not None: elif checkpoint_dir is not None:
iteration = _load_latest_checkpoint(checkpoint_dir) iteration = _load_latest_checkpoint(checkpoint_dir)
if iteration == -1: if iteration == -1:
@ -113,14 +114,14 @@ def load_parameters(model,
@mp_tools.rank_zero_only @mp_tools.rank_zero_only
def save_parameters(checkpoint_dir: str, def save_parameters(checkpoint_dir: str,
iteration: int, tag_or_iteration: Union[int, str],
model: paddle.nn.Layer, model: paddle.nn.Layer,
optimizer: Optimizer=None, optimizer: Optimizer=None,
infos: dict=None): infos: dict=None):
"""Checkpoint the latest trained model parameters. """Checkpoint the latest trained model parameters.
Args: Args:
checkpoint_dir (str): the directory where checkpoint is saved. checkpoint_dir (str): the directory where checkpoint is saved.
iteration (int): the latest iteration(step or epoch) number. tag_or_iteration (int or str): the latest iteration(step or epoch) number.
model (Layer): model to be checkpointed. model (Layer): model to be checkpointed.
optimizer (Optimizer, optional): optimizer to be checkpointed. optimizer (Optimizer, optional): optimizer to be checkpointed.
Defaults to None. Defaults to None.
@ -128,7 +129,8 @@ def save_parameters(checkpoint_dir: str,
Returns: Returns:
None None
""" """
checkpoint_path = os.path.join(checkpoint_dir, "{}".format(iteration)) checkpoint_path = os.path.join(checkpoint_dir,
"{}".format(tag_or_iteration))
model_dict = model.state_dict() model_dict = model.state_dict()
params_path = checkpoint_path + ".pdparams" params_path = checkpoint_path + ".pdparams"
@ -142,10 +144,10 @@ def save_parameters(checkpoint_dir: str,
logger.info("Saved optimzier state to {}".format(optimizer_path)) logger.info("Saved optimzier state to {}".format(optimizer_path))
info_path = re.sub('.pdparams$', '.json', params_path) info_path = re.sub('.pdparams$', '.json', params_path)
if infos is None: infos = {} if infos is None else infos
infos = {}
with open(info_path, 'w') as fout: with open(info_path, 'w') as fout:
data = json.dumps(infos) data = json.dumps(infos)
fout.write(data) fout.write(data)
_save_checkpoint(checkpoint_dir, iteration) if isinstance(tag_or_iteration, int):
_save_checkpoint(checkpoint_dir, tag_or_iteration)

@ -6,7 +6,6 @@ data:
vocab_filepath: data/vocab.txt vocab_filepath: data/vocab.txt
unit_type: 'char' unit_type: 'char'
spm_model_prefix: '' spm_model_prefix: ''
mean_std_filepath: ""
augmentation_config: conf/augmentation.json augmentation_config: conf/augmentation.json
batch_size: 64 batch_size: 64
min_input_len: 0.5 min_input_len: 0.5

@ -12,7 +12,7 @@ data:
min_input_len: 0.5 min_input_len: 0.5
max_input_len: 20.0 max_input_len: 20.0
min_output_len: 0.0 min_output_len: 0.0
max_output_len: 400 max_output_len: 400.0
min_output_input_ratio: 0.05 min_output_input_ratio: 0.05
max_output_input_ratio: 10.0 max_output_input_ratio: 10.0
raw_wav: True # use raw_wav or kaldi feature raw_wav: True # use raw_wav or kaldi feature

Loading…
Cancel
Save