bugfix
This commit is contained in:
parent
9e7ff6c71f
commit
e590c1084f
@ -17,6 +17,7 @@ import traceback
|
||||
from filetxt.loader import fileloader
|
||||
from ahserver.serverenv import get_serverenv
|
||||
from typing import List, Dict, Any
|
||||
from rag.service_opts import get_service_params, sor_get_service_params
|
||||
import json
|
||||
|
||||
class RagFileMgr(FileMgr):
|
||||
@ -43,73 +44,6 @@ where a.orgid = b.orgid
|
||||
return r.quota, r.expired_date
|
||||
return None, None
|
||||
|
||||
async def get_service_params(self,sor, orgid):
|
||||
""" 根据 orgid 从数据库获取服务参数 (仅 upappid),假设 service_opts 表返回单条记录。 """
|
||||
sql_opts = """
|
||||
SELECT embedding_id, vdb_id, reranker_id, triples_id, gdb_id, entities_id
|
||||
FROM service_opts
|
||||
WHERE orgid = ${orgid}$
|
||||
"""
|
||||
opts_result = await sor.sqlExe(sql_opts, {"orgid": orgid})
|
||||
if not opts_result:
|
||||
error(f"未找到 orgid={orgid} 的服务配置")
|
||||
return None
|
||||
opts = opts_result[0]
|
||||
|
||||
# 收集服务 ID
|
||||
service_ids = set()
|
||||
for key in ['embedding_id', 'vdb_id', 'reranker_id', 'triples_id', 'gdb_id', 'entities_id']:
|
||||
if opts[key]:
|
||||
service_ids.add(opts[key])
|
||||
|
||||
# 检查 service_ids 是否为空
|
||||
if not service_ids:
|
||||
error(f"未找到任何服务 ID for orgid={orgid}")
|
||||
return None
|
||||
|
||||
# 手动构造 IN 子句的 ID 列表
|
||||
id_list = ','.join([f"'{id}'" for id in service_ids]) # 确保每个 ID 被单引号包裹
|
||||
sql_services = """
|
||||
SELECT id, name, upappid
|
||||
FROM ragservices
|
||||
WHERE id IN ${id_list}$
|
||||
"""
|
||||
services_result = await sor.sqlExe(sql_services, {'id_list': id_list})
|
||||
if not services_result:
|
||||
error(f"未找到服务 ID {service_ids} 的 ragservices 配置")
|
||||
return None
|
||||
|
||||
# 构建服务参数字典,基于 name 字段匹配,仅存储 upappid
|
||||
service_params = {
|
||||
'embedding': None,
|
||||
'vdb': None,
|
||||
'reranker': None,
|
||||
'triples': None,
|
||||
'gdb': None,
|
||||
'entities': None
|
||||
}
|
||||
for service in services_result:
|
||||
name = service['name']
|
||||
if name == 'bgem3嵌入':
|
||||
service_params['embedding'] = service['upappid']
|
||||
elif name == 'milvus向量检索':
|
||||
service_params['vdb'] = service['upappid']
|
||||
elif name == 'bgem2v3重排':
|
||||
service_params['reranker'] = service['upappid']
|
||||
elif name == 'mrebel三元组抽取':
|
||||
service_params['triples'] = service['upappid']
|
||||
elif name == 'neo4j删除知识库':
|
||||
service_params['gdb'] = service['upappid']
|
||||
elif name == 'small实体抽取':
|
||||
service_params['entities'] = service['upappid']
|
||||
|
||||
# 检查是否所有服务参数都已填充
|
||||
missing_services = [k for k, v in service_params.items() if v is None]
|
||||
if missing_services:
|
||||
error(f"未找到以下服务的配置: {missing_services}")
|
||||
return None
|
||||
return service_params
|
||||
|
||||
async def file_uploaded(self, request, ns, userid):
|
||||
"""将文档插入 Milvus 并抽取三元组到 Neo4j"""
|
||||
debug(f'Received ns: {ns=}')
|
||||
@ -128,6 +62,13 @@ where a.orgid = b.orgid
|
||||
timings = {}
|
||||
start_total = time.time()
|
||||
|
||||
service_params = await get_service_params(orgid)
|
||||
chunks = await self.get_doucment_chunks(realpath)
|
||||
embeddings = await self.docs_embedding(chunks)
|
||||
await self.embedding_2_vdb(id, fiid, orgid, realpath, embedding)
|
||||
triples = await self.get_triples(chunks)
|
||||
await self.triple2graphdb(id, fiid, orgid, realpath, triples)
|
||||
return
|
||||
try:
|
||||
if not orgid or not fiid or not id:
|
||||
raise ValueError("orgid、fiid 和 id 不能为空")
|
||||
@ -136,30 +77,8 @@ where a.orgid = b.orgid
|
||||
if not os.path.exists(realpath):
|
||||
raise ValueError(f"文件 {realpath} 不存在")
|
||||
|
||||
# 检查 hashvalue 是否已存在
|
||||
db = DBPools()
|
||||
dbname = env.get_module_dbname('rag')
|
||||
sql_check_hash = """
|
||||
SELECT hashvalue
|
||||
FROM file
|
||||
WHERE hashvalue = ${hashvalue}$
|
||||
"""
|
||||
async with db.sqlorContext(dbname) as sor:
|
||||
hash_result = await sor.sqlExe(sql_check_hash, {"hashvalue": hashvalue})
|
||||
if hash_result:
|
||||
debug(f"文件已存在: hashvalue={hashvalue}")
|
||||
timings["total"] = time.time() - start_total
|
||||
return {
|
||||
"status": "error",
|
||||
"document_id": id,
|
||||
"collection_name": "ragdb",
|
||||
"timings": timings,
|
||||
"message": f"文件已存在: hashvalue={hashvalue}",
|
||||
"status_code": 400
|
||||
}
|
||||
|
||||
# 获取服务参数
|
||||
service_params = await self.get_service_params(sor, orgid)
|
||||
service_params = await get_service_params(orgid)
|
||||
if not service_params:
|
||||
raise ValueError("无法获取服务参数")
|
||||
|
||||
@ -392,12 +311,11 @@ where a.orgid = b.orgid
|
||||
raise ValueError(f"缺少必填字段: {', '.join(missing_fields)}")
|
||||
|
||||
# 获取服务参数
|
||||
service_params = await self.get_service_params(sor, orgid)
|
||||
service_params = await get_service_params(orgid)
|
||||
if not service_params:
|
||||
raise ValueError("无法获取服务参数")
|
||||
|
||||
debug(
|
||||
f"调用删除文件端点: userid={orgid}, file_path={realpath}, knowledge_base_id={fiid}, document_id={id}")
|
||||
debug(f"调用删除文件端点: userid={orgid}, file_path={realpath}, knowledge_base_id={fiid}, document_id={id}")
|
||||
milvus_result = await api_service.milvus_delete_document(
|
||||
request=request,
|
||||
userid=orgid,
|
||||
@ -415,7 +333,6 @@ where a.orgid = b.orgid
|
||||
|
||||
neo4j_deleted_nodes = 0
|
||||
neo4j_deleted_rels = 0
|
||||
try:
|
||||
debug(f"调用 Neo4j 删除文档端点: document_id={id}")
|
||||
neo4j_result = await api_service.neo4j_delete_document(
|
||||
request=request,
|
||||
@ -433,8 +350,6 @@ where a.orgid = b.orgid
|
||||
total_nodes_deleted += nodes_deleted
|
||||
total_rels_deleted += rels_deleted
|
||||
info(f"成功删除 document_id={id} 的 {nodes_deleted} 个 Neo4j 节点和 {rels_deleted} 个关系")
|
||||
except Exception as e:
|
||||
error(f"删除 document_id={id} 的 Neo4j 数据失败: {str(e)}")
|
||||
|
||||
results.append({
|
||||
"status": "success",
|
||||
|
||||
75
rag/service_opts.py
Normal file
75
rag/service_opts.py
Normal file
@ -0,0 +1,75 @@
|
||||
from ahserver.serverenv import get_serverenv
|
||||
|
||||
async def sor_get_service_params(sor, orgid):
|
||||
""" 根据 orgid 从数据库获取服务参数 (仅 upappid),假设 service_opts 表返回单条记录。 """
|
||||
sql_opts = """
|
||||
SELECT embedding_id, vdb_id, reranker_id, triples_id, gdb_id, entities_id
|
||||
FROM service_opts
|
||||
WHERE orgid = ${orgid}$
|
||||
"""
|
||||
opts_result = await sor.sqlExe(sql_opts, {"orgid": orgid})
|
||||
if not opts_result:
|
||||
error(f"未找到 orgid={orgid} 的服务配置")
|
||||
return None
|
||||
opts = opts_result[0]
|
||||
|
||||
# 收集服务 ID
|
||||
service_ids = set()
|
||||
for key in ['embedding_id', 'vdb_id', 'reranker_id', 'triples_id', 'gdb_id', 'entities_id']:
|
||||
if opts[key]:
|
||||
service_ids.add(opts[key])
|
||||
|
||||
# 检查 service_ids 是否为空
|
||||
if not service_ids:
|
||||
error(f"未找到任何服务 ID for orgid={orgid}")
|
||||
return None
|
||||
|
||||
# 手动构造 IN 子句的 ID 列表
|
||||
id_list = ','.join([f"'{id}'" for id in service_ids]) # 确保每个 ID 被单引号包裹
|
||||
sql_services = """
|
||||
SELECT id, name, upappid
|
||||
FROM ragservices
|
||||
WHERE id IN ${id_list}$
|
||||
"""
|
||||
services_result = await sor.sqlExe(sql_services, {'id_list': id_list})
|
||||
if not services_result:
|
||||
error(f"未找到服务 ID {service_ids} 的 ragservices 配置")
|
||||
return None
|
||||
|
||||
# 构建服务参数字典,基于 name 字段匹配,仅存储 upappid
|
||||
service_params = {
|
||||
'embedding': None,
|
||||
'vdb': None,
|
||||
'reranker': None,
|
||||
'triples': None,
|
||||
'gdb': None,
|
||||
'entities': None
|
||||
}
|
||||
for service in services_result:
|
||||
name = service['name']
|
||||
if name == 'bgem3嵌入':
|
||||
service_params['embedding'] = service['upappid']
|
||||
elif name == 'milvus向量检索':
|
||||
service_params['vdb'] = service['upappid']
|
||||
elif name == 'bgem2v3重排':
|
||||
service_params['reranker'] = service['upappid']
|
||||
elif name == 'mrebel三元组抽取':
|
||||
service_params['triples'] = service['upappid']
|
||||
elif name == 'neo4j删除知识库':
|
||||
service_params['gdb'] = service['upappid']
|
||||
elif name == 'small实体抽取':
|
||||
service_params['entities'] = service['upappid']
|
||||
|
||||
# 检查是否所有服务参数都已填充
|
||||
missing_services = [k for k, v in service_params.items() if v is None]
|
||||
if missing_services:
|
||||
error(f"未找到以下服务的配置: {missing_services}")
|
||||
return None
|
||||
return service_params
|
||||
|
||||
async def get_service_params(orgid):
|
||||
db = DBPools()
|
||||
dbname = get_serverenv('get_module_dbname')('rag')
|
||||
async with db.sqlorContext(dbname) as sor:
|
||||
return await sor_get_server_params(sor, orgid)
|
||||
return None
|
||||
Loading…
x
Reference in New Issue
Block a user