diff --git a/rag/init.py b/rag/init.py index c903eed..89a2cc3 100644 --- a/rag/init.py +++ b/rag/init.py @@ -9,7 +9,4 @@ def load_rag(): rf.register('docs', docs) rf.register('get_kdbs', get_kdbs) rf.register('fusedsearch', fusedsearch) - rf.register('add_user_messages', add_user_messages) - rf.register('get_user_memories', get_user_memories) - diff --git a/rag/ragapi.py b/rag/ragapi.py index cfbd795..e3aa762 100644 --- a/rag/ragapi.py +++ b/rag/ragapi.py @@ -7,7 +7,6 @@ import json import math from rag.service_opts import get_service_params, sor_get_service_params from rag.rag_operations import RagOperations -from rag.aslmapi import MemoryManager helptext = """kyrag API: @@ -195,123 +194,5 @@ async def fusedsearch(request, params_kw, *params): } -async def add_user_messages(request, params_kw, *params, **kw): - """添加用户消息到记忆库 - - Args: - request: HTTP 请求对象 - params_kw: 请求参数字典,包含 orgid 和 messages - *params: 任意位置参数 - **kw: 任意关键字参数(如 get_module_dbname, get_userorgid) - - Returns: - str: JSON 格式的响应,包含 status 和 result 或错误信息 - """ - debug(f"Received request: path={request.path}, params_kw={params_kw}, kw={kw}") - orgid = params_kw.get('orgid') - messages = params_kw.get('messages') - - if not orgid or not isinstance(orgid, str): - error("orgid 参数无效,必须为非空字符串") - return json.dumps({"status": "error", "message": "orgid 参数无效,必须为非空字符串"}) - - if not messages or not isinstance(messages, list): - error("messages 参数无效,必须为非空列表") - return json.dumps({"status": "error", "message": "messages 参数无效,必须为非空列表"}) - - try: - manager = MemoryManager() - result = await manager.add_messages_to_memory(messages, orgid) - debug(f"用户消息添加结果: {result}") - return json.dumps(result, ensure_ascii=False) - except Exception as e: - error(f"添加用户消息失败: {str(e)}") - return json.dumps({"status": "error", "message": str(e)}) - -async def get_user_memories(request, params_kw, *params, **kw): - """获取用户的所有记忆记录 - - Args: - request: HTTP 请求对象 - params_kw: 请求参数字典,包含 orgid 和可选的 limit - *params: 任意位置参数 - **kw: 任意关键字参数(如 get_module_dbname, get_userorgid) - - Returns: - str: JSON 格式的响应,包含 status 和 result 或错误信息 - """ - debug(f"Received request: path={request.path}, params_kw={params_kw}, kw={kw}") - orgid = params_kw.get('orgid') - limit = params_kw.get('limit', 10) - - if not orgid or not isinstance(orgid, str): - error("orgid 参数无效,必须为非空字符串") - return json.dumps({"status": "error", "message": "orgid 参数无效,必须为非空字符串"}) - - try: - limit = int(limit) if isinstance(limit, (str, int, float)) else 10 - manager = MemoryManager() - result = await manager.get_all_memories(user_id=orgid, limit=limit) - debug(f"用户 {orgid} 的记忆检索结果: {result}") - return json.dumps(result, ensure_ascii=False) - except Exception as e: - error(f"获取用户 {orgid} 的记忆失败: {str(e)}") - return json.dumps({"status": "error", "message": str(e)}) - -async def _validate_fiids_orgid(fiids, orgid, kw): - """验证 fiids 的 orgid 与当前用户 orgid 是否一致""" - if fiids: - db = DBPools() - dbname = kw.get('get_module_dbname')('rag') - sql_opts = """SELECT orgid FROM kdb WHERE id = ${id}$""" - try: - async with db.sqlorContext(dbname) as sor: - result = await sor.sqlExe(sql_opts, {"id": fiids[0]}) - if not result: - raise ValueError(f"未找到 fiid={fiids[0]} 的记录") - kdb_orgid = result[0].get('orgid') - if kdb_orgid != orgid: - raise ValueError(f"orgid 不一致: kdb.orgid={kdb_orgid}, user orgid={orgid}") - except Exception as e: - error(f"orgid 验证失败: {str(e)}") - raise - -async def _validate_fiids_orgid(fiids, orgid, kw): - """验证 fiids 的 orgid 与当前用户 orgid 是否一致""" - if fiids: - db = DBPools() - dbname = kw.get('get_module_dbname')('rag') - sql_opts = """SELECT orgid FROM kdb WHERE id = ${id}$""" - try: - async with db.sqlorContext(dbname) as sor: - result = await sor.sqlExe(sql_opts, {"id": fiids[0]}) - if not result: - raise ValueError(f"未找到 fiid={fiids[0]} 的记录") - kdb_orgid = result[0].get('orgid') - if kdb_orgid != orgid: - raise ValueError(f"orgid 不一致: kdb.orgid={kdb_orgid}, user orgid={orgid}") - except Exception as e: - error(f"orgid 验证失败: {str(e)}") - raise - - -def _combine_query_with_triplets(query, triplets): - """拼接查询文本和三元组文本""" - triplet_texts = [] - for triplet in triplets: - head = triplet.get('head', '') - type_ = triplet.get('type', '') - tail = triplet.get('tail', '') - if head and type_ and tail: - triplet_texts.append(f"{head} {type_} {tail}") - else: - debug(f"无效三元组: {triplet}") - - combined_text = query - if triplet_texts: - combined_text += "".join(triplet_texts) - - debug(f"拼接文本: {combined_text[:200]}... (总长度: {len(combined_text)}, 三元组数量: {len(triplet_texts)})") - return combined_text