This commit is contained in:
wangmeihua 2025-10-10 13:50:29 +08:00
parent 672e96a09a
commit b3bf2ba1db
2 changed files with 0 additions and 122 deletions

View File

@ -9,7 +9,4 @@ def load_rag():
rf.register('docs', docs)
rf.register('get_kdbs', get_kdbs)
rf.register('fusedsearch', fusedsearch)
rf.register('add_user_messages', add_user_messages)
rf.register('get_user_memories', get_user_memories)

View File

@ -7,7 +7,6 @@ import json
import math
from rag.service_opts import get_service_params, sor_get_service_params
from rag.rag_operations import RagOperations
from rag.aslmapi import MemoryManager
helptext = """kyrag API:
@ -195,123 +194,5 @@ async def fusedsearch(request, params_kw, *params):
}
async def add_user_messages(request, params_kw, *params, **kw):
"""添加用户消息到记忆库
Args:
request: HTTP 请求对象
params_kw: 请求参数字典包含 orgid messages
*params: 任意位置参数
**kw: 任意关键字参数 get_module_dbname, get_userorgid
Returns:
str: JSON 格式的响应包含 status result 或错误信息
"""
debug(f"Received request: path={request.path}, params_kw={params_kw}, kw={kw}")
orgid = params_kw.get('orgid')
messages = params_kw.get('messages')
if not orgid or not isinstance(orgid, str):
error("orgid 参数无效,必须为非空字符串")
return json.dumps({"status": "error", "message": "orgid 参数无效,必须为非空字符串"})
if not messages or not isinstance(messages, list):
error("messages 参数无效,必须为非空列表")
return json.dumps({"status": "error", "message": "messages 参数无效,必须为非空列表"})
try:
manager = MemoryManager()
result = await manager.add_messages_to_memory(messages, orgid)
debug(f"用户消息添加结果: {result}")
return json.dumps(result, ensure_ascii=False)
except Exception as e:
error(f"添加用户消息失败: {str(e)}")
return json.dumps({"status": "error", "message": str(e)})
async def get_user_memories(request, params_kw, *params, **kw):
"""获取用户的所有记忆记录
Args:
request: HTTP 请求对象
params_kw: 请求参数字典包含 orgid 和可选的 limit
*params: 任意位置参数
**kw: 任意关键字参数 get_module_dbname, get_userorgid
Returns:
str: JSON 格式的响应包含 status result 或错误信息
"""
debug(f"Received request: path={request.path}, params_kw={params_kw}, kw={kw}")
orgid = params_kw.get('orgid')
limit = params_kw.get('limit', 10)
if not orgid or not isinstance(orgid, str):
error("orgid 参数无效,必须为非空字符串")
return json.dumps({"status": "error", "message": "orgid 参数无效,必须为非空字符串"})
try:
limit = int(limit) if isinstance(limit, (str, int, float)) else 10
manager = MemoryManager()
result = await manager.get_all_memories(user_id=orgid, limit=limit)
debug(f"用户 {orgid} 的记忆检索结果: {result}")
return json.dumps(result, ensure_ascii=False)
except Exception as e:
error(f"获取用户 {orgid} 的记忆失败: {str(e)}")
return json.dumps({"status": "error", "message": str(e)})
async def _validate_fiids_orgid(fiids, orgid, kw):
"""验证 fiids 的 orgid 与当前用户 orgid 是否一致"""
if fiids:
db = DBPools()
dbname = kw.get('get_module_dbname')('rag')
sql_opts = """SELECT orgid FROM kdb WHERE id = ${id}$"""
try:
async with db.sqlorContext(dbname) as sor:
result = await sor.sqlExe(sql_opts, {"id": fiids[0]})
if not result:
raise ValueError(f"未找到 fiid={fiids[0]} 的记录")
kdb_orgid = result[0].get('orgid')
if kdb_orgid != orgid:
raise ValueError(f"orgid 不一致: kdb.orgid={kdb_orgid}, user orgid={orgid}")
except Exception as e:
error(f"orgid 验证失败: {str(e)}")
raise
async def _validate_fiids_orgid(fiids, orgid, kw):
"""验证 fiids 的 orgid 与当前用户 orgid 是否一致"""
if fiids:
db = DBPools()
dbname = kw.get('get_module_dbname')('rag')
sql_opts = """SELECT orgid FROM kdb WHERE id = ${id}$"""
try:
async with db.sqlorContext(dbname) as sor:
result = await sor.sqlExe(sql_opts, {"id": fiids[0]})
if not result:
raise ValueError(f"未找到 fiid={fiids[0]} 的记录")
kdb_orgid = result[0].get('orgid')
if kdb_orgid != orgid:
raise ValueError(f"orgid 不一致: kdb.orgid={kdb_orgid}, user orgid={orgid}")
except Exception as e:
error(f"orgid 验证失败: {str(e)}")
raise
def _combine_query_with_triplets(query, triplets):
"""拼接查询文本和三元组文本"""
triplet_texts = []
for triplet in triplets:
head = triplet.get('head', '')
type_ = triplet.get('type', '')
tail = triplet.get('tail', '')
if head and type_ and tail:
triplet_texts.append(f"{head} {type_} {tail}")
else:
debug(f"无效三元组: {triplet}")
combined_text = query
if triplet_texts:
combined_text += "".join(triplet_texts)
debug(f"拼接文本: {combined_text[:200]}... (总长度: {len(combined_text)}, 三元组数量: {len(triplet_texts)})")
return combined_text