rag
This commit is contained in:
parent
524f1e93d1
commit
cddb4733fb
@ -20,7 +20,7 @@ from typing import List, Dict, Any
|
|||||||
from rag.service_opts import get_service_params, sor_get_service_params
|
from rag.service_opts import get_service_params, sor_get_service_params
|
||||||
from rag.rag_operations import RagOperations
|
from rag.rag_operations import RagOperations
|
||||||
import json
|
import json
|
||||||
|
from rag.transaction_manager import TransactionContext
|
||||||
from dataclasses import dataclass
|
from dataclasses import dataclass
|
||||||
from enum import Enum
|
from enum import Enum
|
||||||
|
|
||||||
@ -53,7 +53,6 @@ where a.orgid = b.orgid
|
|||||||
return r.quota, r.expired_date
|
return r.quota, r.expired_date
|
||||||
return None, None
|
return None, None
|
||||||
|
|
||||||
|
|
||||||
async def file_uploaded(self, request, ns, userid):
|
async def file_uploaded(self, request, ns, userid):
|
||||||
"""将文档插入 Milvus 并抽取三元组到 Neo4j"""
|
"""将文档插入 Milvus 并抽取三元组到 Neo4j"""
|
||||||
debug(f'Received ns: {ns=}')
|
debug(f'Received ns: {ns=}')
|
||||||
@ -64,53 +63,103 @@ where a.orgid = b.orgid
|
|||||||
orgid = ns.get('ownerid', '')
|
orgid = ns.get('ownerid', '')
|
||||||
db_type = ''
|
db_type = ''
|
||||||
|
|
||||||
debug(f'Inserting document: file_path={realpath}, userid={orgid}, db_type={db_type}, knowledge_base_id={fiid}, document_id={id}')
|
debug(
|
||||||
|
f'Inserting document: file_path={realpath}, userid={orgid}, db_type={db_type}, knowledge_base_id={fiid}, document_id={id}')
|
||||||
|
|
||||||
timings = {}
|
timings = {}
|
||||||
start_total = time.time()
|
start_total = time.time()
|
||||||
|
result = {
|
||||||
|
"status": "error",
|
||||||
|
"userid": orgid,
|
||||||
|
"document_id": id,
|
||||||
|
"collection_name": "ragdb",
|
||||||
|
"timings": timings,
|
||||||
|
"message": "",
|
||||||
|
"status_code": 400
|
||||||
|
}
|
||||||
|
|
||||||
try:
|
# 初始化回滚上下文
|
||||||
if not orgid or not fiid or not id:
|
rollback_context = {
|
||||||
raise ValueError("orgid、fiid 和 id 不能为空")
|
"request": request,
|
||||||
if len(orgid) > 32 or len(fiid) > 255:
|
"userid": userid,
|
||||||
raise ValueError("orgid 或 fiid 的长度超出限制")
|
"service_params": None # 在 try 块中设置
|
||||||
if not os.path.exists(realpath):
|
}
|
||||||
raise ValueError(f"文件 {realpath} 不存在")
|
|
||||||
|
|
||||||
# 获取服务参数
|
async with TransactionContext(f"file_upload_{id}") as transaction_mgr:
|
||||||
service_params = await get_service_params(orgid)
|
# 将 rollback_context 绑定到 TransactionContext
|
||||||
if not service_params:
|
transaction_mgr.transaction_context = rollback_context
|
||||||
raise ValueError("无法获取服务参数")
|
try:
|
||||||
|
# 验证必填字段
|
||||||
|
if not orgid or not fiid or not id:
|
||||||
|
raise ValueError("orgid、fiid 和 id 不能为空")
|
||||||
|
if len(orgid) > 32 or len(fiid) > 255:
|
||||||
|
raise ValueError("orgid 或 fiid 的长度超出限制")
|
||||||
|
if not os.path.exists(realpath):
|
||||||
|
raise ValueError(f"文件 {realpath} 不存在")
|
||||||
|
|
||||||
chunks = await self.rag_ops.load_and_chunk_document(realpath, timings)
|
# 获取服务参数
|
||||||
embeddings = await self.rag_ops.generate_embeddings(request, chunks, service_params, userid, timings)
|
service_params = await get_service_params(orgid)
|
||||||
await self.rag_ops.insert_to_vector_db(request, chunks, embeddings, realpath, orgid, fiid, id, service_params, userid,db_type, timings)
|
if not service_params:
|
||||||
triples = await self.rag_ops.extract_triples(request, chunks, service_params, userid, timings)
|
raise ValueError("无法获取服务参数")
|
||||||
await self.rag_ops.insert_to_graph_db(request, triples, id, fiid, orgid, service_params, userid, timings)
|
rollback_context["service_params"] = service_params
|
||||||
|
|
||||||
timings["total"] = time.time() - start_total
|
# 加载和分片文档
|
||||||
debug(f"总耗时: {timings['total']:.2f} 秒")
|
chunks = await self.rag_ops.load_and_chunk_document(
|
||||||
return {
|
realpath, timings, transaction_mgr=transaction_mgr
|
||||||
"status": "success",
|
)
|
||||||
"userid": orgid,
|
|
||||||
"document_id": id,
|
# 生成嵌入向量
|
||||||
"collection_name": "ragdb",
|
embeddings = await self.rag_ops.generate_embeddings(
|
||||||
"timings": timings,
|
request, chunks, service_params, userid, timings, transaction_mgr=transaction_mgr
|
||||||
"unique_triples": triples,
|
)
|
||||||
"message": f"文件 {realpath} 成功嵌入并处理三元组",
|
|
||||||
"status_code": 200
|
# 插入 Milvus
|
||||||
}
|
chunks_data = await self.rag_ops.insert_to_vector_db(
|
||||||
except Exception as e:
|
request, chunks, embeddings, realpath, orgid, fiid, id,
|
||||||
error(f"插入文档失败: {str(e)}, 堆栈: {traceback.format_exc()}")
|
service_params, userid, db_type, timings, transaction_mgr=transaction_mgr
|
||||||
timings["total"] = time.time() - start_total
|
)
|
||||||
return {
|
|
||||||
"status": "error",
|
# 抽取三元组
|
||||||
"document_id": id,
|
triples = await self.rag_ops.extract_triples(
|
||||||
"collection_name": "ragdb",
|
request, chunks, service_params, userid, timings, transaction_mgr=transaction_mgr
|
||||||
"timings": timings,
|
)
|
||||||
"message": f"插入文档失败: {str(e)}",
|
|
||||||
"status_code": 400
|
# 插入 Neo4j
|
||||||
}
|
await self.rag_ops.insert_to_graph_db(
|
||||||
|
request, triples, id, fiid, orgid, service_params, userid, timings, transaction_mgr=transaction_mgr
|
||||||
|
)
|
||||||
|
|
||||||
|
timings["total"] = time.time() - start_total
|
||||||
|
result.update({
|
||||||
|
"status": "success",
|
||||||
|
"unique_triples": triples,
|
||||||
|
"message": f"文件 {realpath} 成功嵌入并处理三元组",
|
||||||
|
"status_code": 200
|
||||||
|
})
|
||||||
|
debug(f"总耗时: {timings['total']:.2f} 秒")
|
||||||
|
|
||||||
|
except Exception as e:
|
||||||
|
error(f"插入文档失败: {str(e)}, 堆栈: {traceback.format_exc()}")
|
||||||
|
timings["total"] = time.time() - start_total
|
||||||
|
result.update({
|
||||||
|
"message": f"插入文档失败: {str(e)}",
|
||||||
|
"timings": timings
|
||||||
|
})
|
||||||
|
# 记录回滚日志
|
||||||
|
rollback_log = {
|
||||||
|
"transaction_id": transaction_mgr.transaction_id,
|
||||||
|
"document_id": id,
|
||||||
|
"realpath": realpath,
|
||||||
|
"action": "rollback",
|
||||||
|
"reason": f"插入失败: {str(e)}",
|
||||||
|
"timestamp": datetime.now().isoformat()
|
||||||
|
}
|
||||||
|
with open('/home/wangmeihua/kyrag/logs/rollback.log', 'a') as f:
|
||||||
|
f.write(f"{json.dumps(rollback_log, ensure_ascii=False)}\n")
|
||||||
|
# 提供回滚上下文
|
||||||
|
raise ValueError(str(e)) from e
|
||||||
|
debug(f"最终结果是:{result}")
|
||||||
|
return result
|
||||||
|
|
||||||
async def file_deleted(self, request, recs, userid):
|
async def file_deleted(self, request, recs, userid):
|
||||||
"""删除用户指定文件数据,包括 Milvus 和 Neo4j 中的记录"""
|
"""删除用户指定文件数据,包括 Milvus 和 Neo4j 中的记录"""
|
||||||
|
|||||||
@ -148,12 +148,27 @@ class RagOperations:
|
|||||||
# 记录事务操作,包含回滚函数
|
# 记录事务操作,包含回滚函数
|
||||||
if transaction_mgr:
|
if transaction_mgr:
|
||||||
async def rollback_vdb_insert(data, context):
|
async def rollback_vdb_insert(data, context):
|
||||||
await self.delete_from_vector_db(
|
try:
|
||||||
context['request'], data['orgid'], data['realpath'],
|
# 防御性检查
|
||||||
data['fiid'], data['id'], context['service_params'],
|
required_context = ['request', 'service_params', 'userid']
|
||||||
context['userid'], data['db_type']
|
missing_context = [k for k in required_context if k not in context or context[k] is None]
|
||||||
)
|
if missing_context:
|
||||||
return f"已回滚向量数据库插入: {data['id']}"
|
raise ValueError(f"回滚上下文缺少字段: {', '.join(missing_context)}")
|
||||||
|
|
||||||
|
required_data = ['orgid', 'realpath', 'fiid', 'id', 'db_type']
|
||||||
|
missing_data = [k for k in required_data if k not in data or data[k] is None]
|
||||||
|
if missing_data:
|
||||||
|
raise ValueError(f"VDB_INSERT 数据缺少字段: {', '.join(missing_data)}")
|
||||||
|
|
||||||
|
await self.delete_from_vector_db(
|
||||||
|
context['request'], data['orgid'], data['realpath'],
|
||||||
|
data['fiid'], data['id'], context['service_params'],
|
||||||
|
context['userid'], data['db_type']
|
||||||
|
)
|
||||||
|
return f"已回滚向量数据库插入: {data['id']}"
|
||||||
|
except Exception as e:
|
||||||
|
error(f"回滚向量数据库失败: document_id={data.get('id', '未知')}, 错误: {str(e)}")
|
||||||
|
raise
|
||||||
|
|
||||||
transaction_mgr.add_operation(
|
transaction_mgr.add_operation(
|
||||||
OperationType.VDB_INSERT,
|
OperationType.VDB_INSERT,
|
||||||
@ -207,6 +222,7 @@ class RagOperations:
|
|||||||
async def insert_to_graph_db(self, request, triples: List[Dict], id: str, fiid: str,
|
async def insert_to_graph_db(self, request, triples: List[Dict], id: str, fiid: str,
|
||||||
orgid: str, service_params: Dict, userid: str, timings: Dict,
|
orgid: str, service_params: Dict, userid: str, timings: Dict,
|
||||||
transaction_mgr: TransactionManager = None):
|
transaction_mgr: TransactionManager = None):
|
||||||
|
|
||||||
"""插入图数据库"""
|
"""插入图数据库"""
|
||||||
debug(f"插入 {len(triples)} 个三元组到 Neo4j")
|
debug(f"插入 {len(triples)} 个三元组到 Neo4j")
|
||||||
start_neo4j = time.time()
|
start_neo4j = time.time()
|
||||||
|
|||||||
@ -27,7 +27,7 @@ class RollbackOperation:
|
|||||||
operation_type: OperationType
|
operation_type: OperationType
|
||||||
data: Dict[str, Any]
|
data: Dict[str, Any]
|
||||||
timestamp: str
|
timestamp: str
|
||||||
rollback_func: Optional[Callable] = None # 自定义回滚函数
|
rollback_func: Optional[Callable] = None
|
||||||
|
|
||||||
|
|
||||||
class TransactionManager:
|
class TransactionManager:
|
||||||
@ -38,7 +38,7 @@ class TransactionManager:
|
|||||||
self.transaction_id: str = datetime.now().strftime("%Y%m%d_%H%M%S_%f")
|
self.transaction_id: str = datetime.now().strftime("%Y%m%d_%H%M%S_%f")
|
||||||
|
|
||||||
def add_operation(self, operation_type: OperationType, data: Dict[str, Any],
|
def add_operation(self, operation_type: OperationType, data: Dict[str, Any],
|
||||||
rollback_func: Optional[Callable] = None):
|
rollback_func: Optional[Callable] = None):
|
||||||
"""添加操作记录"""
|
"""添加操作记录"""
|
||||||
operation = RollbackOperation(
|
operation = RollbackOperation(
|
||||||
operation_type=operation_type,
|
operation_type=operation_type,
|
||||||
@ -57,6 +57,7 @@ class TransactionManager:
|
|||||||
"""执行所有回滚操作"""
|
"""执行所有回滚操作"""
|
||||||
debug(f"事务 {self.transaction_id} 开始回滚,共 {len(self.rollback_ops)} 个操作")
|
debug(f"事务 {self.transaction_id} 开始回滚,共 {len(self.rollback_ops)} 个操作")
|
||||||
rollback_results = []
|
rollback_results = []
|
||||||
|
rollback_context = rollback_context or {}
|
||||||
|
|
||||||
# 按相反顺序执行回滚
|
# 按相反顺序执行回滚
|
||||||
for op in reversed(self.rollback_ops):
|
for op in reversed(self.rollback_ops):
|
||||||
@ -64,6 +65,19 @@ class TransactionManager:
|
|||||||
debug(f"回滚操作: {op.operation_type.value}")
|
debug(f"回滚操作: {op.operation_type.value}")
|
||||||
|
|
||||||
if op.rollback_func:
|
if op.rollback_func:
|
||||||
|
# 验证上下文和数据
|
||||||
|
required_context_keys = ['request', 'service_params', 'userid']
|
||||||
|
missing_context = [k for k in required_context_keys if
|
||||||
|
k not in rollback_context or rollback_context[k] is None]
|
||||||
|
if missing_context:
|
||||||
|
raise ValueError(f"回滚上下文缺少字段: {', '.join(missing_context)}")
|
||||||
|
|
||||||
|
required_data_keys = ['orgid', 'realpath', 'fiid', 'id', 'db_type']
|
||||||
|
if op.operation_type == OperationType.VDB_INSERT:
|
||||||
|
missing_data = [k for k in required_data_keys if k not in op.data or op.data[k] is None]
|
||||||
|
if missing_data:
|
||||||
|
raise ValueError(f"VDB_INSERT 数据缺少字段: {', '.join(missing_data)}")
|
||||||
|
|
||||||
# 执行自定义回滚函数
|
# 执行自定义回滚函数
|
||||||
result = await op.rollback_func(op.data, rollback_context)
|
result = await op.rollback_func(op.data, rollback_context)
|
||||||
rollback_results.append(f"成功回滚 {op.operation_type.value}: {result}")
|
rollback_results.append(f"成功回滚 {op.operation_type.value}: {result}")
|
||||||
@ -73,7 +87,7 @@ class TransactionManager:
|
|||||||
rollback_results.append(f"成功回滚 {op.operation_type.value}: {result}")
|
rollback_results.append(f"成功回滚 {op.operation_type.value}: {result}")
|
||||||
|
|
||||||
except Exception as e:
|
except Exception as e:
|
||||||
error(f"回滚操作失败 {op.operation_type.value}: {str(e)}")
|
error(f"回滚操作失败 {op.operation_type.value}: {str(e)}, 堆栈: {traceback.format_exc()}")
|
||||||
rollback_results.append(f"回滚失败: {op.operation_type.value} - {str(e)}")
|
rollback_results.append(f"回滚失败: {op.operation_type.value} - {str(e)}")
|
||||||
|
|
||||||
return rollback_results
|
return rollback_results
|
||||||
@ -81,12 +95,10 @@ class TransactionManager:
|
|||||||
async def _default_rollback(self, op: RollbackOperation, context: Dict[str, Any]) -> str:
|
async def _default_rollback(self, op: RollbackOperation, context: Dict[str, Any]) -> str:
|
||||||
"""默认回滚处理"""
|
"""默认回滚处理"""
|
||||||
if op.operation_type in [OperationType.FILE_LOAD, OperationType.EMBEDDING,
|
if op.operation_type in [OperationType.FILE_LOAD, OperationType.EMBEDDING,
|
||||||
OperationType.TRIPLES_EXTRACT, OperationType.ENTITY_EXTRACT]:
|
OperationType.TRIPLES_EXTRACT, OperationType.ENTITY_EXTRACT]:
|
||||||
# 内存操作,无需特殊回滚
|
|
||||||
return "内存操作,已自动清理"
|
return "内存操作,已自动清理"
|
||||||
elif op.operation_type in [OperationType.VDB_INSERT, OperationType.VDB_DELETE,
|
elif op.operation_type in [OperationType.VDB_INSERT, OperationType.VDB_DELETE,
|
||||||
OperationType.GDB_INSERT, OperationType.GDB_DELETE]:
|
OperationType.GDB_INSERT, OperationType.GDB_DELETE]:
|
||||||
# 数据库操作需要在具体实现中处理
|
|
||||||
return "需要在具体实现中处理数据库回滚"
|
return "需要在具体实现中处理数据库回滚"
|
||||||
else:
|
else:
|
||||||
return f"跳过回滚操作: {op.operation_type.value}"
|
return f"跳过回滚操作: {op.operation_type.value}"
|
||||||
@ -118,7 +130,9 @@ class TransactionContext:
|
|||||||
if exc_type is not None:
|
if exc_type is not None:
|
||||||
# 发生异常,执行回滚
|
# 发生异常,执行回滚
|
||||||
error(f"事务失败: {self.transaction_name}, 异常: {exc_val}")
|
error(f"事务失败: {self.transaction_name}, 异常: {exc_val}")
|
||||||
rollback_results = await self.transaction_manager.rollback_all()
|
# 使用 transaction_manager 的 rollback_context
|
||||||
|
rollback_context = getattr(self.transaction_manager, 'transaction_context', {})
|
||||||
|
rollback_results = await self.transaction_manager.rollback_all(rollback_context)
|
||||||
error(f"回滚结果: {rollback_results}")
|
error(f"回滚结果: {rollback_results}")
|
||||||
info(f"事务回滚完成: {self.transaction_name}, 耗时: {duration:.2f}秒")
|
info(f"事务回滚完成: {self.transaction_name}, 耗时: {duration:.2f}秒")
|
||||||
else:
|
else:
|
||||||
@ -128,7 +142,6 @@ class TransactionContext:
|
|||||||
self.transaction_manager.clear()
|
self.transaction_manager.clear()
|
||||||
|
|
||||||
|
|
||||||
# 工厂函数
|
|
||||||
def create_transaction_manager() -> TransactionManager:
|
def create_transaction_manager() -> TransactionManager:
|
||||||
"""创建事务管理器"""
|
"""创建事务管理器"""
|
||||||
return TransactionManager()
|
return TransactionManager()
|
||||||
|
|||||||
Loading…
x
Reference in New Issue
Block a user