101 lines
3.3 KiB
Python
101 lines
3.3 KiB
Python
from ahserver.serverenv import get_serverenv
|
||
from sqlor.dbpools import DBPools
|
||
from appPublic.log import debug, error, info
|
||
|
||
async def sor_get_service_params(sor, orgid):
|
||
""" 根据 orgid 从数据库获取服务参数 (仅 upappid),假设 service_opts 表返回单条记录。 """
|
||
sql_opts = """
|
||
SELECT embedding_id, vdb_id, reranker_id, triples_id, gdb_id, entities_id
|
||
FROM service_opts
|
||
WHERE orgid = ${orgid}$
|
||
"""
|
||
opts_result = await sor.sqlExe(sql_opts, {"orgid": orgid})
|
||
if not opts_result:
|
||
error(f"未找到 orgid={orgid} 的服务配置")
|
||
return None
|
||
opts = opts_result[0]
|
||
|
||
# 收集服务 ID
|
||
service_ids = set()
|
||
for key in ['embedding_id', 'vdb_id', 'reranker_id', 'triples_id', 'gdb_id', 'entities_id']:
|
||
if opts[key]:
|
||
service_ids.add(opts[key])
|
||
|
||
# 检查 service_ids 是否为空
|
||
if not service_ids:
|
||
error(f"未找到任何服务 ID for orgid={orgid}")
|
||
return None
|
||
|
||
# 手动构造 IN 子句的 ID 列表
|
||
id_list = [id for id in service_ids] # 确保每个 ID 被单引号包裹
|
||
sql_services = """
|
||
SELECT id, name, upappid
|
||
FROM ragservices
|
||
WHERE id IN ${id_list}$
|
||
"""
|
||
services_result = await sor.sqlExe(sql_services, {'id_list': id_list})
|
||
if not services_result:
|
||
error(f"未找到服务 ID {service_ids} 的 ragservices 配置")
|
||
return None
|
||
|
||
# 构建服务参数字典,基于 name 字段匹配,仅存储 upappid
|
||
service_params = {
|
||
'embedding': None,
|
||
'vdb': None,
|
||
'reranker': None,
|
||
'triples': None,
|
||
'gdb': None,
|
||
'entities': None
|
||
}
|
||
for service in services_result:
|
||
name = service['name']
|
||
if name == 'bgem3嵌入':
|
||
service_params['embedding'] = service['upappid']
|
||
elif name == 'milvus向量检索':
|
||
service_params['vdb'] = service['upappid']
|
||
elif name == 'bgem2v3重排':
|
||
service_params['reranker'] = service['upappid']
|
||
elif name == 'mrebel三元组抽取':
|
||
service_params['triples'] = service['upappid']
|
||
elif name == 'neo4j图知识库':
|
||
service_params['gdb'] = service['upappid']
|
||
elif name == 'small实体抽取':
|
||
service_params['entities'] = service['upappid']
|
||
elif name == 'clip多模态嵌入服务':
|
||
service_params['embedding'] = service['upappid']
|
||
# 检查是否所有服务参数都已填充
|
||
missing_services = [k for k, v in service_params.items() if v is None]
|
||
if missing_services:
|
||
error(f"未找到以下服务的配置: {missing_services}")
|
||
return None
|
||
return service_params
|
||
|
||
async def get_service_params(orgid):
|
||
db = DBPools()
|
||
debug(f"传入的orgid是:{orgid}")
|
||
dbname = get_serverenv('get_module_dbname')('rag')
|
||
async with db.sqlorContext(dbname) as sor:
|
||
return await sor_get_service_params(sor, orgid)
|
||
return None
|
||
|
||
async def sor_get_embedding_mode(sor, orgid) -> int:
|
||
"""根据 orgid 获取嵌入模式:0=纯文本,1=多模态"""
|
||
sql = """
|
||
SELECT em.mode
|
||
FROM service_opts so
|
||
JOIN embedding_mode em ON so.embedding_id = em.embeddingid
|
||
WHERE so.orgid = ${orgid}$
|
||
"""
|
||
rows = await sor.sqlExe(sql, {"orgid": orgid})
|
||
if not rows:
|
||
debug(f"orgid={orgid} 未配置 embedding_mode,默认为 0(纯文本)")
|
||
return 0
|
||
return int(rows[0].mode)
|
||
|
||
async def get_embedding_mode(orgid):
|
||
db = DBPools()
|
||
# debug(f"传入的orgid是:{orgid}")
|
||
dbname = get_serverenv('get_module_dbname')('rag')
|
||
async with db.sqlorContext(dbname) as sor:
|
||
return await sor_get_embedding_mode(sor, orgid)
|
||
return None |