[vec][server] vpr demo support, test=doc fix #1695

pull/1696/head
qingen 3 years ago
parent b02e0daedd
commit 89a0ec9018

@ -90,7 +90,7 @@ Then to start the system server, and it provides HTTP backend services.
```bash ```bash
export PYTHONPATH=$PYTHONPATH:./src:../../paddleaudio export PYTHONPATH=$PYTHONPATH:./src:../../paddleaudio
python src/main.py python src/audio_search.py
``` ```
Then you will see the Application is started: Then you will see the Application is started:
@ -111,7 +111,7 @@ Then to start the system server, and it provides HTTP backend services.
```bash ```bash
wget -c https://www.openslr.org/resources/82/cn-celeb_v2.tar.gz && tar -xvf cn-celeb_v2.tar.gz wget -c https://www.openslr.org/resources/82/cn-celeb_v2.tar.gz && tar -xvf cn-celeb_v2.tar.gz
``` ```
**Note**: If you want to build a quick demo, you can use ./src/test_main.py:download_audio_data function, it downloads 20 audio files , Subsequent results show this collection as an example **Note**: If you want to build a quick demo, you can use ./src/test_audio_search.py:download_audio_data function, it downloads 20 audio files , Subsequent results show this collection as an example
- Prepare model(Skip this step if you use the default model.) - Prepare model(Skip this step if you use the default model.)
```bash ```bash
@ -123,7 +123,7 @@ Then to start the system server, and it provides HTTP backend services.
The internal process is downloading data, loading the paddlespeech model, extracting embedding, storing library, retrieving and deleting library The internal process is downloading data, loading the paddlespeech model, extracting embedding, storing library, retrieving and deleting library
```bash ```bash
python ./src/test_main.py python ./src/test_audio_search.py
``` ```
Output Output

@ -92,7 +92,7 @@ ffce340b3790 minio/minio:RELEASE.2020-12-03T00-03-10Z "/usr/bin/docker-ent…"
```bash ```bash
export PYTHONPATH=$PYTHONPATH:./src:../../paddleaudio export PYTHONPATH=$PYTHONPATH:./src:../../paddleaudio
python src/main.py python src/audio_search.py
``` ```
然后你会看到应用程序启动: 然后你会看到应用程序启动:
@ -113,7 +113,7 @@ ffce340b3790 minio/minio:RELEASE.2020-12-03T00-03-10Z "/usr/bin/docker-ent…"
```bash ```bash
wget -c https://www.openslr.org/resources/82/cn-celeb_v2.tar.gz && tar -xvf cn-celeb_v2.tar.gz wget -c https://www.openslr.org/resources/82/cn-celeb_v2.tar.gz && tar -xvf cn-celeb_v2.tar.gz
``` ```
**注**:如果希望快速搭建 demo可以采用 ./src/test_main.py:download_audio_data 内部的 20 条音频,另外后续结果展示以该集合为例 **注**:如果希望快速搭建 demo可以采用 ./src/test_audio_search.py:download_audio_data 内部的 20 条音频,另外后续结果展示以该集合为例
- 准备模型(如果使用默认模型,可以跳过此步骤) - 准备模型(如果使用默认模型,可以跳过此步骤)
```bash ```bash
@ -124,7 +124,7 @@ ffce340b3790 minio/minio:RELEASE.2020-12-03T00-03-10Z "/usr/bin/docker-ent…"
- 脚本测试(推荐) - 脚本测试(推荐)
```bash ```bash
python ./src/test_main.py python ./src/test_audio_search.py
``` ```
注:内部将依次下载数据,加载 paddlespeech 模型,提取 embedding存储建库检索删库 注:内部将依次下载数据,加载 paddlespeech 模型,提取 embedding存储建库检索删库

@ -20,7 +20,6 @@ from diskcache import Cache
from fastapi import FastAPI from fastapi import FastAPI
from fastapi import File from fastapi import File
from fastapi import UploadFile from fastapi import UploadFile
from logs import LOGGER
from milvus_helpers import MilvusHelper from milvus_helpers import MilvusHelper
from mysql_helpers import MySQLHelper from mysql_helpers import MySQLHelper
from operations.count import do_count from operations.count import do_count
@ -32,6 +31,8 @@ from starlette.middleware.cors import CORSMiddleware
from starlette.requests import Request from starlette.requests import Request
from starlette.responses import FileResponse from starlette.responses import FileResponse
from logs import LOGGER
app = FastAPI() app = FastAPI()
app.add_middleware( app.add_middleware(
CORSMiddleware, CORSMiddleware,
@ -40,7 +41,6 @@ app.add_middleware(
allow_methods=["*"], allow_methods=["*"],
allow_headers=["*"]) allow_headers=["*"])
MODEL = None
MILVUS_CLI = MilvusHelper() MILVUS_CLI = MilvusHelper()
MYSQL_CLI = MySQLHelper() MYSQL_CLI = MySQLHelper()

@ -13,12 +13,14 @@
# limitations under the License. # limitations under the License.
import sys import sys
import numpy
import pymysql import pymysql
from config import MYSQL_DB from config import MYSQL_DB
from config import MYSQL_HOST from config import MYSQL_HOST
from config import MYSQL_PORT from config import MYSQL_PORT
from config import MYSQL_PWD from config import MYSQL_PWD
from config import MYSQL_USER from config import MYSQL_USER
from logs import LOGGER from logs import LOGGER
@ -69,7 +71,7 @@ class MySQLHelper():
sys.exit(1) sys.exit(1)
def load_data_to_mysql(self, table_name, data): def load_data_to_mysql(self, table_name, data):
# Batch insert (Milvus_ids, img_path) to mysql # Batch insert (Milvus_ids, audio_path) to mysql
self.test_connection() self.test_connection()
sql = "insert into " + table_name + " (milvus_id,audio_path) values (%s,%s);" sql = "insert into " + table_name + " (milvus_id,audio_path) values (%s,%s);"
try: try:
@ -82,7 +84,7 @@ class MySQLHelper():
sys.exit(1) sys.exit(1)
def search_by_milvus_ids(self, ids, table_name): def search_by_milvus_ids(self, ids, table_name):
# Get the img_path according to the milvus ids # Get the audio_path according to the milvus ids
self.test_connection() self.test_connection()
str_ids = str(ids).replace('[', '').replace(']', '') str_ids = str(ids).replace('[', '').replace(']', '')
sql = "select audio_path from " + table_name + " where milvus_id in (" + str_ids + ") order by field (milvus_id," + str_ids + ");" sql = "select audio_path from " + table_name + " where milvus_id in (" + str_ids + ") order by field (milvus_id," + str_ids + ");"
@ -120,14 +122,83 @@ class MySQLHelper():
sys.exit(1) sys.exit(1)
def count_table(self, table_name): def count_table(self, table_name):
# Get the number of mysql table # Get the number of spk in mysql table
self.test_connection() self.test_connection()
sql = "select count(milvus_id) from " + table_name + ";" sql = "select count(spk_id) from " + table_name + ";"
try: try:
self.cursor.execute(sql) self.cursor.execute(sql)
results = self.cursor.fetchall() results = self.cursor.fetchall()
LOGGER.debug(f"MYSQL count table:{table_name}") LOGGER.debug(f"MYSQL count table:{results[0][0]}")
return results[0][0] return results[0][0]
except Exception as e: except Exception as e:
LOGGER.error(f"MYSQL ERROR: {e} with sql: {sql}") LOGGER.error(f"MYSQL ERROR: {e} with sql: {sql}")
sys.exit(1) sys.exit(1)
def create_mysql_table_vpr(self, table_name):
# Create mysql table if not exists
self.test_connection()
sql = "create table if not exists " + table_name + "(spk_id TEXT, audio_path TEXT, embedding TEXT);"
try:
self.cursor.execute(sql)
LOGGER.debug(f"MYSQL create table: {table_name} with sql: {sql}")
except Exception as e:
LOGGER.error(f"MYSQL ERROR: {e} with sql: {sql}")
sys.exit(1)
def load_data_to_mysql_vpr(self, table_name, data):
# Insert (spk, audio, embedding) to mysql
self.test_connection()
sql = "insert into " + table_name + " (spk_id,audio_path,embedding) values (%s,%s,%s);"
try:
self.cursor.execute(sql, data)
LOGGER.debug(
f"MYSQL loads data to table: {table_name} successfully")
except Exception as e:
LOGGER.error(f"MYSQL ERROR: {e} with sql: {sql}")
sys.exit(1)
def list_vpr(self, table_name):
# Get all records in mysql
self.test_connection()
sql = "select * from " + table_name + " ;"
try:
self.cursor.execute(sql)
results = self.cursor.fetchall()
self.conn.commit()
spk_ids = [res[0] for res in results]
audio_paths = [res[1] for res in results]
embeddings = [
numpy.array(
str(res[2]).replace('[', '').replace(']', '').split(","))
for res in results
]
return spk_ids, audio_paths, embeddings
except Exception as e:
LOGGER.error(f"MYSQL ERROR: {e} with sql: {sql}")
sys.exit(1)
def search_audio_vpr(self, table_name, spk_id):
# Get the audio_path according to the spk_id
self.test_connection()
sql = "select audio_path from " + table_name + " where spk_id='" + spk_id + "' ;"
try:
self.cursor.execute(sql)
results = self.cursor.fetchall()
LOGGER.debug(
f"MYSQL search by spk id {spk_id} to get audio {results[0][0]}.")
return results[0][0]
except Exception as e:
LOGGER.error(f"MYSQL ERROR: {e} with sql: {sql}")
sys.exit(1)
def delete_data_vpr(self, table_name, spk_id):
# Delete a record by spk_id in mysql table
self.test_connection()
sql = "delete from " + table_name + " where spk_id='" + spk_id + "';"
try:
self.cursor.execute(sql)
LOGGER.debug(
f"MYSQL delete a record {spk_id} in table {table_name}")
except Exception as e:
LOGGER.error(f"MYSQL ERROR: {e} with sql: {sql}")
sys.exit(1)

@ -14,6 +14,7 @@
import sys import sys
from config import DEFAULT_TABLE from config import DEFAULT_TABLE
from logs import LOGGER from logs import LOGGER
@ -31,3 +32,45 @@ def do_count(table_name, milvus_cli):
except Exception as e: except Exception as e:
LOGGER.error(f"Error attempting to count table {e}") LOGGER.error(f"Error attempting to count table {e}")
sys.exit(1) sys.exit(1)
def do_count_vpr(table_name, mysql_cli):
"""
Returns the total number of spk in the system
"""
if not table_name:
table_name = DEFAULT_TABLE
try:
num = mysql_cli.count_table(table_name)
return num
except Exception as e:
LOGGER.error(f"Error attempting to count table {e}")
sys.exit(1)
def do_list(table_name, mysql_cli):
"""
Returns the total records of vpr in the system
"""
if not table_name:
table_name = DEFAULT_TABLE
try:
spk_ids, audio_paths, _ = mysql_cli.list_vpr(table_name)
return spk_ids, audio_paths
except Exception as e:
LOGGER.error(f"Error attempting to count table {e}")
sys.exit(1)
def do_get(table_name, spk_id, mysql_cli):
"""
Returns the audio path by spk_id in the system
"""
if not table_name:
table_name = DEFAULT_TABLE
try:
audio_apth = mysql_cli.search_audio_vpr(table_name, spk_id)
return audio_apth
except Exception as e:
LOGGER.error(f"Error attempting to count table {e}")
sys.exit(1)

@ -14,6 +14,7 @@
import sys import sys
from config import DEFAULT_TABLE from config import DEFAULT_TABLE
from logs import LOGGER from logs import LOGGER
@ -32,3 +33,31 @@ def do_drop(table_name, milvus_cli, mysql_cli):
except Exception as e: except Exception as e:
LOGGER.error(f"Error attempting to drop table: {e}") LOGGER.error(f"Error attempting to drop table: {e}")
sys.exit(1) sys.exit(1)
def do_drop_vpr(table_name, mysql_cli):
"""
Delete the table of MySQL
"""
if not table_name:
table_name = DEFAULT_TABLE
try:
mysql_cli.delete_table(table_name)
return "OK"
except Exception as e:
LOGGER.error(f"Error attempting to drop table: {e}")
sys.exit(1)
def do_delete(table_name, spk_id, mysql_cli):
"""
Delete a record by spk_id in MySQL
"""
if not table_name:
table_name = DEFAULT_TABLE
try:
mysql_cli.delete_data_vpr(table_name, spk_id)
return "OK"
except Exception as e:
LOGGER.error(f"Error attempting to drop table: {e}")
sys.exit(1)

@ -17,6 +17,7 @@ import sys
from config import DEFAULT_TABLE from config import DEFAULT_TABLE
from diskcache import Cache from diskcache import Cache
from encode import get_audio_embedding from encode import get_audio_embedding
from logs import LOGGER from logs import LOGGER
@ -26,8 +27,9 @@ def get_audios(path):
""" """
supported_formats = [".wav", ".mp3", ".ogg", ".flac", ".m4a"] supported_formats = [".wav", ".mp3", ".ogg", ".flac", ".m4a"]
return [ return [
item for sublist in [[os.path.join(dir, file) for file in files] item
for dir, _, files in list(os.walk(path))] for sublist in [[os.path.join(dir, file) for file in files]
for dir, _, files in list(os.walk(path))]
for item in sublist if os.path.splitext(item)[1] in supported_formats for item in sublist if os.path.splitext(item)[1] in supported_formats
] ]
@ -82,3 +84,16 @@ def do_load(table_name, audio_dir, milvus_cli, mysql_cli):
mysql_cli.create_mysql_table(table_name) mysql_cli.create_mysql_table(table_name)
mysql_cli.load_data_to_mysql(table_name, format_data(ids, names)) mysql_cli.load_data_to_mysql(table_name, format_data(ids, names))
return len(ids) return len(ids)
def do_enroll(table_name, spk_id, audio_path, mysql_cli):
"""
Import spk_id,audio_path,embedding to Mysql
"""
if not table_name:
table_name = DEFAULT_TABLE
embedding = get_audio_embedding(audio_path)
mysql_cli.create_mysql_table_vpr(table_name)
data = (spk_id, audio_path, str(embedding))
mysql_cli.load_data_to_mysql_vpr(table_name, data)
return "OK"

@ -13,9 +13,11 @@
# limitations under the License. # limitations under the License.
import sys import sys
import numpy
from config import DEFAULT_TABLE from config import DEFAULT_TABLE
from config import TOP_K from config import TOP_K
from encode import get_audio_embedding from encode import get_audio_embedding
from logs import LOGGER from logs import LOGGER
@ -39,3 +41,26 @@ def do_search(host, table_name, audio_path, milvus_cli, mysql_cli):
except Exception as e: except Exception as e:
LOGGER.error(f"Error with search: {e}") LOGGER.error(f"Error with search: {e}")
sys.exit(1) sys.exit(1)
def do_search_vpr(host, table_name, audio_path, mysql_cli):
"""
Search the uploaded audio in MySQL
"""
try:
if not table_name:
table_name = DEFAULT_TABLE
emb = get_audio_embedding(audio_path)
emb = numpy.array(emb)
spk_ids, paths, vectors = mysql_cli.list_vpr(table_name)
scores = [numpy.dot(emb, x.astype(numpy.float64)) for x in vectors]
spk_ids = [str(x) for x in spk_ids]
paths = [str(x) for x in paths]
for i in range(len(paths)):
tmp = "http://" + str(host) + "/data?audio_path=" + str(paths[i])
paths[i] = tmp
scores[i] = scores[i] * 100
return spk_ids, paths, scores
except Exception as e:
LOGGER.error(f"Error with search: {e}")
sys.exit(1)

@ -11,8 +11,8 @@
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and # See the License for the specific language governing permissions and
# limitations under the License. # limitations under the License.
from audio_search import app
from fastapi.testclient import TestClient from fastapi.testclient import TestClient
from main import app
from utils.utility import download from utils.utility import download
from utils.utility import unpack from utils.utility import unpack
@ -22,7 +22,7 @@ client = TestClient(app)
def download_audio_data(): def download_audio_data():
""" """
download audio data Download audio data
""" """
url = "https://paddlespeech.bj.bcebos.com/vector/audio/example_audio.tar.gz" url = "https://paddlespeech.bj.bcebos.com/vector/audio/example_audio.tar.gz"
md5sum = "52ac69316c1aa1fdef84da7dd2c67b39" md5sum = "52ac69316c1aa1fdef84da7dd2c67b39"
@ -64,7 +64,7 @@ def test_count():
""" """
Returns the total number of vectors in the system Returns the total number of vectors in the system
""" """
response = client.get("audio/count") response = client.get("/audio/count")
assert response.status_code == 200 assert response.status_code == 200
assert response.json() == 20 assert response.json() == 20

@ -0,0 +1,115 @@
# Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
from fastapi.testclient import TestClient
from vpr_search import app
from utils.utility import download
from utils.utility import unpack
client = TestClient(app)
def download_audio_data():
"""
Download audio data
"""
url = "https://paddlespeech.bj.bcebos.com/vector/audio/example_audio.tar.gz"
md5sum = "52ac69316c1aa1fdef84da7dd2c67b39"
target_dir = "./"
filepath = download(url, md5sum, target_dir)
unpack(filepath, target_dir, True)
def test_drop():
"""
Delete the table of MySQL
"""
response = client.post("/vpr/drop")
assert response.status_code == 200
def test_enroll_local(spk: str, audio: str):
"""
Enroll the audio to MySQL
"""
response = client.post("/vpr/enroll/local?spk_id=" + spk +
"&audio_path=.%2Fexample_audio%2F" + audio + ".wav")
assert response.status_code == 200
assert response.json() == {
'status': True,
'msg': "Successfully enroll data!"
}
def test_search_local():
"""
Search the spk in MySQL by audio
"""
response = client.post(
"/vpr/recog/local?audio_path=.%2Fexample_audio%2Ftest.wav")
assert response.status_code == 200
def test_list():
"""
Get all records in MySQL
"""
response = client.get("/vpr/list")
assert response.status_code == 200
def test_data(spk: str):
"""
Get the audio file by spk_id in MySQL
"""
response = client.get("/vpr/data?spk_id=" + spk)
assert response.status_code == 200
def test_del(spk: str):
"""
Delete the record in MySQL by spk_id
"""
response = client.post("/vpr/del?spk_id=" + spk)
assert response.status_code == 200
def test_count():
"""
Get the number of spk in MySQL
"""
response = client.get("/vpr/count")
assert response.status_code == 200
if __name__ == "__main__":
download_audio_data()
test_enroll_local("spk1", "arms_strikes")
test_enroll_local("spk2", "sword_wielding")
test_enroll_local("spk3", "test")
test_list()
test_data("spk1")
test_count()
test_search_local()
test_del("spk1")
test_count()
test_search_local()
test_enroll_local("spk1", "arms_strikes")
test_count()
test_search_local()
test_drop()

@ -0,0 +1,207 @@
# Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import os
import uvicorn
from config import UPLOAD_PATH
from fastapi import FastAPI
from fastapi import File
from fastapi import UploadFile
from mysql_helpers import MySQLHelper
from operations.count import do_count_vpr
from operations.count import do_get
from operations.count import do_list
from operations.drop import do_delete
from operations.drop import do_drop_vpr
from operations.load import do_enroll
from operations.search import do_search_vpr
from starlette.middleware.cors import CORSMiddleware
from starlette.requests import Request
from starlette.responses import FileResponse
from logs import LOGGER
app = FastAPI()
app.add_middleware(
CORSMiddleware,
allow_origins=["*"],
allow_credentials=True,
allow_methods=["*"],
allow_headers=["*"])
MYSQL_CLI = MySQLHelper()
# Mkdir 'tmp/audio-data'
if not os.path.exists(UPLOAD_PATH):
os.makedirs(UPLOAD_PATH)
LOGGER.info(f"Mkdir the path: {UPLOAD_PATH}")
@app.post('/vpr/enroll')
async def vpr_enroll(table_name: str=None,
spk_id: str=None,
audio: UploadFile=File(...)):
# Enroll the uploaded audio with spk-id into MySQL
try:
# Save the upload data to server.
content = await audio.read()
audio_path = os.path.join(UPLOAD_PATH, audio.filename)
with open(audio_path, "wb+") as f:
f.write(content)
do_enroll(table_name, spk_id, audio_path, MYSQL_CLI)
LOGGER.info(f"Successfully enrolled {spk_id} online!")
return {'status': True, 'msg': "Successfully enroll data!"}
except Exception as e:
LOGGER.error(e)
return {'status': False, 'msg': e}, 400
@app.post('/vpr/enroll/local')
async def vpr_enroll_local(table_name: str=None,
spk_id: str=None,
audio_path: str=None):
# Enroll the local audio with spk-id into MySQL
try:
do_enroll(table_name, spk_id, audio_path, MYSQL_CLI)
LOGGER.info(f"Successfully enrolled {spk_id} locally!")
return {'status': True, 'msg': "Successfully enroll data!"}
except Exception as e:
LOGGER.error(e)
return {'status': False, 'msg': e}, 400
@app.post('/vpr/recog')
async def vpr_recog(request: Request,
table_name: str=None,
audio: UploadFile=File(...)):
# Voice print recognition online
try:
# Save the upload data to server.
content = await audio.read()
query_audio_path = os.path.join(UPLOAD_PATH, audio.filename)
with open(query_audio_path, "wb+") as f:
f.write(content)
host = request.headers['host']
spk_ids, paths, scores = do_search_vpr(host, table_name,
query_audio_path, MYSQL_CLI)
for spk_id, path, score in zip(spk_ids, paths, scores):
LOGGER.info(f"spk {spk_id}, score {score}, audio path {path}, ")
res = dict(zip(spk_ids, zip(paths, scores)))
# Sort results by distance metric, closest distances first
res = sorted(res.items(), key=lambda item: item[1][1], reverse=True)
LOGGER.info("Successfully speaker recognition online!")
return res
except Exception as e:
LOGGER.error(e)
return {'status': False, 'msg': e}, 400
@app.post('/vpr/recog/local')
async def vpr_recog_local(request: Request,
table_name: str=None,
audio_path: str=None):
# Voice print recognition locally
try:
host = request.headers['host']
spk_ids, paths, scores = do_search_vpr(host, table_name, audio_path,
MYSQL_CLI)
for spk_id, path, score in zip(spk_ids, paths, scores):
LOGGER.info(f"spk {spk_id}, score {score}, audio path {path}, ")
res = dict(zip(spk_ids, zip(paths, scores)))
# Sort results by distance metric, closest distances first
res = sorted(res.items(), key=lambda item: item[1][1], reverse=True)
LOGGER.info("Successfully speaker recognition locally!")
return res
except Exception as e:
LOGGER.error(e)
return {'status': False, 'msg': e}, 400
@app.post('/vpr/del')
async def vpr_del(table_name: str=None, spk_id: str=None):
# Delete a record by spk_id in MySQL
try:
do_delete(table_name, spk_id, MYSQL_CLI)
LOGGER.info("Successfully delete a record by spk_id in MySQL")
return {'status': True, 'msg': "Successfully delete data!"}
except Exception as e:
LOGGER.error(e)
return {'status': False, 'msg': e}, 400
@app.get('/vpr/list')
async def vpr_list(table_name: str=None):
# Get all records in MySQL
try:
spk_ids, audio_paths = do_list(table_name, MYSQL_CLI)
for i in range(len(spk_ids)):
LOGGER.debug(f"spk {spk_ids[i]}, audio path {audio_paths[i]}")
LOGGER.info("Successfully list all records from mysql!")
return spk_ids, audio_paths
except Exception as e:
LOGGER.error(e)
return {'status': False, 'msg': e}, 400
@app.get('/vpr/data')
async def vpr_data(
table_name: str=None,
spk_id: str=None, ):
# Get the audio file from path by spk_id in MySQL
try:
audio_path = do_get(table_name, spk_id, MYSQL_CLI)
LOGGER.info(f"Successfully get audio path {audio_path}!")
return FileResponse(audio_path)
except Exception as e:
LOGGER.error(e)
return {'status': False, 'msg': e}, 400
@app.get('/vpr/count')
async def vpr_count(table_name: str=None):
# Get the total number of spk in MySQL
try:
num = do_count_vpr(table_name, MYSQL_CLI)
LOGGER.info("Successfully count the number of spk!")
return num
except Exception as e:
LOGGER.error(e)
return {'status': False, 'msg': e}, 400
@app.post('/vpr/drop')
async def drop_tables(table_name: str=None):
# Delete the table of MySQL
try:
do_drop_vpr(table_name, MYSQL_CLI)
LOGGER.info("Successfully drop tables in MySQL!")
return {'status': True, 'msg': "Successfully drop tables!"}
except Exception as e:
LOGGER.error(e)
return {'status': False, 'msg': e}, 400
@app.get('/data')
def audio_path(audio_path):
# Get the audio file from path
try:
LOGGER.info(f"Successfully get audio: {audio_path}")
return FileResponse(audio_path)
except Exception as e:
LOGGER.error(f"get audio error: {e}")
return {'status': False, 'msg': e}, 400
if __name__ == '__main__':
uvicorn.run(app=app, host='0.0.0.0', port=8002)

@ -11,6 +11,7 @@
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and # See the License for the specific language governing permissions and
# limitations under the License. # limitations under the License.
# Modified from speechbrain(https://github.com/speechbrain/speechbrain)
""" """
This script contains basic functions used for speaker diarization. This script contains basic functions used for speaker diarization.
This script has an optional dependency on open source sklearn library. This script has an optional dependency on open source sklearn library.

Loading…
Cancel
Save