|
|
|
@ -54,17 +54,13 @@ add_arg('spm_model_prefix', str, "spm_model_%(spm_mode)_%(count_threshold)",
|
|
|
|
|
args = parser.parse_args()
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
def count_manifest(counter, manifest_path):
|
|
|
|
|
def count_manifest(counter, text_feature, manifest_path):
|
|
|
|
|
manifest_jsons = read_manifest(manifest_path)
|
|
|
|
|
for line_json in manifest_jsons:
|
|
|
|
|
if args.unit_type == 'char':
|
|
|
|
|
for char in line_json['text']:
|
|
|
|
|
counter.update(char)
|
|
|
|
|
elif args.unit_type == 'word':
|
|
|
|
|
for word in line_json['text'].split():
|
|
|
|
|
counter.update(word)
|
|
|
|
|
|
|
|
|
|
def read_text_manifest(fileobj, manifest_path):
|
|
|
|
|
line = text_feature.tokenize(line_json['text'])
|
|
|
|
|
counter.update(line)
|
|
|
|
|
|
|
|
|
|
def dump_text_manifest(fileobj, manifest_path):
|
|
|
|
|
manifest_jsons = read_manifest(manifest_path)
|
|
|
|
|
for line_json in manifest_jsons:
|
|
|
|
|
fileobj.write(line_json['text'] + "\n")
|
|
|
|
@ -77,9 +73,11 @@ def main():
|
|
|
|
|
fout.write(UNK + '\n') # <unk> must be 1
|
|
|
|
|
|
|
|
|
|
if args.unit_type != 'spm':
|
|
|
|
|
text_feature = TextFeaturizer(args.unit_type, args.vocab_path)
|
|
|
|
|
counter = Counter()
|
|
|
|
|
|
|
|
|
|
for manifest_path in args.manifest_paths:
|
|
|
|
|
count_manifest(counter, manifest_path)
|
|
|
|
|
count_manifest(counter, text_feature, manifest_path)
|
|
|
|
|
|
|
|
|
|
count_sorted = sorted(counter.items(), key=lambda x: x[1], reverse=True)
|
|
|
|
|
for char, count in count_sorted:
|
|
|
|
@ -93,7 +91,7 @@ def main():
|
|
|
|
|
|
|
|
|
|
fp = tempfile.NamedTemporaryFile(mode='w', delete=False)
|
|
|
|
|
for manifest_path in args.manifest_paths:
|
|
|
|
|
read_text_manifest(fp, manifest_path)
|
|
|
|
|
dump_text_manifest(fp, manifest_path)
|
|
|
|
|
fp.close()
|
|
|
|
|
# train
|
|
|
|
|
spm.SentencePieceTrainer.Train(
|
|
|
|
@ -108,20 +106,29 @@ def main():
|
|
|
|
|
# encode
|
|
|
|
|
text_feature = TextFeaturizer(args.unit_type, args.vocab_path, args.spm_model_prefix)
|
|
|
|
|
|
|
|
|
|
vocabs = set()
|
|
|
|
|
# vocabs = set()
|
|
|
|
|
# for manifest_path in args.manifest_paths:
|
|
|
|
|
# manifest_jsons = read_manifest(manifest_path)
|
|
|
|
|
# for line_json in manifest_jsons:
|
|
|
|
|
# line = line_json['text']
|
|
|
|
|
# enc_line = text_feature.spm_tokenize(line)
|
|
|
|
|
# for code in enc_line:
|
|
|
|
|
# vocabs.add(code)
|
|
|
|
|
# #print(" ".join(enc_line))
|
|
|
|
|
# vocabs_sorted = sorted(vocabs)
|
|
|
|
|
# for unit in vocabs_sorted:
|
|
|
|
|
# fout.write(unit + "\n")
|
|
|
|
|
|
|
|
|
|
counter = Counter()
|
|
|
|
|
|
|
|
|
|
for manifest_path in args.manifest_paths:
|
|
|
|
|
manifest_jsons = read_manifest(manifest_path)
|
|
|
|
|
for line_json in manifest_jsons:
|
|
|
|
|
line = line_json['text']
|
|
|
|
|
enc_line = text_feature.spm_tokenize(line)
|
|
|
|
|
for code in enc_line:
|
|
|
|
|
vocabs.add(code)
|
|
|
|
|
#print(" ".join(enc_line))
|
|
|
|
|
vocabs_sorted = sorted(vocabs)
|
|
|
|
|
for unit in vocabs_sorted:
|
|
|
|
|
fout.write(unit + "\n")
|
|
|
|
|
|
|
|
|
|
print(f"spm vocab size: {len(vocabs_sorted)}")
|
|
|
|
|
count_manifest(counter, text_feature, manifest_path)
|
|
|
|
|
|
|
|
|
|
count_sorted = sorted(counter.items(), key=lambda x: x[1], reverse=True)
|
|
|
|
|
for token, count in count_sorted:
|
|
|
|
|
fout.write(token + '\n')
|
|
|
|
|
|
|
|
|
|
print(f"spm vocab size: {len(count_sorted)}")
|
|
|
|
|
|
|
|
|
|
fout.write(SOS + "\n") # <sos/eos>
|
|
|
|
|
fout.close()
|
|
|
|
|