修改添加/删除文档函数

This commit is contained in:
wangmeihua 2025-08-05 17:20:24 +08:00
parent 1e43b3aaec
commit cca0402255
4 changed files with 169 additions and 289 deletions

View File

@ -357,20 +357,22 @@ async def insert_file(request, params_kw, *params, **kw):
userid = params_kw.get('userid', '') userid = params_kw.get('userid', '')
db_type = params_kw.get('db_type', '') db_type = params_kw.get('db_type', '')
knowledge_base_id = params_kw.get('knowledge_base_id', '') knowledge_base_id = params_kw.get('knowledge_base_id', '')
document_id = params_kw.get('document_id', '')
collection_name = "ragdb" if not db_type else f"ragdb_{db_type}" collection_name = "ragdb" if not db_type else f"ragdb_{db_type}"
try: try:
required_fields = ['file_path', 'userid', 'knowledge_base_id'] required_fields = ['file_path', 'userid', 'knowledge_base_id', 'document_id']
missing_fields = [field for field in required_fields if field not in params_kw or not params_kw[field]] missing_fields = [field for field in required_fields if field not in params_kw or not params_kw[field]]
if missing_fields: if missing_fields:
raise ValueError(f"缺少必填字段: {', '.join(missing_fields)}") raise ValueError(f"缺少必填字段: {', '.join(missing_fields)}")
debug( debug(
f'Calling insert_document with: file_path={file_path}, userid={userid}, db_type={db_type}, knowledge_base_id={knowledge_base_id}') f'Calling insert_document with: file_path={file_path}, userid={userid}, db_type={db_type}, knowledge_base_id={knowledge_base_id}, document_id={document_id}')
result = await engine.handle_connection("insert_document", { result = await engine.handle_connection("insert_document", {
"file_path": file_path, "file_path": file_path,
"userid": userid, "userid": userid,
"db_type": db_type, "db_type": db_type,
"knowledge_base_id": knowledge_base_id "knowledge_base_id": knowledge_base_id,
"document_id": document_id
}) })
debug(f'Insert result: {result=}') debug(f'Insert result: {result=}')
status = 200 if result.get("status") == "success" else 400 status = 200 if result.get("status") == "success" else 400
@ -380,7 +382,7 @@ async def insert_file(request, params_kw, *params, **kw):
return web.json_response({ return web.json_response({
"status": "error", "status": "error",
"collection_name": collection_name, "collection_name": collection_name,
"document_id": "", "document_id": document_id,
"message": str(e) "message": str(e)
}, dumps=lambda obj: json.dumps(obj, ensure_ascii=False), status=400) }, dumps=lambda obj: json.dumps(obj, ensure_ascii=False), status=400)
@ -389,21 +391,23 @@ async def delete_file(request, params_kw, *params, **kw):
se = ServerEnv() se = ServerEnv()
engine = se.engine engine = se.engine
userid = params_kw.get('userid', '') userid = params_kw.get('userid', '')
filename = params_kw.get('filename', '') file_path = params_kw.get('file_path', '')
db_type = params_kw.get('db_type', '') db_type = params_kw.get('db_type', '')
knowledge_base_id = params_kw.get('knowledge_base_id', '') knowledge_base_id = params_kw.get('knowledge_base_id', '')
document_id = params_kw.get('document_id', '')
collection_name = "ragdb" if not db_type else f"ragdb_{db_type}" collection_name = "ragdb" if not db_type else f"ragdb_{db_type}"
try: try:
required_fields = ['userid', 'filename', 'knowledge_base_id'] required_fields = ['userid', 'file_path', 'knowledge_base_id', 'document_id']
missing_fields = [field for field in required_fields if field not in params_kw or not params_kw[field]] missing_fields = [field for field in required_fields if field not in params_kw or not params_kw[field]]
if missing_fields: if missing_fields:
raise ValueError(f"缺少必填字段: {', '.join(missing_fields)}") raise ValueError(f"缺少必填字段: {', '.join(missing_fields)}")
debug(f'Calling delete_document with: userid={userid}, filename={filename}, db_type={db_type}, knowledge_base_id={knowledge_base_id}') debug(f'Calling delete_document with: userid={userid}, file_path={file_path}, db_type={db_type}, knowledge_base_id={knowledge_base_id}, document_id={document_id}')
result = await engine.handle_connection("delete_document", { result = await engine.handle_connection("delete_document", {
"userid": userid, "userid": userid,
"filename": filename, "file_path": file_path,
"knowledge_base_id": knowledge_base_id, "knowledge_base_id": knowledge_base_id,
"document_id": document_id,
"db_type": db_type "db_type": db_type
}) })
debug(f'Delete result: {result=}') debug(f'Delete result: {result=}')
@ -414,7 +418,7 @@ async def delete_file(request, params_kw, *params, **kw):
return web.json_response({ return web.json_response({
"status": "error", "status": "error",
"collection_name": collection_name, "collection_name": collection_name,
"document_id": "", "document_id": document_id,
"message": str(e), "message": str(e),
"status_code": 400 "status_code": 400
}, dumps=lambda obj: json.dumps(obj, ensure_ascii=False), status=400) }, dumps=lambda obj: json.dumps(obj, ensure_ascii=False), status=400)

View File

@ -307,20 +307,22 @@ async def delete_document(request, params_kw, *params, **kw):
se = ServerEnv() se = ServerEnv()
engine = se.engine engine = se.engine
userid = params_kw.get('userid', '') userid = params_kw.get('userid', '')
filename = params_kw.get('filename', '') file_path = params_kw.get('file_path', '')
knowledge_base_id = params_kw.get('knowledge_base_id', '') knowledge_base_id = params_kw.get('knowledge_base_id', '')
document_id = params_kw.get('document_id', '')
db_type = params_kw.get('db_type', '') db_type = params_kw.get('db_type', '')
collection_name = "ragdb" if not db_type else f"ragdb_{db_type}" collection_name = "ragdb" if not db_type else f"ragdb_{db_type}"
try: try:
required_fields = ['userid', 'filename', 'knowledge_base_id'] required_fields = ['userid', 'file_path', 'knowledge_base_id', 'document_id']
missing_fields = [field for field in required_fields if field not in params_kw or not params_kw[field]] missing_fields = [field for field in required_fields if field not in params_kw or not params_kw[field]]
if missing_fields: if missing_fields:
raise ValueError(f"缺少必填字段: {', '.join(missing_fields)}") raise ValueError(f"缺少必填字段: {', '.join(missing_fields)}")
result = await engine.handle_connection("delete_document", { result = await engine.handle_connection("delete_document", {
"userid": userid, "userid": userid,
"filename": filename, "file_path": file_path,
"knowledge_base_id": knowledge_base_id, "knowledge_base_id": knowledge_base_id,
"document_id": document_id,
"db_type": db_type "db_type": db_type
}) })
debug(f'Delete result: {result=}') debug(f'Delete result: {result=}')
@ -331,7 +333,7 @@ async def delete_document(request, params_kw, *params, **kw):
return web.json_response({ return web.json_response({
"status": "error", "status": "error",
"collection_name": collection_name, "collection_name": collection_name,
"document_id": "", "document_id": document_id,
"message": str(e), "message": str(e),
"status_code": 400 "status_code": 400
}, dumps=lambda obj: json.dumps(obj, ensure_ascii=False), status=400) }, dumps=lambda obj: json.dumps(obj, ensure_ascii=False), status=400)

View File

@ -1,10 +1,8 @@
from appPublic.jsonConfig import getConfig
import os import os
from appPublic.log import debug, error, info from appPublic.log import debug, error, info
import yaml
from threading import Lock
from llmengine.base_connection import connection_register from llmengine.base_connection import connection_register
from typing import Dict, List, Any from typing import Dict, List, Any
import numpy as np
import aiohttp import aiohttp
from aiohttp import ClientSession, ClientTimeout from aiohttp import ClientSession, ClientTimeout
from langchain_core.documents import Document from langchain_core.documents import Document
@ -12,49 +10,52 @@ from langchain_text_splitters import RecursiveCharacterTextSplitter
import uuid import uuid
from datetime import datetime from datetime import datetime
from filetxt.loader import fileloader from filetxt.loader import fileloader
from llmengine.kgc import KnowledgeGraph
import numpy as np
from py2neo import Graph
from scipy.spatial.distance import cosine
import time import time
from tenacity import retry, stop_after_attempt, wait_exponential, retry_if_exception_type from tenacity import retry, stop_after_attempt, wait_exponential, retry_if_exception_type
import traceback import traceback
import asyncio import asyncio
import re import re
# 嵌入缓存 # 嵌入缓存
EMBED_CACHE = {} EMBED_CACHE = {}
class MilvusConnection: class MilvusConnection:
_instance = None
_lock = Lock()
def __new__(cls):
with cls._lock:
if cls._instance is None:
cls._instance = super(MilvusConnection, cls).__new__(cls)
cls._instance._initialized = False
return cls._instance
def __init__(self): def __init__(self):
if self._initialized: pass
return
@retry(stop = stop_after_attempt(3))
async def _make_neo4japi_request(self, action: str, params: Dict[str, Any]) -> Dict[str, Any]:
debug(f"开始API请求action={action}, params={params}")
try: try:
config = getConfig() async with ClientSession(timeout=ClientTimeout(total=300)) as session:
self.neo4j_uri = config['neo4j']['uri'] url = f"http://localhost:8885/v1/{action}"
self.neo4j_user = config['neo4j']['user'] debug(f"发起POST请求{url}")
self.neo4j_password = config['neo4j']['password'] async with session.post(
except KeyError as e: url,
error(f"配置文件缺少必要字段: {str(e)}") headers={'Content-Type': 'application/json'},
raise RuntimeError(f"配置文件缺少必要字段: {str(e)}") json=params
self._initialized = True ) as response:
info("Neo4jConnection initialized") debug(f"收到相应: status={response.status}, headers={response.headers}")
respose_text = await response.text()
debug(f"响应内容: {respose_text}")
result = await response.json()
debug(f"API响应内容: {result}")
if response.status == 400:
debug(f"客户端错误,状态码: {response.status},返回响应: {result}")
return result
if response.status != 200:
error(f"API 调用失败,动作: {action}, 状态码: {response.status}, 响应: {response_text}")
raise RuntimeError(f"API 调用失败: {response.status}")
debug(f"API 调用成功: {action}, 响应: {result}")
return result
except Exception as e:
error(f"API 调用失败: {action}, 错误: {str(e)}, 堆栈: {traceback.format_exc()}")
raise RuntimeError(f"API 调用失败: {str(e)}")
@retry(stop=stop_after_attempt(3)) @retry(stop=stop_after_attempt(3))
async def _make_api_request(self, action: str, params: Dict[str, Any]) -> Dict[str, Any]: async def _make_api_request(self, action: str, params: Dict[str, Any]) -> Dict[str, Any]:
debug(f"开始 API 请求: action={action}, params={params}") debug(f"开始 API 请求: action={action}, params={params}")
try: try:
async with ClientSession(timeout=ClientTimeout(total=10)) as session: async with ClientSession(timeout=ClientTimeout(total=300)) as session:
url = f"http://localhost:8886/v1/{action}" url = f"http://localhost:8886/v1/{action}"
debug(f"发起 POST 请求: {url}") debug(f"发起 POST 请求: {url}")
async with session.post( async with session.post(
@ -107,30 +108,32 @@ class MilvusConnection:
file_path = params.get("file_path", "") file_path = params.get("file_path", "")
userid = params.get("userid", "") userid = params.get("userid", "")
knowledge_base_id = params.get("knowledge_base_id", "") knowledge_base_id = params.get("knowledge_base_id", "")
if not file_path or not userid or not knowledge_base_id: document_id = params.get("document_id", "")
return {"status": "error", "message": "file_path、userid 和 knowledge_base_id 不能为空", if not file_path or not userid or not knowledge_base_id or not document_id:
return {"status": "error", "message": "file_path、userid document_id和 knowledge_base_id 不能为空",
"collection_name": collection_name, "document_id": "", "status_code": 400} "collection_name": collection_name, "document_id": "", "status_code": 400}
if "_" in userid or "_" in knowledge_base_id: if "_" in userid or "_" in knowledge_base_id:
return {"status": "error", "message": "userid 和 knowledge_base_id 不能包含下划线", return {"status": "error", "message": "userid 和 knowledge_base_id 不能包含下划线",
"collection_name": collection_name, "document_id": "", "status_code": 400} "collection_name": collection_name, "document_id": document_id, "status_code": 400}
if len(knowledge_base_id) > 100: if len(knowledge_base_id) > 100:
return {"status": "error", "message": "knowledge_base_id 的长度应小于 100", return {"status": "error", "message": "knowledge_base_id 的长度应小于 100",
"collection_name": collection_name, "document_id": "", "status_code": 400} "collection_name": collection_name, "document_id": "", "status_code": 400}
return await self._insert_document(file_path, userid, knowledge_base_id, db_type) return await self._insert_document(file_path, userid, knowledge_base_id, document_id, db_type)
elif action == "delete_document": elif action == "delete_document":
userid = params.get("userid", "") userid = params.get("userid", "")
filename = params.get("filename", "") file_path = params.get("file_path", "")
knowledge_base_id = params.get("knowledge_base_id", "") knowledge_base_id = params.get("knowledge_base_id", "")
if not userid or not filename or not knowledge_base_id: document_id = params.get("document_id", "")
return {"status": "error", "message": "userid、filename 和 knowledge_base_id 不能为空", if not userid or not file_path or not knowledge_base_id or not document_id:
return {"status": "error", "message": "userid、file_path document_id和 knowledge_base_id 不能为空",
"collection_name": collection_name, "document_id": "", "status_code": 400} "collection_name": collection_name, "document_id": "", "status_code": 400}
if "_" in userid or "_" in knowledge_base_id: if "_" in userid or "_" in knowledge_base_id:
return {"status": "error", "message": "userid 和 knowledge_base_id 不能包含下划线", return {"status": "error", "message": "userid 和 knowledge_base_id 不能包含下划线",
"collection_name": collection_name, "document_id": "", "status_code": 400} "collection_name": collection_name, "document_id": "", "status_code": 400}
if len(userid) > 100 or len(filename) > 255 or len(knowledge_base_id) > 100: if len(userid) > 100 or len(file_path) > 255 or len(knowledge_base_id) > 100:
return {"status": "error", "message": "userid、filename 或 knowledge_base_id 的长度超出限制", return {"status": "error", "message": "userid、file_path 或 knowledge_base_id 的长度超出限制",
"collection_name": collection_name, "document_id": "", "status_code": 400} "collection_name": collection_name, "document_id": "", "status_code": 400}
return await self._delete_document(userid, filename, knowledge_base_id, db_type) return await self._delete_document(userid, file_path, knowledge_base_id, document_id, db_type)
elif action == "delete_knowledge_base": elif action == "delete_knowledge_base":
userid = params.get("userid", "") userid = params.get("userid", "")
knowledge_base_id = params.get("knowledge_base_id", "") knowledge_base_id = params.get("knowledge_base_id", "")
@ -232,10 +235,9 @@ class MilvusConnection:
"status_code": 400 "status_code": 400
} }
async def _insert_document(self, file_path: str, userid: str, knowledge_base_id: str, db_type: str = "") -> Dict[ async def _insert_document(self, file_path: str, userid: str, knowledge_base_id: str, document_id:str, db_type: str = "") -> Dict[
str, Any]: str, Any]:
"""将文档插入 Milvus 并抽取三元组到 Neo4j""" """将文档插入 Milvus 并抽取三元组到 Neo4j"""
document_id = str(uuid.uuid4())
collection_name = "ragdb" if not db_type else f"ragdb_{db_type}" collection_name = "ragdb" if not db_type else f"ragdb_{db_type}"
debug( debug(
f'Inserting document: file_path={file_path}, userid={userid}, db_type={db_type}, knowledge_base_id={knowledge_base_id}, document_id={document_id}') f'Inserting document: file_path={file_path}, userid={userid}, db_type={db_type}, knowledge_base_id={knowledge_base_id}, document_id={document_id}')
@ -377,15 +379,28 @@ class MilvusConnection:
f"三元组抽取耗时: {timings['extract_triples']:.2f} 秒, 抽取到 {len(unique_triples)} 个三元组: {unique_triples[:5]}") f"三元组抽取耗时: {timings['extract_triples']:.2f} 秒, 抽取到 {len(unique_triples)} 个三元组: {unique_triples[:5]}")
# Neo4j 插入 # Neo4j 插入
debug(f"抽取到 {len(unique_triples)} 个三元组,插入 Neo4j") debug(f"抽取到 {len(unique_triples)} 个三元组,调用Neo4j服务插入")
start_neo4j = time.time() start_neo4j = time.time()
if unique_triples: if unique_triples:
kg = KnowledgeGraph(triples=unique_triples, document_id=document_id, neo4j_result = await self._make_neo4japi_request("inserttriples", {
knowledge_base_id=knowledge_base_id, userid=userid) "triples":unique_triples,
kg.create_graphnodes() "document_id": document_id,
kg.create_graphrels() "knowledge_base_id": knowledge_base_id,
kg.export_data() "userid": userid
info(f"文件 {file_path} 三元组成功插入 Neo4j") })
debug(f"Neo4j服务响应: {neo4j_result}")
if neo4j_result.get("status") != "success":
timings["insert_neo4j"] = time.time() - start_neo4j
timings["total"] = time.time() - start_total
return{
"status": "error",
"document_id": document_id,
"collection_name": collection_name,
"timings": timings,
"message": f"Neo4j 三元组插入失败: {neo4j_result.get('message', '未知错误')}",
"status_code": 400
}
info(f"文件 {file_path} 三元组成功插入 Neo4j: {neo4j_result.get('message')}")
else: else:
debug(f"文件 {file_path} 未抽取到三元组") debug(f"文件 {file_path} 未抽取到三元组")
timings["insert_neo4j"] = time.time() - start_neo4j timings["insert_neo4j"] = time.time() - start_neo4j
@ -500,16 +515,17 @@ class MilvusConnection:
debug(f"Request #{request_id} traceback: {traceback.format_exc()}") debug(f"Request #{request_id} traceback: {traceback.format_exc()}")
raise RuntimeError(f"三元组抽取服务调用失败: {str(e)}") raise RuntimeError(f"三元组抽取服务调用失败: {str(e)}")
async def _delete_document(self, userid: str, filename: str, knowledge_base_id: str, db_type: str = "") -> Dict[str, Any]: async def _delete_document(self, userid: str, file_path: str, knowledge_base_id: str, document_id:str, db_type: str = "") -> Dict[str, Any]:
"""删除用户指定文件数据,包括 Milvus 和 Neo4j 中的记录""" """删除用户指定文件数据,包括 Milvus 和 Neo4j 中的记录"""
collection_name = "ragdb" if not db_type else f"ragdb_{db_type}" collection_name = "ragdb" if not db_type else f"ragdb_{db_type}"
try: try:
# 调用 Milvus 删除文件端点 # 调用 Milvus 删除文件端点
debug(f"调用删除文件端点: userid={userid}, filename={filename}, knowledge_base_id={knowledge_base_id}") debug(f"调用删除文件端点: userid={userid}, file_path={file_path}, knowledge_base_id={knowledge_base_id}, document_id={document_id}")
milvus_result = await self._make_api_request("deletedocument", { milvus_result = await self._make_api_request("deletedocument", {
"userid": userid, "userid": userid,
"filename": filename, "file_path": file_path,
"knowledge_base_id": knowledge_base_id, "knowledge_base_id": knowledge_base_id,
"document_id": document_id,
"db_type": db_type "db_type": db_type
}) })
@ -517,41 +533,29 @@ class MilvusConnection:
error(f"Milvus 删除文件失败: {milvus_result.get('message', '未知错误')}") error(f"Milvus 删除文件失败: {milvus_result.get('message', '未知错误')}")
return milvus_result return milvus_result
document_ids = milvus_result.get("document_id", "").split(",") if milvus_result.get("document_id") else [] # 调用 Neo4j 删除端点
neo4j_deleted_nodes = 0 neo4j_deleted_nodes = 0
neo4j_deleted_rels = 0 neo4j_deleted_rels = 0
try:
# 删除 Neo4j 数据 debug(f"调用 Neo4j 删除文档端点: document_id={document_id}")
for doc_id in document_ids: neo4j_result = await self._make_neo4japi_request("deletedocument", {
if not doc_id: "document_id": document_id
continue })
try: if neo4j_result.get("status") != "success":
graph = Graph(self.neo4j_uri, auth=(self.neo4j_user, self.neo4j_password)) error(
query = """ f"Neo4j 删除文档失败: document_id={document_id}, 错误: {neo4j_result.get('message', '未知错误')}")
MATCH (n {document_id: $document_id}) nodes_deleted = neo4j_result.get("nodes_deleted", 0)
OPTIONAL MATCH (n)-[r {document_id: $document_id}]->() rels_deleted = neo4j_result.get("rels_deleted", 0)
WITH collect(r) AS rels, collect(n) AS nodes neo4j_deleted_nodes += nodes_deleted
FOREACH (r IN rels | DELETE r) neo4j_deleted_rels += rels_deleted
FOREACH (n IN nodes | DELETE n) info(f"成功删除 document_id={document_id}{nodes_deleted} 个 Neo4j 节点和 {rels_deleted} 个关系")
RETURN size(nodes) AS node_count, size(rels) AS rel_count, [r IN rels | type(r)] AS rel_types except Exception as e:
""" error(f"删除 document_id={document_id} 的 Neo4j 数据失败: {str(e)}")
result = graph.run(query, document_id=doc_id).data()
nodes_deleted = result[0]['node_count'] if result else 0
rels_deleted = result[0]['rel_count'] if result else 0
rel_types = result[0]['rel_types'] if result else []
info(
f"成功删除 document_id={doc_id}{nodes_deleted} 个 Neo4j 节点和 {rels_deleted} 个关系,关系类型: {rel_types}")
neo4j_deleted_nodes += nodes_deleted
neo4j_deleted_rels += rels_deleted
except Exception as e:
error(f"删除 document_id={doc_id} 的 Neo4j 三元组失败: {str(e)}")
continue
return { return {
"status": "success", "status": "success",
"collection_name": collection_name, "collection_name": collection_name,
"document_id": ",".join(document_ids), "document_id": document_id,
"message": f"成功删除 Milvus 记录,{neo4j_deleted_nodes} 个 Neo4j 节点,{neo4j_deleted_rels} 个 Neo4j 关系", "message": f"成功删除 Milvus 记录,{neo4j_deleted_nodes} 个 Neo4j 节点,{neo4j_deleted_rels} 个 Neo4j 关系",
"status_code": 200 "status_code": 200
} }
@ -561,7 +565,7 @@ class MilvusConnection:
return { return {
"status": "error", "status": "error",
"collection_name": collection_name, "collection_name": collection_name,
"document_id": "", "document_id": document_id,
"message": f"删除文档失败: {str(e)}", "message": f"删除文档失败: {str(e)}",
"status_code": 400 "status_code": 400
} }
@ -584,30 +588,30 @@ class MilvusConnection:
deleted_files = milvus_result.get("deleted_files", []) deleted_files = milvus_result.get("deleted_files", [])
# 删除 Neo4j 数据 # 新增:调用 Neo4j 删除知识库端点
neo4j_deleted_nodes = 0 neo4j_deleted_nodes = 0
neo4j_deleted_rels = 0 neo4j_deleted_rels = 0
try: try:
debug(f"尝试连接 Neo4j: uri={self.neo4j_uri}, user={self.neo4j_user}") debug(f"调用 Neo4j 删除知识库端点: userid={userid}, knowledge_base_id={knowledge_base_id}")
graph = Graph(self.neo4j_uri, auth=(self.neo4j_user, self.neo4j_password)) neo4j_result = await self._make_neo4japi_request("deleteknowledgebase", {
debug("Neo4j 连接成功") "userid": userid,
query = """ "knowledge_base_id": knowledge_base_id
MATCH (n {userid: $userid, knowledge_base_id: $knowledge_base_id}) })
OPTIONAL MATCH (n)-[r {userid: $userid, knowledge_base_id: $knowledge_base_id}]->() if neo4j_result.get("status") == "success":
WITH collect(r) AS rels, collect(n) AS nodes neo4j_deleted_nodes = neo4j_result.get("nodes_deleted", 0)
FOREACH (r IN rels | DELETE r) neo4j_deleted_rels = neo4j_result.get("rels_deleted", 0)
FOREACH (n IN nodes | DELETE n) info(f"成功删除 {neo4j_deleted_nodes} 个 Neo4j 节点和 {neo4j_deleted_rels} 个关系")
RETURN size(nodes) AS node_count, size(rels) AS rel_count, [r IN rels | type(r)] AS rel_types else:
""" error(f"Neo4j 删除知识库失败: {neo4j_result.get('message', '未知错误')}")
result = graph.run(query, userid=userid, knowledge_base_id=knowledge_base_id).data() return {
nodes_deleted = result[0]['node_count'] if result else 0 "status": "success",
rels_deleted = result[0]['rel_count'] if result else 0 "collection_name": collection_name,
rel_types = result[0]['rel_types'] if result else [] "deleted_files": deleted_files,
neo4j_deleted_nodes += nodes_deleted "message": f"成功删除 Milvus 知识库,{neo4j_deleted_nodes} 个 Neo4j 节点,{neo4j_deleted_rels} 个 Neo4j 关系,但 Neo4j 删除失败: {neo4j_result.get('message')}",
neo4j_deleted_rels += rels_deleted "status_code": 200
info(f"成功删除 {nodes_deleted} 个 Neo4j 节点和 {rels_deleted} 个关系,关系类型: {rel_types}") }
except Exception as e: except Exception as e:
error(f"删除 Neo4j 数据失败: {str(e)}") error(f"Neo4j 删除知识库失败: {str(e)}")
return { return {
"status": "success", "status": "success",
"collection_name": collection_name, "collection_name": collection_name,
@ -672,119 +676,6 @@ class MilvusConnection:
error(f"实体识别服务调用失败: {str(e)}") error(f"实体识别服务调用失败: {str(e)}")
return [] return []
async def _match_triplets(self, query: str, query_entities: List[str], userid: str, knowledge_base_id: str) -> List[Dict]:
"""匹配查询实体与 Neo4j 中的三元组"""
start_time = time.time() # 记录开始时间
matched_triplets = []
ENTITY_SIMILARITY_THRESHOLD = 0.8
try:
graph = Graph(self.neo4j_uri, auth=(self.neo4j_user, self.neo4j_password))
debug(f"已连接到 Neo4j: {self.neo4j_uri}")
neo4j_connect_time = time.time() - start_time
debug(f"Neo4j 连接耗时: {neo4j_connect_time:.3f}")
matched_names = set()
entity_match_start = time.time()
for entity in query_entities:
normalized_entity = entity.lower().strip()
query = """
MATCH (n {userid: $userid, knowledge_base_id: $knowledge_base_id})
WHERE toLower(n.name) CONTAINS $entity
OR apoc.text.levenshteinSimilarity(toLower(n.name), $entity) > 0.7
RETURN n.name, apoc.text.levenshteinSimilarity(toLower(n.name), $entity) AS sim
ORDER BY sim DESC
LIMIT 100
"""
try:
results = graph.run(query, userid=userid, knowledge_base_id=knowledge_base_id, entity=normalized_entity).data()
for record in results:
matched_names.add(record['n.name'])
debug(f"实体 {entity} 匹配节点: {record['n.name']} (Levenshtein 相似度: {record['sim']:.2f})")
except Exception as e:
debug(f"模糊匹配实体 {entity} 失败: {str(e)}")
continue
entity_match_time = time.time() - entity_match_start
debug(f"实体匹配耗时: {entity_match_time:.3f}")
triplets = []
if matched_names:
triplet_query_start = time.time()
query = """
MATCH (h {userid: $userid, knowledge_base_id: $knowledge_base_id})-[r {userid: $userid, knowledge_base_id: $knowledge_base_id}]->(t {userid: $userid, knowledge_base_id: $knowledge_base_id})
WHERE h.name IN $matched_names OR t.name IN $matched_names
RETURN h.name AS head, r.name AS type, t.name AS tail
LIMIT 100
"""
try:
results = graph.run(query, userid=userid, knowledge_base_id=knowledge_base_id, matched_names=list(matched_names)).data()
seen = set()
for record in results:
head, type_, tail = record['head'], record['type'], record['tail']
triplet_key = (head.lower(), type_.lower(), tail.lower())
if triplet_key not in seen:
seen.add(triplet_key)
triplets.append({
'head': head,
'type': type_,
'tail': tail,
'head_type': '',
'tail_type': ''
})
debug(f"从 Neo4j 加载三元组: knowledge_base_id={knowledge_base_id}, 数量={len(triplets)}")
except Exception as e:
error(f"检索三元组失败: knowledge_base_id={knowledge_base_id}, 错误: {str(e)}")
return []
triplet_query_time = time.time() - triplet_query_start
debug(f"Neo4j 三元组查询耗时: {triplet_query_time:.3f}")
if not triplets:
debug(f"知识库 knowledge_base_id={knowledge_base_id} 无匹配三元组")
return []
embedding_start = time.time()
texts_to_embed = query_entities + [t['head'] for t in triplets] + [t['tail'] for t in triplets]
embeddings = await self._get_embeddings(texts_to_embed)
entity_vectors = {entity: embeddings[i] for i, entity in enumerate(query_entities)}
head_vectors = {t['head']: embeddings[len(query_entities) + i] for i, t in enumerate(triplets)}
tail_vectors = {t['tail']: embeddings[len(query_entities) + len(triplets) + i] for i, t in enumerate(triplets)}
debug(f"成功获取 {len(embeddings)} 个嵌入向量({len(query_entities)} entities + {len(triplets)} heads + {len(triplets)} tails")
embedding_time = time.time() - embedding_start
debug(f"嵌入向量生成耗时: {embedding_time:.3f}")
similarity_start = time.time()
for entity in query_entities:
entity_vec = entity_vectors[entity]
for d_triplet in triplets:
d_head_vec = head_vectors[d_triplet['head']]
d_tail_vec = tail_vectors[d_triplet['tail']]
head_similarity = 1 - cosine(entity_vec, d_head_vec)
tail_similarity = 1 - cosine(entity_vec, d_tail_vec)
if head_similarity >= ENTITY_SIMILARITY_THRESHOLD or tail_similarity >= ENTITY_SIMILARITY_THRESHOLD:
matched_triplets.append(d_triplet)
debug(f"匹配三元组: {d_triplet['head']} - {d_triplet['type']} - {d_triplet['tail']} "
f"(entity={entity}, head_sim={head_similarity:.2f}, tail_sim={tail_similarity:.2f})")
similarity_time = time.time() - similarity_start
debug(f"相似度计算耗时: {similarity_time:.3f}")
unique_matched = []
seen = set()
for t in matched_triplets:
identifier = (t['head'].lower(), t['type'].lower(), t['tail'].lower())
if identifier not in seen:
seen.add(identifier)
unique_matched.append(t)
total_time = time.time() - start_time
debug(f"_match_triplets 总耗时: {total_time:.3f}")
info(f"找到 {len(unique_matched)} 个匹配的三元组")
return unique_matched
except Exception as e:
error(f"匹配三元组失败: {str(e)}")
return []
async def _rerank_results(self, query: str, results: List[Dict], top_n: int) -> List[Dict]: async def _rerank_results(self, query: str, results: List[Dict], top_n: int) -> List[Dict]:
"""调用重排序服务""" """调用重排序服务"""
try: try:
@ -936,14 +827,28 @@ class MilvusConnection:
timing_stats["entity_extraction"] = time.time() - entity_extract_start timing_stats["entity_extraction"] = time.time() - entity_extract_start
debug(f"提取实体: {query_entities}, 耗时: {timing_stats['entity_extraction']:.3f}") debug(f"提取实体: {query_entities}, 耗时: {timing_stats['entity_extraction']:.3f}")
# 匹配三元组 # 调用 Neo4j 服务进行三元组匹配
all_triplets = [] all_triplets = []
triplet_match_start = time.time() triplet_match_start = time.time()
for kb_id in knowledge_base_ids: for kb_id in knowledge_base_ids:
debug(f"处理知识库: {kb_id}") debug(f"调用 Neo4j 三元组匹配: knowledge_base_id={kb_id}")
matched_triplets = await self._match_triplets(query, query_entities, userid, kb_id) try:
debug(f"知识库 {kb_id} 匹配三元组: {len(matched_triplets)}") neo4j_result = await self._make_neo4japi_request("matchtriplets", {
all_triplets.extend(matched_triplets) "query": query,
"query_entities": query_entities,
"userid": userid,
"knowledge_base_id": kb_id
})
if neo4j_result.get("status") == "success":
triplets = neo4j_result.get("triplets", [])
all_triplets.extend(triplets)
debug(f"知识库 {kb_id} 匹配到 {len(triplets)} 个三元组: {triplets[:5]}")
else:
error(
f"Neo4j 三元组匹配失败: knowledge_base_id={kb_id}, 错误: {neo4j_result.get('message', '未知错误')}")
except Exception as e:
error(f"Neo4j 三元组匹配失败: knowledge_base_id={kb_id}, 错误: {str(e)}")
continue
timing_stats["triplet_matching"] = time.time() - triplet_match_start timing_stats["triplet_matching"] = time.time() - triplet_match_start
debug(f"三元组匹配总耗时: {timing_stats['triplet_matching']:.3f}") debug(f"三元组匹配总耗时: {timing_stats['triplet_matching']:.3f}")
@ -977,7 +882,7 @@ class MilvusConnection:
# 调用融合搜索端点 # 调用融合搜索端点
search_start = time.time() search_start = time.time()
result = await self._make_api_request("searchquery", { # 注意:使用 searchquery 端点 result = await self._make_api_request("searchquery", {
"query_vector": query_vector.tolist(), "query_vector": query_vector.tolist(),
"userid": userid, "userid": userid,
"knowledge_base_ids": knowledge_base_ids, "knowledge_base_ids": knowledge_base_ids,

View File

@ -79,19 +79,20 @@ class MilvusDBConnection(BaseDBConnection):
return await self._insert_document(chunks, db_type) return await self._insert_document(chunks, db_type)
elif action == "delete_document": elif action == "delete_document":
userid = params.get("userid", "") userid = params.get("userid", "")
filename = params.get("filename", "") file_path = params.get("file_path", "")
knowledge_base_id = params.get("knowledge_base_id", "") knowledge_base_id = params.get("knowledge_base_id", "")
document_id = params.get("document_id", "")
db_type = params.get("db_type", "") db_type = params.get("db_type", "")
if not userid or not filename or not knowledge_base_id: if not userid or not file_path or not knowledge_base_id or not document_id:
return {"status": "error", "message": "userid、filename 和 knowledge_base_id 不能为空", return {"status": "error", "message": "userid、file_path document_id和 knowledge_base_id 不能为空1",
"collection_name": collection_name, "document_id": "", "status_code": 400} "collection_name": collection_name, "document_id": "", "status_code": 400}
if "_" in userid or "_" in knowledge_base_id: if "_" in userid or "_" in knowledge_base_id:
return {"status": "error", "message": "userid 和 knowledge_base_id 不能包含下划线", return {"status": "error", "message": "userid 和 knowledge_base_id 不能包含下划线",
"collection_name": collection_name, "document_id": "", "status_code": 400} "collection_name": collection_name, "document_id": "", "status_code": 400}
if len(userid) > 100 or len(filename) > 255 or len(knowledge_base_id) > 100: if len(userid) > 100 or len(file_path) > 255 or len(knowledge_base_id) > 100:
return {"status": "error", "message": "userid、filename 或 knowledge_base_id 的长度超出限制", return {"status": "error", "message": "userid、filename 或 knowledge_base_id 的长度超出限制",
"collection_name": collection_name, "document_id": "", "status_code": 400} "collection_name": collection_name, "document_id": "", "status_code": 400}
return await self._delete_document(userid, filename, knowledge_base_id, db_type) return await self._delete_document(userid, file_path, knowledge_base_id, document_id, db_type)
elif action == "delete_knowledge_base": elif action == "delete_knowledge_base":
userid = params.get("userid", "") userid = params.get("userid", "")
knowledge_base_id = params.get("knowledge_base_id", "") knowledge_base_id = params.get("knowledge_base_id", "")
@ -392,7 +393,7 @@ class MilvusDBConnection(BaseDBConnection):
"status_code": 400 "status_code": 400
} }
async def _delete_document(self, userid: str, filename: str, knowledge_base_id: str, db_type: str = "") -> Dict[ async def _delete_document(self, userid: str, file_path: str, knowledge_base_id: str, document_id:str, db_type: str = "") -> Dict[
str, Any]: str, Any]:
"""删除用户指定文件数据,仅处理 Milvus 记录""" """删除用户指定文件数据,仅处理 Milvus 记录"""
collection_name = "ragdb" if not db_type else f"ragdb_{db_type}" collection_name = "ragdb" if not db_type else f"ragdb_{db_type}"
@ -402,7 +403,7 @@ class MilvusDBConnection(BaseDBConnection):
return { return {
"status": "success", "status": "success",
"collection_name": collection_name, "collection_name": collection_name,
"document_id": "", "document_id": document_id,
"message": f"集合 {collection_name} 不存在,无需删除", "message": f"集合 {collection_name} 不存在,无需删除",
"status_code": 200 "status_code": 200
} }
@ -416,72 +417,40 @@ class MilvusDBConnection(BaseDBConnection):
return { return {
"status": "error", "status": "error",
"collection_name": collection_name, "collection_name": collection_name,
"document_id": "", "document_id": document_id,
"message": f"加载集合失败: {str(e)}", "message": f"加载集合失败: {str(e)}",
"status_code": 400 "status_code": 400
} }
expr = f"userid == '{userid}' and filename == '{filename}' and knowledge_base_id == '{knowledge_base_id}'"
debug(f"查询表达式: {expr}")
try:
results = collection.query(
expr=expr,
output_fields=["document_id"],
limit=1000
)
if not results:
debug(
f"没有找到 userid={userid}, filename={filename}, knowledge_base_id={knowledge_base_id} 的记录")
return {
"status": "success",
"collection_name": collection_name,
"document_id": "",
"message": f"没有找到 userid={userid}, filename={filename}, knowledge_base_id={knowledge_base_id} 的记录,无需删除",
"status_code": 200
}
document_ids = list(set(result["document_id"] for result in results if "document_id" in result))
debug(f"找到 {len(document_ids)} 个 document_id: {document_ids}")
except Exception as e:
error(f"查询 document_id 失败: {str(e)}")
return {
"status": "error",
"collection_name": collection_name,
"document_id": "",
"message": f"查询失败: {str(e)}",
"status_code": 400
}
total_deleted = 0 total_deleted = 0
for doc_id in document_ids: try:
try: delete_expr = f"document_id == '{document_id}'"
delete_expr = f"document_id == '{doc_id}'" debug(f"删除表达式: {delete_expr}")
debug(f"删除表达式: {delete_expr}") delete_result = collection.delete(delete_expr)
delete_result = collection.delete(delete_expr) deleted_count = delete_result.delete_count
deleted_count = delete_result.delete_count total_deleted += deleted_count
total_deleted += deleted_count info(f"成功删除 document_id={document_id}{deleted_count} 条 Milvus 记录")
info(f"成功删除 document_id={doc_id}{deleted_count} 条 Milvus 记录") except Exception as e:
except Exception as e: error(f"删除 document_id={document_id} 的 Milvus 记录失败: {str(e)}")
error(f"删除 document_id={doc_id} 的 Milvus 记录失败: {str(e)}")
continue
if total_deleted == 0: if total_deleted == 0:
debug( debug(
f"没有删除任何 Milvus 记录userid={userid}, filename={filename}, knowledge_base_id={knowledge_base_id}") f"没有删除任何 Milvus 记录userid={userid}, file_path={file_path}, knowledge_base_id={knowledge_base_id}, document_id={document_id}")
return { return {
"status": "success", "status": "success",
"collection_name": collection_name, "collection_name": collection_name,
"document_id": "", "document_id": document_id,
"message": f"没有删除任何记录userid={userid}, filename={filename}, knowledge_base_id={knowledge_base_id}", "message": f"没有删除任何记录userid={userid}, file_path={file_path}, knowledge_base_id={knowledge_base_id}, document_id={document_id}",
"status_code": 200 "status_code": 200
} }
info( info(
f"总计删除 {total_deleted} 条 Milvus 记录userid={userid}, filename={filename}, knowledge_base_id={knowledge_base_id}") f"总计删除 {total_deleted} 条 Milvus 记录userid={userid}, file_path={file_path}, knowledge_base_id={knowledge_base_id}, document_id={document_id}")
return { return {
"status": "success", "status": "success",
"collection_name": collection_name, "collection_name": collection_name,
"document_id": ",".join(document_ids), "document_id": ",".join(document_ids),
"message": f"成功删除 {total_deleted} 条 Milvus 记录userid={userid}, filename={filename}, knowledge_base_id={knowledge_base_id}", "message": f"成功删除 {total_deleted} 条 Milvus 记录userid={userid}, file_path={file_path}, knowledge_base_id={knowledge_base_id}, document_id={document_id}",
"status_code": 200 "status_code": 200
} }
@ -490,7 +459,7 @@ class MilvusDBConnection(BaseDBConnection):
return { return {
"status": "error", "status": "error",
"collection_name": collection_name, "collection_name": collection_name,
"document_id": "", "document_id": document_id,
"message": f"删除文档失败: {str(e)}", "message": f"删除文档失败: {str(e)}",
"status_code": 400 "status_code": 400
} }