From 0ab299a8423463c17c10aca204c03b974708b100 Mon Sep 17 00:00:00 2001 From: Hui Zhang Date: Thu, 19 Aug 2021 02:54:09 +0000 Subject: [PATCH] test bin --- deepspeech/exps/u2_kaldi/bin/test.py | 23 +++++++++++++++-------- deepspeech/exps/u2_kaldi/bin/train.py | 4 ++-- 2 files changed, 17 insertions(+), 10 deletions(-) diff --git a/deepspeech/exps/u2_kaldi/bin/test.py b/deepspeech/exps/u2_kaldi/bin/test.py index 48244a54..06504827 100644 --- a/deepspeech/exps/u2_kaldi/bin/test.py +++ b/deepspeech/exps/u2_kaldi/bin/test.py @@ -14,16 +14,19 @@ """Evaluation for U2 model.""" import cProfile -from deepspeech.exps.u2.model import get_cfg_defaults -from deepspeech.exps.u2.model import U2Tester as Tester from deepspeech.training.cli import default_argument_parser +from deepspeech.utils.dynamic_import import dynamic_import from deepspeech.utils.utility import print_arguments -# TODO(hui zhang): dynamic load +model_alias = { + "u2": "deepspeech.exps.u2.model:U2Tester", + "u2_kaldi": "deepspeech.exps.u2_kaldi.model:U2Tester", +} def main_sp(config, args): - exp = Tester(config, args) + class_obj = dynamic_import(args.model_name, model_alias) + exp = class_obj(config, args) exp.setup() exp.run_test() @@ -34,13 +37,17 @@ def main(config, args): if __name__ == "__main__": parser = default_argument_parser() + parser.add_argument( + '--model-name', + type=str, + default='u2_kaldi', + help='model name, e.g: deepspeech2, u2, u2_kaldi, u2_st') args = parser.parse_args() print_arguments(args, globals()) - # https://yaml.org/type/float.html - config = get_cfg_defaults() - if args.config: - config.merge_from_file(args.config) + config = CfgNode() + config.set_new_allowed(True) + config.merge_from_file(args.config) if args.opts: config.merge_from_list(args.opts) config.freeze() diff --git a/deepspeech/exps/u2_kaldi/bin/train.py b/deepspeech/exps/u2_kaldi/bin/train.py index 45ad3dba..3a240b80 100644 --- a/deepspeech/exps/u2_kaldi/bin/train.py +++ b/deepspeech/exps/u2_kaldi/bin/train.py @@ -29,8 +29,8 @@ model_alias = { def main_sp(config, args): - trainer_cls = dynamic_import(args.model_name, model_alias) - exp = trainer_cls(config, args) + class_obj = dynamic_import(args.model_name, model_alias) + exp = class_obj(config, args) exp.setup() exp.run()