You can not select more than 25 topics Topics must start with a letter or number, can include dashes ('-') and can be up to 35 characters long.
PaddleSpeech/demos/speech_web/speech_server/src/SpeechBase/vpr.py

119 lines
3.8 KiB

This file contains ambiguous Unicode characters!

This file contains ambiguous Unicode characters that may be confused with others in your current locale. If your use case is intentional and legitimate, you can safely ignore this warning. Use the Escape button to highlight these characters.

# vpr Demo 没有使用 mysql 与 muilvs, 仅用于docker演示
import logging
import faiss
from matplotlib import use
import numpy as np
from .sql_helper import DataBase
from .vpr_encode import get_audio_embedding
class VPR:
def __init__(self, db_path, dim, top_k) -> None:
# 初始化
self.db_path = db_path
self.dim = dim
self.top_k = top_k
self.dtype = np.float32
self.vpr_idx = 0
# db 初始化
self.db = DataBase(db_path)
# faiss 初始化
index_ip = faiss.IndexFlatIP(dim)
self.index_ip = faiss.IndexIDMap(index_ip)
self.init()
def init(self):
# demo 初始化,把 mysql中的向量注册到 faiss 中
sql_dbs = self.db.select_all()
if sql_dbs:
for sql_db in sql_dbs:
idx = sql_db['id']
vc_bs64 = sql_db['vector']
vc = self.db.decode_vector(vc_bs64)
if len(vc.shape) == 1:
vc = np.expand_dims(vc, axis=0)
# 构建数据库
self.index_ip.add_with_ids(vc, np.array((idx,)).astype('int64'))
logging.info("faiss 构建完毕")
def faiss_enroll(self, idx, vc):
self.index_ip.add_with_ids(vc, np.array((idx,)).astype('int64'))
def vpr_enroll(self, username, wav_path):
# 注册声纹
emb = get_audio_embedding(wav_path)
emb = np.expand_dims(emb, axis=0)
if emb is not None:
emb_bs64 = self.db.encode_vector(emb)
last_idx, mess = self.db.insert_one(username, emb_bs64, wav_path)
if last_idx:
# faiss 注册
self.faiss_enroll(last_idx, emb)
else:
last_idx, mess = None
return last_idx
def vpr_recog(self, wav_path):
# 识别声纹
emb_search = get_audio_embedding(wav_path)
if emb_search is not None:
emb_search = np.expand_dims(emb_search, axis=0)
D, I = self.index_ip.search(emb_search, self.top_k)
D = D.tolist()[0]
I = I.tolist()[0]
return [(round(D[i] * 100, 2 ), I[i]) for i in range(len(D)) if I[i] != -1]
else:
logging.error("识别失败")
return None
def do_search_vpr(self, wav_path):
spk_ids, paths, scores = [], [], []
recog_result = self.vpr_recog(wav_path)
for score, idx in recog_result:
username = self.db.select_by_id(idx)[0]['username']
if username not in spk_ids:
spk_ids.append(username)
scores.append(score)
paths.append("")
return spk_ids, paths, scores
def vpr_del(self, username):
# 根据用户username, 删除声纹
# 查用户ID删除对应向量
res = self.db.select_by_username(username)
for r in res:
idx = r['id']
self.index_ip.remove_ids(np.array((idx,)).astype('int64'))
self.db.drop_by_username(username)
def vpr_list(self):
# 获取数据列表
return self.db.select_all()
def do_list(self):
spk_ids, vpr_ids = [], []
for res in self.db.select_all():
spk_ids.append(res['username'])
vpr_ids.append(res['id'])
return spk_ids, vpr_ids
def do_get_wav(self, vpr_idx):
res = self.db.select_by_id(vpr_idx)
return res[0]['wavpath']
def vpr_data(self, idx):
# 获取对应ID的数据
res = self.db.select_by_id(idx)
return res
def vpr_droptable(self):
# 删除表
self.db.drop_table()
# 清空 faiss
self.index_ip.reset()