数据库服务 rag服务

This commit is contained in:
wangmeihua 2025-07-23 17:50:59 +08:00
parent 83331a914b
commit c32c16512e
6 changed files with 459 additions and 894 deletions

0
llmengine/base_connection.py Normal file → Executable file
View File

0
llmengine/base_db.py Normal file → Executable file
View File

106
llmengine/connection.py Normal file → Executable file
View File

@ -1,4 +1,4 @@
import milvus_connection import llmengine.milvus_connection
from traceback import format_exc from traceback import format_exc
import argparse import argparse
from aiohttp import web from aiohttp import web
@ -403,8 +403,8 @@ async def delete_file(request, params_kw, *params, **kw):
result = await engine.handle_connection("delete_document", { result = await engine.handle_connection("delete_document", {
"userid": userid, "userid": userid,
"filename": filename, "filename": filename,
"db_type": db_type, "knowledge_base_id": knowledge_base_id,
"knowledge_base_id": knowledge_base_id "db_type": db_type
}) })
debug(f'Delete result: {result=}') debug(f'Delete result: {result=}')
status = 200 if result.get("status") == "success" else 400 status = 200 if result.get("status") == "success" else 400
@ -454,60 +454,15 @@ async def delete_knowledge_base(request, params_kw, *params, **kw):
"status_code": 400 "status_code": 400
}, dumps=lambda obj: json.dumps(obj, ensure_ascii=False), status=400) }, dumps=lambda obj: json.dumps(obj, ensure_ascii=False), status=400)
async def fused_search_query(request, params_kw, *params, **kw):
debug(f'{params_kw=}')
se = ServerEnv()
engine = se.engine
query = params_kw.get('query')
userid = params_kw.get('userid')
db_type = params_kw.get('db_type', '')
knowledge_base_ids = params_kw.get('knowledge_base_ids')
limit = params_kw.get('limit')
offset = params_kw.get('offset', 0)
use_rerank = params_kw.get('use_rerank', True)
collection_name = "ragdb" if not db_type else f"ragdb_{db_type}"
try:
if not all([query, userid, knowledge_base_ids]):
debug(f'query, userid 或 knowledge_base_ids 未提供')
return web.json_response({
"status": "error",
"message": "query, userid 或 knowledge_base_ids 未提供",
"collection_name": collection_name
}, dumps=lambda obj: json.dumps(obj, ensure_ascii=False), status=400)
result = await engine.handle_connection("fused_search", {
"query": query,
"userid": userid,
"db_type": db_type,
"knowledge_base_ids": knowledge_base_ids,
"limit": limit,
"offset": offset,
"use_rerank": use_rerank
})
debug(f'{result=}')
response = {
"status": "success",
"results": result.get("results", []),
"timing": result.get("timing", {}),
"collection_name": collection_name
}
return web.json_response(response, dumps=lambda obj: json.dumps(obj, ensure_ascii=False))
except Exception as e:
error(f'融合搜索失败: {str(e)}')
return web.json_response({
"status": "error",
"message": str(e),
"collection_name": collection_name
}, dumps=lambda obj: json.dumps(obj, ensure_ascii=False), status=400)
async def search_query(request, params_kw, *params, **kw): async def search_query(request, params_kw, *params, **kw):
debug(f'{params_kw=}') debug(f'Received search_query params: {params_kw=}')
se = ServerEnv() se = ServerEnv()
engine = se.engine engine = se.engine
query = params_kw.get('query') query = params_kw.get('query')
userid = params_kw.get('userid') userid = params_kw.get('userid')
db_type = params_kw.get('db_type', '') db_type = params_kw.get('db_type', '')
knowledge_base_ids = params_kw.get('knowledge_base_ids') knowledge_base_ids = params_kw.get('knowledge_base_ids')
limit = params_kw.get('limit') limit = params_kw.get('limit', 5)
offset = params_kw.get('offset', 0) offset = params_kw.get('offset', 0)
use_rerank = params_kw.get('use_rerank', True) use_rerank = params_kw.get('use_rerank', True)
collection_name = "ragdb" if not db_type else f"ragdb_{db_type}" collection_name = "ragdb" if not db_type else f"ragdb_{db_type}"
@ -522,13 +477,13 @@ async def search_query(request, params_kw, *params, **kw):
result = await engine.handle_connection("search_query", { result = await engine.handle_connection("search_query", {
"query": query, "query": query,
"userid": userid, "userid": userid,
"db_type": db_type,
"knowledge_base_ids": knowledge_base_ids, "knowledge_base_ids": knowledge_base_ids,
"limit": limit, "limit": limit,
"offset": offset, "offset": offset,
"use_rerank": use_rerank "use_rerank": use_rerank,
"db_type": db_type
}) })
debug(f'{result=}') debug(f'Search result: {result=}')
response = { response = {
"status": "success", "status": "success",
"results": result.get("results", []), "results": result.get("results", []),
@ -544,6 +499,51 @@ async def search_query(request, params_kw, *params, **kw):
"collection_name": collection_name "collection_name": collection_name
}, dumps=lambda obj: json.dumps(obj, ensure_ascii=False), status=400) }, dumps=lambda obj: json.dumps(obj, ensure_ascii=False), status=400)
async def fused_search_query(request, params_kw, *params, **kw):
debug(f'Received fused_search_query params: {params_kw=}')
se = ServerEnv()
engine = se.engine
query = params_kw.get('query')
userid = params_kw.get('userid')
db_type = params_kw.get('db_type', '')
knowledge_base_ids = params_kw.get('knowledge_base_ids')
limit = params_kw.get('limit', 5)
offset = params_kw.get('offset', 0)
use_rerank = params_kw.get('use_rerank', True)
collection_name = "ragdb" if not db_type else f"ragdb_{db_type}"
try:
if not all([query, userid, knowledge_base_ids]):
debug(f'query, userid 或 knowledge_base_ids 未提供')
return web.json_response({
"status": "error",
"message": "query, userid 或 knowledge_base_ids 未提供",
"collection_name": collection_name
}, dumps=lambda obj: json.dumps(obj, ensure_ascii=False), status=400)
result = await engine.handle_connection("fused_search", {
"query": query,
"userid": userid,
"knowledge_base_ids": knowledge_base_ids,
"limit": limit,
"offset": offset,
"use_rerank": use_rerank,
"db_type": db_type
})
debug(f'Fused search result: {result=}')
response = {
"status": "success",
"results": result.get("results", []),
"timing": result.get("timing", {}),
"collection_name": collection_name
}
return web.json_response(response, dumps=lambda obj: json.dumps(obj, ensure_ascii=False))
except Exception as e:
error(f'融合搜索失败: {str(e)}')
return web.json_response({
"status": "error",
"message": str(e),
"collection_name": collection_name
}, dumps=lambda obj: json.dumps(obj, ensure_ascii=False), status=400)
async def list_user_files(request, params_kw, *params, **kw): async def list_user_files(request, params_kw, *params, **kw):
debug(f'{params_kw=}') debug(f'{params_kw=}')
se = ServerEnv() se = ServerEnv()

22
llmengine/db_service.py Normal file → Executable file
View File

@ -276,33 +276,17 @@ async def insert_document(request, params_kw, *params, **kw):
debug(f'Received params: {params_kw=}') debug(f'Received params: {params_kw=}')
se = ServerEnv() se = ServerEnv()
engine = se.engine engine = se.engine
userid = params_kw.get('userid', '') chunks = params_kw.get('chunks', '')
knowledge_base_id = params_kw.get('knowledge_base_id', '')
document_id = params_kw.get('document_id', '')
texts = params_kw.get('texts', [])
embeddings = params_kw.get('embeddings', [])
filename = params_kw.get('filename', '')
file_path = params_kw.get('file_path', '')
upload_time = params_kw.get('upload_time', '')
file_type = params_kw.get('file_type', '')
db_type = params_kw.get('db_type', '') db_type = params_kw.get('db_type', '')
collection_name = "ragdb" if not db_type else f"ragdb_{db_type}" collection_name = "ragdb" if not db_type else f"ragdb_{db_type}"
try: try:
required_fields = ['userid', 'knowledge_base_id', 'texts', 'embeddings'] required_fields = ['chunks']
missing_fields = [field for field in required_fields if field not in params_kw or not params_kw[field]] missing_fields = [field for field in required_fields if field not in params_kw or not params_kw[field]]
if missing_fields: if missing_fields:
raise ValueError(f"缺少必填字段: {', '.join(missing_fields)}") raise ValueError(f"缺少必填字段: {', '.join(missing_fields)}")
result = await engine.handle_connection("insert_document", { result = await engine.handle_connection("insert_document", {
"userid": userid, "chunks": chunks,
"knowledge_base_id": knowledge_base_id,
"document_id": document_id,
"texts": texts,
"embeddings": embeddings,
"filename": filename,
"file_path": file_path,
"upload_time": upload_time,
"file_type": file_type,
"db_type": db_type "db_type": db_type
}) })
debug(f'Insert result: {result=}') debug(f'Insert result: {result=}')

1043
llmengine/milvus_connection.py Normal file → Executable file

File diff suppressed because it is too large Load Diff

128
llmengine/milvus_db.py Normal file → Executable file
View File

@ -7,6 +7,7 @@ from typing import Dict, List, Any
import uuid import uuid
from datetime import datetime from datetime import datetime
from llmengine.base_db import connection_register, BaseDBConnection from llmengine.base_db import connection_register, BaseDBConnection
import time
class MilvusDBConnection(BaseDBConnection): class MilvusDBConnection(BaseDBConnection):
_instance = None _instance = None
@ -74,30 +75,13 @@ class MilvusDBConnection(BaseDBConnection):
elif action == "delete_collection": elif action == "delete_collection":
return await self._delete_collection(db_type) return await self._delete_collection(db_type)
elif action == "insert_document": elif action == "insert_document":
userid = params.get("userid", "") chunks = params.get("chunks", [])
knowledge_base_id = params.get("knowledge_base_id", "") return await self._insert_document(chunks, db_type)
document_id = params.get("document_id", str(uuid.uuid4()))
texts = params.get("texts", [])
embeddings = params.get("embeddings", [])
filename = params.get("filename", "")
file_path = params.get("file_path", "")
upload_time = params.get("upload_time", datetime.now().isoformat())
file_type = params.get("file_type", "")
if not userid or not knowledge_base_id or not texts or not embeddings:
return {"status": "error", "message": "userid、knowledge_base_id、texts 和 embeddings 不能为空",
"collection_name": collection_name, "document_id": "", "status_code": 400}
if "_" in userid or "_" in knowledge_base_id:
return {"status": "error", "message": "userid 和 knowledge_base_id 不能包含下划线",
"collection_name": collection_name, "document_id": "", "status_code": 400}
if len(knowledge_base_id) > 100 or len(userid) > 100:
return {"status": "error", "message": "userid 或 knowledge_base_id 的长度应小于 100",
"collection_name": collection_name, "document_id": "", "status_code": 400}
return await self._insert_document(collection_name, userid, knowledge_base_id, document_id, texts, embeddings,
filename, file_path, upload_time, file_type)
elif action == "delete_document": elif action == "delete_document":
userid = params.get("userid", "") userid = params.get("userid", "")
filename = params.get("filename", "") filename = params.get("filename", "")
knowledge_base_id = params.get("knowledge_base_id", "") knowledge_base_id = params.get("knowledge_base_id", "")
db_type = params.get("db_type", "")
if not userid or not filename or not knowledge_base_id: if not userid or not filename or not knowledge_base_id:
return {"status": "error", "message": "userid、filename 和 knowledge_base_id 不能为空", return {"status": "error", "message": "userid、filename 和 knowledge_base_id 不能为空",
"collection_name": collection_name, "document_id": "", "status_code": 400} "collection_name": collection_name, "document_id": "", "status_code": 400}
@ -107,7 +91,7 @@ class MilvusDBConnection(BaseDBConnection):
if len(userid) > 100 or len(filename) > 255 or len(knowledge_base_id) > 100: if len(userid) > 100 or len(filename) > 255 or len(knowledge_base_id) > 100:
return {"status": "error", "message": "userid、filename 或 knowledge_base_id 的长度超出限制", return {"status": "error", "message": "userid、filename 或 knowledge_base_id 的长度超出限制",
"collection_name": collection_name, "document_id": "", "status_code": 400} "collection_name": collection_name, "document_id": "", "status_code": 400}
return await self._delete_document(db_type, userid, filename, knowledge_base_id) return await self._delete_document(userid, filename, knowledge_base_id, db_type)
elif action == "delete_knowledge_base": elif action == "delete_knowledge_base":
userid = params.get("userid", "") userid = params.get("userid", "")
knowledge_base_id = params.get("knowledge_base_id", "") knowledge_base_id = params.get("knowledge_base_id", "")
@ -127,13 +111,14 @@ class MilvusDBConnection(BaseDBConnection):
knowledge_base_ids = params.get("knowledge_base_ids", []) knowledge_base_ids = params.get("knowledge_base_ids", [])
limit = params.get("limit", 5) limit = params.get("limit", 5)
offset = params.get("offset", 0) offset = params.get("offset", 0)
db_type = params.get("db_type", "")
if not query_vector or not userid or not knowledge_base_ids: if not query_vector or not userid or not knowledge_base_ids:
return {"status": "error", "message": "query_vector、userid 或 knowledge_base_ids 不能为空", return {"status": "error", "message": "query_vector、userid 或 knowledge_base_ids 不能为空",
"collection_name": collection_name, "document_id": "", "status_code": 400} "collection_name": collection_name, "document_id": "", "status_code": 400}
if limit < 1 or limit > 16384: if limit < 1 or limit > 16384:
return {"status": "error", "message": "limit 必须在 1 到 16384 之间", return {"status": "error", "message": "limit 必须在 1 到 16384 之间",
"collection_name": collection_name, "document_id": "", "status_code": 400} "collection_name": collection_name, "document_id": "", "status_code": 400}
return await self._search_query(collection_name, query_vector, userid, knowledge_base_ids, limit, offset) return await self._search_query(query_vector, userid, knowledge_base_ids, limit, offset, db_type)
elif action == "list_user_files": elif action == "list_user_files":
userid = params.get("userid", "") userid = params.get("userid", "")
if not userid: if not userid:
@ -300,36 +285,84 @@ class MilvusDBConnection(BaseDBConnection):
"message": str(e) "message": str(e)
} }
async def _insert_document(self, collection_name: str, userid: str, knowledge_base_id: str, document_id: str, async def _insert_document(self, chunks: List[Dict], db_type: str = "") -> Dict[str, Any]:
texts: List[str], embeddings: List[List[float]], filename: str, file_path: str,
upload_time: str, file_type: str) -> Dict[str, Any]:
"""插入文档到 Milvus""" """插入文档到 Milvus"""
collection_name = "ragdb" if not db_type else f"ragdb_{db_type}"
document_id = chunks[0]["document_id"] if chunks else ""
try: try:
# 检查集合是否存在 # 检查集合是否存在
create_result = await self._create_collection(collection_name.split('_')[-1] if '_' in collection_name else "") create_result = await self._create_collection(db_type)
if create_result["status"] == "error": if create_result["status"] == "error":
raise RuntimeError(f"集合创建失败: {create_result['message']}") raise RuntimeError(f"集合创建失败: {create_result['message']}")
# 检查输入数据 # 检查输入数据
if len(texts) != len(embeddings): if not chunks:
raise ValueError("texts 和 embeddings 的长度必须一致") raise ValueError("chunks 不能为空")
if not all(isinstance(emb, list) and len(emb) == 1024 for emb in embeddings): for chunk in chunks:
raise ValueError("embeddings 必须是长度为 1024 的浮点数列表") if not isinstance(chunk, dict):
raise ValueError("每个 chunk 必须是一个字典")
required_fields = ["text", "vector", "document_id", "filename", "file_path", "upload_time", "file_type",
"userid", "knowledge_base_id"]
if not all(k in chunk for k in required_fields):
raise ValueError(f"chunk 缺少必要字段: {', '.join(set(required_fields) - set(chunk.keys()))}")
if not isinstance(chunk["vector"], list) or len(chunk["vector"]) != 1024:
raise ValueError("vector 必须是长度为 1024 的浮点数列表")
# 插入 Milvus # 验证 userid 和 knowledge_base_id 一致性
if len(set(chunk["userid"] for chunk in chunks)) > 1:
raise ValueError("所有 chunk 的 userid 必须一致")
if len(set(chunk["knowledge_base_id"] for chunk in chunks)) > 1:
raise ValueError("所有 chunk 的 knowledge_base_id 必须一致")
if len(set(chunk["filename"] for chunk in chunks)) > 1:
raise ValueError("所有 chunk 的 filename 必须一致")
# 检查是否已存在相同的 userid、knowledge_base_id 和 filename
collection = Collection(collection_name) collection = Collection(collection_name)
collection.load() collection.load()
expr = f"userid == '{chunks[0]['userid']}' and knowledge_base_id == '{chunks[0]['knowledge_base_id']}' and filename == '{chunks[0]['filename']}'"
debug(f"检查重复文档: {expr}")
results = collection.query(expr=expr, output_fields=["document_id"], limit=1)
if results:
debug(
f"找到重复文档: userid={chunks[0]['userid']}, knowledge_base_id={chunks[0]['knowledge_base_id']}, filename={chunks[0]['filename']}")
return {
"status": "error",
"document_id": document_id,
"collection_name": collection_name,
"message": f"文档已存在: userid={chunks[0]['userid']}, knowledge_base_id={chunks[0]['knowledge_base_id']}, filename={chunks[0]['filename']}",
"status_code": 400
}
# 提取数据
userids = [chunk["userid"] for chunk in chunks]
knowledge_base_ids = [chunk["knowledge_base_id"] for chunk in chunks]
texts = [chunk["text"] for chunk in chunks]
embeddings = [chunk["vector"] for chunk in chunks]
document_ids = [chunk["document_id"] for chunk in chunks]
filenames = [chunk["filename"] for chunk in chunks]
file_paths = [chunk["file_path"] for chunk in chunks]
upload_times = [chunk["upload_time"] for chunk in chunks]
file_types = [chunk["file_type"] for chunk in chunks]
# 构造插入数据
data = { data = {
"userid": [userid] * len(texts), "userid": userids,
"knowledge_base_id": [knowledge_base_id] * len(texts), "knowledge_base_id": knowledge_base_ids,
"document_id": [document_id] * len(texts), "document_id": document_ids,
"text": texts, "text": texts,
"vector": embeddings, "vector": embeddings,
"filename": [filename] * len(texts), "filename": filenames,
"file_path": [file_path] * len(texts), "file_path": file_paths,
"upload_time": [upload_time] * len(texts), "upload_time": upload_times,
"file_type": [file_type] * len(texts), "file_type": file_types,
} }
schema_fields = [field.name for field in collection.schema.fields if field.name != "pk"]
debug(f"Schema fields: {schema_fields}")
debug(f"Data keys: {list(data.keys())}")
if list(data.keys()) != schema_fields:
raise ValueError(f"数据字段顺序不匹配,期望: {schema_fields}, 实际: {list(data.keys())}")
collection.insert([data[field.name] for field in collection.schema.fields if field.name != "pk"]) collection.insert([data[field.name] for field in collection.schema.fields if field.name != "pk"])
collection.flush() collection.flush()
debug(f"成功插入 {len(texts)} 个文档到集合 {collection_name}") debug(f"成功插入 {len(texts)} 个文档到集合 {collection_name}")
@ -340,8 +373,17 @@ class MilvusDBConnection(BaseDBConnection):
"message": f"成功插入 {len(texts)} 个文档到 {collection_name}", "message": f"成功插入 {len(texts)} 个文档到 {collection_name}",
"status_code": 200 "status_code": 200
} }
except MilvusException as e:
error(f"Milvus 插入失败: {str(e)}, 堆栈: {traceback.format_exc()}")
return {
"status": "error",
"document_id": document_id,
"collection_name": collection_name,
"message": f"Milvus 插入失败: {str(e)}",
"status_code": 400
}
except Exception as e: except Exception as e:
error(f"插入文档失败: {str(e)}") error(f"插入文档失败: {str(e)}, 堆栈: {traceback.format_exc()}")
return { return {
"status": "error", "status": "error",
"document_id": document_id, "document_id": document_id,
@ -350,7 +392,8 @@ class MilvusDBConnection(BaseDBConnection):
"status_code": 400 "status_code": 400
} }
async def _delete_document(self, db_type: str, userid: str, filename: str, knowledge_base_id: str) -> Dict[str, Any]: async def _delete_document(self, userid: str, filename: str, knowledge_base_id: str, db_type: str = "") -> Dict[
str, Any]:
"""删除用户指定文件数据,仅处理 Milvus 记录""" """删除用户指定文件数据,仅处理 Milvus 记录"""
collection_name = "ragdb" if not db_type else f"ragdb_{db_type}" collection_name = "ragdb" if not db_type else f"ragdb_{db_type}"
try: try:
@ -550,9 +593,10 @@ class MilvusDBConnection(BaseDBConnection):
"status_code": 400 "status_code": 400
} }
async def _search_query(self, collection_name: str, query_vector: List[float], userid: str, async def _search_query(self, query_vector: List[float], userid: str,
knowledge_base_ids: List[str], limit: int = 5, offset: int = 0) -> Dict[str, Any]: knowledge_base_ids: List[str], limit: int = 5, offset: int = 0, db_type: str = "") -> Dict[str, Any]:
"""基于向量搜索 Milvus 集合""" """基于向量搜索 Milvus 集合"""
collection_name = "ragdb" if not db_type else f"ragdb_{db_type}"
timing_stats = {} timing_stats = {}
start_time = time.time() start_time = time.time()
try: try: