add transformerLM config path and dynamic import

pull/931/head
huangyuxin 4 years ago
parent 871fc5b70d
commit ac5824ef1c

@ -14,6 +14,7 @@
"""V2 backend for `asr_recog.py` using py:class:`decoders.beam_search.BeamSearch`."""
import jsonlines
import paddle
import yaml
from yacs.config import CfgNode
from .beam_search import BatchBeamSearch
@ -25,6 +26,7 @@ from deepspeech.exps import dynamic_import_tester
from deepspeech.io.reader import LoadInputsAndTargets
from deepspeech.models.asr_interface import ASRInterface
from deepspeech.models.lm.transformer import TransformerLM
from deepspeech.models.lm_interface import dynamic_import_lm
from deepspeech.utils.log import Log
# from espnet.asr.asr_utils import get_model_conf
# from espnet.asr.asr_utils import torch_load
@ -80,15 +82,12 @@ def recog_v2(args):
if args.rnnlm:
lm_path = args.rnnlm
lm = TransformerLM(
n_vocab=5002,
pos_enc=None,
embed_unit=128,
att_unit=512,
head=8,
unit=2048,
layer=16,
dropout_rate=0.5, )
lm_config_path = args.rnnlm_conf
stream = open(lm_config_path, mode='r', encoding="utf-8")
lm_config = yaml.load(stream, Loader=yaml.FullLoader)
stream.close()
lm_class = dynamic_import_lm("transformer")
lm = lm_class(**lm_config)
model_dict = paddle.load(lm_path)
lm.set_state_dict(model_dict)
lm.eval()

@ -0,0 +1,11 @@
n_vocab: 5002
pos_enc: null
embed_unit: 128
att_unit: 512
head: 8
unit: 2048
layer: 16
dropout_rate: 0.5
emb_dropout_rate: 0.0
att_dropout_rate: 0.0
tie_weights: False

@ -11,9 +11,11 @@ tag=
decode_config=conf/decode/decode.yaml
# lm params
lang_model=rnnlm.model.best
lmexpdir=exp/train_rnnlm_pytorch_lm_transformer_cosine_batchsize32_lr1e-4_layer16_unigram5000_ngpu4/
lang_model=transformerLM.pdparams
lmexpdir=exp/transformerLM
lmtag='nolm'
rnnlm_config_path=conf/lm/transformer.yaml
recog_set="test-clean test-other dev-clean dev-other"
recog_set="test-clean"
@ -27,7 +29,7 @@ bpemodel=${bpeprefix}.model
# bin params
config_path=conf/transformer.yaml
dict=data/bpe_unigram_5000_units.txt
ckpt_prefix=
ckpt_prefix=exp/avg_10
source ${MAIN_ROOT}/utils/parse_options.sh || exit 1;
@ -90,9 +92,9 @@ for dmethd in join_ctc; do
--recog-json ${feat_recog_dir}/split${nj}/JOB/manifest.${rtask} \
--result-label ${decode_dir}/data.JOB.json \
--model-conf ${config_path} \
--model ${ckpt_prefix}.pdparams
#--rnnlm ${lmexpdir}/${lang_model} \
--model ${ckpt_prefix}.pdparams \
--rnnlm-conf ${rnnlm_config_path} \
--rnnlm ${lmexpdir}/${lang_model} \
score_sclite.sh --bpe ${nbpe} --bpemodel ${bpemodel} --wer false ${decode_dir} ${dict}

Loading…
Cancel
Save