添加ragapi.py
This commit is contained in:
parent
93d5786b15
commit
0fdb0d3393
211
rag/ragapi.py
Normal file
211
rag/ragapi.py
Normal file
@ -0,0 +1,211 @@
|
|||||||
|
from rag.uapi_service import APIService
|
||||||
|
from rag.folderinfo import RagFileMgr
|
||||||
|
|
||||||
|
helptext = """kyrag API:
|
||||||
|
|
||||||
|
1. 得到kdb表:
|
||||||
|
path: /v1/get_kdbs
|
||||||
|
headers: {
|
||||||
|
"Content-Type": "application/json"
|
||||||
|
}
|
||||||
|
response:
|
||||||
|
|
||||||
|
2. 向量检索文本块:
|
||||||
|
path: /v1/fusedsearch
|
||||||
|
headers: {
|
||||||
|
"Content-Type": "application/json"
|
||||||
|
}
|
||||||
|
data: {
|
||||||
|
"query": "什么是知识抽取。"
|
||||||
|
"fiids":["1"]
|
||||||
|
}
|
||||||
|
3、docs文档
|
||||||
|
"""
|
||||||
|
|
||||||
|
async def docs(request, params_kw, *params, **kw):
|
||||||
|
return helptext
|
||||||
|
|
||||||
|
async def get_kdbs(request, params_kw, *params, **kw):
|
||||||
|
"""返回 kdb 表的全部内容,返回 JSON"""
|
||||||
|
print("初始化数据库连接池...")
|
||||||
|
db = DBPools()
|
||||||
|
dbname = kw.get('get_module_dbname')('rag')
|
||||||
|
sql_opts = """
|
||||||
|
SELECT id, name, description
|
||||||
|
FROM kdb
|
||||||
|
"""
|
||||||
|
try:
|
||||||
|
async with db.sqlorContext(dbname) as sor:
|
||||||
|
opts_result = await sor.sqlExe(sql_opts, {})
|
||||||
|
if not opts_result:
|
||||||
|
error("未找到 kdb 表记录")
|
||||||
|
return json.dumps({"status": "error", "message": "未找到记录"})
|
||||||
|
return json.dumps(opts_result, ensure_ascii=False)
|
||||||
|
except Exception as e:
|
||||||
|
error(f"查询 kdb 表失败: {str(e)}, 堆栈: {traceback.format_exc()}")
|
||||||
|
return json.dumps({"status": "error", "message": str(e)})
|
||||||
|
|
||||||
|
async def fusedsearch(request, params_kw, *params, **kw):
|
||||||
|
"""融合搜索,调用服务化端点"""
|
||||||
|
f = kw.get('get_userorgid')
|
||||||
|
orgid = await f()
|
||||||
|
f = kw.get('get_user')
|
||||||
|
userid = await f()
|
||||||
|
debug(f"params_kw: {params_kw}")
|
||||||
|
query = params_kw.get('query', '')
|
||||||
|
fiids = params_kw.get('fiid', [])
|
||||||
|
|
||||||
|
# 验证 fiids的orgid与orgid = await f()是否一致
|
||||||
|
if fiids:
|
||||||
|
db = DBPools()
|
||||||
|
dbname = kw.get('get_module_dbname')('rag')
|
||||||
|
sql_opts = """
|
||||||
|
SELECT orgid
|
||||||
|
FROM kdb
|
||||||
|
WHERE id = ${fiid}$
|
||||||
|
"""
|
||||||
|
try:
|
||||||
|
async with db.sqlorContext(dbname) as sor:
|
||||||
|
result = await sor.sqlExe(sql_opts, {"fiid": fiids[0]})
|
||||||
|
if not result:
|
||||||
|
raise ValueError(f"未找到 fiid={fiids[0]} 的记录")
|
||||||
|
kdb_orgid = result[0].get('orgid')
|
||||||
|
if kdb_orgid != orgid:
|
||||||
|
raise ValueError(f"orgid 不一致: kdb.orgid={kdb_orgid}, user orgid={orgid}")
|
||||||
|
except Exception as e:
|
||||||
|
error(f"orgid 验证失败: {str(e)}")
|
||||||
|
return json.dumps({"status": "error", "message": str(e)})
|
||||||
|
ragfilemgr = RagFileMgr("fiids[0]")
|
||||||
|
service_params = ragfilemgr.get_service_params(orgid)
|
||||||
|
|
||||||
|
api_service = APIService()
|
||||||
|
start_time = time.time()
|
||||||
|
timing_stats = {}
|
||||||
|
try:
|
||||||
|
info(
|
||||||
|
f"开始融合搜索: query={query}, userid={orgid}, knowledge_base_ids={fiids}")
|
||||||
|
|
||||||
|
if not query or not orgid or not knowledge_base_ids:
|
||||||
|
raise ValueError("query、orgid 和 knowledge_base_ids 不能为空")
|
||||||
|
|
||||||
|
# 提取实体
|
||||||
|
entity_extract_start = time.time()
|
||||||
|
query_entities = await api_service.extract_entities(
|
||||||
|
request=request,
|
||||||
|
query=query,
|
||||||
|
upappid=service_params['entities'],
|
||||||
|
apiname="LTP/small",
|
||||||
|
user=userid
|
||||||
|
)
|
||||||
|
timing_stats["entity_extraction"] = time.time() - entity_extract_start
|
||||||
|
debug(f"提取实体: {query_entities}, 耗时: {timing_stats['entity_extraction']:.3f} 秒")
|
||||||
|
|
||||||
|
# 调用 Neo4j 服务进行三元组匹配
|
||||||
|
all_triplets = []
|
||||||
|
triplet_match_start = time.time()
|
||||||
|
for kb_id in fiids:
|
||||||
|
debug(f"调用 Neo4j 三元组匹配: knowledge_base_id={kb_id}")
|
||||||
|
try:
|
||||||
|
neo4j_result = await api_service.neo4j_match_triplets(
|
||||||
|
request=request,
|
||||||
|
query=query,
|
||||||
|
query_entities=query_entities,
|
||||||
|
userid=orgid,
|
||||||
|
knowledge_base_id=kb_id,
|
||||||
|
upappid=service_params['gdb'],
|
||||||
|
apiname="neo4j/matchtriplets",
|
||||||
|
user=userid
|
||||||
|
)
|
||||||
|
if neo4j_result.get("status") == "success":
|
||||||
|
triplets = neo4j_result.get("triplets", [])
|
||||||
|
all_triplets.extend(triplets)
|
||||||
|
debug(f"知识库 {kb_id} 匹配到 {len(triplets)} 个三元组: {triplets[:5]}")
|
||||||
|
else:
|
||||||
|
error(
|
||||||
|
f"Neo4j 三元组匹配失败: knowledge_base_id={kb_id}, 错误: {neo4j_result.get('message', '未知错误')}")
|
||||||
|
except Exception as e:
|
||||||
|
error(f"Neo4j 三元组匹配失败: knowledge_base_id={kb_id}, 错误: {str(e)}")
|
||||||
|
continue
|
||||||
|
timing_stats["triplet_matching"] = time.time() - triplet_match_start
|
||||||
|
debug(f"三元组匹配总耗时: {timing_stats['triplet_matching']:.3f} 秒")
|
||||||
|
|
||||||
|
# 拼接三元组文本
|
||||||
|
triplet_text_start = time.time()
|
||||||
|
triplet_texts = []
|
||||||
|
for triplet in all_triplets:
|
||||||
|
head = triplet.get('head', '')
|
||||||
|
type_ = triplet.get('type', '')
|
||||||
|
tail = triplet.get('tail', '')
|
||||||
|
if head and type_ and tail:
|
||||||
|
triplet_texts.append(f"{head} {type_} {tail}")
|
||||||
|
else:
|
||||||
|
debug(f"无效三元组: {triplet}")
|
||||||
|
combined_text = query
|
||||||
|
if triplet_texts:
|
||||||
|
combined_text += " [三元组] " + "; ".join(triplet_texts)
|
||||||
|
debug(
|
||||||
|
f"拼接文本: {combined_text[:200]}... (总长度: {len(combined_text)}, 三元组数量: {len(triplet_texts)})")
|
||||||
|
timing_stats["triplet_text_combine"] = time.time() - triplet_text_start
|
||||||
|
debug(f"拼接三元组文本耗时: {timing_stats['triplet_text_combine']:.3f} 秒")
|
||||||
|
|
||||||
|
# 将拼接文本转换为向量
|
||||||
|
vector_start = time.time()
|
||||||
|
query_vector = await api_service.get_embeddings(
|
||||||
|
request=request,
|
||||||
|
texts=[combined_text],
|
||||||
|
upappid=service_params['embedding'],
|
||||||
|
user=userid
|
||||||
|
)
|
||||||
|
if not query_vector or not all(len(vec) == 1024 for vec in query_vector):
|
||||||
|
raise ValueError("查询向量必须是长度为 1024 的浮点数列表")
|
||||||
|
query_vector = query_vector[0] # 取第一个向量
|
||||||
|
timing_stats["vector_generation"] = time.time() - vector_start
|
||||||
|
debug(f"生成查询向量耗时: {timing_stats['vector_generation']:.3f} 秒")
|
||||||
|
|
||||||
|
# 调用搜索端点
|
||||||
|
limit = 5
|
||||||
|
search_start = time.time()
|
||||||
|
result = await api_service.milvus_search_query(
|
||||||
|
request=request,
|
||||||
|
query_vector=query_vector,
|
||||||
|
userid=orgid,
|
||||||
|
knowledge_base_ids=[fiids],
|
||||||
|
upappid=service_params['vdb'],
|
||||||
|
apiname="mlvus/searchquery",
|
||||||
|
user=userid
|
||||||
|
)
|
||||||
|
timing_stats["vector_search"] = time.time() - search_start
|
||||||
|
debug(f"向量搜索耗时: {timing_stats['vector_search']:.3f} 秒")
|
||||||
|
|
||||||
|
if result.get("status") != "success":
|
||||||
|
error(f"融合搜索失败: {result.get('message', '未知错误')}")
|
||||||
|
return {"results": [], "timing": timing_stats}
|
||||||
|
|
||||||
|
unique_results = result.get("results", [])
|
||||||
|
use_rerank = True
|
||||||
|
if use_rerank and unique_results:
|
||||||
|
rerank_start = time.time()
|
||||||
|
debug("开始重排序")
|
||||||
|
unique_results = await api_service(
|
||||||
|
request=request,
|
||||||
|
query=combined_text,
|
||||||
|
results=unique_results,
|
||||||
|
top_n=limit,
|
||||||
|
upappid=service_params['reranker'],
|
||||||
|
apiname="BAAI/bge-reranker-v2-m3",
|
||||||
|
user=userid
|
||||||
|
)
|
||||||
|
unique_results = sorted(unique_results, key=lambda x: x.get('rerank_score', 0), reverse=True)
|
||||||
|
timing_stats["reranking"] = time.time() - rerank_start
|
||||||
|
debug(f"重排序耗时: {timing_stats['reranking']:.3f} 秒")
|
||||||
|
debug(f"重排序分数分布: {[round(r.get('rerank_score', 0), 3) for r in unique_results]}")
|
||||||
|
else:
|
||||||
|
unique_results = [{k: v for k, v in r.items() if k != 'rerank_score'} for r in unique_results]
|
||||||
|
|
||||||
|
timing_stats["total_time"] = time.time() - start_time
|
||||||
|
info(f"融合搜索完成,返回 {len(unique_results)} 条结果,总耗时: {timing_stats['total_time']:.3f} 秒")
|
||||||
|
return {"results": unique_results[:limit], "timing": timing_stats}
|
||||||
|
|
||||||
|
except Exception as e:
|
||||||
|
error(f"融合搜索失败: {str(e)}, 堆栈: {traceback.format_exc()}")
|
||||||
|
return {"results": [], "timing": timing_stats}
|
||||||
Loading…
x
Reference in New Issue
Block a user