From 75faf7e55ebde5ec63fb9191130a20cfb60c7c21 Mon Sep 17 00:00:00 2001 From: wangmeihua <13383952685@163.com> Date: Wed, 27 Aug 2025 13:21:52 +0800 Subject: [PATCH] rag --- rag/ragapi.py | 46 ++++++++++++++++++++++++++++++++++------------ rag/ragprogram.py | 1 + 2 files changed, 35 insertions(+), 12 deletions(-) diff --git a/rag/ragapi.py b/rag/ragapi.py index b02addc..b143d24 100644 --- a/rag/ragapi.py +++ b/rag/ragapi.py @@ -28,6 +28,9 @@ data: { 3、docs文档 path: /v1/docs + +4、longmemory存储 + """ async def docs(request, params_kw, *params, **kw): @@ -35,15 +38,20 @@ async def docs(request, params_kw, *params, **kw): async def get_kdbs(request, params_kw, *params, **kw): """返回 kdb 表的全部内容,返回 JSON""" + f = kw.get('get_userorgid') + orgid = await f() + debug(f"orgid: {orgid},{f=}") + debug(f"params_kw: {params_kw}") db = DBPools() dbname = kw.get('get_module_dbname')('rag') sql_opts = """ SELECT id, name, description FROM kdb + WHERE orgid = ${orgid}$ """ try: async with db.sqlorContext(dbname) as sor: - opts_result = await sor.sqlExe(sql_opts, {}) + opts_result = await sor.sqlExe(sql_opts, {"orgid": orgid}) if not opts_result: error("未找到 kdb 表记录") return json.dumps({"status": "error", "message": "未找到记录"}) @@ -52,20 +60,32 @@ async def get_kdbs(request, params_kw, *params, **kw): error(f"查询 kdb 表失败: {str(e)}, 堆栈: {traceback.format_exc()}") return json.dumps({"status": "error", "message": str(e)}) + except Exception as e: + error(f"列出用户文件失败: {str(e)}, 堆栈: {traceback.format_exc()}") + result.update({ + "status": "error", + "files_by_knowledge_base": {}, + "message": f"列出用户文件失败: {str(e)}", + "status_code": 400 + }) + return result + async def fusedsearch(request, params_kw, *params, **kw): """融合搜索,调用服务化端点""" - # f = kw.get('get_userorgid') - # orgid = await f() - # debug(f"orgid: {orgid},{f=}") - # f = kw.get('get_user') - # userid = await f() + f = kw.get('get_userorgid') + orgid = await f() + debug(f"orgid: {orgid},{f=}") + f = kw.get('get_user') + userid = await f() debug(f"params_kw: {params_kw}") - orgid = "04J6VbxLqB_9RPMcgOv_8" - userid = "04J6VbxLqB_9RPMcgOv_8" + # orgid = "04J6VbxLqB_9RPMcgOv_8" + # userid = "04J6VbxLqB_9RPMcgOv_8" query = params_kw.get('query', '') fiids = params_kw.get('fiids', []) + limit = int(params_kw.get('limit', 5)) debug(f"fiids: {fiids}") - + if isinstance(fiids, str): + fiids = [f.strip() for f in fiids.split(',') if f.strip()] # 验证 fiids的orgid与orgid = await f()是否一致 if fiids: db = DBPools() @@ -153,7 +173,7 @@ async def fusedsearch(request, params_kw, *params, **kw): debug(f"无效三元组: {triplet}") combined_text = query if triplet_texts: - combined_text += " [三元组] " + "; ".join(triplet_texts) + combined_text += "".join(triplet_texts) debug( f"拼接文本: {combined_text[:200]}... (总长度: {len(combined_text)}, 三元组数量: {len(triplet_texts)})") timing_stats["triplet_text_combine"] = time.time() - triplet_text_start @@ -175,14 +195,14 @@ async def fusedsearch(request, params_kw, *params, **kw): debug(f"生成查询向量耗时: {timing_stats['vector_generation']:.3f} 秒") # 调用搜索端点 - limit = 5 + sum = limit + 5 search_start = time.time() result = await api_service.milvus_search_query( request=request, query_vector=query_vector, userid=orgid, knowledge_base_ids=fiids, - limit=limit, + limit=sum, offset=0, upappid=service_params['vdb'], apiname="mlvus/searchquery", @@ -196,6 +216,8 @@ async def fusedsearch(request, params_kw, *params, **kw): return {"results": [], "timing": timing_stats} unique_results = result.get("results", []) + sum = len(unique_results) + debug(f"从向量数据中搜索到{sum}条数据") use_rerank = True if use_rerank and unique_results: rerank_start = time.time() diff --git a/rag/ragprogram.py b/rag/ragprogram.py index d711dd6..a6e8a3b 100644 --- a/rag/ragprogram.py +++ b/rag/ragprogram.py @@ -1,5 +1,6 @@ from appPublic.timeUtils import curDateString, dateAdd from ahserver.serverenv import get_serverenv +from sqlor.dbpools import DBPools async def set_program(request, program_type, quota, term=1): db = DBPools()