refactor trainer.py and rm ueseless dir setup code

pull/879/head
Hui Zhang 3 years ago
parent f5ec6e34c6
commit 8b45c3e65e

@ -386,13 +386,6 @@ class DeepSpeech2Tester(DeepSpeech2Trainer):
logger.info(msg)
self.autolog.report()
def run_test(self):
self.resume_or_scratch()
try:
self.test()
except KeyboardInterrupt:
exit(-1)
def export(self):
if self.args.model_type == 'offline':
infer_model = DeepSpeech2InferModel.from_pretrained(
@ -409,40 +402,6 @@ class DeepSpeech2Tester(DeepSpeech2Trainer):
logger.info(f"Export code: {static_model.forward.code}")
paddle.jit.save(static_model, self.args.export_path)
def run_export(self):
try:
self.export()
except KeyboardInterrupt:
exit(-1)
def setup(self):
"""Setup the experiment.
"""
paddle.set_device('gpu' if self.args.nprocs > 0 else 'cpu')
self.setup_output_dir()
self.setup_checkpointer()
self.setup_dataloader()
self.setup_model()
self.iteration = 0
self.epoch = 0
def setup_output_dir(self):
"""Create a directory used for output.
"""
# output dir
if self.args.output:
output_dir = Path(self.args.output).expanduser()
output_dir.mkdir(parents=True, exist_ok=True)
else:
output_dir = Path(
self.args.checkpoint_path).expanduser().parent.parent
output_dir.mkdir(parents=True, exist_ok=True)
self.output_dir = output_dir
class DeepSpeech2ExportTester(DeepSpeech2Tester):
def __init__(self, config, args):
@ -646,38 +605,6 @@ class DeepSpeech2ExportTester(DeepSpeech2Tester):
output_lens = output_lens_handle.copy_to_cpu()
return output_probs, output_lens
def run_test(self):
try:
self.test()
except KeyboardInterrupt:
exit(-1)
def setup(self):
"""Setup the experiment.
"""
paddle.set_device('gpu' if self.args.nprocs > 0 else 'cpu')
self.setup_output_dir()
self.setup_dataloader()
self.setup_model()
self.iteration = 0
self.epoch = 0
def setup_output_dir(self):
"""Create a directory used for output.
"""
# output dir
if self.args.output:
output_dir = Path(self.args.output).expanduser()
output_dir.mkdir(parents=True, exist_ok=True)
else:
output_dir = Path(self.args.export_path).expanduser().parent.parent
output_dir.mkdir(parents=True, exist_ok=True)
self.output_dir = output_dir
def setup_model(self):
super().setup_model()
speedyspeech_config = inference.Config(

@ -551,13 +551,6 @@ class U2Tester(U2Trainer):
})
f.write(data + '\n')
def run_test(self):
self.resume_or_scratch()
try:
self.test()
except KeyboardInterrupt:
sys.exit(-1)
@paddle.no_grad()
def align(self):
if self.config.decoding.batch_size > 1:
@ -617,13 +610,6 @@ class U2Tester(U2Trainer):
intervals=tierformat,
output=str(textgrid_path))
def run_align(self):
self.resume_or_scratch()
try:
self.align()
except KeyboardInterrupt:
sys.exit(-1)
def load_inferspec(self):
"""infer model and input spec.
@ -651,37 +637,3 @@ class U2Tester(U2Trainer):
static_model = paddle.jit.to_static(infer_model, input_spec=input_spec)
logger.info(f"Export code: {static_model.forward.code}")
paddle.jit.save(static_model, self.args.export_path)
def run_export(self):
try:
self.export()
except KeyboardInterrupt:
sys.exit(-1)
def setup(self):
"""Setup the experiment.
"""
paddle.set_device('gpu' if self.args.nprocs > 0 else 'cpu')
self.setup_output_dir()
self.setup_checkpointer()
self.setup_dataloader()
self.setup_model()
self.iteration = 0
self.epoch = 0
def setup_output_dir(self):
"""Create a directory used for output.
"""
# output dir
if self.args.output:
output_dir = Path(self.args.output).expanduser()
output_dir.mkdir(parents=True, exist_ok=True)
else:
output_dir = Path(
self.args.checkpoint_path).expanduser().parent.parent
output_dir.mkdir(parents=True, exist_ok=True)
self.output_dir = output_dir

@ -525,13 +525,6 @@ class U2Tester(U2Trainer):
})
f.write(data + '\n')
def run_test(self):
self.resume_or_scratch()
try:
self.test()
except KeyboardInterrupt:
sys.exit(-1)
@paddle.no_grad()
def align(self):
if self.config.decoding.batch_size > 1:
@ -591,13 +584,6 @@ class U2Tester(U2Trainer):
intervals=tierformat,
output=str(textgrid_path))
def run_align(self):
self.resume_or_scratch()
try:
self.align()
except KeyboardInterrupt:
sys.exit(-1)
def load_inferspec(self):
"""infer model and input spec.
@ -626,43 +612,11 @@ class U2Tester(U2Trainer):
logger.info(f"Export code: {static_model.forward.code}")
paddle.jit.save(static_model, self.args.export_path)
def run_export(self):
try:
self.export()
except KeyboardInterrupt:
sys.exit(-1)
def setup_dict(self):
# load dictionary for debug log
self.args.char_list = load_dict(self.args.dict_path,
"maskctc" in self.args.model_name)
def setup(self):
"""Setup the experiment.
"""
paddle.set_device('gpu' if self.args.nprocs > 0 else 'cpu')
self.setup_output_dir()
self.setup_checkpointer()
self.setup_dataloader()
self.setup_model()
super().setup()
self.setup_dict()
self.iteration = 0
self.epoch = 0
def setup_output_dir(self):
"""Create a directory used for output.
"""
# output dir
if self.args.output:
output_dir = Path(self.args.output).expanduser()
output_dir.mkdir(parents=True, exist_ok=True)
else:
output_dir = Path(
self.args.checkpoint_path).expanduser().parent.parent
output_dir.mkdir(parents=True, exist_ok=True)
self.output_dir = output_dir

@ -545,13 +545,6 @@ class U2STTester(U2STTrainer):
})
f.write(data + '\n')
def run_test(self):
self.resume_or_scratch()
try:
self.test()
except KeyboardInterrupt:
sys.exit(-1)
@paddle.no_grad()
def align(self):
if self.config.decoding.batch_size > 1:
@ -611,13 +604,6 @@ class U2STTester(U2STTrainer):
intervals=tierformat,
output=str(textgrid_path))
def run_align(self):
self.resume_or_scratch()
try:
self.align()
except KeyboardInterrupt:
sys.exit(-1)
def load_inferspec(self):
"""infer model and input spec.
@ -645,37 +631,3 @@ class U2STTester(U2STTrainer):
static_model = paddle.jit.to_static(infer_model, input_spec=input_spec)
logger.info(f"Export code: {static_model.forward.code}")
paddle.jit.save(static_model, self.args.export_path)
def run_export(self):
try:
self.export()
except KeyboardInterrupt:
sys.exit(-1)
def setup(self):
"""Setup the experiment.
"""
paddle.set_device('gpu' if self.args.nprocs > 0 else 'cpu')
self.setup_output_dir()
self.setup_checkpointer()
self.setup_dataloader()
self.setup_model()
self.iteration = 0
self.epoch = 0
def setup_output_dir(self):
"""Create a directory used for output.
"""
# output dir
if self.args.output:
output_dir = Path(self.args.output).expanduser()
output_dir.mkdir(parents=True, exist_ok=True)
else:
output_dir = Path(
self.args.checkpoint_path).expanduser().parent.parent
output_dir.mkdir(parents=True, exist_ok=True)
self.output_dir = output_dir

@ -12,7 +12,6 @@
# See the License for the specific language governing permissions and
# limitations under the License.
import inspect
from functools import partial
import paddle
from paddle import nn

@ -14,6 +14,7 @@
import sys
import time
from collections import OrderedDict
from contextlib import contextmanager
from pathlib import Path
import paddle
@ -103,14 +104,28 @@ class Trainer():
self.iteration = 0
self.epoch = 0
self.rank = dist.get_rank()
self.world_size = dist.get_world_size()
self._train = True
# print deps version
all_version()
logger.info(f"Rank: {self.rank}/{dist.get_world_size()}")
logger.info(f"Rank: {self.rank}/{self.world_size}")
# set device
paddle.set_device('gpu' if self.args.nprocs > 0 else 'cpu')
if self.parallel:
self.init_parallel()
self.checkpoint = Checkpoint(
kbest_n=self.config.training.checkpoint.kbest_n,
latest_n=self.config.training.checkpoint.latest_n)
# set random seed if needed
if args.seed:
seed_all(args.seed)
logger.info(f"Set seed {args.seed}")
# profiler and benchmark options
if self.args.benchmark_batch_size:
with UpdateConfig(self.config):
self.config.collator.batch_size = self.args.benchmark_batch_size
@ -118,17 +133,18 @@ class Trainer():
logger.info(
f"Benchmark reset batch-size: {self.args.benchmark_batch_size}")
@contextmanager
def eval(self):
self._train = False
yield
self._train = True
def setup(self):
"""Setup the experiment.
"""
paddle.set_device('gpu' if self.args.nprocs > 0 else 'cpu')
if self.parallel:
self.init_parallel()
self.setup_output_dir()
self.dump_config()
self.setup_visualizer()
self.setup_checkpointer()
self.setup_dataloader()
self.setup_model()
@ -183,8 +199,8 @@ class Trainer():
if infos:
# just restore ckpt
# lr will resotre from optimizer ckpt
self.iteration = infos["step"]
self.epoch = infos["epoch"]
self.iteration = infos["step"] + 1
self.epoch = infos["epoch"] + 1
scratch = False
logger.info(
f"Restore ckpt: epoch {self.epoch }, step {self.iteration}!")
@ -302,37 +318,74 @@ class Trainer():
"""The routine of the experiment after setup. This method is intended
to be used by the user.
"""
with Timer("Training Done: {}"):
try:
try:
with Timer("Training Done: {}"):
self.train()
except KeyboardInterrupt:
exit(-1)
finally:
self.destory()
except KeyboardInterrupt:
exit(-1)
finally:
self.destory()
def run_test(self):
"""Do Test/Decode"""
try:
with Timer("Test/Decode Done: {}"):
with self.eval():
self.resume_or_scratch()
self.test()
except KeyboardInterrupt:
exit(-1)
def run_export(self):
"""Do Model Export"""
try:
with Timer("Export Done: {}"):
with self.eval():
self.export()
except KeyboardInterrupt:
exit(-1)
def run_align(self):
"""Do CTC alignment"""
try:
with Timer("Align Done: {}"):
with self.eval():
self.resume_or_scratch()
self.align()
except KeyboardInterrupt:
sys.exit(-1)
def setup_output_dir(self):
"""Create a directory used for output.
"""
# output dir
output_dir = Path(self.args.output).expanduser()
output_dir.mkdir(parents=True, exist_ok=True)
if self.args.output:
output_dir = Path(self.args.output).expanduser()
elif self.args.checkpoint_path:
output_dir = Path(
self.args.checkpoint_path).expanduser().parent.parent
self.output_dir = output_dir
self.output_dir.mkdir(parents=True, exist_ok=True)
def setup_checkpointer(self):
"""Create a directory used to save checkpoints into.
self.checkpoint_dir = self.output_dir / "checkpoints"
self.checkpoint_dir.mkdir(parents=True, exist_ok=True)
It is "checkpoints" inside the output directory.
"""
# checkpoint dir
checkpoint_dir = self.output_dir / "checkpoints"
checkpoint_dir.mkdir(exist_ok=True)
self.log_dir = output_dir / "log"
self.log_dir.mkdir(parents=True, exist_ok=True)
self.checkpoint_dir = checkpoint_dir
self.test_dir = output_dir / "test"
self.test_dir.mkdir(parents=True, exist_ok=True)
self.checkpoint = Checkpoint(
kbest_n=self.config.training.checkpoint.kbest_n,
latest_n=self.config.training.checkpoint.latest_n)
self.decode_dir = output_dir / "decode"
self.decode_dir.mkdir(parents=True, exist_ok=True)
self.export_dir = output_dir / "export"
self.export_dir.mkdir(parents=True, exist_ok=True)
self.visual_dir = output_dir / "visual"
self.visual_dir.mkdir(parents=True, exist_ok=True)
self.config_dir = output_dir / "conf"
self.config_dir.mkdir(parents=True, exist_ok=True)
@mp_tools.rank_zero_only
def destory(self):
@ -354,7 +407,7 @@ class Trainer():
unexpected behaviors.
"""
# visualizer
visualizer = SummaryWriter(logdir=str(self.output_dir))
visualizer = SummaryWriter(logdir=str(self.visual_dir))
self.visualizer = visualizer
@mp_tools.rank_zero_only
@ -364,7 +417,14 @@ class Trainer():
It is saved in to ``config.yaml`` in the output directory at the
beginning of the experiment.
"""
with open(self.output_dir / "config.yaml", 'wt') as f:
config_file = self.config_dir / "config.yaml"
if self._train and config_file.exists():
time_stamp = time.strftime("%Y_%m_%d_%H_%M_%s", time.gmtime())
target_path = self.config_dir / ".".join(
[time_stamp, "config.yaml"])
config_file.rename(target_path)
with open(config_file, 'wt') as f:
print(self.config, file=f)
def train_batch(self):
@ -378,6 +438,24 @@ class Trainer():
"""
raise NotImplementedError("valid should be implemented.")
@paddle.no_grad()
def test(self):
"""The test. A subclass should implement this method in Tester.
"""
raise NotImplementedError("test should be implemented.")
@paddle.no_grad()
def export(self):
"""The test. A subclass should implement this method in Tester.
"""
raise NotImplementedError("export should be implemented.")
@paddle.no_grad()
def align(self):
"""The align. A subclass should implement this method in Tester.
"""
raise NotImplementedError("align should be implemented.")
def setup_model(self):
"""Setup model, criterion and optimizer, etc. A subclass should
implement this method.

Loading…
Cancel
Save