You can not select more than 25 topics
Topics must start with a letter or number, can include dashes ('-') and can be up to 35 characters long.
169 lines
5.6 KiB
169 lines
5.6 KiB
# Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved.
|
|
#
|
|
# Licensed under the Apache License, Version 2.0 (the "License");
|
|
# you may not use this file except in compliance with the License.
|
|
# You may obtain a copy of the License at
|
|
#
|
|
# http://www.apache.org/licenses/LICENSE-2.0
|
|
#
|
|
# Unless required by applicable law or agreed to in writing, software
|
|
# distributed under the License is distributed on an "AS IS" BASIS,
|
|
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
|
# See the License for the specific language governing permissions and
|
|
# limitations under the License.
|
|
import os
|
|
from typing import Optional
|
|
|
|
import uvicorn
|
|
from config import UPLOAD_PATH
|
|
from diskcache import Cache
|
|
from fastapi import FastAPI
|
|
from fastapi import File
|
|
from fastapi import UploadFile
|
|
from logs import LOGGER
|
|
from milvus_helpers import MilvusHelper
|
|
from mysql_helpers import MySQLHelper
|
|
from operations.count import do_count
|
|
from operations.drop import do_drop
|
|
from operations.load import do_load
|
|
from operations.search import do_search
|
|
from pydantic import BaseModel
|
|
from starlette.middleware.cors import CORSMiddleware
|
|
from starlette.requests import Request
|
|
from starlette.responses import FileResponse
|
|
|
|
app = FastAPI()
|
|
app.add_middleware(
|
|
CORSMiddleware,
|
|
allow_origins=["*"],
|
|
allow_credentials=True,
|
|
allow_methods=["*"],
|
|
allow_headers=["*"])
|
|
|
|
MODEL = None
|
|
MILVUS_CLI = MilvusHelper()
|
|
MYSQL_CLI = MySQLHelper()
|
|
|
|
# Mkdir 'tmp/audio-data'
|
|
if not os.path.exists(UPLOAD_PATH):
|
|
os.makedirs(UPLOAD_PATH)
|
|
LOGGER.info(f"Mkdir the path: {UPLOAD_PATH}")
|
|
|
|
|
|
@app.get('/data')
|
|
def audio_path(audio_path):
|
|
# Get the audio file
|
|
try:
|
|
LOGGER.info(f"Successfully load audio: {audio_path}")
|
|
return FileResponse(audio_path)
|
|
except Exception as e:
|
|
LOGGER.error(f"upload audio error: {e}")
|
|
return {'status': False, 'msg': e}, 400
|
|
|
|
|
|
@app.get('/progress')
|
|
def get_progress():
|
|
# Get the progress of dealing with data
|
|
try:
|
|
cache = Cache('./tmp')
|
|
return f"current: {cache['current']}, total: {cache['total']}"
|
|
except Exception as e:
|
|
LOGGER.error(f"Upload data error: {e}")
|
|
return {'status': False, 'msg': e}, 400
|
|
|
|
|
|
class Item(BaseModel):
|
|
Table: Optional[str] = None
|
|
File: str
|
|
|
|
|
|
@app.post('/audio/load')
|
|
async def load_audios(item: Item):
|
|
# Insert all the audio files under the file path to Milvus/MySQL
|
|
try:
|
|
total_num = do_load(item.Table, item.File, MILVUS_CLI, MYSQL_CLI)
|
|
LOGGER.info(f"Successfully loaded data, total count: {total_num}")
|
|
return {'status': True, 'msg': "Successfully loaded data!"}
|
|
except Exception as e:
|
|
LOGGER.error(e)
|
|
return {'status': False, 'msg': e}, 400
|
|
|
|
|
|
@app.post('/audio/search')
|
|
async def search_audio(request: Request,
|
|
table_name: str=None,
|
|
audio: UploadFile=File(...)):
|
|
# Search the uploaded audio in Milvus/MySQL
|
|
try:
|
|
# Save the upload data to server.
|
|
content = await audio.read()
|
|
query_audio_path = os.path.join(UPLOAD_PATH, audio.filename)
|
|
with open(query_audio_path, "wb+") as f:
|
|
f.write(content)
|
|
host = request.headers['host']
|
|
_, paths, distances = do_search(host, table_name, query_audio_path,
|
|
MILVUS_CLI, MYSQL_CLI)
|
|
names = []
|
|
for path, score in zip(paths, distances):
|
|
names.append(os.path.basename(path))
|
|
LOGGER.info(f"search result {path}, score {score}")
|
|
res = dict(zip(paths, zip(names, distances)))
|
|
# Sort results by distance metric, closest distances first
|
|
res = sorted(res.items(), key=lambda item: item[1][1], reverse=True)
|
|
LOGGER.info("Successfully searched similar audio!")
|
|
return res
|
|
except Exception as e:
|
|
LOGGER.error(e)
|
|
return {'status': False, 'msg': e}, 400
|
|
|
|
|
|
@app.post('/audio/search/local')
|
|
async def search_local_audio(request: Request,
|
|
query_audio_path: str,
|
|
table_name: str=None):
|
|
# Search the uploaded audio in Milvus/MySQL
|
|
try:
|
|
host = request.headers['host']
|
|
_, paths, distances = do_search(host, table_name, query_audio_path,
|
|
MILVUS_CLI, MYSQL_CLI)
|
|
names = []
|
|
for path, score in zip(paths, distances):
|
|
names.append(os.path.basename(path))
|
|
LOGGER.info(f"search result {path}, score {score}")
|
|
res = dict(zip(paths, zip(names, distances)))
|
|
# Sort results by distance metric, closest distances first
|
|
res = sorted(res.items(), key=lambda item: item[1][1], reverse=True)
|
|
LOGGER.info("Successfully searched similar audio!")
|
|
return res
|
|
except Exception as e:
|
|
LOGGER.error(e)
|
|
return {'status': False, 'msg': e}, 400
|
|
|
|
|
|
@app.get('/audio/count')
|
|
async def count_audio(table_name: str=None):
|
|
# Returns the total number of vectors in the system
|
|
try:
|
|
num = do_count(table_name, MILVUS_CLI)
|
|
LOGGER.info("Successfully count the number of data!")
|
|
return num
|
|
except Exception as e:
|
|
LOGGER.error(e)
|
|
return {'status': False, 'msg': e}, 400
|
|
|
|
|
|
@app.post('/audio/drop')
|
|
async def drop_tables(table_name: str=None):
|
|
# Delete the collection of Milvus and MySQL
|
|
try:
|
|
status = do_drop(table_name, MILVUS_CLI, MYSQL_CLI)
|
|
LOGGER.info("Successfully drop tables in Milvus and MySQL!")
|
|
return status
|
|
except Exception as e:
|
|
LOGGER.error(e)
|
|
return {'status': False, 'msg': e}, 400
|
|
|
|
|
|
if __name__ == '__main__':
|
|
uvicorn.run(app=app, host='0.0.0.0', port=8002)
|