You can not select more than 25 topics Topics must start with a letter or number, can include dashes ('-') and can be up to 35 characters long.

336 lines
18 KiB

This file contains ambiguous Unicode characters!

This file contains ambiguous Unicode characters that may be confused with others in your current locale. If your use case is intentional and legitimate, you can safely ignore this warning. Use the Escape button to highlight these characters.

import time, gc
from config import *
from agent.question.question_utils import preprocessing_question, process_question, process_dict, add_question_mark, question_rew
from agent.llm_client import create_chat_completion, create_chat_completion_change_zhushi
from agent.schema.schema_utils import get_table_schema, find_table_with_emb, get_table_schema_with_emb, deep_find_tables
from agent.sql.exec import exec_sql_s
from agent.sql.process_sql import to_select, extract_table_names, all_tables_in_prompt
from agent.sql.knowledge import way_string_2
from agent.optimization import optimize_answer, find_kl_answer
from agent.utils.async_util import async_llm_chain_call
from agent.utils.io_utils import load_external_answers
from agent.utils.json_utils import find_json
def run_conversation_xietong(question, org_question):
global QUESTION_TYPE
global QUESTION_TYPE_LIST
global LL
global table_maps
global table_maps_LL
global all_tables_name_list
str_process_question = process_question(question)
if DEBUG_VER == 3:
print(f"-------------已查询获得事实: {str_process_question}")
if "ConstantDB.HK_SecuMain" in str(str_process_question) and "港股" not in QUESTION_TYPE_LIST:
QUESTION_TYPE_LIST.append("港股")
if "ConstantDB.US_SecuMain" in str(str_process_question) and "美股" not in QUESTION_TYPE_LIST:
QUESTION_TYPE_LIST.append("美股")
if "ConstantDB.SecuMain" in str(str_process_question) and "A股" not in QUESTION_TYPE_LIST:
QUESTION_TYPE_LIST.append("A股")
content_p_1 = content_ALL
if len(QUESTION_TYPE_LIST) == 0:
content_p_1 = content_ALL
elif len(QUESTION_TYPE_LIST) >= 2:
content_p_1 = content_ALL
elif "港股" in QUESTION_TYPE_LIST:
content_p_1 = content_HK
elif "美股" in QUESTION_TYPE_LIST:
content_p_1 = content_US
elif "A股" in QUESTION_TYPE_LIST:
content_p_1 = content_CN
else:
content_p_1 = content_ALL
if DEBUG_VER == 3:
print(f"---------QUESTION_TYPE_LIST:{QUESTION_TYPE_LIST}")
content_p = content_p_1.replace('<<question>>', str(question)).replace('<<fact_1>>',
str(str_process_question))
main_sql_prompts,highest_prompt_similar = way_string_2(org_question)
content_p = content_p + main_sql_prompts
# 执行函数部分
messages = []
messages.append({"role": "user", "content": "您好阿"})
messages.append({"role": "user", "content": content_p})
response = create_chat_completion(messages)
if "全部完成" in response.choices[0].message.content:
if "XX" in response.choices[0].message.content:
response.choices[0].message.content += "ConstantDB.SecuMain,ConstantDB.HK_SecuMain,ConstantDB.US_SecuMain"
else:
if DEBUG_VER == 3:
print("全部完成response",response.choices[0].message.content)
if "需要查询的数据库表" in response.choices[0].message.content or "需要查询数据库表" in response.choices[0].message.content:
pass
else:
return response.choices[0].message.content
messages.append({"role": "assistant", "content": response.choices[0].message.content})
table_maps = get_table_schema(question)
LL = [i for i in table_maps if i.get('数据表名') in response.choices[0].message.content]
tables_name_list = [i.get('数据表名') for i in table_maps if i.get('数据表名') in response.choices[0].message.content]
try:
sql_prompt_db_tables = extract_table_names(main_sql_prompts)
except Exception as e:
if DEBUG_VER == 3:
print(f"------->e:{e}")
if sql_prompt_db_tables:
all_tables_name_list += sql_prompt_db_tables
all_tables_name_list += tables_name_list
all_tables_name_list = list(set(all_tables_name_list))
if DEBUG_VER == 3:
print(f"-----------------all_tables_name_list:{all_tables_name_list}")
all_tables_name_list += find_table_with_emb(all_tables_name_list, question)
all_tables_name_list = list(set(all_tables_name_list))
if DEBUG_VER == 3:
print(f"-----------------find_table_with_emb all_tables_name_list:{all_tables_name_list}")
if all_tables_name_list == []:
highest_similarity_score_list = [0]
else:
table_maps_LL, highest_similarity_score_list = get_table_schema_with_emb(all_tables_name_list, question)
# print(f"--------------table_maps_LL:{table_maps_LL}")
if all_tables_in_prompt(tables_name_list, main_sql_prompts) and highest_prompt_similar > START_DEEP_THRESHOLD:
content_p_2 = """获取的表结构如下<list>,表结构中列名可以引用使用,表结构中数据示例只是参考不能引用。
我们现在开始查询当前问题可参考示例sql写出查询sql语句如果示例sql是一句的一步你也尽可能用一句sql得到结果。
如果需要多步我把查询结果告诉你你再告诉我下一步注意如果我返回的结果为空或者错误影响下一步调用请重新告诉我sql语句。
等你全部回答完成,不需要进行下一步调用时,记得提示我:<全部完成,答案如下>,将答案总结以json格式给我只需要总结当前问题。
查询技巧: 涉及计算的使用sql进行加减乘除数学计算。查询尽量使用InnerCode。查年用year()函数、月用month()、某日则用date()date(TradingDay) = 'YYYY-MM-DD'。sql查询语句不需要注释不然会报错。尽量利用表格中已有的字段。\n"""
else:
if DEEP_COLUMNS_RESEARCH:
if max(highest_similarity_score_list) < START_DEEP_THRESHOLD: # 没有大于0.5的说明这个表可能不对,新增表
print(f"-----------Start Deeping table search")
deep_table_maps_LL = deep_find_tables(all_tables_name_list, org_question) # 直接用原始的扫
print(f"--------------len(deep_table_maps_LL):{len(deep_table_maps_LL)}")
deep_tables_name_list = [i.get('数据表名') for i in deep_table_maps_LL]
all_tables_name_list += deep_tables_name_list
all_tables_name_list = list(set(all_tables_name_list))
if max(highest_similarity_score_list) < START_DEEP_THRESHOLD:
table_maps_LL, highest_similarity_score_list = get_table_schema_with_emb(all_tables_name_list, question)
if DEBUG_VER == 3:
print(f"--------------End Deeping table_maps_LL:{table_maps_LL}")
content_p_2 = """获取的表结构如下<list>,表结构中列名可以引用使用,表结构中数据示例只是参考不能引用。
我们现在开始查询当前问题请你分步写出查询sql语句我把查询结果告诉你你再告诉我下一步
注意如果我返回的结果为空或者错误影响下一步调用请重新告诉我sql语句。
等你全部回答完成,不需要进行下一步调用时,记得提示我:<全部完成,答案如下>,将答案总结以json格式给我只需要总结当前问题。
查询技巧: 涉及计算的使用sql进行加减乘除数学计算。查询尽量使用InnerCode。查年用year()、查月用month()、某日用date()date(TradingDay) = 'YYYY-MM-DD'。sql查询语句不需要注释不然会报错。尽量利用表格中已有的字段。\n"""
if len(str_process_question[1]) == 0:
if DEBUG_VER == 3:
print("--------------查询获得事实为空,跳过处理")
else:
if FRASH_DATA_DESCRIBE:
table_maps_LL = create_chat_completion_change_zhushi(str_process_question, table_maps_LL)
content_p_2 = main_sql_prompts + content_p_2.replace('<list>', str(table_maps_LL))
if DEBUG_VER == 3:
print(f"-------------content_p_2: {content_p_2}")
messages.append({"role": "user", "content": content_p_2}) ###开始对话
last_answer = run_conversation_until_complete(messages, max_rounds=MAX_ROUNDS)
return last_answer
def run_conversation_until_complete(messages, max_rounds=9):
"""
Test function to run a conversation loop until the assistant indicates completion.
"""
last_response = None # 用于存储最后一次对话的响应
round_count = 0 # 对话轮数计数器
# response = create_chat_completion(messages)
if ASYNC_LLM_TIME >= 2:
response = async_llm_chain_call(messages, MODEL)
else:
response = create_chat_completion(messages)
while True:
if round_count >= max_rounds:
break # 如果对话轮数超过最大值,则退出循环
question = response.choices[0].message.content
select_result = to_select(question)
messages.append({"role": "assistant", "content": question})
messages.append({"role": "user", "content": str(select_result)})
# response = create_chat_completion(messages)
if ASYNC_LLM_TIME >= 2:
response = async_llm_chain_call(messages, MODEL)
else:
response = create_chat_completion(messages)
last_response = response.choices[0].message.content # 存储最后一次响应
if "全部完成" in response.choices[0].message.content:
break # 如果检测到“回答完成”,则停止循环
round_count += 1 # 增加对话轮数计数
return last_response # 返回最后一次对话的内容
def run_conversation(question, org_question):
"""
Run a conversation flow given a question by:
1. Using run_conversation_xietong(question) to get an answer.
2. Attempting to find and parse JSON from the answer using find_json.
3. Converting the parsed result (or original text if parsing fails) into a descriptive sentence using process_dict.
Parameters:
question (str): The question to ask.
Returns:
str: The final processed answer as a descriptive string.
"""
org_answer = run_conversation_xietong(question, org_question)
parsed_data = find_json(org_answer)
last_answer = process_dict(parsed_data)
return last_answer,org_answer
def get_answer(question, org_question):
"""
Attempt to answer the given question by interacting with the
conversation model. If an error occurs, return a default error message.
Parameters:
question (str): The question that needs an answer.
Returns:
str: The answer string or an error message if an exception occurs.
"""
try:
if DEBUG_VER == 3:
print(f"Attempting to answer the question: {question}")
last_answer,org_answer = run_conversation(question, org_question)
return last_answer,org_answer
except Exception as e:
if DEBUG_VER == 3:
print(f"Error occurred while executing get_answer: {e}")
return "An error occurred while retrieving the answer."
def main_answer(q_json_list, out_path, idx_ranges):
"""
处理整个 JSON 数据列表:
- 如果当前索引在 idx_ranges 中任一范围内,则按照原有逻辑提取问题、重写问题、调用 get_answer 得到答案,
并更新每个问题对应的 answer 字段;
- 否则将每个 team 成员的 answer 字段更新为空。
参数:
q_json_list: JSON 数据列表
out_path: 输出文件路径
idx_ranges: 包含多个 (start, end) 范围的列表,例如 [(21, 22), (74, 75), (43, 44)]
"""
global QUESTION_TYPE, QUESTION_TYPE_LIST, all_tables_name_list, prev_tables_name_list, org_question
# 如果 idx_ranges 为空,则默认处理全部数据
if not idx_ranges:
idx_ranges = [(0, len(q_json_list))]
if REFRASH and Knowledge_file:
external_answers = load_external_answers(Knowledge_file)
total = len(q_json_list)
# 以流式写入的方式生成 JSON 数组
with open(out_path, 'w', encoding='utf-8') as out_f:
out_f.write('[\n')
for i in range(total):
try:
print("--> cleaned mem:", gc.collect())
print("--> id:", i)
item = q_json_list[i]
start_time = time.time()
# 判断当前索引是否在任一给定范围内
process_this_item = any(start <= i < end for start, end in idx_ranges)
if process_this_item:
# 处理需要调用 get_answer 生成答案的数据
questions_list = [(member["id"], member["question"]) for member in item["team"]]
answers_dict = {}
history_dict = {}
all_previous = ''
# 初始化相关全局变量
QUESTION_TYPE = ""
QUESTION_TYPE_LIST = []
question_num = 0
all_tables_name_list = []
prev_tables_name_list = []
if REFRASH:
get_ext_answer_num = 1
# 遍历当前 item 内所有问题
for question_id, question_text in questions_list:
try:
answer = ""
# 更新历史表变量(超过 MAX_TABLES_LIST 后清空)
all_tables_name_list = prev_tables_name_list
if len(all_tables_name_list) > MAX_TABLES_LIST:
if DEBUG_VER == 3:
print(f"----------------历史表过多,进行清空")
all_tables_name_list = []
question_num += 1
org_question = question_text
# 根据关键词更新 QUESTION_TYPE_LIST
if "港股" in question_text and "港股份" not in question_text and "港股" not in QUESTION_TYPE_LIST:
QUESTION_TYPE_LIST.append("港股")
if "美股" in question_text and "美股份" not in question_text and "美股" not in QUESTION_TYPE_LIST:
QUESTION_TYPE_LIST.append("美股")
if "A股" in question_text and "A股" not in QUESTION_TYPE_LIST:
QUESTION_TYPE_LIST.append("A股")
# 问题预处理
# if DEBUG_VER == 3:
print(f'---------->Org question_text:{question_text}')
question_text = preprocessing_question(question_text)
question_text = add_question_mark(question_text)
# 如果没有历史问答,则直接使用原问题;否则重写问题
if all_previous == '':
rewritten_question = question_text
else:
rewritten_question = question_rew(all_previous, question_text)
# 调用 get_answer 获取答案,并进行优化
if REFRASH:
matched_answer = find_kl_answer(question_num, history_dict, org_question, external_answers)
if matched_answer:
if get_ext_answer_num == question_num:
answer = matched_answer
history_dict[org_question] = answer
# 累加问答历史
all_previous += f"{question_num}个问题:" + question_text + f"{question_num}个的回答:" + answer + " \n"
get_ext_answer_num += 1
if not answer or answer == "":
answer,org_answer = get_answer(rewritten_question, org_question)
# 累加问答历史
all_previous += f"{question_num}个问题:" + question_text + f"{question_num}个的回答:" + answer + " \n"
answer = optimize_answer(rewritten_question, answer, org_answer)
# if DEBUG_VER == 3:
print(f"---------->answer:{answer}")
answer = answer.replace(";", "")
answers_dict[question_id] = answer
del answer
except Exception as e_inner:
print(f"Error processing question id {question_id}: {e_inner}")
answers_dict[question_id] = "异常"
# 更新 item 中各 team 成员的 answer 字段
for member in item["team"]:
member["answer"] = answers_dict.get(member["id"], "无答案")
else:
# 对于不在处理范围内的数据,直接置空 answer 字段
for member in item["team"]:
member["answer"] = ""
updated_data = {"tid": item["tid"], "team": item["team"]}
elapsed_time = time.time() - start_time
print(f"Completed processing JSON index {i} in {elapsed_time:.2f} seconds")
# 写入当前处理完的数据项。若不是最后一个,则尾部添加逗号。
json_str = json.dumps(updated_data, ensure_ascii=False, indent=4)
if i < total - 1:
out_f.write(json_str + ',\n')
else:
out_f.write(json_str + '\n')
out_f.flush() # 每处理完一道题就 flush 到磁盘,降低内存占用
except Exception as e_outer:
print(f"Error processing JSON item index {i}: {e_outer}")
out_f.write(']\n')