import os import re import time import math from datetime import datetime from typing import List, Dict, Any, Optional from langchain_core.documents import Document from langchain_text_splitters import RecursiveCharacterTextSplitter from appPublic.log import debug, error, info from filetxt.loader import fileloader, File2Text from rag.uapi_service import APIService from rag.service_opts import get_service_params from rag.transaction_manager import TransactionManager, OperationType class RagOperations: """RAG 操作类,提供所有通用的 RAG 操作""" def __init__(self): self.api_service = APIService() async def load_and_chunk_document(self, realpath: str, timings: Dict, transaction_mgr: TransactionManager = None) -> List[Document]: """加载文件并进行文本分片""" debug(f"加载文件: {realpath}") start_load = time.time() # 检查文件格式支持 supported_formats = File2Text.supported_types() debug(f"支持的文件格式:{supported_formats}") ext = realpath.rsplit('.', 1)[1].lower() if '.' in realpath else '' if ext not in supported_formats: raise ValueError(f"不支持的文件格式: {ext}, 支持的格式: {', '.join(supported_formats)}") # 加载文件内容 text = fileloader(realpath) text = re.sub(r'[^\u4e00-\u9fa5a-zA-Z0-9\s.;,\n/]', '', text) timings["load_file"] = time.time() - start_load debug(f"加载文件耗时: {timings['load_file']:.2f} 秒, 文本长度: {len(text)}") if not text or not text.strip(): raise ValueError(f"文件 {realpath} 加载为空") # 分片处理 document = Document(page_content=text) text_splitter = RecursiveCharacterTextSplitter( chunk_size=500, chunk_overlap=100, length_function=len ) debug("开始分片文件内容") start_split = time.time() chunks = text_splitter.split_documents([document]) timings["split_text"] = time.time() - start_split debug(f"文本分片耗时: {timings['split_text']:.2f} 秒, 分片数量: {len(chunks)}") if not chunks: raise ValueError(f"文件 {realpath} 未生成任何文档块") # 记录事务操作 if transaction_mgr: transaction_mgr.add_operation( OperationType.FILE_LOAD, {'realpath': realpath, 'chunks_count': len(chunks)} ) return chunks async def generate_embeddings(self, request, chunks: List[Document], service_params: Dict, userid: str, timings: Dict, transaction_mgr: TransactionManager = None) -> List[List[float]]: """生成嵌入向量""" debug("调用嵌入服务生成向量") start_embedding = time.time() texts = [chunk.page_content for chunk in chunks] embeddings = [] # 批量处理嵌入 for i in range(0, len(texts), 10): batch_texts = texts[i:i + 10] batch_embeddings = await self.api_service.get_embeddings( request=request, texts=batch_texts, upappid=service_params['embedding'], apiname="BAAI/bge-m3", user=userid ) embeddings.extend(batch_embeddings) if not embeddings or not all(len(vec) == 1024 for vec in embeddings): raise ValueError("所有嵌入向量必须是长度为 1024 的浮点数列表") timings["generate_embeddings"] = time.time() - start_embedding debug(f"生成嵌入向量耗时: {timings['generate_embeddings']:.2f} 秒, 嵌入数量: {len(embeddings)}") # 记录事务操作 if transaction_mgr: transaction_mgr.add_operation( OperationType.EMBEDDING, {'embeddings_count': len(embeddings)} ) return embeddings async def insert_to_vector_db(self, request, chunks: List[Document], embeddings: List[List[float]], realpath: str, orgid: str, fiid: str, id: str, service_params: Dict, userid: str, db_type: str, timings: Dict, transaction_mgr: TransactionManager = None): """插入向量数据库""" debug(f"准备数据并调用插入文件端点: {realpath}") filename = os.path.basename(realpath).rsplit('.', 1)[0] ext = realpath.rsplit('.', 1)[1].lower() if '.' in realpath else '' upload_time = datetime.now().isoformat() chunks_data = [ { "userid": orgid, "knowledge_base_id": fiid, "text": chunk.page_content, "vector": embeddings[i], "document_id": id, "filename": filename + '.' + ext, "file_path": realpath, "upload_time": upload_time, "file_type": ext, } for i, chunk in enumerate(chunks) ] start_milvus = time.time() for i in range(0, len(chunks_data), 10): batch_chunks = chunks_data[i:i + 10] debug(f"传入的数据是:{batch_chunks}") result = await self.api_service.milvus_insert_document( request=request, chunks=batch_chunks, db_type=db_type, upappid=service_params['vdb'], apiname="milvus/insertdocument", user=userid ) if result.get("status") != "success": raise ValueError(result.get("message", "Milvus 插入失败")) timings["insert_milvus"] = time.time() - start_milvus debug(f"Milvus 插入耗时: {timings['insert_milvus']:.2f} 秒") # 记录事务操作,包含回滚函数 if transaction_mgr: async def rollback_vdb_insert(data, context): try: # 防御性检查 required_context = ['request', 'service_params', 'userid'] missing_context = [k for k in required_context if k not in context or context[k] is None] if missing_context: raise ValueError(f"回滚上下文缺少字段: {', '.join(missing_context)}") required_data = ['orgid', 'realpath', 'fiid', 'id', 'db_type'] missing_data = [k for k in required_data if k not in data or data[k] is None] if missing_data: raise ValueError(f"VDB_INSERT 数据缺少字段: {', '.join(missing_data)}") await self.delete_from_vector_db( context['request'], data['orgid'], data['realpath'], data['fiid'], data['id'], context['service_params'], context['userid'], data['db_type'] ) return f"已回滚向量数据库插入: {data['id']}" except Exception as e: error(f"回滚向量数据库失败: document_id={data.get('id', '未知')}, 错误: {str(e)}") raise transaction_mgr.add_operation( OperationType.VDB_INSERT, { 'orgid': orgid, 'realpath': realpath, 'fiid': fiid, 'id': id, 'db_type': db_type }, rollback_func=rollback_vdb_insert ) return chunks_data async def insert_to_vector_text(self, request, db_type: str, fields: Dict, service_params: Dict, userid: str, timings: Dict) -> List[Dict]: """插入单一纯文本到向量数据库,支持动态 schema""" chunk_data = {} debug("准备单一纯文本数据并调用插入端点") start = time.time() for key, value in fields.items(): chunk_data[key] = value chunks_data = [chunk_data] debug(f"向量库插入传入的数据是:{chunks_data}") # 调用 Milvus 插入 result = await self.api_service.milvus_insert_document( request=request, chunks=chunks_data, upappid=service_params['vdb'], apiname="milvus/insertdocument", user=userid, db_type=db_type ) if result.get("status") != "success": raise ValueError(result.get("message", "Milvus 插入失败")) debug(f"成功插入纯文本到集合 {result.get('collection_name')}") timings["textinsert"] = time.time() - start debug(f"插入纯文本耗时: {timings['textinsert']:.2f} 秒") return chunks_data async def extract_triples(self, request, chunks: List[Document], service_params: Dict, userid: str, timings: Dict, transaction_mgr: TransactionManager = None) -> List[Dict]: """抽取三元组""" debug("调用三元组抽取服务") start_triples = time.time() chunk_texts = [doc.page_content for doc in chunks] triples = [] for i, chunk in enumerate(chunk_texts): result = await self.api_service.extract_triples( request=request, text=chunk, upappid=service_params['triples'], apiname="Babelscape/mrebel-large", user=userid ) if isinstance(result, list): triples.extend(result) debug(f"分片 {i + 1} 抽取到 {len(result)} 个三元组") else: error(f"分片 {i + 1} 处理失败: {str(result)}") # 去重和优化三元组 unique_triples = self._deduplicate_triples(triples) timings["extract_triples"] = time.time() - start_triples debug(f"三元组抽取耗时: {timings['extract_triples']:.2f} 秒, 抽取到 {len(unique_triples)} 个三元组") # 记录事务操作 if transaction_mgr: transaction_mgr.add_operation( OperationType.TRIPLES_EXTRACT, {'triples_count': len(unique_triples)} ) return unique_triples async def insert_to_graph_db(self, request, triples: List[Dict], id: str, fiid: str, orgid: str, service_params: Dict, userid: str, timings: Dict, transaction_mgr: TransactionManager = None): """插入图数据库""" debug(f"插入 {len(triples)} 个三元组到 Neo4j") start_neo4j = time.time() if triples: for i in range(0, len(triples), 30): batch_triples = triples[i:i + 30] neo4j_result = await self.api_service.neo4j_insert_triples( request=request, triples=batch_triples, document_id=id, knowledge_base_id=fiid, userid=orgid, upappid=service_params['gdb'], apiname="neo4j/inserttriples", user=userid ) if neo4j_result.get("status") != "success": raise ValueError(f"Neo4j 三元组插入失败: {neo4j_result.get('message', '未知错误')}") info(f"文件三元组成功插入 Neo4j: {neo4j_result.get('message')}") timings["insert_neo4j"] = time.time() - start_neo4j debug(f"Neo4j 插入耗时: {timings['insert_neo4j']:.2f} 秒") else: debug("未抽取到三元组") timings["insert_neo4j"] = 0.0 # 记录事务操作,包含回滚函数 if transaction_mgr: async def rollback_gdb_insert(data, context): await self.delete_from_graph_db( context['request'], data['id'], context['service_params'], context['userid'] ) return f"已回滚图数据库插入: {data['id']}" transaction_mgr.add_operation( OperationType.GDB_INSERT, {'id': id, 'triples_count': len(triples)}, rollback_func=rollback_gdb_insert ) async def delete_from_vector_db(self, request, orgid: str, realpath: str, fiid: str, id: str, service_params: Dict, userid: str, db_type: str): """从向量数据库删除文档""" debug(f"调用删除文件端点: userid={orgid}, file_path={realpath}, knowledge_base_id={fiid}, document_id={id}") milvus_result = await self.api_service.milvus_delete_document( request=request, userid=orgid, file_path=realpath, knowledge_base_id=fiid, document_id=id, db_type=db_type, upappid=service_params['vdb'], apiname="milvus/deletedocument", user=userid ) if milvus_result.get("status") != "success": raise ValueError(milvus_result.get("message", "Milvus 删除失败")) async def delete_from_graph_db(self, request, id: str, service_params: Dict, userid: str): """从图数据库删除文档""" debug(f"调用 Neo4j 删除文档端点: document_id={id}") neo4j_result = await self.api_service.neo4j_delete_document( request=request, document_id=id, upappid=service_params['gdb'], apiname="neo4j/deletedocument", user=userid ) if neo4j_result.get("status") != "success": raise ValueError(neo4j_result.get("message", "Neo4j 删除失败")) nodes_deleted = neo4j_result.get("nodes_deleted", 0) rels_deleted = neo4j_result.get("rels_deleted", 0) info(f"成功删除 document_id={id} 的 {nodes_deleted} 个 Neo4j 节点和 {rels_deleted} 个关系") return nodes_deleted, rels_deleted async def extract_entities(self, request, query: str, service_params: Dict, userid: str, timings: Dict) -> List[str]: """提取实体""" debug(f"提取查询实体: {query}") start_extract = time.time() entities = await self.api_service.extract_entities( request=request, query=query, upappid=service_params['entities'], apiname="LTP/small", user=userid ) timings["entity_extraction"] = time.time() - start_extract debug(f"提取实体: {entities}, 耗时: {timings['entity_extraction']:.3f} 秒") return entities async def match_triplets(self, request, query: str, entities: List[str], orgid: str, fiids: List[str], service_params: Dict, userid: str, timings: Dict) -> List[Dict]: """匹配三元组""" debug("开始三元组匹配") start_triplet = time.time() all_triplets = [] for kb_id in fiids: debug(f"调用 Neo4j 三元组匹配: knowledge_base_id={kb_id}") try: neo4j_result = await self.api_service.neo4j_match_triplets( request=request, query=query, query_entities=entities, userid=orgid, knowledge_base_id=kb_id, upappid=service_params['gdb'], apiname="neo4j/matchtriplets", user=userid ) if neo4j_result.get("status") == "success": triplets = neo4j_result.get("triplets", []) all_triplets.extend(triplets) debug(f"知识库 {kb_id} 匹配到 {len(triplets)} 个三元组") else: error(f"Neo4j 三元组匹配失败: knowledge_base_id={kb_id}, 错误: {neo4j_result.get('message', '未知错误')}") except Exception as e: error(f"Neo4j 三元组匹配失败: knowledge_base_id={kb_id}, 错误: {str(e)}") continue timings["triplet_matching"] = time.time() - start_triplet debug(f"三元组匹配总耗时: {timings['triplet_matching']:.3f} 秒") return all_triplets async def generate_query_vector(self, request, text: str, service_params: Dict, userid: str, timings: Dict) -> List[float]: """生成查询向量""" debug(f"生成查询向量: {text[:200]}...") start_vector = time.time() query_vector = await self.api_service.get_embeddings( request=request, texts=[text], upappid=service_params['embedding'], apiname="BAAI/bge-m3", user=userid ) if not query_vector or not all(len(vec) == 1024 for vec in query_vector): raise ValueError("查询向量必须是长度为 1024 的浮点数列表") query_vector = query_vector[0] timings["vector_generation"] = time.time() - start_vector debug(f"生成查询向量耗时: {timings['vector_generation']:.3f} 秒") return query_vector async def vector_search(self, request, query_vector: List[float], orgid: str, fiids: List[str], limit: int, service_params: Dict, userid: str, timings: Dict) -> List[Dict]: """向量搜索""" debug("开始向量搜索") start_search = time.time() result = await self.api_service.milvus_search_query( request=request, query_vector=query_vector, userid=orgid, knowledge_base_ids=fiids, limit=limit, offset=0, upappid=service_params['vdb'], apiname="mlvus/searchquery", user=userid ) if result.get("status") != "success": raise ValueError(f"向量搜索失败: {result.get('message', '未知错误')}") search_results = result.get("results", []) timings["vector_search"] = time.time() - start_search debug(f"向量搜索耗时: {timings['vector_search']:.3f} 秒") debug(f"从向量数据中搜索到{len(search_results)}条数据") return search_results async def rerank_results(self, request, query: str, results: List[Dict], top_n: int, service_params: Dict, userid: str, timings: Dict) -> List[Dict]: """重排序结果""" debug("开始重排序") start_rerank = time.time() reranked_results = await self.api_service.rerank_results( request=request, query=query, results=results, top_n=top_n, upappid=service_params['reranker'], apiname="BAAI/bge-reranker-v2-m3", user=userid ) reranked_results = sorted(reranked_results, key=lambda x: x.get('rerank_score', 0), reverse=True) timings["reranking"] = time.time() - start_rerank debug(f"重排序耗时: {timings['reranking']:.3f} 秒") debug(f"重排序分数分布: {[round(r.get('rerank_score', 0), 3) for r in reranked_results]}") return reranked_results def _deduplicate_triples(self, triples: List[Dict]) -> List[Dict]: """去重和优化三元组""" unique_triples = [] seen = set() for t in triples: identifier = (t['head'].lower(), t['tail'].lower(), t['type'].lower()) if identifier not in seen: seen.add(identifier) unique_triples.append(t) else: # 如果发现更具体的类型,则替换 for existing in unique_triples: if (existing['head'].lower() == t['head'].lower() and existing['tail'].lower() == t['tail'].lower() and len(t['type']) > len(existing['type'])): unique_triples.remove(existing) unique_triples.append(t) debug(f"替换三元组为更具体类型: {t}") break return unique_triples def format_search_results(self, results: List[Dict], limit: int) -> List[Dict]: """格式化搜索结果为统一格式""" formatted_results = [] # for res in results[:limit]: # score = res.get('rerank_score', res.get('distance', 0)) # # content = res.get('text', '') # title = res.get('metadata', {}).get('filename', 'Untitled') # document_id = res.get('metadata', {}).get('document_id', '') # # formatted_results.append({ # "content": content, # "title": title, # "metadata": {"document_id": document_id, "score": score}, # }) #得分归一化 for res in results[:limit]: rerank_score = res.get('rerank_score', 0) score = 1 / (1 + math.exp(-rerank_score)) if rerank_score is not None else 1 - res.get('distance', 0) score = max(0.0, min(1.0, score)) content = res.get('text', '') title = res.get('metadata', {}).get('filename', 'Untitled') document_id = res.get('metadata', {}).get('document_id', '') formatted_results.append({ "content": content, "title": title, "metadata": {"document_id": document_id, "score": score}, }) return formatted_results