[wip][vec] add demo for searching vectors base on MySQL and Milvus, test=doc #1543
parent
fd20056718
commit
b1f11ccd65
@ -0,0 +1,73 @@
|
||||
version: '3.5'
|
||||
|
||||
services:
|
||||
etcd:
|
||||
container_name: milvus-etcd
|
||||
image: quay.io/coreos/etcd:v3.5.0
|
||||
networks:
|
||||
app_net:
|
||||
environment:
|
||||
- ETCD_AUTO_COMPACTION_MODE=revision
|
||||
- ETCD_AUTO_COMPACTION_RETENTION=1000
|
||||
- ETCD_QUOTA_BACKEND_BYTES=4294967296
|
||||
volumes:
|
||||
- ${DOCKER_VOLUME_DIRECTORY:-.}/volumes/etcd:/etcd
|
||||
command: etcd -advertise-client-urls=http://127.0.0.1:2379 -listen-client-urls http://0.0.0.0:2379 --data-dir /etcd
|
||||
|
||||
minio:
|
||||
container_name: milvus-minio
|
||||
image: minio/minio:RELEASE.2020-12-03T00-03-10Z
|
||||
networks:
|
||||
app_net:
|
||||
environment:
|
||||
MINIO_ACCESS_KEY: minioadmin
|
||||
MINIO_SECRET_KEY: minioadmin
|
||||
volumes:
|
||||
- ${DOCKER_VOLUME_DIRECTORY:-.}/volumes/minio:/minio_data
|
||||
command: minio server /minio_data
|
||||
healthcheck:
|
||||
test: ["CMD", "curl", "-f", "http://localhost:9000/minio/health/live"]
|
||||
interval: 30s
|
||||
timeout: 20s
|
||||
retries: 3
|
||||
|
||||
standalone:
|
||||
container_name: milvus-standalone
|
||||
image: milvusdb/milvus:v2.0.1
|
||||
networks:
|
||||
app_net:
|
||||
ipv4_address: 172.16.23.10
|
||||
command: ["milvus", "run", "standalone"]
|
||||
environment:
|
||||
ETCD_ENDPOINTS: etcd:2379
|
||||
MINIO_ADDRESS: minio:9000
|
||||
volumes:
|
||||
- ${DOCKER_VOLUME_DIRECTORY:-.}/volumes/milvus:/var/lib/milvus
|
||||
ports:
|
||||
- "19530:19530"
|
||||
depends_on:
|
||||
- "etcd"
|
||||
- "minio"
|
||||
|
||||
|
||||
mysql:
|
||||
container_name: audio-mysql
|
||||
image: mysql:5.7
|
||||
networks:
|
||||
app_net:
|
||||
ipv4_address: 172.16.23.11
|
||||
environment:
|
||||
- MYSQL_ROOT_PASSWORD=123456
|
||||
volumes:
|
||||
- ${DOCKER_VOLUME_DIRECTORY:-.}/volumes/mysql:/var/lib/mysql
|
||||
ports:
|
||||
- "3306:3306"
|
||||
|
||||
networks:
|
||||
app_net:
|
||||
driver: bridge
|
||||
ipam:
|
||||
driver: default
|
||||
config:
|
||||
- subnet: 172.16.23.0/24
|
||||
gateway: 172.16.23.1
|
@ -0,0 +1,12 @@
|
||||
soundfile==0.10.3.post1
|
||||
librosa==0.8.0
|
||||
numpy
|
||||
pymysql
|
||||
fastapi
|
||||
uvicorn
|
||||
diskcache==5.2.1
|
||||
pymilvus==2.0.1
|
||||
python-multipart
|
||||
typing
|
||||
starlette
|
||||
pydantic
|
@ -0,0 +1,37 @@
|
||||
# Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
|
||||
import os
|
||||
|
||||
############### Milvus Configuration ###############
|
||||
MILVUS_HOST = os.getenv("MILVUS_HOST", "127.0.0.1")
|
||||
MILVUS_PORT = int(os.getenv("MILVUS_PORT", "19530"))
|
||||
VECTOR_DIMENSION = int(os.getenv("VECTOR_DIMENSION", "2048"))
|
||||
INDEX_FILE_SIZE = int(os.getenv("INDEX_FILE_SIZE", "1024"))
|
||||
METRIC_TYPE = os.getenv("METRIC_TYPE", "L2")
|
||||
DEFAULT_TABLE = os.getenv("DEFAULT_TABLE", "audio_table")
|
||||
TOP_K = int(os.getenv("TOP_K", "10"))
|
||||
|
||||
############### MySQL Configuration ###############
|
||||
MYSQL_HOST = os.getenv("MYSQL_HOST", "127.0.0.1")
|
||||
MYSQL_PORT = int(os.getenv("MYSQL_PORT", "3306"))
|
||||
MYSQL_USER = os.getenv("MYSQL_USER", "root")
|
||||
MYSQL_PWD = os.getenv("MYSQL_PWD", "123456")
|
||||
MYSQL_DB = os.getenv("MYSQL_DB", "mysql")
|
||||
|
||||
############### Data Path ###############
|
||||
UPLOAD_PATH = os.getenv("UPLOAD_PATH", "tmp/audio-data")
|
||||
|
||||
############### Number of Log Files ###############
|
||||
LOGS_NUM = int(os.getenv("logs_num", "0"))
|
@ -0,0 +1,37 @@
|
||||
# Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
import librosa
|
||||
import numpy as np
|
||||
from logs import LOGGER
|
||||
|
||||
|
||||
def get_audio_embedding(path):
|
||||
"""
|
||||
Use vpr_inference to generate embedding of audio
|
||||
"""
|
||||
try:
|
||||
RESAMPLE_RATE = 16000
|
||||
audio, _ = librosa.load(path, sr=RESAMPLE_RATE, mono=True)
|
||||
|
||||
# TODO add infer/python interface to get embedding, now fake it by rand
|
||||
# vpr = ECAPATDNN(checkpoint_path=None, device='cuda')
|
||||
# embedding = vpr.inference(audio)
|
||||
|
||||
embedding = np.random.rand(1, 2048)
|
||||
embedding = embedding / np.linalg.norm(embedding)
|
||||
embedding = embedding.tolist()[0]
|
||||
return embedding
|
||||
except Exception as e:
|
||||
LOGGER.error(f"Error with embedding:{e}")
|
||||
return None
|
@ -0,0 +1,166 @@
|
||||
# Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
import os
|
||||
from typing import Optional
|
||||
|
||||
import uvicorn
|
||||
from config import UPLOAD_PATH
|
||||
from diskcache import Cache
|
||||
from fastapi import FastAPI
|
||||
from fastapi import File
|
||||
from fastapi import UploadFile
|
||||
from logs import LOGGER
|
||||
from milvus_helpers import MilvusHelper
|
||||
from mysql_helpers import MySQLHelper
|
||||
from operations.count import do_count
|
||||
from operations.drop import do_drop
|
||||
from operations.load import do_load
|
||||
from operations.search import do_search
|
||||
from pydantic import BaseModel
|
||||
from starlette.middleware.cors import CORSMiddleware
|
||||
from starlette.requests import Request
|
||||
from starlette.responses import FileResponse
|
||||
|
||||
app = FastAPI()
|
||||
app.add_middleware(
|
||||
CORSMiddleware,
|
||||
allow_origins=["*"],
|
||||
allow_credentials=True,
|
||||
allow_methods=["*"],
|
||||
allow_headers=["*"])
|
||||
|
||||
MODEL = None
|
||||
MILVUS_CLI = MilvusHelper()
|
||||
MYSQL_CLI = MySQLHelper()
|
||||
|
||||
# Mkdir 'tmp/audio-data'
|
||||
if not os.path.exists(UPLOAD_PATH):
|
||||
os.makedirs(UPLOAD_PATH)
|
||||
LOGGER.info(f"Mkdir the path: {UPLOAD_PATH}")
|
||||
|
||||
|
||||
@app.get('/data')
|
||||
def audio_path(audio_path):
|
||||
# Get the audio file
|
||||
try:
|
||||
LOGGER.info(f"Successfully load audio: {audio_path}")
|
||||
return FileResponse(audio_path)
|
||||
except Exception as e:
|
||||
LOGGER.error(f"upload audio error: {e}")
|
||||
return {'status': False, 'msg': e}, 400
|
||||
|
||||
|
||||
@app.get('/progress')
|
||||
def get_progress():
|
||||
# Get the progress of dealing with data
|
||||
try:
|
||||
cache = Cache('./tmp')
|
||||
return f"current: {cache['current']}, total: {cache['total']}"
|
||||
except Exception as e:
|
||||
LOGGER.error(f"Upload data error: {e}")
|
||||
return {'status': False, 'msg': e}, 400
|
||||
|
||||
|
||||
class Item(BaseModel):
|
||||
Table: Optional[str] = None
|
||||
File: str
|
||||
|
||||
|
||||
@app.post('/audio/load')
|
||||
async def load_audios(item: Item):
|
||||
# Insert all the audio files under the file path to Milvus/MySQL
|
||||
try:
|
||||
total_num = do_load(item.Table, item.File, MILVUS_CLI, MYSQL_CLI)
|
||||
LOGGER.info(f"Successfully loaded data, total count: {total_num}")
|
||||
return {'status': True, 'msg': "Successfully loaded data!"}
|
||||
except Exception as e:
|
||||
LOGGER.error(e)
|
||||
return {'status': False, 'msg': e}, 400
|
||||
|
||||
|
||||
@app.post('/audio/search')
|
||||
async def search_audio(request: Request,
|
||||
table_name: str=None,
|
||||
audio: UploadFile=File(...)):
|
||||
# Search the uploaded audio in Milvus/MySQL
|
||||
try:
|
||||
# Save the upload data to server.
|
||||
content = await audio.read()
|
||||
query_audio_path = os.path.join(UPLOAD_PATH, audio.filename)
|
||||
with open(query_audio_path, "wb+") as f:
|
||||
f.write(content)
|
||||
host = request.headers['host']
|
||||
_, paths, distances = do_search(host, table_name, query_audio_path,
|
||||
MILVUS_CLI, MYSQL_CLI)
|
||||
names = []
|
||||
for i in paths:
|
||||
names.append(os.path.basename(i))
|
||||
res = dict(zip(paths, zip(names, distances)))
|
||||
# Sort results by distance metric, closest distances first
|
||||
res = sorted(res.items(), key=lambda item: item[1][1])
|
||||
LOGGER.info("Successfully searched similar audio!")
|
||||
return res
|
||||
except Exception as e:
|
||||
LOGGER.error(e)
|
||||
return {'status': False, 'msg': e}, 400
|
||||
|
||||
|
||||
@app.post('/audio/search/local')
|
||||
async def search_local_audio(request: Request,
|
||||
query_audio_path: str,
|
||||
table_name: str=None):
|
||||
# Search the uploaded audio in Milvus/MySQL
|
||||
try:
|
||||
host = request.headers['host']
|
||||
_, paths, distances = do_search(host, table_name, query_audio_path,
|
||||
MILVUS_CLI, MYSQL_CLI)
|
||||
names = []
|
||||
for i in paths:
|
||||
names.append(os.path.basename(i))
|
||||
res = dict(zip(paths, zip(names, distances)))
|
||||
# Sort results by distance metric, closest distances first
|
||||
res = sorted(res.items(), key=lambda item: item[1][1])
|
||||
LOGGER.info("Successfully searched similar audio!")
|
||||
return res
|
||||
except Exception as e:
|
||||
LOGGER.error(e)
|
||||
return {'status': False, 'msg': e}, 400
|
||||
|
||||
|
||||
@app.get('/audio/count')
|
||||
async def count_audio(table_name: str=None):
|
||||
# Returns the total number of vectors in the system
|
||||
try:
|
||||
num = do_count(table_name, MILVUS_CLI)
|
||||
LOGGER.info("Successfully count the number of data!")
|
||||
return num
|
||||
except Exception as e:
|
||||
LOGGER.error(e)
|
||||
return {'status': False, 'msg': e}, 400
|
||||
|
||||
|
||||
@app.post('/audio/drop')
|
||||
async def drop_tables(table_name: str=None):
|
||||
# Delete the collection of Milvus and MySQL
|
||||
try:
|
||||
status = do_drop(table_name, MILVUS_CLI, MYSQL_CLI)
|
||||
LOGGER.info("Successfully drop tables in Milvus and MySQL!")
|
||||
return status
|
||||
except Exception as e:
|
||||
LOGGER.error(e)
|
||||
return {'status': False, 'msg': e}, 400
|
||||
|
||||
|
||||
if __name__ == '__main__':
|
||||
uvicorn.run(app=app, host='127.0.0.1', port=8002)
|
@ -0,0 +1,186 @@
|
||||
# Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
import sys
|
||||
|
||||
from config import METRIC_TYPE
|
||||
from config import MILVUS_HOST
|
||||
from config import MILVUS_PORT
|
||||
from config import VECTOR_DIMENSION
|
||||
from logs import LOGGER
|
||||
from pymilvus import Collection
|
||||
from pymilvus import CollectionSchema
|
||||
from pymilvus import connections
|
||||
from pymilvus import DataType
|
||||
from pymilvus import FieldSchema
|
||||
from pymilvus import utility
|
||||
|
||||
|
||||
class MilvusHelper:
|
||||
"""
|
||||
the basic operations of PyMilvus
|
||||
|
||||
# This example shows how to:
|
||||
# 1. connect to Milvus server
|
||||
# 2. create a collection
|
||||
# 3. insert entities
|
||||
# 4. create index
|
||||
# 5. search
|
||||
# 6. delete a collection
|
||||
|
||||
"""
|
||||
|
||||
def __init__(self):
|
||||
try:
|
||||
self.collection = None
|
||||
connections.connect(host=MILVUS_HOST, port=MILVUS_PORT)
|
||||
LOGGER.debug(
|
||||
f"Successfully connect to Milvus with IP:{MILVUS_HOST} and PORT:{MILVUS_PORT}"
|
||||
)
|
||||
except Exception as e:
|
||||
LOGGER.error(f"Failed to connect Milvus: {e}")
|
||||
sys.exit(1)
|
||||
|
||||
def set_collection(self, collection_name):
|
||||
try:
|
||||
if self.has_collection(collection_name):
|
||||
self.collection = Collection(name=collection_name)
|
||||
else:
|
||||
raise Exception(
|
||||
f"There is no collection named:{collection_name}")
|
||||
except Exception as e:
|
||||
LOGGER.error(f"Failed to load data to Milvus: {e}")
|
||||
sys.exit(1)
|
||||
|
||||
def has_collection(self, collection_name):
|
||||
# Return if Milvus has the collection
|
||||
try:
|
||||
return utility.has_collection(collection_name)
|
||||
except Exception as e:
|
||||
LOGGER.error(f"Failed to load data to Milvus: {e}")
|
||||
sys.exit(1)
|
||||
|
||||
def create_collection(self, collection_name):
|
||||
# Create milvus collection if not exists
|
||||
try:
|
||||
if not self.has_collection(collection_name):
|
||||
field1 = FieldSchema(
|
||||
name="id",
|
||||
dtype=DataType.INT64,
|
||||
descrition="int64",
|
||||
is_primary=True,
|
||||
auto_id=True)
|
||||
field2 = FieldSchema(
|
||||
name="embedding",
|
||||
dtype=DataType.FLOAT_VECTOR,
|
||||
descrition="speaker embeddings",
|
||||
dim=VECTOR_DIMENSION,
|
||||
is_primary=False)
|
||||
schema = CollectionSchema(
|
||||
fields=[field1, field2], description="embeddings info")
|
||||
self.collection = Collection(
|
||||
name=collection_name, schema=schema)
|
||||
LOGGER.debug(f"Create Milvus collection: {collection_name}")
|
||||
else:
|
||||
self.set_collection(collection_name)
|
||||
return "OK"
|
||||
except Exception as e:
|
||||
LOGGER.error(f"Failed to load data to Milvus: {e}")
|
||||
sys.exit(1)
|
||||
|
||||
def insert(self, collection_name, vectors):
|
||||
# Batch insert vectors to milvus collection
|
||||
try:
|
||||
self.create_collection(collection_name)
|
||||
data = [vectors]
|
||||
self.set_collection(collection_name)
|
||||
mr = self.collection.insert(data)
|
||||
ids = mr.primary_keys
|
||||
self.collection.load()
|
||||
LOGGER.debug(
|
||||
f"Insert vectors to Milvus in collection: {collection_name} with {len(vectors)} rows"
|
||||
)
|
||||
return ids
|
||||
except Exception as e:
|
||||
LOGGER.error(f"Failed to load data to Milvus: {e}")
|
||||
sys.exit(1)
|
||||
|
||||
def create_index(self, collection_name):
|
||||
# Create IVF_FLAT index on milvus collection
|
||||
try:
|
||||
self.set_collection(collection_name)
|
||||
default_index = {
|
||||
"index_type": "IVF_SQ8",
|
||||
"metric_type": METRIC_TYPE,
|
||||
"params": {
|
||||
"nlist": 16384
|
||||
}
|
||||
}
|
||||
status = self.collection.create_index(
|
||||
field_name="embedding", index_params=default_index)
|
||||
if not status.code:
|
||||
LOGGER.debug(
|
||||
f"Successfully create index in collection:{collection_name} with param:{default_index}"
|
||||
)
|
||||
return status
|
||||
else:
|
||||
raise Exception(status.message)
|
||||
except Exception as e:
|
||||
LOGGER.error(f"Failed to create index: {e}")
|
||||
sys.exit(1)
|
||||
|
||||
def delete_collection(self, collection_name):
|
||||
# Delete Milvus collection
|
||||
try:
|
||||
self.set_collection(collection_name)
|
||||
self.collection.drop()
|
||||
LOGGER.debug("Successfully drop collection!")
|
||||
return "ok"
|
||||
except Exception as e:
|
||||
LOGGER.error(f"Failed to drop collection: {e}")
|
||||
sys.exit(1)
|
||||
|
||||
def search_vectors(self, collection_name, vectors, top_k):
|
||||
# Search vector in milvus collection
|
||||
try:
|
||||
self.set_collection(collection_name)
|
||||
search_params = {
|
||||
"metric_type": METRIC_TYPE,
|
||||
"params": {
|
||||
"nprobe": 16
|
||||
}
|
||||
}
|
||||
# data = [vectors]
|
||||
res = self.collection.search(
|
||||
vectors,
|
||||
anns_field="embedding",
|
||||
param=search_params,
|
||||
limit=top_k)
|
||||
LOGGER.debug(f"Successfully search in collection: {res}")
|
||||
return res
|
||||
except Exception as e:
|
||||
LOGGER.error(f"Failed to search vectors in Milvus: {e}")
|
||||
sys.exit(1)
|
||||
|
||||
def count(self, collection_name):
|
||||
# Get the number of milvus collection
|
||||
try:
|
||||
self.set_collection(collection_name)
|
||||
num = self.collection.num_entities
|
||||
LOGGER.debug(
|
||||
f"Successfully get the num:{num} of the collection:{collection_name}"
|
||||
)
|
||||
return num
|
||||
except Exception as e:
|
||||
LOGGER.error(f"Failed to count vectors in Milvus: {e}")
|
||||
sys.exit(1)
|
@ -0,0 +1,133 @@
|
||||
# Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
import sys
|
||||
|
||||
import pymysql
|
||||
from config import MYSQL_DB
|
||||
from config import MYSQL_HOST
|
||||
from config import MYSQL_PORT
|
||||
from config import MYSQL_PWD
|
||||
from config import MYSQL_USER
|
||||
from logs import LOGGER
|
||||
|
||||
|
||||
class MySQLHelper():
|
||||
"""
|
||||
the basic operations of PyMySQL
|
||||
|
||||
# This example shows how to:
|
||||
# 1. connect to MySQL server
|
||||
# 2. create a table
|
||||
# 3. insert data to table
|
||||
# 4. search by milvus ids
|
||||
# 5. delete table
|
||||
"""
|
||||
|
||||
def __init__(self):
|
||||
self.conn = pymysql.connect(
|
||||
host=MYSQL_HOST,
|
||||
user=MYSQL_USER,
|
||||
port=MYSQL_PORT,
|
||||
password=MYSQL_PWD,
|
||||
database=MYSQL_DB,
|
||||
local_infile=True)
|
||||
self.cursor = self.conn.cursor()
|
||||
|
||||
def test_connection(self):
|
||||
try:
|
||||
self.conn.ping()
|
||||
except Exception:
|
||||
self.conn = pymysql.connect(
|
||||
host=MYSQL_HOST,
|
||||
user=MYSQL_USER,
|
||||
port=MYSQL_PORT,
|
||||
password=MYSQL_PWD,
|
||||
database=MYSQL_DB,
|
||||
local_infile=True)
|
||||
self.cursor = self.conn.cursor()
|
||||
|
||||
def create_mysql_table(self, table_name):
|
||||
# Create mysql table if not exists
|
||||
self.test_connection()
|
||||
sql = "create table if not exists " + table_name + "(milvus_id TEXT, audio_path TEXT);"
|
||||
try:
|
||||
self.cursor.execute(sql)
|
||||
LOGGER.debug(f"MYSQL create table: {table_name} with sql: {sql}")
|
||||
except Exception as e:
|
||||
LOGGER.error(f"MYSQL ERROR: {e} with sql: {sql}")
|
||||
sys.exit(1)
|
||||
|
||||
def load_data_to_mysql(self, table_name, data):
|
||||
# Batch insert (Milvus_ids, img_path) to mysql
|
||||
self.test_connection()
|
||||
sql = "insert into " + table_name + " (milvus_id,audio_path) values (%s,%s);"
|
||||
try:
|
||||
self.cursor.executemany(sql, data)
|
||||
self.conn.commit()
|
||||
LOGGER.debug(
|
||||
f"MYSQL loads data to table: {table_name} successfully")
|
||||
except Exception as e:
|
||||
LOGGER.error(f"MYSQL ERROR: {e} with sql: {sql}")
|
||||
sys.exit(1)
|
||||
|
||||
def search_by_milvus_ids(self, ids, table_name):
|
||||
# Get the img_path according to the milvus ids
|
||||
self.test_connection()
|
||||
str_ids = str(ids).replace('[', '').replace(']', '')
|
||||
sql = "select audio_path from " + table_name + " where milvus_id in (" + str_ids + ") order by field (milvus_id," + str_ids + ");"
|
||||
try:
|
||||
self.cursor.execute(sql)
|
||||
results = self.cursor.fetchall()
|
||||
results = [res[0] for res in results]
|
||||
LOGGER.debug("MYSQL search by milvus id.")
|
||||
return results
|
||||
except Exception as e:
|
||||
LOGGER.error(f"MYSQL ERROR: {e} with sql: {sql}")
|
||||
sys.exit(1)
|
||||
|
||||
def delete_table(self, table_name):
|
||||
# Delete mysql table if exists
|
||||
self.test_connection()
|
||||
sql = "drop table if exists " + table_name + ";"
|
||||
try:
|
||||
self.cursor.execute(sql)
|
||||
LOGGER.debug(f"MYSQL delete table:{table_name}")
|
||||
except Exception as e:
|
||||
LOGGER.error(f"MYSQL ERROR: {e} with sql: {sql}")
|
||||
sys.exit(1)
|
||||
|
||||
def delete_all_data(self, table_name):
|
||||
# Delete all the data in mysql table
|
||||
self.test_connection()
|
||||
sql = 'delete from ' + table_name + ';'
|
||||
try:
|
||||
self.cursor.execute(sql)
|
||||
self.conn.commit()
|
||||
LOGGER.debug(f"MYSQL delete all data in table:{table_name}")
|
||||
except Exception as e:
|
||||
LOGGER.error(f"MYSQL ERROR: {e} with sql: {sql}")
|
||||
sys.exit(1)
|
||||
|
||||
def count_table(self, table_name):
|
||||
# Get the number of mysql table
|
||||
self.test_connection()
|
||||
sql = "select count(milvus_id) from " + table_name + ";"
|
||||
try:
|
||||
self.cursor.execute(sql)
|
||||
results = self.cursor.fetchall()
|
||||
LOGGER.debug(f"MYSQL count table:{table_name}")
|
||||
return results[0][0]
|
||||
except Exception as e:
|
||||
LOGGER.error(f"MYSQL ERROR: {e} with sql: {sql}")
|
||||
sys.exit(1)
|
@ -0,0 +1,13 @@
|
||||
# Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
@ -0,0 +1,33 @@
|
||||
# Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
import sys
|
||||
|
||||
from config import DEFAULT_TABLE
|
||||
from logs import LOGGER
|
||||
|
||||
|
||||
def do_count(table_name, milvus_cli):
|
||||
"""
|
||||
Returns the total number of vectors in the system
|
||||
"""
|
||||
if not table_name:
|
||||
table_name = DEFAULT_TABLE
|
||||
try:
|
||||
if not milvus_cli.has_collection(table_name):
|
||||
return None
|
||||
num = milvus_cli.count(table_name)
|
||||
return num
|
||||
except Exception as e:
|
||||
LOGGER.error(f"Error attempting to count table {e}")
|
||||
sys.exit(1)
|
@ -0,0 +1,34 @@
|
||||
# Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
import sys
|
||||
|
||||
from config import DEFAULT_TABLE
|
||||
from logs import LOGGER
|
||||
|
||||
|
||||
def do_drop(table_name, milvus_cli, mysql_cli):
|
||||
"""
|
||||
Delete the collection of Milvus and MySQL
|
||||
"""
|
||||
if not table_name:
|
||||
table_name = DEFAULT_TABLE
|
||||
try:
|
||||
if not milvus_cli.has_collection(table_name):
|
||||
return "Collection is not exist"
|
||||
status = milvus_cli.delete_collection(table_name)
|
||||
mysql_cli.delete_table(table_name)
|
||||
return status
|
||||
except Exception as e:
|
||||
LOGGER.error(f"Error attempting to drop table: {e}")
|
||||
sys.exit(1)
|
@ -0,0 +1,86 @@
|
||||
# Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
import os
|
||||
import sys
|
||||
|
||||
from diskcache import Cache
|
||||
from encode import get_audio_embedding
|
||||
|
||||
from ..config import DEFAULT_TABLE
|
||||
from ..logs import LOGGER
|
||||
|
||||
|
||||
def get_audios(path):
|
||||
"""
|
||||
List all wav and aif files recursively under the path folder.
|
||||
"""
|
||||
supported_formats = [".wav", ".mp3", ".ogg", ".flac", ".m4a"]
|
||||
return [
|
||||
item
|
||||
for sublist in [[os.path.join(dir, file) for file in files]
|
||||
for dir, _, files in list(os.walk(path))]
|
||||
for item in sublist if os.path.splitext(item)[1] in supported_formats
|
||||
]
|
||||
|
||||
|
||||
def extract_features(audio_dir):
|
||||
"""
|
||||
Get the vector of audio
|
||||
"""
|
||||
try:
|
||||
cache = Cache('./tmp')
|
||||
feats = []
|
||||
names = []
|
||||
audio_list = get_audios(audio_dir)
|
||||
total = len(audio_list)
|
||||
cache['total'] = total
|
||||
for i, audio_path in enumerate(audio_list):
|
||||
norm_feat = get_audio_embedding(audio_path)
|
||||
if norm_feat is None:
|
||||
continue
|
||||
feats.append(norm_feat)
|
||||
names.append(audio_path.encode())
|
||||
cache['current'] = i + 1
|
||||
print(
|
||||
f"Extracting feature from audio No. {i + 1} , {total} audios in total"
|
||||
)
|
||||
return feats, names
|
||||
except Exception as e:
|
||||
LOGGER.error(f"Error with extracting feature from audio {e}")
|
||||
sys.exit(1)
|
||||
|
||||
|
||||
def format_data(ids, names):
|
||||
"""
|
||||
Combine the id of the vector and the name of the audio into a list
|
||||
"""
|
||||
data = []
|
||||
for i in range(len(ids)):
|
||||
value = (str(ids[i]), names[i])
|
||||
data.append(value)
|
||||
return data
|
||||
|
||||
|
||||
def do_load(table_name, audio_dir, milvus_cli, mysql_cli):
|
||||
"""
|
||||
Import vectors to Milvus and data to Mysql respectively
|
||||
"""
|
||||
if not table_name:
|
||||
table_name = DEFAULT_TABLE
|
||||
vectors, names = extract_features(audio_dir)
|
||||
ids = milvus_cli.insert(table_name, vectors)
|
||||
milvus_cli.create_index(table_name)
|
||||
mysql_cli.create_mysql_table(table_name)
|
||||
mysql_cli.load_data_to_mysql(table_name, format_data(ids, names))
|
||||
return len(ids)
|
@ -0,0 +1,40 @@
|
||||
# Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
import sys
|
||||
|
||||
from config import DEFAULT_TABLE
|
||||
from config import TOP_K
|
||||
from encode import get_audio_embedding
|
||||
from logs import LOGGER
|
||||
|
||||
|
||||
def do_search(host, table_name, audio_path, milvus_cli, mysql_cli):
|
||||
"""
|
||||
Search the uploaded audio in Milvus/MySQL
|
||||
"""
|
||||
try:
|
||||
if not table_name:
|
||||
table_name = DEFAULT_TABLE
|
||||
feat = get_audio_embedding(audio_path)
|
||||
vectors = milvus_cli.search_vectors(table_name, [feat], TOP_K)
|
||||
vids = [str(x.id) for x in vectors[0]]
|
||||
paths = mysql_cli.search_by_milvus_ids(vids, table_name)
|
||||
distances = [x.distance for x in vectors[0]]
|
||||
for i in range(len(paths)):
|
||||
tmp = "http://" + str(host) + "/data?audio_path=" + str(paths[i])
|
||||
paths[i] = tmp
|
||||
return vids, paths, distances
|
||||
except Exception as e:
|
||||
LOGGER.error(f"Error with search: {e}")
|
||||
sys.exit(1)
|
@ -0,0 +1,96 @@
|
||||
# Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
import zipfile
|
||||
|
||||
import gdown
|
||||
from fastapi.testclient import TestClient
|
||||
from main import app
|
||||
|
||||
client = TestClient(app)
|
||||
|
||||
|
||||
def download_audio_data():
|
||||
"""
|
||||
download audio data
|
||||
"""
|
||||
url = 'https://drive.google.com/uc?id=1bKu21JWBfcZBuEuzFEvPoAX6PmRrgnUp'
|
||||
gdown.download(url)
|
||||
|
||||
with zipfile.ZipFile('example_audio.zip', 'r') as zip_ref:
|
||||
zip_ref.extractall('./example_audio')
|
||||
|
||||
|
||||
def test_drop():
|
||||
"""
|
||||
Delete the collection of Milvus and MySQL
|
||||
"""
|
||||
response = client.post("/audio/drop")
|
||||
assert response.status_code == 200
|
||||
|
||||
|
||||
def test_load():
|
||||
"""
|
||||
Insert all the audio files under the file path to Milvus/MySQL
|
||||
"""
|
||||
response = client.post("/audio/load", json={"File": "./example_audio"})
|
||||
assert response.status_code == 200
|
||||
assert response.json() == {
|
||||
'status': True,
|
||||
'msg': "Successfully loaded data!"
|
||||
}
|
||||
|
||||
|
||||
def test_progress():
|
||||
"""
|
||||
Get the progress of dealing with data
|
||||
"""
|
||||
response = client.get("/progress")
|
||||
assert response.status_code == 200
|
||||
assert response.json() == "current: 20, total: 20"
|
||||
|
||||
|
||||
def test_count():
|
||||
"""
|
||||
Returns the total number of vectors in the system
|
||||
"""
|
||||
response = client.get("audio/count")
|
||||
assert response.status_code == 200
|
||||
assert response.json() == 20
|
||||
|
||||
|
||||
def test_search():
|
||||
"""
|
||||
Search the uploaded audio in Milvus/MySQL
|
||||
"""
|
||||
response = client.post(
|
||||
"/audio/search/local?query_audio_path=.%2Fexample_audio%2Ftest.wav")
|
||||
assert response.status_code == 200
|
||||
assert len(response.json()) == 10
|
||||
|
||||
|
||||
def test_data():
|
||||
"""
|
||||
Get the audio file
|
||||
"""
|
||||
response = client.get("/data?audio_path=.%2Fexample_audio%2Ftest.wav")
|
||||
assert response.status_code == 200
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
download_audio_data()
|
||||
test_drop()
|
||||
test_load()
|
||||
test_count()
|
||||
test_search()
|
||||
test_drop()
|
Loading…
Reference in new issue