You can not select more than 25 topics Topics must start with a letter or number, can include dashes ('-') and can be up to 35 characters long.

587 lines
25 KiB

This file contains ambiguous Unicode characters!

This file contains ambiguous Unicode characters that may be confused with others in your current locale. If your use case is intentional and legitimate, you can safely ignore this warning. Use the Escape button to highlight these characters.

# agent/schema/schema_utils.py
import jieba
import numpy as np
from collections import Counter
from config import *
from agent.utils import clean_text, find_dict_by_element
def parse_table_structures(input_text):
"""
用于解析输入文本并提取表结构信息,返回一个字典,其中表名作为键,表结构作为值
Parse the input text to extract table structures.
The format is expected as pairs: "table_name === table_structure".
Parameters:
input_text (str): The raw text containing table structures.
Returns:
tables_dict (dict): A dictionary where keys are table names and
values are the associated table structures.
"""
tables_text = input_text.split('===')[1:]
tables_dict = {tables_text[i]: tables_text[i + 1] for i in range(0, len(tables_text), 2)}
return tables_dict
def map_chinese_to_english_tables(chinese_names, english_names):
"""
将中文的表名映射到对应的英文表名
Map Chinese table names to their corresponding English table names.
For each Chinese name, there is a matching English name
(case-insensitive comparison).
Parameters:
chinese_names (list): A list of Chinese table names.
english_names (list): A list of English table names.
Returns:
name_map (dict): A dictionary mapping Chinese table names to English table names.
"""
name_map = {}
for cname in chinese_names:
# Find the corresponding English name (case-insensitive match)
english_match = [en for en in english_names if str(en).lower() == cname.lower()][0]
name_map[cname] = english_match
return name_map
def find_value_in_list_of_dicts(dict_list, key_to_match, value_to_match, key_to_return):
"""
在字典列表中查找满足条件的字典,并返回其中指定键的值。
Search through a list of dictionaries and find the first dictionary where
the value of key_to_match equals value_to_match, then return the value
associated with key_to_return.
Parameters:
dict_list (list): A list of dictionaries to search through.
key_to_match (str): The key whose value we want to match.
value_to_match (str): The value we are looking for.
key_to_return (str): The key whose value we want to return.
Returns:
(str): The value associated with key_to_return in the matching dictionary,
or an empty string if no match is found.
"""
for dictionary in dict_list:
if dictionary.get(key_to_match) == value_to_match:
return dictionary.get(key_to_return)
return ''
def filter_table_comments(question, table_comments):
"""
根据输入问题从表注释列表中筛选出与问题相关的注释。
Filter a list of table comments based on the given question.
Uses jieba for segmentation and removes stopwords, returning only comments
that contain at least one of the segmented keywords.
Parameters:
question (str): The question text.
table_comments (list): A list of comment strings to filter.
Returns:
filtered_comments (list): Filtered list of comments.
"""
stopwords = ['', '', '', '多少', '', '', '']
seg_list = list(jieba.cut(question, cut_all=False))
filtered_seg_list = [word for word in seg_list if word not in stopwords]
filtered_comments = []
for comment in table_comments:
if any(keyword in comment for keyword in filtered_seg_list):
filtered_comments.append(comment)
return filtered_comments
def get_table_schema(question=''):
"""
获取表格的结构信息以及字段的注释
Retrieve table schemas along with optional filtered field comments.
If a question is provided, the comments will be filtered based on
question keywords.
The function:
1. Maps Chinese table names to English table names.
2. For each table, retrieves its structure and finds associated comments.
3. If a question is provided, filter the comments based on keywords extracted from the question.
Parameters:
question (str): The question text. If empty, no filtering is performed.
Returns:
table_maps (list): A list of dictionaries, each containing table schema information.
{
'数据表名': EnglishTableName,
'数据表结构': TableStructure,
'字段注释': FilteredComments (optional if question is provided)
}
"""
if QUESTION_TYPE == "全股":
parsed_tables = parse_table_structures(input_text_all)
database_L = database_L_all
database_table_en = database_table_en_all
elif QUESTION_TYPE == "港股":
parsed_tables = parse_table_structures(input_text_hk)
database_L = database_L_hk
database_table_en = database_table_en_hk
elif QUESTION_TYPE == "美股":
parsed_tables = parse_table_structures(input_text_us)
database_L = database_L_us
database_table_en = database_table_en_us
elif QUESTION_TYPE == "A股":
parsed_tables = parse_table_structures(input_text_cn)
database_L = database_L_cn
database_table_en = database_table_en_cn
else:
parsed_tables = parse_table_structures(input_text_all)
database_L = database_L_all
database_table_en = database_table_en_all
# Clean up keys and values
cleaned_tables = {
k.replace(' ', '').replace('表结构', ''): v.replace('--', '')
for k, v in parsed_tables.items()
}
# List of Chinese table names (keys)
chinese_table_names = list(cleaned_tables.keys())
name_map = map_chinese_to_english_tables(chinese_table_names, database_table_en)
table_maps = []
for cname, structure in cleaned_tables.items():
english_name = name_map.get(cname)
comments = find_value_in_list_of_dicts(database_L, '数据表名', english_name, '注释')
if question == '':
# No filtering, just return table name and structure
table_map = {
'数据表名': english_name,
'数据表结构': structure
}
else:
# Filter comments based on question
filtered_comments = filter_table_comments(question, comments)
table_map = {
'数据表名': english_name,
'数据表结构': structure,
'字段注释': filtered_comments
}
table_maps.append(table_map)
return table_maps
def to_get_question_columns(question):
"""
Given a question (string) and a global variable database_L_zh (list of dicts),
find 列名 that correspond to 列名中文描述 mentioned in the question.
If any matching columns are found, return a message instructing the user to
use these column names directly for data querying. If none are found, return an empty string.
Parameters:
question (str): The input question text.
Returns:
str: A message with identified column names or an empty string if none found.
"""
database_L_zh = database_L_zh_all
if len(QUESTION_TYPE_LIST) == 0:
database_L_zh = database_L_zh_all
elif len(QUESTION_TYPE_LIST) >= 2:
database_L_zh = database_L_zh_all
elif "港股" in QUESTION_TYPE_LIST:
database_L_zh = database_L_zh_hk
elif "美股" in QUESTION_TYPE_LIST:
database_L_zh = database_L_zh_us
elif "A股" in QUESTION_TYPE_LIST:
database_L_zh = database_L_zh_cn
else:
database_L_zh = database_L_zh_all
L_num = []
for items in database_L_zh:
L_num += items['列名中文描述']
# Get unique column descriptions
L_num_new = [item for item, count in Counter(L_num).items() if count == 1]
# Drop NaN if any
series_num = pd.Series(L_num_new)
L_num_new = list(series_num.dropna())
# Remove known irrelevant items
irrelevant_items = ['年度', '占比']
for irr in irrelevant_items:
if irr in L_num_new:
L_num_new.remove(irr)
matched_columns = []
for col_descs in L_num_new:
col_desc_another = items_another[col_descs]
for col_desc in col_desc_another:
# Check if the column description or its cleaned version appears in the question
if col_desc in question or clean_text(col_desc) in question:
L_dict = find_dict_by_element(database_L_zh, col_descs)
if not L_dict:
break
# Create a mapping from Chinese description to English column name
dict_zip = dict(zip(L_dict[0]['列名中文描述'], L_dict[0]['列名']))
column_name = dict_zip[col_descs]
data_table = L_dict[0]['数据表名']
matched_columns.append({
'数据库表': data_table,
'列名': column_name,
'列名中文含义': col_descs
})
break
if matched_columns:
return f"已获得一部分数据库列名{matched_columns},请充分利用获得的列名直接查询数据。"
else:
return ''
def get_table_schema_with_emb(tables_name=[], question='', threshold=0.35, print_all=True):
# tables_name是英文名的list
if len(QUESTION_TYPE_LIST) == 0:
parsed_tables = parse_table_structures(input_text_all)
database_L = database_L_all
database_table_en = database_table_en_all
elif len(QUESTION_TYPE_LIST) >= 2:
parsed_tables = parse_table_structures(input_text_all)
database_L = database_L_all
database_table_en = database_table_en_all
elif "港股" in QUESTION_TYPE_LIST:
parsed_tables = parse_table_structures(input_text_hk)
database_L = database_L_hk
database_table_en = database_table_en_hk
elif "美股" in QUESTION_TYPE_LIST:
parsed_tables = parse_table_structures(input_text_us)
database_L = database_L_us
database_table_en = database_table_en_us
elif "A股" in QUESTION_TYPE_LIST:
parsed_tables = parse_table_structures(input_text_cn)
database_L = database_L_cn
database_table_en = database_table_en_cn
else:
parsed_tables = parse_table_structures(input_text_all)
database_L = database_L_all
database_table_en = database_table_en_all
# Clean up keys and values
cleaned_tables = {
k.replace(' ', '').replace('表结构', ''): v.replace('--', '')
for k, v in parsed_tables.items()
}
columns_to_clean = ["column_description"]
time_columns = ["TradingDay", "EndDate", "InfoPublDate"]
remove_columns = ['InitialInfoPublDate', 'XGRQ','InsertTime','UpdateTime']
table_maps = []
highest_score_list = []
if tables_name == [] or question == '':
return None
for table_name in tables_name:
if DEBUG_VER == 3:
print(f"---------->table_name:{table_name}")
for english_name, structure in cleaned_tables.items():
if english_name.lower() == table_name.lower():
# print(f"------------threshold:{threshold}")
filtered_comments = []
df_tmp = df_all2[df_all2["table_name"]==table_name.split(".")[1]]
imp_matched_columns, matched_columns_with_scores = extract_matching_columns_with_similarity(
df_tmp, question, threshold, MAX_TOP_COLUMNS
)
if matched_columns_with_scores == [] and imp_matched_columns == []:
# print(f"------------pass")
pass
else:
# print(f"------------no pass")
# 过滤掉不需要的列
remove_columns = ["InitialInfoPublDate", "XGRQ", "InsertTime", "UpdateTime"]
matched_columns = [
item["column"] for item in matched_columns_with_scores if item["column"] not in remove_columns
]
matched_columns = [column for column in matched_columns if column not in remove_columns]
highest_score_item = max(matched_columns_with_scores, key=lambda x: x["similarity_score"])
highest_score = highest_score_item["similarity_score"]
if print_all:
pass
# print(f"------------table_name:{table_name}")
# print(f"------------imp_matched_columns:{imp_matched_columns}, matched_columns_with_scores:{matched_columns_with_scores}")
else:
if (highest_score > threshold) or imp_matched_columns != []:
pass
# print(f"------------table_name:{table_name}")
# print(f"------------imp_matched_columns:{imp_matched_columns}, matched_columns_with_scores:{matched_columns_with_scores}")
filtered_structure = filter_structure(structure, matched_columns, time_columns)
filtered_comments = extract_column_descriptions(filtered_structure, df_tmp)
table_map = {
'数据表名': table_name,
'数据表结构': filtered_structure,
'部分字段注释补充': filtered_comments
}
table_maps.append(table_map)
highest_score_list.append(highest_score)
if not table_maps or not highest_score_list:
return [], []
# 将 table_maps 和 highest_score_list 按分数排序
combined_list = list(zip(table_maps, highest_score_list))
sorted_combined_list = sorted(combined_list, key=lambda x: x[1], reverse=True)
table_maps, highest_score_list = zip(*sorted_combined_list)
table_maps = list(table_maps)
highest_score_list = list(highest_score_list)
return table_maps, highest_score_list
def deep_find_tables(all_tables_name_list=[], question='', threshold=0.6, print_all=False, top_table_n=5):
# tables_name是英文名的list
if len(QUESTION_TYPE_LIST) == 0:
df_tmp = df_all
elif len(QUESTION_TYPE_LIST) >= 2:
df_tmp = df_all
elif "港股" in QUESTION_TYPE_LIST:
df_tmp = df_hk
elif "美股" in QUESTION_TYPE_LIST:
df_tmp = df_us
elif "A股" in QUESTION_TYPE_LIST:
df_tmp = df_cn
else:
df_tmp = df_all
if all_tables_name_list==[] or question=='':
return []
all_table_list = df_tmp['库表名英文'].values.tolist()
all_table_list = [table for table in all_table_list if table not in all_tables_name_list]
all_table_list = list(set(all_table_list))
table_maps, highest_score_list = get_table_schema_with_emb(all_table_list, question, threshold, print_all)
return table_maps[:top_table_n]
# %% [code] {"execution":{"iopub.status.busy":"2025-03-07T09:14:36.671853Z","iopub.execute_input":"2025-03-07T09:14:36.672272Z","iopub.status.idle":"2025-03-07T09:14:36.693550Z","shell.execute_reply.started":"2025-03-07T09:14:36.672191Z","shell.execute_reply":"2025-03-07T09:14:36.692326Z"},"jupyter":{"source_hidden":true}}
def find_table_with_emb(all_tables_name_list=[], question='', use_local_vectors=USE_LOCAL_VECTORS):
"""
判断后返回最高分的表始终返回1个
1. 则对 df_tmp 中所有唯一的库表进行批量比较,并记录每个表的相似度。
2. 最后合并候选结果从中选取相似度最高的表返回返回列表中只含1个表
"""
# 根据 QUESTION_TYPE_LIST 选择对应的数据集
if len(QUESTION_TYPE_LIST) == 0:
df_tmp = df_all
embeddings_path = df_all_embeddings_path
elif len(QUESTION_TYPE_LIST) >= 2:
df_tmp = df_all
embeddings_path = df_all_embeddings_path
elif "港股" in QUESTION_TYPE_LIST:
df_tmp = df_hk
embeddings_path = df_hk_embeddings_path
elif "美股" in QUESTION_TYPE_LIST:
df_tmp = df_us
embeddings_path = df_us_embeddings_path
elif "A股" in QUESTION_TYPE_LIST:
df_tmp = df_cn
embeddings_path = df_cn_embeddings_path
else:
df_tmp = df_all
embeddings_path = df_all_embeddings_path
if not all_tables_name_list or question == '':
return []
candidate_scores = []
df_unique = df_tmp[['库表名英文', 'table_describe']].drop_duplicates()
candidate_ids_deep = df_unique['库表名英文'].tolist()
candidate_texts_deep = df_unique['table_describe'].tolist()
if use_local_vectors:
texts_for_batch_deep = [question]
similarity_scores_deep = calculate_similarity(texts_for_batch_deep, local_vectors=embeddings_path)
else:
texts_for_batch_deep = [question] + candidate_texts_deep
similarity_scores_deep = calculate_similarity(texts_for_batch_deep)
for table, sim in zip(candidate_ids_deep, similarity_scores_deep):
candidate_scores.append((table, sim))
# 如果没有候选表,则返回空列表
if not candidate_scores:
return []
# 从所有候选中选取相似度最高的表(不论是否达到 threshold
best_table, best_sim = max(candidate_scores, key=lambda x: x[1])
if DEBUG_VER == 3:
print(f"Best table: {best_table} with similarity: {best_sim}")
return [best_table]
def extract_matching_columns_with_similarity(df, question, threshold, top_n=25):
"""
根据question提取匹配列
1. 对于每行的字段(如 column_description 与 注释),先做全量匹配,
如果文本完全包含在question中则直接赋分1.0。
2. 同时收集所有候选文本,最后批量调用 calculate_similarity
将 question 与所有候选文本比较,得到相似度分数(只保留达到阈值的)。
3. 同一列可能对应多个候选文本,最终保留相似度最高的得分。
返回:
- imp_matched_columns全量匹配得分1.0)的列名列表
- topn_matched_columns满足阈值的列及其相似度分数按分数降序排列最多 top_n 个)
"""
imp_matched_columns_with_scores = {} # 全量匹配得分为1.0的列
matched_columns_with_scores = {} # 语义匹配得分(可能多次出现,取最高分)
candidate_texts = [] # 收集待计算相似度的候选文本
candidate_ids = [] # 对应候选文本所属的列名
# 遍历DataFrame中的每一行
for _, row in df.iterrows():
col_name = row["column_name"]
# 处理 column_description 字段
if isinstance(row["column_description"], str):
col_desc = row["column_description"]
# 全量匹配检查如果候选文本出现在question中则直接记1.0分
if col_desc in question:
imp_matched_columns_with_scores[col_name] = 1.0
matched_columns_with_scores[col_name] = 1.0
# 无论全量匹配与否,都加入候选列表进行语义匹配
else:
candidate_texts.append(col_desc)
candidate_ids.append(col_name)
# 处理 注释 字段(假设多个注释以“,”分隔)
if isinstance(row["注释"], str):
words = row["注释"].split("")
for word in words:
if word in question:
imp_matched_columns_with_scores[col_name] = 1.0
matched_columns_with_scores[col_name] = 1.0
# candidate_texts.append(word)
# candidate_ids.append(col_name)
# 批量计算语义相似度:将 question 与所有候选文本一次性比较
if candidate_texts:
# 构造输入列表:第一个文本为 question其余为所有候选文本
texts_for_batch = [question] + candidate_texts
similarity_scores = calculate_similarity(texts_for_batch)
# 遍历每个候选文本的相似度
for idx, sim in enumerate(similarity_scores):
if sim >= threshold:
col = candidate_ids[idx]
# 同一列可能有多个候选文本,保留相似度最高的得分
matched_columns_with_scores[col] = max(matched_columns_with_scores.get(col, 0), sim)
# 合并全量匹配和语义匹配的结果
all_matched_columns = {**matched_columns_with_scores, **imp_matched_columns_with_scores}
# 按相似度分数降序排列注意由于全量匹配的分数为1.0,通常会排在最前面)
unique_matched_columns = {col: score for col, score in sorted(all_matched_columns.items(), key=lambda x: -x[1])}
# 提取全量匹配列
imp_matched_columns = [col for col, score in unique_matched_columns.items() if score == 1.0]
# 筛选出满足阈值的列,并构造输出格式(字典列表),按分数降序排序
matched_columns = sorted(
[{"column": col, "similarity_score": score} for col, score in unique_matched_columns.items() if score >= threshold],
key=lambda x: x["similarity_score"],
reverse=True
)
topn_matched_columns = matched_columns[:top_n]
return imp_matched_columns, topn_matched_columns
def filter_structure(structure, matched_columns, time_columns):
# 分割表头和数据部分
sections = structure.split("\n\n", 1)
header = sections[0] # 表头部分
data_rows = sections[1] if len(sections) > 1 else "" # 数据部分
# 条件检查函数
def satisfies_conditions(row):
row_fields = row.split() # 假设字段之间是用空格分隔的
# 完全匹配检查
if any(col == field for field in row_fields for col in matched_columns):
return True
if any(col == field for field in row_fields for col in time_columns):
return True
if any(keyword in row for keyword in ["Code", "Abbr", "Name"]):
return True
return False
# 逐行过滤数据部分
filtered_rows = []
for row in data_rows.strip().split("\n"):
if satisfies_conditions(row):
filtered_rows.append(row)
# 将过滤后的内容与表头合并
filtered_structure = header + "\n\n" + "\n".join(filtered_rows) if filtered_rows else header
return filtered_structure
def extract_column_descriptions(filtered_structure, df_tmp):
# 从"\n\n"开始提取内容部分
content = filtered_structure.split("\n\n", 1)[1].strip()
# 提取列名(每行第一个空格前的部分)
column_names = []
for line in content.split("\n"):
column_name = line.split()[0] # 获取第一个空格前的内容
column_names.append(column_name)
# 转换 df_tmp 为字典形式,方便查找
column_dict = dict(zip(df_tmp["column_name"], df_tmp["注释"]))
# 构造结果列表
result = []
for column_name in column_names:
if column_name in column_dict and len(str(column_dict[column_name]))>3:
result.append({column_name: column_dict[column_name]})
return result
def calculate_similarity(text_list, local_vectors=False):
"""
批量计算相似度:
- 输入一个文本列表,其中第一个文本作为基准,其余文本与基准比较
- 当候选文本超过 64 条时会分批请求最后返回所有候选文本与基准文本的相似度保留4位小数
"""
base_text = text_list[0]
# 先单独请求基准文本的 embedding
base_response = client.embeddings.create(
model="embedding-3",
input=[base_text]
)
base_embedding = base_response.data[0].embedding
base_embedding = np.array(base_embedding)
all_similarities = []
if local_vectors:
if DEBUG_VER == 3:
print(f'------>local_vectors:{local_vectors}')
# 使用本地保存的向量
with open(local_vectors, 'r', encoding='utf-8') as f:
local_embeddings = json.load(f)
# 确保读取到的向量与候选文本对应
candidate_embeddings = np.array(local_embeddings)
else:
# 批量处理候选文本每次最多请求64条
candidate_texts = text_list[1:]
candidate_embeddings = []
chunk_size = 64 # 每次最多请求64条
for i in range(0, len(candidate_texts), chunk_size):
chunk = candidate_texts[i:i+chunk_size]
response = client.embeddings.create(
model="embedding-3",
input=chunk
)
# 提取候选文本的 embedding 并转换为 NumPy 数组
embeddings = [item.embedding for item in response.data]
candidate_embeddings.extend(embeddings)
candidate_embeddings = np.array(candidate_embeddings)
# 计算余弦相似度: dot / (||base|| * ||candidate||)
dot_products = candidate_embeddings.dot(base_embedding)
norm_base = np.linalg.norm(base_embedding)
norm_candidates = np.linalg.norm(candidate_embeddings, axis=1)
similarities = dot_products / (norm_base * norm_candidates)
# 保留4位小数并加入结果列表
sims = [round(float(sim), 4) for sim in similarities]
all_similarities.extend(sims)
return all_similarities