From 64f0bad5cac1c9d9f1f0f58a6e56c36a62bdd1cb Mon Sep 17 00:00:00 2001 From: Hui Zhang Date: Thu, 8 Apr 2021 08:33:03 +0000 Subject: [PATCH] refactor data, build vocab; add format data --- deepspeech/frontend/utility.py | 81 ++++++++--- deepspeech/io/collator.py | 2 +- deepspeech/models/u2.py | 2 +- deepspeech/utils/tensor_utils.py | 2 - examples/dataset/aishell/aishell.py | 16 ++- .../chime3_background/chime3_background.py | 13 +- examples/dataset/librispeech/librispeech.py | 10 +- .../mini_librispeech/mini_librispeech.py | 10 +- examples/dataset/musan/musan.py | 16 ++- examples/dataset/rir_noise/rir_noise.py | 16 ++- examples/dataset/voxforge/voxforge.py | 5 +- examples/tiny/s0/local/data.sh | 33 ++++- utils/build_vocab.py | 108 +++++++++++++-- utils/format_data.py | 127 ++++++++++++++++++ 14 files changed, 370 insertions(+), 71 deletions(-) create mode 100644 utils/format_data.py diff --git a/deepspeech/frontend/utility.py b/deepspeech/frontend/utility.py index de602cb97..0f35b1ef5 100644 --- a/deepspeech/frontend/utility.py +++ b/deepspeech/frontend/utility.py @@ -29,40 +29,79 @@ logger = logging.getLogger(__name__) __all__ = [ "load_cmvn", "read_manifest", "rms_to_db", "rms_to_dbfs", "max_dbfs", - "mean_dbfs", "gain_db_to_ratio", "normalize_audio" + "mean_dbfs", "gain_db_to_ratio", "normalize_audio", "SOS", "EOS", "UNK", + "BLANK" ] +IGNORE_ID = -1 +SOS = "" +EOS = SOS +UNK = "" +BLANK = "" + +# """Load and parse manifest file. + +# Instances with durations outside [min_duration, max_duration] will be +# filtered out. + +# :param manifest_path: Manifest file to load and parse. +# :type manifest_path: str +# :param max_duration:maximum output seq length, in seconds for raw wav, in frame numbers for feature data. +# :type max_duration: float +# :param min_duration: minimum input seq length, in seconds for raw wav, in frame numbers for feature data. +# :type min_duration: float +# :return: Manifest parsing results. List of dict. +# :rtype: list +# :raises IOError: If failed to parse the manifest. +# """ + + +def read_manifest( + manifest_path, + max_input_len=float('inf'), + min_input_len=0.0, + max_output_len=500.0, + min_output_len=0.0, + max_output_input_ratio=10.0, + min_output_input_ratio=0.05, ): -def read_manifest(manifest_path, max_duration=float('inf'), min_duration=0.0): - """Load and parse manifest file. - - Instances with durations outside [min_duration, max_duration] will be - filtered out. - - :param manifest_path: Manifest file to load and parse. - :type manifest_path: str - :param max_duration: Maximal duration in seconds for instance filter. - :type max_duration: float - :param min_duration: Minimal duration in seconds for instance filter. - :type min_duration: float - :return: Manifest parsing results. List of dict. - :rtype: list - :raises IOError: If failed to parse the manifest. - """ manifest = [] for json_line in codecs.open(manifest_path, 'r', 'utf-8'): try: json_data = json.loads(json_line) except Exception as e: raise IOError("Error reading manifest: %s" % str(e)) - if (json_data["duration"] <= max_duration and - json_data["duration"] >= min_duration): + feat_len = json_data["feat_shape"][0] + token_len = json_data["token_shape"][0] + conditions = [ + feat_len > min_input_len, + feat_len < max_input_len, + token_len > min_output_len, + token_len < max_output_len, + token_len / feat_len > min_output_input_ratio, + token_len / feat_len < max_output_input_ratio, + ] + if all(conditions): manifest.append(json_data) return manifest + # parser.add_argument('--max_input_len', type=float, + # default=20, + # help='maximum output seq length, in seconds for raw wav, in frame numbers for feature data') + # parser.add_argument('--min_output_len', type=float, + # default=0, help='minimum input seq length, in modeling units') + # parser.add_argument('--max_output_len', type=float, + # default=500, + # help='maximum output seq length, in modeling units') + # parser.add_argument('--min_output_input_ratio', type=float, default=0.05, + # help='minimum output seq length/output seq length ratio') + # parser.add_argument('--max_output_input_ratio', type=float, default=10, + # help='maximum output seq length/output seq length ratio') + def rms_to_db(rms: float): """Root Mean Square to dB. + Args: rms ([float]): root mean square @@ -145,8 +184,10 @@ def normalize_audio(sample_data: np.ndarray, dbfs: float=-3.0103): def _load_json_cmvn(json_cmvn_file): """ Load the json format cmvn stats file and calculate cmvn + Args: json_cmvn_file: cmvn stats file in json format + Returns: a numpy array of [means, vars] """ @@ -168,10 +209,12 @@ def _load_json_cmvn(json_cmvn_file): def _load_kaldi_cmvn(kaldi_cmvn_file): """ Load the kaldi format cmvn stats file and calculate cmvn + Args: kaldi_cmvn_file: kaldi text style global cmvn file, which is generated by: compute-cmvn-stats --binary=false scp:feats.scp global_cmvn + Returns: a numpy array of [means, vars] """ diff --git a/deepspeech/io/collator.py b/deepspeech/io/collator.py index c3a11f0f0..322bba73c 100644 --- a/deepspeech/io/collator.py +++ b/deepspeech/io/collator.py @@ -17,7 +17,7 @@ import numpy as np from collections import namedtuple from deepspeech.io.utility import pad_sequence -from deepspeech.utils.tensor_utils import IGNORE_ID +from deepspeech.frontend.utility import IGNORE_ID logger = logging.getLogger(__name__) diff --git a/deepspeech/models/u2.py b/deepspeech/models/u2.py index 1a632a878..8fcc9fca6 100644 --- a/deepspeech/models/u2.py +++ b/deepspeech/models/u2.py @@ -42,11 +42,11 @@ from deepspeech.modules.decoder import TransformerDecoder from deepspeech.modules.loss import LabelSmoothingLoss from deepspeech.frontend.utility import load_cmvn +from deepspeech.frontend.utility import IGNORE_ID from deepspeech.utils import checkpoint from deepspeech.utils import layer_tools from deepspeech.utils.utility import log_add -from deepspeech.utils.tensor_utils import IGNORE_ID from deepspeech.utils.tensor_utils import add_sos_eos from deepspeech.utils.tensor_utils import th_accuracy from deepspeech.utils.tensor_utils import pad_sequence diff --git a/deepspeech/utils/tensor_utils.py b/deepspeech/utils/tensor_utils.py index 345194901..68204d8da 100644 --- a/deepspeech/utils/tensor_utils.py +++ b/deepspeech/utils/tensor_utils.py @@ -22,8 +22,6 @@ logger = logging.getLogger(__name__) __all__ = ["pad_sequence", "add_sos_eos", "th_accuracy"] -IGNORE_ID = -1 - def pad_sequence(sequences: List[paddle.Tensor], batch_first: bool=False, diff --git a/examples/dataset/aishell/aishell.py b/examples/dataset/aishell/aishell.py index 38d0c28a3..764cc50cd 100644 --- a/examples/dataset/aishell/aishell.py +++ b/examples/dataset/aishell/aishell.py @@ -62,9 +62,9 @@ def create_manifest(data_dir, manifest_path_prefix): transcript_dict[audio_id] = text data_types = ['train', 'dev', 'test'] - for type in data_types: + for dtype in data_types: del json_lines[:] - audio_dir = os.path.join(data_dir, 'wav', type) + audio_dir = os.path.join(data_dir, 'wav', dtype) for subfolder, _, filelist in sorted(os.walk(audio_dir)): for fname in filelist: audio_path = os.path.join(subfolder, fname) @@ -78,12 +78,16 @@ def create_manifest(data_dir, manifest_path_prefix): json_lines.append( json.dumps( { - 'audio_filepath': audio_path, - 'duration': duration, - 'text': text + 'utt': + os.path.splitext(os.path.basename(audio_path))[0], + 'feat': + audio_path, + 'feat_shape': (duration, ), #second + 'text': + text }, ensure_ascii=False)) - manifest_path = manifest_path_prefix + '.' + type + manifest_path = manifest_path_prefix + '.' + dtype with codecs.open(manifest_path, 'w', 'utf-8') as fout: for line in json_lines: fout.write(line + '\n') diff --git a/examples/dataset/chime3_background/chime3_background.py b/examples/dataset/chime3_background/chime3_background.py index 31208d147..3f4fd1dc6 100644 --- a/examples/dataset/chime3_background/chime3_background.py +++ b/examples/dataset/chime3_background/chime3_background.py @@ -95,11 +95,14 @@ def create_manifest(data_dir, manifest_path): audio_data, samplerate = soundfile.read(filepath) duration = float(len(audio_data)) / samplerate json_lines.append( - json.dumps({ - 'audio_filepath': filepath, - 'duration': duration, - 'text': '' - })) + json.dumps( + { + 'utt': os.path.splitext(os.path.basename(filepath))[ + 0], + 'feat': filepath, + 'feat_shape': (duration, ), #second + 'type': 'background' + })) with io.open(manifest_path, mode='w', encoding='utf8') as out_file: for line in json_lines: out_file.write(line + '\n') diff --git a/examples/dataset/librispeech/librispeech.py b/examples/dataset/librispeech/librispeech.py index 4cf0f5541..52c940fa4 100644 --- a/examples/dataset/librispeech/librispeech.py +++ b/examples/dataset/librispeech/librispeech.py @@ -89,9 +89,13 @@ def create_manifest(data_dir, manifest_path): duration = float(len(audio_data)) / samplerate json_lines.append( json.dumps({ - 'audio_filepath': audio_filepath, - 'duration': duration, - 'text': text + 'utt': + os.path.splitext(os.path.basename(audio_filepath))[0], + 'feat': + audio_filepath, + 'feat_shape': (duration, ), #second + 'text': + text })) with codecs.open(manifest_path, 'w', 'utf-8') as out_file: for line in json_lines: diff --git a/examples/dataset/mini_librispeech/mini_librispeech.py b/examples/dataset/mini_librispeech/mini_librispeech.py index 883a322dc..34a1c0dc6 100644 --- a/examples/dataset/mini_librispeech/mini_librispeech.py +++ b/examples/dataset/mini_librispeech/mini_librispeech.py @@ -71,9 +71,13 @@ def create_manifest(data_dir, manifest_path): duration = float(len(audio_data)) / samplerate json_lines.append( json.dumps({ - 'audio_filepath': audio_filepath, - 'duration': duration, - 'text': text + 'utt': + os.path.splitext(os.path.basename(audio_filepath))[0], + 'feat': + audio_filepath, + 'feat_shape': (duration, ), #second + 'text': + text })) with codecs.open(manifest_path, 'w', 'utf-8') as out_file: for line in json_lines: diff --git a/examples/dataset/musan/musan.py b/examples/dataset/musan/musan.py index 87d8e5e10..84322051f 100644 --- a/examples/dataset/musan/musan.py +++ b/examples/dataset/musan/musan.py @@ -53,9 +53,9 @@ def create_manifest(data_dir, manifest_path_prefix): print("Creating manifest %s ..." % manifest_path_prefix) json_lines = [] data_types = ['music', 'noise', 'speech'] - for type in data_types: + for dtype in data_types: del json_lines[:] - audio_dir = os.path.join(data_dir, type) + audio_dir = os.path.join(data_dir, dtype) for subfolder, _, filelist in sorted(os.walk(audio_dir)): print('x, ', subfolder) for fname in filelist: @@ -67,12 +67,16 @@ def create_manifest(data_dir, manifest_path_prefix): json_lines.append( json.dumps( { - 'audio_filepath': audio_path, - 'duration': duration, - 'type': type, + 'utt': + os.path.splitext(os.path.basename(audio_path))[0], + 'feat': + audio_path, + 'feat_shape': (duration, ), #second + 'type': + dtype, }, ensure_ascii=False)) - manifest_path = manifest_path_prefix + '.' + type + manifest_path = manifest_path_prefix + '.' + dtype with codecs.open(manifest_path, 'w', 'utf-8') as fout: for line in json_lines: fout.write(line + '\n') diff --git a/examples/dataset/rir_noise/rir_noise.py b/examples/dataset/rir_noise/rir_noise.py index 643540c9b..900fc2696 100644 --- a/examples/dataset/rir_noise/rir_noise.py +++ b/examples/dataset/rir_noise/rir_noise.py @@ -55,9 +55,9 @@ def create_manifest(data_dir, manifest_path_prefix): data_types = [ 'pointsource_noises', 'real_rirs_isotropic_noises', 'simulated_rirs' ] - for type in data_types: + for dtype in data_types: del json_lines[:] - audio_dir = os.path.join(data_dir, type) + audio_dir = os.path.join(data_dir, dtype) for subfolder, _, filelist in sorted(os.walk(audio_dir)): for fname in filelist: audio_path = os.path.join(subfolder, fname) @@ -68,12 +68,16 @@ def create_manifest(data_dir, manifest_path_prefix): json_lines.append( json.dumps( { - 'audio_filepath': audio_path, - 'duration': duration, - 'type': type, + 'utt': + os.path.splitext(os.path.basename(audio_path))[0], + 'feat': + audio_path, + 'feat_shape': (duration, ), #second + 'type': + dtype, }, ensure_ascii=False)) - manifest_path = manifest_path_prefix + '.' + type + manifest_path = manifest_path_prefix + '.' + dtype with codecs.open(manifest_path, 'w', 'utf-8') as fout: for line in json_lines: fout.write(line + '\n') diff --git a/examples/dataset/voxforge/voxforge.py b/examples/dataset/voxforge/voxforge.py index abf1ccff6..c32b783d4 100644 --- a/examples/dataset/voxforge/voxforge.py +++ b/examples/dataset/voxforge/voxforge.py @@ -174,8 +174,9 @@ def generate_manifest(data_dir, manifest_path): duration = float(len(audio_data)) / samplerate json_lines.append( json.dumps({ - 'audio_filepath': u, - 'duration': duration, + 'utt': os.path.splitext(os.path.basename(u))[0], + 'feat': u, + 'feat_shape': (duration, ), #second 'text': trans.lower() })) diff --git a/examples/tiny/s0/local/data.sh b/examples/tiny/s0/local/data.sh index d834ec677..0117218f8 100644 --- a/examples/tiny/s0/local/data.sh +++ b/examples/tiny/s0/local/data.sh @@ -15,13 +15,20 @@ if [ $? -ne 0 ]; then exit 1 fi -head -n 64 data/manifest.dev-clean > data/manifest.tiny +head -n 64 data/manifest.dev-clean > data/manifest.tiny.raw +# bpemode (unigram or bpe) +nbpe=200 +bpemode=unigram +bpeprefix="data/bpe_${bpemode}_${nbpe}" # build vocabulary python3 ${MAIN_ROOT}/utils/build_vocab.py \ ---count_threshold=0 \ +--unit_type "bpe" \ +--count_threshold=${nbpe} \ +--bpe_mode ${bpemode} \ +--bpe_model_prefix ${bpeprefix} \ --vocab_path="data/vocab.txt" \ ---manifest_paths="data/manifest.tiny" +--manifest_paths="data/manifest.tiny.raw" if [ $? -ne 0 ]; then echo "Build vocabulary failed. Terminated." @@ -31,7 +38,7 @@ fi # compute mean and stddev for normalizer python3 ${MAIN_ROOT}/utils/compute_mean_std.py \ ---manifest_path="data/manifest.tiny" \ +--manifest_path="data/manifest.tiny.raw" \ --num_samples=64 \ --specgram_type="linear" \ --output_path="data/mean_std.npz" @@ -41,5 +48,21 @@ if [ $? -ne 0 ]; then exit 1 fi + +# format manifest with tokenids, vocab size +python3 ${MAIN_ROOT}/utils/format_data.py \ +--feat_type "raw" \ +--unit_type "bpe" \ +--bpe_model_prefix ${bpeprefix} \ +--vocab_path="data/vocab.txt" \ +--manifest_path="data/manifest.tiny.raw" \ +--output_path="data/manifest.tiny" + + +if [ $? -ne 0 ]; then + echo "Formt mnaifest failed. Terminated." + exit 1 +fi + echo "LibriSpeech Data preparation done." -exit 0 +exit 0 \ No newline at end of file diff --git a/utils/build_vocab.py b/utils/build_vocab.py index cb17de57c..b147c5325 100644 --- a/utils/build_vocab.py +++ b/utils/build_vocab.py @@ -17,18 +17,24 @@ Each item in vocabulary file is a character. import argparse import functools -import codecs import json from collections import Counter -import os.path +import os +import copy +import tempfile from deepspeech.frontend.utility import read_manifest -from deepspeech.utils.utility import add_arguments, print_arguments +from deepspeech.frontend.utility import UNK +from deepspeech.frontend.utility import BLANK +from deepspeech.frontend.utility import SOS +from deepspeech.utils.utility import add_arguments +from deepspeech.utils.utility import print_arguments parser = argparse.ArgumentParser(description=__doc__) add_arg = functools.partial(add_arguments, argparser=parser) # yapf: disable -add_arg('count_threshold', int, 0, "Truncation threshold for char counts.") +add_arg('unit_type', str, "character", "Unit type, e.g. character, word, bpe") +add_arg('count_threshold', int, 0, "Truncation threshold for char/word/bpe counts.") add_arg('vocab_path', str, 'examples/librispeech/data/vocab.txt', "Filepath to write the vocabulary.") @@ -38,6 +44,11 @@ add_arg('manifest_paths', str, "You can provide multiple manifest files.", nargs='+', required=True) +# bpe +add_arg('bpe_mode', str, 'unigram', + "bpe model type, e.g. unigram, bpe, char, word. only need when `unit_type` is bpe") +add_arg('bpe_model_prefix', str, "bpe_model_%(bpe_mode)_%(count_threshold)", + "bpe model prefix, only need when `unit_type` is bpe") # yapf: disable args = parser.parse_args() @@ -45,23 +56,96 @@ args = parser.parse_args() def count_manifest(counter, manifest_path): manifest_jsons = read_manifest(manifest_path) for line_json in manifest_jsons: - for char in line_json['text']: - counter.update(char) + if args.unit_type == 'character': + for char in line_json['text']: + counter.update(char) + elif args.unit_type == 'word': + for word in line_json['text'].split(): + counter.update(word) +def read_text_manifest(fileobj, manifest_path): + manifest_jsons = read_manifest(manifest_path) + for line_json in manifest_jsons: + fileobj.write(line_json['text'] + "\n") def main(): print_arguments(args) - counter = Counter() - for manifest_path in args.manifest_paths: - count_manifest(counter, manifest_path) + fout = open(args.vocab_path, 'w', encoding='utf-8') + fout.write(BLANK + "\n") # 0 will be used for "blank" in CTC + fout.write(UNK + '\n') # must be 1 + + if args.unit_type != 'bpe': + counter = Counter() + for manifest_path in args.manifest_paths: + count_manifest(counter, manifest_path) - count_sorted = sorted(counter.items(), key=lambda x: x[1], reverse=True) - with codecs.open(args.vocab_path, 'w', 'utf-8') as fout: - fout.write('' + '\n') + count_sorted = sorted(counter.items(), key=lambda x: x[1], reverse=True) for char, count in count_sorted: if count < args.count_threshold: break fout.write(char + '\n') + else: + # tools/spm_train --input=$wave_data/lang_char/input.txt + # --vocab_size=${nbpe} --model_type=${bpemode} + # --model_prefix=${bpemodel} --input_sentence_size=100000000 + import sentencepiece as spm + + fp = tempfile.NamedTemporaryFile(mode='w', delete=False) + for manifest_path in args.manifest_paths: + read_text_manifest(fp, manifest_path) + fp.close() + # train + spm.SentencePieceTrainer.Train( + input=fp.name, + vocab_size=args.count_threshold, + model_type=args.bpe_mode, + model_prefix=args.bpe_model_prefix, + input_sentence_size=100000000, + character_coverage=0.9995) + os.unlink(fp.name) + + # encode + sp = spm.SentencePieceProcessor() + sp.Load(args.bpe_model_prefix + '.model') + stats = {"num_empty": 0, "num_filtered": 0} + + def valid(line): + return True + + def encode(l): + return sp.EncodeAsPieces(l) + + def encode_line(line): + line = line.strip() + if len(line) > 0: + line = encode(line) + if valid(line): + return line + else: + stats["num_filtered"] += 1 + else: + stats["num_empty"] += 1 + return None + + vocabs = set() + for manifest_path in args.manifest_paths: + manifest_jsons = read_manifest(manifest_path) + for line_json in manifest_jsons: + line = line_json['text'] + enc_line = encode_line(line) + for code in enc_line: + vocabs.add(code) + #print(" ".join(enc_line)) + vocabs_sorted = sorted(vocabs) + for unit in vocabs_sorted: + fout.write(unit + "\n") + + print(f"bpe vocab size: {len(vocabs_sorted)}") + print(f"skip {stats['num_empty']} empty lines") + print(f"filter {stats['num_filtered']} invalid lines") + + fout.write(SOS + "\n") # + fout.close() if __name__ == '__main__': diff --git a/utils/format_data.py b/utils/format_data.py new file mode 100644 index 000000000..fc10c0385 --- /dev/null +++ b/utils/format_data.py @@ -0,0 +1,127 @@ +# Copyright (c) 2021 PaddlePaddle Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +"""format manifest with more metadata.""" +import argparse +import functools +import json +from collections import Counter +import os +import copy +import tempfile + +from deepspeech.frontend.utility import read_manifest +from deepspeech.frontend.utility import UNK +from deepspeech.frontend.utility import BLANK +from deepspeech.frontend.utility import SOS +from deepspeech.utils.utility import add_arguments +from deepspeech.utils.utility import print_arguments + +parser = argparse.ArgumentParser(description=__doc__) +add_arg = functools.partial(add_arguments, argparser=parser) +# yapf: disable +add_arg('feat_type', str, "raw", "speech feature type, e.g. raw(wav, flac), kaldi") +add_arg('unit_type', str, "character", "Unit type, e.g. character, word, bpe") +add_arg('vocab_path', str, + 'examples/librispeech/data/vocab.txt', + "Filepath to write the vocabulary.") +add_arg('manifest_paths', str, + None, + "Filepaths of manifests for building vocabulary. " + "You can provide multiple manifest files.", + nargs='+', + required=True) +# bpe +add_arg('bpe_model_prefix', str, "bpe_model_%(bpe_mode)_%(count_threshold)", "bpe model prefix, only need when `unit_type` is bpe") +add_arg('output_path', str, None, "filepath of formated manifest.", required=True) +# yapf: disable +args = parser.parse_args() + + +def main(): + print_arguments(args) + + # read vocab + vocab = dict() + with open(args.vocab_path, 'r', encoding='utf-8') as fin: + for line in fin: + token = line.strip() + vocab[token] = len(vocab) + vocab_size = len(vocab) + + fout = open(args.output_path, 'w', encoding='utf-8') + + if args.unit_type != 'bpe': + for manifest_path in args.manifest_paths: + manifest_jsons = read_manifest(manifest_path) + for line_json in manifest_jsons: + tokens = [] + tokenids = [] + if args.unit_type == 'character': + for char in line_json['text']: + tokens.append(char) + tokenids.append(vocab[char]) + elif args.unit_type == 'word': + for word in line_json['text'].split(): + tokens.append(word) + tokenids.append(vocab[word]) + line_json['token'] = tokens + line_json['token_id'] = tokenids + line_json['token_shape'] = (len(tokenids), vocab_size) + fout.write(json.dumps(line_json) + '\n') + else: + import sentencepiece as spm + + # encode + sp = spm.SentencePieceProcessor() + sp.Load(args.bpe_model_prefix + '.model') + + def valid(line): + return True + + def encode(l): + return sp.EncodeAsPieces(l) + + def encode_line(line): + line = line.strip() + if len(line) > 0: + line = encode(line) + if valid(line): + return line + else: + stats["num_filtered"] += 1 + else: + stats["num_empty"] += 1 + return None + + for manifest_path in args.manifest_paths: + manifest_jsons = read_manifest(manifest_path) + for line_json in manifest_jsons: + line = line_json['text'] + tokens = [] + tokenids = [] + enc_line = encode_line(line) + for code in enc_line: + tokens.append(code) + tokenids.append(vocab[code]) + #print(code, vocab[code]) + line_json['token'] = tokens + line_json['token_id'] = tokenids + line_json['token_shape'] = (len(tokenids), vocab_size) + fout.write(json.dumps(line_json) + '\n') + + fout.close() + + +if __name__ == '__main__': + main()