You can not select more than 25 topics
Topics must start with a letter or number, can include dashes ('-') and can be up to 35 characters long.
115 lines
3.6 KiB
115 lines
3.6 KiB
# Copyright (c) 2021 PaddlePaddle Authors. All Rights Reserved.
|
|
#
|
|
# Licensed under the Apache License, Version 2.0 (the "License");
|
|
# you may not use this file except in compliance with the License.
|
|
# You may obtain a copy of the License at
|
|
#
|
|
# http://www.apache.org/licenses/LICENSE-2.0
|
|
#
|
|
# Unless required by applicable law or agreed to in writing, software
|
|
# distributed under the License is distributed on an "AS IS" BASIS,
|
|
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
|
# See the License for the specific language governing permissions and
|
|
# limitations under the License.
|
|
"""Ngram lm implement."""
|
|
from abc import ABC
|
|
|
|
import kenlm
|
|
import paddle
|
|
|
|
from .scorer_interface import BatchScorerInterface
|
|
from .scorer_interface import PartialScorerInterface
|
|
|
|
|
|
class Ngrambase(ABC):
|
|
"""Ngram base implemented through ScorerInterface."""
|
|
|
|
def __init__(self, ngram_model, token_list):
|
|
"""Initialize Ngrambase.
|
|
|
|
Args:
|
|
ngram_model: ngram model path
|
|
token_list: token list from dict or model.json
|
|
|
|
"""
|
|
self.chardict = [x if x != "<eos>" else "</s>" for x in token_list]
|
|
self.charlen = len(self.chardict)
|
|
self.lm = kenlm.LanguageModel(ngram_model)
|
|
self.tmpkenlmstate = kenlm.State()
|
|
|
|
def init_state(self, x):
|
|
"""Initialize tmp state."""
|
|
state = kenlm.State()
|
|
self.lm.NullContextWrite(state)
|
|
return state
|
|
|
|
def score_partial_(self, y, next_token, state, x):
|
|
"""Score interface for both full and partial scorer.
|
|
|
|
Args:
|
|
y: previous char
|
|
next_token: next token need to be score
|
|
state: previous state
|
|
x: encoded feature
|
|
|
|
Returns:
|
|
tuple[paddle.Tensor, List[Any]]: Tuple of
|
|
batchfied scores for next token with shape of `(n_batch, n_vocab)`
|
|
and next state list for ys.
|
|
|
|
"""
|
|
out_state = kenlm.State()
|
|
ys = self.chardict[y[-1]] if y.shape[0] > 1 else "<s>"
|
|
self.lm.BaseScore(state, ys, out_state)
|
|
scores = paddle.empty_like(next_token, dtype=x.dtype)
|
|
for i, j in enumerate(next_token):
|
|
scores[i] = self.lm.BaseScore(out_state, self.chardict[j],
|
|
self.tmpkenlmstate)
|
|
return scores, out_state
|
|
|
|
|
|
class NgramFullScorer(Ngrambase, BatchScorerInterface):
|
|
"""Fullscorer for ngram."""
|
|
|
|
def score(self, y, state, x):
|
|
"""Score interface for both full and partial scorer.
|
|
|
|
Args:
|
|
y: previous char
|
|
state: previous state
|
|
x: encoded feature
|
|
|
|
Returns:
|
|
tuple[paddle.Tensor, List[Any]]: Tuple of
|
|
batchfied scores for next token with shape of `(n_batch, n_vocab)`
|
|
and next state list for ys.
|
|
|
|
"""
|
|
return self.score_partial_(
|
|
y, paddle.to_tensor(range(self.charlen)), state, x)
|
|
|
|
|
|
class NgramPartScorer(Ngrambase, PartialScorerInterface):
|
|
"""Partialscorer for ngram."""
|
|
|
|
def score_partial(self, y, next_token, state, x):
|
|
"""Score interface for both full and partial scorer.
|
|
|
|
Args:
|
|
y: previous char
|
|
next_token: next token need to be score
|
|
state: previous state
|
|
x: encoded feature
|
|
|
|
Returns:
|
|
tuple[paddle.Tensor, List[Any]]: Tuple of
|
|
batchfied scores for next token with shape of `(n_batch, n_vocab)`
|
|
and next state list for ys.
|
|
|
|
"""
|
|
return self.score_partial_(y, next_token, state, x)
|
|
|
|
def select_state(self, state, i):
|
|
"""Empty select state for scorer interface."""
|
|
return state
|