diff --git a/examples/ted_en_zh/st0/conf/transformer_joint_noam.yaml b/examples/ted_en_zh/st0/conf/transformer_mtl_noam.yaml similarity index 100% rename from examples/ted_en_zh/st0/conf/transformer_joint_noam.yaml rename to examples/ted_en_zh/st0/conf/transformer_mtl_noam.yaml diff --git a/examples/ted_en_zh/st0/local/data.sh b/examples/ted_en_zh/st0/local/data.sh index 7ea185db7..c4de1749e 100755 --- a/examples/ted_en_zh/st0/local/data.sh +++ b/examples/ted_en_zh/st0/local/data.sh @@ -76,8 +76,9 @@ if [ ${stage} -le 1 ] && [ ${stop_stage} -ge 1 ]; then --spm_vocab_size=${nbpe} \ --spm_mode ${bpemode} \ --spm_model_prefix ${bpeprefix} \ + --spm_character_coverage 1. \ --vocab_path="${dict_dir}/vocab.txt" \ - --text_keys 'text' 'text1' \ + --text_keys 'text' \ --manifest_paths="data/manifest.train.raw" if [ $? -ne 0 ]; then diff --git a/examples/ted_en_zh/st0/run.sh b/examples/ted_en_zh/st0/run.sh index fb4bc3388..bc5ee4e60 100755 --- a/examples/ted_en_zh/st0/run.sh +++ b/examples/ted_en_zh/st0/run.sh @@ -5,7 +5,7 @@ source path.sh gpus=0,1,2,3 stage=0 stop_stage=100 -conf_path=conf/transformer_joint_noam.yaml +conf_path=conf/transformer_mtl_noam.yaml avg_num=5 data_path=./TED_EnZh # path to unzipped data source ${MAIN_ROOT}/utils/parse_options.sh || exit 1; diff --git a/utils/build_vocab.py b/utils/build_vocab.py index f832cbbc3..e364e821e 100755 --- a/utils/build_vocab.py +++ b/utils/build_vocab.py @@ -55,6 +55,8 @@ add_arg('text_keys', str, add_arg('spm_vocab_size', int, 0, "Vocab size for spm.") add_arg('spm_mode', str, 'unigram', "spm model type, e.g. unigram, spm, char, word. only need when `unit_type` is spm") add_arg('spm_model_prefix', str, "", "spm_model_%(spm_mode)_%(count_threshold), spm model prefix, only need when `unit_type` is spm") +add_arg('spm_character_coverage', float, 0.9995, "character coverage to determine the minimum symbols") + # yapf: disable args = parser.parse_args() @@ -66,8 +68,14 @@ def count_manifest(counter, text_feature, manifest_path): manifest_jsons.append(json_data) for line_json in manifest_jsons: - line = text_feature.tokenize(line_json['text'], replace_space=False) - counter.update(line) + if isinstance(line_json['text'], str): + line = text_feature.tokenize(line_json['text'], replace_space=False) + counter.update(line) + else: + assert isinstance(line_json['text'], list) + for text in line_json['text']: + line = text_feature.tokenize(text, replace_space=False) + counter.update(line) def dump_text_manifest(fileobj, manifest_path, key='text'): manifest_jsons = [] @@ -76,7 +84,12 @@ def dump_text_manifest(fileobj, manifest_path, key='text'): manifest_jsons.append(json_data) for line_json in manifest_jsons: - fileobj.write(line_json[key] + "\n") + if isinstance(line_json[key], str): + fileobj.write(line_json[key] + "\n") + else: + assert isinstance(line_json[key], list) + for line in line_json[key]: + fileobj.write(line + "\n") def main(): print_arguments(args, globals()) @@ -104,7 +117,7 @@ def main(): model_type=args.spm_mode, model_prefix=args.spm_model_prefix, input_sentence_size=100000000, - character_coverage=0.9995) + character_coverage=args.spm_character_coverage) os.unlink(fp.name) # encode