pull/768/head
Hui Zhang 3 years ago
parent ab23eb5710
commit 0ab299a842

@ -14,16 +14,19 @@
"""Evaluation for U2 model.""" """Evaluation for U2 model."""
import cProfile import cProfile
from deepspeech.exps.u2.model import get_cfg_defaults
from deepspeech.exps.u2.model import U2Tester as Tester
from deepspeech.training.cli import default_argument_parser from deepspeech.training.cli import default_argument_parser
from deepspeech.utils.dynamic_import import dynamic_import
from deepspeech.utils.utility import print_arguments from deepspeech.utils.utility import print_arguments
# TODO(hui zhang): dynamic load model_alias = {
"u2": "deepspeech.exps.u2.model:U2Tester",
"u2_kaldi": "deepspeech.exps.u2_kaldi.model:U2Tester",
}
def main_sp(config, args): def main_sp(config, args):
exp = Tester(config, args) class_obj = dynamic_import(args.model_name, model_alias)
exp = class_obj(config, args)
exp.setup() exp.setup()
exp.run_test() exp.run_test()
@ -34,13 +37,17 @@ def main(config, args):
if __name__ == "__main__": if __name__ == "__main__":
parser = default_argument_parser() parser = default_argument_parser()
parser.add_argument(
'--model-name',
type=str,
default='u2_kaldi',
help='model name, e.g: deepspeech2, u2, u2_kaldi, u2_st')
args = parser.parse_args() args = parser.parse_args()
print_arguments(args, globals()) print_arguments(args, globals())
# https://yaml.org/type/float.html config = CfgNode()
config = get_cfg_defaults() config.set_new_allowed(True)
if args.config: config.merge_from_file(args.config)
config.merge_from_file(args.config)
if args.opts: if args.opts:
config.merge_from_list(args.opts) config.merge_from_list(args.opts)
config.freeze() config.freeze()

@ -29,8 +29,8 @@ model_alias = {
def main_sp(config, args): def main_sp(config, args):
trainer_cls = dynamic_import(args.model_name, model_alias) class_obj = dynamic_import(args.model_name, model_alias)
exp = trainer_cls(config, args) exp = class_obj(config, args)
exp.setup() exp.setup()
exp.run() exp.run()

Loading…
Cancel
Save