|
|
|
@ -109,7 +109,7 @@ class U2Trainer(Trainer):
|
|
|
|
|
msg += "accum: {}, ".format(train_conf.accum_grad)
|
|
|
|
|
msg += ', '.join('{}: {:>.6f}'.format(k, v)
|
|
|
|
|
for k, v in losses_np.items())
|
|
|
|
|
self.logger.info(msg)
|
|
|
|
|
logger.info(msg)
|
|
|
|
|
|
|
|
|
|
def train(self):
|
|
|
|
|
"""The training process control by step."""
|
|
|
|
@ -129,8 +129,7 @@ class U2Trainer(Trainer):
|
|
|
|
|
if self.parallel:
|
|
|
|
|
self.train_loader.batch_sampler.set_epoch(self.epoch)
|
|
|
|
|
|
|
|
|
|
self.logger.info(
|
|
|
|
|
f"Train Total Examples: {len(self.train_loader.dataset)}")
|
|
|
|
|
logger.info(f"Train Total Examples: {len(self.train_loader.dataset)}")
|
|
|
|
|
while self.epoch < self.config.training.n_epoch:
|
|
|
|
|
self.model.train()
|
|
|
|
|
try:
|
|
|
|
@ -145,7 +144,7 @@ class U2Trainer(Trainer):
|
|
|
|
|
self.train_batch(batch_index, batch, msg)
|
|
|
|
|
data_start_time = time.time()
|
|
|
|
|
except Exception as e:
|
|
|
|
|
self.logger.error(e)
|
|
|
|
|
logger.error(e)
|
|
|
|
|
raise e
|
|
|
|
|
|
|
|
|
|
valid_losses = self.valid()
|
|
|
|
@ -156,8 +155,7 @@ class U2Trainer(Trainer):
|
|
|
|
|
@paddle.no_grad()
|
|
|
|
|
def valid(self):
|
|
|
|
|
self.model.eval()
|
|
|
|
|
self.logger.info(
|
|
|
|
|
f"Valid Total Examples: {len(self.valid_loader.dataset)}")
|
|
|
|
|
logger.info(f"Valid Total Examples: {len(self.valid_loader.dataset)}")
|
|
|
|
|
valid_losses = defaultdict(list)
|
|
|
|
|
for i, batch in enumerate(self.valid_loader):
|
|
|
|
|
total_loss, attention_loss, ctc_loss = self.model(*batch)
|
|
|
|
@ -175,7 +173,7 @@ class U2Trainer(Trainer):
|
|
|
|
|
msg += "step: {}, ".format(self.iteration)
|
|
|
|
|
msg += ', '.join('{}: {:>.6f}'.format(k, v)
|
|
|
|
|
for k, v in valid_losses.items())
|
|
|
|
|
self.logger.info(msg)
|
|
|
|
|
logger.info(msg)
|
|
|
|
|
|
|
|
|
|
if self.visualizer:
|
|
|
|
|
for k, v in valid_losses.items():
|
|
|
|
@ -239,7 +237,7 @@ class U2Trainer(Trainer):
|
|
|
|
|
shuffle=False,
|
|
|
|
|
drop_last=False,
|
|
|
|
|
collate_fn=SpeechCollator(keep_transcription_text=True))
|
|
|
|
|
self.logger.info("Setup train/valid/test Dataloader!")
|
|
|
|
|
logger.info("Setup train/valid/test Dataloader!")
|
|
|
|
|
|
|
|
|
|
def setup_model(self):
|
|
|
|
|
config = self.config
|
|
|
|
@ -253,7 +251,7 @@ class U2Trainer(Trainer):
|
|
|
|
|
if self.parallel:
|
|
|
|
|
model = paddle.DataParallel(model)
|
|
|
|
|
|
|
|
|
|
layer_tools.print_params(model, self.logger.info)
|
|
|
|
|
layer_tools.print_params(model, logger.info)
|
|
|
|
|
|
|
|
|
|
train_config = config.training
|
|
|
|
|
optim_type = train_config.optim
|
|
|
|
@ -289,7 +287,7 @@ class U2Trainer(Trainer):
|
|
|
|
|
self.model = model
|
|
|
|
|
self.optimizer = optimizer
|
|
|
|
|
self.lr_scheduler = lr_scheduler
|
|
|
|
|
self.logger.info("Setup model/optimizer/lr_scheduler!")
|
|
|
|
|
logger.info("Setup model/optimizer/lr_scheduler!")
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
class U2Tester(U2Trainer):
|
|
|
|
@ -367,11 +365,10 @@ class U2Tester(U2Trainer):
|
|
|
|
|
num_ins += 1
|
|
|
|
|
if fout:
|
|
|
|
|
fout.write(result + "\n")
|
|
|
|
|
self.logger.info(
|
|
|
|
|
"\nTarget Transcription: %s\nOutput Transcription: %s" %
|
|
|
|
|
(target, result))
|
|
|
|
|
self.logger.info("Current error rate [%s] = %f" % (
|
|
|
|
|
cfg.error_rate_type, error_rate_func(target, result)))
|
|
|
|
|
logger.info("\nTarget Transcription: %s\nOutput Transcription: %s" %
|
|
|
|
|
(target, result))
|
|
|
|
|
logger.info("Current error rate [%s] = %f" %
|
|
|
|
|
(cfg.error_rate_type, error_rate_func(target, result)))
|
|
|
|
|
|
|
|
|
|
return dict(
|
|
|
|
|
errors_sum=errors_sum,
|
|
|
|
@ -385,8 +382,7 @@ class U2Tester(U2Trainer):
|
|
|
|
|
def test(self):
|
|
|
|
|
assert self.args.result_file
|
|
|
|
|
self.model.eval()
|
|
|
|
|
self.logger.info(
|
|
|
|
|
f"Test Total Examples: {len(self.test_loader.dataset)}")
|
|
|
|
|
logger.info(f"Test Total Examples: {len(self.test_loader.dataset)}")
|
|
|
|
|
|
|
|
|
|
error_rate_type = None
|
|
|
|
|
errors_sum, len_refs, num_ins = 0.0, 0, 0
|
|
|
|
@ -398,9 +394,8 @@ class U2Tester(U2Trainer):
|
|
|
|
|
len_refs += metrics['len_refs']
|
|
|
|
|
num_ins += metrics['num_ins']
|
|
|
|
|
error_rate_type = metrics['error_rate_type']
|
|
|
|
|
self.logger.info(
|
|
|
|
|
"Error rate [%s] (%d/?) = %f" %
|
|
|
|
|
(error_rate_type, num_ins, errors_sum / len_refs))
|
|
|
|
|
logger.info("Error rate [%s] (%d/?) = %f" %
|
|
|
|
|
(error_rate_type, num_ins, errors_sum / len_refs))
|
|
|
|
|
|
|
|
|
|
# logging
|
|
|
|
|
msg = "Test: "
|
|
|
|
@ -408,7 +403,7 @@ class U2Tester(U2Trainer):
|
|
|
|
|
msg += "step: {}, ".format(self.iteration)
|
|
|
|
|
msg += ", Final error rate [%s] (%d/%d) = %f" % (
|
|
|
|
|
error_rate_type, num_ins, num_ins, errors_sum / len_refs)
|
|
|
|
|
self.logger.info(msg)
|
|
|
|
|
logger.info(msg)
|
|
|
|
|
|
|
|
|
|
def run_test(self):
|
|
|
|
|
self.resume_or_scratch()
|
|
|
|
@ -459,7 +454,6 @@ class U2Tester(U2Trainer):
|
|
|
|
|
|
|
|
|
|
self.setup_output_dir()
|
|
|
|
|
self.setup_checkpointer()
|
|
|
|
|
self.setup_logger()
|
|
|
|
|
|
|
|
|
|
self.setup_dataloader()
|
|
|
|
|
self.setup_model()
|
|
|
|
@ -480,25 +474,3 @@ class U2Tester(U2Trainer):
|
|
|
|
|
output_dir.mkdir(parents=True, exist_ok=True)
|
|
|
|
|
|
|
|
|
|
self.output_dir = output_dir
|
|
|
|
|
|
|
|
|
|
def setup_logger(self):
|
|
|
|
|
"""Initialize a text logger to log the experiment.
|
|
|
|
|
|
|
|
|
|
Each process has its own text logger. The logging message is write to
|
|
|
|
|
the standard output and a text file named ``worker_n.log`` in the
|
|
|
|
|
output directory, where ``n`` means the rank of the process.
|
|
|
|
|
"""
|
|
|
|
|
format = '[%(levelname)s %(asctime)s %(filename)s:%(lineno)d] %(message)s'
|
|
|
|
|
formatter = logging.Formatter(fmt=format, datefmt='%Y/%m/%d %H:%M:%S')
|
|
|
|
|
|
|
|
|
|
logger.setLevel("INFO")
|
|
|
|
|
|
|
|
|
|
# global logger
|
|
|
|
|
stdout = True
|
|
|
|
|
save_path = ""
|
|
|
|
|
logging.basicConfig(
|
|
|
|
|
level=logging.DEBUG if stdout else logging.INFO,
|
|
|
|
|
format=format,
|
|
|
|
|
datefmt='%Y/%m/%d %H:%M:%S',
|
|
|
|
|
filename=save_path if not stdout else None)
|
|
|
|
|
self.logger = logger
|
|
|
|
|