fix some bug and complete the recog.py

pull/930/head
huangyuxin 3 years ago
parent c0295aa131
commit e4a9328c40

@ -29,6 +29,7 @@ from deepspeech.exps import dynamic_import_tester
from deepspeech.io.reader import LoadInputsAndTargets
from deepspeech.models.asr_interface import ASRInterface
from deepspeech.utils.log import Log
from deepspeech.models.lm.transformer import TransformerLM
# from espnet.asr.asr_utils import get_model_conf
# from espnet.asr.asr_utils import torch_load
# from espnet.nets.lm_interface import dynamic_import_lm
@ -83,12 +84,18 @@ def recog_v2(args):
)
if args.rnnlm:
lm_args = get_model_conf(args.rnnlm, args.rnnlm_conf)
# NOTE: for a compatibility with less than 0.5.0 version models
lm_model_module = getattr(lm_args, "model_module", "default")
lm_class = dynamic_import_lm(lm_model_module, lm_args.backend)
lm = lm_class(len(char_list), lm_args)
torch_load(args.rnnlm, lm)
lm_path = args.rnnlm
lm = TransformerLM(
n_vocab=5002,
pos_enc=None,
embed_unit=128,
att_unit=512,
head=8,
unit=2048,
layer=16,
dropout_rate=0.5, )
model_dict = paddle.load(lm_path)
lm.set_state_dict(model_dict)
lm.eval()
else:
lm = None

@ -23,9 +23,9 @@ import paddle.nn.functional as F
from deepspeech.modules.mask import subsequent_mask
from deepspeech.modules.encoder import TransformerEncoder
from deepspeech.decoders.scorers.scorer_interface import BatchScorerInterface
from deepspeech.models.lm_interface import
#LMInterface
from deepspeech.models.lm_interface import LMInterface
import logging
class TransformerLM(nn.Layer, LMInterface, BatchScorerInterface):
def __init__(
self,
@ -36,7 +36,7 @@ class TransformerLM(nn.Layer, LMInterface, BatchScorerInterface):
head: int=2,
unit: int=1024,
layer: int=4,
dropout_rate: float=0.5,
dropout_rate: float=0.5,
emb_dropout_rate: float = 0.0,
att_dropout_rate: float = 0.0,
tie_weights: bool = False,):
@ -84,6 +84,8 @@ class TransformerLM(nn.Layer, LMInterface, BatchScorerInterface):
), "Tie Weights: True need embedding and final dimensions to match"
self.decoder.weight = self.embed.weight
def _target_mask(self, ys_in_pad):
ys_mask = ys_in_pad != 0
m = subsequent_mask(ys_mask.size(-1)).unsqueeze(0)
@ -151,7 +153,7 @@ class TransformerLM(nn.Layer, LMInterface, BatchScorerInterface):
emb, self._target_mask(y), cache=state
)
h = self.decoder(h[:, -1])
logp = h.log_softmax(axis=-1).squeeze(0)
logp = F.log_softmax(h).squeeze(0)
return logp, cache
# batch beam search API (see BatchScorerInterface)
@ -194,7 +196,7 @@ class TransformerLM(nn.Layer, LMInterface, BatchScorerInterface):
emb, self._target_mask(ys), cache=batch_state
)
h = self.decoder(h[:, -1])
logp = h.log_softmax(axi=-1)
logp = F.log_softmax(h)
# transpose state of [layer, batch] into [batch, layer]
state_list = [[states[i][b] for i in range(n_layers)] for b in range(n_batch)]
@ -219,7 +221,7 @@ if __name__ == "__main__":
# head: int=2,
# unit: int=1024,
# layer: int=4,
# dropout_rate: float=0.5,
# dropout_rate: float=0.5,
# emb_dropout_rate: float = 0.0,
# att_dropout_rate: float = 0.0,
# tie_weights: bool = False,):
@ -231,14 +233,14 @@ if __name__ == "__main__":
#Test the score
input2 = np.array([5])
input2 = paddle.to_tensor(input2)
state = (None, None, 0)
state = None
output, state = tlm.score(input2, state, None)
input3 = np.array([10])
input3 = np.array([5,10])
input3 = paddle.to_tensor(input3)
output, state = tlm.score(input3, state, None)
input4 = np.array([0])
input4 = np.array([5,10,0])
input4 = paddle.to_tensor(input4)
output, state = tlm.score(input4, state, None)
print("output", output)
@ -256,4 +258,4 @@ if __name__ == "__main__":
print("output", output)
#print("cache", cache)
#np.save("output_pd.npy", output)
"""
"""

@ -399,7 +399,7 @@ class TransformerEncoder(BaseEncoder):
#TODO(Hui Zhang): self.embed(xs, masks, offset=0), stride_slice not support bool tensor
xs, pos_emb, masks = self.embed(xs, masks.astype(xs.dtype), offset=0)
else:
xs = self.embed(xs)
xs , pos_emb, masks= self.embed(xs, masks.astype(xs.dtype), offset=0)
#TODO(Hui Zhang): remove mask.astype, stride_slice not support bool tensor
masks = masks.astype(paddle.bool)

Loading…
Cancel
Save