pull/613/head
Hui Zhang 4 years ago
parent 98253954cb
commit 853e5acb13

@ -1,9 +1,22 @@
import kenlm # Copyright (c) 2021 PaddlePaddle Authors. All Rights Reserved.
import jieba #
import time # Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import os import os
import sys import sys
import time
import jieba
import kenlm
language_model_path = sys.argv[1] language_model_path = sys.argv[1]
assert os.path.exists(language_model_path) assert os.path.exists(language_model_path)
@ -33,7 +46,8 @@ def test_score():
for i, v in enumerate(model.full_scores(sentence_char_split)): for i, v in enumerate(model.full_scores(sentence_char_split)):
print(i, v) print(i, v)
split_size += 1 split_size += 1
assert split_size == len(sentence_char_split.split()) + 1, "error split size." assert split_size == len(
sentence_char_split.split()) + 1, "error split size."
print(sentence_word_split) print(sentence_word_split)
print(model.score(sentence_word_split)) print(model.score(sentence_word_split))
@ -47,8 +61,10 @@ def test_full_scores_chars():
print(sentence_char_split) print(sentence_char_split)
# Show scores and n-gram matches # Show scores and n-gram matches
words = ['<s>'] + list(sentence) + ['</s>'] words = ['<s>'] + list(sentence) + ['</s>']
for i, (prob, length, oov) in enumerate(model.full_scores(sentence_char_split)): for i, (prob, length,
print('{0} {1}: {2}'.format(prob, length, ' '.join(words[i + 2 - length:i + 2]))) oov) in enumerate(model.full_scores(sentence_char_split)):
print('{0} {1}: {2}'.format(prob, length, ' '.join(words[i + 2 - length:
i + 2])))
if oov: if oov:
print('\t"{0}" is an OOV'.format(words[i + 1])) print('\t"{0}" is an OOV'.format(words[i + 1]))
@ -67,8 +83,10 @@ def test_full_scores_words():
print(sentence_word_split) print(sentence_word_split)
# Show scores and n-gram matches # Show scores and n-gram matches
words = ['<s>'] + sentence_word_split.split() + ['</s>'] words = ['<s>'] + sentence_word_split.split() + ['</s>']
for i, (prob, length, oov) in enumerate(model.full_scores(sentence_word_split)): for i, (prob, length,
print('{0} {1}: {2}'.format(prob, length, ' '.join(words[i + 2 - length:i + 2]))) oov) in enumerate(model.full_scores(sentence_word_split)):
print('{0} {1}: {2}'.format(prob, length, ' '.join(words[i + 2 - length:
i + 2])))
if oov: if oov:
print('\t"{0}" is an OOV'.format(words[i + 1])) print('\t"{0}" is an OOV'.format(words[i + 1]))
@ -80,7 +98,8 @@ def test_full_scores_words():
print('"{0}" is an OOV'.format(w)) print('"{0}" is an OOV'.format(w))
oov.append(w) oov.append(w)
# zh_giga.no_cna_cmn.prune01244.klm is chinese charactor LM # zh_giga.no_cna_cmn.prune01244.klm is chinese charactor LM
assert oov == ["盘点", "不怕", "网站", "", "", "海淘", "向来", "便宜", "保真", ""], 'error oov' assert oov == ["盘点", "不怕", "网站", "", "", "海淘", "向来", "便宜", "保真", ""
], 'error oov'
def test_full_scores_chars_length(): def test_full_scores_chars_length():
@ -159,6 +178,7 @@ def test_ppl_sentence():
n2 = model.perplexity(part_char_split2) n2 = model.perplexity(part_char_split2)
print(n2) print(n2)
if __name__ == '__main__': if __name__ == '__main__':
test_score() test_score()
test_full_scores_chars() test_full_scores_chars()

@ -1,23 +1,27 @@
#!/usr/bin/env python3 #!/usr/bin/env python3
import re
from typing import List, Text import string
import sys import sys
from typing import List
from typing import Text
import jieba import jieba
import string
import re
from zhon import hanzi from zhon import hanzi
def char_token(s: Text) -> List[Text]: def char_token(s: Text) -> List[Text]:
return list(s) return list(s)
def word_token(s: Text) -> List[Text]: def word_token(s: Text) -> List[Text]:
return jieba.lcut(s) return jieba.lcut(s)
def tn(s: Text) -> Text: def tn(s: Text) -> Text:
s = s.strip() s = s.strip()
s = s.replace('*', '') s = s.replace('*', '')
# rm english punctuations # rm english punctuations
s = re.sub(f'[re.escape(string.punctuation)]' , "", s) s = re.sub(f'[re.escape(string.punctuation)]', "", s)
# rm chinese punctuations # rm chinese punctuations
s = re.sub(f'[{hanzi.punctuation}]', "", s) s = re.sub(f'[{hanzi.punctuation}]', "", s)
# text normalization # text normalization
@ -26,6 +30,7 @@ def tn(s: Text) -> Text:
s = ''.join(re.findall(hanzi.sent, s)) s = ''.join(re.findall(hanzi.sent, s))
return s return s
def main(infile, outfile, tokenizer=None): def main(infile, outfile, tokenizer=None):
with open(infile, 'rt') as fin, open(outfile, 'wt') as fout: with open(infile, 'rt') as fin, open(outfile, 'wt') as fout:
lines = fin.readlines() lines = fin.readlines()
@ -36,6 +41,7 @@ def main(infile, outfile, tokenizer=None):
fout.write(l) fout.write(l)
fout.write('\n') fout.write('\n')
if __name__ == '__main__': if __name__ == '__main__':
if len(sys.argv) != 4: if len(sys.argv) != 4:
print(f"sys.arv[0] [char|word] text text_out ") print(f"sys.arv[0] [char|word] text text_out ")

Loading…
Cancel
Save