This commit is contained in:
wangmeihua 2025-09-15 15:31:04 +08:00
parent 97c9e0f1fa
commit 043fb80ed4
2 changed files with 53 additions and 99 deletions

View File

@ -1,8 +1,3 @@
"""
RAG 操作的通用函数库
包含文档处理搜索嵌入等通用操作 folderinfo.py ragapi.py 共同使用
"""
import os import os
import re import re
import time import time
@ -290,9 +285,10 @@ class RagOperations:
return nodes_deleted, rels_deleted return nodes_deleted, rels_deleted
async def extract_entities(self, request, query: str, service_params: Dict, userid: str, async def extract_entities(self, request, query: str, service_params: Dict, userid: str,
transaction_mgr: TransactionManager = None) -> List[str]: timings: Dict) -> List[str]:
"""提取实体""" """提取实体"""
debug(f"提取查询实体: {query}") debug(f"提取查询实体: {query}")
start_extract = time.time()
entities = await self.api_service.extract_entities( entities = await self.api_service.extract_entities(
request=request, request=request,
query=query, query=query,
@ -300,21 +296,16 @@ class RagOperations:
apiname="LTP/small", apiname="LTP/small",
user=userid user=userid
) )
timings["entity_extraction"] = time.time() - start_extract
# 记录事务操作 debug(f"提取实体: {entities}, 耗时: {timings['entity_extraction']:.3f}")
if transaction_mgr:
transaction_mgr.add_operation(
OperationType.ENTITY_EXTRACT,
{'query': query, 'entities_count': len(entities)}
)
return entities return entities
async def match_triplets(self, request, query: str, entities: List[str], orgid: str, async def match_triplets(self, request, query: str, entities: List[str], orgid: str,
fiids: List[str], service_params: Dict, userid: str, fiids: List[str], service_params: Dict, userid: str,
transaction_mgr: TransactionManager = None) -> List[Dict]: timings: Dict) -> List[Dict]:
"""匹配三元组""" """匹配三元组"""
debug("开始三元组匹配") debug("开始三元组匹配")
start_triplet = time.time()
all_triplets = [] all_triplets = []
for kb_id in fiids: for kb_id in fiids:
@ -335,26 +326,40 @@ class RagOperations:
all_triplets.extend(triplets) all_triplets.extend(triplets)
debug(f"知识库 {kb_id} 匹配到 {len(triplets)} 个三元组") debug(f"知识库 {kb_id} 匹配到 {len(triplets)} 个三元组")
else: else:
error( error(f"Neo4j 三元组匹配失败: knowledge_base_id={kb_id}, 错误: {neo4j_result.get('message', '未知错误')}")
f"Neo4j 三元组匹配失败: knowledge_base_id={kb_id}, 错误: {neo4j_result.get('message', '未知错误')}")
except Exception as e: except Exception as e:
error(f"Neo4j 三元组匹配失败: knowledge_base_id={kb_id}, 错误: {str(e)}") error(f"Neo4j 三元组匹配失败: knowledge_base_id={kb_id}, 错误: {str(e)}")
continue continue
# 记录事务操作 timings["triplet_matching"] = time.time() - start_triplet
if transaction_mgr: debug(f"三元组匹配总耗时: {timings['triplet_matching']:.3f}")
transaction_mgr.add_operation(
OperationType.TRIPLET_MATCH,
{'query': query, 'triplets_count': len(all_triplets)}
)
return all_triplets return all_triplets
async def generate_query_vector(self, request, text: str, service_params: Dict,
userid: str, timings: Dict) -> List[float]:
"""生成查询向量"""
debug(f"生成查询向量: {text[:200]}...")
start_vector = time.time()
query_vector = await self.api_service.get_embeddings(
request=request,
texts=[text],
upappid=service_params['embedding'],
apiname="BAAI/bge-m3",
user=userid
)
if not query_vector or not all(len(vec) == 1024 for vec in query_vector):
raise ValueError("查询向量必须是长度为 1024 的浮点数列表")
query_vector = query_vector[0]
timings["vector_generation"] = time.time() - start_vector
debug(f"生成查询向量耗时: {timings['vector_generation']:.3f}")
return query_vector
async def vector_search(self, request, query_vector: List[float], orgid: str, async def vector_search(self, request, query_vector: List[float], orgid: str,
fiids: List[str], limit: int, service_params: Dict, userid: str, fiids: List[str], limit: int, service_params: Dict, userid: str,
transaction_mgr: TransactionManager = None) -> List[Dict]: timings: Dict) -> List[Dict]:
"""向量搜索""" """向量搜索"""
debug("开始向量搜索") debug("开始向量搜索")
start_search = time.time()
result = await self.api_service.milvus_search_query( result = await self.api_service.milvus_search_query(
request=request, request=request,
query_vector=query_vector, query_vector=query_vector,
@ -371,21 +376,16 @@ class RagOperations:
raise ValueError(f"向量搜索失败: {result.get('message', '未知错误')}") raise ValueError(f"向量搜索失败: {result.get('message', '未知错误')}")
search_results = result.get("results", []) search_results = result.get("results", [])
timings["vector_search"] = time.time() - start_search
# 记录事务操作 debug(f"向量搜索耗时: {timings['vector_search']:.3f}")
if transaction_mgr: debug(f"从向量数据中搜索到{len(search_results)}条数据")
transaction_mgr.add_operation(
OperationType.VECTOR_SEARCH,
{'results_count': len(search_results)}
)
return search_results return search_results
async def rerank_results(self, request, query: str, results: List[Dict], top_n: int, async def rerank_results(self, request, query: str, results: List[Dict], top_n: int,
service_params: Dict, userid: str, service_params: Dict, userid: str, timings: Dict) -> List[Dict]:
transaction_mgr: TransactionManager = None) -> List[Dict]:
"""重排序结果""" """重排序结果"""
debug("开始重排序") debug("开始重排序")
start_rerank = time.time()
reranked_results = await self.api_service.rerank_results( reranked_results = await self.api_service.rerank_results(
request=request, request=request,
query=query, query=query,
@ -395,14 +395,10 @@ class RagOperations:
apiname="BAAI/bge-reranker-v2-m3", apiname="BAAI/bge-reranker-v2-m3",
user=userid user=userid
) )
reranked_results = sorted(reranked_results, key=lambda x: x.get('rerank_score', 0), reverse=True)
# 记录事务操作 timings["reranking"] = time.time() - start_rerank
if transaction_mgr: debug(f"重排序耗时: {timings['reranking']:.3f}")
transaction_mgr.add_operation( debug(f"重排序分数分布: {[round(r.get('rerank_score', 0), 3) for r in reranked_results]}")
OperationType.RERANK,
{'input_count': len(results), 'output_count': len(reranked_results)}
)
return reranked_results return reranked_results
def _deduplicate_triples(self, triples: List[Dict]) -> List[Dict]: def _deduplicate_triples(self, triples: List[Dict]) -> List[Dict]:

View File

@ -139,80 +139,38 @@ async def fusedsearch(request, params_kw, *params, **kw):
raise ValueError("无法获取服务参数") raise ValueError("无法获取服务参数")
try: try:
timing_stats = {} timings = {}
start_time = time.time() start_time = time.time()
rag_ops = RagOperations() rag_ops = RagOperations()
entity_extract_start = time.time() query_entities = await rag_ops.extract_entities(request, query, service_params, userid, timings)
query_entities = await rag_ops.extract_entities(request, query, service_params, userid) all_triplets = await rag_ops.match_triplets(request, query, query_entities, orgid, fiids, service_params,
timing_stats["entity_extraction"] = time.time() - entity_extract_start userid, timings)
debug(f"提取实体: {query_entities}, 耗时: {timing_stats['entity_extraction']:.3f}")
triplet_match_start = time.time()
all_triplets = await rag_ops.match_triplets(request, query, query_entities, orgid, fiids, service_params, userid)
timing_stats["triplet_matching"] = time.time() - triplet_match_start
debug(f"三元组匹配总耗时: {timing_stats['triplet_matching']:.3f}")
triplet_text_start = time.time()
combined_text = _combine_query_with_triplets(query, all_triplets) combined_text = _combine_query_with_triplets(query, all_triplets)
timing_stats["triplet_text_combine"] = time.time() - triplet_text_start query_vector = await rag_ops.generate_query_vector(request, combined_text, service_params, userid, timings)
debug(f"拼接三元组文本耗时: {timing_stats['triplet_text_combine']:.3f}") search_results = await rag_ops.vector_search(request, query_vector, orgid, fiids, limit + 5, service_params,
userid, timings)
vector_start = time.time()
query_vector = await rag_ops.api_service.get_embeddings(
request=request,
texts=[combined_text],
upappid=service_params['embedding'],
apiname="BAAI/bge-m3",
user=userid
)
if not query_vector or not all(len(vec) == 1024 for vec in query_vector):
raise ValueError("查询向量必须是长度为 1024 的浮点数列表")
query_vector = query_vector[0]
timing_stats["vector_generation"] = time.time() - vector_start
debug(f"生成查询向量耗时: {timing_stats['vector_generation']:.3f}")
search_start = time.time()
search_limit = limit + 5
search_results = await rag_ops.vector_search(
request, query_vector, orgid, fiids, search_limit, service_params, userid
)
timing_stats["vector_search"] = time.time() - search_start
debug(f"向量搜索耗时: {timing_stats['vector_search']:.3f}")
debug(f"从向量数据中搜索到{len(search_results)}条数据")
# 步骤6: 重排序(可选)
use_rerank = True use_rerank = True
if use_rerank and search_results: if use_rerank and search_results:
rerank_start = time.time() final_results = await rag_ops.rerank_results(request, combined_text, search_results, limit, service_params,
debug("开始重排序") userid, timings)
reranked_results = await rag_ops.rerank_results(
request, combined_text, search_results, limit, service_params, userid
)
reranked_results = sorted(reranked_results, key=lambda x: x.get('rerank_score', 0), reverse=True)
timing_stats["reranking"] = time.time() - rerank_start
debug(f"重排序耗时: {timing_stats['reranking']:.3f}")
debug(f"重排序分数分布: {[round(r.get('rerank_score', 0), 3) for r in reranked_results]}")
final_results = reranked_results
else: else:
final_results = [{k: v for k, v in r.items() if k != 'rerank_score'} for r in search_results] final_results = [{k: v for k, v in r.items() if k != 'rerank_score'} for r in search_results]
timing_stats["total_time"] = time.time() - start_time
info(f"融合搜索完成,返回 {len(final_results)} 条结果,总耗时: {timing_stats['total_time']:.3f}")
formatted_results = rag_ops.format_search_results(final_results, limit) formatted_results = rag_ops.format_search_results(final_results, limit)
info(f"融合搜索完成,返回 {len(formatted_results)} 条结果") timings["total_time"] = time.time() - start_time
info(f"融合搜索完成,返回 {len(formatted_results)} 条结果,总耗时: {timings['total_time']:.3f}")
return { return {
"records": formatted_results "records": formatted_results,
"timings": timings
} }
except Exception as e: except Exception as e:
error(f"融合搜索失败: {str(e)}, 堆栈: {traceback.format_exc()}") error(f"融合搜索失败: {str(e)}, 堆栈: {traceback.format_exc()}")
# 事务管理器会自动执行回滚
return { return {
"records": [], "records": [],
"timing": {"total_time": time.time() - start_time if 'start_time' in locals() else 0}, "timings": {"total_time": time.time() - start_time if 'start_time' in locals() else 0},
"error": str(e) "error": str(e)
} }