添加ragapi.py

This commit is contained in:
wangmeihua 2025-08-15 15:26:42 +08:00
parent 93d5786b15
commit 0fdb0d3393

211
rag/ragapi.py Normal file
View File

@ -0,0 +1,211 @@
from rag.uapi_service import APIService
from rag.folderinfo import RagFileMgr
helptext = """kyrag API:
1. 得到kdb表:
path: /v1/get_kdbs
headers: {
"Content-Type": "application/json"
}
response:
2. 向量检索文本块:
path: /v1/fusedsearch
headers: {
"Content-Type": "application/json"
}
data: {
"query": "什么是知识抽取。"
"fiids":["1"]
}
3docs文档
"""
async def docs(request, params_kw, *params, **kw):
return helptext
async def get_kdbs(request, params_kw, *params, **kw):
"""返回 kdb 表的全部内容,返回 JSON"""
print("初始化数据库连接池...")
db = DBPools()
dbname = kw.get('get_module_dbname')('rag')
sql_opts = """
SELECT id, name, description
FROM kdb
"""
try:
async with db.sqlorContext(dbname) as sor:
opts_result = await sor.sqlExe(sql_opts, {})
if not opts_result:
error("未找到 kdb 表记录")
return json.dumps({"status": "error", "message": "未找到记录"})
return json.dumps(opts_result, ensure_ascii=False)
except Exception as e:
error(f"查询 kdb 表失败: {str(e)}, 堆栈: {traceback.format_exc()}")
return json.dumps({"status": "error", "message": str(e)})
async def fusedsearch(request, params_kw, *params, **kw):
"""融合搜索,调用服务化端点"""
f = kw.get('get_userorgid')
orgid = await f()
f = kw.get('get_user')
userid = await f()
debug(f"params_kw: {params_kw}")
query = params_kw.get('query', '')
fiids = params_kw.get('fiid', [])
# 验证 fiids的orgid与orgid = await f()是否一致
if fiids:
db = DBPools()
dbname = kw.get('get_module_dbname')('rag')
sql_opts = """
SELECT orgid
FROM kdb
WHERE id = ${fiid}$
"""
try:
async with db.sqlorContext(dbname) as sor:
result = await sor.sqlExe(sql_opts, {"fiid": fiids[0]})
if not result:
raise ValueError(f"未找到 fiid={fiids[0]} 的记录")
kdb_orgid = result[0].get('orgid')
if kdb_orgid != orgid:
raise ValueError(f"orgid 不一致: kdb.orgid={kdb_orgid}, user orgid={orgid}")
except Exception as e:
error(f"orgid 验证失败: {str(e)}")
return json.dumps({"status": "error", "message": str(e)})
ragfilemgr = RagFileMgr("fiids[0]")
service_params = ragfilemgr.get_service_params(orgid)
api_service = APIService()
start_time = time.time()
timing_stats = {}
try:
info(
f"开始融合搜索: query={query}, userid={orgid}, knowledge_base_ids={fiids}")
if not query or not orgid or not knowledge_base_ids:
raise ValueError("query、orgid 和 knowledge_base_ids 不能为空")
# 提取实体
entity_extract_start = time.time()
query_entities = await api_service.extract_entities(
request=request,
query=query,
upappid=service_params['entities'],
apiname="LTP/small",
user=userid
)
timing_stats["entity_extraction"] = time.time() - entity_extract_start
debug(f"提取实体: {query_entities}, 耗时: {timing_stats['entity_extraction']:.3f}")
# 调用 Neo4j 服务进行三元组匹配
all_triplets = []
triplet_match_start = time.time()
for kb_id in fiids:
debug(f"调用 Neo4j 三元组匹配: knowledge_base_id={kb_id}")
try:
neo4j_result = await api_service.neo4j_match_triplets(
request=request,
query=query,
query_entities=query_entities,
userid=orgid,
knowledge_base_id=kb_id,
upappid=service_params['gdb'],
apiname="neo4j/matchtriplets",
user=userid
)
if neo4j_result.get("status") == "success":
triplets = neo4j_result.get("triplets", [])
all_triplets.extend(triplets)
debug(f"知识库 {kb_id} 匹配到 {len(triplets)} 个三元组: {triplets[:5]}")
else:
error(
f"Neo4j 三元组匹配失败: knowledge_base_id={kb_id}, 错误: {neo4j_result.get('message', '未知错误')}")
except Exception as e:
error(f"Neo4j 三元组匹配失败: knowledge_base_id={kb_id}, 错误: {str(e)}")
continue
timing_stats["triplet_matching"] = time.time() - triplet_match_start
debug(f"三元组匹配总耗时: {timing_stats['triplet_matching']:.3f}")
# 拼接三元组文本
triplet_text_start = time.time()
triplet_texts = []
for triplet in all_triplets:
head = triplet.get('head', '')
type_ = triplet.get('type', '')
tail = triplet.get('tail', '')
if head and type_ and tail:
triplet_texts.append(f"{head} {type_} {tail}")
else:
debug(f"无效三元组: {triplet}")
combined_text = query
if triplet_texts:
combined_text += " [三元组] " + "; ".join(triplet_texts)
debug(
f"拼接文本: {combined_text[:200]}... (总长度: {len(combined_text)}, 三元组数量: {len(triplet_texts)})")
timing_stats["triplet_text_combine"] = time.time() - triplet_text_start
debug(f"拼接三元组文本耗时: {timing_stats['triplet_text_combine']:.3f}")
# 将拼接文本转换为向量
vector_start = time.time()
query_vector = await api_service.get_embeddings(
request=request,
texts=[combined_text],
upappid=service_params['embedding'],
user=userid
)
if not query_vector or not all(len(vec) == 1024 for vec in query_vector):
raise ValueError("查询向量必须是长度为 1024 的浮点数列表")
query_vector = query_vector[0] # 取第一个向量
timing_stats["vector_generation"] = time.time() - vector_start
debug(f"生成查询向量耗时: {timing_stats['vector_generation']:.3f}")
# 调用搜索端点
limit = 5
search_start = time.time()
result = await api_service.milvus_search_query(
request=request,
query_vector=query_vector,
userid=orgid,
knowledge_base_ids=[fiids],
upappid=service_params['vdb'],
apiname="mlvus/searchquery",
user=userid
)
timing_stats["vector_search"] = time.time() - search_start
debug(f"向量搜索耗时: {timing_stats['vector_search']:.3f}")
if result.get("status") != "success":
error(f"融合搜索失败: {result.get('message', '未知错误')}")
return {"results": [], "timing": timing_stats}
unique_results = result.get("results", [])
use_rerank = True
if use_rerank and unique_results:
rerank_start = time.time()
debug("开始重排序")
unique_results = await api_service(
request=request,
query=combined_text,
results=unique_results,
top_n=limit,
upappid=service_params['reranker'],
apiname="BAAI/bge-reranker-v2-m3",
user=userid
)
unique_results = sorted(unique_results, key=lambda x: x.get('rerank_score', 0), reverse=True)
timing_stats["reranking"] = time.time() - rerank_start
debug(f"重排序耗时: {timing_stats['reranking']:.3f}")
debug(f"重排序分数分布: {[round(r.get('rerank_score', 0), 3) for r in unique_results]}")
else:
unique_results = [{k: v for k, v in r.items() if k != 'rerank_score'} for r in unique_results]
timing_stats["total_time"] = time.time() - start_time
info(f"融合搜索完成,返回 {len(unique_results)} 条结果,总耗时: {timing_stats['total_time']:.3f}")
return {"results": unique_results[:limit], "timing": timing_stats}
except Exception as e:
error(f"融合搜索失败: {str(e)}, 堆栈: {traceback.format_exc()}")
return {"results": [], "timing": timing_stats}