pull/751/head
Hui Zhang 3 years ago
parent 820b4db287
commit e76123d418

@ -41,8 +41,6 @@ from deepspeech.utils import mp_tools
from deepspeech.utils import text_grid from deepspeech.utils import text_grid
from deepspeech.utils import utility from deepspeech.utils import utility
from deepspeech.utils.log import Log from deepspeech.utils.log import Log
# from deepspeech.training.gradclip import ClipGradByGlobalNormWithLog
# from deepspeech.training.scheduler import WarmupLR
logger = Log(__name__).getlog() logger = Log(__name__).getlog()
@ -324,25 +322,6 @@ class U2Trainer(Trainer):
lr_scheduler = LRSchedulerFactory.from_args(scheduler_type, lr_scheduler = LRSchedulerFactory.from_args(scheduler_type,
scheduler_args) scheduler_args)
# if scheduler_type == 'expdecaylr':
# lr_scheduler = paddle.optimizer.lr.ExponentialDecay(
# learning_rate=optim_conf.lr,
# gamma=scheduler_conf.lr_decay,
# verbose=False)
# elif scheduler_type == 'warmuplr':
# lr_scheduler = WarmupLR(
# learning_rate=optim_conf.lr,
# warmup_steps=scheduler_conf.warmup_steps,
# verbose=False)
# elif scheduler_type == 'noam':
# lr_scheduler = paddle.optimizer.lr.NoamDecay(
# learning_rate=optim_conf.lr,
# d_model=model_conf.encoder_conf.output_size,
# warmup_steps=scheduler_conf.warmup_steps,
# verbose=False)
# else:
# raise ValueError(f"Not support scheduler: {scheduler_type}")
def optimizer_args( def optimizer_args(
config, config,
parameters, parameters,
@ -366,17 +345,6 @@ class U2Trainer(Trainer):
optimzer_args = optimizer_args(config, model.parameters(), lr_scheduler) optimzer_args = optimizer_args(config, model.parameters(), lr_scheduler)
optimizer = OptimizerFactory.from_args(optim_type, optimzer_args) optimizer = OptimizerFactory.from_args(optim_type, optimzer_args)
# grad_clip = ClipGradByGlobalNormWithLog(train_config.global_grad_clip)
# weight_decay = paddle.regularizer.L2Decay(optim_conf.weight_decay)
# if optim_type == 'adam':
# optimizer = paddle.optimizer.Adam(
# learning_rate=lr_scheduler,
# parameters=model.parameters(),
# weight_decay=weight_decay,
# grad_clip=grad_clip)
# else:
# raise ValueError(f"Not support optim: {optim_type}")
self.model = model self.model = model
self.optimizer = optimizer self.optimizer = optimizer
self.lr_scheduler = lr_scheduler self.lr_scheduler = lr_scheduler

Loading…
Cancel
Save