|
|
@ -29,12 +29,13 @@ from deepspeech.frontend.utility import BLANK
|
|
|
|
from deepspeech.frontend.utility import SOS
|
|
|
|
from deepspeech.frontend.utility import SOS
|
|
|
|
from deepspeech.utils.utility import add_arguments
|
|
|
|
from deepspeech.utils.utility import add_arguments
|
|
|
|
from deepspeech.utils.utility import print_arguments
|
|
|
|
from deepspeech.utils.utility import print_arguments
|
|
|
|
|
|
|
|
from deepspeech.frontend.featurizer.text_featurizer import TextFeaturizer
|
|
|
|
|
|
|
|
|
|
|
|
parser = argparse.ArgumentParser(description=__doc__)
|
|
|
|
parser = argparse.ArgumentParser(description=__doc__)
|
|
|
|
add_arg = functools.partial(add_arguments, argparser=parser)
|
|
|
|
add_arg = functools.partial(add_arguments, argparser=parser)
|
|
|
|
# yapf: disable
|
|
|
|
# yapf: disable
|
|
|
|
add_arg('unit_type', str, "character", "Unit type, e.g. character, word, bpe")
|
|
|
|
add_arg('unit_type', str, "char", "Unit type, e.g. char, word, spm")
|
|
|
|
add_arg('count_threshold', int, 0, "Truncation threshold for char/word/bpe counts.")
|
|
|
|
add_arg('count_threshold', int, 0, "Truncation threshold for char/word/spm counts.")
|
|
|
|
add_arg('vocab_path', str,
|
|
|
|
add_arg('vocab_path', str,
|
|
|
|
'examples/librispeech/data/vocab.txt',
|
|
|
|
'examples/librispeech/data/vocab.txt',
|
|
|
|
"Filepath to write the vocabulary.")
|
|
|
|
"Filepath to write the vocabulary.")
|
|
|
@ -45,10 +46,10 @@ add_arg('manifest_paths', str,
|
|
|
|
nargs='+',
|
|
|
|
nargs='+',
|
|
|
|
required=True)
|
|
|
|
required=True)
|
|
|
|
# bpe
|
|
|
|
# bpe
|
|
|
|
add_arg('bpe_mode', str, 'unigram',
|
|
|
|
add_arg('spm_mode', str, 'unigram',
|
|
|
|
"bpe model type, e.g. unigram, bpe, char, word. only need when `unit_type` is bpe")
|
|
|
|
"spm model type, e.g. unigram, spm, char, word. only need when `unit_type` is spm")
|
|
|
|
add_arg('bpe_model_prefix', str, "bpe_model_%(bpe_mode)_%(count_threshold)",
|
|
|
|
add_arg('spm_model_prefix', str, "spm_model_%(spm_mode)_%(count_threshold)",
|
|
|
|
"bpe model prefix, only need when `unit_type` is bpe")
|
|
|
|
"spm model prefix, only need when `unit_type` is spm")
|
|
|
|
# yapf: disable
|
|
|
|
# yapf: disable
|
|
|
|
args = parser.parse_args()
|
|
|
|
args = parser.parse_args()
|
|
|
|
|
|
|
|
|
|
|
@ -56,7 +57,7 @@ args = parser.parse_args()
|
|
|
|
def count_manifest(counter, manifest_path):
|
|
|
|
def count_manifest(counter, manifest_path):
|
|
|
|
manifest_jsons = read_manifest(manifest_path)
|
|
|
|
manifest_jsons = read_manifest(manifest_path)
|
|
|
|
for line_json in manifest_jsons:
|
|
|
|
for line_json in manifest_jsons:
|
|
|
|
if args.unit_type == 'character':
|
|
|
|
if args.unit_type == 'char':
|
|
|
|
for char in line_json['text']:
|
|
|
|
for char in line_json['text']:
|
|
|
|
counter.update(char)
|
|
|
|
counter.update(char)
|
|
|
|
elif args.unit_type == 'word':
|
|
|
|
elif args.unit_type == 'word':
|
|
|
@ -75,7 +76,7 @@ def main():
|
|
|
|
fout.write(BLANK + "\n") # 0 will be used for "blank" in CTC
|
|
|
|
fout.write(BLANK + "\n") # 0 will be used for "blank" in CTC
|
|
|
|
fout.write(UNK + '\n') # <unk> must be 1
|
|
|
|
fout.write(UNK + '\n') # <unk> must be 1
|
|
|
|
|
|
|
|
|
|
|
|
if args.unit_type != 'bpe':
|
|
|
|
if args.unit_type != 'spm':
|
|
|
|
counter = Counter()
|
|
|
|
counter = Counter()
|
|
|
|
for manifest_path in args.manifest_paths:
|
|
|
|
for manifest_path in args.manifest_paths:
|
|
|
|
count_manifest(counter, manifest_path)
|
|
|
|
count_manifest(counter, manifest_path)
|
|
|
@ -98,41 +99,21 @@ def main():
|
|
|
|
spm.SentencePieceTrainer.Train(
|
|
|
|
spm.SentencePieceTrainer.Train(
|
|
|
|
input=fp.name,
|
|
|
|
input=fp.name,
|
|
|
|
vocab_size=args.count_threshold,
|
|
|
|
vocab_size=args.count_threshold,
|
|
|
|
model_type=args.bpe_mode,
|
|
|
|
model_type=args.spm_mode,
|
|
|
|
model_prefix=args.bpe_model_prefix,
|
|
|
|
model_prefix=args.spm_model_prefix,
|
|
|
|
input_sentence_size=100000000,
|
|
|
|
input_sentence_size=100000000,
|
|
|
|
character_coverage=0.9995)
|
|
|
|
character_coverage=0.9995)
|
|
|
|
os.unlink(fp.name)
|
|
|
|
os.unlink(fp.name)
|
|
|
|
|
|
|
|
|
|
|
|
# encode
|
|
|
|
# encode
|
|
|
|
sp = spm.SentencePieceProcessor()
|
|
|
|
text_feature = TextFeaturizer(args.unit_type, args.vocab_path, args.spm_model_prefix)
|
|
|
|
sp.Load(args.bpe_model_prefix + '.model')
|
|
|
|
|
|
|
|
stats = {"num_empty": 0, "num_filtered": 0}
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
def valid(line):
|
|
|
|
|
|
|
|
return True
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
def encode(l):
|
|
|
|
|
|
|
|
return sp.EncodeAsPieces(l)
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
def encode_line(line):
|
|
|
|
|
|
|
|
line = line.strip()
|
|
|
|
|
|
|
|
if len(line) > 0:
|
|
|
|
|
|
|
|
line = encode(line)
|
|
|
|
|
|
|
|
if valid(line):
|
|
|
|
|
|
|
|
return line
|
|
|
|
|
|
|
|
else:
|
|
|
|
|
|
|
|
stats["num_filtered"] += 1
|
|
|
|
|
|
|
|
else:
|
|
|
|
|
|
|
|
stats["num_empty"] += 1
|
|
|
|
|
|
|
|
return None
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
vocabs = set()
|
|
|
|
vocabs = set()
|
|
|
|
for manifest_path in args.manifest_paths:
|
|
|
|
for manifest_path in args.manifest_paths:
|
|
|
|
manifest_jsons = read_manifest(manifest_path)
|
|
|
|
manifest_jsons = read_manifest(manifest_path)
|
|
|
|
for line_json in manifest_jsons:
|
|
|
|
for line_json in manifest_jsons:
|
|
|
|
line = line_json['text']
|
|
|
|
line = line_json['text']
|
|
|
|
enc_line = encode_line(line)
|
|
|
|
enc_line = text_feature.spm_tokenize(line)
|
|
|
|
for code in enc_line:
|
|
|
|
for code in enc_line:
|
|
|
|
vocabs.add(code)
|
|
|
|
vocabs.add(code)
|
|
|
|
#print(" ".join(enc_line))
|
|
|
|
#print(" ".join(enc_line))
|
|
|
@ -140,9 +121,7 @@ def main():
|
|
|
|
for unit in vocabs_sorted:
|
|
|
|
for unit in vocabs_sorted:
|
|
|
|
fout.write(unit + "\n")
|
|
|
|
fout.write(unit + "\n")
|
|
|
|
|
|
|
|
|
|
|
|
print(f"bpe vocab size: {len(vocabs_sorted)}")
|
|
|
|
print(f"spm vocab size: {len(vocabs_sorted)}")
|
|
|
|
print(f"skip {stats['num_empty']} empty lines")
|
|
|
|
|
|
|
|
print(f"filter {stats['num_filtered']} invalid lines")
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
fout.write(SOS + "\n") # <sos/eos>
|
|
|
|
fout.write(SOS + "\n") # <sos/eos>
|
|
|
|
fout.close()
|
|
|
|
fout.close()
|
|
|
|