|
|
import argparse
|
|
|
import base64
|
|
|
import datetime
|
|
|
import json
|
|
|
import os
|
|
|
from typing import List
|
|
|
|
|
|
import aiofiles
|
|
|
import librosa
|
|
|
import soundfile as sf
|
|
|
import uvicorn
|
|
|
from fastapi import FastAPI
|
|
|
from fastapi import UploadFile
|
|
|
from pydantic import BaseModel
|
|
|
from src.ernie_sat import SAT
|
|
|
from src.finetune import FineTune
|
|
|
from src.ge2e_clone import VoiceCloneGE2E
|
|
|
from src.tdnn_clone import VoiceCloneTDNN
|
|
|
from src.util import *
|
|
|
from starlette.responses import FileResponse
|
|
|
|
|
|
from paddlespeech.server.utils.audio_process import float2pcm
|
|
|
|
|
|
# 解析配置
|
|
|
parser = argparse.ArgumentParser(prog='PaddleSpeechDemo', add_help=True)
|
|
|
|
|
|
parser.add_argument(
|
|
|
"--port",
|
|
|
action="store",
|
|
|
type=int,
|
|
|
help="port of the app",
|
|
|
default=8010,
|
|
|
required=False)
|
|
|
|
|
|
args = parser.parse_args()
|
|
|
port = args.port
|
|
|
|
|
|
# 这里会对finetune产生影响,所以finetune使用了cmd
|
|
|
vc_model = VoiceCloneGE2E()
|
|
|
vc_model_tdnn = VoiceCloneTDNN()
|
|
|
|
|
|
sat_model = SAT()
|
|
|
ft_model = FineTune()
|
|
|
|
|
|
# 配置文件
|
|
|
tts_config = "conf/tts_online_application.yaml"
|
|
|
asr_config = "conf/ws_conformer_wenetspeech_application_faster.yaml"
|
|
|
asr_init_path = "source/demo/demo.wav"
|
|
|
db_path = "source/db/vc.sqlite"
|
|
|
ie_model_path = "source/model"
|
|
|
|
|
|
# 路径配置
|
|
|
VC_UPLOAD_PATH = "source/wav/vc/upload"
|
|
|
VC_OUT_PATH = "source/wav/vc/out"
|
|
|
|
|
|
FT_UPLOAD_PATH = "source/wav/finetune/upload"
|
|
|
FT_OUT_PATH = "source/wav/finetune/out"
|
|
|
FT_LABEL_PATH = "source/wav/finetune/label.json"
|
|
|
FT_LABEL_TXT_PATH = "source/wav/finetune/labels.txt"
|
|
|
FT_DEFAULT_PATH = "source/wav/finetune/default"
|
|
|
FT_EXP_BASE_PATH = "tmp_dir/finetune"
|
|
|
|
|
|
SAT_UPLOAD_PATH = "source/wav/SAT/upload"
|
|
|
SAT_OUT_PATH = "source/wav/SAT/out"
|
|
|
SAT_LABEL_PATH = "source/wav/SAT/label.json"
|
|
|
|
|
|
# SAT 标注结果初始化
|
|
|
if os.path.exists(SAT_LABEL_PATH):
|
|
|
with open(SAT_LABEL_PATH, "r", encoding='utf8') as f:
|
|
|
sat_label_dic = json.load(f)
|
|
|
else:
|
|
|
sat_label_dic = {}
|
|
|
|
|
|
# ft 标注结果初始化
|
|
|
if os.path.exists(FT_LABEL_PATH):
|
|
|
with open(FT_LABEL_PATH, "r", encoding='utf8') as f:
|
|
|
ft_label_dic = json.load(f)
|
|
|
else:
|
|
|
ft_label_dic = {}
|
|
|
|
|
|
# 新建文件夹
|
|
|
base_sources = [
|
|
|
VC_UPLOAD_PATH,
|
|
|
VC_OUT_PATH,
|
|
|
FT_UPLOAD_PATH,
|
|
|
FT_OUT_PATH,
|
|
|
FT_DEFAULT_PATH,
|
|
|
SAT_UPLOAD_PATH,
|
|
|
SAT_OUT_PATH,
|
|
|
]
|
|
|
for path in base_sources:
|
|
|
os.makedirs(path, exist_ok=True)
|
|
|
#####################################################################
|
|
|
########################### APP初始化 ###############################
|
|
|
#####################################################################
|
|
|
app = FastAPI()
|
|
|
|
|
|
######################################################################
|
|
|
########################### 接口类型 #################################
|
|
|
#####################################################################
|
|
|
|
|
|
|
|
|
# 接口结构
|
|
|
class VcBase(BaseModel):
|
|
|
wavName: str
|
|
|
wavPath: str
|
|
|
|
|
|
|
|
|
class VcBaseText(BaseModel):
|
|
|
wavName: str
|
|
|
wavPath: str
|
|
|
text: str
|
|
|
func: str
|
|
|
|
|
|
|
|
|
class VcBaseSAT(BaseModel):
|
|
|
old_str: str
|
|
|
new_str: str
|
|
|
language: str
|
|
|
function: str
|
|
|
wav: str # base64编码
|
|
|
filename: str
|
|
|
|
|
|
|
|
|
class FTPath(BaseModel):
|
|
|
dataPath: str
|
|
|
|
|
|
|
|
|
class VcBaseFT(BaseModel):
|
|
|
wav: str # base64编码
|
|
|
filename: str
|
|
|
wav_path: str
|
|
|
|
|
|
|
|
|
class VcBaseFTModel(BaseModel):
|
|
|
wav_path: str
|
|
|
|
|
|
|
|
|
class VcBaseFTSyn(BaseModel):
|
|
|
exp_path: str
|
|
|
text: str
|
|
|
|
|
|
|
|
|
######################################################################
|
|
|
########################### 文件列表查询与保存服务 #################################
|
|
|
#####################################################################
|
|
|
|
|
|
|
|
|
def getVCList(path):
|
|
|
VC_FileDict = []
|
|
|
# 查询upload路径下的wav文件名
|
|
|
for root, dirs, files in os.walk(path, topdown=False):
|
|
|
for name in files:
|
|
|
# print(os.path.join(root, name))
|
|
|
VC_FileDict.append({'name': name, 'path': os.path.join(root, name)})
|
|
|
VC_FileDict = sorted(VC_FileDict, key=lambda x: x['name'], reverse=True)
|
|
|
return VC_FileDict
|
|
|
|
|
|
|
|
|
async def saveFiles(files, SavePath):
|
|
|
right = 0
|
|
|
error = 0
|
|
|
error_info = "错误文件:"
|
|
|
for file in files:
|
|
|
try:
|
|
|
if 'blob' in file.filename:
|
|
|
out_file_path = os.path.join(
|
|
|
SavePath,
|
|
|
datetime.datetime.strftime(datetime.datetime.now(),
|
|
|
'%H%M') + randName(3) + ".wav")
|
|
|
else:
|
|
|
out_file_path = os.path.join(SavePath, file.filename)
|
|
|
|
|
|
print("上传文件名:", out_file_path)
|
|
|
async with aiofiles.open(out_file_path, 'wb') as out_file:
|
|
|
content = await file.read() # async read
|
|
|
await out_file.write(content) # async write
|
|
|
# 将文件转成24k, 16bit类型的wav文件
|
|
|
wav, sr = librosa.load(out_file_path, sr=16000)
|
|
|
sf.write(out_file_path, data=wav, samplerate=sr)
|
|
|
right += 1
|
|
|
except Exception as e:
|
|
|
error += 1
|
|
|
error_info = error_info + file.filename + " " + str(e) + "\n"
|
|
|
continue
|
|
|
return f"上传成功:{right}, 上传失败:{error}, 失败原因: {error_info}"
|
|
|
|
|
|
|
|
|
# 音频下载
|
|
|
@app.post("/vc/download")
|
|
|
async def VcDownload(base: VcBase):
|
|
|
if os.path.exists(base.wavPath):
|
|
|
return FileResponse(base.wavPath)
|
|
|
else:
|
|
|
return ErrorRequest(message="下载请求失败,文件不存在")
|
|
|
|
|
|
|
|
|
# 音频下载base64
|
|
|
@app.post("/vc/download_base64")
|
|
|
async def VcDownloadBase64(base: VcBase):
|
|
|
if os.path.exists(base.wavPath):
|
|
|
# 将文件转成16k, 16bit类型的wav文件
|
|
|
wav, sr = librosa.load(base.wavPath, sr=16000)
|
|
|
wav = float2pcm(wav) # float32 to int16
|
|
|
wav_bytes = wav.tobytes() # to bytes
|
|
|
wav_base64 = base64.b64encode(wav_bytes).decode('utf8')
|
|
|
return SuccessRequest(result=wav_base64)
|
|
|
else:
|
|
|
return ErrorRequest(message="播放请求失败,文件不存在")
|
|
|
|
|
|
|
|
|
######################################################################
|
|
|
########################### VC 服务 #################################
|
|
|
#####################################################################
|
|
|
|
|
|
|
|
|
# 上传文件
|
|
|
@app.post("/vc/upload")
|
|
|
async def VcUpload(files: List[UploadFile]):
|
|
|
# res = saveFiles(files, VC_UPLOAD_PATH)
|
|
|
right = 0
|
|
|
error = 0
|
|
|
error_info = "错误文件:"
|
|
|
for file in files:
|
|
|
try:
|
|
|
if 'blob' in file.filename:
|
|
|
out_file_path = os.path.join(
|
|
|
VC_UPLOAD_PATH,
|
|
|
datetime.datetime.strftime(datetime.datetime.now(),
|
|
|
'%H%M') + randName(3) + ".wav")
|
|
|
else:
|
|
|
out_file_path = os.path.join(VC_UPLOAD_PATH, file.filename)
|
|
|
|
|
|
print("上传文件名:", out_file_path)
|
|
|
async with aiofiles.open(out_file_path, 'wb') as out_file:
|
|
|
content = await file.read() # async read
|
|
|
await out_file.write(content) # async write
|
|
|
# 将文件转成24k, 16bit类型的wav文件
|
|
|
wav, sr = librosa.load(out_file_path, sr=16000)
|
|
|
sf.write(out_file_path, data=wav, samplerate=sr)
|
|
|
right += 1
|
|
|
except Exception as e:
|
|
|
error += 1
|
|
|
error_info = error_info + file.filename + " " + str(e) + "\n"
|
|
|
continue
|
|
|
return SuccessRequest(
|
|
|
result=f"上传成功:{right}, 上传失败:{error}, 失败原因: {error_info}")
|
|
|
|
|
|
|
|
|
# 获取文件列表
|
|
|
@app.get("/vc/list")
|
|
|
async def VcList():
|
|
|
res = getVCList(VC_UPLOAD_PATH)
|
|
|
return SuccessRequest(result=res)
|
|
|
|
|
|
|
|
|
# 获取音频文件
|
|
|
@app.post("/vc/file")
|
|
|
async def VcFileGet(base: VcBase):
|
|
|
if os.path.exists(base.wavPath):
|
|
|
return FileResponse(base.wavPath)
|
|
|
else:
|
|
|
return ErrorRequest(result="获取文件失败")
|
|
|
|
|
|
|
|
|
# 删除音频文件
|
|
|
@app.post("/vc/del")
|
|
|
async def VcFileDel(base: VcBase):
|
|
|
if os.path.exists(base.wavPath):
|
|
|
os.remove(base.wavPath)
|
|
|
return SuccessRequest(result="删除成功")
|
|
|
else:
|
|
|
return ErrorRequest(result="删除失败")
|
|
|
|
|
|
|
|
|
# 声音克隆G2P
|
|
|
@app.post("/vc/clone_g2p")
|
|
|
async def VcCloneG2P(base: VcBaseText):
|
|
|
if os.path.exists(base.wavPath):
|
|
|
try:
|
|
|
if base.func == 'ge2e':
|
|
|
wavName = base.wavName
|
|
|
wavPath = os.path.join(VC_OUT_PATH, wavName)
|
|
|
vc_model.vc(
|
|
|
text=base.text, input_wav=base.wavPath, out_wav=wavPath)
|
|
|
else:
|
|
|
wavName = base.wavName
|
|
|
wavPath = os.path.join(VC_OUT_PATH, wavName)
|
|
|
vc_model_tdnn.vc(
|
|
|
text=base.text, input_wav=base.wavPath, out_wav=wavPath)
|
|
|
res = {"wavName": wavName, "wavPath": wavPath}
|
|
|
return SuccessRequest(result=res)
|
|
|
except Exception as e:
|
|
|
print(e)
|
|
|
return ErrorRequest(message="克隆失败,合成过程报错")
|
|
|
else:
|
|
|
return ErrorRequest(message="克隆失败,音频不存在")
|
|
|
|
|
|
|
|
|
######################################################################
|
|
|
########################### SAT 服务 #################################
|
|
|
#####################################################################
|
|
|
# 声音克隆SAT
|
|
|
@app.post("/vc/clone_sat")
|
|
|
async def VcCloneSAT(base: VcBaseSAT):
|
|
|
# 重新整理 sat_label_dict
|
|
|
if base.filename not in sat_label_dic or sat_label_dic[
|
|
|
base.filename] != base.old_str:
|
|
|
sat_label_dic[base.filename] = base.old_str
|
|
|
with open(SAT_LABEL_PATH, "w", encoding='utf8') as f:
|
|
|
json.dump(sat_label_dic, f, ensure_ascii=False, indent=4)
|
|
|
|
|
|
input_file_path = base.wav
|
|
|
|
|
|
# 选择任务
|
|
|
if base.language == "zh":
|
|
|
# 中文
|
|
|
if base.function == "synthesize":
|
|
|
output_file_path = os.path.join(SAT_OUT_PATH,
|
|
|
"sat_syn_zh_" + base.filename)
|
|
|
# 中文克隆
|
|
|
sat_result = sat_model.zh_synthesize_edit(
|
|
|
old_str=base.old_str,
|
|
|
new_str=base.new_str,
|
|
|
input_name=os.path.realpath(input_file_path),
|
|
|
output_name=os.path.realpath(output_file_path),
|
|
|
task_name="synthesize")
|
|
|
elif base.function == "edit":
|
|
|
output_file_path = os.path.join(SAT_OUT_PATH,
|
|
|
"sat_edit_zh_" + base.filename)
|
|
|
# 中文语音编辑
|
|
|
sat_result = sat_model.zh_synthesize_edit(
|
|
|
old_str=base.old_str,
|
|
|
new_str=base.new_str,
|
|
|
input_name=os.path.realpath(input_file_path),
|
|
|
output_name=os.path.realpath(output_file_path),
|
|
|
task_name="edit")
|
|
|
elif base.function == "crossclone":
|
|
|
output_file_path = os.path.join(SAT_OUT_PATH,
|
|
|
"sat_cross_zh_" + base.filename)
|
|
|
# 中文跨语言
|
|
|
sat_result = sat_model.crossclone(
|
|
|
old_str=base.old_str,
|
|
|
new_str=base.new_str,
|
|
|
input_name=os.path.realpath(input_file_path),
|
|
|
output_name=os.path.realpath(output_file_path),
|
|
|
source_lang="zh",
|
|
|
target_lang="en")
|
|
|
else:
|
|
|
return ErrorRequest(
|
|
|
message="请检查功能选项是否正确,仅支持:synthesize, edit, crossclone")
|
|
|
elif base.language == "en":
|
|
|
if base.function == "synthesize":
|
|
|
output_file_path = os.path.join(SAT_OUT_PATH,
|
|
|
"sat_syn_zh_" + base.filename)
|
|
|
# 英文语音克隆
|
|
|
sat_result = sat_model.en_synthesize_edit(
|
|
|
old_str=base.old_str,
|
|
|
new_str=base.new_str,
|
|
|
input_name=os.path.realpath(input_file_path),
|
|
|
output_name=os.path.realpath(output_file_path),
|
|
|
task_name="synthesize")
|
|
|
elif base.function == "edit":
|
|
|
output_file_path = os.path.join(SAT_OUT_PATH,
|
|
|
"sat_edit_zh_" + base.filename)
|
|
|
# 英文语音编辑
|
|
|
sat_result = sat_model.en_synthesize_edit(
|
|
|
old_str=base.old_str,
|
|
|
new_str=base.new_str,
|
|
|
input_name=os.path.realpath(input_file_path),
|
|
|
output_name=os.path.realpath(output_file_path),
|
|
|
task_name="edit")
|
|
|
elif base.function == "crossclone":
|
|
|
output_file_path = os.path.join(SAT_OUT_PATH,
|
|
|
"sat_cross_zh_" + base.filename)
|
|
|
# 英文跨语言
|
|
|
sat_result = sat_model.crossclone(
|
|
|
old_str=base.old_str,
|
|
|
new_str=base.new_str,
|
|
|
input_name=os.path.realpath(input_file_path),
|
|
|
output_name=os.path.realpath(output_file_path),
|
|
|
source_lang="en",
|
|
|
target_lang="zh")
|
|
|
else:
|
|
|
return ErrorRequest(
|
|
|
message="请检查功能选项是否正确,仅支持:synthesize, edit, crossclone")
|
|
|
else:
|
|
|
return ErrorRequest(message="请检查功能选项是否正确,仅支持中文和英文")
|
|
|
|
|
|
if sat_result:
|
|
|
return SuccessRequest(result=sat_result, message="SAT合成成功")
|
|
|
else:
|
|
|
return ErrorRequest(message="SAT 合成失败,请从后台检查错误信息!")
|
|
|
|
|
|
|
|
|
# SAT 文件列表
|
|
|
@app.get("/sat/list")
|
|
|
async def SatList():
|
|
|
res = []
|
|
|
filelist = getVCList(SAT_UPLOAD_PATH)
|
|
|
for fileitem in filelist:
|
|
|
if fileitem['name'] in sat_label_dic:
|
|
|
fileitem['label'] = sat_label_dic[fileitem['name']]
|
|
|
else:
|
|
|
fileitem['label'] = ""
|
|
|
res.append(fileitem)
|
|
|
return SuccessRequest(result=res)
|
|
|
|
|
|
|
|
|
# 上传 SAT 音频
|
|
|
# 上传文件
|
|
|
@app.post("/sat/upload")
|
|
|
async def SATUpload(files: List[UploadFile]):
|
|
|
right = 0
|
|
|
error = 0
|
|
|
error_info = "错误文件:"
|
|
|
for file in files:
|
|
|
try:
|
|
|
if 'blob' in file.filename:
|
|
|
out_file_path = os.path.join(
|
|
|
SAT_UPLOAD_PATH,
|
|
|
datetime.datetime.strftime(datetime.datetime.now(),
|
|
|
'%H%M') + randName(3) + ".wav")
|
|
|
else:
|
|
|
out_file_path = os.path.join(SAT_UPLOAD_PATH, file.filename)
|
|
|
|
|
|
print("上传文件名:", out_file_path)
|
|
|
async with aiofiles.open(out_file_path, 'wb') as out_file:
|
|
|
content = await file.read() # async read
|
|
|
await out_file.write(content) # async write
|
|
|
# 将文件转成24k, 16bit类型的wav文件
|
|
|
wav, sr = librosa.load(out_file_path, sr=16000)
|
|
|
sf.write(out_file_path, data=wav, samplerate=sr)
|
|
|
right += 1
|
|
|
except Exception as e:
|
|
|
error += 1
|
|
|
error_info = error_info + file.filename + " " + str(e) + "\n"
|
|
|
continue
|
|
|
return SuccessRequest(
|
|
|
result=f"上传成功:{right}, 上传失败:{error}, 失败原因: {error_info}")
|
|
|
|
|
|
|
|
|
######################################################################
|
|
|
########################### FinueTune 服务 #################################
|
|
|
#####################################################################
|
|
|
|
|
|
|
|
|
# finetune 文件列表
|
|
|
@app.post("/finetune/list")
|
|
|
async def FineTuneList(Path: FTPath):
|
|
|
dataPath = Path.dataPath
|
|
|
if dataPath == "default":
|
|
|
# 默认路径
|
|
|
FT_PATH = FT_DEFAULT_PATH
|
|
|
else:
|
|
|
FT_PATH = dataPath
|
|
|
|
|
|
res = []
|
|
|
filelist = getVCList(FT_PATH)
|
|
|
for name, value in ft_label_dic.items():
|
|
|
wav_path = os.path.join(FT_PATH, name)
|
|
|
if not os.path.exists(wav_path):
|
|
|
wav_path = ""
|
|
|
d = {'text': value['text'], 'name': name, 'path': wav_path}
|
|
|
res.append(d)
|
|
|
return SuccessRequest(result=res)
|
|
|
|
|
|
|
|
|
# 一键重置,获取新的文件地址
|
|
|
@app.get('/finetune/newdir')
|
|
|
async def FTGetNewDir():
|
|
|
new_path = os.path.join(FT_UPLOAD_PATH, randName(3))
|
|
|
if not os.path.exists(new_path):
|
|
|
os.makedirs(new_path, exist_ok=True)
|
|
|
# 把 labels.txt 复制进去
|
|
|
cmd = f"cp {FT_LABEL_TXT_PATH} {new_path}"
|
|
|
os.system(cmd)
|
|
|
return SuccessRequest(result=new_path)
|
|
|
|
|
|
|
|
|
# finetune 上传文件
|
|
|
@app.post("/finetune/upload")
|
|
|
async def FTUpload(base: VcBaseFT):
|
|
|
try:
|
|
|
# 文件夹是否存在
|
|
|
if not os.path.exists(base.wav_path):
|
|
|
os.makedirs(base.wav_path)
|
|
|
# 保存音频文件
|
|
|
out_file_path = os.path.join(base.wav_path, base.filename)
|
|
|
wav_b = base64.b64decode(base.wav)
|
|
|
async with aiofiles.open(out_file_path, 'wb') as out_file:
|
|
|
await out_file.write(wav_b) # async write
|
|
|
|
|
|
return SuccessRequest(result="上传成功")
|
|
|
except Exception as e:
|
|
|
return ErrorRequest(result="上传失败")
|
|
|
|
|
|
|
|
|
# finetune 微调
|
|
|
@app.post("/finetune/clone_finetune")
|
|
|
async def FTModel(base: VcBaseFTModel):
|
|
|
# 先检查 wav_path 是否有效
|
|
|
if base.wav_path == 'default':
|
|
|
data_path = FT_DEFAULT_PATH
|
|
|
else:
|
|
|
data_path = base.wav_path
|
|
|
if not os.path.exists(data_path):
|
|
|
return ErrorRequest(message="数据文件夹不存在")
|
|
|
|
|
|
data_base = data_path.split(os.sep)[-1]
|
|
|
exp_dir = os.path.join(FT_EXP_BASE_PATH, data_base)
|
|
|
try:
|
|
|
exp_dir = ft_model.finetune(
|
|
|
input_dir=os.path.realpath(data_path),
|
|
|
exp_dir=os.path.realpath(exp_dir))
|
|
|
if exp_dir:
|
|
|
return SuccessRequest(result=exp_dir)
|
|
|
else:
|
|
|
return ErrorRequest(message="微调失败")
|
|
|
except Exception as e:
|
|
|
print(e)
|
|
|
return ErrorRequest(message="微调失败")
|
|
|
|
|
|
|
|
|
# finetune 合成
|
|
|
@app.post("/finetune/clone_finetune_syn")
|
|
|
async def FTSyn(base: VcBaseFTSyn):
|
|
|
try:
|
|
|
if not os.path.exists(base.exp_path):
|
|
|
return ErrorRequest(result="模型路径不存在")
|
|
|
wav_name = randName(5)
|
|
|
wav_path = ft_model.synthesize(
|
|
|
text=base.text,
|
|
|
wav_name=wav_name,
|
|
|
out_wav_dir=os.path.realpath(FT_OUT_PATH),
|
|
|
exp_dir=os.path.realpath(base.exp_path))
|
|
|
if wav_path:
|
|
|
res = {"wavName": wav_name + ".wav", "wavPath": wav_path}
|
|
|
return SuccessRequest(result=res)
|
|
|
else:
|
|
|
return ErrorRequest(message="音频合成失败")
|
|
|
except Exception as e:
|
|
|
return ErrorRequest(message="音频合成失败")
|
|
|
|
|
|
|
|
|
if __name__ == '__main__':
|
|
|
uvicorn.run(app=app, host='0.0.0.0', port=port)
|