From 89a0ec90184765e050ce10a2114f46a9a85510d0 Mon Sep 17 00:00:00 2001 From: qingen Date: Wed, 13 Apr 2022 18:32:42 +0800 Subject: [PATCH] [vec][server] vpr demo support, test=doc fix #1695 --- demos/audio_searching/README.md | 6 +- demos/audio_searching/README_cn.md | 6 +- .../src/{main.py => audio_search.py} | 4 +- demos/audio_searching/src/mysql_helpers.py | 81 ++++++- demos/audio_searching/src/operations/count.py | 43 ++++ demos/audio_searching/src/operations/drop.py | 29 +++ demos/audio_searching/src/operations/load.py | 19 +- .../audio_searching/src/operations/search.py | 25 +++ .../{test_main.py => test_audio_search.py} | 6 +- demos/audio_searching/src/test_vpr_search.py | 115 ++++++++++ demos/audio_searching/src/vpr_search.py | 207 ++++++++++++++++++ paddlespeech/vector/cluster/diarization.py | 1 + 12 files changed, 524 insertions(+), 18 deletions(-) rename demos/audio_searching/src/{main.py => audio_search.py} (99%) rename demos/audio_searching/src/{test_main.py => test_audio_search.py} (96%) create mode 100644 demos/audio_searching/src/test_vpr_search.py create mode 100644 demos/audio_searching/src/vpr_search.py diff --git a/demos/audio_searching/README.md b/demos/audio_searching/README.md index 8a6f3863..87a1956b 100644 --- a/demos/audio_searching/README.md +++ b/demos/audio_searching/README.md @@ -90,7 +90,7 @@ Then to start the system server, and it provides HTTP backend services. ```bash export PYTHONPATH=$PYTHONPATH:./src:../../paddleaudio - python src/main.py + python src/audio_search.py ``` Then you will see the Application is started: @@ -111,7 +111,7 @@ Then to start the system server, and it provides HTTP backend services. ```bash wget -c https://www.openslr.org/resources/82/cn-celeb_v2.tar.gz && tar -xvf cn-celeb_v2.tar.gz ``` - **Note**: If you want to build a quick demo, you can use ./src/test_main.py:download_audio_data function, it downloads 20 audio files , Subsequent results show this collection as an example + **Note**: If you want to build a quick demo, you can use ./src/test_audio_search.py:download_audio_data function, it downloads 20 audio files , Subsequent results show this collection as an example - Prepare model(Skip this step if you use the default model.) ```bash @@ -123,7 +123,7 @@ Then to start the system server, and it provides HTTP backend services. The internal process is downloading data, loading the paddlespeech model, extracting embedding, storing library, retrieving and deleting library ```bash - python ./src/test_main.py + python ./src/test_audio_search.py ``` Output: diff --git a/demos/audio_searching/README_cn.md b/demos/audio_searching/README_cn.md index 0d0f42a0..a93dbdc1 100644 --- a/demos/audio_searching/README_cn.md +++ b/demos/audio_searching/README_cn.md @@ -92,7 +92,7 @@ ffce340b3790 minio/minio:RELEASE.2020-12-03T00-03-10Z "/usr/bin/docker-ent…" ```bash export PYTHONPATH=$PYTHONPATH:./src:../../paddleaudio - python src/main.py + python src/audio_search.py ``` 然后你会看到应用程序启动: @@ -113,7 +113,7 @@ ffce340b3790 minio/minio:RELEASE.2020-12-03T00-03-10Z "/usr/bin/docker-ent…" ```bash wget -c https://www.openslr.org/resources/82/cn-celeb_v2.tar.gz && tar -xvf cn-celeb_v2.tar.gz ``` - **注**:如果希望快速搭建 demo,可以采用 ./src/test_main.py:download_audio_data 内部的 20 条音频,另外后续结果展示以该集合为例 + **注**:如果希望快速搭建 demo,可以采用 ./src/test_audio_search.py:download_audio_data 内部的 20 条音频,另外后续结果展示以该集合为例 - 准备模型(如果使用默认模型,可以跳过此步骤) ```bash @@ -124,7 +124,7 @@ ffce340b3790 minio/minio:RELEASE.2020-12-03T00-03-10Z "/usr/bin/docker-ent…" - 脚本测试(推荐) ```bash - python ./src/test_main.py + python ./src/test_audio_search.py ``` 注:内部将依次下载数据,加载 paddlespeech 模型,提取 embedding,存储建库,检索,删库 diff --git a/demos/audio_searching/src/main.py b/demos/audio_searching/src/audio_search.py similarity index 99% rename from demos/audio_searching/src/main.py rename to demos/audio_searching/src/audio_search.py index db091a39..f407b128 100644 --- a/demos/audio_searching/src/main.py +++ b/demos/audio_searching/src/audio_search.py @@ -20,7 +20,6 @@ from diskcache import Cache from fastapi import FastAPI from fastapi import File from fastapi import UploadFile -from logs import LOGGER from milvus_helpers import MilvusHelper from mysql_helpers import MySQLHelper from operations.count import do_count @@ -32,6 +31,8 @@ from starlette.middleware.cors import CORSMiddleware from starlette.requests import Request from starlette.responses import FileResponse +from logs import LOGGER + app = FastAPI() app.add_middleware( CORSMiddleware, @@ -40,7 +41,6 @@ app.add_middleware( allow_methods=["*"], allow_headers=["*"]) -MODEL = None MILVUS_CLI = MilvusHelper() MYSQL_CLI = MySQLHelper() diff --git a/demos/audio_searching/src/mysql_helpers.py b/demos/audio_searching/src/mysql_helpers.py index 30383839..a0640c84 100644 --- a/demos/audio_searching/src/mysql_helpers.py +++ b/demos/audio_searching/src/mysql_helpers.py @@ -13,12 +13,14 @@ # limitations under the License. import sys +import numpy import pymysql from config import MYSQL_DB from config import MYSQL_HOST from config import MYSQL_PORT from config import MYSQL_PWD from config import MYSQL_USER + from logs import LOGGER @@ -69,7 +71,7 @@ class MySQLHelper(): sys.exit(1) def load_data_to_mysql(self, table_name, data): - # Batch insert (Milvus_ids, img_path) to mysql + # Batch insert (Milvus_ids, audio_path) to mysql self.test_connection() sql = "insert into " + table_name + " (milvus_id,audio_path) values (%s,%s);" try: @@ -82,7 +84,7 @@ class MySQLHelper(): sys.exit(1) def search_by_milvus_ids(self, ids, table_name): - # Get the img_path according to the milvus ids + # Get the audio_path according to the milvus ids self.test_connection() str_ids = str(ids).replace('[', '').replace(']', '') sql = "select audio_path from " + table_name + " where milvus_id in (" + str_ids + ") order by field (milvus_id," + str_ids + ");" @@ -120,14 +122,83 @@ class MySQLHelper(): sys.exit(1) def count_table(self, table_name): - # Get the number of mysql table + # Get the number of spk in mysql table self.test_connection() - sql = "select count(milvus_id) from " + table_name + ";" + sql = "select count(spk_id) from " + table_name + ";" try: self.cursor.execute(sql) results = self.cursor.fetchall() - LOGGER.debug(f"MYSQL count table:{table_name}") + LOGGER.debug(f"MYSQL count table:{results[0][0]}") return results[0][0] except Exception as e: LOGGER.error(f"MYSQL ERROR: {e} with sql: {sql}") sys.exit(1) + + def create_mysql_table_vpr(self, table_name): + # Create mysql table if not exists + self.test_connection() + sql = "create table if not exists " + table_name + "(spk_id TEXT, audio_path TEXT, embedding TEXT);" + try: + self.cursor.execute(sql) + LOGGER.debug(f"MYSQL create table: {table_name} with sql: {sql}") + except Exception as e: + LOGGER.error(f"MYSQL ERROR: {e} with sql: {sql}") + sys.exit(1) + + def load_data_to_mysql_vpr(self, table_name, data): + # Insert (spk, audio, embedding) to mysql + self.test_connection() + sql = "insert into " + table_name + " (spk_id,audio_path,embedding) values (%s,%s,%s);" + try: + self.cursor.execute(sql, data) + LOGGER.debug( + f"MYSQL loads data to table: {table_name} successfully") + except Exception as e: + LOGGER.error(f"MYSQL ERROR: {e} with sql: {sql}") + sys.exit(1) + + def list_vpr(self, table_name): + # Get all records in mysql + self.test_connection() + sql = "select * from " + table_name + " ;" + try: + self.cursor.execute(sql) + results = self.cursor.fetchall() + self.conn.commit() + spk_ids = [res[0] for res in results] + audio_paths = [res[1] for res in results] + embeddings = [ + numpy.array( + str(res[2]).replace('[', '').replace(']', '').split(",")) + for res in results + ] + return spk_ids, audio_paths, embeddings + except Exception as e: + LOGGER.error(f"MYSQL ERROR: {e} with sql: {sql}") + sys.exit(1) + + def search_audio_vpr(self, table_name, spk_id): + # Get the audio_path according to the spk_id + self.test_connection() + sql = "select audio_path from " + table_name + " where spk_id='" + spk_id + "' ;" + try: + self.cursor.execute(sql) + results = self.cursor.fetchall() + LOGGER.debug( + f"MYSQL search by spk id {spk_id} to get audio {results[0][0]}.") + return results[0][0] + except Exception as e: + LOGGER.error(f"MYSQL ERROR: {e} with sql: {sql}") + sys.exit(1) + + def delete_data_vpr(self, table_name, spk_id): + # Delete a record by spk_id in mysql table + self.test_connection() + sql = "delete from " + table_name + " where spk_id='" + spk_id + "';" + try: + self.cursor.execute(sql) + LOGGER.debug( + f"MYSQL delete a record {spk_id} in table {table_name}") + except Exception as e: + LOGGER.error(f"MYSQL ERROR: {e} with sql: {sql}") + sys.exit(1) diff --git a/demos/audio_searching/src/operations/count.py b/demos/audio_searching/src/operations/count.py index 9a1f4208..2116afd7 100644 --- a/demos/audio_searching/src/operations/count.py +++ b/demos/audio_searching/src/operations/count.py @@ -14,6 +14,7 @@ import sys from config import DEFAULT_TABLE + from logs import LOGGER @@ -31,3 +32,45 @@ def do_count(table_name, milvus_cli): except Exception as e: LOGGER.error(f"Error attempting to count table {e}") sys.exit(1) + + +def do_count_vpr(table_name, mysql_cli): + """ + Returns the total number of spk in the system + """ + if not table_name: + table_name = DEFAULT_TABLE + try: + num = mysql_cli.count_table(table_name) + return num + except Exception as e: + LOGGER.error(f"Error attempting to count table {e}") + sys.exit(1) + + +def do_list(table_name, mysql_cli): + """ + Returns the total records of vpr in the system + """ + if not table_name: + table_name = DEFAULT_TABLE + try: + spk_ids, audio_paths, _ = mysql_cli.list_vpr(table_name) + return spk_ids, audio_paths + except Exception as e: + LOGGER.error(f"Error attempting to count table {e}") + sys.exit(1) + + +def do_get(table_name, spk_id, mysql_cli): + """ + Returns the audio path by spk_id in the system + """ + if not table_name: + table_name = DEFAULT_TABLE + try: + audio_apth = mysql_cli.search_audio_vpr(table_name, spk_id) + return audio_apth + except Exception as e: + LOGGER.error(f"Error attempting to count table {e}") + sys.exit(1) diff --git a/demos/audio_searching/src/operations/drop.py b/demos/audio_searching/src/operations/drop.py index f8278ddd..432da426d 100644 --- a/demos/audio_searching/src/operations/drop.py +++ b/demos/audio_searching/src/operations/drop.py @@ -14,6 +14,7 @@ import sys from config import DEFAULT_TABLE + from logs import LOGGER @@ -32,3 +33,31 @@ def do_drop(table_name, milvus_cli, mysql_cli): except Exception as e: LOGGER.error(f"Error attempting to drop table: {e}") sys.exit(1) + + +def do_drop_vpr(table_name, mysql_cli): + """ + Delete the table of MySQL + """ + if not table_name: + table_name = DEFAULT_TABLE + try: + mysql_cli.delete_table(table_name) + return "OK" + except Exception as e: + LOGGER.error(f"Error attempting to drop table: {e}") + sys.exit(1) + + +def do_delete(table_name, spk_id, mysql_cli): + """ + Delete a record by spk_id in MySQL + """ + if not table_name: + table_name = DEFAULT_TABLE + try: + mysql_cli.delete_data_vpr(table_name, spk_id) + return "OK" + except Exception as e: + LOGGER.error(f"Error attempting to drop table: {e}") + sys.exit(1) diff --git a/demos/audio_searching/src/operations/load.py b/demos/audio_searching/src/operations/load.py index 80b6375f..5852d6ea 100644 --- a/demos/audio_searching/src/operations/load.py +++ b/demos/audio_searching/src/operations/load.py @@ -17,6 +17,7 @@ import sys from config import DEFAULT_TABLE from diskcache import Cache from encode import get_audio_embedding + from logs import LOGGER @@ -26,8 +27,9 @@ def get_audios(path): """ supported_formats = [".wav", ".mp3", ".ogg", ".flac", ".m4a"] return [ - item for sublist in [[os.path.join(dir, file) for file in files] - for dir, _, files in list(os.walk(path))] + item + for sublist in [[os.path.join(dir, file) for file in files] + for dir, _, files in list(os.walk(path))] for item in sublist if os.path.splitext(item)[1] in supported_formats ] @@ -82,3 +84,16 @@ def do_load(table_name, audio_dir, milvus_cli, mysql_cli): mysql_cli.create_mysql_table(table_name) mysql_cli.load_data_to_mysql(table_name, format_data(ids, names)) return len(ids) + + +def do_enroll(table_name, spk_id, audio_path, mysql_cli): + """ + Import spk_id,audio_path,embedding to Mysql + """ + if not table_name: + table_name = DEFAULT_TABLE + embedding = get_audio_embedding(audio_path) + mysql_cli.create_mysql_table_vpr(table_name) + data = (spk_id, audio_path, str(embedding)) + mysql_cli.load_data_to_mysql_vpr(table_name, data) + return "OK" diff --git a/demos/audio_searching/src/operations/search.py b/demos/audio_searching/src/operations/search.py index 9cf48abf..160634a9 100644 --- a/demos/audio_searching/src/operations/search.py +++ b/demos/audio_searching/src/operations/search.py @@ -13,9 +13,11 @@ # limitations under the License. import sys +import numpy from config import DEFAULT_TABLE from config import TOP_K from encode import get_audio_embedding + from logs import LOGGER @@ -39,3 +41,26 @@ def do_search(host, table_name, audio_path, milvus_cli, mysql_cli): except Exception as e: LOGGER.error(f"Error with search: {e}") sys.exit(1) + + +def do_search_vpr(host, table_name, audio_path, mysql_cli): + """ + Search the uploaded audio in MySQL + """ + try: + if not table_name: + table_name = DEFAULT_TABLE + emb = get_audio_embedding(audio_path) + emb = numpy.array(emb) + spk_ids, paths, vectors = mysql_cli.list_vpr(table_name) + scores = [numpy.dot(emb, x.astype(numpy.float64)) for x in vectors] + spk_ids = [str(x) for x in spk_ids] + paths = [str(x) for x in paths] + for i in range(len(paths)): + tmp = "http://" + str(host) + "/data?audio_path=" + str(paths[i]) + paths[i] = tmp + scores[i] = scores[i] * 100 + return spk_ids, paths, scores + except Exception as e: + LOGGER.error(f"Error with search: {e}") + sys.exit(1) diff --git a/demos/audio_searching/src/test_main.py b/demos/audio_searching/src/test_audio_search.py similarity index 96% rename from demos/audio_searching/src/test_main.py rename to demos/audio_searching/src/test_audio_search.py index 32030bae..cb91e156 100644 --- a/demos/audio_searching/src/test_main.py +++ b/demos/audio_searching/src/test_audio_search.py @@ -11,8 +11,8 @@ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. +from audio_search import app from fastapi.testclient import TestClient -from main import app from utils.utility import download from utils.utility import unpack @@ -22,7 +22,7 @@ client = TestClient(app) def download_audio_data(): """ - download audio data + Download audio data """ url = "https://paddlespeech.bj.bcebos.com/vector/audio/example_audio.tar.gz" md5sum = "52ac69316c1aa1fdef84da7dd2c67b39" @@ -64,7 +64,7 @@ def test_count(): """ Returns the total number of vectors in the system """ - response = client.get("audio/count") + response = client.get("/audio/count") assert response.status_code == 200 assert response.json() == 20 diff --git a/demos/audio_searching/src/test_vpr_search.py b/demos/audio_searching/src/test_vpr_search.py new file mode 100644 index 00000000..8cc8dc84 --- /dev/null +++ b/demos/audio_searching/src/test_vpr_search.py @@ -0,0 +1,115 @@ +# Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +from fastapi.testclient import TestClient +from vpr_search import app + +from utils.utility import download +from utils.utility import unpack + +client = TestClient(app) + + +def download_audio_data(): + """ + Download audio data + """ + url = "https://paddlespeech.bj.bcebos.com/vector/audio/example_audio.tar.gz" + md5sum = "52ac69316c1aa1fdef84da7dd2c67b39" + target_dir = "./" + filepath = download(url, md5sum, target_dir) + unpack(filepath, target_dir, True) + + +def test_drop(): + """ + Delete the table of MySQL + """ + response = client.post("/vpr/drop") + assert response.status_code == 200 + + +def test_enroll_local(spk: str, audio: str): + """ + Enroll the audio to MySQL + """ + response = client.post("/vpr/enroll/local?spk_id=" + spk + + "&audio_path=.%2Fexample_audio%2F" + audio + ".wav") + assert response.status_code == 200 + assert response.json() == { + 'status': True, + 'msg': "Successfully enroll data!" + } + + +def test_search_local(): + """ + Search the spk in MySQL by audio + """ + response = client.post( + "/vpr/recog/local?audio_path=.%2Fexample_audio%2Ftest.wav") + assert response.status_code == 200 + + +def test_list(): + """ + Get all records in MySQL + """ + response = client.get("/vpr/list") + assert response.status_code == 200 + + +def test_data(spk: str): + """ + Get the audio file by spk_id in MySQL + """ + response = client.get("/vpr/data?spk_id=" + spk) + assert response.status_code == 200 + + +def test_del(spk: str): + """ + Delete the record in MySQL by spk_id + """ + response = client.post("/vpr/del?spk_id=" + spk) + assert response.status_code == 200 + + +def test_count(): + """ + Get the number of spk in MySQL + """ + response = client.get("/vpr/count") + assert response.status_code == 200 + + +if __name__ == "__main__": + download_audio_data() + + test_enroll_local("spk1", "arms_strikes") + test_enroll_local("spk2", "sword_wielding") + test_enroll_local("spk3", "test") + test_list() + test_data("spk1") + test_count() + test_search_local() + + test_del("spk1") + test_count() + test_search_local() + + test_enroll_local("spk1", "arms_strikes") + test_count() + test_search_local() + + test_drop() diff --git a/demos/audio_searching/src/vpr_search.py b/demos/audio_searching/src/vpr_search.py new file mode 100644 index 00000000..9ca65062 --- /dev/null +++ b/demos/audio_searching/src/vpr_search.py @@ -0,0 +1,207 @@ +# Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +import os + +import uvicorn +from config import UPLOAD_PATH +from fastapi import FastAPI +from fastapi import File +from fastapi import UploadFile +from mysql_helpers import MySQLHelper +from operations.count import do_count_vpr +from operations.count import do_get +from operations.count import do_list +from operations.drop import do_delete +from operations.drop import do_drop_vpr +from operations.load import do_enroll +from operations.search import do_search_vpr +from starlette.middleware.cors import CORSMiddleware +from starlette.requests import Request +from starlette.responses import FileResponse + +from logs import LOGGER + +app = FastAPI() +app.add_middleware( + CORSMiddleware, + allow_origins=["*"], + allow_credentials=True, + allow_methods=["*"], + allow_headers=["*"]) + +MYSQL_CLI = MySQLHelper() + +# Mkdir 'tmp/audio-data' +if not os.path.exists(UPLOAD_PATH): + os.makedirs(UPLOAD_PATH) + LOGGER.info(f"Mkdir the path: {UPLOAD_PATH}") + + +@app.post('/vpr/enroll') +async def vpr_enroll(table_name: str=None, + spk_id: str=None, + audio: UploadFile=File(...)): + # Enroll the uploaded audio with spk-id into MySQL + try: + # Save the upload data to server. + content = await audio.read() + audio_path = os.path.join(UPLOAD_PATH, audio.filename) + with open(audio_path, "wb+") as f: + f.write(content) + do_enroll(table_name, spk_id, audio_path, MYSQL_CLI) + LOGGER.info(f"Successfully enrolled {spk_id} online!") + return {'status': True, 'msg': "Successfully enroll data!"} + except Exception as e: + LOGGER.error(e) + return {'status': False, 'msg': e}, 400 + + +@app.post('/vpr/enroll/local') +async def vpr_enroll_local(table_name: str=None, + spk_id: str=None, + audio_path: str=None): + # Enroll the local audio with spk-id into MySQL + try: + do_enroll(table_name, spk_id, audio_path, MYSQL_CLI) + LOGGER.info(f"Successfully enrolled {spk_id} locally!") + return {'status': True, 'msg': "Successfully enroll data!"} + except Exception as e: + LOGGER.error(e) + return {'status': False, 'msg': e}, 400 + + +@app.post('/vpr/recog') +async def vpr_recog(request: Request, + table_name: str=None, + audio: UploadFile=File(...)): + # Voice print recognition online + try: + # Save the upload data to server. + content = await audio.read() + query_audio_path = os.path.join(UPLOAD_PATH, audio.filename) + with open(query_audio_path, "wb+") as f: + f.write(content) + host = request.headers['host'] + spk_ids, paths, scores = do_search_vpr(host, table_name, + query_audio_path, MYSQL_CLI) + for spk_id, path, score in zip(spk_ids, paths, scores): + LOGGER.info(f"spk {spk_id}, score {score}, audio path {path}, ") + res = dict(zip(spk_ids, zip(paths, scores))) + # Sort results by distance metric, closest distances first + res = sorted(res.items(), key=lambda item: item[1][1], reverse=True) + LOGGER.info("Successfully speaker recognition online!") + return res + except Exception as e: + LOGGER.error(e) + return {'status': False, 'msg': e}, 400 + + +@app.post('/vpr/recog/local') +async def vpr_recog_local(request: Request, + table_name: str=None, + audio_path: str=None): + # Voice print recognition locally + try: + host = request.headers['host'] + spk_ids, paths, scores = do_search_vpr(host, table_name, audio_path, + MYSQL_CLI) + for spk_id, path, score in zip(spk_ids, paths, scores): + LOGGER.info(f"spk {spk_id}, score {score}, audio path {path}, ") + res = dict(zip(spk_ids, zip(paths, scores))) + # Sort results by distance metric, closest distances first + res = sorted(res.items(), key=lambda item: item[1][1], reverse=True) + LOGGER.info("Successfully speaker recognition locally!") + return res + except Exception as e: + LOGGER.error(e) + return {'status': False, 'msg': e}, 400 + + +@app.post('/vpr/del') +async def vpr_del(table_name: str=None, spk_id: str=None): + # Delete a record by spk_id in MySQL + try: + do_delete(table_name, spk_id, MYSQL_CLI) + LOGGER.info("Successfully delete a record by spk_id in MySQL") + return {'status': True, 'msg': "Successfully delete data!"} + except Exception as e: + LOGGER.error(e) + return {'status': False, 'msg': e}, 400 + + +@app.get('/vpr/list') +async def vpr_list(table_name: str=None): + # Get all records in MySQL + try: + spk_ids, audio_paths = do_list(table_name, MYSQL_CLI) + for i in range(len(spk_ids)): + LOGGER.debug(f"spk {spk_ids[i]}, audio path {audio_paths[i]}") + LOGGER.info("Successfully list all records from mysql!") + return spk_ids, audio_paths + except Exception as e: + LOGGER.error(e) + return {'status': False, 'msg': e}, 400 + + +@app.get('/vpr/data') +async def vpr_data( + table_name: str=None, + spk_id: str=None, ): + # Get the audio file from path by spk_id in MySQL + try: + audio_path = do_get(table_name, spk_id, MYSQL_CLI) + LOGGER.info(f"Successfully get audio path {audio_path}!") + return FileResponse(audio_path) + except Exception as e: + LOGGER.error(e) + return {'status': False, 'msg': e}, 400 + + +@app.get('/vpr/count') +async def vpr_count(table_name: str=None): + # Get the total number of spk in MySQL + try: + num = do_count_vpr(table_name, MYSQL_CLI) + LOGGER.info("Successfully count the number of spk!") + return num + except Exception as e: + LOGGER.error(e) + return {'status': False, 'msg': e}, 400 + + +@app.post('/vpr/drop') +async def drop_tables(table_name: str=None): + # Delete the table of MySQL + try: + do_drop_vpr(table_name, MYSQL_CLI) + LOGGER.info("Successfully drop tables in MySQL!") + return {'status': True, 'msg': "Successfully drop tables!"} + except Exception as e: + LOGGER.error(e) + return {'status': False, 'msg': e}, 400 + + +@app.get('/data') +def audio_path(audio_path): + # Get the audio file from path + try: + LOGGER.info(f"Successfully get audio: {audio_path}") + return FileResponse(audio_path) + except Exception as e: + LOGGER.error(f"get audio error: {e}") + return {'status': False, 'msg': e}, 400 + + +if __name__ == '__main__': + uvicorn.run(app=app, host='0.0.0.0', port=8002) diff --git a/paddlespeech/vector/cluster/diarization.py b/paddlespeech/vector/cluster/diarization.py index 597aa480..ee00cb53 100644 --- a/paddlespeech/vector/cluster/diarization.py +++ b/paddlespeech/vector/cluster/diarization.py @@ -11,6 +11,7 @@ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. +# Modified from speechbrain(https://github.com/speechbrain/speechbrain) """ This script contains basic functions used for speaker diarization. This script has an optional dependency on open source sklearn library.