添加ragapi.py
This commit is contained in:
parent
93d5786b15
commit
0fdb0d3393
211
rag/ragapi.py
Normal file
211
rag/ragapi.py
Normal file
@ -0,0 +1,211 @@
|
||||
from rag.uapi_service import APIService
|
||||
from rag.folderinfo import RagFileMgr
|
||||
|
||||
helptext = """kyrag API:
|
||||
|
||||
1. 得到kdb表:
|
||||
path: /v1/get_kdbs
|
||||
headers: {
|
||||
"Content-Type": "application/json"
|
||||
}
|
||||
response:
|
||||
|
||||
2. 向量检索文本块:
|
||||
path: /v1/fusedsearch
|
||||
headers: {
|
||||
"Content-Type": "application/json"
|
||||
}
|
||||
data: {
|
||||
"query": "什么是知识抽取。"
|
||||
"fiids":["1"]
|
||||
}
|
||||
3、docs文档
|
||||
"""
|
||||
|
||||
async def docs(request, params_kw, *params, **kw):
|
||||
return helptext
|
||||
|
||||
async def get_kdbs(request, params_kw, *params, **kw):
|
||||
"""返回 kdb 表的全部内容,返回 JSON"""
|
||||
print("初始化数据库连接池...")
|
||||
db = DBPools()
|
||||
dbname = kw.get('get_module_dbname')('rag')
|
||||
sql_opts = """
|
||||
SELECT id, name, description
|
||||
FROM kdb
|
||||
"""
|
||||
try:
|
||||
async with db.sqlorContext(dbname) as sor:
|
||||
opts_result = await sor.sqlExe(sql_opts, {})
|
||||
if not opts_result:
|
||||
error("未找到 kdb 表记录")
|
||||
return json.dumps({"status": "error", "message": "未找到记录"})
|
||||
return json.dumps(opts_result, ensure_ascii=False)
|
||||
except Exception as e:
|
||||
error(f"查询 kdb 表失败: {str(e)}, 堆栈: {traceback.format_exc()}")
|
||||
return json.dumps({"status": "error", "message": str(e)})
|
||||
|
||||
async def fusedsearch(request, params_kw, *params, **kw):
|
||||
"""融合搜索,调用服务化端点"""
|
||||
f = kw.get('get_userorgid')
|
||||
orgid = await f()
|
||||
f = kw.get('get_user')
|
||||
userid = await f()
|
||||
debug(f"params_kw: {params_kw}")
|
||||
query = params_kw.get('query', '')
|
||||
fiids = params_kw.get('fiid', [])
|
||||
|
||||
# 验证 fiids的orgid与orgid = await f()是否一致
|
||||
if fiids:
|
||||
db = DBPools()
|
||||
dbname = kw.get('get_module_dbname')('rag')
|
||||
sql_opts = """
|
||||
SELECT orgid
|
||||
FROM kdb
|
||||
WHERE id = ${fiid}$
|
||||
"""
|
||||
try:
|
||||
async with db.sqlorContext(dbname) as sor:
|
||||
result = await sor.sqlExe(sql_opts, {"fiid": fiids[0]})
|
||||
if not result:
|
||||
raise ValueError(f"未找到 fiid={fiids[0]} 的记录")
|
||||
kdb_orgid = result[0].get('orgid')
|
||||
if kdb_orgid != orgid:
|
||||
raise ValueError(f"orgid 不一致: kdb.orgid={kdb_orgid}, user orgid={orgid}")
|
||||
except Exception as e:
|
||||
error(f"orgid 验证失败: {str(e)}")
|
||||
return json.dumps({"status": "error", "message": str(e)})
|
||||
ragfilemgr = RagFileMgr("fiids[0]")
|
||||
service_params = ragfilemgr.get_service_params(orgid)
|
||||
|
||||
api_service = APIService()
|
||||
start_time = time.time()
|
||||
timing_stats = {}
|
||||
try:
|
||||
info(
|
||||
f"开始融合搜索: query={query}, userid={orgid}, knowledge_base_ids={fiids}")
|
||||
|
||||
if not query or not orgid or not knowledge_base_ids:
|
||||
raise ValueError("query、orgid 和 knowledge_base_ids 不能为空")
|
||||
|
||||
# 提取实体
|
||||
entity_extract_start = time.time()
|
||||
query_entities = await api_service.extract_entities(
|
||||
request=request,
|
||||
query=query,
|
||||
upappid=service_params['entities'],
|
||||
apiname="LTP/small",
|
||||
user=userid
|
||||
)
|
||||
timing_stats["entity_extraction"] = time.time() - entity_extract_start
|
||||
debug(f"提取实体: {query_entities}, 耗时: {timing_stats['entity_extraction']:.3f} 秒")
|
||||
|
||||
# 调用 Neo4j 服务进行三元组匹配
|
||||
all_triplets = []
|
||||
triplet_match_start = time.time()
|
||||
for kb_id in fiids:
|
||||
debug(f"调用 Neo4j 三元组匹配: knowledge_base_id={kb_id}")
|
||||
try:
|
||||
neo4j_result = await api_service.neo4j_match_triplets(
|
||||
request=request,
|
||||
query=query,
|
||||
query_entities=query_entities,
|
||||
userid=orgid,
|
||||
knowledge_base_id=kb_id,
|
||||
upappid=service_params['gdb'],
|
||||
apiname="neo4j/matchtriplets",
|
||||
user=userid
|
||||
)
|
||||
if neo4j_result.get("status") == "success":
|
||||
triplets = neo4j_result.get("triplets", [])
|
||||
all_triplets.extend(triplets)
|
||||
debug(f"知识库 {kb_id} 匹配到 {len(triplets)} 个三元组: {triplets[:5]}")
|
||||
else:
|
||||
error(
|
||||
f"Neo4j 三元组匹配失败: knowledge_base_id={kb_id}, 错误: {neo4j_result.get('message', '未知错误')}")
|
||||
except Exception as e:
|
||||
error(f"Neo4j 三元组匹配失败: knowledge_base_id={kb_id}, 错误: {str(e)}")
|
||||
continue
|
||||
timing_stats["triplet_matching"] = time.time() - triplet_match_start
|
||||
debug(f"三元组匹配总耗时: {timing_stats['triplet_matching']:.3f} 秒")
|
||||
|
||||
# 拼接三元组文本
|
||||
triplet_text_start = time.time()
|
||||
triplet_texts = []
|
||||
for triplet in all_triplets:
|
||||
head = triplet.get('head', '')
|
||||
type_ = triplet.get('type', '')
|
||||
tail = triplet.get('tail', '')
|
||||
if head and type_ and tail:
|
||||
triplet_texts.append(f"{head} {type_} {tail}")
|
||||
else:
|
||||
debug(f"无效三元组: {triplet}")
|
||||
combined_text = query
|
||||
if triplet_texts:
|
||||
combined_text += " [三元组] " + "; ".join(triplet_texts)
|
||||
debug(
|
||||
f"拼接文本: {combined_text[:200]}... (总长度: {len(combined_text)}, 三元组数量: {len(triplet_texts)})")
|
||||
timing_stats["triplet_text_combine"] = time.time() - triplet_text_start
|
||||
debug(f"拼接三元组文本耗时: {timing_stats['triplet_text_combine']:.3f} 秒")
|
||||
|
||||
# 将拼接文本转换为向量
|
||||
vector_start = time.time()
|
||||
query_vector = await api_service.get_embeddings(
|
||||
request=request,
|
||||
texts=[combined_text],
|
||||
upappid=service_params['embedding'],
|
||||
user=userid
|
||||
)
|
||||
if not query_vector or not all(len(vec) == 1024 for vec in query_vector):
|
||||
raise ValueError("查询向量必须是长度为 1024 的浮点数列表")
|
||||
query_vector = query_vector[0] # 取第一个向量
|
||||
timing_stats["vector_generation"] = time.time() - vector_start
|
||||
debug(f"生成查询向量耗时: {timing_stats['vector_generation']:.3f} 秒")
|
||||
|
||||
# 调用搜索端点
|
||||
limit = 5
|
||||
search_start = time.time()
|
||||
result = await api_service.milvus_search_query(
|
||||
request=request,
|
||||
query_vector=query_vector,
|
||||
userid=orgid,
|
||||
knowledge_base_ids=[fiids],
|
||||
upappid=service_params['vdb'],
|
||||
apiname="mlvus/searchquery",
|
||||
user=userid
|
||||
)
|
||||
timing_stats["vector_search"] = time.time() - search_start
|
||||
debug(f"向量搜索耗时: {timing_stats['vector_search']:.3f} 秒")
|
||||
|
||||
if result.get("status") != "success":
|
||||
error(f"融合搜索失败: {result.get('message', '未知错误')}")
|
||||
return {"results": [], "timing": timing_stats}
|
||||
|
||||
unique_results = result.get("results", [])
|
||||
use_rerank = True
|
||||
if use_rerank and unique_results:
|
||||
rerank_start = time.time()
|
||||
debug("开始重排序")
|
||||
unique_results = await api_service(
|
||||
request=request,
|
||||
query=combined_text,
|
||||
results=unique_results,
|
||||
top_n=limit,
|
||||
upappid=service_params['reranker'],
|
||||
apiname="BAAI/bge-reranker-v2-m3",
|
||||
user=userid
|
||||
)
|
||||
unique_results = sorted(unique_results, key=lambda x: x.get('rerank_score', 0), reverse=True)
|
||||
timing_stats["reranking"] = time.time() - rerank_start
|
||||
debug(f"重排序耗时: {timing_stats['reranking']:.3f} 秒")
|
||||
debug(f"重排序分数分布: {[round(r.get('rerank_score', 0), 3) for r in unique_results]}")
|
||||
else:
|
||||
unique_results = [{k: v for k, v in r.items() if k != 'rerank_score'} for r in unique_results]
|
||||
|
||||
timing_stats["total_time"] = time.time() - start_time
|
||||
info(f"融合搜索完成,返回 {len(unique_results)} 条结果,总耗时: {timing_stats['total_time']:.3f} 秒")
|
||||
return {"results": unique_results[:limit], "timing": timing_stats}
|
||||
|
||||
except Exception as e:
|
||||
error(f"融合搜索失败: {str(e)}, 堆栈: {traceback.format_exc()}")
|
||||
return {"results": [], "timing": timing_stats}
|
||||
Loading…
x
Reference in New Issue
Block a user