|
|
|
@ -1,9 +1,22 @@
|
|
|
|
|
import kenlm
|
|
|
|
|
import jieba
|
|
|
|
|
import time
|
|
|
|
|
|
|
|
|
|
# Copyright (c) 2021 PaddlePaddle Authors. All Rights Reserved.
|
|
|
|
|
#
|
|
|
|
|
# Licensed under the Apache License, Version 2.0 (the "License");
|
|
|
|
|
# you may not use this file except in compliance with the License.
|
|
|
|
|
# You may obtain a copy of the License at
|
|
|
|
|
#
|
|
|
|
|
# http://www.apache.org/licenses/LICENSE-2.0
|
|
|
|
|
#
|
|
|
|
|
# Unless required by applicable law or agreed to in writing, software
|
|
|
|
|
# distributed under the License is distributed on an "AS IS" BASIS,
|
|
|
|
|
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
|
|
|
|
# See the License for the specific language governing permissions and
|
|
|
|
|
# limitations under the License.
|
|
|
|
|
import os
|
|
|
|
|
import sys
|
|
|
|
|
import time
|
|
|
|
|
|
|
|
|
|
import jieba
|
|
|
|
|
import kenlm
|
|
|
|
|
|
|
|
|
|
language_model_path = sys.argv[1]
|
|
|
|
|
assert os.path.exists(language_model_path)
|
|
|
|
@ -33,7 +46,8 @@ def test_score():
|
|
|
|
|
for i, v in enumerate(model.full_scores(sentence_char_split)):
|
|
|
|
|
print(i, v)
|
|
|
|
|
split_size += 1
|
|
|
|
|
assert split_size == len(sentence_char_split.split()) + 1, "error split size."
|
|
|
|
|
assert split_size == len(
|
|
|
|
|
sentence_char_split.split()) + 1, "error split size."
|
|
|
|
|
|
|
|
|
|
print(sentence_word_split)
|
|
|
|
|
print(model.score(sentence_word_split))
|
|
|
|
@ -47,8 +61,10 @@ def test_full_scores_chars():
|
|
|
|
|
print(sentence_char_split)
|
|
|
|
|
# Show scores and n-gram matches
|
|
|
|
|
words = ['<s>'] + list(sentence) + ['</s>']
|
|
|
|
|
for i, (prob, length, oov) in enumerate(model.full_scores(sentence_char_split)):
|
|
|
|
|
print('{0} {1}: {2}'.format(prob, length, ' '.join(words[i + 2 - length:i + 2])))
|
|
|
|
|
for i, (prob, length,
|
|
|
|
|
oov) in enumerate(model.full_scores(sentence_char_split)):
|
|
|
|
|
print('{0} {1}: {2}'.format(prob, length, ' '.join(words[i + 2 - length:
|
|
|
|
|
i + 2])))
|
|
|
|
|
if oov:
|
|
|
|
|
print('\t"{0}" is an OOV'.format(words[i + 1]))
|
|
|
|
|
|
|
|
|
@ -67,8 +83,10 @@ def test_full_scores_words():
|
|
|
|
|
print(sentence_word_split)
|
|
|
|
|
# Show scores and n-gram matches
|
|
|
|
|
words = ['<s>'] + sentence_word_split.split() + ['</s>']
|
|
|
|
|
for i, (prob, length, oov) in enumerate(model.full_scores(sentence_word_split)):
|
|
|
|
|
print('{0} {1}: {2}'.format(prob, length, ' '.join(words[i + 2 - length:i + 2])))
|
|
|
|
|
for i, (prob, length,
|
|
|
|
|
oov) in enumerate(model.full_scores(sentence_word_split)):
|
|
|
|
|
print('{0} {1}: {2}'.format(prob, length, ' '.join(words[i + 2 - length:
|
|
|
|
|
i + 2])))
|
|
|
|
|
if oov:
|
|
|
|
|
print('\t"{0}" is an OOV'.format(words[i + 1]))
|
|
|
|
|
|
|
|
|
@ -80,7 +98,8 @@ def test_full_scores_words():
|
|
|
|
|
print('"{0}" is an OOV'.format(w))
|
|
|
|
|
oov.append(w)
|
|
|
|
|
# zh_giga.no_cna_cmn.prune01244.klm is chinese charactor LM
|
|
|
|
|
assert oov == ["盘点", "不怕", "网站", "❗", "️", "海淘", "向来", "便宜", "保真", "!"], 'error oov'
|
|
|
|
|
assert oov == ["盘点", "不怕", "网站", "❗", "️", "海淘", "向来", "便宜", "保真", "!"
|
|
|
|
|
], 'error oov'
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
def test_full_scores_chars_length():
|
|
|
|
@ -159,9 +178,10 @@ def test_ppl_sentence():
|
|
|
|
|
n2 = model.perplexity(part_char_split2)
|
|
|
|
|
print(n2)
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
if __name__ == '__main__':
|
|
|
|
|
test_score()
|
|
|
|
|
test_full_scores_chars()
|
|
|
|
|
test_full_scores_words()
|
|
|
|
|
test_full_scores_chars_length()
|
|
|
|
|
test_ppl_sentence()
|
|
|
|
|
test_ppl_sentence()
|
|
|
|
|