You can not select more than 25 topics Topics must start with a letter or number, can include dashes ('-') and can be up to 35 characters long.
PaddleSpeech/demos/speech_web/speech_server/vc.py

551 lines
18 KiB

import argparse
import base64
import datetime
import json
import os
from typing import List
import aiofiles
import librosa
import soundfile as sf
import uvicorn
from fastapi import FastAPI
from fastapi import UploadFile
from pydantic import BaseModel
from src.ernie_sat import SAT
from src.finetune import FineTune
from src.ge2e_clone import VoiceCloneGE2E
from src.tdnn_clone import VoiceCloneTDNN
from src.util import *
from starlette.responses import FileResponse
from paddlespeech.server.utils.audio_process import float2pcm
# 解析配置
parser = argparse.ArgumentParser(prog='PaddleSpeechDemo', add_help=True)
parser.add_argument(
"--port",
action="store",
type=int,
help="port of the app",
default=8010,
required=False)
args = parser.parse_args()
port = args.port
# 这里会对finetune产生影响所以finetune使用了cmd
vc_model = VoiceCloneGE2E()
vc_model_tdnn = VoiceCloneTDNN()
sat_model = SAT()
ft_model = FineTune()
# 配置文件
tts_config = "conf/tts_online_application.yaml"
asr_config = "conf/ws_conformer_wenetspeech_application_faster.yaml"
asr_init_path = "source/demo/demo.wav"
db_path = "source/db/vc.sqlite"
ie_model_path = "source/model"
# 路径配置
VC_UPLOAD_PATH = "source/wav/vc/upload"
VC_OUT_PATH = "source/wav/vc/out"
FT_UPLOAD_PATH = "source/wav/finetune/upload"
FT_OUT_PATH = "source/wav/finetune/out"
FT_LABEL_PATH = "source/wav/finetune/label.json"
FT_LABEL_TXT_PATH = "source/wav/finetune/labels.txt"
FT_DEFAULT_PATH = "source/wav/finetune/default"
FT_EXP_BASE_PATH = "tmp_dir/finetune"
SAT_UPLOAD_PATH = "source/wav/SAT/upload"
SAT_OUT_PATH = "source/wav/SAT/out"
SAT_LABEL_PATH = "source/wav/SAT/label.json"
# SAT 标注结果初始化
if os.path.exists(SAT_LABEL_PATH):
with open(SAT_LABEL_PATH, "r", encoding='utf8') as f:
sat_label_dic = json.load(f)
else:
sat_label_dic = {}
# ft 标注结果初始化
if os.path.exists(FT_LABEL_PATH):
with open(FT_LABEL_PATH, "r", encoding='utf8') as f:
ft_label_dic = json.load(f)
else:
ft_label_dic = {}
# 新建文件夹
base_sources = [
VC_UPLOAD_PATH,
VC_OUT_PATH,
FT_UPLOAD_PATH,
FT_OUT_PATH,
FT_DEFAULT_PATH,
SAT_UPLOAD_PATH,
SAT_OUT_PATH,
]
for path in base_sources:
os.makedirs(path, exist_ok=True)
#####################################################################
########################### APP初始化 ###############################
#####################################################################
app = FastAPI()
######################################################################
########################### 接口类型 #################################
#####################################################################
# 接口结构
class VcBase(BaseModel):
wavName: str
wavPath: str
class VcBaseText(BaseModel):
wavName: str
wavPath: str
text: str
func: str
class VcBaseSAT(BaseModel):
old_str: str
new_str: str
language: str
function: str
wav: str # base64编码
filename: str
class FTPath(BaseModel):
dataPath: str
class VcBaseFT(BaseModel):
wav: str # base64编码
filename: str
wav_path: str
class VcBaseFTModel(BaseModel):
wav_path: str
class VcBaseFTSyn(BaseModel):
exp_path: str
text: str
######################################################################
########################### 文件列表查询与保存服务 #################################
#####################################################################
def getVCList(path):
VC_FileDict = []
# 查询upload路径下的wav文件名
for root, dirs, files in os.walk(path, topdown=False):
for name in files:
# print(os.path.join(root, name))
VC_FileDict.append({'name': name, 'path': os.path.join(root, name)})
VC_FileDict = sorted(VC_FileDict, key=lambda x: x['name'], reverse=True)
return VC_FileDict
async def saveFiles(files, SavePath):
right = 0
error = 0
error_info = "错误文件:"
for file in files:
try:
if 'blob' in file.filename:
out_file_path = os.path.join(
SavePath,
datetime.datetime.strftime(datetime.datetime.now(),
'%H%M') + randName(3) + ".wav")
else:
out_file_path = os.path.join(SavePath, file.filename)
print("上传文件名:", out_file_path)
async with aiofiles.open(out_file_path, 'wb') as out_file:
content = await file.read() # async read
await out_file.write(content) # async write
# 将文件转成24k, 16bit类型的wav文件
wav, sr = librosa.load(out_file_path, sr=16000)
sf.write(out_file_path, data=wav, samplerate=sr)
right += 1
except Exception as e:
error += 1
error_info = error_info + file.filename + " " + str(e) + "\n"
continue
return f"上传成功:{right}, 上传失败:{error}, 失败原因: {error_info}"
# 音频下载
@app.post("/vc/download")
async def VcDownload(base: VcBase):
if os.path.exists(base.wavPath):
return FileResponse(base.wavPath)
else:
return ErrorRequest(message="下载请求失败,文件不存在")
# 音频下载base64
@app.post("/vc/download_base64")
async def VcDownloadBase64(base: VcBase):
if os.path.exists(base.wavPath):
# 将文件转成16k, 16bit类型的wav文件
wav, sr = librosa.load(base.wavPath, sr=16000)
wav = float2pcm(wav) # float32 to int16
wav_bytes = wav.tobytes() # to bytes
wav_base64 = base64.b64encode(wav_bytes).decode('utf8')
return SuccessRequest(result=wav_base64)
else:
return ErrorRequest(message="播放请求失败,文件不存在")
######################################################################
########################### VC 服务 #################################
#####################################################################
# 上传文件
@app.post("/vc/upload")
async def VcUpload(files: List[UploadFile]):
# res = saveFiles(files, VC_UPLOAD_PATH)
right = 0
error = 0
error_info = "错误文件:"
for file in files:
try:
if 'blob' in file.filename:
out_file_path = os.path.join(
VC_UPLOAD_PATH,
datetime.datetime.strftime(datetime.datetime.now(),
'%H%M') + randName(3) + ".wav")
else:
out_file_path = os.path.join(VC_UPLOAD_PATH, file.filename)
print("上传文件名:", out_file_path)
async with aiofiles.open(out_file_path, 'wb') as out_file:
content = await file.read() # async read
await out_file.write(content) # async write
# 将文件转成24k, 16bit类型的wav文件
wav, sr = librosa.load(out_file_path, sr=16000)
sf.write(out_file_path, data=wav, samplerate=sr)
right += 1
except Exception as e:
error += 1
error_info = error_info + file.filename + " " + str(e) + "\n"
continue
return SuccessRequest(
result=f"上传成功:{right}, 上传失败:{error}, 失败原因: {error_info}")
# 获取文件列表
@app.get("/vc/list")
async def VcList():
res = getVCList(VC_UPLOAD_PATH)
return SuccessRequest(result=res)
# 获取音频文件
@app.post("/vc/file")
async def VcFileGet(base: VcBase):
if os.path.exists(base.wavPath):
return FileResponse(base.wavPath)
else:
return ErrorRequest(result="获取文件失败")
# 删除音频文件
@app.post("/vc/del")
async def VcFileDel(base: VcBase):
if os.path.exists(base.wavPath):
os.remove(base.wavPath)
return SuccessRequest(result="删除成功")
else:
return ErrorRequest(result="删除失败")
# 声音克隆G2P
@app.post("/vc/clone_g2p")
async def VcCloneG2P(base: VcBaseText):
if os.path.exists(base.wavPath):
try:
if base.func == 'ge2e':
wavName = base.wavName
wavPath = os.path.join(VC_OUT_PATH, wavName)
wavPath = vc_model.vc(
text=base.text, input_wav=base.wavPath, out_wav=wavPath)
else:
wavName = base.wavName
wavPath = os.path.join(VC_OUT_PATH, wavName)
wavPath = vc_model_tdnn.vc(
text=base.text, input_wav=base.wavPath, out_wav=wavPath)
if wavPath:
res = {"wavName": wavName, "wavPath": wavPath}
return SuccessRequest(result=res)
else:
return ErrorRequest(message="克隆失败,检查克隆脚本是否有效")
except Exception as e:
print(e)
return ErrorRequest(message="克隆失败,合成过程报错")
else:
return ErrorRequest(message="克隆失败,音频不存在")
######################################################################
########################### SAT 服务 #################################
#####################################################################
# 声音克隆SAT
@app.post("/vc/clone_sat")
async def VcCloneSAT(base: VcBaseSAT):
# 重新整理 sat_label_dict
if base.filename not in sat_label_dic or sat_label_dic[
base.filename] != base.old_str:
sat_label_dic[base.filename] = base.old_str
with open(SAT_LABEL_PATH, "w", encoding='utf8') as f:
json.dump(sat_label_dic, f, ensure_ascii=False, indent=4)
input_file_path = base.wav
# 选择任务
if base.language == "zh":
# 中文
if base.function == "synthesize":
output_file_path = os.path.join(SAT_OUT_PATH,
"sat_syn_zh_" + base.filename)
# 中文克隆
sat_result = sat_model.zh_synthesize_edit(
old_str=base.old_str,
new_str=base.new_str,
input_name=os.path.realpath(input_file_path),
output_name=os.path.realpath(output_file_path),
task_name="synthesize")
elif base.function == "edit":
output_file_path = os.path.join(SAT_OUT_PATH,
"sat_edit_zh_" + base.filename)
# 中文语音编辑
sat_result = sat_model.zh_synthesize_edit(
old_str=base.old_str,
new_str=base.new_str,
input_name=os.path.realpath(input_file_path),
output_name=os.path.realpath(output_file_path),
task_name="edit")
elif base.function == "crossclone":
output_file_path = os.path.join(SAT_OUT_PATH,
"sat_cross_zh_" + base.filename)
# 中文跨语言
sat_result = sat_model.crossclone(
old_str=base.old_str,
new_str=base.new_str,
input_name=os.path.realpath(input_file_path),
output_name=os.path.realpath(output_file_path),
source_lang="zh",
target_lang="en")
else:
return ErrorRequest(
message="请检查功能选项是否正确,仅支持:synthesize, edit, crossclone")
elif base.language == "en":
if base.function == "synthesize":
output_file_path = os.path.join(SAT_OUT_PATH,
"sat_syn_zh_" + base.filename)
# 英文语音克隆
sat_result = sat_model.en_synthesize_edit(
old_str=base.old_str,
new_str=base.new_str,
input_name=os.path.realpath(input_file_path),
output_name=os.path.realpath(output_file_path),
task_name="synthesize")
elif base.function == "edit":
output_file_path = os.path.join(SAT_OUT_PATH,
"sat_edit_zh_" + base.filename)
# 英文语音编辑
sat_result = sat_model.en_synthesize_edit(
old_str=base.old_str,
new_str=base.new_str,
input_name=os.path.realpath(input_file_path),
output_name=os.path.realpath(output_file_path),
task_name="edit")
elif base.function == "crossclone":
output_file_path = os.path.join(SAT_OUT_PATH,
"sat_cross_zh_" + base.filename)
# 英文跨语言
sat_result = sat_model.crossclone(
old_str=base.old_str,
new_str=base.new_str,
input_name=os.path.realpath(input_file_path),
output_name=os.path.realpath(output_file_path),
source_lang="en",
target_lang="zh")
else:
return ErrorRequest(
message="请检查功能选项是否正确,仅支持:synthesize, edit, crossclone")
else:
return ErrorRequest(message="请检查功能选项是否正确,仅支持中文和英文")
if sat_result:
return SuccessRequest(result=sat_result, message="SAT合成成功")
else:
return ErrorRequest(message="SAT 合成失败,请从后台检查错误信息!")
# SAT 文件列表
@app.get("/sat/list")
async def SatList():
res = []
filelist = getVCList(SAT_UPLOAD_PATH)
for fileitem in filelist:
if fileitem['name'] in sat_label_dic:
fileitem['label'] = sat_label_dic[fileitem['name']]
else:
fileitem['label'] = ""
res.append(fileitem)
return SuccessRequest(result=res)
# 上传 SAT 音频
# 上传文件
@app.post("/sat/upload")
async def SATUpload(files: List[UploadFile]):
right = 0
error = 0
error_info = "错误文件:"
for file in files:
try:
if 'blob' in file.filename:
out_file_path = os.path.join(
SAT_UPLOAD_PATH,
datetime.datetime.strftime(datetime.datetime.now(),
'%H%M') + randName(3) + ".wav")
else:
out_file_path = os.path.join(SAT_UPLOAD_PATH, file.filename)
print("上传文件名:", out_file_path)
async with aiofiles.open(out_file_path, 'wb') as out_file:
content = await file.read() # async read
await out_file.write(content) # async write
# 将文件转成24k, 16bit类型的wav文件
wav, sr = librosa.load(out_file_path, sr=16000)
sf.write(out_file_path, data=wav, samplerate=sr)
right += 1
except Exception as e:
error += 1
error_info = error_info + file.filename + " " + str(e) + "\n"
continue
return SuccessRequest(
result=f"上传成功:{right}, 上传失败:{error}, 失败原因: {error_info}")
######################################################################
########################### FinueTune 服务 #################################
#####################################################################
# finetune 文件列表
@app.post("/finetune/list")
async def FineTuneList(Path: FTPath):
dataPath = Path.dataPath
if dataPath == "default":
# 默认路径
FT_PATH = FT_DEFAULT_PATH
else:
FT_PATH = dataPath
res = []
filelist = getVCList(FT_PATH)
for name, value in ft_label_dic.items():
wav_path = os.path.join(FT_PATH, name)
if not os.path.exists(wav_path):
wav_path = ""
d = {'text': value['text'], 'name': name, 'path': wav_path}
res.append(d)
return SuccessRequest(result=res)
# 一键重置,获取新的文件地址
@app.get('/finetune/newdir')
async def FTGetNewDir():
new_path = os.path.join(FT_UPLOAD_PATH, randName(3))
if not os.path.exists(new_path):
os.makedirs(new_path, exist_ok=True)
# 把 labels.txt 复制进去
cmd = f"cp {FT_LABEL_TXT_PATH} {new_path}"
os.system(cmd)
return SuccessRequest(result=new_path)
# finetune 上传文件
@app.post("/finetune/upload")
async def FTUpload(base: VcBaseFT):
try:
# 文件夹是否存在
if not os.path.exists(base.wav_path):
os.makedirs(base.wav_path)
# 保存音频文件
out_file_path = os.path.join(base.wav_path, base.filename)
wav_b = base64.b64decode(base.wav)
async with aiofiles.open(out_file_path, 'wb') as out_file:
await out_file.write(wav_b) # async write
return SuccessRequest(result="上传成功")
except Exception as e:
return ErrorRequest(result="上传失败")
# finetune 微调
@app.post("/finetune/clone_finetune")
async def FTModel(base: VcBaseFTModel):
# 先检查 wav_path 是否有效
if base.wav_path == 'default':
data_path = FT_DEFAULT_PATH
else:
data_path = base.wav_path
if not os.path.exists(data_path):
return ErrorRequest(message="数据文件夹不存在")
data_base = data_path.split(os.sep)[-1]
exp_dir = os.path.join(FT_EXP_BASE_PATH, data_base)
try:
exp_dir = ft_model.finetune(
input_dir=os.path.realpath(data_path),
exp_dir=os.path.realpath(exp_dir))
if exp_dir:
return SuccessRequest(result=exp_dir)
else:
return ErrorRequest(message="微调失败")
except Exception as e:
print(e)
return ErrorRequest(message="微调失败")
# finetune 合成
@app.post("/finetune/clone_finetune_syn")
async def FTSyn(base: VcBaseFTSyn):
try:
if not os.path.exists(base.exp_path):
return ErrorRequest(result="模型路径不存在")
wav_name = randName(5)
wav_path = ft_model.synthesize(
text=base.text,
wav_name=wav_name,
out_wav_dir=os.path.realpath(FT_OUT_PATH),
exp_dir=os.path.realpath(base.exp_path))
if wav_path:
res = {"wavName": wav_name + ".wav", "wavPath": wav_path}
return SuccessRequest(result=res)
else:
return ErrorRequest(message="音频合成失败")
except Exception as e:
return ErrorRequest(message="音频合成失败")
if __name__ == '__main__':
uvicorn.run(app=app, host='0.0.0.0', port=port)