rag/rag/milvus_connection.py

978 lines
52 KiB
Python
Raw Blame History

This file contains ambiguous Unicode characters

This file contains Unicode characters that might be confused with other characters. If you think that this is intentional, you can safely ignore this warning. Use the Escape button to reveal them.

import os
from appPublic.log import debug, error, info
from base_connection import connection_register
from typing import Dict, List, Any
import numpy as np
import aiohttp
from aiohttp import ClientSession, ClientTimeout
from langchain_core.documents import Document
from langchain_text_splitters import RecursiveCharacterTextSplitter
import uuid
from datetime import datetime
from filetxt.loader import fileloader
import time
from tenacity import retry, stop_after_attempt, wait_exponential, retry_if_exception_type
import traceback
import asyncio
import re
# 嵌入缓存
EMBED_CACHE = {}
class MilvusConnection:
def __init__(self):
pass
@retry(stop = stop_after_attempt(3))
async def _make_neo4japi_request(self, action: str, params: Dict[str, Any]) -> Dict[str, Any]:
debug(f"开始API请求action={action}, params={params}")
try:
async with ClientSession(timeout=ClientTimeout(total=300)) as session:
url = f"http://localhost:8885/v1/{action}"
debug(f"发起POST请求{url}")
async with session.post(
url,
headers={'Content-Type': 'application/json'},
json=params
) as response:
debug(f"收到相应: status={response.status}, headers={response.headers}")
respose_text = await response.text()
debug(f"响应内容: {respose_text}")
result = await response.json()
debug(f"API响应内容: {result}")
if response.status == 400:
debug(f"客户端错误,状态码: {response.status},返回响应: {result}")
return result
if response.status != 200:
error(f"API 调用失败,动作: {action}, 状态码: {response.status}, 响应: {response_text}")
raise RuntimeError(f"API 调用失败: {response.status}")
debug(f"API 调用成功: {action}, 响应: {result}")
return result
except Exception as e:
error(f"API 调用失败: {action}, 错误: {str(e)}, 堆栈: {traceback.format_exc()}")
raise RuntimeError(f"API 调用失败: {str(e)}")
@retry(stop=stop_after_attempt(3))
async def _make_api_request(self, action: str, params: Dict[str, Any]) -> Dict[str, Any]:
debug(f"开始 API 请求: action={action}, params={params}")
try:
async with ClientSession(timeout=ClientTimeout(total=300)) as session:
url = f"http://localhost:8886/v1/{action}"
debug(f"发起 POST 请求: {url}")
async with session.post(
url,
headers={"Content-Type": "application/json"},
json=params
) as response:
debug(f"收到响应: status={response.status}, headers={response.headers}")
response_text = await response.text()
debug(f"响应内容: {response_text}")
result = await response.json()
debug(f"API 响应内容: {result}")
if response.status == 400: # 客户端错误,直接返回
debug(f"客户端错误,状态码: {response.status}, 返回响应: {result}")
return result
if response.status != 200:
error(f"API 调用失败,动作: {action}, 状态码: {response.status}, 响应: {response_text}")
raise RuntimeError(f"API 调用失败: {response.status}")
debug(f"API 调用成功: {action}, 响应: {result}")
return result
except Exception as e:
error(f"API 调用失败: {action}, 错误: {str(e)}, 堆栈: {traceback.format_exc()}")
raise RuntimeError(f"API 调用失败: {str(e)}")
async def handle_connection(self, action: str, params: Dict = None) -> Dict:
"""处理数据库操作"""
try:
debug(f"处理操作: action={action}, params={params}")
if not params:
params = {}
# 通用 db_type 验证
db_type = params.get("db_type", "")
collection_name = "ragdb" if not db_type else f"ragdb_{db_type}"
if db_type and "_" in db_type:
return {"status": "error", "message": "db_type 不能包含下划线", "collection_name": collection_name,
"document_id": "", "status_code": 400}
if db_type and len(db_type) > 100:
return {"status": "error", "message": "db_type 的长度应小于 100", "collection_name": collection_name,
"document_id": "", "status_code": 400}
if action == "initialize":
return {"status": "success", "message": "Milvus 服务已初始化"}
elif action == "get_params":
return {"status": "success", "params": {}}
elif action == "create_collection":
return await self._create_collection(db_type)
elif action == "delete_collection":
return await self._delete_collection(db_type)
elif action == "insert_document":
file_path = params.get("file_path", "")
userid = params.get("userid", "")
knowledge_base_id = params.get("knowledge_base_id", "")
document_id = params.get("document_id", "")
if not file_path or not userid or not knowledge_base_id or not document_id:
return {"status": "error", "message": "file_path、userid document_id和 knowledge_base_id 不能为空",
"collection_name": collection_name, "document_id": "", "status_code": 400}
if "_" in userid or "_" in knowledge_base_id:
return {"status": "error", "message": "userid 和 knowledge_base_id 不能包含下划线",
"collection_name": collection_name, "document_id": document_id, "status_code": 400}
if len(knowledge_base_id) > 100:
return {"status": "error", "message": "knowledge_base_id 的长度应小于 100",
"collection_name": collection_name, "document_id": "", "status_code": 400}
return await self._insert_document(file_path, userid, knowledge_base_id, document_id, db_type)
elif action == "delete_document":
userid = params.get("userid", "")
file_path = params.get("file_path", "")
knowledge_base_id = params.get("knowledge_base_id", "")
document_id = params.get("document_id", "")
if not userid or not file_path or not knowledge_base_id or not document_id:
return {"status": "error", "message": "userid、file_path document_id和 knowledge_base_id 不能为空",
"collection_name": collection_name, "document_id": "", "status_code": 400}
if "_" in userid or "_" in knowledge_base_id:
return {"status": "error", "message": "userid 和 knowledge_base_id 不能包含下划线",
"collection_name": collection_name, "document_id": "", "status_code": 400}
if len(userid) > 100 or len(file_path) > 255 or len(knowledge_base_id) > 100:
return {"status": "error", "message": "userid、file_path 或 knowledge_base_id 的长度超出限制",
"collection_name": collection_name, "document_id": "", "status_code": 400}
return await self._delete_document(userid, file_path, knowledge_base_id, document_id, db_type)
elif action == "delete_knowledge_base":
userid = params.get("userid", "")
knowledge_base_id = params.get("knowledge_base_id", "")
if not userid or not knowledge_base_id:
return {"status": "error", "message": "userid 和 knowledge_base_id 不能为空",
"collection_name": collection_name, "document_id": "", "status_code": 400}
if "_" in userid or "_" in knowledge_base_id:
return {"status": "error", "message": "userid 和 knowledge_base_id 不能包含下划线",
"collection_name": collection_name, "document_id": "", "status_code": 400}
if len(userid) > 100 or len(knowledge_base_id) > 100:
return {"status": "error", "message": "userid 或 knowledge_base_id 的长度超出限制",
"collection_name": collection_name, "document_id": "", "status_code": 400}
return await self._delete_knowledge_base(db_type, userid, knowledge_base_id)
elif action == "search_query":
query = params.get("query", "")
userid = params.get("userid", "")
knowledge_base_ids = params.get("knowledge_base_ids", [])
limit = params.get("limit", 5)
offset = params.get("offset", 0)
db_type = params.get("db_type", "")
use_rerank = params.get("use_rerank", True)
if not query or not userid or not knowledge_base_ids:
return {"status": "error", "message": "query、userid 或 knowledge_base_ids 不能为空",
"collection_name": collection_name, "document_id": "", "status_code": 400}
return await self._search_query(query, userid, knowledge_base_ids, limit, offset, use_rerank, db_type)
elif action == "fused_search":
query = params.get("query", "")
userid = params.get("userid", "")
knowledge_base_ids = params.get("knowledge_base_ids", [])
limit = params.get("limit", 5)
offset = params.get("offset", 0)
db_type = params.get("db_type", "")
use_rerank = params.get("use_rerank", True)
if not query or not userid or not knowledge_base_ids:
return {"status": "error", "message": "query、userid 或 knowledge_base_ids 不能为空",
"collection_name": collection_name, "document_id": "", "status_code": 400}
return await self._fused_search(query, userid, knowledge_base_ids, limit, offset, use_rerank, db_type)
elif action == "list_user_files":
userid = params.get("userid", "")
if not userid:
return {"status": "error", "message": "userid 不能为空", "collection_name": collection_name,
"document_id": "", "status_code": 400}
return await self._list_user_files(userid, db_type)
elif action == "list_all_knowledge_bases":
return await self._list_all_knowledge_bases(db_type)
else:
return {"status": "error", "message": f"未知的 action: {action}", "collection_name": collection_name,
"document_id": "", "status_code": 400}
except Exception as e:
error(f"处理操作失败: action={action}, 错误: {str(e)}")
return {
"status": "error",
"message": f"服务器错误: {str(e)}",
"collection_name": collection_name,
"document_id": "",
"status_code": 400
}
async def _create_collection(self, db_type: str = "") -> Dict[str, Any]:
"""创建 Milvus 集合"""
collection_name = "ragdb" if not db_type else f"ragdb_{db_type}"
try:
if len(collection_name) > 255:
raise ValueError(f"集合名称 {collection_name} 超过 255 个字符")
if len(db_type) > 100:
raise ValueError("db_type 的长度应小于 100")
debug(f"调用创建集合端点: {collection_name}, 参数: {{'db_type': '{db_type}'}}")
result = await self._make_api_request("createcollection", {"db_type": db_type})
return result
except Exception as e:
error(f"创建集合失败: {str(e)}, 堆栈: {traceback.format_exc()}")
return {
"status": "error",
"collection_name": collection_name,
"message": f"创建集合失败: {str(e)}",
"status_code": 400
}
async def _delete_collection(self, db_type: str = "") -> Dict:
"""删除 Milvus 集合通过服务化端点"""
try:
collection_name = "ragdb" if not db_type else f"ragdb_{db_type}"
if len(collection_name) > 255:
raise ValueError(f"集合名称 {collection_name} 超过 255 个字符")
if db_type and "_" in db_type:
raise ValueError("db_type 不能包含下划线")
if db_type and len(db_type) > 100:
raise ValueError("db_type 的长度应小于 100")
debug(f"调用删除集合端点: {collection_name}")
result = await self._make_api_request("deletecollection", {"db_type": db_type})
return result
except Exception as e:
error(f"删除集合失败: {str(e)}")
return {
"status": "error",
"collection_name": collection_name,
"message": str(e),
"status_code": 400
}
async def _insert_document(self, file_path: str, userid: str, knowledge_base_id: str, document_id:str, db_type: str = "") -> Dict[
str, Any]:
"""将文档插入 Milvus 并抽取三元组到 Neo4j"""
collection_name = "ragdb" if not db_type else f"ragdb_{db_type}"
debug(
f'Inserting document: file_path={file_path}, userid={userid}, db_type={db_type}, knowledge_base_id={knowledge_base_id}, document_id={document_id}')
timings = {}
start_total = time.time()
try:
# 验证参数
if not userid or not knowledge_base_id:
raise ValueError("userid 和 knowledge_base_id 不能为空")
if "_" in userid or "_" in knowledge_base_id:
raise ValueError("userid 和 knowledge_base_id 不能包含下划线")
if len(userid) > 100 or len(knowledge_base_id) > 100:
raise ValueError("userid 或 knowledge_base_id 的长度超出限制")
if not os.path.exists(file_path):
raise ValueError(f"文件 {file_path} 不存在")
supported_formats = {'pdf', 'docx', 'xlsx', 'pptx', 'csv', 'txt'}
ext = file_path.rsplit('.', 1)[1].lower() if '.' in file_path else ''
if ext not in supported_formats:
raise ValueError(f"不支持的文件格式: {ext}, 支持的格式: {', '.join(supported_formats)}")
info(f"生成 document_id: {document_id} for file: {file_path}")
# 文件加载
debug(f"加载文件: {file_path}")
start_load = time.time()
text = fileloader(file_path)
text = re.sub(r'[^\u4e00-\u9fa5a-zA-Z0-9\s.;,\n]', '', text)
timings["load_file"] = time.time() - start_load
debug(f"加载文件耗时: {timings['load_file']:.2f} 秒, 文本长度: {len(text)}")
if not text or not text.strip():
raise ValueError(f"文件 {file_path} 加载为空")
# 文本分片
document = Document(page_content=text)
text_splitter = RecursiveCharacterTextSplitter(
chunk_size=500,
chunk_overlap=100,
length_function=len,
)
debug("开始分片文件内容")
start_split = time.time()
chunks = text_splitter.split_documents([document])
timings["split_text"] = time.time() - start_split
debug(
f"文本分片耗时: {timings['split_text']:.2f} 秒, 分片数量: {len(chunks)}, 分片内容: {[chunk.page_content[:50] for chunk in chunks[:5]]}")
if not chunks:
raise ValueError(f"文件 {file_path} 未生成任何文档块")
filename = os.path.basename(file_path).rsplit('.', 1)[0]
upload_time = datetime.now().isoformat()
# 生成嵌入向量
debug("调用嵌入服务生成向量")
start_embedding = time.time()
texts = [chunk.page_content for chunk in chunks]
embeddings = await self._get_embeddings(texts)
if not embeddings or not all(len(vec) == 1024 for vec in embeddings):
raise ValueError("所有嵌入向量必须是长度为 1024 的浮点数列表")
timings["generate_embeddings"] = time.time() - start_embedding
debug(f"生成嵌入向量耗时: {timings['generate_embeddings']:.2f} 秒, 嵌入数量: {len(embeddings)}")
# 构造 chunks 参数(展平结构)
chunks_data = []
for i, chunk in enumerate(chunks):
chunks_data.append({
"userid": userid,
"knowledge_base_id": knowledge_base_id,
"text": chunk.page_content,
"vector": embeddings[i].tolist(),
"document_id": document_id,
"filename": filename + '.' + ext,
"file_path": file_path,
"upload_time": upload_time,
"file_type": ext,
})
# 调用 Milvus 插入端点
debug(f"调用插入文件端点: {file_path}")
start_milvus = time.time()
result = await self._make_api_request("insertdocument", {
"chunks": chunks_data,
"db_type": db_type,
})
timings["insert_milvus"] = time.time() - start_milvus
debug(f"Milvus 插入耗时: {timings['insert_milvus']:.2f}")
if result.get("status") != "success":
timings["total"] = time.time() - start_total
return {
"status": "error",
"document_id": document_id,
"collection_name": collection_name,
"timings": timings,
"message": result.get("message", "未知错误"),
"status_code": 400
}
# 三元组抽取
debug("调用三元组抽取服务")
start_triples = time.time()
try:
chunk_texts = [doc.page_content for doc in chunks]
debug(f"处理 {len(chunk_texts)} 个分片进行三元组抽取")
tasks = [self._extract_triples(chunk) for chunk in chunk_texts]
results = await asyncio.gather(*tasks, return_exceptions=True)
triples = []
for i, result in enumerate(results):
if isinstance(result, list):
triples.extend(result)
debug(f"分片 {i + 1} 抽取到 {len(result)} 个三元组: {result[:5]}")
else:
error(f"分片 {i + 1} 处理失败: {str(result)}")
# 去重
unique_triples = []
seen = set()
for t in triples:
identifier = (t['head'].lower(), t['tail'].lower(), t['type'].lower())
if identifier not in seen:
seen.add(identifier)
unique_triples.append(t)
else:
for existing in unique_triples:
if (existing['head'].lower() == t['head'].lower() and
existing['tail'].lower() == t['tail'].lower() and
len(t['type']) > len(existing['type'])):
unique_triples.remove(existing)
unique_triples.append(t)
debug(f"替换三元组为更具体类型: {t}")
break
timings["extract_triples"] = time.time() - start_triples
debug(
f"三元组抽取耗时: {timings['extract_triples']:.2f} 秒, 抽取到 {len(unique_triples)} 个三元组: {unique_triples[:5]}")
# Neo4j 插入
debug(f"抽取到 {len(unique_triples)} 个三元组调用Neo4j服务插入")
start_neo4j = time.time()
if unique_triples:
neo4j_result = await self._make_neo4japi_request("inserttriples", {
"triples":unique_triples,
"document_id": document_id,
"knowledge_base_id": knowledge_base_id,
"userid": userid
})
debug(f"Neo4j服务响应: {neo4j_result}")
if neo4j_result.get("status") != "success":
timings["insert_neo4j"] = time.time() - start_neo4j
timings["total"] = time.time() - start_total
return{
"status": "error",
"document_id": document_id,
"collection_name": collection_name,
"timings": timings,
"message": f"Neo4j 三元组插入失败: {neo4j_result.get('message', '未知错误')}",
"status_code": 400
}
info(f"文件 {file_path} 三元组成功插入 Neo4j: {neo4j_result.get('message')}")
else:
debug(f"文件 {file_path} 未抽取到三元组")
timings["insert_neo4j"] = time.time() - start_neo4j
debug(f"Neo4j 插入耗时: {timings['insert_neo4j']:.2f}")
except Exception as e:
timings["extract_triples"] = time.time() - start_triples if "extract_triples" not in timings else \
timings["extract_triples"]
timings["insert_neo4j"] = time.time() - start_neo4j
debug(f"处理三元组或 Neo4j 插入失败: {str(e)}, 堆栈: {traceback.format_exc()}")
timings["total"] = time.time() - start_total
return {
"status": "success",
"document_id": document_id,
"collection_name": collection_name,
"timings": timings,
"unique_triples": unique_triples,
"message": f"文件 {file_path} 成功嵌入,但三元组处理或 Neo4j 插入失败: {str(e)}",
"status_code": 200
}
timings["total"] = time.time() - start_total
debug(f"总耗时: {timings['total']:.2f}")
return {
"status": "success",
"document_id": document_id,
"collection_name": collection_name,
"timings": timings,
"unique_triples": unique_triples,
"message": f"文件 {file_path} 成功嵌入并处理三元组",
"status_code": 200
}
except Exception as e:
error(f"插入文档失败: {str(e)}, 堆栈: {traceback.format_exc()}")
timings["total"] = time.time() - start_total
return {
"status": "error",
"document_id": document_id,
"collection_name": collection_name,
"timings": timings,
"message": f"插入文档失败: {str(e)}",
"status_code": 400
}
@retry(
stop=stop_after_attempt(3),
wait=wait_exponential(multiplier=1, min=1, max=10),
retry=retry_if_exception_type((aiohttp.ClientError, RuntimeError)),
before_sleep=lambda retry_state: debug(f"重试三元组抽取服务,第 {retry_state.attempt_number}")
)
async def _get_embeddings(self, texts: List[str]) -> List[List[float]]:
"""调用嵌入服务获取文本的向量,带缓存"""
try:
# 检查缓存
uncached_texts = [text for text in texts if text not in EMBED_CACHE]
if uncached_texts:
async with aiohttp.ClientSession() as session:
async with session.post(
"http://localhost:9998/v1/embeddings",
headers={"Content-Type": "application/json"},
json={"input": uncached_texts}
) as response:
if response.status != 200:
error(f"嵌入服务调用失败,状态码: {response.status}")
raise RuntimeError(f"嵌入服务调用失败: {response.status}")
result = await response.json()
if result.get("object") != "list" or not result.get("data"):
error(f"嵌入服务响应格式错误: {result}")
raise RuntimeError("嵌入服务响应格式错误")
embeddings = [item["embedding"] for item in result["data"]]
for text, embedding in zip(uncached_texts, embeddings):
EMBED_CACHE[text] = np.array(embedding) / np.linalg.norm(embedding)
debug(f"成功获取 {len(embeddings)} 个新嵌入向量,缓存大小: {len(EMBED_CACHE)}")
# 返回缓存中的嵌入
return [EMBED_CACHE[text] for text in texts]
except Exception as e:
error(f"嵌入服务调用失败: {str(e)}")
raise RuntimeError(f"嵌入服务调用失败: {str(e)}")
async def _extract_triples(self, text: str) -> List[Dict]:
"""调用三元组抽取服务,无超时限制"""
request_id = str(uuid.uuid4()) # 为每个请求生成唯一 ID
start_time = time.time()
debug(f"Request #{request_id} started for triples extraction")
try:
async with aiohttp.ClientSession(
connector=aiohttp.TCPConnector(limit=30),
timeout=aiohttp.ClientTimeout(total=None) # 无限等待
) as session:
async with session.post(
"http://localhost:9991/v1/triples",
headers={"Content-Type": "application/json; charset=utf-8"},
json={"text": text}
) as response:
elapsed_time = time.time() - start_time
debug(f"Request #{request_id} received response, status: {response.status}, took {elapsed_time:.2f} seconds")
if response.status != 200:
error_text = await response.text()
error(f"Request #{request_id} failed, status: {response.status}, response: {error_text}")
raise RuntimeError(f"三元组抽取服务调用失败: {response.status}, {error_text}")
result = await response.json()
if result.get("object") != "list" or not result.get("data"):
error(f"Request #{request_id} invalid response format: {result}")
raise RuntimeError("三元组抽取服务响应格式错误")
triples = result["data"]
debug(f"Request #{request_id} extracted {len(triples)} triples, total time: {elapsed_time:.2f} seconds")
return triples
except Exception as e:
elapsed_time = time.time() - start_time
error(f"Request #{request_id} failed to extract triples: {str(e)}, took {elapsed_time:.2f} seconds")
debug(f"Request #{request_id} traceback: {traceback.format_exc()}")
raise RuntimeError(f"三元组抽取服务调用失败: {str(e)}")
async def _delete_document(self, userid: str, file_path: str, knowledge_base_id: str, document_id:str, db_type: str = "") -> Dict[str, Any]:
"""删除用户指定文件数据,包括 Milvus 和 Neo4j 中的记录"""
collection_name = "ragdb" if not db_type else f"ragdb_{db_type}"
try:
# 调用 Milvus 删除文件端点
debug(f"调用删除文件端点: userid={userid}, file_path={file_path}, knowledge_base_id={knowledge_base_id}, document_id={document_id}")
milvus_result = await self._make_api_request("deletedocument", {
"userid": userid,
"file_path": file_path,
"knowledge_base_id": knowledge_base_id,
"document_id": document_id,
"db_type": db_type
})
if milvus_result.get("status") != "success":
error(f"Milvus 删除文件失败: {milvus_result.get('message', '未知错误')}")
return milvus_result
# 调用 Neo4j 删除端点
neo4j_deleted_nodes = 0
neo4j_deleted_rels = 0
try:
debug(f"调用 Neo4j 删除文档端点: document_id={document_id}")
neo4j_result = await self._make_neo4japi_request("deletedocument", {
"document_id": document_id
})
if neo4j_result.get("status") != "success":
error(
f"Neo4j 删除文档失败: document_id={document_id}, 错误: {neo4j_result.get('message', '未知错误')}")
nodes_deleted = neo4j_result.get("nodes_deleted", 0)
rels_deleted = neo4j_result.get("rels_deleted", 0)
neo4j_deleted_nodes += nodes_deleted
neo4j_deleted_rels += rels_deleted
info(f"成功删除 document_id={document_id}{nodes_deleted} 个 Neo4j 节点和 {rels_deleted} 个关系")
except Exception as e:
error(f"删除 document_id={document_id} 的 Neo4j 数据失败: {str(e)}")
return {
"status": "success",
"collection_name": collection_name,
"document_id": document_id,
"message": f"成功删除 Milvus 记录,{neo4j_deleted_nodes} 个 Neo4j 节点,{neo4j_deleted_rels} 个 Neo4j 关系",
"status_code": 200
}
except Exception as e:
error(f"删除文档失败: {str(e)}")
return {
"status": "error",
"collection_name": collection_name,
"document_id": document_id,
"message": f"删除文档失败: {str(e)}",
"status_code": 400
}
async def _delete_knowledge_base(self, db_type: str, userid: str, knowledge_base_id: str) -> Dict[str, Any]:
"""删除用户的整个知识库,包括 Milvus 和 Neo4j 中的记录"""
collection_name = "ragdb" if not db_type else f"ragdb_{db_type}"
try:
# 调用 Milvus 删除知识库端点
debug(f"调用删除知识库端点: userid={userid}, knowledge_base_id={knowledge_base_id}")
milvus_result = await self._make_api_request("deleteknowledgebase", {
"userid": userid,
"knowledge_base_id": knowledge_base_id,
"db_type": db_type
})
if milvus_result.get("status") != "success":
error(f"Milvus 删除知识库失败: {milvus_result.get('message', '未知错误')}")
return milvus_result
deleted_files = milvus_result.get("deleted_files", [])
# 新增:调用 Neo4j 删除知识库端点
neo4j_deleted_nodes = 0
neo4j_deleted_rels = 0
try:
debug(f"调用 Neo4j 删除知识库端点: userid={userid}, knowledge_base_id={knowledge_base_id}")
neo4j_result = await self._make_neo4japi_request("deleteknowledgebase", {
"userid": userid,
"knowledge_base_id": knowledge_base_id
})
if neo4j_result.get("status") == "success":
neo4j_deleted_nodes = neo4j_result.get("nodes_deleted", 0)
neo4j_deleted_rels = neo4j_result.get("rels_deleted", 0)
info(f"成功删除 {neo4j_deleted_nodes} 个 Neo4j 节点和 {neo4j_deleted_rels} 个关系")
else:
error(f"Neo4j 删除知识库失败: {neo4j_result.get('message', '未知错误')}")
return {
"status": "success",
"collection_name": collection_name,
"deleted_files": deleted_files,
"message": f"成功删除 Milvus 知识库,{neo4j_deleted_nodes} 个 Neo4j 节点,{neo4j_deleted_rels} 个 Neo4j 关系,但 Neo4j 删除失败: {neo4j_result.get('message')}",
"status_code": 200
}
except Exception as e:
error(f"Neo4j 删除知识库失败: {str(e)}")
return {
"status": "success",
"collection_name": collection_name,
"deleted_files": deleted_files,
"message": f"成功删除 Milvus 知识库,{neo4j_deleted_nodes} 个 Neo4j 节点,{neo4j_deleted_rels} 个 Neo4j 关系,但 Neo4j 删除失败: {str(e)}",
"status_code": 200
}
if not deleted_files and neo4j_deleted_nodes == 0 and neo4j_deleted_rels == 0:
debug(f"没有删除任何记录userid={userid}, knowledge_base_id={knowledge_base_id}")
return {
"status": "success",
"collection_name": collection_name,
"deleted_files": [],
"message": f"没有找到 userid={userid}, knowledge_base_id={knowledge_base_id} 的记录,无需删除",
"status_code": 200
}
info(
f"总计删除 Milvus 记录,{neo4j_deleted_nodes} 个 Neo4j 节点,{neo4j_deleted_rels} 个 Neo4j 关系,删除文件: {deleted_files}")
return {
"status": "success",
"collection_name": collection_name,
"deleted_files": deleted_files,
"message": f"成功删除 Milvus 记录,{neo4j_deleted_nodes} 个 Neo4j 节点,{neo4j_deleted_rels} 个 Neo4j 关系,删除文件: {deleted_files}",
"status_code": 200
}
except Exception as e:
error(f"删除知识库失败: {str(e)}")
return {
"status": "error",
"collection_name": collection_name,
"deleted_files": [],
"message": f"删除知识库失败: {str(e)}",
"status_code": 400
}
async def _extract_entities(self, query: str) -> List[str]:
"""调用实体识别服务"""
try:
if not query:
raise ValueError("查询文本不能为空")
async with aiohttp.ClientSession() as session:
async with session.post(
"http://localhost:9990/v1/entities",
headers={"Content-Type": "application/json"},
json={"query": query}
) as response:
if response.status != 200:
error(f"实体识别服务调用失败,状态码: {response.status}")
raise RuntimeError(f"实体识别服务调用失败: {response.status}")
result = await response.json()
if result.get("object") != "list" or not result.get("data"):
error(f"实体识别服务响应格式错误: {result}")
raise RuntimeError("实体识别服务响应格式错误")
entities = result["data"]
unique_entities = list(dict.fromkeys(entities)) # 去重
debug(f"成功提取 {len(unique_entities)} 个唯一实体: {unique_entities}")
return unique_entities
except Exception as e:
error(f"实体识别服务调用失败: {str(e)}")
return []
async def _rerank_results(self, query: str, results: List[Dict], top_n: int) -> List[Dict]:
"""调用重排序服务"""
try:
if not results:
debug("无结果需要重排序")
return results
if not isinstance(top_n, int) or top_n < 1:
debug(f"无效的 top_n 参数: {top_n},使用 len(results)={len(results)}")
top_n = len(results)
else:
top_n = min(top_n, len(results))
debug(f"重排序 top_n={top_n}, 原始结果数={len(results)}")
documents = [result["text"] for result in results]
async with aiohttp.ClientSession() as session:
async with session.post(
"http://localhost:9997/v1/rerank",
headers={"Content-Type": "application/json"},
json={
"model": "rerank-001",
"query": query,
"documents": documents,
"top_n": top_n
}
) as response:
if response.status != 200:
error(f"重排序服务调用失败,状态码: {response.status}")
raise RuntimeError(f"重排序服务调用失败: {response.status}")
result = await response.json()
if result.get("object") != "rerank.result" or not result.get("data"):
error(f"重排序服务响应格式错误: {result}")
raise RuntimeError("重排序服务响应格式错误")
rerank_data = result["data"]
reranked_results = []
for item in rerank_data:
index = item["index"]
if index < len(results):
results[index]["rerank_score"] = item["relevance_score"]
reranked_results.append(results[index])
debug(f"成功重排序 {len(reranked_results)} 条结果")
return reranked_results[:top_n]
except Exception as e:
error(f"重排序服务调用失败: {str(e)}")
return results
async def _search_query(self, query: str, userid: str, knowledge_base_ids: List[str], limit: int = 5,
offset: int = 0, use_rerank: bool = True, db_type: str = "") -> Dict[str, Any]:
"""纯向量搜索,调用服务化端点"""
start_time = time.time()
collection_name = "ragdb" if not db_type else f"ragdb_{db_type}"
timing_stats = {}
try:
info(
f"开始纯向量搜索: query={query}, userid={userid}, db_type={db_type}, knowledge_base_ids={knowledge_base_ids}, limit={limit}, offset={offset}, use_rerank={use_rerank}")
if not query:
raise ValueError("查询文本不能为空")
if not userid:
raise ValueError("userid 不能为空")
if "_" in userid or (db_type and "_" in db_type):
raise ValueError("userid 和 db_type 不能包含下划线")
if (db_type and len(db_type) > 100) or len(userid) > 100:
raise ValueError("userid 或 db_type 的长度超出限制")
if limit <= 0 or limit > 16384:
raise ValueError("limit 必须在 1 到 16384 之间")
if offset < 0:
raise ValueError("offset 不能为负数")
if limit + offset > 16384:
raise ValueError("limit + offset 不能超过 16384")
if not knowledge_base_ids:
raise ValueError("knowledge_base_ids 不能为空")
for kb_id in knowledge_base_ids:
if not isinstance(kb_id, str):
raise ValueError(f"knowledge_base_id 必须是字符串: {kb_id}")
if len(kb_id) > 100:
raise ValueError(f"knowledge_base_id 长度超出 100 个字符: {kb_id}")
if "_" in kb_id:
raise ValueError(f"knowledge_base_id 不能包含下划线: {kb_id}")
# 将查询文本转换为向量
vector_start = time.time()
query_vector = await self._get_embeddings([query])
if not query_vector or not all(len(vec) == 1024 for vec in query_vector):
raise ValueError("查询向量必须是长度为 1024 的浮点数列表")
query_vector = query_vector[0] # 取第一个向量
timing_stats["vector_generation"] = time.time() - vector_start
debug(f"生成查询向量耗时: {timing_stats['vector_generation']:.3f}")
# 调用纯向量搜索端点
search_start = time.time()
result = await self._make_api_request("searchquery", {
"query_vector": query_vector.tolist(),
"userid": userid,
"knowledge_base_ids": knowledge_base_ids,
"limit": limit,
"offset": offset,
"db_type": db_type
})
timing_stats["vector_search"] = time.time() - search_start
debug(f"向量搜索耗时: {timing_stats['vector_search']:.3f}")
if result.get("status") != "success":
error(f"纯向量搜索失败: {result.get('message', '未知错误')}")
return {"results": [], "timing": timing_stats}
unique_results = result.get("results", [])
if use_rerank and unique_results:
rerank_start = time.time()
debug("开始重排序")
unique_results = await self._rerank_results(query, unique_results, limit)
unique_results = sorted(unique_results, key=lambda x: x.get('rerank_score', 0), reverse=True)
timing_stats["reranking"] = time.time() - rerank_start
debug(f"重排序耗时: {timing_stats['reranking']:.3f}")
debug(f"重排序分数分布: {[round(r.get('rerank_score', 0), 3) for r in unique_results]}")
else:
unique_results = [{k: v for k, v in r.items() if k != 'rerank_score'} for r in unique_results]
timing_stats["total_time"] = time.time() - start_time
info(f"纯向量搜索完成,返回 {len(unique_results)} 条结果,总耗时: {timing_stats['total_time']:.3f}")
return {"results": unique_results[:limit], "timing": timing_stats}
except Exception as e:
error(f"纯向量搜索失败: {str(e)}, 堆栈: {traceback.format_exc()}")
return {"results": [], "timing": timing_stats}
async def _fused_search(self, query: str, userid: str, knowledge_base_ids: List[str], limit: int = 5,
offset: int = 0, use_rerank: bool = True, db_type: str = "") -> Dict[str, Any]:
"""融合搜索,调用服务化端点"""
start_time = time.time()
collection_name = "ragdb" if not db_type else f"ragdb_{db_type}"
timing_stats = {}
try:
info(
f"开始融合搜索: query={query}, userid={userid}, db_type={db_type}, knowledge_base_ids={knowledge_base_ids}, limit={limit}, offset={offset}, use_rerank={use_rerank}")
if not query or not userid or not knowledge_base_ids:
raise ValueError("query、userid 和 knowledge_base_ids 不能为空")
if "_" in userid or (db_type and "_" in db_type):
raise ValueError("userid 和 db_type 不能包含下划线")
if (db_type and len(db_type) > 100) or len(userid) > 100:
raise ValueError("db_type 或 userid 的长度超出限制")
if limit < 1 or limit > 16384 or offset < 0:
raise ValueError("limit 必须在 1 到 16384 之间offset 必须大于或等于 0")
# 提取实体
entity_extract_start = time.time()
query_entities = await self._extract_entities(query)
timing_stats["entity_extraction"] = time.time() - entity_extract_start
debug(f"提取实体: {query_entities}, 耗时: {timing_stats['entity_extraction']:.3f}")
# 调用 Neo4j 服务进行三元组匹配
all_triplets = []
triplet_match_start = time.time()
for kb_id in knowledge_base_ids:
debug(f"调用 Neo4j 三元组匹配: knowledge_base_id={kb_id}")
try:
neo4j_result = await self._make_neo4japi_request("matchtriplets", {
"query": query,
"query_entities": query_entities,
"userid": userid,
"knowledge_base_id": kb_id
})
if neo4j_result.get("status") == "success":
triplets = neo4j_result.get("triplets", [])
all_triplets.extend(triplets)
debug(f"知识库 {kb_id} 匹配到 {len(triplets)} 个三元组: {triplets[:5]}")
else:
error(
f"Neo4j 三元组匹配失败: knowledge_base_id={kb_id}, 错误: {neo4j_result.get('message', '未知错误')}")
except Exception as e:
error(f"Neo4j 三元组匹配失败: knowledge_base_id={kb_id}, 错误: {str(e)}")
continue
timing_stats["triplet_matching"] = time.time() - triplet_match_start
debug(f"三元组匹配总耗时: {timing_stats['triplet_matching']:.3f}")
# 拼接三元组文本
triplet_text_start = time.time()
triplet_texts = []
for triplet in all_triplets:
head = triplet.get('head', '')
type_ = triplet.get('type', '')
tail = triplet.get('tail', '')
if head and type_ and tail:
triplet_texts.append(f"{head} {type_} {tail}")
else:
debug(f"无效三元组: {triplet}")
combined_text = query
if triplet_texts:
combined_text += " [三元组] " + "; ".join(triplet_texts)
debug(
f"拼接文本: {combined_text[:200]}... (总长度: {len(combined_text)}, 三元组数量: {len(triplet_texts)})")
timing_stats["triplet_text_combine"] = time.time() - triplet_text_start
debug(f"拼接三元组文本耗时: {timing_stats['triplet_text_combine']:.3f}")
# 将拼接文本转换为向量
vector_start = time.time()
query_vector = await self._get_embeddings([combined_text])
if not query_vector or not all(len(vec) == 1024 for vec in query_vector):
raise ValueError("查询向量必须是长度为 1024 的浮点数列表")
query_vector = query_vector[0] # 取第一个向量
timing_stats["vector_generation"] = time.time() - vector_start
debug(f"生成查询向量耗时: {timing_stats['vector_generation']:.3f}")
# 调用融合搜索端点
search_start = time.time()
result = await self._make_api_request("searchquery", {
"query_vector": query_vector.tolist(),
"userid": userid,
"knowledge_base_ids": knowledge_base_ids,
"limit": limit,
"offset": offset,
"db_type": db_type
})
timing_stats["vector_search"] = time.time() - search_start
debug(f"向量搜索耗时: {timing_stats['vector_search']:.3f}")
if result.get("status") != "success":
error(f"融合搜索失败: {result.get('message', '未知错误')}")
return {"results": [], "timing": timing_stats}
unique_results = result.get("results", [])
if use_rerank and unique_results:
rerank_start = time.time()
debug("开始重排序")
unique_results = await self._rerank_results(combined_text, unique_results, limit)
unique_results = sorted(unique_results, key=lambda x: x.get('rerank_score', 0), reverse=True)
timing_stats["reranking"] = time.time() - rerank_start
debug(f"重排序耗时: {timing_stats['reranking']:.3f}")
debug(f"重排序分数分布: {[round(r.get('rerank_score', 0), 3) for r in unique_results]}")
else:
unique_results = [{k: v for k, v in r.items() if k != 'rerank_score'} for r in unique_results]
timing_stats["total_time"] = time.time() - start_time
info(f"融合搜索完成,返回 {len(unique_results)} 条结果,总耗时: {timing_stats['total_time']:.3f}")
return {"results": unique_results[:limit], "timing": timing_stats}
except Exception as e:
error(f"融合搜索失败: {str(e)}, 堆栈: {traceback.format_exc()}")
return {"results": [], "timing": timing_stats}
async def _list_user_files(self, userid: str, db_type: str = "") -> Dict[str, List[Dict]]:
"""列出用户的所有知识库及其文件,按 knowledge_base_id 分组"""
collection_name = "ragdb" if not db_type else f"ragdb_{db_type}"
try:
info(f"列出用户文件: userid={userid}, db_type={db_type}")
if not userid:
raise ValueError("userid 不能为空")
if "_" in userid or (db_type and "_" in db_type):
raise ValueError("userid 和 db_type 不能包含下划线")
if (db_type and len(db_type) > 100) or len(userid) > 100:
raise ValueError("userid 或 db_type 的长度超出限制")
# 调用列出用户文件端点
result = await self._make_api_request("listuserfiles", {
"userid": userid,
"db_type": db_type
})
if result.get("status") != "success":
error(f"列出用户文件失败: {result.get('message', '未知错误')}")
return {}
return result.get("files_by_knowledge_base", {})
except Exception as e:
error(f"列出用户文件失败: {str(e)}")
return {}
async def _list_all_knowledge_bases(self, db_type: str = "") -> Dict[str, Any]:
"""列出数据库中所有用户的知识库及其文件,按用户分组"""
collection_name = "ragdb" if not db_type else f"ragdb_{db_type}"
try:
info(f"列出所有用户的知识库: db_type={db_type}")
if db_type and "_" in db_type:
raise ValueError("db_type 不能包含下划线")
if db_type and len(db_type) > 100:
raise ValueError("db_type 的长度应小于 100")
# 调用列出所有知识库端点
result = await self._make_api_request("listallknowledgebases", {
"db_type": db_type
})
return result
except Exception as e:
error(f"列出所有用户知识库失败: {str(e)}")
return {
"status": "error",
"users_knowledge_bases": {},
"collection_name": collection_name,
"message": f"列出所有用户知识库失败: {str(e)}",
"status_code": 400
}
connection_register('Rag', MilvusConnection)
info("MilvusConnection registered")