You can not select more than 25 topics
Topics must start with a letter or number, can include dashes ('-') and can be up to 35 characters long.
106 lines
3.5 KiB
106 lines
3.5 KiB
#!/usr/bin/python
|
|
import argparse
|
|
import os
|
|
|
|
import yaml
|
|
|
|
|
|
def change_device(yamlfile: str, engine: str, device: str):
|
|
"""Change the settings of the device under the voice task configuration file
|
|
|
|
Args:
|
|
yaml_name (str): asr or asr_pd or tts or tts_pd
|
|
cpu (bool): True means set device to "cpu"
|
|
model_type (dict): change model type
|
|
"""
|
|
tmp_yamlfile = yamlfile.split(".yaml")[0] + "_tmp.yaml"
|
|
os.system("cp %s %s" % (yamlfile, tmp_yamlfile))
|
|
|
|
if device == 'cpu':
|
|
set_device = 'cpu'
|
|
elif device == 'gpu':
|
|
set_device = 'gpu:0'
|
|
else:
|
|
print("Please set correct device: cpu or gpu.")
|
|
|
|
with open(tmp_yamlfile) as f, open(yamlfile, "w+", encoding="utf-8") as fw:
|
|
y = yaml.safe_load(f)
|
|
if engine == 'asr_python' or engine == 'tts_python':
|
|
y[engine]['device'] = set_device
|
|
elif engine == 'asr_inference':
|
|
y[engine]['am_predictor_conf']['device'] = set_device
|
|
elif engine == 'tts_inference':
|
|
y[engine]['am_predictor_conf']['device'] = set_device
|
|
y[engine]['voc_predictor_conf']['device'] = set_device
|
|
else:
|
|
print(
|
|
"Please set correct engine: asr_python, tts_python, asr_inference, tts_inference."
|
|
)
|
|
|
|
print(yaml.dump(y, default_flow_style=False, sort_keys=False))
|
|
yaml.dump(y, fw, allow_unicode=True)
|
|
os.system("rm %s" % (tmp_yamlfile))
|
|
print("Change %s successfully." % (yamlfile))
|
|
|
|
|
|
def change_engine_type(yamlfile: str, engine_type):
|
|
"""Change the engine type and corresponding configuration file of the speech task in application.yaml
|
|
|
|
Args:
|
|
task (str): asr or tts
|
|
"""
|
|
tmp_yamlfile = yamlfile.split(".yaml")[0] + "_tmp.yaml"
|
|
os.system("cp %s %s" % (yamlfile, tmp_yamlfile))
|
|
speech_task = engine_type.split("_")[0]
|
|
|
|
with open(tmp_yamlfile) as f, open(yamlfile, "w+", encoding="utf-8") as fw:
|
|
y = yaml.safe_load(f)
|
|
engine_list = y['engine_list']
|
|
for engine in engine_list:
|
|
if speech_task in engine:
|
|
engine_list.remove(engine)
|
|
engine_list.append(engine_type)
|
|
y['engine_list'] = engine_list
|
|
print(yaml.dump(y, default_flow_style=False, sort_keys=False))
|
|
yaml.dump(y, fw, allow_unicode=True)
|
|
os.system("rm %s" % (tmp_yamlfile))
|
|
print("Change %s successfully." % (yamlfile))
|
|
|
|
|
|
if __name__ == "__main__":
|
|
parser = argparse.ArgumentParser()
|
|
parser.add_argument(
|
|
'--config_file',
|
|
type=str,
|
|
default='./conf/application.yaml',
|
|
help='server yaml file.')
|
|
parser.add_argument(
|
|
'--change_task',
|
|
type=str,
|
|
default=None,
|
|
help='Change task',
|
|
choices=[
|
|
'enginetype-asr_python',
|
|
'enginetype-asr_inference',
|
|
'enginetype-tts_python',
|
|
'enginetype-tts_inference',
|
|
'device-asr_python-cpu',
|
|
'device-asr_python-gpu',
|
|
'device-asr_inference-cpu',
|
|
'device-asr_inference-gpu',
|
|
'device-tts_python-cpu',
|
|
'device-tts_python-gpu',
|
|
'device-tts_inference-cpu',
|
|
'device-tts_inference-gpu',
|
|
],
|
|
required=True)
|
|
args = parser.parse_args()
|
|
|
|
types = args.change_task.split("-")
|
|
if types[0] == "enginetype":
|
|
change_engine_type(args.config_file, types[1])
|
|
elif types[0] == "device":
|
|
change_device(args.config_file, types[1], types[2])
|
|
else:
|
|
print("Error change task, please check change_task.")
|