rag/rag/service_opts.py
2025-11-28 16:20:20 +08:00

101 lines
3.3 KiB
Python
Raw Blame History

This file contains ambiguous Unicode characters

This file contains Unicode characters that might be confused with other characters. If you think that this is intentional, you can safely ignore this warning. Use the Escape button to reveal them.

from ahserver.serverenv import get_serverenv
from sqlor.dbpools import DBPools
from appPublic.log import debug, error, info
async def sor_get_service_params(sor, orgid):
""" 根据 orgid 从数据库获取服务参数 (仅 upappid),假设 service_opts 表返回单条记录。 """
sql_opts = """
SELECT embedding_id, vdb_id, reranker_id, triples_id, gdb_id, entities_id
FROM service_opts
WHERE orgid = ${orgid}$
"""
opts_result = await sor.sqlExe(sql_opts, {"orgid": orgid})
if not opts_result:
error(f"未找到 orgid={orgid} 的服务配置")
return None
opts = opts_result[0]
# 收集服务 ID
service_ids = set()
for key in ['embedding_id', 'vdb_id', 'reranker_id', 'triples_id', 'gdb_id', 'entities_id']:
if opts[key]:
service_ids.add(opts[key])
# 检查 service_ids 是否为空
if not service_ids:
error(f"未找到任何服务 ID for orgid={orgid}")
return None
# 手动构造 IN 子句的 ID 列表
id_list = [id for id in service_ids] # 确保每个 ID 被单引号包裹
sql_services = """
SELECT id, name, upappid
FROM ragservices
WHERE id IN ${id_list}$
"""
services_result = await sor.sqlExe(sql_services, {'id_list': id_list})
if not services_result:
error(f"未找到服务 ID {service_ids} 的 ragservices 配置")
return None
# 构建服务参数字典,基于 name 字段匹配,仅存储 upappid
service_params = {
'embedding': None,
'vdb': None,
'reranker': None,
'triples': None,
'gdb': None,
'entities': None
}
for service in services_result:
name = service['name']
if name == 'bgem3嵌入':
service_params['embedding'] = service['upappid']
elif name == 'milvus向量检索':
service_params['vdb'] = service['upappid']
elif name == 'bgem2v3重排':
service_params['reranker'] = service['upappid']
elif name == 'mrebel三元组抽取':
service_params['triples'] = service['upappid']
elif name == 'neo4j图知识库':
service_params['gdb'] = service['upappid']
elif name == 'small实体抽取':
service_params['entities'] = service['upappid']
elif name == 'clip多模态嵌入服务':
service_params['embedding'] = service['upappid']
# 检查是否所有服务参数都已填充
missing_services = [k for k, v in service_params.items() if v is None]
if missing_services:
error(f"未找到以下服务的配置: {missing_services}")
return None
return service_params
async def get_service_params(orgid):
db = DBPools()
debug(f"传入的orgid是{orgid}")
dbname = get_serverenv('get_module_dbname')('rag')
async with db.sqlorContext(dbname) as sor:
return await sor_get_service_params(sor, orgid)
return None
async def sor_get_embedding_mode(sor, orgid) -> int:
"""根据 orgid 获取嵌入模式0=纯文本1=多模态"""
sql = """
SELECT em.mode
FROM service_opts so
JOIN embedding_mode em ON so.embedding_id = em.embeddingid
WHERE so.orgid = ${orgid}$
"""
rows = await sor.sqlExe(sql, {"orgid": orgid})
if not rows:
debug(f"orgid={orgid} 未配置 embedding_mode默认为 0纯文本")
return 0
return int(rows[0].mode)
async def get_embedding_mode(orgid):
db = DBPools()
# debug(f"传入的orgid是{orgid}")
dbname = get_serverenv('get_module_dbname')('rag')
async with db.sqlorContext(dbname) as sor:
return await sor_get_embedding_mode(sor, orgid)
return None