|
|
|
# Copyright (c) 2023 PaddlePaddle Authors. All Rights Reserved.
|
|
|
|
#
|
|
|
|
# Licensed under the Apache License, Version 2.0 (the "License");
|
|
|
|
# you may not use this file except in compliance with the License.
|
|
|
|
# You may obtain a copy of the License at
|
|
|
|
#
|
|
|
|
# http://www.apache.org/licenses/LICENSE-2.0
|
|
|
|
#
|
|
|
|
# Unless required by applicable law or agreed to in writing, software
|
|
|
|
# distributed under the License is distributed on an "AS IS" BASIS,
|
|
|
|
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
|
|
|
# See the License for the specific language governing permissions and
|
|
|
|
# limitations under the License.
|
|
|
|
"""Build vocabulary from manifest files.
|
|
|
|
Each item in vocabulary file is a character.
|
|
|
|
"""
|
|
|
|
import argparse
|
|
|
|
import functools
|
|
|
|
import os
|
|
|
|
import tempfile
|
|
|
|
from collections import Counter
|
|
|
|
|
|
|
|
import jsonlines
|
|
|
|
|
|
|
|
from paddlespeech.s2t.frontend.featurizer.text_featurizer import TextFeaturizer
|
|
|
|
from paddlespeech.s2t.frontend.utility import BLANK
|
|
|
|
from paddlespeech.s2t.frontend.utility import SOS
|
|
|
|
from paddlespeech.s2t.frontend.utility import SPACE
|
|
|
|
from paddlespeech.s2t.frontend.utility import UNK
|
|
|
|
from paddlespeech.utils.argparse import add_arguments
|
|
|
|
from paddlespeech.utils.argparse import print_arguments
|
|
|
|
|
|
|
|
|
|
|
|
def count_manifest(counter, text_feature, manifest_path):
|
|
|
|
manifest_jsons = []
|
|
|
|
with jsonlines.open(manifest_path, 'r') as reader:
|
|
|
|
for json_data in reader:
|
|
|
|
manifest_jsons.append(json_data)
|
|
|
|
|
|
|
|
for line_json in manifest_jsons:
|
|
|
|
if isinstance(line_json['text'], str):
|
|
|
|
tokens = text_feature.tokenize(
|
|
|
|
line_json['text'], replace_space=False)
|
|
|
|
|
|
|
|
counter.update(tokens)
|
|
|
|
else:
|
|
|
|
assert isinstance(line_json['text'], list)
|
|
|
|
for text in line_json['text']:
|
|
|
|
tokens = text_feature.tokenize(text, replace_space=False)
|
|
|
|
counter.update(tokens)
|
|
|
|
|
|
|
|
|
|
|
|
def dump_text_manifest(fileobj, manifest_path, key='text'):
|
|
|
|
manifest_jsons = []
|
|
|
|
with jsonlines.open(manifest_path, 'r') as reader:
|
|
|
|
for json_data in reader:
|
|
|
|
manifest_jsons.append(json_data)
|
|
|
|
|
|
|
|
for line_json in manifest_jsons:
|
|
|
|
if isinstance(line_json[key], str):
|
|
|
|
fileobj.write(line_json[key] + "\n")
|
|
|
|
else:
|
|
|
|
assert isinstance(line_json[key], list)
|
|
|
|
for line in line_json[key]:
|
|
|
|
fileobj.write(line + "\n")
|
|
|
|
|
|
|
|
|
|
|
|
def build_vocab(manifest_paths="",
|
|
|
|
vocab_path="examples/librispeech/data/vocab.txt",
|
|
|
|
unit_type="char",
|
|
|
|
count_threshold=0,
|
|
|
|
text_keys='text',
|
|
|
|
spm_mode="unigram",
|
|
|
|
spm_vocab_size=0,
|
|
|
|
spm_model_prefix="",
|
|
|
|
spm_character_coverage=0.9995):
|
|
|
|
manifest_paths = [manifest_paths] if isinstance(manifest_paths,
|
|
|
|
str) else manifest_paths
|
|
|
|
|
|
|
|
fout = open(vocab_path, 'w', encoding='utf-8')
|
|
|
|
fout.write(BLANK + "\n") # 0 will be used for "blank" in CTC
|
|
|
|
fout.write(UNK + '\n') # <unk> must be 1
|
|
|
|
|
|
|
|
if unit_type == 'spm':
|
|
|
|
# tools/spm_train --input=$wave_data/lang_char/input.txt
|
|
|
|
# --vocab_size=${nbpe} --model_type=${bpemode}
|
|
|
|
# --model_prefix=${bpemodel} --input_sentence_size=100000000
|
|
|
|
import sentencepiece as spm
|
|
|
|
|
|
|
|
fp = tempfile.NamedTemporaryFile(mode='w', delete=False)
|
|
|
|
for manifest_path in manifest_paths:
|
|
|
|
_text_keys = [text_keys] if type(
|
|
|
|
text_keys) is not list else text_keys
|
|
|
|
for text_key in _text_keys:
|
|
|
|
dump_text_manifest(fp, manifest_path, key=text_key)
|
|
|
|
fp.close()
|
|
|
|
# train
|
|
|
|
spm.SentencePieceTrainer.Train(
|
|
|
|
input=fp.name,
|
|
|
|
vocab_size=spm_vocab_size,
|
|
|
|
model_type=spm_mode,
|
|
|
|
model_prefix=spm_model_prefix,
|
|
|
|
input_sentence_size=100000000,
|
|
|
|
character_coverage=spm_character_coverage)
|
|
|
|
os.unlink(fp.name)
|
|
|
|
|
|
|
|
# encode
|
|
|
|
text_feature = TextFeaturizer(unit_type, "", spm_model_prefix)
|
|
|
|
counter = Counter()
|
|
|
|
|
|
|
|
for manifest_path in manifest_paths:
|
|
|
|
count_manifest(counter, text_feature, manifest_path)
|
|
|
|
|
|
|
|
count_sorted = sorted(counter.items(), key=lambda x: x[1], reverse=True)
|
|
|
|
tokens = []
|
|
|
|
for token, count in count_sorted:
|
|
|
|
if count < count_threshold:
|
|
|
|
break
|
|
|
|
# replace space by `<space>`
|
|
|
|
token = SPACE if token == ' ' else token
|
|
|
|
tokens.append(token)
|
|
|
|
|
|
|
|
tokens = sorted(tokens)
|
|
|
|
for token in tokens:
|
|
|
|
fout.write(token + '\n')
|
|
|
|
|
|
|
|
fout.write(SOS + "\n") # <sos/eos>
|
|
|
|
fout.close()
|
|
|
|
|
|
|
|
|
|
|
|
def define_argparse():
|
|
|
|
parser = argparse.ArgumentParser(description=__doc__)
|
|
|
|
add_arg = functools.partial(add_arguments, argparser=parser)
|
|
|
|
|
|
|
|
# yapf: disable
|
|
|
|
add_arg('unit_type', str, "char", "Unit type, e.g. char, word, spm")
|
|
|
|
add_arg('count_threshold', int, 0,
|
|
|
|
"Truncation threshold for char/word counts.Default 0, no truncate.")
|
|
|
|
add_arg('vocab_path', str,
|
|
|
|
'examples/librispeech/data/vocab.txt',
|
|
|
|
"Filepath to write the vocabulary.")
|
|
|
|
add_arg('manifest_paths', str,
|
|
|
|
None,
|
|
|
|
"Filepaths of manifests for building vocabulary. "
|
|
|
|
"You can provide multiple manifest files.",
|
|
|
|
nargs='+',
|
|
|
|
required=True)
|
|
|
|
add_arg('text_keys', str,
|
|
|
|
'text',
|
|
|
|
"keys of the text in manifest for building vocabulary. "
|
|
|
|
"You can provide multiple k.",
|
|
|
|
nargs='+')
|
|
|
|
# bpe
|
|
|
|
add_arg('spm_vocab_size', int, 0, "Vocab size for spm.")
|
|
|
|
add_arg('spm_mode', str, 'unigram', "spm model type, e.g. unigram, spm, char, word. only need when `unit_type` is spm")
|
|
|
|
add_arg('spm_model_prefix', str, "", "spm_model_%(spm_mode)_%(count_threshold), spm model prefix, only need when `unit_type` is spm")
|
|
|
|
add_arg('spm_character_coverage', float, 0.9995, "character coverage to determine the minimum symbols")
|
|
|
|
# yapf: disable
|
|
|
|
|
|
|
|
args = parser.parse_args()
|
|
|
|
return args
|
|
|
|
|
|
|
|
def main():
|
|
|
|
args = define_argparse()
|
|
|
|
print_arguments(args, globals())
|
|
|
|
build_vocab(**vars(args))
|
|
|
|
|
|
|
|
if __name__ == '__main__':
|
|
|
|
main()
|