This commit is contained in:
yumoqing 2025-09-11 13:46:48 +08:00
parent 2e81611868
commit 9e7ff6c71f

View File

@ -43,26 +43,18 @@ where a.orgid = b.orgid
return r.quota, r.expired_date return r.quota, r.expired_date
return None, None return None, None
async def get_service_params(self,orgid): async def get_service_params(self,sor, orgid):
""" 根据 orgid 从数据库获取服务参数 (仅 upappid),假设 service_opts 表返回单条记录。 """ """ 根据 orgid 从数据库获取服务参数 (仅 upappid),假设 service_opts 表返回单条记录。 """
db = DBPools()
dbname = "kyrag"
sql_opts = """ sql_opts = """
SELECT embedding_id, vdb_id, reranker_id, triples_id, gdb_id, entities_id SELECT embedding_id, vdb_id, reranker_id, triples_id, gdb_id, entities_id
FROM service_opts FROM service_opts
WHERE orgid = ${orgid}$ WHERE orgid = ${orgid}$
""" """
try:
async with db.sqlorContext(dbname) as sor:
opts_result = await sor.sqlExe(sql_opts, {"orgid": orgid}) opts_result = await sor.sqlExe(sql_opts, {"orgid": orgid})
if not opts_result: if not opts_result:
error(f"未找到 orgid={orgid} 的服务配置") error(f"未找到 orgid={orgid} 的服务配置")
return None return None
opts = opts_result[0] opts = opts_result[0]
except Exception as e:
error(f"查询 service_opts 失败: {str(e)}, 堆栈: {traceback.format_exc()}")
return None
# 收集服务 ID # 收集服务 ID
service_ids = set() service_ids = set()
@ -77,14 +69,12 @@ where a.orgid = b.orgid
# 手动构造 IN 子句的 ID 列表 # 手动构造 IN 子句的 ID 列表
id_list = ','.join([f"'{id}'" for id in service_ids]) # 确保每个 ID 被单引号包裹 id_list = ','.join([f"'{id}'" for id in service_ids]) # 确保每个 ID 被单引号包裹
sql_services = f""" sql_services = """
SELECT id, name, upappid SELECT id, name, upappid
FROM ragservices FROM ragservices
WHERE id IN ({id_list}) WHERE id IN ${id_list}$
""" """
try: services_result = await sor.sqlExe(sql_services, {'id_list': id_list})
async with db.sqlorContext(dbname) as sor:
services_result = await sor.sqlExe(sql_services, {})
if not services_result: if not services_result:
error(f"未找到服务 ID {service_ids} 的 ragservices 配置") error(f"未找到服务 ID {service_ids} 的 ragservices 配置")
return None return None
@ -118,15 +108,12 @@ where a.orgid = b.orgid
if missing_services: if missing_services:
error(f"未找到以下服务的配置: {missing_services}") error(f"未找到以下服务的配置: {missing_services}")
return None return None
return service_params return service_params
except Exception as e:
error(f"查询 ragservices 失败: {str(e)}, 堆栈: {traceback.format_exc()}")
return None
async def file_uploaded(self, request, ns, userid): async def file_uploaded(self, request, ns, userid):
"""将文档插入 Milvus 并抽取三元组到 Neo4j""" """将文档插入 Milvus 并抽取三元组到 Neo4j"""
debug(f'Received ns: {ns=}') debug(f'Received ns: {ns=}')
env = request._run_ns
realpath = ns.get('realpath', '') realpath = ns.get('realpath', '')
fiid = ns.get('fiid', '') fiid = ns.get('fiid', '')
id = ns.get('id', '') id = ns.get('id', '')
@ -151,7 +138,7 @@ where a.orgid = b.orgid
# 检查 hashvalue 是否已存在 # 检查 hashvalue 是否已存在
db = DBPools() db = DBPools()
dbname = "kyrag" dbname = env.get_module_dbname('rag')
sql_check_hash = """ sql_check_hash = """
SELECT hashvalue SELECT hashvalue
FROM file FROM file
@ -172,7 +159,7 @@ where a.orgid = b.orgid
} }
# 获取服务参数 # 获取服务参数
service_params = await self.get_service_params(orgid) service_params = await self.get_service_params(sor, orgid)
if not service_params: if not service_params:
raise ValueError("无法获取服务参数") raise ValueError("无法获取服务参数")
@ -405,7 +392,7 @@ where a.orgid = b.orgid
raise ValueError(f"缺少必填字段: {', '.join(missing_fields)}") raise ValueError(f"缺少必填字段: {', '.join(missing_fields)}")
# 获取服务参数 # 获取服务参数
service_params = await self.get_service_params(orgid) service_params = await self.get_service_params(sor, orgid)
if not service_params: if not service_params:
raise ValueError("无法获取服务参数") raise ValueError("无法获取服务参数")