Compare commits

...

2 Commits

Author SHA1 Message Date
f1fdb502be rag 2025-10-10 14:32:13 +08:00
1bca72c8fc rag 2025-10-10 14:30:17 +08:00

View File

@ -149,6 +149,9 @@ async def fusedsearch(request, params_kw, *params):
debug(f"fiids: {fiids}")
# 验证 fiids的orgid与orgid = await f()是否一致
await _validate_fiids_orgid(fiids, orgid, kw)
service_params = await get_service_params(orgid)
if not service_params:
raise ValueError("无法获取服务参数")
@ -190,7 +193,58 @@ async def fusedsearch(request, params_kw, *params):
"error": str(e)
}
async def _validate_fiids_orgid(fiids, orgid, kw):
"""验证 fiids 的 orgid 与当前用户 orgid 是否一致"""
if fiids:
db = DBPools()
dbname = kw.get('get_module_dbname')('rag')
sql_opts = """SELECT orgid FROM kdb WHERE id = ${id}$"""
try:
async with db.sqlorContext(dbname) as sor:
result = await sor.sqlExe(sql_opts, {"id": fiids[0]})
if not result:
raise ValueError(f"未找到 fiid={fiids[0]} 的记录")
kdb_orgid = result[0].get('orgid')
if kdb_orgid != orgid:
raise ValueError(f"orgid 不一致: kdb.orgid={kdb_orgid}, user orgid={orgid}")
except Exception as e:
error(f"orgid 验证失败: {str(e)}")
raise
async def _validate_fiids_orgid(fiids, orgid, kw):
"""验证 fiids 的 orgid 与当前用户 orgid 是否一致"""
if fiids:
db = DBPools()
dbname = kw.get('get_module_dbname')('rag')
sql_opts = """SELECT orgid FROM kdb WHERE id = ${id}$"""
try:
async with db.sqlorContext(dbname) as sor:
result = await sor.sqlExe(sql_opts, {"id": fiids[0]})
if not result:
raise ValueError(f"未找到 fiid={fiids[0]} 的记录")
kdb_orgid = result[0].get('orgid')
if kdb_orgid != orgid:
raise ValueError(f"orgid 不一致: kdb.orgid={kdb_orgid}, user orgid={orgid}")
except Exception as e:
error(f"orgid 验证失败: {str(e)}")
raise
def _combine_query_with_triplets(query, triplets):
"""拼接查询文本和三元组文本"""
triplet_texts = []
for triplet in triplets:
head = triplet.get('head', '')
type_ = triplet.get('type', '')
tail = triplet.get('tail', '')
if head and type_ and tail:
triplet_texts.append(f"{head} {type_} {tail}")
else:
debug(f"无效三元组: {triplet}")
combined_text = query
if triplet_texts:
combined_text += "".join(triplet_texts)
debug(f"拼接文本: {combined_text[:200]}... (总长度: {len(combined_text)}, 三元组数量: {len(triplet_texts)})")
return combined_text