数据库服务 rag服务
This commit is contained in:
parent
83331a914b
commit
c32c16512e
0
llmengine/base_connection.py
Normal file → Executable file
0
llmengine/base_connection.py
Normal file → Executable file
0
llmengine/base_db.py
Normal file → Executable file
0
llmengine/base_db.py
Normal file → Executable file
106
llmengine/connection.py
Normal file → Executable file
106
llmengine/connection.py
Normal file → Executable file
@ -1,4 +1,4 @@
|
||||
import milvus_connection
|
||||
import llmengine.milvus_connection
|
||||
from traceback import format_exc
|
||||
import argparse
|
||||
from aiohttp import web
|
||||
@ -403,8 +403,8 @@ async def delete_file(request, params_kw, *params, **kw):
|
||||
result = await engine.handle_connection("delete_document", {
|
||||
"userid": userid,
|
||||
"filename": filename,
|
||||
"db_type": db_type,
|
||||
"knowledge_base_id": knowledge_base_id
|
||||
"knowledge_base_id": knowledge_base_id,
|
||||
"db_type": db_type
|
||||
})
|
||||
debug(f'Delete result: {result=}')
|
||||
status = 200 if result.get("status") == "success" else 400
|
||||
@ -454,60 +454,15 @@ async def delete_knowledge_base(request, params_kw, *params, **kw):
|
||||
"status_code": 400
|
||||
}, dumps=lambda obj: json.dumps(obj, ensure_ascii=False), status=400)
|
||||
|
||||
async def fused_search_query(request, params_kw, *params, **kw):
|
||||
debug(f'{params_kw=}')
|
||||
se = ServerEnv()
|
||||
engine = se.engine
|
||||
query = params_kw.get('query')
|
||||
userid = params_kw.get('userid')
|
||||
db_type = params_kw.get('db_type', '')
|
||||
knowledge_base_ids = params_kw.get('knowledge_base_ids')
|
||||
limit = params_kw.get('limit')
|
||||
offset = params_kw.get('offset', 0)
|
||||
use_rerank = params_kw.get('use_rerank', True)
|
||||
collection_name = "ragdb" if not db_type else f"ragdb_{db_type}"
|
||||
try:
|
||||
if not all([query, userid, knowledge_base_ids]):
|
||||
debug(f'query, userid 或 knowledge_base_ids 未提供')
|
||||
return web.json_response({
|
||||
"status": "error",
|
||||
"message": "query, userid 或 knowledge_base_ids 未提供",
|
||||
"collection_name": collection_name
|
||||
}, dumps=lambda obj: json.dumps(obj, ensure_ascii=False), status=400)
|
||||
result = await engine.handle_connection("fused_search", {
|
||||
"query": query,
|
||||
"userid": userid,
|
||||
"db_type": db_type,
|
||||
"knowledge_base_ids": knowledge_base_ids,
|
||||
"limit": limit,
|
||||
"offset": offset,
|
||||
"use_rerank": use_rerank
|
||||
})
|
||||
debug(f'{result=}')
|
||||
response = {
|
||||
"status": "success",
|
||||
"results": result.get("results", []),
|
||||
"timing": result.get("timing", {}),
|
||||
"collection_name": collection_name
|
||||
}
|
||||
return web.json_response(response, dumps=lambda obj: json.dumps(obj, ensure_ascii=False))
|
||||
except Exception as e:
|
||||
error(f'融合搜索失败: {str(e)}')
|
||||
return web.json_response({
|
||||
"status": "error",
|
||||
"message": str(e),
|
||||
"collection_name": collection_name
|
||||
}, dumps=lambda obj: json.dumps(obj, ensure_ascii=False), status=400)
|
||||
|
||||
async def search_query(request, params_kw, *params, **kw):
|
||||
debug(f'{params_kw=}')
|
||||
debug(f'Received search_query params: {params_kw=}')
|
||||
se = ServerEnv()
|
||||
engine = se.engine
|
||||
query = params_kw.get('query')
|
||||
userid = params_kw.get('userid')
|
||||
db_type = params_kw.get('db_type', '')
|
||||
knowledge_base_ids = params_kw.get('knowledge_base_ids')
|
||||
limit = params_kw.get('limit')
|
||||
limit = params_kw.get('limit', 5)
|
||||
offset = params_kw.get('offset', 0)
|
||||
use_rerank = params_kw.get('use_rerank', True)
|
||||
collection_name = "ragdb" if not db_type else f"ragdb_{db_type}"
|
||||
@ -522,13 +477,13 @@ async def search_query(request, params_kw, *params, **kw):
|
||||
result = await engine.handle_connection("search_query", {
|
||||
"query": query,
|
||||
"userid": userid,
|
||||
"db_type": db_type,
|
||||
"knowledge_base_ids": knowledge_base_ids,
|
||||
"limit": limit,
|
||||
"offset": offset,
|
||||
"use_rerank": use_rerank
|
||||
"use_rerank": use_rerank,
|
||||
"db_type": db_type
|
||||
})
|
||||
debug(f'{result=}')
|
||||
debug(f'Search result: {result=}')
|
||||
response = {
|
||||
"status": "success",
|
||||
"results": result.get("results", []),
|
||||
@ -544,6 +499,51 @@ async def search_query(request, params_kw, *params, **kw):
|
||||
"collection_name": collection_name
|
||||
}, dumps=lambda obj: json.dumps(obj, ensure_ascii=False), status=400)
|
||||
|
||||
async def fused_search_query(request, params_kw, *params, **kw):
|
||||
debug(f'Received fused_search_query params: {params_kw=}')
|
||||
se = ServerEnv()
|
||||
engine = se.engine
|
||||
query = params_kw.get('query')
|
||||
userid = params_kw.get('userid')
|
||||
db_type = params_kw.get('db_type', '')
|
||||
knowledge_base_ids = params_kw.get('knowledge_base_ids')
|
||||
limit = params_kw.get('limit', 5)
|
||||
offset = params_kw.get('offset', 0)
|
||||
use_rerank = params_kw.get('use_rerank', True)
|
||||
collection_name = "ragdb" if not db_type else f"ragdb_{db_type}"
|
||||
try:
|
||||
if not all([query, userid, knowledge_base_ids]):
|
||||
debug(f'query, userid 或 knowledge_base_ids 未提供')
|
||||
return web.json_response({
|
||||
"status": "error",
|
||||
"message": "query, userid 或 knowledge_base_ids 未提供",
|
||||
"collection_name": collection_name
|
||||
}, dumps=lambda obj: json.dumps(obj, ensure_ascii=False), status=400)
|
||||
result = await engine.handle_connection("fused_search", {
|
||||
"query": query,
|
||||
"userid": userid,
|
||||
"knowledge_base_ids": knowledge_base_ids,
|
||||
"limit": limit,
|
||||
"offset": offset,
|
||||
"use_rerank": use_rerank,
|
||||
"db_type": db_type
|
||||
})
|
||||
debug(f'Fused search result: {result=}')
|
||||
response = {
|
||||
"status": "success",
|
||||
"results": result.get("results", []),
|
||||
"timing": result.get("timing", {}),
|
||||
"collection_name": collection_name
|
||||
}
|
||||
return web.json_response(response, dumps=lambda obj: json.dumps(obj, ensure_ascii=False))
|
||||
except Exception as e:
|
||||
error(f'融合搜索失败: {str(e)}')
|
||||
return web.json_response({
|
||||
"status": "error",
|
||||
"message": str(e),
|
||||
"collection_name": collection_name
|
||||
}, dumps=lambda obj: json.dumps(obj, ensure_ascii=False), status=400)
|
||||
|
||||
async def list_user_files(request, params_kw, *params, **kw):
|
||||
debug(f'{params_kw=}')
|
||||
se = ServerEnv()
|
||||
|
||||
22
llmengine/db_service.py
Normal file → Executable file
22
llmengine/db_service.py
Normal file → Executable file
@ -276,33 +276,17 @@ async def insert_document(request, params_kw, *params, **kw):
|
||||
debug(f'Received params: {params_kw=}')
|
||||
se = ServerEnv()
|
||||
engine = se.engine
|
||||
userid = params_kw.get('userid', '')
|
||||
knowledge_base_id = params_kw.get('knowledge_base_id', '')
|
||||
document_id = params_kw.get('document_id', '')
|
||||
texts = params_kw.get('texts', [])
|
||||
embeddings = params_kw.get('embeddings', [])
|
||||
filename = params_kw.get('filename', '')
|
||||
file_path = params_kw.get('file_path', '')
|
||||
upload_time = params_kw.get('upload_time', '')
|
||||
file_type = params_kw.get('file_type', '')
|
||||
chunks = params_kw.get('chunks', '')
|
||||
db_type = params_kw.get('db_type', '')
|
||||
collection_name = "ragdb" if not db_type else f"ragdb_{db_type}"
|
||||
try:
|
||||
required_fields = ['userid', 'knowledge_base_id', 'texts', 'embeddings']
|
||||
required_fields = ['chunks']
|
||||
missing_fields = [field for field in required_fields if field not in params_kw or not params_kw[field]]
|
||||
if missing_fields:
|
||||
raise ValueError(f"缺少必填字段: {', '.join(missing_fields)}")
|
||||
|
||||
result = await engine.handle_connection("insert_document", {
|
||||
"userid": userid,
|
||||
"knowledge_base_id": knowledge_base_id,
|
||||
"document_id": document_id,
|
||||
"texts": texts,
|
||||
"embeddings": embeddings,
|
||||
"filename": filename,
|
||||
"file_path": file_path,
|
||||
"upload_time": upload_time,
|
||||
"file_type": file_type,
|
||||
"chunks": chunks,
|
||||
"db_type": db_type
|
||||
})
|
||||
debug(f'Insert result: {result=}')
|
||||
|
||||
1097
llmengine/milvus_connection.py
Normal file → Executable file
1097
llmengine/milvus_connection.py
Normal file → Executable file
File diff suppressed because it is too large
Load Diff
128
llmengine/milvus_db.py
Normal file → Executable file
128
llmengine/milvus_db.py
Normal file → Executable file
@ -7,6 +7,7 @@ from typing import Dict, List, Any
|
||||
import uuid
|
||||
from datetime import datetime
|
||||
from llmengine.base_db import connection_register, BaseDBConnection
|
||||
import time
|
||||
|
||||
class MilvusDBConnection(BaseDBConnection):
|
||||
_instance = None
|
||||
@ -74,30 +75,13 @@ class MilvusDBConnection(BaseDBConnection):
|
||||
elif action == "delete_collection":
|
||||
return await self._delete_collection(db_type)
|
||||
elif action == "insert_document":
|
||||
userid = params.get("userid", "")
|
||||
knowledge_base_id = params.get("knowledge_base_id", "")
|
||||
document_id = params.get("document_id", str(uuid.uuid4()))
|
||||
texts = params.get("texts", [])
|
||||
embeddings = params.get("embeddings", [])
|
||||
filename = params.get("filename", "")
|
||||
file_path = params.get("file_path", "")
|
||||
upload_time = params.get("upload_time", datetime.now().isoformat())
|
||||
file_type = params.get("file_type", "")
|
||||
if not userid or not knowledge_base_id or not texts or not embeddings:
|
||||
return {"status": "error", "message": "userid、knowledge_base_id、texts 和 embeddings 不能为空",
|
||||
"collection_name": collection_name, "document_id": "", "status_code": 400}
|
||||
if "_" in userid or "_" in knowledge_base_id:
|
||||
return {"status": "error", "message": "userid 和 knowledge_base_id 不能包含下划线",
|
||||
"collection_name": collection_name, "document_id": "", "status_code": 400}
|
||||
if len(knowledge_base_id) > 100 or len(userid) > 100:
|
||||
return {"status": "error", "message": "userid 或 knowledge_base_id 的长度应小于 100",
|
||||
"collection_name": collection_name, "document_id": "", "status_code": 400}
|
||||
return await self._insert_document(collection_name, userid, knowledge_base_id, document_id, texts, embeddings,
|
||||
filename, file_path, upload_time, file_type)
|
||||
chunks = params.get("chunks", [])
|
||||
return await self._insert_document(chunks, db_type)
|
||||
elif action == "delete_document":
|
||||
userid = params.get("userid", "")
|
||||
filename = params.get("filename", "")
|
||||
knowledge_base_id = params.get("knowledge_base_id", "")
|
||||
db_type = params.get("db_type", "")
|
||||
if not userid or not filename or not knowledge_base_id:
|
||||
return {"status": "error", "message": "userid、filename 和 knowledge_base_id 不能为空",
|
||||
"collection_name": collection_name, "document_id": "", "status_code": 400}
|
||||
@ -107,7 +91,7 @@ class MilvusDBConnection(BaseDBConnection):
|
||||
if len(userid) > 100 or len(filename) > 255 or len(knowledge_base_id) > 100:
|
||||
return {"status": "error", "message": "userid、filename 或 knowledge_base_id 的长度超出限制",
|
||||
"collection_name": collection_name, "document_id": "", "status_code": 400}
|
||||
return await self._delete_document(db_type, userid, filename, knowledge_base_id)
|
||||
return await self._delete_document(userid, filename, knowledge_base_id, db_type)
|
||||
elif action == "delete_knowledge_base":
|
||||
userid = params.get("userid", "")
|
||||
knowledge_base_id = params.get("knowledge_base_id", "")
|
||||
@ -127,13 +111,14 @@ class MilvusDBConnection(BaseDBConnection):
|
||||
knowledge_base_ids = params.get("knowledge_base_ids", [])
|
||||
limit = params.get("limit", 5)
|
||||
offset = params.get("offset", 0)
|
||||
db_type = params.get("db_type", "")
|
||||
if not query_vector or not userid or not knowledge_base_ids:
|
||||
return {"status": "error", "message": "query_vector、userid 或 knowledge_base_ids 不能为空",
|
||||
"collection_name": collection_name, "document_id": "", "status_code": 400}
|
||||
if limit < 1 or limit > 16384:
|
||||
return {"status": "error", "message": "limit 必须在 1 到 16384 之间",
|
||||
"collection_name": collection_name, "document_id": "", "status_code": 400}
|
||||
return await self._search_query(collection_name, query_vector, userid, knowledge_base_ids, limit, offset)
|
||||
return await self._search_query(query_vector, userid, knowledge_base_ids, limit, offset, db_type)
|
||||
elif action == "list_user_files":
|
||||
userid = params.get("userid", "")
|
||||
if not userid:
|
||||
@ -300,36 +285,84 @@ class MilvusDBConnection(BaseDBConnection):
|
||||
"message": str(e)
|
||||
}
|
||||
|
||||
async def _insert_document(self, collection_name: str, userid: str, knowledge_base_id: str, document_id: str,
|
||||
texts: List[str], embeddings: List[List[float]], filename: str, file_path: str,
|
||||
upload_time: str, file_type: str) -> Dict[str, Any]:
|
||||
async def _insert_document(self, chunks: List[Dict], db_type: str = "") -> Dict[str, Any]:
|
||||
"""插入文档到 Milvus"""
|
||||
collection_name = "ragdb" if not db_type else f"ragdb_{db_type}"
|
||||
document_id = chunks[0]["document_id"] if chunks else ""
|
||||
try:
|
||||
# 检查集合是否存在
|
||||
create_result = await self._create_collection(collection_name.split('_')[-1] if '_' in collection_name else "")
|
||||
create_result = await self._create_collection(db_type)
|
||||
if create_result["status"] == "error":
|
||||
raise RuntimeError(f"集合创建失败: {create_result['message']}")
|
||||
|
||||
# 检查输入数据
|
||||
if len(texts) != len(embeddings):
|
||||
raise ValueError("texts 和 embeddings 的长度必须一致")
|
||||
if not all(isinstance(emb, list) and len(emb) == 1024 for emb in embeddings):
|
||||
raise ValueError("embeddings 必须是长度为 1024 的浮点数列表")
|
||||
if not chunks:
|
||||
raise ValueError("chunks 不能为空")
|
||||
for chunk in chunks:
|
||||
if not isinstance(chunk, dict):
|
||||
raise ValueError("每个 chunk 必须是一个字典")
|
||||
required_fields = ["text", "vector", "document_id", "filename", "file_path", "upload_time", "file_type",
|
||||
"userid", "knowledge_base_id"]
|
||||
if not all(k in chunk for k in required_fields):
|
||||
raise ValueError(f"chunk 缺少必要字段: {', '.join(set(required_fields) - set(chunk.keys()))}")
|
||||
if not isinstance(chunk["vector"], list) or len(chunk["vector"]) != 1024:
|
||||
raise ValueError("vector 必须是长度为 1024 的浮点数列表")
|
||||
|
||||
# 插入 Milvus
|
||||
# 验证 userid 和 knowledge_base_id 一致性
|
||||
if len(set(chunk["userid"] for chunk in chunks)) > 1:
|
||||
raise ValueError("所有 chunk 的 userid 必须一致")
|
||||
if len(set(chunk["knowledge_base_id"] for chunk in chunks)) > 1:
|
||||
raise ValueError("所有 chunk 的 knowledge_base_id 必须一致")
|
||||
if len(set(chunk["filename"] for chunk in chunks)) > 1:
|
||||
raise ValueError("所有 chunk 的 filename 必须一致")
|
||||
|
||||
# 检查是否已存在相同的 userid、knowledge_base_id 和 filename
|
||||
collection = Collection(collection_name)
|
||||
collection.load()
|
||||
expr = f"userid == '{chunks[0]['userid']}' and knowledge_base_id == '{chunks[0]['knowledge_base_id']}' and filename == '{chunks[0]['filename']}'"
|
||||
debug(f"检查重复文档: {expr}")
|
||||
results = collection.query(expr=expr, output_fields=["document_id"], limit=1)
|
||||
if results:
|
||||
debug(
|
||||
f"找到重复文档: userid={chunks[0]['userid']}, knowledge_base_id={chunks[0]['knowledge_base_id']}, filename={chunks[0]['filename']}")
|
||||
return {
|
||||
"status": "error",
|
||||
"document_id": document_id,
|
||||
"collection_name": collection_name,
|
||||
"message": f"文档已存在: userid={chunks[0]['userid']}, knowledge_base_id={chunks[0]['knowledge_base_id']}, filename={chunks[0]['filename']}",
|
||||
"status_code": 400
|
||||
}
|
||||
|
||||
# 提取数据
|
||||
userids = [chunk["userid"] for chunk in chunks]
|
||||
knowledge_base_ids = [chunk["knowledge_base_id"] for chunk in chunks]
|
||||
texts = [chunk["text"] for chunk in chunks]
|
||||
embeddings = [chunk["vector"] for chunk in chunks]
|
||||
document_ids = [chunk["document_id"] for chunk in chunks]
|
||||
filenames = [chunk["filename"] for chunk in chunks]
|
||||
file_paths = [chunk["file_path"] for chunk in chunks]
|
||||
upload_times = [chunk["upload_time"] for chunk in chunks]
|
||||
file_types = [chunk["file_type"] for chunk in chunks]
|
||||
|
||||
# 构造插入数据
|
||||
data = {
|
||||
"userid": [userid] * len(texts),
|
||||
"knowledge_base_id": [knowledge_base_id] * len(texts),
|
||||
"document_id": [document_id] * len(texts),
|
||||
"userid": userids,
|
||||
"knowledge_base_id": knowledge_base_ids,
|
||||
"document_id": document_ids,
|
||||
"text": texts,
|
||||
"vector": embeddings,
|
||||
"filename": [filename] * len(texts),
|
||||
"file_path": [file_path] * len(texts),
|
||||
"upload_time": [upload_time] * len(texts),
|
||||
"file_type": [file_type] * len(texts),
|
||||
"filename": filenames,
|
||||
"file_path": file_paths,
|
||||
"upload_time": upload_times,
|
||||
"file_type": file_types,
|
||||
}
|
||||
|
||||
schema_fields = [field.name for field in collection.schema.fields if field.name != "pk"]
|
||||
debug(f"Schema fields: {schema_fields}")
|
||||
debug(f"Data keys: {list(data.keys())}")
|
||||
if list(data.keys()) != schema_fields:
|
||||
raise ValueError(f"数据字段顺序不匹配,期望: {schema_fields}, 实际: {list(data.keys())}")
|
||||
|
||||
collection.insert([data[field.name] for field in collection.schema.fields if field.name != "pk"])
|
||||
collection.flush()
|
||||
debug(f"成功插入 {len(texts)} 个文档到集合 {collection_name}")
|
||||
@ -340,8 +373,17 @@ class MilvusDBConnection(BaseDBConnection):
|
||||
"message": f"成功插入 {len(texts)} 个文档到 {collection_name}",
|
||||
"status_code": 200
|
||||
}
|
||||
except MilvusException as e:
|
||||
error(f"Milvus 插入失败: {str(e)}, 堆栈: {traceback.format_exc()}")
|
||||
return {
|
||||
"status": "error",
|
||||
"document_id": document_id,
|
||||
"collection_name": collection_name,
|
||||
"message": f"Milvus 插入失败: {str(e)}",
|
||||
"status_code": 400
|
||||
}
|
||||
except Exception as e:
|
||||
error(f"插入文档失败: {str(e)}")
|
||||
error(f"插入文档失败: {str(e)}, 堆栈: {traceback.format_exc()}")
|
||||
return {
|
||||
"status": "error",
|
||||
"document_id": document_id,
|
||||
@ -350,7 +392,8 @@ class MilvusDBConnection(BaseDBConnection):
|
||||
"status_code": 400
|
||||
}
|
||||
|
||||
async def _delete_document(self, db_type: str, userid: str, filename: str, knowledge_base_id: str) -> Dict[str, Any]:
|
||||
async def _delete_document(self, userid: str, filename: str, knowledge_base_id: str, db_type: str = "") -> Dict[
|
||||
str, Any]:
|
||||
"""删除用户指定文件数据,仅处理 Milvus 记录"""
|
||||
collection_name = "ragdb" if not db_type else f"ragdb_{db_type}"
|
||||
try:
|
||||
@ -550,9 +593,10 @@ class MilvusDBConnection(BaseDBConnection):
|
||||
"status_code": 400
|
||||
}
|
||||
|
||||
async def _search_query(self, collection_name: str, query_vector: List[float], userid: str,
|
||||
knowledge_base_ids: List[str], limit: int = 5, offset: int = 0) -> Dict[str, Any]:
|
||||
async def _search_query(self, query_vector: List[float], userid: str,
|
||||
knowledge_base_ids: List[str], limit: int = 5, offset: int = 0, db_type: str = "") -> Dict[str, Any]:
|
||||
"""基于向量搜索 Milvus 集合"""
|
||||
collection_name = "ragdb" if not db_type else f"ragdb_{db_type}"
|
||||
timing_stats = {}
|
||||
start_time = time.time()
|
||||
try:
|
||||
|
||||
Loading…
x
Reference in New Issue
Block a user