refactor data, build vocab; add format data

pull/578/head
Hui Zhang 5 years ago
parent 12c01f39aa
commit 64f0bad5ca

@ -29,40 +29,79 @@ logger = logging.getLogger(__name__)
__all__ = [ __all__ = [
"load_cmvn", "read_manifest", "rms_to_db", "rms_to_dbfs", "max_dbfs", "load_cmvn", "read_manifest", "rms_to_db", "rms_to_dbfs", "max_dbfs",
"mean_dbfs", "gain_db_to_ratio", "normalize_audio" "mean_dbfs", "gain_db_to_ratio", "normalize_audio", "SOS", "EOS", "UNK",
"BLANK"
] ]
IGNORE_ID = -1
SOS = "<sos/eos>"
EOS = SOS
UNK = "<unk>"
BLANK = "<blank>"
# """Load and parse manifest file.
# Instances with durations outside [min_duration, max_duration] will be
# filtered out.
# :param manifest_path: Manifest file to load and parse.
# :type manifest_path: str
# :param max_duration:maximum output seq length, in seconds for raw wav, in frame numbers for feature data.
# :type max_duration: float
# :param min_duration: minimum input seq length, in seconds for raw wav, in frame numbers for feature data.
# :type min_duration: float
# :return: Manifest parsing results. List of dict.
# :rtype: list
# :raises IOError: If failed to parse the manifest.
# """
def read_manifest(
manifest_path,
max_input_len=float('inf'),
min_input_len=0.0,
max_output_len=500.0,
min_output_len=0.0,
max_output_input_ratio=10.0,
min_output_input_ratio=0.05, ):
def read_manifest(manifest_path, max_duration=float('inf'), min_duration=0.0):
"""Load and parse manifest file.
Instances with durations outside [min_duration, max_duration] will be
filtered out.
:param manifest_path: Manifest file to load and parse.
:type manifest_path: str
:param max_duration: Maximal duration in seconds for instance filter.
:type max_duration: float
:param min_duration: Minimal duration in seconds for instance filter.
:type min_duration: float
:return: Manifest parsing results. List of dict.
:rtype: list
:raises IOError: If failed to parse the manifest.
"""
manifest = [] manifest = []
for json_line in codecs.open(manifest_path, 'r', 'utf-8'): for json_line in codecs.open(manifest_path, 'r', 'utf-8'):
try: try:
json_data = json.loads(json_line) json_data = json.loads(json_line)
except Exception as e: except Exception as e:
raise IOError("Error reading manifest: %s" % str(e)) raise IOError("Error reading manifest: %s" % str(e))
if (json_data["duration"] <= max_duration and feat_len = json_data["feat_shape"][0]
json_data["duration"] >= min_duration): token_len = json_data["token_shape"][0]
conditions = [
feat_len > min_input_len,
feat_len < max_input_len,
token_len > min_output_len,
token_len < max_output_len,
token_len / feat_len > min_output_input_ratio,
token_len / feat_len < max_output_input_ratio,
]
if all(conditions):
manifest.append(json_data) manifest.append(json_data)
return manifest return manifest
# parser.add_argument('--max_input_len', type=float,
# default=20,
# help='maximum output seq length, in seconds for raw wav, in frame numbers for feature data')
# parser.add_argument('--min_output_len', type=float,
# default=0, help='minimum input seq length, in modeling units')
# parser.add_argument('--max_output_len', type=float,
# default=500,
# help='maximum output seq length, in modeling units')
# parser.add_argument('--min_output_input_ratio', type=float, default=0.05,
# help='minimum output seq length/output seq length ratio')
# parser.add_argument('--max_output_input_ratio', type=float, default=10,
# help='maximum output seq length/output seq length ratio')
def rms_to_db(rms: float): def rms_to_db(rms: float):
"""Root Mean Square to dB. """Root Mean Square to dB.
Args: Args:
rms ([float]): root mean square rms ([float]): root mean square
@ -145,8 +184,10 @@ def normalize_audio(sample_data: np.ndarray, dbfs: float=-3.0103):
def _load_json_cmvn(json_cmvn_file): def _load_json_cmvn(json_cmvn_file):
""" Load the json format cmvn stats file and calculate cmvn """ Load the json format cmvn stats file and calculate cmvn
Args: Args:
json_cmvn_file: cmvn stats file in json format json_cmvn_file: cmvn stats file in json format
Returns: Returns:
a numpy array of [means, vars] a numpy array of [means, vars]
""" """
@ -168,10 +209,12 @@ def _load_json_cmvn(json_cmvn_file):
def _load_kaldi_cmvn(kaldi_cmvn_file): def _load_kaldi_cmvn(kaldi_cmvn_file):
""" Load the kaldi format cmvn stats file and calculate cmvn """ Load the kaldi format cmvn stats file and calculate cmvn
Args: Args:
kaldi_cmvn_file: kaldi text style global cmvn file, which kaldi_cmvn_file: kaldi text style global cmvn file, which
is generated by: is generated by:
compute-cmvn-stats --binary=false scp:feats.scp global_cmvn compute-cmvn-stats --binary=false scp:feats.scp global_cmvn
Returns: Returns:
a numpy array of [means, vars] a numpy array of [means, vars]
""" """

@ -17,7 +17,7 @@ import numpy as np
from collections import namedtuple from collections import namedtuple
from deepspeech.io.utility import pad_sequence from deepspeech.io.utility import pad_sequence
from deepspeech.utils.tensor_utils import IGNORE_ID from deepspeech.frontend.utility import IGNORE_ID
logger = logging.getLogger(__name__) logger = logging.getLogger(__name__)

@ -42,11 +42,11 @@ from deepspeech.modules.decoder import TransformerDecoder
from deepspeech.modules.loss import LabelSmoothingLoss from deepspeech.modules.loss import LabelSmoothingLoss
from deepspeech.frontend.utility import load_cmvn from deepspeech.frontend.utility import load_cmvn
from deepspeech.frontend.utility import IGNORE_ID
from deepspeech.utils import checkpoint from deepspeech.utils import checkpoint
from deepspeech.utils import layer_tools from deepspeech.utils import layer_tools
from deepspeech.utils.utility import log_add from deepspeech.utils.utility import log_add
from deepspeech.utils.tensor_utils import IGNORE_ID
from deepspeech.utils.tensor_utils import add_sos_eos from deepspeech.utils.tensor_utils import add_sos_eos
from deepspeech.utils.tensor_utils import th_accuracy from deepspeech.utils.tensor_utils import th_accuracy
from deepspeech.utils.tensor_utils import pad_sequence from deepspeech.utils.tensor_utils import pad_sequence

@ -22,8 +22,6 @@ logger = logging.getLogger(__name__)
__all__ = ["pad_sequence", "add_sos_eos", "th_accuracy"] __all__ = ["pad_sequence", "add_sos_eos", "th_accuracy"]
IGNORE_ID = -1
def pad_sequence(sequences: List[paddle.Tensor], def pad_sequence(sequences: List[paddle.Tensor],
batch_first: bool=False, batch_first: bool=False,

@ -62,9 +62,9 @@ def create_manifest(data_dir, manifest_path_prefix):
transcript_dict[audio_id] = text transcript_dict[audio_id] = text
data_types = ['train', 'dev', 'test'] data_types = ['train', 'dev', 'test']
for type in data_types: for dtype in data_types:
del json_lines[:] del json_lines[:]
audio_dir = os.path.join(data_dir, 'wav', type) audio_dir = os.path.join(data_dir, 'wav', dtype)
for subfolder, _, filelist in sorted(os.walk(audio_dir)): for subfolder, _, filelist in sorted(os.walk(audio_dir)):
for fname in filelist: for fname in filelist:
audio_path = os.path.join(subfolder, fname) audio_path = os.path.join(subfolder, fname)
@ -78,12 +78,16 @@ def create_manifest(data_dir, manifest_path_prefix):
json_lines.append( json_lines.append(
json.dumps( json.dumps(
{ {
'audio_filepath': audio_path, 'utt':
'duration': duration, os.path.splitext(os.path.basename(audio_path))[0],
'text': text 'feat':
audio_path,
'feat_shape': (duration, ), #second
'text':
text
}, },
ensure_ascii=False)) ensure_ascii=False))
manifest_path = manifest_path_prefix + '.' + type manifest_path = manifest_path_prefix + '.' + dtype
with codecs.open(manifest_path, 'w', 'utf-8') as fout: with codecs.open(manifest_path, 'w', 'utf-8') as fout:
for line in json_lines: for line in json_lines:
fout.write(line + '\n') fout.write(line + '\n')

@ -95,10 +95,13 @@ def create_manifest(data_dir, manifest_path):
audio_data, samplerate = soundfile.read(filepath) audio_data, samplerate = soundfile.read(filepath)
duration = float(len(audio_data)) / samplerate duration = float(len(audio_data)) / samplerate
json_lines.append( json_lines.append(
json.dumps({ json.dumps(
'audio_filepath': filepath, {
'duration': duration, 'utt': os.path.splitext(os.path.basename(filepath))[
'text': '' 0],
'feat': filepath,
'feat_shape': (duration, ), #second
'type': 'background'
})) }))
with io.open(manifest_path, mode='w', encoding='utf8') as out_file: with io.open(manifest_path, mode='w', encoding='utf8') as out_file:
for line in json_lines: for line in json_lines:

@ -89,9 +89,13 @@ def create_manifest(data_dir, manifest_path):
duration = float(len(audio_data)) / samplerate duration = float(len(audio_data)) / samplerate
json_lines.append( json_lines.append(
json.dumps({ json.dumps({
'audio_filepath': audio_filepath, 'utt':
'duration': duration, os.path.splitext(os.path.basename(audio_filepath))[0],
'text': text 'feat':
audio_filepath,
'feat_shape': (duration, ), #second
'text':
text
})) }))
with codecs.open(manifest_path, 'w', 'utf-8') as out_file: with codecs.open(manifest_path, 'w', 'utf-8') as out_file:
for line in json_lines: for line in json_lines:

@ -71,9 +71,13 @@ def create_manifest(data_dir, manifest_path):
duration = float(len(audio_data)) / samplerate duration = float(len(audio_data)) / samplerate
json_lines.append( json_lines.append(
json.dumps({ json.dumps({
'audio_filepath': audio_filepath, 'utt':
'duration': duration, os.path.splitext(os.path.basename(audio_filepath))[0],
'text': text 'feat':
audio_filepath,
'feat_shape': (duration, ), #second
'text':
text
})) }))
with codecs.open(manifest_path, 'w', 'utf-8') as out_file: with codecs.open(manifest_path, 'w', 'utf-8') as out_file:
for line in json_lines: for line in json_lines:

@ -53,9 +53,9 @@ def create_manifest(data_dir, manifest_path_prefix):
print("Creating manifest %s ..." % manifest_path_prefix) print("Creating manifest %s ..." % manifest_path_prefix)
json_lines = [] json_lines = []
data_types = ['music', 'noise', 'speech'] data_types = ['music', 'noise', 'speech']
for type in data_types: for dtype in data_types:
del json_lines[:] del json_lines[:]
audio_dir = os.path.join(data_dir, type) audio_dir = os.path.join(data_dir, dtype)
for subfolder, _, filelist in sorted(os.walk(audio_dir)): for subfolder, _, filelist in sorted(os.walk(audio_dir)):
print('x, ', subfolder) print('x, ', subfolder)
for fname in filelist: for fname in filelist:
@ -67,12 +67,16 @@ def create_manifest(data_dir, manifest_path_prefix):
json_lines.append( json_lines.append(
json.dumps( json.dumps(
{ {
'audio_filepath': audio_path, 'utt':
'duration': duration, os.path.splitext(os.path.basename(audio_path))[0],
'type': type, 'feat':
audio_path,
'feat_shape': (duration, ), #second
'type':
dtype,
}, },
ensure_ascii=False)) ensure_ascii=False))
manifest_path = manifest_path_prefix + '.' + type manifest_path = manifest_path_prefix + '.' + dtype
with codecs.open(manifest_path, 'w', 'utf-8') as fout: with codecs.open(manifest_path, 'w', 'utf-8') as fout:
for line in json_lines: for line in json_lines:
fout.write(line + '\n') fout.write(line + '\n')

@ -55,9 +55,9 @@ def create_manifest(data_dir, manifest_path_prefix):
data_types = [ data_types = [
'pointsource_noises', 'real_rirs_isotropic_noises', 'simulated_rirs' 'pointsource_noises', 'real_rirs_isotropic_noises', 'simulated_rirs'
] ]
for type in data_types: for dtype in data_types:
del json_lines[:] del json_lines[:]
audio_dir = os.path.join(data_dir, type) audio_dir = os.path.join(data_dir, dtype)
for subfolder, _, filelist in sorted(os.walk(audio_dir)): for subfolder, _, filelist in sorted(os.walk(audio_dir)):
for fname in filelist: for fname in filelist:
audio_path = os.path.join(subfolder, fname) audio_path = os.path.join(subfolder, fname)
@ -68,12 +68,16 @@ def create_manifest(data_dir, manifest_path_prefix):
json_lines.append( json_lines.append(
json.dumps( json.dumps(
{ {
'audio_filepath': audio_path, 'utt':
'duration': duration, os.path.splitext(os.path.basename(audio_path))[0],
'type': type, 'feat':
audio_path,
'feat_shape': (duration, ), #second
'type':
dtype,
}, },
ensure_ascii=False)) ensure_ascii=False))
manifest_path = manifest_path_prefix + '.' + type manifest_path = manifest_path_prefix + '.' + dtype
with codecs.open(manifest_path, 'w', 'utf-8') as fout: with codecs.open(manifest_path, 'w', 'utf-8') as fout:
for line in json_lines: for line in json_lines:
fout.write(line + '\n') fout.write(line + '\n')

@ -174,8 +174,9 @@ def generate_manifest(data_dir, manifest_path):
duration = float(len(audio_data)) / samplerate duration = float(len(audio_data)) / samplerate
json_lines.append( json_lines.append(
json.dumps({ json.dumps({
'audio_filepath': u, 'utt': os.path.splitext(os.path.basename(u))[0],
'duration': duration, 'feat': u,
'feat_shape': (duration, ), #second
'text': trans.lower() 'text': trans.lower()
})) }))

@ -15,13 +15,20 @@ if [ $? -ne 0 ]; then
exit 1 exit 1
fi fi
head -n 64 data/manifest.dev-clean > data/manifest.tiny head -n 64 data/manifest.dev-clean > data/manifest.tiny.raw
# bpemode (unigram or bpe)
nbpe=200
bpemode=unigram
bpeprefix="data/bpe_${bpemode}_${nbpe}"
# build vocabulary # build vocabulary
python3 ${MAIN_ROOT}/utils/build_vocab.py \ python3 ${MAIN_ROOT}/utils/build_vocab.py \
--count_threshold=0 \ --unit_type "bpe" \
--count_threshold=${nbpe} \
--bpe_mode ${bpemode} \
--bpe_model_prefix ${bpeprefix} \
--vocab_path="data/vocab.txt" \ --vocab_path="data/vocab.txt" \
--manifest_paths="data/manifest.tiny" --manifest_paths="data/manifest.tiny.raw"
if [ $? -ne 0 ]; then if [ $? -ne 0 ]; then
echo "Build vocabulary failed. Terminated." echo "Build vocabulary failed. Terminated."
@ -31,7 +38,7 @@ fi
# compute mean and stddev for normalizer # compute mean and stddev for normalizer
python3 ${MAIN_ROOT}/utils/compute_mean_std.py \ python3 ${MAIN_ROOT}/utils/compute_mean_std.py \
--manifest_path="data/manifest.tiny" \ --manifest_path="data/manifest.tiny.raw" \
--num_samples=64 \ --num_samples=64 \
--specgram_type="linear" \ --specgram_type="linear" \
--output_path="data/mean_std.npz" --output_path="data/mean_std.npz"
@ -41,5 +48,21 @@ if [ $? -ne 0 ]; then
exit 1 exit 1
fi fi
# format manifest with tokenids, vocab size
python3 ${MAIN_ROOT}/utils/format_data.py \
--feat_type "raw" \
--unit_type "bpe" \
--bpe_model_prefix ${bpeprefix} \
--vocab_path="data/vocab.txt" \
--manifest_path="data/manifest.tiny.raw" \
--output_path="data/manifest.tiny"
if [ $? -ne 0 ]; then
echo "Formt mnaifest failed. Terminated."
exit 1
fi
echo "LibriSpeech Data preparation done." echo "LibriSpeech Data preparation done."
exit 0 exit 0

@ -17,18 +17,24 @@ Each item in vocabulary file is a character.
import argparse import argparse
import functools import functools
import codecs
import json import json
from collections import Counter from collections import Counter
import os.path import os
import copy
import tempfile
from deepspeech.frontend.utility import read_manifest from deepspeech.frontend.utility import read_manifest
from deepspeech.utils.utility import add_arguments, print_arguments from deepspeech.frontend.utility import UNK
from deepspeech.frontend.utility import BLANK
from deepspeech.frontend.utility import SOS
from deepspeech.utils.utility import add_arguments
from deepspeech.utils.utility import print_arguments
parser = argparse.ArgumentParser(description=__doc__) parser = argparse.ArgumentParser(description=__doc__)
add_arg = functools.partial(add_arguments, argparser=parser) add_arg = functools.partial(add_arguments, argparser=parser)
# yapf: disable # yapf: disable
add_arg('count_threshold', int, 0, "Truncation threshold for char counts.") add_arg('unit_type', str, "character", "Unit type, e.g. character, word, bpe")
add_arg('count_threshold', int, 0, "Truncation threshold for char/word/bpe counts.")
add_arg('vocab_path', str, add_arg('vocab_path', str,
'examples/librispeech/data/vocab.txt', 'examples/librispeech/data/vocab.txt',
"Filepath to write the vocabulary.") "Filepath to write the vocabulary.")
@ -38,6 +44,11 @@ add_arg('manifest_paths', str,
"You can provide multiple manifest files.", "You can provide multiple manifest files.",
nargs='+', nargs='+',
required=True) required=True)
# bpe
add_arg('bpe_mode', str, 'unigram',
"bpe model type, e.g. unigram, bpe, char, word. only need when `unit_type` is bpe")
add_arg('bpe_model_prefix', str, "bpe_model_%(bpe_mode)_%(count_threshold)",
"bpe model prefix, only need when `unit_type` is bpe")
# yapf: disable # yapf: disable
args = parser.parse_args() args = parser.parse_args()
@ -45,23 +56,96 @@ args = parser.parse_args()
def count_manifest(counter, manifest_path): def count_manifest(counter, manifest_path):
manifest_jsons = read_manifest(manifest_path) manifest_jsons = read_manifest(manifest_path)
for line_json in manifest_jsons: for line_json in manifest_jsons:
if args.unit_type == 'character':
for char in line_json['text']: for char in line_json['text']:
counter.update(char) counter.update(char)
elif args.unit_type == 'word':
for word in line_json['text'].split():
counter.update(word)
def read_text_manifest(fileobj, manifest_path):
manifest_jsons = read_manifest(manifest_path)
for line_json in manifest_jsons:
fileobj.write(line_json['text'] + "\n")
def main(): def main():
print_arguments(args) print_arguments(args)
fout = open(args.vocab_path, 'w', encoding='utf-8')
fout.write(BLANK + "\n") # 0 will be used for "blank" in CTC
fout.write(UNK + '\n') # <unk> must be 1
if args.unit_type != 'bpe':
counter = Counter() counter = Counter()
for manifest_path in args.manifest_paths: for manifest_path in args.manifest_paths:
count_manifest(counter, manifest_path) count_manifest(counter, manifest_path)
count_sorted = sorted(counter.items(), key=lambda x: x[1], reverse=True) count_sorted = sorted(counter.items(), key=lambda x: x[1], reverse=True)
with codecs.open(args.vocab_path, 'w', 'utf-8') as fout:
fout.write('<unk>' + '\n')
for char, count in count_sorted: for char, count in count_sorted:
if count < args.count_threshold: break if count < args.count_threshold: break
fout.write(char + '\n') fout.write(char + '\n')
else:
# tools/spm_train --input=$wave_data/lang_char/input.txt
# --vocab_size=${nbpe} --model_type=${bpemode}
# --model_prefix=${bpemodel} --input_sentence_size=100000000
import sentencepiece as spm
fp = tempfile.NamedTemporaryFile(mode='w', delete=False)
for manifest_path in args.manifest_paths:
read_text_manifest(fp, manifest_path)
fp.close()
# train
spm.SentencePieceTrainer.Train(
input=fp.name,
vocab_size=args.count_threshold,
model_type=args.bpe_mode,
model_prefix=args.bpe_model_prefix,
input_sentence_size=100000000,
character_coverage=0.9995)
os.unlink(fp.name)
# encode
sp = spm.SentencePieceProcessor()
sp.Load(args.bpe_model_prefix + '.model')
stats = {"num_empty": 0, "num_filtered": 0}
def valid(line):
return True
def encode(l):
return sp.EncodeAsPieces(l)
def encode_line(line):
line = line.strip()
if len(line) > 0:
line = encode(line)
if valid(line):
return line
else:
stats["num_filtered"] += 1
else:
stats["num_empty"] += 1
return None
vocabs = set()
for manifest_path in args.manifest_paths:
manifest_jsons = read_manifest(manifest_path)
for line_json in manifest_jsons:
line = line_json['text']
enc_line = encode_line(line)
for code in enc_line:
vocabs.add(code)
#print(" ".join(enc_line))
vocabs_sorted = sorted(vocabs)
for unit in vocabs_sorted:
fout.write(unit + "\n")
print(f"bpe vocab size: {len(vocabs_sorted)}")
print(f"skip {stats['num_empty']} empty lines")
print(f"filter {stats['num_filtered']} invalid lines")
fout.write(SOS + "\n") # <sos/eos>
fout.close()
if __name__ == '__main__': if __name__ == '__main__':

@ -0,0 +1,127 @@
# Copyright (c) 2021 PaddlePaddle Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""format manifest with more metadata."""
import argparse
import functools
import json
from collections import Counter
import os
import copy
import tempfile
from deepspeech.frontend.utility import read_manifest
from deepspeech.frontend.utility import UNK
from deepspeech.frontend.utility import BLANK
from deepspeech.frontend.utility import SOS
from deepspeech.utils.utility import add_arguments
from deepspeech.utils.utility import print_arguments
parser = argparse.ArgumentParser(description=__doc__)
add_arg = functools.partial(add_arguments, argparser=parser)
# yapf: disable
add_arg('feat_type', str, "raw", "speech feature type, e.g. raw(wav, flac), kaldi")
add_arg('unit_type', str, "character", "Unit type, e.g. character, word, bpe")
add_arg('vocab_path', str,
'examples/librispeech/data/vocab.txt',
"Filepath to write the vocabulary.")
add_arg('manifest_paths', str,
None,
"Filepaths of manifests for building vocabulary. "
"You can provide multiple manifest files.",
nargs='+',
required=True)
# bpe
add_arg('bpe_model_prefix', str, "bpe_model_%(bpe_mode)_%(count_threshold)", "bpe model prefix, only need when `unit_type` is bpe")
add_arg('output_path', str, None, "filepath of formated manifest.", required=True)
# yapf: disable
args = parser.parse_args()
def main():
print_arguments(args)
# read vocab
vocab = dict()
with open(args.vocab_path, 'r', encoding='utf-8') as fin:
for line in fin:
token = line.strip()
vocab[token] = len(vocab)
vocab_size = len(vocab)
fout = open(args.output_path, 'w', encoding='utf-8')
if args.unit_type != 'bpe':
for manifest_path in args.manifest_paths:
manifest_jsons = read_manifest(manifest_path)
for line_json in manifest_jsons:
tokens = []
tokenids = []
if args.unit_type == 'character':
for char in line_json['text']:
tokens.append(char)
tokenids.append(vocab[char])
elif args.unit_type == 'word':
for word in line_json['text'].split():
tokens.append(word)
tokenids.append(vocab[word])
line_json['token'] = tokens
line_json['token_id'] = tokenids
line_json['token_shape'] = (len(tokenids), vocab_size)
fout.write(json.dumps(line_json) + '\n')
else:
import sentencepiece as spm
# encode
sp = spm.SentencePieceProcessor()
sp.Load(args.bpe_model_prefix + '.model')
def valid(line):
return True
def encode(l):
return sp.EncodeAsPieces(l)
def encode_line(line):
line = line.strip()
if len(line) > 0:
line = encode(line)
if valid(line):
return line
else:
stats["num_filtered"] += 1
else:
stats["num_empty"] += 1
return None
for manifest_path in args.manifest_paths:
manifest_jsons = read_manifest(manifest_path)
for line_json in manifest_jsons:
line = line_json['text']
tokens = []
tokenids = []
enc_line = encode_line(line)
for code in enc_line:
tokens.append(code)
tokenids.append(vocab[code])
#print(code, vocab[code])
line_json['token'] = tokens
line_json['token_id'] = tokenids
line_json['token_shape'] = (len(tokenids), vocab_size)
fout.write(json.dumps(line_json) + '\n')
fout.close()
if __name__ == '__main__':
main()
Loading…
Cancel
Save