You can not select more than 25 topics
Topics must start with a letter or number, can include dashes ('-') and can be up to 35 characters long.
54 lines
1.7 KiB
54 lines
1.7 KiB
"""This tool is used for splitting data into each node of
|
|
paddle cloud by total trainer count and current trainer id.
|
|
The meaning of trainer is a instance of k8s cluster.
|
|
This script should be called in paddle cloud.
|
|
"""
|
|
from __future__ import absolute_import
|
|
from __future__ import division
|
|
from __future__ import print_function
|
|
|
|
import os
|
|
import json
|
|
import argparse
|
|
|
|
parser = argparse.ArgumentParser(description=__doc__)
|
|
parser.add_argument(
|
|
"--in_manifest_path",
|
|
default='./cloud.train.manifest',
|
|
type=str,
|
|
help="Input manifest path. (default: %(default)s)")
|
|
parser.add_argument(
|
|
"--data_tar_path",
|
|
default='./cloud.train.tar',
|
|
type=str,
|
|
help="Data tar file path. (default: %(default)s)")
|
|
parser.add_argument(
|
|
"--out_manifest_path",
|
|
default='./local.train.manifest',
|
|
type=str,
|
|
help="Out manifest file path. (default: %(default)s)")
|
|
args = parser.parse_args()
|
|
|
|
|
|
def split_data(in_manifest, tar_path, out_manifest):
|
|
with open("/trainer_id", "r") as f:
|
|
trainer_id = int(f.readline()[:-1])
|
|
with open("/trainer_count", "r") as f:
|
|
trainer_count = int(f.readline()[:-1])
|
|
|
|
tar_path = os.path.abspath(tar_path)
|
|
result = []
|
|
for index, json_line in enumerate(open(in_manifest)):
|
|
if (index % trainer_count) == trainer_id:
|
|
json_data = json.loads(json_line)
|
|
json_data['audio_filepath'] = "tar:%s#%s" % (
|
|
tar_path, json_data['audio_filepath'])
|
|
result.append("%s\n" % json.dumps(json_data))
|
|
with open(out_manifest, 'w') as manifest:
|
|
manifest.writelines(result)
|
|
|
|
|
|
if __name__ == '__main__':
|
|
split_data(args.in_manifest_path, args.data_tar_path,
|
|
args.out_manifest_path)
|