This commit is contained in:
wangmeihua 2025-08-15 17:44:03 +08:00
parent 0fdb0d3393
commit a939351267

View File

@ -1,5 +1,10 @@
from rag.uapi_service import APIService from rag.uapi_service import APIService
from rag.folderinfo import RagFileMgr from rag.folderinfo import RagFileMgr
from sqlor.dbpools import DBPools
from appPublic.log import debug, error, info
import time
import traceback
import json
helptext = """kyrag API: helptext = """kyrag API:
@ -9,6 +14,7 @@ headers: {
"Content-Type": "application/json" "Content-Type": "application/json"
} }
response: response:
[{"id": "1", "name": "textdb", "description": "文本数据库"}, {"id": "testkdb", "name": "testkdb", "description": ""}, {"id": "Vdtbt3qBfocteit1HIxVH", "name": "trm", "description": ""}]
2. 向量检索文本块: 2. 向量检索文本块:
path: /v1/fusedsearch path: /v1/fusedsearch
@ -19,7 +25,9 @@ data: {
"query": "什么是知识抽取。" "query": "什么是知识抽取。"
"fiids":["1"] "fiids":["1"]
} }
3docs文档 3docs文档
path: /v1/docs
""" """
async def docs(request, params_kw, *params, **kw): async def docs(request, params_kw, *params, **kw):
@ -27,7 +35,6 @@ async def docs(request, params_kw, *params, **kw):
async def get_kdbs(request, params_kw, *params, **kw): async def get_kdbs(request, params_kw, *params, **kw):
"""返回 kdb 表的全部内容,返回 JSON""" """返回 kdb 表的全部内容,返回 JSON"""
print("初始化数据库连接池...")
db = DBPools() db = DBPools()
dbname = kw.get('get_module_dbname')('rag') dbname = kw.get('get_module_dbname')('rag')
sql_opts = """ sql_opts = """
@ -47,13 +54,17 @@ async def get_kdbs(request, params_kw, *params, **kw):
async def fusedsearch(request, params_kw, *params, **kw): async def fusedsearch(request, params_kw, *params, **kw):
"""融合搜索,调用服务化端点""" """融合搜索,调用服务化端点"""
f = kw.get('get_userorgid') # f = kw.get('get_userorgid')
orgid = await f() # orgid = await f()
f = kw.get('get_user') # debug(f"orgid: {orgid}{f=}")
userid = await f() # f = kw.get('get_user')
# userid = await f()
debug(f"params_kw: {params_kw}") debug(f"params_kw: {params_kw}")
orgid = "04J6VbxLqB_9RPMcgOv_8"
userid = "04J6VbxLqB_9RPMcgOv_8"
query = params_kw.get('query', '') query = params_kw.get('query', '')
fiids = params_kw.get('fiid', []) fiids = params_kw.get('fiids', [])
debug(f"fiids: {fiids}")
# 验证 fiids的orgid与orgid = await f()是否一致 # 验证 fiids的orgid与orgid = await f()是否一致
if fiids: if fiids:
@ -62,11 +73,11 @@ async def fusedsearch(request, params_kw, *params, **kw):
sql_opts = """ sql_opts = """
SELECT orgid SELECT orgid
FROM kdb FROM kdb
WHERE id = ${fiid}$ WHERE id = ${id}$
""" """
try: try:
async with db.sqlorContext(dbname) as sor: async with db.sqlorContext(dbname) as sor:
result = await sor.sqlExe(sql_opts, {"fiid": fiids[0]}) result = await sor.sqlExe(sql_opts, {"id": fiids[0]})
if not result: if not result:
raise ValueError(f"未找到 fiid={fiids[0]} 的记录") raise ValueError(f"未找到 fiid={fiids[0]} 的记录")
kdb_orgid = result[0].get('orgid') kdb_orgid = result[0].get('orgid')
@ -76,7 +87,7 @@ async def fusedsearch(request, params_kw, *params, **kw):
error(f"orgid 验证失败: {str(e)}") error(f"orgid 验证失败: {str(e)}")
return json.dumps({"status": "error", "message": str(e)}) return json.dumps({"status": "error", "message": str(e)})
ragfilemgr = RagFileMgr("fiids[0]") ragfilemgr = RagFileMgr("fiids[0]")
service_params = ragfilemgr.get_service_params(orgid) service_params = await ragfilemgr.get_service_params(orgid)
api_service = APIService() api_service = APIService()
start_time = time.time() start_time = time.time()
@ -85,7 +96,7 @@ async def fusedsearch(request, params_kw, *params, **kw):
info( info(
f"开始融合搜索: query={query}, userid={orgid}, knowledge_base_ids={fiids}") f"开始融合搜索: query={query}, userid={orgid}, knowledge_base_ids={fiids}")
if not query or not orgid or not knowledge_base_ids: if not query or not orgid or not fiids:
raise ValueError("query、orgid 和 knowledge_base_ids 不能为空") raise ValueError("query、orgid 和 knowledge_base_ids 不能为空")
# 提取实体 # 提取实体
@ -154,6 +165,7 @@ async def fusedsearch(request, params_kw, *params, **kw):
request=request, request=request,
texts=[combined_text], texts=[combined_text],
upappid=service_params['embedding'], upappid=service_params['embedding'],
apiname="BAAI/bge-m3",
user=userid user=userid
) )
if not query_vector or not all(len(vec) == 1024 for vec in query_vector): if not query_vector or not all(len(vec) == 1024 for vec in query_vector):
@ -169,7 +181,9 @@ async def fusedsearch(request, params_kw, *params, **kw):
request=request, request=request,
query_vector=query_vector, query_vector=query_vector,
userid=orgid, userid=orgid,
knowledge_base_ids=[fiids], knowledge_base_ids=fiids,
limit=limit,
offset=0,
upappid=service_params['vdb'], upappid=service_params['vdb'],
apiname="mlvus/searchquery", apiname="mlvus/searchquery",
user=userid user=userid
@ -186,7 +200,7 @@ async def fusedsearch(request, params_kw, *params, **kw):
if use_rerank and unique_results: if use_rerank and unique_results:
rerank_start = time.time() rerank_start = time.time()
debug("开始重排序") debug("开始重排序")
unique_results = await api_service( unique_results = await api_service.rerank_results(
request=request, request=request,
query=combined_text, query=combined_text,
results=unique_results, results=unique_results,