|
|
@ -24,6 +24,7 @@ from .utils import add_results_to_json
|
|
|
|
from deepspeech.exps import dynamic_import_tester
|
|
|
|
from deepspeech.exps import dynamic_import_tester
|
|
|
|
from deepspeech.io.reader import LoadInputsAndTargets
|
|
|
|
from deepspeech.io.reader import LoadInputsAndTargets
|
|
|
|
from deepspeech.models.asr_interface import ASRInterface
|
|
|
|
from deepspeech.models.asr_interface import ASRInterface
|
|
|
|
|
|
|
|
from deepspeech.models.lm_interface import dynamic_import_lm
|
|
|
|
from deepspeech.utils.log import Log
|
|
|
|
from deepspeech.utils.log import Log
|
|
|
|
|
|
|
|
|
|
|
|
logger = Log(__name__).getlog()
|
|
|
|
logger = Log(__name__).getlog()
|
|
|
@ -31,11 +32,15 @@ logger = Log(__name__).getlog()
|
|
|
|
# NOTE: you need this func to generate our sphinx doc
|
|
|
|
# NOTE: you need this func to generate our sphinx doc
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
def get_config(config_path):
|
|
|
|
|
|
|
|
confs = CfgNode(new_allowed=True)
|
|
|
|
|
|
|
|
confs.merge_from_file(config_path)
|
|
|
|
|
|
|
|
return confs
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
def load_trained_model(args):
|
|
|
|
def load_trained_model(args):
|
|
|
|
args.nprocs = args.ngpu
|
|
|
|
args.nprocs = args.ngpu
|
|
|
|
confs = CfgNode()
|
|
|
|
confs = get_config(args.model_conf)
|
|
|
|
confs.set_new_allowed(True)
|
|
|
|
|
|
|
|
confs.merge_from_file(args.model_conf)
|
|
|
|
|
|
|
|
class_obj = dynamic_import_tester(args.model_name)
|
|
|
|
class_obj = dynamic_import_tester(args.model_name)
|
|
|
|
exp = class_obj(confs, args)
|
|
|
|
exp = class_obj(confs, args)
|
|
|
|
with exp.eval():
|
|
|
|
with exp.eval():
|
|
|
@ -46,19 +51,11 @@ def load_trained_model(args):
|
|
|
|
return model, char_list, exp, confs
|
|
|
|
return model, char_list, exp, confs
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
def get_config(config_path):
|
|
|
|
|
|
|
|
stream = open(config_path, mode='r', encoding="utf-8")
|
|
|
|
|
|
|
|
config = yaml.load(stream, Loader=yaml.FullLoader)
|
|
|
|
|
|
|
|
stream.close()
|
|
|
|
|
|
|
|
return config
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
def load_trained_lm(args):
|
|
|
|
def load_trained_lm(args):
|
|
|
|
lm_args = get_config(args.rnnlm_conf)
|
|
|
|
lm_args = get_config(args.rnnlm_conf)
|
|
|
|
# NOTE: for a compatibility with less than 0.5.0 version models
|
|
|
|
lm_model_module = lm_args.model_module
|
|
|
|
lm_model_module = getattr(lm_args, "model_module", "default")
|
|
|
|
|
|
|
|
lm_class = dynamic_import_lm(lm_model_module)
|
|
|
|
lm_class = dynamic_import_lm(lm_model_module)
|
|
|
|
lm = lm_class(lm_args.model)
|
|
|
|
lm = lm_class(**lm_args.model)
|
|
|
|
model_dict = paddle.load(args.rnnlm)
|
|
|
|
model_dict = paddle.load(args.rnnlm)
|
|
|
|
lm.set_state_dict(model_dict)
|
|
|
|
lm.set_state_dict(model_dict)
|
|
|
|
return lm
|
|
|
|
return lm
|
|
|
|