This commit is contained in:
wangmeihua 2025-09-12 15:34:36 +08:00
parent e590c1084f
commit b6d3f39081
3 changed files with 352 additions and 308 deletions

View File

@ -14,7 +14,7 @@ import time
import uuid import uuid
from datetime import datetime from datetime import datetime
import traceback import traceback
from filetxt.loader import fileloader from filetxt.loader import fileloader,File2Text
from ahserver.serverenv import get_serverenv from ahserver.serverenv import get_serverenv
from typing import List, Dict, Any from typing import List, Dict, Any
from rag.service_opts import get_service_params, sor_get_service_params from rag.service_opts import get_service_params, sor_get_service_params
@ -44,6 +44,206 @@ where a.orgid = b.orgid
return r.quota, r.expired_date return r.quota, r.expired_date
return None, None return None, None
async def get_doucment_chunks(self, realpath, timings):
"""加载文件并进行文本分片"""
debug(f"加载文件: {realpath}")
start_load = time.time()
supported_formats = File2Text.supported_types()
debug(f"支持的文件格式:{supported_formats}")
ext = realpath.rsplit('.', 1)[1].lower() if '.' in realpath else ''
if ext not in supported_formats:
raise ValueError(f"不支持的文件格式: {ext}, 支持的格式: {', '.join(supported_formats)}")
text = fileloader(realpath)
text = re.sub(r'[^\u4e00-\u9fa5a-zA-Z0-9\s.;,\n/]', '', text)
timings["load_file"] = time.time() - start_load
debug(f"加载文件耗时: {timings['load_file']:.2f} 秒, 文本长度: {len(text)}")
if not text or not text.strip():
raise ValueError(f"文件 {realpath} 加载为空")
document = Document(page_content=text)
text_splitter = RecursiveCharacterTextSplitter(
chunk_size=500,
chunk_overlap=100,
length_function=len
)
debug("开始分片文件内容")
start_split = time.time()
chunks = text_splitter.split_documents([document])
timings["split_text"] = time.time() - start_split
debug(f"文本分片耗时: {timings['split_text']:.2f} 秒, 分片数量: {len(chunks)}")
if not chunks:
raise ValueError(f"文件 {realpath} 未生成任何文档块")
return chunks
async def docs_embedding(self, request, chunks, service_params, userid, timings):
"""调用嵌入服务生成向量"""
debug("调用嵌入服务生成向量")
start_embedding = time.time()
texts = [chunk.page_content for chunk in chunks]
embeddings = []
for i in range(0, len(texts), 10):
batch_texts = texts[i:i + 10]
batch_embeddings = await APIService().get_embeddings(
request=request,
texts=batch_texts,
upappid=service_params['embedding'],
apiname="BAAI/bge-m3",
user=userid
)
embeddings.extend(batch_embeddings)
if not embeddings or not all(len(vec) == 1024 for vec in embeddings):
raise ValueError("所有嵌入向量必须是长度为 1024 的浮点数列表")
timings["generate_embeddings"] = time.time() - start_embedding
debug(f"生成嵌入向量耗时: {timings['generate_embeddings']:.2f} 秒, 嵌入数量: {len(embeddings)}")
return embeddings
async def embedding_2_vdb(self, request, chunks, embeddings, realpath, orgid, fiid, id, service_params, userid,
db_type, timings):
"""准备数据并插入 Milvus"""
debug(f"准备数据并调用插入文件端点: {realpath}")
filename = os.path.basename(realpath).rsplit('.', 1)[0]
ext = realpath.rsplit('.', 1)[1].lower() if '.' in realpath else ''
upload_time = datetime.now().isoformat()
chunks_data = [
{
"userid": orgid,
"knowledge_base_id": fiid,
"text": chunk.page_content,
"vector": embeddings[i],
"document_id": id,
"filename": filename + '.' + ext,
"file_path": realpath,
"upload_time": upload_time,
"file_type": ext,
}
for i, chunk in enumerate(chunks)
]
start_milvus = time.time()
for i in range(0, len(chunks_data), 10):
batch_chunks = chunks_data[i:i + 10]
result = await APIService().milvus_insert_document(
request=request,
chunks=batch_chunks,
db_type=db_type,
upappid=service_params['vdb'],
apiname="milvus/insertdocument",
user=userid
)
if result.get("status") != "success":
raise ValueError(result.get("message", "Milvus 插入失败"))
timings["insert_milvus"] = time.time() - start_milvus
debug(f"Milvus 插入耗时: {timings['insert_milvus']:.2f}")
return chunks_data
async def get_triples(self, request, chunks, service_params, userid, timings):
"""调用三元组抽取服务"""
debug("调用三元组抽取服务")
start_triples = time.time()
chunk_texts = [doc.page_content for doc in chunks]
triples = []
for i, chunk in enumerate(chunk_texts):
result = await APIService().extract_triples(
request=request,
text=chunk,
upappid=service_params['triples'],
apiname="Babelscape/mrebel-large",
user=userid
)
if isinstance(result, list):
triples.extend(result)
debug(f"分片 {i + 1} 抽取到 {len(result)} 个三元组")
else:
error(f"分片 {i + 1} 处理失败: {str(result)}")
unique_triples = []
seen = set()
for t in triples:
identifier = (t['head'].lower(), t['tail'].lower(), t['type'].lower())
if identifier not in seen:
seen.add(identifier)
unique_triples.append(t)
else:
for existing in unique_triples:
if (existing['head'].lower() == t['head'].lower() and
existing['tail'].lower() == t['tail'].lower() and
len(t['type']) > len(existing['type'])):
unique_triples.remove(existing)
unique_triples.append(t)
debug(f"替换三元组为更具体类型: {t}")
break
timings["extract_triples"] = time.time() - start_triples
debug(f"三元组抽取耗时: {timings['extract_triples']:.2f} 秒, 抽取到 {len(unique_triples)} 个三元组")
return unique_triples
async def triple2graphdb(self, request, unique_triples, id, fiid, orgid, service_params, userid, timings):
"""调用 Neo4j 插入三元组"""
debug(f"插入 {len(unique_triples)} 个三元组到 Neo4j")
start_neo4j = time.time()
if unique_triples:
for i in range(0, len(unique_triples), 30):
batch_triples = unique_triples[i:i + 30]
neo4j_result = await APIService().neo4j_insert_triples(
request=request,
triples=batch_triples,
document_id=id,
knowledge_base_id=fiid,
userid=orgid,
upappid=service_params['gdb'],
apiname="neo4j/inserttriples",
user=userid
)
if neo4j_result.get("status") != "success":
raise ValueError(f"Neo4j 三元组插入失败: {neo4j_result.get('message', '未知错误')}")
info(f"文件三元组成功插入 Neo4j: {neo4j_result.get('message')}")
timings["insert_neo4j"] = time.time() - start_neo4j
debug(f"Neo4j 插入耗时: {timings['insert_neo4j']:.2f}")
else:
debug("未抽取到三元组")
timings["insert_neo4j"] = 0.0
async def delete_from_milvus(self, request, orgid, realpath, fiid, id, service_params, userid, db_type):
"""调用 Milvus 删除文档"""
debug(f"调用删除文件端点: userid={orgid}, file_path={realpath}, knowledge_base_id={fiid}, document_id={id}")
milvus_result = await APIService().milvus_delete_document(
request=request,
userid=orgid,
file_path=realpath,
knowledge_base_id=fiid,
document_id=id,
db_type=db_type,
upappid=service_params['vdb'],
apiname="milvus/deletedocument",
user=userid
)
if milvus_result.get("status") != "success":
raise ValueError(milvus_result.get("message", "Milvus 删除失败"))
async def delete_from_neo4j(self, request, id, service_params, userid):
"""调用 Neo4j 删除文档"""
debug(f"调用 Neo4j 删除文档端点: document_id={id}")
neo4j_result = await APIService().neo4j_delete_document(
request=request,
document_id=id,
upappid=service_params['gdb'],
apiname="neo4j/deletedocument",
user=userid
)
if neo4j_result.get("status") != "success":
raise ValueError(neo4j_result.get("message", "Neo4j 删除失败"))
nodes_deleted = neo4j_result.get("nodes_deleted", 0)
rels_deleted = neo4j_result.get("rels_deleted", 0)
info(f"成功删除 document_id={id}{nodes_deleted} 个 Neo4j 节点和 {rels_deleted} 个关系")
return nodes_deleted, rels_deleted
async def file_uploaded(self, request, ns, userid): async def file_uploaded(self, request, ns, userid):
"""将文档插入 Milvus 并抽取三元组到 Neo4j""" """将文档插入 Milvus 并抽取三元组到 Neo4j"""
debug(f'Received ns: {ns=}') debug(f'Received ns: {ns=}')
@ -52,23 +252,13 @@ where a.orgid = b.orgid
fiid = ns.get('fiid', '') fiid = ns.get('fiid', '')
id = ns.get('id', '') id = ns.get('id', '')
orgid = ns.get('ownerid', '') orgid = ns.get('ownerid', '')
hashvalue = ns.get('hashvalue', '')
db_type = '' db_type = ''
api_service = APIService() debug(f'Inserting document: file_path={realpath}, userid={orgid}, db_type={db_type}, knowledge_base_id={fiid}, document_id={id}')
debug(
f'Inserting document: file_path={realpath}, userid={orgid}, db_type={db_type}, knowledge_base_id={fiid}, document_id={id}')
timings = {} timings = {}
start_total = time.time() start_total = time.time()
service_params = await get_service_params(orgid)
chunks = await self.get_doucment_chunks(realpath)
embeddings = await self.docs_embedding(chunks)
await self.embedding_2_vdb(id, fiid, orgid, realpath, embedding)
triples = await self.get_triples(chunks)
await self.triple2graphdb(id, fiid, orgid, realpath, triples)
return
try: try:
if not orgid or not fiid or not id: if not orgid or not fiid or not id:
raise ValueError("orgid、fiid 和 id 不能为空") raise ValueError("orgid、fiid 和 id 不能为空")
@ -82,217 +272,41 @@ where a.orgid = b.orgid
if not service_params: if not service_params:
raise ValueError("无法获取服务参数") raise ValueError("无法获取服务参数")
supported_formats = {'pdf', 'docx', 'xlsx', 'pptx', 'csv', 'txt'} chunks = await self.get_doucment_chunks(realpath, timings)
ext = realpath.rsplit('.', 1)[1].lower() if '.' in realpath else '' embeddings = await self.docs_embedding(request, chunks, service_params, userid, timings)
if ext not in supported_formats: await self.embedding_2_vdb(request, chunks, embeddings, realpath, orgid, fiid, id, service_params, userid,db_type, timings)
raise ValueError(f"不支持的文件格式: {ext}, 支持的格式: {', '.join(supported_formats)}") triples = await self.get_triples(request, chunks, service_params, userid, timings)
await self.triple2graphdb(request, triples, id, fiid, orgid, service_params, userid, timings)
debug(f"加载文件: {realpath}")
start_load = time.time()
text = fileloader(realpath)
# debug(f"处理后的文件内容是:{text=}")
text = re.sub(r'[^\u4e00-\u9fa5a-zA-Z0-9\s.;,\n/]', '', text)
timings["load_file"] = time.time() - start_load
debug(f"加载文件耗时: {timings['load_file']:.2f} 秒, 文本长度: {len(text)}")
if not text or not text.strip():
raise ValueError(f"文件 {realpath} 加载为空")
document = Document(page_content=text)
text_splitter = RecursiveCharacterTextSplitter(
chunk_size=500,
chunk_overlap=100,
length_function=len)
debug("开始分片文件内容")
start_split = time.time()
chunks = text_splitter.split_documents([document])
timings["split_text"] = time.time() - start_split
debug(
f"文本分片耗时: {timings['split_text']:.2f} 秒, 分片数量: {len(chunks)}, 分片内容: {[chunk.page_content[:50] for chunk in chunks[:5]]}")
debug(f"分片内容: {[chunk.page_content[:100] + '...' for chunk in chunks]}")
if not chunks:
raise ValueError(f"文件 {realpath} 未生成任何文档块")
filename = os.path.basename(realpath).rsplit('.', 1)[0]
upload_time = datetime.now().isoformat()
debug("调用嵌入服务生成向量")
start_embedding = time.time()
texts = [chunk.page_content for chunk in chunks]
embeddings = []
for i in range(0, len(texts), 10): # 每次处理 10 个文本块
batch_texts = texts[i:i + 10]
batch_embeddings = await api_service.get_embeddings(
request=request,
texts=batch_texts,
upappid=service_params['embedding'],
apiname="BAAI/bge-m3",
user=userid
)
embeddings.extend(batch_embeddings)
if not embeddings or not all(len(vec) == 1024 for vec in embeddings):
raise ValueError("所有嵌入向量必须是长度为 1024 的浮点数列表")
timings["generate_embeddings"] = time.time() - start_embedding
debug(f"生成嵌入向量耗时: {timings['generate_embeddings']:.2f} 秒, 嵌入数量: {len(embeddings)}")
chunks_data = []
for i, chunk in enumerate(chunks):
chunks_data.append({
"userid": orgid,
"knowledge_base_id": fiid,
"text": chunk.page_content,
"vector": embeddings[i],
"document_id": id,
"filename": filename + '.' + ext,
"file_path": realpath,
"upload_time": upload_time,
"file_type": ext,
})
debug(f"调用插入文件端点: {realpath}")
start_milvus = time.time()
for i in range(0, len(chunks_data), 10): # 每次处理 10 条数据
batch_chunks = chunks_data[i:i + 10]
result = await api_service.milvus_insert_document(
request=request,
chunks=batch_chunks,
db_type=db_type,
upappid=service_params['vdb'],
apiname="milvus/insertdocument",
user=userid
)
if result.get("status") != "success":
raise ValueError(result.get("message", "Milvus 插入失败"))
timings["insert_milvus"] = time.time() - start_milvus
debug(f"Milvus 插入耗时: {timings['insert_milvus']:.2f}")
if result.get("status") != "success":
timings["total"] = time.time() - start_total
return {"status": "error", "document_id": id, "timings": timings,
"message": result.get("message", "未知错误"), "status_code": 400}
debug("调用三元组抽取服务")
start_triples = time.time()
unique_triples = []
try:
chunk_texts = [doc.page_content for doc in chunks]
debug(f"处理 {len(chunk_texts)} 个分片进行三元组抽取")
triples = []
for i, chunk in enumerate(chunk_texts):
result = await api_service.extract_triples(
request=request,
text=chunk,
upappid=service_params['triples'],
apiname="Babelscape/mrebel-large",
user=userid
)
if isinstance(result, list):
triples.extend(result)
debug(f"分片 {i + 1} 抽取到 {len(result)} 个三元组: {result[:5]}")
else:
error(f"分片 {i + 1} 处理失败: {str(result)}")
seen = set()
for t in triples:
identifier = (t['head'].lower(), t['tail'].lower(), t['type'].lower())
if identifier not in seen:
seen.add(identifier)
unique_triples.append(t)
else:
for existing in unique_triples:
if (existing['head'].lower() == t['head'].lower() and
existing['tail'].lower() == t['tail'].lower() and
len(t['type']) > len(existing['type'])):
unique_triples.remove(existing)
unique_triples.append(t)
debug(f"替换三元组为更具体类型: {t}")
break
timings["extract_triples"] = time.time() - start_triples
debug(
f"三元组抽取耗时: {timings['extract_triples']:.2f} 秒, 抽取到 {len(unique_triples)} 个三元组: {unique_triples[:5]}")
if unique_triples:
debug(f"抽取到 {len(unique_triples)} 个三元组,调用 Neo4j 服务插入")
start_neo4j = time.time()
for i in range(0, len(unique_triples), 30): # 每次插入 30 个三元组
batch_triples = unique_triples[i:i + 30]
neo4j_result = await api_service.neo4j_insert_triples(
request=request,
triples=batch_triples,
document_id=id,
knowledge_base_id=fiid,
userid=orgid,
upappid=service_params['gdb'],
apiname="neo4j/inserttriples",
user=userid
)
debug(f"Neo4j 服务响应: {neo4j_result}")
if neo4j_result.get("status") != "success":
timings["insert_neo4j"] = time.time() - start_neo4j
timings["total"] = time.time() - start_total
return {
"status": "error",
"document_id": id,
"collection_name": "ragdb",
"timings": timings,
"message": f"Neo4j 三元组插入失败: {neo4j_result.get('message', '未知错误')}",
"status_code": 400
}
info(f"文件 {realpath} 三元组成功插入 Neo4j: {neo4j_result.get('message')}")
timings["insert_neo4j"] = time.time() - start_neo4j
debug(f"Neo4j 插入耗时: {timings['insert_neo4j']:.2f}")
else:
debug(f"文件 {realpath} 未抽取到三元组")
timings["insert_neo4j"] = 0.0
except Exception as e:
timings["extract_triples"] = time.time() - start_triples if "extract_triples" not in timings else \
timings["extract_triples"]
timings["insert_neo4j"] = time.time() - start_neo4j if "insert_neo4j" not in timings else timings[
"insert_neo4j"]
debug(f"处理三元组或 Neo4j 插入失败: {str(e)}, 堆栈: {traceback.format_exc()}")
timings["total"] = time.time() - start_total
return {
"status": "success",
"document_id": id,
"collection_name": "ragdb",
"timings": timings,
"unique_triples": unique_triples,
"message": f"文件 {realpath} 成功嵌入,但三元组处理或 Neo4j 插入失败: {str(e)}",
"status_code": 200
}
timings["total"] = time.time() - start_total timings["total"] = time.time() - start_total
debug(f"总耗时: {timings['total']:.2f}") debug(f"总耗时: {timings['total']:.2f}")
return { return {
"status": "success", "status": "success",
"userid": orgid, "userid": orgid,
"document_id": id, "document_id": id,
"collection_name": "ragdb", "collection_name": "ragdb",
"timings": timings, "timings": timings,
"unique_triples": unique_triples, "unique_triples": triples,
"message": f"文件 {realpath} 成功嵌入并处理三元组", "message": f"文件 {realpath} 成功嵌入并处理三元组",
"status_code": 200 "status_code": 200
} }
except Exception as e: except Exception as e:
error(f"插入文档失败: {str(e)}, 堆栈: {traceback.format_exc()}") error(f"插入文档失败: {str(e)}, 堆栈: {traceback.format_exc()}")
timings["total"] = time.time() - start_total timings["total"] = time.time() - start_total
return { return {
"status": "error", "status": "error",
"document_id": id, "document_id": id,
"collection_name": "ragdb", "collection_name": "ragdb",
"timings": timings, "timings": timings,
"message": f"插入文档失败: {str(e)}", "message": f"插入文档失败: {str(e)}",
"status_code": 400 "status_code": 400
} }
async def file_deleted(self, request, recs, userid): async def file_deleted(self, request, recs, userid):
"""删除用户指定文件数据,包括 Milvus 和 Neo4j 中的记录""" """删除用户指定文件数据,包括 Milvus 和 Neo4j 中的记录"""
if not isinstance(recs, list): if not isinstance(recs, list):
recs = [recs] # 确保 recs 是列表,即使传入单个记录 recs = [recs]
results = [] results = []
api_service = APIService()
total_nodes_deleted = 0 total_nodes_deleted = 0
total_rels_deleted = 0 total_rels_deleted = 0
@ -310,46 +324,24 @@ where a.orgid = b.orgid
if missing_fields: if missing_fields:
raise ValueError(f"缺少必填字段: {', '.join(missing_fields)}") raise ValueError(f"缺少必填字段: {', '.join(missing_fields)}")
# 获取服务参数
service_params = await get_service_params(orgid) service_params = await get_service_params(orgid)
if not service_params: if not service_params:
raise ValueError("无法获取服务参数") raise ValueError("无法获取服务参数")
debug(f"调用删除文件端点: userid={orgid}, file_path={realpath}, knowledge_base_id={fiid}, document_id={id}") # 调用 Milvus 删除
milvus_result = await api_service.milvus_delete_document( await self.delete_from_milvus(request, orgid, realpath, fiid, id, service_params, userid, db_type)
request=request,
userid=orgid,
file_path=realpath,
knowledge_base_id=fiid,
document_id=id,
db_type=db_type,
upappid=service_params['vdb'],
apiname="milvus/deletedocument",
user=userid
)
if milvus_result.get("status") != "success":
raise ValueError(milvus_result.get("message", "Milvus 删除失败"))
# 调用 Neo4j 删除
neo4j_deleted_nodes = 0 neo4j_deleted_nodes = 0
neo4j_deleted_rels = 0 neo4j_deleted_rels = 0
debug(f"调用 Neo4j 删除文档端点: document_id={id}") try:
neo4j_result = await api_service.neo4j_delete_document( nodes_deleted, rels_deleted = await self.delete_from_neo4j(request, id, service_params, userid)
request=request, neo4j_deleted_nodes += nodes_deleted
document_id=id, neo4j_deleted_rels += rels_deleted
upappid=service_params['gdb'], total_nodes_deleted += nodes_deleted
apiname="neo4j/deletedocument", total_rels_deleted += rels_deleted
user=userid except Exception as e:
) error(f"删除 document_id={id} 的 Neo4j 数据失败: {str(e)}")
if neo4j_result.get("status") != "success":
raise ValueError(neo4j_result.get("message", "Neo4j 删除失败"))
nodes_deleted = neo4j_result.get("nodes_deleted", 0)
rels_deleted = neo4j_result.get("rels_deleted", 0)
neo4j_deleted_nodes += nodes_deleted
neo4j_deleted_rels += rels_deleted
total_nodes_deleted += nodes_deleted
total_rels_deleted += rels_deleted
info(f"成功删除 document_id={id}{nodes_deleted} 个 Neo4j 节点和 {rels_deleted} 个关系")
results.append({ results.append({
"status": "success", "status": "success",
@ -361,13 +353,13 @@ where a.orgid = b.orgid
except Exception as e: except Exception as e:
error(f"删除文档 {realpath} 失败: {str(e)}, 堆栈: {traceback.format_exc()}") error(f"删除文档 {realpath} 失败: {str(e)}, 堆栈: {traceback.format_exc()}")
results.append({ results.append({
"status": "error", "status": "error",
"collection_name": collection_name, "collection_name": collection_name,
"document_id": id, "document_id": id,
"message": f"删除文档 {realpath} 失败: {str(e)}", "message": f"删除文档 {realpath} 失败: {str(e)}",
"status_code": 400 "status_code": 400
}) })
return { return {
"status": "success" if all(r["status"] == "success" for r in results) else "partial", "status": "success" if all(r["status"] == "success" for r in results) else "partial",
@ -378,31 +370,31 @@ where a.orgid = b.orgid
"status_code": 200 if all(r["status"] == "success" for r in results) else 207 "status_code": 200 if all(r["status"] == "success" for r in results) else 207
} }
async def test_ragfilemgr(): # async def test_ragfilemgr():
"""测试 RagFileMgr 类的 get_service_params""" # """测试 RagFileMgr 类的 get_service_params"""
print("初始化数据库连接池...") # print("初始化数据库连接池...")
dbs = { # dbs = {
"kyrag": { # "kyrag": {
"driver": "aiomysql", # "driver": "aiomysql",
"async_mode": True, # "async_mode": True,
"coding": "utf8", # "coding": "utf8",
"maxconn": 100, # "maxconn": 100,
"dbname": "kyrag", # "dbname": "kyrag",
"kwargs": { # "kwargs": {
"user": "test", # "user": "test",
"db": "kyrag", # "db": "kyrag",
"password": "QUZVcXg5V1p1STMybG5Ia6mX9D0v7+g=", # "password": "QUZVcXg5V1p1STMybG5Ia6mX9D0v7+g=",
"host": "db" # "host": "db"
} # }
} # }
} # }
DBPools(dbs) # DBPools(dbs)
#
ragfilemgr = RagFileMgr() # ragfilemgr = RagFileMgr()
orgid = "04J6VbxLqB_9RPMcgOv_8" # orgid = "04J6VbxLqB_9RPMcgOv_8"
result = await ragfilemgr.get_service_params(orgid) # result = await get_service_params(orgid)
print(f"get_service_params 结果: {result}") # print(f"get_service_params 结果: {result}")
#
#
if __name__ == "__main__": # if __name__ == "__main__":
asyncio.run(test_ragfilemgr()) # asyncio.run(test_ragfilemgr())

View File

@ -1,11 +1,11 @@
from rag.uapi_service import APIService from rag.uapi_service import APIService
from rag.folderinfo import RagFileMgr
from sqlor.dbpools import DBPools from sqlor.dbpools import DBPools
from appPublic.log import debug, error, info from appPublic.log import debug, error, info
import time import time
import traceback import traceback
import json import json
import math import math
from rag.service_opts import get_service_params, sor_get_service_params
helptext = """kyrag API: helptext = """kyrag API:
@ -82,7 +82,7 @@ async def fusedsearch(request, params_kw, *params, **kw):
# orgid = "04J6VbxLqB_9RPMcgOv_8" # orgid = "04J6VbxLqB_9RPMcgOv_8"
# userid = "04J6VbxLqB_9RPMcgOv_8" # userid = "04J6VbxLqB_9RPMcgOv_8"
query = params_kw.get('query', '') query = params_kw.get('query', '')
# 统一模式处理 limit 参数 # 统一模式处理 limit 参数,为了对接dify和coze
raw_limit = params_kw.get('limit') or ( raw_limit = params_kw.get('limit') or (
params_kw.get('retrieval_setting', {}).get('top_k') params_kw.get('retrieval_setting', {}).get('top_k')
if isinstance(params_kw.get('retrieval_setting'), dict) if isinstance(params_kw.get('retrieval_setting'), dict)
@ -103,7 +103,7 @@ async def fusedsearch(request, params_kw, *params, **kw):
else: else:
limit = 5 # 其他意外类型使用默认值 limit = 5 # 其他意外类型使用默认值
debug(f"limit: {limit}") debug(f"limit: {limit}")
raw_fiids = params_kw.get('fiids') or params_kw.get('knowledge_id') raw_fiids = params_kw.get('fiids') or params_kw.get('knowledge_id') #
# 标准化为列表格式 # 标准化为列表格式
if raw_fiids is None: if raw_fiids is None:
@ -111,8 +111,18 @@ async def fusedsearch(request, params_kw, *params, **kw):
elif isinstance(raw_fiids, list): elif isinstance(raw_fiids, list):
fiids = [str(item).strip() for item in raw_fiids] # 已经是列表 fiids = [str(item).strip() for item in raw_fiids] # 已经是列表
elif isinstance(raw_fiids, str): elif isinstance(raw_fiids, str):
# 处理逗号分隔的字符串或单个ID字符串 # fiids = [f.strip() for f in raw_fiids.split(',') if f.strip()]
fiids = [f.strip() for f in raw_fiids.split(',') if f.strip()] try:
# 尝试解析 JSON 字符串
parsed = json.loads(raw_fiids)
if isinstance(parsed, list):
fiids = [str(item).strip() for item in parsed] # JSON 数组转为字符串列表
else:
# 处理逗号分隔的字符串或单个 ID 字符串
fiids = [f.strip() for f in raw_fiids.split(',') if f.strip()]
except json.JSONDecodeError:
# 如果不是合法 JSON按逗号分隔
fiids = [f.strip() for f in raw_fiids.split(',') if f.strip()]
elif isinstance(raw_fiids, (int, float)): elif isinstance(raw_fiids, (int, float)):
fiids = [str(int(raw_fiids))] # 数值类型转为字符串列表 fiids = [str(int(raw_fiids))] # 数值类型转为字符串列表
else: else:
@ -140,8 +150,7 @@ async def fusedsearch(request, params_kw, *params, **kw):
except Exception as e: except Exception as e:
error(f"orgid 验证失败: {str(e)}") error(f"orgid 验证失败: {str(e)}")
return json.dumps({"status": "error", "message": str(e)}) return json.dumps({"status": "error", "message": str(e)})
ragfilemgr = RagFileMgr("fiids[0]") service_params = await get_service_params(orgid)
service_params = await ragfilemgr.get_service_params(orgid)
api_service = APIService() api_service = APIService()
start_time = time.time() start_time = time.time()
@ -276,9 +285,19 @@ async def fusedsearch(request, params_kw, *params, **kw):
timing_stats["total_time"] = time.time() - start_time timing_stats["total_time"] = time.time() - start_time
info(f"融合搜索完成,返回 {len(unique_results)} 条结果,总耗时: {timing_stats['total_time']:.3f}") info(f"融合搜索完成,返回 {len(unique_results)} 条结果,总耗时: {timing_stats['total_time']:.3f}")
# debug(f"results: {unique_results[:limit]},timing: {timing_stats}") # dify_result = []
# return {"results": unique_results[:limit], "timing": timing_stats} # for res in unique_results[:limit]:
# content = res.get('text', '')
# title = res.get('metadata', {}).get('filename', 'Untitled')
# document_id = res.get('metadata', {}).get('document_id', '')
# dify_result.append({
# 'metadata': {'document_id': document_id},
# 'title': title,
# 'content': content
# })
# info(f"融合搜索完成,返回 {len(dify_result)} 条结果,总耗时: {(time.time() - start_time):.3f} 秒")
# debug(f"result: {dify_result}")
# return dify_result
dify_records = [] dify_records = []
dify_result = [] dify_result = []
@ -291,18 +310,50 @@ async def fusedsearch(request, params_kw, *params, **kw):
document_id = res.get('metadata', {}).get('document_id', '') document_id = res.get('metadata', {}).get('document_id', '')
dify_records.append({ dify_records.append({
"content": content, "content": content,
"score": score, "title": title,
"title": title "metadata": {"document_id": document_id, "score": score},
}) })
dify_result.append({ dify_result.append({
"content": content, "content": content,
"title": title, "title": title,
"metadata": {"document_id": document_id} "metadata": {"document_id": document_id, "score": score},
}) })
info(f"融合搜索完成,返回 {len(dify_records)} 条结果,总耗时: {(time.time() - start_time):.3f}") info(f"融合搜索完成,返回 {len(dify_records)} 条结果,总耗时: {(time.time() - start_time):.3f}")
debug(f"records: {dify_records}, result: {dify_result}") debug(f"records: {dify_records}, result: {dify_result}")
return {"records": dify_records, "result": dify_result, "own":{"results": unique_results[:limit], "timing": timing_stats}} # return {"records": dify_records, "result": dify_result,"own": {"results": unique_results[:limit], "timing": timing_stats}}
return {"records": dify_records}
# dify_result = []
# for res in unique_results[:limit]:
# rerank_score = res.get('rerank_score', 0)
# score = 1 / (1 + math.exp(-rerank_score)) if rerank_score is not None else 1 - res.get('distance', 0)
# score = max(0.0, min(1.0, score))
# content = res.get('text', '')
# title = res.get('metadata', {}).get('filename', 'Untitled')
# document_id = res.get('metadata', {}).get('document_id', '')
# dify_result.append({
# "metadata": {
# "_source": "konwledge",
# "dataset_id":"111111",
# "dataset_name": "NVIDIA_GPU性能参数-RAG-V1.xlsx",
# "document_id": document_id,
# "document_name": "test.docx",
# "data_source_type": "upload_file",
# "segment_id": "7b391707-93bc-4654-80ae-7989f393b045",
# "retriever_from": "workflow",
# "score": score,
# "segment_hit_count": 7,
# "segment_word_count": 275,
# "segment_position": 5,
# "segment_index_node_hash": "1cd60b478221c9d4831a0b2af3e8b8581d94ecb53e8ffd46af687e8fc3077b73",
# "doc_metadata": None,
# "position":1
# },
# "title": title,
# "content": content
# })
# return {"result": dify_result}
except Exception as e: except Exception as e:
error(f"融合搜索失败: {str(e)}, 堆栈: {traceback.format_exc()}") error(f"融合搜索失败: {str(e)}, 堆栈: {traceback.format_exc()}")
return {"results": [], "timing": timing_stats} return {"results": [], "timing": timing_stats}

View File

@ -1,4 +1,5 @@
from ahserver.serverenv import get_serverenv from ahserver.serverenv import get_serverenv
from sqlor.dbpools import DBPools
async def sor_get_service_params(sor, orgid): async def sor_get_service_params(sor, orgid):
""" 根据 orgid 从数据库获取服务参数 (仅 upappid),假设 service_opts 表返回单条记录。 """ """ 根据 orgid 从数据库获取服务参数 (仅 upappid),假设 service_opts 表返回单条记录。 """
@ -16,8 +17,8 @@ async def sor_get_service_params(sor, orgid):
# 收集服务 ID # 收集服务 ID
service_ids = set() service_ids = set()
for key in ['embedding_id', 'vdb_id', 'reranker_id', 'triples_id', 'gdb_id', 'entities_id']: for key in ['embedding_id', 'vdb_id', 'reranker_id', 'triples_id', 'gdb_id', 'entities_id']:
if opts[key]: if opts[key]:
service_ids.add(opts[key]) service_ids.add(opts[key])
# 检查 service_ids 是否为空 # 检查 service_ids 是否为空
if not service_ids: if not service_ids:
@ -25,7 +26,7 @@ async def sor_get_service_params(sor, orgid):
return None return None
# 手动构造 IN 子句的 ID 列表 # 手动构造 IN 子句的 ID 列表
id_list = ','.join([f"'{id}'" for id in service_ids]) # 确保每个 ID 被单引号包裹 id_list = [id for id in service_ids] # 确保每个 ID 被单引号包裹
sql_services = """ sql_services = """
SELECT id, name, upappid SELECT id, name, upappid
FROM ragservices FROM ragservices
@ -46,19 +47,19 @@ async def sor_get_service_params(sor, orgid):
'entities': None 'entities': None
} }
for service in services_result: for service in services_result:
name = service['name'] name = service['name']
if name == 'bgem3嵌入': if name == 'bgem3嵌入':
service_params['embedding'] = service['upappid'] service_params['embedding'] = service['upappid']
elif name == 'milvus向量检索': elif name == 'milvus向量检索':
service_params['vdb'] = service['upappid'] service_params['vdb'] = service['upappid']
elif name == 'bgem2v3重排': elif name == 'bgem2v3重排':
service_params['reranker'] = service['upappid'] service_params['reranker'] = service['upappid']
elif name == 'mrebel三元组抽取': elif name == 'mrebel三元组抽取':
service_params['triples'] = service['upappid'] service_params['triples'] = service['upappid']
elif name == 'neo4j删除知识库': elif name == 'neo4j删除知识库':
service_params['gdb'] = service['upappid'] service_params['gdb'] = service['upappid']
elif name == 'small实体抽取': elif name == 'small实体抽取':
service_params['entities'] = service['upappid'] service_params['entities'] = service['upappid']
# 检查是否所有服务参数都已填充 # 检查是否所有服务参数都已填充
missing_services = [k for k, v in service_params.items() if v is None] missing_services = [k for k, v in service_params.items() if v is None]
@ -71,5 +72,5 @@ async def get_service_params(orgid):
db = DBPools() db = DBPools()
dbname = get_serverenv('get_module_dbname')('rag') dbname = get_serverenv('get_module_dbname')('rag')
async with db.sqlorContext(dbname) as sor: async with db.sqlorContext(dbname) as sor:
return await sor_get_server_params(sor, orgid) return await sor_get_service_params(sor, orgid)
return None return None