This commit is contained in:
wangmeihua 2025-09-15 18:06:45 +08:00
parent 524f1e93d1
commit cddb4733fb
3 changed files with 135 additions and 57 deletions

View File

@ -20,7 +20,7 @@ from typing import List, Dict, Any
from rag.service_opts import get_service_params, sor_get_service_params
from rag.rag_operations import RagOperations
import json
from rag.transaction_manager import TransactionContext
from dataclasses import dataclass
from enum import Enum
@ -53,7 +53,6 @@ where a.orgid = b.orgid
return r.quota, r.expired_date
return None, None
async def file_uploaded(self, request, ns, userid):
"""将文档插入 Milvus 并抽取三元组到 Neo4j"""
debug(f'Received ns: {ns=}')
@ -64,53 +63,103 @@ where a.orgid = b.orgid
orgid = ns.get('ownerid', '')
db_type = ''
debug(f'Inserting document: file_path={realpath}, userid={orgid}, db_type={db_type}, knowledge_base_id={fiid}, document_id={id}')
debug(
f'Inserting document: file_path={realpath}, userid={orgid}, db_type={db_type}, knowledge_base_id={fiid}, document_id={id}')
timings = {}
start_total = time.time()
result = {
"status": "error",
"userid": orgid,
"document_id": id,
"collection_name": "ragdb",
"timings": timings,
"message": "",
"status_code": 400
}
try:
if not orgid or not fiid or not id:
raise ValueError("orgid、fiid 和 id 不能为空")
if len(orgid) > 32 or len(fiid) > 255:
raise ValueError("orgid 或 fiid 的长度超出限制")
if not os.path.exists(realpath):
raise ValueError(f"文件 {realpath} 不存在")
# 初始化回滚上下文
rollback_context = {
"request": request,
"userid": userid,
"service_params": None # 在 try 块中设置
}
# 获取服务参数
service_params = await get_service_params(orgid)
if not service_params:
raise ValueError("无法获取服务参数")
async with TransactionContext(f"file_upload_{id}") as transaction_mgr:
# 将 rollback_context 绑定到 TransactionContext
transaction_mgr.transaction_context = rollback_context
try:
# 验证必填字段
if not orgid or not fiid or not id:
raise ValueError("orgid、fiid 和 id 不能为空")
if len(orgid) > 32 or len(fiid) > 255:
raise ValueError("orgid 或 fiid 的长度超出限制")
if not os.path.exists(realpath):
raise ValueError(f"文件 {realpath} 不存在")
chunks = await self.rag_ops.load_and_chunk_document(realpath, timings)
embeddings = await self.rag_ops.generate_embeddings(request, chunks, service_params, userid, timings)
await self.rag_ops.insert_to_vector_db(request, chunks, embeddings, realpath, orgid, fiid, id, service_params, userid,db_type, timings)
triples = await self.rag_ops.extract_triples(request, chunks, service_params, userid, timings)
await self.rag_ops.insert_to_graph_db(request, triples, id, fiid, orgid, service_params, userid, timings)
# 获取服务参数
service_params = await get_service_params(orgid)
if not service_params:
raise ValueError("无法获取服务参数")
rollback_context["service_params"] = service_params
timings["total"] = time.time() - start_total
debug(f"总耗时: {timings['total']:.2f}")
return {
"status": "success",
"userid": orgid,
"document_id": id,
"collection_name": "ragdb",
"timings": timings,
"unique_triples": triples,
"message": f"文件 {realpath} 成功嵌入并处理三元组",
"status_code": 200
}
except Exception as e:
error(f"插入文档失败: {str(e)}, 堆栈: {traceback.format_exc()}")
timings["total"] = time.time() - start_total
return {
"status": "error",
"document_id": id,
"collection_name": "ragdb",
"timings": timings,
"message": f"插入文档失败: {str(e)}",
"status_code": 400
}
# 加载和分片文档
chunks = await self.rag_ops.load_and_chunk_document(
realpath, timings, transaction_mgr=transaction_mgr
)
# 生成嵌入向量
embeddings = await self.rag_ops.generate_embeddings(
request, chunks, service_params, userid, timings, transaction_mgr=transaction_mgr
)
# 插入 Milvus
chunks_data = await self.rag_ops.insert_to_vector_db(
request, chunks, embeddings, realpath, orgid, fiid, id,
service_params, userid, db_type, timings, transaction_mgr=transaction_mgr
)
# 抽取三元组
triples = await self.rag_ops.extract_triples(
request, chunks, service_params, userid, timings, transaction_mgr=transaction_mgr
)
# 插入 Neo4j
await self.rag_ops.insert_to_graph_db(
request, triples, id, fiid, orgid, service_params, userid, timings, transaction_mgr=transaction_mgr
)
timings["total"] = time.time() - start_total
result.update({
"status": "success",
"unique_triples": triples,
"message": f"文件 {realpath} 成功嵌入并处理三元组",
"status_code": 200
})
debug(f"总耗时: {timings['total']:.2f}")
except Exception as e:
error(f"插入文档失败: {str(e)}, 堆栈: {traceback.format_exc()}")
timings["total"] = time.time() - start_total
result.update({
"message": f"插入文档失败: {str(e)}",
"timings": timings
})
# 记录回滚日志
rollback_log = {
"transaction_id": transaction_mgr.transaction_id,
"document_id": id,
"realpath": realpath,
"action": "rollback",
"reason": f"插入失败: {str(e)}",
"timestamp": datetime.now().isoformat()
}
with open('/home/wangmeihua/kyrag/logs/rollback.log', 'a') as f:
f.write(f"{json.dumps(rollback_log, ensure_ascii=False)}\n")
# 提供回滚上下文
raise ValueError(str(e)) from e
debug(f"最终结果是:{result}")
return result
async def file_deleted(self, request, recs, userid):
"""删除用户指定文件数据,包括 Milvus 和 Neo4j 中的记录"""

View File

@ -148,12 +148,27 @@ class RagOperations:
# 记录事务操作,包含回滚函数
if transaction_mgr:
async def rollback_vdb_insert(data, context):
await self.delete_from_vector_db(
context['request'], data['orgid'], data['realpath'],
data['fiid'], data['id'], context['service_params'],
context['userid'], data['db_type']
)
return f"已回滚向量数据库插入: {data['id']}"
try:
# 防御性检查
required_context = ['request', 'service_params', 'userid']
missing_context = [k for k in required_context if k not in context or context[k] is None]
if missing_context:
raise ValueError(f"回滚上下文缺少字段: {', '.join(missing_context)}")
required_data = ['orgid', 'realpath', 'fiid', 'id', 'db_type']
missing_data = [k for k in required_data if k not in data or data[k] is None]
if missing_data:
raise ValueError(f"VDB_INSERT 数据缺少字段: {', '.join(missing_data)}")
await self.delete_from_vector_db(
context['request'], data['orgid'], data['realpath'],
data['fiid'], data['id'], context['service_params'],
context['userid'], data['db_type']
)
return f"已回滚向量数据库插入: {data['id']}"
except Exception as e:
error(f"回滚向量数据库失败: document_id={data.get('id', '未知')}, 错误: {str(e)}")
raise
transaction_mgr.add_operation(
OperationType.VDB_INSERT,
@ -207,6 +222,7 @@ class RagOperations:
async def insert_to_graph_db(self, request, triples: List[Dict], id: str, fiid: str,
orgid: str, service_params: Dict, userid: str, timings: Dict,
transaction_mgr: TransactionManager = None):
"""插入图数据库"""
debug(f"插入 {len(triples)} 个三元组到 Neo4j")
start_neo4j = time.time()

View File

@ -27,7 +27,7 @@ class RollbackOperation:
operation_type: OperationType
data: Dict[str, Any]
timestamp: str
rollback_func: Optional[Callable] = None # 自定义回滚函数
rollback_func: Optional[Callable] = None
class TransactionManager:
@ -38,7 +38,7 @@ class TransactionManager:
self.transaction_id: str = datetime.now().strftime("%Y%m%d_%H%M%S_%f")
def add_operation(self, operation_type: OperationType, data: Dict[str, Any],
rollback_func: Optional[Callable] = None):
rollback_func: Optional[Callable] = None):
"""添加操作记录"""
operation = RollbackOperation(
operation_type=operation_type,
@ -57,6 +57,7 @@ class TransactionManager:
"""执行所有回滚操作"""
debug(f"事务 {self.transaction_id} 开始回滚,共 {len(self.rollback_ops)} 个操作")
rollback_results = []
rollback_context = rollback_context or {}
# 按相反顺序执行回滚
for op in reversed(self.rollback_ops):
@ -64,6 +65,19 @@ class TransactionManager:
debug(f"回滚操作: {op.operation_type.value}")
if op.rollback_func:
# 验证上下文和数据
required_context_keys = ['request', 'service_params', 'userid']
missing_context = [k for k in required_context_keys if
k not in rollback_context or rollback_context[k] is None]
if missing_context:
raise ValueError(f"回滚上下文缺少字段: {', '.join(missing_context)}")
required_data_keys = ['orgid', 'realpath', 'fiid', 'id', 'db_type']
if op.operation_type == OperationType.VDB_INSERT:
missing_data = [k for k in required_data_keys if k not in op.data or op.data[k] is None]
if missing_data:
raise ValueError(f"VDB_INSERT 数据缺少字段: {', '.join(missing_data)}")
# 执行自定义回滚函数
result = await op.rollback_func(op.data, rollback_context)
rollback_results.append(f"成功回滚 {op.operation_type.value}: {result}")
@ -73,7 +87,7 @@ class TransactionManager:
rollback_results.append(f"成功回滚 {op.operation_type.value}: {result}")
except Exception as e:
error(f"回滚操作失败 {op.operation_type.value}: {str(e)}")
error(f"回滚操作失败 {op.operation_type.value}: {str(e)}, 堆栈: {traceback.format_exc()}")
rollback_results.append(f"回滚失败: {op.operation_type.value} - {str(e)}")
return rollback_results
@ -81,12 +95,10 @@ class TransactionManager:
async def _default_rollback(self, op: RollbackOperation, context: Dict[str, Any]) -> str:
"""默认回滚处理"""
if op.operation_type in [OperationType.FILE_LOAD, OperationType.EMBEDDING,
OperationType.TRIPLES_EXTRACT, OperationType.ENTITY_EXTRACT]:
# 内存操作,无需特殊回滚
OperationType.TRIPLES_EXTRACT, OperationType.ENTITY_EXTRACT]:
return "内存操作,已自动清理"
elif op.operation_type in [OperationType.VDB_INSERT, OperationType.VDB_DELETE,
OperationType.GDB_INSERT, OperationType.GDB_DELETE]:
# 数据库操作需要在具体实现中处理
OperationType.GDB_INSERT, OperationType.GDB_DELETE]:
return "需要在具体实现中处理数据库回滚"
else:
return f"跳过回滚操作: {op.operation_type.value}"
@ -118,7 +130,9 @@ class TransactionContext:
if exc_type is not None:
# 发生异常,执行回滚
error(f"事务失败: {self.transaction_name}, 异常: {exc_val}")
rollback_results = await self.transaction_manager.rollback_all()
# 使用 transaction_manager 的 rollback_context
rollback_context = getattr(self.transaction_manager, 'transaction_context', {})
rollback_results = await self.transaction_manager.rollback_all(rollback_context)
error(f"回滚结果: {rollback_results}")
info(f"事务回滚完成: {self.transaction_name}, 耗时: {duration:.2f}")
else:
@ -128,7 +142,6 @@ class TransactionContext:
self.transaction_manager.clear()
# 工厂函数
def create_transaction_manager() -> TransactionManager:
"""创建事务管理器"""
return TransactionManager()