From 1bca72c8fce7ad639e14bd0368907531e544684f Mon Sep 17 00:00:00 2001 From: wangmeihua <13383952685@163.com> Date: Fri, 10 Oct 2025 14:30:17 +0800 Subject: [PATCH] rag --- rag/ragapi.py | 54 +++++++++++++++++++++++++++++++++++++++++++++++++++ 1 file changed, 54 insertions(+) diff --git a/rag/ragapi.py b/rag/ragapi.py index 8fc3469..d35aa67 100644 --- a/rag/ragapi.py +++ b/rag/ragapi.py @@ -124,6 +124,9 @@ async def fusedsearch(request, params_kw, *params): debug(f"limit: {limit}") raw_fiids = params_kw.get('fiids') or params_kw.get('knowledge_id') # + # 验证 fiids的orgid与orgid = await f()是否一致 + await _validate_fiids_orgid(fiids, orgid, kw) + # 标准化为列表格式 if raw_fiids is None: fiids = [] # 两个参数都不存在 @@ -190,7 +193,58 @@ async def fusedsearch(request, params_kw, *params): "error": str(e) } +async def _validate_fiids_orgid(fiids, orgid, kw): + """验证 fiids 的 orgid 与当前用户 orgid 是否一致""" + if fiids: + db = DBPools() + dbname = kw.get('get_module_dbname')('rag') + sql_opts = """SELECT orgid FROM kdb WHERE id = ${id}$""" + try: + async with db.sqlorContext(dbname) as sor: + result = await sor.sqlExe(sql_opts, {"id": fiids[0]}) + if not result: + raise ValueError(f"未找到 fiid={fiids[0]} 的记录") + kdb_orgid = result[0].get('orgid') + if kdb_orgid != orgid: + raise ValueError(f"orgid 不一致: kdb.orgid={kdb_orgid}, user orgid={orgid}") + except Exception as e: + error(f"orgid 验证失败: {str(e)}") + raise + +async def _validate_fiids_orgid(fiids, orgid, kw): + """验证 fiids 的 orgid 与当前用户 orgid 是否一致""" + if fiids: + db = DBPools() + dbname = kw.get('get_module_dbname')('rag') + sql_opts = """SELECT orgid FROM kdb WHERE id = ${id}$""" + try: + async with db.sqlorContext(dbname) as sor: + result = await sor.sqlExe(sql_opts, {"id": fiids[0]}) + if not result: + raise ValueError(f"未找到 fiid={fiids[0]} 的记录") + kdb_orgid = result[0].get('orgid') + if kdb_orgid != orgid: + raise ValueError(f"orgid 不一致: kdb.orgid={kdb_orgid}, user orgid={orgid}") + except Exception as e: + error(f"orgid 验证失败: {str(e)}") + raise +def _combine_query_with_triplets(query, triplets): + """拼接查询文本和三元组文本""" + triplet_texts = [] + for triplet in triplets: + head = triplet.get('head', '') + type_ = triplet.get('type', '') + tail = triplet.get('tail', '') + if head and type_ and tail: + triplet_texts.append(f"{head} {type_} {tail}") + else: + debug(f"无效三元组: {triplet}") + combined_text = query + if triplet_texts: + combined_text += "".join(triplet_texts) + debug(f"拼接文本: {combined_text[:200]}... (总长度: {len(combined_text)}, 三元组数量: {len(triplet_texts)})") + return combined_text \ No newline at end of file