rag/rag/rag_operations.py
2025-10-14 13:23:57 +08:00

506 lines
21 KiB
Python

import os
import re
import time
import math
from datetime import datetime
from typing import List, Dict, Any, Optional
from langchain_core.documents import Document
from langchain_text_splitters import RecursiveCharacterTextSplitter
from appPublic.log import debug, error, info
from filetxt.loader import fileloader, File2Text
from rag.uapi_service import APIService
from rag.service_opts import get_service_params
from rag.transaction_manager import TransactionManager, OperationType
class RagOperations:
"""RAG 操作类,提供所有通用的 RAG 操作"""
def __init__(self):
self.api_service = APIService()
async def load_and_chunk_document(self, realpath: str, timings: Dict,
transaction_mgr: TransactionManager = None) -> List[Document]:
"""加载文件并进行文本分片"""
debug(f"加载文件: {realpath}")
start_load = time.time()
# 检查文件格式支持
supported_formats = File2Text.supported_types()
debug(f"支持的文件格式:{supported_formats}")
ext = realpath.rsplit('.', 1)[1].lower() if '.' in realpath else ''
if ext not in supported_formats:
raise ValueError(f"不支持的文件格式: {ext}, 支持的格式: {', '.join(supported_formats)}")
# 加载文件内容
text = fileloader(realpath)
text = re.sub(r'[^\u4e00-\u9fa5a-zA-Z0-9\s.;,\n/]', '', text)
timings["load_file"] = time.time() - start_load
debug(f"加载文件耗时: {timings['load_file']:.2f} 秒, 文本长度: {len(text)}")
if not text or not text.strip():
raise ValueError(f"文件 {realpath} 加载为空")
# 分片处理
document = Document(page_content=text)
text_splitter = RecursiveCharacterTextSplitter(
chunk_size=500,
chunk_overlap=100,
length_function=len
)
debug("开始分片文件内容")
start_split = time.time()
chunks = text_splitter.split_documents([document])
timings["split_text"] = time.time() - start_split
debug(f"文本分片耗时: {timings['split_text']:.2f} 秒, 分片数量: {len(chunks)}")
if not chunks:
raise ValueError(f"文件 {realpath} 未生成任何文档块")
# 记录事务操作
if transaction_mgr:
transaction_mgr.add_operation(
OperationType.FILE_LOAD,
{'realpath': realpath, 'chunks_count': len(chunks)}
)
return chunks
async def generate_embeddings(self, request, chunks: List[Document], service_params: Dict,
userid: str, timings: Dict,
transaction_mgr: TransactionManager = None) -> List[List[float]]:
"""生成嵌入向量"""
debug("调用嵌入服务生成向量")
start_embedding = time.time()
texts = [chunk.page_content for chunk in chunks]
embeddings = []
# 批量处理嵌入
for i in range(0, len(texts), 10):
batch_texts = texts[i:i + 10]
batch_embeddings = await self.api_service.get_embeddings(
request=request,
texts=batch_texts,
upappid=service_params['embedding'],
apiname="BAAI/bge-m3",
user=userid
)
embeddings.extend(batch_embeddings)
if not embeddings or not all(len(vec) == 1024 for vec in embeddings):
raise ValueError("所有嵌入向量必须是长度为 1024 的浮点数列表")
timings["generate_embeddings"] = time.time() - start_embedding
debug(f"生成嵌入向量耗时: {timings['generate_embeddings']:.2f} 秒, 嵌入数量: {len(embeddings)}")
# 记录事务操作
if transaction_mgr:
transaction_mgr.add_operation(
OperationType.EMBEDDING,
{'embeddings_count': len(embeddings)}
)
return embeddings
async def insert_to_vector_db(self, request, chunks: List[Document], embeddings: List[List[float]],
realpath: str, orgid: str, fiid: str, id: str, service_params: Dict,
userid: str, db_type: str, timings: Dict,
transaction_mgr: TransactionManager = None):
"""插入向量数据库"""
debug(f"准备数据并调用插入文件端点: {realpath}")
filename = os.path.basename(realpath).rsplit('.', 1)[0]
ext = realpath.rsplit('.', 1)[1].lower() if '.' in realpath else ''
upload_time = datetime.now().isoformat()
chunks_data = [
{
"userid": orgid,
"knowledge_base_id": fiid,
"text": chunk.page_content,
"vector": embeddings[i],
"document_id": id,
"filename": filename + '.' + ext,
"file_path": realpath,
"upload_time": upload_time,
"file_type": ext,
}
for i, chunk in enumerate(chunks)
]
start_milvus = time.time()
for i in range(0, len(chunks_data), 10):
batch_chunks = chunks_data[i:i + 10]
debug(f"传入的数据是:{batch_chunks}")
result = await self.api_service.milvus_insert_document(
request=request,
chunks=batch_chunks,
db_type=db_type,
upappid=service_params['vdb'],
apiname="milvus/insertdocument",
user=userid
)
if result.get("status") != "success":
raise ValueError(result.get("message", "Milvus 插入失败"))
timings["insert_milvus"] = time.time() - start_milvus
debug(f"Milvus 插入耗时: {timings['insert_milvus']:.2f}")
# 记录事务操作,包含回滚函数
if transaction_mgr:
async def rollback_vdb_insert(data, context):
try:
# 防御性检查
required_context = ['request', 'service_params', 'userid']
missing_context = [k for k in required_context if k not in context or context[k] is None]
if missing_context:
raise ValueError(f"回滚上下文缺少字段: {', '.join(missing_context)}")
required_data = ['orgid', 'realpath', 'fiid', 'id', 'db_type']
missing_data = [k for k in required_data if k not in data or data[k] is None]
if missing_data:
raise ValueError(f"VDB_INSERT 数据缺少字段: {', '.join(missing_data)}")
await self.delete_from_vector_db(
context['request'], data['orgid'], data['realpath'],
data['fiid'], data['id'], context['service_params'],
context['userid'], data['db_type']
)
return f"已回滚向量数据库插入: {data['id']}"
except Exception as e:
error(f"回滚向量数据库失败: document_id={data.get('id', '未知')}, 错误: {str(e)}")
raise
transaction_mgr.add_operation(
OperationType.VDB_INSERT,
{
'orgid': orgid, 'realpath': realpath, 'fiid': fiid,
'id': id, 'db_type': db_type
},
rollback_func=rollback_vdb_insert
)
return chunks_data
async def insert_to_vector_text(self, request,
db_type: str, fields: Dict, service_params: Dict, userid: str, timings: Dict) -> List[Dict]:
"""插入单一纯文本到向量数据库,支持动态 schema"""
chunk_data = {}
debug("准备单一纯文本数据并调用插入端点")
start = time.time()
for key, value in fields.items():
chunk_data[key] = value
chunks_data = [chunk_data]
debug(f"向量库插入传入的数据是:{chunks_data}")
# 调用 Milvus 插入
result = await self.api_service.milvus_insert_document(
request=request,
chunks=chunks_data,
upappid=service_params['vdb'],
apiname="milvus/insertdocument",
user=userid,
db_type=db_type
)
if result.get("status") != "success":
raise ValueError(result.get("message", "Milvus 插入失败"))
debug(f"成功插入纯文本到集合 {result.get('collection_name')}")
timings["textinsert"] = time.time() - start
debug(f"插入纯文本耗时: {timings['textinsert']:.2f}")
return chunks_data
async def extract_triples(self, request, chunks: List[Document], service_params: Dict,
userid: str, timings: Dict,
transaction_mgr: TransactionManager = None) -> List[Dict]:
"""抽取三元组"""
debug("调用三元组抽取服务")
start_triples = time.time()
chunk_texts = [doc.page_content for doc in chunks]
triples = []
for i, chunk in enumerate(chunk_texts):
result = await self.api_service.extract_triples(
request=request,
text=chunk,
upappid=service_params['triples'],
apiname="Babelscape/mrebel-large",
user=userid
)
if isinstance(result, list):
triples.extend(result)
debug(f"分片 {i + 1} 抽取到 {len(result)} 个三元组")
else:
error(f"分片 {i + 1} 处理失败: {str(result)}")
# 去重和优化三元组
unique_triples = self._deduplicate_triples(triples)
timings["extract_triples"] = time.time() - start_triples
debug(f"三元组抽取耗时: {timings['extract_triples']:.2f} 秒, 抽取到 {len(unique_triples)} 个三元组")
# 记录事务操作
if transaction_mgr:
transaction_mgr.add_operation(
OperationType.TRIPLES_EXTRACT,
{'triples_count': len(unique_triples)}
)
return unique_triples
async def insert_to_graph_db(self, request, triples: List[Dict], id: str, fiid: str,
orgid: str, service_params: Dict, userid: str, timings: Dict,
transaction_mgr: TransactionManager = None):
"""插入图数据库"""
debug(f"插入 {len(triples)} 个三元组到 Neo4j")
start_neo4j = time.time()
if triples:
for i in range(0, len(triples), 30):
batch_triples = triples[i:i + 30]
neo4j_result = await self.api_service.neo4j_insert_triples(
request=request,
triples=batch_triples,
document_id=id,
knowledge_base_id=fiid,
userid=orgid,
upappid=service_params['gdb'],
apiname="neo4j/inserttriples",
user=userid
)
if neo4j_result.get("status") != "success":
raise ValueError(f"Neo4j 三元组插入失败: {neo4j_result.get('message', '未知错误')}")
info(f"文件三元组成功插入 Neo4j: {neo4j_result.get('message')}")
timings["insert_neo4j"] = time.time() - start_neo4j
debug(f"Neo4j 插入耗时: {timings['insert_neo4j']:.2f}")
else:
debug("未抽取到三元组")
timings["insert_neo4j"] = 0.0
# 记录事务操作,包含回滚函数
if transaction_mgr:
async def rollback_gdb_insert(data, context):
await self.delete_from_graph_db(
context['request'], data['id'],
context['service_params'], context['userid']
)
return f"已回滚图数据库插入: {data['id']}"
transaction_mgr.add_operation(
OperationType.GDB_INSERT,
{'id': id, 'triples_count': len(triples)},
rollback_func=rollback_gdb_insert
)
async def delete_from_vector_db(self, request, orgid: str, realpath: str, fiid: str,
id: str, service_params: Dict, userid: str, db_type: str):
"""从向量数据库删除文档"""
debug(f"调用删除文件端点: userid={orgid}, file_path={realpath}, knowledge_base_id={fiid}, document_id={id}")
milvus_result = await self.api_service.milvus_delete_document(
request=request,
userid=orgid,
file_path=realpath,
knowledge_base_id=fiid,
document_id=id,
db_type=db_type,
upappid=service_params['vdb'],
apiname="milvus/deletedocument",
user=userid
)
if milvus_result.get("status") != "success":
raise ValueError(milvus_result.get("message", "Milvus 删除失败"))
async def delete_from_graph_db(self, request, id: str, service_params: Dict, userid: str):
"""从图数据库删除文档"""
debug(f"调用 Neo4j 删除文档端点: document_id={id}")
neo4j_result = await self.api_service.neo4j_delete_document(
request=request,
document_id=id,
upappid=service_params['gdb'],
apiname="neo4j/deletedocument",
user=userid
)
if neo4j_result.get("status") != "success":
raise ValueError(neo4j_result.get("message", "Neo4j 删除失败"))
nodes_deleted = neo4j_result.get("nodes_deleted", 0)
rels_deleted = neo4j_result.get("rels_deleted", 0)
info(f"成功删除 document_id={id}{nodes_deleted} 个 Neo4j 节点和 {rels_deleted} 个关系")
return nodes_deleted, rels_deleted
async def extract_entities(self, request, query: str, service_params: Dict, userid: str,
timings: Dict) -> List[str]:
"""提取实体"""
debug(f"提取查询实体: {query}")
start_extract = time.time()
entities = await self.api_service.extract_entities(
request=request,
query=query,
upappid=service_params['entities'],
apiname="LTP/small",
user=userid
)
timings["entity_extraction"] = time.time() - start_extract
debug(f"提取实体: {entities}, 耗时: {timings['entity_extraction']:.3f}")
return entities
async def match_triplets(self, request, query: str, entities: List[str], orgid: str,
fiids: List[str], service_params: Dict, userid: str,
timings: Dict) -> List[Dict]:
"""匹配三元组"""
debug("开始三元组匹配")
start_triplet = time.time()
all_triplets = []
for kb_id in fiids:
debug(f"调用 Neo4j 三元组匹配: knowledge_base_id={kb_id}")
try:
neo4j_result = await self.api_service.neo4j_match_triplets(
request=request,
query=query,
query_entities=entities,
userid=orgid,
knowledge_base_id=kb_id,
upappid=service_params['gdb'],
apiname="neo4j/matchtriplets",
user=userid
)
if neo4j_result.get("status") == "success":
triplets = neo4j_result.get("triplets", [])
all_triplets.extend(triplets)
debug(f"知识库 {kb_id} 匹配到 {len(triplets)} 个三元组")
else:
error(f"Neo4j 三元组匹配失败: knowledge_base_id={kb_id}, 错误: {neo4j_result.get('message', '未知错误')}")
except Exception as e:
error(f"Neo4j 三元组匹配失败: knowledge_base_id={kb_id}, 错误: {str(e)}")
continue
timings["triplet_matching"] = time.time() - start_triplet
debug(f"三元组匹配总耗时: {timings['triplet_matching']:.3f}")
return all_triplets
async def generate_query_vector(self, request, text: str, service_params: Dict,
userid: str, timings: Dict) -> List[float]:
"""生成查询向量"""
debug(f"生成查询向量: {text[:200]}...")
start_vector = time.time()
query_vector = await self.api_service.get_embeddings(
request=request,
texts=[text],
upappid=service_params['embedding'],
apiname="BAAI/bge-m3",
user=userid
)
if not query_vector or not all(len(vec) == 1024 for vec in query_vector):
raise ValueError("查询向量必须是长度为 1024 的浮点数列表")
query_vector = query_vector[0]
timings["vector_generation"] = time.time() - start_vector
debug(f"生成查询向量耗时: {timings['vector_generation']:.3f}")
return query_vector
async def vector_search(self, request, query_vector: List[float], orgid: str,
fiids: List[str], limit: int, service_params: Dict, userid: str,
timings: Dict) -> List[Dict]:
"""向量搜索"""
debug("开始向量搜索")
start_search = time.time()
result = await self.api_service.milvus_search_query(
request=request,
query_vector=query_vector,
userid=orgid,
knowledge_base_ids=fiids,
limit=limit,
offset=0,
upappid=service_params['vdb'],
apiname="mlvus/searchquery",
user=userid
)
if result.get("status") != "success":
raise ValueError(f"向量搜索失败: {result.get('message', '未知错误')}")
search_results = result.get("results", [])
timings["vector_search"] = time.time() - start_search
debug(f"向量搜索耗时: {timings['vector_search']:.3f}")
debug(f"从向量数据中搜索到{len(search_results)}条数据")
return search_results
async def rerank_results(self, request, query: str, results: List[Dict], top_n: int,
service_params: Dict, userid: str, timings: Dict) -> List[Dict]:
"""重排序结果"""
debug("开始重排序")
start_rerank = time.time()
reranked_results = await self.api_service.rerank_results(
request=request,
query=query,
results=results,
top_n=top_n,
upappid=service_params['reranker'],
apiname="BAAI/bge-reranker-v2-m3",
user=userid
)
reranked_results = sorted(reranked_results, key=lambda x: x.get('rerank_score', 0), reverse=True)
timings["reranking"] = time.time() - start_rerank
debug(f"重排序耗时: {timings['reranking']:.3f}")
debug(f"重排序分数分布: {[round(r.get('rerank_score', 0), 3) for r in reranked_results]}")
return reranked_results
def _deduplicate_triples(self, triples: List[Dict]) -> List[Dict]:
"""去重和优化三元组"""
unique_triples = []
seen = set()
for t in triples:
identifier = (t['head'].lower(), t['tail'].lower(), t['type'].lower())
if identifier not in seen:
seen.add(identifier)
unique_triples.append(t)
else:
# 如果发现更具体的类型,则替换
for existing in unique_triples:
if (existing['head'].lower() == t['head'].lower() and
existing['tail'].lower() == t['tail'].lower() and
len(t['type']) > len(existing['type'])):
unique_triples.remove(existing)
unique_triples.append(t)
debug(f"替换三元组为更具体类型: {t}")
break
return unique_triples
def format_search_results(self, results: List[Dict], limit: int) -> List[Dict]:
"""格式化搜索结果为统一格式"""
formatted_results = []
# for res in results[:limit]:
# score = res.get('rerank_score', res.get('distance', 0))
#
# content = res.get('text', '')
# title = res.get('metadata', {}).get('filename', 'Untitled')
# document_id = res.get('metadata', {}).get('document_id', '')
#
# formatted_results.append({
# "content": content,
# "title": title,
# "metadata": {"document_id": document_id, "score": score},
# })
#得分归一化
for res in results[:limit]:
rerank_score = res.get('rerank_score', 0)
score = 1 / (1 + math.exp(-rerank_score)) if rerank_score is not None else 1 - res.get('distance', 0)
score = max(0.0, min(1.0, score))
content = res.get('text', '')
title = res.get('metadata', {}).get('filename', 'Untitled')
document_id = res.get('metadata', {}).get('document_id', '')
formatted_results.append({
"content": content,
"title": title,
"metadata": {"document_id": document_id, "score": score},
})
return formatted_results