From 7af055631b804729fc31d5f0d2b3f02c71d9008e Mon Sep 17 00:00:00 2001 From: Hui Zhang Date: Fri, 14 May 2021 11:51:22 +0000 Subject: [PATCH] add ngram lm egs --- examples/ngram_lm/local/build_zh_lm.sh | 4 +- examples/ngram_lm/local/kenlm_score_test.py | 4 +- examples/ngram_lm/local/zh_preprocess.py | 61 -- examples/ngram_lm/path.sh | 7 +- .../chinese_text_normalization/README.md | 2 +- .../thrax/src/README.md | 4 +- .../ngram_lm/local => utils}/ngram_train.sh | 0 utils/zh_tn.py | 927 ++++++++++++++++++ 8 files changed, 936 insertions(+), 73 deletions(-) delete mode 100644 examples/ngram_lm/local/zh_preprocess.py rename {examples/ngram_lm/local => utils}/ngram_train.sh (100%) create mode 100644 utils/zh_tn.py diff --git a/examples/ngram_lm/local/build_zh_lm.sh b/examples/ngram_lm/local/build_zh_lm.sh index a2708e194..73eb165ec 100644 --- a/examples/ngram_lm/local/build_zh_lm.sh +++ b/examples/ngram_lm/local/build_zh_lm.sh @@ -27,11 +27,11 @@ arpa=$3 if [ $stage -le 0 ] && [ $stop_stage -ge 0 ];then # text tn & wordseg preprocess echo "process text." - python3 local/zh_preprocess.py ${type} ${text} ${text}.${type}.tn + python3 ${MAIN_ROOT}/utils/zh_tn.py ${type} ${text} ${text}.${type}.tn fi if [ $stage -le 1 ] && [ $stop_stage -ge 1 ];then # train ngram lm echo "build lm." - bash local/ngram_train.sh --order ${order} --mem ${mem} --prune "${prune}" ${text}.${type}.tn ${arpa} + bash ${MAIN_ROOT}/utils/ngram_train.sh --order ${order} --mem ${mem} --prune "${prune}" ${text}.${type}.tn ${arpa} fi \ No newline at end of file diff --git a/examples/ngram_lm/local/kenlm_score_test.py b/examples/ngram_lm/local/kenlm_score_test.py index 6268f53a1..30bc1e4b1 100644 --- a/examples/ngram_lm/local/kenlm_score_test.py +++ b/examples/ngram_lm/local/kenlm_score_test.py @@ -98,8 +98,8 @@ def test_full_scores_words(): print('"{0}" is an OOV'.format(w)) oov.append(w) # zh_giga.no_cna_cmn.prune01244.klm is chinese charactor LM - assert oov == ["盘点", "不怕", "网站", "❗", "️", "海淘", "向来", "便宜", "保真", "!" - ], 'error oov' + assert oov == ["盘点", "不怕", "网站", "❗", "️", "海淘", "向来", "便宜", "保真", + "!"], 'error oov' def test_full_scores_chars_length(): diff --git a/examples/ngram_lm/local/zh_preprocess.py b/examples/ngram_lm/local/zh_preprocess.py deleted file mode 100644 index 93f98624e..000000000 --- a/examples/ngram_lm/local/zh_preprocess.py +++ /dev/null @@ -1,61 +0,0 @@ -#!/usr/bin/env python3 -import re -import string -import sys -from typing import List -from typing import Text - -import jieba -from zhon import hanzi - - -def char_token(s: Text) -> List[Text]: - return list(s) - - -def word_token(s: Text) -> List[Text]: - return jieba.lcut(s) - - -def tn(s: Text) -> Text: - s = s.strip() - s = s.replace('*', '') - # rm english punctuations - s = re.sub(f'[re.escape(string.punctuation)]', "", s) - # rm chinese punctuations - s = re.sub(f'[{hanzi.punctuation}]', "", s) - # text normalization - - # rm english - s = ''.join(re.findall(hanzi.sent, s)) - return s - - -def main(infile, outfile, tokenizer=None): - with open(infile, 'rt') as fin, open(outfile, 'wt') as fout: - lines = fin.readlines() - for l in lines: - l = tn(l) - if tokenizer: - l = ' '.join(tokenizer(l)) - fout.write(l) - fout.write('\n') - - -if __name__ == '__main__': - if len(sys.argv) != 4: - print(f"sys.arv[0] [char|word] text text_out ") - exit(-1) - - token_type = sys.argv[1] - text = sys.argv[2] - text_out = sys.argv[3] - - if token_type == 'char': - tokenizer = char_token - elif token_type == 'word': - tokenizer = word_token - else: - tokenizer = None - - main(text, text_out, tokenizer) diff --git a/examples/ngram_lm/path.sh b/examples/ngram_lm/path.sh index c07b11cf0..84e2de7d0 100644 --- a/examples/ngram_lm/path.sh +++ b/examples/ngram_lm/path.sh @@ -1,13 +1,10 @@ export MAIN_ROOT=${PWD}/../../ -export PATH=${MAIN_ROOT}:${PWD}/tools:${PATH} +export PATH=${MAIN_ROOT}:${MAIN_ROOT}/utils:${PATH} export LC_ALL=C # Use UTF-8 in Python to avoid UnicodeDecodeError when LC_ALL=C export PYTHONIOENCODING=UTF-8 export PYTHONPATH=${MAIN_ROOT}:${PYTHONPATH} -export LD_LIBRARY_PATH=${LD_LIBRARY_PATH}:/usr/local/lib/ - -MODEL=deepspeech2 -export BIN_DIR=${MAIN_ROOT}/deepspeech/exps/${MODEL}/bin +export LD_LIBRARY_PATH=/usr/local/lib/:${LD_LIBRARY_PATH} \ No newline at end of file diff --git a/third_party/chinese_text_normalization/README.md b/third_party/chinese_text_normalization/README.md index fd5182594..105e8fd52 100644 --- a/third_party/chinese_text_normalization/README.md +++ b/third_party/chinese_text_normalization/README.md @@ -45,7 +45,7 @@ This project sets up a ready-to-use TN module for **Chinese**. Since my backgrou acknowledgement: the NSW normalization codes are based on [Zhiyang Zhou's work here](https://github.com/Joee1995/chn_text_norm.git) 1. punctuation removal - + For Chinese, it removes punctuation list collected in [Zhon](https://github.com/tsroten/zhon) project, containing * non-stop puncs ``` diff --git a/third_party/chinese_text_normalization/thrax/src/README.md b/third_party/chinese_text_normalization/thrax/src/README.md index a7b2b0242..878ff18fb 100644 --- a/third_party/chinese_text_normalization/thrax/src/README.md +++ b/third_party/chinese_text_normalization/thrax/src/README.md @@ -2,10 +2,10 @@ This repository provides covering grammars for English and Russian text normalization as documented in: - + Gorman, K., and Sproat, R. 2016. Minimally supervised number normalization. _Transactions of the Association for Computational Linguistics_ 4: 507-519. - + Ng, A. H., Gorman, K., and Sproat, R. 2017. Minimally supervised written-to-spoken text normalization. In _ASRU_, pages 665-670. diff --git a/examples/ngram_lm/local/ngram_train.sh b/utils/ngram_train.sh similarity index 100% rename from examples/ngram_lm/local/ngram_train.sh rename to utils/ngram_train.sh diff --git a/utils/zh_tn.py b/utils/zh_tn.py new file mode 100644 index 000000000..4dcf27431 --- /dev/null +++ b/utils/zh_tn.py @@ -0,0 +1,927 @@ +#!/usr/bin/env python3 +# https://github.com/speechio/chinese_text_normalization/blob/master/python/cn_tn.py +import argparse +import re +import string +import sys +from typing import List +from typing import Text + +import jieba +from zhon import hanzi + +# ================================================================================ # +# basic constant +# ================================================================================ # +CHINESE_DIGIS = u'零一二三四五六七八九' +BIG_CHINESE_DIGIS_SIMPLIFIED = u'零壹贰叁肆伍陆柒捌玖' +BIG_CHINESE_DIGIS_TRADITIONAL = u'零壹貳參肆伍陸柒捌玖' +SMALLER_BIG_CHINESE_UNITS_SIMPLIFIED = u'十百千万' +SMALLER_BIG_CHINESE_UNITS_TRADITIONAL = u'拾佰仟萬' +LARGER_CHINESE_NUMERING_UNITS_SIMPLIFIED = u'亿兆京垓秭穰沟涧正载' +LARGER_CHINESE_NUMERING_UNITS_TRADITIONAL = u'億兆京垓秭穰溝澗正載' +SMALLER_CHINESE_NUMERING_UNITS_SIMPLIFIED = u'十百千万' +SMALLER_CHINESE_NUMERING_UNITS_TRADITIONAL = u'拾佰仟萬' + +ZERO_ALT = u'〇' +ONE_ALT = u'幺' +TWO_ALTS = [u'两', u'兩'] + +POSITIVE = [u'正', u'正'] +NEGATIVE = [u'负', u'負'] +POINT = [u'点', u'點'] +# PLUS = [u'加', u'加'] +# SIL = [u'杠', u'槓'] + +# 中文数字系统类型 +NUMBERING_TYPES = ['low', 'mid', 'high'] + +CURRENCY_NAMES = '(人民币|美元|日元|英镑|欧元|马克|法郎|加拿大元|澳元|港币|先令|芬兰马克|爱尔兰镑|' \ + '里拉|荷兰盾|埃斯库多|比塞塔|印尼盾|林吉特|新西兰元|比索|卢布|新加坡元|韩元|泰铢)' +CURRENCY_UNITS = '((亿|千万|百万|万|千|百)|(亿|千万|百万|万|千|百|)元|(亿|千万|百万|万|千|百|)块|角|毛|分)' +COM_QUANTIFIERS = '(匹|张|座|回|场|尾|条|个|首|阙|阵|网|炮|顶|丘|棵|只|支|袭|辆|挑|担|颗|壳|窠|曲|墙|群|腔|' \ + '砣|座|客|贯|扎|捆|刀|令|打|手|罗|坡|山|岭|江|溪|钟|队|单|双|对|出|口|头|脚|板|跳|枝|件|贴|' \ + '针|线|管|名|位|身|堂|课|本|页|家|户|层|丝|毫|厘|分|钱|两|斤|担|铢|石|钧|锱|忽|(千|毫|微)克|' \ + '毫|厘|分|寸|尺|丈|里|寻|常|铺|程|(千|分|厘|毫|微)米|撮|勺|合|升|斗|石|盘|碗|碟|叠|桶|笼|盆|' \ + '盒|杯|钟|斛|锅|簋|篮|盘|桶|罐|瓶|壶|卮|盏|箩|箱|煲|啖|袋|钵|年|月|日|季|刻|时|周|天|秒|分|旬|' \ + '纪|岁|世|更|夜|春|夏|秋|冬|代|伏|辈|丸|泡|粒|颗|幢|堆|条|根|支|道|面|片|张|颗|块)' + +# punctuation information are based on Zhon project (https://github.com/tsroten/zhon.git) +CHINESE_PUNC_STOP = '!?。。' +CHINESE_PUNC_NON_STOP = '"#$%&'()*+,-/:;<=>@[\]^_`{|}~⦅⦆「」、、〃《》「」『』【】〔〕〖〗〘〙〚〛〜〝〞〟〰〾〿–—‘’‛“”„‟…‧﹏' +CHINESE_PUNC_OTHER = '·〈〉-' +CHINESE_PUNC_LIST = CHINESE_PUNC_STOP + CHINESE_PUNC_NON_STOP + CHINESE_PUNC_OTHER + + +# ================================================================================ # +# basic class +# ================================================================================ # +class ChineseChar(): + """ + 中文字符 + 每个字符对应简体和繁体, + e.g. 简体 = '负', 繁体 = '負' + 转换时可转换为简体或繁体 + """ + + def __init__(self, simplified, traditional): + self.simplified = simplified + self.traditional = traditional + + def __str__(self): + return self.simplified or self.traditional or None + + def __repr__(self): + return self.__str__() + + +class ChineseNumberUnit(ChineseChar): + """ + 中文数字/数位字符 + 每个字符除繁简体外还有一个额外的大写字符 + e.g. '陆' 和 '陸' + """ + + def __init__(self, power, simplified, traditional, big_s, big_t): + super().__init__(simplified, traditional) + self.power = power + self.big_s = big_s + self.big_t = big_t + + def __str__(self): + return '10^{}'.format(self.power) + + @classmethod + def create(cls, + index, + value, + numbering_type=NUMBERING_TYPES[1], + small_unit=False): + + if small_unit: + return ChineseNumberUnit( + power=index + 1, + simplified=value[0], + traditional=value[1], + big_s=value[1], + big_t=value[1]) + elif numbering_type == NUMBERING_TYPES[0]: + return ChineseNumberUnit( + power=index + 8, + simplified=value[0], + traditional=value[1], + big_s=value[0], + big_t=value[1]) + elif numbering_type == NUMBERING_TYPES[1]: + return ChineseNumberUnit( + power=(index + 2) * 4, + simplified=value[0], + traditional=value[1], + big_s=value[0], + big_t=value[1]) + elif numbering_type == NUMBERING_TYPES[2]: + return ChineseNumberUnit( + power=pow(2, index + 3), + simplified=value[0], + traditional=value[1], + big_s=value[0], + big_t=value[1]) + else: + raise ValueError('Counting type should be in {0} ({1} provided).'. + format(NUMBERING_TYPES, numbering_type)) + + +class ChineseNumberDigit(ChineseChar): + """ + 中文数字字符 + """ + + def __init__(self, + value, + simplified, + traditional, + big_s, + big_t, + alt_s=None, + alt_t=None): + super().__init__(simplified, traditional) + self.value = value + self.big_s = big_s + self.big_t = big_t + self.alt_s = alt_s + self.alt_t = alt_t + + def __str__(self): + return str(self.value) + + @classmethod + def create(cls, i, v): + return ChineseNumberDigit(i, v[0], v[1], v[2], v[3]) + + +class ChineseMath(ChineseChar): + """ + 中文数位字符 + """ + + def __init__(self, simplified, traditional, symbol, expression=None): + super().__init__(simplified, traditional) + self.symbol = symbol + self.expression = expression + self.big_s = simplified + self.big_t = traditional + + +CC, CNU, CND, CM = ChineseChar, ChineseNumberUnit, ChineseNumberDigit, ChineseMath + + +class NumberSystem(): + """ + 中文数字系统 + """ + pass + + +class MathSymbol(): + """ + 用于中文数字系统的数学符号 (繁/简体), e.g. + positive = ['正', '正'] + negative = ['负', '負'] + point = ['点', '點'] + """ + + def __init__(self, positive, negative, point): + self.positive = positive + self.negative = negative + self.point = point + + def __iter__(self): + for v in self.__dict__.values(): + yield v + + +# class OtherSymbol(): +# """ +# 其他符号 +# """ +# +# def __init__(self, sil): +# self.sil = sil +# +# def __iter__(self): +# for v in self.__dict__.values(): +# yield v + + +# ================================================================================ # +# basic utils +# ================================================================================ # +def create_system(numbering_type=NUMBERING_TYPES[1]): + """ + 根据数字系统类型返回创建相应的数字系统,默认为 mid + NUMBERING_TYPES = ['low', 'mid', 'high']: 中文数字系统类型 + low: '兆' = '亿' * '十' = $10^{9}$, '京' = '兆' * '十', etc. + mid: '兆' = '亿' * '万' = $10^{12}$, '京' = '兆' * '万', etc. + high: '兆' = '亿' * '亿' = $10^{16}$, '京' = '兆' * '兆', etc. + 返回对应的数字系统 + """ + + # chinese number units of '亿' and larger + all_larger_units = zip(LARGER_CHINESE_NUMERING_UNITS_SIMPLIFIED, + LARGER_CHINESE_NUMERING_UNITS_TRADITIONAL) + larger_units = [ + CNU.create(i, v, numbering_type, False) + for i, v in enumerate(all_larger_units) + ] + # chinese number units of '十, 百, 千, 万' + all_smaller_units = zip(SMALLER_CHINESE_NUMERING_UNITS_SIMPLIFIED, + SMALLER_CHINESE_NUMERING_UNITS_TRADITIONAL) + smaller_units = [ + CNU.create(i, v, small_unit=True) + for i, v in enumerate(all_smaller_units) + ] + # digis + chinese_digis = zip(CHINESE_DIGIS, CHINESE_DIGIS, + BIG_CHINESE_DIGIS_SIMPLIFIED, + BIG_CHINESE_DIGIS_TRADITIONAL) + digits = [CND.create(i, v) for i, v in enumerate(chinese_digis)] + digits[0].alt_s, digits[0].alt_t = ZERO_ALT, ZERO_ALT + digits[1].alt_s, digits[1].alt_t = ONE_ALT, ONE_ALT + digits[2].alt_s, digits[2].alt_t = TWO_ALTS[0], TWO_ALTS[1] + + # symbols + positive_cn = CM(POSITIVE[0], POSITIVE[1], '+', lambda x: x) + negative_cn = CM(NEGATIVE[0], NEGATIVE[1], '-', lambda x: -x) + point_cn = CM(POINT[0], POINT[1], '.', + lambda x, y: float(str(x) + '.' + str(y))) + # sil_cn = CM(SIL[0], SIL[1], '-', lambda x, y: float(str(x) + '-' + str(y))) + system = NumberSystem() + system.units = smaller_units + larger_units + system.digits = digits + system.math = MathSymbol(positive_cn, negative_cn, point_cn) + # system.symbols = OtherSymbol(sil_cn) + return system + + +def chn2num(chinese_string, numbering_type=NUMBERING_TYPES[1]): + def get_symbol(char, system): + for u in system.units: + if char in [u.traditional, u.simplified, u.big_s, u.big_t]: + return u + for d in system.digits: + if char in [ + d.traditional, d.simplified, d.big_s, d.big_t, d.alt_s, + d.alt_t + ]: + return d + for m in system.math: + if char in [m.traditional, m.simplified]: + return m + + def string2symbols(chinese_string, system): + int_string, dec_string = chinese_string, '' + for p in [system.math.point.simplified, system.math.point.traditional]: + if p in chinese_string: + int_string, dec_string = chinese_string.split(p) + break + return [get_symbol(c, system) for c in int_string], \ + [get_symbol(c, system) for c in dec_string] + + def correct_symbols(integer_symbols, system): + """ + 一百八 to 一百八十 + 一亿一千三百万 to 一亿 一千万 三百万 + """ + + if integer_symbols and isinstance(integer_symbols[0], CNU): + if integer_symbols[0].power == 1: + integer_symbols = [system.digits[1]] + integer_symbols + + if len(integer_symbols) > 1: + if isinstance(integer_symbols[-1], CND) and isinstance( + integer_symbols[-2], CNU): + integer_symbols.append( + CNU(integer_symbols[-2].power - 1, None, None, None, None)) + + result = [] + unit_count = 0 + for s in integer_symbols: + if isinstance(s, CND): + result.append(s) + unit_count = 0 + elif isinstance(s, CNU): + current_unit = CNU(s.power, None, None, None, None) + unit_count += 1 + + if unit_count == 1: + result.append(current_unit) + elif unit_count > 1: + for i in range(len(result)): + if isinstance( + result[-i - 1], + CNU) and result[-i - 1].power < current_unit.power: + result[-i - 1] = CNU( + result[-i - 1].power + current_unit.power, None, + None, None, None) + return result + + def compute_value(integer_symbols): + """ + Compute the value. + When current unit is larger than previous unit, current unit * all previous units will be used as all previous units. + e.g. '两千万' = 2000 * 10000 not 2000 + 10000 + """ + value = [0] + last_power = 0 + for s in integer_symbols: + if isinstance(s, CND): + value[-1] = s.value + elif isinstance(s, CNU): + value[-1] *= pow(10, s.power) + if s.power > last_power: + value[:-1] = list( + map(lambda v: v * pow(10, s.power), value[:-1])) + last_power = s.power + value.append(0) + return sum(value) + + system = create_system(numbering_type) + int_part, dec_part = string2symbols(chinese_string, system) + int_part = correct_symbols(int_part, system) + int_str = str(compute_value(int_part)) + dec_str = ''.join([str(d.value) for d in dec_part]) + if dec_part: + return '{0}.{1}'.format(int_str, dec_str) + else: + return int_str + + +def num2chn(number_string, + numbering_type=NUMBERING_TYPES[1], + big=False, + traditional=False, + alt_zero=False, + alt_one=False, + alt_two=True, + use_zeros=True, + use_units=True): + def get_value(value_string, use_zeros=True): + striped_string = value_string.lstrip('0') + + # record nothing if all zeros + if not striped_string: + return [] + # record one digits + elif len(striped_string) == 1: + if use_zeros and len(value_string) != len(striped_string): + return [system.digits[0], system.digits[int(striped_string)]] + else: + return [system.digits[int(striped_string)]] + # recursively record multiple digits + else: + result_unit = next( + u for u in reversed(system.units) + if u.power < len(striped_string)) + result_string = value_string[:-result_unit.power] + return get_value(result_string) + [result_unit] + get_value( + striped_string[-result_unit.power:]) + + system = create_system(numbering_type) + + int_dec = number_string.split('.') + if len(int_dec) == 1: + int_string = int_dec[0] + dec_string = "" + elif len(int_dec) == 2: + int_string = int_dec[0] + dec_string = int_dec[1] + else: + raise ValueError("invalid input num string with more than one dot: {}". + format(number_string)) + + if use_units and len(int_string) > 1: + result_symbols = get_value(int_string) + else: + result_symbols = [system.digits[int(c)] for c in int_string] + + dec_symbols = [system.digits[int(c)] for c in dec_string] + if dec_string: + result_symbols += [system.math.point] + dec_symbols + + if alt_two: + liang = CND(2, system.digits[2].alt_s, system.digits[2].alt_t, + system.digits[2].big_s, system.digits[2].big_t) + for i, v in enumerate(result_symbols): + if isinstance(v, CND) and v.value == 2: + next_symbol = result_symbols[i + 1] if i < len( + result_symbols) - 1 else None + previous_symbol = result_symbols[i - 1] if i > 0 else None + if isinstance(next_symbol, CNU) and isinstance( + previous_symbol, (CNU, type(None))): + # yapf: disable + if next_symbol.power != 1 and ((previous_symbol is None) or + (previous_symbol.power != 1)): + result_symbols[i] = liang + # yapf: enable + + # if big is True, '两' will not be used and `alt_two` has no impact on output + if big: + attr_name = 'big_' + if traditional: + attr_name += 't' + else: + attr_name += 's' + else: + if traditional: + attr_name = 'traditional' + else: + attr_name = 'simplified' + + result = ''.join([getattr(s, attr_name) for s in result_symbols]) + + # if not use_zeros: + # result = result.strip(getattr(system.digits[0], attr_name)) + + if alt_zero: + result = result.replace( + getattr(system.digits[0], attr_name), system.digits[0].alt_s) + + if alt_one: + result = result.replace( + getattr(system.digits[1], attr_name), system.digits[1].alt_s) + + for i, p in enumerate(POINT): + if result.startswith(p): + return CHINESE_DIGIS[0] + result + + # ^10, 11, .., 19 + if len(result) >= 2 and result[1] in [SMALLER_CHINESE_NUMERING_UNITS_SIMPLIFIED[0], + SMALLER_CHINESE_NUMERING_UNITS_TRADITIONAL[0]] and \ + result[0] in [CHINESE_DIGIS[1], BIG_CHINESE_DIGIS_SIMPLIFIED[1], BIG_CHINESE_DIGIS_TRADITIONAL[1]]: + result = result[1:] + + return result + + +# ================================================================================ # +# different types of rewriters +# ================================================================================ # +class Cardinal: + """ + CARDINAL类 + """ + + def __init__(self, cardinal=None, chntext=None): + self.cardinal = cardinal + self.chntext = chntext + + def chntext2cardinal(self): + return chn2num(self.chntext) + + def cardinal2chntext(self): + return num2chn(self.cardinal) + + +class Digit: + """ + DIGIT类 + """ + + def __init__(self, digit=None, chntext=None): + self.digit = digit + self.chntext = chntext + + # def chntext2digit(self): + # return chn2num(self.chntext) + + def digit2chntext(self): + return num2chn(self.digit, alt_two=False, use_units=False) + + +class TelePhone: + """ + TELEPHONE类 + """ + + def __init__(self, telephone=None, raw_chntext=None, chntext=None): + self.telephone = telephone + self.raw_chntext = raw_chntext + self.chntext = chntext + + # def chntext2telephone(self): + # sil_parts = self.raw_chntext.split('') + # self.telephone = '-'.join([ + # str(chn2num(p)) for p in sil_parts + # ]) + # return self.telephone + + def telephone2chntext(self, fixed=False): + if fixed: + sil_parts = self.telephone.split('-') + self.raw_chntext = ''.join([ + num2chn(part, alt_two=False, use_units=False) + for part in sil_parts + ]) + self.chntext = self.raw_chntext.replace('', '') + else: + sp_parts = self.telephone.strip('+').split() + self.raw_chntext = ''.join([ + num2chn(part, alt_two=False, use_units=False) + for part in sp_parts + ]) + self.chntext = self.raw_chntext.replace('', '') + return self.chntext + + +class Fraction: + """ + FRACTION类 + """ + + def __init__(self, fraction=None, chntext=None): + self.fraction = fraction + self.chntext = chntext + + def chntext2fraction(self): + denominator, numerator = self.chntext.split('分之') + return chn2num(numerator) + '/' + chn2num(denominator) + + def fraction2chntext(self): + numerator, denominator = self.fraction.split('/') + return num2chn(denominator) + '分之' + num2chn(numerator) + + +class Date: + """ + DATE类 + """ + + def __init__(self, date=None, chntext=None): + self.date = date + self.chntext = chntext + + # def chntext2date(self): + # chntext = self.chntext + # try: + # year, other = chntext.strip().split('年', maxsplit=1) + # year = Digit(chntext=year).digit2chntext() + '年' + # except ValueError: + # other = chntext + # year = '' + # if other: + # try: + # month, day = other.strip().split('月', maxsplit=1) + # month = Cardinal(chntext=month).chntext2cardinal() + '月' + # except ValueError: + # day = chntext + # month = '' + # if day: + # day = Cardinal(chntext=day[:-1]).chntext2cardinal() + day[-1] + # else: + # month = '' + # day = '' + # date = year + month + day + # self.date = date + # return self.date + + def date2chntext(self): + date = self.date + try: + year, other = date.strip().split('年', 1) + year = Digit(digit=year).digit2chntext() + '年' + except ValueError: + other = date + year = '' + + if other: + try: + month, day = other.strip().split('月', 1) + month = Cardinal(cardinal=month).cardinal2chntext() + '月' + except ValueError: + day = date + month = '' + + if day: + day = Cardinal(cardinal=day[:-1]).cardinal2chntext() + day[-1] + else: + month = '' + day = '' + + chntext = year + month + day + self.chntext = chntext + return self.chntext + + +class Money: + """ + MONEY类 + """ + + def __init__(self, money=None, chntext=None): + self.money = money + self.chntext = chntext + + # def chntext2money(self): + # return self.money + + def money2chntext(self): + money = self.money + pattern = re.compile(r'(\d+(\.\d+)?)') + matchers = pattern.findall(money) + if matchers: + for matcher in matchers: + money = money.replace( + matcher[0], + Cardinal(cardinal=matcher[0]).cardinal2chntext()) + self.chntext = money + return self.chntext + + +class Percentage: + """ + PERCENTAGE类 + """ + + def __init__(self, percentage=None, chntext=None): + self.percentage = percentage + self.chntext = chntext + + def chntext2percentage(self): + return chn2num(self.chntext.strip().strip('百分之')) + '%' + + def percentage2chntext(self): + return '百分之' + num2chn(self.percentage.strip().strip('%')) + + +# ================================================================================ # +# NSW Normalizer +# ================================================================================ # +class NSWNormalizer: + def __init__(self, raw_text): + self.raw_text = '^' + raw_text + '$' + self.norm_text = '' + + def _particular(self): + text = self.norm_text + pattern = re.compile(r"(([a-zA-Z]+)二([a-zA-Z]+))") + matchers = pattern.findall(text) + if matchers: + # print('particular') + for matcher in matchers: + text = text.replace(matcher[0], matcher[1] + '2' + matcher[2], + 1) + self.norm_text = text + return self.norm_text + + def normalize(self): + text = self.raw_text + + # 规范化日期 + pattern = re.compile( + r"\D+((([089]\d|(19|20)\d{2})年)?(\d{1,2}月(\d{1,2}[日号])?)?)") + matchers = pattern.findall(text) + if matchers: + #print('date') + for matcher in matchers: + text = text.replace( + matcher[0], Date(date=matcher[0]).date2chntext(), 1) + + # 规范化金钱 + pattern = re.compile(r"\D+((\d+(\.\d+)?)[多余几]?" + CURRENCY_UNITS + + r"(\d" + CURRENCY_UNITS + r"?)?)") + matchers = pattern.findall(text) + if matchers: + #print('money') + for matcher in matchers: + text = text.replace( + matcher[0], Money(money=matcher[0]).money2chntext(), 1) + + # 规范化固话/手机号码 + # 手机 + # http://www.jihaoba.com/news/show/13680 + # 移动:139、138、137、136、135、134、159、158、157、150、151、152、188、187、182、183、184、178、198 + # 联通:130、131、132、156、155、186、185、176 + # 电信:133、153、189、180、181、177 + pattern = re.compile( + r"\D((\+?86 ?)?1([38]\d|5[0-35-9]|7[678]|9[89])\d{8})\D") + matchers = pattern.findall(text) + if matchers: + #print('telephone') + for matcher in matchers: + text = text.replace( + matcher[0], + TelePhone(telephone=matcher[0]).telephone2chntext(), + 1) + # 固话 + pattern = re.compile(r"\D((0(10|2[1-3]|[3-9]\d{2})-?)?[1-9]\d{6,7})\D") + matchers = pattern.findall(text) + if matchers: + # print('fixed telephone') + for matcher in matchers: + text = text.replace( + matcher[0], + TelePhone(telephone=matcher[0]).telephone2chntext( + fixed=True), + 1) + + # 规范化分数 + pattern = re.compile(r"(\d+/\d+)") + matchers = pattern.findall(text) + if matchers: + #print('fraction') + for matcher in matchers: + text = text.replace( + matcher, Fraction(fraction=matcher).fraction2chntext(), 1) + + # 规范化百分数 + text = text.replace('%', '%') + pattern = re.compile(r"(\d+(\.\d+)?%)") + matchers = pattern.findall(text) + if matchers: + #print('percentage') + for matcher in matchers: + text = text.replace( + matcher[0], + Percentage(percentage=matcher[0]).percentage2chntext(), + 1) + + # 规范化纯数+量词 + pattern = re.compile(r"(\d+(\.\d+)?)[多余几]?" + COM_QUANTIFIERS) + matchers = pattern.findall(text) + if matchers: + #print('cardinal+quantifier') + for matcher in matchers: + text = text.replace( + matcher[0], + Cardinal(cardinal=matcher[0]).cardinal2chntext(), + 1) + + # 规范化数字编号 + pattern = re.compile(r"(\d{4,32})") + matchers = pattern.findall(text) + if matchers: + #print('digit') + for matcher in matchers: + text = text.replace( + matcher, Digit(digit=matcher).digit2chntext(), 1) + + # 规范化纯数 + pattern = re.compile(r"(\d+(\.\d+)?)") + matchers = pattern.findall(text) + if matchers: + #print('cardinal') + for matcher in matchers: + text = text.replace( + matcher[0], + Cardinal(cardinal=matcher[0]).cardinal2chntext(), + 1) + + self.norm_text = text + self._particular() + + return self.norm_text.lstrip('^').rstrip('$') + + +def nsw_test_case(raw_text): + print('I:' + raw_text) + print('O:' + NSWNormalizer(raw_text).normalize()) + print('') + + +def nsw_test(): + nsw_test_case('固话:0595-23865596或23880880。') + nsw_test_case('固话:0595-23865596或23880880。') + nsw_test_case('手机:+86 19859213959或15659451527。') + nsw_test_case('分数:32477/76391。') + nsw_test_case('百分数:80.03%。') + nsw_test_case('编号:31520181154418。') + nsw_test_case('纯数:2983.07克或12345.60米。') + nsw_test_case('日期:1999年2月20日或09年3月15号。') + nsw_test_case('金钱:12块5,34.5元,20.1万') + nsw_test_case('特殊:O2O或B2C。') + nsw_test_case('3456万吨') + nsw_test_case('2938个') + nsw_test_case('938') + nsw_test_case('今天吃了115个小笼包231个馒头') + nsw_test_case('有62%的概率') + + +def char_token(s: Text) -> List[Text]: + """chinese charactor + + Args: + s (Text): [description] + + Returns: + List[Text]: [description] + """ + return list(s) + + +def word_token(s: Text) -> List[Text]: + """chinese word + + Args: + s (Text): [description] + + Returns: + List[Text]: [description] + """ + return jieba.lcut(s) + + +def text_process(s: Text) -> Text: + """do chinese text normaliztion + + Args: + s (Text): [description] + + Returns: + Text: [description] + """ + s = s.replace('*', '') + # NSW(Non-Standard-Word) normalization + s = NSWNormalizer(s).normalize() + # Punctuations removal + s = re.sub(f'[{hanzi.punctuation}{string.punctuation}]', "", s) + # rm english + s = ''.join(re.findall(hanzi.sent, s)) + return s + + +def main(infile, outfile, args): + # tokenizer + token_type = args.token_type + if token_type == 'char': + tokenizer = char_token + elif token_type == 'word': + tokenizer = word_token + else: + tokenizer = None + + with open(infile, 'rt') as fin, open(outfile, 'wt') as fout: + lines = fin.readlines() + n = 0 + for l in lines: + key = '' + text = '' + + if args.has_key: + cols = l.split(maxsplit=1) + key = cols[0] + if len(cols) == 2: + text = cols[1] + else: + text = '' + else: + text = l + + # strip + text = text.strip() + # cases + if args.to_upper and args.to_lower: + sys.stderr.write('to_upper OR to_lower?') + exit(1) + if args.to_upper: + text = text.upper() + if args.to_lower: + text = text.lower() + + # Normalization + text = text_process(text) + if tokenizer: + text = ' '.join(tokenizer(text)) + + if args.has_key: + fout.write(key + '\t' + text + '\n') + else: + if text.strip( + ) != '': # skip empty line in pure text format(without Kaldi's utt key) + fout.write(text + '\n') + + n += 1 + if n % args.log_interval == 0: + print(f"process {n} lines.", file=sys.stderr) + + +if __name__ == '__main__': + p = argparse.ArgumentParser() + p.add_argument('token_type', default=None, help='token type. [char|word]') + p.add_argument('ifile', help='input filename, assume utf-8 encoding') + p.add_argument('ofile', help='output filename') + p.add_argument( + '--to_upper', action='store_true', help='convert to upper case') + p.add_argument( + '--to_lower', action='store_true', help='convert to lower case') + p.add_argument( + '--has_key', + action='store_true', + help="input text has Kaldi's key as first field.") + p.add_argument( + '--log_interval', + type=int, + default=100000, + help='log interval in number of processed lines') + args = p.parse_args() + + main(args.ifile, args.ofile, args)