From ed793b30b7c9f605ced6bac7c0527eee3993c8da Mon Sep 17 00:00:00 2001 From: Hui Zhang Date: Thu, 8 Apr 2021 11:49:34 +0000 Subject: [PATCH] refactor build vocab --- .../frontend/featurizer/text_featurizer.py | 23 ++-- examples/tiny/s0/local/data.sh | 12 +- utils/build_vocab.py | 49 +++----- utils/format_data.py | 109 +++++------------- 4 files changed, 59 insertions(+), 134 deletions(-) diff --git a/deepspeech/frontend/featurizer/text_featurizer.py b/deepspeech/frontend/featurizer/text_featurizer.py index db9734f1e..13b404e86 100644 --- a/deepspeech/frontend/featurizer/text_featurizer.py +++ b/deepspeech/frontend/featurizer/text_featurizer.py @@ -14,7 +14,6 @@ """Contains the text featurizer class.""" import os -import codecs import sentencepiece as spm from deepspeech.frontend.utility import UNK @@ -42,7 +41,7 @@ class TextFeaturizer(object): if unit_type == 'spm': spm_model = spm_model_prefix + '.model' self.sp = spm.SentencePieceProcessor() - self.sp.Load(self.spm_model) + self.sp.Load(spm_model) def featurize(self, text): """Convert text string to a list of token indices in char-level.Note @@ -51,14 +50,14 @@ class TextFeaturizer(object): :param text: Text to process. :type text: str :return: List of char-level token indices. - :rtype: list + :rtype: List[int] """ - if unit_type == 'char': - tokens = self._char_tokenize(text) - elif unit_type == 'word': - tokens = self._word_tokenize(text) + if self.unit_type == 'char': + tokens = self.char_tokenize(text) + elif self.unit_type == 'word': + tokens = self.word_tokenize(text) else: - tokens = self._spm_tokenize(text) + tokens = self.spm_tokenize(text) ids = [] for token in tokens: @@ -84,15 +83,15 @@ class TextFeaturizer(object): """ return self._vocab_list - def _char_tokenize(self, text): + def char_tokenize(self, text): """Character tokenizer.""" return list(text.strip()) - def _word_tokenize(self, text): + def word_tokenize(self, text): """Word tokenizer, spearte by .""" return text.strip().split() - def _spm_tokenize(self, text): + def spm_tokenize(self, text): """spm tokenize. Args: @@ -127,7 +126,7 @@ class TextFeaturizer(object): def _load_vocabulary_from_file(self, vocab_filepath): """Load vocabulary from file.""" vocab_lines = [] - with codecs.open(vocab_filepath, 'r', 'utf-8') as file: + with open(vocab_filepath, 'r', encoding='utf-8') as file: vocab_lines.extend(file.readlines()) vocab_list = [line[:-1] for line in vocab_lines] vocab_dict = dict( diff --git a/examples/tiny/s0/local/data.sh b/examples/tiny/s0/local/data.sh index 240655847..9794da349 100644 --- a/examples/tiny/s0/local/data.sh +++ b/examples/tiny/s0/local/data.sh @@ -23,10 +23,10 @@ bpemode=unigram bpeprefix="data/bpe_${bpemode}_${nbpe}" # build vocabulary python3 ${MAIN_ROOT}/utils/build_vocab.py \ ---unit_type "bpe" \ +--unit_type "spm" \ --count_threshold=${nbpe} \ ---bpe_mode ${bpemode} \ ---bpe_model_prefix ${bpeprefix} \ +--spm_mode ${bpemode} \ +--spm_model_prefix ${bpeprefix} \ --vocab_path="data/vocab.txt" \ --manifest_paths="data/manifest.tiny.raw" @@ -53,8 +53,8 @@ fi python3 ${MAIN_ROOT}/utils/format_data.py \ --feat_type "raw" \ --cmvn_path "data/mean_std.npz" \ ---unit_type "bpe" \ ---bpe_model_prefix ${bpeprefix} \ +--unit_type "spm" \ +--spm_model_prefix ${bpeprefix} \ --vocab_path="data/vocab.txt" \ --manifest_path="data/manifest.tiny.raw" \ --output_path="data/manifest.tiny" @@ -66,4 +66,4 @@ if [ $? -ne 0 ]; then fi echo "LibriSpeech Data preparation done." -exit 0 \ No newline at end of file +exit 0 diff --git a/utils/build_vocab.py b/utils/build_vocab.py index b147c5325..cbd3339d3 100644 --- a/utils/build_vocab.py +++ b/utils/build_vocab.py @@ -29,12 +29,13 @@ from deepspeech.frontend.utility import BLANK from deepspeech.frontend.utility import SOS from deepspeech.utils.utility import add_arguments from deepspeech.utils.utility import print_arguments +from deepspeech.frontend.featurizer.text_featurizer import TextFeaturizer parser = argparse.ArgumentParser(description=__doc__) add_arg = functools.partial(add_arguments, argparser=parser) # yapf: disable -add_arg('unit_type', str, "character", "Unit type, e.g. character, word, bpe") -add_arg('count_threshold', int, 0, "Truncation threshold for char/word/bpe counts.") +add_arg('unit_type', str, "char", "Unit type, e.g. char, word, spm") +add_arg('count_threshold', int, 0, "Truncation threshold for char/word/spm counts.") add_arg('vocab_path', str, 'examples/librispeech/data/vocab.txt', "Filepath to write the vocabulary.") @@ -45,10 +46,10 @@ add_arg('manifest_paths', str, nargs='+', required=True) # bpe -add_arg('bpe_mode', str, 'unigram', - "bpe model type, e.g. unigram, bpe, char, word. only need when `unit_type` is bpe") -add_arg('bpe_model_prefix', str, "bpe_model_%(bpe_mode)_%(count_threshold)", - "bpe model prefix, only need when `unit_type` is bpe") +add_arg('spm_mode', str, 'unigram', + "spm model type, e.g. unigram, spm, char, word. only need when `unit_type` is spm") +add_arg('spm_model_prefix', str, "spm_model_%(spm_mode)_%(count_threshold)", + "spm model prefix, only need when `unit_type` is spm") # yapf: disable args = parser.parse_args() @@ -56,7 +57,7 @@ args = parser.parse_args() def count_manifest(counter, manifest_path): manifest_jsons = read_manifest(manifest_path) for line_json in manifest_jsons: - if args.unit_type == 'character': + if args.unit_type == 'char': for char in line_json['text']: counter.update(char) elif args.unit_type == 'word': @@ -75,7 +76,7 @@ def main(): fout.write(BLANK + "\n") # 0 will be used for "blank" in CTC fout.write(UNK + '\n') # must be 1 - if args.unit_type != 'bpe': + if args.unit_type != 'spm': counter = Counter() for manifest_path in args.manifest_paths: count_manifest(counter, manifest_path) @@ -98,41 +99,21 @@ def main(): spm.SentencePieceTrainer.Train( input=fp.name, vocab_size=args.count_threshold, - model_type=args.bpe_mode, - model_prefix=args.bpe_model_prefix, + model_type=args.spm_mode, + model_prefix=args.spm_model_prefix, input_sentence_size=100000000, character_coverage=0.9995) os.unlink(fp.name) # encode - sp = spm.SentencePieceProcessor() - sp.Load(args.bpe_model_prefix + '.model') - stats = {"num_empty": 0, "num_filtered": 0} - - def valid(line): - return True - - def encode(l): - return sp.EncodeAsPieces(l) - - def encode_line(line): - line = line.strip() - if len(line) > 0: - line = encode(line) - if valid(line): - return line - else: - stats["num_filtered"] += 1 - else: - stats["num_empty"] += 1 - return None + text_feature = TextFeaturizer(args.unit_type, args.vocab_path, args.spm_model_prefix) vocabs = set() for manifest_path in args.manifest_paths: manifest_jsons = read_manifest(manifest_path) for line_json in manifest_jsons: line = line_json['text'] - enc_line = encode_line(line) + enc_line = text_feature.spm_tokenize(line) for code in enc_line: vocabs.add(code) #print(" ".join(enc_line)) @@ -140,9 +121,7 @@ def main(): for unit in vocabs_sorted: fout.write(unit + "\n") - print(f"bpe vocab size: {len(vocabs_sorted)}") - print(f"skip {stats['num_empty']} empty lines") - print(f"filter {stats['num_filtered']} invalid lines") + print(f"spm vocab size: {len(vocabs_sorted)}") fout.write(SOS + "\n") # fout.close() diff --git a/utils/format_data.py b/utils/format_data.py index 4788f8579..f1744d175 100644 --- a/utils/format_data.py +++ b/utils/format_data.py @@ -27,6 +27,7 @@ from deepspeech.frontend.utility import SOS from deepspeech.frontend.utility import load_cmvn from deepspeech.utils.utility import add_arguments from deepspeech.utils.utility import print_arguments +from deepspeech.frontend.featurizer.text_featurizer import TextFeaturizer parser = argparse.ArgumentParser(description=__doc__) add_arg = functools.partial(add_arguments, argparser=parser) @@ -35,7 +36,7 @@ add_arg('feat_type', str, "raw", "speech feature type, e.g. raw(wav, flac), kald add_arg('cmvn_path', str, 'examples/librispeech/data/mean_std.npz', "Filepath of cmvn.") -add_arg('unit_type', str, "character", "Unit type, e.g. character, word, bpe") +add_arg('unit_type', str, "char", "Unit type, e.g. char, word, spm") add_arg('vocab_path', str, 'examples/librispeech/data/vocab.txt', "Filepath of the vocabulary.") @@ -46,7 +47,8 @@ add_arg('manifest_paths', str, nargs='+', required=True) # bpe -add_arg('bpe_model_prefix', str, "bpe_model_%(bpe_mode)_%(count_threshold)", "bpe model prefix, only need when `unit_type` is bpe") +add_arg('spm_model_prefix', str, None, + "spm model prefix, spm_model_%(bpe_mode)_%(count_threshold), only need when `unit_type` is spm") add_arg('output_path', str, None, "filepath of formated manifest.", required=True) # yapf: disable args = parser.parse_args() @@ -54,93 +56,38 @@ args = parser.parse_args() def main(): print_arguments(args) + fout = open(args.output_path, 'w', encoding='utf-8') # get feat dim mean, std = load_cmvn(args.cmvn_path, filetype='npz') feat_dim = mean.shape[0] print(f"Feature dim: {feat_dim}") - # read vocab - vocab = dict() - with open(args.vocab_path, 'r', encoding='utf-8') as fin: - for line in fin: - token = line.strip() - vocab[token] = len(vocab) - vocab_size = len(vocab) + text_feature = TextFeaturizer(args.unit_type, args.vocab_path, args.spm_model_prefix) + vocab_size = text_feature.vocab_size print(f"Vocab size: {vocab_size}") - fout = open(args.output_path, 'w', encoding='utf-8') - - if args.unit_type != 'bpe': - for manifest_path in args.manifest_paths: - manifest_jsons = read_manifest(manifest_path) - for line_json in manifest_jsons: - tokens = [] - tokenids = [] - if args.unit_type == 'character': - for char in line_json['text']: - tokens.append(char) - tokenids.append(vocab[char]) - elif args.unit_type == 'word': - for word in line_json['text'].split(): - tokens.append(word) - tokenids.append(vocab[word]) - line_json['token'] = tokens - line_json['token_id'] = tokenids - line_json['token_shape'] = (len(tokenids), vocab_size) - feat_shape = line_json['feat_shape'] - assert isinstance(feat_shape, (list, tuple)), type(feat_shape) - if args.feat_type == 'raw': - feat_shape.append(feat_dim) - else: # kaldi - raise NotImplemented('no support kaldi feat now!') - fout.write(json.dumps(line_json) + '\n') - else: - import sentencepiece as spm - - # encode - sp = spm.SentencePieceProcessor() - sp.Load(args.bpe_model_prefix + '.model') - - def valid(line): - return True - - def encode(l): - return sp.EncodeAsPieces(l) - - def encode_line(line): - line = line.strip() - if len(line) > 0: - line = encode(line) - if valid(line): - return line - else: - stats["num_filtered"] += 1 - else: - stats["num_empty"] += 1 - return None - - for manifest_path in args.manifest_paths: - manifest_jsons = read_manifest(manifest_path) - for line_json in manifest_jsons: - line = line_json['text'] - tokens = [] - tokenids = [] - enc_line = encode_line(line) - for code in enc_line: - tokens.append(code) - tokenids.append(vocab[code]) - #print(code, vocab[code]) - line_json['token'] = tokens - line_json['token_id'] = tokenids - line_json['token_shape'] = (len(tokenids), vocab_size) - feat_shape = line_json['feat_shape'] - assert isinstance(feat_shape, (list, tuple)), type(feat_shape) - if args.feat_type == 'raw': - feat_shape.append(feat_dim) - else: # kaldi - raise NotImplemented('no support kaldi feat now!') - fout.write(json.dumps(line_json) + '\n') + for manifest_path in args.manifest_paths: + manifest_jsons = read_manifest(manifest_path) + for line_json in manifest_jsons: + line = line_json['text'] + if args.unit_type == 'char': + tokens = text_feature.char_tokenize(line) + elif args.unit_type == 'word': + tokens = text_feature.word_tokenize(line) + else: #spm + tokens = text_feature.spm_tokenize(line) + tokenids = text_feature.featurize(line) + line_json['token'] = tokens + line_json['token_id'] = tokenids + line_json['token_shape'] = (len(tokenids), vocab_size) + feat_shape = line_json['feat_shape'] + assert isinstance(feat_shape, (list, tuple)), type(feat_shape) + if args.feat_type == 'raw': + feat_shape.append(feat_dim) + else: # kaldi + raise NotImplemented('no support kaldi feat now!') + fout.write(json.dumps(line_json) + '\n') fout.close()