rag
This commit is contained in:
parent
524f1e93d1
commit
cddb4733fb
@ -20,7 +20,7 @@ from typing import List, Dict, Any
|
||||
from rag.service_opts import get_service_params, sor_get_service_params
|
||||
from rag.rag_operations import RagOperations
|
||||
import json
|
||||
|
||||
from rag.transaction_manager import TransactionContext
|
||||
from dataclasses import dataclass
|
||||
from enum import Enum
|
||||
|
||||
@ -53,7 +53,6 @@ where a.orgid = b.orgid
|
||||
return r.quota, r.expired_date
|
||||
return None, None
|
||||
|
||||
|
||||
async def file_uploaded(self, request, ns, userid):
|
||||
"""将文档插入 Milvus 并抽取三元组到 Neo4j"""
|
||||
debug(f'Received ns: {ns=}')
|
||||
@ -64,12 +63,33 @@ where a.orgid = b.orgid
|
||||
orgid = ns.get('ownerid', '')
|
||||
db_type = ''
|
||||
|
||||
debug(f'Inserting document: file_path={realpath}, userid={orgid}, db_type={db_type}, knowledge_base_id={fiid}, document_id={id}')
|
||||
debug(
|
||||
f'Inserting document: file_path={realpath}, userid={orgid}, db_type={db_type}, knowledge_base_id={fiid}, document_id={id}')
|
||||
|
||||
timings = {}
|
||||
start_total = time.time()
|
||||
result = {
|
||||
"status": "error",
|
||||
"userid": orgid,
|
||||
"document_id": id,
|
||||
"collection_name": "ragdb",
|
||||
"timings": timings,
|
||||
"message": "",
|
||||
"status_code": 400
|
||||
}
|
||||
|
||||
# 初始化回滚上下文
|
||||
rollback_context = {
|
||||
"request": request,
|
||||
"userid": userid,
|
||||
"service_params": None # 在 try 块中设置
|
||||
}
|
||||
|
||||
async with TransactionContext(f"file_upload_{id}") as transaction_mgr:
|
||||
# 将 rollback_context 绑定到 TransactionContext
|
||||
transaction_mgr.transaction_context = rollback_context
|
||||
try:
|
||||
# 验证必填字段
|
||||
if not orgid or not fiid or not id:
|
||||
raise ValueError("orgid、fiid 和 id 不能为空")
|
||||
if len(orgid) > 32 or len(fiid) > 255:
|
||||
@ -81,36 +101,65 @@ where a.orgid = b.orgid
|
||||
service_params = await get_service_params(orgid)
|
||||
if not service_params:
|
||||
raise ValueError("无法获取服务参数")
|
||||
rollback_context["service_params"] = service_params
|
||||
|
||||
chunks = await self.rag_ops.load_and_chunk_document(realpath, timings)
|
||||
embeddings = await self.rag_ops.generate_embeddings(request, chunks, service_params, userid, timings)
|
||||
await self.rag_ops.insert_to_vector_db(request, chunks, embeddings, realpath, orgid, fiid, id, service_params, userid,db_type, timings)
|
||||
triples = await self.rag_ops.extract_triples(request, chunks, service_params, userid, timings)
|
||||
await self.rag_ops.insert_to_graph_db(request, triples, id, fiid, orgid, service_params, userid, timings)
|
||||
# 加载和分片文档
|
||||
chunks = await self.rag_ops.load_and_chunk_document(
|
||||
realpath, timings, transaction_mgr=transaction_mgr
|
||||
)
|
||||
|
||||
# 生成嵌入向量
|
||||
embeddings = await self.rag_ops.generate_embeddings(
|
||||
request, chunks, service_params, userid, timings, transaction_mgr=transaction_mgr
|
||||
)
|
||||
|
||||
# 插入 Milvus
|
||||
chunks_data = await self.rag_ops.insert_to_vector_db(
|
||||
request, chunks, embeddings, realpath, orgid, fiid, id,
|
||||
service_params, userid, db_type, timings, transaction_mgr=transaction_mgr
|
||||
)
|
||||
|
||||
# 抽取三元组
|
||||
triples = await self.rag_ops.extract_triples(
|
||||
request, chunks, service_params, userid, timings, transaction_mgr=transaction_mgr
|
||||
)
|
||||
|
||||
# 插入 Neo4j
|
||||
await self.rag_ops.insert_to_graph_db(
|
||||
request, triples, id, fiid, orgid, service_params, userid, timings, transaction_mgr=transaction_mgr
|
||||
)
|
||||
|
||||
timings["total"] = time.time() - start_total
|
||||
debug(f"总耗时: {timings['total']:.2f} 秒")
|
||||
return {
|
||||
result.update({
|
||||
"status": "success",
|
||||
"userid": orgid,
|
||||
"document_id": id,
|
||||
"collection_name": "ragdb",
|
||||
"timings": timings,
|
||||
"unique_triples": triples,
|
||||
"message": f"文件 {realpath} 成功嵌入并处理三元组",
|
||||
"status_code": 200
|
||||
}
|
||||
})
|
||||
debug(f"总耗时: {timings['total']:.2f} 秒")
|
||||
|
||||
except Exception as e:
|
||||
error(f"插入文档失败: {str(e)}, 堆栈: {traceback.format_exc()}")
|
||||
timings["total"] = time.time() - start_total
|
||||
return {
|
||||
"status": "error",
|
||||
"document_id": id,
|
||||
"collection_name": "ragdb",
|
||||
"timings": timings,
|
||||
result.update({
|
||||
"message": f"插入文档失败: {str(e)}",
|
||||
"status_code": 400
|
||||
"timings": timings
|
||||
})
|
||||
# 记录回滚日志
|
||||
rollback_log = {
|
||||
"transaction_id": transaction_mgr.transaction_id,
|
||||
"document_id": id,
|
||||
"realpath": realpath,
|
||||
"action": "rollback",
|
||||
"reason": f"插入失败: {str(e)}",
|
||||
"timestamp": datetime.now().isoformat()
|
||||
}
|
||||
with open('/home/wangmeihua/kyrag/logs/rollback.log', 'a') as f:
|
||||
f.write(f"{json.dumps(rollback_log, ensure_ascii=False)}\n")
|
||||
# 提供回滚上下文
|
||||
raise ValueError(str(e)) from e
|
||||
debug(f"最终结果是:{result}")
|
||||
return result
|
||||
|
||||
async def file_deleted(self, request, recs, userid):
|
||||
"""删除用户指定文件数据,包括 Milvus 和 Neo4j 中的记录"""
|
||||
|
||||
@ -148,12 +148,27 @@ class RagOperations:
|
||||
# 记录事务操作,包含回滚函数
|
||||
if transaction_mgr:
|
||||
async def rollback_vdb_insert(data, context):
|
||||
try:
|
||||
# 防御性检查
|
||||
required_context = ['request', 'service_params', 'userid']
|
||||
missing_context = [k for k in required_context if k not in context or context[k] is None]
|
||||
if missing_context:
|
||||
raise ValueError(f"回滚上下文缺少字段: {', '.join(missing_context)}")
|
||||
|
||||
required_data = ['orgid', 'realpath', 'fiid', 'id', 'db_type']
|
||||
missing_data = [k for k in required_data if k not in data or data[k] is None]
|
||||
if missing_data:
|
||||
raise ValueError(f"VDB_INSERT 数据缺少字段: {', '.join(missing_data)}")
|
||||
|
||||
await self.delete_from_vector_db(
|
||||
context['request'], data['orgid'], data['realpath'],
|
||||
data['fiid'], data['id'], context['service_params'],
|
||||
context['userid'], data['db_type']
|
||||
)
|
||||
return f"已回滚向量数据库插入: {data['id']}"
|
||||
except Exception as e:
|
||||
error(f"回滚向量数据库失败: document_id={data.get('id', '未知')}, 错误: {str(e)}")
|
||||
raise
|
||||
|
||||
transaction_mgr.add_operation(
|
||||
OperationType.VDB_INSERT,
|
||||
@ -207,6 +222,7 @@ class RagOperations:
|
||||
async def insert_to_graph_db(self, request, triples: List[Dict], id: str, fiid: str,
|
||||
orgid: str, service_params: Dict, userid: str, timings: Dict,
|
||||
transaction_mgr: TransactionManager = None):
|
||||
|
||||
"""插入图数据库"""
|
||||
debug(f"插入 {len(triples)} 个三元组到 Neo4j")
|
||||
start_neo4j = time.time()
|
||||
|
||||
@ -27,7 +27,7 @@ class RollbackOperation:
|
||||
operation_type: OperationType
|
||||
data: Dict[str, Any]
|
||||
timestamp: str
|
||||
rollback_func: Optional[Callable] = None # 自定义回滚函数
|
||||
rollback_func: Optional[Callable] = None
|
||||
|
||||
|
||||
class TransactionManager:
|
||||
@ -57,6 +57,7 @@ class TransactionManager:
|
||||
"""执行所有回滚操作"""
|
||||
debug(f"事务 {self.transaction_id} 开始回滚,共 {len(self.rollback_ops)} 个操作")
|
||||
rollback_results = []
|
||||
rollback_context = rollback_context or {}
|
||||
|
||||
# 按相反顺序执行回滚
|
||||
for op in reversed(self.rollback_ops):
|
||||
@ -64,6 +65,19 @@ class TransactionManager:
|
||||
debug(f"回滚操作: {op.operation_type.value}")
|
||||
|
||||
if op.rollback_func:
|
||||
# 验证上下文和数据
|
||||
required_context_keys = ['request', 'service_params', 'userid']
|
||||
missing_context = [k for k in required_context_keys if
|
||||
k not in rollback_context or rollback_context[k] is None]
|
||||
if missing_context:
|
||||
raise ValueError(f"回滚上下文缺少字段: {', '.join(missing_context)}")
|
||||
|
||||
required_data_keys = ['orgid', 'realpath', 'fiid', 'id', 'db_type']
|
||||
if op.operation_type == OperationType.VDB_INSERT:
|
||||
missing_data = [k for k in required_data_keys if k not in op.data or op.data[k] is None]
|
||||
if missing_data:
|
||||
raise ValueError(f"VDB_INSERT 数据缺少字段: {', '.join(missing_data)}")
|
||||
|
||||
# 执行自定义回滚函数
|
||||
result = await op.rollback_func(op.data, rollback_context)
|
||||
rollback_results.append(f"成功回滚 {op.operation_type.value}: {result}")
|
||||
@ -73,7 +87,7 @@ class TransactionManager:
|
||||
rollback_results.append(f"成功回滚 {op.operation_type.value}: {result}")
|
||||
|
||||
except Exception as e:
|
||||
error(f"回滚操作失败 {op.operation_type.value}: {str(e)}")
|
||||
error(f"回滚操作失败 {op.operation_type.value}: {str(e)}, 堆栈: {traceback.format_exc()}")
|
||||
rollback_results.append(f"回滚失败: {op.operation_type.value} - {str(e)}")
|
||||
|
||||
return rollback_results
|
||||
@ -82,11 +96,9 @@ class TransactionManager:
|
||||
"""默认回滚处理"""
|
||||
if op.operation_type in [OperationType.FILE_LOAD, OperationType.EMBEDDING,
|
||||
OperationType.TRIPLES_EXTRACT, OperationType.ENTITY_EXTRACT]:
|
||||
# 内存操作,无需特殊回滚
|
||||
return "内存操作,已自动清理"
|
||||
elif op.operation_type in [OperationType.VDB_INSERT, OperationType.VDB_DELETE,
|
||||
OperationType.GDB_INSERT, OperationType.GDB_DELETE]:
|
||||
# 数据库操作需要在具体实现中处理
|
||||
return "需要在具体实现中处理数据库回滚"
|
||||
else:
|
||||
return f"跳过回滚操作: {op.operation_type.value}"
|
||||
@ -118,7 +130,9 @@ class TransactionContext:
|
||||
if exc_type is not None:
|
||||
# 发生异常,执行回滚
|
||||
error(f"事务失败: {self.transaction_name}, 异常: {exc_val}")
|
||||
rollback_results = await self.transaction_manager.rollback_all()
|
||||
# 使用 transaction_manager 的 rollback_context
|
||||
rollback_context = getattr(self.transaction_manager, 'transaction_context', {})
|
||||
rollback_results = await self.transaction_manager.rollback_all(rollback_context)
|
||||
error(f"回滚结果: {rollback_results}")
|
||||
info(f"事务回滚完成: {self.transaction_name}, 耗时: {duration:.2f}秒")
|
||||
else:
|
||||
@ -128,7 +142,6 @@ class TransactionContext:
|
||||
self.transaction_manager.clear()
|
||||
|
||||
|
||||
# 工厂函数
|
||||
def create_transaction_manager() -> TransactionManager:
|
||||
"""创建事务管理器"""
|
||||
return TransactionManager()
|
||||
|
||||
Loading…
x
Reference in New Issue
Block a user