You can not select more than 25 topics
Topics must start with a letter or number, can include dashes ('-') and can be up to 35 characters long.
42 lines
1.5 KiB
42 lines
1.5 KiB
# Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved.
|
|
#
|
|
# Licensed under the Apache License, Version 2.0 (the "License");
|
|
# you may not use this file except in compliance with the License.
|
|
# You may obtain a copy of the License at
|
|
#
|
|
# http://www.apache.org/licenses/LICENSE-2.0
|
|
#
|
|
# Unless required by applicable law or agreed to in writing, software
|
|
# distributed under the License is distributed on an "AS IS" BASIS,
|
|
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
|
# See the License for the specific language governing permissions and
|
|
# limitations under the License.
|
|
import sys
|
|
|
|
from config import DEFAULT_TABLE
|
|
from config import TOP_K
|
|
from encode import get_audio_embedding
|
|
from logs import LOGGER
|
|
|
|
|
|
def do_search(host, table_name, audio_path, milvus_cli, mysql_cli):
|
|
"""
|
|
Search the uploaded audio in Milvus/MySQL
|
|
"""
|
|
try:
|
|
if not table_name:
|
|
table_name = DEFAULT_TABLE
|
|
feat = get_audio_embedding(audio_path)
|
|
vectors = milvus_cli.search_vectors(table_name, [feat], TOP_K)
|
|
vids = [str(x.id) for x in vectors[0]]
|
|
paths = mysql_cli.search_by_milvus_ids(vids, table_name)
|
|
distances = [x.distance for x in vectors[0]]
|
|
for i in range(len(paths)):
|
|
tmp = "http://" + str(host) + "/data?audio_path=" + str(paths[i])
|
|
paths[i] = tmp
|
|
distances[i] = (1 - distances[i]) * 100
|
|
return vids, paths, distances
|
|
except Exception as e:
|
|
LOGGER.error(f"Error with search: {e}")
|
|
sys.exit(1)
|