refactor trainer.py and rm ueseless dir setup code

pull/879/head
Hui Zhang 3 years ago
parent f5ec6e34c6
commit 8b45c3e65e

@ -386,13 +386,6 @@ class DeepSpeech2Tester(DeepSpeech2Trainer):
logger.info(msg) logger.info(msg)
self.autolog.report() self.autolog.report()
def run_test(self):
self.resume_or_scratch()
try:
self.test()
except KeyboardInterrupt:
exit(-1)
def export(self): def export(self):
if self.args.model_type == 'offline': if self.args.model_type == 'offline':
infer_model = DeepSpeech2InferModel.from_pretrained( infer_model = DeepSpeech2InferModel.from_pretrained(
@ -409,40 +402,6 @@ class DeepSpeech2Tester(DeepSpeech2Trainer):
logger.info(f"Export code: {static_model.forward.code}") logger.info(f"Export code: {static_model.forward.code}")
paddle.jit.save(static_model, self.args.export_path) paddle.jit.save(static_model, self.args.export_path)
def run_export(self):
try:
self.export()
except KeyboardInterrupt:
exit(-1)
def setup(self):
"""Setup the experiment.
"""
paddle.set_device('gpu' if self.args.nprocs > 0 else 'cpu')
self.setup_output_dir()
self.setup_checkpointer()
self.setup_dataloader()
self.setup_model()
self.iteration = 0
self.epoch = 0
def setup_output_dir(self):
"""Create a directory used for output.
"""
# output dir
if self.args.output:
output_dir = Path(self.args.output).expanduser()
output_dir.mkdir(parents=True, exist_ok=True)
else:
output_dir = Path(
self.args.checkpoint_path).expanduser().parent.parent
output_dir.mkdir(parents=True, exist_ok=True)
self.output_dir = output_dir
class DeepSpeech2ExportTester(DeepSpeech2Tester): class DeepSpeech2ExportTester(DeepSpeech2Tester):
def __init__(self, config, args): def __init__(self, config, args):
@ -646,38 +605,6 @@ class DeepSpeech2ExportTester(DeepSpeech2Tester):
output_lens = output_lens_handle.copy_to_cpu() output_lens = output_lens_handle.copy_to_cpu()
return output_probs, output_lens return output_probs, output_lens
def run_test(self):
try:
self.test()
except KeyboardInterrupt:
exit(-1)
def setup(self):
"""Setup the experiment.
"""
paddle.set_device('gpu' if self.args.nprocs > 0 else 'cpu')
self.setup_output_dir()
self.setup_dataloader()
self.setup_model()
self.iteration = 0
self.epoch = 0
def setup_output_dir(self):
"""Create a directory used for output.
"""
# output dir
if self.args.output:
output_dir = Path(self.args.output).expanduser()
output_dir.mkdir(parents=True, exist_ok=True)
else:
output_dir = Path(self.args.export_path).expanduser().parent.parent
output_dir.mkdir(parents=True, exist_ok=True)
self.output_dir = output_dir
def setup_model(self): def setup_model(self):
super().setup_model() super().setup_model()
speedyspeech_config = inference.Config( speedyspeech_config = inference.Config(

@ -551,13 +551,6 @@ class U2Tester(U2Trainer):
}) })
f.write(data + '\n') f.write(data + '\n')
def run_test(self):
self.resume_or_scratch()
try:
self.test()
except KeyboardInterrupt:
sys.exit(-1)
@paddle.no_grad() @paddle.no_grad()
def align(self): def align(self):
if self.config.decoding.batch_size > 1: if self.config.decoding.batch_size > 1:
@ -617,13 +610,6 @@ class U2Tester(U2Trainer):
intervals=tierformat, intervals=tierformat,
output=str(textgrid_path)) output=str(textgrid_path))
def run_align(self):
self.resume_or_scratch()
try:
self.align()
except KeyboardInterrupt:
sys.exit(-1)
def load_inferspec(self): def load_inferspec(self):
"""infer model and input spec. """infer model and input spec.
@ -651,37 +637,3 @@ class U2Tester(U2Trainer):
static_model = paddle.jit.to_static(infer_model, input_spec=input_spec) static_model = paddle.jit.to_static(infer_model, input_spec=input_spec)
logger.info(f"Export code: {static_model.forward.code}") logger.info(f"Export code: {static_model.forward.code}")
paddle.jit.save(static_model, self.args.export_path) paddle.jit.save(static_model, self.args.export_path)
def run_export(self):
try:
self.export()
except KeyboardInterrupt:
sys.exit(-1)
def setup(self):
"""Setup the experiment.
"""
paddle.set_device('gpu' if self.args.nprocs > 0 else 'cpu')
self.setup_output_dir()
self.setup_checkpointer()
self.setup_dataloader()
self.setup_model()
self.iteration = 0
self.epoch = 0
def setup_output_dir(self):
"""Create a directory used for output.
"""
# output dir
if self.args.output:
output_dir = Path(self.args.output).expanduser()
output_dir.mkdir(parents=True, exist_ok=True)
else:
output_dir = Path(
self.args.checkpoint_path).expanduser().parent.parent
output_dir.mkdir(parents=True, exist_ok=True)
self.output_dir = output_dir

@ -525,13 +525,6 @@ class U2Tester(U2Trainer):
}) })
f.write(data + '\n') f.write(data + '\n')
def run_test(self):
self.resume_or_scratch()
try:
self.test()
except KeyboardInterrupt:
sys.exit(-1)
@paddle.no_grad() @paddle.no_grad()
def align(self): def align(self):
if self.config.decoding.batch_size > 1: if self.config.decoding.batch_size > 1:
@ -591,13 +584,6 @@ class U2Tester(U2Trainer):
intervals=tierformat, intervals=tierformat,
output=str(textgrid_path)) output=str(textgrid_path))
def run_align(self):
self.resume_or_scratch()
try:
self.align()
except KeyboardInterrupt:
sys.exit(-1)
def load_inferspec(self): def load_inferspec(self):
"""infer model and input spec. """infer model and input spec.
@ -626,43 +612,11 @@ class U2Tester(U2Trainer):
logger.info(f"Export code: {static_model.forward.code}") logger.info(f"Export code: {static_model.forward.code}")
paddle.jit.save(static_model, self.args.export_path) paddle.jit.save(static_model, self.args.export_path)
def run_export(self):
try:
self.export()
except KeyboardInterrupt:
sys.exit(-1)
def setup_dict(self): def setup_dict(self):
# load dictionary for debug log # load dictionary for debug log
self.args.char_list = load_dict(self.args.dict_path, self.args.char_list = load_dict(self.args.dict_path,
"maskctc" in self.args.model_name) "maskctc" in self.args.model_name)
def setup(self): def setup(self):
"""Setup the experiment. super().setup()
"""
paddle.set_device('gpu' if self.args.nprocs > 0 else 'cpu')
self.setup_output_dir()
self.setup_checkpointer()
self.setup_dataloader()
self.setup_model()
self.setup_dict() self.setup_dict()
self.iteration = 0
self.epoch = 0
def setup_output_dir(self):
"""Create a directory used for output.
"""
# output dir
if self.args.output:
output_dir = Path(self.args.output).expanduser()
output_dir.mkdir(parents=True, exist_ok=True)
else:
output_dir = Path(
self.args.checkpoint_path).expanduser().parent.parent
output_dir.mkdir(parents=True, exist_ok=True)
self.output_dir = output_dir

@ -545,13 +545,6 @@ class U2STTester(U2STTrainer):
}) })
f.write(data + '\n') f.write(data + '\n')
def run_test(self):
self.resume_or_scratch()
try:
self.test()
except KeyboardInterrupt:
sys.exit(-1)
@paddle.no_grad() @paddle.no_grad()
def align(self): def align(self):
if self.config.decoding.batch_size > 1: if self.config.decoding.batch_size > 1:
@ -611,13 +604,6 @@ class U2STTester(U2STTrainer):
intervals=tierformat, intervals=tierformat,
output=str(textgrid_path)) output=str(textgrid_path))
def run_align(self):
self.resume_or_scratch()
try:
self.align()
except KeyboardInterrupt:
sys.exit(-1)
def load_inferspec(self): def load_inferspec(self):
"""infer model and input spec. """infer model and input spec.
@ -645,37 +631,3 @@ class U2STTester(U2STTrainer):
static_model = paddle.jit.to_static(infer_model, input_spec=input_spec) static_model = paddle.jit.to_static(infer_model, input_spec=input_spec)
logger.info(f"Export code: {static_model.forward.code}") logger.info(f"Export code: {static_model.forward.code}")
paddle.jit.save(static_model, self.args.export_path) paddle.jit.save(static_model, self.args.export_path)
def run_export(self):
try:
self.export()
except KeyboardInterrupt:
sys.exit(-1)
def setup(self):
"""Setup the experiment.
"""
paddle.set_device('gpu' if self.args.nprocs > 0 else 'cpu')
self.setup_output_dir()
self.setup_checkpointer()
self.setup_dataloader()
self.setup_model()
self.iteration = 0
self.epoch = 0
def setup_output_dir(self):
"""Create a directory used for output.
"""
# output dir
if self.args.output:
output_dir = Path(self.args.output).expanduser()
output_dir.mkdir(parents=True, exist_ok=True)
else:
output_dir = Path(
self.args.checkpoint_path).expanduser().parent.parent
output_dir.mkdir(parents=True, exist_ok=True)
self.output_dir = output_dir

@ -12,7 +12,6 @@
# See the License for the specific language governing permissions and # See the License for the specific language governing permissions and
# limitations under the License. # limitations under the License.
import inspect import inspect
from functools import partial
import paddle import paddle
from paddle import nn from paddle import nn

@ -14,6 +14,7 @@
import sys import sys
import time import time
from collections import OrderedDict from collections import OrderedDict
from contextlib import contextmanager
from pathlib import Path from pathlib import Path
import paddle import paddle
@ -103,14 +104,28 @@ class Trainer():
self.iteration = 0 self.iteration = 0
self.epoch = 0 self.epoch = 0
self.rank = dist.get_rank() self.rank = dist.get_rank()
self.world_size = dist.get_world_size()
self._train = True
# print deps version
all_version() all_version()
logger.info(f"Rank: {self.rank}/{dist.get_world_size()}") logger.info(f"Rank: {self.rank}/{self.world_size}")
# set device
paddle.set_device('gpu' if self.args.nprocs > 0 else 'cpu')
if self.parallel:
self.init_parallel()
self.checkpoint = Checkpoint(
kbest_n=self.config.training.checkpoint.kbest_n,
latest_n=self.config.training.checkpoint.latest_n)
# set random seed if needed
if args.seed: if args.seed:
seed_all(args.seed) seed_all(args.seed)
logger.info(f"Set seed {args.seed}") logger.info(f"Set seed {args.seed}")
# profiler and benchmark options
if self.args.benchmark_batch_size: if self.args.benchmark_batch_size:
with UpdateConfig(self.config): with UpdateConfig(self.config):
self.config.collator.batch_size = self.args.benchmark_batch_size self.config.collator.batch_size = self.args.benchmark_batch_size
@ -118,17 +133,18 @@ class Trainer():
logger.info( logger.info(
f"Benchmark reset batch-size: {self.args.benchmark_batch_size}") f"Benchmark reset batch-size: {self.args.benchmark_batch_size}")
@contextmanager
def eval(self):
self._train = False
yield
self._train = True
def setup(self): def setup(self):
"""Setup the experiment. """Setup the experiment.
""" """
paddle.set_device('gpu' if self.args.nprocs > 0 else 'cpu')
if self.parallel:
self.init_parallel()
self.setup_output_dir() self.setup_output_dir()
self.dump_config() self.dump_config()
self.setup_visualizer() self.setup_visualizer()
self.setup_checkpointer()
self.setup_dataloader() self.setup_dataloader()
self.setup_model() self.setup_model()
@ -183,8 +199,8 @@ class Trainer():
if infos: if infos:
# just restore ckpt # just restore ckpt
# lr will resotre from optimizer ckpt # lr will resotre from optimizer ckpt
self.iteration = infos["step"] self.iteration = infos["step"] + 1
self.epoch = infos["epoch"] self.epoch = infos["epoch"] + 1
scratch = False scratch = False
logger.info( logger.info(
f"Restore ckpt: epoch {self.epoch }, step {self.iteration}!") f"Restore ckpt: epoch {self.epoch }, step {self.iteration}!")
@ -302,37 +318,74 @@ class Trainer():
"""The routine of the experiment after setup. This method is intended """The routine of the experiment after setup. This method is intended
to be used by the user. to be used by the user.
""" """
with Timer("Training Done: {}"): try:
try: with Timer("Training Done: {}"):
self.train() self.train()
except KeyboardInterrupt: except KeyboardInterrupt:
exit(-1) exit(-1)
finally: finally:
self.destory() self.destory()
def run_test(self):
"""Do Test/Decode"""
try:
with Timer("Test/Decode Done: {}"):
with self.eval():
self.resume_or_scratch()
self.test()
except KeyboardInterrupt:
exit(-1)
def run_export(self):
"""Do Model Export"""
try:
with Timer("Export Done: {}"):
with self.eval():
self.export()
except KeyboardInterrupt:
exit(-1)
def run_align(self):
"""Do CTC alignment"""
try:
with Timer("Align Done: {}"):
with self.eval():
self.resume_or_scratch()
self.align()
except KeyboardInterrupt:
sys.exit(-1)
def setup_output_dir(self): def setup_output_dir(self):
"""Create a directory used for output. """Create a directory used for output.
""" """
# output dir if self.args.output:
output_dir = Path(self.args.output).expanduser() output_dir = Path(self.args.output).expanduser()
output_dir.mkdir(parents=True, exist_ok=True) elif self.args.checkpoint_path:
output_dir = Path(
self.args.checkpoint_path).expanduser().parent.parent
self.output_dir = output_dir self.output_dir = output_dir
self.output_dir.mkdir(parents=True, exist_ok=True)
def setup_checkpointer(self): self.checkpoint_dir = self.output_dir / "checkpoints"
"""Create a directory used to save checkpoints into. self.checkpoint_dir.mkdir(parents=True, exist_ok=True)
It is "checkpoints" inside the output directory. self.log_dir = output_dir / "log"
""" self.log_dir.mkdir(parents=True, exist_ok=True)
# checkpoint dir
checkpoint_dir = self.output_dir / "checkpoints"
checkpoint_dir.mkdir(exist_ok=True)
self.checkpoint_dir = checkpoint_dir self.test_dir = output_dir / "test"
self.test_dir.mkdir(parents=True, exist_ok=True)
self.checkpoint = Checkpoint( self.decode_dir = output_dir / "decode"
kbest_n=self.config.training.checkpoint.kbest_n, self.decode_dir.mkdir(parents=True, exist_ok=True)
latest_n=self.config.training.checkpoint.latest_n)
self.export_dir = output_dir / "export"
self.export_dir.mkdir(parents=True, exist_ok=True)
self.visual_dir = output_dir / "visual"
self.visual_dir.mkdir(parents=True, exist_ok=True)
self.config_dir = output_dir / "conf"
self.config_dir.mkdir(parents=True, exist_ok=True)
@mp_tools.rank_zero_only @mp_tools.rank_zero_only
def destory(self): def destory(self):
@ -354,7 +407,7 @@ class Trainer():
unexpected behaviors. unexpected behaviors.
""" """
# visualizer # visualizer
visualizer = SummaryWriter(logdir=str(self.output_dir)) visualizer = SummaryWriter(logdir=str(self.visual_dir))
self.visualizer = visualizer self.visualizer = visualizer
@mp_tools.rank_zero_only @mp_tools.rank_zero_only
@ -364,7 +417,14 @@ class Trainer():
It is saved in to ``config.yaml`` in the output directory at the It is saved in to ``config.yaml`` in the output directory at the
beginning of the experiment. beginning of the experiment.
""" """
with open(self.output_dir / "config.yaml", 'wt') as f: config_file = self.config_dir / "config.yaml"
if self._train and config_file.exists():
time_stamp = time.strftime("%Y_%m_%d_%H_%M_%s", time.gmtime())
target_path = self.config_dir / ".".join(
[time_stamp, "config.yaml"])
config_file.rename(target_path)
with open(config_file, 'wt') as f:
print(self.config, file=f) print(self.config, file=f)
def train_batch(self): def train_batch(self):
@ -378,6 +438,24 @@ class Trainer():
""" """
raise NotImplementedError("valid should be implemented.") raise NotImplementedError("valid should be implemented.")
@paddle.no_grad()
def test(self):
"""The test. A subclass should implement this method in Tester.
"""
raise NotImplementedError("test should be implemented.")
@paddle.no_grad()
def export(self):
"""The test. A subclass should implement this method in Tester.
"""
raise NotImplementedError("export should be implemented.")
@paddle.no_grad()
def align(self):
"""The align. A subclass should implement this method in Tester.
"""
raise NotImplementedError("align should be implemented.")
def setup_model(self): def setup_model(self):
"""Setup model, criterion and optimizer, etc. A subclass should """Setup model, criterion and optimizer, etc. A subclass should
implement this method. implement this method.

Loading…
Cancel
Save