|
|
@ -35,7 +35,8 @@ parser = argparse.ArgumentParser(description=__doc__)
|
|
|
|
add_arg = functools.partial(add_arguments, argparser=parser)
|
|
|
|
add_arg = functools.partial(add_arguments, argparser=parser)
|
|
|
|
# yapf: disable
|
|
|
|
# yapf: disable
|
|
|
|
add_arg('unit_type', str, "char", "Unit type, e.g. char, word, spm")
|
|
|
|
add_arg('unit_type', str, "char", "Unit type, e.g. char, word, spm")
|
|
|
|
add_arg('count_threshold', int, 0, "Truncation threshold for char/word/spm counts.")
|
|
|
|
add_arg('count_threshold', int, 0,
|
|
|
|
|
|
|
|
"Truncation threshold for char/word counts.Default 0, no truncate.")
|
|
|
|
add_arg('vocab_path', str,
|
|
|
|
add_arg('vocab_path', str,
|
|
|
|
'examples/librispeech/data/vocab.txt',
|
|
|
|
'examples/librispeech/data/vocab.txt',
|
|
|
|
"Filepath to write the vocabulary.")
|
|
|
|
"Filepath to write the vocabulary.")
|
|
|
@ -46,6 +47,7 @@ add_arg('manifest_paths', str,
|
|
|
|
nargs='+',
|
|
|
|
nargs='+',
|
|
|
|
required=True)
|
|
|
|
required=True)
|
|
|
|
# bpe
|
|
|
|
# bpe
|
|
|
|
|
|
|
|
add_arg('vocab_size', int, 0, "Vocab size for spm.")
|
|
|
|
add_arg('spm_mode', str, 'unigram',
|
|
|
|
add_arg('spm_mode', str, 'unigram',
|
|
|
|
"spm model type, e.g. unigram, spm, char, word. only need when `unit_type` is spm")
|
|
|
|
"spm model type, e.g. unigram, spm, char, word. only need when `unit_type` is spm")
|
|
|
|
add_arg('spm_model_prefix', str, "spm_model_%(spm_mode)_%(count_threshold)",
|
|
|
|
add_arg('spm_model_prefix', str, "spm_model_%(spm_mode)_%(count_threshold)",
|
|
|
@ -72,18 +74,7 @@ def main():
|
|
|
|
fout.write(BLANK + "\n") # 0 will be used for "blank" in CTC
|
|
|
|
fout.write(BLANK + "\n") # 0 will be used for "blank" in CTC
|
|
|
|
fout.write(UNK + '\n') # <unk> must be 1
|
|
|
|
fout.write(UNK + '\n') # <unk> must be 1
|
|
|
|
|
|
|
|
|
|
|
|
if args.unit_type != 'spm':
|
|
|
|
if args.unit_type == 'spm':
|
|
|
|
text_feature = TextFeaturizer(args.unit_type, args.vocab_path)
|
|
|
|
|
|
|
|
counter = Counter()
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
for manifest_path in args.manifest_paths:
|
|
|
|
|
|
|
|
count_manifest(counter, text_feature, manifest_path)
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
count_sorted = sorted(counter.items(), key=lambda x: x[1], reverse=True)
|
|
|
|
|
|
|
|
for char, count in count_sorted:
|
|
|
|
|
|
|
|
if count < args.count_threshold: break
|
|
|
|
|
|
|
|
fout.write(char + '\n')
|
|
|
|
|
|
|
|
else:
|
|
|
|
|
|
|
|
# tools/spm_train --input=$wave_data/lang_char/input.txt
|
|
|
|
# tools/spm_train --input=$wave_data/lang_char/input.txt
|
|
|
|
# --vocab_size=${nbpe} --model_type=${bpemode}
|
|
|
|
# --vocab_size=${nbpe} --model_type=${bpemode}
|
|
|
|
# --model_prefix=${bpemodel} --input_sentence_size=100000000
|
|
|
|
# --model_prefix=${bpemodel} --input_sentence_size=100000000
|
|
|
@ -96,39 +87,29 @@ def main():
|
|
|
|
# train
|
|
|
|
# train
|
|
|
|
spm.SentencePieceTrainer.Train(
|
|
|
|
spm.SentencePieceTrainer.Train(
|
|
|
|
input=fp.name,
|
|
|
|
input=fp.name,
|
|
|
|
vocab_size=args.count_threshold,
|
|
|
|
vocab_size=args.vocab_size,
|
|
|
|
model_type=args.spm_mode,
|
|
|
|
model_type=args.spm_mode,
|
|
|
|
model_prefix=args.spm_model_prefix,
|
|
|
|
model_prefix=args.spm_model_prefix,
|
|
|
|
input_sentence_size=100000000,
|
|
|
|
input_sentence_size=100000000,
|
|
|
|
character_coverage=0.9995)
|
|
|
|
character_coverage=0.9995)
|
|
|
|
os.unlink(fp.name)
|
|
|
|
os.unlink(fp.name)
|
|
|
|
|
|
|
|
|
|
|
|
# encode
|
|
|
|
# encode
|
|
|
|
text_feature = TextFeaturizer(args.unit_type, args.vocab_path, args.spm_model_prefix)
|
|
|
|
text_feature = TextFeaturizer(args.unit_type, args.vocab_path, args.spm_model_prefix)
|
|
|
|
|
|
|
|
counter = Counter()
|
|
|
|
# vocabs = set()
|
|
|
|
|
|
|
|
# for manifest_path in args.manifest_paths:
|
|
|
|
for manifest_path in args.manifest_paths:
|
|
|
|
# manifest_jsons = read_manifest(manifest_path)
|
|
|
|
count_manifest(counter, text_feature, manifest_path)
|
|
|
|
# for line_json in manifest_jsons:
|
|
|
|
|
|
|
|
# line = line_json['text']
|
|
|
|
count_sorted = sorted(counter.items(), key=lambda x: x[1], reverse=True)
|
|
|
|
# enc_line = text_feature.spm_tokenize(line)
|
|
|
|
tokens = []
|
|
|
|
# for code in enc_line:
|
|
|
|
for token, count in count_sorted:
|
|
|
|
# vocabs.add(code)
|
|
|
|
if count < args.count_threshold: break
|
|
|
|
# #print(" ".join(enc_line))
|
|
|
|
tokens.append(token)
|
|
|
|
# vocabs_sorted = sorted(vocabs)
|
|
|
|
|
|
|
|
# for unit in vocabs_sorted:
|
|
|
|
tokens = sorted(tokens)
|
|
|
|
# fout.write(unit + "\n")
|
|
|
|
for token in tokens:
|
|
|
|
|
|
|
|
fout.write(token + '\n')
|
|
|
|
counter = Counter()
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
for manifest_path in args.manifest_paths:
|
|
|
|
|
|
|
|
count_manifest(counter, text_feature, manifest_path)
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
count_sorted = sorted(counter.items(), key=lambda x: x[1], reverse=True)
|
|
|
|
|
|
|
|
for token, count in count_sorted:
|
|
|
|
|
|
|
|
fout.write(token + '\n')
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
print(f"spm vocab size: {len(count_sorted)}")
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
fout.write(SOS + "\n") # <sos/eos>
|
|
|
|
fout.write(SOS + "\n") # <sos/eos>
|
|
|
|
fout.close()
|
|
|
|
fout.close()
|
|
|
|