250 lines
9.2 KiB
Python
250 lines
9.2 KiB
Python
from rag.uapi_service import APIService
|
||
from sqlor.dbpools import DBPools
|
||
from appPublic.log import debug, error, info
|
||
import time
|
||
import traceback
|
||
import json
|
||
import math
|
||
from rag.service_opts import get_service_params, sor_get_service_params
|
||
from rag.rag_operations import RagOperations
|
||
|
||
helptext = """kyrag API:
|
||
|
||
1. 得到kdb表:
|
||
path: /v1/get_kdbs
|
||
headers: {
|
||
"Content-Type": "application/json"
|
||
}
|
||
response:
|
||
[{"id": "1", "name": "textdb", "description": "文本数据库"}, {"id": "testkdb", "name": "testkdb", "description": ""}, {"id": "Vdtbt3qBfocteit1HIxVH", "name": "trm", "description": ""}]
|
||
|
||
2. 向量检索文本块:
|
||
path: /v1/fusedsearch
|
||
headers: {
|
||
"Content-Type": "application/json"
|
||
}
|
||
data: {
|
||
"query": "什么是知识抽取。"
|
||
"fiids":["1"]
|
||
}
|
||
|
||
3、docs文档
|
||
path: /v1/docs
|
||
|
||
4. 添加用户消息到记忆:
|
||
path: /v1/add_user_messages
|
||
headers: {
|
||
"Content-Type": "application/json"
|
||
}
|
||
data: {
|
||
"orgid": "用户组织ID",
|
||
"messages": [{"role": "user", "content": "消息内容"}, ...]
|
||
}
|
||
|
||
5. 获取用户所有记忆:
|
||
path: /v1/get_user_memories
|
||
headers: {
|
||
"Content-Type": "application/json"
|
||
}
|
||
data: {
|
||
"orgid": "用户组织ID",
|
||
"limit": 10 # 可选,默认为 10
|
||
}
|
||
"""
|
||
|
||
async def docs(request, params_kw, *params, **kw):
|
||
return helptext
|
||
|
||
async def get_kdbs(request, params_kw, *params, **kw):
|
||
"""返回 kdb 表的全部内容,返回 JSON"""
|
||
f = kw.get('get_userorgid')
|
||
orgid = await f()
|
||
debug(f"orgid: {orgid},{f=}")
|
||
debug(f"params_kw: {params_kw}")
|
||
db = DBPools()
|
||
dbname = kw.get('get_module_dbname')('rag')
|
||
sql_opts = """
|
||
SELECT id, name, description
|
||
FROM kdb
|
||
WHERE orgid = ${orgid}$
|
||
"""
|
||
try:
|
||
async with db.sqlorContext(dbname) as sor:
|
||
opts_result = await sor.sqlExe(sql_opts, {"orgid": orgid})
|
||
if not opts_result:
|
||
error("未找到 kdb 表记录")
|
||
return json.dumps({"status": "error", "message": "未找到记录"})
|
||
return json.dumps(opts_result, ensure_ascii=False)
|
||
except Exception as e:
|
||
error(f"查询 kdb 表失败: {str(e)}, 堆栈: {traceback.format_exc()}")
|
||
return json.dumps({"status": "error", "message": str(e)})
|
||
|
||
except Exception as e:
|
||
error(f"列出用户文件失败: {str(e)}, 堆栈: {traceback.format_exc()}")
|
||
result.update({
|
||
"status": "error",
|
||
"files_by_knowledge_base": {},
|
||
"message": f"列出用户文件失败: {str(e)}",
|
||
"status_code": 400
|
||
})
|
||
return result
|
||
|
||
async def fusedsearch(request, params_kw, *params):
|
||
"""融合搜索,调用服务化端点"""
|
||
kw = request._run_ns
|
||
f = kw.get('get_userorgid')
|
||
orgid = await f()
|
||
debug(f"orgid: {orgid},{f=}")
|
||
f = kw.get('get_user')
|
||
userid = await f()
|
||
debug(f"params_kw: {params_kw}")
|
||
# orgid = "04J6VbxLqB_9RPMcgOv_8"
|
||
# userid = "04J6VbxLqB_9RPMcgOv_8"
|
||
query = params_kw.get('query', '')
|
||
# 统一模式处理 limit 参数,为了对接dify和coze
|
||
raw_limit = params_kw.get('limit') or (
|
||
params_kw.get('retrieval_setting', {}).get('top_k')
|
||
if isinstance(params_kw.get('retrieval_setting'), dict)
|
||
else None
|
||
)
|
||
|
||
# 标准化为整数值
|
||
if raw_limit is None:
|
||
limit = 5 # 两个来源都不存在时使用默认值
|
||
elif isinstance(raw_limit, (int, float)):
|
||
limit = int(raw_limit) # 数值类型直接转换
|
||
elif isinstance(raw_limit, str):
|
||
try:
|
||
# 字符串转换为整数
|
||
limit = int(raw_limit)
|
||
except (TypeError, ValueError):
|
||
limit = 5 # 转换失败使用默认值
|
||
else:
|
||
limit = 5 # 其他意外类型使用默认值
|
||
debug(f"limit: {limit}")
|
||
raw_fiids = params_kw.get('fiids') or params_kw.get('knowledge_id') #
|
||
|
||
# 标准化为列表格式
|
||
if raw_fiids is None:
|
||
fiids = [] # 两个参数都不存在
|
||
elif isinstance(raw_fiids, list):
|
||
fiids = [str(item).strip() for item in raw_fiids] # 已经是列表
|
||
elif isinstance(raw_fiids, str):
|
||
# fiids = [f.strip() for f in raw_fiids.split(',') if f.strip()]
|
||
try:
|
||
# 尝试解析 JSON 字符串
|
||
parsed = json.loads(raw_fiids)
|
||
if isinstance(parsed, list):
|
||
fiids = [str(item).strip() for item in parsed] # JSON 数组转为字符串列表
|
||
else:
|
||
# 处理逗号分隔的字符串或单个 ID 字符串
|
||
fiids = [f.strip() for f in raw_fiids.split(',') if f.strip()]
|
||
except json.JSONDecodeError:
|
||
# 如果不是合法 JSON,按逗号分隔
|
||
fiids = [f.strip() for f in raw_fiids.split(',') if f.strip()]
|
||
elif isinstance(raw_fiids, (int, float)):
|
||
fiids = [str(int(raw_fiids))] # 数值类型转为字符串列表
|
||
else:
|
||
fiids = [] # 其他意外类型
|
||
|
||
debug(f"fiids: {fiids}")
|
||
|
||
# 验证 fiids的orgid与orgid = await f()是否一致
|
||
await _validate_fiids_orgid(fiids, orgid, kw)
|
||
|
||
service_params = await get_service_params(orgid)
|
||
if not service_params:
|
||
raise ValueError("无法获取服务参数")
|
||
|
||
try:
|
||
timings = {}
|
||
start_time = time.time()
|
||
rag_ops = RagOperations()
|
||
|
||
query_entities = await rag_ops.extract_entities(request, query, service_params, userid, timings)
|
||
all_triplets = await rag_ops.match_triplets(request, query, query_entities, orgid, fiids, service_params,
|
||
userid, timings)
|
||
combined_text = _combine_query_with_triplets(query, all_triplets)
|
||
query_vector = await rag_ops.generate_query_vector(request, combined_text, service_params, userid, timings)
|
||
search_results = await rag_ops.vector_search(request, query_vector, orgid, fiids, limit + 5, service_params,
|
||
userid, timings)
|
||
|
||
use_rerank = True
|
||
if use_rerank and search_results:
|
||
final_results = await rag_ops.rerank_results(request, combined_text, search_results, limit, service_params,
|
||
userid, timings)
|
||
debug(f"final_results: {final_results}")
|
||
else:
|
||
final_results = [{k: v for k, v in r.items() if k != 'rerank_score'} for r in search_results]
|
||
|
||
formatted_results = rag_ops.format_search_results(final_results, limit)
|
||
timings["total_time"] = time.time() - start_time
|
||
info(f"融合搜索完成,返回 {len(formatted_results)} 条结果,总耗时: {timings['total_time']:.3f} 秒")
|
||
|
||
return {
|
||
"records": formatted_results,
|
||
"timings": timings
|
||
}
|
||
except Exception as e:
|
||
error(f"融合搜索失败: {str(e)}, 堆栈: {traceback.format_exc()}")
|
||
return {
|
||
"records": [],
|
||
"timings": {"total_time": time.time() - start_time if 'start_time' in locals() else 0},
|
||
"error": str(e)
|
||
}
|
||
|
||
async def _validate_fiids_orgid(fiids, orgid, kw):
|
||
"""验证 fiids 的 orgid 与当前用户 orgid 是否一致"""
|
||
if fiids:
|
||
db = DBPools()
|
||
dbname = kw.get('get_module_dbname')('rag')
|
||
sql_opts = """SELECT orgid FROM kdb WHERE id = ${id}$"""
|
||
try:
|
||
async with db.sqlorContext(dbname) as sor:
|
||
result = await sor.sqlExe(sql_opts, {"id": fiids[0]})
|
||
if not result:
|
||
raise ValueError(f"未找到 fiid={fiids[0]} 的记录")
|
||
kdb_orgid = result[0].get('orgid')
|
||
if kdb_orgid != orgid:
|
||
raise ValueError(f"orgid 不一致: kdb.orgid={kdb_orgid}, user orgid={orgid}")
|
||
except Exception as e:
|
||
error(f"orgid 验证失败: {str(e)}")
|
||
raise
|
||
|
||
async def _validate_fiids_orgid(fiids, orgid, kw):
|
||
"""验证 fiids 的 orgid 与当前用户 orgid 是否一致"""
|
||
if fiids:
|
||
db = DBPools()
|
||
dbname = kw.get('get_module_dbname')('rag')
|
||
sql_opts = """SELECT orgid FROM kdb WHERE id = ${id}$"""
|
||
try:
|
||
async with db.sqlorContext(dbname) as sor:
|
||
result = await sor.sqlExe(sql_opts, {"id": fiids[0]})
|
||
if not result:
|
||
raise ValueError(f"未找到 fiid={fiids[0]} 的记录")
|
||
kdb_orgid = result[0].get('orgid')
|
||
if kdb_orgid != orgid:
|
||
raise ValueError(f"orgid 不一致: kdb.orgid={kdb_orgid}, user orgid={orgid}")
|
||
except Exception as e:
|
||
error(f"orgid 验证失败: {str(e)}")
|
||
raise
|
||
|
||
|
||
def _combine_query_with_triplets(query, triplets):
|
||
"""拼接查询文本和三元组文本"""
|
||
triplet_texts = []
|
||
for triplet in triplets:
|
||
head = triplet.get('head', '')
|
||
type_ = triplet.get('type', '')
|
||
tail = triplet.get('tail', '')
|
||
if head and type_ and tail:
|
||
triplet_texts.append(f"{head} {type_} {tail}")
|
||
else:
|
||
debug(f"无效三元组: {triplet}")
|
||
|
||
combined_text = query
|
||
if triplet_texts:
|
||
combined_text += "".join(triplet_texts)
|
||
|
||
debug(f"拼接文本: {combined_text[:200]}... (总长度: {len(combined_text)}, 三元组数量: {len(triplet_texts)})")
|
||
return combined_text |