|
|
# vpr Demo 没有使用 mysql 与 muilvs, 仅用于docker演示
|
|
|
import logging
|
|
|
import faiss
|
|
|
from matplotlib import use
|
|
|
import numpy as np
|
|
|
from .sql_helper import DataBase
|
|
|
from .vpr_encode import get_audio_embedding
|
|
|
|
|
|
class VPR:
|
|
|
def __init__(self, db_path, dim, top_k) -> None:
|
|
|
# 初始化
|
|
|
self.db_path = db_path
|
|
|
self.dim = dim
|
|
|
self.top_k = top_k
|
|
|
self.dtype = np.float32
|
|
|
self.vpr_idx = 0
|
|
|
|
|
|
# db 初始化
|
|
|
self.db = DataBase(db_path)
|
|
|
|
|
|
# faiss 初始化
|
|
|
index_ip = faiss.IndexFlatIP(dim)
|
|
|
self.index_ip = faiss.IndexIDMap(index_ip)
|
|
|
self.init()
|
|
|
|
|
|
def init(self):
|
|
|
# demo 初始化,把 mysql中的向量注册到 faiss 中
|
|
|
sql_dbs = self.db.select_all()
|
|
|
if sql_dbs:
|
|
|
for sql_db in sql_dbs:
|
|
|
idx = sql_db['id']
|
|
|
vc_bs64 = sql_db['vector']
|
|
|
vc = self.db.decode_vector(vc_bs64)
|
|
|
if len(vc.shape) == 1:
|
|
|
vc = np.expand_dims(vc, axis=0)
|
|
|
# 构建数据库
|
|
|
self.index_ip.add_with_ids(vc, np.array((idx,)).astype('int64'))
|
|
|
logging.info("faiss 构建完毕")
|
|
|
|
|
|
def faiss_enroll(self, idx, vc):
|
|
|
self.index_ip.add_with_ids(vc, np.array((idx,)).astype('int64'))
|
|
|
|
|
|
def vpr_enroll(self, username, wav_path):
|
|
|
# 注册声纹
|
|
|
emb = get_audio_embedding(wav_path)
|
|
|
emb = np.expand_dims(emb, axis=0)
|
|
|
if emb is not None:
|
|
|
emb_bs64 = self.db.encode_vector(emb)
|
|
|
last_idx, mess = self.db.insert_one(username, emb_bs64, wav_path)
|
|
|
if last_idx:
|
|
|
# faiss 注册
|
|
|
self.faiss_enroll(last_idx, emb)
|
|
|
else:
|
|
|
last_idx, mess = None
|
|
|
return last_idx
|
|
|
|
|
|
def vpr_recog(self, wav_path):
|
|
|
# 识别声纹
|
|
|
emb_search = get_audio_embedding(wav_path)
|
|
|
|
|
|
if emb_search is not None:
|
|
|
emb_search = np.expand_dims(emb_search, axis=0)
|
|
|
D, I = self.index_ip.search(emb_search, self.top_k)
|
|
|
D = D.tolist()[0]
|
|
|
I = I.tolist()[0]
|
|
|
return [(round(D[i] * 100, 2 ), I[i]) for i in range(len(D)) if I[i] != -1]
|
|
|
else:
|
|
|
logging.error("识别失败")
|
|
|
return None
|
|
|
|
|
|
def do_search_vpr(self, wav_path):
|
|
|
spk_ids, paths, scores = [], [], []
|
|
|
recog_result = self.vpr_recog(wav_path)
|
|
|
for score, idx in recog_result:
|
|
|
username = self.db.select_by_id(idx)[0]['username']
|
|
|
if username not in spk_ids:
|
|
|
spk_ids.append(username)
|
|
|
scores.append(score)
|
|
|
paths.append("")
|
|
|
return spk_ids, paths, scores
|
|
|
|
|
|
def vpr_del(self, username):
|
|
|
# 根据用户username, 删除声纹
|
|
|
# 查用户ID,删除对应向量
|
|
|
res = self.db.select_by_username(username)
|
|
|
for r in res:
|
|
|
idx = r['id']
|
|
|
self.index_ip.remove_ids(np.array((idx,)).astype('int64'))
|
|
|
|
|
|
self.db.drop_by_username(username)
|
|
|
|
|
|
def vpr_list(self):
|
|
|
# 获取数据列表
|
|
|
return self.db.select_all()
|
|
|
|
|
|
def do_list(self):
|
|
|
spk_ids, vpr_ids = [], []
|
|
|
for res in self.db.select_all():
|
|
|
spk_ids.append(res['username'])
|
|
|
vpr_ids.append(res['id'])
|
|
|
return spk_ids, vpr_ids
|
|
|
|
|
|
def do_get_wav(self, vpr_idx):
|
|
|
res = self.db.select_by_id(vpr_idx)
|
|
|
return res[0]['wavpath']
|
|
|
|
|
|
|
|
|
def vpr_data(self, idx):
|
|
|
# 获取对应ID的数据
|
|
|
res = self.db.select_by_id(idx)
|
|
|
return res
|
|
|
|
|
|
def vpr_droptable(self):
|
|
|
# 删除表
|
|
|
self.db.drop_table()
|
|
|
# 清空 faiss
|
|
|
self.index_ip.reset()
|
|
|
|