diff --git a/llmengine/base_connection.py b/llmengine/base_connection.py old mode 100644 new mode 100755 diff --git a/llmengine/base_db.py b/llmengine/base_db.py old mode 100644 new mode 100755 diff --git a/llmengine/connection.py b/llmengine/connection.py old mode 100644 new mode 100755 index 933503f..06f37ec --- a/llmengine/connection.py +++ b/llmengine/connection.py @@ -1,4 +1,4 @@ -import milvus_connection +import llmengine.milvus_connection from traceback import format_exc import argparse from aiohttp import web @@ -403,8 +403,8 @@ async def delete_file(request, params_kw, *params, **kw): result = await engine.handle_connection("delete_document", { "userid": userid, "filename": filename, - "db_type": db_type, - "knowledge_base_id": knowledge_base_id + "knowledge_base_id": knowledge_base_id, + "db_type": db_type }) debug(f'Delete result: {result=}') status = 200 if result.get("status") == "success" else 400 @@ -454,60 +454,15 @@ async def delete_knowledge_base(request, params_kw, *params, **kw): "status_code": 400 }, dumps=lambda obj: json.dumps(obj, ensure_ascii=False), status=400) -async def fused_search_query(request, params_kw, *params, **kw): - debug(f'{params_kw=}') - se = ServerEnv() - engine = se.engine - query = params_kw.get('query') - userid = params_kw.get('userid') - db_type = params_kw.get('db_type', '') - knowledge_base_ids = params_kw.get('knowledge_base_ids') - limit = params_kw.get('limit') - offset = params_kw.get('offset', 0) - use_rerank = params_kw.get('use_rerank', True) - collection_name = "ragdb" if not db_type else f"ragdb_{db_type}" - try: - if not all([query, userid, knowledge_base_ids]): - debug(f'query, userid 或 knowledge_base_ids 未提供') - return web.json_response({ - "status": "error", - "message": "query, userid 或 knowledge_base_ids 未提供", - "collection_name": collection_name - }, dumps=lambda obj: json.dumps(obj, ensure_ascii=False), status=400) - result = await engine.handle_connection("fused_search", { - "query": query, - "userid": userid, - "db_type": db_type, - "knowledge_base_ids": knowledge_base_ids, - "limit": limit, - "offset": offset, - "use_rerank": use_rerank - }) - debug(f'{result=}') - response = { - "status": "success", - "results": result.get("results", []), - "timing": result.get("timing", {}), - "collection_name": collection_name - } - return web.json_response(response, dumps=lambda obj: json.dumps(obj, ensure_ascii=False)) - except Exception as e: - error(f'融合搜索失败: {str(e)}') - return web.json_response({ - "status": "error", - "message": str(e), - "collection_name": collection_name - }, dumps=lambda obj: json.dumps(obj, ensure_ascii=False), status=400) - async def search_query(request, params_kw, *params, **kw): - debug(f'{params_kw=}') + debug(f'Received search_query params: {params_kw=}') se = ServerEnv() engine = se.engine query = params_kw.get('query') userid = params_kw.get('userid') db_type = params_kw.get('db_type', '') knowledge_base_ids = params_kw.get('knowledge_base_ids') - limit = params_kw.get('limit') + limit = params_kw.get('limit', 5) offset = params_kw.get('offset', 0) use_rerank = params_kw.get('use_rerank', True) collection_name = "ragdb" if not db_type else f"ragdb_{db_type}" @@ -522,13 +477,13 @@ async def search_query(request, params_kw, *params, **kw): result = await engine.handle_connection("search_query", { "query": query, "userid": userid, - "db_type": db_type, "knowledge_base_ids": knowledge_base_ids, "limit": limit, "offset": offset, - "use_rerank": use_rerank + "use_rerank": use_rerank, + "db_type": db_type }) - debug(f'{result=}') + debug(f'Search result: {result=}') response = { "status": "success", "results": result.get("results", []), @@ -544,6 +499,51 @@ async def search_query(request, params_kw, *params, **kw): "collection_name": collection_name }, dumps=lambda obj: json.dumps(obj, ensure_ascii=False), status=400) +async def fused_search_query(request, params_kw, *params, **kw): + debug(f'Received fused_search_query params: {params_kw=}') + se = ServerEnv() + engine = se.engine + query = params_kw.get('query') + userid = params_kw.get('userid') + db_type = params_kw.get('db_type', '') + knowledge_base_ids = params_kw.get('knowledge_base_ids') + limit = params_kw.get('limit', 5) + offset = params_kw.get('offset', 0) + use_rerank = params_kw.get('use_rerank', True) + collection_name = "ragdb" if not db_type else f"ragdb_{db_type}" + try: + if not all([query, userid, knowledge_base_ids]): + debug(f'query, userid 或 knowledge_base_ids 未提供') + return web.json_response({ + "status": "error", + "message": "query, userid 或 knowledge_base_ids 未提供", + "collection_name": collection_name + }, dumps=lambda obj: json.dumps(obj, ensure_ascii=False), status=400) + result = await engine.handle_connection("fused_search", { + "query": query, + "userid": userid, + "knowledge_base_ids": knowledge_base_ids, + "limit": limit, + "offset": offset, + "use_rerank": use_rerank, + "db_type": db_type + }) + debug(f'Fused search result: {result=}') + response = { + "status": "success", + "results": result.get("results", []), + "timing": result.get("timing", {}), + "collection_name": collection_name + } + return web.json_response(response, dumps=lambda obj: json.dumps(obj, ensure_ascii=False)) + except Exception as e: + error(f'融合搜索失败: {str(e)}') + return web.json_response({ + "status": "error", + "message": str(e), + "collection_name": collection_name + }, dumps=lambda obj: json.dumps(obj, ensure_ascii=False), status=400) + async def list_user_files(request, params_kw, *params, **kw): debug(f'{params_kw=}') se = ServerEnv() diff --git a/llmengine/db_service.py b/llmengine/db_service.py old mode 100644 new mode 100755 index 6ed65a2..164bed3 --- a/llmengine/db_service.py +++ b/llmengine/db_service.py @@ -276,33 +276,17 @@ async def insert_document(request, params_kw, *params, **kw): debug(f'Received params: {params_kw=}') se = ServerEnv() engine = se.engine - userid = params_kw.get('userid', '') - knowledge_base_id = params_kw.get('knowledge_base_id', '') - document_id = params_kw.get('document_id', '') - texts = params_kw.get('texts', []) - embeddings = params_kw.get('embeddings', []) - filename = params_kw.get('filename', '') - file_path = params_kw.get('file_path', '') - upload_time = params_kw.get('upload_time', '') - file_type = params_kw.get('file_type', '') + chunks = params_kw.get('chunks', '') db_type = params_kw.get('db_type', '') collection_name = "ragdb" if not db_type else f"ragdb_{db_type}" try: - required_fields = ['userid', 'knowledge_base_id', 'texts', 'embeddings'] + required_fields = ['chunks'] missing_fields = [field for field in required_fields if field not in params_kw or not params_kw[field]] if missing_fields: raise ValueError(f"缺少必填字段: {', '.join(missing_fields)}") result = await engine.handle_connection("insert_document", { - "userid": userid, - "knowledge_base_id": knowledge_base_id, - "document_id": document_id, - "texts": texts, - "embeddings": embeddings, - "filename": filename, - "file_path": file_path, - "upload_time": upload_time, - "file_type": file_type, + "chunks": chunks, "db_type": db_type }) debug(f'Insert result: {result=}') diff --git a/llmengine/milvus_connection.py b/llmengine/milvus_connection.py old mode 100644 new mode 100755 index 3de3f49..9d08f31 --- a/llmengine/milvus_connection.py +++ b/llmengine/milvus_connection.py @@ -2,11 +2,11 @@ from appPublic.jsonConfig import getConfig import os from appPublic.log import debug, error, info import yaml -from pymilvus import connections, utility, Collection, CollectionSchema, FieldSchema, DataType from threading import Lock from llmengine.base_connection import connection_register from typing import Dict, List, Any import aiohttp +from aiohttp import ClientSession, ClientTimeout from langchain_core.documents import Document from langchain_text_splitters import RecursiveCharacterTextSplitter import uuid @@ -41,34 +41,43 @@ class MilvusConnection: return try: config = getConfig() - self.db_path = config['milvus_db'] self.neo4j_uri = config['neo4j']['uri'] self.neo4j_user = config['neo4j']['user'] self.neo4j_password = config['neo4j']['password'] except KeyError as e: error(f"配置文件缺少必要字段: {str(e)}") raise RuntimeError(f"配置文件缺少必要字段: {str(e)}") - self._initialize_connection() self._initialized = True - info(f"MilvusConnection initialized with db_path: {self.db_path}") + info("Neo4jConnection initialized") - def _initialize_connection(self): - """初始化 Milvus 连接,确保单一连接""" + @retry(stop=stop_after_attempt(3)) + async def _make_api_request(self, action: str, params: Dict[str, Any]) -> Dict[str, Any]: + debug(f"开始 API 请求: action={action}, params={params}") try: - db_dir = os.path.dirname(self.db_path) - if not os.path.exists(db_dir): - os.makedirs(db_dir, exist_ok=True) - debug(f"创建 Milvus 目录: {db_dir}") - if not os.access(db_dir, os.W_OK): - raise RuntimeError(f"Milvus 目录 {db_dir} 不可写") - if not connections.has_connection("default"): - connections.connect("default", uri=self.db_path) - debug(f"已连接到 Milvus Lite,路径: {self.db_path}") - else: - debug("已存在 Milvus 连接,跳过重复连接") + async with ClientSession(timeout=ClientTimeout(total=10)) as session: + url = f"http://localhost:8886/v1/{action}" + debug(f"发起 POST 请求: {url}") + async with session.post( + url, + headers={"Content-Type": "application/json"}, + json=params + ) as response: + debug(f"收到响应: status={response.status}, headers={response.headers}") + response_text = await response.text() + debug(f"响应内容: {response_text}") + result = await response.json() + debug(f"API 响应内容: {result}") + if response.status == 400: # 客户端错误,直接返回 + debug(f"客户端错误,状态码: {response.status}, 返回响应: {result}") + return result + if response.status != 200: + error(f"API 调用失败,动作: {action}, 状态码: {response.status}, 响应: {response_text}") + raise RuntimeError(f"API 调用失败: {response.status}") + debug(f"API 调用成功: {action}, 响应: {result}") + return result except Exception as e: - error(f"连接 Milvus 失败: {str(e)}") - raise RuntimeError(f"连接 Milvus 失败: {str(e)}") + error(f"API 调用失败: {action}, 错误: {str(e)}, 堆栈: {traceback.format_exc()}") + raise RuntimeError(f"API 调用失败: {str(e)}") async def handle_connection(self, action: str, params: Dict = None) -> Dict: """处理数据库操作""" @@ -87,9 +96,9 @@ class MilvusConnection: "document_id": "", "status_code": 400} if action == "initialize": - return {"status": "success", "message": f"Milvus 连接已初始化,路径: {self.db_path}"} + return {"status": "success", "message": "Milvus 服务已初始化"} elif action == "get_params": - return {"status": "success", "params": {"uri": self.db_path}} + return {"status": "success", "params": {}} elif action == "create_collection": return await self._create_collection(db_type) elif action == "delete_collection": @@ -121,7 +130,7 @@ class MilvusConnection: if len(userid) > 100 or len(filename) > 255 or len(knowledge_base_id) > 100: return {"status": "error", "message": "userid、filename 或 knowledge_base_id 的长度超出限制", "collection_name": collection_name, "document_id": "", "status_code": 400} - return await self._delete_document(db_type, userid, filename, knowledge_base_id) + return await self._delete_document(userid, filename, knowledge_base_id, db_type) elif action == "delete_knowledge_base": userid = params.get("userid", "") knowledge_base_id = params.get("knowledge_base_id", "") @@ -135,59 +144,36 @@ class MilvusConnection: return {"status": "error", "message": "userid 或 knowledge_base_id 的长度超出限制", "collection_name": collection_name, "document_id": "", "status_code": 400} return await self._delete_knowledge_base(db_type, userid, knowledge_base_id) + elif action == "search_query": + query = params.get("query", "") + userid = params.get("userid", "") + knowledge_base_ids = params.get("knowledge_base_ids", []) + limit = params.get("limit", 5) + offset = params.get("offset", 0) + db_type = params.get("db_type", "") + use_rerank = params.get("use_rerank", True) + if not query or not userid or not knowledge_base_ids: + return {"status": "error", "message": "query、userid 或 knowledge_base_ids 不能为空", + "collection_name": collection_name, "document_id": "", "status_code": 400} + return await self._search_query(query, userid, knowledge_base_ids, limit, offset, use_rerank, db_type) elif action == "fused_search": query = params.get("query", "") userid = params.get("userid", "") knowledge_base_ids = params.get("knowledge_base_ids", []) limit = params.get("limit", 5) - if not query or not userid or not knowledge_base_ids: - return { - "status": "error", - "message": "query、userid 或 knowledge_base_ids 不能为空", - "collection_name": "ragdb" if not params.get("db_type","") else f"ragdb_{params.get('db_type')}", - "document_id": "", - "status_code": 400 - } - if limit < 1 or limit > 16384: - return { - "status": "error", - "message": "limit 必须在 1 到 16384 之间", - "collection_name": "ragdb" if not params.get("db_type","") else f"ragdb_{params.get('db_type')}", - "document_id": "", - "status_code": 400 - } - return await self._fused_search( - query, - userid, - params.get("db_type", ""), - knowledge_base_ids, - limit, - params.get("offset", 0), - params.get("use_rerank", True) - ) - elif action == "search_query": - query = params.get("query", "") - userid = params.get("userid", "") - limit = params.get("limit", "") - knowledge_base_ids = params.get("knowledge_base_ids", []) + offset = params.get("offset", 0) + db_type = params.get("db_type", "") + use_rerank = params.get("use_rerank", True) if not query or not userid or not knowledge_base_ids: return {"status": "error", "message": "query、userid 或 knowledge_base_ids 不能为空", "collection_name": collection_name, "document_id": "", "status_code": 400} - return await self._search_query( - query, - userid, - db_type, - knowledge_base_ids, - limit, - params.get("offset", 0), - params.get("use_rerank", True) - ) + return await self._fused_search(query, userid, knowledge_base_ids, limit, offset, use_rerank, db_type) elif action == "list_user_files": userid = params.get("userid", "") if not userid: return {"status": "error", "message": "userid 不能为空", "collection_name": collection_name, "document_id": "", "status_code": 400} - return await self._list_user_files(userid) + return await self._list_user_files(userid, db_type) elif action == "list_all_knowledge_bases": return await self._list_all_knowledge_bases(db_type) else: @@ -198,118 +184,34 @@ class MilvusConnection: return { "status": "error", "message": f"服务器错误: {str(e)}", - "collection_name": params.get("db_type", "ragdb") if params else "ragdb", + "collection_name": collection_name, "document_id": "", "status_code": 400 } - async def _create_collection(self, db_type: str = "") -> Dict: + async def _create_collection(self, db_type: str = "") -> Dict[str, Any]: """创建 Milvus 集合""" + collection_name = "ragdb" if not db_type else f"ragdb_{db_type}" try: - # 根据 db_type 决定集合名称 - collection_name = "ragdb" if not db_type else f"ragdb_{db_type}" if len(collection_name) > 255: raise ValueError(f"集合名称 {collection_name} 超过 255 个字符") - if db_type and "_" in db_type: - raise ValueError("db_type 不能包含下划线") - if db_type and len(db_type) > 100: + if len(db_type) > 100: raise ValueError("db_type 的长度应小于 100") - debug(f"集合名称: {collection_name}") - - fields = [ - FieldSchema(name="pk", dtype=DataType.VARCHAR, is_primary=True, max_length=36, auto_id=True), - FieldSchema(name="userid", dtype=DataType.VARCHAR, max_length=100), - FieldSchema(name="knowledge_base_id", dtype=DataType.VARCHAR, max_length=100), - FieldSchema(name="document_id", dtype=DataType.VARCHAR, max_length=36), - FieldSchema(name="text", dtype=DataType.VARCHAR, max_length=65535), - FieldSchema(name="vector", dtype=DataType.FLOAT_VECTOR, dim=1024), - FieldSchema(name="filename", dtype=DataType.VARCHAR, max_length=255), - FieldSchema(name="file_path", dtype=DataType.VARCHAR, max_length=1024), - FieldSchema(name="upload_time", dtype=DataType.VARCHAR, max_length=64), - FieldSchema(name="file_type", dtype=DataType.VARCHAR, max_length=64), - ] - schema = CollectionSchema( - fields=fields, - description="统一数据集合,包含用户ID、知识库ID、document_id 和元数据字段", - auto_id=True, - primary_field="pk", - ) - - if utility.has_collection(collection_name): - try: - collection = Collection(collection_name) - existing_schema = collection.schema - expected_fields = {f.name for f in fields} - actual_fields = {f.name for f in existing_schema.fields} - vector_field = next((f for f in existing_schema.fields if f.name == "vector"), None) - - schema_compatible = False - if expected_fields == actual_fields and vector_field is not None and vector_field.dtype == DataType.FLOAT_VECTOR: - dim = vector_field.params.get('dim', None) if hasattr(vector_field, 'params') and vector_field.params else None - schema_compatible = dim == 1024 - debug(f"检查集合 {collection_name} 的 schema: 字段匹配={expected_fields == actual_fields}, " - f"vector_field存在={vector_field is not None}, dtype={vector_field.dtype if vector_field else '无'}, " - f"dim={dim if dim is not None else '未定义'}") - if not schema_compatible: - debug(f"集合 {collection_name} 的 schema 不兼容,原因: " - f"字段不匹配: {expected_fields.symmetric_difference(actual_fields) or '无'}, " - f"vector_field: {vector_field is not None}, " - f"dtype: {vector_field.dtype if vector_field else '无'}, " - f"dim: {vector_field.params.get('dim', '未定义') if vector_field and hasattr(vector_field, 'params') and vector_field.params else '未定义'}") - utility.drop_collection(collection_name) - else: - collection.load() - debug(f"集合 {collection_name} 已存在并加载成功") - return { - "status": "success", - "collection_name": collection_name, - "message": f"集合 {collection_name} 已存在" - } - except Exception as e: - error(f"加载集合 {collection_name} 失败: {str(e)}") - return { - "status": "error", - "collection_name": collection_name, - "message": str(e) - } - - try: - collection = Collection(collection_name, schema) - collection.create_index( - field_name="vector", - index_params={"index_type": "AUTOINDEX", "metric_type": "COSINE"} - ) - for field in ["userid", "knowledge_base_id", "document_id", "filename", "file_path", "upload_time", "file_type"]: - collection.create_index( - field_name=field, - index_params={"index_type": "INVERTED"} - ) - collection.load() - debug(f"成功创建并加载集合: {collection_name}") - return { - "status": "success", - "collection_name": collection_name, - "message": f"集合 {collection_name} 创建成功" - } - except Exception as e: - error(f"创建集合 {collection_name} 失败: {str(e)}") - return { - "status": "error", - "collection_name": collection_name, - "message": str(e) - } + debug(f"调用创建集合端点: {collection_name}, 参数: {{'db_type': '{db_type}'}}") + result = await self._make_api_request("createcollection", {"db_type": db_type}) + return result except Exception as e: - error(f"创建集合失败: {str(e)}") + error(f"创建集合失败: {str(e)}, 堆栈: {traceback.format_exc()}") return { "status": "error", - "collection_name":collection_name, - "message": str(e) + "collection_name": collection_name, + "message": f"创建集合失败: {str(e)}", + "status_code": 400 } async def _delete_collection(self, db_type: str = "") -> Dict: - """删除 Milvus 集合""" + """删除 Milvus 集合通过服务化端点""" try: - # 根据 db_type 决定集合名称 collection_name = "ragdb" if not db_type else f"ragdb_{db_type}" if len(collection_name) > 255: raise ValueError(f"集合名称 {collection_name} 超过 255 个字符") @@ -317,37 +219,17 @@ class MilvusConnection: raise ValueError("db_type 不能包含下划线") if db_type and len(db_type) > 100: raise ValueError("db_type 的长度应小于 100") - debug(f"集合名称: {collection_name}") + debug(f"调用删除集合端点: {collection_name}") - if not utility.has_collection(collection_name): - debug(f"集合 {collection_name} 不存在") - return { - "status": "success", - "collection_name": collection_name, - "message": f"集合 {collection_name} 不存在,无需删除" - } - - try: - utility.drop_collection(collection_name) - debug(f"成功删除集合: {collection_name}") - return { - "status": "success", - "collection_name": collection_name, - "message": f"集合 {collection_name} 删除成功" - } - except Exception as e: - error(f"删除集合 {collection_name} 失败: {str(e)}") - return { - "status": "error", - "collection_name": collection_name, - "message": str(e) - } + result = await self._make_api_request("deletecollection", {"db_type": db_type}) + return result except Exception as e: error(f"删除集合失败: {str(e)}") return { "status": "error", "collection_name": collection_name, - "message": str(e) + "message": str(e), + "status_code": 400 } async def _insert_document(self, file_path: str, userid: str, knowledge_base_id: str, db_type: str = "") -> Dict[ @@ -360,21 +242,15 @@ class MilvusConnection: timings = {} start_total = time.time() - start_neo4j = None try: - # 检查是否已存在相同的 file_path、userid 和 knowledge_base_id - collection = Collection(collection_name) - expr = f'file_path == "{file_path}" && userid == "{userid}" && knowledge_base_id == "{knowledge_base_id}"' - debug(f"检查重复文档: {expr}") - start_check = time.time() - results = collection.query(expr=expr, output_fields=["document_id"]) - timings["check_duplicate"] = time.time() - start_check - debug(f"检查重复文档耗时: {timings['check_duplicate']:.2f} 秒") - if results: - raise ValueError( - f"文档已存在: file_path={file_path}, userid={userid}, knowledge_base_id={knowledge_base_id}") - + # 验证参数 + if not userid or not knowledge_base_id: + raise ValueError("userid 和 knowledge_base_id 不能为空") + if "_" in userid or "_" in knowledge_base_id: + raise ValueError("userid 和 knowledge_base_id 不能包含下划线") + if len(userid) > 100 or len(knowledge_base_id) > 100: + raise ValueError("userid 或 knowledge_base_id 的长度超出限制") if not os.path.exists(file_path): raise ValueError(f"文件 {file_path} 不存在") @@ -413,49 +289,58 @@ class MilvusConnection: filename = os.path.basename(file_path).rsplit('.', 1)[0] upload_time = datetime.now().isoformat() - documents = [] - for i, chunk in enumerate(chunks): - chunk.metadata.update({ - 'userid': userid, - 'knowledge_base_id': knowledge_base_id, - 'document_id': document_id, - 'filename': filename + '.' + ext, - 'file_path': file_path, - 'upload_time': upload_time, - 'file_type': ext, - }) - documents.append(chunk) - debug(f"文档块 {i} 元数据: {chunk.metadata}") - # 确保集合存在 - debug(f"确保集合 {collection_name} 存在") - start_create = time.time() - create_result = await self._create_collection(db_type) - timings["create_collection"] = time.time() - start_create - debug(f"集合创建耗时: {timings['create_collection']:.2f} 秒") - if create_result["status"] == "error": - raise RuntimeError(f"集合创建失败: {create_result['message']}") - - # 生成嵌入 + # 生成嵌入向量 debug("调用嵌入服务生成向量") - texts = [doc.page_content for doc in documents] - start_embed = time.time() + start_embedding = time.time() + texts = [chunk.page_content for chunk in chunks] embeddings = await self._get_embeddings(texts) - timings["generate_embeddings"] = time.time() - start_embed - debug(f"生成嵌入耗时: {timings['generate_embeddings']:.2f} 秒") + if not embeddings or not all(len(vec) == 1024 for vec in embeddings): + raise ValueError("所有嵌入向量必须是长度为 1024 的浮点数列表") + timings["generate_embeddings"] = time.time() - start_embedding + debug(f"生成嵌入向量耗时: {timings['generate_embeddings']:.2f} 秒, 嵌入数量: {len(embeddings)}") - # 插入 Milvus + # 构造 chunks 参数(展平结构) + chunks_data = [] + for i, chunk in enumerate(chunks): + chunks_data.append({ + "userid": userid, + "knowledge_base_id": knowledge_base_id, + "text": chunk.page_content, + "vector": embeddings[i].tolist(), + "document_id": document_id, + "filename": filename + '.' + ext, + "file_path": file_path, + "upload_time": upload_time, + "file_type": ext, + }) + + # 调用 Milvus 插入端点 + debug(f"调用插入文件端点: {file_path}") start_milvus = time.time() - await self._insert_to_milvus(collection_name, documents, embeddings) + result = await self._make_api_request("insertdocument", { + "chunks": chunks_data, + "db_type": db_type, + }) timings["insert_milvus"] = time.time() - start_milvus debug(f"Milvus 插入耗时: {timings['insert_milvus']:.2f} 秒") - info(f"成功插入 {len(documents)} 个文档块到 {collection_name}") + + if result.get("status") != "success": + timings["total"] = time.time() - start_total + return { + "status": "error", + "document_id": document_id, + "collection_name": collection_name, + "timings": timings, + "message": result.get("message", "未知错误"), + "status_code": 400 + } # 三元组抽取 debug("调用三元组抽取服务") start_triples = time.time() try: - chunk_texts = [doc.page_content for doc in documents] + chunk_texts = [doc.page_content for doc in chunks] debug(f"处理 {len(chunk_texts)} 个分片进行三元组抽取") tasks = [self._extract_triples(chunk) for chunk in chunk_texts] @@ -503,13 +388,13 @@ class MilvusConnection: info(f"文件 {file_path} 三元组成功插入 Neo4j") else: debug(f"文件 {file_path} 未抽取到三元组") - timings["insert_neo4j"] = time.time() - start_neo4j if start_neo4j is not None else 0 + timings["insert_neo4j"] = time.time() - start_neo4j debug(f"Neo4j 插入耗时: {timings['insert_neo4j']:.2f} 秒") except Exception as e: timings["extract_triples"] = time.time() - start_triples if "extract_triples" not in timings else \ timings["extract_triples"] - timings["insert_neo4j"] = time.time() - start_neo4j if start_neo4j is not None else 0 + timings["insert_neo4j"] = time.time() - start_neo4j debug(f"处理三元组或 Neo4j 插入失败: {str(e)}, 堆栈: {traceback.format_exc()}") timings["total"] = time.time() - start_total return { @@ -535,9 +420,8 @@ class MilvusConnection: } except Exception as e: - error(f"插入文档失败: {str(e)}") + error(f"插入文档失败: {str(e)}, 堆栈: {traceback.format_exc()}") timings["total"] = time.time() - start_total - debug(f"总耗时: {timings['total']:.2f} 秒") return { "status": "error", "document_id": document_id, @@ -553,7 +437,6 @@ class MilvusConnection: retry=retry_if_exception_type((aiohttp.ClientError, RuntimeError)), before_sleep=lambda retry_state: debug(f"重试三元组抽取服务,第 {retry_state.attempt_number} 次") ) - async def _get_embeddings(self, texts: List[str]) -> List[List[float]]: """调用嵌入服务获取文本的向量,带缓存""" try: @@ -617,148 +500,59 @@ class MilvusConnection: debug(f"Request #{request_id} traceback: {traceback.format_exc()}") raise RuntimeError(f"三元组抽取服务调用失败: {str(e)}") - async def _insert_to_milvus(self, collection_name: str, documents: List[Document], - embeddings: List[List[float]]) -> None: - """将文档和嵌入向量插入 Milvus 集合""" - try: - if not connections.has_connection("default"): - self._initialize_connection() - collection = Collection(collection_name) - collection.load() - data = { - "userid": [doc.metadata["userid"] for doc in documents], - "knowledge_base_id": [doc.metadata["knowledge_base_id"] for doc in documents], - "document_id": [doc.metadata["document_id"] for doc in documents], - "text": [doc.page_content for doc in documents], - "vector": embeddings, - "filename": [doc.metadata["filename"] for doc in documents], - "file_path": [doc.metadata["file_path"] for doc in documents], - "upload_time": [doc.metadata["upload_time"] for doc in documents], - "file_type": [doc.metadata["file_type"] for doc in documents], - } - collection.insert([data[field.name] for field in collection.schema.fields if field.name != "pk"]) - collection.flush() - debug(f"成功插入 {len(documents)} 个文档到集合 {collection_name}") - except Exception as e: - error(f"插入 Milvus 失败: {str(e)}") - raise RuntimeError(f"插入 Milvus 失败: {str(e)}") - - async def _delete_document(self, db_type: str, userid: str, filename: str, knowledge_base_id: str) -> Dict[ - str, Any]: + async def _delete_document(self, userid: str, filename: str, knowledge_base_id: str, db_type: str = "") -> Dict[str, Any]: """删除用户指定文件数据,包括 Milvus 和 Neo4j 中的记录""" collection_name = "ragdb" if not db_type else f"ragdb_{db_type}" try: - if not utility.has_collection(collection_name): - debug(f"集合 {collection_name} 不存在") - return { - "status": "success", - "collection_name": collection_name, - "document_id": "", - "message": f"集合 {collection_name} 不存在,无需删除", - "status_code": 200 - } + # 调用 Milvus 删除文件端点 + debug(f"调用删除文件端点: userid={userid}, filename={filename}, knowledge_base_id={knowledge_base_id}") + milvus_result = await self._make_api_request("deletedocument", { + "userid": userid, + "filename": filename, + "knowledge_base_id": knowledge_base_id, + "db_type": db_type + }) - try: - collection = Collection(collection_name) - collection.load() - debug(f"加载集合: {collection_name}") - except Exception as e: - error(f"加载集合 {collection_name} 失败: {str(e)}") - return { - "status": "error", - "collection_name": collection_name, - "document_id": "", - "message": f"加载集合失败: {str(e)}", - "status_code": 400 - } + if milvus_result.get("status") != "success": + error(f"Milvus 删除文件失败: {milvus_result.get('message', '未知错误')}") + return milvus_result - expr = f"userid == '{userid}' and filename == '{filename}' and knowledge_base_id == '{knowledge_base_id}'" - debug(f"查询表达式: {expr}") - try: - results = collection.query( - expr=expr, - output_fields=["document_id"], - limit=1000 - ) - if not results: - debug( - f"没有找到 userid={userid}, filename={filename}, knowledge_base_id={knowledge_base_id} 的记录") - return { - "status": "success", - "collection_name": collection_name, - "document_id": "", - "message": f"没有找到 userid={userid}, filename={filename}, knowledge_base_id={knowledge_base_id} 的记录,无需删除", - "status_code": 200 - } - document_ids = list(set(result["document_id"] for result in results if "document_id" in result)) - debug(f"找到 {len(document_ids)} 个 document_id: {document_ids}") - except Exception as e: - error(f"查询 document_id 失败: {str(e)}") - return { - "status": "error", - "collection_name": collection_name, - "document_id": "", - "message": f"查询失败: {str(e)}", - "status_code": 400 - } + document_ids = milvus_result.get("document_id", "").split(",") if milvus_result.get("document_id") else [] - total_deleted = 0 neo4j_deleted_nodes = 0 neo4j_deleted_rels = 0 - for doc_id in document_ids: - try: - # 删除 Milvus 记录 - delete_expr = f"document_id == '{doc_id}'" - debug(f"删除表达式: {delete_expr}") - delete_result = collection.delete(delete_expr) - deleted_count = delete_result.delete_count - total_deleted += deleted_count - info(f"成功删除 document_id={doc_id} 的 {deleted_count} 条 Milvus 记录") - # 删除 Neo4j 三元组 - try: - graph = Graph(self.neo4j_uri, auth=(self.neo4j_user, self.neo4j_password)) - query = """ - MATCH (n {document_id: $document_id}) - OPTIONAL MATCH (n)-[r {document_id: $document_id}]->() - WITH collect(r) AS rels, collect(n) AS nodes - FOREACH (r IN rels | DELETE r) - FOREACH (n IN nodes | DELETE n) - RETURN size(nodes) AS node_count, size(rels) AS rel_count, [r IN rels | type(r)] AS rel_types - """ - result = graph.run(query, document_id=doc_id).data() - nodes_deleted = result[0]['node_count'] if result else 0 - rels_deleted = result[0]['rel_count'] if result else 0 - rel_types = result[0]['rel_types'] if result else [] - info( - f"成功删除 document_id={doc_id} 的 {nodes_deleted} 个 Neo4j 节点和 {rels_deleted} 个关系,关系类型: {rel_types}") - neo4j_deleted_nodes += nodes_deleted - neo4j_deleted_rels += rels_deleted - except Exception as e: - error(f"删除 document_id={doc_id} 的 Neo4j 三元组失败: {str(e)}") - continue + # 删除 Neo4j 数据 + for doc_id in document_ids: + if not doc_id: + continue + try: + graph = Graph(self.neo4j_uri, auth=(self.neo4j_user, self.neo4j_password)) + query = """ + MATCH (n {document_id: $document_id}) + OPTIONAL MATCH (n)-[r {document_id: $document_id}]->() + WITH collect(r) AS rels, collect(n) AS nodes + FOREACH (r IN rels | DELETE r) + FOREACH (n IN nodes | DELETE n) + RETURN size(nodes) AS node_count, size(rels) AS rel_count, [r IN rels | type(r)] AS rel_types + """ + result = graph.run(query, document_id=doc_id).data() + nodes_deleted = result[0]['node_count'] if result else 0 + rels_deleted = result[0]['rel_count'] if result else 0 + rel_types = result[0]['rel_types'] if result else [] + info( + f"成功删除 document_id={doc_id} 的 {nodes_deleted} 个 Neo4j 节点和 {rels_deleted} 个关系,关系类型: {rel_types}") + neo4j_deleted_nodes += nodes_deleted + neo4j_deleted_rels += rels_deleted except Exception as e: - error(f"删除 document_id={doc_id} 的 Milvus 记录失败: {str(e)}") + error(f"删除 document_id={doc_id} 的 Neo4j 三元组失败: {str(e)}") continue - if total_deleted == 0: - debug( - f"没有删除任何 Milvus 记录,userid={userid}, filename={filename}, knowledge_base_id={knowledge_base_id}") - return { - "status": "success", - "collection_name": collection_name, - "document_id": "", - "message": f"没有删除任何记录,userid={userid}, filename={filename}, knowledge_base_id={knowledge_base_id}", - "status_code": 200 - } - - info( - f"总计删除 {total_deleted} 条 Milvus 记录,{neo4j_deleted_nodes} 个 Neo4j 节点,{neo4j_deleted_rels} 个 Neo4j 关系,userid={userid}, filename={filename}, knowledge_base_id={knowledge_base_id}") return { "status": "success", "collection_name": collection_name, "document_id": ",".join(document_ids), - "message": f"成功删除 {total_deleted} 条 Milvus 记录,{neo4j_deleted_nodes} 个 Neo4j 节点,{neo4j_deleted_rels} 个 Neo4j 关系,userid={userid}, filename={filename}, knowledge_base_id={knowledge_base_id}", + "message": f"成功删除 Milvus 记录,{neo4j_deleted_nodes} 个 Neo4j 节点,{neo4j_deleted_rels} 个 Neo4j 关系", "status_code": 200 } @@ -776,71 +570,19 @@ class MilvusConnection: """删除用户的整个知识库,包括 Milvus 和 Neo4j 中的记录""" collection_name = "ragdb" if not db_type else f"ragdb_{db_type}" try: - if not utility.has_collection(collection_name): - debug(f"集合 {collection_name} 不存在") - return { - "status": "success", - "collection_name": collection_name, - "deleted_files": [], - "message": f"集合 {collection_name} 不存在,无需删除", - "status_code": 200 - } + # 调用 Milvus 删除知识库端点 + debug(f"调用删除知识库端点: userid={userid}, knowledge_base_id={knowledge_base_id}") + milvus_result = await self._make_api_request("deleteknowledgebase", { + "userid": userid, + "knowledge_base_id": knowledge_base_id, + "db_type": db_type + }) - try: - collection = Collection(collection_name) - debug(f"加载集合: {collection_name}") - except Exception as e: - error(f"加载集合 {collection_name} 失败: {str(e)}") - return { - "status": "error", - "collection_name": collection_name, - "deleted_files": [], - "message": f"加载集合失败: {str(e)}", - "status_code": 400 - } + if milvus_result.get("status") != "success": + error(f"Milvus 删除知识库失败: {milvus_result.get('message', '未知错误')}") + return milvus_result - # 查询被删除的文件列表 - deleted_files = [] - try: - expr = f"userid == '{userid}' and knowledge_base_id == '{knowledge_base_id}'" - debug(f"查询表达式: {expr}") - results = collection.query( - expr=expr, - output_fields=["file_path"], - limit=1000 - ) - if results: - deleted_files = list(set(result["file_path"] for result in results if "file_path" in result)) - debug(f"找到 {len(deleted_files)} 个唯一文件: {deleted_files}") - else: - debug(f"没有找到 userid={userid}, knowledge_base_id={knowledge_base_id} 的记录") - except Exception as e: - error(f"查询 file_path 失败: {str(e)}") - return { - "status": "error", - "collection_name": collection_name, - "deleted_files": [], - "message": f"查询 file_path 失败: {str(e)}", - "status_code": 400 - } - - # 删除 Milvus 记录 - total_deleted = 0 - try: - delete_expr = f"userid == '{userid}' and knowledge_base_id == '{knowledge_base_id}'" - debug(f"删除表达式: {delete_expr}") - delete_result = collection.delete(delete_expr) - total_deleted = delete_result.delete_count - info(f"成功删除 {total_deleted} 条 Milvus 记录") - except Exception as e: - error(f"删除 Milvus 记录失败: {str(e)}") - return { - "status": "error", - "collection_name": collection_name, - "deleted_files": deleted_files, - "message": f"删除 Milvus 记录失败: {str(e)}", - "status_code": 400 - } + deleted_files = milvus_result.get("deleted_files", []) # 删除 Neo4j 数据 neo4j_deleted_nodes = 0 @@ -850,13 +592,13 @@ class MilvusConnection: graph = Graph(self.neo4j_uri, auth=(self.neo4j_user, self.neo4j_password)) debug("Neo4j 连接成功") query = """ - MATCH (n {userid: $userid, knowledge_base_id: $knowledge_base_id}) - OPTIONAL MATCH (n)-[r {userid: $userid, knowledge_base_id: $knowledge_base_id}]->() - WITH collect(r) AS rels, collect(n) AS nodes - FOREACH (r IN rels | DELETE r) - FOREACH (n IN nodes | DELETE n) - RETURN size(nodes) AS node_count, size(rels) AS rel_count, [r IN rels | type(r)] AS rel_types - """ + MATCH (n {userid: $userid, knowledge_base_id: $knowledge_base_id}) + OPTIONAL MATCH (n)-[r {userid: $userid, knowledge_base_id: $knowledge_base_id}]->() + WITH collect(r) AS rels, collect(n) AS nodes + FOREACH (r IN rels | DELETE r) + FOREACH (n IN nodes | DELETE n) + RETURN size(nodes) AS node_count, size(rels) AS rel_count, [r IN rels | type(r)] AS rel_types + """ result = graph.run(query, userid=userid, knowledge_base_id=knowledge_base_id).data() nodes_deleted = result[0]['node_count'] if result else 0 rels_deleted = result[0]['rel_count'] if result else 0 @@ -870,11 +612,11 @@ class MilvusConnection: "status": "success", "collection_name": collection_name, "deleted_files": deleted_files, - "message": f"成功删除 {total_deleted} 条 Milvus 记录,{neo4j_deleted_nodes} 个 Neo4j 节点,{neo4j_deleted_rels} 个 Neo4j 关系,但 Neo4j 删除失败: {str(e)}", + "message": f"成功删除 Milvus 知识库,{neo4j_deleted_nodes} 个 Neo4j 节点,{neo4j_deleted_rels} 个 Neo4j 关系,但 Neo4j 删除失败: {str(e)}", "status_code": 200 } - if total_deleted == 0 and neo4j_deleted_nodes == 0 and neo4j_deleted_rels == 0: + if not deleted_files and neo4j_deleted_nodes == 0 and neo4j_deleted_rels == 0: debug(f"没有删除任何记录,userid={userid}, knowledge_base_id={knowledge_base_id}") return { "status": "success", @@ -885,12 +627,12 @@ class MilvusConnection: } info( - f"总计删除 {total_deleted} 条 Milvus 记录,{neo4j_deleted_nodes} 个 Neo4j 节点,{neo4j_deleted_rels} 个 Neo4j 关系,删除文件: {deleted_files}, userid={userid}, knowledge_base_id={knowledge_base_id}") + f"总计删除 Milvus 记录,{neo4j_deleted_nodes} 个 Neo4j 节点,{neo4j_deleted_rels} 个 Neo4j 关系,删除文件: {deleted_files}") return { "status": "success", "collection_name": collection_name, "deleted_files": deleted_files, - "message": f"成功删除 {total_deleted} 条 Milvus 记录,{neo4j_deleted_nodes} 个 Neo4j 节点,{neo4j_deleted_rels} 个 Neo4j 关系,删除文件: {deleted_files}, userid={userid}, knowledge_base_id={knowledge_base_id}", + "message": f"成功删除 Milvus 记录,{neo4j_deleted_nodes} 个 Neo4j 节点,{neo4j_deleted_rels} 个 Neo4j 关系,删除文件: {deleted_files}", "status_code": 200 } @@ -1089,163 +831,12 @@ class MilvusConnection: error(f"重排序服务调用失败: {str(e)}") return results - async def _fused_search(self, query: str, userid: str, db_type: str, knowledge_base_ids: List[str], limit: int = 5, - offset: int = 0, use_rerank: bool = True) -> Dict[str, Any]: - """融合搜索,将查询与所有三元组拼接后向量化搜索""" - start_time = time.time() # 记录开始时间 + async def _search_query(self, query: str, userid: str, knowledge_base_ids: List[str], limit: int = 5, + offset: int = 0, use_rerank: bool = True, db_type: str = "") -> Dict[str, Any]: + """纯向量搜索,调用服务化端点""" + start_time = time.time() collection_name = "ragdb" if not db_type else f"ragdb_{db_type}" - timing_stats = {} # 记录各步骤耗时 - try: - info( - f"开始融合搜索: query={query}, userid={userid}, db_type={db_type}, knowledge_base_ids={knowledge_base_ids}, limit={limit}, offset={offset}, use_rerank={use_rerank}") - - if not query or not userid or not knowledge_base_ids: - raise ValueError("query、userid 和 knowledge_base_ids 不能为空") - if "_" in userid or (db_type and "_" in db_type): - raise ValueError("userid 和 db_type 不能包含下划线") - if (db_type and len(db_type) > 100) or len(userid) > 100: - raise ValueError("db_type 或 userid 的长度超出限制") - if limit < 1 or limit > 16384 or offset < 0: - raise ValueError("limit 必须在 1 到 16384 之间,offset 必须大于或等于 0") - - if not utility.has_collection(collection_name): - debug(f"集合 {collection_name} 不存在") - return {"results": [], "timing": timing_stats} - - try: - collection = Collection(collection_name) - collection.load() - debug(f"加载集合: {collection_name}") - timing_stats["collection_load"] = time.time() - start_time - debug(f"集合加载耗时: {timing_stats['collection_load']:.3f} 秒") - except Exception as e: - error(f"加载集合 {collection_name} 失败: {str(e)}") - return {"results": [], "timing": timing_stats} - - entity_extract_start = time.time() - query_entities = await self._extract_entities(query) - timing_stats["entity_extraction"] = time.time() - entity_extract_start - debug(f"提取实体: {query_entities}, 耗时: {timing_stats['entity_extraction']:.3f} 秒") - - all_triplets = [] - triplet_match_start = time.time() - for kb_id in knowledge_base_ids: - debug(f"处理知识库: {kb_id}") - matched_triplets = await self._match_triplets(query, query_entities, userid, kb_id) - debug(f"知识库 {kb_id} 匹配三元组: {len(matched_triplets)} 条") - all_triplets.extend(matched_triplets) - timing_stats["triplet_matching"] = time.time() - triplet_match_start - debug(f"三元组匹配总耗时: {timing_stats['triplet_matching']:.3f} 秒") - - if not all_triplets: - debug("未找到任何匹配的三元组") - return {"results": [], "timing": timing_stats} - - triplet_text_start = time.time() - triplet_texts = [] - for triplet in all_triplets: - head = triplet.get('head', '') - type_ = triplet.get('type', '') - tail = triplet.get('tail', '') - if head and type_ and tail: - triplet_texts.append(f"{head} {type_} {tail}") - else: - debug(f"无效三元组: {triplet}") - combined_text = query - if triplet_texts: - combined_text += " [三元组] " + "; ".join(triplet_texts) - debug( - f"拼接文本: {combined_text[:200]}... (总长度: {len(combined_text)}, 三元组数量: {len(triplet_texts)})") - timing_stats["triplet_text_combine"] = time.time() - triplet_text_start - debug(f"拼接三元组文本耗时: {timing_stats['triplet_text_combine']:.3f} 秒") - - embedding_start = time.time() - embeddings = await self._get_embeddings([combined_text]) - query_vector = embeddings[0] - debug(f"拼接文本向量维度: {len(query_vector)}") - timing_stats["embedding_generation"] = time.time() - embedding_start - debug(f"嵌入向量生成耗时: {timing_stats['embedding_generation']:.3f} 秒") - - search_start = time.time() - search_params = {"metric_type": "COSINE", "params": {"nprobe": 10}} - kb_expr = " or ".join([f"knowledge_base_id == '{kb_id}'" for kb_id in knowledge_base_ids]) - expr = f"userid == '{userid}' and ({kb_expr})" - debug(f"搜索表达式: {expr}") - - try: - results = collection.search( - data=[query_vector], - anns_field="vector", - param=search_params, - limit=100, - expr=expr, - output_fields=["text", "userid", "document_id", "filename", "file_path", "upload_time", - "file_type"], - offset=offset - ) - except Exception as e: - error(f"向量搜索失败: {str(e)}") - return {"results": [], "timing": timing_stats} - timing_stats["vector_search"] = time.time() - search_start - debug(f"向量搜索耗时: {timing_stats['vector_search']:.3f} 秒") - - search_results = [] - for hits in results: - for hit in hits: - metadata = { - "userid": hit.entity.get("userid"), - "document_id": hit.entity.get("document_id"), - "filename": hit.entity.get("filename"), - "file_path": hit.entity.get("file_path"), - "upload_time": hit.entity.get("upload_time"), - "file_type": hit.entity.get("file_type") - } - result = { - "text": hit.entity.get("text"), - "distance": hit.distance, - "source": "fused_query_with_triplets", - "metadata": metadata - } - search_results.append(result) - debug( - f"搜索命中: text={result['text'][:100]}..., distance={hit.distance}, source={result['source']}") - - unique_results = [] - seen_texts = set() - dedup_start = time.time() - for result in sorted(search_results, key=lambda x: x['distance'], reverse=True): - if result['text'] not in seen_texts: - unique_results.append(result) - seen_texts.add(result['text']) - timing_stats["deduplication"] = time.time() - dedup_start - debug(f"去重耗时: {timing_stats['deduplication']:.3f} 秒") - info(f"去重后结果数量: {len(unique_results)} (原始数量: {len(search_results)})") - - if use_rerank and unique_results: - rerank_start = time.time() - debug("开始重排序") - unique_results = await self._rerank_results(combined_text, unique_results, limit) - unique_results = sorted(unique_results, key=lambda x: x.get('rerank_score', 0), reverse=True) - timing_stats["reranking"] = time.time() - rerank_start - debug(f"重排序耗时: {timing_stats['reranking']:.3f} 秒") - debug(f"重排序分数分布: {[round(r.get('rerank_score', 0), 3) for r in unique_results]}") - else: - unique_results = [{k: v for k, v in r.items() if k != 'rerank_score'} for r in unique_results] - - timing_stats["total_time"] = time.time() - start_time - info(f"融合搜索完成,返回 {len(unique_results)} 条结果,总耗时: {timing_stats['total_time']:.3f} 秒") - return {"results": unique_results[:limit], "timing": timing_stats} - - except Exception as e: - error(f"融合搜索失败: {str(e)}") - return {"results": [], "timing": timing_stats} - - async def _search_query(self, query: str, userid: str, db_type: str = "", knowledge_base_ids: List[str] = [], limit: int = 5, - offset: int = 0, use_rerank: bool = True) -> Dict[str, Any]: - """纯向量搜索,基于查询文本在指定知识库中搜索相关文本块""" - start_time = time.time() # 记录开始时间 - collection_name = "ragdb" if not db_type else f"ragdb_{db_type}" - timing_stats = {} # 记录各步骤耗时 + timing_stats = {} try: info( f"开始纯向量搜索: query={query}, userid={userid}, db_type={db_type}, knowledge_base_ids={knowledge_base_ids}, limit={limit}, offset={offset}, use_rerank={use_rerank}") @@ -1274,82 +865,33 @@ class MilvusConnection: if "_" in kb_id: raise ValueError(f"knowledge_base_id 不能包含下划线: {kb_id}") - if not utility.has_collection(collection_name): - debug(f"集合 {collection_name} 不存在") - return {"results": [], "timing": timing_stats} - - try: - collection = Collection(collection_name) - collection.load() - debug(f"加载集合: {collection_name}") - timing_stats["collection_load"] = time.time() - start_time - debug(f"集合加载耗时: {timing_stats['collection_load']:.3f} 秒") - except Exception as e: - error(f"加载集合 {collection_name} 失败: {str(e)}") - return {"results": [], "timing": timing_stats} - - embedding_start = time.time() - embeddings = await self._get_embeddings([query]) - query_vector = embeddings[0] - debug(f"查询向量维度: {len(query_vector)}") - timing_stats["embedding_generation"] = time.time() - embedding_start - debug(f"嵌入向量生成耗时: {timing_stats['embedding_generation']:.3f} 秒") + # 将查询文本转换为向量 + vector_start = time.time() + query_vector = await self._get_embeddings([query]) + if not query_vector or not all(len(vec) == 1024 for vec in query_vector): + raise ValueError("查询向量必须是长度为 1024 的浮点数列表") + query_vector = query_vector[0] # 取第一个向量 + timing_stats["vector_generation"] = time.time() - vector_start + debug(f"生成查询向量耗时: {timing_stats['vector_generation']:.3f} 秒") + # 调用纯向量搜索端点 search_start = time.time() - search_params = {"metric_type": "COSINE", "params": {"nprobe": 10}} - kb_id_expr = " or ".join([f"knowledge_base_id == '{kb_id}'" for kb_id in knowledge_base_ids]) - expr = f"userid == '{userid}' and ({kb_id_expr})" - debug(f"搜索表达式: {expr}") - - try: - results = collection.search( - data=[query_vector], - anns_field="vector", - param=search_params, - limit=100, - expr=expr, - output_fields=["text", "userid", "document_id", "filename", "file_path", "upload_time", - "file_type"], - offset=offset - ) - except Exception as e: - error(f"搜索失败: {str(e)}") - return {"results": [], "timing": timing_stats} + result = await self._make_api_request("searchquery", { + "query_vector": query_vector.tolist(), + "userid": userid, + "knowledge_base_ids": knowledge_base_ids, + "limit": limit, + "offset": offset, + "db_type": db_type + }) timing_stats["vector_search"] = time.time() - search_start debug(f"向量搜索耗时: {timing_stats['vector_search']:.3f} 秒") - search_results = [] - for hits in results: - for hit in hits: - metadata = { - "userid": hit.entity.get("userid"), - "document_id": hit.entity.get("document_id"), - "filename": hit.entity.get("filename"), - "file_path": hit.entity.get("file_path"), - "upload_time": hit.entity.get("upload_time"), - "file_type": hit.entity.get("file_type") - } - result = { - "text": hit.entity.get("text"), - "distance": hit.distance, - "source": "vector_query", - "metadata": metadata - } - search_results.append(result) - debug( - f"命中: text={result['text'][:100]}..., distance={hit.distance}, filename={metadata['filename']}") - - dedup_start = time.time() - unique_results = [] - seen_texts = set() - for result in sorted(search_results, key=lambda x: x['distance'], reverse=True): - if result['text'] not in seen_texts: - unique_results.append(result) - seen_texts.add(result['text']) - timing_stats["deduplication"] = time.time() - dedup_start - debug(f"去重耗时: {timing_stats['deduplication']:.3f} 秒") - info(f"去重后结果数量: {len(unique_results)} (原始数量: {len(search_results)})") + if result.get("status") != "success": + error(f"纯向量搜索失败: {result.get('message', '未知错误')}") + return {"results": [], "timing": timing_stats} + unique_results = result.get("results", []) if use_rerank and unique_results: rerank_start = time.time() debug("开始重排序") @@ -1366,7 +908,108 @@ class MilvusConnection: return {"results": unique_results[:limit], "timing": timing_stats} except Exception as e: - error(f"纯向量搜索失败: {str(e)}") + error(f"纯向量搜索失败: {str(e)}, 堆栈: {traceback.format_exc()}") + return {"results": [], "timing": timing_stats} + + async def _fused_search(self, query: str, userid: str, knowledge_base_ids: List[str], limit: int = 5, + offset: int = 0, use_rerank: bool = True, db_type: str = "") -> Dict[str, Any]: + """融合搜索,调用服务化端点""" + start_time = time.time() + collection_name = "ragdb" if not db_type else f"ragdb_{db_type}" + timing_stats = {} + try: + info( + f"开始融合搜索: query={query}, userid={userid}, db_type={db_type}, knowledge_base_ids={knowledge_base_ids}, limit={limit}, offset={offset}, use_rerank={use_rerank}") + + if not query or not userid or not knowledge_base_ids: + raise ValueError("query、userid 和 knowledge_base_ids 不能为空") + if "_" in userid or (db_type and "_" in db_type): + raise ValueError("userid 和 db_type 不能包含下划线") + if (db_type and len(db_type) > 100) or len(userid) > 100: + raise ValueError("db_type 或 userid 的长度超出限制") + if limit < 1 or limit > 16384 or offset < 0: + raise ValueError("limit 必须在 1 到 16384 之间,offset 必须大于或等于 0") + + # 提取实体 + entity_extract_start = time.time() + query_entities = await self._extract_entities(query) + timing_stats["entity_extraction"] = time.time() - entity_extract_start + debug(f"提取实体: {query_entities}, 耗时: {timing_stats['entity_extraction']:.3f} 秒") + + # 匹配三元组 + all_triplets = [] + triplet_match_start = time.time() + for kb_id in knowledge_base_ids: + debug(f"处理知识库: {kb_id}") + matched_triplets = await self._match_triplets(query, query_entities, userid, kb_id) + debug(f"知识库 {kb_id} 匹配三元组: {len(matched_triplets)} 条") + all_triplets.extend(matched_triplets) + timing_stats["triplet_matching"] = time.time() - triplet_match_start + debug(f"三元组匹配总耗时: {timing_stats['triplet_matching']:.3f} 秒") + + # 拼接三元组文本 + triplet_text_start = time.time() + triplet_texts = [] + for triplet in all_triplets: + head = triplet.get('head', '') + type_ = triplet.get('type', '') + tail = triplet.get('tail', '') + if head and type_ and tail: + triplet_texts.append(f"{head} {type_} {tail}") + else: + debug(f"无效三元组: {triplet}") + combined_text = query + if triplet_texts: + combined_text += " [三元组] " + "; ".join(triplet_texts) + debug( + f"拼接文本: {combined_text[:200]}... (总长度: {len(combined_text)}, 三元组数量: {len(triplet_texts)})") + timing_stats["triplet_text_combine"] = time.time() - triplet_text_start + debug(f"拼接三元组文本耗时: {timing_stats['triplet_text_combine']:.3f} 秒") + + # 将拼接文本转换为向量 + vector_start = time.time() + query_vector = await self._get_embeddings([combined_text]) + if not query_vector or not all(len(vec) == 1024 for vec in query_vector): + raise ValueError("查询向量必须是长度为 1024 的浮点数列表") + query_vector = query_vector[0] # 取第一个向量 + timing_stats["vector_generation"] = time.time() - vector_start + debug(f"生成查询向量耗时: {timing_stats['vector_generation']:.3f} 秒") + + # 调用融合搜索端点 + search_start = time.time() + result = await self._make_api_request("searchquery", { # 注意:使用 searchquery 端点 + "query_vector": query_vector.tolist(), + "userid": userid, + "knowledge_base_ids": knowledge_base_ids, + "limit": limit, + "offset": offset, + "db_type": db_type + }) + timing_stats["vector_search"] = time.time() - search_start + debug(f"向量搜索耗时: {timing_stats['vector_search']:.3f} 秒") + + if result.get("status") != "success": + error(f"融合搜索失败: {result.get('message', '未知错误')}") + return {"results": [], "timing": timing_stats} + + unique_results = result.get("results", []) + if use_rerank and unique_results: + rerank_start = time.time() + debug("开始重排序") + unique_results = await self._rerank_results(combined_text, unique_results, limit) + unique_results = sorted(unique_results, key=lambda x: x.get('rerank_score', 0), reverse=True) + timing_stats["reranking"] = time.time() - rerank_start + debug(f"重排序耗时: {timing_stats['reranking']:.3f} 秒") + debug(f"重排序分数分布: {[round(r.get('rerank_score', 0), 3) for r in unique_results]}") + else: + unique_results = [{k: v for k, v in r.items() if k != 'rerank_score'} for r in unique_results] + + timing_stats["total_time"] = time.time() - start_time + info(f"融合搜索完成,返回 {len(unique_results)} 条结果,总耗时: {timing_stats['total_time']:.3f} 秒") + return {"results": unique_results[:limit], "timing": timing_stats} + + except Exception as e: + error(f"融合搜索失败: {str(e)}, 堆栈: {traceback.format_exc()}") return {"results": [], "timing": timing_stats} async def _list_user_files(self, userid: str, db_type: str = "") -> Dict[str, List[Dict]]: @@ -1382,53 +1025,17 @@ class MilvusConnection: if (db_type and len(db_type) > 100) or len(userid) > 100: raise ValueError("userid 或 db_type 的长度超出限制") - if not utility.has_collection(collection_name): - debug(f"集合 {collection_name} 不存在") + # 调用列出用户文件端点 + result = await self._make_api_request("listuserfiles", { + "userid": userid, + "db_type": db_type + }) + + if result.get("status") != "success": + error(f"列出用户文件失败: {result.get('message', '未知错误')}") return {} - try: - collection = Collection(collection_name) - collection.load() - debug(f"加载集合: {collection_name}") - except Exception as e: - error(f"加载集合 {collection_name} 失败: {str(e)}") - return {} - - expr = f"userid == '{userid}'" - debug(f"查询表达式: {expr}") - - try: - results = collection.query( - expr=expr, - output_fields=["document_id", "filename", "file_path", "upload_time", "file_type", "knowledge_base_id"], - limit=1000 - ) - except Exception as e: - error(f"查询用户文件失败: {str(e)}") - return {} - - files_by_kb = {} - seen_document_ids = set() - for result in results: - document_id = result.get("document_id") - kb_id = result.get("knowledge_base_id") - if document_id not in seen_document_ids: - seen_document_ids.add(document_id) - file_info = { - "document_id": document_id, - "filename": result.get("filename"), - "file_path": result.get("file_path"), - "upload_time": result.get("upload_time"), - "file_type": result.get("file_type"), - "knowledge_base_id": kb_id - } - if kb_id not in files_by_kb: - files_by_kb[kb_id] = [] - files_by_kb[kb_id].append(file_info) - debug(f"找到文件: document_id={document_id}, filename={result.get('filename')}, knowledge_base_id={kb_id}") - - info(f"找到 {len(seen_document_ids)} 个文件,userid={userid}, 知识库数量={len(files_by_kb)}") - return files_by_kb + return result.get("files_by_knowledge_base", {}) except Exception as e: error(f"列出用户文件失败: {str(e)}") @@ -1445,82 +1052,12 @@ class MilvusConnection: if db_type and len(db_type) > 100: raise ValueError("db_type 的长度应小于 100") - if not utility.has_collection(collection_name): - debug(f"集合 {collection_name} 不存在") - return { - "status": "success", - "users_knowledge_bases": {}, - "collection_name": collection_name, - "message": f"集合 {collection_name} 不存在", - "status_code": 200 - } + # 调用列出所有知识库端点 + result = await self._make_api_request("listallknowledgebases", { + "db_type": db_type + }) - try: - collection = Collection(collection_name) - collection.load() - debug(f"加载集合: {collection_name}") - except Exception as e: - error(f"加载集合 {collection_name} 失败: {str(e)}") - return { - "status": "error", - "users_knowledge_bases": {}, - "collection_name": collection_name, - "message": f"加载集合失败: {str(e)}", - "status_code": 400 - } - - # 查询所有用户的文件,按 userid 和 knowledge_base_id 分组 - expr = "userid != ''" # 查询所有非空用户 - debug(f"查询表达式: {expr}") - try: - results = collection.query( - expr=expr, - output_fields=["userid", "knowledge_base_id", "document_id", "filename", "file_path", "upload_time", - "file_type"], - limit=10000 # 假设最大 10000 条记录,需根据实际数据量调整 - ) - except Exception as e: - error(f"查询所有用户文件失败: {str(e)}") - return { - "status": "error", - "users_knowledge_bases": {}, - "collection_name": collection_name, - "message": f"查询失败: {str(e)}", - "status_code": 400 - } - - users_knowledge_bases = {} - seen_document_ids = set() - for result in results: - userid = result.get("userid") - kb_id = result.get("knowledge_base_id") - document_id = result.get("document_id") - if document_id not in seen_document_ids: - seen_document_ids.add(document_id) - file_info = { - "document_id": document_id, - "filename": result.get("filename"), - "file_path": result.get("file_path"), - "upload_time": result.get("upload_time"), - "file_type": result.get("file_type"), - "knowledge_base_id": kb_id - } - if userid not in users_knowledge_bases: - users_knowledge_bases[userid] = {} - if kb_id not in users_knowledge_bases[userid]: - users_knowledge_bases[userid][kb_id] = [] - users_knowledge_bases[userid][kb_id].append(file_info) - debug( - f"找到文件: userid={userid}, knowledge_base_id={kb_id}, document_id={document_id}, filename={result.get('filename')}") - - info(f"找到 {len(seen_document_ids)} 个文件,涉及 {len(users_knowledge_bases)} 个用户") - return { - "status": "success", - "users_knowledge_bases": users_knowledge_bases, - "collection_name": collection_name, - "message": f"成功列出 {len(users_knowledge_bases)} 个用户的知识库和文件", - "status_code": 200 - } + return result except Exception as e: error(f"列出所有用户知识库失败: {str(e)}") @@ -1532,5 +1069,5 @@ class MilvusConnection: "status_code": 400 } -connection_register('Milvus', MilvusConnection) +connection_register('Rag', MilvusConnection) info("MilvusConnection registered") \ No newline at end of file diff --git a/llmengine/milvus_db.py b/llmengine/milvus_db.py old mode 100644 new mode 100755 index 600d99e..c2d2ba6 --- a/llmengine/milvus_db.py +++ b/llmengine/milvus_db.py @@ -7,6 +7,7 @@ from typing import Dict, List, Any import uuid from datetime import datetime from llmengine.base_db import connection_register, BaseDBConnection +import time class MilvusDBConnection(BaseDBConnection): _instance = None @@ -74,30 +75,13 @@ class MilvusDBConnection(BaseDBConnection): elif action == "delete_collection": return await self._delete_collection(db_type) elif action == "insert_document": - userid = params.get("userid", "") - knowledge_base_id = params.get("knowledge_base_id", "") - document_id = params.get("document_id", str(uuid.uuid4())) - texts = params.get("texts", []) - embeddings = params.get("embeddings", []) - filename = params.get("filename", "") - file_path = params.get("file_path", "") - upload_time = params.get("upload_time", datetime.now().isoformat()) - file_type = params.get("file_type", "") - if not userid or not knowledge_base_id or not texts or not embeddings: - return {"status": "error", "message": "userid、knowledge_base_id、texts 和 embeddings 不能为空", - "collection_name": collection_name, "document_id": "", "status_code": 400} - if "_" in userid or "_" in knowledge_base_id: - return {"status": "error", "message": "userid 和 knowledge_base_id 不能包含下划线", - "collection_name": collection_name, "document_id": "", "status_code": 400} - if len(knowledge_base_id) > 100 or len(userid) > 100: - return {"status": "error", "message": "userid 或 knowledge_base_id 的长度应小于 100", - "collection_name": collection_name, "document_id": "", "status_code": 400} - return await self._insert_document(collection_name, userid, knowledge_base_id, document_id, texts, embeddings, - filename, file_path, upload_time, file_type) + chunks = params.get("chunks", []) + return await self._insert_document(chunks, db_type) elif action == "delete_document": userid = params.get("userid", "") filename = params.get("filename", "") knowledge_base_id = params.get("knowledge_base_id", "") + db_type = params.get("db_type", "") if not userid or not filename or not knowledge_base_id: return {"status": "error", "message": "userid、filename 和 knowledge_base_id 不能为空", "collection_name": collection_name, "document_id": "", "status_code": 400} @@ -107,7 +91,7 @@ class MilvusDBConnection(BaseDBConnection): if len(userid) > 100 or len(filename) > 255 or len(knowledge_base_id) > 100: return {"status": "error", "message": "userid、filename 或 knowledge_base_id 的长度超出限制", "collection_name": collection_name, "document_id": "", "status_code": 400} - return await self._delete_document(db_type, userid, filename, knowledge_base_id) + return await self._delete_document(userid, filename, knowledge_base_id, db_type) elif action == "delete_knowledge_base": userid = params.get("userid", "") knowledge_base_id = params.get("knowledge_base_id", "") @@ -127,13 +111,14 @@ class MilvusDBConnection(BaseDBConnection): knowledge_base_ids = params.get("knowledge_base_ids", []) limit = params.get("limit", 5) offset = params.get("offset", 0) + db_type = params.get("db_type", "") if not query_vector or not userid or not knowledge_base_ids: return {"status": "error", "message": "query_vector、userid 或 knowledge_base_ids 不能为空", "collection_name": collection_name, "document_id": "", "status_code": 400} if limit < 1 or limit > 16384: return {"status": "error", "message": "limit 必须在 1 到 16384 之间", "collection_name": collection_name, "document_id": "", "status_code": 400} - return await self._search_query(collection_name, query_vector, userid, knowledge_base_ids, limit, offset) + return await self._search_query(query_vector, userid, knowledge_base_ids, limit, offset, db_type) elif action == "list_user_files": userid = params.get("userid", "") if not userid: @@ -300,36 +285,84 @@ class MilvusDBConnection(BaseDBConnection): "message": str(e) } - async def _insert_document(self, collection_name: str, userid: str, knowledge_base_id: str, document_id: str, - texts: List[str], embeddings: List[List[float]], filename: str, file_path: str, - upload_time: str, file_type: str) -> Dict[str, Any]: + async def _insert_document(self, chunks: List[Dict], db_type: str = "") -> Dict[str, Any]: """插入文档到 Milvus""" + collection_name = "ragdb" if not db_type else f"ragdb_{db_type}" + document_id = chunks[0]["document_id"] if chunks else "" try: # 检查集合是否存在 - create_result = await self._create_collection(collection_name.split('_')[-1] if '_' in collection_name else "") + create_result = await self._create_collection(db_type) if create_result["status"] == "error": raise RuntimeError(f"集合创建失败: {create_result['message']}") # 检查输入数据 - if len(texts) != len(embeddings): - raise ValueError("texts 和 embeddings 的长度必须一致") - if not all(isinstance(emb, list) and len(emb) == 1024 for emb in embeddings): - raise ValueError("embeddings 必须是长度为 1024 的浮点数列表") + if not chunks: + raise ValueError("chunks 不能为空") + for chunk in chunks: + if not isinstance(chunk, dict): + raise ValueError("每个 chunk 必须是一个字典") + required_fields = ["text", "vector", "document_id", "filename", "file_path", "upload_time", "file_type", + "userid", "knowledge_base_id"] + if not all(k in chunk for k in required_fields): + raise ValueError(f"chunk 缺少必要字段: {', '.join(set(required_fields) - set(chunk.keys()))}") + if not isinstance(chunk["vector"], list) or len(chunk["vector"]) != 1024: + raise ValueError("vector 必须是长度为 1024 的浮点数列表") - # 插入 Milvus + # 验证 userid 和 knowledge_base_id 一致性 + if len(set(chunk["userid"] for chunk in chunks)) > 1: + raise ValueError("所有 chunk 的 userid 必须一致") + if len(set(chunk["knowledge_base_id"] for chunk in chunks)) > 1: + raise ValueError("所有 chunk 的 knowledge_base_id 必须一致") + if len(set(chunk["filename"] for chunk in chunks)) > 1: + raise ValueError("所有 chunk 的 filename 必须一致") + + # 检查是否已存在相同的 userid、knowledge_base_id 和 filename collection = Collection(collection_name) collection.load() + expr = f"userid == '{chunks[0]['userid']}' and knowledge_base_id == '{chunks[0]['knowledge_base_id']}' and filename == '{chunks[0]['filename']}'" + debug(f"检查重复文档: {expr}") + results = collection.query(expr=expr, output_fields=["document_id"], limit=1) + if results: + debug( + f"找到重复文档: userid={chunks[0]['userid']}, knowledge_base_id={chunks[0]['knowledge_base_id']}, filename={chunks[0]['filename']}") + return { + "status": "error", + "document_id": document_id, + "collection_name": collection_name, + "message": f"文档已存在: userid={chunks[0]['userid']}, knowledge_base_id={chunks[0]['knowledge_base_id']}, filename={chunks[0]['filename']}", + "status_code": 400 + } + + # 提取数据 + userids = [chunk["userid"] for chunk in chunks] + knowledge_base_ids = [chunk["knowledge_base_id"] for chunk in chunks] + texts = [chunk["text"] for chunk in chunks] + embeddings = [chunk["vector"] for chunk in chunks] + document_ids = [chunk["document_id"] for chunk in chunks] + filenames = [chunk["filename"] for chunk in chunks] + file_paths = [chunk["file_path"] for chunk in chunks] + upload_times = [chunk["upload_time"] for chunk in chunks] + file_types = [chunk["file_type"] for chunk in chunks] + + # 构造插入数据 data = { - "userid": [userid] * len(texts), - "knowledge_base_id": [knowledge_base_id] * len(texts), - "document_id": [document_id] * len(texts), + "userid": userids, + "knowledge_base_id": knowledge_base_ids, + "document_id": document_ids, "text": texts, "vector": embeddings, - "filename": [filename] * len(texts), - "file_path": [file_path] * len(texts), - "upload_time": [upload_time] * len(texts), - "file_type": [file_type] * len(texts), + "filename": filenames, + "file_path": file_paths, + "upload_time": upload_times, + "file_type": file_types, } + + schema_fields = [field.name for field in collection.schema.fields if field.name != "pk"] + debug(f"Schema fields: {schema_fields}") + debug(f"Data keys: {list(data.keys())}") + if list(data.keys()) != schema_fields: + raise ValueError(f"数据字段顺序不匹配,期望: {schema_fields}, 实际: {list(data.keys())}") + collection.insert([data[field.name] for field in collection.schema.fields if field.name != "pk"]) collection.flush() debug(f"成功插入 {len(texts)} 个文档到集合 {collection_name}") @@ -340,8 +373,17 @@ class MilvusDBConnection(BaseDBConnection): "message": f"成功插入 {len(texts)} 个文档到 {collection_name}", "status_code": 200 } + except MilvusException as e: + error(f"Milvus 插入失败: {str(e)}, 堆栈: {traceback.format_exc()}") + return { + "status": "error", + "document_id": document_id, + "collection_name": collection_name, + "message": f"Milvus 插入失败: {str(e)}", + "status_code": 400 + } except Exception as e: - error(f"插入文档失败: {str(e)}") + error(f"插入文档失败: {str(e)}, 堆栈: {traceback.format_exc()}") return { "status": "error", "document_id": document_id, @@ -350,7 +392,8 @@ class MilvusDBConnection(BaseDBConnection): "status_code": 400 } - async def _delete_document(self, db_type: str, userid: str, filename: str, knowledge_base_id: str) -> Dict[str, Any]: + async def _delete_document(self, userid: str, filename: str, knowledge_base_id: str, db_type: str = "") -> Dict[ + str, Any]: """删除用户指定文件数据,仅处理 Milvus 记录""" collection_name = "ragdb" if not db_type else f"ragdb_{db_type}" try: @@ -550,9 +593,10 @@ class MilvusDBConnection(BaseDBConnection): "status_code": 400 } - async def _search_query(self, collection_name: str, query_vector: List[float], userid: str, - knowledge_base_ids: List[str], limit: int = 5, offset: int = 0) -> Dict[str, Any]: + async def _search_query(self, query_vector: List[float], userid: str, + knowledge_base_ids: List[str], limit: int = 5, offset: int = 0, db_type: str = "") -> Dict[str, Any]: """基于向量搜索 Milvus 集合""" + collection_name = "ragdb" if not db_type else f"ragdb_{db_type}" timing_stats = {} start_time = time.time() try: