fix egs of transformer lm usage

pull/936/head
Hui Zhang 3 years ago
parent eb65793769
commit c89820e7b2

@ -24,6 +24,7 @@ from .utils import add_results_to_json
from deepspeech.exps import dynamic_import_tester
from deepspeech.io.reader import LoadInputsAndTargets
from deepspeech.models.asr_interface import ASRInterface
from deepspeech.models.lm_interface import dynamic_import_lm
from deepspeech.utils.log import Log
logger = Log(__name__).getlog()
@ -31,11 +32,15 @@ logger = Log(__name__).getlog()
# NOTE: you need this func to generate our sphinx doc
def get_config(config_path):
confs = CfgNode(new_allowed=True)
confs.merge_from_file(config_path)
return confs
def load_trained_model(args):
args.nprocs = args.ngpu
confs = CfgNode()
confs.set_new_allowed(True)
confs.merge_from_file(args.model_conf)
confs = get_config(args.model_conf)
class_obj = dynamic_import_tester(args.model_name)
exp = class_obj(confs, args)
with exp.eval():
@ -46,19 +51,11 @@ def load_trained_model(args):
return model, char_list, exp, confs
def get_config(config_path):
stream = open(config_path, mode='r', encoding="utf-8")
config = yaml.load(stream, Loader=yaml.FullLoader)
stream.close()
return config
def load_trained_lm(args):
lm_args = get_config(args.rnnlm_conf)
# NOTE: for a compatibility with less than 0.5.0 version models
lm_model_module = getattr(lm_args, "model_module", "default")
lm_model_module = lm_args.model_module
lm_class = dynamic_import_lm(lm_model_module)
lm = lm_class(lm_args.model)
lm = lm_class(**lm_args.model)
model_dict = paddle.load(args.rnnlm)
lm.set_state_dict(model_dict)
return lm

@ -11,9 +11,9 @@ tag=
decode_config=conf/decode/decode.yaml
# lm params
lang_model=transformerLM.pdparams
lmexpdir=exp/lm/transformer
rnnlm_config_path=conf/lm/transformer.yaml
lmexpdir=exp/lm
lang_model=rnnlm.pdparams
lmtag='transformer'
train_set=train_960
@ -53,6 +53,9 @@ if [[ ${config_path} =~ ^.*chunk_.*yaml$ ]];then
fi
echo "chunk mode: ${chunk_mode}"
echo "decode conf: ${decode_config}"
echo "lm conf: ${rnnlm_config_path}"
echo "lm model: ${lmexpdir}/${lang_model}"
# download language model
#bash local/download_lm_en.sh
@ -61,6 +64,13 @@ echo "decode conf: ${decode_config}"
#fi
# download rnnlm
mkdir -p ${lmexpdir}
if [ ! -f ${lmexpdir}/${lang_model} ]; then
wget -c -O ${lmexpdir}/${lang_model} https://deepspeech.bj.bcebos.com/transformer_lm/transformerLM.pdparams
fi
pids=() # initialize pids
for dmethd in join_ctc; do

@ -37,12 +37,9 @@ if [ ${stage} -le 3 ] && [ ${stop_stage} -ge 3 ]; then
./local/test.sh ${conf_path} ${dict_path} exp/${ckpt}/checkpoints/${avg_ckpt} || exit -1
fi
if [ ${stage} -le 4 ] && [ ${stop_stage} -ge 4 ] && ${use_lm} == true; then
if [ ${stage} -le 4 ] && [ ${stop_stage} -ge 4 ]; then
# join ctc decoder, use transformerlm to score
if [ ! -f exp/lm/transformer/transformerLM.pdparams ]; then
wget https://deepspeech.bj.bcebos.com/transformer_lm/transformerLM.pdparams exp/lm/transformer/
fi
bash local/recog.sh --ckpt_prefix exp/${ckpt}/checkpoints/${avg_ckpt}
./local/recog.sh --ckpt_prefix exp/${ckpt}/checkpoints/${avg_ckpt}
fi
if [ ${stage} -le 5 ] && [ ${stop_stage} -ge 5 ]; then

Loading…
Cancel
Save