You can not select more than 25 topics Topics must start with a letter or number, can include dashes ('-') and can be up to 35 characters long.

367 lines
16 KiB

This file contains ambiguous Unicode characters!

This file contains ambiguous Unicode characters that may be confused with others in your current locale. If your use case is intentional and legitimate, you can safely ignore this warning. Use the Escape button to highlight these characters.

# agent/question/question_utils.py
import json
from config import *
from agent.llm_client import create_chat_completion
from agent.sql.exec import exec_sql_s
from agent.utils import find_json, dict_to_sentence
def preprocessing_question(question):
"""
SQL问题预处理
"""
if "年末" in question:
question = question.replace("年末","年12月31日")
if "年底" in question:
question = question.replace("年底","年12月31日")
return question
def add_question_mark(text):
# 检查文本最后是否缺少问号
if not text.endswith('?') and not text.endswith(''):
text += ''
# 检查文本中右括号“)”的对应左括号“(”的处理
if len(text) >= 2 and text[-2] == '':
left_bracket_index = text.rfind('')
if left_bracket_index != -1:
if left_bracket_index == 0 or text[left_bracket_index - 1] not in ['?', '']:
text = text[:left_bracket_index] + '' + text[left_bracket_index:]
return text
def process_company_name(value):
"""
Given a company name (or related keyword), search in three tables:
ConstantDB.SecuMain, ConstantDB.HK_SecuMain, ConstantDB.US_SecuMain.
Attempts to match various company-related fields (e.g., ChiName, EngName, etc.)
and returns all matching results along with the table where they were found.
Parameters:
value (str): The company name or related string to match.
Returns:
list: A list of tuples (result, table) where result is the matched data and table is the table name.
If no matches found, prints a message and returns an empty list.
"""
res_lst = []
tables = ['ConstantDB.SecuMain', 'ConstantDB.HK_SecuMain', 'ConstantDB.US_SecuMain']
if len(QUESTION_TYPE_LIST) == 0:
tables = ['ConstantDB.SecuMain', 'ConstantDB.HK_SecuMain', 'ConstantDB.US_SecuMain']
elif len(QUESTION_TYPE_LIST) >= 2:
tables = ['ConstantDB.SecuMain', 'ConstantDB.HK_SecuMain', 'ConstantDB.US_SecuMain']
elif "港股" in QUESTION_TYPE_LIST:
tables = ['ConstantDB.HK_SecuMain']
elif "美股" in QUESTION_TYPE_LIST:
tables = ['ConstantDB.US_SecuMain']
elif "A股" in QUESTION_TYPE_LIST:
tables = ['ConstantDB.SecuMain']
else:
tables = ['ConstantDB.SecuMain', 'ConstantDB.HK_SecuMain', 'ConstantDB.US_SecuMain']
# print(f"--------------tables:{tables}")
columns_to_match = ['CompanyCode', 'SecuCode', 'ChiName', 'ChiNameAbbr',
'EngName', 'EngNameAbbr', 'SecuAbbr', 'ChiSpelling']
columns_to_select = ['InnerCode', 'CompanyCode', 'SecuCode', 'ChiName', 'ChiNameAbbr',
'EngName', 'EngNameAbbr', 'SecuAbbr', 'ChiSpelling']
# Escape single quotes to prevent SQL injection
value = value.replace("'", "''")
for table in tables:
# For the US table, remove columns that may not be available
local_match_cols = columns_to_match.copy()
local_select_cols = columns_to_select.copy()
if 'US' in table:
if 'ChiNameAbbr' in local_match_cols:
local_match_cols.remove('ChiNameAbbr')
if 'ChiNameAbbr' in local_select_cols:
local_select_cols.remove('ChiNameAbbr')
if 'EngNameAbbr' in local_match_cols:
local_match_cols.remove('EngNameAbbr')
if 'EngNameAbbr' in local_select_cols:
local_select_cols.remove('EngNameAbbr')
# Build the WHERE clause with OR conditions for each column
match_conditions = [f"{col} = '{value}'" for col in local_match_cols]
where_clause = ' OR '.join(match_conditions)
sql = f"""
SELECT {', '.join(local_select_cols)}
FROM {table}
WHERE {where_clause}
"""
result = exec_sql_s(sql)
if result:
res_lst.append((result, table))
# else:
# # The 'else' clause in a for loop runs only if no 'break' was encountered.
# # Here it just prints if no results were found.
# if not res_lst:
# if DEBUG_VER == 3:
# print(f"未在任何表中找到上市公司名称为 {value} 的信息。")
else:
# If no result, modify query for fuzzy matching
fuzzy_match_conditions = [f"{col} LIKE '%{value}%' " for col in local_match_cols]
fuzzy_where_clause = ' OR '.join(fuzzy_match_conditions)
# Query with fuzzy match
sql_fuzzy = f"""
SELECT {', '.join(local_select_cols)}
FROM {table}
WHERE {fuzzy_where_clause}
"""
fuzzy_result = exec_sql_s(sql_fuzzy)
if fuzzy_result:
res_lst.append((fuzzy_result, table))
# If no results found after both exact and fuzzy matches
if not res_lst:
if DEBUG_VER == 3:
print(f"未在任何表中找到上市公司名称为 {value} 的信息。")
return res_lst
def process_code(value):
"""
Given a code (e.g., a stock code), search the three tables and return matches.
Parameters:
value (str): The code to search for.
Returns:
list: A list of tuples (result, table) if found, else empty.
"""
res_lst = []
tables = ['ConstantDB.SecuMain', 'ConstantDB.HK_SecuMain', 'ConstantDB.US_SecuMain']
if len(QUESTION_TYPE_LIST) == 0:
tables = ['ConstantDB.SecuMain', 'ConstantDB.HK_SecuMain', 'ConstantDB.US_SecuMain']
elif len(QUESTION_TYPE_LIST) >= 2:
tables = ['ConstantDB.SecuMain', 'ConstantDB.HK_SecuMain', 'ConstantDB.US_SecuMain']
elif "港股" in QUESTION_TYPE_LIST:
tables = ['ConstantDB.HK_SecuMain']
elif "美股" in QUESTION_TYPE_LIST:
tables = ['ConstantDB.US_SecuMain']
elif "A股" in QUESTION_TYPE_LIST:
tables = ['ConstantDB.SecuMain']
else:
tables = ['ConstantDB.SecuMain', 'ConstantDB.HK_SecuMain', 'ConstantDB.US_SecuMain']
# print(f"--------------tables:{tables}")
columns_to_select = ['InnerCode', 'CompanyCode', 'SecuCode', 'ChiName', 'ChiNameAbbr',
'EngName', 'EngNameAbbr', 'SecuAbbr', 'ChiSpelling']
value = value.replace("'", "''") # Escape single quotes
for table in tables:
local_select_cols = columns_to_select.copy()
if 'US' in table:
if 'ChiNameAbbr' in local_select_cols:
local_select_cols.remove('ChiNameAbbr')
if 'EngNameAbbr' in local_select_cols:
local_select_cols.remove('EngNameAbbr')
sql = f"""
SELECT {', '.join(local_select_cols)}
FROM {table}
WHERE SecuCode = '{value}'
"""
result = exec_sql_s(sql)
if result:
res_lst.append((result, table))
else:
if not res_lst:
if DEBUG_VER == 3:
print(f"未在任何表中找到代码为 {value} 的信息。")
return res_lst
def process_jj(value):
sql = f"SELECT ChiName, CompanyCode FROM InstitutionDB.LC_InstiArchive WHERE ChiName LIKE '%{value}%'"
res = exec_sql_s(sql)
if res:
return [(res,'InstitutionDB.LC_InstiArchive')]
else:
res_lst = process_company_name(value)
if res_lst:
return res_lst
return []
def process_items(item_list):
"""
Given a list of items (dictionaries) from JSON extraction, attempt to process each based on its key:
- If key is '基金名称' or '上市公司名称', use process_company_name.
- If key is '代码', use process_code.
- Otherwise, print an unrecognized key message.
Parameters:
item_list (list): A list of dictionaries like [{"上市公司名称": "XX公司"}, {"代码":"600872"}].
Returns:
tuple: (res, tables)
res (str): A formatted string showing what was found.
tables (list): A list of table names where matches were found.
"""
res_list = []
try:
for item in item_list:
key, value = list(item.items())[0]
if key in ["基金名称", "上市公司名称"]:
res_list.extend(process_company_name(value))
elif key == "代码":
res_list.extend(process_code(value))
elif key == "基金公司简称":
res_list.extend(process_jj(value))
else:
if DEBUG_VER == 3:
print(f"无法识别的键:{key}")
pass
except Exception as e:
if DEBUG_VER == 3:
print(f"process_items 发生错误: {e}")
pass
# Filter out empty results
res_list = [i for i in res_list if i]
res = ''
tables = []
for result_data, table_name in res_list:
tables.append(table_name)
res += f"预处理程序通过表格:{table_name} 查询到以下内容:\n {json.dumps(result_data, ensure_ascii=False, indent=1)} \n"
return res, tables
def process_question(question):
"""
Given a question, run it through a prompt to perform Named Entity Recognition (NER),
extract entities (上市公司名称, 代码, 基金名称), parse the assistant's JSON response,
and process the items to retrieve relevant information from the database.
Parameters:
question (str): The user question.
Returns:
tuple: (res, tables) where
res (str) - Processed result details as a string.
tables (list) - List of tables involved in the final result.
"""
prompt = '''
你将会进行命名实体识别任务并输出实体json。你只需要识别以下4种实体
-上市公司名称
-代码
-基金名称
-基金公司简称
其中,上市公司名称可以是全称,简称,拼音缩写,代码包含股票代码和基金代码,基金名称包含债券型基金,
以下是几个示例:
user:唐山港集团股份有限公司是什么时间上市的回答XXXX-XX-XX
当年一共上市了多少家企业?
这些企业有多少是在北京注册的?
assistant:```json
[{"上市公司名称":"唐山港集团股份有限公司"}]
```
user:JD的职工总数有多少人
该公司披露的硕士或研究生学历(及以上)的有多少人?
20201月1日至年底退休了多少人
assistant:```json
[{"上市公司名称":"JD"}]
```
user:600872的全称、A股简称、法人、法律顾问、会计师事务所及董秘是
该公司实控人是否发生改变如果发生变化什么时候变成了谁是哪国人是否有永久境外居留权回答时间用XXXX-XX-XX
assistant:```json
[{"代码":"600872"}]
```
user:华夏鼎康债券A在2019年的分红次数是多少每次分红的派现比例是多少
基于上述分红数据在2019年最后一次分红时如果一位投资者持有1000份该基金税后可以获得多少分红收益
assistant:```json
[{"基金名称":"华夏鼎康债券A"}]
```
user:实体识别任务:```易方达基金管理有限公司在19年成立了多少支基金
哪支基金的规模最大?
这支基金20年最后一次分红派现比例多少钱```
assistant:```json
[{"基金公司简称":"易方达"}]
```
user:化工纳入过多少个子类概念?
assistant:```json
[]
```
'''
messages = [{'role': 'system', 'content': prompt}, {'role': 'user', 'content': question}]
aa = create_chat_completion(messages)
bb = find_json(aa.choices[0].message.content)
return process_items(bb)
def question_rew(context_text, original_question):
"""
Rewrite the given question to be clearer and more specific based on the provided context,
without altering the original meaning or omitting any information.
Parameters:
context_text (str): The context text that the question is based on.
original_question (str): The question to be rewritten.
Returns:
str: The rewritten question.
"""
prompt = (
f"根据这些内容:'{context_text}',帮我重写当前问题:'{original_question}' ,让问题清晰明确,"
"不改变原意,代词转成具体人事物,不要遗漏信息,只返回问题。"
"如果当前问题中有时间代词(如“当年、当天”等)或指物代词(如“该公司”、“它”等),检查前面问题和回答(一般是上一个,就近原则)中是否明确了时间或主体等,并将这些信息补充到当前问题中。"
"如果当前问题无法从前面问题和回答中找到代词(如时间)所指的具体信息,则表示当前代词,如时间指全部时间。"
"问题可能需要时间,也可能不需要时间,如果不需要则在后面追加一个不带时间的小问题(不需要换行等,只需要接在原问题后面)。"
"让我们一步一步思考"
"以下是几个示例:\n"
"user:根据这些内容:'第1轮问题最新更新的2021年度报告中机构持有无限售流通A股数量合计最多的公司简称是 第1轮回答公司简称 帝尔激光',帮我重写当前问题:'在这份报告中该公司机构持有无限售流通A股比例合计是多少保留2位小数'"
"assistant:最新更新的2021年度报告中,公司简称 帝尔激光 持有无限售流通A股比例合计是多少保留2位小数"
"user:根据这些内容:'第1轮问题TK他是否已经成立了是或者否 第1轮回答',帮我重写当前问题:'这家公司17年最高收盘价是多少'"
"assistant: 2017年TK这家公司的最高收盘价是多少"
"user:根据这些内容:'第1轮问题TK他是否已经成立了是或者否 第1轮回答\n第2轮问题2017年TK这家公司的最高收盘价是多少 第2轮回答2017年TK最高收盘价 5.79',帮我重写当前问题:'当天有多少家公司成立了?'"
"assistant: 2017年TK最高收盘价 10.79 是什么时候?当天有多少家公司成立了?"
"user:根据这些内容:'第1轮问题航锦科技股份有限公司是否变更过公司名称 第1轮回答没有\n第2轮问题航锦科技股份有限公司涉及回购的最大的一笔金额是多少第2轮回答43951008.0',帮我重写当前问题:该年度前十大股东的持股比例变成了多少?"
"assistant:?航锦科技股份有限公司涉及回购的最大的一笔金额 43951008.0是哪一年?该年度前十大股东的持股比例变成了多少?"
)
messages = [{"role": "user", "content": prompt}]
response = create_chat_completion(messages)
return response.choices[0].message.content
def process_dict(d):
"""
Recursively process a nested dictionary to produce a comma-separated description.
For nested dictionaries, it processes them recursively and returns a descriptive string.
For example:
{
"company": {
"name": "ABC Corp",
"location": "New York"
},
"year": 2021
}
might be processed into a string like:
"company company 是 name 是 ABC Corp, location 是 New York, year 2021"
Parameters:
d (dict): A dictionary or another object to describe.
Returns:
str: A descriptive string.
"""
def recursive_process(sub_dict):
sentences = []
for key, value in sub_dict.items():
if isinstance(value, dict):
# Process nested dictionary and wrap result in dict_to_sentence for formatting
nested_result = recursive_process(value)
sentences.append(dict_to_sentence({key: nested_result}))
else:
# Non-dict values are directly appended
sentences.append(f"{key} {value}")
return ", ".join(sentences)
if not isinstance(d, dict):
# If it's not a dictionary, just return its string representation
return str(d)
return recursive_process(d)