refactor data preprare

pull/578/head
Hui Zhang 5 years ago
parent b767486608
commit 553aa35989

@ -24,7 +24,7 @@ bpeprefix="data/bpe_${bpemode}_${nbpe}"
# build vocabulary # build vocabulary
python3 ${MAIN_ROOT}/utils/build_vocab.py \ python3 ${MAIN_ROOT}/utils/build_vocab.py \
--unit_type "spm" \ --unit_type "spm" \
--count_threshold=${nbpe} \ --vocab_size=${nbpe} \
--spm_mode ${bpemode} \ --spm_mode ${bpemode} \
--spm_model_prefix ${bpeprefix} \ --spm_model_prefix ${bpeprefix} \
--vocab_path="data/vocab.txt" \ --vocab_path="data/vocab.txt" \

@ -35,7 +35,8 @@ parser = argparse.ArgumentParser(description=__doc__)
add_arg = functools.partial(add_arguments, argparser=parser) add_arg = functools.partial(add_arguments, argparser=parser)
# yapf: disable # yapf: disable
add_arg('unit_type', str, "char", "Unit type, e.g. char, word, spm") add_arg('unit_type', str, "char", "Unit type, e.g. char, word, spm")
add_arg('count_threshold', int, 0, "Truncation threshold for char/word/spm counts.") add_arg('count_threshold', int, 0,
"Truncation threshold for char/word counts.Default 0, no truncate.")
add_arg('vocab_path', str, add_arg('vocab_path', str,
'examples/librispeech/data/vocab.txt', 'examples/librispeech/data/vocab.txt',
"Filepath to write the vocabulary.") "Filepath to write the vocabulary.")
@ -46,6 +47,7 @@ add_arg('manifest_paths', str,
nargs='+', nargs='+',
required=True) required=True)
# bpe # bpe
add_arg('vocab_size', int, 0, "Vocab size for spm.")
add_arg('spm_mode', str, 'unigram', add_arg('spm_mode', str, 'unigram',
"spm model type, e.g. unigram, spm, char, word. only need when `unit_type` is spm") "spm model type, e.g. unigram, spm, char, word. only need when `unit_type` is spm")
add_arg('spm_model_prefix', str, "spm_model_%(spm_mode)_%(count_threshold)", add_arg('spm_model_prefix', str, "spm_model_%(spm_mode)_%(count_threshold)",
@ -72,18 +74,7 @@ def main():
fout.write(BLANK + "\n") # 0 will be used for "blank" in CTC fout.write(BLANK + "\n") # 0 will be used for "blank" in CTC
fout.write(UNK + '\n') # <unk> must be 1 fout.write(UNK + '\n') # <unk> must be 1
if args.unit_type != 'spm': if args.unit_type == 'spm':
text_feature = TextFeaturizer(args.unit_type, args.vocab_path)
counter = Counter()
for manifest_path in args.manifest_paths:
count_manifest(counter, text_feature, manifest_path)
count_sorted = sorted(counter.items(), key=lambda x: x[1], reverse=True)
for char, count in count_sorted:
if count < args.count_threshold: break
fout.write(char + '\n')
else:
# tools/spm_train --input=$wave_data/lang_char/input.txt # tools/spm_train --input=$wave_data/lang_char/input.txt
# --vocab_size=${nbpe} --model_type=${bpemode} # --vocab_size=${nbpe} --model_type=${bpemode}
# --model_prefix=${bpemodel} --input_sentence_size=100000000 # --model_prefix=${bpemodel} --input_sentence_size=100000000
@ -96,39 +87,29 @@ def main():
# train # train
spm.SentencePieceTrainer.Train( spm.SentencePieceTrainer.Train(
input=fp.name, input=fp.name,
vocab_size=args.count_threshold, vocab_size=args.vocab_size,
model_type=args.spm_mode, model_type=args.spm_mode,
model_prefix=args.spm_model_prefix, model_prefix=args.spm_model_prefix,
input_sentence_size=100000000, input_sentence_size=100000000,
character_coverage=0.9995) character_coverage=0.9995)
os.unlink(fp.name) os.unlink(fp.name)
# encode # encode
text_feature = TextFeaturizer(args.unit_type, args.vocab_path, args.spm_model_prefix) text_feature = TextFeaturizer(args.unit_type, args.vocab_path, args.spm_model_prefix)
counter = Counter()
# vocabs = set() for manifest_path in args.manifest_paths:
# for manifest_path in args.manifest_paths: count_manifest(counter, text_feature, manifest_path)
# manifest_jsons = read_manifest(manifest_path)
# for line_json in manifest_jsons:
# line = line_json['text']
# enc_line = text_feature.spm_tokenize(line)
# for code in enc_line:
# vocabs.add(code)
# #print(" ".join(enc_line))
# vocabs_sorted = sorted(vocabs)
# for unit in vocabs_sorted:
# fout.write(unit + "\n")
counter = Counter()
for manifest_path in args.manifest_paths:
count_manifest(counter, text_feature, manifest_path)
count_sorted = sorted(counter.items(), key=lambda x: x[1], reverse=True) count_sorted = sorted(counter.items(), key=lambda x: x[1], reverse=True)
for token, count in count_sorted: tokens = []
fout.write(token + '\n') for token, count in count_sorted:
if count < args.count_threshold: break
tokens.append(token)
print(f"spm vocab size: {len(count_sorted)}") tokens = sorted(tokens)
for token in tokens:
fout.write(token + '\n')
fout.write(SOS + "\n") # <sos/eos> fout.write(SOS + "\n") # <sos/eos>
fout.close() fout.close()

@ -67,16 +67,12 @@ def main():
vocab_size = text_feature.vocab_size vocab_size = text_feature.vocab_size
print(f"Vocab size: {vocab_size}") print(f"Vocab size: {vocab_size}")
count = 0
for manifest_path in args.manifest_paths: for manifest_path in args.manifest_paths:
manifest_jsons = read_manifest(manifest_path) manifest_jsons = read_manifest(manifest_path)
for line_json in manifest_jsons: for line_json in manifest_jsons:
line = line_json['text'] line = line_json['text']
if args.unit_type == 'char': tokens = text_feature.tokenize(line)
tokens = text_feature.char_tokenize(line)
elif args.unit_type == 'word':
tokens = text_feature.word_tokenize(line)
else: #spm
tokens = text_feature.spm_tokenize(line)
tokenids = text_feature.featurize(line) tokenids = text_feature.featurize(line)
line_json['token'] = tokens line_json['token'] = tokens
line_json['token_id'] = tokenids line_json['token_id'] = tokenids
@ -88,7 +84,9 @@ def main():
else: # kaldi else: # kaldi
raise NotImplemented('no support kaldi feat now!') raise NotImplemented('no support kaldi feat now!')
fout.write(json.dumps(line_json) + '\n') fout.write(json.dumps(line_json) + '\n')
count += 1
print(f"Examples number: {count}")
fout.close() fout.close()

Loading…
Cancel
Save