rag
This commit is contained in:
parent
1b42987012
commit
1bca72c8fc
@ -124,6 +124,9 @@ async def fusedsearch(request, params_kw, *params):
|
|||||||
debug(f"limit: {limit}")
|
debug(f"limit: {limit}")
|
||||||
raw_fiids = params_kw.get('fiids') or params_kw.get('knowledge_id') #
|
raw_fiids = params_kw.get('fiids') or params_kw.get('knowledge_id') #
|
||||||
|
|
||||||
|
# 验证 fiids的orgid与orgid = await f()是否一致
|
||||||
|
await _validate_fiids_orgid(fiids, orgid, kw)
|
||||||
|
|
||||||
# 标准化为列表格式
|
# 标准化为列表格式
|
||||||
if raw_fiids is None:
|
if raw_fiids is None:
|
||||||
fiids = [] # 两个参数都不存在
|
fiids = [] # 两个参数都不存在
|
||||||
@ -190,7 +193,58 @@ async def fusedsearch(request, params_kw, *params):
|
|||||||
"error": str(e)
|
"error": str(e)
|
||||||
}
|
}
|
||||||
|
|
||||||
|
async def _validate_fiids_orgid(fiids, orgid, kw):
|
||||||
|
"""验证 fiids 的 orgid 与当前用户 orgid 是否一致"""
|
||||||
|
if fiids:
|
||||||
|
db = DBPools()
|
||||||
|
dbname = kw.get('get_module_dbname')('rag')
|
||||||
|
sql_opts = """SELECT orgid FROM kdb WHERE id = ${id}$"""
|
||||||
|
try:
|
||||||
|
async with db.sqlorContext(dbname) as sor:
|
||||||
|
result = await sor.sqlExe(sql_opts, {"id": fiids[0]})
|
||||||
|
if not result:
|
||||||
|
raise ValueError(f"未找到 fiid={fiids[0]} 的记录")
|
||||||
|
kdb_orgid = result[0].get('orgid')
|
||||||
|
if kdb_orgid != orgid:
|
||||||
|
raise ValueError(f"orgid 不一致: kdb.orgid={kdb_orgid}, user orgid={orgid}")
|
||||||
|
except Exception as e:
|
||||||
|
error(f"orgid 验证失败: {str(e)}")
|
||||||
|
raise
|
||||||
|
|
||||||
|
async def _validate_fiids_orgid(fiids, orgid, kw):
|
||||||
|
"""验证 fiids 的 orgid 与当前用户 orgid 是否一致"""
|
||||||
|
if fiids:
|
||||||
|
db = DBPools()
|
||||||
|
dbname = kw.get('get_module_dbname')('rag')
|
||||||
|
sql_opts = """SELECT orgid FROM kdb WHERE id = ${id}$"""
|
||||||
|
try:
|
||||||
|
async with db.sqlorContext(dbname) as sor:
|
||||||
|
result = await sor.sqlExe(sql_opts, {"id": fiids[0]})
|
||||||
|
if not result:
|
||||||
|
raise ValueError(f"未找到 fiid={fiids[0]} 的记录")
|
||||||
|
kdb_orgid = result[0].get('orgid')
|
||||||
|
if kdb_orgid != orgid:
|
||||||
|
raise ValueError(f"orgid 不一致: kdb.orgid={kdb_orgid}, user orgid={orgid}")
|
||||||
|
except Exception as e:
|
||||||
|
error(f"orgid 验证失败: {str(e)}")
|
||||||
|
raise
|
||||||
|
|
||||||
|
|
||||||
|
def _combine_query_with_triplets(query, triplets):
|
||||||
|
"""拼接查询文本和三元组文本"""
|
||||||
|
triplet_texts = []
|
||||||
|
for triplet in triplets:
|
||||||
|
head = triplet.get('head', '')
|
||||||
|
type_ = triplet.get('type', '')
|
||||||
|
tail = triplet.get('tail', '')
|
||||||
|
if head and type_ and tail:
|
||||||
|
triplet_texts.append(f"{head} {type_} {tail}")
|
||||||
|
else:
|
||||||
|
debug(f"无效三元组: {triplet}")
|
||||||
|
|
||||||
|
combined_text = query
|
||||||
|
if triplet_texts:
|
||||||
|
combined_text += "".join(triplet_texts)
|
||||||
|
|
||||||
|
debug(f"拼接文本: {combined_text[:200]}... (总长度: {len(combined_text)}, 三元组数量: {len(triplet_texts)})")
|
||||||
|
return combined_text
|
||||||
Loading…
x
Reference in New Issue
Block a user