pull/613/head
Hui Zhang 4 years ago
parent 98253954cb
commit 853e5acb13

@ -52,4 +52,4 @@ DeepSpeech is provided under the [Apache-2.0 License](./LICENSE).
## Acknowledgement
We depends on many open source repos. See [References](doc/src/reference.md) for more information.
We depends on many open source repos. See [References](doc/src/reference.md) for more information.

@ -50,4 +50,4 @@ DeepSpeech遵循[Apache-2.0开源协议](./LICENSE)。
## 感谢
开发中参考一些优秀的仓库,详情参见 [References](doc/src/reference.md)。
开发中参考一些优秀的仓库,详情参见 [References](doc/src/reference.md)。

@ -1,9 +1,22 @@
import kenlm
import jieba
import time
# Copyright (c) 2021 PaddlePaddle Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import os
import sys
import time
import jieba
import kenlm
language_model_path = sys.argv[1]
assert os.path.exists(language_model_path)
@ -33,7 +46,8 @@ def test_score():
for i, v in enumerate(model.full_scores(sentence_char_split)):
print(i, v)
split_size += 1
assert split_size == len(sentence_char_split.split()) + 1, "error split size."
assert split_size == len(
sentence_char_split.split()) + 1, "error split size."
print(sentence_word_split)
print(model.score(sentence_word_split))
@ -47,8 +61,10 @@ def test_full_scores_chars():
print(sentence_char_split)
# Show scores and n-gram matches
words = ['<s>'] + list(sentence) + ['</s>']
for i, (prob, length, oov) in enumerate(model.full_scores(sentence_char_split)):
print('{0} {1}: {2}'.format(prob, length, ' '.join(words[i + 2 - length:i + 2])))
for i, (prob, length,
oov) in enumerate(model.full_scores(sentence_char_split)):
print('{0} {1}: {2}'.format(prob, length, ' '.join(words[i + 2 - length:
i + 2])))
if oov:
print('\t"{0}" is an OOV'.format(words[i + 1]))
@ -67,8 +83,10 @@ def test_full_scores_words():
print(sentence_word_split)
# Show scores and n-gram matches
words = ['<s>'] + sentence_word_split.split() + ['</s>']
for i, (prob, length, oov) in enumerate(model.full_scores(sentence_word_split)):
print('{0} {1}: {2}'.format(prob, length, ' '.join(words[i + 2 - length:i + 2])))
for i, (prob, length,
oov) in enumerate(model.full_scores(sentence_word_split)):
print('{0} {1}: {2}'.format(prob, length, ' '.join(words[i + 2 - length:
i + 2])))
if oov:
print('\t"{0}" is an OOV'.format(words[i + 1]))
@ -80,7 +98,8 @@ def test_full_scores_words():
print('"{0}" is an OOV'.format(w))
oov.append(w)
# zh_giga.no_cna_cmn.prune01244.klm is chinese charactor LM
assert oov == ["盘点", "不怕", "网站", "", "", "海淘", "向来", "便宜", "保真", ""], 'error oov'
assert oov == ["盘点", "不怕", "网站", "", "", "海淘", "向来", "便宜", "保真", ""
], 'error oov'
def test_full_scores_chars_length():
@ -159,9 +178,10 @@ def test_ppl_sentence():
n2 = model.perplexity(part_char_split2)
print(n2)
if __name__ == '__main__':
test_score()
test_full_scores_chars()
test_full_scores_words()
test_full_scores_chars_length()
test_ppl_sentence()
test_ppl_sentence()

@ -1,31 +1,36 @@
#!/usr/bin/env python3
from typing import List, Text
import re
import string
import sys
from typing import List
from typing import Text
import jieba
import string
import re
from zhon import hanzi
def char_token(s: Text) -> List[Text]:
return list(s)
def word_token(s: Text) -> List[Text]:
return jieba.lcut(s)
def tn(s: Text) -> Text:
s = s.strip()
s = s.replace('*', '')
# rm english punctuations
s = re.sub(f'[re.escape(string.punctuation)]' , "", s)
s = re.sub(f'[re.escape(string.punctuation)]', "", s)
# rm chinese punctuations
s = re.sub(f'[{hanzi.punctuation}]', "", s)
# text normalization
# rm english
s = ''.join(re.findall(hanzi.sent, s))
return s
def main(infile, outfile, tokenizer=None):
with open(infile, 'rt') as fin, open(outfile, 'wt') as fout:
lines = fin.readlines()
@ -36,6 +41,7 @@ def main(infile, outfile, tokenizer=None):
fout.write(l)
fout.write('\n')
if __name__ == '__main__':
if len(sys.argv) != 4:
print(f"sys.arv[0] [char|word] text text_out ")
@ -52,4 +58,4 @@ if __name__ == '__main__':
else:
tokenizer = None
main(text, text_out, tokenizer)
main(text, text_out, tokenizer)

Loading…
Cancel
Save