import time, gc from config import * from agent.question.question_utils import preprocessing_question, process_question, process_dict, add_question_mark, question_rew from agent.llm_client import create_chat_completion, create_chat_completion_change_zhushi from agent.schema.schema_utils import get_table_schema, find_table_with_emb, get_table_schema_with_emb, deep_find_tables from agent.sql.exec import exec_sql_s from agent.sql.process_sql import to_select, extract_table_names, all_tables_in_prompt from agent.sql.knowledge import way_string_2 from agent.optimization import optimize_answer, find_kl_answer from agent.utils.async_util import async_llm_chain_call from agent.utils.io_utils import load_external_answers from agent.utils.json_utils import find_json def run_conversation_xietong(question, org_question): global QUESTION_TYPE global QUESTION_TYPE_LIST global LL global table_maps global table_maps_LL global all_tables_name_list str_process_question = process_question(question) if DEBUG_VER == 3: print(f"-------------已查询获得事实: {str_process_question}") if "ConstantDB.HK_SecuMain" in str(str_process_question) and "港股" not in QUESTION_TYPE_LIST: QUESTION_TYPE_LIST.append("港股") if "ConstantDB.US_SecuMain" in str(str_process_question) and "美股" not in QUESTION_TYPE_LIST: QUESTION_TYPE_LIST.append("美股") if "ConstantDB.SecuMain" in str(str_process_question) and "A股" not in QUESTION_TYPE_LIST: QUESTION_TYPE_LIST.append("A股") content_p_1 = content_ALL if len(QUESTION_TYPE_LIST) == 0: content_p_1 = content_ALL elif len(QUESTION_TYPE_LIST) >= 2: content_p_1 = content_ALL elif "港股" in QUESTION_TYPE_LIST: content_p_1 = content_HK elif "美股" in QUESTION_TYPE_LIST: content_p_1 = content_US elif "A股" in QUESTION_TYPE_LIST: content_p_1 = content_CN else: content_p_1 = content_ALL if DEBUG_VER == 3: print(f"---------QUESTION_TYPE_LIST:{QUESTION_TYPE_LIST}") content_p = content_p_1.replace('<>', str(question)).replace('<>', str(str_process_question)) main_sql_prompts,highest_prompt_similar = way_string_2(org_question) content_p = content_p + main_sql_prompts # 执行函数部分 messages = [] messages.append({"role": "user", "content": "您好阿"}) messages.append({"role": "user", "content": content_p}) response = create_chat_completion(messages) if "全部完成" in response.choices[0].message.content: if "XX" in response.choices[0].message.content: response.choices[0].message.content += "ConstantDB.SecuMain,ConstantDB.HK_SecuMain,ConstantDB.US_SecuMain" else: if DEBUG_VER == 3: print("全部完成:response:",response.choices[0].message.content) if "需要查询的数据库表" in response.choices[0].message.content or "需要查询数据库表" in response.choices[0].message.content: pass else: return response.choices[0].message.content messages.append({"role": "assistant", "content": response.choices[0].message.content}) table_maps = get_table_schema(question) LL = [i for i in table_maps if i.get('数据表名') in response.choices[0].message.content] tables_name_list = [i.get('数据表名') for i in table_maps if i.get('数据表名') in response.choices[0].message.content] try: sql_prompt_db_tables = extract_table_names(main_sql_prompts) except Exception as e: if DEBUG_VER == 3: print(f"------->e:{e}") if sql_prompt_db_tables: all_tables_name_list += sql_prompt_db_tables all_tables_name_list += tables_name_list all_tables_name_list = list(set(all_tables_name_list)) if DEBUG_VER == 3: print(f"-----------------all_tables_name_list:{all_tables_name_list}") all_tables_name_list += find_table_with_emb(all_tables_name_list, question) all_tables_name_list = list(set(all_tables_name_list)) if DEBUG_VER == 3: print(f"-----------------find_table_with_emb all_tables_name_list:{all_tables_name_list}") if all_tables_name_list == []: highest_similarity_score_list = [0] else: table_maps_LL, highest_similarity_score_list = get_table_schema_with_emb(all_tables_name_list, question) # print(f"--------------table_maps_LL:{table_maps_LL}") if all_tables_in_prompt(tables_name_list, main_sql_prompts) and highest_prompt_similar > START_DEEP_THRESHOLD: content_p_2 = """获取的表结构如下,表结构中列名可以引用使用,表结构中数据示例只是参考不能引用。 我们现在开始查询当前问题,可参考示例sql写出查询sql语句,如果示例sql是一句的(一步),你也尽可能用一句sql得到结果。 如果需要多步,我把查询结果告诉你,你再告诉我下一步,注意如果我返回的结果为空或者错误影响下一步调用,请重新告诉我sql语句。 等你全部回答完成,不需要进行下一步调用时,记得提示我:<全部完成,答案如下>,将答案总结以json格式给我,只需要总结当前问题。 查询技巧: 涉及计算的使用sql进行加减乘除数学计算。查询尽量使用InnerCode。查年用year()函数、月用month()、某日则用date(),如:date(TradingDay) = 'YYYY-MM-DD'。sql查询语句不需要注释,不然会报错。尽量利用表格中已有的字段。\n""" else: if DEEP_COLUMNS_RESEARCH: if max(highest_similarity_score_list) < START_DEEP_THRESHOLD: # 没有大于0.5的说明这个表可能不对,新增表 print(f"-----------Start Deeping table search") deep_table_maps_LL = deep_find_tables(all_tables_name_list, org_question) # 直接用原始的扫 print(f"--------------len(deep_table_maps_LL):{len(deep_table_maps_LL)}") deep_tables_name_list = [i.get('数据表名') for i in deep_table_maps_LL] all_tables_name_list += deep_tables_name_list all_tables_name_list = list(set(all_tables_name_list)) if max(highest_similarity_score_list) < START_DEEP_THRESHOLD: table_maps_LL, highest_similarity_score_list = get_table_schema_with_emb(all_tables_name_list, question) if DEBUG_VER == 3: print(f"--------------End Deeping table_maps_LL:{table_maps_LL}") content_p_2 = """获取的表结构如下,表结构中列名可以引用使用,表结构中数据示例只是参考不能引用。 我们现在开始查询当前问题,请你分步写出查询sql语句,我把查询结果告诉你,你再告诉我下一步, 注意如果我返回的结果为空或者错误影响下一步调用,请重新告诉我sql语句。 等你全部回答完成,不需要进行下一步调用时,记得提示我:<全部完成,答案如下>,将答案总结以json格式给我,只需要总结当前问题。 查询技巧: 涉及计算的使用sql进行加减乘除数学计算。查询尽量使用InnerCode。查年用year()、查月用month()、某日用date(),如:date(TradingDay) = 'YYYY-MM-DD'。sql查询语句不需要注释,不然会报错。尽量利用表格中已有的字段。\n""" if len(str_process_question[1]) == 0: if DEBUG_VER == 3: print("--------------查询获得事实为空,跳过处理") else: if FRASH_DATA_DESCRIBE: table_maps_LL = create_chat_completion_change_zhushi(str_process_question, table_maps_LL) content_p_2 = main_sql_prompts + content_p_2.replace('', str(table_maps_LL)) if DEBUG_VER == 3: print(f"-------------content_p_2: {content_p_2}") messages.append({"role": "user", "content": content_p_2}) ###开始对话 last_answer = run_conversation_until_complete(messages, max_rounds=MAX_ROUNDS) return last_answer def run_conversation_until_complete(messages, max_rounds=9): """ Test function to run a conversation loop until the assistant indicates completion. """ last_response = None # 用于存储最后一次对话的响应 round_count = 0 # 对话轮数计数器 # response = create_chat_completion(messages) if ASYNC_LLM_TIME >= 2: response = async_llm_chain_call(messages, MODEL) else: response = create_chat_completion(messages) while True: if round_count >= max_rounds: break # 如果对话轮数超过最大值,则退出循环 question = response.choices[0].message.content select_result = to_select(question) messages.append({"role": "assistant", "content": question}) messages.append({"role": "user", "content": str(select_result)}) # response = create_chat_completion(messages) if ASYNC_LLM_TIME >= 2: response = async_llm_chain_call(messages, MODEL) else: response = create_chat_completion(messages) last_response = response.choices[0].message.content # 存储最后一次响应 if "全部完成" in response.choices[0].message.content: break # 如果检测到“回答完成”,则停止循环 round_count += 1 # 增加对话轮数计数 return last_response # 返回最后一次对话的内容 def run_conversation(question, org_question): """ Run a conversation flow given a question by: 1. Using run_conversation_xietong(question) to get an answer. 2. Attempting to find and parse JSON from the answer using find_json. 3. Converting the parsed result (or original text if parsing fails) into a descriptive sentence using process_dict. Parameters: question (str): The question to ask. Returns: str: The final processed answer as a descriptive string. """ org_answer = run_conversation_xietong(question, org_question) parsed_data = find_json(org_answer) last_answer = process_dict(parsed_data) return last_answer,org_answer def get_answer(question, org_question): """ Attempt to answer the given question by interacting with the conversation model. If an error occurs, return a default error message. Parameters: question (str): The question that needs an answer. Returns: str: The answer string or an error message if an exception occurs. """ try: if DEBUG_VER == 3: print(f"Attempting to answer the question: {question}") last_answer,org_answer = run_conversation(question, org_question) return last_answer,org_answer except Exception as e: if DEBUG_VER == 3: print(f"Error occurred while executing get_answer: {e}") return "An error occurred while retrieving the answer." def main_answer(q_json_list, out_path, idx_ranges): """ 处理整个 JSON 数据列表: - 如果当前索引在 idx_ranges 中任一范围内,则按照原有逻辑提取问题、重写问题、调用 get_answer 得到答案, 并更新每个问题对应的 answer 字段; - 否则将每个 team 成员的 answer 字段更新为空。 参数: q_json_list: JSON 数据列表 out_path: 输出文件路径 idx_ranges: 包含多个 (start, end) 范围的列表,例如 [(21, 22), (74, 75), (43, 44)] """ global QUESTION_TYPE, QUESTION_TYPE_LIST, all_tables_name_list, prev_tables_name_list, org_question # 如果 idx_ranges 为空,则默认处理全部数据 if not idx_ranges: idx_ranges = [(0, len(q_json_list))] if REFRASH and Knowledge_file: external_answers = load_external_answers(Knowledge_file) total = len(q_json_list) # 以流式写入的方式生成 JSON 数组 with open(out_path, 'w', encoding='utf-8') as out_f: out_f.write('[\n') for i in range(total): try: print("--> cleaned mem:", gc.collect()) print("--> id:", i) item = q_json_list[i] start_time = time.time() # 判断当前索引是否在任一给定范围内 process_this_item = any(start <= i < end for start, end in idx_ranges) if process_this_item: # 处理需要调用 get_answer 生成答案的数据 questions_list = [(member["id"], member["question"]) for member in item["team"]] answers_dict = {} history_dict = {} all_previous = '' # 初始化相关全局变量 QUESTION_TYPE = "" QUESTION_TYPE_LIST = [] question_num = 0 all_tables_name_list = [] prev_tables_name_list = [] if REFRASH: get_ext_answer_num = 1 # 遍历当前 item 内所有问题 for question_id, question_text in questions_list: try: answer = "" # 更新历史表变量(超过 MAX_TABLES_LIST 后清空) all_tables_name_list = prev_tables_name_list if len(all_tables_name_list) > MAX_TABLES_LIST: if DEBUG_VER == 3: print(f"----------------历史表过多,进行清空") all_tables_name_list = [] question_num += 1 org_question = question_text # 根据关键词更新 QUESTION_TYPE_LIST if "港股" in question_text and "港股份" not in question_text and "港股" not in QUESTION_TYPE_LIST: QUESTION_TYPE_LIST.append("港股") if "美股" in question_text and "美股份" not in question_text and "美股" not in QUESTION_TYPE_LIST: QUESTION_TYPE_LIST.append("美股") if "A股" in question_text and "A股" not in QUESTION_TYPE_LIST: QUESTION_TYPE_LIST.append("A股") # 问题预处理 # if DEBUG_VER == 3: print(f'---------->Org question_text:{question_text}') question_text = preprocessing_question(question_text) question_text = add_question_mark(question_text) # 如果没有历史问答,则直接使用原问题;否则重写问题 if all_previous == '': rewritten_question = question_text else: rewritten_question = question_rew(all_previous, question_text) # 调用 get_answer 获取答案,并进行优化 if REFRASH: matched_answer = find_kl_answer(question_num, history_dict, org_question, external_answers) if matched_answer: if get_ext_answer_num == question_num: answer = matched_answer history_dict[org_question] = answer # 累加问答历史 all_previous += f"第{question_num}个问题:" + question_text + f" 第{question_num}个的回答:" + answer + " \n" get_ext_answer_num += 1 if not answer or answer == "": answer,org_answer = get_answer(rewritten_question, org_question) # 累加问答历史 all_previous += f"第{question_num}个问题:" + question_text + f" 第{question_num}个的回答:" + answer + " \n" answer = optimize_answer(rewritten_question, answer, org_answer) # if DEBUG_VER == 3: print(f"---------->answer:{answer}") answer = answer.replace(";", "") answers_dict[question_id] = answer del answer except Exception as e_inner: print(f"Error processing question id {question_id}: {e_inner}") answers_dict[question_id] = "异常" # 更新 item 中各 team 成员的 answer 字段 for member in item["team"]: member["answer"] = answers_dict.get(member["id"], "无答案") else: # 对于不在处理范围内的数据,直接置空 answer 字段 for member in item["team"]: member["answer"] = "" updated_data = {"tid": item["tid"], "team": item["team"]} elapsed_time = time.time() - start_time print(f"Completed processing JSON index {i} in {elapsed_time:.2f} seconds") # 写入当前处理完的数据项。若不是最后一个,则尾部添加逗号。 json_str = json.dumps(updated_data, ensure_ascii=False, indent=4) if i < total - 1: out_f.write(json_str + ',\n') else: out_f.write(json_str + '\n') out_f.flush() # 每处理完一道题就 flush 到磁盘,降低内存占用 except Exception as e_outer: print(f"Error processing JSON item index {i}: {e_outer}") out_f.write(']\n')