506 lines
21 KiB
Python
506 lines
21 KiB
Python
import os
|
|
import re
|
|
import time
|
|
import math
|
|
from datetime import datetime
|
|
from typing import List, Dict, Any, Optional
|
|
from langchain_core.documents import Document
|
|
from langchain_text_splitters import RecursiveCharacterTextSplitter
|
|
|
|
from appPublic.log import debug, error, info
|
|
from filetxt.loader import fileloader, File2Text
|
|
from rag.uapi_service import APIService
|
|
from rag.service_opts import get_service_params
|
|
from rag.transaction_manager import TransactionManager, OperationType
|
|
|
|
|
|
class RagOperations:
|
|
"""RAG 操作类,提供所有通用的 RAG 操作"""
|
|
|
|
def __init__(self):
|
|
self.api_service = APIService()
|
|
|
|
async def load_and_chunk_document(self, realpath: str, timings: Dict,
|
|
transaction_mgr: TransactionManager = None) -> List[Document]:
|
|
"""加载文件并进行文本分片"""
|
|
debug(f"加载文件: {realpath}")
|
|
start_load = time.time()
|
|
|
|
# 检查文件格式支持
|
|
supported_formats = File2Text.supported_types()
|
|
debug(f"支持的文件格式:{supported_formats}")
|
|
ext = realpath.rsplit('.', 1)[1].lower() if '.' in realpath else ''
|
|
if ext not in supported_formats:
|
|
raise ValueError(f"不支持的文件格式: {ext}, 支持的格式: {', '.join(supported_formats)}")
|
|
|
|
# 加载文件内容
|
|
text = fileloader(realpath)
|
|
text = re.sub(r'[^\u4e00-\u9fa5a-zA-Z0-9\s.;,\n/]', '', text)
|
|
timings["load_file"] = time.time() - start_load
|
|
debug(f"加载文件耗时: {timings['load_file']:.2f} 秒, 文本长度: {len(text)}")
|
|
|
|
if not text or not text.strip():
|
|
raise ValueError(f"文件 {realpath} 加载为空")
|
|
|
|
# 分片处理
|
|
document = Document(page_content=text)
|
|
text_splitter = RecursiveCharacterTextSplitter(
|
|
chunk_size=500,
|
|
chunk_overlap=100,
|
|
length_function=len
|
|
)
|
|
debug("开始分片文件内容")
|
|
start_split = time.time()
|
|
chunks = text_splitter.split_documents([document])
|
|
timings["split_text"] = time.time() - start_split
|
|
debug(f"文本分片耗时: {timings['split_text']:.2f} 秒, 分片数量: {len(chunks)}")
|
|
|
|
if not chunks:
|
|
raise ValueError(f"文件 {realpath} 未生成任何文档块")
|
|
|
|
# 记录事务操作
|
|
if transaction_mgr:
|
|
transaction_mgr.add_operation(
|
|
OperationType.FILE_LOAD,
|
|
{'realpath': realpath, 'chunks_count': len(chunks)}
|
|
)
|
|
|
|
return chunks
|
|
|
|
async def generate_embeddings(self, request, chunks: List[Document], service_params: Dict,
|
|
userid: str, timings: Dict,
|
|
transaction_mgr: TransactionManager = None) -> List[List[float]]:
|
|
"""生成嵌入向量"""
|
|
debug("调用嵌入服务生成向量")
|
|
start_embedding = time.time()
|
|
texts = [chunk.page_content for chunk in chunks]
|
|
embeddings = []
|
|
|
|
# 批量处理嵌入
|
|
for i in range(0, len(texts), 10):
|
|
batch_texts = texts[i:i + 10]
|
|
batch_embeddings = await self.api_service.get_embeddings(
|
|
request=request,
|
|
texts=batch_texts,
|
|
upappid=service_params['embedding'],
|
|
apiname="BAAI/bge-m3",
|
|
user=userid
|
|
)
|
|
embeddings.extend(batch_embeddings)
|
|
|
|
if not embeddings or not all(len(vec) == 1024 for vec in embeddings):
|
|
raise ValueError("所有嵌入向量必须是长度为 1024 的浮点数列表")
|
|
|
|
timings["generate_embeddings"] = time.time() - start_embedding
|
|
debug(f"生成嵌入向量耗时: {timings['generate_embeddings']:.2f} 秒, 嵌入数量: {len(embeddings)}")
|
|
|
|
# 记录事务操作
|
|
if transaction_mgr:
|
|
transaction_mgr.add_operation(
|
|
OperationType.EMBEDDING,
|
|
{'embeddings_count': len(embeddings)}
|
|
)
|
|
|
|
return embeddings
|
|
|
|
async def insert_to_vector_db(self, request, chunks: List[Document], embeddings: List[List[float]],
|
|
realpath: str, orgid: str, fiid: str, id: str, service_params: Dict,
|
|
userid: str, db_type: str, timings: Dict,
|
|
transaction_mgr: TransactionManager = None):
|
|
"""插入向量数据库"""
|
|
debug(f"准备数据并调用插入文件端点: {realpath}")
|
|
filename = os.path.basename(realpath).rsplit('.', 1)[0]
|
|
ext = realpath.rsplit('.', 1)[1].lower() if '.' in realpath else ''
|
|
upload_time = datetime.now().isoformat()
|
|
|
|
chunks_data = [
|
|
{
|
|
"userid": orgid,
|
|
"knowledge_base_id": fiid,
|
|
"text": chunk.page_content,
|
|
"vector": embeddings[i],
|
|
"document_id": id,
|
|
"filename": filename + '.' + ext,
|
|
"file_path": realpath,
|
|
"upload_time": upload_time,
|
|
"file_type": ext,
|
|
}
|
|
for i, chunk in enumerate(chunks)
|
|
]
|
|
|
|
start_milvus = time.time()
|
|
for i in range(0, len(chunks_data), 10):
|
|
batch_chunks = chunks_data[i:i + 10]
|
|
debug(f"传入的数据是:{batch_chunks}")
|
|
result = await self.api_service.milvus_insert_document(
|
|
request=request,
|
|
chunks=batch_chunks,
|
|
db_type=db_type,
|
|
upappid=service_params['vdb'],
|
|
apiname="milvus/insertdocument",
|
|
user=userid
|
|
)
|
|
if result.get("status") != "success":
|
|
raise ValueError(result.get("message", "Milvus 插入失败"))
|
|
|
|
timings["insert_milvus"] = time.time() - start_milvus
|
|
debug(f"Milvus 插入耗时: {timings['insert_milvus']:.2f} 秒")
|
|
|
|
# 记录事务操作,包含回滚函数
|
|
if transaction_mgr:
|
|
async def rollback_vdb_insert(data, context):
|
|
try:
|
|
# 防御性检查
|
|
required_context = ['request', 'service_params', 'userid']
|
|
missing_context = [k for k in required_context if k not in context or context[k] is None]
|
|
if missing_context:
|
|
raise ValueError(f"回滚上下文缺少字段: {', '.join(missing_context)}")
|
|
|
|
required_data = ['orgid', 'realpath', 'fiid', 'id', 'db_type']
|
|
missing_data = [k for k in required_data if k not in data or data[k] is None]
|
|
if missing_data:
|
|
raise ValueError(f"VDB_INSERT 数据缺少字段: {', '.join(missing_data)}")
|
|
|
|
await self.delete_from_vector_db(
|
|
context['request'], data['orgid'], data['realpath'],
|
|
data['fiid'], data['id'], context['service_params'],
|
|
context['userid'], data['db_type']
|
|
)
|
|
return f"已回滚向量数据库插入: {data['id']}"
|
|
except Exception as e:
|
|
error(f"回滚向量数据库失败: document_id={data.get('id', '未知')}, 错误: {str(e)}")
|
|
raise
|
|
|
|
transaction_mgr.add_operation(
|
|
OperationType.VDB_INSERT,
|
|
{
|
|
'orgid': orgid, 'realpath': realpath, 'fiid': fiid,
|
|
'id': id, 'db_type': db_type
|
|
},
|
|
rollback_func=rollback_vdb_insert
|
|
)
|
|
|
|
return chunks_data
|
|
|
|
async def insert_to_vector_text(self, request,
|
|
db_type: str, fields: Dict, service_params: Dict, userid: str, timings: Dict) -> List[Dict]:
|
|
"""插入单一纯文本到向量数据库,支持动态 schema"""
|
|
|
|
chunk_data = {}
|
|
debug("准备单一纯文本数据并调用插入端点")
|
|
start = time.time()
|
|
for key, value in fields.items():
|
|
chunk_data[key] = value
|
|
|
|
chunks_data = [chunk_data]
|
|
debug(f"向量库插入传入的数据是:{chunks_data}")
|
|
|
|
# 调用 Milvus 插入
|
|
result = await self.api_service.milvus_insert_document(
|
|
request=request,
|
|
chunks=chunks_data,
|
|
upappid=service_params['vdb'],
|
|
apiname="milvus/insertdocument",
|
|
user=userid,
|
|
db_type=db_type
|
|
)
|
|
if result.get("status") != "success":
|
|
raise ValueError(result.get("message", "Milvus 插入失败"))
|
|
|
|
debug(f"成功插入纯文本到集合 {result.get('collection_name')}")
|
|
timings["textinsert"] = time.time() - start
|
|
debug(f"插入纯文本耗时: {timings['textinsert']:.2f} 秒")
|
|
|
|
return chunks_data
|
|
|
|
async def extract_triples(self, request, chunks: List[Document], service_params: Dict,
|
|
userid: str, timings: Dict,
|
|
transaction_mgr: TransactionManager = None) -> List[Dict]:
|
|
"""抽取三元组"""
|
|
debug("调用三元组抽取服务")
|
|
start_triples = time.time()
|
|
chunk_texts = [doc.page_content for doc in chunks]
|
|
triples = []
|
|
|
|
for i, chunk in enumerate(chunk_texts):
|
|
result = await self.api_service.extract_triples(
|
|
request=request,
|
|
text=chunk,
|
|
upappid=service_params['triples'],
|
|
apiname="Babelscape/mrebel-large",
|
|
user=userid
|
|
)
|
|
if isinstance(result, list):
|
|
triples.extend(result)
|
|
debug(f"分片 {i + 1} 抽取到 {len(result)} 个三元组")
|
|
else:
|
|
error(f"分片 {i + 1} 处理失败: {str(result)}")
|
|
|
|
# 去重和优化三元组
|
|
unique_triples = self._deduplicate_triples(triples)
|
|
|
|
timings["extract_triples"] = time.time() - start_triples
|
|
debug(f"三元组抽取耗时: {timings['extract_triples']:.2f} 秒, 抽取到 {len(unique_triples)} 个三元组")
|
|
|
|
# 记录事务操作
|
|
if transaction_mgr:
|
|
transaction_mgr.add_operation(
|
|
OperationType.TRIPLES_EXTRACT,
|
|
{'triples_count': len(unique_triples)}
|
|
)
|
|
|
|
return unique_triples
|
|
|
|
async def insert_to_graph_db(self, request, triples: List[Dict], id: str, fiid: str,
|
|
orgid: str, service_params: Dict, userid: str, timings: Dict,
|
|
transaction_mgr: TransactionManager = None):
|
|
|
|
"""插入图数据库"""
|
|
debug(f"插入 {len(triples)} 个三元组到 Neo4j")
|
|
start_neo4j = time.time()
|
|
|
|
if triples:
|
|
for i in range(0, len(triples), 30):
|
|
batch_triples = triples[i:i + 30]
|
|
neo4j_result = await self.api_service.neo4j_insert_triples(
|
|
request=request,
|
|
triples=batch_triples,
|
|
document_id=id,
|
|
knowledge_base_id=fiid,
|
|
userid=orgid,
|
|
upappid=service_params['gdb'],
|
|
apiname="neo4j/inserttriples",
|
|
user=userid
|
|
)
|
|
if neo4j_result.get("status") != "success":
|
|
raise ValueError(f"Neo4j 三元组插入失败: {neo4j_result.get('message', '未知错误')}")
|
|
info(f"文件三元组成功插入 Neo4j: {neo4j_result.get('message')}")
|
|
|
|
timings["insert_neo4j"] = time.time() - start_neo4j
|
|
debug(f"Neo4j 插入耗时: {timings['insert_neo4j']:.2f} 秒")
|
|
else:
|
|
debug("未抽取到三元组")
|
|
timings["insert_neo4j"] = 0.0
|
|
|
|
# 记录事务操作,包含回滚函数
|
|
if transaction_mgr:
|
|
async def rollback_gdb_insert(data, context):
|
|
await self.delete_from_graph_db(
|
|
context['request'], data['id'],
|
|
context['service_params'], context['userid']
|
|
)
|
|
return f"已回滚图数据库插入: {data['id']}"
|
|
|
|
transaction_mgr.add_operation(
|
|
OperationType.GDB_INSERT,
|
|
{'id': id, 'triples_count': len(triples)},
|
|
rollback_func=rollback_gdb_insert
|
|
)
|
|
|
|
async def delete_from_vector_db(self, request, orgid: str, realpath: str, fiid: str,
|
|
id: str, service_params: Dict, userid: str, db_type: str):
|
|
"""从向量数据库删除文档"""
|
|
debug(f"调用删除文件端点: userid={orgid}, file_path={realpath}, knowledge_base_id={fiid}, document_id={id}")
|
|
milvus_result = await self.api_service.milvus_delete_document(
|
|
request=request,
|
|
userid=orgid,
|
|
file_path=realpath,
|
|
knowledge_base_id=fiid,
|
|
document_id=id,
|
|
db_type=db_type,
|
|
upappid=service_params['vdb'],
|
|
apiname="milvus/deletedocument",
|
|
user=userid
|
|
)
|
|
if milvus_result.get("status") != "success":
|
|
raise ValueError(milvus_result.get("message", "Milvus 删除失败"))
|
|
|
|
async def delete_from_graph_db(self, request, id: str, service_params: Dict, userid: str):
|
|
"""从图数据库删除文档"""
|
|
debug(f"调用 Neo4j 删除文档端点: document_id={id}")
|
|
neo4j_result = await self.api_service.neo4j_delete_document(
|
|
request=request,
|
|
document_id=id,
|
|
upappid=service_params['gdb'],
|
|
apiname="neo4j/deletedocument",
|
|
user=userid
|
|
)
|
|
if neo4j_result.get("status") != "success":
|
|
raise ValueError(neo4j_result.get("message", "Neo4j 删除失败"))
|
|
nodes_deleted = neo4j_result.get("nodes_deleted", 0)
|
|
rels_deleted = neo4j_result.get("rels_deleted", 0)
|
|
info(f"成功删除 document_id={id} 的 {nodes_deleted} 个 Neo4j 节点和 {rels_deleted} 个关系")
|
|
return nodes_deleted, rels_deleted
|
|
|
|
async def extract_entities(self, request, query: str, service_params: Dict, userid: str,
|
|
timings: Dict) -> List[str]:
|
|
"""提取实体"""
|
|
debug(f"提取查询实体: {query}")
|
|
start_extract = time.time()
|
|
entities = await self.api_service.extract_entities(
|
|
request=request,
|
|
query=query,
|
|
upappid=service_params['entities'],
|
|
apiname="LTP/small",
|
|
user=userid
|
|
)
|
|
timings["entity_extraction"] = time.time() - start_extract
|
|
debug(f"提取实体: {entities}, 耗时: {timings['entity_extraction']:.3f} 秒")
|
|
return entities
|
|
|
|
async def match_triplets(self, request, query: str, entities: List[str], orgid: str,
|
|
fiids: List[str], service_params: Dict, userid: str,
|
|
timings: Dict) -> List[Dict]:
|
|
"""匹配三元组"""
|
|
debug("开始三元组匹配")
|
|
start_triplet = time.time()
|
|
all_triplets = []
|
|
|
|
for kb_id in fiids:
|
|
debug(f"调用 Neo4j 三元组匹配: knowledge_base_id={kb_id}")
|
|
try:
|
|
neo4j_result = await self.api_service.neo4j_match_triplets(
|
|
request=request,
|
|
query=query,
|
|
query_entities=entities,
|
|
userid=orgid,
|
|
knowledge_base_id=kb_id,
|
|
upappid=service_params['gdb'],
|
|
apiname="neo4j/matchtriplets",
|
|
user=userid
|
|
)
|
|
if neo4j_result.get("status") == "success":
|
|
triplets = neo4j_result.get("triplets", [])
|
|
all_triplets.extend(triplets)
|
|
debug(f"知识库 {kb_id} 匹配到 {len(triplets)} 个三元组")
|
|
else:
|
|
error(f"Neo4j 三元组匹配失败: knowledge_base_id={kb_id}, 错误: {neo4j_result.get('message', '未知错误')}")
|
|
except Exception as e:
|
|
error(f"Neo4j 三元组匹配失败: knowledge_base_id={kb_id}, 错误: {str(e)}")
|
|
continue
|
|
|
|
timings["triplet_matching"] = time.time() - start_triplet
|
|
debug(f"三元组匹配总耗时: {timings['triplet_matching']:.3f} 秒")
|
|
return all_triplets
|
|
|
|
async def generate_query_vector(self, request, text: str, service_params: Dict,
|
|
userid: str, timings: Dict) -> List[float]:
|
|
"""生成查询向量"""
|
|
debug(f"生成查询向量: {text[:200]}...")
|
|
start_vector = time.time()
|
|
query_vector = await self.api_service.get_embeddings(
|
|
request=request,
|
|
texts=[text],
|
|
upappid=service_params['embedding'],
|
|
apiname="BAAI/bge-m3",
|
|
user=userid
|
|
)
|
|
if not query_vector or not all(len(vec) == 1024 for vec in query_vector):
|
|
raise ValueError("查询向量必须是长度为 1024 的浮点数列表")
|
|
query_vector = query_vector[0]
|
|
timings["vector_generation"] = time.time() - start_vector
|
|
debug(f"生成查询向量耗时: {timings['vector_generation']:.3f} 秒")
|
|
return query_vector
|
|
|
|
async def vector_search(self, request, query_vector: List[float], orgid: str,
|
|
fiids: List[str], limit: int, service_params: Dict, userid: str,
|
|
timings: Dict) -> List[Dict]:
|
|
"""向量搜索"""
|
|
debug("开始向量搜索")
|
|
start_search = time.time()
|
|
result = await self.api_service.milvus_search_query(
|
|
request=request,
|
|
query_vector=query_vector,
|
|
userid=orgid,
|
|
knowledge_base_ids=fiids,
|
|
limit=limit,
|
|
offset=0,
|
|
upappid=service_params['vdb'],
|
|
apiname="mlvus/searchquery",
|
|
user=userid
|
|
)
|
|
|
|
if result.get("status") != "success":
|
|
raise ValueError(f"向量搜索失败: {result.get('message', '未知错误')}")
|
|
|
|
search_results = result.get("results", [])
|
|
timings["vector_search"] = time.time() - start_search
|
|
debug(f"向量搜索耗时: {timings['vector_search']:.3f} 秒")
|
|
debug(f"从向量数据中搜索到{len(search_results)}条数据")
|
|
return search_results
|
|
|
|
async def rerank_results(self, request, query: str, results: List[Dict], top_n: int,
|
|
service_params: Dict, userid: str, timings: Dict) -> List[Dict]:
|
|
"""重排序结果"""
|
|
debug("开始重排序")
|
|
start_rerank = time.time()
|
|
reranked_results = await self.api_service.rerank_results(
|
|
request=request,
|
|
query=query,
|
|
results=results,
|
|
top_n=top_n,
|
|
upappid=service_params['reranker'],
|
|
apiname="BAAI/bge-reranker-v2-m3",
|
|
user=userid
|
|
)
|
|
reranked_results = sorted(reranked_results, key=lambda x: x.get('rerank_score', 0), reverse=True)
|
|
timings["reranking"] = time.time() - start_rerank
|
|
debug(f"重排序耗时: {timings['reranking']:.3f} 秒")
|
|
debug(f"重排序分数分布: {[round(r.get('rerank_score', 0), 3) for r in reranked_results]}")
|
|
return reranked_results
|
|
|
|
def _deduplicate_triples(self, triples: List[Dict]) -> List[Dict]:
|
|
"""去重和优化三元组"""
|
|
unique_triples = []
|
|
seen = set()
|
|
|
|
for t in triples:
|
|
identifier = (t['head'].lower(), t['tail'].lower(), t['type'].lower())
|
|
if identifier not in seen:
|
|
seen.add(identifier)
|
|
unique_triples.append(t)
|
|
else:
|
|
# 如果发现更具体的类型,则替换
|
|
for existing in unique_triples:
|
|
if (existing['head'].lower() == t['head'].lower() and
|
|
existing['tail'].lower() == t['tail'].lower() and
|
|
len(t['type']) > len(existing['type'])):
|
|
unique_triples.remove(existing)
|
|
unique_triples.append(t)
|
|
debug(f"替换三元组为更具体类型: {t}")
|
|
break
|
|
|
|
return unique_triples
|
|
|
|
def format_search_results(self, results: List[Dict], limit: int) -> List[Dict]:
|
|
"""格式化搜索结果为统一格式"""
|
|
formatted_results = []
|
|
# for res in results[:limit]:
|
|
# score = res.get('rerank_score', res.get('distance', 0))
|
|
#
|
|
# content = res.get('text', '')
|
|
# title = res.get('metadata', {}).get('filename', 'Untitled')
|
|
# document_id = res.get('metadata', {}).get('document_id', '')
|
|
#
|
|
# formatted_results.append({
|
|
# "content": content,
|
|
# "title": title,
|
|
# "metadata": {"document_id": document_id, "score": score},
|
|
# })
|
|
#得分归一化
|
|
for res in results[:limit]:
|
|
rerank_score = res.get('rerank_score', 0)
|
|
score = 1 / (1 + math.exp(-rerank_score)) if rerank_score is not None else 1 - res.get('distance', 0)
|
|
score = max(0.0, min(1.0, score))
|
|
|
|
content = res.get('text', '')
|
|
title = res.get('metadata', {}).get('filename', 'Untitled')
|
|
document_id = res.get('metadata', {}).get('document_id', '')
|
|
|
|
formatted_results.append({
|
|
"content": content,
|
|
"title": title,
|
|
"metadata": {"document_id": document_id, "score": score},
|
|
})
|
|
|
|
return formatted_results |