From 3dcb706d2f32f742847da6ed6293a4b703026e9b Mon Sep 17 00:00:00 2001 From: wangmeihua <13383952685@163.com> Date: Fri, 10 Oct 2025 13:35:01 +0800 Subject: [PATCH] rag --- rag/init.py | 174 +++------------------------------------------------- 1 file changed, 7 insertions(+), 167 deletions(-) diff --git a/rag/init.py b/rag/init.py index 63a1e91..b392b76 100644 --- a/rag/init.py +++ b/rag/init.py @@ -1,175 +1,15 @@ -# -from appPublic.log import debug, error, info -from ahserver.serverenv import ServerEnv -import aiohttp -from aiohttp import ClientSession, ClientTimeout -import json -from .file import file_uploaded, file_deleted -from .folderinfo import RagFileMgr -from .ragprogram import set_program, get_rag_programs -from .ragllm_utils import get_ragllms_by_catelog from appPublic.registerfunction import RegisterFunction -from sqlor.dbpools import DBPools - -async def _make_connection_request(action: str, params: dict = None) -> dict: - """ - 通用函数,调用 MilvusConnection 的服务化接口。 - - 参数: - action (str): 操作类型,例如 'initialize'、'insert_document'。 - params (dict): 操作参数,默认为 None。 - - 返回: - dict: 服务响应,包含 status、message、collection_name 等字段。 - """ - if params is None: - params = {} - - url = f"http://localhost:8888/v1/{action}" - - try: - debug(f"发起 {action} 请求: params={params}") - async with ClientSession(timeout=ClientTimeout(total=10)) as session: - async with session.post( - url, - headers={"Content-Type": "application/json"}, - json=params - ) as response: - response_text = await response.text() - debug(f"收到 {action} 响应: status={response.status}, content={response_text}") - if response.status != 200: - error(f"{action} 请求失败: 状态码={response.status}, 响应={response_text}") - return { - "status": "error", - "message": f"请求失败: 状态码 {response.status}", - "collection_name": params.get("db_type", "ragdb"), - "document_id": "", - "status_code": response.status - } - result = await response.json() - info(f"{action} 请求成功: 结果={result}") - return result - except Exception as e: - error(f"{action} 请求异常: 错误={str(e)}") - return { - "status": "error", - "message": f"服务器错误: {str(e)}", - "collection_name": params.get("db_type", "ragdb"), - "document_id": "", - "status_code": 500 - } - -async def create_collection(db_type: str = "") -> dict: - """创建 Milvus 集合""" - return await _make_connection_request("create_collection", {"db_type": db_type}) - -async def delete_collection(db_type: str = "") -> dict: - """删除 Milvus 集合""" - return await _make_connection_request("delete_collection", {"db_type": db_type}) - -async def insert_document(file_path: str, userid: str, knowledge_base_id: str, db_type: str = "") -> dict: - """插入文档到 Milvus 并抽取三元组到 Neo4j""" - params = { - "file_path": file_path, - "userid": userid, - "knowledge_base_id": knowledge_base_id, - "db_type": db_type - } - return await _make_connection_request("insert_document", params) - -async def delete_document(userid: str, filename: str, knowledge_base_id: str, db_type: str = "") -> dict: - """删除指定文档的 Milvus 和 Neo4j 记录""" - params = { - "userid": userid, - "filename": filename, - "knowledge_base_id": knowledge_base_id, - "db_type": db_type - } - return await _make_connection_request("delete_document", params) - -async def delete_knowledge_base(userid: str, knowledge_base_id: str, db_type: str = "") -> dict: - """删除整个知识库的 Milvus 和 Neo4j 记录""" - params = { - "userid": userid, - "knowledge_base_id": knowledge_base_id, - "db_type": db_type - } - return await _make_connection_request("delete_knowledge_base", params) - -async def search_query(query: str, userid: str, knowledge_base_ids: list, limit: int = 5, offset: int = 0, - use_rerank: bool = True, db_type: str = "") -> dict: - """执行纯向量搜索""" - params = { - "query": query, - "userid": userid, - "knowledge_base_ids": knowledge_base_ids, - "limit": limit, - "offset": offset, - "use_rerank": use_rerank, - "db_type": db_type - } - return await _make_connection_request("search_query", params) - -async def fused_search(query: str, userid: str, knowledge_base_ids: list, limit: int = 5, offset: int = 0, - use_rerank: bool = True, db_type: str = "") -> dict: - """执行融合搜索(向量 + 三元组)""" - params = { - "query": query, - "userid": userid, - "knowledge_base_ids": knowledge_base_ids, - "limit": limit, - "offset": offset, - "use_rerank": use_rerank, - "db_type": db_type - } - return await _make_connection_request("fused_search", params) - -async def list_user_files(userid: str, db_type: str = "") -> dict: - """列出用户的所有知识库及其文件""" - params = { - "userid": userid, - "db_type": db_type - } - return await _make_connection_request("list_user_files", params) - -async def list_all_knowledge_bases(db_type: str = "") -> dict: - """列出所有用户的知识库及其文件""" - return await _make_connection_request("list_all_knowledge_bases", {"db_type": db_type}) - -async def docs() -> dict: - """列出所有用户的知识库及其文件""" - return await _make_connection_request("docs", {}) - -async def get_user_kdbs(request): - env = request._run_ns - db = env.DBPools() - dbname = env.get_module_dbname('rag') - userorgid = await env.get_userorgid() - async with db.sqlorContext(dbname) as sor: - sql = "select * from kdb where orgid = ${orgid}$" - recs = await sor.sqlExe(sql, {'orgid': userorgid}) - return recs - return [] +from rag.ragapi import docs, get_kdbs, fusedsearch, add_user_messages, get_user_memories def load_rag(): """ 初始化 ServerEnv,绑定 MilvusConnection 的所有功能。 """ - env = ServerEnv() - env.create_collection = create_collection - env.delete_collection = delete_collection - env.insert_document = insert_document - env.delete_document = delete_document - env.delete_knowledge_base = delete_knowledge_base - env.search_query = search_query - env.fused_search = fused_search - env.list_user_files = list_user_files - env.list_all_knowledge_bases = list_all_knowledge_bases - env.docs = docs - env.RagFileMgr = RagFileMgr - env.set_program = set_program - env.get_rag_programs = get_rag_programs - env.get_ragllms_by_catelog = get_ragllms_by_catelog - env.get_user_kdbs = get_user_kdbs + rf = RegisterFunction() + rf.register('docs', docs) + rf.register('get_kdbs', get_kdbs) + rf.register('fusedsearch', fusedsearch) + rf.register('add_user_messages', add_user_messages) + rf.register('get_user_memories', get_user_memories)