# Copyright (c) 2021 PaddlePaddle Authors. All Rights Reserved. # # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. # You may obtain a copy of the License at # # http://www.apache.org/licenses/LICENSE-2.0 # # Unless required by applicable law or agreed to in writing, software # distributed under the License is distributed on an "AS IS" BASIS, # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. # Modified from espnet(https://github.com/espnet/espnet) """Ngram lm implement.""" from abc import ABC import kenlm import paddle from .scorer_interface import BatchScorerInterface from .scorer_interface import PartialScorerInterface class Ngrambase(ABC): """Ngram base implemented through ScorerInterface.""" def __init__(self, ngram_model, token_list): """Initialize Ngrambase. Args: ngram_model: ngram model path token_list: token list from dict or model.json """ self.chardict = [x if x != "" else "" for x in token_list] self.charlen = len(self.chardict) self.lm = kenlm.LanguageModel(ngram_model) self.tmpkenlmstate = kenlm.State() def init_state(self, x): """Initialize tmp state.""" state = kenlm.State() self.lm.NullContextWrite(state) return state def score_partial_(self, y, next_token, state, x): """Score interface for both full and partial scorer. Args: y: previous char next_token: next token need to be score state: previous state x: encoded feature Returns: tuple[paddle.Tensor, List[Any]]: Tuple of batchfied scores for next token with shape of `(n_batch, n_vocab)` and next state list for ys. """ out_state = kenlm.State() ys = self.chardict[y[-1]] if y.shape[0] > 1 else "" self.lm.BaseScore(state, ys, out_state) scores = paddle.empty_like(next_token, dtype=x.dtype) for i, j in enumerate(next_token): scores[i] = self.lm.BaseScore(out_state, self.chardict[j], self.tmpkenlmstate) return scores, out_state class NgramFullScorer(Ngrambase, BatchScorerInterface): """Fullscorer for ngram.""" def score(self, y, state, x): """Score interface for both full and partial scorer. Args: y: previous char state: previous state x: encoded feature Returns: tuple[paddle.Tensor, List[Any]]: Tuple of batchfied scores for next token with shape of `(n_batch, n_vocab)` and next state list for ys. """ return self.score_partial_(y, paddle.to_tensor(range(self.charlen)), state, x) class NgramPartScorer(Ngrambase, PartialScorerInterface): """Partialscorer for ngram.""" def score_partial(self, y, next_token, state, x): """Score interface for both full and partial scorer. Args: y: previous char next_token: next token need to be score state: previous state x: encoded feature Returns: tuple[paddle.Tensor, List[Any]]: Tuple of batchfied scores for next token with shape of `(n_batch, n_vocab)` and next state list for ys. """ return self.score_partial_(y, next_token, state, x) def select_state(self, state, i): """Empty select state for scorer interface.""" return state