This commit is contained in:
wangmeihua 2025-09-12 17:58:21 +08:00
parent 1c8eb52816
commit 97c9e0f1fa
3 changed files with 1102 additions and 413 deletions

View File

@ -18,9 +18,18 @@ from filetxt.loader import fileloader,File2Text
from ahserver.serverenv import get_serverenv
from typing import List, Dict, Any
from rag.service_opts import get_service_params, sor_get_service_params
from rag.rag_operations import RagOperations
import json
from dataclasses import dataclass
from enum import Enum
class RagFileMgr(FileMgr):
def __init__(self, fiid):
super().__init__(fiid)
self.rag_ops = RagOperations()
async def get_folder_ownerid(self, sor):
fiid = self.fiid
recs = await sor.R('kdb', {'id': self.fiid})
@ -44,205 +53,6 @@ where a.orgid = b.orgid
return r.quota, r.expired_date
return None, None
async def get_doucment_chunks(self, realpath, timings):
"""加载文件并进行文本分片"""
debug(f"加载文件: {realpath}")
start_load = time.time()
supported_formats = File2Text.supported_types()
debug(f"支持的文件格式:{supported_formats}")
ext = realpath.rsplit('.', 1)[1].lower() if '.' in realpath else ''
if ext not in supported_formats:
raise ValueError(f"不支持的文件格式: {ext}, 支持的格式: {', '.join(supported_formats)}")
text = fileloader(realpath)
text = re.sub(r'[^\u4e00-\u9fa5a-zA-Z0-9\s.;,\n/]', '', text)
timings["load_file"] = time.time() - start_load
debug(f"加载文件耗时: {timings['load_file']:.2f} 秒, 文本长度: {len(text)}")
if not text or not text.strip():
raise ValueError(f"文件 {realpath} 加载为空")
document = Document(page_content=text)
text_splitter = RecursiveCharacterTextSplitter(
chunk_size=500,
chunk_overlap=100,
length_function=len
)
debug("开始分片文件内容")
start_split = time.time()
chunks = text_splitter.split_documents([document])
timings["split_text"] = time.time() - start_split
debug(f"文本分片耗时: {timings['split_text']:.2f} 秒, 分片数量: {len(chunks)}")
if not chunks:
raise ValueError(f"文件 {realpath} 未生成任何文档块")
return chunks
async def docs_embedding(self, request, chunks, service_params, userid, timings):
"""调用嵌入服务生成向量"""
debug("调用嵌入服务生成向量")
start_embedding = time.time()
texts = [chunk.page_content for chunk in chunks]
embeddings = []
for i in range(0, len(texts), 10):
batch_texts = texts[i:i + 10]
batch_embeddings = await APIService().get_embeddings(
request=request,
texts=batch_texts,
upappid=service_params['embedding'],
apiname="BAAI/bge-m3",
user=userid
)
embeddings.extend(batch_embeddings)
if not embeddings or not all(len(vec) == 1024 for vec in embeddings):
raise ValueError("所有嵌入向量必须是长度为 1024 的浮点数列表")
timings["generate_embeddings"] = time.time() - start_embedding
debug(f"生成嵌入向量耗时: {timings['generate_embeddings']:.2f} 秒, 嵌入数量: {len(embeddings)}")
return embeddings
async def embedding_2_vdb(self, request, chunks, embeddings, realpath, orgid, fiid, id, service_params, userid,
db_type, timings):
"""准备数据并插入 Milvus"""
debug(f"准备数据并调用插入文件端点: {realpath}")
filename = os.path.basename(realpath).rsplit('.', 1)[0]
ext = realpath.rsplit('.', 1)[1].lower() if '.' in realpath else ''
upload_time = datetime.now().isoformat()
chunks_data = [
{
"userid": orgid,
"knowledge_base_id": fiid,
"text": chunk.page_content,
"vector": embeddings[i],
"document_id": id,
"filename": filename + '.' + ext,
"file_path": realpath,
"upload_time": upload_time,
"file_type": ext,
}
for i, chunk in enumerate(chunks)
]
start_milvus = time.time()
for i in range(0, len(chunks_data), 10):
batch_chunks = chunks_data[i:i + 10]
result = await APIService().milvus_insert_document(
request=request,
chunks=batch_chunks,
db_type=db_type,
upappid=service_params['vdb'],
apiname="milvus/insertdocument",
user=userid
)
if result.get("status") != "success":
raise ValueError(result.get("message", "Milvus 插入失败"))
timings["insert_milvus"] = time.time() - start_milvus
debug(f"Milvus 插入耗时: {timings['insert_milvus']:.2f}")
return chunks_data
async def get_triples(self, request, chunks, service_params, userid, timings):
"""调用三元组抽取服务"""
debug("调用三元组抽取服务")
start_triples = time.time()
chunk_texts = [doc.page_content for doc in chunks]
triples = []
for i, chunk in enumerate(chunk_texts):
result = await APIService().extract_triples(
request=request,
text=chunk,
upappid=service_params['triples'],
apiname="Babelscape/mrebel-large",
user=userid
)
if isinstance(result, list):
triples.extend(result)
debug(f"分片 {i + 1} 抽取到 {len(result)} 个三元组")
else:
error(f"分片 {i + 1} 处理失败: {str(result)}")
unique_triples = []
seen = set()
for t in triples:
identifier = (t['head'].lower(), t['tail'].lower(), t['type'].lower())
if identifier not in seen:
seen.add(identifier)
unique_triples.append(t)
else:
for existing in unique_triples:
if (existing['head'].lower() == t['head'].lower() and
existing['tail'].lower() == t['tail'].lower() and
len(t['type']) > len(existing['type'])):
unique_triples.remove(existing)
unique_triples.append(t)
debug(f"替换三元组为更具体类型: {t}")
break
timings["extract_triples"] = time.time() - start_triples
debug(f"三元组抽取耗时: {timings['extract_triples']:.2f} 秒, 抽取到 {len(unique_triples)} 个三元组")
return unique_triples
async def triple2graphdb(self, request, unique_triples, id, fiid, orgid, service_params, userid, timings):
"""调用 Neo4j 插入三元组"""
debug(f"插入 {len(unique_triples)} 个三元组到 Neo4j")
start_neo4j = time.time()
if unique_triples:
for i in range(0, len(unique_triples), 30):
batch_triples = unique_triples[i:i + 30]
neo4j_result = await APIService().neo4j_insert_triples(
request=request,
triples=batch_triples,
document_id=id,
knowledge_base_id=fiid,
userid=orgid,
upappid=service_params['gdb'],
apiname="neo4j/inserttriples",
user=userid
)
if neo4j_result.get("status") != "success":
raise ValueError(f"Neo4j 三元组插入失败: {neo4j_result.get('message', '未知错误')}")
info(f"文件三元组成功插入 Neo4j: {neo4j_result.get('message')}")
timings["insert_neo4j"] = time.time() - start_neo4j
debug(f"Neo4j 插入耗时: {timings['insert_neo4j']:.2f}")
else:
debug("未抽取到三元组")
timings["insert_neo4j"] = 0.0
async def delete_from_milvus(self, request, orgid, realpath, fiid, id, service_params, userid, db_type):
"""调用 Milvus 删除文档"""
debug(f"调用删除文件端点: userid={orgid}, file_path={realpath}, knowledge_base_id={fiid}, document_id={id}")
milvus_result = await APIService().milvus_delete_document(
request=request,
userid=orgid,
file_path=realpath,
knowledge_base_id=fiid,
document_id=id,
db_type=db_type,
upappid=service_params['vdb'],
apiname="milvus/deletedocument",
user=userid
)
if milvus_result.get("status") != "success":
raise ValueError(milvus_result.get("message", "Milvus 删除失败"))
async def delete_from_neo4j(self, request, id, service_params, userid):
"""调用 Neo4j 删除文档"""
debug(f"调用 Neo4j 删除文档端点: document_id={id}")
neo4j_result = await APIService().neo4j_delete_document(
request=request,
document_id=id,
upappid=service_params['gdb'],
apiname="neo4j/deletedocument",
user=userid
)
if neo4j_result.get("status") != "success":
raise ValueError(neo4j_result.get("message", "Neo4j 删除失败"))
nodes_deleted = neo4j_result.get("nodes_deleted", 0)
rels_deleted = neo4j_result.get("rels_deleted", 0)
info(f"成功删除 document_id={id}{nodes_deleted} 个 Neo4j 节点和 {rels_deleted} 个关系")
return nodes_deleted, rels_deleted
async def file_uploaded(self, request, ns, userid):
"""将文档插入 Milvus 并抽取三元组到 Neo4j"""
@ -272,11 +82,11 @@ where a.orgid = b.orgid
if not service_params:
raise ValueError("无法获取服务参数")
chunks = await self.get_doucment_chunks(realpath, timings)
embeddings = await self.docs_embedding(request, chunks, service_params, userid, timings)
await self.embedding_2_vdb(request, chunks, embeddings, realpath, orgid, fiid, id, service_params, userid,db_type, timings)
triples = await self.get_triples(request, chunks, service_params, userid, timings)
await self.triple2graphdb(request, triples, id, fiid, orgid, service_params, userid, timings)
chunks = await self.rag_ops.load_and_chunk_document(realpath, timings)
embeddings = await self.rag_ops.generate_embeddings(request, chunks, service_params, userid, timings)
await self.rag_ops.insert_to_vector_db(request, chunks, embeddings, realpath, orgid, fiid, id, service_params, userid,db_type, timings)
triples = await self.rag_ops.extract_triples(request, chunks, service_params, userid, timings)
await self.rag_ops.insert_to_graph_db(request, triples, id, fiid, orgid, service_params, userid, timings)
timings["total"] = time.time() - start_total
debug(f"总耗时: {timings['total']:.2f}")
@ -329,13 +139,13 @@ where a.orgid = b.orgid
raise ValueError("无法获取服务参数")
# 调用 Milvus 删除
await self.delete_from_milvus(request, orgid, realpath, fiid, id, service_params, userid, db_type)
await self.rag_ops.delete_from_vector_db(request, orgid, realpath, fiid, id, service_params, userid, db_type)
# 调用 Neo4j 删除
neo4j_deleted_nodes = 0
neo4j_deleted_rels = 0
try:
nodes_deleted, rels_deleted = await self.delete_from_neo4j(request, id, service_params, userid)
nodes_deleted, rels_deleted = await self.rag_ops.delete_from_graph_db(request, id, service_params, userid)
neo4j_deleted_nodes += nodes_deleted
neo4j_deleted_rels += rels_deleted
total_nodes_deleted += nodes_deleted
@ -369,6 +179,332 @@ where a.orgid = b.orgid
"message": f"处理 {len(recs)} 个文件,成功删除 {sum(1 for r in results if r['status'] == 'success')}",
"status_code": 200 if all(r["status"] == "success" for r in results) else 207
}
# async def get_doucment_chunks(self, realpath, timings):
# """加载文件并进行文本分片"""
# debug(f"加载文件: {realpath}")
# start_load = time.time()
# supported_formats = File2Text.supported_types()
# debug(f"支持的文件格式:{supported_formats}")
# ext = realpath.rsplit('.', 1)[1].lower() if '.' in realpath else ''
# if ext not in supported_formats:
# raise ValueError(f"不支持的文件格式: {ext}, 支持的格式: {', '.join(supported_formats)}")
# text = fileloader(realpath)
# text = re.sub(r'[^\u4e00-\u9fa5a-zA-Z0-9\s.;,\n/]', '', text)
# timings["load_file"] = time.time() - start_load
# debug(f"加载文件耗时: {timings['load_file']:.2f} 秒, 文本长度: {len(text)}")
#
# if not text or not text.strip():
# raise ValueError(f"文件 {realpath} 加载为空")
#
# document = Document(page_content=text)
# text_splitter = RecursiveCharacterTextSplitter(
# chunk_size=500,
# chunk_overlap=100,
# length_function=len
# )
# debug("开始分片文件内容")
# start_split = time.time()
# chunks = text_splitter.split_documents([document])
# timings["split_text"] = time.time() - start_split
# debug(f"文本分片耗时: {timings['split_text']:.2f} 秒, 分片数量: {len(chunks)}")
#
# if not chunks:
# raise ValueError(f"文件 {realpath} 未生成任何文档块")
#
# return chunks
#
# async def docs_embedding(self, request, chunks, service_params, userid, timings):
# """调用嵌入服务生成向量"""
# debug("调用嵌入服务生成向量")
# start_embedding = time.time()
# texts = [chunk.page_content for chunk in chunks]
# embeddings = []
# for i in range(0, len(texts), 10):
# batch_texts = texts[i:i + 10]
# batch_embeddings = await APIService().get_embeddings(
# request=request,
# texts=batch_texts,
# upappid=service_params['embedding'],
# apiname="BAAI/bge-m3",
# user=userid
# )
# embeddings.extend(batch_embeddings)
#
# if not embeddings or not all(len(vec) == 1024 for vec in embeddings):
# raise ValueError("所有嵌入向量必须是长度为 1024 的浮点数列表")
#
# timings["generate_embeddings"] = time.time() - start_embedding
# debug(f"生成嵌入向量耗时: {timings['generate_embeddings']:.2f} 秒, 嵌入数量: {len(embeddings)}")
# return embeddings
#
# async def embedding_2_vdb(self, request, chunks, embeddings, realpath, orgid, fiid, id, service_params, userid,
# db_type, timings):
# """准备数据并插入 Milvus"""
# debug(f"准备数据并调用插入文件端点: {realpath}")
# filename = os.path.basename(realpath).rsplit('.', 1)[0]
# ext = realpath.rsplit('.', 1)[1].lower() if '.' in realpath else ''
# upload_time = datetime.now().isoformat()
#
# chunks_data = [
# {
# "userid": orgid,
# "knowledge_base_id": fiid,
# "text": chunk.page_content,
# "vector": embeddings[i],
# "document_id": id,
# "filename": filename + '.' + ext,
# "file_path": realpath,
# "upload_time": upload_time,
# "file_type": ext,
# }
# for i, chunk in enumerate(chunks)
# ]
#
# start_milvus = time.time()
# for i in range(0, len(chunks_data), 10):
# batch_chunks = chunks_data[i:i + 10]
# result = await APIService().milvus_insert_document(
# request=request,
# chunks=batch_chunks,
# db_type=db_type,
# upappid=service_params['vdb'],
# apiname="milvus/insertdocument",
# user=userid
# )
# if result.get("status") != "success":
# raise ValueError(result.get("message", "Milvus 插入失败"))
#
# timings["insert_milvus"] = time.time() - start_milvus
# debug(f"Milvus 插入耗时: {timings['insert_milvus']:.2f} 秒")
# return chunks_data
#
# async def get_triples(self, request, chunks, service_params, userid, timings):
# """调用三元组抽取服务"""
# debug("调用三元组抽取服务")
# start_triples = time.time()
# chunk_texts = [doc.page_content for doc in chunks]
# triples = []
# for i, chunk in enumerate(chunk_texts):
# result = await APIService().extract_triples(
# request=request,
# text=chunk,
# upappid=service_params['triples'],
# apiname="Babelscape/mrebel-large",
# user=userid
# )
# if isinstance(result, list):
# triples.extend(result)
# debug(f"分片 {i + 1} 抽取到 {len(result)} 个三元组")
# else:
# error(f"分片 {i + 1} 处理失败: {str(result)}")
#
# unique_triples = []
# seen = set()
# for t in triples:
# identifier = (t['head'].lower(), t['tail'].lower(), t['type'].lower())
# if identifier not in seen:
# seen.add(identifier)
# unique_triples.append(t)
# else:
# for existing in unique_triples:
# if (existing['head'].lower() == t['head'].lower() and
# existing['tail'].lower() == t['tail'].lower() and
# len(t['type']) > len(existing['type'])):
# unique_triples.remove(existing)
# unique_triples.append(t)
# debug(f"替换三元组为更具体类型: {t}")
# break
#
# timings["extract_triples"] = time.time() - start_triples
# debug(f"三元组抽取耗时: {timings['extract_triples']:.2f} 秒, 抽取到 {len(unique_triples)} 个三元组")
# return unique_triples
#
# async def triple2graphdb(self, request, unique_triples, id, fiid, orgid, service_params, userid, timings):
# """调用 Neo4j 插入三元组"""
# debug(f"插入 {len(unique_triples)} 个三元组到 Neo4j")
# start_neo4j = time.time()
# if unique_triples:
# for i in range(0, len(unique_triples), 30):
# batch_triples = unique_triples[i:i + 30]
# neo4j_result = await APIService().neo4j_insert_triples(
# request=request,
# triples=batch_triples,
# document_id=id,
# knowledge_base_id=fiid,
# userid=orgid,
# upappid=service_params['gdb'],
# apiname="neo4j/inserttriples",
# user=userid
# )
# if neo4j_result.get("status") != "success":
# raise ValueError(f"Neo4j 三元组插入失败: {neo4j_result.get('message', '未知错误')}")
# info(f"文件三元组成功插入 Neo4j: {neo4j_result.get('message')}")
# timings["insert_neo4j"] = time.time() - start_neo4j
# debug(f"Neo4j 插入耗时: {timings['insert_neo4j']:.2f} 秒")
# else:
# debug("未抽取到三元组")
# timings["insert_neo4j"] = 0.0
#
# async def delete_from_milvus(self, request, orgid, realpath, fiid, id, service_params, userid, db_type):
# """调用 Milvus 删除文档"""
# debug(f"调用删除文件端点: userid={orgid}, file_path={realpath}, knowledge_base_id={fiid}, document_id={id}")
# milvus_result = await APIService().milvus_delete_document(
# request=request,
# userid=orgid,
# file_path=realpath,
# knowledge_base_id=fiid,
# document_id=id,
# db_type=db_type,
# upappid=service_params['vdb'],
# apiname="milvus/deletedocument",
# user=userid
# )
# if milvus_result.get("status") != "success":
# raise ValueError(milvus_result.get("message", "Milvus 删除失败"))
#
# async def delete_from_neo4j(self, request, id, service_params, userid):
# """调用 Neo4j 删除文档"""
# debug(f"调用 Neo4j 删除文档端点: document_id={id}")
# neo4j_result = await APIService().neo4j_delete_document(
# request=request,
# document_id=id,
# upappid=service_params['gdb'],
# apiname="neo4j/deletedocument",
# user=userid
# )
# if neo4j_result.get("status") != "success":
# raise ValueError(neo4j_result.get("message", "Neo4j 删除失败"))
# nodes_deleted = neo4j_result.get("nodes_deleted", 0)
# rels_deleted = neo4j_result.get("rels_deleted", 0)
# info(f"成功删除 document_id={id} 的 {nodes_deleted} 个 Neo4j 节点和 {rels_deleted} 个关系")
# return nodes_deleted, rels_deleted
# async def file_uploaded(self, request, ns, userid):
# """将文档插入 Milvus 并抽取三元组到 Neo4j"""
# debug(f'Received ns: {ns=}')
# env = request._run_ns
# realpath = ns.get('realpath', '')
# fiid = ns.get('fiid', '')
# id = ns.get('id', '')
# orgid = ns.get('ownerid', '')
# db_type = ''
#
# debug(f'Inserting document: file_path={realpath}, userid={orgid}, db_type={db_type}, knowledge_base_id={fiid}, document_id={id}')
#
# timings = {}
# start_total = time.time()
#
# try:
# if not orgid or not fiid or not id:
# raise ValueError("orgid、fiid 和 id 不能为空")
# if len(orgid) > 32 or len(fiid) > 255:
# raise ValueError("orgid 或 fiid 的长度超出限制")
# if not os.path.exists(realpath):
# raise ValueError(f"文件 {realpath} 不存在")
#
# # 获取服务参数
# service_params = await get_service_params(orgid)
# if not service_params:
# raise ValueError("无法获取服务参数")
#
# chunks = await self.get_doucment_chunks(realpath, timings)
# embeddings = await self.docs_embedding(request, chunks, service_params, userid, timings)
# await self.embedding_2_vdb(request, chunks, embeddings, realpath, orgid, fiid, id, service_params, userid,db_type, timings)
# triples = await self.get_triples(request, chunks, service_params, userid, timings)
# await self.triple2graphdb(request, triples, id, fiid, orgid, service_params, userid, timings)
#
# timings["total"] = time.time() - start_total
# debug(f"总耗时: {timings['total']:.2f} 秒")
# return {
# "status": "success",
# "userid": orgid,
# "document_id": id,
# "collection_name": "ragdb",
# "timings": timings,
# "unique_triples": triples,
# "message": f"文件 {realpath} 成功嵌入并处理三元组",
# "status_code": 200
# }
# except Exception as e:
# error(f"插入文档失败: {str(e)}, 堆栈: {traceback.format_exc()}")
# timings["total"] = time.time() - start_total
# return {
# "status": "error",
# "document_id": id,
# "collection_name": "ragdb",
# "timings": timings,
# "message": f"插入文档失败: {str(e)}",
# "status_code": 400
# }
#
# async def file_deleted(self, request, recs, userid):
# """删除用户指定文件数据,包括 Milvus 和 Neo4j 中的记录"""
# if not isinstance(recs, list):
# recs = [recs]
# results = []
# total_nodes_deleted = 0
# total_rels_deleted = 0
#
# for rec in recs:
# id = rec.get('id', '')
# realpath = rec.get('realpath', '')
# fiid = rec.get('fiid', '')
# orgid = rec.get('ownerid', '')
# db_type = ''
# collection_name = "ragdb" if not db_type else f"ragdb_{db_type}"
#
# try:
# required_fields = ['id', 'realpath', 'fiid', 'ownerid']
# missing_fields = [field for field in required_fields if not rec.get(field, '')]
# if missing_fields:
# raise ValueError(f"缺少必填字段: {', '.join(missing_fields)}")
#
# service_params = await get_service_params(orgid)
# if not service_params:
# raise ValueError("无法获取服务参数")
#
# # 调用 Milvus 删除
# await self.delete_from_milvus(request, orgid, realpath, fiid, id, service_params, userid, db_type)
#
# # 调用 Neo4j 删除
# neo4j_deleted_nodes = 0
# neo4j_deleted_rels = 0
# try:
# nodes_deleted, rels_deleted = await self.delete_from_neo4j(request, id, service_params, userid)
# neo4j_deleted_nodes += nodes_deleted
# neo4j_deleted_rels += rels_deleted
# total_nodes_deleted += nodes_deleted
# total_rels_deleted += rels_deleted
# except Exception as e:
# error(f"删除 document_id={id} 的 Neo4j 数据失败: {str(e)}")
#
# results.append({
# "status": "success",
# "collection_name": collection_name,
# "document_id": id,
# "message": f"成功删除文件 {realpath} 的 Milvus 记录,{neo4j_deleted_nodes} 个 Neo4j 节点,{neo4j_deleted_rels} 个 Neo4j 关系",
# "status_code": 200
# })
#
# except Exception as e:
# error(f"删除文档 {realpath} 失败: {str(e)}, 堆栈: {traceback.format_exc()}")
# results.append({
# "status": "error",
# "collection_name": collection_name,
# "document_id": id,
# "message": f"删除文档 {realpath} 失败: {str(e)}",
# "status_code": 400
# })
#
# return {
# "status": "success" if all(r["status"] == "success" for r in results) else "partial",
# "results": results,
# "total_nodes_deleted": total_nodes_deleted,
# "total_rels_deleted": total_rels_deleted,
# "message": f"处理 {len(recs)} 个文件,成功删除 {sum(1 for r in results if r['status'] == 'success')} 个",
# "status_code": 200 if all(r["status"] == "success" for r in results) else 207
# }
# async def test_ragfilemgr():
# """测试 RagFileMgr 类的 get_service_params"""

450
rag/rag_operations.py Normal file
View File

@ -0,0 +1,450 @@
"""
RAG 操作的通用函数库
包含文档处理搜索嵌入等通用操作 folderinfo.py ragapi.py 共同使用
"""
import os
import re
import time
import math
from datetime import datetime
from typing import List, Dict, Any, Optional
from langchain_core.documents import Document
from langchain_text_splitters import RecursiveCharacterTextSplitter
from appPublic.log import debug, error, info
from filetxt.loader import fileloader, File2Text
from rag.uapi_service import APIService
from rag.service_opts import get_service_params
from rag.transaction_manager import TransactionManager, OperationType
class RagOperations:
"""RAG 操作类,提供所有通用的 RAG 操作"""
def __init__(self):
self.api_service = APIService()
async def load_and_chunk_document(self, realpath: str, timings: Dict,
transaction_mgr: TransactionManager = None) -> List[Document]:
"""加载文件并进行文本分片"""
debug(f"加载文件: {realpath}")
start_load = time.time()
# 检查文件格式支持
supported_formats = File2Text.supported_types()
debug(f"支持的文件格式:{supported_formats}")
ext = realpath.rsplit('.', 1)[1].lower() if '.' in realpath else ''
if ext not in supported_formats:
raise ValueError(f"不支持的文件格式: {ext}, 支持的格式: {', '.join(supported_formats)}")
# 加载文件内容
text = fileloader(realpath)
text = re.sub(r'[^\u4e00-\u9fa5a-zA-Z0-9\s.;,\n/]', '', text)
timings["load_file"] = time.time() - start_load
debug(f"加载文件耗时: {timings['load_file']:.2f} 秒, 文本长度: {len(text)}")
if not text or not text.strip():
raise ValueError(f"文件 {realpath} 加载为空")
# 分片处理
document = Document(page_content=text)
text_splitter = RecursiveCharacterTextSplitter(
chunk_size=500,
chunk_overlap=100,
length_function=len
)
debug("开始分片文件内容")
start_split = time.time()
chunks = text_splitter.split_documents([document])
timings["split_text"] = time.time() - start_split
debug(f"文本分片耗时: {timings['split_text']:.2f} 秒, 分片数量: {len(chunks)}")
if not chunks:
raise ValueError(f"文件 {realpath} 未生成任何文档块")
# 记录事务操作
if transaction_mgr:
transaction_mgr.add_operation(
OperationType.FILE_LOAD,
{'realpath': realpath, 'chunks_count': len(chunks)}
)
return chunks
async def generate_embeddings(self, request, chunks: List[Document], service_params: Dict,
userid: str, timings: Dict,
transaction_mgr: TransactionManager = None) -> List[List[float]]:
"""生成嵌入向量"""
debug("调用嵌入服务生成向量")
start_embedding = time.time()
texts = [chunk.page_content for chunk in chunks]
embeddings = []
# 批量处理嵌入
for i in range(0, len(texts), 10):
batch_texts = texts[i:i + 10]
batch_embeddings = await self.api_service.get_embeddings(
request=request,
texts=batch_texts,
upappid=service_params['embedding'],
apiname="BAAI/bge-m3",
user=userid
)
embeddings.extend(batch_embeddings)
if not embeddings or not all(len(vec) == 1024 for vec in embeddings):
raise ValueError("所有嵌入向量必须是长度为 1024 的浮点数列表")
timings["generate_embeddings"] = time.time() - start_embedding
debug(f"生成嵌入向量耗时: {timings['generate_embeddings']:.2f} 秒, 嵌入数量: {len(embeddings)}")
# 记录事务操作
if transaction_mgr:
transaction_mgr.add_operation(
OperationType.EMBEDDING,
{'embeddings_count': len(embeddings)}
)
return embeddings
async def insert_to_vector_db(self, request, chunks: List[Document], embeddings: List[List[float]],
realpath: str, orgid: str, fiid: str, id: str, service_params: Dict,
userid: str, db_type: str, timings: Dict,
transaction_mgr: TransactionManager = None):
"""插入向量数据库"""
debug(f"准备数据并调用插入文件端点: {realpath}")
filename = os.path.basename(realpath).rsplit('.', 1)[0]
ext = realpath.rsplit('.', 1)[1].lower() if '.' in realpath else ''
upload_time = datetime.now().isoformat()
chunks_data = [
{
"userid": orgid,
"knowledge_base_id": fiid,
"text": chunk.page_content,
"vector": embeddings[i],
"document_id": id,
"filename": filename + '.' + ext,
"file_path": realpath,
"upload_time": upload_time,
"file_type": ext,
}
for i, chunk in enumerate(chunks)
]
start_milvus = time.time()
for i in range(0, len(chunks_data), 10):
batch_chunks = chunks_data[i:i + 10]
result = await self.api_service.milvus_insert_document(
request=request,
chunks=batch_chunks,
db_type=db_type,
upappid=service_params['vdb'],
apiname="milvus/insertdocument",
user=userid
)
if result.get("status") != "success":
raise ValueError(result.get("message", "Milvus 插入失败"))
timings["insert_milvus"] = time.time() - start_milvus
debug(f"Milvus 插入耗时: {timings['insert_milvus']:.2f}")
# 记录事务操作,包含回滚函数
if transaction_mgr:
async def rollback_vdb_insert(data, context):
await self.delete_from_vector_db(
context['request'], data['orgid'], data['realpath'],
data['fiid'], data['id'], context['service_params'],
context['userid'], data['db_type']
)
return f"已回滚向量数据库插入: {data['id']}"
transaction_mgr.add_operation(
OperationType.VDB_INSERT,
{
'orgid': orgid, 'realpath': realpath, 'fiid': fiid,
'id': id, 'db_type': db_type
},
rollback_func=rollback_vdb_insert
)
return chunks_data
async def extract_triples(self, request, chunks: List[Document], service_params: Dict,
userid: str, timings: Dict,
transaction_mgr: TransactionManager = None) -> List[Dict]:
"""抽取三元组"""
debug("调用三元组抽取服务")
start_triples = time.time()
chunk_texts = [doc.page_content for doc in chunks]
triples = []
for i, chunk in enumerate(chunk_texts):
result = await self.api_service.extract_triples(
request=request,
text=chunk,
upappid=service_params['triples'],
apiname="Babelscape/mrebel-large",
user=userid
)
if isinstance(result, list):
triples.extend(result)
debug(f"分片 {i + 1} 抽取到 {len(result)} 个三元组")
else:
error(f"分片 {i + 1} 处理失败: {str(result)}")
# 去重和优化三元组
unique_triples = self._deduplicate_triples(triples)
timings["extract_triples"] = time.time() - start_triples
debug(f"三元组抽取耗时: {timings['extract_triples']:.2f} 秒, 抽取到 {len(unique_triples)} 个三元组")
# 记录事务操作
if transaction_mgr:
transaction_mgr.add_operation(
OperationType.TRIPLES_EXTRACT,
{'triples_count': len(unique_triples)}
)
return unique_triples
async def insert_to_graph_db(self, request, triples: List[Dict], id: str, fiid: str,
orgid: str, service_params: Dict, userid: str, timings: Dict,
transaction_mgr: TransactionManager = None):
"""插入图数据库"""
debug(f"插入 {len(triples)} 个三元组到 Neo4j")
start_neo4j = time.time()
if triples:
for i in range(0, len(triples), 30):
batch_triples = triples[i:i + 30]
neo4j_result = await self.api_service.neo4j_insert_triples(
request=request,
triples=batch_triples,
document_id=id,
knowledge_base_id=fiid,
userid=orgid,
upappid=service_params['gdb'],
apiname="neo4j/inserttriples",
user=userid
)
if neo4j_result.get("status") != "success":
raise ValueError(f"Neo4j 三元组插入失败: {neo4j_result.get('message', '未知错误')}")
info(f"文件三元组成功插入 Neo4j: {neo4j_result.get('message')}")
timings["insert_neo4j"] = time.time() - start_neo4j
debug(f"Neo4j 插入耗时: {timings['insert_neo4j']:.2f}")
else:
debug("未抽取到三元组")
timings["insert_neo4j"] = 0.0
# 记录事务操作,包含回滚函数
if transaction_mgr:
async def rollback_gdb_insert(data, context):
await self.delete_from_graph_db(
context['request'], data['id'],
context['service_params'], context['userid']
)
return f"已回滚图数据库插入: {data['id']}"
transaction_mgr.add_operation(
OperationType.GDB_INSERT,
{'id': id, 'triples_count': len(triples)},
rollback_func=rollback_gdb_insert
)
async def delete_from_vector_db(self, request, orgid: str, realpath: str, fiid: str,
id: str, service_params: Dict, userid: str, db_type: str):
"""从向量数据库删除文档"""
debug(f"调用删除文件端点: userid={orgid}, file_path={realpath}, knowledge_base_id={fiid}, document_id={id}")
milvus_result = await self.api_service.milvus_delete_document(
request=request,
userid=orgid,
file_path=realpath,
knowledge_base_id=fiid,
document_id=id,
db_type=db_type,
upappid=service_params['vdb'],
apiname="milvus/deletedocument",
user=userid
)
if milvus_result.get("status") != "success":
raise ValueError(milvus_result.get("message", "Milvus 删除失败"))
async def delete_from_graph_db(self, request, id: str, service_params: Dict, userid: str):
"""从图数据库删除文档"""
debug(f"调用 Neo4j 删除文档端点: document_id={id}")
neo4j_result = await self.api_service.neo4j_delete_document(
request=request,
document_id=id,
upappid=service_params['gdb'],
apiname="neo4j/deletedocument",
user=userid
)
if neo4j_result.get("status") != "success":
raise ValueError(neo4j_result.get("message", "Neo4j 删除失败"))
nodes_deleted = neo4j_result.get("nodes_deleted", 0)
rels_deleted = neo4j_result.get("rels_deleted", 0)
info(f"成功删除 document_id={id}{nodes_deleted} 个 Neo4j 节点和 {rels_deleted} 个关系")
return nodes_deleted, rels_deleted
async def extract_entities(self, request, query: str, service_params: Dict, userid: str,
transaction_mgr: TransactionManager = None) -> List[str]:
"""提取实体"""
debug(f"提取查询实体: {query}")
entities = await self.api_service.extract_entities(
request=request,
query=query,
upappid=service_params['entities'],
apiname="LTP/small",
user=userid
)
# 记录事务操作
if transaction_mgr:
transaction_mgr.add_operation(
OperationType.ENTITY_EXTRACT,
{'query': query, 'entities_count': len(entities)}
)
return entities
async def match_triplets(self, request, query: str, entities: List[str], orgid: str,
fiids: List[str], service_params: Dict, userid: str,
transaction_mgr: TransactionManager = None) -> List[Dict]:
"""匹配三元组"""
debug("开始三元组匹配")
all_triplets = []
for kb_id in fiids:
debug(f"调用 Neo4j 三元组匹配: knowledge_base_id={kb_id}")
try:
neo4j_result = await self.api_service.neo4j_match_triplets(
request=request,
query=query,
query_entities=entities,
userid=orgid,
knowledge_base_id=kb_id,
upappid=service_params['gdb'],
apiname="neo4j/matchtriplets",
user=userid
)
if neo4j_result.get("status") == "success":
triplets = neo4j_result.get("triplets", [])
all_triplets.extend(triplets)
debug(f"知识库 {kb_id} 匹配到 {len(triplets)} 个三元组")
else:
error(
f"Neo4j 三元组匹配失败: knowledge_base_id={kb_id}, 错误: {neo4j_result.get('message', '未知错误')}")
except Exception as e:
error(f"Neo4j 三元组匹配失败: knowledge_base_id={kb_id}, 错误: {str(e)}")
continue
# 记录事务操作
if transaction_mgr:
transaction_mgr.add_operation(
OperationType.TRIPLET_MATCH,
{'query': query, 'triplets_count': len(all_triplets)}
)
return all_triplets
async def vector_search(self, request, query_vector: List[float], orgid: str,
fiids: List[str], limit: int, service_params: Dict, userid: str,
transaction_mgr: TransactionManager = None) -> List[Dict]:
"""向量搜索"""
debug("开始向量搜索")
result = await self.api_service.milvus_search_query(
request=request,
query_vector=query_vector,
userid=orgid,
knowledge_base_ids=fiids,
limit=limit,
offset=0,
upappid=service_params['vdb'],
apiname="mlvus/searchquery",
user=userid
)
if result.get("status") != "success":
raise ValueError(f"向量搜索失败: {result.get('message', '未知错误')}")
search_results = result.get("results", [])
# 记录事务操作
if transaction_mgr:
transaction_mgr.add_operation(
OperationType.VECTOR_SEARCH,
{'results_count': len(search_results)}
)
return search_results
async def rerank_results(self, request, query: str, results: List[Dict], top_n: int,
service_params: Dict, userid: str,
transaction_mgr: TransactionManager = None) -> List[Dict]:
"""重排序结果"""
debug("开始重排序")
reranked_results = await self.api_service.rerank_results(
request=request,
query=query,
results=results,
top_n=top_n,
upappid=service_params['reranker'],
apiname="BAAI/bge-reranker-v2-m3",
user=userid
)
# 记录事务操作
if transaction_mgr:
transaction_mgr.add_operation(
OperationType.RERANK,
{'input_count': len(results), 'output_count': len(reranked_results)}
)
return reranked_results
def _deduplicate_triples(self, triples: List[Dict]) -> List[Dict]:
"""去重和优化三元组"""
unique_triples = []
seen = set()
for t in triples:
identifier = (t['head'].lower(), t['tail'].lower(), t['type'].lower())
if identifier not in seen:
seen.add(identifier)
unique_triples.append(t)
else:
# 如果发现更具体的类型,则替换
for existing in unique_triples:
if (existing['head'].lower() == t['head'].lower() and
existing['tail'].lower() == t['tail'].lower() and
len(t['type']) > len(existing['type'])):
unique_triples.remove(existing)
unique_triples.append(t)
debug(f"替换三元组为更具体类型: {t}")
break
return unique_triples
def format_search_results(self, results: List[Dict], limit: int) -> List[Dict]:
"""格式化搜索结果为统一格式"""
formatted_results = []
for res in results[:limit]:
rerank_score = res.get('rerank_score', 0)
score = 1 / (1 + math.exp(-rerank_score)) if rerank_score is not None else 1 - res.get('distance', 0)
score = max(0.0, min(1.0, score))
content = res.get('text', '')
title = res.get('metadata', {}).get('filename', 'Untitled')
document_id = res.get('metadata', {}).get('document_id', '')
formatted_results.append({
"content": content,
"title": title,
"metadata": {"document_id": document_id, "score": score},
})
return formatted_results

View File

@ -6,6 +6,7 @@ import traceback
import json
import math
from rag.service_opts import get_service_params, sor_get_service_params
from rag.rag_operations import RagOperations
helptext = """kyrag API:
@ -131,14 +132,97 @@ async def fusedsearch(request, params_kw, *params, **kw):
debug(f"fiids: {fiids}")
# 验证 fiids的orgid与orgid = await f()是否一致
await _validate_fiids_orgid(fiids, orgid, kw)
service_params = await get_service_params(orgid)
if not service_params:
raise ValueError("无法获取服务参数")
try:
timing_stats = {}
start_time = time.time()
rag_ops = RagOperations()
entity_extract_start = time.time()
query_entities = await rag_ops.extract_entities(request, query, service_params, userid)
timing_stats["entity_extraction"] = time.time() - entity_extract_start
debug(f"提取实体: {query_entities}, 耗时: {timing_stats['entity_extraction']:.3f}")
triplet_match_start = time.time()
all_triplets = await rag_ops.match_triplets(request, query, query_entities, orgid, fiids, service_params, userid)
timing_stats["triplet_matching"] = time.time() - triplet_match_start
debug(f"三元组匹配总耗时: {timing_stats['triplet_matching']:.3f}")
triplet_text_start = time.time()
combined_text = _combine_query_with_triplets(query, all_triplets)
timing_stats["triplet_text_combine"] = time.time() - triplet_text_start
debug(f"拼接三元组文本耗时: {timing_stats['triplet_text_combine']:.3f}")
vector_start = time.time()
query_vector = await rag_ops.api_service.get_embeddings(
request=request,
texts=[combined_text],
upappid=service_params['embedding'],
apiname="BAAI/bge-m3",
user=userid
)
if not query_vector or not all(len(vec) == 1024 for vec in query_vector):
raise ValueError("查询向量必须是长度为 1024 的浮点数列表")
query_vector = query_vector[0]
timing_stats["vector_generation"] = time.time() - vector_start
debug(f"生成查询向量耗时: {timing_stats['vector_generation']:.3f}")
search_start = time.time()
search_limit = limit + 5
search_results = await rag_ops.vector_search(
request, query_vector, orgid, fiids, search_limit, service_params, userid
)
timing_stats["vector_search"] = time.time() - search_start
debug(f"向量搜索耗时: {timing_stats['vector_search']:.3f}")
debug(f"从向量数据中搜索到{len(search_results)}条数据")
# 步骤6: 重排序(可选)
use_rerank = True
if use_rerank and search_results:
rerank_start = time.time()
debug("开始重排序")
reranked_results = await rag_ops.rerank_results(
request, combined_text, search_results, limit, service_params, userid
)
reranked_results = sorted(reranked_results, key=lambda x: x.get('rerank_score', 0), reverse=True)
timing_stats["reranking"] = time.time() - rerank_start
debug(f"重排序耗时: {timing_stats['reranking']:.3f}")
debug(f"重排序分数分布: {[round(r.get('rerank_score', 0), 3) for r in reranked_results]}")
final_results = reranked_results
else:
final_results = [{k: v for k, v in r.items() if k != 'rerank_score'} for r in search_results]
timing_stats["total_time"] = time.time() - start_time
info(f"融合搜索完成,返回 {len(final_results)} 条结果,总耗时: {timing_stats['total_time']:.3f}")
formatted_results = rag_ops.format_search_results(final_results, limit)
info(f"融合搜索完成,返回 {len(formatted_results)} 条结果")
return {
"records": formatted_results
}
except Exception as e:
error(f"融合搜索失败: {str(e)}, 堆栈: {traceback.format_exc()}")
# 事务管理器会自动执行回滚
return {
"records": [],
"timing": {"total_time": time.time() - start_time if 'start_time' in locals() else 0},
"error": str(e)
}
async def _validate_fiids_orgid(fiids, orgid, kw):
"""验证 fiids 的 orgid 与当前用户 orgid 是否一致"""
if fiids:
db = DBPools()
dbname = kw.get('get_module_dbname')('rag')
sql_opts = """
SELECT orgid
FROM kdb
WHERE id = ${id}$
"""
sql_opts = """SELECT orgid FROM kdb WHERE id = ${id}$"""
try:
async with db.sqlorContext(dbname) as sor:
result = await sor.sqlExe(sql_opts, {"id": fiids[0]})
@ -149,211 +233,230 @@ async def fusedsearch(request, params_kw, *params, **kw):
raise ValueError(f"orgid 不一致: kdb.orgid={kdb_orgid}, user orgid={orgid}")
except Exception as e:
error(f"orgid 验证失败: {str(e)}")
return json.dumps({"status": "error", "message": str(e)})
service_params = await get_service_params(orgid)
raise
api_service = APIService()
start_time = time.time()
timing_stats = {}
try:
info(
f"开始融合搜索: query={query}, userid={orgid}, knowledge_base_ids={fiids}")
if not query or not orgid or not fiids:
raise ValueError("query、orgid 和 knowledge_base_ids 不能为空")
# 提取实体
entity_extract_start = time.time()
query_entities = await api_service.extract_entities(
request=request,
query=query,
upappid=service_params['entities'],
apiname="LTP/small",
user=userid
)
timing_stats["entity_extraction"] = time.time() - entity_extract_start
debug(f"提取实体: {query_entities}, 耗时: {timing_stats['entity_extraction']:.3f}")
# 调用 Neo4j 服务进行三元组匹配
all_triplets = []
triplet_match_start = time.time()
for kb_id in fiids:
debug(f"调用 Neo4j 三元组匹配: knowledge_base_id={kb_id}")
try:
neo4j_result = await api_service.neo4j_match_triplets(
request=request,
query=query,
query_entities=query_entities,
userid=orgid,
knowledge_base_id=kb_id,
upappid=service_params['gdb'],
apiname="neo4j/matchtriplets",
user=userid
)
if neo4j_result.get("status") == "success":
triplets = neo4j_result.get("triplets", [])
all_triplets.extend(triplets)
debug(f"知识库 {kb_id} 匹配到 {len(triplets)} 个三元组: {triplets[:5]}")
else:
error(
f"Neo4j 三元组匹配失败: knowledge_base_id={kb_id}, 错误: {neo4j_result.get('message', '未知错误')}")
except Exception as e:
error(f"Neo4j 三元组匹配失败: knowledge_base_id={kb_id}, 错误: {str(e)}")
continue
timing_stats["triplet_matching"] = time.time() - triplet_match_start
debug(f"三元组匹配总耗时: {timing_stats['triplet_matching']:.3f}")
# 拼接三元组文本
triplet_text_start = time.time()
triplet_texts = []
for triplet in all_triplets:
head = triplet.get('head', '')
type_ = triplet.get('type', '')
tail = triplet.get('tail', '')
if head and type_ and tail:
triplet_texts.append(f"{head} {type_} {tail}")
else:
debug(f"无效三元组: {triplet}")
combined_text = query
if triplet_texts:
combined_text += "".join(triplet_texts)
debug(
f"拼接文本: {combined_text[:200]}... (总长度: {len(combined_text)}, 三元组数量: {len(triplet_texts)})")
timing_stats["triplet_text_combine"] = time.time() - triplet_text_start
debug(f"拼接三元组文本耗时: {timing_stats['triplet_text_combine']:.3f}")
# 将拼接文本转换为向量
vector_start = time.time()
query_vector = await api_service.get_embeddings(
request=request,
texts=[combined_text],
upappid=service_params['embedding'],
apiname="BAAI/bge-m3",
user=userid
)
if not query_vector or not all(len(vec) == 1024 for vec in query_vector):
raise ValueError("查询向量必须是长度为 1024 的浮点数列表")
query_vector = query_vector[0] # 取第一个向量
timing_stats["vector_generation"] = time.time() - vector_start
debug(f"生成查询向量耗时: {timing_stats['vector_generation']:.3f}")
# 调用搜索端点
sum = limit + 5
search_start = time.time()
debug(f"orgid: {orgid}")
result = await api_service.milvus_search_query(
request=request,
query_vector=query_vector,
userid=orgid,
knowledge_base_ids=fiids,
limit=sum,
offset=0,
upappid=service_params['vdb'],
apiname="mlvus/searchquery",
user=userid
)
timing_stats["vector_search"] = time.time() - search_start
debug(f"向量搜索耗时: {timing_stats['vector_search']:.3f}")
if result.get("status") != "success":
error(f"融合搜索失败: {result.get('message', '未知错误')}")
return {"results": [], "timing": timing_stats}
unique_results = result.get("results", [])
sum = len(unique_results)
debug(f"从向量数据中搜索到{sum}条数据")
use_rerank = True
if use_rerank and unique_results:
rerank_start = time.time()
debug("开始重排序")
unique_results = await api_service.rerank_results(
request=request,
query=combined_text,
results=unique_results,
top_n=limit,
upappid=service_params['reranker'],
apiname="BAAI/bge-reranker-v2-m3",
user=userid
)
unique_results = sorted(unique_results, key=lambda x: x.get('rerank_score', 0), reverse=True)
timing_stats["reranking"] = time.time() - rerank_start
debug(f"重排序耗时: {timing_stats['reranking']:.3f}")
debug(f"重排序分数分布: {[round(r.get('rerank_score', 0), 3) for r in unique_results]}")
def _combine_query_with_triplets(query, triplets):
"""拼接查询文本和三元组文本"""
triplet_texts = []
for triplet in triplets:
head = triplet.get('head', '')
type_ = triplet.get('type', '')
tail = triplet.get('tail', '')
if head and type_ and tail:
triplet_texts.append(f"{head} {type_} {tail}")
else:
unique_results = [{k: v for k, v in r.items() if k != 'rerank_score'} for r in unique_results]
debug(f"无效三元组: {triplet}")
timing_stats["total_time"] = time.time() - start_time
info(f"融合搜索完成,返回 {len(unique_results)} 条结果,总耗时: {timing_stats['total_time']:.3f}")
combined_text = query
if triplet_texts:
combined_text += "".join(triplet_texts)
# dify_result = []
# for res in unique_results[:limit]:
# content = res.get('text', '')
# title = res.get('metadata', {}).get('filename', 'Untitled')
# document_id = res.get('metadata', {}).get('document_id', '')
# dify_result.append({
# 'metadata': {'document_id': document_id},
# 'title': title,
# 'content': content
# })
# info(f"融合搜索完成,返回 {len(dify_result)} 条结果,总耗时: {(time.time() - start_time):.3f} 秒")
# debug(f"result: {dify_result}")
# return dify_result
debug(f"拼接文本: {combined_text[:200]}... (总长度: {len(combined_text)}, 三元组数量: {len(triplet_texts)})")
return combined_text
dify_records = []
dify_result = []
for res in unique_results[:limit]:
rerank_score = res.get('rerank_score', 0)
score = 1 / (1 + math.exp(-rerank_score)) if rerank_score is not None else 1 - res.get('distance', 0)
score = max(0.0, min(1.0, score))
content = res.get('text', '')
title = res.get('metadata', {}).get('filename', 'Untitled')
document_id = res.get('metadata', {}).get('document_id', '')
dify_records.append({
"content": content,
"title": title,
"metadata": {"document_id": document_id, "score": score},
})
dify_result.append({
"content": content,
"title": title,
"metadata": {"document_id": document_id, "score": score},
})
info(f"融合搜索完成,返回 {len(dify_records)} 条结果,总耗时: {(time.time() - start_time):.3f}")
debug(f"records: {dify_records}, result: {dify_result}")
# return {"records": dify_records, "result": dify_result,"own": {"results": unique_results[:limit], "timing": timing_stats}}
return {"records": dify_records}
# dify_result = []
# for res in unique_results[:limit]:
# rerank_score = res.get('rerank_score', 0)
# score = 1 / (1 + math.exp(-rerank_score)) if rerank_score is not None else 1 - res.get('distance', 0)
# score = max(0.0, min(1.0, score))
# content = res.get('text', '')
# title = res.get('metadata', {}).get('filename', 'Untitled')
# document_id = res.get('metadata', {}).get('document_id', '')
# dify_result.append({
# "metadata": {
# "_source": "konwledge",
# "dataset_id":"111111",
# "dataset_name": "NVIDIA_GPU性能参数-RAG-V1.xlsx",
# "document_id": document_id,
# "document_name": "test.docx",
# "data_source_type": "upload_file",
# "segment_id": "7b391707-93bc-4654-80ae-7989f393b045",
# "retriever_from": "workflow",
# "score": score,
# "segment_hit_count": 7,
# "segment_word_count": 275,
# "segment_position": 5,
# "segment_index_node_hash": "1cd60b478221c9d4831a0b2af3e8b8581d94ecb53e8ffd46af687e8fc3077b73",
# "doc_metadata": None,
# "position":1
# },
# "title": title,
# "content": content
# })
# return {"result": dify_result}
except Exception as e:
error(f"融合搜索失败: {str(e)}, 堆栈: {traceback.format_exc()}")
return {"results": [], "timing": timing_stats}
# api_service = APIService()
# start_time = time.time()
# timing_stats = {}
# try:
# info(
# f"开始融合搜索: query={query}, userid={orgid}, knowledge_base_ids={fiids}")
#
# if not query or not orgid or not fiids:
# raise ValueError("query、orgid 和 knowledge_base_ids 不能为空")
#
# # 提取实体
# entity_extract_start = time.time()
# query_entities = await api_service.extract_entities(
# request=request,
# query=query,
# upappid=service_params['entities'],
# apiname="LTP/small",
# user=userid
# )
# timing_stats["entity_extraction"] = time.time() - entity_extract_start
# debug(f"提取实体: {query_entities}, 耗时: {timing_stats['entity_extraction']:.3f} 秒")
#
# # 调用 Neo4j 服务进行三元组匹配
# all_triplets = []
# triplet_match_start = time.time()
# for kb_id in fiids:
# debug(f"调用 Neo4j 三元组匹配: knowledge_base_id={kb_id}")
# try:
# neo4j_result = await api_service.neo4j_match_triplets(
# request=request,
# query=query,
# query_entities=query_entities,
# userid=orgid,
# knowledge_base_id=kb_id,
# upappid=service_params['gdb'],
# apiname="neo4j/matchtriplets",
# user=userid
# )
# if neo4j_result.get("status") == "success":
# triplets = neo4j_result.get("triplets", [])
# all_triplets.extend(triplets)
# debug(f"知识库 {kb_id} 匹配到 {len(triplets)} 个三元组: {triplets[:5]}")
# else:
# error(
# f"Neo4j 三元组匹配失败: knowledge_base_id={kb_id}, 错误: {neo4j_result.get('message', '未知错误')}")
# except Exception as e:
# error(f"Neo4j 三元组匹配失败: knowledge_base_id={kb_id}, 错误: {str(e)}")
# continue
# timing_stats["triplet_matching"] = time.time() - triplet_match_start
# debug(f"三元组匹配总耗时: {timing_stats['triplet_matching']:.3f} 秒")
#
# # 拼接三元组文本
# triplet_text_start = time.time()
# triplet_texts = []
# for triplet in all_triplets:
# head = triplet.get('head', '')
# type_ = triplet.get('type', '')
# tail = triplet.get('tail', '')
# if head and type_ and tail:
# triplet_texts.append(f"{head} {type_} {tail}")
# else:
# debug(f"无效三元组: {triplet}")
# combined_text = query
# if triplet_texts:
# combined_text += "".join(triplet_texts)
# debug(
# f"拼接文本: {combined_text[:200]}... (总长度: {len(combined_text)}, 三元组数量: {len(triplet_texts)})")
# timing_stats["triplet_text_combine"] = time.time() - triplet_text_start
# debug(f"拼接三元组文本耗时: {timing_stats['triplet_text_combine']:.3f} 秒")
#
# # 将拼接文本转换为向量
# vector_start = time.time()
# query_vector = await api_service.get_embeddings(
# request=request,
# texts=[combined_text],
# upappid=service_params['embedding'],
# apiname="BAAI/bge-m3",
# user=userid
# )
# if not query_vector or not all(len(vec) == 1024 for vec in query_vector):
# raise ValueError("查询向量必须是长度为 1024 的浮点数列表")
# query_vector = query_vector[0] # 取第一个向量
# timing_stats["vector_generation"] = time.time() - vector_start
# debug(f"生成查询向量耗时: {timing_stats['vector_generation']:.3f} 秒")
#
# # 调用搜索端点
# sum = limit + 5
# search_start = time.time()
# debug(f"orgid: {orgid}")
# result = await api_service.milvus_search_query(
# request=request,
# query_vector=query_vector,
# userid=orgid,
# knowledge_base_ids=fiids,
# limit=sum,
# offset=0,
# upappid=service_params['vdb'],
# apiname="mlvus/searchquery",
# user=userid
# )
# timing_stats["vector_search"] = time.time() - search_start
# debug(f"向量搜索耗时: {timing_stats['vector_search']:.3f} 秒")
#
# if result.get("status") != "success":
# error(f"融合搜索失败: {result.get('message', '未知错误')}")
# return {"results": [], "timing": timing_stats}
#
# unique_results = result.get("results", [])
# sum = len(unique_results)
# debug(f"从向量数据中搜索到{sum}条数据")
# use_rerank = True
# if use_rerank and unique_results:
# rerank_start = time.time()
# debug("开始重排序")
# unique_results = await api_service.rerank_results(
# request=request,
# query=combined_text,
# results=unique_results,
# top_n=limit,
# upappid=service_params['reranker'],
# apiname="BAAI/bge-reranker-v2-m3",
# user=userid
# )
# unique_results = sorted(unique_results, key=lambda x: x.get('rerank_score', 0), reverse=True)
# timing_stats["reranking"] = time.time() - rerank_start
# debug(f"重排序耗时: {timing_stats['reranking']:.3f} 秒")
# debug(f"重排序分数分布: {[round(r.get('rerank_score', 0), 3) for r in unique_results]}")
# else:
# unique_results = [{k: v for k, v in r.items() if k != 'rerank_score'} for r in unique_results]
#
# timing_stats["total_time"] = time.time() - start_time
# info(f"融合搜索完成,返回 {len(unique_results)} 条结果,总耗时: {timing_stats['total_time']:.3f} 秒")
#
# # dify_result = []
# # for res in unique_results[:limit]:
# # content = res.get('text', '')
# # title = res.get('metadata', {}).get('filename', 'Untitled')
# # document_id = res.get('metadata', {}).get('document_id', '')
# # dify_result.append({
# # 'metadata': {'document_id': document_id},
# # 'title': title,
# # 'content': content
# # })
# # info(f"融合搜索完成,返回 {len(dify_result)} 条结果,总耗时: {(time.time() - start_time):.3f} 秒")
# # debug(f"result: {dify_result}")
# # return dify_result
#
# dify_records = []
# dify_result = []
# for res in unique_results[:limit]:
# rerank_score = res.get('rerank_score', 0)
# score = 1 / (1 + math.exp(-rerank_score)) if rerank_score is not None else 1 - res.get('distance', 0)
# score = max(0.0, min(1.0, score))
# content = res.get('text', '')
# title = res.get('metadata', {}).get('filename', 'Untitled')
# document_id = res.get('metadata', {}).get('document_id', '')
# dify_records.append({
# "content": content,
# "title": title,
# "metadata": {"document_id": document_id, "score": score},
# })
# dify_result.append({
# "content": content,
# "title": title,
# "metadata": {"document_id": document_id, "score": score},
# })
# info(f"融合搜索完成,返回 {len(dify_records)} 条结果,总耗时: {(time.time() - start_time):.3f} 秒")
# debug(f"records: {dify_records}, result: {dify_result}")
# # return {"records": dify_records, "result": dify_result,"own": {"results": unique_results[:limit], "timing": timing_stats}}
# return {"records": dify_records}
#
# # dify_result = []
# # for res in unique_results[:limit]:
# # rerank_score = res.get('rerank_score', 0)
# # score = 1 / (1 + math.exp(-rerank_score)) if rerank_score is not None else 1 - res.get('distance', 0)
# # score = max(0.0, min(1.0, score))
# # content = res.get('text', '')
# # title = res.get('metadata', {}).get('filename', 'Untitled')
# # document_id = res.get('metadata', {}).get('document_id', '')
# # dify_result.append({
# # "metadata": {
# # "_source": "konwledge",
# # "dataset_id":"111111",
# # "dataset_name": "NVIDIA_GPU性能参数-RAG-V1.xlsx",
# # "document_id": document_id,
# # "document_name": "test.docx",
# # "data_source_type": "upload_file",
# # "segment_id": "7b391707-93bc-4654-80ae-7989f393b045",
# # "retriever_from": "workflow",
# # "score": score,
# # "segment_hit_count": 7,
# # "segment_word_count": 275,
# # "segment_position": 5,
# # "segment_index_node_hash": "1cd60b478221c9d4831a0b2af3e8b8581d94ecb53e8ffd46af687e8fc3077b73",
# # "doc_metadata": None,
# # "position":1
# # },
# # "title": title,
# # "content": content
# # })
# # return {"result": dify_result}
#
# except Exception as e:
# error(f"融合搜索失败: {str(e)}, 堆栈: {traceback.format_exc()}")
# return {"results": [], "timing": timing_stats}