You can not select more than 25 topics Topics must start with a letter or number, can include dashes ('-') and can be up to 35 characters long.
PaddleSpeech/cloud/split_data.py

54 lines
1.7 KiB

"""This tool is used for splitting data into each node of
paddle cloud by total trainer count and current trainer id.
The meaning of trainer is a instance of k8s cluster.
This script should be called in paddle cloud.
"""
from __future__ import absolute_import
from __future__ import division
from __future__ import print_function
import os
import json
import argparse
parser = argparse.ArgumentParser(description=__doc__)
parser.add_argument(
"--in_manifest_path",
default='./cloud.train.manifest',
type=str,
help="Input manifest path. (default: %(default)s)")
parser.add_argument(
"--data_tar_path",
default='./cloud.train.tar',
type=str,
help="Data tar file path. (default: %(default)s)")
parser.add_argument(
"--out_manifest_path",
default='./local.train.manifest',
type=str,
help="Out manifest file path. (default: %(default)s)")
args = parser.parse_args()
def split_data(in_manifest, tar_path, out_manifest):
with open("/trainer_id", "r") as f:
trainer_id = int(f.readline()[:-1])
with open("/trainer_count", "r") as f:
trainer_count = int(f.readline()[:-1])
tar_path = os.path.abspath(tar_path)
result = []
for index, json_line in enumerate(open(in_manifest)):
if (index % trainer_count) == trainer_id:
json_data = json.loads(json_line)
json_data['audio_filepath'] = "tar:%s#%s" % (
tar_path, json_data['audio_filepath'])
result.append("%s\n" % json.dumps(json_data))
with open(out_manifest, 'w') as manifest:
manifest.writelines(result)
if __name__ == '__main__':
split_data(args.in_manifest_path, args.data_tar_path,
args.out_manifest_path)