diff --git a/rag/rag_operations.py b/rag/rag_operations.py index 493134a..3bd771b 100644 --- a/rag/rag_operations.py +++ b/rag/rag_operations.py @@ -1,8 +1,3 @@ -""" -RAG 操作的通用函数库 -包含文档处理、搜索、嵌入等通用操作,供 folderinfo.py 和 ragapi.py 共同使用 -""" - import os import re import time @@ -290,9 +285,10 @@ class RagOperations: return nodes_deleted, rels_deleted async def extract_entities(self, request, query: str, service_params: Dict, userid: str, - transaction_mgr: TransactionManager = None) -> List[str]: + timings: Dict) -> List[str]: """提取实体""" debug(f"提取查询实体: {query}") + start_extract = time.time() entities = await self.api_service.extract_entities( request=request, query=query, @@ -300,21 +296,16 @@ class RagOperations: apiname="LTP/small", user=userid ) - - # 记录事务操作 - if transaction_mgr: - transaction_mgr.add_operation( - OperationType.ENTITY_EXTRACT, - {'query': query, 'entities_count': len(entities)} - ) - + timings["entity_extraction"] = time.time() - start_extract + debug(f"提取实体: {entities}, 耗时: {timings['entity_extraction']:.3f} 秒") return entities async def match_triplets(self, request, query: str, entities: List[str], orgid: str, fiids: List[str], service_params: Dict, userid: str, - transaction_mgr: TransactionManager = None) -> List[Dict]: + timings: Dict) -> List[Dict]: """匹配三元组""" debug("开始三元组匹配") + start_triplet = time.time() all_triplets = [] for kb_id in fiids: @@ -335,26 +326,40 @@ class RagOperations: all_triplets.extend(triplets) debug(f"知识库 {kb_id} 匹配到 {len(triplets)} 个三元组") else: - error( - f"Neo4j 三元组匹配失败: knowledge_base_id={kb_id}, 错误: {neo4j_result.get('message', '未知错误')}") + error(f"Neo4j 三元组匹配失败: knowledge_base_id={kb_id}, 错误: {neo4j_result.get('message', '未知错误')}") except Exception as e: error(f"Neo4j 三元组匹配失败: knowledge_base_id={kb_id}, 错误: {str(e)}") continue - # 记录事务操作 - if transaction_mgr: - transaction_mgr.add_operation( - OperationType.TRIPLET_MATCH, - {'query': query, 'triplets_count': len(all_triplets)} - ) - + timings["triplet_matching"] = time.time() - start_triplet + debug(f"三元组匹配总耗时: {timings['triplet_matching']:.3f} 秒") return all_triplets + async def generate_query_vector(self, request, text: str, service_params: Dict, + userid: str, timings: Dict) -> List[float]: + """生成查询向量""" + debug(f"生成查询向量: {text[:200]}...") + start_vector = time.time() + query_vector = await self.api_service.get_embeddings( + request=request, + texts=[text], + upappid=service_params['embedding'], + apiname="BAAI/bge-m3", + user=userid + ) + if not query_vector or not all(len(vec) == 1024 for vec in query_vector): + raise ValueError("查询向量必须是长度为 1024 的浮点数列表") + query_vector = query_vector[0] + timings["vector_generation"] = time.time() - start_vector + debug(f"生成查询向量耗时: {timings['vector_generation']:.3f} 秒") + return query_vector + async def vector_search(self, request, query_vector: List[float], orgid: str, fiids: List[str], limit: int, service_params: Dict, userid: str, - transaction_mgr: TransactionManager = None) -> List[Dict]: + timings: Dict) -> List[Dict]: """向量搜索""" debug("开始向量搜索") + start_search = time.time() result = await self.api_service.milvus_search_query( request=request, query_vector=query_vector, @@ -371,21 +376,16 @@ class RagOperations: raise ValueError(f"向量搜索失败: {result.get('message', '未知错误')}") search_results = result.get("results", []) - - # 记录事务操作 - if transaction_mgr: - transaction_mgr.add_operation( - OperationType.VECTOR_SEARCH, - {'results_count': len(search_results)} - ) - + timings["vector_search"] = time.time() - start_search + debug(f"向量搜索耗时: {timings['vector_search']:.3f} 秒") + debug(f"从向量数据中搜索到{len(search_results)}条数据") return search_results async def rerank_results(self, request, query: str, results: List[Dict], top_n: int, - service_params: Dict, userid: str, - transaction_mgr: TransactionManager = None) -> List[Dict]: + service_params: Dict, userid: str, timings: Dict) -> List[Dict]: """重排序结果""" debug("开始重排序") + start_rerank = time.time() reranked_results = await self.api_service.rerank_results( request=request, query=query, @@ -395,14 +395,10 @@ class RagOperations: apiname="BAAI/bge-reranker-v2-m3", user=userid ) - - # 记录事务操作 - if transaction_mgr: - transaction_mgr.add_operation( - OperationType.RERANK, - {'input_count': len(results), 'output_count': len(reranked_results)} - ) - + reranked_results = sorted(reranked_results, key=lambda x: x.get('rerank_score', 0), reverse=True) + timings["reranking"] = time.time() - start_rerank + debug(f"重排序耗时: {timings['reranking']:.3f} 秒") + debug(f"重排序分数分布: {[round(r.get('rerank_score', 0), 3) for r in reranked_results]}") return reranked_results def _deduplicate_triples(self, triples: List[Dict]) -> List[Dict]: diff --git a/rag/ragapi.py b/rag/ragapi.py index bbc8084..dca0925 100644 --- a/rag/ragapi.py +++ b/rag/ragapi.py @@ -139,80 +139,38 @@ async def fusedsearch(request, params_kw, *params, **kw): raise ValueError("无法获取服务参数") try: - timing_stats = {} + timings = {} start_time = time.time() rag_ops = RagOperations() - entity_extract_start = time.time() - query_entities = await rag_ops.extract_entities(request, query, service_params, userid) - timing_stats["entity_extraction"] = time.time() - entity_extract_start - debug(f"提取实体: {query_entities}, 耗时: {timing_stats['entity_extraction']:.3f} 秒") - - triplet_match_start = time.time() - all_triplets = await rag_ops.match_triplets(request, query, query_entities, orgid, fiids, service_params, userid) - timing_stats["triplet_matching"] = time.time() - triplet_match_start - debug(f"三元组匹配总耗时: {timing_stats['triplet_matching']:.3f} 秒") - - triplet_text_start = time.time() + query_entities = await rag_ops.extract_entities(request, query, service_params, userid, timings) + all_triplets = await rag_ops.match_triplets(request, query, query_entities, orgid, fiids, service_params, + userid, timings) combined_text = _combine_query_with_triplets(query, all_triplets) - timing_stats["triplet_text_combine"] = time.time() - triplet_text_start - debug(f"拼接三元组文本耗时: {timing_stats['triplet_text_combine']:.3f} 秒") + query_vector = await rag_ops.generate_query_vector(request, combined_text, service_params, userid, timings) + search_results = await rag_ops.vector_search(request, query_vector, orgid, fiids, limit + 5, service_params, + userid, timings) - vector_start = time.time() - query_vector = await rag_ops.api_service.get_embeddings( - request=request, - texts=[combined_text], - upappid=service_params['embedding'], - apiname="BAAI/bge-m3", - user=userid - ) - if not query_vector or not all(len(vec) == 1024 for vec in query_vector): - raise ValueError("查询向量必须是长度为 1024 的浮点数列表") - query_vector = query_vector[0] - timing_stats["vector_generation"] = time.time() - vector_start - debug(f"生成查询向量耗时: {timing_stats['vector_generation']:.3f} 秒") - - search_start = time.time() - search_limit = limit + 5 - search_results = await rag_ops.vector_search( - request, query_vector, orgid, fiids, search_limit, service_params, userid - ) - timing_stats["vector_search"] = time.time() - search_start - debug(f"向量搜索耗时: {timing_stats['vector_search']:.3f} 秒") - debug(f"从向量数据中搜索到{len(search_results)}条数据") - - # 步骤6: 重排序(可选) use_rerank = True if use_rerank and search_results: - rerank_start = time.time() - debug("开始重排序") - reranked_results = await rag_ops.rerank_results( - request, combined_text, search_results, limit, service_params, userid - ) - reranked_results = sorted(reranked_results, key=lambda x: x.get('rerank_score', 0), reverse=True) - timing_stats["reranking"] = time.time() - rerank_start - debug(f"重排序耗时: {timing_stats['reranking']:.3f} 秒") - debug(f"重排序分数分布: {[round(r.get('rerank_score', 0), 3) for r in reranked_results]}") - final_results = reranked_results + final_results = await rag_ops.rerank_results(request, combined_text, search_results, limit, service_params, + userid, timings) else: final_results = [{k: v for k, v in r.items() if k != 'rerank_score'} for r in search_results] - timing_stats["total_time"] = time.time() - start_time - info(f"融合搜索完成,返回 {len(final_results)} 条结果,总耗时: {timing_stats['total_time']:.3f} 秒") - formatted_results = rag_ops.format_search_results(final_results, limit) - info(f"融合搜索完成,返回 {len(formatted_results)} 条结果") + timings["total_time"] = time.time() - start_time + info(f"融合搜索完成,返回 {len(formatted_results)} 条结果,总耗时: {timings['total_time']:.3f} 秒") return { - "records": formatted_results + "records": formatted_results, + "timings": timings } - except Exception as e: error(f"融合搜索失败: {str(e)}, 堆栈: {traceback.format_exc()}") - # 事务管理器会自动执行回滚 return { "records": [], - "timing": {"total_time": time.time() - start_time if 'start_time' in locals() else 0}, + "timings": {"total_time": time.time() - start_time if 'start_time' in locals() else 0}, "error": str(e) }