This commit is contained in:
wangmeihua 2025-10-10 11:48:16 +08:00
parent f0c8120ada
commit 5bb3180cc4

View File

@ -7,6 +7,7 @@ import json
import math
from rag.service_opts import get_service_params, sor_get_service_params
from rag.rag_operations import RagOperations
from rag.aslmapi import MemoryManager
helptext = """kyrag API:
@ -31,8 +32,25 @@ data: {
3docs文档
path: /v1/docs
4longmemory存储
4. 添加用户消息到记忆:
path: /v1/add_user_messages
headers: {
"Content-Type": "application/json"
}
data: {
"orgid": "用户组织ID",
"messages": [{"role": "user", "content": "消息内容"}, ...]
}
5. 获取用户所有记忆:
path: /v1/get_user_memories
headers: {
"Content-Type": "application/json"
}
data: {
"orgid": "用户组织ID",
"limit": 10 # 可选,默认为 10
}
"""
async def docs(request, params_kw, *params, **kw):
@ -72,8 +90,9 @@ async def get_kdbs(request, params_kw, *params, **kw):
})
return result
async def fusedsearch(request, params_kw, *params, **kw):
async def fusedsearch(request, params_kw, *params):
"""融合搜索,调用服务化端点"""
kw = request._run_ns
f = kw.get('get_userorgid')
orgid = await f()
debug(f"orgid: {orgid}{f=}")
@ -155,6 +174,7 @@ async def fusedsearch(request, params_kw, *params, **kw):
if use_rerank and search_results:
final_results = await rag_ops.rerank_results(request, combined_text, search_results, limit, service_params,
userid, timings)
debug(f"final_results: {final_results}")
else:
final_results = [{k: v for k, v in r.items() if k != 'rerank_score'} for r in search_results]
@ -175,6 +195,87 @@ async def fusedsearch(request, params_kw, *params, **kw):
}
async def add_user_messages(request, params_kw, *params, **kw):
"""添加用户消息到记忆库
Args:
request: HTTP 请求对象
params_kw: 请求参数字典包含 orgid messages
*params: 任意位置参数
**kw: 任意关键字参数 get_module_dbname, get_userorgid
Returns:
str: JSON 格式的响应包含 status result 或错误信息
"""
debug(f"Received request: path={request.path}, params_kw={params_kw}, kw={kw}")
orgid = params_kw.get('orgid')
messages = params_kw.get('messages')
if not orgid or not isinstance(orgid, str):
error("orgid 参数无效,必须为非空字符串")
return json.dumps({"status": "error", "message": "orgid 参数无效,必须为非空字符串"})
if not messages or not isinstance(messages, list):
error("messages 参数无效,必须为非空列表")
return json.dumps({"status": "error", "message": "messages 参数无效,必须为非空列表"})
try:
manager = MemoryManager()
result = await manager.add_messages_to_memory(messages, orgid)
debug(f"用户消息添加结果: {result}")
return json.dumps(result, ensure_ascii=False)
except Exception as e:
error(f"添加用户消息失败: {str(e)}")
return json.dumps({"status": "error", "message": str(e)})
async def get_user_memories(request, params_kw, *params, **kw):
"""获取用户的所有记忆记录
Args:
request: HTTP 请求对象
params_kw: 请求参数字典包含 orgid 和可选的 limit
*params: 任意位置参数
**kw: 任意关键字参数 get_module_dbname, get_userorgid
Returns:
str: JSON 格式的响应包含 status result 或错误信息
"""
debug(f"Received request: path={request.path}, params_kw={params_kw}, kw={kw}")
orgid = params_kw.get('orgid')
limit = params_kw.get('limit', 10)
if not orgid or not isinstance(orgid, str):
error("orgid 参数无效,必须为非空字符串")
return json.dumps({"status": "error", "message": "orgid 参数无效,必须为非空字符串"})
try:
limit = int(limit) if isinstance(limit, (str, int, float)) else 10
manager = MemoryManager()
result = await manager.get_all_memories(user_id=orgid, limit=limit)
debug(f"用户 {orgid} 的记忆检索结果: {result}")
return json.dumps(result, ensure_ascii=False)
except Exception as e:
error(f"获取用户 {orgid} 的记忆失败: {str(e)}")
return json.dumps({"status": "error", "message": str(e)})
async def _validate_fiids_orgid(fiids, orgid, kw):
"""验证 fiids 的 orgid 与当前用户 orgid 是否一致"""
if fiids:
db = DBPools()
dbname = kw.get('get_module_dbname')('rag')
sql_opts = """SELECT orgid FROM kdb WHERE id = ${id}$"""
try:
async with db.sqlorContext(dbname) as sor:
result = await sor.sqlExe(sql_opts, {"id": fiids[0]})
if not result:
raise ValueError(f"未找到 fiid={fiids[0]} 的记录")
kdb_orgid = result[0].get('orgid')
if kdb_orgid != orgid:
raise ValueError(f"orgid 不一致: kdb.orgid={kdb_orgid}, user orgid={orgid}")
except Exception as e:
error(f"orgid 验证失败: {str(e)}")
raise
async def _validate_fiids_orgid(fiids, orgid, kw):
"""验证 fiids 的 orgid 与当前用户 orgid 是否一致"""
if fiids:
@ -213,208 +314,4 @@ def _combine_query_with_triplets(query, triplets):
debug(f"拼接文本: {combined_text[:200]}... (总长度: {len(combined_text)}, 三元组数量: {len(triplet_texts)})")
return combined_text
# api_service = APIService()
# start_time = time.time()
# timing_stats = {}
# try:
# info(
# f"开始融合搜索: query={query}, userid={orgid}, knowledge_base_ids={fiids}")
#
# if not query or not orgid or not fiids:
# raise ValueError("query、orgid 和 knowledge_base_ids 不能为空")
#
# # 提取实体
# entity_extract_start = time.time()
# query_entities = await api_service.extract_entities(
# request=request,
# query=query,
# upappid=service_params['entities'],
# apiname="LTP/small",
# user=userid
# )
# timing_stats["entity_extraction"] = time.time() - entity_extract_start
# debug(f"提取实体: {query_entities}, 耗时: {timing_stats['entity_extraction']:.3f} 秒")
#
# # 调用 Neo4j 服务进行三元组匹配
# all_triplets = []
# triplet_match_start = time.time()
# for kb_id in fiids:
# debug(f"调用 Neo4j 三元组匹配: knowledge_base_id={kb_id}")
# try:
# neo4j_result = await api_service.neo4j_match_triplets(
# request=request,
# query=query,
# query_entities=query_entities,
# userid=orgid,
# knowledge_base_id=kb_id,
# upappid=service_params['gdb'],
# apiname="neo4j/matchtriplets",
# user=userid
# )
# if neo4j_result.get("status") == "success":
# triplets = neo4j_result.get("triplets", [])
# all_triplets.extend(triplets)
# debug(f"知识库 {kb_id} 匹配到 {len(triplets)} 个三元组: {triplets[:5]}")
# else:
# error(
# f"Neo4j 三元组匹配失败: knowledge_base_id={kb_id}, 错误: {neo4j_result.get('message', '未知错误')}")
# except Exception as e:
# error(f"Neo4j 三元组匹配失败: knowledge_base_id={kb_id}, 错误: {str(e)}")
# continue
# timing_stats["triplet_matching"] = time.time() - triplet_match_start
# debug(f"三元组匹配总耗时: {timing_stats['triplet_matching']:.3f} 秒")
#
# # 拼接三元组文本
# triplet_text_start = time.time()
# triplet_texts = []
# for triplet in all_triplets:
# head = triplet.get('head', '')
# type_ = triplet.get('type', '')
# tail = triplet.get('tail', '')
# if head and type_ and tail:
# triplet_texts.append(f"{head} {type_} {tail}")
# else:
# debug(f"无效三元组: {triplet}")
# combined_text = query
# if triplet_texts:
# combined_text += "".join(triplet_texts)
# debug(
# f"拼接文本: {combined_text[:200]}... (总长度: {len(combined_text)}, 三元组数量: {len(triplet_texts)})")
# timing_stats["triplet_text_combine"] = time.time() - triplet_text_start
# debug(f"拼接三元组文本耗时: {timing_stats['triplet_text_combine']:.3f} 秒")
#
# # 将拼接文本转换为向量
# vector_start = time.time()
# query_vector = await api_service.get_embeddings(
# request=request,
# texts=[combined_text],
# upappid=service_params['embedding'],
# apiname="BAAI/bge-m3",
# user=userid
# )
# if not query_vector or not all(len(vec) == 1024 for vec in query_vector):
# raise ValueError("查询向量必须是长度为 1024 的浮点数列表")
# query_vector = query_vector[0] # 取第一个向量
# timing_stats["vector_generation"] = time.time() - vector_start
# debug(f"生成查询向量耗时: {timing_stats['vector_generation']:.3f} 秒")
#
# # 调用搜索端点
# sum = limit + 5
# search_start = time.time()
# debug(f"orgid: {orgid}")
# result = await api_service.milvus_search_query(
# request=request,
# query_vector=query_vector,
# userid=orgid,
# knowledge_base_ids=fiids,
# limit=sum,
# offset=0,
# upappid=service_params['vdb'],
# apiname="mlvus/searchquery",
# user=userid
# )
# timing_stats["vector_search"] = time.time() - search_start
# debug(f"向量搜索耗时: {timing_stats['vector_search']:.3f} 秒")
#
# if result.get("status") != "success":
# error(f"融合搜索失败: {result.get('message', '未知错误')}")
# return {"results": [], "timing": timing_stats}
#
# unique_results = result.get("results", [])
# sum = len(unique_results)
# debug(f"从向量数据中搜索到{sum}条数据")
# use_rerank = True
# if use_rerank and unique_results:
# rerank_start = time.time()
# debug("开始重排序")
# unique_results = await api_service.rerank_results(
# request=request,
# query=combined_text,
# results=unique_results,
# top_n=limit,
# upappid=service_params['reranker'],
# apiname="BAAI/bge-reranker-v2-m3",
# user=userid
# )
# unique_results = sorted(unique_results, key=lambda x: x.get('rerank_score', 0), reverse=True)
# timing_stats["reranking"] = time.time() - rerank_start
# debug(f"重排序耗时: {timing_stats['reranking']:.3f} 秒")
# debug(f"重排序分数分布: {[round(r.get('rerank_score', 0), 3) for r in unique_results]}")
# else:
# unique_results = [{k: v for k, v in r.items() if k != 'rerank_score'} for r in unique_results]
#
# timing_stats["total_time"] = time.time() - start_time
# info(f"融合搜索完成,返回 {len(unique_results)} 条结果,总耗时: {timing_stats['total_time']:.3f} 秒")
#
# # dify_result = []
# # for res in unique_results[:limit]:
# # content = res.get('text', '')
# # title = res.get('metadata', {}).get('filename', 'Untitled')
# # document_id = res.get('metadata', {}).get('document_id', '')
# # dify_result.append({
# # 'metadata': {'document_id': document_id},
# # 'title': title,
# # 'content': content
# # })
# # info(f"融合搜索完成,返回 {len(dify_result)} 条结果,总耗时: {(time.time() - start_time):.3f} 秒")
# # debug(f"result: {dify_result}")
# # return dify_result
#
# dify_records = []
# dify_result = []
# for res in unique_results[:limit]:
# rerank_score = res.get('rerank_score', 0)
# score = 1 / (1 + math.exp(-rerank_score)) if rerank_score is not None else 1 - res.get('distance', 0)
# score = max(0.0, min(1.0, score))
# content = res.get('text', '')
# title = res.get('metadata', {}).get('filename', 'Untitled')
# document_id = res.get('metadata', {}).get('document_id', '')
# dify_records.append({
# "content": content,
# "title": title,
# "metadata": {"document_id": document_id, "score": score},
# })
# dify_result.append({
# "content": content,
# "title": title,
# "metadata": {"document_id": document_id, "score": score},
# })
# info(f"融合搜索完成,返回 {len(dify_records)} 条结果,总耗时: {(time.time() - start_time):.3f} 秒")
# debug(f"records: {dify_records}, result: {dify_result}")
# # return {"records": dify_records, "result": dify_result,"own": {"results": unique_results[:limit], "timing": timing_stats}}
# return {"records": dify_records}
#
# # dify_result = []
# # for res in unique_results[:limit]:
# # rerank_score = res.get('rerank_score', 0)
# # score = 1 / (1 + math.exp(-rerank_score)) if rerank_score is not None else 1 - res.get('distance', 0)
# # score = max(0.0, min(1.0, score))
# # content = res.get('text', '')
# # title = res.get('metadata', {}).get('filename', 'Untitled')
# # document_id = res.get('metadata', {}).get('document_id', '')
# # dify_result.append({
# # "metadata": {
# # "_source": "konwledge",
# # "dataset_id":"111111",
# # "dataset_name": "NVIDIA_GPU性能参数-RAG-V1.xlsx",
# # "document_id": document_id,
# # "document_name": "test.docx",
# # "data_source_type": "upload_file",
# # "segment_id": "7b391707-93bc-4654-80ae-7989f393b045",
# # "retriever_from": "workflow",
# # "score": score,
# # "segment_hit_count": 7,
# # "segment_word_count": 275,
# # "segment_position": 5,
# # "segment_index_node_hash": "1cd60b478221c9d4831a0b2af3e8b8581d94ecb53e8ffd46af687e8fc3077b73",
# # "doc_metadata": None,
# # "position":1
# # },
# # "title": title,
# # "content": content
# # })
# # return {"result": dify_result}
#
# except Exception as e:
# error(f"融合搜索失败: {str(e)}, 堆栈: {traceback.format_exc()}")
# return {"results": [], "timing": timing_stats}