rag
This commit is contained in:
parent
f0c8120ada
commit
5bb3180cc4
311
rag/ragapi.py
311
rag/ragapi.py
@ -7,6 +7,7 @@ import json
|
||||
import math
|
||||
from rag.service_opts import get_service_params, sor_get_service_params
|
||||
from rag.rag_operations import RagOperations
|
||||
from rag.aslmapi import MemoryManager
|
||||
|
||||
helptext = """kyrag API:
|
||||
|
||||
@ -31,8 +32,25 @@ data: {
|
||||
3、docs文档
|
||||
path: /v1/docs
|
||||
|
||||
4、longmemory存储
|
||||
4. 添加用户消息到记忆:
|
||||
path: /v1/add_user_messages
|
||||
headers: {
|
||||
"Content-Type": "application/json"
|
||||
}
|
||||
data: {
|
||||
"orgid": "用户组织ID",
|
||||
"messages": [{"role": "user", "content": "消息内容"}, ...]
|
||||
}
|
||||
|
||||
5. 获取用户所有记忆:
|
||||
path: /v1/get_user_memories
|
||||
headers: {
|
||||
"Content-Type": "application/json"
|
||||
}
|
||||
data: {
|
||||
"orgid": "用户组织ID",
|
||||
"limit": 10 # 可选,默认为 10
|
||||
}
|
||||
"""
|
||||
|
||||
async def docs(request, params_kw, *params, **kw):
|
||||
@ -72,8 +90,9 @@ async def get_kdbs(request, params_kw, *params, **kw):
|
||||
})
|
||||
return result
|
||||
|
||||
async def fusedsearch(request, params_kw, *params, **kw):
|
||||
async def fusedsearch(request, params_kw, *params):
|
||||
"""融合搜索,调用服务化端点"""
|
||||
kw = request._run_ns
|
||||
f = kw.get('get_userorgid')
|
||||
orgid = await f()
|
||||
debug(f"orgid: {orgid},{f=}")
|
||||
@ -155,6 +174,7 @@ async def fusedsearch(request, params_kw, *params, **kw):
|
||||
if use_rerank and search_results:
|
||||
final_results = await rag_ops.rerank_results(request, combined_text, search_results, limit, service_params,
|
||||
userid, timings)
|
||||
debug(f"final_results: {final_results}")
|
||||
else:
|
||||
final_results = [{k: v for k, v in r.items() if k != 'rerank_score'} for r in search_results]
|
||||
|
||||
@ -175,6 +195,87 @@ async def fusedsearch(request, params_kw, *params, **kw):
|
||||
}
|
||||
|
||||
|
||||
async def add_user_messages(request, params_kw, *params, **kw):
|
||||
"""添加用户消息到记忆库
|
||||
|
||||
Args:
|
||||
request: HTTP 请求对象
|
||||
params_kw: 请求参数字典,包含 orgid 和 messages
|
||||
*params: 任意位置参数
|
||||
**kw: 任意关键字参数(如 get_module_dbname, get_userorgid)
|
||||
|
||||
Returns:
|
||||
str: JSON 格式的响应,包含 status 和 result 或错误信息
|
||||
"""
|
||||
debug(f"Received request: path={request.path}, params_kw={params_kw}, kw={kw}")
|
||||
orgid = params_kw.get('orgid')
|
||||
messages = params_kw.get('messages')
|
||||
|
||||
if not orgid or not isinstance(orgid, str):
|
||||
error("orgid 参数无效,必须为非空字符串")
|
||||
return json.dumps({"status": "error", "message": "orgid 参数无效,必须为非空字符串"})
|
||||
|
||||
if not messages or not isinstance(messages, list):
|
||||
error("messages 参数无效,必须为非空列表")
|
||||
return json.dumps({"status": "error", "message": "messages 参数无效,必须为非空列表"})
|
||||
|
||||
try:
|
||||
manager = MemoryManager()
|
||||
result = await manager.add_messages_to_memory(messages, orgid)
|
||||
debug(f"用户消息添加结果: {result}")
|
||||
return json.dumps(result, ensure_ascii=False)
|
||||
except Exception as e:
|
||||
error(f"添加用户消息失败: {str(e)}")
|
||||
return json.dumps({"status": "error", "message": str(e)})
|
||||
|
||||
async def get_user_memories(request, params_kw, *params, **kw):
|
||||
"""获取用户的所有记忆记录
|
||||
|
||||
Args:
|
||||
request: HTTP 请求对象
|
||||
params_kw: 请求参数字典,包含 orgid 和可选的 limit
|
||||
*params: 任意位置参数
|
||||
**kw: 任意关键字参数(如 get_module_dbname, get_userorgid)
|
||||
|
||||
Returns:
|
||||
str: JSON 格式的响应,包含 status 和 result 或错误信息
|
||||
"""
|
||||
debug(f"Received request: path={request.path}, params_kw={params_kw}, kw={kw}")
|
||||
orgid = params_kw.get('orgid')
|
||||
limit = params_kw.get('limit', 10)
|
||||
|
||||
if not orgid or not isinstance(orgid, str):
|
||||
error("orgid 参数无效,必须为非空字符串")
|
||||
return json.dumps({"status": "error", "message": "orgid 参数无效,必须为非空字符串"})
|
||||
|
||||
try:
|
||||
limit = int(limit) if isinstance(limit, (str, int, float)) else 10
|
||||
manager = MemoryManager()
|
||||
result = await manager.get_all_memories(user_id=orgid, limit=limit)
|
||||
debug(f"用户 {orgid} 的记忆检索结果: {result}")
|
||||
return json.dumps(result, ensure_ascii=False)
|
||||
except Exception as e:
|
||||
error(f"获取用户 {orgid} 的记忆失败: {str(e)}")
|
||||
return json.dumps({"status": "error", "message": str(e)})
|
||||
|
||||
async def _validate_fiids_orgid(fiids, orgid, kw):
|
||||
"""验证 fiids 的 orgid 与当前用户 orgid 是否一致"""
|
||||
if fiids:
|
||||
db = DBPools()
|
||||
dbname = kw.get('get_module_dbname')('rag')
|
||||
sql_opts = """SELECT orgid FROM kdb WHERE id = ${id}$"""
|
||||
try:
|
||||
async with db.sqlorContext(dbname) as sor:
|
||||
result = await sor.sqlExe(sql_opts, {"id": fiids[0]})
|
||||
if not result:
|
||||
raise ValueError(f"未找到 fiid={fiids[0]} 的记录")
|
||||
kdb_orgid = result[0].get('orgid')
|
||||
if kdb_orgid != orgid:
|
||||
raise ValueError(f"orgid 不一致: kdb.orgid={kdb_orgid}, user orgid={orgid}")
|
||||
except Exception as e:
|
||||
error(f"orgid 验证失败: {str(e)}")
|
||||
raise
|
||||
|
||||
async def _validate_fiids_orgid(fiids, orgid, kw):
|
||||
"""验证 fiids 的 orgid 与当前用户 orgid 是否一致"""
|
||||
if fiids:
|
||||
@ -213,208 +314,4 @@ def _combine_query_with_triplets(query, triplets):
|
||||
debug(f"拼接文本: {combined_text[:200]}... (总长度: {len(combined_text)}, 三元组数量: {len(triplet_texts)})")
|
||||
return combined_text
|
||||
|
||||
# api_service = APIService()
|
||||
# start_time = time.time()
|
||||
# timing_stats = {}
|
||||
# try:
|
||||
# info(
|
||||
# f"开始融合搜索: query={query}, userid={orgid}, knowledge_base_ids={fiids}")
|
||||
#
|
||||
# if not query or not orgid or not fiids:
|
||||
# raise ValueError("query、orgid 和 knowledge_base_ids 不能为空")
|
||||
#
|
||||
# # 提取实体
|
||||
# entity_extract_start = time.time()
|
||||
# query_entities = await api_service.extract_entities(
|
||||
# request=request,
|
||||
# query=query,
|
||||
# upappid=service_params['entities'],
|
||||
# apiname="LTP/small",
|
||||
# user=userid
|
||||
# )
|
||||
# timing_stats["entity_extraction"] = time.time() - entity_extract_start
|
||||
# debug(f"提取实体: {query_entities}, 耗时: {timing_stats['entity_extraction']:.3f} 秒")
|
||||
#
|
||||
# # 调用 Neo4j 服务进行三元组匹配
|
||||
# all_triplets = []
|
||||
# triplet_match_start = time.time()
|
||||
# for kb_id in fiids:
|
||||
# debug(f"调用 Neo4j 三元组匹配: knowledge_base_id={kb_id}")
|
||||
# try:
|
||||
# neo4j_result = await api_service.neo4j_match_triplets(
|
||||
# request=request,
|
||||
# query=query,
|
||||
# query_entities=query_entities,
|
||||
# userid=orgid,
|
||||
# knowledge_base_id=kb_id,
|
||||
# upappid=service_params['gdb'],
|
||||
# apiname="neo4j/matchtriplets",
|
||||
# user=userid
|
||||
# )
|
||||
# if neo4j_result.get("status") == "success":
|
||||
# triplets = neo4j_result.get("triplets", [])
|
||||
# all_triplets.extend(triplets)
|
||||
# debug(f"知识库 {kb_id} 匹配到 {len(triplets)} 个三元组: {triplets[:5]}")
|
||||
# else:
|
||||
# error(
|
||||
# f"Neo4j 三元组匹配失败: knowledge_base_id={kb_id}, 错误: {neo4j_result.get('message', '未知错误')}")
|
||||
# except Exception as e:
|
||||
# error(f"Neo4j 三元组匹配失败: knowledge_base_id={kb_id}, 错误: {str(e)}")
|
||||
# continue
|
||||
# timing_stats["triplet_matching"] = time.time() - triplet_match_start
|
||||
# debug(f"三元组匹配总耗时: {timing_stats['triplet_matching']:.3f} 秒")
|
||||
#
|
||||
# # 拼接三元组文本
|
||||
# triplet_text_start = time.time()
|
||||
# triplet_texts = []
|
||||
# for triplet in all_triplets:
|
||||
# head = triplet.get('head', '')
|
||||
# type_ = triplet.get('type', '')
|
||||
# tail = triplet.get('tail', '')
|
||||
# if head and type_ and tail:
|
||||
# triplet_texts.append(f"{head} {type_} {tail}")
|
||||
# else:
|
||||
# debug(f"无效三元组: {triplet}")
|
||||
# combined_text = query
|
||||
# if triplet_texts:
|
||||
# combined_text += "".join(triplet_texts)
|
||||
# debug(
|
||||
# f"拼接文本: {combined_text[:200]}... (总长度: {len(combined_text)}, 三元组数量: {len(triplet_texts)})")
|
||||
# timing_stats["triplet_text_combine"] = time.time() - triplet_text_start
|
||||
# debug(f"拼接三元组文本耗时: {timing_stats['triplet_text_combine']:.3f} 秒")
|
||||
#
|
||||
# # 将拼接文本转换为向量
|
||||
# vector_start = time.time()
|
||||
# query_vector = await api_service.get_embeddings(
|
||||
# request=request,
|
||||
# texts=[combined_text],
|
||||
# upappid=service_params['embedding'],
|
||||
# apiname="BAAI/bge-m3",
|
||||
# user=userid
|
||||
# )
|
||||
# if not query_vector or not all(len(vec) == 1024 for vec in query_vector):
|
||||
# raise ValueError("查询向量必须是长度为 1024 的浮点数列表")
|
||||
# query_vector = query_vector[0] # 取第一个向量
|
||||
# timing_stats["vector_generation"] = time.time() - vector_start
|
||||
# debug(f"生成查询向量耗时: {timing_stats['vector_generation']:.3f} 秒")
|
||||
#
|
||||
# # 调用搜索端点
|
||||
# sum = limit + 5
|
||||
# search_start = time.time()
|
||||
# debug(f"orgid: {orgid}")
|
||||
# result = await api_service.milvus_search_query(
|
||||
# request=request,
|
||||
# query_vector=query_vector,
|
||||
# userid=orgid,
|
||||
# knowledge_base_ids=fiids,
|
||||
# limit=sum,
|
||||
# offset=0,
|
||||
# upappid=service_params['vdb'],
|
||||
# apiname="mlvus/searchquery",
|
||||
# user=userid
|
||||
# )
|
||||
# timing_stats["vector_search"] = time.time() - search_start
|
||||
# debug(f"向量搜索耗时: {timing_stats['vector_search']:.3f} 秒")
|
||||
#
|
||||
# if result.get("status") != "success":
|
||||
# error(f"融合搜索失败: {result.get('message', '未知错误')}")
|
||||
# return {"results": [], "timing": timing_stats}
|
||||
#
|
||||
# unique_results = result.get("results", [])
|
||||
# sum = len(unique_results)
|
||||
# debug(f"从向量数据中搜索到{sum}条数据")
|
||||
# use_rerank = True
|
||||
# if use_rerank and unique_results:
|
||||
# rerank_start = time.time()
|
||||
# debug("开始重排序")
|
||||
# unique_results = await api_service.rerank_results(
|
||||
# request=request,
|
||||
# query=combined_text,
|
||||
# results=unique_results,
|
||||
# top_n=limit,
|
||||
# upappid=service_params['reranker'],
|
||||
# apiname="BAAI/bge-reranker-v2-m3",
|
||||
# user=userid
|
||||
# )
|
||||
# unique_results = sorted(unique_results, key=lambda x: x.get('rerank_score', 0), reverse=True)
|
||||
# timing_stats["reranking"] = time.time() - rerank_start
|
||||
# debug(f"重排序耗时: {timing_stats['reranking']:.3f} 秒")
|
||||
# debug(f"重排序分数分布: {[round(r.get('rerank_score', 0), 3) for r in unique_results]}")
|
||||
# else:
|
||||
# unique_results = [{k: v for k, v in r.items() if k != 'rerank_score'} for r in unique_results]
|
||||
#
|
||||
# timing_stats["total_time"] = time.time() - start_time
|
||||
# info(f"融合搜索完成,返回 {len(unique_results)} 条结果,总耗时: {timing_stats['total_time']:.3f} 秒")
|
||||
#
|
||||
# # dify_result = []
|
||||
# # for res in unique_results[:limit]:
|
||||
# # content = res.get('text', '')
|
||||
# # title = res.get('metadata', {}).get('filename', 'Untitled')
|
||||
# # document_id = res.get('metadata', {}).get('document_id', '')
|
||||
# # dify_result.append({
|
||||
# # 'metadata': {'document_id': document_id},
|
||||
# # 'title': title,
|
||||
# # 'content': content
|
||||
# # })
|
||||
# # info(f"融合搜索完成,返回 {len(dify_result)} 条结果,总耗时: {(time.time() - start_time):.3f} 秒")
|
||||
# # debug(f"result: {dify_result}")
|
||||
# # return dify_result
|
||||
#
|
||||
# dify_records = []
|
||||
# dify_result = []
|
||||
# for res in unique_results[:limit]:
|
||||
# rerank_score = res.get('rerank_score', 0)
|
||||
# score = 1 / (1 + math.exp(-rerank_score)) if rerank_score is not None else 1 - res.get('distance', 0)
|
||||
# score = max(0.0, min(1.0, score))
|
||||
# content = res.get('text', '')
|
||||
# title = res.get('metadata', {}).get('filename', 'Untitled')
|
||||
# document_id = res.get('metadata', {}).get('document_id', '')
|
||||
# dify_records.append({
|
||||
# "content": content,
|
||||
# "title": title,
|
||||
# "metadata": {"document_id": document_id, "score": score},
|
||||
# })
|
||||
# dify_result.append({
|
||||
# "content": content,
|
||||
# "title": title,
|
||||
# "metadata": {"document_id": document_id, "score": score},
|
||||
# })
|
||||
# info(f"融合搜索完成,返回 {len(dify_records)} 条结果,总耗时: {(time.time() - start_time):.3f} 秒")
|
||||
# debug(f"records: {dify_records}, result: {dify_result}")
|
||||
# # return {"records": dify_records, "result": dify_result,"own": {"results": unique_results[:limit], "timing": timing_stats}}
|
||||
# return {"records": dify_records}
|
||||
#
|
||||
# # dify_result = []
|
||||
# # for res in unique_results[:limit]:
|
||||
# # rerank_score = res.get('rerank_score', 0)
|
||||
# # score = 1 / (1 + math.exp(-rerank_score)) if rerank_score is not None else 1 - res.get('distance', 0)
|
||||
# # score = max(0.0, min(1.0, score))
|
||||
# # content = res.get('text', '')
|
||||
# # title = res.get('metadata', {}).get('filename', 'Untitled')
|
||||
# # document_id = res.get('metadata', {}).get('document_id', '')
|
||||
# # dify_result.append({
|
||||
# # "metadata": {
|
||||
# # "_source": "konwledge",
|
||||
# # "dataset_id":"111111",
|
||||
# # "dataset_name": "NVIDIA_GPU性能参数-RAG-V1.xlsx",
|
||||
# # "document_id": document_id,
|
||||
# # "document_name": "test.docx",
|
||||
# # "data_source_type": "upload_file",
|
||||
# # "segment_id": "7b391707-93bc-4654-80ae-7989f393b045",
|
||||
# # "retriever_from": "workflow",
|
||||
# # "score": score,
|
||||
# # "segment_hit_count": 7,
|
||||
# # "segment_word_count": 275,
|
||||
# # "segment_position": 5,
|
||||
# # "segment_index_node_hash": "1cd60b478221c9d4831a0b2af3e8b8581d94ecb53e8ffd46af687e8fc3077b73",
|
||||
# # "doc_metadata": None,
|
||||
# # "position":1
|
||||
# # },
|
||||
# # "title": title,
|
||||
# # "content": content
|
||||
# # })
|
||||
# # return {"result": dify_result}
|
||||
#
|
||||
# except Exception as e:
|
||||
# error(f"融合搜索失败: {str(e)}, 堆栈: {traceback.format_exc()}")
|
||||
# return {"results": [], "timing": timing_stats}
|
||||
|
||||
|
||||
Loading…
x
Reference in New Issue
Block a user