You can not select more than 25 topics Topics must start with a letter or number, can include dashes ('-') and can be up to 35 characters long.
PaddleSpeech/examples/ngram_lm/local/kenlm_score_test.py

188 lines
6.2 KiB

This file contains invisible Unicode characters!

This file contains invisible Unicode characters that may be processed differently from what appears below. If your use case is intentional and legitimate, you can safely ignore this warning. Use the Escape button to reveal hidden characters.

This file contains ambiguous Unicode characters that may be confused with others in your current locale. If your use case is intentional and legitimate, you can safely ignore this warning. Use the Escape button to highlight these characters.

# Copyright (c) 2021 PaddlePaddle Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import os
import sys
import time
import jieba
import kenlm
language_model_path = sys.argv[1]
assert os.path.exists(language_model_path)
start = time.time()
model = kenlm.Model(language_model_path)
print(f"load kenLM cost: {time.time() - start}s")
sentence = '盘点不怕被税的海淘网站❗️海淘向来便宜又保真!'
sentence_char_split = ' '.join(list(sentence))
sentence_word_split = ' '.join(jieba.lcut(sentence))
def test_score():
print('Loaded language model: %s' % language_model_path)
print(sentence)
print(model.score(sentence))
print(list(model.full_scores(sentence)))
for i, v in enumerate(model.full_scores(sentence)):
print(i, v)
print(sentence_char_split)
print(model.score(sentence_char_split))
print(list(model.full_scores(sentence_char_split)))
split_size = 0
for i, v in enumerate(model.full_scores(sentence_char_split)):
print(i, v)
split_size += 1
assert split_size == len(
sentence_char_split.split()) + 1, "error split size."
print(sentence_word_split)
print(model.score(sentence_word_split))
print(list(model.full_scores(sentence_word_split)))
for i, v in enumerate(model.full_scores(sentence_word_split)):
print(i, v)
def test_full_scores_chars():
print('Loaded language model: %s' % language_model_path)
print(sentence_char_split)
# Show scores and n-gram matches
words = ['<s>'] + list(sentence) + ['</s>']
for i, (prob, length,
oov) in enumerate(model.full_scores(sentence_char_split)):
print('{0} {1}: {2}'.format(prob, length, ' '.join(words[i + 2 - length:
i + 2])))
if oov:
print('\t"{0}" is an OOV'.format(words[i + 1]))
print("-" * 42)
# Find out-of-vocabulary words
oov = []
for w in words:
if w not in model:
print('"{0}" is an OOV'.format(w))
oov.append(w)
assert oov == ["", "", ""], 'error oov'
def test_full_scores_words():
print('Loaded language model: %s' % language_model_path)
print(sentence_word_split)
# Show scores and n-gram matches
words = ['<s>'] + sentence_word_split.split() + ['</s>']
for i, (prob, length,
oov) in enumerate(model.full_scores(sentence_word_split)):
print('{0} {1}: {2}'.format(prob, length, ' '.join(words[i + 2 - length:
i + 2])))
if oov:
print('\t"{0}" is an OOV'.format(words[i + 1]))
print("-" * 42)
# Find out-of-vocabulary words
oov = []
for w in words:
if w not in model:
print('"{0}" is an OOV'.format(w))
oov.append(w)
# zh_giga.no_cna_cmn.prune01244.klm is chinese charactor LM
assert oov == ["盘点", "不怕", "网站", "", "", "海淘", "向来", "便宜", "保真",
""], 'error oov'
def test_full_scores_chars_length():
"""test bos eos size"""
print('Loaded language model: %s' % language_model_path)
r = list(model.full_scores(sentence_char_split))
n = list(model.full_scores(sentence_char_split, bos=False, eos=False))
print(r)
print(n)
assert len(r) == len(n) + 1
# bos=False, eos=False, input len == output len
print(len(n), len(sentence_char_split.split()))
assert len(n) == len(sentence_char_split.split())
k = list(model.full_scores(sentence_char_split, bos=False, eos=True))
print(k, len(k))
def test_ppl_sentence():
"""测试句子粒度的ppl得分"""
sentence_char_split1 = ' '.join('先救挨饿的人,然后治疗病人。')
sentence_char_split2 = ' '.join('先就挨饿的人,然后治疗病人。')
n = model.perplexity(sentence_char_split1)
print('1', n)
n = model.perplexity(sentence_char_split2)
print(n)
part_char_split1 = ' '.join('先救挨饿的人')
part_char_split2 = ' '.join('先就挨饿的人')
n = model.perplexity(part_char_split1)
print('2', n)
n = model.perplexity(part_char_split2)
print(n)
part_char_split1 = '先救挨'
part_char_split2 = '先就挨'
n1 = model.perplexity(part_char_split1)
print('3', n1)
n2 = model.perplexity(part_char_split2)
print(n2)
assert n1 == n2
part_char_split1 = '先 救 挨'
part_char_split2 = '先 就 挨'
n1 = model.perplexity(part_char_split1)
print('4', n1)
n2 = model.perplexity(part_char_split2)
print(n2)
part_char_split1 = '先 救 挨 饿 的 人'
part_char_split2 = '先 就 挨 饿 的 人'
n1 = model.perplexity(part_char_split1)
print('5', n1)
n2 = model.perplexity(part_char_split2)
print(n2)
part_char_split1 = '先 救 挨 饿 的 人 '
part_char_split2 = '先 就 挨 饿 的 人 '
n1 = model.perplexity(part_char_split1)
print('6', n1)
n2 = model.perplexity(part_char_split2)
print(n2)
part_char_split1 = '先 救 挨 饿 的 人 然 后 治 疗 病 人'
part_char_split2 = '先 就 挨 饿 的 人 然 后 治 疗 病 人'
n1 = model.perplexity(part_char_split1)
print('7', n1)
n2 = model.perplexity(part_char_split2)
print(n2)
part_char_split1 = '先 救 挨 饿 的 人 然 后 治 疗 病 人 。'
part_char_split2 = '先 就 挨 饿 的 人 然 后 治 疗 病 人 。'
n1 = model.perplexity(part_char_split1)
print('8', n1)
n2 = model.perplexity(part_char_split2)
print(n2)
if __name__ == '__main__':
test_score()
test_full_scores_chars()
test_full_scores_words()
test_full_scores_chars_length()
test_ppl_sentence()