rag
This commit is contained in:
parent
672e96a09a
commit
b3bf2ba1db
@ -9,7 +9,4 @@ def load_rag():
|
||||
rf.register('docs', docs)
|
||||
rf.register('get_kdbs', get_kdbs)
|
||||
rf.register('fusedsearch', fusedsearch)
|
||||
rf.register('add_user_messages', add_user_messages)
|
||||
rf.register('get_user_memories', get_user_memories)
|
||||
|
||||
|
||||
|
||||
119
rag/ragapi.py
119
rag/ragapi.py
@ -7,7 +7,6 @@ import json
|
||||
import math
|
||||
from rag.service_opts import get_service_params, sor_get_service_params
|
||||
from rag.rag_operations import RagOperations
|
||||
from rag.aslmapi import MemoryManager
|
||||
|
||||
helptext = """kyrag API:
|
||||
|
||||
@ -195,123 +194,5 @@ async def fusedsearch(request, params_kw, *params):
|
||||
}
|
||||
|
||||
|
||||
async def add_user_messages(request, params_kw, *params, **kw):
|
||||
"""添加用户消息到记忆库
|
||||
|
||||
Args:
|
||||
request: HTTP 请求对象
|
||||
params_kw: 请求参数字典,包含 orgid 和 messages
|
||||
*params: 任意位置参数
|
||||
**kw: 任意关键字参数(如 get_module_dbname, get_userorgid)
|
||||
|
||||
Returns:
|
||||
str: JSON 格式的响应,包含 status 和 result 或错误信息
|
||||
"""
|
||||
debug(f"Received request: path={request.path}, params_kw={params_kw}, kw={kw}")
|
||||
orgid = params_kw.get('orgid')
|
||||
messages = params_kw.get('messages')
|
||||
|
||||
if not orgid or not isinstance(orgid, str):
|
||||
error("orgid 参数无效,必须为非空字符串")
|
||||
return json.dumps({"status": "error", "message": "orgid 参数无效,必须为非空字符串"})
|
||||
|
||||
if not messages or not isinstance(messages, list):
|
||||
error("messages 参数无效,必须为非空列表")
|
||||
return json.dumps({"status": "error", "message": "messages 参数无效,必须为非空列表"})
|
||||
|
||||
try:
|
||||
manager = MemoryManager()
|
||||
result = await manager.add_messages_to_memory(messages, orgid)
|
||||
debug(f"用户消息添加结果: {result}")
|
||||
return json.dumps(result, ensure_ascii=False)
|
||||
except Exception as e:
|
||||
error(f"添加用户消息失败: {str(e)}")
|
||||
return json.dumps({"status": "error", "message": str(e)})
|
||||
|
||||
async def get_user_memories(request, params_kw, *params, **kw):
|
||||
"""获取用户的所有记忆记录
|
||||
|
||||
Args:
|
||||
request: HTTP 请求对象
|
||||
params_kw: 请求参数字典,包含 orgid 和可选的 limit
|
||||
*params: 任意位置参数
|
||||
**kw: 任意关键字参数(如 get_module_dbname, get_userorgid)
|
||||
|
||||
Returns:
|
||||
str: JSON 格式的响应,包含 status 和 result 或错误信息
|
||||
"""
|
||||
debug(f"Received request: path={request.path}, params_kw={params_kw}, kw={kw}")
|
||||
orgid = params_kw.get('orgid')
|
||||
limit = params_kw.get('limit', 10)
|
||||
|
||||
if not orgid or not isinstance(orgid, str):
|
||||
error("orgid 参数无效,必须为非空字符串")
|
||||
return json.dumps({"status": "error", "message": "orgid 参数无效,必须为非空字符串"})
|
||||
|
||||
try:
|
||||
limit = int(limit) if isinstance(limit, (str, int, float)) else 10
|
||||
manager = MemoryManager()
|
||||
result = await manager.get_all_memories(user_id=orgid, limit=limit)
|
||||
debug(f"用户 {orgid} 的记忆检索结果: {result}")
|
||||
return json.dumps(result, ensure_ascii=False)
|
||||
except Exception as e:
|
||||
error(f"获取用户 {orgid} 的记忆失败: {str(e)}")
|
||||
return json.dumps({"status": "error", "message": str(e)})
|
||||
|
||||
async def _validate_fiids_orgid(fiids, orgid, kw):
|
||||
"""验证 fiids 的 orgid 与当前用户 orgid 是否一致"""
|
||||
if fiids:
|
||||
db = DBPools()
|
||||
dbname = kw.get('get_module_dbname')('rag')
|
||||
sql_opts = """SELECT orgid FROM kdb WHERE id = ${id}$"""
|
||||
try:
|
||||
async with db.sqlorContext(dbname) as sor:
|
||||
result = await sor.sqlExe(sql_opts, {"id": fiids[0]})
|
||||
if not result:
|
||||
raise ValueError(f"未找到 fiid={fiids[0]} 的记录")
|
||||
kdb_orgid = result[0].get('orgid')
|
||||
if kdb_orgid != orgid:
|
||||
raise ValueError(f"orgid 不一致: kdb.orgid={kdb_orgid}, user orgid={orgid}")
|
||||
except Exception as e:
|
||||
error(f"orgid 验证失败: {str(e)}")
|
||||
raise
|
||||
|
||||
async def _validate_fiids_orgid(fiids, orgid, kw):
|
||||
"""验证 fiids 的 orgid 与当前用户 orgid 是否一致"""
|
||||
if fiids:
|
||||
db = DBPools()
|
||||
dbname = kw.get('get_module_dbname')('rag')
|
||||
sql_opts = """SELECT orgid FROM kdb WHERE id = ${id}$"""
|
||||
try:
|
||||
async with db.sqlorContext(dbname) as sor:
|
||||
result = await sor.sqlExe(sql_opts, {"id": fiids[0]})
|
||||
if not result:
|
||||
raise ValueError(f"未找到 fiid={fiids[0]} 的记录")
|
||||
kdb_orgid = result[0].get('orgid')
|
||||
if kdb_orgid != orgid:
|
||||
raise ValueError(f"orgid 不一致: kdb.orgid={kdb_orgid}, user orgid={orgid}")
|
||||
except Exception as e:
|
||||
error(f"orgid 验证失败: {str(e)}")
|
||||
raise
|
||||
|
||||
|
||||
def _combine_query_with_triplets(query, triplets):
|
||||
"""拼接查询文本和三元组文本"""
|
||||
triplet_texts = []
|
||||
for triplet in triplets:
|
||||
head = triplet.get('head', '')
|
||||
type_ = triplet.get('type', '')
|
||||
tail = triplet.get('tail', '')
|
||||
if head and type_ and tail:
|
||||
triplet_texts.append(f"{head} {type_} {tail}")
|
||||
else:
|
||||
debug(f"无效三元组: {triplet}")
|
||||
|
||||
combined_text = query
|
||||
if triplet_texts:
|
||||
combined_text += "".join(triplet_texts)
|
||||
|
||||
debug(f"拼接文本: {combined_text[:200]}... (总长度: {len(combined_text)}, 三元组数量: {len(triplet_texts)})")
|
||||
return combined_text
|
||||
|
||||
|
||||
|
||||
Loading…
x
Reference in New Issue
Block a user