From 169dbd3a565557b759ad23f05a08c7c7b100c8f1 Mon Sep 17 00:00:00 2001 From: wangmeihua <13383952685@163.com> Date: Mon, 28 Jul 2025 18:52:14 +0800 Subject: [PATCH] =?UTF-8?q?=E4=BF=AE=E6=94=B9init.py?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit --- rag/init.py | 155 +++++++++++++++++++++++++++++++++++++++++++++++++++- 1 file changed, 153 insertions(+), 2 deletions(-) diff --git a/rag/init.py b/rag/init.py index 87ffb6d..916e7eb 100644 --- a/rag/init.py +++ b/rag/init.py @@ -1,7 +1,158 @@ -from appPublic.worker import awaitify +from appPublic.log import debug, error, info from ahserver.serverenv import ServerEnv +import aiohttp +from aiohttp import ClientSession, ClientTimeout +import json + +async def _make_connection_request(action: str, params: dict = None) -> dict: + """ + 通用函数,调用 MilvusConnection 的服务化接口。 + + 参数: + action (str): 操作类型,例如 'initialize'、'insert_document'。 + params (dict): 操作参数,默认为 None。 + + 返回: + dict: 服务响应,包含 status、message、collection_name 等字段。 + """ + if params is None: + params = {} + + url = "http://localhost:8888/v1/connection" + payload = {"action": action, "params": params} + + try: + debug(f"发起 {action} 请求: params={params}") + async with ClientSession(timeout=ClientTimeout(total=10)) as session: + async with session.post( + url, + headers={"Content-Type": "application/json"}, + json=payload + ) as response: + response_text = await response.text() + debug(f"收到 {action} 响应: status={response.status}, content={response_text}") + if response.status != 200: + error(f"{action} 请求失败: 状态码={response.status}, 响应={response_text}") + return { + "status": "error", + "message": f"请求失败: 状态码 {response.status}", + "collection_name": params.get("db_type", "ragdb"), + "document_id": "", + "status_code": response.status + } + result = await response.json() + info(f"{action} 请求成功: 结果={result}") + return result + except Exception as e: + error(f"{action} 请求异常: 错误={str(e)}") + return { + "status": "error", + "message": f"服务器错误: {str(e)}", + "collection_name": params.get("db_type", "ragdb"), + "document_id": "", + "status_code": 500 + } + +async def initialize() -> dict: + """初始化 服务""" + return await _make_connection_request("initialize") + +async def get_params() -> dict: + """获取服务参数""" + return await _make_connection_request("get_params") + +async def create_collection(db_type: str = "") -> dict: + """创建 Milvus 集合""" + return await _make_connection_request("create_collection", {"db_type": db_type}) + +async def delete_collection(db_type: str = "") -> dict: + """删除 Milvus 集合""" + return await _make_connection_request("delete_collection", {"db_type": db_type}) + +async def insert_document(file_path: str, userid: str, knowledge_base_id: str, db_type: str = "") -> dict: + """插入文档到 Milvus 并抽取三元组到 Neo4j""" + params = { + "file_path": file_path, + "userid": userid, + "knowledge_base_id": knowledge_base_id, + "db_type": db_type + } + return await _make_connection_request("insert_document", params) + +async def delete_document(userid: str, filename: str, knowledge_base_id: str, db_type: str = "") -> dict: + """删除指定文档的 Milvus 和 Neo4j 记录""" + params = { + "userid": userid, + "filename": filename, + "knowledge_base_id": knowledge_base_id, + "db_type": db_type + } + return await _make_connection_request("delete_document", params) + +async def delete_knowledge_base(userid: str, knowledge_base_id: str, db_type: str = "") -> dict: + """删除整个知识库的 Milvus 和 Neo4j 记录""" + params = { + "userid": userid, + "knowledge_base_id": knowledge_base_id, + "db_type": db_type + } + return await _make_connection_request("delete_knowledge_base", params) + +async def search_query(query: str, userid: str, knowledge_base_ids: list, limit: int = 5, offset: int = 0, + use_rerank: bool = True, db_type: str = "") -> dict: + """执行纯向量搜索""" + params = { + "query": query, + "userid": userid, + "knowledge_base_ids": knowledge_base_ids, + "limit": limit, + "offset": offset, + "use_rerank": use_rerank, + "db_type": db_type + } + return await _make_connection_request("search_query", params) + +async def fused_search(query: str, userid: str, knowledge_base_ids: list, limit: int = 5, offset: int = 0, + use_rerank: bool = True, db_type: str = "") -> dict: + """执行融合搜索(向量 + 三元组)""" + params = { + "query": query, + "userid": userid, + "knowledge_base_ids": knowledge_base_ids, + "limit": limit, + "offset": offset, + "use_rerank": use_rerank, + "db_type": db_type + } + return await _make_connection_request("fused_search", params) + +async def list_user_files(userid: str, db_type: str = "") -> dict: + """列出用户的所有知识库及其文件""" + params = { + "userid": userid, + "db_type": db_type + } + return await _make_connection_request("list_user_files", params) + +async def list_all_knowledge_bases(db_type: str = "") -> dict: + """列出所有用户的知识库及其文件""" + return await _make_connection_request("list_all_knowledge_bases", {"db_type": db_type}) def load_rag(): - env = ServerEnv() + """ + 初始化 ServerEnv,绑定 MilvusConnection 的所有功能。 + """ + env = ServerEnv() + env.initialize = initialize + env.get_params = get_params + env.create_collection = create_collection + env.delete_collection = delete_collection + env.insert_document = insert_document + env.delete_document = delete_document + env.delete_knowledge_base = delete_knowledge_base + env.search_query = search_query + env.fused_search = fused_search + env.list_user_files = list_user_files + env.list_all_knowledge_bases = list_all_knowledge_bases