|
|
@ -15,6 +15,7 @@
|
|
|
|
import json
|
|
|
|
import json
|
|
|
|
import math
|
|
|
|
import math
|
|
|
|
import os
|
|
|
|
import os
|
|
|
|
|
|
|
|
import re
|
|
|
|
import time
|
|
|
|
import time
|
|
|
|
from collections import defaultdict
|
|
|
|
from collections import defaultdict
|
|
|
|
from collections import OrderedDict
|
|
|
|
from collections import OrderedDict
|
|
|
@ -62,6 +63,19 @@ class Wav2Vec2ASRTrainer(Trainer):
|
|
|
|
self.avg_train_loss -= self.avg_train_loss / (batch_index + 1)
|
|
|
|
self.avg_train_loss -= self.avg_train_loss / (batch_index + 1)
|
|
|
|
self.avg_train_loss += loss / (batch_index + 1)
|
|
|
|
self.avg_train_loss += loss / (batch_index + 1)
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
def before_train(self):
|
|
|
|
|
|
|
|
from_scratch = self.resume_or_scratch()
|
|
|
|
|
|
|
|
if from_scratch:
|
|
|
|
|
|
|
|
# scratch: save init model, i.e. 0 epoch
|
|
|
|
|
|
|
|
self.save(tag='init', infos=None)
|
|
|
|
|
|
|
|
else:
|
|
|
|
|
|
|
|
# resume: train next_epoch and next_iteration
|
|
|
|
|
|
|
|
self.epoch += 1
|
|
|
|
|
|
|
|
logger.info(
|
|
|
|
|
|
|
|
f"Resume train: epoch {self.epoch }, step {self.iteration}!")
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
self.maybe_batch_sampler_step()
|
|
|
|
|
|
|
|
|
|
|
|
def train_batch(self, batch_index, batch, msg):
|
|
|
|
def train_batch(self, batch_index, batch, msg):
|
|
|
|
train_conf = self.config
|
|
|
|
train_conf = self.config
|
|
|
|
start = time.time()
|
|
|
|
start = time.time()
|
|
|
@ -69,14 +83,14 @@ class Wav2Vec2ASRTrainer(Trainer):
|
|
|
|
# forward
|
|
|
|
# forward
|
|
|
|
utt, wav, wavs_lens, target, target_lens = batch
|
|
|
|
utt, wav, wavs_lens, target, target_lens = batch
|
|
|
|
wavs_lens_rate = wavs_lens / wav.shape[1]
|
|
|
|
wavs_lens_rate = wavs_lens / wav.shape[1]
|
|
|
|
target_lens_rate = target_lens / target.shape[1]
|
|
|
|
|
|
|
|
wav = wav[:, :, 0]
|
|
|
|
wav = wav[:, :, 0]
|
|
|
|
if hasattr(train_conf, 'speech_augment'):
|
|
|
|
if hasattr(train_conf, 'audio_augment'):
|
|
|
|
wav = self.speech_augmentation(wav, wavs_lens_rate)
|
|
|
|
wav = self.speech_augmentation(wav, wavs_lens_rate)
|
|
|
|
loss = self.model(wav, wavs_lens_rate, target, target_lens_rate)
|
|
|
|
|
|
|
|
|
|
|
|
loss = self.model(wav, wavs_lens_rate, target, target_lens)
|
|
|
|
# loss div by `batch_size * accum_grad`
|
|
|
|
# loss div by `batch_size * accum_grad`
|
|
|
|
loss /= train_conf.accum_grad
|
|
|
|
loss /= train_conf.accum_grad
|
|
|
|
|
|
|
|
|
|
|
|
# update self.avg_train_loss
|
|
|
|
# update self.avg_train_loss
|
|
|
|
self.update_average(batch_index, float(loss))
|
|
|
|
self.update_average(batch_index, float(loss))
|
|
|
|
|
|
|
|
|
|
|
@ -98,11 +112,17 @@ class Wav2Vec2ASRTrainer(Trainer):
|
|
|
|
|
|
|
|
|
|
|
|
# optimizer step old
|
|
|
|
# optimizer step old
|
|
|
|
if (batch_index + 1) % train_conf.accum_grad == 0:
|
|
|
|
if (batch_index + 1) % train_conf.accum_grad == 0:
|
|
|
|
self.optimizer.step()
|
|
|
|
self.model_optimizer.step()
|
|
|
|
self.optimizer.clear_grad()
|
|
|
|
self.model_optimizer.clear_grad()
|
|
|
|
self.lr_scheduler.step()
|
|
|
|
if not train_conf.freeze_wav2vec2:
|
|
|
|
|
|
|
|
self.wav2vec2_optimizer.step()
|
|
|
|
|
|
|
|
self.wav2vec2_optimizer.clear_grad()
|
|
|
|
|
|
|
|
if self.config.model_scheduler != 'newbobscheduler':
|
|
|
|
|
|
|
|
self.model_lr_scheduler.step()
|
|
|
|
|
|
|
|
if self.config.wav2vec2_scheduler != 'newbobscheduler':
|
|
|
|
|
|
|
|
if not train_conf.freeze_wav2vec2:
|
|
|
|
|
|
|
|
self.wav2vec2_lr_scheduler.step()
|
|
|
|
self.iteration += 1
|
|
|
|
self.iteration += 1
|
|
|
|
|
|
|
|
|
|
|
|
losses_np = {'loss': self.avg_train_loss * train_conf.accum_grad}
|
|
|
|
losses_np = {'loss': self.avg_train_loss * train_conf.accum_grad}
|
|
|
|
iteration_time = time.time() - start
|
|
|
|
iteration_time = time.time() - start
|
|
|
|
for k, v in losses_np.items():
|
|
|
|
for k, v in losses_np.items():
|
|
|
@ -114,7 +134,10 @@ class Wav2Vec2ASRTrainer(Trainer):
|
|
|
|
if (batch_index + 1) % train_conf.accum_grad == 0:
|
|
|
|
if (batch_index + 1) % train_conf.accum_grad == 0:
|
|
|
|
if dist.get_rank() == 0 and self.visualizer:
|
|
|
|
if dist.get_rank() == 0 and self.visualizer:
|
|
|
|
losses_np_v = losses_np.copy()
|
|
|
|
losses_np_v = losses_np.copy()
|
|
|
|
losses_np_v.update({"lr": self.lr_scheduler()})
|
|
|
|
losses_np_v.update({
|
|
|
|
|
|
|
|
"model_lr": self.model_lr_scheduler(),
|
|
|
|
|
|
|
|
"wav2vec2_lr": self.wav2vec2_lr_scheduler()
|
|
|
|
|
|
|
|
})
|
|
|
|
for key, val in losses_np_v.items():
|
|
|
|
for key, val in losses_np_v.items():
|
|
|
|
self.visualizer.add_scalar(
|
|
|
|
self.visualizer.add_scalar(
|
|
|
|
tag='train/' + key, value=val, step=self.iteration - 1)
|
|
|
|
tag='train/' + key, value=val, step=self.iteration - 1)
|
|
|
@ -131,11 +154,10 @@ class Wav2Vec2ASRTrainer(Trainer):
|
|
|
|
for i, batch in enumerate(self.valid_loader):
|
|
|
|
for i, batch in enumerate(self.valid_loader):
|
|
|
|
utt, wav, wavs_lens, target, target_lens = batch
|
|
|
|
utt, wav, wavs_lens, target, target_lens = batch
|
|
|
|
wavs_lens_rate = wavs_lens / wav.shape[1]
|
|
|
|
wavs_lens_rate = wavs_lens / wav.shape[1]
|
|
|
|
target_lens_rate = target_lens / target.shape[1]
|
|
|
|
|
|
|
|
wav = wav[:, :, 0]
|
|
|
|
wav = wav[:, :, 0]
|
|
|
|
loss = self.model(wav, wavs_lens_rate, target, target_lens_rate)
|
|
|
|
loss = self.model(wav, wavs_lens_rate, target, target_lens)
|
|
|
|
|
|
|
|
|
|
|
|
if paddle.isfinite(loss):
|
|
|
|
if math.isfinite(float(loss)):
|
|
|
|
num_utts = batch[1].shape[0]
|
|
|
|
num_utts = batch[1].shape[0]
|
|
|
|
num_seen_utts += num_utts
|
|
|
|
num_seen_utts += num_utts
|
|
|
|
total_loss += float(loss) * num_utts
|
|
|
|
total_loss += float(loss) * num_utts
|
|
|
@ -160,6 +182,106 @@ class Wav2Vec2ASRTrainer(Trainer):
|
|
|
|
dist.get_rank(), total_loss / num_seen_utts))
|
|
|
|
dist.get_rank(), total_loss / num_seen_utts))
|
|
|
|
return total_loss, num_seen_utts
|
|
|
|
return total_loss, num_seen_utts
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
@mp_tools.rank_zero_only
|
|
|
|
|
|
|
|
def save(self, tag=None, infos: dict=None):
|
|
|
|
|
|
|
|
"""Save checkpoint (model parameters and optimizer states).
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
Args:
|
|
|
|
|
|
|
|
tag (int or str, optional): None for step, else using tag, e.g epoch. Defaults to None.
|
|
|
|
|
|
|
|
infos (dict, optional): meta data to save. Defaults to None.
|
|
|
|
|
|
|
|
"""
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
infos = infos if infos else dict()
|
|
|
|
|
|
|
|
infos.update({
|
|
|
|
|
|
|
|
"epoch": self.epoch,
|
|
|
|
|
|
|
|
"model_lr": self.model_optimizer.get_lr(),
|
|
|
|
|
|
|
|
"wav2vec2_lr": self.wav2vec2_optimizer.get_lr()
|
|
|
|
|
|
|
|
})
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
checkpoint_path = os.path.join(
|
|
|
|
|
|
|
|
self.checkpoint_dir,
|
|
|
|
|
|
|
|
"{}".format(self.iteration if tag is None else tag))
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
model_dict = self.model.state_dict()
|
|
|
|
|
|
|
|
params_path = checkpoint_path + ".pdparams"
|
|
|
|
|
|
|
|
paddle.save(model_dict, params_path)
|
|
|
|
|
|
|
|
logger.info("Saved model to {}".format(params_path))
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
model_opt_dict = self.model_optimizer.state_dict()
|
|
|
|
|
|
|
|
wav2vec2_opt_dict = self.wav2vec2_optimizer.state_dict()
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
opt_dict = {'model': model_opt_dict, 'wav2vec2': wav2vec2_opt_dict}
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
optimizer_path = checkpoint_path + ".pdopt"
|
|
|
|
|
|
|
|
paddle.save(opt_dict, optimizer_path)
|
|
|
|
|
|
|
|
logger.info("Saved optimzier state to {}".format(optimizer_path))
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
scheduler_dict = {}
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
if self.config.model_scheduler == 'newbobscheduler':
|
|
|
|
|
|
|
|
scheduler_dict['model'] = self.model_lr_scheduler.save()
|
|
|
|
|
|
|
|
if self.config.wav2vec2_scheduler == 'newbobscheduler':
|
|
|
|
|
|
|
|
scheduler_dict['wav2vec2'] = self.wav2vec2_lr_scheduler.save()
|
|
|
|
|
|
|
|
if scheduler_dict:
|
|
|
|
|
|
|
|
scheduler_path = checkpoint_path + ".pdlrs"
|
|
|
|
|
|
|
|
paddle.save(scheduler_dict, scheduler_path)
|
|
|
|
|
|
|
|
logger.info("Saved scheduler state to {}".format(scheduler_path))
|
|
|
|
|
|
|
|
info_path = re.sub('.pdparams$', '.json', params_path)
|
|
|
|
|
|
|
|
infos = {} if infos is None else infos
|
|
|
|
|
|
|
|
with open(info_path, 'w') as fout:
|
|
|
|
|
|
|
|
data = json.dumps(infos)
|
|
|
|
|
|
|
|
fout.write(data)
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
def resume_or_scratch(self):
|
|
|
|
|
|
|
|
"""Resume from latest checkpoint at checkpoints in the output
|
|
|
|
|
|
|
|
directory or load a specified checkpoint.
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
If ``args.checkpoint_path`` is not None, load the checkpoint, else
|
|
|
|
|
|
|
|
resume training.
|
|
|
|
|
|
|
|
"""
|
|
|
|
|
|
|
|
scratch = None
|
|
|
|
|
|
|
|
if self.args.resume:
|
|
|
|
|
|
|
|
# just restore ckpt
|
|
|
|
|
|
|
|
# lr will resotre from optimizer ckpt
|
|
|
|
|
|
|
|
resume_json_path = os.path.join(self.checkpoint_dir,
|
|
|
|
|
|
|
|
self.args.resume + '.json')
|
|
|
|
|
|
|
|
with open(resume_json_path, 'r') as f:
|
|
|
|
|
|
|
|
resume_json = json.load(f)
|
|
|
|
|
|
|
|
self.iteration = 0
|
|
|
|
|
|
|
|
self.epoch = resume_json["epoch"]
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
# resotre model from *.pdparams
|
|
|
|
|
|
|
|
params_path = os.path.join(self.checkpoint_dir,
|
|
|
|
|
|
|
|
"{}".format(self.epoch)) + '.pdparams'
|
|
|
|
|
|
|
|
model_dict = paddle.load(params_path)
|
|
|
|
|
|
|
|
self.model.set_state_dict(model_dict)
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
# resotre optimizer from *.pdopt
|
|
|
|
|
|
|
|
optimizer_path = os.path.join(self.checkpoint_dir,
|
|
|
|
|
|
|
|
"{}".format(self.epoch)) + '.pdopt'
|
|
|
|
|
|
|
|
optimizer_dict = paddle.load(optimizer_path)
|
|
|
|
|
|
|
|
self.model_optimizer.set_state_dict(optimizer_dict['model'])
|
|
|
|
|
|
|
|
self.wav2vec2_optimizer.set_state_dict(optimizer_dict['wav2vec2'])
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
# resotre lr_scheduler from *.pdlrs
|
|
|
|
|
|
|
|
scheduler_path = os.path.join(self.checkpoint_dir,
|
|
|
|
|
|
|
|
"{}".format(self.epoch)) + '.pdlrs'
|
|
|
|
|
|
|
|
if os.path.isfile(os.path.join(scheduler_path)):
|
|
|
|
|
|
|
|
scheduler_dict = paddle.load(scheduler_path)
|
|
|
|
|
|
|
|
if self.config.model_scheduler == 'newbobscheduler':
|
|
|
|
|
|
|
|
self.model_lr_scheduler.load(scheduler_dict['model'])
|
|
|
|
|
|
|
|
if self.config.wav2vec2_scheduler == 'newbobscheduler':
|
|
|
|
|
|
|
|
self.wav2vec2_lr_scheduler.load(scheduler_dict['wav2vec2'])
|
|
|
|
|
|
|
|
logger.info(
|
|
|
|
|
|
|
|
f"Restore ckpt: epoch {self.epoch }, step {self.iteration}!")
|
|
|
|
|
|
|
|
scratch = False
|
|
|
|
|
|
|
|
else:
|
|
|
|
|
|
|
|
self.iteration = 0
|
|
|
|
|
|
|
|
self.epoch = 0
|
|
|
|
|
|
|
|
scratch = True
|
|
|
|
|
|
|
|
logger.info("Init from scratch!")
|
|
|
|
|
|
|
|
return scratch
|
|
|
|
|
|
|
|
|
|
|
|
def do_train(self):
|
|
|
|
def do_train(self):
|
|
|
|
"""The training process control by step."""
|
|
|
|
"""The training process control by step."""
|
|
|
|
# !!!IMPORTANT!!!
|
|
|
|
# !!!IMPORTANT!!!
|
|
|
@ -170,7 +292,6 @@ class Wav2Vec2ASRTrainer(Trainer):
|
|
|
|
# paddle.jit.save(script_model, script_model_path)
|
|
|
|
# paddle.jit.save(script_model, script_model_path)
|
|
|
|
|
|
|
|
|
|
|
|
self.before_train()
|
|
|
|
self.before_train()
|
|
|
|
|
|
|
|
|
|
|
|
if not self.use_streamdata:
|
|
|
|
if not self.use_streamdata:
|
|
|
|
logger.info(
|
|
|
|
logger.info(
|
|
|
|
f"Train Total Examples: {len(self.train_loader.dataset)}")
|
|
|
|
f"Train Total Examples: {len(self.train_loader.dataset)}")
|
|
|
@ -187,7 +308,9 @@ class Wav2Vec2ASRTrainer(Trainer):
|
|
|
|
report("Rank", dist.get_rank())
|
|
|
|
report("Rank", dist.get_rank())
|
|
|
|
report("epoch", self.epoch)
|
|
|
|
report("epoch", self.epoch)
|
|
|
|
report('step', self.iteration)
|
|
|
|
report('step', self.iteration)
|
|
|
|
report("lr", self.lr_scheduler())
|
|
|
|
report("model_lr", self.model_optimizer.get_lr())
|
|
|
|
|
|
|
|
report("wav2vec2_lr",
|
|
|
|
|
|
|
|
self.wav2vec2_optimizer.get_lr())
|
|
|
|
self.train_batch(batch_index, batch, msg)
|
|
|
|
self.train_batch(batch_index, batch, msg)
|
|
|
|
self.after_train_batch()
|
|
|
|
self.after_train_batch()
|
|
|
|
report('iter', batch_index + 1)
|
|
|
|
report('iter', batch_index + 1)
|
|
|
@ -225,15 +348,25 @@ class Wav2Vec2ASRTrainer(Trainer):
|
|
|
|
cv_loss = float(cv_loss)
|
|
|
|
cv_loss = float(cv_loss)
|
|
|
|
else:
|
|
|
|
else:
|
|
|
|
cv_loss = total_loss / num_seen_utts
|
|
|
|
cv_loss = total_loss / num_seen_utts
|
|
|
|
|
|
|
|
|
|
|
|
logger.info(
|
|
|
|
logger.info(
|
|
|
|
'Epoch {} Val info val_loss {}'.format(self.epoch, cv_loss))
|
|
|
|
'Epoch {} Val info val_loss {}'.format(self.epoch, cv_loss))
|
|
|
|
if self.visualizer:
|
|
|
|
if self.visualizer:
|
|
|
|
self.visualizer.add_scalar(
|
|
|
|
self.visualizer.add_scalar(
|
|
|
|
tag='eval/cv_loss', value=cv_loss, step=self.epoch)
|
|
|
|
tag='eval/cv_loss', value=cv_loss, step=self.epoch)
|
|
|
|
self.visualizer.add_scalar(
|
|
|
|
self.visualizer.add_scalar(
|
|
|
|
tag='eval/lr', value=self.lr_scheduler(), step=self.epoch)
|
|
|
|
tag='eval/model_lr',
|
|
|
|
|
|
|
|
value=self.model_lr_scheduler(),
|
|
|
|
|
|
|
|
step=self.epoch)
|
|
|
|
|
|
|
|
self.visualizer.add_scalar(
|
|
|
|
|
|
|
|
tag='eval/wav2vec2_lr',
|
|
|
|
|
|
|
|
value=self.wav2vec2_lr_scheduler(),
|
|
|
|
|
|
|
|
step=self.epoch)
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
if self.config.model_scheduler == 'newbobscheduler':
|
|
|
|
|
|
|
|
self.model_lr_scheduler.step(cv_loss)
|
|
|
|
|
|
|
|
if self.config.wav2vec2_scheduler == 'newbobscheduler':
|
|
|
|
|
|
|
|
if not self.config.freeze_wav2vec2:
|
|
|
|
|
|
|
|
self.wav2vec2_lr_scheduler.step(cv_loss)
|
|
|
|
self.save(tag=self.epoch, infos={'val_loss': cv_loss})
|
|
|
|
self.save(tag=self.epoch, infos={'val_loss': cv_loss})
|
|
|
|
self.new_epoch()
|
|
|
|
self.new_epoch()
|
|
|
|
|
|
|
|
|
|
|
@ -268,14 +401,11 @@ class Wav2Vec2ASRTrainer(Trainer):
|
|
|
|
model_conf.output_dim = self.test_loader.vocab_size
|
|
|
|
model_conf.output_dim = self.test_loader.vocab_size
|
|
|
|
|
|
|
|
|
|
|
|
model = Wav2vec2ASR.from_config(model_conf)
|
|
|
|
model = Wav2vec2ASR.from_config(model_conf)
|
|
|
|
|
|
|
|
model_dict = paddle.load(config.wav2vec2_params_path)
|
|
|
|
# load pretrained wav2vec2 model params
|
|
|
|
model.wav2vec2.set_state_dict(model_dict)
|
|
|
|
wav2vec2_dict = paddle.load(config.wav2vec2_params_path)
|
|
|
|
|
|
|
|
model.wav2vec2.set_state_dict(wav2vec2_dict)
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
if self.parallel:
|
|
|
|
if self.parallel:
|
|
|
|
model = paddle.DataParallel(model, find_unused_parameters=True)
|
|
|
|
model = paddle.DataParallel(model, find_unused_parameters=True)
|
|
|
|
|
|
|
|
|
|
|
|
logger.info(f"{model}")
|
|
|
|
logger.info(f"{model}")
|
|
|
|
layer_tools.print_params(model, logger.info)
|
|
|
|
layer_tools.print_params(model, logger.info)
|
|
|
|
self.model = model
|
|
|
|
self.model = model
|
|
|
@ -290,46 +420,74 @@ class Wav2Vec2ASRTrainer(Trainer):
|
|
|
|
return
|
|
|
|
return
|
|
|
|
|
|
|
|
|
|
|
|
train_config = config
|
|
|
|
train_config = config
|
|
|
|
optim_type = train_config.model_optim
|
|
|
|
model_optim_type = train_config.model_optim
|
|
|
|
optim_conf = train_config.model_optim_conf
|
|
|
|
model_optim_conf = train_config.model_optim_conf
|
|
|
|
scheduler_type = train_config.scheduler
|
|
|
|
wav2vec2_optim_type = train_config.model_optim
|
|
|
|
scheduler_conf = train_config.scheduler_conf
|
|
|
|
wav2vec2_optim_conf = train_config.wav2vec2_optim_conf
|
|
|
|
|
|
|
|
|
|
|
|
scheduler_args = {
|
|
|
|
model_scheduler_type = train_config.model_scheduler
|
|
|
|
"learning_rate": optim_conf.lr,
|
|
|
|
model_scheduler_conf = train_config.model_scheduler_conf
|
|
|
|
"verbose": False,
|
|
|
|
wav2vec2_scheduler_type = train_config.wav2vec2_scheduler
|
|
|
|
"warmup_steps": scheduler_conf.warmup_steps,
|
|
|
|
wav2vec2_scheduler_conf = train_config.wav2vec2_scheduler_conf
|
|
|
|
"gamma": scheduler_conf.lr_decay,
|
|
|
|
|
|
|
|
"d_model": model_conf.dnn_neurons,
|
|
|
|
model_scheduler_args = dict(
|
|
|
|
}
|
|
|
|
**{"learning_rate": model_optim_conf.lr,
|
|
|
|
lr_scheduler = LRSchedulerFactory.from_args(scheduler_type,
|
|
|
|
"verbose": False}, **(dict(model_scheduler_conf)))
|
|
|
|
scheduler_args)
|
|
|
|
|
|
|
|
|
|
|
|
wav2vec2_scheduler_args = dict(
|
|
|
|
|
|
|
|
**{"learning_rate": wav2vec2_optim_conf.lr,
|
|
|
|
|
|
|
|
"verbose": False}, **(dict(wav2vec2_scheduler_conf)))
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
model_lr_scheduler = LRSchedulerFactory.from_args(model_scheduler_type,
|
|
|
|
|
|
|
|
model_scheduler_args)
|
|
|
|
|
|
|
|
wav2vec2_lr_scheduler = LRSchedulerFactory.from_args(
|
|
|
|
|
|
|
|
wav2vec2_scheduler_type, wav2vec2_scheduler_args)
|
|
|
|
|
|
|
|
|
|
|
|
def optimizer_args(
|
|
|
|
def optimizer_args(
|
|
|
|
config,
|
|
|
|
config,
|
|
|
|
|
|
|
|
optim_type,
|
|
|
|
|
|
|
|
optim_conf,
|
|
|
|
parameters,
|
|
|
|
parameters,
|
|
|
|
lr_scheduler=None, ):
|
|
|
|
lr_scheduler=None, ):
|
|
|
|
train_config = config
|
|
|
|
train_config = config
|
|
|
|
optim_type = train_config.model_optim
|
|
|
|
optim_arg = dict(optim_conf)
|
|
|
|
optim_conf = train_config.model_optim_conf
|
|
|
|
optim_arg.update({
|
|
|
|
scheduler_type = train_config.scheduler
|
|
|
|
"grad_clip":
|
|
|
|
scheduler_conf = train_config.scheduler_conf
|
|
|
|
train_config.global_grad_clip,
|
|
|
|
return {
|
|
|
|
"learning_rate":
|
|
|
|
"grad_clip": train_config.global_grad_clip,
|
|
|
|
lr_scheduler if lr_scheduler else optim_conf.lr,
|
|
|
|
"learning_rate": lr_scheduler
|
|
|
|
"parameters":
|
|
|
|
if lr_scheduler else optim_conf.lr,
|
|
|
|
parameters
|
|
|
|
"epsilon": optim_conf.epsilon,
|
|
|
|
})
|
|
|
|
"rho": optim_conf.rho,
|
|
|
|
return optim_arg
|
|
|
|
"parameters": parameters,
|
|
|
|
|
|
|
|
"beta1": 0.9 if optim_type == 'noam' else None,
|
|
|
|
model_optimizer_args = optimizer_args(config, model_optim_type,
|
|
|
|
"beat2": 0.98 if optim_type == 'noam' else None,
|
|
|
|
model_optim_conf, [{
|
|
|
|
}
|
|
|
|
'params':
|
|
|
|
|
|
|
|
model._layers.enc.parameters()
|
|
|
|
optimzer_args = optimizer_args(config, model.parameters(), lr_scheduler)
|
|
|
|
}, {
|
|
|
|
optimizer = OptimizerFactory.from_args(optim_type, optimzer_args)
|
|
|
|
'params':
|
|
|
|
|
|
|
|
model._layers.ctc.parameters()
|
|
|
|
self.optimizer = optimizer
|
|
|
|
}] if self.parallel else [{
|
|
|
|
self.lr_scheduler = lr_scheduler
|
|
|
|
'params':
|
|
|
|
|
|
|
|
model.enc.parameters()
|
|
|
|
|
|
|
|
}, {
|
|
|
|
|
|
|
|
'params':
|
|
|
|
|
|
|
|
model.ctc.parameters()
|
|
|
|
|
|
|
|
}], model_lr_scheduler)
|
|
|
|
|
|
|
|
wav2vec2_optimizer_args = optimizer_args(
|
|
|
|
|
|
|
|
config, wav2vec2_optim_type, wav2vec2_optim_conf,
|
|
|
|
|
|
|
|
model._layers.wav2vec2.parameters() if self.parallel else
|
|
|
|
|
|
|
|
model.wav2vec2.parameters(), wav2vec2_lr_scheduler)
|
|
|
|
|
|
|
|
model_optimizer = OptimizerFactory.from_args(model_optim_type,
|
|
|
|
|
|
|
|
model_optimizer_args)
|
|
|
|
|
|
|
|
wav2vec2_optimizer = OptimizerFactory.from_args(wav2vec2_optim_type,
|
|
|
|
|
|
|
|
wav2vec2_optimizer_args)
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
self.model_optimizer = model_optimizer
|
|
|
|
|
|
|
|
self.wav2vec2_optimizer = wav2vec2_optimizer
|
|
|
|
|
|
|
|
self.model_lr_scheduler = model_lr_scheduler
|
|
|
|
|
|
|
|
self.wav2vec2_lr_scheduler = wav2vec2_lr_scheduler
|
|
|
|
logger.info("Setup optimizer/lr_scheduler!")
|
|
|
|
logger.info("Setup optimizer/lr_scheduler!")
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|