diff --git a/deepspeech/decoders/recog.py b/deepspeech/decoders/recog.py index 6dea6b701..dae3cd429 100644 --- a/deepspeech/decoders/recog.py +++ b/deepspeech/decoders/recog.py @@ -24,6 +24,7 @@ from .utils import add_results_to_json from deepspeech.exps import dynamic_import_tester from deepspeech.io.reader import LoadInputsAndTargets from deepspeech.models.asr_interface import ASRInterface +from deepspeech.models.lm.transformer import TransformerLM from deepspeech.utils.log import Log # from espnet.asr.asr_utils import get_model_conf # from espnet.asr.asr_utils import torch_load @@ -78,12 +79,18 @@ def recog_v2(args): preprocess_args={"train": False}, ) if args.rnnlm: - lm_args = get_model_conf(args.rnnlm, args.rnnlm_conf) - # NOTE: for a compatibility with less than 0.5.0 version models - lm_model_module = getattr(lm_args, "model_module", "default") - lm_class = dynamic_import_lm(lm_model_module, lm_args.backend) - lm = lm_class(len(char_list), lm_args) - torch_load(args.rnnlm, lm) + lm_path = args.rnnlm + lm = TransformerLM( + n_vocab=5002, + pos_enc=None, + embed_unit=128, + att_unit=512, + head=8, + unit=2048, + layer=16, + dropout_rate=0.5, ) + model_dict = paddle.load(lm_path) + lm.set_state_dict(model_dict) lm.eval() else: lm = None diff --git a/deepspeech/models/lm/transformer.py b/deepspeech/models/lm/transformer.py index 467c4ab90..3f5a76c52 100644 --- a/deepspeech/models/lm/transformer.py +++ b/deepspeech/models/lm/transformer.py @@ -11,6 +11,7 @@ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. +import logging from typing import Any from typing import List from typing import Tuple @@ -150,7 +151,7 @@ class TransformerLM(nn.Layer, LMInterface, BatchScorerInterface): h, _, cache = self.encoder.forward_one_step( emb, self._target_mask(y), cache=state) h = self.decoder(h[:, -1]) - logp = h.log_softmax(axis=-1).squeeze(0) + logp = F.log_softmax(h).squeeze(0) return logp, cache # batch beam search API (see BatchScorerInterface) @@ -193,7 +194,7 @@ class TransformerLM(nn.Layer, LMInterface, BatchScorerInterface): h, _, states = self.encoder.forward_one_step( emb, self._target_mask(ys), cache=batch_state) h = self.decoder(h[:, -1]) - logp = h.log_softmax(axi=-1) + logp = F.log_softmax(h) # transpose state of [layer, batch] into [batch, layer] state_list = [[states[i][b] for i in range(n_layers)] @@ -219,7 +220,7 @@ if __name__ == "__main__": # head: int=2, # unit: int=1024, # layer: int=4, - # dropout_rate: float=0.5, + # dropout_rate: float=0.5, # emb_dropout_rate: float = 0.0, # att_dropout_rate: float = 0.0, # tie_weights: bool = False,): @@ -231,14 +232,14 @@ if __name__ == "__main__": #Test the score input2 = np.array([5]) input2 = paddle.to_tensor(input2) - state = (None, None, 0) + state = None output, state = tlm.score(input2, state, None) - input3 = np.array([10]) + input3 = np.array([5, 10]) input3 = paddle.to_tensor(input3) output, state = tlm.score(input3, state, None) - input4 = np.array([0]) + input4 = np.array([5, 10, 0]) input4 = paddle.to_tensor(input4) output, state = tlm.score(input4, state, None) print("output", output) diff --git a/deepspeech/modules/encoder.py b/deepspeech/modules/encoder.py index f2c269883..794117712 100644 --- a/deepspeech/modules/encoder.py +++ b/deepspeech/modules/encoder.py @@ -399,7 +399,8 @@ class TransformerEncoder(BaseEncoder): xs, pos_emb, masks = self.embed( xs, masks.astype(xs.dtype), offset=0) else: - xs = self.embed(xs) + xs, pos_emb, masks = self.embed( + xs, masks.astype(xs.dtype), offset=0) #TODO(Hui Zhang): remove mask.astype, stride_slice not support bool tensor masks = masks.astype(paddle.bool)