update data process

pull/1064/head
Junkun 3 years ago
parent 4823892169
commit 72a8c9337c

@ -76,8 +76,9 @@ if [ ${stage} -le 1 ] && [ ${stop_stage} -ge 1 ]; then
--spm_vocab_size=${nbpe} \ --spm_vocab_size=${nbpe} \
--spm_mode ${bpemode} \ --spm_mode ${bpemode} \
--spm_model_prefix ${bpeprefix} \ --spm_model_prefix ${bpeprefix} \
--spm_character_coverage 1. \
--vocab_path="${dict_dir}/vocab.txt" \ --vocab_path="${dict_dir}/vocab.txt" \
--text_keys 'text' 'text1' \ --text_keys 'text' \
--manifest_paths="data/manifest.train.raw" --manifest_paths="data/manifest.train.raw"
if [ $? -ne 0 ]; then if [ $? -ne 0 ]; then

@ -5,7 +5,7 @@ source path.sh
gpus=0,1,2,3 gpus=0,1,2,3
stage=0 stage=0
stop_stage=100 stop_stage=100
conf_path=conf/transformer_joint_noam.yaml conf_path=conf/transformer_mtl_noam.yaml
avg_num=5 avg_num=5
data_path=./TED_EnZh # path to unzipped data data_path=./TED_EnZh # path to unzipped data
source ${MAIN_ROOT}/utils/parse_options.sh || exit 1; source ${MAIN_ROOT}/utils/parse_options.sh || exit 1;

@ -55,6 +55,8 @@ add_arg('text_keys', str,
add_arg('spm_vocab_size', int, 0, "Vocab size for spm.") add_arg('spm_vocab_size', int, 0, "Vocab size for spm.")
add_arg('spm_mode', str, 'unigram', "spm model type, e.g. unigram, spm, char, word. only need when `unit_type` is spm") add_arg('spm_mode', str, 'unigram', "spm model type, e.g. unigram, spm, char, word. only need when `unit_type` is spm")
add_arg('spm_model_prefix', str, "", "spm_model_%(spm_mode)_%(count_threshold), spm model prefix, only need when `unit_type` is spm") add_arg('spm_model_prefix', str, "", "spm_model_%(spm_mode)_%(count_threshold), spm model prefix, only need when `unit_type` is spm")
add_arg('spm_character_coverage', float, 0.9995, "character coverage to determine the minimum symbols")
# yapf: disable # yapf: disable
args = parser.parse_args() args = parser.parse_args()
@ -66,8 +68,14 @@ def count_manifest(counter, text_feature, manifest_path):
manifest_jsons.append(json_data) manifest_jsons.append(json_data)
for line_json in manifest_jsons: for line_json in manifest_jsons:
if isinstance(line_json['text'], str):
line = text_feature.tokenize(line_json['text'], replace_space=False) line = text_feature.tokenize(line_json['text'], replace_space=False)
counter.update(line) counter.update(line)
else:
assert isinstance(line_json['text'], list)
for text in line_json['text']:
line = text_feature.tokenize(text, replace_space=False)
counter.update(line)
def dump_text_manifest(fileobj, manifest_path, key='text'): def dump_text_manifest(fileobj, manifest_path, key='text'):
manifest_jsons = [] manifest_jsons = []
@ -76,7 +84,12 @@ def dump_text_manifest(fileobj, manifest_path, key='text'):
manifest_jsons.append(json_data) manifest_jsons.append(json_data)
for line_json in manifest_jsons: for line_json in manifest_jsons:
if isinstance(line_json[key], str):
fileobj.write(line_json[key] + "\n") fileobj.write(line_json[key] + "\n")
else:
assert isinstance(line_json[key], list)
for line in line_json[key]:
fileobj.write(line + "\n")
def main(): def main():
print_arguments(args, globals()) print_arguments(args, globals())
@ -104,7 +117,7 @@ def main():
model_type=args.spm_mode, model_type=args.spm_mode,
model_prefix=args.spm_model_prefix, model_prefix=args.spm_model_prefix,
input_sentence_size=100000000, input_sentence_size=100000000,
character_coverage=0.9995) character_coverage=args.spm_character_coverage)
os.unlink(fp.name) os.unlink(fp.name)
# encode # encode

Loading…
Cancel
Save