rag
This commit is contained in:
parent
e590c1084f
commit
b6d3f39081
@ -14,7 +14,7 @@ import time
|
|||||||
import uuid
|
import uuid
|
||||||
from datetime import datetime
|
from datetime import datetime
|
||||||
import traceback
|
import traceback
|
||||||
from filetxt.loader import fileloader
|
from filetxt.loader import fileloader,File2Text
|
||||||
from ahserver.serverenv import get_serverenv
|
from ahserver.serverenv import get_serverenv
|
||||||
from typing import List, Dict, Any
|
from typing import List, Dict, Any
|
||||||
from rag.service_opts import get_service_params, sor_get_service_params
|
from rag.service_opts import get_service_params, sor_get_service_params
|
||||||
@ -44,56 +44,20 @@ where a.orgid = b.orgid
|
|||||||
return r.quota, r.expired_date
|
return r.quota, r.expired_date
|
||||||
return None, None
|
return None, None
|
||||||
|
|
||||||
async def file_uploaded(self, request, ns, userid):
|
async def get_doucment_chunks(self, realpath, timings):
|
||||||
"""将文档插入 Milvus 并抽取三元组到 Neo4j"""
|
"""加载文件并进行文本分片"""
|
||||||
debug(f'Received ns: {ns=}')
|
debug(f"加载文件: {realpath}")
|
||||||
env = request._run_ns
|
start_load = time.time()
|
||||||
realpath = ns.get('realpath', '')
|
supported_formats = File2Text.supported_types()
|
||||||
fiid = ns.get('fiid', '')
|
debug(f"支持的文件格式:{supported_formats}")
|
||||||
id = ns.get('id', '')
|
|
||||||
orgid = ns.get('ownerid', '')
|
|
||||||
hashvalue = ns.get('hashvalue', '')
|
|
||||||
db_type = ''
|
|
||||||
|
|
||||||
api_service = APIService()
|
|
||||||
debug(
|
|
||||||
f'Inserting document: file_path={realpath}, userid={orgid}, db_type={db_type}, knowledge_base_id={fiid}, document_id={id}')
|
|
||||||
|
|
||||||
timings = {}
|
|
||||||
start_total = time.time()
|
|
||||||
|
|
||||||
service_params = await get_service_params(orgid)
|
|
||||||
chunks = await self.get_doucment_chunks(realpath)
|
|
||||||
embeddings = await self.docs_embedding(chunks)
|
|
||||||
await self.embedding_2_vdb(id, fiid, orgid, realpath, embedding)
|
|
||||||
triples = await self.get_triples(chunks)
|
|
||||||
await self.triple2graphdb(id, fiid, orgid, realpath, triples)
|
|
||||||
return
|
|
||||||
try:
|
|
||||||
if not orgid or not fiid or not id:
|
|
||||||
raise ValueError("orgid、fiid 和 id 不能为空")
|
|
||||||
if len(orgid) > 32 or len(fiid) > 255:
|
|
||||||
raise ValueError("orgid 或 fiid 的长度超出限制")
|
|
||||||
if not os.path.exists(realpath):
|
|
||||||
raise ValueError(f"文件 {realpath} 不存在")
|
|
||||||
|
|
||||||
# 获取服务参数
|
|
||||||
service_params = await get_service_params(orgid)
|
|
||||||
if not service_params:
|
|
||||||
raise ValueError("无法获取服务参数")
|
|
||||||
|
|
||||||
supported_formats = {'pdf', 'docx', 'xlsx', 'pptx', 'csv', 'txt'}
|
|
||||||
ext = realpath.rsplit('.', 1)[1].lower() if '.' in realpath else ''
|
ext = realpath.rsplit('.', 1)[1].lower() if '.' in realpath else ''
|
||||||
if ext not in supported_formats:
|
if ext not in supported_formats:
|
||||||
raise ValueError(f"不支持的文件格式: {ext}, 支持的格式: {', '.join(supported_formats)}")
|
raise ValueError(f"不支持的文件格式: {ext}, 支持的格式: {', '.join(supported_formats)}")
|
||||||
|
|
||||||
debug(f"加载文件: {realpath}")
|
|
||||||
start_load = time.time()
|
|
||||||
text = fileloader(realpath)
|
text = fileloader(realpath)
|
||||||
# debug(f"处理后的文件内容是:{text=}")
|
|
||||||
text = re.sub(r'[^\u4e00-\u9fa5a-zA-Z0-9\s.;,\n/]', '', text)
|
text = re.sub(r'[^\u4e00-\u9fa5a-zA-Z0-9\s.;,\n/]', '', text)
|
||||||
timings["load_file"] = time.time() - start_load
|
timings["load_file"] = time.time() - start_load
|
||||||
debug(f"加载文件耗时: {timings['load_file']:.2f} 秒, 文本长度: {len(text)}")
|
debug(f"加载文件耗时: {timings['load_file']:.2f} 秒, 文本长度: {len(text)}")
|
||||||
|
|
||||||
if not text or not text.strip():
|
if not text or not text.strip():
|
||||||
raise ValueError(f"文件 {realpath} 加载为空")
|
raise ValueError(f"文件 {realpath} 加载为空")
|
||||||
|
|
||||||
@ -101,27 +65,28 @@ where a.orgid = b.orgid
|
|||||||
text_splitter = RecursiveCharacterTextSplitter(
|
text_splitter = RecursiveCharacterTextSplitter(
|
||||||
chunk_size=500,
|
chunk_size=500,
|
||||||
chunk_overlap=100,
|
chunk_overlap=100,
|
||||||
length_function=len)
|
length_function=len
|
||||||
|
)
|
||||||
debug("开始分片文件内容")
|
debug("开始分片文件内容")
|
||||||
start_split = time.time()
|
start_split = time.time()
|
||||||
chunks = text_splitter.split_documents([document])
|
chunks = text_splitter.split_documents([document])
|
||||||
timings["split_text"] = time.time() - start_split
|
timings["split_text"] = time.time() - start_split
|
||||||
debug(
|
debug(f"文本分片耗时: {timings['split_text']:.2f} 秒, 分片数量: {len(chunks)}")
|
||||||
f"文本分片耗时: {timings['split_text']:.2f} 秒, 分片数量: {len(chunks)}, 分片内容: {[chunk.page_content[:50] for chunk in chunks[:5]]}")
|
|
||||||
debug(f"分片内容: {[chunk.page_content[:100] + '...' for chunk in chunks]}")
|
|
||||||
if not chunks:
|
if not chunks:
|
||||||
raise ValueError(f"文件 {realpath} 未生成任何文档块")
|
raise ValueError(f"文件 {realpath} 未生成任何文档块")
|
||||||
|
|
||||||
filename = os.path.basename(realpath).rsplit('.', 1)[0]
|
return chunks
|
||||||
upload_time = datetime.now().isoformat()
|
|
||||||
|
|
||||||
|
async def docs_embedding(self, request, chunks, service_params, userid, timings):
|
||||||
|
"""调用嵌入服务生成向量"""
|
||||||
debug("调用嵌入服务生成向量")
|
debug("调用嵌入服务生成向量")
|
||||||
start_embedding = time.time()
|
start_embedding = time.time()
|
||||||
texts = [chunk.page_content for chunk in chunks]
|
texts = [chunk.page_content for chunk in chunks]
|
||||||
embeddings = []
|
embeddings = []
|
||||||
for i in range(0, len(texts), 10): # 每次处理 10 个文本块
|
for i in range(0, len(texts), 10):
|
||||||
batch_texts = texts[i:i + 10]
|
batch_texts = texts[i:i + 10]
|
||||||
batch_embeddings = await api_service.get_embeddings(
|
batch_embeddings = await APIService().get_embeddings(
|
||||||
request=request,
|
request=request,
|
||||||
texts=batch_texts,
|
texts=batch_texts,
|
||||||
upappid=service_params['embedding'],
|
upappid=service_params['embedding'],
|
||||||
@ -129,14 +94,24 @@ where a.orgid = b.orgid
|
|||||||
user=userid
|
user=userid
|
||||||
)
|
)
|
||||||
embeddings.extend(batch_embeddings)
|
embeddings.extend(batch_embeddings)
|
||||||
|
|
||||||
if not embeddings or not all(len(vec) == 1024 for vec in embeddings):
|
if not embeddings or not all(len(vec) == 1024 for vec in embeddings):
|
||||||
raise ValueError("所有嵌入向量必须是长度为 1024 的浮点数列表")
|
raise ValueError("所有嵌入向量必须是长度为 1024 的浮点数列表")
|
||||||
|
|
||||||
timings["generate_embeddings"] = time.time() - start_embedding
|
timings["generate_embeddings"] = time.time() - start_embedding
|
||||||
debug(f"生成嵌入向量耗时: {timings['generate_embeddings']:.2f} 秒, 嵌入数量: {len(embeddings)}")
|
debug(f"生成嵌入向量耗时: {timings['generate_embeddings']:.2f} 秒, 嵌入数量: {len(embeddings)}")
|
||||||
|
return embeddings
|
||||||
|
|
||||||
chunks_data = []
|
async def embedding_2_vdb(self, request, chunks, embeddings, realpath, orgid, fiid, id, service_params, userid,
|
||||||
for i, chunk in enumerate(chunks):
|
db_type, timings):
|
||||||
chunks_data.append({
|
"""准备数据并插入 Milvus"""
|
||||||
|
debug(f"准备数据并调用插入文件端点: {realpath}")
|
||||||
|
filename = os.path.basename(realpath).rsplit('.', 1)[0]
|
||||||
|
ext = realpath.rsplit('.', 1)[1].lower() if '.' in realpath else ''
|
||||||
|
upload_time = datetime.now().isoformat()
|
||||||
|
|
||||||
|
chunks_data = [
|
||||||
|
{
|
||||||
"userid": orgid,
|
"userid": orgid,
|
||||||
"knowledge_base_id": fiid,
|
"knowledge_base_id": fiid,
|
||||||
"text": chunk.page_content,
|
"text": chunk.page_content,
|
||||||
@ -146,13 +121,14 @@ where a.orgid = b.orgid
|
|||||||
"file_path": realpath,
|
"file_path": realpath,
|
||||||
"upload_time": upload_time,
|
"upload_time": upload_time,
|
||||||
"file_type": ext,
|
"file_type": ext,
|
||||||
})
|
}
|
||||||
|
for i, chunk in enumerate(chunks)
|
||||||
|
]
|
||||||
|
|
||||||
debug(f"调用插入文件端点: {realpath}")
|
|
||||||
start_milvus = time.time()
|
start_milvus = time.time()
|
||||||
for i in range(0, len(chunks_data), 10): # 每次处理 10 条数据
|
for i in range(0, len(chunks_data), 10):
|
||||||
batch_chunks = chunks_data[i:i + 10]
|
batch_chunks = chunks_data[i:i + 10]
|
||||||
result = await api_service.milvus_insert_document(
|
result = await APIService().milvus_insert_document(
|
||||||
request=request,
|
request=request,
|
||||||
chunks=batch_chunks,
|
chunks=batch_chunks,
|
||||||
db_type=db_type,
|
db_type=db_type,
|
||||||
@ -162,23 +138,19 @@ where a.orgid = b.orgid
|
|||||||
)
|
)
|
||||||
if result.get("status") != "success":
|
if result.get("status") != "success":
|
||||||
raise ValueError(result.get("message", "Milvus 插入失败"))
|
raise ValueError(result.get("message", "Milvus 插入失败"))
|
||||||
|
|
||||||
timings["insert_milvus"] = time.time() - start_milvus
|
timings["insert_milvus"] = time.time() - start_milvus
|
||||||
debug(f"Milvus 插入耗时: {timings['insert_milvus']:.2f} 秒")
|
debug(f"Milvus 插入耗时: {timings['insert_milvus']:.2f} 秒")
|
||||||
|
return chunks_data
|
||||||
|
|
||||||
if result.get("status") != "success":
|
async def get_triples(self, request, chunks, service_params, userid, timings):
|
||||||
timings["total"] = time.time() - start_total
|
"""调用三元组抽取服务"""
|
||||||
return {"status": "error", "document_id": id, "timings": timings,
|
|
||||||
"message": result.get("message", "未知错误"), "status_code": 400}
|
|
||||||
|
|
||||||
debug("调用三元组抽取服务")
|
debug("调用三元组抽取服务")
|
||||||
start_triples = time.time()
|
start_triples = time.time()
|
||||||
unique_triples = []
|
|
||||||
try:
|
|
||||||
chunk_texts = [doc.page_content for doc in chunks]
|
chunk_texts = [doc.page_content for doc in chunks]
|
||||||
debug(f"处理 {len(chunk_texts)} 个分片进行三元组抽取")
|
|
||||||
triples = []
|
triples = []
|
||||||
for i, chunk in enumerate(chunk_texts):
|
for i, chunk in enumerate(chunk_texts):
|
||||||
result = await api_service.extract_triples(
|
result = await APIService().extract_triples(
|
||||||
request=request,
|
request=request,
|
||||||
text=chunk,
|
text=chunk,
|
||||||
upappid=service_params['triples'],
|
upappid=service_params['triples'],
|
||||||
@ -187,10 +159,11 @@ where a.orgid = b.orgid
|
|||||||
)
|
)
|
||||||
if isinstance(result, list):
|
if isinstance(result, list):
|
||||||
triples.extend(result)
|
triples.extend(result)
|
||||||
debug(f"分片 {i + 1} 抽取到 {len(result)} 个三元组: {result[:5]}")
|
debug(f"分片 {i + 1} 抽取到 {len(result)} 个三元组")
|
||||||
else:
|
else:
|
||||||
error(f"分片 {i + 1} 处理失败: {str(result)}")
|
error(f"分片 {i + 1} 处理失败: {str(result)}")
|
||||||
|
|
||||||
|
unique_triples = []
|
||||||
seen = set()
|
seen = set()
|
||||||
for t in triples:
|
for t in triples:
|
||||||
identifier = (t['head'].lower(), t['tail'].lower(), t['type'].lower())
|
identifier = (t['head'].lower(), t['tail'].lower(), t['type'].lower())
|
||||||
@ -208,15 +181,17 @@ where a.orgid = b.orgid
|
|||||||
break
|
break
|
||||||
|
|
||||||
timings["extract_triples"] = time.time() - start_triples
|
timings["extract_triples"] = time.time() - start_triples
|
||||||
debug(
|
debug(f"三元组抽取耗时: {timings['extract_triples']:.2f} 秒, 抽取到 {len(unique_triples)} 个三元组")
|
||||||
f"三元组抽取耗时: {timings['extract_triples']:.2f} 秒, 抽取到 {len(unique_triples)} 个三元组: {unique_triples[:5]}")
|
return unique_triples
|
||||||
|
|
||||||
if unique_triples:
|
async def triple2graphdb(self, request, unique_triples, id, fiid, orgid, service_params, userid, timings):
|
||||||
debug(f"抽取到 {len(unique_triples)} 个三元组,调用 Neo4j 服务插入")
|
"""调用 Neo4j 插入三元组"""
|
||||||
|
debug(f"插入 {len(unique_triples)} 个三元组到 Neo4j")
|
||||||
start_neo4j = time.time()
|
start_neo4j = time.time()
|
||||||
for i in range(0, len(unique_triples), 30): # 每次插入 30 个三元组
|
if unique_triples:
|
||||||
|
for i in range(0, len(unique_triples), 30):
|
||||||
batch_triples = unique_triples[i:i + 30]
|
batch_triples = unique_triples[i:i + 30]
|
||||||
neo4j_result = await api_service.neo4j_insert_triples(
|
neo4j_result = await APIService().neo4j_insert_triples(
|
||||||
request=request,
|
request=request,
|
||||||
triples=batch_triples,
|
triples=batch_triples,
|
||||||
document_id=id,
|
document_id=id,
|
||||||
@ -226,41 +201,82 @@ where a.orgid = b.orgid
|
|||||||
apiname="neo4j/inserttriples",
|
apiname="neo4j/inserttriples",
|
||||||
user=userid
|
user=userid
|
||||||
)
|
)
|
||||||
debug(f"Neo4j 服务响应: {neo4j_result}")
|
|
||||||
if neo4j_result.get("status") != "success":
|
if neo4j_result.get("status") != "success":
|
||||||
timings["insert_neo4j"] = time.time() - start_neo4j
|
raise ValueError(f"Neo4j 三元组插入失败: {neo4j_result.get('message', '未知错误')}")
|
||||||
timings["total"] = time.time() - start_total
|
info(f"文件三元组成功插入 Neo4j: {neo4j_result.get('message')}")
|
||||||
return {
|
|
||||||
"status": "error",
|
|
||||||
"document_id": id,
|
|
||||||
"collection_name": "ragdb",
|
|
||||||
"timings": timings,
|
|
||||||
"message": f"Neo4j 三元组插入失败: {neo4j_result.get('message', '未知错误')}",
|
|
||||||
"status_code": 400
|
|
||||||
}
|
|
||||||
info(f"文件 {realpath} 三元组成功插入 Neo4j: {neo4j_result.get('message')}")
|
|
||||||
timings["insert_neo4j"] = time.time() - start_neo4j
|
timings["insert_neo4j"] = time.time() - start_neo4j
|
||||||
debug(f"Neo4j 插入耗时: {timings['insert_neo4j']:.2f} 秒")
|
debug(f"Neo4j 插入耗时: {timings['insert_neo4j']:.2f} 秒")
|
||||||
else:
|
else:
|
||||||
debug(f"文件 {realpath} 未抽取到三元组")
|
debug("未抽取到三元组")
|
||||||
timings["insert_neo4j"] = 0.0
|
timings["insert_neo4j"] = 0.0
|
||||||
|
|
||||||
except Exception as e:
|
async def delete_from_milvus(self, request, orgid, realpath, fiid, id, service_params, userid, db_type):
|
||||||
timings["extract_triples"] = time.time() - start_triples if "extract_triples" not in timings else \
|
"""调用 Milvus 删除文档"""
|
||||||
timings["extract_triples"]
|
debug(f"调用删除文件端点: userid={orgid}, file_path={realpath}, knowledge_base_id={fiid}, document_id={id}")
|
||||||
timings["insert_neo4j"] = time.time() - start_neo4j if "insert_neo4j" not in timings else timings[
|
milvus_result = await APIService().milvus_delete_document(
|
||||||
"insert_neo4j"]
|
request=request,
|
||||||
debug(f"处理三元组或 Neo4j 插入失败: {str(e)}, 堆栈: {traceback.format_exc()}")
|
userid=orgid,
|
||||||
timings["total"] = time.time() - start_total
|
file_path=realpath,
|
||||||
return {
|
knowledge_base_id=fiid,
|
||||||
"status": "success",
|
document_id=id,
|
||||||
"document_id": id,
|
db_type=db_type,
|
||||||
"collection_name": "ragdb",
|
upappid=service_params['vdb'],
|
||||||
"timings": timings,
|
apiname="milvus/deletedocument",
|
||||||
"unique_triples": unique_triples,
|
user=userid
|
||||||
"message": f"文件 {realpath} 成功嵌入,但三元组处理或 Neo4j 插入失败: {str(e)}",
|
)
|
||||||
"status_code": 200
|
if milvus_result.get("status") != "success":
|
||||||
}
|
raise ValueError(milvus_result.get("message", "Milvus 删除失败"))
|
||||||
|
|
||||||
|
async def delete_from_neo4j(self, request, id, service_params, userid):
|
||||||
|
"""调用 Neo4j 删除文档"""
|
||||||
|
debug(f"调用 Neo4j 删除文档端点: document_id={id}")
|
||||||
|
neo4j_result = await APIService().neo4j_delete_document(
|
||||||
|
request=request,
|
||||||
|
document_id=id,
|
||||||
|
upappid=service_params['gdb'],
|
||||||
|
apiname="neo4j/deletedocument",
|
||||||
|
user=userid
|
||||||
|
)
|
||||||
|
if neo4j_result.get("status") != "success":
|
||||||
|
raise ValueError(neo4j_result.get("message", "Neo4j 删除失败"))
|
||||||
|
nodes_deleted = neo4j_result.get("nodes_deleted", 0)
|
||||||
|
rels_deleted = neo4j_result.get("rels_deleted", 0)
|
||||||
|
info(f"成功删除 document_id={id} 的 {nodes_deleted} 个 Neo4j 节点和 {rels_deleted} 个关系")
|
||||||
|
return nodes_deleted, rels_deleted
|
||||||
|
|
||||||
|
async def file_uploaded(self, request, ns, userid):
|
||||||
|
"""将文档插入 Milvus 并抽取三元组到 Neo4j"""
|
||||||
|
debug(f'Received ns: {ns=}')
|
||||||
|
env = request._run_ns
|
||||||
|
realpath = ns.get('realpath', '')
|
||||||
|
fiid = ns.get('fiid', '')
|
||||||
|
id = ns.get('id', '')
|
||||||
|
orgid = ns.get('ownerid', '')
|
||||||
|
db_type = ''
|
||||||
|
|
||||||
|
debug(f'Inserting document: file_path={realpath}, userid={orgid}, db_type={db_type}, knowledge_base_id={fiid}, document_id={id}')
|
||||||
|
|
||||||
|
timings = {}
|
||||||
|
start_total = time.time()
|
||||||
|
|
||||||
|
try:
|
||||||
|
if not orgid or not fiid or not id:
|
||||||
|
raise ValueError("orgid、fiid 和 id 不能为空")
|
||||||
|
if len(orgid) > 32 or len(fiid) > 255:
|
||||||
|
raise ValueError("orgid 或 fiid 的长度超出限制")
|
||||||
|
if not os.path.exists(realpath):
|
||||||
|
raise ValueError(f"文件 {realpath} 不存在")
|
||||||
|
|
||||||
|
# 获取服务参数
|
||||||
|
service_params = await get_service_params(orgid)
|
||||||
|
if not service_params:
|
||||||
|
raise ValueError("无法获取服务参数")
|
||||||
|
|
||||||
|
chunks = await self.get_doucment_chunks(realpath, timings)
|
||||||
|
embeddings = await self.docs_embedding(request, chunks, service_params, userid, timings)
|
||||||
|
await self.embedding_2_vdb(request, chunks, embeddings, realpath, orgid, fiid, id, service_params, userid,db_type, timings)
|
||||||
|
triples = await self.get_triples(request, chunks, service_params, userid, timings)
|
||||||
|
await self.triple2graphdb(request, triples, id, fiid, orgid, service_params, userid, timings)
|
||||||
|
|
||||||
timings["total"] = time.time() - start_total
|
timings["total"] = time.time() - start_total
|
||||||
debug(f"总耗时: {timings['total']:.2f} 秒")
|
debug(f"总耗时: {timings['total']:.2f} 秒")
|
||||||
@ -270,11 +286,10 @@ where a.orgid = b.orgid
|
|||||||
"document_id": id,
|
"document_id": id,
|
||||||
"collection_name": "ragdb",
|
"collection_name": "ragdb",
|
||||||
"timings": timings,
|
"timings": timings,
|
||||||
"unique_triples": unique_triples,
|
"unique_triples": triples,
|
||||||
"message": f"文件 {realpath} 成功嵌入并处理三元组",
|
"message": f"文件 {realpath} 成功嵌入并处理三元组",
|
||||||
"status_code": 200
|
"status_code": 200
|
||||||
}
|
}
|
||||||
|
|
||||||
except Exception as e:
|
except Exception as e:
|
||||||
error(f"插入文档失败: {str(e)}, 堆栈: {traceback.format_exc()}")
|
error(f"插入文档失败: {str(e)}, 堆栈: {traceback.format_exc()}")
|
||||||
timings["total"] = time.time() - start_total
|
timings["total"] = time.time() - start_total
|
||||||
@ -290,9 +305,8 @@ where a.orgid = b.orgid
|
|||||||
async def file_deleted(self, request, recs, userid):
|
async def file_deleted(self, request, recs, userid):
|
||||||
"""删除用户指定文件数据,包括 Milvus 和 Neo4j 中的记录"""
|
"""删除用户指定文件数据,包括 Milvus 和 Neo4j 中的记录"""
|
||||||
if not isinstance(recs, list):
|
if not isinstance(recs, list):
|
||||||
recs = [recs] # 确保 recs 是列表,即使传入单个记录
|
recs = [recs]
|
||||||
results = []
|
results = []
|
||||||
api_service = APIService()
|
|
||||||
total_nodes_deleted = 0
|
total_nodes_deleted = 0
|
||||||
total_rels_deleted = 0
|
total_rels_deleted = 0
|
||||||
|
|
||||||
@ -310,46 +324,24 @@ where a.orgid = b.orgid
|
|||||||
if missing_fields:
|
if missing_fields:
|
||||||
raise ValueError(f"缺少必填字段: {', '.join(missing_fields)}")
|
raise ValueError(f"缺少必填字段: {', '.join(missing_fields)}")
|
||||||
|
|
||||||
# 获取服务参数
|
|
||||||
service_params = await get_service_params(orgid)
|
service_params = await get_service_params(orgid)
|
||||||
if not service_params:
|
if not service_params:
|
||||||
raise ValueError("无法获取服务参数")
|
raise ValueError("无法获取服务参数")
|
||||||
|
|
||||||
debug(f"调用删除文件端点: userid={orgid}, file_path={realpath}, knowledge_base_id={fiid}, document_id={id}")
|
# 调用 Milvus 删除
|
||||||
milvus_result = await api_service.milvus_delete_document(
|
await self.delete_from_milvus(request, orgid, realpath, fiid, id, service_params, userid, db_type)
|
||||||
request=request,
|
|
||||||
userid=orgid,
|
|
||||||
file_path=realpath,
|
|
||||||
knowledge_base_id=fiid,
|
|
||||||
document_id=id,
|
|
||||||
db_type=db_type,
|
|
||||||
upappid=service_params['vdb'],
|
|
||||||
apiname="milvus/deletedocument",
|
|
||||||
user=userid
|
|
||||||
)
|
|
||||||
|
|
||||||
if milvus_result.get("status") != "success":
|
|
||||||
raise ValueError(milvus_result.get("message", "Milvus 删除失败"))
|
|
||||||
|
|
||||||
|
# 调用 Neo4j 删除
|
||||||
neo4j_deleted_nodes = 0
|
neo4j_deleted_nodes = 0
|
||||||
neo4j_deleted_rels = 0
|
neo4j_deleted_rels = 0
|
||||||
debug(f"调用 Neo4j 删除文档端点: document_id={id}")
|
try:
|
||||||
neo4j_result = await api_service.neo4j_delete_document(
|
nodes_deleted, rels_deleted = await self.delete_from_neo4j(request, id, service_params, userid)
|
||||||
request=request,
|
|
||||||
document_id=id,
|
|
||||||
upappid=service_params['gdb'],
|
|
||||||
apiname="neo4j/deletedocument",
|
|
||||||
user=userid
|
|
||||||
)
|
|
||||||
if neo4j_result.get("status") != "success":
|
|
||||||
raise ValueError(neo4j_result.get("message", "Neo4j 删除失败"))
|
|
||||||
nodes_deleted = neo4j_result.get("nodes_deleted", 0)
|
|
||||||
rels_deleted = neo4j_result.get("rels_deleted", 0)
|
|
||||||
neo4j_deleted_nodes += nodes_deleted
|
neo4j_deleted_nodes += nodes_deleted
|
||||||
neo4j_deleted_rels += rels_deleted
|
neo4j_deleted_rels += rels_deleted
|
||||||
total_nodes_deleted += nodes_deleted
|
total_nodes_deleted += nodes_deleted
|
||||||
total_rels_deleted += rels_deleted
|
total_rels_deleted += rels_deleted
|
||||||
info(f"成功删除 document_id={id} 的 {nodes_deleted} 个 Neo4j 节点和 {rels_deleted} 个关系")
|
except Exception as e:
|
||||||
|
error(f"删除 document_id={id} 的 Neo4j 数据失败: {str(e)}")
|
||||||
|
|
||||||
results.append({
|
results.append({
|
||||||
"status": "success",
|
"status": "success",
|
||||||
@ -378,31 +370,31 @@ where a.orgid = b.orgid
|
|||||||
"status_code": 200 if all(r["status"] == "success" for r in results) else 207
|
"status_code": 200 if all(r["status"] == "success" for r in results) else 207
|
||||||
}
|
}
|
||||||
|
|
||||||
async def test_ragfilemgr():
|
# async def test_ragfilemgr():
|
||||||
"""测试 RagFileMgr 类的 get_service_params"""
|
# """测试 RagFileMgr 类的 get_service_params"""
|
||||||
print("初始化数据库连接池...")
|
# print("初始化数据库连接池...")
|
||||||
dbs = {
|
# dbs = {
|
||||||
"kyrag": {
|
# "kyrag": {
|
||||||
"driver": "aiomysql",
|
# "driver": "aiomysql",
|
||||||
"async_mode": True,
|
# "async_mode": True,
|
||||||
"coding": "utf8",
|
# "coding": "utf8",
|
||||||
"maxconn": 100,
|
# "maxconn": 100,
|
||||||
"dbname": "kyrag",
|
# "dbname": "kyrag",
|
||||||
"kwargs": {
|
# "kwargs": {
|
||||||
"user": "test",
|
# "user": "test",
|
||||||
"db": "kyrag",
|
# "db": "kyrag",
|
||||||
"password": "QUZVcXg5V1p1STMybG5Ia6mX9D0v7+g=",
|
# "password": "QUZVcXg5V1p1STMybG5Ia6mX9D0v7+g=",
|
||||||
"host": "db"
|
# "host": "db"
|
||||||
}
|
# }
|
||||||
}
|
# }
|
||||||
}
|
# }
|
||||||
DBPools(dbs)
|
# DBPools(dbs)
|
||||||
|
#
|
||||||
ragfilemgr = RagFileMgr()
|
# ragfilemgr = RagFileMgr()
|
||||||
orgid = "04J6VbxLqB_9RPMcgOv_8"
|
# orgid = "04J6VbxLqB_9RPMcgOv_8"
|
||||||
result = await ragfilemgr.get_service_params(orgid)
|
# result = await get_service_params(orgid)
|
||||||
print(f"get_service_params 结果: {result}")
|
# print(f"get_service_params 结果: {result}")
|
||||||
|
#
|
||||||
|
#
|
||||||
if __name__ == "__main__":
|
# if __name__ == "__main__":
|
||||||
asyncio.run(test_ragfilemgr())
|
# asyncio.run(test_ragfilemgr())
|
||||||
|
|||||||
@ -1,11 +1,11 @@
|
|||||||
from rag.uapi_service import APIService
|
from rag.uapi_service import APIService
|
||||||
from rag.folderinfo import RagFileMgr
|
|
||||||
from sqlor.dbpools import DBPools
|
from sqlor.dbpools import DBPools
|
||||||
from appPublic.log import debug, error, info
|
from appPublic.log import debug, error, info
|
||||||
import time
|
import time
|
||||||
import traceback
|
import traceback
|
||||||
import json
|
import json
|
||||||
import math
|
import math
|
||||||
|
from rag.service_opts import get_service_params, sor_get_service_params
|
||||||
|
|
||||||
helptext = """kyrag API:
|
helptext = """kyrag API:
|
||||||
|
|
||||||
@ -82,7 +82,7 @@ async def fusedsearch(request, params_kw, *params, **kw):
|
|||||||
# orgid = "04J6VbxLqB_9RPMcgOv_8"
|
# orgid = "04J6VbxLqB_9RPMcgOv_8"
|
||||||
# userid = "04J6VbxLqB_9RPMcgOv_8"
|
# userid = "04J6VbxLqB_9RPMcgOv_8"
|
||||||
query = params_kw.get('query', '')
|
query = params_kw.get('query', '')
|
||||||
# 统一模式处理 limit 参数
|
# 统一模式处理 limit 参数,为了对接dify和coze
|
||||||
raw_limit = params_kw.get('limit') or (
|
raw_limit = params_kw.get('limit') or (
|
||||||
params_kw.get('retrieval_setting', {}).get('top_k')
|
params_kw.get('retrieval_setting', {}).get('top_k')
|
||||||
if isinstance(params_kw.get('retrieval_setting'), dict)
|
if isinstance(params_kw.get('retrieval_setting'), dict)
|
||||||
@ -103,7 +103,7 @@ async def fusedsearch(request, params_kw, *params, **kw):
|
|||||||
else:
|
else:
|
||||||
limit = 5 # 其他意外类型使用默认值
|
limit = 5 # 其他意外类型使用默认值
|
||||||
debug(f"limit: {limit}")
|
debug(f"limit: {limit}")
|
||||||
raw_fiids = params_kw.get('fiids') or params_kw.get('knowledge_id')
|
raw_fiids = params_kw.get('fiids') or params_kw.get('knowledge_id') #
|
||||||
|
|
||||||
# 标准化为列表格式
|
# 标准化为列表格式
|
||||||
if raw_fiids is None:
|
if raw_fiids is None:
|
||||||
@ -111,8 +111,18 @@ async def fusedsearch(request, params_kw, *params, **kw):
|
|||||||
elif isinstance(raw_fiids, list):
|
elif isinstance(raw_fiids, list):
|
||||||
fiids = [str(item).strip() for item in raw_fiids] # 已经是列表
|
fiids = [str(item).strip() for item in raw_fiids] # 已经是列表
|
||||||
elif isinstance(raw_fiids, str):
|
elif isinstance(raw_fiids, str):
|
||||||
|
# fiids = [f.strip() for f in raw_fiids.split(',') if f.strip()]
|
||||||
|
try:
|
||||||
|
# 尝试解析 JSON 字符串
|
||||||
|
parsed = json.loads(raw_fiids)
|
||||||
|
if isinstance(parsed, list):
|
||||||
|
fiids = [str(item).strip() for item in parsed] # JSON 数组转为字符串列表
|
||||||
|
else:
|
||||||
# 处理逗号分隔的字符串或单个 ID 字符串
|
# 处理逗号分隔的字符串或单个 ID 字符串
|
||||||
fiids = [f.strip() for f in raw_fiids.split(',') if f.strip()]
|
fiids = [f.strip() for f in raw_fiids.split(',') if f.strip()]
|
||||||
|
except json.JSONDecodeError:
|
||||||
|
# 如果不是合法 JSON,按逗号分隔
|
||||||
|
fiids = [f.strip() for f in raw_fiids.split(',') if f.strip()]
|
||||||
elif isinstance(raw_fiids, (int, float)):
|
elif isinstance(raw_fiids, (int, float)):
|
||||||
fiids = [str(int(raw_fiids))] # 数值类型转为字符串列表
|
fiids = [str(int(raw_fiids))] # 数值类型转为字符串列表
|
||||||
else:
|
else:
|
||||||
@ -140,8 +150,7 @@ async def fusedsearch(request, params_kw, *params, **kw):
|
|||||||
except Exception as e:
|
except Exception as e:
|
||||||
error(f"orgid 验证失败: {str(e)}")
|
error(f"orgid 验证失败: {str(e)}")
|
||||||
return json.dumps({"status": "error", "message": str(e)})
|
return json.dumps({"status": "error", "message": str(e)})
|
||||||
ragfilemgr = RagFileMgr("fiids[0]")
|
service_params = await get_service_params(orgid)
|
||||||
service_params = await ragfilemgr.get_service_params(orgid)
|
|
||||||
|
|
||||||
api_service = APIService()
|
api_service = APIService()
|
||||||
start_time = time.time()
|
start_time = time.time()
|
||||||
@ -276,9 +285,19 @@ async def fusedsearch(request, params_kw, *params, **kw):
|
|||||||
timing_stats["total_time"] = time.time() - start_time
|
timing_stats["total_time"] = time.time() - start_time
|
||||||
info(f"融合搜索完成,返回 {len(unique_results)} 条结果,总耗时: {timing_stats['total_time']:.3f} 秒")
|
info(f"融合搜索完成,返回 {len(unique_results)} 条结果,总耗时: {timing_stats['total_time']:.3f} 秒")
|
||||||
|
|
||||||
# debug(f"results: {unique_results[:limit]},timing: {timing_stats}")
|
# dify_result = []
|
||||||
# return {"results": unique_results[:limit], "timing": timing_stats}
|
# for res in unique_results[:limit]:
|
||||||
|
# content = res.get('text', '')
|
||||||
|
# title = res.get('metadata', {}).get('filename', 'Untitled')
|
||||||
|
# document_id = res.get('metadata', {}).get('document_id', '')
|
||||||
|
# dify_result.append({
|
||||||
|
# 'metadata': {'document_id': document_id},
|
||||||
|
# 'title': title,
|
||||||
|
# 'content': content
|
||||||
|
# })
|
||||||
|
# info(f"融合搜索完成,返回 {len(dify_result)} 条结果,总耗时: {(time.time() - start_time):.3f} 秒")
|
||||||
|
# debug(f"result: {dify_result}")
|
||||||
|
# return dify_result
|
||||||
|
|
||||||
dify_records = []
|
dify_records = []
|
||||||
dify_result = []
|
dify_result = []
|
||||||
@ -291,18 +310,50 @@ async def fusedsearch(request, params_kw, *params, **kw):
|
|||||||
document_id = res.get('metadata', {}).get('document_id', '')
|
document_id = res.get('metadata', {}).get('document_id', '')
|
||||||
dify_records.append({
|
dify_records.append({
|
||||||
"content": content,
|
"content": content,
|
||||||
"score": score,
|
"title": title,
|
||||||
"title": title
|
"metadata": {"document_id": document_id, "score": score},
|
||||||
})
|
})
|
||||||
dify_result.append({
|
dify_result.append({
|
||||||
"content": content,
|
"content": content,
|
||||||
"title": title,
|
"title": title,
|
||||||
"metadata": {"document_id": document_id}
|
"metadata": {"document_id": document_id, "score": score},
|
||||||
})
|
})
|
||||||
|
|
||||||
info(f"融合搜索完成,返回 {len(dify_records)} 条结果,总耗时: {(time.time() - start_time):.3f} 秒")
|
info(f"融合搜索完成,返回 {len(dify_records)} 条结果,总耗时: {(time.time() - start_time):.3f} 秒")
|
||||||
debug(f"records: {dify_records}, result: {dify_result}")
|
debug(f"records: {dify_records}, result: {dify_result}")
|
||||||
return {"records": dify_records, "result": dify_result, "own":{"results": unique_results[:limit], "timing": timing_stats}}
|
# return {"records": dify_records, "result": dify_result,"own": {"results": unique_results[:limit], "timing": timing_stats}}
|
||||||
|
return {"records": dify_records}
|
||||||
|
|
||||||
|
# dify_result = []
|
||||||
|
# for res in unique_results[:limit]:
|
||||||
|
# rerank_score = res.get('rerank_score', 0)
|
||||||
|
# score = 1 / (1 + math.exp(-rerank_score)) if rerank_score is not None else 1 - res.get('distance', 0)
|
||||||
|
# score = max(0.0, min(1.0, score))
|
||||||
|
# content = res.get('text', '')
|
||||||
|
# title = res.get('metadata', {}).get('filename', 'Untitled')
|
||||||
|
# document_id = res.get('metadata', {}).get('document_id', '')
|
||||||
|
# dify_result.append({
|
||||||
|
# "metadata": {
|
||||||
|
# "_source": "konwledge",
|
||||||
|
# "dataset_id":"111111",
|
||||||
|
# "dataset_name": "NVIDIA_GPU性能参数-RAG-V1.xlsx",
|
||||||
|
# "document_id": document_id,
|
||||||
|
# "document_name": "test.docx",
|
||||||
|
# "data_source_type": "upload_file",
|
||||||
|
# "segment_id": "7b391707-93bc-4654-80ae-7989f393b045",
|
||||||
|
# "retriever_from": "workflow",
|
||||||
|
# "score": score,
|
||||||
|
# "segment_hit_count": 7,
|
||||||
|
# "segment_word_count": 275,
|
||||||
|
# "segment_position": 5,
|
||||||
|
# "segment_index_node_hash": "1cd60b478221c9d4831a0b2af3e8b8581d94ecb53e8ffd46af687e8fc3077b73",
|
||||||
|
# "doc_metadata": None,
|
||||||
|
# "position":1
|
||||||
|
# },
|
||||||
|
# "title": title,
|
||||||
|
# "content": content
|
||||||
|
# })
|
||||||
|
# return {"result": dify_result}
|
||||||
|
|
||||||
except Exception as e:
|
except Exception as e:
|
||||||
error(f"融合搜索失败: {str(e)}, 堆栈: {traceback.format_exc()}")
|
error(f"融合搜索失败: {str(e)}, 堆栈: {traceback.format_exc()}")
|
||||||
return {"results": [], "timing": timing_stats}
|
return {"results": [], "timing": timing_stats}
|
||||||
|
|||||||
@ -1,4 +1,5 @@
|
|||||||
from ahserver.serverenv import get_serverenv
|
from ahserver.serverenv import get_serverenv
|
||||||
|
from sqlor.dbpools import DBPools
|
||||||
|
|
||||||
async def sor_get_service_params(sor, orgid):
|
async def sor_get_service_params(sor, orgid):
|
||||||
""" 根据 orgid 从数据库获取服务参数 (仅 upappid),假设 service_opts 表返回单条记录。 """
|
""" 根据 orgid 从数据库获取服务参数 (仅 upappid),假设 service_opts 表返回单条记录。 """
|
||||||
@ -25,7 +26,7 @@ async def sor_get_service_params(sor, orgid):
|
|||||||
return None
|
return None
|
||||||
|
|
||||||
# 手动构造 IN 子句的 ID 列表
|
# 手动构造 IN 子句的 ID 列表
|
||||||
id_list = ','.join([f"'{id}'" for id in service_ids]) # 确保每个 ID 被单引号包裹
|
id_list = [id for id in service_ids] # 确保每个 ID 被单引号包裹
|
||||||
sql_services = """
|
sql_services = """
|
||||||
SELECT id, name, upappid
|
SELECT id, name, upappid
|
||||||
FROM ragservices
|
FROM ragservices
|
||||||
@ -71,5 +72,5 @@ async def get_service_params(orgid):
|
|||||||
db = DBPools()
|
db = DBPools()
|
||||||
dbname = get_serverenv('get_module_dbname')('rag')
|
dbname = get_serverenv('get_module_dbname')('rag')
|
||||||
async with db.sqlorContext(dbname) as sor:
|
async with db.sqlorContext(dbname) as sor:
|
||||||
return await sor_get_server_params(sor, orgid)
|
return await sor_get_service_params(sor, orgid)
|
||||||
return None
|
return None
|
||||||
|
|||||||
Loading…
x
Reference in New Issue
Block a user