You can not select more than 25 topics Topics must start with a letter or number, can include dashes ('-') and can be up to 35 characters long.
PaddleSpeech/demos/speech_web/speech_server/src/ernie_sat.py

199 lines
7.5 KiB

import os
from .util import get_ngpu
from .util import MAIN_ROOT
from .util import run_cmd
class SAT:
def __init__(self):
# pretrain model path
self.zh_pretrain_model_path = os.path.realpath(
"source/model/erniesat_aishell3_ckpt_1.2.0")
self.en_pretrain_model_path = os.path.realpath(
"source/model/erniesat_vctk_ckpt_1.2.0")
self.cross_pretrain_model_path = os.path.realpath(
"source/model/erniesat_aishell3_vctk_ckpt_1.2.0")
self.zh_voc_model_path = os.path.realpath(
"source/model/hifigan_aishell3_ckpt_0.2.0")
self.eb_voc_model_path = os.path.realpath(
"source/model/hifigan_vctk_ckpt_0.2.0")
self.cross_voc_model_path = os.path.realpath(
"source/model/hifigan_aishell3_ckpt_0.2.0")
self.BIN_DIR = os.path.join(MAIN_ROOT,
"paddlespeech/t2s/exps/ernie_sat")
def zh_synthesize_edit(self,
old_str: str,
new_str: str,
input_name: os.PathLike,
output_name: os.PathLike,
task_name: str="synthesize",
erniesat_ckpt_name: str="snapshot_iter_289500.pdz"):
if task_name not in ['synthesize', 'edit']:
print("task name only in ['edit', 'synthesize']")
return None
# 推理文件配置
config_path = os.path.join(self.zh_pretrain_model_path, "default.yaml")
phones_dict = os.path.join(self.zh_pretrain_model_path,
"phone_id_map.txt")
erniesat_ckpt = os.path.join(self.zh_pretrain_model_path,
erniesat_ckpt_name)
erniesat_stat = os.path.join(self.zh_pretrain_model_path,
"speech_stats.npy")
voc = "hifigan_aishell3"
voc_config = os.path.join(self.zh_voc_model_path, "default.yaml")
voc_ckpt = os.path.join(self.zh_voc_model_path,
"snapshot_iter_2500000.pdz")
voc_stat = os.path.join(self.zh_voc_model_path, "feats_stats.npy")
cmd = self.get_cmd(
task_name=task_name,
input_name=input_name,
old_str=old_str,
new_str=new_str,
config_path=config_path,
phones_dict=phones_dict,
erniesat_ckpt=erniesat_ckpt,
erniesat_stat=erniesat_stat,
voc=voc,
voc_config=voc_config,
voc_ckpt=voc_ckpt,
voc_stat=voc_stat,
output_name=output_name,
source_lang="zh",
target_lang="zh")
return run_cmd(cmd, output_name)
def crossclone(self,
old_str: str,
new_str: str,
input_name: os.PathLike,
output_name: os.PathLike,
source_lang: str,
target_lang: str,
erniesat_ckpt_name: str="snapshot_iter_489000.pdz"):
# 推理文件配置
config_path = os.path.join(self.cross_pretrain_model_path,
"default.yaml")
phones_dict = os.path.join(self.cross_pretrain_model_path,
"phone_id_map.txt")
erniesat_ckpt = os.path.join(self.cross_pretrain_model_path,
erniesat_ckpt_name)
erniesat_stat = os.path.join(self.cross_pretrain_model_path,
"speech_stats.npy")
voc = "hifigan_aishell3"
voc_config = os.path.join(self.cross_voc_model_path, "default.yaml")
voc_ckpt = os.path.join(self.cross_voc_model_path,
"snapshot_iter_2500000.pdz")
voc_stat = os.path.join(self.cross_voc_model_path, "feats_stats.npy")
task_name = "synthesize"
cmd = self.get_cmd(
task_name=task_name,
input_name=input_name,
old_str=old_str,
new_str=new_str,
config_path=config_path,
phones_dict=phones_dict,
erniesat_ckpt=erniesat_ckpt,
erniesat_stat=erniesat_stat,
voc=voc,
voc_config=voc_config,
voc_ckpt=voc_ckpt,
voc_stat=voc_stat,
output_name=output_name,
source_lang=source_lang,
target_lang=target_lang)
return run_cmd(cmd, output_name)
def en_synthesize_edit(self,
old_str: str,
new_str: str,
input_name: os.PathLike,
output_name: os.PathLike,
task_name: str="synthesize",
erniesat_ckpt_name: str="snapshot_iter_199500.pdz"):
# 推理文件配置
config_path = os.path.join(self.en_pretrain_model_path, "default.yaml")
phones_dict = os.path.join(self.en_pretrain_model_path,
"phone_id_map.txt")
erniesat_ckpt = os.path.join(self.en_pretrain_model_path,
erniesat_ckpt_name)
erniesat_stat = os.path.join(self.en_pretrain_model_path,
"speech_stats.npy")
voc = "hifigan_aishell3"
voc_config = os.path.join(self.zh_voc_model_path, "default.yaml")
voc_ckpt = os.path.join(self.zh_voc_model_path,
"snapshot_iter_2500000.pdz")
voc_stat = os.path.join(self.zh_voc_model_path, "feats_stats.npy")
cmd = self.get_cmd(
task_name=task_name,
input_name=input_name,
old_str=old_str,
new_str=new_str,
config_path=config_path,
phones_dict=phones_dict,
erniesat_ckpt=erniesat_ckpt,
erniesat_stat=erniesat_stat,
voc=voc,
voc_config=voc_config,
voc_ckpt=voc_ckpt,
voc_stat=voc_stat,
output_name=output_name,
source_lang="en",
target_lang="en")
return run_cmd(cmd, output_name)
def get_cmd(self,
task_name: str,
input_name: str,
old_str: str,
new_str: str,
config_path: str,
phones_dict: str,
erniesat_ckpt: str,
erniesat_stat: str,
voc: str,
voc_config: str,
voc_ckpt: str,
voc_stat: str,
output_name: str,
source_lang: str,
target_lang: str):
ngpu = get_ngpu()
cmd = f"""
FLAGS_allocator_strategy=naive_best_fit \
FLAGS_fraction_of_gpu_memory_to_use=0.01 \
python3 {self.BIN_DIR}/synthesize_e2e.py \
--task_name={task_name} \
--wav_path={input_name} \
--old_str='{old_str}' \
--new_str='{new_str}' \
--source_lang={source_lang} \
--target_lang={target_lang} \
--erniesat_config={config_path} \
--phones_dict={phones_dict} \
--erniesat_ckpt={erniesat_ckpt} \
--erniesat_stat={erniesat_stat} \
--voc={voc} \
--voc_config={voc_config} \
--voc_ckpt={voc_ckpt} \
--voc_stat={voc_stat} \
--output_name={output_name} \
--ngpu={ngpu}
"""
return cmd