From ac5824ef1cb7a76ebdf952f039caf6626d91159f Mon Sep 17 00:00:00 2001 From: huangyuxin Date: Mon, 25 Oct 2021 09:06:21 +0000 Subject: [PATCH] add transformerLM config path and dynamic import --- deepspeech/decoders/recog.py | 17 ++++++++--------- .../librispeech/s2/conf/lm/transformer.yaml | 11 +++++++++++ examples/librispeech/s2/local/recog.sh | 14 ++++++++------ 3 files changed, 27 insertions(+), 15 deletions(-) create mode 100644 examples/librispeech/s2/conf/lm/transformer.yaml diff --git a/deepspeech/decoders/recog.py b/deepspeech/decoders/recog.py index dae3cd429..5eef20c5d 100644 --- a/deepspeech/decoders/recog.py +++ b/deepspeech/decoders/recog.py @@ -14,6 +14,7 @@ """V2 backend for `asr_recog.py` using py:class:`decoders.beam_search.BeamSearch`.""" import jsonlines import paddle +import yaml from yacs.config import CfgNode from .beam_search import BatchBeamSearch @@ -25,6 +26,7 @@ from deepspeech.exps import dynamic_import_tester from deepspeech.io.reader import LoadInputsAndTargets from deepspeech.models.asr_interface import ASRInterface from deepspeech.models.lm.transformer import TransformerLM +from deepspeech.models.lm_interface import dynamic_import_lm from deepspeech.utils.log import Log # from espnet.asr.asr_utils import get_model_conf # from espnet.asr.asr_utils import torch_load @@ -80,15 +82,12 @@ def recog_v2(args): if args.rnnlm: lm_path = args.rnnlm - lm = TransformerLM( - n_vocab=5002, - pos_enc=None, - embed_unit=128, - att_unit=512, - head=8, - unit=2048, - layer=16, - dropout_rate=0.5, ) + lm_config_path = args.rnnlm_conf + stream = open(lm_config_path, mode='r', encoding="utf-8") + lm_config = yaml.load(stream, Loader=yaml.FullLoader) + stream.close() + lm_class = dynamic_import_lm("transformer") + lm = lm_class(**lm_config) model_dict = paddle.load(lm_path) lm.set_state_dict(model_dict) lm.eval() diff --git a/examples/librispeech/s2/conf/lm/transformer.yaml b/examples/librispeech/s2/conf/lm/transformer.yaml new file mode 100644 index 000000000..a9f069484 --- /dev/null +++ b/examples/librispeech/s2/conf/lm/transformer.yaml @@ -0,0 +1,11 @@ +n_vocab: 5002 +pos_enc: null +embed_unit: 128 +att_unit: 512 +head: 8 +unit: 2048 +layer: 16 +dropout_rate: 0.5 +emb_dropout_rate: 0.0 +att_dropout_rate: 0.0 +tie_weights: False diff --git a/examples/librispeech/s2/local/recog.sh b/examples/librispeech/s2/local/recog.sh index df3846c02..f377950e9 100755 --- a/examples/librispeech/s2/local/recog.sh +++ b/examples/librispeech/s2/local/recog.sh @@ -11,9 +11,11 @@ tag= decode_config=conf/decode/decode.yaml # lm params -lang_model=rnnlm.model.best -lmexpdir=exp/train_rnnlm_pytorch_lm_transformer_cosine_batchsize32_lr1e-4_layer16_unigram5000_ngpu4/ +lang_model=transformerLM.pdparams +lmexpdir=exp/transformerLM lmtag='nolm' +rnnlm_config_path=conf/lm/transformer.yaml + recog_set="test-clean test-other dev-clean dev-other" recog_set="test-clean" @@ -27,7 +29,7 @@ bpemodel=${bpeprefix}.model # bin params config_path=conf/transformer.yaml dict=data/bpe_unigram_5000_units.txt -ckpt_prefix= +ckpt_prefix=exp/avg_10 source ${MAIN_ROOT}/utils/parse_options.sh || exit 1; @@ -90,9 +92,9 @@ for dmethd in join_ctc; do --recog-json ${feat_recog_dir}/split${nj}/JOB/manifest.${rtask} \ --result-label ${decode_dir}/data.JOB.json \ --model-conf ${config_path} \ - --model ${ckpt_prefix}.pdparams - - #--rnnlm ${lmexpdir}/${lang_model} \ + --model ${ckpt_prefix}.pdparams \ + --rnnlm-conf ${rnnlm_config_path} \ + --rnnlm ${lmexpdir}/${lang_model} \ score_sclite.sh --bpe ${nbpe} --bpemodel ${bpemodel} --wer false ${decode_dir} ${dict}