You can not select more than 25 topics
Topics must start with a letter or number, can include dashes ('-') and can be up to 35 characters long.
162 lines
4.7 KiB
162 lines
4.7 KiB
2 years ago
|
import re
|
||
|
import logging
|
||
|
import sys
|
||
|
|
||
|
|
||
|
class RunningAverage:
|
||
|
def __init__(self):
|
||
|
self.values = []
|
||
|
|
||
|
def add(self, val):
|
||
|
self.values.append(val)
|
||
|
|
||
|
def add_all(self, vals):
|
||
|
self.values += vals
|
||
|
|
||
|
def get(self):
|
||
|
if len(self.values) == 0:
|
||
|
return None
|
||
|
return sum(self.values) / len(self.values)
|
||
|
|
||
|
def flush(self):
|
||
|
self.values = []
|
||
|
|
||
|
|
||
|
def wordize_and_map(text):
|
||
|
words = []
|
||
|
index_map_from_text_to_word = []
|
||
|
index_map_from_word_to_text = []
|
||
|
while len(text) > 0:
|
||
|
match_space = re.match(r'^ +', text)
|
||
|
if match_space:
|
||
|
space_str = match_space.group(0)
|
||
|
index_map_from_text_to_word += [None] * len(space_str)
|
||
|
text = text[len(space_str):]
|
||
|
continue
|
||
|
|
||
|
match_en = re.match(r'^[a-zA-Z0-9]+', text)
|
||
|
if match_en:
|
||
|
en_word = match_en.group(0)
|
||
|
|
||
|
word_start_pos = len(index_map_from_text_to_word)
|
||
|
word_end_pos = word_start_pos + len(en_word)
|
||
|
index_map_from_word_to_text.append((word_start_pos, word_end_pos))
|
||
|
|
||
|
index_map_from_text_to_word += [len(words)] * len(en_word)
|
||
|
|
||
|
words.append(en_word)
|
||
|
text = text[len(en_word):]
|
||
|
else:
|
||
|
word_start_pos = len(index_map_from_text_to_word)
|
||
|
word_end_pos = word_start_pos + 1
|
||
|
index_map_from_word_to_text.append((word_start_pos, word_end_pos))
|
||
|
|
||
|
index_map_from_text_to_word += [len(words)]
|
||
|
|
||
|
words.append(text[0])
|
||
|
text = text[1:]
|
||
|
return words, index_map_from_text_to_word, index_map_from_word_to_text
|
||
|
|
||
|
|
||
|
def tokenize_and_map(tokenizer, text):
|
||
|
words, text2word, word2text = wordize_and_map(text)
|
||
|
|
||
|
tokens = []
|
||
|
index_map_from_token_to_text = []
|
||
|
for word, (word_start, word_end) in zip(words, word2text):
|
||
|
word_tokens = tokenizer.tokenize(word)
|
||
|
|
||
|
if len(word_tokens) == 0 or word_tokens == ['[UNK]']:
|
||
|
index_map_from_token_to_text.append((word_start, word_end))
|
||
|
tokens.append('[UNK]')
|
||
|
else:
|
||
|
current_word_start = word_start
|
||
|
for word_token in word_tokens:
|
||
|
word_token_len = len(re.sub(r'^##', '', word_token))
|
||
|
index_map_from_token_to_text.append(
|
||
|
(current_word_start, current_word_start + word_token_len))
|
||
|
current_word_start = current_word_start + word_token_len
|
||
|
tokens.append(word_token)
|
||
|
|
||
|
index_map_from_text_to_token = text2word
|
||
|
for i, (token_start, token_end) in enumerate(index_map_from_token_to_text):
|
||
|
for token_pos in range(token_start, token_end):
|
||
|
index_map_from_text_to_token[token_pos] = i
|
||
|
|
||
|
return tokens, index_map_from_text_to_token, index_map_from_token_to_text
|
||
|
|
||
|
|
||
|
def _load_config(config_path):
|
||
|
import importlib.util
|
||
|
spec = importlib.util.spec_from_file_location('__init__', config_path)
|
||
|
config = importlib.util.module_from_spec(spec)
|
||
|
spec.loader.exec_module(config)
|
||
|
return config
|
||
|
|
||
|
|
||
|
default_config_dict = {
|
||
|
'manual_seed': 1313,
|
||
|
'model_source': 'bert-base-chinese',
|
||
|
'window_size': 32,
|
||
|
'num_workers': 2,
|
||
|
'use_mask': True,
|
||
|
'use_char_phoneme': False,
|
||
|
'use_conditional': True,
|
||
|
'param_conditional': {
|
||
|
'affect_location': 'softmax',
|
||
|
'bias': True,
|
||
|
'char-linear': True,
|
||
|
'pos-linear': False,
|
||
|
'char+pos-second': True,
|
||
|
|
||
|
'char+pos-second_lowrank': False,
|
||
|
'lowrank_size': 0,
|
||
|
'char+pos-second_fm': False,
|
||
|
'fm_size': 0,
|
||
|
'fix_mode': None,
|
||
|
'count_json': 'train.count.json'
|
||
|
},
|
||
|
'lr': 5e-5,
|
||
|
'val_interval': 200,
|
||
|
'num_iter': 10000,
|
||
|
'use_focal': False,
|
||
|
'param_focal': {
|
||
|
'alpha': 0.0,
|
||
|
'gamma': 0.7
|
||
|
},
|
||
|
'use_pos': True,
|
||
|
'param_pos ': {
|
||
|
'weight': 0.1,
|
||
|
'pos_joint_training': True,
|
||
|
'train_pos_path': 'train.pos',
|
||
|
'valid_pos_path': 'dev.pos',
|
||
|
'test_pos_path': 'test.pos'
|
||
|
}
|
||
|
}
|
||
|
|
||
|
|
||
|
def load_config(config_path, use_default=False):
|
||
|
config = _load_config(config_path)
|
||
|
if use_default:
|
||
|
for attr, val in default_config_dict.items():
|
||
|
if not hasattr(config, attr):
|
||
|
setattr(config, attr, val)
|
||
|
elif isinstance(val, dict):
|
||
|
d = getattr(config, attr)
|
||
|
for dict_k, dict_v in val.items():
|
||
|
if dict_k not in d:
|
||
|
d[dict_k] = dict_v
|
||
|
return config
|
||
|
|
||
|
|
||
|
def get_logger(file_path):
|
||
|
logger = logging.getLogger()
|
||
|
logger.setLevel(logging.DEBUG)
|
||
|
|
||
|
output_file_handler = logging.FileHandler(file_path)
|
||
|
stdout_handler = logging.StreamHandler(sys.stdout)
|
||
|
|
||
|
logger.addHandler(output_file_handler)
|
||
|
logger.addHandler(stdout_handler)
|
||
|
return logger
|