bugfix
This commit is contained in:
parent
9e7ff6c71f
commit
e590c1084f
@ -17,6 +17,7 @@ import traceback
|
|||||||
from filetxt.loader import fileloader
|
from filetxt.loader import fileloader
|
||||||
from ahserver.serverenv import get_serverenv
|
from ahserver.serverenv import get_serverenv
|
||||||
from typing import List, Dict, Any
|
from typing import List, Dict, Any
|
||||||
|
from rag.service_opts import get_service_params, sor_get_service_params
|
||||||
import json
|
import json
|
||||||
|
|
||||||
class RagFileMgr(FileMgr):
|
class RagFileMgr(FileMgr):
|
||||||
@ -43,73 +44,6 @@ where a.orgid = b.orgid
|
|||||||
return r.quota, r.expired_date
|
return r.quota, r.expired_date
|
||||||
return None, None
|
return None, None
|
||||||
|
|
||||||
async def get_service_params(self,sor, orgid):
|
|
||||||
""" 根据 orgid 从数据库获取服务参数 (仅 upappid),假设 service_opts 表返回单条记录。 """
|
|
||||||
sql_opts = """
|
|
||||||
SELECT embedding_id, vdb_id, reranker_id, triples_id, gdb_id, entities_id
|
|
||||||
FROM service_opts
|
|
||||||
WHERE orgid = ${orgid}$
|
|
||||||
"""
|
|
||||||
opts_result = await sor.sqlExe(sql_opts, {"orgid": orgid})
|
|
||||||
if not opts_result:
|
|
||||||
error(f"未找到 orgid={orgid} 的服务配置")
|
|
||||||
return None
|
|
||||||
opts = opts_result[0]
|
|
||||||
|
|
||||||
# 收集服务 ID
|
|
||||||
service_ids = set()
|
|
||||||
for key in ['embedding_id', 'vdb_id', 'reranker_id', 'triples_id', 'gdb_id', 'entities_id']:
|
|
||||||
if opts[key]:
|
|
||||||
service_ids.add(opts[key])
|
|
||||||
|
|
||||||
# 检查 service_ids 是否为空
|
|
||||||
if not service_ids:
|
|
||||||
error(f"未找到任何服务 ID for orgid={orgid}")
|
|
||||||
return None
|
|
||||||
|
|
||||||
# 手动构造 IN 子句的 ID 列表
|
|
||||||
id_list = ','.join([f"'{id}'" for id in service_ids]) # 确保每个 ID 被单引号包裹
|
|
||||||
sql_services = """
|
|
||||||
SELECT id, name, upappid
|
|
||||||
FROM ragservices
|
|
||||||
WHERE id IN ${id_list}$
|
|
||||||
"""
|
|
||||||
services_result = await sor.sqlExe(sql_services, {'id_list': id_list})
|
|
||||||
if not services_result:
|
|
||||||
error(f"未找到服务 ID {service_ids} 的 ragservices 配置")
|
|
||||||
return None
|
|
||||||
|
|
||||||
# 构建服务参数字典,基于 name 字段匹配,仅存储 upappid
|
|
||||||
service_params = {
|
|
||||||
'embedding': None,
|
|
||||||
'vdb': None,
|
|
||||||
'reranker': None,
|
|
||||||
'triples': None,
|
|
||||||
'gdb': None,
|
|
||||||
'entities': None
|
|
||||||
}
|
|
||||||
for service in services_result:
|
|
||||||
name = service['name']
|
|
||||||
if name == 'bgem3嵌入':
|
|
||||||
service_params['embedding'] = service['upappid']
|
|
||||||
elif name == 'milvus向量检索':
|
|
||||||
service_params['vdb'] = service['upappid']
|
|
||||||
elif name == 'bgem2v3重排':
|
|
||||||
service_params['reranker'] = service['upappid']
|
|
||||||
elif name == 'mrebel三元组抽取':
|
|
||||||
service_params['triples'] = service['upappid']
|
|
||||||
elif name == 'neo4j删除知识库':
|
|
||||||
service_params['gdb'] = service['upappid']
|
|
||||||
elif name == 'small实体抽取':
|
|
||||||
service_params['entities'] = service['upappid']
|
|
||||||
|
|
||||||
# 检查是否所有服务参数都已填充
|
|
||||||
missing_services = [k for k, v in service_params.items() if v is None]
|
|
||||||
if missing_services:
|
|
||||||
error(f"未找到以下服务的配置: {missing_services}")
|
|
||||||
return None
|
|
||||||
return service_params
|
|
||||||
|
|
||||||
async def file_uploaded(self, request, ns, userid):
|
async def file_uploaded(self, request, ns, userid):
|
||||||
"""将文档插入 Milvus 并抽取三元组到 Neo4j"""
|
"""将文档插入 Milvus 并抽取三元组到 Neo4j"""
|
||||||
debug(f'Received ns: {ns=}')
|
debug(f'Received ns: {ns=}')
|
||||||
@ -128,45 +62,30 @@ where a.orgid = b.orgid
|
|||||||
timings = {}
|
timings = {}
|
||||||
start_total = time.time()
|
start_total = time.time()
|
||||||
|
|
||||||
|
service_params = await get_service_params(orgid)
|
||||||
|
chunks = await self.get_doucment_chunks(realpath)
|
||||||
|
embeddings = await self.docs_embedding(chunks)
|
||||||
|
await self.embedding_2_vdb(id, fiid, orgid, realpath, embedding)
|
||||||
|
triples = await self.get_triples(chunks)
|
||||||
|
await self.triple2graphdb(id, fiid, orgid, realpath, triples)
|
||||||
|
return
|
||||||
try:
|
try:
|
||||||
if not orgid or not fiid or not id:
|
if not orgid or not fiid or not id:
|
||||||
raise ValueError("orgid、fiid 和 id 不能为空")
|
raise ValueError("orgid、fiid 和 id 不能为空")
|
||||||
if len(orgid) > 32 or len(fiid) > 255:
|
if len(orgid) > 32 or len(fiid) > 255:
|
||||||
raise ValueError("orgid 或 fiid 的长度超出限制")
|
raise ValueError("orgid 或 fiid 的长度超出限制")
|
||||||
if not os.path.exists(realpath):
|
if not os.path.exists(realpath):
|
||||||
raise ValueError(f"文件 {realpath} 不存在")
|
raise ValueError(f"文件 {realpath} 不存在")
|
||||||
|
|
||||||
# 检查 hashvalue 是否已存在
|
|
||||||
db = DBPools()
|
|
||||||
dbname = env.get_module_dbname('rag')
|
|
||||||
sql_check_hash = """
|
|
||||||
SELECT hashvalue
|
|
||||||
FROM file
|
|
||||||
WHERE hashvalue = ${hashvalue}$
|
|
||||||
"""
|
|
||||||
async with db.sqlorContext(dbname) as sor:
|
|
||||||
hash_result = await sor.sqlExe(sql_check_hash, {"hashvalue": hashvalue})
|
|
||||||
if hash_result:
|
|
||||||
debug(f"文件已存在: hashvalue={hashvalue}")
|
|
||||||
timings["total"] = time.time() - start_total
|
|
||||||
return {
|
|
||||||
"status": "error",
|
|
||||||
"document_id": id,
|
|
||||||
"collection_name": "ragdb",
|
|
||||||
"timings": timings,
|
|
||||||
"message": f"文件已存在: hashvalue={hashvalue}",
|
|
||||||
"status_code": 400
|
|
||||||
}
|
|
||||||
|
|
||||||
# 获取服务参数
|
# 获取服务参数
|
||||||
service_params = await self.get_service_params(sor, orgid)
|
service_params = await get_service_params(orgid)
|
||||||
if not service_params:
|
if not service_params:
|
||||||
raise ValueError("无法获取服务参数")
|
raise ValueError("无法获取服务参数")
|
||||||
|
|
||||||
supported_formats = {'pdf', 'docx', 'xlsx', 'pptx', 'csv', 'txt'}
|
supported_formats = {'pdf', 'docx', 'xlsx', 'pptx', 'csv', 'txt'}
|
||||||
ext = realpath.rsplit('.', 1)[1].lower() if '.' in realpath else ''
|
ext = realpath.rsplit('.', 1)[1].lower() if '.' in realpath else ''
|
||||||
if ext not in supported_formats:
|
if ext not in supported_formats:
|
||||||
raise ValueError(f"不支持的文件格式: {ext}, 支持的格式: {', '.join(supported_formats)}")
|
raise ValueError(f"不支持的文件格式: {ext}, 支持的格式: {', '.join(supported_formats)}")
|
||||||
|
|
||||||
debug(f"加载文件: {realpath}")
|
debug(f"加载文件: {realpath}")
|
||||||
start_load = time.time()
|
start_load = time.time()
|
||||||
@ -176,22 +95,22 @@ where a.orgid = b.orgid
|
|||||||
timings["load_file"] = time.time() - start_load
|
timings["load_file"] = time.time() - start_load
|
||||||
debug(f"加载文件耗时: {timings['load_file']:.2f} 秒, 文本长度: {len(text)}")
|
debug(f"加载文件耗时: {timings['load_file']:.2f} 秒, 文本长度: {len(text)}")
|
||||||
if not text or not text.strip():
|
if not text or not text.strip():
|
||||||
raise ValueError(f"文件 {realpath} 加载为空")
|
raise ValueError(f"文件 {realpath} 加载为空")
|
||||||
|
|
||||||
document = Document(page_content=text)
|
document = Document(page_content=text)
|
||||||
text_splitter = RecursiveCharacterTextSplitter(
|
text_splitter = RecursiveCharacterTextSplitter(
|
||||||
chunk_size=500,
|
chunk_size=500,
|
||||||
chunk_overlap=100,
|
chunk_overlap=100,
|
||||||
length_function=len)
|
length_function=len)
|
||||||
debug("开始分片文件内容")
|
debug("开始分片文件内容")
|
||||||
start_split = time.time()
|
start_split = time.time()
|
||||||
chunks = text_splitter.split_documents([document])
|
chunks = text_splitter.split_documents([document])
|
||||||
timings["split_text"] = time.time() - start_split
|
timings["split_text"] = time.time() - start_split
|
||||||
debug(
|
debug(
|
||||||
f"文本分片耗时: {timings['split_text']:.2f} 秒, 分片数量: {len(chunks)}, 分片内容: {[chunk.page_content[:50] for chunk in chunks[:5]]}")
|
f"文本分片耗时: {timings['split_text']:.2f} 秒, 分片数量: {len(chunks)}, 分片内容: {[chunk.page_content[:50] for chunk in chunks[:5]]}")
|
||||||
debug(f"分片内容: {[chunk.page_content[:100] + '...' for chunk in chunks]}")
|
debug(f"分片内容: {[chunk.page_content[:100] + '...' for chunk in chunks]}")
|
||||||
if not chunks:
|
if not chunks:
|
||||||
raise ValueError(f"文件 {realpath} 未生成任何文档块")
|
raise ValueError(f"文件 {realpath} 未生成任何文档块")
|
||||||
|
|
||||||
filename = os.path.basename(realpath).rsplit('.', 1)[0]
|
filename = os.path.basename(realpath).rsplit('.', 1)[0]
|
||||||
upload_time = datetime.now().isoformat()
|
upload_time = datetime.now().isoformat()
|
||||||
@ -201,23 +120,23 @@ where a.orgid = b.orgid
|
|||||||
texts = [chunk.page_content for chunk in chunks]
|
texts = [chunk.page_content for chunk in chunks]
|
||||||
embeddings = []
|
embeddings = []
|
||||||
for i in range(0, len(texts), 10): # 每次处理 10 个文本块
|
for i in range(0, len(texts), 10): # 每次处理 10 个文本块
|
||||||
batch_texts = texts[i:i + 10]
|
batch_texts = texts[i:i + 10]
|
||||||
batch_embeddings = await api_service.get_embeddings(
|
batch_embeddings = await api_service.get_embeddings(
|
||||||
request=request,
|
request=request,
|
||||||
texts=batch_texts,
|
texts=batch_texts,
|
||||||
upappid=service_params['embedding'],
|
upappid=service_params['embedding'],
|
||||||
apiname="BAAI/bge-m3",
|
apiname="BAAI/bge-m3",
|
||||||
user=userid
|
user=userid
|
||||||
)
|
)
|
||||||
embeddings.extend(batch_embeddings)
|
embeddings.extend(batch_embeddings)
|
||||||
if not embeddings or not all(len(vec) == 1024 for vec in embeddings):
|
if not embeddings or not all(len(vec) == 1024 for vec in embeddings):
|
||||||
raise ValueError("所有嵌入向量必须是长度为 1024 的浮点数列表")
|
raise ValueError("所有嵌入向量必须是长度为 1024 的浮点数列表")
|
||||||
timings["generate_embeddings"] = time.time() - start_embedding
|
timings["generate_embeddings"] = time.time() - start_embedding
|
||||||
debug(f"生成嵌入向量耗时: {timings['generate_embeddings']:.2f} 秒, 嵌入数量: {len(embeddings)}")
|
debug(f"生成嵌入向量耗时: {timings['generate_embeddings']:.2f} 秒, 嵌入数量: {len(embeddings)}")
|
||||||
|
|
||||||
chunks_data = []
|
chunks_data = []
|
||||||
for i, chunk in enumerate(chunks):
|
for i, chunk in enumerate(chunks):
|
||||||
chunks_data.append({
|
chunks_data.append({
|
||||||
"userid": orgid,
|
"userid": orgid,
|
||||||
"knowledge_base_id": fiid,
|
"knowledge_base_id": fiid,
|
||||||
"text": chunk.page_content,
|
"text": chunk.page_content,
|
||||||
@ -227,38 +146,38 @@ where a.orgid = b.orgid
|
|||||||
"file_path": realpath,
|
"file_path": realpath,
|
||||||
"upload_time": upload_time,
|
"upload_time": upload_time,
|
||||||
"file_type": ext,
|
"file_type": ext,
|
||||||
})
|
})
|
||||||
|
|
||||||
debug(f"调用插入文件端点: {realpath}")
|
debug(f"调用插入文件端点: {realpath}")
|
||||||
start_milvus = time.time()
|
start_milvus = time.time()
|
||||||
for i in range(0, len(chunks_data), 10): # 每次处理 10 条数据
|
for i in range(0, len(chunks_data), 10): # 每次处理 10 条数据
|
||||||
batch_chunks = chunks_data[i:i + 10]
|
batch_chunks = chunks_data[i:i + 10]
|
||||||
result = await api_service.milvus_insert_document(
|
result = await api_service.milvus_insert_document(
|
||||||
request=request,
|
request=request,
|
||||||
chunks=batch_chunks,
|
chunks=batch_chunks,
|
||||||
db_type=db_type,
|
db_type=db_type,
|
||||||
upappid=service_params['vdb'],
|
upappid=service_params['vdb'],
|
||||||
apiname="milvus/insertdocument",
|
apiname="milvus/insertdocument",
|
||||||
user=userid
|
user=userid
|
||||||
)
|
)
|
||||||
if result.get("status") != "success":
|
if result.get("status") != "success":
|
||||||
raise ValueError(result.get("message", "Milvus 插入失败"))
|
raise ValueError(result.get("message", "Milvus 插入失败"))
|
||||||
timings["insert_milvus"] = time.time() - start_milvus
|
timings["insert_milvus"] = time.time() - start_milvus
|
||||||
debug(f"Milvus 插入耗时: {timings['insert_milvus']:.2f} 秒")
|
debug(f"Milvus 插入耗时: {timings['insert_milvus']:.2f} 秒")
|
||||||
|
|
||||||
if result.get("status") != "success":
|
if result.get("status") != "success":
|
||||||
timings["total"] = time.time() - start_total
|
timings["total"] = time.time() - start_total
|
||||||
return {"status": "error", "document_id": id, "timings": timings,
|
return {"status": "error", "document_id": id, "timings": timings,
|
||||||
"message": result.get("message", "未知错误"), "status_code": 400}
|
"message": result.get("message", "未知错误"), "status_code": 400}
|
||||||
|
|
||||||
debug("调用三元组抽取服务")
|
debug("调用三元组抽取服务")
|
||||||
start_triples = time.time()
|
start_triples = time.time()
|
||||||
unique_triples = []
|
unique_triples = []
|
||||||
try:
|
try:
|
||||||
chunk_texts = [doc.page_content for doc in chunks]
|
chunk_texts = [doc.page_content for doc in chunks]
|
||||||
debug(f"处理 {len(chunk_texts)} 个分片进行三元组抽取")
|
debug(f"处理 {len(chunk_texts)} 个分片进行三元组抽取")
|
||||||
triples = []
|
triples = []
|
||||||
for i, chunk in enumerate(chunk_texts):
|
for i, chunk in enumerate(chunk_texts):
|
||||||
result = await api_service.extract_triples(
|
result = await api_service.extract_triples(
|
||||||
request=request,
|
request=request,
|
||||||
text=chunk,
|
text=chunk,
|
||||||
@ -283,16 +202,16 @@ where a.orgid = b.orgid
|
|||||||
if (existing['head'].lower() == t['head'].lower() and
|
if (existing['head'].lower() == t['head'].lower() and
|
||||||
existing['tail'].lower() == t['tail'].lower() and
|
existing['tail'].lower() == t['tail'].lower() and
|
||||||
len(t['type']) > len(existing['type'])):
|
len(t['type']) > len(existing['type'])):
|
||||||
unique_triples.remove(existing)
|
unique_triples.remove(existing)
|
||||||
unique_triples.append(t)
|
unique_triples.append(t)
|
||||||
debug(f"替换三元组为更具体类型: {t}")
|
debug(f"替换三元组为更具体类型: {t}")
|
||||||
break
|
break
|
||||||
|
|
||||||
timings["extract_triples"] = time.time() - start_triples
|
timings["extract_triples"] = time.time() - start_triples
|
||||||
debug(
|
debug(
|
||||||
f"三元组抽取耗时: {timings['extract_triples']:.2f} 秒, 抽取到 {len(unique_triples)} 个三元组: {unique_triples[:5]}")
|
f"三元组抽取耗时: {timings['extract_triples']:.2f} 秒, 抽取到 {len(unique_triples)} 个三元组: {unique_triples[:5]}")
|
||||||
|
|
||||||
if unique_triples:
|
if unique_triples:
|
||||||
debug(f"抽取到 {len(unique_triples)} 个三元组,调用 Neo4j 服务插入")
|
debug(f"抽取到 {len(unique_triples)} 个三元组,调用 Neo4j 服务插入")
|
||||||
start_neo4j = time.time()
|
start_neo4j = time.time()
|
||||||
for i in range(0, len(unique_triples), 30): # 每次插入 30 个三元组
|
for i in range(0, len(unique_triples), 30): # 每次插入 30 个三元组
|
||||||
@ -322,11 +241,11 @@ where a.orgid = b.orgid
|
|||||||
info(f"文件 {realpath} 三元组成功插入 Neo4j: {neo4j_result.get('message')}")
|
info(f"文件 {realpath} 三元组成功插入 Neo4j: {neo4j_result.get('message')}")
|
||||||
timings["insert_neo4j"] = time.time() - start_neo4j
|
timings["insert_neo4j"] = time.time() - start_neo4j
|
||||||
debug(f"Neo4j 插入耗时: {timings['insert_neo4j']:.2f} 秒")
|
debug(f"Neo4j 插入耗时: {timings['insert_neo4j']:.2f} 秒")
|
||||||
else:
|
else:
|
||||||
debug(f"文件 {realpath} 未抽取到三元组")
|
debug(f"文件 {realpath} 未抽取到三元组")
|
||||||
timings["insert_neo4j"] = 0.0
|
timings["insert_neo4j"] = 0.0
|
||||||
|
|
||||||
except Exception as e:
|
except Exception as e:
|
||||||
timings["extract_triples"] = time.time() - start_triples if "extract_triples" not in timings else \
|
timings["extract_triples"] = time.time() - start_triples if "extract_triples" not in timings else \
|
||||||
timings["extract_triples"]
|
timings["extract_triples"]
|
||||||
timings["insert_neo4j"] = time.time() - start_neo4j if "insert_neo4j" not in timings else timings[
|
timings["insert_neo4j"] = time.time() - start_neo4j if "insert_neo4j" not in timings else timings[
|
||||||
@ -386,36 +305,34 @@ where a.orgid = b.orgid
|
|||||||
collection_name = "ragdb" if not db_type else f"ragdb_{db_type}"
|
collection_name = "ragdb" if not db_type else f"ragdb_{db_type}"
|
||||||
|
|
||||||
try:
|
try:
|
||||||
required_fields = ['id', 'realpath', 'fiid', 'ownerid']
|
required_fields = ['id', 'realpath', 'fiid', 'ownerid']
|
||||||
missing_fields = [field for field in required_fields if not rec.get(field, '')]
|
missing_fields = [field for field in required_fields if not rec.get(field, '')]
|
||||||
if missing_fields:
|
if missing_fields:
|
||||||
raise ValueError(f"缺少必填字段: {', '.join(missing_fields)}")
|
raise ValueError(f"缺少必填字段: {', '.join(missing_fields)}")
|
||||||
|
|
||||||
# 获取服务参数
|
# 获取服务参数
|
||||||
service_params = await self.get_service_params(sor, orgid)
|
service_params = await get_service_params(orgid)
|
||||||
if not service_params:
|
if not service_params:
|
||||||
raise ValueError("无法获取服务参数")
|
raise ValueError("无法获取服务参数")
|
||||||
|
|
||||||
debug(
|
debug(f"调用删除文件端点: userid={orgid}, file_path={realpath}, knowledge_base_id={fiid}, document_id={id}")
|
||||||
f"调用删除文件端点: userid={orgid}, file_path={realpath}, knowledge_base_id={fiid}, document_id={id}")
|
milvus_result = await api_service.milvus_delete_document(
|
||||||
milvus_result = await api_service.milvus_delete_document(
|
request=request,
|
||||||
request=request,
|
userid=orgid,
|
||||||
userid=orgid,
|
file_path=realpath,
|
||||||
file_path=realpath,
|
knowledge_base_id=fiid,
|
||||||
knowledge_base_id=fiid,
|
document_id=id,
|
||||||
document_id=id,
|
db_type=db_type,
|
||||||
db_type=db_type,
|
upappid=service_params['vdb'],
|
||||||
upappid=service_params['vdb'],
|
apiname="milvus/deletedocument",
|
||||||
apiname="milvus/deletedocument",
|
user=userid
|
||||||
user=userid
|
)
|
||||||
)
|
|
||||||
|
|
||||||
if milvus_result.get("status") != "success":
|
if milvus_result.get("status") != "success":
|
||||||
raise ValueError(milvus_result.get("message", "Milvus 删除失败"))
|
raise ValueError(milvus_result.get("message", "Milvus 删除失败"))
|
||||||
|
|
||||||
neo4j_deleted_nodes = 0
|
neo4j_deleted_nodes = 0
|
||||||
neo4j_deleted_rels = 0
|
neo4j_deleted_rels = 0
|
||||||
try:
|
|
||||||
debug(f"调用 Neo4j 删除文档端点: document_id={id}")
|
debug(f"调用 Neo4j 删除文档端点: document_id={id}")
|
||||||
neo4j_result = await api_service.neo4j_delete_document(
|
neo4j_result = await api_service.neo4j_delete_document(
|
||||||
request=request,
|
request=request,
|
||||||
@ -433,26 +350,24 @@ where a.orgid = b.orgid
|
|||||||
total_nodes_deleted += nodes_deleted
|
total_nodes_deleted += nodes_deleted
|
||||||
total_rels_deleted += rels_deleted
|
total_rels_deleted += rels_deleted
|
||||||
info(f"成功删除 document_id={id} 的 {nodes_deleted} 个 Neo4j 节点和 {rels_deleted} 个关系")
|
info(f"成功删除 document_id={id} 的 {nodes_deleted} 个 Neo4j 节点和 {rels_deleted} 个关系")
|
||||||
except Exception as e:
|
|
||||||
error(f"删除 document_id={id} 的 Neo4j 数据失败: {str(e)}")
|
|
||||||
|
|
||||||
results.append({
|
results.append({
|
||||||
"status": "success",
|
"status": "success",
|
||||||
"collection_name": collection_name,
|
"collection_name": collection_name,
|
||||||
"document_id": id,
|
"document_id": id,
|
||||||
"message": f"成功删除文件 {realpath} 的 Milvus 记录,{neo4j_deleted_nodes} 个 Neo4j 节点,{neo4j_deleted_rels} 个 Neo4j 关系",
|
"message": f"成功删除文件 {realpath} 的 Milvus 记录,{neo4j_deleted_nodes} 个 Neo4j 节点,{neo4j_deleted_rels} 个 Neo4j 关系",
|
||||||
"status_code": 200
|
"status_code": 200
|
||||||
})
|
})
|
||||||
|
|
||||||
except Exception as e:
|
except Exception as e:
|
||||||
error(f"删除文档 {realpath} 失败: {str(e)}, 堆栈: {traceback.format_exc()}")
|
error(f"删除文档 {realpath} 失败: {str(e)}, 堆栈: {traceback.format_exc()}")
|
||||||
results.append({
|
results.append({
|
||||||
"status": "error",
|
"status": "error",
|
||||||
"collection_name": collection_name,
|
"collection_name": collection_name,
|
||||||
"document_id": id,
|
"document_id": id,
|
||||||
"message": f"删除文档 {realpath} 失败: {str(e)}",
|
"message": f"删除文档 {realpath} 失败: {str(e)}",
|
||||||
"status_code": 400
|
"status_code": 400
|
||||||
})
|
})
|
||||||
|
|
||||||
return {
|
return {
|
||||||
"status": "success" if all(r["status"] == "success" for r in results) else "partial",
|
"status": "success" if all(r["status"] == "success" for r in results) else "partial",
|
||||||
|
|||||||
75
rag/service_opts.py
Normal file
75
rag/service_opts.py
Normal file
@ -0,0 +1,75 @@
|
|||||||
|
from ahserver.serverenv import get_serverenv
|
||||||
|
|
||||||
|
async def sor_get_service_params(sor, orgid):
|
||||||
|
""" 根据 orgid 从数据库获取服务参数 (仅 upappid),假设 service_opts 表返回单条记录。 """
|
||||||
|
sql_opts = """
|
||||||
|
SELECT embedding_id, vdb_id, reranker_id, triples_id, gdb_id, entities_id
|
||||||
|
FROM service_opts
|
||||||
|
WHERE orgid = ${orgid}$
|
||||||
|
"""
|
||||||
|
opts_result = await sor.sqlExe(sql_opts, {"orgid": orgid})
|
||||||
|
if not opts_result:
|
||||||
|
error(f"未找到 orgid={orgid} 的服务配置")
|
||||||
|
return None
|
||||||
|
opts = opts_result[0]
|
||||||
|
|
||||||
|
# 收集服务 ID
|
||||||
|
service_ids = set()
|
||||||
|
for key in ['embedding_id', 'vdb_id', 'reranker_id', 'triples_id', 'gdb_id', 'entities_id']:
|
||||||
|
if opts[key]:
|
||||||
|
service_ids.add(opts[key])
|
||||||
|
|
||||||
|
# 检查 service_ids 是否为空
|
||||||
|
if not service_ids:
|
||||||
|
error(f"未找到任何服务 ID for orgid={orgid}")
|
||||||
|
return None
|
||||||
|
|
||||||
|
# 手动构造 IN 子句的 ID 列表
|
||||||
|
id_list = ','.join([f"'{id}'" for id in service_ids]) # 确保每个 ID 被单引号包裹
|
||||||
|
sql_services = """
|
||||||
|
SELECT id, name, upappid
|
||||||
|
FROM ragservices
|
||||||
|
WHERE id IN ${id_list}$
|
||||||
|
"""
|
||||||
|
services_result = await sor.sqlExe(sql_services, {'id_list': id_list})
|
||||||
|
if not services_result:
|
||||||
|
error(f"未找到服务 ID {service_ids} 的 ragservices 配置")
|
||||||
|
return None
|
||||||
|
|
||||||
|
# 构建服务参数字典,基于 name 字段匹配,仅存储 upappid
|
||||||
|
service_params = {
|
||||||
|
'embedding': None,
|
||||||
|
'vdb': None,
|
||||||
|
'reranker': None,
|
||||||
|
'triples': None,
|
||||||
|
'gdb': None,
|
||||||
|
'entities': None
|
||||||
|
}
|
||||||
|
for service in services_result:
|
||||||
|
name = service['name']
|
||||||
|
if name == 'bgem3嵌入':
|
||||||
|
service_params['embedding'] = service['upappid']
|
||||||
|
elif name == 'milvus向量检索':
|
||||||
|
service_params['vdb'] = service['upappid']
|
||||||
|
elif name == 'bgem2v3重排':
|
||||||
|
service_params['reranker'] = service['upappid']
|
||||||
|
elif name == 'mrebel三元组抽取':
|
||||||
|
service_params['triples'] = service['upappid']
|
||||||
|
elif name == 'neo4j删除知识库':
|
||||||
|
service_params['gdb'] = service['upappid']
|
||||||
|
elif name == 'small实体抽取':
|
||||||
|
service_params['entities'] = service['upappid']
|
||||||
|
|
||||||
|
# 检查是否所有服务参数都已填充
|
||||||
|
missing_services = [k for k, v in service_params.items() if v is None]
|
||||||
|
if missing_services:
|
||||||
|
error(f"未找到以下服务的配置: {missing_services}")
|
||||||
|
return None
|
||||||
|
return service_params
|
||||||
|
|
||||||
|
async def get_service_params(orgid):
|
||||||
|
db = DBPools()
|
||||||
|
dbname = get_serverenv('get_module_dbname')('rag')
|
||||||
|
async with db.sqlorContext(dbname) as sor:
|
||||||
|
return await sor_get_server_params(sor, orgid)
|
||||||
|
return None
|
||||||
Loading…
x
Reference in New Issue
Block a user