|
|
@ -14,6 +14,7 @@
|
|
|
|
"""V2 backend for `asr_recog.py` using py:class:`decoders.beam_search.BeamSearch`."""
|
|
|
|
"""V2 backend for `asr_recog.py` using py:class:`decoders.beam_search.BeamSearch`."""
|
|
|
|
import jsonlines
|
|
|
|
import jsonlines
|
|
|
|
import paddle
|
|
|
|
import paddle
|
|
|
|
|
|
|
|
import yaml
|
|
|
|
from yacs.config import CfgNode
|
|
|
|
from yacs.config import CfgNode
|
|
|
|
|
|
|
|
|
|
|
|
from .beam_search import BatchBeamSearch
|
|
|
|
from .beam_search import BatchBeamSearch
|
|
|
@ -25,6 +26,7 @@ from deepspeech.exps import dynamic_import_tester
|
|
|
|
from deepspeech.io.reader import LoadInputsAndTargets
|
|
|
|
from deepspeech.io.reader import LoadInputsAndTargets
|
|
|
|
from deepspeech.models.asr_interface import ASRInterface
|
|
|
|
from deepspeech.models.asr_interface import ASRInterface
|
|
|
|
from deepspeech.models.lm.transformer import TransformerLM
|
|
|
|
from deepspeech.models.lm.transformer import TransformerLM
|
|
|
|
|
|
|
|
from deepspeech.models.lm_interface import dynamic_import_lm
|
|
|
|
from deepspeech.utils.log import Log
|
|
|
|
from deepspeech.utils.log import Log
|
|
|
|
# from espnet.asr.asr_utils import get_model_conf
|
|
|
|
# from espnet.asr.asr_utils import get_model_conf
|
|
|
|
# from espnet.asr.asr_utils import torch_load
|
|
|
|
# from espnet.asr.asr_utils import torch_load
|
|
|
@ -80,15 +82,12 @@ def recog_v2(args):
|
|
|
|
|
|
|
|
|
|
|
|
if args.rnnlm:
|
|
|
|
if args.rnnlm:
|
|
|
|
lm_path = args.rnnlm
|
|
|
|
lm_path = args.rnnlm
|
|
|
|
lm = TransformerLM(
|
|
|
|
lm_config_path = args.rnnlm_conf
|
|
|
|
n_vocab=5002,
|
|
|
|
stream = open(lm_config_path, mode='r', encoding="utf-8")
|
|
|
|
pos_enc=None,
|
|
|
|
lm_config = yaml.load(stream, Loader=yaml.FullLoader)
|
|
|
|
embed_unit=128,
|
|
|
|
stream.close()
|
|
|
|
att_unit=512,
|
|
|
|
lm_class = dynamic_import_lm("transformer")
|
|
|
|
head=8,
|
|
|
|
lm = lm_class(**lm_config)
|
|
|
|
unit=2048,
|
|
|
|
|
|
|
|
layer=16,
|
|
|
|
|
|
|
|
dropout_rate=0.5, )
|
|
|
|
|
|
|
|
model_dict = paddle.load(lm_path)
|
|
|
|
model_dict = paddle.load(lm_path)
|
|
|
|
lm.set_state_dict(model_dict)
|
|
|
|
lm.set_state_dict(model_dict)
|
|
|
|
lm.eval()
|
|
|
|
lm.eval()
|
|
|
|