Merge pull request #1547 from qingen/database-search
[vec] add demo for searching vectors base on MySQL and Milvus, t…pull/1583/head
commit
41d194d123
@ -0,0 +1,88 @@
|
|||||||
|
version: '3.5'
|
||||||
|
|
||||||
|
services:
|
||||||
|
etcd:
|
||||||
|
container_name: milvus-etcd
|
||||||
|
image: quay.io/coreos/etcd:v3.5.0
|
||||||
|
networks:
|
||||||
|
app_net:
|
||||||
|
environment:
|
||||||
|
- ETCD_AUTO_COMPACTION_MODE=revision
|
||||||
|
- ETCD_AUTO_COMPACTION_RETENTION=1000
|
||||||
|
- ETCD_QUOTA_BACKEND_BYTES=4294967296
|
||||||
|
volumes:
|
||||||
|
- ${DOCKER_VOLUME_DIRECTORY:-.}/volumes/etcd:/etcd
|
||||||
|
command: etcd -advertise-client-urls=http://127.0.0.1:2379 -listen-client-urls http://0.0.0.0:2379 --data-dir /etcd
|
||||||
|
|
||||||
|
minio:
|
||||||
|
container_name: milvus-minio
|
||||||
|
image: minio/minio:RELEASE.2020-12-03T00-03-10Z
|
||||||
|
networks:
|
||||||
|
app_net:
|
||||||
|
environment:
|
||||||
|
MINIO_ACCESS_KEY: minioadmin
|
||||||
|
MINIO_SECRET_KEY: minioadmin
|
||||||
|
volumes:
|
||||||
|
- ${DOCKER_VOLUME_DIRECTORY:-.}/volumes/minio:/minio_data
|
||||||
|
command: minio server /minio_data
|
||||||
|
healthcheck:
|
||||||
|
test: ["CMD", "curl", "-f", "http://localhost:9000/minio/health/live"]
|
||||||
|
interval: 30s
|
||||||
|
timeout: 20s
|
||||||
|
retries: 3
|
||||||
|
|
||||||
|
standalone:
|
||||||
|
container_name: milvus-standalone
|
||||||
|
image: milvusdb/milvus:v2.0.1
|
||||||
|
networks:
|
||||||
|
app_net:
|
||||||
|
ipv4_address: 172.16.23.10
|
||||||
|
command: ["milvus", "run", "standalone"]
|
||||||
|
environment:
|
||||||
|
ETCD_ENDPOINTS: etcd:2379
|
||||||
|
MINIO_ADDRESS: minio:9000
|
||||||
|
volumes:
|
||||||
|
- ${DOCKER_VOLUME_DIRECTORY:-.}/volumes/milvus:/var/lib/milvus
|
||||||
|
ports:
|
||||||
|
- "19530:19530"
|
||||||
|
depends_on:
|
||||||
|
- "etcd"
|
||||||
|
- "minio"
|
||||||
|
|
||||||
|
mysql:
|
||||||
|
container_name: audio-mysql
|
||||||
|
image: mysql:5.7
|
||||||
|
networks:
|
||||||
|
app_net:
|
||||||
|
ipv4_address: 172.16.23.11
|
||||||
|
environment:
|
||||||
|
- MYSQL_ROOT_PASSWORD=123456
|
||||||
|
volumes:
|
||||||
|
- ${DOCKER_VOLUME_DIRECTORY:-.}/volumes/mysql:/var/lib/mysql
|
||||||
|
ports:
|
||||||
|
- "3306:3306"
|
||||||
|
|
||||||
|
webclient:
|
||||||
|
container_name: audio-webclient
|
||||||
|
image: qingen1/paddlespeech-audio-search-client:2.3
|
||||||
|
networks:
|
||||||
|
app_net:
|
||||||
|
ipv4_address: 172.16.23.13
|
||||||
|
environment:
|
||||||
|
API_URL: 'http://127.0.0.1:8002'
|
||||||
|
ports:
|
||||||
|
- "8068:80"
|
||||||
|
healthcheck:
|
||||||
|
test: ["CMD", "curl", "-f", "http://localhost/"]
|
||||||
|
interval: 30s
|
||||||
|
timeout: 20s
|
||||||
|
retries: 3
|
||||||
|
|
||||||
|
networks:
|
||||||
|
app_net:
|
||||||
|
driver: bridge
|
||||||
|
ipam:
|
||||||
|
driver: default
|
||||||
|
config:
|
||||||
|
- subnet: 172.16.23.0/24
|
||||||
|
gateway: 172.16.23.1
|
After Width: | Height: | Size: 29 KiB |
After Width: | Height: | Size: 80 KiB |
After Width: | Height: | Size: 33 KiB |
After Width: | Height: | Size: 84 KiB |
@ -0,0 +1,12 @@
|
|||||||
|
soundfile==0.10.3.post1
|
||||||
|
librosa==0.8.0
|
||||||
|
numpy
|
||||||
|
pymysql
|
||||||
|
fastapi
|
||||||
|
uvicorn
|
||||||
|
diskcache==5.2.1
|
||||||
|
pymilvus==2.0.1
|
||||||
|
python-multipart
|
||||||
|
typing
|
||||||
|
starlette
|
||||||
|
pydantic
|
@ -0,0 +1,37 @@
|
|||||||
|
# Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved.
|
||||||
|
#
|
||||||
|
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||||
|
# you may not use this file except in compliance with the License.
|
||||||
|
# You may obtain a copy of the License at
|
||||||
|
#
|
||||||
|
# http://www.apache.org/licenses/LICENSE-2.0
|
||||||
|
#
|
||||||
|
# Unless required by applicable law or agreed to in writing, software
|
||||||
|
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||||
|
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||||
|
# See the License for the specific language governing permissions and
|
||||||
|
# limitations under the License.
|
||||||
|
|
||||||
|
import os
|
||||||
|
|
||||||
|
############### Milvus Configuration ###############
|
||||||
|
MILVUS_HOST = os.getenv("MILVUS_HOST", "127.0.0.1")
|
||||||
|
MILVUS_PORT = int(os.getenv("MILVUS_PORT", "19530"))
|
||||||
|
VECTOR_DIMENSION = int(os.getenv("VECTOR_DIMENSION", "2048"))
|
||||||
|
INDEX_FILE_SIZE = int(os.getenv("INDEX_FILE_SIZE", "1024"))
|
||||||
|
METRIC_TYPE = os.getenv("METRIC_TYPE", "L2")
|
||||||
|
DEFAULT_TABLE = os.getenv("DEFAULT_TABLE", "audio_table")
|
||||||
|
TOP_K = int(os.getenv("TOP_K", "10"))
|
||||||
|
|
||||||
|
############### MySQL Configuration ###############
|
||||||
|
MYSQL_HOST = os.getenv("MYSQL_HOST", "127.0.0.1")
|
||||||
|
MYSQL_PORT = int(os.getenv("MYSQL_PORT", "3306"))
|
||||||
|
MYSQL_USER = os.getenv("MYSQL_USER", "root")
|
||||||
|
MYSQL_PWD = os.getenv("MYSQL_PWD", "123456")
|
||||||
|
MYSQL_DB = os.getenv("MYSQL_DB", "mysql")
|
||||||
|
|
||||||
|
############### Data Path ###############
|
||||||
|
UPLOAD_PATH = os.getenv("UPLOAD_PATH", "tmp/audio-data")
|
||||||
|
|
||||||
|
############### Number of Log Files ###############
|
||||||
|
LOGS_NUM = int(os.getenv("logs_num", "0"))
|
@ -0,0 +1,39 @@
|
|||||||
|
# Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved.
|
||||||
|
#
|
||||||
|
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||||
|
# you may not use this file except in compliance with the License.
|
||||||
|
# You may obtain a copy of the License at
|
||||||
|
#
|
||||||
|
# http://www.apache.org/licenses/LICENSE-2.0
|
||||||
|
#
|
||||||
|
# Unless required by applicable law or agreed to in writing, software
|
||||||
|
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||||
|
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||||
|
# See the License for the specific language governing permissions and
|
||||||
|
# limitations under the License.
|
||||||
|
import os
|
||||||
|
|
||||||
|
import librosa
|
||||||
|
import numpy as np
|
||||||
|
from logs import LOGGER
|
||||||
|
|
||||||
|
|
||||||
|
def get_audio_embedding(path):
|
||||||
|
"""
|
||||||
|
Use vpr_inference to generate embedding of audio
|
||||||
|
"""
|
||||||
|
try:
|
||||||
|
RESAMPLE_RATE = 16000
|
||||||
|
audio, _ = librosa.load(path, sr=RESAMPLE_RATE, mono=True)
|
||||||
|
|
||||||
|
# TODO add infer/python interface to get embedding, now fake it by rand
|
||||||
|
# vpr = ECAPATDNN(checkpoint_path=None, device='cuda')
|
||||||
|
# embedding = vpr.inference(audio)
|
||||||
|
np.random.seed(hash(os.path.basename(path)) % 1000000)
|
||||||
|
embedding = np.random.rand(1, 2048)
|
||||||
|
embedding = embedding / np.linalg.norm(embedding)
|
||||||
|
embedding = embedding.tolist()[0]
|
||||||
|
return embedding
|
||||||
|
except Exception as e:
|
||||||
|
LOGGER.error(f"Error with embedding:{e}")
|
||||||
|
return None
|
@ -0,0 +1,168 @@
|
|||||||
|
# Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved.
|
||||||
|
#
|
||||||
|
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||||
|
# you may not use this file except in compliance with the License.
|
||||||
|
# You may obtain a copy of the License at
|
||||||
|
#
|
||||||
|
# http://www.apache.org/licenses/LICENSE-2.0
|
||||||
|
#
|
||||||
|
# Unless required by applicable law or agreed to in writing, software
|
||||||
|
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||||
|
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||||
|
# See the License for the specific language governing permissions and
|
||||||
|
# limitations under the License.
|
||||||
|
import os
|
||||||
|
from typing import Optional
|
||||||
|
|
||||||
|
import uvicorn
|
||||||
|
from config import UPLOAD_PATH
|
||||||
|
from diskcache import Cache
|
||||||
|
from fastapi import FastAPI
|
||||||
|
from fastapi import File
|
||||||
|
from fastapi import UploadFile
|
||||||
|
from logs import LOGGER
|
||||||
|
from milvus_helpers import MilvusHelper
|
||||||
|
from mysql_helpers import MySQLHelper
|
||||||
|
from operations.count import do_count
|
||||||
|
from operations.drop import do_drop
|
||||||
|
from operations.load import do_load
|
||||||
|
from operations.search import do_search
|
||||||
|
from pydantic import BaseModel
|
||||||
|
from starlette.middleware.cors import CORSMiddleware
|
||||||
|
from starlette.requests import Request
|
||||||
|
from starlette.responses import FileResponse
|
||||||
|
|
||||||
|
app = FastAPI()
|
||||||
|
app.add_middleware(
|
||||||
|
CORSMiddleware,
|
||||||
|
allow_origins=["*"],
|
||||||
|
allow_credentials=True,
|
||||||
|
allow_methods=["*"],
|
||||||
|
allow_headers=["*"])
|
||||||
|
|
||||||
|
MODEL = None
|
||||||
|
MILVUS_CLI = MilvusHelper()
|
||||||
|
MYSQL_CLI = MySQLHelper()
|
||||||
|
|
||||||
|
# Mkdir 'tmp/audio-data'
|
||||||
|
if not os.path.exists(UPLOAD_PATH):
|
||||||
|
os.makedirs(UPLOAD_PATH)
|
||||||
|
LOGGER.info(f"Mkdir the path: {UPLOAD_PATH}")
|
||||||
|
|
||||||
|
|
||||||
|
@app.get('/data')
|
||||||
|
def audio_path(audio_path):
|
||||||
|
# Get the audio file
|
||||||
|
try:
|
||||||
|
LOGGER.info(f"Successfully load audio: {audio_path}")
|
||||||
|
return FileResponse(audio_path)
|
||||||
|
except Exception as e:
|
||||||
|
LOGGER.error(f"upload audio error: {e}")
|
||||||
|
return {'status': False, 'msg': e}, 400
|
||||||
|
|
||||||
|
|
||||||
|
@app.get('/progress')
|
||||||
|
def get_progress():
|
||||||
|
# Get the progress of dealing with data
|
||||||
|
try:
|
||||||
|
cache = Cache('./tmp')
|
||||||
|
return f"current: {cache['current']}, total: {cache['total']}"
|
||||||
|
except Exception as e:
|
||||||
|
LOGGER.error(f"Upload data error: {e}")
|
||||||
|
return {'status': False, 'msg': e}, 400
|
||||||
|
|
||||||
|
|
||||||
|
class Item(BaseModel):
|
||||||
|
Table: Optional[str] = None
|
||||||
|
File: str
|
||||||
|
|
||||||
|
|
||||||
|
@app.post('/audio/load')
|
||||||
|
async def load_audios(item: Item):
|
||||||
|
# Insert all the audio files under the file path to Milvus/MySQL
|
||||||
|
try:
|
||||||
|
total_num = do_load(item.Table, item.File, MILVUS_CLI, MYSQL_CLI)
|
||||||
|
LOGGER.info(f"Successfully loaded data, total count: {total_num}")
|
||||||
|
return {'status': True, 'msg': "Successfully loaded data!"}
|
||||||
|
except Exception as e:
|
||||||
|
LOGGER.error(e)
|
||||||
|
return {'status': False, 'msg': e}, 400
|
||||||
|
|
||||||
|
|
||||||
|
@app.post('/audio/search')
|
||||||
|
async def search_audio(request: Request,
|
||||||
|
table_name: str=None,
|
||||||
|
audio: UploadFile=File(...)):
|
||||||
|
# Search the uploaded audio in Milvus/MySQL
|
||||||
|
try:
|
||||||
|
# Save the upload data to server.
|
||||||
|
content = await audio.read()
|
||||||
|
query_audio_path = os.path.join(UPLOAD_PATH, audio.filename)
|
||||||
|
with open(query_audio_path, "wb+") as f:
|
||||||
|
f.write(content)
|
||||||
|
host = request.headers['host']
|
||||||
|
_, paths, distances = do_search(host, table_name, query_audio_path,
|
||||||
|
MILVUS_CLI, MYSQL_CLI)
|
||||||
|
names = []
|
||||||
|
for path, score in zip(paths, distances):
|
||||||
|
names.append(os.path.basename(path))
|
||||||
|
LOGGER.info(f"search result {path}, score {score}")
|
||||||
|
res = dict(zip(paths, zip(names, distances)))
|
||||||
|
# Sort results by distance metric, closest distances first
|
||||||
|
res = sorted(res.items(), key=lambda item: item[1][1], reverse=True)
|
||||||
|
LOGGER.info("Successfully searched similar audio!")
|
||||||
|
return res
|
||||||
|
except Exception as e:
|
||||||
|
LOGGER.error(e)
|
||||||
|
return {'status': False, 'msg': e}, 400
|
||||||
|
|
||||||
|
|
||||||
|
@app.post('/audio/search/local')
|
||||||
|
async def search_local_audio(request: Request,
|
||||||
|
query_audio_path: str,
|
||||||
|
table_name: str=None):
|
||||||
|
# Search the uploaded audio in Milvus/MySQL
|
||||||
|
try:
|
||||||
|
host = request.headers['host']
|
||||||
|
_, paths, distances = do_search(host, table_name, query_audio_path,
|
||||||
|
MILVUS_CLI, MYSQL_CLI)
|
||||||
|
names = []
|
||||||
|
for path, score in zip(paths, distances):
|
||||||
|
names.append(os.path.basename(path))
|
||||||
|
LOGGER.info(f"search result {path}, score {score}")
|
||||||
|
res = dict(zip(paths, zip(names, distances)))
|
||||||
|
# Sort results by distance metric, closest distances first
|
||||||
|
res = sorted(res.items(), key=lambda item: item[1][1], reverse=True)
|
||||||
|
LOGGER.info("Successfully searched similar audio!")
|
||||||
|
return res
|
||||||
|
except Exception as e:
|
||||||
|
LOGGER.error(e)
|
||||||
|
return {'status': False, 'msg': e}, 400
|
||||||
|
|
||||||
|
|
||||||
|
@app.get('/audio/count')
|
||||||
|
async def count_audio(table_name: str=None):
|
||||||
|
# Returns the total number of vectors in the system
|
||||||
|
try:
|
||||||
|
num = do_count(table_name, MILVUS_CLI)
|
||||||
|
LOGGER.info("Successfully count the number of data!")
|
||||||
|
return num
|
||||||
|
except Exception as e:
|
||||||
|
LOGGER.error(e)
|
||||||
|
return {'status': False, 'msg': e}, 400
|
||||||
|
|
||||||
|
|
||||||
|
@app.post('/audio/drop')
|
||||||
|
async def drop_tables(table_name: str=None):
|
||||||
|
# Delete the collection of Milvus and MySQL
|
||||||
|
try:
|
||||||
|
status = do_drop(table_name, MILVUS_CLI, MYSQL_CLI)
|
||||||
|
LOGGER.info("Successfully drop tables in Milvus and MySQL!")
|
||||||
|
return status
|
||||||
|
except Exception as e:
|
||||||
|
LOGGER.error(e)
|
||||||
|
return {'status': False, 'msg': e}, 400
|
||||||
|
|
||||||
|
|
||||||
|
if __name__ == '__main__':
|
||||||
|
uvicorn.run(app=app, host='0.0.0.0', port=8002)
|
@ -0,0 +1,185 @@
|
|||||||
|
# Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved.
|
||||||
|
#
|
||||||
|
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||||
|
# you may not use this file except in compliance with the License.
|
||||||
|
# You may obtain a copy of the License at
|
||||||
|
#
|
||||||
|
# http://www.apache.org/licenses/LICENSE-2.0
|
||||||
|
#
|
||||||
|
# Unless required by applicable law or agreed to in writing, software
|
||||||
|
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||||
|
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||||
|
# See the License for the specific language governing permissions and
|
||||||
|
# limitations under the License.
|
||||||
|
import sys
|
||||||
|
|
||||||
|
from config import METRIC_TYPE
|
||||||
|
from config import MILVUS_HOST
|
||||||
|
from config import MILVUS_PORT
|
||||||
|
from config import VECTOR_DIMENSION
|
||||||
|
from logs import LOGGER
|
||||||
|
from pymilvus import Collection
|
||||||
|
from pymilvus import CollectionSchema
|
||||||
|
from pymilvus import connections
|
||||||
|
from pymilvus import DataType
|
||||||
|
from pymilvus import FieldSchema
|
||||||
|
from pymilvus import utility
|
||||||
|
|
||||||
|
|
||||||
|
class MilvusHelper:
|
||||||
|
"""
|
||||||
|
the basic operations of PyMilvus
|
||||||
|
|
||||||
|
# This example shows how to:
|
||||||
|
# 1. connect to Milvus server
|
||||||
|
# 2. create a collection
|
||||||
|
# 3. insert entities
|
||||||
|
# 4. create index
|
||||||
|
# 5. search
|
||||||
|
# 6. delete a collection
|
||||||
|
|
||||||
|
"""
|
||||||
|
|
||||||
|
def __init__(self):
|
||||||
|
try:
|
||||||
|
self.collection = None
|
||||||
|
connections.connect(host=MILVUS_HOST, port=MILVUS_PORT)
|
||||||
|
LOGGER.debug(
|
||||||
|
f"Successfully connect to Milvus with IP:{MILVUS_HOST} and PORT:{MILVUS_PORT}"
|
||||||
|
)
|
||||||
|
except Exception as e:
|
||||||
|
LOGGER.error(f"Failed to connect Milvus: {e}")
|
||||||
|
sys.exit(1)
|
||||||
|
|
||||||
|
def set_collection(self, collection_name):
|
||||||
|
try:
|
||||||
|
if self.has_collection(collection_name):
|
||||||
|
self.collection = Collection(name=collection_name)
|
||||||
|
else:
|
||||||
|
raise Exception(
|
||||||
|
f"There is no collection named:{collection_name}")
|
||||||
|
except Exception as e:
|
||||||
|
LOGGER.error(f"Failed to set collection in Milvus: {e}")
|
||||||
|
sys.exit(1)
|
||||||
|
|
||||||
|
def has_collection(self, collection_name):
|
||||||
|
# Return if Milvus has the collection
|
||||||
|
try:
|
||||||
|
return utility.has_collection(collection_name)
|
||||||
|
except Exception as e:
|
||||||
|
LOGGER.error(f"Failed to check state of collection in Milvus: {e}")
|
||||||
|
sys.exit(1)
|
||||||
|
|
||||||
|
def create_collection(self, collection_name):
|
||||||
|
# Create milvus collection if not exists
|
||||||
|
try:
|
||||||
|
if not self.has_collection(collection_name):
|
||||||
|
field1 = FieldSchema(
|
||||||
|
name="id",
|
||||||
|
dtype=DataType.INT64,
|
||||||
|
descrition="int64",
|
||||||
|
is_primary=True,
|
||||||
|
auto_id=True)
|
||||||
|
field2 = FieldSchema(
|
||||||
|
name="embedding",
|
||||||
|
dtype=DataType.FLOAT_VECTOR,
|
||||||
|
descrition="speaker embeddings",
|
||||||
|
dim=VECTOR_DIMENSION,
|
||||||
|
is_primary=False)
|
||||||
|
schema = CollectionSchema(
|
||||||
|
fields=[field1, field2], description="embeddings info")
|
||||||
|
self.collection = Collection(
|
||||||
|
name=collection_name, schema=schema)
|
||||||
|
LOGGER.debug(f"Create Milvus collection: {collection_name}")
|
||||||
|
else:
|
||||||
|
self.set_collection(collection_name)
|
||||||
|
return "OK"
|
||||||
|
except Exception as e:
|
||||||
|
LOGGER.error(f"Failed to create collection in Milvus: {e}")
|
||||||
|
sys.exit(1)
|
||||||
|
|
||||||
|
def insert(self, collection_name, vectors):
|
||||||
|
# Batch insert vectors to milvus collection
|
||||||
|
try:
|
||||||
|
self.create_collection(collection_name)
|
||||||
|
data = [vectors]
|
||||||
|
self.set_collection(collection_name)
|
||||||
|
mr = self.collection.insert(data)
|
||||||
|
ids = mr.primary_keys
|
||||||
|
self.collection.load()
|
||||||
|
LOGGER.debug(
|
||||||
|
f"Insert vectors to Milvus in collection: {collection_name} with {len(vectors)} rows"
|
||||||
|
)
|
||||||
|
return ids
|
||||||
|
except Exception as e:
|
||||||
|
LOGGER.error(f"Failed to insert data to Milvus: {e}")
|
||||||
|
sys.exit(1)
|
||||||
|
|
||||||
|
def create_index(self, collection_name):
|
||||||
|
# Create IVF_FLAT index on milvus collection
|
||||||
|
try:
|
||||||
|
self.set_collection(collection_name)
|
||||||
|
default_index = {
|
||||||
|
"index_type": "IVF_SQ8",
|
||||||
|
"metric_type": METRIC_TYPE,
|
||||||
|
"params": {
|
||||||
|
"nlist": 16384
|
||||||
|
}
|
||||||
|
}
|
||||||
|
status = self.collection.create_index(
|
||||||
|
field_name="embedding", index_params=default_index)
|
||||||
|
if not status.code:
|
||||||
|
LOGGER.debug(
|
||||||
|
f"Successfully create index in collection:{collection_name} with param:{default_index}"
|
||||||
|
)
|
||||||
|
return status
|
||||||
|
else:
|
||||||
|
raise Exception(status.message)
|
||||||
|
except Exception as e:
|
||||||
|
LOGGER.error(f"Failed to create index: {e}")
|
||||||
|
sys.exit(1)
|
||||||
|
|
||||||
|
def delete_collection(self, collection_name):
|
||||||
|
# Delete Milvus collection
|
||||||
|
try:
|
||||||
|
self.set_collection(collection_name)
|
||||||
|
self.collection.drop()
|
||||||
|
LOGGER.debug("Successfully drop collection!")
|
||||||
|
return "ok"
|
||||||
|
except Exception as e:
|
||||||
|
LOGGER.error(f"Failed to drop collection: {e}")
|
||||||
|
sys.exit(1)
|
||||||
|
|
||||||
|
def search_vectors(self, collection_name, vectors, top_k):
|
||||||
|
# Search vector in milvus collection
|
||||||
|
try:
|
||||||
|
self.set_collection(collection_name)
|
||||||
|
search_params = {
|
||||||
|
"metric_type": METRIC_TYPE,
|
||||||
|
"params": {
|
||||||
|
"nprobe": 16
|
||||||
|
}
|
||||||
|
}
|
||||||
|
res = self.collection.search(
|
||||||
|
vectors,
|
||||||
|
anns_field="embedding",
|
||||||
|
param=search_params,
|
||||||
|
limit=top_k)
|
||||||
|
LOGGER.debug(f"Successfully search in collection: {res}")
|
||||||
|
return res
|
||||||
|
except Exception as e:
|
||||||
|
LOGGER.error(f"Failed to search vectors in Milvus: {e}")
|
||||||
|
sys.exit(1)
|
||||||
|
|
||||||
|
def count(self, collection_name):
|
||||||
|
# Get the number of milvus collection
|
||||||
|
try:
|
||||||
|
self.set_collection(collection_name)
|
||||||
|
num = self.collection.num_entities
|
||||||
|
LOGGER.debug(
|
||||||
|
f"Successfully get the num:{num} of the collection:{collection_name}"
|
||||||
|
)
|
||||||
|
return num
|
||||||
|
except Exception as e:
|
||||||
|
LOGGER.error(f"Failed to count vectors in Milvus: {e}")
|
||||||
|
sys.exit(1)
|
@ -0,0 +1,133 @@
|
|||||||
|
# Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved.
|
||||||
|
#
|
||||||
|
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||||
|
# you may not use this file except in compliance with the License.
|
||||||
|
# You may obtain a copy of the License at
|
||||||
|
#
|
||||||
|
# http://www.apache.org/licenses/LICENSE-2.0
|
||||||
|
#
|
||||||
|
# Unless required by applicable law or agreed to in writing, software
|
||||||
|
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||||
|
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||||
|
# See the License for the specific language governing permissions and
|
||||||
|
# limitations under the License.
|
||||||
|
import sys
|
||||||
|
|
||||||
|
import pymysql
|
||||||
|
from config import MYSQL_DB
|
||||||
|
from config import MYSQL_HOST
|
||||||
|
from config import MYSQL_PORT
|
||||||
|
from config import MYSQL_PWD
|
||||||
|
from config import MYSQL_USER
|
||||||
|
from logs import LOGGER
|
||||||
|
|
||||||
|
|
||||||
|
class MySQLHelper():
|
||||||
|
"""
|
||||||
|
the basic operations of PyMySQL
|
||||||
|
|
||||||
|
# This example shows how to:
|
||||||
|
# 1. connect to MySQL server
|
||||||
|
# 2. create a table
|
||||||
|
# 3. insert data to table
|
||||||
|
# 4. search by milvus ids
|
||||||
|
# 5. delete table
|
||||||
|
"""
|
||||||
|
|
||||||
|
def __init__(self):
|
||||||
|
self.conn = pymysql.connect(
|
||||||
|
host=MYSQL_HOST,
|
||||||
|
user=MYSQL_USER,
|
||||||
|
port=MYSQL_PORT,
|
||||||
|
password=MYSQL_PWD,
|
||||||
|
database=MYSQL_DB,
|
||||||
|
local_infile=True)
|
||||||
|
self.cursor = self.conn.cursor()
|
||||||
|
|
||||||
|
def test_connection(self):
|
||||||
|
try:
|
||||||
|
self.conn.ping()
|
||||||
|
except Exception:
|
||||||
|
self.conn = pymysql.connect(
|
||||||
|
host=MYSQL_HOST,
|
||||||
|
user=MYSQL_USER,
|
||||||
|
port=MYSQL_PORT,
|
||||||
|
password=MYSQL_PWD,
|
||||||
|
database=MYSQL_DB,
|
||||||
|
local_infile=True)
|
||||||
|
self.cursor = self.conn.cursor()
|
||||||
|
|
||||||
|
def create_mysql_table(self, table_name):
|
||||||
|
# Create mysql table if not exists
|
||||||
|
self.test_connection()
|
||||||
|
sql = "create table if not exists " + table_name + "(milvus_id TEXT, audio_path TEXT);"
|
||||||
|
try:
|
||||||
|
self.cursor.execute(sql)
|
||||||
|
LOGGER.debug(f"MYSQL create table: {table_name} with sql: {sql}")
|
||||||
|
except Exception as e:
|
||||||
|
LOGGER.error(f"MYSQL ERROR: {e} with sql: {sql}")
|
||||||
|
sys.exit(1)
|
||||||
|
|
||||||
|
def load_data_to_mysql(self, table_name, data):
|
||||||
|
# Batch insert (Milvus_ids, img_path) to mysql
|
||||||
|
self.test_connection()
|
||||||
|
sql = "insert into " + table_name + " (milvus_id,audio_path) values (%s,%s);"
|
||||||
|
try:
|
||||||
|
self.cursor.executemany(sql, data)
|
||||||
|
self.conn.commit()
|
||||||
|
LOGGER.debug(
|
||||||
|
f"MYSQL loads data to table: {table_name} successfully")
|
||||||
|
except Exception as e:
|
||||||
|
LOGGER.error(f"MYSQL ERROR: {e} with sql: {sql}")
|
||||||
|
sys.exit(1)
|
||||||
|
|
||||||
|
def search_by_milvus_ids(self, ids, table_name):
|
||||||
|
# Get the img_path according to the milvus ids
|
||||||
|
self.test_connection()
|
||||||
|
str_ids = str(ids).replace('[', '').replace(']', '')
|
||||||
|
sql = "select audio_path from " + table_name + " where milvus_id in (" + str_ids + ") order by field (milvus_id," + str_ids + ");"
|
||||||
|
try:
|
||||||
|
self.cursor.execute(sql)
|
||||||
|
results = self.cursor.fetchall()
|
||||||
|
results = [res[0] for res in results]
|
||||||
|
LOGGER.debug("MYSQL search by milvus id.")
|
||||||
|
return results
|
||||||
|
except Exception as e:
|
||||||
|
LOGGER.error(f"MYSQL ERROR: {e} with sql: {sql}")
|
||||||
|
sys.exit(1)
|
||||||
|
|
||||||
|
def delete_table(self, table_name):
|
||||||
|
# Delete mysql table if exists
|
||||||
|
self.test_connection()
|
||||||
|
sql = "drop table if exists " + table_name + ";"
|
||||||
|
try:
|
||||||
|
self.cursor.execute(sql)
|
||||||
|
LOGGER.debug(f"MYSQL delete table:{table_name}")
|
||||||
|
except Exception as e:
|
||||||
|
LOGGER.error(f"MYSQL ERROR: {e} with sql: {sql}")
|
||||||
|
sys.exit(1)
|
||||||
|
|
||||||
|
def delete_all_data(self, table_name):
|
||||||
|
# Delete all the data in mysql table
|
||||||
|
self.test_connection()
|
||||||
|
sql = 'delete from ' + table_name + ';'
|
||||||
|
try:
|
||||||
|
self.cursor.execute(sql)
|
||||||
|
self.conn.commit()
|
||||||
|
LOGGER.debug(f"MYSQL delete all data in table:{table_name}")
|
||||||
|
except Exception as e:
|
||||||
|
LOGGER.error(f"MYSQL ERROR: {e} with sql: {sql}")
|
||||||
|
sys.exit(1)
|
||||||
|
|
||||||
|
def count_table(self, table_name):
|
||||||
|
# Get the number of mysql table
|
||||||
|
self.test_connection()
|
||||||
|
sql = "select count(milvus_id) from " + table_name + ";"
|
||||||
|
try:
|
||||||
|
self.cursor.execute(sql)
|
||||||
|
results = self.cursor.fetchall()
|
||||||
|
LOGGER.debug(f"MYSQL count table:{table_name}")
|
||||||
|
return results[0][0]
|
||||||
|
except Exception as e:
|
||||||
|
LOGGER.error(f"MYSQL ERROR: {e} with sql: {sql}")
|
||||||
|
sys.exit(1)
|
@ -0,0 +1,13 @@
|
|||||||
|
# Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved.
|
||||||
|
#
|
||||||
|
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||||
|
# you may not use this file except in compliance with the License.
|
||||||
|
# You may obtain a copy of the License at
|
||||||
|
#
|
||||||
|
# http://www.apache.org/licenses/LICENSE-2.0
|
||||||
|
#
|
||||||
|
# Unless required by applicable law or agreed to in writing, software
|
||||||
|
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||||
|
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||||
|
# See the License for the specific language governing permissions and
|
||||||
|
# limitations under the License.
|
@ -0,0 +1,33 @@
|
|||||||
|
# Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved.
|
||||||
|
#
|
||||||
|
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||||
|
# you may not use this file except in compliance with the License.
|
||||||
|
# You may obtain a copy of the License at
|
||||||
|
#
|
||||||
|
# http://www.apache.org/licenses/LICENSE-2.0
|
||||||
|
#
|
||||||
|
# Unless required by applicable law or agreed to in writing, software
|
||||||
|
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||||
|
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||||
|
# See the License for the specific language governing permissions and
|
||||||
|
# limitations under the License.
|
||||||
|
import sys
|
||||||
|
|
||||||
|
from config import DEFAULT_TABLE
|
||||||
|
from logs import LOGGER
|
||||||
|
|
||||||
|
|
||||||
|
def do_count(table_name, milvus_cli):
|
||||||
|
"""
|
||||||
|
Returns the total number of vectors in the system
|
||||||
|
"""
|
||||||
|
if not table_name:
|
||||||
|
table_name = DEFAULT_TABLE
|
||||||
|
try:
|
||||||
|
if not milvus_cli.has_collection(table_name):
|
||||||
|
return None
|
||||||
|
num = milvus_cli.count(table_name)
|
||||||
|
return num
|
||||||
|
except Exception as e:
|
||||||
|
LOGGER.error(f"Error attempting to count table {e}")
|
||||||
|
sys.exit(1)
|
@ -0,0 +1,34 @@
|
|||||||
|
# Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved.
|
||||||
|
#
|
||||||
|
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||||
|
# you may not use this file except in compliance with the License.
|
||||||
|
# You may obtain a copy of the License at
|
||||||
|
#
|
||||||
|
# http://www.apache.org/licenses/LICENSE-2.0
|
||||||
|
#
|
||||||
|
# Unless required by applicable law or agreed to in writing, software
|
||||||
|
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||||
|
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||||
|
# See the License for the specific language governing permissions and
|
||||||
|
# limitations under the License.
|
||||||
|
import sys
|
||||||
|
|
||||||
|
from config import DEFAULT_TABLE
|
||||||
|
from logs import LOGGER
|
||||||
|
|
||||||
|
|
||||||
|
def do_drop(table_name, milvus_cli, mysql_cli):
|
||||||
|
"""
|
||||||
|
Delete the collection of Milvus and MySQL
|
||||||
|
"""
|
||||||
|
if not table_name:
|
||||||
|
table_name = DEFAULT_TABLE
|
||||||
|
try:
|
||||||
|
if not milvus_cli.has_collection(table_name):
|
||||||
|
return "Collection is not exist"
|
||||||
|
status = milvus_cli.delete_collection(table_name)
|
||||||
|
mysql_cli.delete_table(table_name)
|
||||||
|
return status
|
||||||
|
except Exception as e:
|
||||||
|
LOGGER.error(f"Error attempting to drop table: {e}")
|
||||||
|
sys.exit(1)
|
@ -0,0 +1,85 @@
|
|||||||
|
# Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved.
|
||||||
|
#
|
||||||
|
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||||
|
# you may not use this file except in compliance with the License.
|
||||||
|
# You may obtain a copy of the License at
|
||||||
|
#
|
||||||
|
# http://www.apache.org/licenses/LICENSE-2.0
|
||||||
|
#
|
||||||
|
# Unless required by applicable law or agreed to in writing, software
|
||||||
|
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||||
|
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||||
|
# See the License for the specific language governing permissions and
|
||||||
|
# limitations under the License.
|
||||||
|
import os
|
||||||
|
import sys
|
||||||
|
|
||||||
|
from config import DEFAULT_TABLE
|
||||||
|
from diskcache import Cache
|
||||||
|
from encode import get_audio_embedding
|
||||||
|
from logs import LOGGER
|
||||||
|
|
||||||
|
|
||||||
|
def get_audios(path):
|
||||||
|
"""
|
||||||
|
List all wav and aif files recursively under the path folder.
|
||||||
|
"""
|
||||||
|
supported_formats = [".wav", ".mp3", ".ogg", ".flac", ".m4a"]
|
||||||
|
return [
|
||||||
|
item
|
||||||
|
for sublist in [[os.path.join(dir, file) for file in files]
|
||||||
|
for dir, _, files in list(os.walk(path))]
|
||||||
|
for item in sublist if os.path.splitext(item)[1] in supported_formats
|
||||||
|
]
|
||||||
|
|
||||||
|
|
||||||
|
def extract_features(audio_dir):
|
||||||
|
"""
|
||||||
|
Get the vector of audio
|
||||||
|
"""
|
||||||
|
try:
|
||||||
|
cache = Cache('./tmp')
|
||||||
|
feats = []
|
||||||
|
names = []
|
||||||
|
audio_list = get_audios(audio_dir)
|
||||||
|
total = len(audio_list)
|
||||||
|
cache['total'] = total
|
||||||
|
for i, audio_path in enumerate(audio_list):
|
||||||
|
norm_feat = get_audio_embedding(audio_path)
|
||||||
|
if norm_feat is None:
|
||||||
|
continue
|
||||||
|
feats.append(norm_feat)
|
||||||
|
names.append(audio_path.encode())
|
||||||
|
cache['current'] = i + 1
|
||||||
|
print(
|
||||||
|
f"Extracting feature from audio No. {i + 1} , {total} audios in total"
|
||||||
|
)
|
||||||
|
return feats, names
|
||||||
|
except Exception as e:
|
||||||
|
LOGGER.error(f"Error with extracting feature from audio {e}")
|
||||||
|
sys.exit(1)
|
||||||
|
|
||||||
|
|
||||||
|
def format_data(ids, names):
|
||||||
|
"""
|
||||||
|
Combine the id of the vector and the name of the audio into a list
|
||||||
|
"""
|
||||||
|
data = []
|
||||||
|
for i in range(len(ids)):
|
||||||
|
value = (str(ids[i]), names[i])
|
||||||
|
data.append(value)
|
||||||
|
return data
|
||||||
|
|
||||||
|
|
||||||
|
def do_load(table_name, audio_dir, milvus_cli, mysql_cli):
|
||||||
|
"""
|
||||||
|
Import vectors to Milvus and data to Mysql respectively
|
||||||
|
"""
|
||||||
|
if not table_name:
|
||||||
|
table_name = DEFAULT_TABLE
|
||||||
|
vectors, names = extract_features(audio_dir)
|
||||||
|
ids = milvus_cli.insert(table_name, vectors)
|
||||||
|
milvus_cli.create_index(table_name)
|
||||||
|
mysql_cli.create_mysql_table(table_name)
|
||||||
|
mysql_cli.load_data_to_mysql(table_name, format_data(ids, names))
|
||||||
|
return len(ids)
|
@ -0,0 +1,41 @@
|
|||||||
|
# Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved.
|
||||||
|
#
|
||||||
|
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||||
|
# you may not use this file except in compliance with the License.
|
||||||
|
# You may obtain a copy of the License at
|
||||||
|
#
|
||||||
|
# http://www.apache.org/licenses/LICENSE-2.0
|
||||||
|
#
|
||||||
|
# Unless required by applicable law or agreed to in writing, software
|
||||||
|
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||||
|
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||||
|
# See the License for the specific language governing permissions and
|
||||||
|
# limitations under the License.
|
||||||
|
import sys
|
||||||
|
|
||||||
|
from config import DEFAULT_TABLE
|
||||||
|
from config import TOP_K
|
||||||
|
from encode import get_audio_embedding
|
||||||
|
from logs import LOGGER
|
||||||
|
|
||||||
|
|
||||||
|
def do_search(host, table_name, audio_path, milvus_cli, mysql_cli):
|
||||||
|
"""
|
||||||
|
Search the uploaded audio in Milvus/MySQL
|
||||||
|
"""
|
||||||
|
try:
|
||||||
|
if not table_name:
|
||||||
|
table_name = DEFAULT_TABLE
|
||||||
|
feat = get_audio_embedding(audio_path)
|
||||||
|
vectors = milvus_cli.search_vectors(table_name, [feat], TOP_K)
|
||||||
|
vids = [str(x.id) for x in vectors[0]]
|
||||||
|
paths = mysql_cli.search_by_milvus_ids(vids, table_name)
|
||||||
|
distances = [x.distance for x in vectors[0]]
|
||||||
|
for i in range(len(paths)):
|
||||||
|
tmp = "http://" + str(host) + "/data?audio_path=" + str(paths[i])
|
||||||
|
paths[i] = tmp
|
||||||
|
distances[i] = (1 - distances[i]) * 100
|
||||||
|
return vids, paths, distances
|
||||||
|
except Exception as e:
|
||||||
|
LOGGER.error(f"Error with search: {e}")
|
||||||
|
sys.exit(1)
|
@ -0,0 +1,95 @@
|
|||||||
|
# Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved.
|
||||||
|
#
|
||||||
|
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||||
|
# you may not use this file except in compliance with the License.
|
||||||
|
# You may obtain a copy of the License at
|
||||||
|
#
|
||||||
|
# http://www.apache.org/licenses/LICENSE-2.0
|
||||||
|
#
|
||||||
|
# Unless required by applicable law or agreed to in writing, software
|
||||||
|
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||||
|
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||||
|
# See the License for the specific language governing permissions and
|
||||||
|
# limitations under the License.
|
||||||
|
import zipfile
|
||||||
|
|
||||||
|
import gdown
|
||||||
|
from fastapi.testclient import TestClient
|
||||||
|
from main import app
|
||||||
|
|
||||||
|
client = TestClient(app)
|
||||||
|
|
||||||
|
|
||||||
|
def download_audio_data():
|
||||||
|
"""
|
||||||
|
download audio data
|
||||||
|
"""
|
||||||
|
url = 'https://drive.google.com/uc?id=1bKu21JWBfcZBuEuzFEvPoAX6PmRrgnUp'
|
||||||
|
gdown.download(url)
|
||||||
|
|
||||||
|
with zipfile.ZipFile('example_audio.zip', 'r') as zip_ref:
|
||||||
|
zip_ref.extractall('./example_audio')
|
||||||
|
|
||||||
|
|
||||||
|
def test_drop():
|
||||||
|
"""
|
||||||
|
Delete the collection of Milvus and MySQL
|
||||||
|
"""
|
||||||
|
response = client.post("/audio/drop")
|
||||||
|
assert response.status_code == 200
|
||||||
|
|
||||||
|
|
||||||
|
def test_load():
|
||||||
|
"""
|
||||||
|
Insert all the audio files under the file path to Milvus/MySQL
|
||||||
|
"""
|
||||||
|
response = client.post("/audio/load", json={"File": "./example_audio"})
|
||||||
|
assert response.status_code == 200
|
||||||
|
assert response.json() == {
|
||||||
|
'status': True,
|
||||||
|
'msg': "Successfully loaded data!"
|
||||||
|
}
|
||||||
|
|
||||||
|
|
||||||
|
def test_progress():
|
||||||
|
"""
|
||||||
|
Get the progress of dealing with data
|
||||||
|
"""
|
||||||
|
response = client.get("/progress")
|
||||||
|
assert response.status_code == 200
|
||||||
|
assert response.json() == "current: 20, total: 20"
|
||||||
|
|
||||||
|
|
||||||
|
def test_count():
|
||||||
|
"""
|
||||||
|
Returns the total number of vectors in the system
|
||||||
|
"""
|
||||||
|
response = client.get("audio/count")
|
||||||
|
assert response.status_code == 200
|
||||||
|
assert response.json() == 20
|
||||||
|
|
||||||
|
|
||||||
|
def test_search():
|
||||||
|
"""
|
||||||
|
Search the uploaded audio in Milvus/MySQL
|
||||||
|
"""
|
||||||
|
response = client.post(
|
||||||
|
"/audio/search/local?query_audio_path=.%2Fexample_audio%2Ftest.wav")
|
||||||
|
assert response.status_code == 200
|
||||||
|
assert len(response.json()) == 10
|
||||||
|
|
||||||
|
|
||||||
|
def test_data():
|
||||||
|
"""
|
||||||
|
Get the audio file
|
||||||
|
"""
|
||||||
|
response = client.get("/data?audio_path=.%2Fexample_audio%2Ftest.wav")
|
||||||
|
assert response.status_code == 200
|
||||||
|
|
||||||
|
|
||||||
|
if __name__ == "__main__":
|
||||||
|
download_audio_data()
|
||||||
|
test_load()
|
||||||
|
test_count()
|
||||||
|
test_search()
|
||||||
|
test_drop()
|
Loading…
Reference in new issue