rag/rag/ragapi.py
2025-10-10 14:32:13 +08:00

250 lines
9.2 KiB
Python
Raw Blame History

This file contains ambiguous Unicode characters

This file contains Unicode characters that might be confused with other characters. If you think that this is intentional, you can safely ignore this warning. Use the Escape button to reveal them.

from rag.uapi_service import APIService
from sqlor.dbpools import DBPools
from appPublic.log import debug, error, info
import time
import traceback
import json
import math
from rag.service_opts import get_service_params, sor_get_service_params
from rag.rag_operations import RagOperations
helptext = """kyrag API:
1. 得到kdb表:
path: /v1/get_kdbs
headers: {
"Content-Type": "application/json"
}
response:
[{"id": "1", "name": "textdb", "description": "文本数据库"}, {"id": "testkdb", "name": "testkdb", "description": ""}, {"id": "Vdtbt3qBfocteit1HIxVH", "name": "trm", "description": ""}]
2. 向量检索文本块:
path: /v1/fusedsearch
headers: {
"Content-Type": "application/json"
}
data: {
"query": "什么是知识抽取。"
"fiids":["1"]
}
3、docs文档
path: /v1/docs
4. 添加用户消息到记忆:
path: /v1/add_user_messages
headers: {
"Content-Type": "application/json"
}
data: {
"orgid": "用户组织ID",
"messages": [{"role": "user", "content": "消息内容"}, ...]
}
5. 获取用户所有记忆:
path: /v1/get_user_memories
headers: {
"Content-Type": "application/json"
}
data: {
"orgid": "用户组织ID",
"limit": 10 # 可选,默认为 10
}
"""
async def docs(request, params_kw, *params, **kw):
return helptext
async def get_kdbs(request, params_kw, *params, **kw):
"""返回 kdb 表的全部内容,返回 JSON"""
f = kw.get('get_userorgid')
orgid = await f()
debug(f"orgid: {orgid}{f=}")
debug(f"params_kw: {params_kw}")
db = DBPools()
dbname = kw.get('get_module_dbname')('rag')
sql_opts = """
SELECT id, name, description
FROM kdb
WHERE orgid = ${orgid}$
"""
try:
async with db.sqlorContext(dbname) as sor:
opts_result = await sor.sqlExe(sql_opts, {"orgid": orgid})
if not opts_result:
error("未找到 kdb 表记录")
return json.dumps({"status": "error", "message": "未找到记录"})
return json.dumps(opts_result, ensure_ascii=False)
except Exception as e:
error(f"查询 kdb 表失败: {str(e)}, 堆栈: {traceback.format_exc()}")
return json.dumps({"status": "error", "message": str(e)})
except Exception as e:
error(f"列出用户文件失败: {str(e)}, 堆栈: {traceback.format_exc()}")
result.update({
"status": "error",
"files_by_knowledge_base": {},
"message": f"列出用户文件失败: {str(e)}",
"status_code": 400
})
return result
async def fusedsearch(request, params_kw, *params):
"""融合搜索,调用服务化端点"""
kw = request._run_ns
f = kw.get('get_userorgid')
orgid = await f()
debug(f"orgid: {orgid}{f=}")
f = kw.get('get_user')
userid = await f()
debug(f"params_kw: {params_kw}")
# orgid = "04J6VbxLqB_9RPMcgOv_8"
# userid = "04J6VbxLqB_9RPMcgOv_8"
query = params_kw.get('query', '')
# 统一模式处理 limit 参数,为了对接dify和coze
raw_limit = params_kw.get('limit') or (
params_kw.get('retrieval_setting', {}).get('top_k')
if isinstance(params_kw.get('retrieval_setting'), dict)
else None
)
# 标准化为整数值
if raw_limit is None:
limit = 5 # 两个来源都不存在时使用默认值
elif isinstance(raw_limit, (int, float)):
limit = int(raw_limit) # 数值类型直接转换
elif isinstance(raw_limit, str):
try:
# 字符串转换为整数
limit = int(raw_limit)
except (TypeError, ValueError):
limit = 5 # 转换失败使用默认值
else:
limit = 5 # 其他意外类型使用默认值
debug(f"limit: {limit}")
raw_fiids = params_kw.get('fiids') or params_kw.get('knowledge_id') #
# 标准化为列表格式
if raw_fiids is None:
fiids = [] # 两个参数都不存在
elif isinstance(raw_fiids, list):
fiids = [str(item).strip() for item in raw_fiids] # 已经是列表
elif isinstance(raw_fiids, str):
# fiids = [f.strip() for f in raw_fiids.split(',') if f.strip()]
try:
# 尝试解析 JSON 字符串
parsed = json.loads(raw_fiids)
if isinstance(parsed, list):
fiids = [str(item).strip() for item in parsed] # JSON 数组转为字符串列表
else:
# 处理逗号分隔的字符串或单个 ID 字符串
fiids = [f.strip() for f in raw_fiids.split(',') if f.strip()]
except json.JSONDecodeError:
# 如果不是合法 JSON按逗号分隔
fiids = [f.strip() for f in raw_fiids.split(',') if f.strip()]
elif isinstance(raw_fiids, (int, float)):
fiids = [str(int(raw_fiids))] # 数值类型转为字符串列表
else:
fiids = [] # 其他意外类型
debug(f"fiids: {fiids}")
# 验证 fiids的orgid与orgid = await f()是否一致
await _validate_fiids_orgid(fiids, orgid, kw)
service_params = await get_service_params(orgid)
if not service_params:
raise ValueError("无法获取服务参数")
try:
timings = {}
start_time = time.time()
rag_ops = RagOperations()
query_entities = await rag_ops.extract_entities(request, query, service_params, userid, timings)
all_triplets = await rag_ops.match_triplets(request, query, query_entities, orgid, fiids, service_params,
userid, timings)
combined_text = _combine_query_with_triplets(query, all_triplets)
query_vector = await rag_ops.generate_query_vector(request, combined_text, service_params, userid, timings)
search_results = await rag_ops.vector_search(request, query_vector, orgid, fiids, limit + 5, service_params,
userid, timings)
use_rerank = True
if use_rerank and search_results:
final_results = await rag_ops.rerank_results(request, combined_text, search_results, limit, service_params,
userid, timings)
debug(f"final_results: {final_results}")
else:
final_results = [{k: v for k, v in r.items() if k != 'rerank_score'} for r in search_results]
formatted_results = rag_ops.format_search_results(final_results, limit)
timings["total_time"] = time.time() - start_time
info(f"融合搜索完成,返回 {len(formatted_results)} 条结果,总耗时: {timings['total_time']:.3f}")
return {
"records": formatted_results,
"timings": timings
}
except Exception as e:
error(f"融合搜索失败: {str(e)}, 堆栈: {traceback.format_exc()}")
return {
"records": [],
"timings": {"total_time": time.time() - start_time if 'start_time' in locals() else 0},
"error": str(e)
}
async def _validate_fiids_orgid(fiids, orgid, kw):
"""验证 fiids 的 orgid 与当前用户 orgid 是否一致"""
if fiids:
db = DBPools()
dbname = kw.get('get_module_dbname')('rag')
sql_opts = """SELECT orgid FROM kdb WHERE id = ${id}$"""
try:
async with db.sqlorContext(dbname) as sor:
result = await sor.sqlExe(sql_opts, {"id": fiids[0]})
if not result:
raise ValueError(f"未找到 fiid={fiids[0]} 的记录")
kdb_orgid = result[0].get('orgid')
if kdb_orgid != orgid:
raise ValueError(f"orgid 不一致: kdb.orgid={kdb_orgid}, user orgid={orgid}")
except Exception as e:
error(f"orgid 验证失败: {str(e)}")
raise
async def _validate_fiids_orgid(fiids, orgid, kw):
"""验证 fiids 的 orgid 与当前用户 orgid 是否一致"""
if fiids:
db = DBPools()
dbname = kw.get('get_module_dbname')('rag')
sql_opts = """SELECT orgid FROM kdb WHERE id = ${id}$"""
try:
async with db.sqlorContext(dbname) as sor:
result = await sor.sqlExe(sql_opts, {"id": fiids[0]})
if not result:
raise ValueError(f"未找到 fiid={fiids[0]} 的记录")
kdb_orgid = result[0].get('orgid')
if kdb_orgid != orgid:
raise ValueError(f"orgid 不一致: kdb.orgid={kdb_orgid}, user orgid={orgid}")
except Exception as e:
error(f"orgid 验证失败: {str(e)}")
raise
def _combine_query_with_triplets(query, triplets):
"""拼接查询文本和三元组文本"""
triplet_texts = []
for triplet in triplets:
head = triplet.get('head', '')
type_ = triplet.get('type', '')
tail = triplet.get('tail', '')
if head and type_ and tail:
triplet_texts.append(f"{head} {type_} {tail}")
else:
debug(f"无效三元组: {triplet}")
combined_text = query
if triplet_texts:
combined_text += "".join(triplet_texts)
debug(f"拼接文本: {combined_text[:200]}... (总长度: {len(combined_text)}, 三元组数量: {len(triplet_texts)})")
return combined_text