# agent/schema/schema_utils.py import jieba import numpy as np from collections import Counter from config import * from agent.utils import clean_text, find_dict_by_element def parse_table_structures(input_text): """ 用于解析输入文本并提取表结构信息,返回一个字典,其中表名作为键,表结构作为值 Parse the input text to extract table structures. The format is expected as pairs: "table_name === table_structure". Parameters: input_text (str): The raw text containing table structures. Returns: tables_dict (dict): A dictionary where keys are table names and values are the associated table structures. """ tables_text = input_text.split('===')[1:] tables_dict = {tables_text[i]: tables_text[i + 1] for i in range(0, len(tables_text), 2)} return tables_dict def map_chinese_to_english_tables(chinese_names, english_names): """ 将中文的表名映射到对应的英文表名 Map Chinese table names to their corresponding English table names. For each Chinese name, there is a matching English name (case-insensitive comparison). Parameters: chinese_names (list): A list of Chinese table names. english_names (list): A list of English table names. Returns: name_map (dict): A dictionary mapping Chinese table names to English table names. """ name_map = {} for cname in chinese_names: # Find the corresponding English name (case-insensitive match) english_match = [en for en in english_names if str(en).lower() == cname.lower()][0] name_map[cname] = english_match return name_map def find_value_in_list_of_dicts(dict_list, key_to_match, value_to_match, key_to_return): """ 在字典列表中查找满足条件的字典,并返回其中指定键的值。 Search through a list of dictionaries and find the first dictionary where the value of key_to_match equals value_to_match, then return the value associated with key_to_return. Parameters: dict_list (list): A list of dictionaries to search through. key_to_match (str): The key whose value we want to match. value_to_match (str): The value we are looking for. key_to_return (str): The key whose value we want to return. Returns: (str): The value associated with key_to_return in the matching dictionary, or an empty string if no match is found. """ for dictionary in dict_list: if dictionary.get(key_to_match) == value_to_match: return dictionary.get(key_to_return) return '' def filter_table_comments(question, table_comments): """ 根据输入问题从表注释列表中筛选出与问题相关的注释。 Filter a list of table comments based on the given question. Uses jieba for segmentation and removes stopwords, returning only comments that contain at least one of the segmented keywords. Parameters: question (str): The question text. table_comments (list): A list of comment strings to filter. Returns: filtered_comments (list): Filtered list of comments. """ stopwords = ['?', '有', '的', '多少', '人', '(', ')'] seg_list = list(jieba.cut(question, cut_all=False)) filtered_seg_list = [word for word in seg_list if word not in stopwords] filtered_comments = [] for comment in table_comments: if any(keyword in comment for keyword in filtered_seg_list): filtered_comments.append(comment) return filtered_comments def get_table_schema(question=''): """ 获取表格的结构信息以及字段的注释 Retrieve table schemas along with optional filtered field comments. If a question is provided, the comments will be filtered based on question keywords. The function: 1. Maps Chinese table names to English table names. 2. For each table, retrieves its structure and finds associated comments. 3. If a question is provided, filter the comments based on keywords extracted from the question. Parameters: question (str): The question text. If empty, no filtering is performed. Returns: table_maps (list): A list of dictionaries, each containing table schema information. { '数据表名': EnglishTableName, '数据表结构': TableStructure, '字段注释': FilteredComments (optional if question is provided) } """ if QUESTION_TYPE == "全股": parsed_tables = parse_table_structures(input_text_all) database_L = database_L_all database_table_en = database_table_en_all elif QUESTION_TYPE == "港股": parsed_tables = parse_table_structures(input_text_hk) database_L = database_L_hk database_table_en = database_table_en_hk elif QUESTION_TYPE == "美股": parsed_tables = parse_table_structures(input_text_us) database_L = database_L_us database_table_en = database_table_en_us elif QUESTION_TYPE == "A股": parsed_tables = parse_table_structures(input_text_cn) database_L = database_L_cn database_table_en = database_table_en_cn else: parsed_tables = parse_table_structures(input_text_all) database_L = database_L_all database_table_en = database_table_en_all # Clean up keys and values cleaned_tables = { k.replace(' ', '').replace('表结构', ''): v.replace('--', '') for k, v in parsed_tables.items() } # List of Chinese table names (keys) chinese_table_names = list(cleaned_tables.keys()) name_map = map_chinese_to_english_tables(chinese_table_names, database_table_en) table_maps = [] for cname, structure in cleaned_tables.items(): english_name = name_map.get(cname) comments = find_value_in_list_of_dicts(database_L, '数据表名', english_name, '注释') if question == '': # No filtering, just return table name and structure table_map = { '数据表名': english_name, '数据表结构': structure } else: # Filter comments based on question filtered_comments = filter_table_comments(question, comments) table_map = { '数据表名': english_name, '数据表结构': structure, '字段注释': filtered_comments } table_maps.append(table_map) return table_maps def to_get_question_columns(question): """ Given a question (string) and a global variable database_L_zh (list of dicts), find 列名 that correspond to 列名中文描述 mentioned in the question. If any matching columns are found, return a message instructing the user to use these column names directly for data querying. If none are found, return an empty string. Parameters: question (str): The input question text. Returns: str: A message with identified column names or an empty string if none found. """ database_L_zh = database_L_zh_all if len(QUESTION_TYPE_LIST) == 0: database_L_zh = database_L_zh_all elif len(QUESTION_TYPE_LIST) >= 2: database_L_zh = database_L_zh_all elif "港股" in QUESTION_TYPE_LIST: database_L_zh = database_L_zh_hk elif "美股" in QUESTION_TYPE_LIST: database_L_zh = database_L_zh_us elif "A股" in QUESTION_TYPE_LIST: database_L_zh = database_L_zh_cn else: database_L_zh = database_L_zh_all L_num = [] for items in database_L_zh: L_num += items['列名中文描述'] # Get unique column descriptions L_num_new = [item for item, count in Counter(L_num).items() if count == 1] # Drop NaN if any series_num = pd.Series(L_num_new) L_num_new = list(series_num.dropna()) # Remove known irrelevant items irrelevant_items = ['年度', '占比'] for irr in irrelevant_items: if irr in L_num_new: L_num_new.remove(irr) matched_columns = [] for col_descs in L_num_new: col_desc_another = items_another[col_descs] for col_desc in col_desc_another: # Check if the column description or its cleaned version appears in the question if col_desc in question or clean_text(col_desc) in question: L_dict = find_dict_by_element(database_L_zh, col_descs) if not L_dict: break # Create a mapping from Chinese description to English column name dict_zip = dict(zip(L_dict[0]['列名中文描述'], L_dict[0]['列名'])) column_name = dict_zip[col_descs] data_table = L_dict[0]['数据表名'] matched_columns.append({ '数据库表': data_table, '列名': column_name, '列名中文含义': col_descs }) break if matched_columns: return f"已获得一部分数据库列名{matched_columns},请充分利用获得的列名直接查询数据。" else: return '' def get_table_schema_with_emb(tables_name=[], question='', threshold=0.35, print_all=True): # tables_name是英文名的list if len(QUESTION_TYPE_LIST) == 0: parsed_tables = parse_table_structures(input_text_all) database_L = database_L_all database_table_en = database_table_en_all elif len(QUESTION_TYPE_LIST) >= 2: parsed_tables = parse_table_structures(input_text_all) database_L = database_L_all database_table_en = database_table_en_all elif "港股" in QUESTION_TYPE_LIST: parsed_tables = parse_table_structures(input_text_hk) database_L = database_L_hk database_table_en = database_table_en_hk elif "美股" in QUESTION_TYPE_LIST: parsed_tables = parse_table_structures(input_text_us) database_L = database_L_us database_table_en = database_table_en_us elif "A股" in QUESTION_TYPE_LIST: parsed_tables = parse_table_structures(input_text_cn) database_L = database_L_cn database_table_en = database_table_en_cn else: parsed_tables = parse_table_structures(input_text_all) database_L = database_L_all database_table_en = database_table_en_all # Clean up keys and values cleaned_tables = { k.replace(' ', '').replace('表结构', ''): v.replace('--', '') for k, v in parsed_tables.items() } columns_to_clean = ["column_description"] time_columns = ["TradingDay", "EndDate", "InfoPublDate"] remove_columns = ['InitialInfoPublDate', 'XGRQ','InsertTime','UpdateTime'] table_maps = [] highest_score_list = [] if tables_name == [] or question == '': return None for table_name in tables_name: if DEBUG_VER == 3: print(f"---------->table_name:{table_name}") for english_name, structure in cleaned_tables.items(): if english_name.lower() == table_name.lower(): # print(f"------------threshold:{threshold}") filtered_comments = [] df_tmp = df_all2[df_all2["table_name"]==table_name.split(".")[1]] imp_matched_columns, matched_columns_with_scores = extract_matching_columns_with_similarity( df_tmp, question, threshold, MAX_TOP_COLUMNS ) if matched_columns_with_scores == [] and imp_matched_columns == []: # print(f"------------pass") pass else: # print(f"------------no pass") # 过滤掉不需要的列 remove_columns = ["InitialInfoPublDate", "XGRQ", "InsertTime", "UpdateTime"] matched_columns = [ item["column"] for item in matched_columns_with_scores if item["column"] not in remove_columns ] matched_columns = [column for column in matched_columns if column not in remove_columns] highest_score_item = max(matched_columns_with_scores, key=lambda x: x["similarity_score"]) highest_score = highest_score_item["similarity_score"] if print_all: pass # print(f"------------table_name:{table_name}") # print(f"------------imp_matched_columns:{imp_matched_columns}, matched_columns_with_scores:{matched_columns_with_scores}") else: if (highest_score > threshold) or imp_matched_columns != []: pass # print(f"------------table_name:{table_name}") # print(f"------------imp_matched_columns:{imp_matched_columns}, matched_columns_with_scores:{matched_columns_with_scores}") filtered_structure = filter_structure(structure, matched_columns, time_columns) filtered_comments = extract_column_descriptions(filtered_structure, df_tmp) table_map = { '数据表名': table_name, '数据表结构': filtered_structure, '部分字段注释补充': filtered_comments } table_maps.append(table_map) highest_score_list.append(highest_score) if not table_maps or not highest_score_list: return [], [] # 将 table_maps 和 highest_score_list 按分数排序 combined_list = list(zip(table_maps, highest_score_list)) sorted_combined_list = sorted(combined_list, key=lambda x: x[1], reverse=True) table_maps, highest_score_list = zip(*sorted_combined_list) table_maps = list(table_maps) highest_score_list = list(highest_score_list) return table_maps, highest_score_list def deep_find_tables(all_tables_name_list=[], question='', threshold=0.6, print_all=False, top_table_n=5): # tables_name是英文名的list if len(QUESTION_TYPE_LIST) == 0: df_tmp = df_all elif len(QUESTION_TYPE_LIST) >= 2: df_tmp = df_all elif "港股" in QUESTION_TYPE_LIST: df_tmp = df_hk elif "美股" in QUESTION_TYPE_LIST: df_tmp = df_us elif "A股" in QUESTION_TYPE_LIST: df_tmp = df_cn else: df_tmp = df_all if all_tables_name_list==[] or question=='': return [] all_table_list = df_tmp['库表名英文'].values.tolist() all_table_list = [table for table in all_table_list if table not in all_tables_name_list] all_table_list = list(set(all_table_list)) table_maps, highest_score_list = get_table_schema_with_emb(all_table_list, question, threshold, print_all) return table_maps[:top_table_n] # %% [code] {"execution":{"iopub.status.busy":"2025-03-07T09:14:36.671853Z","iopub.execute_input":"2025-03-07T09:14:36.672272Z","iopub.status.idle":"2025-03-07T09:14:36.693550Z","shell.execute_reply.started":"2025-03-07T09:14:36.672191Z","shell.execute_reply":"2025-03-07T09:14:36.692326Z"},"jupyter":{"source_hidden":true}} def find_table_with_emb(all_tables_name_list=[], question='', use_local_vectors=USE_LOCAL_VECTORS): """ 判断后返回最高分的表(始终返回1个): 1. 则对 df_tmp 中所有唯一的库表进行批量比较,并记录每个表的相似度。 2. 最后合并候选结果,从中选取相似度最高的表返回(返回列表中只含1个表)。 """ # 根据 QUESTION_TYPE_LIST 选择对应的数据集 if len(QUESTION_TYPE_LIST) == 0: df_tmp = df_all embeddings_path = df_all_embeddings_path elif len(QUESTION_TYPE_LIST) >= 2: df_tmp = df_all embeddings_path = df_all_embeddings_path elif "港股" in QUESTION_TYPE_LIST: df_tmp = df_hk embeddings_path = df_hk_embeddings_path elif "美股" in QUESTION_TYPE_LIST: df_tmp = df_us embeddings_path = df_us_embeddings_path elif "A股" in QUESTION_TYPE_LIST: df_tmp = df_cn embeddings_path = df_cn_embeddings_path else: df_tmp = df_all embeddings_path = df_all_embeddings_path if not all_tables_name_list or question == '': return [] candidate_scores = [] df_unique = df_tmp[['库表名英文', 'table_describe']].drop_duplicates() candidate_ids_deep = df_unique['库表名英文'].tolist() candidate_texts_deep = df_unique['table_describe'].tolist() if use_local_vectors: texts_for_batch_deep = [question] similarity_scores_deep = calculate_similarity(texts_for_batch_deep, local_vectors=embeddings_path) else: texts_for_batch_deep = [question] + candidate_texts_deep similarity_scores_deep = calculate_similarity(texts_for_batch_deep) for table, sim in zip(candidate_ids_deep, similarity_scores_deep): candidate_scores.append((table, sim)) # 如果没有候选表,则返回空列表 if not candidate_scores: return [] # 从所有候选中选取相似度最高的表(不论是否达到 threshold) best_table, best_sim = max(candidate_scores, key=lambda x: x[1]) if DEBUG_VER == 3: print(f"Best table: {best_table} with similarity: {best_sim}") return [best_table] def extract_matching_columns_with_similarity(df, question, threshold, top_n=25): """ 根据question提取匹配列: 1. 对于每行的字段(如 column_description 与 注释),先做全量匹配, 如果文本完全包含在question中,则直接赋分1.0。 2. 同时收集所有候选文本,最后批量调用 calculate_similarity, 将 question 与所有候选文本比较,得到相似度分数(只保留达到阈值的)。 3. 同一列可能对应多个候选文本,最终保留相似度最高的得分。 返回: - imp_matched_columns:全量匹配(得分1.0)的列名列表 - topn_matched_columns:满足阈值的列及其相似度分数(按分数降序排列,最多 top_n 个) """ imp_matched_columns_with_scores = {} # 全量匹配得分为1.0的列 matched_columns_with_scores = {} # 语义匹配得分(可能多次出现,取最高分) candidate_texts = [] # 收集待计算相似度的候选文本 candidate_ids = [] # 对应候选文本所属的列名 # 遍历DataFrame中的每一行 for _, row in df.iterrows(): col_name = row["column_name"] # 处理 column_description 字段 if isinstance(row["column_description"], str): col_desc = row["column_description"] # 全量匹配检查:如果候选文本出现在question中,则直接记1.0分 if col_desc in question: imp_matched_columns_with_scores[col_name] = 1.0 matched_columns_with_scores[col_name] = 1.0 # 无论全量匹配与否,都加入候选列表进行语义匹配 else: candidate_texts.append(col_desc) candidate_ids.append(col_name) # 处理 注释 字段(假设多个注释以“,”分隔) if isinstance(row["注释"], str): words = row["注释"].split(",") for word in words: if word in question: imp_matched_columns_with_scores[col_name] = 1.0 matched_columns_with_scores[col_name] = 1.0 # candidate_texts.append(word) # candidate_ids.append(col_name) # 批量计算语义相似度:将 question 与所有候选文本一次性比较 if candidate_texts: # 构造输入列表:第一个文本为 question,其余为所有候选文本 texts_for_batch = [question] + candidate_texts similarity_scores = calculate_similarity(texts_for_batch) # 遍历每个候选文本的相似度 for idx, sim in enumerate(similarity_scores): if sim >= threshold: col = candidate_ids[idx] # 同一列可能有多个候选文本,保留相似度最高的得分 matched_columns_with_scores[col] = max(matched_columns_with_scores.get(col, 0), sim) # 合并全量匹配和语义匹配的结果 all_matched_columns = {**matched_columns_with_scores, **imp_matched_columns_with_scores} # 按相似度分数降序排列(注意:由于全量匹配的分数为1.0,通常会排在最前面) unique_matched_columns = {col: score for col, score in sorted(all_matched_columns.items(), key=lambda x: -x[1])} # 提取全量匹配列 imp_matched_columns = [col for col, score in unique_matched_columns.items() if score == 1.0] # 筛选出满足阈值的列,并构造输出格式(字典列表),按分数降序排序 matched_columns = sorted( [{"column": col, "similarity_score": score} for col, score in unique_matched_columns.items() if score >= threshold], key=lambda x: x["similarity_score"], reverse=True ) topn_matched_columns = matched_columns[:top_n] return imp_matched_columns, topn_matched_columns def filter_structure(structure, matched_columns, time_columns): # 分割表头和数据部分 sections = structure.split("\n\n", 1) header = sections[0] # 表头部分 data_rows = sections[1] if len(sections) > 1 else "" # 数据部分 # 条件检查函数 def satisfies_conditions(row): row_fields = row.split() # 假设字段之间是用空格分隔的 # 完全匹配检查 if any(col == field for field in row_fields for col in matched_columns): return True if any(col == field for field in row_fields for col in time_columns): return True if any(keyword in row for keyword in ["Code", "Abbr", "Name"]): return True return False # 逐行过滤数据部分 filtered_rows = [] for row in data_rows.strip().split("\n"): if satisfies_conditions(row): filtered_rows.append(row) # 将过滤后的内容与表头合并 filtered_structure = header + "\n\n" + "\n".join(filtered_rows) if filtered_rows else header return filtered_structure def extract_column_descriptions(filtered_structure, df_tmp): # 从"\n\n"开始提取内容部分 content = filtered_structure.split("\n\n", 1)[1].strip() # 提取列名(每行第一个空格前的部分) column_names = [] for line in content.split("\n"): column_name = line.split()[0] # 获取第一个空格前的内容 column_names.append(column_name) # 转换 df_tmp 为字典形式,方便查找 column_dict = dict(zip(df_tmp["column_name"], df_tmp["注释"])) # 构造结果列表 result = [] for column_name in column_names: if column_name in column_dict and len(str(column_dict[column_name]))>3: result.append({column_name: column_dict[column_name]}) return result def calculate_similarity(text_list, local_vectors=False): """ 批量计算相似度: - 输入一个文本列表,其中第一个文本作为基准,其余文本与基准比较 - 当候选文本超过 64 条时,会分批请求,最后返回所有候选文本与基准文本的相似度(保留4位小数) """ base_text = text_list[0] # 先单独请求基准文本的 embedding base_response = client.embeddings.create( model="embedding-3", input=[base_text] ) base_embedding = base_response.data[0].embedding base_embedding = np.array(base_embedding) all_similarities = [] if local_vectors: if DEBUG_VER == 3: print(f'------>local_vectors:{local_vectors}') # 使用本地保存的向量 with open(local_vectors, 'r', encoding='utf-8') as f: local_embeddings = json.load(f) # 确保读取到的向量与候选文本对应 candidate_embeddings = np.array(local_embeddings) else: # 批量处理候选文本,每次最多请求64条 candidate_texts = text_list[1:] candidate_embeddings = [] chunk_size = 64 # 每次最多请求64条 for i in range(0, len(candidate_texts), chunk_size): chunk = candidate_texts[i:i+chunk_size] response = client.embeddings.create( model="embedding-3", input=chunk ) # 提取候选文本的 embedding 并转换为 NumPy 数组 embeddings = [item.embedding for item in response.data] candidate_embeddings.extend(embeddings) candidate_embeddings = np.array(candidate_embeddings) # 计算余弦相似度: dot / (||base|| * ||candidate||) dot_products = candidate_embeddings.dot(base_embedding) norm_base = np.linalg.norm(base_embedding) norm_candidates = np.linalg.norm(candidate_embeddings, axis=1) similarities = dot_products / (norm_base * norm_candidates) # 保留4位小数,并加入结果列表 sims = [round(float(sim), 4) for sim in similarities] all_similarities.extend(sims) return all_similarities