You can not select more than 25 topics Topics must start with a letter or number, can include dashes ('-') and can be up to 35 characters long.
PaddleSpeech/demos/speech_web/speech_server/src/SpeechBase/sql_helper.py

116 lines
3.2 KiB

import base64
import sqlite3
import os
import numpy as np
from pkg_resources import resource_stream
def dict_factory(cursor, row):
d = {}
for idx, col in enumerate(cursor.description):
d[col[0]] = row[idx]
return d
class DataBase(object):
def __init__(self, db_path:str):
db_path = os.path.realpath(db_path)
if os.path.exists(db_path):
self.db_path = db_path
else:
db_path_dir = os.path.dirname(db_path)
os.makedirs(db_path_dir, exist_ok=True)
self.db_path = db_path
self.conn = sqlite3.connect(self.db_path)
self.conn.row_factory = dict_factory
self.cursor = self.conn.cursor()
self.init_database()
def init_database(self):
"""
初始化数据库, 若表不存在则创建
"""
sql = """
CREATE TABLE IF NOT EXISTS vprtable (
`id` INTEGER PRIMARY KEY AUTOINCREMENT,
`username` TEXT NOT NULL,
`vector` TEXT NOT NULL,
`wavpath` TEXT NOT NULL
);
"""
self.cursor.execute(sql)
self.conn.commit()
def execute_base(self, sql, data_dict):
self.cursor.execute(sql, data_dict)
self.conn.commit()
def insert_one(self, username, vector_base64:str, wav_path):
if not os.path.exists(wav_path):
return None, "wav not exists"
else:
sql = f"""
insert into
vprtable (username, vector, wavpath)
values (?, ?, ?)
"""
try:
self.cursor.execute(sql, (username, vector_base64, wav_path))
self.conn.commit()
lastidx = self.cursor.lastrowid
return lastidx, "data insert success"
except Exception as e:
print(e)
return None, e
def select_all(self):
sql = """
SELECT * from vprtable
"""
result = self.cursor.execute(sql).fetchall()
return result
def select_by_id(self, vpr_id):
sql = f"""
SELECT * from vprtable WHERE `id` = {vpr_id}
"""
result = self.cursor.execute(sql).fetchall()
return result
def select_by_username(self, username):
sql = f"""
SELECT * from vprtable WHERE `username` = '{username}'
"""
result = self.cursor.execute(sql).fetchall()
return result
def drop_by_username(self, username):
sql = f"""
DELETE from vprtable WHERE `username`='{username}'
"""
self.cursor.execute(sql)
self.conn.commit()
def drop_all(self):
sql = f"""
DELETE from vprtable
"""
self.cursor.execute(sql)
self.conn.commit()
def drop_table(self):
sql = f"""
DROP TABLE vprtable
"""
self.cursor.execute(sql)
self.conn.commit()
def encode_vector(self, vector:np.ndarray):
return base64.b64encode(vector).decode('utf8')
def decode_vector(self, vector_base64, dtype=np.float32):
b = base64.b64decode(vector_base64)
vc = np.frombuffer(b, dtype=dtype)
return vc