rag
This commit is contained in:
parent
97c9e0f1fa
commit
043fb80ed4
@ -1,8 +1,3 @@
|
||||
"""
|
||||
RAG 操作的通用函数库
|
||||
包含文档处理、搜索、嵌入等通用操作,供 folderinfo.py 和 ragapi.py 共同使用
|
||||
"""
|
||||
|
||||
import os
|
||||
import re
|
||||
import time
|
||||
@ -290,9 +285,10 @@ class RagOperations:
|
||||
return nodes_deleted, rels_deleted
|
||||
|
||||
async def extract_entities(self, request, query: str, service_params: Dict, userid: str,
|
||||
transaction_mgr: TransactionManager = None) -> List[str]:
|
||||
timings: Dict) -> List[str]:
|
||||
"""提取实体"""
|
||||
debug(f"提取查询实体: {query}")
|
||||
start_extract = time.time()
|
||||
entities = await self.api_service.extract_entities(
|
||||
request=request,
|
||||
query=query,
|
||||
@ -300,21 +296,16 @@ class RagOperations:
|
||||
apiname="LTP/small",
|
||||
user=userid
|
||||
)
|
||||
|
||||
# 记录事务操作
|
||||
if transaction_mgr:
|
||||
transaction_mgr.add_operation(
|
||||
OperationType.ENTITY_EXTRACT,
|
||||
{'query': query, 'entities_count': len(entities)}
|
||||
)
|
||||
|
||||
timings["entity_extraction"] = time.time() - start_extract
|
||||
debug(f"提取实体: {entities}, 耗时: {timings['entity_extraction']:.3f} 秒")
|
||||
return entities
|
||||
|
||||
async def match_triplets(self, request, query: str, entities: List[str], orgid: str,
|
||||
fiids: List[str], service_params: Dict, userid: str,
|
||||
transaction_mgr: TransactionManager = None) -> List[Dict]:
|
||||
timings: Dict) -> List[Dict]:
|
||||
"""匹配三元组"""
|
||||
debug("开始三元组匹配")
|
||||
start_triplet = time.time()
|
||||
all_triplets = []
|
||||
|
||||
for kb_id in fiids:
|
||||
@ -335,26 +326,40 @@ class RagOperations:
|
||||
all_triplets.extend(triplets)
|
||||
debug(f"知识库 {kb_id} 匹配到 {len(triplets)} 个三元组")
|
||||
else:
|
||||
error(
|
||||
f"Neo4j 三元组匹配失败: knowledge_base_id={kb_id}, 错误: {neo4j_result.get('message', '未知错误')}")
|
||||
error(f"Neo4j 三元组匹配失败: knowledge_base_id={kb_id}, 错误: {neo4j_result.get('message', '未知错误')}")
|
||||
except Exception as e:
|
||||
error(f"Neo4j 三元组匹配失败: knowledge_base_id={kb_id}, 错误: {str(e)}")
|
||||
continue
|
||||
|
||||
# 记录事务操作
|
||||
if transaction_mgr:
|
||||
transaction_mgr.add_operation(
|
||||
OperationType.TRIPLET_MATCH,
|
||||
{'query': query, 'triplets_count': len(all_triplets)}
|
||||
)
|
||||
|
||||
timings["triplet_matching"] = time.time() - start_triplet
|
||||
debug(f"三元组匹配总耗时: {timings['triplet_matching']:.3f} 秒")
|
||||
return all_triplets
|
||||
|
||||
async def generate_query_vector(self, request, text: str, service_params: Dict,
|
||||
userid: str, timings: Dict) -> List[float]:
|
||||
"""生成查询向量"""
|
||||
debug(f"生成查询向量: {text[:200]}...")
|
||||
start_vector = time.time()
|
||||
query_vector = await self.api_service.get_embeddings(
|
||||
request=request,
|
||||
texts=[text],
|
||||
upappid=service_params['embedding'],
|
||||
apiname="BAAI/bge-m3",
|
||||
user=userid
|
||||
)
|
||||
if not query_vector or not all(len(vec) == 1024 for vec in query_vector):
|
||||
raise ValueError("查询向量必须是长度为 1024 的浮点数列表")
|
||||
query_vector = query_vector[0]
|
||||
timings["vector_generation"] = time.time() - start_vector
|
||||
debug(f"生成查询向量耗时: {timings['vector_generation']:.3f} 秒")
|
||||
return query_vector
|
||||
|
||||
async def vector_search(self, request, query_vector: List[float], orgid: str,
|
||||
fiids: List[str], limit: int, service_params: Dict, userid: str,
|
||||
transaction_mgr: TransactionManager = None) -> List[Dict]:
|
||||
timings: Dict) -> List[Dict]:
|
||||
"""向量搜索"""
|
||||
debug("开始向量搜索")
|
||||
start_search = time.time()
|
||||
result = await self.api_service.milvus_search_query(
|
||||
request=request,
|
||||
query_vector=query_vector,
|
||||
@ -371,21 +376,16 @@ class RagOperations:
|
||||
raise ValueError(f"向量搜索失败: {result.get('message', '未知错误')}")
|
||||
|
||||
search_results = result.get("results", [])
|
||||
|
||||
# 记录事务操作
|
||||
if transaction_mgr:
|
||||
transaction_mgr.add_operation(
|
||||
OperationType.VECTOR_SEARCH,
|
||||
{'results_count': len(search_results)}
|
||||
)
|
||||
|
||||
timings["vector_search"] = time.time() - start_search
|
||||
debug(f"向量搜索耗时: {timings['vector_search']:.3f} 秒")
|
||||
debug(f"从向量数据中搜索到{len(search_results)}条数据")
|
||||
return search_results
|
||||
|
||||
async def rerank_results(self, request, query: str, results: List[Dict], top_n: int,
|
||||
service_params: Dict, userid: str,
|
||||
transaction_mgr: TransactionManager = None) -> List[Dict]:
|
||||
service_params: Dict, userid: str, timings: Dict) -> List[Dict]:
|
||||
"""重排序结果"""
|
||||
debug("开始重排序")
|
||||
start_rerank = time.time()
|
||||
reranked_results = await self.api_service.rerank_results(
|
||||
request=request,
|
||||
query=query,
|
||||
@ -395,14 +395,10 @@ class RagOperations:
|
||||
apiname="BAAI/bge-reranker-v2-m3",
|
||||
user=userid
|
||||
)
|
||||
|
||||
# 记录事务操作
|
||||
if transaction_mgr:
|
||||
transaction_mgr.add_operation(
|
||||
OperationType.RERANK,
|
||||
{'input_count': len(results), 'output_count': len(reranked_results)}
|
||||
)
|
||||
|
||||
reranked_results = sorted(reranked_results, key=lambda x: x.get('rerank_score', 0), reverse=True)
|
||||
timings["reranking"] = time.time() - start_rerank
|
||||
debug(f"重排序耗时: {timings['reranking']:.3f} 秒")
|
||||
debug(f"重排序分数分布: {[round(r.get('rerank_score', 0), 3) for r in reranked_results]}")
|
||||
return reranked_results
|
||||
|
||||
def _deduplicate_triples(self, triples: List[Dict]) -> List[Dict]:
|
||||
|
||||
@ -139,80 +139,38 @@ async def fusedsearch(request, params_kw, *params, **kw):
|
||||
raise ValueError("无法获取服务参数")
|
||||
|
||||
try:
|
||||
timing_stats = {}
|
||||
timings = {}
|
||||
start_time = time.time()
|
||||
rag_ops = RagOperations()
|
||||
|
||||
entity_extract_start = time.time()
|
||||
query_entities = await rag_ops.extract_entities(request, query, service_params, userid)
|
||||
timing_stats["entity_extraction"] = time.time() - entity_extract_start
|
||||
debug(f"提取实体: {query_entities}, 耗时: {timing_stats['entity_extraction']:.3f} 秒")
|
||||
|
||||
triplet_match_start = time.time()
|
||||
all_triplets = await rag_ops.match_triplets(request, query, query_entities, orgid, fiids, service_params, userid)
|
||||
timing_stats["triplet_matching"] = time.time() - triplet_match_start
|
||||
debug(f"三元组匹配总耗时: {timing_stats['triplet_matching']:.3f} 秒")
|
||||
|
||||
triplet_text_start = time.time()
|
||||
query_entities = await rag_ops.extract_entities(request, query, service_params, userid, timings)
|
||||
all_triplets = await rag_ops.match_triplets(request, query, query_entities, orgid, fiids, service_params,
|
||||
userid, timings)
|
||||
combined_text = _combine_query_with_triplets(query, all_triplets)
|
||||
timing_stats["triplet_text_combine"] = time.time() - triplet_text_start
|
||||
debug(f"拼接三元组文本耗时: {timing_stats['triplet_text_combine']:.3f} 秒")
|
||||
query_vector = await rag_ops.generate_query_vector(request, combined_text, service_params, userid, timings)
|
||||
search_results = await rag_ops.vector_search(request, query_vector, orgid, fiids, limit + 5, service_params,
|
||||
userid, timings)
|
||||
|
||||
vector_start = time.time()
|
||||
query_vector = await rag_ops.api_service.get_embeddings(
|
||||
request=request,
|
||||
texts=[combined_text],
|
||||
upappid=service_params['embedding'],
|
||||
apiname="BAAI/bge-m3",
|
||||
user=userid
|
||||
)
|
||||
if not query_vector or not all(len(vec) == 1024 for vec in query_vector):
|
||||
raise ValueError("查询向量必须是长度为 1024 的浮点数列表")
|
||||
query_vector = query_vector[0]
|
||||
timing_stats["vector_generation"] = time.time() - vector_start
|
||||
debug(f"生成查询向量耗时: {timing_stats['vector_generation']:.3f} 秒")
|
||||
|
||||
search_start = time.time()
|
||||
search_limit = limit + 5
|
||||
search_results = await rag_ops.vector_search(
|
||||
request, query_vector, orgid, fiids, search_limit, service_params, userid
|
||||
)
|
||||
timing_stats["vector_search"] = time.time() - search_start
|
||||
debug(f"向量搜索耗时: {timing_stats['vector_search']:.3f} 秒")
|
||||
debug(f"从向量数据中搜索到{len(search_results)}条数据")
|
||||
|
||||
# 步骤6: 重排序(可选)
|
||||
use_rerank = True
|
||||
if use_rerank and search_results:
|
||||
rerank_start = time.time()
|
||||
debug("开始重排序")
|
||||
reranked_results = await rag_ops.rerank_results(
|
||||
request, combined_text, search_results, limit, service_params, userid
|
||||
)
|
||||
reranked_results = sorted(reranked_results, key=lambda x: x.get('rerank_score', 0), reverse=True)
|
||||
timing_stats["reranking"] = time.time() - rerank_start
|
||||
debug(f"重排序耗时: {timing_stats['reranking']:.3f} 秒")
|
||||
debug(f"重排序分数分布: {[round(r.get('rerank_score', 0), 3) for r in reranked_results]}")
|
||||
final_results = reranked_results
|
||||
final_results = await rag_ops.rerank_results(request, combined_text, search_results, limit, service_params,
|
||||
userid, timings)
|
||||
else:
|
||||
final_results = [{k: v for k, v in r.items() if k != 'rerank_score'} for r in search_results]
|
||||
|
||||
timing_stats["total_time"] = time.time() - start_time
|
||||
info(f"融合搜索完成,返回 {len(final_results)} 条结果,总耗时: {timing_stats['total_time']:.3f} 秒")
|
||||
|
||||
formatted_results = rag_ops.format_search_results(final_results, limit)
|
||||
info(f"融合搜索完成,返回 {len(formatted_results)} 条结果")
|
||||
timings["total_time"] = time.time() - start_time
|
||||
info(f"融合搜索完成,返回 {len(formatted_results)} 条结果,总耗时: {timings['total_time']:.3f} 秒")
|
||||
|
||||
return {
|
||||
"records": formatted_results
|
||||
"records": formatted_results,
|
||||
"timings": timings
|
||||
}
|
||||
|
||||
except Exception as e:
|
||||
error(f"融合搜索失败: {str(e)}, 堆栈: {traceback.format_exc()}")
|
||||
# 事务管理器会自动执行回滚
|
||||
return {
|
||||
"records": [],
|
||||
"timing": {"total_time": time.time() - start_time if 'start_time' in locals() else 0},
|
||||
"timings": {"total_time": time.time() - start_time if 'start_time' in locals() else 0},
|
||||
"error": str(e)
|
||||
}
|
||||
|
||||
|
||||
Loading…
x
Reference in New Issue
Block a user