rag
This commit is contained in:
parent
1c8eb52816
commit
97c9e0f1fa
@ -18,9 +18,18 @@ from filetxt.loader import fileloader,File2Text
|
||||
from ahserver.serverenv import get_serverenv
|
||||
from typing import List, Dict, Any
|
||||
from rag.service_opts import get_service_params, sor_get_service_params
|
||||
from rag.rag_operations import RagOperations
|
||||
import json
|
||||
|
||||
from dataclasses import dataclass
|
||||
from enum import Enum
|
||||
|
||||
|
||||
class RagFileMgr(FileMgr):
|
||||
def __init__(self, fiid):
|
||||
super().__init__(fiid)
|
||||
self.rag_ops = RagOperations()
|
||||
|
||||
async def get_folder_ownerid(self, sor):
|
||||
fiid = self.fiid
|
||||
recs = await sor.R('kdb', {'id': self.fiid})
|
||||
@ -44,205 +53,6 @@ where a.orgid = b.orgid
|
||||
return r.quota, r.expired_date
|
||||
return None, None
|
||||
|
||||
async def get_doucment_chunks(self, realpath, timings):
|
||||
"""加载文件并进行文本分片"""
|
||||
debug(f"加载文件: {realpath}")
|
||||
start_load = time.time()
|
||||
supported_formats = File2Text.supported_types()
|
||||
debug(f"支持的文件格式:{supported_formats}")
|
||||
ext = realpath.rsplit('.', 1)[1].lower() if '.' in realpath else ''
|
||||
if ext not in supported_formats:
|
||||
raise ValueError(f"不支持的文件格式: {ext}, 支持的格式: {', '.join(supported_formats)}")
|
||||
text = fileloader(realpath)
|
||||
text = re.sub(r'[^\u4e00-\u9fa5a-zA-Z0-9\s.;,\n/]', '', text)
|
||||
timings["load_file"] = time.time() - start_load
|
||||
debug(f"加载文件耗时: {timings['load_file']:.2f} 秒, 文本长度: {len(text)}")
|
||||
|
||||
if not text or not text.strip():
|
||||
raise ValueError(f"文件 {realpath} 加载为空")
|
||||
|
||||
document = Document(page_content=text)
|
||||
text_splitter = RecursiveCharacterTextSplitter(
|
||||
chunk_size=500,
|
||||
chunk_overlap=100,
|
||||
length_function=len
|
||||
)
|
||||
debug("开始分片文件内容")
|
||||
start_split = time.time()
|
||||
chunks = text_splitter.split_documents([document])
|
||||
timings["split_text"] = time.time() - start_split
|
||||
debug(f"文本分片耗时: {timings['split_text']:.2f} 秒, 分片数量: {len(chunks)}")
|
||||
|
||||
if not chunks:
|
||||
raise ValueError(f"文件 {realpath} 未生成任何文档块")
|
||||
|
||||
return chunks
|
||||
|
||||
async def docs_embedding(self, request, chunks, service_params, userid, timings):
|
||||
"""调用嵌入服务生成向量"""
|
||||
debug("调用嵌入服务生成向量")
|
||||
start_embedding = time.time()
|
||||
texts = [chunk.page_content for chunk in chunks]
|
||||
embeddings = []
|
||||
for i in range(0, len(texts), 10):
|
||||
batch_texts = texts[i:i + 10]
|
||||
batch_embeddings = await APIService().get_embeddings(
|
||||
request=request,
|
||||
texts=batch_texts,
|
||||
upappid=service_params['embedding'],
|
||||
apiname="BAAI/bge-m3",
|
||||
user=userid
|
||||
)
|
||||
embeddings.extend(batch_embeddings)
|
||||
|
||||
if not embeddings or not all(len(vec) == 1024 for vec in embeddings):
|
||||
raise ValueError("所有嵌入向量必须是长度为 1024 的浮点数列表")
|
||||
|
||||
timings["generate_embeddings"] = time.time() - start_embedding
|
||||
debug(f"生成嵌入向量耗时: {timings['generate_embeddings']:.2f} 秒, 嵌入数量: {len(embeddings)}")
|
||||
return embeddings
|
||||
|
||||
async def embedding_2_vdb(self, request, chunks, embeddings, realpath, orgid, fiid, id, service_params, userid,
|
||||
db_type, timings):
|
||||
"""准备数据并插入 Milvus"""
|
||||
debug(f"准备数据并调用插入文件端点: {realpath}")
|
||||
filename = os.path.basename(realpath).rsplit('.', 1)[0]
|
||||
ext = realpath.rsplit('.', 1)[1].lower() if '.' in realpath else ''
|
||||
upload_time = datetime.now().isoformat()
|
||||
|
||||
chunks_data = [
|
||||
{
|
||||
"userid": orgid,
|
||||
"knowledge_base_id": fiid,
|
||||
"text": chunk.page_content,
|
||||
"vector": embeddings[i],
|
||||
"document_id": id,
|
||||
"filename": filename + '.' + ext,
|
||||
"file_path": realpath,
|
||||
"upload_time": upload_time,
|
||||
"file_type": ext,
|
||||
}
|
||||
for i, chunk in enumerate(chunks)
|
||||
]
|
||||
|
||||
start_milvus = time.time()
|
||||
for i in range(0, len(chunks_data), 10):
|
||||
batch_chunks = chunks_data[i:i + 10]
|
||||
result = await APIService().milvus_insert_document(
|
||||
request=request,
|
||||
chunks=batch_chunks,
|
||||
db_type=db_type,
|
||||
upappid=service_params['vdb'],
|
||||
apiname="milvus/insertdocument",
|
||||
user=userid
|
||||
)
|
||||
if result.get("status") != "success":
|
||||
raise ValueError(result.get("message", "Milvus 插入失败"))
|
||||
|
||||
timings["insert_milvus"] = time.time() - start_milvus
|
||||
debug(f"Milvus 插入耗时: {timings['insert_milvus']:.2f} 秒")
|
||||
return chunks_data
|
||||
|
||||
async def get_triples(self, request, chunks, service_params, userid, timings):
|
||||
"""调用三元组抽取服务"""
|
||||
debug("调用三元组抽取服务")
|
||||
start_triples = time.time()
|
||||
chunk_texts = [doc.page_content for doc in chunks]
|
||||
triples = []
|
||||
for i, chunk in enumerate(chunk_texts):
|
||||
result = await APIService().extract_triples(
|
||||
request=request,
|
||||
text=chunk,
|
||||
upappid=service_params['triples'],
|
||||
apiname="Babelscape/mrebel-large",
|
||||
user=userid
|
||||
)
|
||||
if isinstance(result, list):
|
||||
triples.extend(result)
|
||||
debug(f"分片 {i + 1} 抽取到 {len(result)} 个三元组")
|
||||
else:
|
||||
error(f"分片 {i + 1} 处理失败: {str(result)}")
|
||||
|
||||
unique_triples = []
|
||||
seen = set()
|
||||
for t in triples:
|
||||
identifier = (t['head'].lower(), t['tail'].lower(), t['type'].lower())
|
||||
if identifier not in seen:
|
||||
seen.add(identifier)
|
||||
unique_triples.append(t)
|
||||
else:
|
||||
for existing in unique_triples:
|
||||
if (existing['head'].lower() == t['head'].lower() and
|
||||
existing['tail'].lower() == t['tail'].lower() and
|
||||
len(t['type']) > len(existing['type'])):
|
||||
unique_triples.remove(existing)
|
||||
unique_triples.append(t)
|
||||
debug(f"替换三元组为更具体类型: {t}")
|
||||
break
|
||||
|
||||
timings["extract_triples"] = time.time() - start_triples
|
||||
debug(f"三元组抽取耗时: {timings['extract_triples']:.2f} 秒, 抽取到 {len(unique_triples)} 个三元组")
|
||||
return unique_triples
|
||||
|
||||
async def triple2graphdb(self, request, unique_triples, id, fiid, orgid, service_params, userid, timings):
|
||||
"""调用 Neo4j 插入三元组"""
|
||||
debug(f"插入 {len(unique_triples)} 个三元组到 Neo4j")
|
||||
start_neo4j = time.time()
|
||||
if unique_triples:
|
||||
for i in range(0, len(unique_triples), 30):
|
||||
batch_triples = unique_triples[i:i + 30]
|
||||
neo4j_result = await APIService().neo4j_insert_triples(
|
||||
request=request,
|
||||
triples=batch_triples,
|
||||
document_id=id,
|
||||
knowledge_base_id=fiid,
|
||||
userid=orgid,
|
||||
upappid=service_params['gdb'],
|
||||
apiname="neo4j/inserttriples",
|
||||
user=userid
|
||||
)
|
||||
if neo4j_result.get("status") != "success":
|
||||
raise ValueError(f"Neo4j 三元组插入失败: {neo4j_result.get('message', '未知错误')}")
|
||||
info(f"文件三元组成功插入 Neo4j: {neo4j_result.get('message')}")
|
||||
timings["insert_neo4j"] = time.time() - start_neo4j
|
||||
debug(f"Neo4j 插入耗时: {timings['insert_neo4j']:.2f} 秒")
|
||||
else:
|
||||
debug("未抽取到三元组")
|
||||
timings["insert_neo4j"] = 0.0
|
||||
|
||||
async def delete_from_milvus(self, request, orgid, realpath, fiid, id, service_params, userid, db_type):
|
||||
"""调用 Milvus 删除文档"""
|
||||
debug(f"调用删除文件端点: userid={orgid}, file_path={realpath}, knowledge_base_id={fiid}, document_id={id}")
|
||||
milvus_result = await APIService().milvus_delete_document(
|
||||
request=request,
|
||||
userid=orgid,
|
||||
file_path=realpath,
|
||||
knowledge_base_id=fiid,
|
||||
document_id=id,
|
||||
db_type=db_type,
|
||||
upappid=service_params['vdb'],
|
||||
apiname="milvus/deletedocument",
|
||||
user=userid
|
||||
)
|
||||
if milvus_result.get("status") != "success":
|
||||
raise ValueError(milvus_result.get("message", "Milvus 删除失败"))
|
||||
|
||||
async def delete_from_neo4j(self, request, id, service_params, userid):
|
||||
"""调用 Neo4j 删除文档"""
|
||||
debug(f"调用 Neo4j 删除文档端点: document_id={id}")
|
||||
neo4j_result = await APIService().neo4j_delete_document(
|
||||
request=request,
|
||||
document_id=id,
|
||||
upappid=service_params['gdb'],
|
||||
apiname="neo4j/deletedocument",
|
||||
user=userid
|
||||
)
|
||||
if neo4j_result.get("status") != "success":
|
||||
raise ValueError(neo4j_result.get("message", "Neo4j 删除失败"))
|
||||
nodes_deleted = neo4j_result.get("nodes_deleted", 0)
|
||||
rels_deleted = neo4j_result.get("rels_deleted", 0)
|
||||
info(f"成功删除 document_id={id} 的 {nodes_deleted} 个 Neo4j 节点和 {rels_deleted} 个关系")
|
||||
return nodes_deleted, rels_deleted
|
||||
|
||||
async def file_uploaded(self, request, ns, userid):
|
||||
"""将文档插入 Milvus 并抽取三元组到 Neo4j"""
|
||||
@ -272,11 +82,11 @@ where a.orgid = b.orgid
|
||||
if not service_params:
|
||||
raise ValueError("无法获取服务参数")
|
||||
|
||||
chunks = await self.get_doucment_chunks(realpath, timings)
|
||||
embeddings = await self.docs_embedding(request, chunks, service_params, userid, timings)
|
||||
await self.embedding_2_vdb(request, chunks, embeddings, realpath, orgid, fiid, id, service_params, userid,db_type, timings)
|
||||
triples = await self.get_triples(request, chunks, service_params, userid, timings)
|
||||
await self.triple2graphdb(request, triples, id, fiid, orgid, service_params, userid, timings)
|
||||
chunks = await self.rag_ops.load_and_chunk_document(realpath, timings)
|
||||
embeddings = await self.rag_ops.generate_embeddings(request, chunks, service_params, userid, timings)
|
||||
await self.rag_ops.insert_to_vector_db(request, chunks, embeddings, realpath, orgid, fiid, id, service_params, userid,db_type, timings)
|
||||
triples = await self.rag_ops.extract_triples(request, chunks, service_params, userid, timings)
|
||||
await self.rag_ops.insert_to_graph_db(request, triples, id, fiid, orgid, service_params, userid, timings)
|
||||
|
||||
timings["total"] = time.time() - start_total
|
||||
debug(f"总耗时: {timings['total']:.2f} 秒")
|
||||
@ -329,13 +139,13 @@ where a.orgid = b.orgid
|
||||
raise ValueError("无法获取服务参数")
|
||||
|
||||
# 调用 Milvus 删除
|
||||
await self.delete_from_milvus(request, orgid, realpath, fiid, id, service_params, userid, db_type)
|
||||
await self.rag_ops.delete_from_vector_db(request, orgid, realpath, fiid, id, service_params, userid, db_type)
|
||||
|
||||
# 调用 Neo4j 删除
|
||||
neo4j_deleted_nodes = 0
|
||||
neo4j_deleted_rels = 0
|
||||
try:
|
||||
nodes_deleted, rels_deleted = await self.delete_from_neo4j(request, id, service_params, userid)
|
||||
nodes_deleted, rels_deleted = await self.rag_ops.delete_from_graph_db(request, id, service_params, userid)
|
||||
neo4j_deleted_nodes += nodes_deleted
|
||||
neo4j_deleted_rels += rels_deleted
|
||||
total_nodes_deleted += nodes_deleted
|
||||
@ -369,6 +179,332 @@ where a.orgid = b.orgid
|
||||
"message": f"处理 {len(recs)} 个文件,成功删除 {sum(1 for r in results if r['status'] == 'success')} 个",
|
||||
"status_code": 200 if all(r["status"] == "success" for r in results) else 207
|
||||
}
|
||||
# async def get_doucment_chunks(self, realpath, timings):
|
||||
# """加载文件并进行文本分片"""
|
||||
# debug(f"加载文件: {realpath}")
|
||||
# start_load = time.time()
|
||||
# supported_formats = File2Text.supported_types()
|
||||
# debug(f"支持的文件格式:{supported_formats}")
|
||||
# ext = realpath.rsplit('.', 1)[1].lower() if '.' in realpath else ''
|
||||
# if ext not in supported_formats:
|
||||
# raise ValueError(f"不支持的文件格式: {ext}, 支持的格式: {', '.join(supported_formats)}")
|
||||
# text = fileloader(realpath)
|
||||
# text = re.sub(r'[^\u4e00-\u9fa5a-zA-Z0-9\s.;,\n/]', '', text)
|
||||
# timings["load_file"] = time.time() - start_load
|
||||
# debug(f"加载文件耗时: {timings['load_file']:.2f} 秒, 文本长度: {len(text)}")
|
||||
#
|
||||
# if not text or not text.strip():
|
||||
# raise ValueError(f"文件 {realpath} 加载为空")
|
||||
#
|
||||
# document = Document(page_content=text)
|
||||
# text_splitter = RecursiveCharacterTextSplitter(
|
||||
# chunk_size=500,
|
||||
# chunk_overlap=100,
|
||||
# length_function=len
|
||||
# )
|
||||
# debug("开始分片文件内容")
|
||||
# start_split = time.time()
|
||||
# chunks = text_splitter.split_documents([document])
|
||||
# timings["split_text"] = time.time() - start_split
|
||||
# debug(f"文本分片耗时: {timings['split_text']:.2f} 秒, 分片数量: {len(chunks)}")
|
||||
#
|
||||
# if not chunks:
|
||||
# raise ValueError(f"文件 {realpath} 未生成任何文档块")
|
||||
#
|
||||
# return chunks
|
||||
#
|
||||
# async def docs_embedding(self, request, chunks, service_params, userid, timings):
|
||||
# """调用嵌入服务生成向量"""
|
||||
# debug("调用嵌入服务生成向量")
|
||||
# start_embedding = time.time()
|
||||
# texts = [chunk.page_content for chunk in chunks]
|
||||
# embeddings = []
|
||||
# for i in range(0, len(texts), 10):
|
||||
# batch_texts = texts[i:i + 10]
|
||||
# batch_embeddings = await APIService().get_embeddings(
|
||||
# request=request,
|
||||
# texts=batch_texts,
|
||||
# upappid=service_params['embedding'],
|
||||
# apiname="BAAI/bge-m3",
|
||||
# user=userid
|
||||
# )
|
||||
# embeddings.extend(batch_embeddings)
|
||||
#
|
||||
# if not embeddings or not all(len(vec) == 1024 for vec in embeddings):
|
||||
# raise ValueError("所有嵌入向量必须是长度为 1024 的浮点数列表")
|
||||
#
|
||||
# timings["generate_embeddings"] = time.time() - start_embedding
|
||||
# debug(f"生成嵌入向量耗时: {timings['generate_embeddings']:.2f} 秒, 嵌入数量: {len(embeddings)}")
|
||||
# return embeddings
|
||||
#
|
||||
# async def embedding_2_vdb(self, request, chunks, embeddings, realpath, orgid, fiid, id, service_params, userid,
|
||||
# db_type, timings):
|
||||
# """准备数据并插入 Milvus"""
|
||||
# debug(f"准备数据并调用插入文件端点: {realpath}")
|
||||
# filename = os.path.basename(realpath).rsplit('.', 1)[0]
|
||||
# ext = realpath.rsplit('.', 1)[1].lower() if '.' in realpath else ''
|
||||
# upload_time = datetime.now().isoformat()
|
||||
#
|
||||
# chunks_data = [
|
||||
# {
|
||||
# "userid": orgid,
|
||||
# "knowledge_base_id": fiid,
|
||||
# "text": chunk.page_content,
|
||||
# "vector": embeddings[i],
|
||||
# "document_id": id,
|
||||
# "filename": filename + '.' + ext,
|
||||
# "file_path": realpath,
|
||||
# "upload_time": upload_time,
|
||||
# "file_type": ext,
|
||||
# }
|
||||
# for i, chunk in enumerate(chunks)
|
||||
# ]
|
||||
#
|
||||
# start_milvus = time.time()
|
||||
# for i in range(0, len(chunks_data), 10):
|
||||
# batch_chunks = chunks_data[i:i + 10]
|
||||
# result = await APIService().milvus_insert_document(
|
||||
# request=request,
|
||||
# chunks=batch_chunks,
|
||||
# db_type=db_type,
|
||||
# upappid=service_params['vdb'],
|
||||
# apiname="milvus/insertdocument",
|
||||
# user=userid
|
||||
# )
|
||||
# if result.get("status") != "success":
|
||||
# raise ValueError(result.get("message", "Milvus 插入失败"))
|
||||
#
|
||||
# timings["insert_milvus"] = time.time() - start_milvus
|
||||
# debug(f"Milvus 插入耗时: {timings['insert_milvus']:.2f} 秒")
|
||||
# return chunks_data
|
||||
#
|
||||
# async def get_triples(self, request, chunks, service_params, userid, timings):
|
||||
# """调用三元组抽取服务"""
|
||||
# debug("调用三元组抽取服务")
|
||||
# start_triples = time.time()
|
||||
# chunk_texts = [doc.page_content for doc in chunks]
|
||||
# triples = []
|
||||
# for i, chunk in enumerate(chunk_texts):
|
||||
# result = await APIService().extract_triples(
|
||||
# request=request,
|
||||
# text=chunk,
|
||||
# upappid=service_params['triples'],
|
||||
# apiname="Babelscape/mrebel-large",
|
||||
# user=userid
|
||||
# )
|
||||
# if isinstance(result, list):
|
||||
# triples.extend(result)
|
||||
# debug(f"分片 {i + 1} 抽取到 {len(result)} 个三元组")
|
||||
# else:
|
||||
# error(f"分片 {i + 1} 处理失败: {str(result)}")
|
||||
#
|
||||
# unique_triples = []
|
||||
# seen = set()
|
||||
# for t in triples:
|
||||
# identifier = (t['head'].lower(), t['tail'].lower(), t['type'].lower())
|
||||
# if identifier not in seen:
|
||||
# seen.add(identifier)
|
||||
# unique_triples.append(t)
|
||||
# else:
|
||||
# for existing in unique_triples:
|
||||
# if (existing['head'].lower() == t['head'].lower() and
|
||||
# existing['tail'].lower() == t['tail'].lower() and
|
||||
# len(t['type']) > len(existing['type'])):
|
||||
# unique_triples.remove(existing)
|
||||
# unique_triples.append(t)
|
||||
# debug(f"替换三元组为更具体类型: {t}")
|
||||
# break
|
||||
#
|
||||
# timings["extract_triples"] = time.time() - start_triples
|
||||
# debug(f"三元组抽取耗时: {timings['extract_triples']:.2f} 秒, 抽取到 {len(unique_triples)} 个三元组")
|
||||
# return unique_triples
|
||||
#
|
||||
# async def triple2graphdb(self, request, unique_triples, id, fiid, orgid, service_params, userid, timings):
|
||||
# """调用 Neo4j 插入三元组"""
|
||||
# debug(f"插入 {len(unique_triples)} 个三元组到 Neo4j")
|
||||
# start_neo4j = time.time()
|
||||
# if unique_triples:
|
||||
# for i in range(0, len(unique_triples), 30):
|
||||
# batch_triples = unique_triples[i:i + 30]
|
||||
# neo4j_result = await APIService().neo4j_insert_triples(
|
||||
# request=request,
|
||||
# triples=batch_triples,
|
||||
# document_id=id,
|
||||
# knowledge_base_id=fiid,
|
||||
# userid=orgid,
|
||||
# upappid=service_params['gdb'],
|
||||
# apiname="neo4j/inserttriples",
|
||||
# user=userid
|
||||
# )
|
||||
# if neo4j_result.get("status") != "success":
|
||||
# raise ValueError(f"Neo4j 三元组插入失败: {neo4j_result.get('message', '未知错误')}")
|
||||
# info(f"文件三元组成功插入 Neo4j: {neo4j_result.get('message')}")
|
||||
# timings["insert_neo4j"] = time.time() - start_neo4j
|
||||
# debug(f"Neo4j 插入耗时: {timings['insert_neo4j']:.2f} 秒")
|
||||
# else:
|
||||
# debug("未抽取到三元组")
|
||||
# timings["insert_neo4j"] = 0.0
|
||||
#
|
||||
# async def delete_from_milvus(self, request, orgid, realpath, fiid, id, service_params, userid, db_type):
|
||||
# """调用 Milvus 删除文档"""
|
||||
# debug(f"调用删除文件端点: userid={orgid}, file_path={realpath}, knowledge_base_id={fiid}, document_id={id}")
|
||||
# milvus_result = await APIService().milvus_delete_document(
|
||||
# request=request,
|
||||
# userid=orgid,
|
||||
# file_path=realpath,
|
||||
# knowledge_base_id=fiid,
|
||||
# document_id=id,
|
||||
# db_type=db_type,
|
||||
# upappid=service_params['vdb'],
|
||||
# apiname="milvus/deletedocument",
|
||||
# user=userid
|
||||
# )
|
||||
# if milvus_result.get("status") != "success":
|
||||
# raise ValueError(milvus_result.get("message", "Milvus 删除失败"))
|
||||
#
|
||||
# async def delete_from_neo4j(self, request, id, service_params, userid):
|
||||
# """调用 Neo4j 删除文档"""
|
||||
# debug(f"调用 Neo4j 删除文档端点: document_id={id}")
|
||||
# neo4j_result = await APIService().neo4j_delete_document(
|
||||
# request=request,
|
||||
# document_id=id,
|
||||
# upappid=service_params['gdb'],
|
||||
# apiname="neo4j/deletedocument",
|
||||
# user=userid
|
||||
# )
|
||||
# if neo4j_result.get("status") != "success":
|
||||
# raise ValueError(neo4j_result.get("message", "Neo4j 删除失败"))
|
||||
# nodes_deleted = neo4j_result.get("nodes_deleted", 0)
|
||||
# rels_deleted = neo4j_result.get("rels_deleted", 0)
|
||||
# info(f"成功删除 document_id={id} 的 {nodes_deleted} 个 Neo4j 节点和 {rels_deleted} 个关系")
|
||||
# return nodes_deleted, rels_deleted
|
||||
|
||||
# async def file_uploaded(self, request, ns, userid):
|
||||
# """将文档插入 Milvus 并抽取三元组到 Neo4j"""
|
||||
# debug(f'Received ns: {ns=}')
|
||||
# env = request._run_ns
|
||||
# realpath = ns.get('realpath', '')
|
||||
# fiid = ns.get('fiid', '')
|
||||
# id = ns.get('id', '')
|
||||
# orgid = ns.get('ownerid', '')
|
||||
# db_type = ''
|
||||
#
|
||||
# debug(f'Inserting document: file_path={realpath}, userid={orgid}, db_type={db_type}, knowledge_base_id={fiid}, document_id={id}')
|
||||
#
|
||||
# timings = {}
|
||||
# start_total = time.time()
|
||||
#
|
||||
# try:
|
||||
# if not orgid or not fiid or not id:
|
||||
# raise ValueError("orgid、fiid 和 id 不能为空")
|
||||
# if len(orgid) > 32 or len(fiid) > 255:
|
||||
# raise ValueError("orgid 或 fiid 的长度超出限制")
|
||||
# if not os.path.exists(realpath):
|
||||
# raise ValueError(f"文件 {realpath} 不存在")
|
||||
#
|
||||
# # 获取服务参数
|
||||
# service_params = await get_service_params(orgid)
|
||||
# if not service_params:
|
||||
# raise ValueError("无法获取服务参数")
|
||||
#
|
||||
# chunks = await self.get_doucment_chunks(realpath, timings)
|
||||
# embeddings = await self.docs_embedding(request, chunks, service_params, userid, timings)
|
||||
# await self.embedding_2_vdb(request, chunks, embeddings, realpath, orgid, fiid, id, service_params, userid,db_type, timings)
|
||||
# triples = await self.get_triples(request, chunks, service_params, userid, timings)
|
||||
# await self.triple2graphdb(request, triples, id, fiid, orgid, service_params, userid, timings)
|
||||
#
|
||||
# timings["total"] = time.time() - start_total
|
||||
# debug(f"总耗时: {timings['total']:.2f} 秒")
|
||||
# return {
|
||||
# "status": "success",
|
||||
# "userid": orgid,
|
||||
# "document_id": id,
|
||||
# "collection_name": "ragdb",
|
||||
# "timings": timings,
|
||||
# "unique_triples": triples,
|
||||
# "message": f"文件 {realpath} 成功嵌入并处理三元组",
|
||||
# "status_code": 200
|
||||
# }
|
||||
# except Exception as e:
|
||||
# error(f"插入文档失败: {str(e)}, 堆栈: {traceback.format_exc()}")
|
||||
# timings["total"] = time.time() - start_total
|
||||
# return {
|
||||
# "status": "error",
|
||||
# "document_id": id,
|
||||
# "collection_name": "ragdb",
|
||||
# "timings": timings,
|
||||
# "message": f"插入文档失败: {str(e)}",
|
||||
# "status_code": 400
|
||||
# }
|
||||
#
|
||||
# async def file_deleted(self, request, recs, userid):
|
||||
# """删除用户指定文件数据,包括 Milvus 和 Neo4j 中的记录"""
|
||||
# if not isinstance(recs, list):
|
||||
# recs = [recs]
|
||||
# results = []
|
||||
# total_nodes_deleted = 0
|
||||
# total_rels_deleted = 0
|
||||
#
|
||||
# for rec in recs:
|
||||
# id = rec.get('id', '')
|
||||
# realpath = rec.get('realpath', '')
|
||||
# fiid = rec.get('fiid', '')
|
||||
# orgid = rec.get('ownerid', '')
|
||||
# db_type = ''
|
||||
# collection_name = "ragdb" if not db_type else f"ragdb_{db_type}"
|
||||
#
|
||||
# try:
|
||||
# required_fields = ['id', 'realpath', 'fiid', 'ownerid']
|
||||
# missing_fields = [field for field in required_fields if not rec.get(field, '')]
|
||||
# if missing_fields:
|
||||
# raise ValueError(f"缺少必填字段: {', '.join(missing_fields)}")
|
||||
#
|
||||
# service_params = await get_service_params(orgid)
|
||||
# if not service_params:
|
||||
# raise ValueError("无法获取服务参数")
|
||||
#
|
||||
# # 调用 Milvus 删除
|
||||
# await self.delete_from_milvus(request, orgid, realpath, fiid, id, service_params, userid, db_type)
|
||||
#
|
||||
# # 调用 Neo4j 删除
|
||||
# neo4j_deleted_nodes = 0
|
||||
# neo4j_deleted_rels = 0
|
||||
# try:
|
||||
# nodes_deleted, rels_deleted = await self.delete_from_neo4j(request, id, service_params, userid)
|
||||
# neo4j_deleted_nodes += nodes_deleted
|
||||
# neo4j_deleted_rels += rels_deleted
|
||||
# total_nodes_deleted += nodes_deleted
|
||||
# total_rels_deleted += rels_deleted
|
||||
# except Exception as e:
|
||||
# error(f"删除 document_id={id} 的 Neo4j 数据失败: {str(e)}")
|
||||
#
|
||||
# results.append({
|
||||
# "status": "success",
|
||||
# "collection_name": collection_name,
|
||||
# "document_id": id,
|
||||
# "message": f"成功删除文件 {realpath} 的 Milvus 记录,{neo4j_deleted_nodes} 个 Neo4j 节点,{neo4j_deleted_rels} 个 Neo4j 关系",
|
||||
# "status_code": 200
|
||||
# })
|
||||
#
|
||||
# except Exception as e:
|
||||
# error(f"删除文档 {realpath} 失败: {str(e)}, 堆栈: {traceback.format_exc()}")
|
||||
# results.append({
|
||||
# "status": "error",
|
||||
# "collection_name": collection_name,
|
||||
# "document_id": id,
|
||||
# "message": f"删除文档 {realpath} 失败: {str(e)}",
|
||||
# "status_code": 400
|
||||
# })
|
||||
#
|
||||
# return {
|
||||
# "status": "success" if all(r["status"] == "success" for r in results) else "partial",
|
||||
# "results": results,
|
||||
# "total_nodes_deleted": total_nodes_deleted,
|
||||
# "total_rels_deleted": total_rels_deleted,
|
||||
# "message": f"处理 {len(recs)} 个文件,成功删除 {sum(1 for r in results if r['status'] == 'success')} 个",
|
||||
# "status_code": 200 if all(r["status"] == "success" for r in results) else 207
|
||||
# }
|
||||
|
||||
|
||||
# async def test_ragfilemgr():
|
||||
# """测试 RagFileMgr 类的 get_service_params"""
|
||||
|
||||
450
rag/rag_operations.py
Normal file
450
rag/rag_operations.py
Normal file
@ -0,0 +1,450 @@
|
||||
"""
|
||||
RAG 操作的通用函数库
|
||||
包含文档处理、搜索、嵌入等通用操作,供 folderinfo.py 和 ragapi.py 共同使用
|
||||
"""
|
||||
|
||||
import os
|
||||
import re
|
||||
import time
|
||||
import math
|
||||
from datetime import datetime
|
||||
from typing import List, Dict, Any, Optional
|
||||
from langchain_core.documents import Document
|
||||
from langchain_text_splitters import RecursiveCharacterTextSplitter
|
||||
|
||||
from appPublic.log import debug, error, info
|
||||
from filetxt.loader import fileloader, File2Text
|
||||
from rag.uapi_service import APIService
|
||||
from rag.service_opts import get_service_params
|
||||
from rag.transaction_manager import TransactionManager, OperationType
|
||||
|
||||
|
||||
class RagOperations:
|
||||
"""RAG 操作类,提供所有通用的 RAG 操作"""
|
||||
|
||||
def __init__(self):
|
||||
self.api_service = APIService()
|
||||
|
||||
async def load_and_chunk_document(self, realpath: str, timings: Dict,
|
||||
transaction_mgr: TransactionManager = None) -> List[Document]:
|
||||
"""加载文件并进行文本分片"""
|
||||
debug(f"加载文件: {realpath}")
|
||||
start_load = time.time()
|
||||
|
||||
# 检查文件格式支持
|
||||
supported_formats = File2Text.supported_types()
|
||||
debug(f"支持的文件格式:{supported_formats}")
|
||||
ext = realpath.rsplit('.', 1)[1].lower() if '.' in realpath else ''
|
||||
if ext not in supported_formats:
|
||||
raise ValueError(f"不支持的文件格式: {ext}, 支持的格式: {', '.join(supported_formats)}")
|
||||
|
||||
# 加载文件内容
|
||||
text = fileloader(realpath)
|
||||
text = re.sub(r'[^\u4e00-\u9fa5a-zA-Z0-9\s.;,\n/]', '', text)
|
||||
timings["load_file"] = time.time() - start_load
|
||||
debug(f"加载文件耗时: {timings['load_file']:.2f} 秒, 文本长度: {len(text)}")
|
||||
|
||||
if not text or not text.strip():
|
||||
raise ValueError(f"文件 {realpath} 加载为空")
|
||||
|
||||
# 分片处理
|
||||
document = Document(page_content=text)
|
||||
text_splitter = RecursiveCharacterTextSplitter(
|
||||
chunk_size=500,
|
||||
chunk_overlap=100,
|
||||
length_function=len
|
||||
)
|
||||
debug("开始分片文件内容")
|
||||
start_split = time.time()
|
||||
chunks = text_splitter.split_documents([document])
|
||||
timings["split_text"] = time.time() - start_split
|
||||
debug(f"文本分片耗时: {timings['split_text']:.2f} 秒, 分片数量: {len(chunks)}")
|
||||
|
||||
if not chunks:
|
||||
raise ValueError(f"文件 {realpath} 未生成任何文档块")
|
||||
|
||||
# 记录事务操作
|
||||
if transaction_mgr:
|
||||
transaction_mgr.add_operation(
|
||||
OperationType.FILE_LOAD,
|
||||
{'realpath': realpath, 'chunks_count': len(chunks)}
|
||||
)
|
||||
|
||||
return chunks
|
||||
|
||||
async def generate_embeddings(self, request, chunks: List[Document], service_params: Dict,
|
||||
userid: str, timings: Dict,
|
||||
transaction_mgr: TransactionManager = None) -> List[List[float]]:
|
||||
"""生成嵌入向量"""
|
||||
debug("调用嵌入服务生成向量")
|
||||
start_embedding = time.time()
|
||||
texts = [chunk.page_content for chunk in chunks]
|
||||
embeddings = []
|
||||
|
||||
# 批量处理嵌入
|
||||
for i in range(0, len(texts), 10):
|
||||
batch_texts = texts[i:i + 10]
|
||||
batch_embeddings = await self.api_service.get_embeddings(
|
||||
request=request,
|
||||
texts=batch_texts,
|
||||
upappid=service_params['embedding'],
|
||||
apiname="BAAI/bge-m3",
|
||||
user=userid
|
||||
)
|
||||
embeddings.extend(batch_embeddings)
|
||||
|
||||
if not embeddings or not all(len(vec) == 1024 for vec in embeddings):
|
||||
raise ValueError("所有嵌入向量必须是长度为 1024 的浮点数列表")
|
||||
|
||||
timings["generate_embeddings"] = time.time() - start_embedding
|
||||
debug(f"生成嵌入向量耗时: {timings['generate_embeddings']:.2f} 秒, 嵌入数量: {len(embeddings)}")
|
||||
|
||||
# 记录事务操作
|
||||
if transaction_mgr:
|
||||
transaction_mgr.add_operation(
|
||||
OperationType.EMBEDDING,
|
||||
{'embeddings_count': len(embeddings)}
|
||||
)
|
||||
|
||||
return embeddings
|
||||
|
||||
async def insert_to_vector_db(self, request, chunks: List[Document], embeddings: List[List[float]],
|
||||
realpath: str, orgid: str, fiid: str, id: str, service_params: Dict,
|
||||
userid: str, db_type: str, timings: Dict,
|
||||
transaction_mgr: TransactionManager = None):
|
||||
"""插入向量数据库"""
|
||||
debug(f"准备数据并调用插入文件端点: {realpath}")
|
||||
filename = os.path.basename(realpath).rsplit('.', 1)[0]
|
||||
ext = realpath.rsplit('.', 1)[1].lower() if '.' in realpath else ''
|
||||
upload_time = datetime.now().isoformat()
|
||||
|
||||
chunks_data = [
|
||||
{
|
||||
"userid": orgid,
|
||||
"knowledge_base_id": fiid,
|
||||
"text": chunk.page_content,
|
||||
"vector": embeddings[i],
|
||||
"document_id": id,
|
||||
"filename": filename + '.' + ext,
|
||||
"file_path": realpath,
|
||||
"upload_time": upload_time,
|
||||
"file_type": ext,
|
||||
}
|
||||
for i, chunk in enumerate(chunks)
|
||||
]
|
||||
|
||||
start_milvus = time.time()
|
||||
for i in range(0, len(chunks_data), 10):
|
||||
batch_chunks = chunks_data[i:i + 10]
|
||||
result = await self.api_service.milvus_insert_document(
|
||||
request=request,
|
||||
chunks=batch_chunks,
|
||||
db_type=db_type,
|
||||
upappid=service_params['vdb'],
|
||||
apiname="milvus/insertdocument",
|
||||
user=userid
|
||||
)
|
||||
if result.get("status") != "success":
|
||||
raise ValueError(result.get("message", "Milvus 插入失败"))
|
||||
|
||||
timings["insert_milvus"] = time.time() - start_milvus
|
||||
debug(f"Milvus 插入耗时: {timings['insert_milvus']:.2f} 秒")
|
||||
|
||||
# 记录事务操作,包含回滚函数
|
||||
if transaction_mgr:
|
||||
async def rollback_vdb_insert(data, context):
|
||||
await self.delete_from_vector_db(
|
||||
context['request'], data['orgid'], data['realpath'],
|
||||
data['fiid'], data['id'], context['service_params'],
|
||||
context['userid'], data['db_type']
|
||||
)
|
||||
return f"已回滚向量数据库插入: {data['id']}"
|
||||
|
||||
transaction_mgr.add_operation(
|
||||
OperationType.VDB_INSERT,
|
||||
{
|
||||
'orgid': orgid, 'realpath': realpath, 'fiid': fiid,
|
||||
'id': id, 'db_type': db_type
|
||||
},
|
||||
rollback_func=rollback_vdb_insert
|
||||
)
|
||||
|
||||
return chunks_data
|
||||
|
||||
async def extract_triples(self, request, chunks: List[Document], service_params: Dict,
|
||||
userid: str, timings: Dict,
|
||||
transaction_mgr: TransactionManager = None) -> List[Dict]:
|
||||
"""抽取三元组"""
|
||||
debug("调用三元组抽取服务")
|
||||
start_triples = time.time()
|
||||
chunk_texts = [doc.page_content for doc in chunks]
|
||||
triples = []
|
||||
|
||||
for i, chunk in enumerate(chunk_texts):
|
||||
result = await self.api_service.extract_triples(
|
||||
request=request,
|
||||
text=chunk,
|
||||
upappid=service_params['triples'],
|
||||
apiname="Babelscape/mrebel-large",
|
||||
user=userid
|
||||
)
|
||||
if isinstance(result, list):
|
||||
triples.extend(result)
|
||||
debug(f"分片 {i + 1} 抽取到 {len(result)} 个三元组")
|
||||
else:
|
||||
error(f"分片 {i + 1} 处理失败: {str(result)}")
|
||||
|
||||
# 去重和优化三元组
|
||||
unique_triples = self._deduplicate_triples(triples)
|
||||
|
||||
timings["extract_triples"] = time.time() - start_triples
|
||||
debug(f"三元组抽取耗时: {timings['extract_triples']:.2f} 秒, 抽取到 {len(unique_triples)} 个三元组")
|
||||
|
||||
# 记录事务操作
|
||||
if transaction_mgr:
|
||||
transaction_mgr.add_operation(
|
||||
OperationType.TRIPLES_EXTRACT,
|
||||
{'triples_count': len(unique_triples)}
|
||||
)
|
||||
|
||||
return unique_triples
|
||||
|
||||
async def insert_to_graph_db(self, request, triples: List[Dict], id: str, fiid: str,
|
||||
orgid: str, service_params: Dict, userid: str, timings: Dict,
|
||||
transaction_mgr: TransactionManager = None):
|
||||
"""插入图数据库"""
|
||||
debug(f"插入 {len(triples)} 个三元组到 Neo4j")
|
||||
start_neo4j = time.time()
|
||||
|
||||
if triples:
|
||||
for i in range(0, len(triples), 30):
|
||||
batch_triples = triples[i:i + 30]
|
||||
neo4j_result = await self.api_service.neo4j_insert_triples(
|
||||
request=request,
|
||||
triples=batch_triples,
|
||||
document_id=id,
|
||||
knowledge_base_id=fiid,
|
||||
userid=orgid,
|
||||
upappid=service_params['gdb'],
|
||||
apiname="neo4j/inserttriples",
|
||||
user=userid
|
||||
)
|
||||
if neo4j_result.get("status") != "success":
|
||||
raise ValueError(f"Neo4j 三元组插入失败: {neo4j_result.get('message', '未知错误')}")
|
||||
info(f"文件三元组成功插入 Neo4j: {neo4j_result.get('message')}")
|
||||
|
||||
timings["insert_neo4j"] = time.time() - start_neo4j
|
||||
debug(f"Neo4j 插入耗时: {timings['insert_neo4j']:.2f} 秒")
|
||||
else:
|
||||
debug("未抽取到三元组")
|
||||
timings["insert_neo4j"] = 0.0
|
||||
|
||||
# 记录事务操作,包含回滚函数
|
||||
if transaction_mgr:
|
||||
async def rollback_gdb_insert(data, context):
|
||||
await self.delete_from_graph_db(
|
||||
context['request'], data['id'],
|
||||
context['service_params'], context['userid']
|
||||
)
|
||||
return f"已回滚图数据库插入: {data['id']}"
|
||||
|
||||
transaction_mgr.add_operation(
|
||||
OperationType.GDB_INSERT,
|
||||
{'id': id, 'triples_count': len(triples)},
|
||||
rollback_func=rollback_gdb_insert
|
||||
)
|
||||
|
||||
async def delete_from_vector_db(self, request, orgid: str, realpath: str, fiid: str,
|
||||
id: str, service_params: Dict, userid: str, db_type: str):
|
||||
"""从向量数据库删除文档"""
|
||||
debug(f"调用删除文件端点: userid={orgid}, file_path={realpath}, knowledge_base_id={fiid}, document_id={id}")
|
||||
milvus_result = await self.api_service.milvus_delete_document(
|
||||
request=request,
|
||||
userid=orgid,
|
||||
file_path=realpath,
|
||||
knowledge_base_id=fiid,
|
||||
document_id=id,
|
||||
db_type=db_type,
|
||||
upappid=service_params['vdb'],
|
||||
apiname="milvus/deletedocument",
|
||||
user=userid
|
||||
)
|
||||
if milvus_result.get("status") != "success":
|
||||
raise ValueError(milvus_result.get("message", "Milvus 删除失败"))
|
||||
|
||||
async def delete_from_graph_db(self, request, id: str, service_params: Dict, userid: str):
|
||||
"""从图数据库删除文档"""
|
||||
debug(f"调用 Neo4j 删除文档端点: document_id={id}")
|
||||
neo4j_result = await self.api_service.neo4j_delete_document(
|
||||
request=request,
|
||||
document_id=id,
|
||||
upappid=service_params['gdb'],
|
||||
apiname="neo4j/deletedocument",
|
||||
user=userid
|
||||
)
|
||||
if neo4j_result.get("status") != "success":
|
||||
raise ValueError(neo4j_result.get("message", "Neo4j 删除失败"))
|
||||
nodes_deleted = neo4j_result.get("nodes_deleted", 0)
|
||||
rels_deleted = neo4j_result.get("rels_deleted", 0)
|
||||
info(f"成功删除 document_id={id} 的 {nodes_deleted} 个 Neo4j 节点和 {rels_deleted} 个关系")
|
||||
return nodes_deleted, rels_deleted
|
||||
|
||||
async def extract_entities(self, request, query: str, service_params: Dict, userid: str,
|
||||
transaction_mgr: TransactionManager = None) -> List[str]:
|
||||
"""提取实体"""
|
||||
debug(f"提取查询实体: {query}")
|
||||
entities = await self.api_service.extract_entities(
|
||||
request=request,
|
||||
query=query,
|
||||
upappid=service_params['entities'],
|
||||
apiname="LTP/small",
|
||||
user=userid
|
||||
)
|
||||
|
||||
# 记录事务操作
|
||||
if transaction_mgr:
|
||||
transaction_mgr.add_operation(
|
||||
OperationType.ENTITY_EXTRACT,
|
||||
{'query': query, 'entities_count': len(entities)}
|
||||
)
|
||||
|
||||
return entities
|
||||
|
||||
async def match_triplets(self, request, query: str, entities: List[str], orgid: str,
|
||||
fiids: List[str], service_params: Dict, userid: str,
|
||||
transaction_mgr: TransactionManager = None) -> List[Dict]:
|
||||
"""匹配三元组"""
|
||||
debug("开始三元组匹配")
|
||||
all_triplets = []
|
||||
|
||||
for kb_id in fiids:
|
||||
debug(f"调用 Neo4j 三元组匹配: knowledge_base_id={kb_id}")
|
||||
try:
|
||||
neo4j_result = await self.api_service.neo4j_match_triplets(
|
||||
request=request,
|
||||
query=query,
|
||||
query_entities=entities,
|
||||
userid=orgid,
|
||||
knowledge_base_id=kb_id,
|
||||
upappid=service_params['gdb'],
|
||||
apiname="neo4j/matchtriplets",
|
||||
user=userid
|
||||
)
|
||||
if neo4j_result.get("status") == "success":
|
||||
triplets = neo4j_result.get("triplets", [])
|
||||
all_triplets.extend(triplets)
|
||||
debug(f"知识库 {kb_id} 匹配到 {len(triplets)} 个三元组")
|
||||
else:
|
||||
error(
|
||||
f"Neo4j 三元组匹配失败: knowledge_base_id={kb_id}, 错误: {neo4j_result.get('message', '未知错误')}")
|
||||
except Exception as e:
|
||||
error(f"Neo4j 三元组匹配失败: knowledge_base_id={kb_id}, 错误: {str(e)}")
|
||||
continue
|
||||
|
||||
# 记录事务操作
|
||||
if transaction_mgr:
|
||||
transaction_mgr.add_operation(
|
||||
OperationType.TRIPLET_MATCH,
|
||||
{'query': query, 'triplets_count': len(all_triplets)}
|
||||
)
|
||||
|
||||
return all_triplets
|
||||
|
||||
async def vector_search(self, request, query_vector: List[float], orgid: str,
|
||||
fiids: List[str], limit: int, service_params: Dict, userid: str,
|
||||
transaction_mgr: TransactionManager = None) -> List[Dict]:
|
||||
"""向量搜索"""
|
||||
debug("开始向量搜索")
|
||||
result = await self.api_service.milvus_search_query(
|
||||
request=request,
|
||||
query_vector=query_vector,
|
||||
userid=orgid,
|
||||
knowledge_base_ids=fiids,
|
||||
limit=limit,
|
||||
offset=0,
|
||||
upappid=service_params['vdb'],
|
||||
apiname="mlvus/searchquery",
|
||||
user=userid
|
||||
)
|
||||
|
||||
if result.get("status") != "success":
|
||||
raise ValueError(f"向量搜索失败: {result.get('message', '未知错误')}")
|
||||
|
||||
search_results = result.get("results", [])
|
||||
|
||||
# 记录事务操作
|
||||
if transaction_mgr:
|
||||
transaction_mgr.add_operation(
|
||||
OperationType.VECTOR_SEARCH,
|
||||
{'results_count': len(search_results)}
|
||||
)
|
||||
|
||||
return search_results
|
||||
|
||||
async def rerank_results(self, request, query: str, results: List[Dict], top_n: int,
|
||||
service_params: Dict, userid: str,
|
||||
transaction_mgr: TransactionManager = None) -> List[Dict]:
|
||||
"""重排序结果"""
|
||||
debug("开始重排序")
|
||||
reranked_results = await self.api_service.rerank_results(
|
||||
request=request,
|
||||
query=query,
|
||||
results=results,
|
||||
top_n=top_n,
|
||||
upappid=service_params['reranker'],
|
||||
apiname="BAAI/bge-reranker-v2-m3",
|
||||
user=userid
|
||||
)
|
||||
|
||||
# 记录事务操作
|
||||
if transaction_mgr:
|
||||
transaction_mgr.add_operation(
|
||||
OperationType.RERANK,
|
||||
{'input_count': len(results), 'output_count': len(reranked_results)}
|
||||
)
|
||||
|
||||
return reranked_results
|
||||
|
||||
def _deduplicate_triples(self, triples: List[Dict]) -> List[Dict]:
|
||||
"""去重和优化三元组"""
|
||||
unique_triples = []
|
||||
seen = set()
|
||||
|
||||
for t in triples:
|
||||
identifier = (t['head'].lower(), t['tail'].lower(), t['type'].lower())
|
||||
if identifier not in seen:
|
||||
seen.add(identifier)
|
||||
unique_triples.append(t)
|
||||
else:
|
||||
# 如果发现更具体的类型,则替换
|
||||
for existing in unique_triples:
|
||||
if (existing['head'].lower() == t['head'].lower() and
|
||||
existing['tail'].lower() == t['tail'].lower() and
|
||||
len(t['type']) > len(existing['type'])):
|
||||
unique_triples.remove(existing)
|
||||
unique_triples.append(t)
|
||||
debug(f"替换三元组为更具体类型: {t}")
|
||||
break
|
||||
|
||||
return unique_triples
|
||||
|
||||
def format_search_results(self, results: List[Dict], limit: int) -> List[Dict]:
|
||||
"""格式化搜索结果为统一格式"""
|
||||
formatted_results = []
|
||||
|
||||
for res in results[:limit]:
|
||||
rerank_score = res.get('rerank_score', 0)
|
||||
score = 1 / (1 + math.exp(-rerank_score)) if rerank_score is not None else 1 - res.get('distance', 0)
|
||||
score = max(0.0, min(1.0, score))
|
||||
|
||||
content = res.get('text', '')
|
||||
title = res.get('metadata', {}).get('filename', 'Untitled')
|
||||
document_id = res.get('metadata', {}).get('document_id', '')
|
||||
|
||||
formatted_results.append({
|
||||
"content": content,
|
||||
"title": title,
|
||||
"metadata": {"document_id": document_id, "score": score},
|
||||
})
|
||||
|
||||
return formatted_results
|
||||
479
rag/ragapi.py
479
rag/ragapi.py
@ -6,6 +6,7 @@ import traceback
|
||||
import json
|
||||
import math
|
||||
from rag.service_opts import get_service_params, sor_get_service_params
|
||||
from rag.rag_operations import RagOperations
|
||||
|
||||
helptext = """kyrag API:
|
||||
|
||||
@ -131,14 +132,97 @@ async def fusedsearch(request, params_kw, *params, **kw):
|
||||
debug(f"fiids: {fiids}")
|
||||
|
||||
# 验证 fiids的orgid与orgid = await f()是否一致
|
||||
await _validate_fiids_orgid(fiids, orgid, kw)
|
||||
|
||||
service_params = await get_service_params(orgid)
|
||||
if not service_params:
|
||||
raise ValueError("无法获取服务参数")
|
||||
|
||||
try:
|
||||
timing_stats = {}
|
||||
start_time = time.time()
|
||||
rag_ops = RagOperations()
|
||||
|
||||
entity_extract_start = time.time()
|
||||
query_entities = await rag_ops.extract_entities(request, query, service_params, userid)
|
||||
timing_stats["entity_extraction"] = time.time() - entity_extract_start
|
||||
debug(f"提取实体: {query_entities}, 耗时: {timing_stats['entity_extraction']:.3f} 秒")
|
||||
|
||||
triplet_match_start = time.time()
|
||||
all_triplets = await rag_ops.match_triplets(request, query, query_entities, orgid, fiids, service_params, userid)
|
||||
timing_stats["triplet_matching"] = time.time() - triplet_match_start
|
||||
debug(f"三元组匹配总耗时: {timing_stats['triplet_matching']:.3f} 秒")
|
||||
|
||||
triplet_text_start = time.time()
|
||||
combined_text = _combine_query_with_triplets(query, all_triplets)
|
||||
timing_stats["triplet_text_combine"] = time.time() - triplet_text_start
|
||||
debug(f"拼接三元组文本耗时: {timing_stats['triplet_text_combine']:.3f} 秒")
|
||||
|
||||
vector_start = time.time()
|
||||
query_vector = await rag_ops.api_service.get_embeddings(
|
||||
request=request,
|
||||
texts=[combined_text],
|
||||
upappid=service_params['embedding'],
|
||||
apiname="BAAI/bge-m3",
|
||||
user=userid
|
||||
)
|
||||
if not query_vector or not all(len(vec) == 1024 for vec in query_vector):
|
||||
raise ValueError("查询向量必须是长度为 1024 的浮点数列表")
|
||||
query_vector = query_vector[0]
|
||||
timing_stats["vector_generation"] = time.time() - vector_start
|
||||
debug(f"生成查询向量耗时: {timing_stats['vector_generation']:.3f} 秒")
|
||||
|
||||
search_start = time.time()
|
||||
search_limit = limit + 5
|
||||
search_results = await rag_ops.vector_search(
|
||||
request, query_vector, orgid, fiids, search_limit, service_params, userid
|
||||
)
|
||||
timing_stats["vector_search"] = time.time() - search_start
|
||||
debug(f"向量搜索耗时: {timing_stats['vector_search']:.3f} 秒")
|
||||
debug(f"从向量数据中搜索到{len(search_results)}条数据")
|
||||
|
||||
# 步骤6: 重排序(可选)
|
||||
use_rerank = True
|
||||
if use_rerank and search_results:
|
||||
rerank_start = time.time()
|
||||
debug("开始重排序")
|
||||
reranked_results = await rag_ops.rerank_results(
|
||||
request, combined_text, search_results, limit, service_params, userid
|
||||
)
|
||||
reranked_results = sorted(reranked_results, key=lambda x: x.get('rerank_score', 0), reverse=True)
|
||||
timing_stats["reranking"] = time.time() - rerank_start
|
||||
debug(f"重排序耗时: {timing_stats['reranking']:.3f} 秒")
|
||||
debug(f"重排序分数分布: {[round(r.get('rerank_score', 0), 3) for r in reranked_results]}")
|
||||
final_results = reranked_results
|
||||
else:
|
||||
final_results = [{k: v for k, v in r.items() if k != 'rerank_score'} for r in search_results]
|
||||
|
||||
timing_stats["total_time"] = time.time() - start_time
|
||||
info(f"融合搜索完成,返回 {len(final_results)} 条结果,总耗时: {timing_stats['total_time']:.3f} 秒")
|
||||
|
||||
formatted_results = rag_ops.format_search_results(final_results, limit)
|
||||
info(f"融合搜索完成,返回 {len(formatted_results)} 条结果")
|
||||
|
||||
return {
|
||||
"records": formatted_results
|
||||
}
|
||||
|
||||
except Exception as e:
|
||||
error(f"融合搜索失败: {str(e)}, 堆栈: {traceback.format_exc()}")
|
||||
# 事务管理器会自动执行回滚
|
||||
return {
|
||||
"records": [],
|
||||
"timing": {"total_time": time.time() - start_time if 'start_time' in locals() else 0},
|
||||
"error": str(e)
|
||||
}
|
||||
|
||||
|
||||
async def _validate_fiids_orgid(fiids, orgid, kw):
|
||||
"""验证 fiids 的 orgid 与当前用户 orgid 是否一致"""
|
||||
if fiids:
|
||||
db = DBPools()
|
||||
dbname = kw.get('get_module_dbname')('rag')
|
||||
sql_opts = """
|
||||
SELECT orgid
|
||||
FROM kdb
|
||||
WHERE id = ${id}$
|
||||
"""
|
||||
sql_opts = """SELECT orgid FROM kdb WHERE id = ${id}$"""
|
||||
try:
|
||||
async with db.sqlorContext(dbname) as sor:
|
||||
result = await sor.sqlExe(sql_opts, {"id": fiids[0]})
|
||||
@ -149,64 +233,13 @@ async def fusedsearch(request, params_kw, *params, **kw):
|
||||
raise ValueError(f"orgid 不一致: kdb.orgid={kdb_orgid}, user orgid={orgid}")
|
||||
except Exception as e:
|
||||
error(f"orgid 验证失败: {str(e)}")
|
||||
return json.dumps({"status": "error", "message": str(e)})
|
||||
service_params = await get_service_params(orgid)
|
||||
raise
|
||||
|
||||
api_service = APIService()
|
||||
start_time = time.time()
|
||||
timing_stats = {}
|
||||
try:
|
||||
info(
|
||||
f"开始融合搜索: query={query}, userid={orgid}, knowledge_base_ids={fiids}")
|
||||
|
||||
if not query or not orgid or not fiids:
|
||||
raise ValueError("query、orgid 和 knowledge_base_ids 不能为空")
|
||||
|
||||
# 提取实体
|
||||
entity_extract_start = time.time()
|
||||
query_entities = await api_service.extract_entities(
|
||||
request=request,
|
||||
query=query,
|
||||
upappid=service_params['entities'],
|
||||
apiname="LTP/small",
|
||||
user=userid
|
||||
)
|
||||
timing_stats["entity_extraction"] = time.time() - entity_extract_start
|
||||
debug(f"提取实体: {query_entities}, 耗时: {timing_stats['entity_extraction']:.3f} 秒")
|
||||
|
||||
# 调用 Neo4j 服务进行三元组匹配
|
||||
all_triplets = []
|
||||
triplet_match_start = time.time()
|
||||
for kb_id in fiids:
|
||||
debug(f"调用 Neo4j 三元组匹配: knowledge_base_id={kb_id}")
|
||||
try:
|
||||
neo4j_result = await api_service.neo4j_match_triplets(
|
||||
request=request,
|
||||
query=query,
|
||||
query_entities=query_entities,
|
||||
userid=orgid,
|
||||
knowledge_base_id=kb_id,
|
||||
upappid=service_params['gdb'],
|
||||
apiname="neo4j/matchtriplets",
|
||||
user=userid
|
||||
)
|
||||
if neo4j_result.get("status") == "success":
|
||||
triplets = neo4j_result.get("triplets", [])
|
||||
all_triplets.extend(triplets)
|
||||
debug(f"知识库 {kb_id} 匹配到 {len(triplets)} 个三元组: {triplets[:5]}")
|
||||
else:
|
||||
error(
|
||||
f"Neo4j 三元组匹配失败: knowledge_base_id={kb_id}, 错误: {neo4j_result.get('message', '未知错误')}")
|
||||
except Exception as e:
|
||||
error(f"Neo4j 三元组匹配失败: knowledge_base_id={kb_id}, 错误: {str(e)}")
|
||||
continue
|
||||
timing_stats["triplet_matching"] = time.time() - triplet_match_start
|
||||
debug(f"三元组匹配总耗时: {timing_stats['triplet_matching']:.3f} 秒")
|
||||
|
||||
# 拼接三元组文本
|
||||
triplet_text_start = time.time()
|
||||
def _combine_query_with_triplets(query, triplets):
|
||||
"""拼接查询文本和三元组文本"""
|
||||
triplet_texts = []
|
||||
for triplet in all_triplets:
|
||||
for triplet in triplets:
|
||||
head = triplet.get('head', '')
|
||||
type_ = triplet.get('type', '')
|
||||
tail = triplet.get('tail', '')
|
||||
@ -214,115 +247,162 @@ async def fusedsearch(request, params_kw, *params, **kw):
|
||||
triplet_texts.append(f"{head} {type_} {tail}")
|
||||
else:
|
||||
debug(f"无效三元组: {triplet}")
|
||||
|
||||
combined_text = query
|
||||
if triplet_texts:
|
||||
combined_text += "".join(triplet_texts)
|
||||
debug(
|
||||
f"拼接文本: {combined_text[:200]}... (总长度: {len(combined_text)}, 三元组数量: {len(triplet_texts)})")
|
||||
timing_stats["triplet_text_combine"] = time.time() - triplet_text_start
|
||||
debug(f"拼接三元组文本耗时: {timing_stats['triplet_text_combine']:.3f} 秒")
|
||||
|
||||
# 将拼接文本转换为向量
|
||||
vector_start = time.time()
|
||||
query_vector = await api_service.get_embeddings(
|
||||
request=request,
|
||||
texts=[combined_text],
|
||||
upappid=service_params['embedding'],
|
||||
apiname="BAAI/bge-m3",
|
||||
user=userid
|
||||
)
|
||||
if not query_vector or not all(len(vec) == 1024 for vec in query_vector):
|
||||
raise ValueError("查询向量必须是长度为 1024 的浮点数列表")
|
||||
query_vector = query_vector[0] # 取第一个向量
|
||||
timing_stats["vector_generation"] = time.time() - vector_start
|
||||
debug(f"生成查询向量耗时: {timing_stats['vector_generation']:.3f} 秒")
|
||||
|
||||
# 调用搜索端点
|
||||
sum = limit + 5
|
||||
search_start = time.time()
|
||||
debug(f"orgid: {orgid}")
|
||||
result = await api_service.milvus_search_query(
|
||||
request=request,
|
||||
query_vector=query_vector,
|
||||
userid=orgid,
|
||||
knowledge_base_ids=fiids,
|
||||
limit=sum,
|
||||
offset=0,
|
||||
upappid=service_params['vdb'],
|
||||
apiname="mlvus/searchquery",
|
||||
user=userid
|
||||
)
|
||||
timing_stats["vector_search"] = time.time() - search_start
|
||||
debug(f"向量搜索耗时: {timing_stats['vector_search']:.3f} 秒")
|
||||
|
||||
if result.get("status") != "success":
|
||||
error(f"融合搜索失败: {result.get('message', '未知错误')}")
|
||||
return {"results": [], "timing": timing_stats}
|
||||
|
||||
unique_results = result.get("results", [])
|
||||
sum = len(unique_results)
|
||||
debug(f"从向量数据中搜索到{sum}条数据")
|
||||
use_rerank = True
|
||||
if use_rerank and unique_results:
|
||||
rerank_start = time.time()
|
||||
debug("开始重排序")
|
||||
unique_results = await api_service.rerank_results(
|
||||
request=request,
|
||||
query=combined_text,
|
||||
results=unique_results,
|
||||
top_n=limit,
|
||||
upappid=service_params['reranker'],
|
||||
apiname="BAAI/bge-reranker-v2-m3",
|
||||
user=userid
|
||||
)
|
||||
unique_results = sorted(unique_results, key=lambda x: x.get('rerank_score', 0), reverse=True)
|
||||
timing_stats["reranking"] = time.time() - rerank_start
|
||||
debug(f"重排序耗时: {timing_stats['reranking']:.3f} 秒")
|
||||
debug(f"重排序分数分布: {[round(r.get('rerank_score', 0), 3) for r in unique_results]}")
|
||||
else:
|
||||
unique_results = [{k: v for k, v in r.items() if k != 'rerank_score'} for r in unique_results]
|
||||
|
||||
timing_stats["total_time"] = time.time() - start_time
|
||||
info(f"融合搜索完成,返回 {len(unique_results)} 条结果,总耗时: {timing_stats['total_time']:.3f} 秒")
|
||||
|
||||
# dify_result = []
|
||||
# for res in unique_results[:limit]:
|
||||
# content = res.get('text', '')
|
||||
# title = res.get('metadata', {}).get('filename', 'Untitled')
|
||||
# document_id = res.get('metadata', {}).get('document_id', '')
|
||||
# dify_result.append({
|
||||
# 'metadata': {'document_id': document_id},
|
||||
# 'title': title,
|
||||
# 'content': content
|
||||
# })
|
||||
# info(f"融合搜索完成,返回 {len(dify_result)} 条结果,总耗时: {(time.time() - start_time):.3f} 秒")
|
||||
# debug(f"result: {dify_result}")
|
||||
# return dify_result
|
||||
|
||||
dify_records = []
|
||||
dify_result = []
|
||||
for res in unique_results[:limit]:
|
||||
rerank_score = res.get('rerank_score', 0)
|
||||
score = 1 / (1 + math.exp(-rerank_score)) if rerank_score is not None else 1 - res.get('distance', 0)
|
||||
score = max(0.0, min(1.0, score))
|
||||
content = res.get('text', '')
|
||||
title = res.get('metadata', {}).get('filename', 'Untitled')
|
||||
document_id = res.get('metadata', {}).get('document_id', '')
|
||||
dify_records.append({
|
||||
"content": content,
|
||||
"title": title,
|
||||
"metadata": {"document_id": document_id, "score": score},
|
||||
})
|
||||
dify_result.append({
|
||||
"content": content,
|
||||
"title": title,
|
||||
"metadata": {"document_id": document_id, "score": score},
|
||||
})
|
||||
info(f"融合搜索完成,返回 {len(dify_records)} 条结果,总耗时: {(time.time() - start_time):.3f} 秒")
|
||||
debug(f"records: {dify_records}, result: {dify_result}")
|
||||
# return {"records": dify_records, "result": dify_result,"own": {"results": unique_results[:limit], "timing": timing_stats}}
|
||||
return {"records": dify_records}
|
||||
debug(f"拼接文本: {combined_text[:200]}... (总长度: {len(combined_text)}, 三元组数量: {len(triplet_texts)})")
|
||||
return combined_text
|
||||
|
||||
# api_service = APIService()
|
||||
# start_time = time.time()
|
||||
# timing_stats = {}
|
||||
# try:
|
||||
# info(
|
||||
# f"开始融合搜索: query={query}, userid={orgid}, knowledge_base_ids={fiids}")
|
||||
#
|
||||
# if not query or not orgid or not fiids:
|
||||
# raise ValueError("query、orgid 和 knowledge_base_ids 不能为空")
|
||||
#
|
||||
# # 提取实体
|
||||
# entity_extract_start = time.time()
|
||||
# query_entities = await api_service.extract_entities(
|
||||
# request=request,
|
||||
# query=query,
|
||||
# upappid=service_params['entities'],
|
||||
# apiname="LTP/small",
|
||||
# user=userid
|
||||
# )
|
||||
# timing_stats["entity_extraction"] = time.time() - entity_extract_start
|
||||
# debug(f"提取实体: {query_entities}, 耗时: {timing_stats['entity_extraction']:.3f} 秒")
|
||||
#
|
||||
# # 调用 Neo4j 服务进行三元组匹配
|
||||
# all_triplets = []
|
||||
# triplet_match_start = time.time()
|
||||
# for kb_id in fiids:
|
||||
# debug(f"调用 Neo4j 三元组匹配: knowledge_base_id={kb_id}")
|
||||
# try:
|
||||
# neo4j_result = await api_service.neo4j_match_triplets(
|
||||
# request=request,
|
||||
# query=query,
|
||||
# query_entities=query_entities,
|
||||
# userid=orgid,
|
||||
# knowledge_base_id=kb_id,
|
||||
# upappid=service_params['gdb'],
|
||||
# apiname="neo4j/matchtriplets",
|
||||
# user=userid
|
||||
# )
|
||||
# if neo4j_result.get("status") == "success":
|
||||
# triplets = neo4j_result.get("triplets", [])
|
||||
# all_triplets.extend(triplets)
|
||||
# debug(f"知识库 {kb_id} 匹配到 {len(triplets)} 个三元组: {triplets[:5]}")
|
||||
# else:
|
||||
# error(
|
||||
# f"Neo4j 三元组匹配失败: knowledge_base_id={kb_id}, 错误: {neo4j_result.get('message', '未知错误')}")
|
||||
# except Exception as e:
|
||||
# error(f"Neo4j 三元组匹配失败: knowledge_base_id={kb_id}, 错误: {str(e)}")
|
||||
# continue
|
||||
# timing_stats["triplet_matching"] = time.time() - triplet_match_start
|
||||
# debug(f"三元组匹配总耗时: {timing_stats['triplet_matching']:.3f} 秒")
|
||||
#
|
||||
# # 拼接三元组文本
|
||||
# triplet_text_start = time.time()
|
||||
# triplet_texts = []
|
||||
# for triplet in all_triplets:
|
||||
# head = triplet.get('head', '')
|
||||
# type_ = triplet.get('type', '')
|
||||
# tail = triplet.get('tail', '')
|
||||
# if head and type_ and tail:
|
||||
# triplet_texts.append(f"{head} {type_} {tail}")
|
||||
# else:
|
||||
# debug(f"无效三元组: {triplet}")
|
||||
# combined_text = query
|
||||
# if triplet_texts:
|
||||
# combined_text += "".join(triplet_texts)
|
||||
# debug(
|
||||
# f"拼接文本: {combined_text[:200]}... (总长度: {len(combined_text)}, 三元组数量: {len(triplet_texts)})")
|
||||
# timing_stats["triplet_text_combine"] = time.time() - triplet_text_start
|
||||
# debug(f"拼接三元组文本耗时: {timing_stats['triplet_text_combine']:.3f} 秒")
|
||||
#
|
||||
# # 将拼接文本转换为向量
|
||||
# vector_start = time.time()
|
||||
# query_vector = await api_service.get_embeddings(
|
||||
# request=request,
|
||||
# texts=[combined_text],
|
||||
# upappid=service_params['embedding'],
|
||||
# apiname="BAAI/bge-m3",
|
||||
# user=userid
|
||||
# )
|
||||
# if not query_vector or not all(len(vec) == 1024 for vec in query_vector):
|
||||
# raise ValueError("查询向量必须是长度为 1024 的浮点数列表")
|
||||
# query_vector = query_vector[0] # 取第一个向量
|
||||
# timing_stats["vector_generation"] = time.time() - vector_start
|
||||
# debug(f"生成查询向量耗时: {timing_stats['vector_generation']:.3f} 秒")
|
||||
#
|
||||
# # 调用搜索端点
|
||||
# sum = limit + 5
|
||||
# search_start = time.time()
|
||||
# debug(f"orgid: {orgid}")
|
||||
# result = await api_service.milvus_search_query(
|
||||
# request=request,
|
||||
# query_vector=query_vector,
|
||||
# userid=orgid,
|
||||
# knowledge_base_ids=fiids,
|
||||
# limit=sum,
|
||||
# offset=0,
|
||||
# upappid=service_params['vdb'],
|
||||
# apiname="mlvus/searchquery",
|
||||
# user=userid
|
||||
# )
|
||||
# timing_stats["vector_search"] = time.time() - search_start
|
||||
# debug(f"向量搜索耗时: {timing_stats['vector_search']:.3f} 秒")
|
||||
#
|
||||
# if result.get("status") != "success":
|
||||
# error(f"融合搜索失败: {result.get('message', '未知错误')}")
|
||||
# return {"results": [], "timing": timing_stats}
|
||||
#
|
||||
# unique_results = result.get("results", [])
|
||||
# sum = len(unique_results)
|
||||
# debug(f"从向量数据中搜索到{sum}条数据")
|
||||
# use_rerank = True
|
||||
# if use_rerank and unique_results:
|
||||
# rerank_start = time.time()
|
||||
# debug("开始重排序")
|
||||
# unique_results = await api_service.rerank_results(
|
||||
# request=request,
|
||||
# query=combined_text,
|
||||
# results=unique_results,
|
||||
# top_n=limit,
|
||||
# upappid=service_params['reranker'],
|
||||
# apiname="BAAI/bge-reranker-v2-m3",
|
||||
# user=userid
|
||||
# )
|
||||
# unique_results = sorted(unique_results, key=lambda x: x.get('rerank_score', 0), reverse=True)
|
||||
# timing_stats["reranking"] = time.time() - rerank_start
|
||||
# debug(f"重排序耗时: {timing_stats['reranking']:.3f} 秒")
|
||||
# debug(f"重排序分数分布: {[round(r.get('rerank_score', 0), 3) for r in unique_results]}")
|
||||
# else:
|
||||
# unique_results = [{k: v for k, v in r.items() if k != 'rerank_score'} for r in unique_results]
|
||||
#
|
||||
# timing_stats["total_time"] = time.time() - start_time
|
||||
# info(f"融合搜索完成,返回 {len(unique_results)} 条结果,总耗时: {timing_stats['total_time']:.3f} 秒")
|
||||
#
|
||||
# # dify_result = []
|
||||
# # for res in unique_results[:limit]:
|
||||
# # content = res.get('text', '')
|
||||
# # title = res.get('metadata', {}).get('filename', 'Untitled')
|
||||
# # document_id = res.get('metadata', {}).get('document_id', '')
|
||||
# # dify_result.append({
|
||||
# # 'metadata': {'document_id': document_id},
|
||||
# # 'title': title,
|
||||
# # 'content': content
|
||||
# # })
|
||||
# # info(f"融合搜索完成,返回 {len(dify_result)} 条结果,总耗时: {(time.time() - start_time):.3f} 秒")
|
||||
# # debug(f"result: {dify_result}")
|
||||
# # return dify_result
|
||||
#
|
||||
# dify_records = []
|
||||
# dify_result = []
|
||||
# for res in unique_results[:limit]:
|
||||
# rerank_score = res.get('rerank_score', 0)
|
||||
@ -331,29 +411,52 @@ async def fusedsearch(request, params_kw, *params, **kw):
|
||||
# content = res.get('text', '')
|
||||
# title = res.get('metadata', {}).get('filename', 'Untitled')
|
||||
# document_id = res.get('metadata', {}).get('document_id', '')
|
||||
# dify_result.append({
|
||||
# "metadata": {
|
||||
# "_source": "konwledge",
|
||||
# "dataset_id":"111111",
|
||||
# "dataset_name": "NVIDIA_GPU性能参数-RAG-V1.xlsx",
|
||||
# "document_id": document_id,
|
||||
# "document_name": "test.docx",
|
||||
# "data_source_type": "upload_file",
|
||||
# "segment_id": "7b391707-93bc-4654-80ae-7989f393b045",
|
||||
# "retriever_from": "workflow",
|
||||
# "score": score,
|
||||
# "segment_hit_count": 7,
|
||||
# "segment_word_count": 275,
|
||||
# "segment_position": 5,
|
||||
# "segment_index_node_hash": "1cd60b478221c9d4831a0b2af3e8b8581d94ecb53e8ffd46af687e8fc3077b73",
|
||||
# "doc_metadata": None,
|
||||
# "position":1
|
||||
# },
|
||||
# dify_records.append({
|
||||
# "content": content,
|
||||
# "title": title,
|
||||
# "content": content
|
||||
# "metadata": {"document_id": document_id, "score": score},
|
||||
# })
|
||||
# return {"result": dify_result}
|
||||
|
||||
except Exception as e:
|
||||
error(f"融合搜索失败: {str(e)}, 堆栈: {traceback.format_exc()}")
|
||||
return {"results": [], "timing": timing_stats}
|
||||
# dify_result.append({
|
||||
# "content": content,
|
||||
# "title": title,
|
||||
# "metadata": {"document_id": document_id, "score": score},
|
||||
# })
|
||||
# info(f"融合搜索完成,返回 {len(dify_records)} 条结果,总耗时: {(time.time() - start_time):.3f} 秒")
|
||||
# debug(f"records: {dify_records}, result: {dify_result}")
|
||||
# # return {"records": dify_records, "result": dify_result,"own": {"results": unique_results[:limit], "timing": timing_stats}}
|
||||
# return {"records": dify_records}
|
||||
#
|
||||
# # dify_result = []
|
||||
# # for res in unique_results[:limit]:
|
||||
# # rerank_score = res.get('rerank_score', 0)
|
||||
# # score = 1 / (1 + math.exp(-rerank_score)) if rerank_score is not None else 1 - res.get('distance', 0)
|
||||
# # score = max(0.0, min(1.0, score))
|
||||
# # content = res.get('text', '')
|
||||
# # title = res.get('metadata', {}).get('filename', 'Untitled')
|
||||
# # document_id = res.get('metadata', {}).get('document_id', '')
|
||||
# # dify_result.append({
|
||||
# # "metadata": {
|
||||
# # "_source": "konwledge",
|
||||
# # "dataset_id":"111111",
|
||||
# # "dataset_name": "NVIDIA_GPU性能参数-RAG-V1.xlsx",
|
||||
# # "document_id": document_id,
|
||||
# # "document_name": "test.docx",
|
||||
# # "data_source_type": "upload_file",
|
||||
# # "segment_id": "7b391707-93bc-4654-80ae-7989f393b045",
|
||||
# # "retriever_from": "workflow",
|
||||
# # "score": score,
|
||||
# # "segment_hit_count": 7,
|
||||
# # "segment_word_count": 275,
|
||||
# # "segment_position": 5,
|
||||
# # "segment_index_node_hash": "1cd60b478221c9d4831a0b2af3e8b8581d94ecb53e8ffd46af687e8fc3077b73",
|
||||
# # "doc_metadata": None,
|
||||
# # "position":1
|
||||
# # },
|
||||
# # "title": title,
|
||||
# # "content": content
|
||||
# # })
|
||||
# # return {"result": dify_result}
|
||||
#
|
||||
# except Exception as e:
|
||||
# error(f"融合搜索失败: {str(e)}, 堆栈: {traceback.format_exc()}")
|
||||
# return {"results": [], "timing": timing_stats}
|
||||
|
||||
Loading…
x
Reference in New Issue
Block a user