u2 with chianer updater

pull/823/head
Hui Zhang 3 years ago
parent 91bc5959a9
commit 3843372958

@ -80,23 +80,23 @@ def convert_dtype_to_string(tensor_dtype):
if not hasattr(paddle, 'softmax'): if not hasattr(paddle, 'softmax'):
logger.warn("register user softmax to paddle, remove this when fixed!") logger.debug("register user softmax to paddle, remove this when fixed!")
setattr(paddle, 'softmax', paddle.nn.functional.softmax) setattr(paddle, 'softmax', paddle.nn.functional.softmax)
if not hasattr(paddle, 'log_softmax'): if not hasattr(paddle, 'log_softmax'):
logger.warn("register user log_softmax to paddle, remove this when fixed!") logger.debug("register user log_softmax to paddle, remove this when fixed!")
setattr(paddle, 'log_softmax', paddle.nn.functional.log_softmax) setattr(paddle, 'log_softmax', paddle.nn.functional.log_softmax)
if not hasattr(paddle, 'sigmoid'): if not hasattr(paddle, 'sigmoid'):
logger.warn("register user sigmoid to paddle, remove this when fixed!") logger.debug("register user sigmoid to paddle, remove this when fixed!")
setattr(paddle, 'sigmoid', paddle.nn.functional.sigmoid) setattr(paddle, 'sigmoid', paddle.nn.functional.sigmoid)
if not hasattr(paddle, 'log_sigmoid'): if not hasattr(paddle, 'log_sigmoid'):
logger.warn("register user log_sigmoid to paddle, remove this when fixed!") logger.debug("register user log_sigmoid to paddle, remove this when fixed!")
setattr(paddle, 'log_sigmoid', paddle.nn.functional.log_sigmoid) setattr(paddle, 'log_sigmoid', paddle.nn.functional.log_sigmoid)
if not hasattr(paddle, 'relu'): if not hasattr(paddle, 'relu'):
logger.warn("register user relu to paddle, remove this when fixed!") logger.debug("register user relu to paddle, remove this when fixed!")
setattr(paddle, 'relu', paddle.nn.functional.relu) setattr(paddle, 'relu', paddle.nn.functional.relu)
@ -105,7 +105,7 @@ def cat(xs, dim=0):
if not hasattr(paddle, 'cat'): if not hasattr(paddle, 'cat'):
logger.warn( logger.debug(
"override cat of paddle if exists or register, remove this when fixed!") "override cat of paddle if exists or register, remove this when fixed!")
paddle.cat = cat paddle.cat = cat
@ -116,7 +116,7 @@ def item(x: paddle.Tensor):
if not hasattr(paddle.Tensor, 'item'): if not hasattr(paddle.Tensor, 'item'):
logger.warn( logger.debug(
"override item of paddle.Tensor if exists or register, remove this when fixed!" "override item of paddle.Tensor if exists or register, remove this when fixed!"
) )
paddle.Tensor.item = item paddle.Tensor.item = item
@ -127,13 +127,13 @@ def func_long(x: paddle.Tensor):
if not hasattr(paddle.Tensor, 'long'): if not hasattr(paddle.Tensor, 'long'):
logger.warn( logger.debug(
"override long of paddle.Tensor if exists or register, remove this when fixed!" "override long of paddle.Tensor if exists or register, remove this when fixed!"
) )
paddle.Tensor.long = func_long paddle.Tensor.long = func_long
if not hasattr(paddle.Tensor, 'numel'): if not hasattr(paddle.Tensor, 'numel'):
logger.warn( logger.debug(
"override numel of paddle.Tensor if exists or register, remove this when fixed!" "override numel of paddle.Tensor if exists or register, remove this when fixed!"
) )
paddle.Tensor.numel = paddle.numel paddle.Tensor.numel = paddle.numel
@ -147,7 +147,7 @@ def new_full(x: paddle.Tensor,
if not hasattr(paddle.Tensor, 'new_full'): if not hasattr(paddle.Tensor, 'new_full'):
logger.warn( logger.debug(
"override new_full of paddle.Tensor if exists or register, remove this when fixed!" "override new_full of paddle.Tensor if exists or register, remove this when fixed!"
) )
paddle.Tensor.new_full = new_full paddle.Tensor.new_full = new_full
@ -162,13 +162,13 @@ def eq(xs: paddle.Tensor, ys: Union[paddle.Tensor, float]) -> paddle.Tensor:
if not hasattr(paddle.Tensor, 'eq'): if not hasattr(paddle.Tensor, 'eq'):
logger.warn( logger.debug(
"override eq of paddle.Tensor if exists or register, remove this when fixed!" "override eq of paddle.Tensor if exists or register, remove this when fixed!"
) )
paddle.Tensor.eq = eq paddle.Tensor.eq = eq
if not hasattr(paddle, 'eq'): if not hasattr(paddle, 'eq'):
logger.warn( logger.debug(
"override eq of paddle if exists or register, remove this when fixed!") "override eq of paddle if exists or register, remove this when fixed!")
paddle.eq = eq paddle.eq = eq
@ -178,7 +178,7 @@ def contiguous(xs: paddle.Tensor) -> paddle.Tensor:
if not hasattr(paddle.Tensor, 'contiguous'): if not hasattr(paddle.Tensor, 'contiguous'):
logger.warn( logger.debug(
"override contiguous of paddle.Tensor if exists or register, remove this when fixed!" "override contiguous of paddle.Tensor if exists or register, remove this when fixed!"
) )
paddle.Tensor.contiguous = contiguous paddle.Tensor.contiguous = contiguous
@ -195,7 +195,7 @@ def size(xs: paddle.Tensor, *args: int) -> paddle.Tensor:
#`to_static` do not process `size` property, maybe some `paddle` api dependent on it. #`to_static` do not process `size` property, maybe some `paddle` api dependent on it.
logger.warn( logger.debug(
"override size of paddle.Tensor " "override size of paddle.Tensor "
"(`to_static` do not process `size` property, maybe some `paddle` api dependent on it), remove this when fixed!" "(`to_static` do not process `size` property, maybe some `paddle` api dependent on it), remove this when fixed!"
) )
@ -207,7 +207,7 @@ def view(xs: paddle.Tensor, *args: int) -> paddle.Tensor:
if not hasattr(paddle.Tensor, 'view'): if not hasattr(paddle.Tensor, 'view'):
logger.warn("register user view to paddle.Tensor, remove this when fixed!") logger.debug("register user view to paddle.Tensor, remove this when fixed!")
paddle.Tensor.view = view paddle.Tensor.view = view
@ -216,7 +216,7 @@ def view_as(xs: paddle.Tensor, ys: paddle.Tensor) -> paddle.Tensor:
if not hasattr(paddle.Tensor, 'view_as'): if not hasattr(paddle.Tensor, 'view_as'):
logger.warn( logger.debug(
"register user view_as to paddle.Tensor, remove this when fixed!") "register user view_as to paddle.Tensor, remove this when fixed!")
paddle.Tensor.view_as = view_as paddle.Tensor.view_as = view_as
@ -242,7 +242,7 @@ def masked_fill(xs: paddle.Tensor,
if not hasattr(paddle.Tensor, 'masked_fill'): if not hasattr(paddle.Tensor, 'masked_fill'):
logger.warn( logger.debug(
"register user masked_fill to paddle.Tensor, remove this when fixed!") "register user masked_fill to paddle.Tensor, remove this when fixed!")
paddle.Tensor.masked_fill = masked_fill paddle.Tensor.masked_fill = masked_fill
@ -260,7 +260,7 @@ def masked_fill_(xs: paddle.Tensor,
if not hasattr(paddle.Tensor, 'masked_fill_'): if not hasattr(paddle.Tensor, 'masked_fill_'):
logger.warn( logger.debug(
"register user masked_fill_ to paddle.Tensor, remove this when fixed!") "register user masked_fill_ to paddle.Tensor, remove this when fixed!")
paddle.Tensor.masked_fill_ = masked_fill_ paddle.Tensor.masked_fill_ = masked_fill_
@ -272,7 +272,8 @@ def fill_(xs: paddle.Tensor, value: Union[float, int]) -> paddle.Tensor:
if not hasattr(paddle.Tensor, 'fill_'): if not hasattr(paddle.Tensor, 'fill_'):
logger.warn("register user fill_ to paddle.Tensor, remove this when fixed!") logger.debug(
"register user fill_ to paddle.Tensor, remove this when fixed!")
paddle.Tensor.fill_ = fill_ paddle.Tensor.fill_ = fill_
@ -281,22 +282,22 @@ def repeat(xs: paddle.Tensor, *size: Any) -> paddle.Tensor:
if not hasattr(paddle.Tensor, 'repeat'): if not hasattr(paddle.Tensor, 'repeat'):
logger.warn( logger.debug(
"register user repeat to paddle.Tensor, remove this when fixed!") "register user repeat to paddle.Tensor, remove this when fixed!")
paddle.Tensor.repeat = repeat paddle.Tensor.repeat = repeat
if not hasattr(paddle.Tensor, 'softmax'): if not hasattr(paddle.Tensor, 'softmax'):
logger.warn( logger.debug(
"register user softmax to paddle.Tensor, remove this when fixed!") "register user softmax to paddle.Tensor, remove this when fixed!")
setattr(paddle.Tensor, 'softmax', paddle.nn.functional.softmax) setattr(paddle.Tensor, 'softmax', paddle.nn.functional.softmax)
if not hasattr(paddle.Tensor, 'sigmoid'): if not hasattr(paddle.Tensor, 'sigmoid'):
logger.warn( logger.debug(
"register user sigmoid to paddle.Tensor, remove this when fixed!") "register user sigmoid to paddle.Tensor, remove this when fixed!")
setattr(paddle.Tensor, 'sigmoid', paddle.nn.functional.sigmoid) setattr(paddle.Tensor, 'sigmoid', paddle.nn.functional.sigmoid)
if not hasattr(paddle.Tensor, 'relu'): if not hasattr(paddle.Tensor, 'relu'):
logger.warn("register user relu to paddle.Tensor, remove this when fixed!") logger.debug("register user relu to paddle.Tensor, remove this when fixed!")
setattr(paddle.Tensor, 'relu', paddle.nn.functional.relu) setattr(paddle.Tensor, 'relu', paddle.nn.functional.relu)
@ -305,7 +306,7 @@ def type_as(x: paddle.Tensor, other: paddle.Tensor) -> paddle.Tensor:
if not hasattr(paddle.Tensor, 'type_as'): if not hasattr(paddle.Tensor, 'type_as'):
logger.warn( logger.debug(
"register user type_as to paddle.Tensor, remove this when fixed!") "register user type_as to paddle.Tensor, remove this when fixed!")
setattr(paddle.Tensor, 'type_as', type_as) setattr(paddle.Tensor, 'type_as', type_as)
@ -321,7 +322,7 @@ def to(x: paddle.Tensor, *args, **kwargs) -> paddle.Tensor:
if not hasattr(paddle.Tensor, 'to'): if not hasattr(paddle.Tensor, 'to'):
logger.warn("register user to to paddle.Tensor, remove this when fixed!") logger.debug("register user to to paddle.Tensor, remove this when fixed!")
setattr(paddle.Tensor, 'to', to) setattr(paddle.Tensor, 'to', to)
@ -330,7 +331,8 @@ def func_float(x: paddle.Tensor) -> paddle.Tensor:
if not hasattr(paddle.Tensor, 'float'): if not hasattr(paddle.Tensor, 'float'):
logger.warn("register user float to paddle.Tensor, remove this when fixed!") logger.debug(
"register user float to paddle.Tensor, remove this when fixed!")
setattr(paddle.Tensor, 'float', func_float) setattr(paddle.Tensor, 'float', func_float)
@ -339,7 +341,7 @@ def func_int(x: paddle.Tensor) -> paddle.Tensor:
if not hasattr(paddle.Tensor, 'int'): if not hasattr(paddle.Tensor, 'int'):
logger.warn("register user int to paddle.Tensor, remove this when fixed!") logger.debug("register user int to paddle.Tensor, remove this when fixed!")
setattr(paddle.Tensor, 'int', func_int) setattr(paddle.Tensor, 'int', func_int)
@ -348,6 +350,6 @@ def tolist(x: paddle.Tensor) -> List[Any]:
if not hasattr(paddle.Tensor, 'tolist'): if not hasattr(paddle.Tensor, 'tolist'):
logger.warn( logger.debug(
"register user tolist to paddle.Tensor, remove this when fixed!") "register user tolist to paddle.Tensor, remove this when fixed!")
setattr(paddle.Tensor, 'tolist', tolist) setattr(paddle.Tensor, 'tolist', tolist)

@ -18,10 +18,12 @@ import os
from paddle import distributed as dist from paddle import distributed as dist
from deepspeech.exps.u2.config import get_cfg_defaults from deepspeech.exps.u2.config import get_cfg_defaults
from deepspeech.exps.u2.model import U2Trainer as Trainer # from deepspeech.exps.u2.trainer import U2Trainer as Trainer
from deepspeech.training.cli import default_argument_parser from deepspeech.training.cli import default_argument_parser
from deepspeech.utils.utility import print_arguments from deepspeech.utils.utility import print_arguments
from deepspeech.exps.u2.model import U2Trainer as Trainer
def main_sp(config, args): def main_sp(config, args):
exp = Trainer(config, args) exp = Trainer(config, args)

@ -0,0 +1,219 @@
# Copyright (c) 2021 PaddlePaddle Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""Contains U2 model."""
import paddle
from paddle import distributed as dist
from paddle.io import DataLoader
from deepspeech.io.collator import SpeechCollator
from deepspeech.io.dataset import ManifestDataset
from deepspeech.io.sampler import SortagradBatchSampler
from deepspeech.io.sampler import SortagradDistributedBatchSampler
from deepspeech.models.u2 import U2Evaluator
from deepspeech.models.u2 import U2Model
from deepspeech.models.u2 import U2Updater
from deepspeech.training.extensions.snapshot import Snapshot
from deepspeech.training.extensions.visualizer import VisualDL
from deepspeech.training.optimizer import OptimizerFactory
from deepspeech.training.scheduler import LRSchedulerFactory
from deepspeech.training.timer import Timer
from deepspeech.training.trainer import Trainer
from deepspeech.training.updaters.trainer import Trainer as NewTrainer
from deepspeech.utils import layer_tools
from deepspeech.utils.log import Log
logger = Log(__name__).getlog()
class U2Trainer(Trainer):
def __init__(self, config, args):
super().__init__(config, args)
def setup_dataloader(self):
config = self.config.clone()
config.defrost()
config.collator.keep_transcription_text = False
# train/valid dataset, return token ids
config.data.manifest = config.data.train_manifest
train_dataset = ManifestDataset.from_config(config)
config.data.manifest = config.data.dev_manifest
dev_dataset = ManifestDataset.from_config(config)
collate_fn_train = SpeechCollator.from_config(config)
config.collator.augmentation_config = ""
collate_fn_dev = SpeechCollator.from_config(config)
if self.parallel:
batch_sampler = SortagradDistributedBatchSampler(
train_dataset,
batch_size=config.collator.batch_size,
num_replicas=None,
rank=None,
shuffle=True,
drop_last=True,
sortagrad=config.collator.sortagrad,
shuffle_method=config.collator.shuffle_method)
else:
batch_sampler = SortagradBatchSampler(
train_dataset,
shuffle=True,
batch_size=config.collator.batch_size,
drop_last=True,
sortagrad=config.collator.sortagrad,
shuffle_method=config.collator.shuffle_method)
self.train_loader = DataLoader(
train_dataset,
batch_sampler=batch_sampler,
collate_fn=collate_fn_train,
num_workers=config.collator.num_workers, )
self.valid_loader = DataLoader(
dev_dataset,
batch_size=config.collator.batch_size,
shuffle=False,
drop_last=False,
collate_fn=collate_fn_dev)
# test dataset, return raw text
config.data.manifest = config.data.test_manifest
# filter test examples, will cause less examples, but no mismatch with training
# and can use large batch size , save training time, so filter test egs now.
config.data.min_input_len = 0.0 # second
config.data.max_input_len = float('inf') # second
config.data.min_output_len = 0.0 # tokens
config.data.max_output_len = float('inf') # tokens
config.data.min_output_input_ratio = 0.00
config.data.max_output_input_ratio = float('inf')
test_dataset = ManifestDataset.from_config(config)
# return text ord id
config.collator.keep_transcription_text = True
config.collator.augmentation_config = ""
self.test_loader = DataLoader(
test_dataset,
batch_size=config.decoding.batch_size,
shuffle=False,
drop_last=False,
collate_fn=SpeechCollator.from_config(config))
# return text token id
config.collator.keep_transcription_text = False
self.align_loader = DataLoader(
test_dataset,
batch_size=config.decoding.batch_size,
shuffle=False,
drop_last=False,
collate_fn=SpeechCollator.from_config(config))
logger.info("Setup train/valid/test/align Dataloader!")
def setup_model(self):
config = self.config
model_conf = config.model
model_conf.defrost()
model_conf.input_dim = self.train_loader.collate_fn.feature_size
model_conf.output_dim = self.train_loader.collate_fn.vocab_size
model_conf.freeze()
model = U2Model.from_config(model_conf)
if self.parallel:
model = paddle.DataParallel(model)
model.train()
logger.info(f"{model}")
layer_tools.print_params(model, logger.info)
train_config = config.training
optim_type = train_config.optim
optim_conf = train_config.optim_conf
scheduler_type = train_config.scheduler
scheduler_conf = train_config.scheduler_conf
scheduler_args = {
"learning_rate": optim_conf.lr,
"verbose": False,
"warmup_steps": scheduler_conf.warmup_steps,
"gamma": scheduler_conf.lr_decay,
"d_model": model_conf.encoder_conf.output_size,
}
lr_scheduler = LRSchedulerFactory.from_args(scheduler_type,
scheduler_args)
def optimizer_args(
config,
parameters,
lr_scheduler=None, ):
train_config = config.training
optim_type = train_config.optim
optim_conf = train_config.optim_conf
scheduler_type = train_config.scheduler
scheduler_conf = train_config.scheduler_conf
return {
"grad_clip": train_config.global_grad_clip,
"weight_decay": optim_conf.weight_decay,
"learning_rate": lr_scheduler
if lr_scheduler else optim_conf.lr,
"parameters": parameters,
"epsilon": 1e-9 if optim_type == 'noam' else None,
"beta1": 0.9 if optim_type == 'noam' else None,
"beat2": 0.98 if optim_type == 'noam' else None,
}
optimzer_args = optimizer_args(config, model.parameters(), lr_scheduler)
optimizer = OptimizerFactory.from_args(optim_type, optimzer_args)
self.model = model
self.optimizer = optimizer
self.lr_scheduler = lr_scheduler
logger.info("Setup model/optimizer/lr_scheduler!")
def setup_updater(self):
output_dir = self.output_dir
config = self.config.training
updater = U2Updater(
model=self.model,
optimizer=self.optimizer,
scheduler=self.lr_scheduler,
dataloader=self.train_loader,
output_dir=output_dir,
accum_grad=config.accum_grad)
trainer = NewTrainer(updater, (config.n_epoch, 'epoch'), output_dir)
evaluator = U2Evaluator(self.model, self.valid_loader)
trainer.extend(evaluator, trigger=(1, "epoch"))
if dist.get_rank() == 0:
trainer.extend(VisualDL(output_dir), trigger=(1, "iteration"))
num_snapshots = config.checkpoint.kbest_n
trainer.extend(
Snapshot(
mode='kbest',
max_size=num_snapshots,
indicator='VALID/LOSS',
less_better=True),
trigger=(1, 'epoch'))
# print(trainer.extensions)
# trainer.run()
self.trainer = trainer
def run(self):
"""The routine of the experiment after setup. This method is intended
to be used by the user.
"""
self.setup_updater()
with Timer("Training Done: {}"):
self.trainer.run()

@ -0,0 +1,19 @@
# Copyright (c) 2021 PaddlePaddle Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
from .u2 import U2InferModel
from .u2 import U2Model
from .updater import U2Evaluator
from .updater import U2Updater
__all__ = ["U2Model", "U2InferModel", "U2Evaluator", "U2Updater"]

@ -0,0 +1,149 @@
# Copyright (c) 2021 PaddlePaddle Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
from contextlib import nullcontext
import paddle
from paddle import distributed as dist
from deepspeech.training.extensions.evaluator import StandardEvaluator
from deepspeech.training.reporter import report
from deepspeech.training.timer import Timer
from deepspeech.training.updaters.standard_updater import StandardUpdater
from deepspeech.utils import layer_tools
from deepspeech.utils.log import Log
logger = Log(__name__).getlog()
class U2Evaluator(StandardEvaluator):
def __init__(self, model, dataloader):
super().__init__(model, dataloader)
self.msg = ""
self.num_seen_utts = 0
self.total_loss = 0.0
def evaluate_core(self, batch):
self.msg = "Valid: Rank: {}, ".format(dist.get_rank())
losses_dict = {}
loss, attention_loss, ctc_loss = self.model(*batch[1:])
if paddle.isfinite(loss):
num_utts = batch[1].shape[0]
self.num_seen_utts += num_utts
self.total_loss += float(loss) * num_utts
losses_dict['loss'] = float(loss)
if attention_loss:
losses_dict['att_loss'] = float(attention_loss)
if ctc_loss:
losses_dict['ctc_loss'] = float(ctc_loss)
for k, v in losses_dict.items():
report("eval/" + k, v)
self.msg += ', '.join('{}: {:>.6f}'.format(k, v)
for k, v in losses_dict.items())
logger.info(self.msg)
return self.total_loss, self.num_seen_utts
class U2Updater(StandardUpdater):
def __init__(self,
model,
optimizer,
scheduler,
dataloader,
init_state=None,
accum_grad=1,
**kwargs):
super().__init__(
model, optimizer, scheduler, dataloader, init_state=init_state)
self.accum_grad = accum_grad
self.forward_count = 0
self.msg = ""
def update_core(self, batch):
"""One Step
Args:
batch (List[Object]): utts, xs, xlens, ys, ylens
"""
losses_dict = {}
self.msg = "Rank: {}, ".format(dist.get_rank())
# forward
batch_size = batch[1].shape[0]
loss, attention_loss, ctc_loss = self.model(*batch[1:])
# loss div by `batch_size * accum_grad`
loss /= self.accum_grad
# loss backward
if (self.forward_count + 1) != self.accum_grad:
# Disable gradient synchronizations across DDP processes.
# Within this context, gradients will be accumulated on module
# variables, which will later be synchronized.
context = self.model.no_sync
else:
# Used for single gpu training and DDP gradient synchronization
# processes.
context = nullcontext
with context():
loss.backward()
layer_tools.print_grads(self.model, print_func=None)
# loss info
losses_dict['loss'] = float(loss) * self.accum_grad
if attention_loss:
losses_dict['att_loss'] = float(attention_loss)
if ctc_loss:
losses_dict['ctc_loss'] = float(ctc_loss)
# report loss
for k, v in losses_dict.items():
report("train/" + k, v)
# loss msg
self.msg += "batch size: {}, ".format(batch_size)
self.msg += "accum: {}, ".format(self.accum_grad)
self.msg += ', '.join('{}: {:>.6f}'.format(k, v)
for k, v in losses_dict.items())
# Truncate the graph
loss.detach()
# update parameters
self.forward_count += 1
if self.forward_count != self.accum_grad:
return
self.forward_count = 0
self.optimizer.step()
self.optimizer.clear_grad()
self.scheduler.step()
def update(self):
# model is default in train mode
# training for a step is implemented here
with Timer("data time cost:{}"):
batch = self.read_batch()
with Timer("step time cost:{}"):
self.update_core(batch)
# #iterations with accum_grad > 1
# Ref.: https://github.com/espnet/espnet/issues/777
if self.forward_count == 0:
self.state.iteration += 1
if self.updates_per_epoch is not None:
if self.state.iteration % self.updates_per_epoch == 0:
self.state.epoch += 1

@ -46,7 +46,7 @@ class CTCLoss(nn.Layer):
if grad_norm_type == 'instance': if grad_norm_type == 'instance':
self.norm_by_times = True self.norm_by_times = True
if grad_norm_type == 'batch': if grad_norm_type == 'batch':
self.norm_by_times = True self.norm_by_batchsize = True
if grad_norm_type == 'frame': if grad_norm_type == 'frame':
self.norm_by_total_logits_len = True self.norm_by_total_logits_len = True

@ -13,14 +13,18 @@
# limitations under the License. # limitations under the License.
from typing import Dict from typing import Dict
import extension
import paddle import paddle
from paddle import distributed as dist
from paddle.io import DataLoader from paddle.io import DataLoader
from paddle.nn import Layer from paddle.nn import Layer
from . import extension
from ..reporter import DictSummary from ..reporter import DictSummary
from ..reporter import report from ..reporter import report
from ..reporter import scope from ..reporter import scope
from ..timer import Timer
from deepspeech.utils.log import Log
logger = Log(__name__).getlog()
class StandardEvaluator(extension.Extension): class StandardEvaluator(extension.Extension):
@ -43,6 +47,27 @@ class StandardEvaluator(extension.Extension):
def evaluate_core(self, batch): def evaluate_core(self, batch):
# compute # compute
self.model(batch) # you may report here self.model(batch) # you may report here
return
def evaluate_sync(self, data):
# dist sync `evaluate_core` outputs
if data is None:
return
numerator, denominator = data
if dist.get_world_size() > 1:
numerator = paddle.to_tensor(numerator)
denominator = paddle.to_tensor(denominator)
# the default operator in all_reduce function is sum.
dist.all_reduce(numerator)
dist.all_reduce(denominator)
value = numerator / denominator
value = float(value)
else:
value = numerator / denominator
# used for `snapshort` to do kbest save.
report("VALID/LOSS", value)
logger.info(f"Valid: all-reduce loss {value}")
def evaluate(self): def evaluate(self):
# switch to eval mode # switch to eval mode
@ -56,9 +81,13 @@ class StandardEvaluator(extension.Extension):
with scope(observation): with scope(observation):
# main evaluation computation here. # main evaluation computation here.
with paddle.no_grad(): with paddle.no_grad():
self.evaluate_core(batch) self.evaluate_sync(self.evaluate_core(batch))
summary.add(observation) summary.add(observation)
summary = summary.compute_mean() summary = summary.compute_mean()
# switch to train mode
for model in self.models.values():
model.train()
return summary return summary
def __call__(self, trainer=None): def __call__(self, trainer=None):
@ -66,6 +95,7 @@ class StandardEvaluator(extension.Extension):
# if it is used to extend a trainer, the metrics is reported to # if it is used to extend a trainer, the metrics is reported to
# to observation of the trainer # to observation of the trainer
# or otherwise, you can use your own observation # or otherwise, you can use your own observation
summary = self.evaluate() with Timer("Eval Time Cost: {}"):
summary = self.evaluate()
for k, v in summary.items(): for k, v in summary.items():
report(k, v) report(k, v)

@ -20,8 +20,9 @@ from typing import List
import jsonlines import jsonlines
from deepspeech.training.extensions import extension from . import extension
from deepspeech.training.updaters.trainer import Trainer from ..reporter import get_observations
from ..updaters.trainer import Trainer
from deepspeech.utils.log import Log from deepspeech.utils.log import Log
from deepspeech.utils.mp_tools import rank_zero_only from deepspeech.utils.mp_tools import rank_zero_only
@ -52,8 +53,19 @@ class Snapshot(extension.Extension):
priority = -100 priority = -100
default_name = "snapshot" default_name = "snapshot"
def __init__(self, max_size: int=5, snapshot_on_error: bool=False): def __init__(self,
mode='latest',
max_size: int=5,
indicator=None,
less_better=True,
snapshot_on_error: bool=False):
self.records: List[Dict[str, Any]] = [] self.records: List[Dict[str, Any]] = []
assert mode in ('latest', 'kbest'), mode
if mode == 'kbest':
assert indicator is not None
self.mode = mode
self.indicator = indicator
self.less_is_better = less_better
self.max_size = max_size self.max_size = max_size
self._snapshot_on_error = snapshot_on_error self._snapshot_on_error = snapshot_on_error
self._save_all = (max_size == -1) self._save_all = (max_size == -1)
@ -66,16 +78,17 @@ class Snapshot(extension.Extension):
# load existing records # load existing records
record_path: Path = self.checkpoint_dir / "records.jsonl" record_path: Path = self.checkpoint_dir / "records.jsonl"
if record_path.exists(): if record_path.exists():
logger.debug("Loading from an existing checkpoint dir")
self.records = load_records(record_path) self.records = load_records(record_path)
trainer.updater.load(self.records[-1]['path']) ckpt_path = self.records[-1]['path']
logger.info(f"Loading from an existing checkpoint {ckpt_path}")
trainer.updater.load(ckpt_path)
def on_error(self, trainer, exc, tb): def on_error(self, trainer, exc, tb):
if self._snapshot_on_error: if self._snapshot_on_error:
self.save_checkpoint_and_update(trainer) self.save_checkpoint_and_update(trainer, 'latest')
def __call__(self, trainer: Trainer): def __call__(self, trainer: Trainer):
self.save_checkpoint_and_update(trainer) self.save_checkpoint_and_update(trainer, self.mode)
def full(self): def full(self):
"""Whether the number of snapshots it keeps track of is greater """Whether the number of snapshots it keeps track of is greater
@ -83,7 +96,7 @@ class Snapshot(extension.Extension):
return (not self._save_all) and len(self.records) > self.max_size return (not self._save_all) and len(self.records) > self.max_size
@rank_zero_only @rank_zero_only
def save_checkpoint_and_update(self, trainer: Trainer): def save_checkpoint_and_update(self, trainer: Trainer, mode: str):
"""Saving new snapshot and remove the oldest snapshot if needed.""" """Saving new snapshot and remove the oldest snapshot if needed."""
iteration = trainer.updater.state.iteration iteration = trainer.updater.state.iteration
epoch = trainer.updater.state.epoch epoch = trainer.updater.state.epoch
@ -97,11 +110,17 @@ class Snapshot(extension.Extension):
'path': str(path.resolve()), # use absolute path 'path': str(path.resolve()), # use absolute path
'iteration': iteration, 'iteration': iteration,
'epoch': epoch, 'epoch': epoch,
'indicator': get_observations()[self.indicator]
} }
self.records.append(record) self.records.append(record)
# remove the earist # remove the earist
if self.full(): if self.full():
if mode == 'kbest':
self.records = sorted(
self.records,
key=lambda record: record['indicator'],
reverse=not self.less_is_better)
eariest_record = self.records[0] eariest_record = self.records[0]
os.remove(eariest_record["path"]) os.remove(eariest_record["path"])
self.records.pop(0) self.records.pop(0)

@ -11,8 +11,10 @@
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and # See the License for the specific language governing permissions and
# limitations under the License. # limitations under the License.
from deepspeech.training.extensions import extension from visualdl import LogWriter
from deepspeech.training.updaters.trainer import Trainer
from . import extension
from ..updaters.trainer import Trainer
class VisualDL(extension.Extension): class VisualDL(extension.Extension):
@ -26,8 +28,8 @@ class VisualDL(extension.Extension):
default_name = 'visualdl' default_name = 'visualdl'
priority = extension.PRIORITY_READER priority = extension.PRIORITY_READER
def __init__(self, writer): def __init__(self, output_dir):
self.writer = writer self.writer = LogWriter(str(output_dir))
def __call__(self, trainer: Trainer): def __call__(self, trainer: Trainer):
for k, v in trainer.observation.items(): for k, v in trainer.observation.items():

@ -171,7 +171,7 @@ class Trainer():
self.iteration = 0 self.iteration = 0
self.epoch = 0 self.epoch = 0
scratch = True scratch = True
logger.info("Restore/Init checkpoint!")
return scratch return scratch
def new_epoch(self): def new_epoch(self):

@ -14,12 +14,12 @@
from typing import Dict from typing import Dict
from typing import Optional from typing import Optional
from paddle import Tensor import paddle
from paddle.io import DataLoader from paddle.io import DataLoader
from paddle.io import DistributedBatchSampler from paddle.io import DistributedBatchSampler
from paddle.nn import Layer from paddle.nn import Layer
from paddle.optimizer import Optimizer from paddle.optimizer import Optimizer
from timer import timer from paddle.optimizer.lr import LRScheduler
from deepspeech.training.reporter import report from deepspeech.training.reporter import report
from deepspeech.training.updaters.updater import UpdaterBase from deepspeech.training.updaters.updater import UpdaterBase
@ -39,8 +39,10 @@ class StandardUpdater(UpdaterBase):
def __init__(self, def __init__(self,
model: Layer, model: Layer,
optimizer: Optimizer, optimizer: Optimizer,
scheduler: LRScheduler,
dataloader: DataLoader, dataloader: DataLoader,
init_state: Optional[UpdaterState]=None): init_state: Optional[UpdaterState]=None):
super().__init__(init_state)
# it is designed to hold multiple models # it is designed to hold multiple models
models = {"main": model} models = {"main": model}
self.models: Dict[str, Layer] = models self.models: Dict[str, Layer] = models
@ -51,15 +53,14 @@ class StandardUpdater(UpdaterBase):
self.optimizer = optimizer self.optimizer = optimizer
self.optimizers: Dict[str, Optimizer] = optimizers self.optimizers: Dict[str, Optimizer] = optimizers
# it is designed to hold multiple scheduler
schedulers = {"main": scheduler}
self.scheduler = scheduler
self.schedulers: Dict[str, LRScheduler] = schedulers
# dataloaders # dataloaders
self.dataloader = dataloader self.dataloader = dataloader
# init state
if init_state is None:
self.state = UpdaterState()
else:
self.state = init_state
self.train_iterator = iter(dataloader) self.train_iterator = iter(dataloader)
def update(self): def update(self):
@ -103,8 +104,10 @@ class StandardUpdater(UpdaterBase):
model.train() model.train()
# training for a step is implemented here # training for a step is implemented here
batch = self.read_batch() with Timier("data time cost:{}"):
self.update_core(batch) batch = self.read_batch()
with Timier("step time cost:{}"):
self.update_core(batch)
self.state.iteration += 1 self.state.iteration += 1
if self.updates_per_epoch is not None: if self.updates_per_epoch is not None:
@ -115,13 +118,14 @@ class StandardUpdater(UpdaterBase):
"""A simple case for a training step. Basic assumptions are: """A simple case for a training step. Basic assumptions are:
Single model; Single model;
Single optimizer; Single optimizer;
Single scheduler, and update learning rate each step;
A batch from the dataloader is just the input of the model; A batch from the dataloader is just the input of the model;
The model return a single loss, or a dict containing serval losses. The model return a single loss, or a dict containing serval losses.
Parameters updates at every batch, no gradient accumulation. Parameters updates at every batch, no gradient accumulation.
""" """
loss = self.model(*batch) loss = self.model(*batch)
if isinstance(loss, Tensor): if isinstance(loss, paddle.Tensor):
loss_dict = {"main": loss} loss_dict = {"main": loss}
else: else:
# Dict[str, Tensor] # Dict[str, Tensor]
@ -135,14 +139,15 @@ class StandardUpdater(UpdaterBase):
for name, loss_item in loss_dict.items(): for name, loss_item in loss_dict.items():
report(name, float(loss_item)) report(name, float(loss_item))
self.optimizer.clear_gradient() self.optimizer.clear_grad()
loss_dict["main"].backward() loss_dict["main"].backward()
self.optimizer.update() self.optimizer.step()
self.scheduler.step()
@property @property
def updates_per_epoch(self): def updates_per_epoch(self):
"""Number of updater per epoch, determined by the length of the """Number of steps per epoch,
dataloader.""" determined by the length of the dataloader."""
length_of_dataloader = None length_of_dataloader = None
try: try:
length_of_dataloader = len(self.dataloader) length_of_dataloader = len(self.dataloader)
@ -163,18 +168,16 @@ class StandardUpdater(UpdaterBase):
def read_batch(self): def read_batch(self):
"""Read a batch from the data loader, auto renew when data is exhausted.""" """Read a batch from the data loader, auto renew when data is exhausted."""
with timer() as t: try:
try: batch = next(self.train_iterator)
batch = next(self.train_iterator) except StopIteration:
except StopIteration: self.new_epoch()
self.new_epoch() batch = next(self.train_iterator)
batch = next(self.train_iterator)
logger.debug(
f"Read a batch takes {t.elapse}s.") # replace it with logger
return batch return batch
def state_dict(self): def state_dict(self):
"""State dict of a Updater, model, optimizer and updater state are included.""" """State dict of a Updater, model, optimizers/schedulers
and updater state are included."""
state_dict = super().state_dict() state_dict = super().state_dict()
for name, model in self.models.items(): for name, model in self.models.items():
state_dict[f"{name}_params"] = model.state_dict() state_dict[f"{name}_params"] = model.state_dict()
@ -184,7 +187,7 @@ class StandardUpdater(UpdaterBase):
def set_state_dict(self, state_dict): def set_state_dict(self, state_dict):
"""Set state dict for a Updater. Parameters of models, states for """Set state dict for a Updater. Parameters of models, states for
optimizers and UpdaterState are restored.""" optimizers/schedulers and UpdaterState are restored."""
for name, model in self.models.items(): for name, model in self.models.items():
model.set_state_dict(state_dict[f"{name}_params"]) model.set_state_dict(state_dict[f"{name}_params"])
for name, optim in self.optimizers.items(): for name, optim in self.optimizers.items():

@ -140,8 +140,8 @@ class Trainer():
try: try:
while not stop_trigger(self): while not stop_trigger(self):
self.observation = {} self.observation = {}
# set observation as the report target # set observation as the `report` target
# you can use report freely in Updater.update() # you can use `report` freely in Updater.update()
# updating parameters and state # updating parameters and state
with scope(self.observation): with scope(self.observation):

@ -52,6 +52,7 @@ class UpdaterBase():
""" """
def __init__(self, init_state=None): def __init__(self, init_state=None):
# init state
if init_state is None: if init_state is None:
self.state = UpdaterState() self.state = UpdaterState()
else: else:

@ -114,13 +114,13 @@ class Checkpoint():
params_path = checkpoint_path + ".pdparams" params_path = checkpoint_path + ".pdparams"
model_dict = paddle.load(params_path) model_dict = paddle.load(params_path)
model.set_state_dict(model_dict) model.set_state_dict(model_dict)
logger.info("Rank {}: loaded model from {}".format(rank, params_path)) logger.info("Rank {}: Restore model from {}".format(rank, params_path))
optimizer_path = checkpoint_path + ".pdopt" optimizer_path = checkpoint_path + ".pdopt"
if optimizer and os.path.isfile(optimizer_path): if optimizer and os.path.isfile(optimizer_path):
optimizer_dict = paddle.load(optimizer_path) optimizer_dict = paddle.load(optimizer_path)
optimizer.set_state_dict(optimizer_dict) optimizer.set_state_dict(optimizer_dict)
logger.info("Rank {}: loaded optimizer state from {}".format( logger.info("Rank {}: Restore optimizer state from {}".format(
rank, optimizer_path)) rank, optimizer_path))
info_path = re.sub('.pdparams$', '.json', params_path) info_path = re.sub('.pdparams$', '.json', params_path)

@ -12,19 +12,13 @@
# See the License for the specific language governing permissions and # See the License for the specific language governing permissions and
# limitations under the License. # limitations under the License.
import getpass import getpass
import logging
import os import os
import socket import socket
import sys import sys
from loguru import logger
from paddle import inference from paddle import inference
FORMAT_STR = '[%(levelname)s %(asctime)s %(filename)s:%(lineno)d] %(message)s'
DATE_FMT_STR = '%Y/%m/%d %H:%M:%S'
logging.basicConfig(
level=logging.DEBUG, format=FORMAT_STR, datefmt=DATE_FMT_STR)
def find_log_dir(log_dir=None): def find_log_dir(log_dir=None):
"""Returns the most suitable directory to put log files into. """Returns the most suitable directory to put log files into.
@ -98,59 +92,28 @@ def find_log_dir_and_names(program_name=None, log_dir=None):
class Log(): class Log():
"""Default Logger for all."""
log_name = None logger.remove()
logger.add(
def __init__(self, logger=None): sys.stdout,
self.logger = logging.getLogger(logger) level='INFO',
self.logger.setLevel(logging.DEBUG) enqueue=True,
filter=lambda record: record['level'].no >= 20)
file_dir = os.getcwd() + '/log' _, file_prefix, _ = find_log_dir_and_names()
if not os.path.exists(file_dir): sink_prefix = os.path.join("exp/log", file_prefix)
os.mkdir(file_dir) sink_path = sink_prefix[:-3] + "{time}.log"
self.log_dir = file_dir logger.add(sink_path, level='DEBUG', enqueue=True, rotation="500 MB")
actual_log_dir, file_prefix, symlink_prefix = find_log_dir_and_names( def __init__(self, name=None):
program_name=None, log_dir=self.log_dir) pass
basename = '%s.DEBUG.%d' % (file_prefix, os.getpid())
filename = os.path.join(actual_log_dir, basename)
if Log.log_name is None:
Log.log_name = filename
# Create a symlink to the log file with a canonical name.
symlink = os.path.join(actual_log_dir, symlink_prefix + '.DEBUG')
try:
if os.path.islink(symlink):
os.unlink(symlink)
os.symlink(os.path.basename(Log.log_name), symlink)
except EnvironmentError:
# If it fails, we're sad but it's no error. Commonly, this
# fails because the symlink was created by another user and so
# we can't modify it
pass
if not self.logger.hasHandlers():
formatter = logging.Formatter(fmt=FORMAT_STR, datefmt=DATE_FMT_STR)
fh = logging.FileHandler(Log.log_name)
fh.setLevel(logging.DEBUG)
fh.setFormatter(formatter)
self.logger.addHandler(fh)
ch = logging.StreamHandler()
ch.setLevel(logging.INFO)
ch.setFormatter(formatter)
self.logger.addHandler(ch)
# stop propagate for propagating may print
# log multiple times
self.logger.propagate = False
def getlog(self): def getlog(self):
return self.logger return logger
class Autolog: class Autolog:
"""Just used by fullchain project"""
def __init__(self, def __init__(self,
batch_size, batch_size,
model_name="DeepSpeech", model_name="DeepSpeech",

@ -86,7 +86,7 @@ training:
lr_decay: 1.0 lr_decay: 1.0
log_interval: 1 log_interval: 1
checkpoint: checkpoint:
kbest_n: 10 kbest_n: 2
latest_n: 1 latest_n: 1

@ -1,7 +1,9 @@
coverage coverage
gpustat gpustat
jsonlines jsonlines
jsonlines
kaldiio kaldiio
loguru
Pillow Pillow
pre-commit pre-commit
pybind11 pybind11
@ -14,5 +16,7 @@ SoundFile==0.9.0.post1
sox sox
tensorboardX tensorboardX
textgrid textgrid
tqdm
typeguard typeguard
visualdl==2.2.0
yacs yacs

Loading…
Cancel
Save