This commit is contained in:
wangmeihua 2025-09-12 15:34:36 +08:00
parent e590c1084f
commit b6d3f39081
3 changed files with 352 additions and 308 deletions

View File

@ -14,7 +14,7 @@ import time
import uuid import uuid
from datetime import datetime from datetime import datetime
import traceback import traceback
from filetxt.loader import fileloader from filetxt.loader import fileloader,File2Text
from ahserver.serverenv import get_serverenv from ahserver.serverenv import get_serverenv
from typing import List, Dict, Any from typing import List, Dict, Any
from rag.service_opts import get_service_params, sor_get_service_params from rag.service_opts import get_service_params, sor_get_service_params
@ -44,56 +44,20 @@ where a.orgid = b.orgid
return r.quota, r.expired_date return r.quota, r.expired_date
return None, None return None, None
async def file_uploaded(self, request, ns, userid): async def get_doucment_chunks(self, realpath, timings):
"""将文档插入 Milvus 并抽取三元组到 Neo4j""" """加载文件并进行文本分片"""
debug(f'Received ns: {ns=}') debug(f"加载文件: {realpath}")
env = request._run_ns start_load = time.time()
realpath = ns.get('realpath', '') supported_formats = File2Text.supported_types()
fiid = ns.get('fiid', '') debug(f"支持的文件格式:{supported_formats}")
id = ns.get('id', '')
orgid = ns.get('ownerid', '')
hashvalue = ns.get('hashvalue', '')
db_type = ''
api_service = APIService()
debug(
f'Inserting document: file_path={realpath}, userid={orgid}, db_type={db_type}, knowledge_base_id={fiid}, document_id={id}')
timings = {}
start_total = time.time()
service_params = await get_service_params(orgid)
chunks = await self.get_doucment_chunks(realpath)
embeddings = await self.docs_embedding(chunks)
await self.embedding_2_vdb(id, fiid, orgid, realpath, embedding)
triples = await self.get_triples(chunks)
await self.triple2graphdb(id, fiid, orgid, realpath, triples)
return
try:
if not orgid or not fiid or not id:
raise ValueError("orgid、fiid 和 id 不能为空")
if len(orgid) > 32 or len(fiid) > 255:
raise ValueError("orgid 或 fiid 的长度超出限制")
if not os.path.exists(realpath):
raise ValueError(f"文件 {realpath} 不存在")
# 获取服务参数
service_params = await get_service_params(orgid)
if not service_params:
raise ValueError("无法获取服务参数")
supported_formats = {'pdf', 'docx', 'xlsx', 'pptx', 'csv', 'txt'}
ext = realpath.rsplit('.', 1)[1].lower() if '.' in realpath else '' ext = realpath.rsplit('.', 1)[1].lower() if '.' in realpath else ''
if ext not in supported_formats: if ext not in supported_formats:
raise ValueError(f"不支持的文件格式: {ext}, 支持的格式: {', '.join(supported_formats)}") raise ValueError(f"不支持的文件格式: {ext}, 支持的格式: {', '.join(supported_formats)}")
debug(f"加载文件: {realpath}")
start_load = time.time()
text = fileloader(realpath) text = fileloader(realpath)
# debug(f"处理后的文件内容是:{text=}")
text = re.sub(r'[^\u4e00-\u9fa5a-zA-Z0-9\s.;,\n/]', '', text) text = re.sub(r'[^\u4e00-\u9fa5a-zA-Z0-9\s.;,\n/]', '', text)
timings["load_file"] = time.time() - start_load timings["load_file"] = time.time() - start_load
debug(f"加载文件耗时: {timings['load_file']:.2f} 秒, 文本长度: {len(text)}") debug(f"加载文件耗时: {timings['load_file']:.2f} 秒, 文本长度: {len(text)}")
if not text or not text.strip(): if not text or not text.strip():
raise ValueError(f"文件 {realpath} 加载为空") raise ValueError(f"文件 {realpath} 加载为空")
@ -101,27 +65,28 @@ where a.orgid = b.orgid
text_splitter = RecursiveCharacterTextSplitter( text_splitter = RecursiveCharacterTextSplitter(
chunk_size=500, chunk_size=500,
chunk_overlap=100, chunk_overlap=100,
length_function=len) length_function=len
)
debug("开始分片文件内容") debug("开始分片文件内容")
start_split = time.time() start_split = time.time()
chunks = text_splitter.split_documents([document]) chunks = text_splitter.split_documents([document])
timings["split_text"] = time.time() - start_split timings["split_text"] = time.time() - start_split
debug( debug(f"文本分片耗时: {timings['split_text']:.2f} 秒, 分片数量: {len(chunks)}")
f"文本分片耗时: {timings['split_text']:.2f} 秒, 分片数量: {len(chunks)}, 分片内容: {[chunk.page_content[:50] for chunk in chunks[:5]]}")
debug(f"分片内容: {[chunk.page_content[:100] + '...' for chunk in chunks]}")
if not chunks: if not chunks:
raise ValueError(f"文件 {realpath} 未生成任何文档块") raise ValueError(f"文件 {realpath} 未生成任何文档块")
filename = os.path.basename(realpath).rsplit('.', 1)[0] return chunks
upload_time = datetime.now().isoformat()
async def docs_embedding(self, request, chunks, service_params, userid, timings):
"""调用嵌入服务生成向量"""
debug("调用嵌入服务生成向量") debug("调用嵌入服务生成向量")
start_embedding = time.time() start_embedding = time.time()
texts = [chunk.page_content for chunk in chunks] texts = [chunk.page_content for chunk in chunks]
embeddings = [] embeddings = []
for i in range(0, len(texts), 10): # 每次处理 10 个文本块 for i in range(0, len(texts), 10):
batch_texts = texts[i:i + 10] batch_texts = texts[i:i + 10]
batch_embeddings = await api_service.get_embeddings( batch_embeddings = await APIService().get_embeddings(
request=request, request=request,
texts=batch_texts, texts=batch_texts,
upappid=service_params['embedding'], upappid=service_params['embedding'],
@ -129,14 +94,24 @@ where a.orgid = b.orgid
user=userid user=userid
) )
embeddings.extend(batch_embeddings) embeddings.extend(batch_embeddings)
if not embeddings or not all(len(vec) == 1024 for vec in embeddings): if not embeddings or not all(len(vec) == 1024 for vec in embeddings):
raise ValueError("所有嵌入向量必须是长度为 1024 的浮点数列表") raise ValueError("所有嵌入向量必须是长度为 1024 的浮点数列表")
timings["generate_embeddings"] = time.time() - start_embedding timings["generate_embeddings"] = time.time() - start_embedding
debug(f"生成嵌入向量耗时: {timings['generate_embeddings']:.2f} 秒, 嵌入数量: {len(embeddings)}") debug(f"生成嵌入向量耗时: {timings['generate_embeddings']:.2f} 秒, 嵌入数量: {len(embeddings)}")
return embeddings
chunks_data = [] async def embedding_2_vdb(self, request, chunks, embeddings, realpath, orgid, fiid, id, service_params, userid,
for i, chunk in enumerate(chunks): db_type, timings):
chunks_data.append({ """准备数据并插入 Milvus"""
debug(f"准备数据并调用插入文件端点: {realpath}")
filename = os.path.basename(realpath).rsplit('.', 1)[0]
ext = realpath.rsplit('.', 1)[1].lower() if '.' in realpath else ''
upload_time = datetime.now().isoformat()
chunks_data = [
{
"userid": orgid, "userid": orgid,
"knowledge_base_id": fiid, "knowledge_base_id": fiid,
"text": chunk.page_content, "text": chunk.page_content,
@ -146,13 +121,14 @@ where a.orgid = b.orgid
"file_path": realpath, "file_path": realpath,
"upload_time": upload_time, "upload_time": upload_time,
"file_type": ext, "file_type": ext,
}) }
for i, chunk in enumerate(chunks)
]
debug(f"调用插入文件端点: {realpath}")
start_milvus = time.time() start_milvus = time.time()
for i in range(0, len(chunks_data), 10): # 每次处理 10 条数据 for i in range(0, len(chunks_data), 10):
batch_chunks = chunks_data[i:i + 10] batch_chunks = chunks_data[i:i + 10]
result = await api_service.milvus_insert_document( result = await APIService().milvus_insert_document(
request=request, request=request,
chunks=batch_chunks, chunks=batch_chunks,
db_type=db_type, db_type=db_type,
@ -162,23 +138,19 @@ where a.orgid = b.orgid
) )
if result.get("status") != "success": if result.get("status") != "success":
raise ValueError(result.get("message", "Milvus 插入失败")) raise ValueError(result.get("message", "Milvus 插入失败"))
timings["insert_milvus"] = time.time() - start_milvus timings["insert_milvus"] = time.time() - start_milvus
debug(f"Milvus 插入耗时: {timings['insert_milvus']:.2f}") debug(f"Milvus 插入耗时: {timings['insert_milvus']:.2f}")
return chunks_data
if result.get("status") != "success": async def get_triples(self, request, chunks, service_params, userid, timings):
timings["total"] = time.time() - start_total """调用三元组抽取服务"""
return {"status": "error", "document_id": id, "timings": timings,
"message": result.get("message", "未知错误"), "status_code": 400}
debug("调用三元组抽取服务") debug("调用三元组抽取服务")
start_triples = time.time() start_triples = time.time()
unique_triples = []
try:
chunk_texts = [doc.page_content for doc in chunks] chunk_texts = [doc.page_content for doc in chunks]
debug(f"处理 {len(chunk_texts)} 个分片进行三元组抽取")
triples = [] triples = []
for i, chunk in enumerate(chunk_texts): for i, chunk in enumerate(chunk_texts):
result = await api_service.extract_triples( result = await APIService().extract_triples(
request=request, request=request,
text=chunk, text=chunk,
upappid=service_params['triples'], upappid=service_params['triples'],
@ -187,10 +159,11 @@ where a.orgid = b.orgid
) )
if isinstance(result, list): if isinstance(result, list):
triples.extend(result) triples.extend(result)
debug(f"分片 {i + 1} 抽取到 {len(result)} 个三元组: {result[:5]}") debug(f"分片 {i + 1} 抽取到 {len(result)} 个三元组")
else: else:
error(f"分片 {i + 1} 处理失败: {str(result)}") error(f"分片 {i + 1} 处理失败: {str(result)}")
unique_triples = []
seen = set() seen = set()
for t in triples: for t in triples:
identifier = (t['head'].lower(), t['tail'].lower(), t['type'].lower()) identifier = (t['head'].lower(), t['tail'].lower(), t['type'].lower())
@ -208,15 +181,17 @@ where a.orgid = b.orgid
break break
timings["extract_triples"] = time.time() - start_triples timings["extract_triples"] = time.time() - start_triples
debug( debug(f"三元组抽取耗时: {timings['extract_triples']:.2f} 秒, 抽取到 {len(unique_triples)} 个三元组")
f"三元组抽取耗时: {timings['extract_triples']:.2f} 秒, 抽取到 {len(unique_triples)} 个三元组: {unique_triples[:5]}") return unique_triples
if unique_triples: async def triple2graphdb(self, request, unique_triples, id, fiid, orgid, service_params, userid, timings):
debug(f"抽取到 {len(unique_triples)} 个三元组,调用 Neo4j 服务插入") """调用 Neo4j 插入三元组"""
debug(f"插入 {len(unique_triples)} 个三元组到 Neo4j")
start_neo4j = time.time() start_neo4j = time.time()
for i in range(0, len(unique_triples), 30): # 每次插入 30 个三元组 if unique_triples:
for i in range(0, len(unique_triples), 30):
batch_triples = unique_triples[i:i + 30] batch_triples = unique_triples[i:i + 30]
neo4j_result = await api_service.neo4j_insert_triples( neo4j_result = await APIService().neo4j_insert_triples(
request=request, request=request,
triples=batch_triples, triples=batch_triples,
document_id=id, document_id=id,
@ -226,41 +201,82 @@ where a.orgid = b.orgid
apiname="neo4j/inserttriples", apiname="neo4j/inserttriples",
user=userid user=userid
) )
debug(f"Neo4j 服务响应: {neo4j_result}")
if neo4j_result.get("status") != "success": if neo4j_result.get("status") != "success":
timings["insert_neo4j"] = time.time() - start_neo4j raise ValueError(f"Neo4j 三元组插入失败: {neo4j_result.get('message', '未知错误')}")
timings["total"] = time.time() - start_total info(f"文件三元组成功插入 Neo4j: {neo4j_result.get('message')}")
return {
"status": "error",
"document_id": id,
"collection_name": "ragdb",
"timings": timings,
"message": f"Neo4j 三元组插入失败: {neo4j_result.get('message', '未知错误')}",
"status_code": 400
}
info(f"文件 {realpath} 三元组成功插入 Neo4j: {neo4j_result.get('message')}")
timings["insert_neo4j"] = time.time() - start_neo4j timings["insert_neo4j"] = time.time() - start_neo4j
debug(f"Neo4j 插入耗时: {timings['insert_neo4j']:.2f}") debug(f"Neo4j 插入耗时: {timings['insert_neo4j']:.2f}")
else: else:
debug(f"文件 {realpath} 未抽取到三元组") debug("未抽取到三元组")
timings["insert_neo4j"] = 0.0 timings["insert_neo4j"] = 0.0
except Exception as e: async def delete_from_milvus(self, request, orgid, realpath, fiid, id, service_params, userid, db_type):
timings["extract_triples"] = time.time() - start_triples if "extract_triples" not in timings else \ """调用 Milvus 删除文档"""
timings["extract_triples"] debug(f"调用删除文件端点: userid={orgid}, file_path={realpath}, knowledge_base_id={fiid}, document_id={id}")
timings["insert_neo4j"] = time.time() - start_neo4j if "insert_neo4j" not in timings else timings[ milvus_result = await APIService().milvus_delete_document(
"insert_neo4j"] request=request,
debug(f"处理三元组或 Neo4j 插入失败: {str(e)}, 堆栈: {traceback.format_exc()}") userid=orgid,
timings["total"] = time.time() - start_total file_path=realpath,
return { knowledge_base_id=fiid,
"status": "success", document_id=id,
"document_id": id, db_type=db_type,
"collection_name": "ragdb", upappid=service_params['vdb'],
"timings": timings, apiname="milvus/deletedocument",
"unique_triples": unique_triples, user=userid
"message": f"文件 {realpath} 成功嵌入,但三元组处理或 Neo4j 插入失败: {str(e)}", )
"status_code": 200 if milvus_result.get("status") != "success":
} raise ValueError(milvus_result.get("message", "Milvus 删除失败"))
async def delete_from_neo4j(self, request, id, service_params, userid):
"""调用 Neo4j 删除文档"""
debug(f"调用 Neo4j 删除文档端点: document_id={id}")
neo4j_result = await APIService().neo4j_delete_document(
request=request,
document_id=id,
upappid=service_params['gdb'],
apiname="neo4j/deletedocument",
user=userid
)
if neo4j_result.get("status") != "success":
raise ValueError(neo4j_result.get("message", "Neo4j 删除失败"))
nodes_deleted = neo4j_result.get("nodes_deleted", 0)
rels_deleted = neo4j_result.get("rels_deleted", 0)
info(f"成功删除 document_id={id}{nodes_deleted} 个 Neo4j 节点和 {rels_deleted} 个关系")
return nodes_deleted, rels_deleted
async def file_uploaded(self, request, ns, userid):
"""将文档插入 Milvus 并抽取三元组到 Neo4j"""
debug(f'Received ns: {ns=}')
env = request._run_ns
realpath = ns.get('realpath', '')
fiid = ns.get('fiid', '')
id = ns.get('id', '')
orgid = ns.get('ownerid', '')
db_type = ''
debug(f'Inserting document: file_path={realpath}, userid={orgid}, db_type={db_type}, knowledge_base_id={fiid}, document_id={id}')
timings = {}
start_total = time.time()
try:
if not orgid or not fiid or not id:
raise ValueError("orgid、fiid 和 id 不能为空")
if len(orgid) > 32 or len(fiid) > 255:
raise ValueError("orgid 或 fiid 的长度超出限制")
if not os.path.exists(realpath):
raise ValueError(f"文件 {realpath} 不存在")
# 获取服务参数
service_params = await get_service_params(orgid)
if not service_params:
raise ValueError("无法获取服务参数")
chunks = await self.get_doucment_chunks(realpath, timings)
embeddings = await self.docs_embedding(request, chunks, service_params, userid, timings)
await self.embedding_2_vdb(request, chunks, embeddings, realpath, orgid, fiid, id, service_params, userid,db_type, timings)
triples = await self.get_triples(request, chunks, service_params, userid, timings)
await self.triple2graphdb(request, triples, id, fiid, orgid, service_params, userid, timings)
timings["total"] = time.time() - start_total timings["total"] = time.time() - start_total
debug(f"总耗时: {timings['total']:.2f}") debug(f"总耗时: {timings['total']:.2f}")
@ -270,11 +286,10 @@ where a.orgid = b.orgid
"document_id": id, "document_id": id,
"collection_name": "ragdb", "collection_name": "ragdb",
"timings": timings, "timings": timings,
"unique_triples": unique_triples, "unique_triples": triples,
"message": f"文件 {realpath} 成功嵌入并处理三元组", "message": f"文件 {realpath} 成功嵌入并处理三元组",
"status_code": 200 "status_code": 200
} }
except Exception as e: except Exception as e:
error(f"插入文档失败: {str(e)}, 堆栈: {traceback.format_exc()}") error(f"插入文档失败: {str(e)}, 堆栈: {traceback.format_exc()}")
timings["total"] = time.time() - start_total timings["total"] = time.time() - start_total
@ -290,9 +305,8 @@ where a.orgid = b.orgid
async def file_deleted(self, request, recs, userid): async def file_deleted(self, request, recs, userid):
"""删除用户指定文件数据,包括 Milvus 和 Neo4j 中的记录""" """删除用户指定文件数据,包括 Milvus 和 Neo4j 中的记录"""
if not isinstance(recs, list): if not isinstance(recs, list):
recs = [recs] # 确保 recs 是列表,即使传入单个记录 recs = [recs]
results = [] results = []
api_service = APIService()
total_nodes_deleted = 0 total_nodes_deleted = 0
total_rels_deleted = 0 total_rels_deleted = 0
@ -310,46 +324,24 @@ where a.orgid = b.orgid
if missing_fields: if missing_fields:
raise ValueError(f"缺少必填字段: {', '.join(missing_fields)}") raise ValueError(f"缺少必填字段: {', '.join(missing_fields)}")
# 获取服务参数
service_params = await get_service_params(orgid) service_params = await get_service_params(orgid)
if not service_params: if not service_params:
raise ValueError("无法获取服务参数") raise ValueError("无法获取服务参数")
debug(f"调用删除文件端点: userid={orgid}, file_path={realpath}, knowledge_base_id={fiid}, document_id={id}") # 调用 Milvus 删除
milvus_result = await api_service.milvus_delete_document( await self.delete_from_milvus(request, orgid, realpath, fiid, id, service_params, userid, db_type)
request=request,
userid=orgid,
file_path=realpath,
knowledge_base_id=fiid,
document_id=id,
db_type=db_type,
upappid=service_params['vdb'],
apiname="milvus/deletedocument",
user=userid
)
if milvus_result.get("status") != "success":
raise ValueError(milvus_result.get("message", "Milvus 删除失败"))
# 调用 Neo4j 删除
neo4j_deleted_nodes = 0 neo4j_deleted_nodes = 0
neo4j_deleted_rels = 0 neo4j_deleted_rels = 0
debug(f"调用 Neo4j 删除文档端点: document_id={id}") try:
neo4j_result = await api_service.neo4j_delete_document( nodes_deleted, rels_deleted = await self.delete_from_neo4j(request, id, service_params, userid)
request=request,
document_id=id,
upappid=service_params['gdb'],
apiname="neo4j/deletedocument",
user=userid
)
if neo4j_result.get("status") != "success":
raise ValueError(neo4j_result.get("message", "Neo4j 删除失败"))
nodes_deleted = neo4j_result.get("nodes_deleted", 0)
rels_deleted = neo4j_result.get("rels_deleted", 0)
neo4j_deleted_nodes += nodes_deleted neo4j_deleted_nodes += nodes_deleted
neo4j_deleted_rels += rels_deleted neo4j_deleted_rels += rels_deleted
total_nodes_deleted += nodes_deleted total_nodes_deleted += nodes_deleted
total_rels_deleted += rels_deleted total_rels_deleted += rels_deleted
info(f"成功删除 document_id={id}{nodes_deleted} 个 Neo4j 节点和 {rels_deleted} 个关系") except Exception as e:
error(f"删除 document_id={id} 的 Neo4j 数据失败: {str(e)}")
results.append({ results.append({
"status": "success", "status": "success",
@ -378,31 +370,31 @@ where a.orgid = b.orgid
"status_code": 200 if all(r["status"] == "success" for r in results) else 207 "status_code": 200 if all(r["status"] == "success" for r in results) else 207
} }
async def test_ragfilemgr(): # async def test_ragfilemgr():
"""测试 RagFileMgr 类的 get_service_params""" # """测试 RagFileMgr 类的 get_service_params"""
print("初始化数据库连接池...") # print("初始化数据库连接池...")
dbs = { # dbs = {
"kyrag": { # "kyrag": {
"driver": "aiomysql", # "driver": "aiomysql",
"async_mode": True, # "async_mode": True,
"coding": "utf8", # "coding": "utf8",
"maxconn": 100, # "maxconn": 100,
"dbname": "kyrag", # "dbname": "kyrag",
"kwargs": { # "kwargs": {
"user": "test", # "user": "test",
"db": "kyrag", # "db": "kyrag",
"password": "QUZVcXg5V1p1STMybG5Ia6mX9D0v7+g=", # "password": "QUZVcXg5V1p1STMybG5Ia6mX9D0v7+g=",
"host": "db" # "host": "db"
} # }
} # }
} # }
DBPools(dbs) # DBPools(dbs)
#
ragfilemgr = RagFileMgr() # ragfilemgr = RagFileMgr()
orgid = "04J6VbxLqB_9RPMcgOv_8" # orgid = "04J6VbxLqB_9RPMcgOv_8"
result = await ragfilemgr.get_service_params(orgid) # result = await get_service_params(orgid)
print(f"get_service_params 结果: {result}") # print(f"get_service_params 结果: {result}")
#
#
if __name__ == "__main__": # if __name__ == "__main__":
asyncio.run(test_ragfilemgr()) # asyncio.run(test_ragfilemgr())

View File

@ -1,11 +1,11 @@
from rag.uapi_service import APIService from rag.uapi_service import APIService
from rag.folderinfo import RagFileMgr
from sqlor.dbpools import DBPools from sqlor.dbpools import DBPools
from appPublic.log import debug, error, info from appPublic.log import debug, error, info
import time import time
import traceback import traceback
import json import json
import math import math
from rag.service_opts import get_service_params, sor_get_service_params
helptext = """kyrag API: helptext = """kyrag API:
@ -82,7 +82,7 @@ async def fusedsearch(request, params_kw, *params, **kw):
# orgid = "04J6VbxLqB_9RPMcgOv_8" # orgid = "04J6VbxLqB_9RPMcgOv_8"
# userid = "04J6VbxLqB_9RPMcgOv_8" # userid = "04J6VbxLqB_9RPMcgOv_8"
query = params_kw.get('query', '') query = params_kw.get('query', '')
# 统一模式处理 limit 参数 # 统一模式处理 limit 参数,为了对接dify和coze
raw_limit = params_kw.get('limit') or ( raw_limit = params_kw.get('limit') or (
params_kw.get('retrieval_setting', {}).get('top_k') params_kw.get('retrieval_setting', {}).get('top_k')
if isinstance(params_kw.get('retrieval_setting'), dict) if isinstance(params_kw.get('retrieval_setting'), dict)
@ -103,7 +103,7 @@ async def fusedsearch(request, params_kw, *params, **kw):
else: else:
limit = 5 # 其他意外类型使用默认值 limit = 5 # 其他意外类型使用默认值
debug(f"limit: {limit}") debug(f"limit: {limit}")
raw_fiids = params_kw.get('fiids') or params_kw.get('knowledge_id') raw_fiids = params_kw.get('fiids') or params_kw.get('knowledge_id') #
# 标准化为列表格式 # 标准化为列表格式
if raw_fiids is None: if raw_fiids is None:
@ -111,7 +111,17 @@ async def fusedsearch(request, params_kw, *params, **kw):
elif isinstance(raw_fiids, list): elif isinstance(raw_fiids, list):
fiids = [str(item).strip() for item in raw_fiids] # 已经是列表 fiids = [str(item).strip() for item in raw_fiids] # 已经是列表
elif isinstance(raw_fiids, str): elif isinstance(raw_fiids, str):
# 处理逗号分隔的字符串或单个ID字符串 # fiids = [f.strip() for f in raw_fiids.split(',') if f.strip()]
try:
# 尝试解析 JSON 字符串
parsed = json.loads(raw_fiids)
if isinstance(parsed, list):
fiids = [str(item).strip() for item in parsed] # JSON 数组转为字符串列表
else:
# 处理逗号分隔的字符串或单个 ID 字符串
fiids = [f.strip() for f in raw_fiids.split(',') if f.strip()]
except json.JSONDecodeError:
# 如果不是合法 JSON按逗号分隔
fiids = [f.strip() for f in raw_fiids.split(',') if f.strip()] fiids = [f.strip() for f in raw_fiids.split(',') if f.strip()]
elif isinstance(raw_fiids, (int, float)): elif isinstance(raw_fiids, (int, float)):
fiids = [str(int(raw_fiids))] # 数值类型转为字符串列表 fiids = [str(int(raw_fiids))] # 数值类型转为字符串列表
@ -140,8 +150,7 @@ async def fusedsearch(request, params_kw, *params, **kw):
except Exception as e: except Exception as e:
error(f"orgid 验证失败: {str(e)}") error(f"orgid 验证失败: {str(e)}")
return json.dumps({"status": "error", "message": str(e)}) return json.dumps({"status": "error", "message": str(e)})
ragfilemgr = RagFileMgr("fiids[0]") service_params = await get_service_params(orgid)
service_params = await ragfilemgr.get_service_params(orgid)
api_service = APIService() api_service = APIService()
start_time = time.time() start_time = time.time()
@ -276,9 +285,19 @@ async def fusedsearch(request, params_kw, *params, **kw):
timing_stats["total_time"] = time.time() - start_time timing_stats["total_time"] = time.time() - start_time
info(f"融合搜索完成,返回 {len(unique_results)} 条结果,总耗时: {timing_stats['total_time']:.3f}") info(f"融合搜索完成,返回 {len(unique_results)} 条结果,总耗时: {timing_stats['total_time']:.3f}")
# debug(f"results: {unique_results[:limit]},timing: {timing_stats}") # dify_result = []
# return {"results": unique_results[:limit], "timing": timing_stats} # for res in unique_results[:limit]:
# content = res.get('text', '')
# title = res.get('metadata', {}).get('filename', 'Untitled')
# document_id = res.get('metadata', {}).get('document_id', '')
# dify_result.append({
# 'metadata': {'document_id': document_id},
# 'title': title,
# 'content': content
# })
# info(f"融合搜索完成,返回 {len(dify_result)} 条结果,总耗时: {(time.time() - start_time):.3f} 秒")
# debug(f"result: {dify_result}")
# return dify_result
dify_records = [] dify_records = []
dify_result = [] dify_result = []
@ -291,18 +310,50 @@ async def fusedsearch(request, params_kw, *params, **kw):
document_id = res.get('metadata', {}).get('document_id', '') document_id = res.get('metadata', {}).get('document_id', '')
dify_records.append({ dify_records.append({
"content": content, "content": content,
"score": score, "title": title,
"title": title "metadata": {"document_id": document_id, "score": score},
}) })
dify_result.append({ dify_result.append({
"content": content, "content": content,
"title": title, "title": title,
"metadata": {"document_id": document_id} "metadata": {"document_id": document_id, "score": score},
}) })
info(f"融合搜索完成,返回 {len(dify_records)} 条结果,总耗时: {(time.time() - start_time):.3f}") info(f"融合搜索完成,返回 {len(dify_records)} 条结果,总耗时: {(time.time() - start_time):.3f}")
debug(f"records: {dify_records}, result: {dify_result}") debug(f"records: {dify_records}, result: {dify_result}")
return {"records": dify_records, "result": dify_result, "own":{"results": unique_results[:limit], "timing": timing_stats}} # return {"records": dify_records, "result": dify_result,"own": {"results": unique_results[:limit], "timing": timing_stats}}
return {"records": dify_records}
# dify_result = []
# for res in unique_results[:limit]:
# rerank_score = res.get('rerank_score', 0)
# score = 1 / (1 + math.exp(-rerank_score)) if rerank_score is not None else 1 - res.get('distance', 0)
# score = max(0.0, min(1.0, score))
# content = res.get('text', '')
# title = res.get('metadata', {}).get('filename', 'Untitled')
# document_id = res.get('metadata', {}).get('document_id', '')
# dify_result.append({
# "metadata": {
# "_source": "konwledge",
# "dataset_id":"111111",
# "dataset_name": "NVIDIA_GPU性能参数-RAG-V1.xlsx",
# "document_id": document_id,
# "document_name": "test.docx",
# "data_source_type": "upload_file",
# "segment_id": "7b391707-93bc-4654-80ae-7989f393b045",
# "retriever_from": "workflow",
# "score": score,
# "segment_hit_count": 7,
# "segment_word_count": 275,
# "segment_position": 5,
# "segment_index_node_hash": "1cd60b478221c9d4831a0b2af3e8b8581d94ecb53e8ffd46af687e8fc3077b73",
# "doc_metadata": None,
# "position":1
# },
# "title": title,
# "content": content
# })
# return {"result": dify_result}
except Exception as e: except Exception as e:
error(f"融合搜索失败: {str(e)}, 堆栈: {traceback.format_exc()}") error(f"融合搜索失败: {str(e)}, 堆栈: {traceback.format_exc()}")
return {"results": [], "timing": timing_stats} return {"results": [], "timing": timing_stats}

View File

@ -1,4 +1,5 @@
from ahserver.serverenv import get_serverenv from ahserver.serverenv import get_serverenv
from sqlor.dbpools import DBPools
async def sor_get_service_params(sor, orgid): async def sor_get_service_params(sor, orgid):
""" 根据 orgid 从数据库获取服务参数 (仅 upappid),假设 service_opts 表返回单条记录。 """ """ 根据 orgid 从数据库获取服务参数 (仅 upappid),假设 service_opts 表返回单条记录。 """
@ -25,7 +26,7 @@ async def sor_get_service_params(sor, orgid):
return None return None
# 手动构造 IN 子句的 ID 列表 # 手动构造 IN 子句的 ID 列表
id_list = ','.join([f"'{id}'" for id in service_ids]) # 确保每个 ID 被单引号包裹 id_list = [id for id in service_ids] # 确保每个 ID 被单引号包裹
sql_services = """ sql_services = """
SELECT id, name, upappid SELECT id, name, upappid
FROM ragservices FROM ragservices
@ -71,5 +72,5 @@ async def get_service_params(orgid):
db = DBPools() db = DBPools()
dbname = get_serverenv('get_module_dbname')('rag') dbname = get_serverenv('get_module_dbname')('rag')
async with db.sqlorContext(dbname) as sor: async with db.sqlorContext(dbname) as sor:
return await sor_get_server_params(sor, orgid) return await sor_get_service_params(sor, orgid)
return None return None