add transformerLM config path and dynamic import

pull/931/head
huangyuxin 4 years ago
parent 871fc5b70d
commit ac5824ef1c

@ -14,6 +14,7 @@
"""V2 backend for `asr_recog.py` using py:class:`decoders.beam_search.BeamSearch`.""" """V2 backend for `asr_recog.py` using py:class:`decoders.beam_search.BeamSearch`."""
import jsonlines import jsonlines
import paddle import paddle
import yaml
from yacs.config import CfgNode from yacs.config import CfgNode
from .beam_search import BatchBeamSearch from .beam_search import BatchBeamSearch
@ -25,6 +26,7 @@ from deepspeech.exps import dynamic_import_tester
from deepspeech.io.reader import LoadInputsAndTargets from deepspeech.io.reader import LoadInputsAndTargets
from deepspeech.models.asr_interface import ASRInterface from deepspeech.models.asr_interface import ASRInterface
from deepspeech.models.lm.transformer import TransformerLM from deepspeech.models.lm.transformer import TransformerLM
from deepspeech.models.lm_interface import dynamic_import_lm
from deepspeech.utils.log import Log from deepspeech.utils.log import Log
# from espnet.asr.asr_utils import get_model_conf # from espnet.asr.asr_utils import get_model_conf
# from espnet.asr.asr_utils import torch_load # from espnet.asr.asr_utils import torch_load
@ -80,15 +82,12 @@ def recog_v2(args):
if args.rnnlm: if args.rnnlm:
lm_path = args.rnnlm lm_path = args.rnnlm
lm = TransformerLM( lm_config_path = args.rnnlm_conf
n_vocab=5002, stream = open(lm_config_path, mode='r', encoding="utf-8")
pos_enc=None, lm_config = yaml.load(stream, Loader=yaml.FullLoader)
embed_unit=128, stream.close()
att_unit=512, lm_class = dynamic_import_lm("transformer")
head=8, lm = lm_class(**lm_config)
unit=2048,
layer=16,
dropout_rate=0.5, )
model_dict = paddle.load(lm_path) model_dict = paddle.load(lm_path)
lm.set_state_dict(model_dict) lm.set_state_dict(model_dict)
lm.eval() lm.eval()

@ -0,0 +1,11 @@
n_vocab: 5002
pos_enc: null
embed_unit: 128
att_unit: 512
head: 8
unit: 2048
layer: 16
dropout_rate: 0.5
emb_dropout_rate: 0.0
att_dropout_rate: 0.0
tie_weights: False

@ -11,9 +11,11 @@ tag=
decode_config=conf/decode/decode.yaml decode_config=conf/decode/decode.yaml
# lm params # lm params
lang_model=rnnlm.model.best lang_model=transformerLM.pdparams
lmexpdir=exp/train_rnnlm_pytorch_lm_transformer_cosine_batchsize32_lr1e-4_layer16_unigram5000_ngpu4/ lmexpdir=exp/transformerLM
lmtag='nolm' lmtag='nolm'
rnnlm_config_path=conf/lm/transformer.yaml
recog_set="test-clean test-other dev-clean dev-other" recog_set="test-clean test-other dev-clean dev-other"
recog_set="test-clean" recog_set="test-clean"
@ -27,7 +29,7 @@ bpemodel=${bpeprefix}.model
# bin params # bin params
config_path=conf/transformer.yaml config_path=conf/transformer.yaml
dict=data/bpe_unigram_5000_units.txt dict=data/bpe_unigram_5000_units.txt
ckpt_prefix= ckpt_prefix=exp/avg_10
source ${MAIN_ROOT}/utils/parse_options.sh || exit 1; source ${MAIN_ROOT}/utils/parse_options.sh || exit 1;
@ -90,9 +92,9 @@ for dmethd in join_ctc; do
--recog-json ${feat_recog_dir}/split${nj}/JOB/manifest.${rtask} \ --recog-json ${feat_recog_dir}/split${nj}/JOB/manifest.${rtask} \
--result-label ${decode_dir}/data.JOB.json \ --result-label ${decode_dir}/data.JOB.json \
--model-conf ${config_path} \ --model-conf ${config_path} \
--model ${ckpt_prefix}.pdparams --model ${ckpt_prefix}.pdparams \
--rnnlm-conf ${rnnlm_config_path} \
#--rnnlm ${lmexpdir}/${lang_model} \ --rnnlm ${lmexpdir}/${lang_model} \
score_sclite.sh --bpe ${nbpe} --bpemodel ${bpemodel} --wer false ${decode_dir} ${dict} score_sclite.sh --bpe ${nbpe} --bpemodel ${bpemodel} --wer false ${decode_dir} ${dict}

Loading…
Cancel
Save