fix egs of transformer lm usage

pull/936/head
Hui Zhang 3 years ago
parent eb65793769
commit c89820e7b2

@ -24,6 +24,7 @@ from .utils import add_results_to_json
from deepspeech.exps import dynamic_import_tester from deepspeech.exps import dynamic_import_tester
from deepspeech.io.reader import LoadInputsAndTargets from deepspeech.io.reader import LoadInputsAndTargets
from deepspeech.models.asr_interface import ASRInterface from deepspeech.models.asr_interface import ASRInterface
from deepspeech.models.lm_interface import dynamic_import_lm
from deepspeech.utils.log import Log from deepspeech.utils.log import Log
logger = Log(__name__).getlog() logger = Log(__name__).getlog()
@ -31,11 +32,15 @@ logger = Log(__name__).getlog()
# NOTE: you need this func to generate our sphinx doc # NOTE: you need this func to generate our sphinx doc
def get_config(config_path):
confs = CfgNode(new_allowed=True)
confs.merge_from_file(config_path)
return confs
def load_trained_model(args): def load_trained_model(args):
args.nprocs = args.ngpu args.nprocs = args.ngpu
confs = CfgNode() confs = get_config(args.model_conf)
confs.set_new_allowed(True)
confs.merge_from_file(args.model_conf)
class_obj = dynamic_import_tester(args.model_name) class_obj = dynamic_import_tester(args.model_name)
exp = class_obj(confs, args) exp = class_obj(confs, args)
with exp.eval(): with exp.eval():
@ -46,19 +51,11 @@ def load_trained_model(args):
return model, char_list, exp, confs return model, char_list, exp, confs
def get_config(config_path):
stream = open(config_path, mode='r', encoding="utf-8")
config = yaml.load(stream, Loader=yaml.FullLoader)
stream.close()
return config
def load_trained_lm(args): def load_trained_lm(args):
lm_args = get_config(args.rnnlm_conf) lm_args = get_config(args.rnnlm_conf)
# NOTE: for a compatibility with less than 0.5.0 version models lm_model_module = lm_args.model_module
lm_model_module = getattr(lm_args, "model_module", "default")
lm_class = dynamic_import_lm(lm_model_module) lm_class = dynamic_import_lm(lm_model_module)
lm = lm_class(lm_args.model) lm = lm_class(**lm_args.model)
model_dict = paddle.load(args.rnnlm) model_dict = paddle.load(args.rnnlm)
lm.set_state_dict(model_dict) lm.set_state_dict(model_dict)
return lm return lm

@ -11,9 +11,9 @@ tag=
decode_config=conf/decode/decode.yaml decode_config=conf/decode/decode.yaml
# lm params # lm params
lang_model=transformerLM.pdparams
lmexpdir=exp/lm/transformer
rnnlm_config_path=conf/lm/transformer.yaml rnnlm_config_path=conf/lm/transformer.yaml
lmexpdir=exp/lm
lang_model=rnnlm.pdparams
lmtag='transformer' lmtag='transformer'
train_set=train_960 train_set=train_960
@ -53,6 +53,9 @@ if [[ ${config_path} =~ ^.*chunk_.*yaml$ ]];then
fi fi
echo "chunk mode: ${chunk_mode}" echo "chunk mode: ${chunk_mode}"
echo "decode conf: ${decode_config}" echo "decode conf: ${decode_config}"
echo "lm conf: ${rnnlm_config_path}"
echo "lm model: ${lmexpdir}/${lang_model}"
# download language model # download language model
#bash local/download_lm_en.sh #bash local/download_lm_en.sh
@ -61,6 +64,13 @@ echo "decode conf: ${decode_config}"
#fi #fi
# download rnnlm
mkdir -p ${lmexpdir}
if [ ! -f ${lmexpdir}/${lang_model} ]; then
wget -c -O ${lmexpdir}/${lang_model} https://deepspeech.bj.bcebos.com/transformer_lm/transformerLM.pdparams
fi
pids=() # initialize pids pids=() # initialize pids
for dmethd in join_ctc; do for dmethd in join_ctc; do

@ -37,12 +37,9 @@ if [ ${stage} -le 3 ] && [ ${stop_stage} -ge 3 ]; then
./local/test.sh ${conf_path} ${dict_path} exp/${ckpt}/checkpoints/${avg_ckpt} || exit -1 ./local/test.sh ${conf_path} ${dict_path} exp/${ckpt}/checkpoints/${avg_ckpt} || exit -1
fi fi
if [ ${stage} -le 4 ] && [ ${stop_stage} -ge 4 ] && ${use_lm} == true; then if [ ${stage} -le 4 ] && [ ${stop_stage} -ge 4 ]; then
# join ctc decoder, use transformerlm to score # join ctc decoder, use transformerlm to score
if [ ! -f exp/lm/transformer/transformerLM.pdparams ]; then ./local/recog.sh --ckpt_prefix exp/${ckpt}/checkpoints/${avg_ckpt}
wget https://deepspeech.bj.bcebos.com/transformer_lm/transformerLM.pdparams exp/lm/transformer/
fi
bash local/recog.sh --ckpt_prefix exp/${ckpt}/checkpoints/${avg_ckpt}
fi fi
if [ ${stage} -le 5 ] && [ ${stop_stage} -ge 5 ]; then if [ ${stage} -le 5 ] && [ ${stop_stage} -ge 5 ]; then

Loading…
Cancel
Save