463 lines
20 KiB
Python
463 lines
20 KiB
Python
from rag.uapi_service import APIService
|
||
from sqlor.dbpools import DBPools
|
||
from appPublic.log import debug, error, info
|
||
import time
|
||
import traceback
|
||
import json
|
||
import math
|
||
from rag.service_opts import get_service_params, sor_get_service_params
|
||
from rag.rag_operations import RagOperations
|
||
|
||
helptext = """kyrag API:
|
||
|
||
1. 得到kdb表:
|
||
path: /v1/get_kdbs
|
||
headers: {
|
||
"Content-Type": "application/json"
|
||
}
|
||
response:
|
||
[{"id": "1", "name": "textdb", "description": "文本数据库"}, {"id": "testkdb", "name": "testkdb", "description": ""}, {"id": "Vdtbt3qBfocteit1HIxVH", "name": "trm", "description": ""}]
|
||
|
||
2. 向量检索文本块:
|
||
path: /v1/fusedsearch
|
||
headers: {
|
||
"Content-Type": "application/json"
|
||
}
|
||
data: {
|
||
"query": "什么是知识抽取。"
|
||
"fiids":["1"]
|
||
}
|
||
|
||
3、docs文档
|
||
path: /v1/docs
|
||
|
||
4、longmemory存储
|
||
|
||
"""
|
||
|
||
async def docs(request, params_kw, *params, **kw):
|
||
return helptext
|
||
|
||
async def get_kdbs(request, params_kw, *params, **kw):
|
||
"""返回 kdb 表的全部内容,返回 JSON"""
|
||
f = kw.get('get_userorgid')
|
||
orgid = await f()
|
||
debug(f"orgid: {orgid},{f=}")
|
||
debug(f"params_kw: {params_kw}")
|
||
db = DBPools()
|
||
dbname = kw.get('get_module_dbname')('rag')
|
||
sql_opts = """
|
||
SELECT id, name, description
|
||
FROM kdb
|
||
WHERE orgid = ${orgid}$
|
||
"""
|
||
try:
|
||
async with db.sqlorContext(dbname) as sor:
|
||
opts_result = await sor.sqlExe(sql_opts, {"orgid": orgid})
|
||
if not opts_result:
|
||
error("未找到 kdb 表记录")
|
||
return json.dumps({"status": "error", "message": "未找到记录"})
|
||
return json.dumps(opts_result, ensure_ascii=False)
|
||
except Exception as e:
|
||
error(f"查询 kdb 表失败: {str(e)}, 堆栈: {traceback.format_exc()}")
|
||
return json.dumps({"status": "error", "message": str(e)})
|
||
|
||
except Exception as e:
|
||
error(f"列出用户文件失败: {str(e)}, 堆栈: {traceback.format_exc()}")
|
||
result.update({
|
||
"status": "error",
|
||
"files_by_knowledge_base": {},
|
||
"message": f"列出用户文件失败: {str(e)}",
|
||
"status_code": 400
|
||
})
|
||
return result
|
||
|
||
async def fusedsearch(request, params_kw, *params, **kw):
|
||
"""融合搜索,调用服务化端点"""
|
||
f = kw.get('get_userorgid')
|
||
orgid = await f()
|
||
debug(f"orgid: {orgid},{f=}")
|
||
f = kw.get('get_user')
|
||
userid = await f()
|
||
debug(f"params_kw: {params_kw}")
|
||
# orgid = "04J6VbxLqB_9RPMcgOv_8"
|
||
# userid = "04J6VbxLqB_9RPMcgOv_8"
|
||
query = params_kw.get('query', '')
|
||
# 统一模式处理 limit 参数,为了对接dify和coze
|
||
raw_limit = params_kw.get('limit') or (
|
||
params_kw.get('retrieval_setting', {}).get('top_k')
|
||
if isinstance(params_kw.get('retrieval_setting'), dict)
|
||
else None
|
||
)
|
||
|
||
# 标准化为整数值
|
||
if raw_limit is None:
|
||
limit = 5 # 两个来源都不存在时使用默认值
|
||
elif isinstance(raw_limit, (int, float)):
|
||
limit = int(raw_limit) # 数值类型直接转换
|
||
elif isinstance(raw_limit, str):
|
||
try:
|
||
# 字符串转换为整数
|
||
limit = int(raw_limit)
|
||
except (TypeError, ValueError):
|
||
limit = 5 # 转换失败使用默认值
|
||
else:
|
||
limit = 5 # 其他意外类型使用默认值
|
||
debug(f"limit: {limit}")
|
||
raw_fiids = params_kw.get('fiids') or params_kw.get('knowledge_id') #
|
||
|
||
# 标准化为列表格式
|
||
if raw_fiids is None:
|
||
fiids = [] # 两个参数都不存在
|
||
elif isinstance(raw_fiids, list):
|
||
fiids = [str(item).strip() for item in raw_fiids] # 已经是列表
|
||
elif isinstance(raw_fiids, str):
|
||
# fiids = [f.strip() for f in raw_fiids.split(',') if f.strip()]
|
||
try:
|
||
# 尝试解析 JSON 字符串
|
||
parsed = json.loads(raw_fiids)
|
||
if isinstance(parsed, list):
|
||
fiids = [str(item).strip() for item in parsed] # JSON 数组转为字符串列表
|
||
else:
|
||
# 处理逗号分隔的字符串或单个 ID 字符串
|
||
fiids = [f.strip() for f in raw_fiids.split(',') if f.strip()]
|
||
except json.JSONDecodeError:
|
||
# 如果不是合法 JSON,按逗号分隔
|
||
fiids = [f.strip() for f in raw_fiids.split(',') if f.strip()]
|
||
elif isinstance(raw_fiids, (int, float)):
|
||
fiids = [str(int(raw_fiids))] # 数值类型转为字符串列表
|
||
else:
|
||
fiids = [] # 其他意外类型
|
||
|
||
debug(f"fiids: {fiids}")
|
||
|
||
# 验证 fiids的orgid与orgid = await f()是否一致
|
||
await _validate_fiids_orgid(fiids, orgid, kw)
|
||
|
||
service_params = await get_service_params(orgid)
|
||
if not service_params:
|
||
raise ValueError("无法获取服务参数")
|
||
|
||
try:
|
||
timing_stats = {}
|
||
start_time = time.time()
|
||
rag_ops = RagOperations()
|
||
|
||
entity_extract_start = time.time()
|
||
query_entities = await rag_ops.extract_entities(request, query, service_params, userid)
|
||
timing_stats["entity_extraction"] = time.time() - entity_extract_start
|
||
debug(f"提取实体: {query_entities}, 耗时: {timing_stats['entity_extraction']:.3f} 秒")
|
||
|
||
triplet_match_start = time.time()
|
||
all_triplets = await rag_ops.match_triplets(request, query, query_entities, orgid, fiids, service_params, userid)
|
||
timing_stats["triplet_matching"] = time.time() - triplet_match_start
|
||
debug(f"三元组匹配总耗时: {timing_stats['triplet_matching']:.3f} 秒")
|
||
|
||
triplet_text_start = time.time()
|
||
combined_text = _combine_query_with_triplets(query, all_triplets)
|
||
timing_stats["triplet_text_combine"] = time.time() - triplet_text_start
|
||
debug(f"拼接三元组文本耗时: {timing_stats['triplet_text_combine']:.3f} 秒")
|
||
|
||
vector_start = time.time()
|
||
query_vector = await rag_ops.api_service.get_embeddings(
|
||
request=request,
|
||
texts=[combined_text],
|
||
upappid=service_params['embedding'],
|
||
apiname="BAAI/bge-m3",
|
||
user=userid
|
||
)
|
||
if not query_vector or not all(len(vec) == 1024 for vec in query_vector):
|
||
raise ValueError("查询向量必须是长度为 1024 的浮点数列表")
|
||
query_vector = query_vector[0]
|
||
timing_stats["vector_generation"] = time.time() - vector_start
|
||
debug(f"生成查询向量耗时: {timing_stats['vector_generation']:.3f} 秒")
|
||
|
||
search_start = time.time()
|
||
search_limit = limit + 5
|
||
search_results = await rag_ops.vector_search(
|
||
request, query_vector, orgid, fiids, search_limit, service_params, userid
|
||
)
|
||
timing_stats["vector_search"] = time.time() - search_start
|
||
debug(f"向量搜索耗时: {timing_stats['vector_search']:.3f} 秒")
|
||
debug(f"从向量数据中搜索到{len(search_results)}条数据")
|
||
|
||
# 步骤6: 重排序(可选)
|
||
use_rerank = True
|
||
if use_rerank and search_results:
|
||
rerank_start = time.time()
|
||
debug("开始重排序")
|
||
reranked_results = await rag_ops.rerank_results(
|
||
request, combined_text, search_results, limit, service_params, userid
|
||
)
|
||
reranked_results = sorted(reranked_results, key=lambda x: x.get('rerank_score', 0), reverse=True)
|
||
timing_stats["reranking"] = time.time() - rerank_start
|
||
debug(f"重排序耗时: {timing_stats['reranking']:.3f} 秒")
|
||
debug(f"重排序分数分布: {[round(r.get('rerank_score', 0), 3) for r in reranked_results]}")
|
||
final_results = reranked_results
|
||
else:
|
||
final_results = [{k: v for k, v in r.items() if k != 'rerank_score'} for r in search_results]
|
||
|
||
timing_stats["total_time"] = time.time() - start_time
|
||
info(f"融合搜索完成,返回 {len(final_results)} 条结果,总耗时: {timing_stats['total_time']:.3f} 秒")
|
||
|
||
formatted_results = rag_ops.format_search_results(final_results, limit)
|
||
info(f"融合搜索完成,返回 {len(formatted_results)} 条结果")
|
||
|
||
return {
|
||
"records": formatted_results
|
||
}
|
||
|
||
except Exception as e:
|
||
error(f"融合搜索失败: {str(e)}, 堆栈: {traceback.format_exc()}")
|
||
# 事务管理器会自动执行回滚
|
||
return {
|
||
"records": [],
|
||
"timing": {"total_time": time.time() - start_time if 'start_time' in locals() else 0},
|
||
"error": str(e)
|
||
}
|
||
|
||
|
||
async def _validate_fiids_orgid(fiids, orgid, kw):
|
||
"""验证 fiids 的 orgid 与当前用户 orgid 是否一致"""
|
||
if fiids:
|
||
db = DBPools()
|
||
dbname = kw.get('get_module_dbname')('rag')
|
||
sql_opts = """SELECT orgid FROM kdb WHERE id = ${id}$"""
|
||
try:
|
||
async with db.sqlorContext(dbname) as sor:
|
||
result = await sor.sqlExe(sql_opts, {"id": fiids[0]})
|
||
if not result:
|
||
raise ValueError(f"未找到 fiid={fiids[0]} 的记录")
|
||
kdb_orgid = result[0].get('orgid')
|
||
if kdb_orgid != orgid:
|
||
raise ValueError(f"orgid 不一致: kdb.orgid={kdb_orgid}, user orgid={orgid}")
|
||
except Exception as e:
|
||
error(f"orgid 验证失败: {str(e)}")
|
||
raise
|
||
|
||
|
||
def _combine_query_with_triplets(query, triplets):
|
||
"""拼接查询文本和三元组文本"""
|
||
triplet_texts = []
|
||
for triplet in triplets:
|
||
head = triplet.get('head', '')
|
||
type_ = triplet.get('type', '')
|
||
tail = triplet.get('tail', '')
|
||
if head and type_ and tail:
|
||
triplet_texts.append(f"{head} {type_} {tail}")
|
||
else:
|
||
debug(f"无效三元组: {triplet}")
|
||
|
||
combined_text = query
|
||
if triplet_texts:
|
||
combined_text += "".join(triplet_texts)
|
||
|
||
debug(f"拼接文本: {combined_text[:200]}... (总长度: {len(combined_text)}, 三元组数量: {len(triplet_texts)})")
|
||
return combined_text
|
||
|
||
# api_service = APIService()
|
||
# start_time = time.time()
|
||
# timing_stats = {}
|
||
# try:
|
||
# info(
|
||
# f"开始融合搜索: query={query}, userid={orgid}, knowledge_base_ids={fiids}")
|
||
#
|
||
# if not query or not orgid or not fiids:
|
||
# raise ValueError("query、orgid 和 knowledge_base_ids 不能为空")
|
||
#
|
||
# # 提取实体
|
||
# entity_extract_start = time.time()
|
||
# query_entities = await api_service.extract_entities(
|
||
# request=request,
|
||
# query=query,
|
||
# upappid=service_params['entities'],
|
||
# apiname="LTP/small",
|
||
# user=userid
|
||
# )
|
||
# timing_stats["entity_extraction"] = time.time() - entity_extract_start
|
||
# debug(f"提取实体: {query_entities}, 耗时: {timing_stats['entity_extraction']:.3f} 秒")
|
||
#
|
||
# # 调用 Neo4j 服务进行三元组匹配
|
||
# all_triplets = []
|
||
# triplet_match_start = time.time()
|
||
# for kb_id in fiids:
|
||
# debug(f"调用 Neo4j 三元组匹配: knowledge_base_id={kb_id}")
|
||
# try:
|
||
# neo4j_result = await api_service.neo4j_match_triplets(
|
||
# request=request,
|
||
# query=query,
|
||
# query_entities=query_entities,
|
||
# userid=orgid,
|
||
# knowledge_base_id=kb_id,
|
||
# upappid=service_params['gdb'],
|
||
# apiname="neo4j/matchtriplets",
|
||
# user=userid
|
||
# )
|
||
# if neo4j_result.get("status") == "success":
|
||
# triplets = neo4j_result.get("triplets", [])
|
||
# all_triplets.extend(triplets)
|
||
# debug(f"知识库 {kb_id} 匹配到 {len(triplets)} 个三元组: {triplets[:5]}")
|
||
# else:
|
||
# error(
|
||
# f"Neo4j 三元组匹配失败: knowledge_base_id={kb_id}, 错误: {neo4j_result.get('message', '未知错误')}")
|
||
# except Exception as e:
|
||
# error(f"Neo4j 三元组匹配失败: knowledge_base_id={kb_id}, 错误: {str(e)}")
|
||
# continue
|
||
# timing_stats["triplet_matching"] = time.time() - triplet_match_start
|
||
# debug(f"三元组匹配总耗时: {timing_stats['triplet_matching']:.3f} 秒")
|
||
#
|
||
# # 拼接三元组文本
|
||
# triplet_text_start = time.time()
|
||
# triplet_texts = []
|
||
# for triplet in all_triplets:
|
||
# head = triplet.get('head', '')
|
||
# type_ = triplet.get('type', '')
|
||
# tail = triplet.get('tail', '')
|
||
# if head and type_ and tail:
|
||
# triplet_texts.append(f"{head} {type_} {tail}")
|
||
# else:
|
||
# debug(f"无效三元组: {triplet}")
|
||
# combined_text = query
|
||
# if triplet_texts:
|
||
# combined_text += "".join(triplet_texts)
|
||
# debug(
|
||
# f"拼接文本: {combined_text[:200]}... (总长度: {len(combined_text)}, 三元组数量: {len(triplet_texts)})")
|
||
# timing_stats["triplet_text_combine"] = time.time() - triplet_text_start
|
||
# debug(f"拼接三元组文本耗时: {timing_stats['triplet_text_combine']:.3f} 秒")
|
||
#
|
||
# # 将拼接文本转换为向量
|
||
# vector_start = time.time()
|
||
# query_vector = await api_service.get_embeddings(
|
||
# request=request,
|
||
# texts=[combined_text],
|
||
# upappid=service_params['embedding'],
|
||
# apiname="BAAI/bge-m3",
|
||
# user=userid
|
||
# )
|
||
# if not query_vector or not all(len(vec) == 1024 for vec in query_vector):
|
||
# raise ValueError("查询向量必须是长度为 1024 的浮点数列表")
|
||
# query_vector = query_vector[0] # 取第一个向量
|
||
# timing_stats["vector_generation"] = time.time() - vector_start
|
||
# debug(f"生成查询向量耗时: {timing_stats['vector_generation']:.3f} 秒")
|
||
#
|
||
# # 调用搜索端点
|
||
# sum = limit + 5
|
||
# search_start = time.time()
|
||
# debug(f"orgid: {orgid}")
|
||
# result = await api_service.milvus_search_query(
|
||
# request=request,
|
||
# query_vector=query_vector,
|
||
# userid=orgid,
|
||
# knowledge_base_ids=fiids,
|
||
# limit=sum,
|
||
# offset=0,
|
||
# upappid=service_params['vdb'],
|
||
# apiname="mlvus/searchquery",
|
||
# user=userid
|
||
# )
|
||
# timing_stats["vector_search"] = time.time() - search_start
|
||
# debug(f"向量搜索耗时: {timing_stats['vector_search']:.3f} 秒")
|
||
#
|
||
# if result.get("status") != "success":
|
||
# error(f"融合搜索失败: {result.get('message', '未知错误')}")
|
||
# return {"results": [], "timing": timing_stats}
|
||
#
|
||
# unique_results = result.get("results", [])
|
||
# sum = len(unique_results)
|
||
# debug(f"从向量数据中搜索到{sum}条数据")
|
||
# use_rerank = True
|
||
# if use_rerank and unique_results:
|
||
# rerank_start = time.time()
|
||
# debug("开始重排序")
|
||
# unique_results = await api_service.rerank_results(
|
||
# request=request,
|
||
# query=combined_text,
|
||
# results=unique_results,
|
||
# top_n=limit,
|
||
# upappid=service_params['reranker'],
|
||
# apiname="BAAI/bge-reranker-v2-m3",
|
||
# user=userid
|
||
# )
|
||
# unique_results = sorted(unique_results, key=lambda x: x.get('rerank_score', 0), reverse=True)
|
||
# timing_stats["reranking"] = time.time() - rerank_start
|
||
# debug(f"重排序耗时: {timing_stats['reranking']:.3f} 秒")
|
||
# debug(f"重排序分数分布: {[round(r.get('rerank_score', 0), 3) for r in unique_results]}")
|
||
# else:
|
||
# unique_results = [{k: v for k, v in r.items() if k != 'rerank_score'} for r in unique_results]
|
||
#
|
||
# timing_stats["total_time"] = time.time() - start_time
|
||
# info(f"融合搜索完成,返回 {len(unique_results)} 条结果,总耗时: {timing_stats['total_time']:.3f} 秒")
|
||
#
|
||
# # dify_result = []
|
||
# # for res in unique_results[:limit]:
|
||
# # content = res.get('text', '')
|
||
# # title = res.get('metadata', {}).get('filename', 'Untitled')
|
||
# # document_id = res.get('metadata', {}).get('document_id', '')
|
||
# # dify_result.append({
|
||
# # 'metadata': {'document_id': document_id},
|
||
# # 'title': title,
|
||
# # 'content': content
|
||
# # })
|
||
# # info(f"融合搜索完成,返回 {len(dify_result)} 条结果,总耗时: {(time.time() - start_time):.3f} 秒")
|
||
# # debug(f"result: {dify_result}")
|
||
# # return dify_result
|
||
#
|
||
# dify_records = []
|
||
# dify_result = []
|
||
# for res in unique_results[:limit]:
|
||
# rerank_score = res.get('rerank_score', 0)
|
||
# score = 1 / (1 + math.exp(-rerank_score)) if rerank_score is not None else 1 - res.get('distance', 0)
|
||
# score = max(0.0, min(1.0, score))
|
||
# content = res.get('text', '')
|
||
# title = res.get('metadata', {}).get('filename', 'Untitled')
|
||
# document_id = res.get('metadata', {}).get('document_id', '')
|
||
# dify_records.append({
|
||
# "content": content,
|
||
# "title": title,
|
||
# "metadata": {"document_id": document_id, "score": score},
|
||
# })
|
||
# dify_result.append({
|
||
# "content": content,
|
||
# "title": title,
|
||
# "metadata": {"document_id": document_id, "score": score},
|
||
# })
|
||
# info(f"融合搜索完成,返回 {len(dify_records)} 条结果,总耗时: {(time.time() - start_time):.3f} 秒")
|
||
# debug(f"records: {dify_records}, result: {dify_result}")
|
||
# # return {"records": dify_records, "result": dify_result,"own": {"results": unique_results[:limit], "timing": timing_stats}}
|
||
# return {"records": dify_records}
|
||
#
|
||
# # dify_result = []
|
||
# # for res in unique_results[:limit]:
|
||
# # rerank_score = res.get('rerank_score', 0)
|
||
# # score = 1 / (1 + math.exp(-rerank_score)) if rerank_score is not None else 1 - res.get('distance', 0)
|
||
# # score = max(0.0, min(1.0, score))
|
||
# # content = res.get('text', '')
|
||
# # title = res.get('metadata', {}).get('filename', 'Untitled')
|
||
# # document_id = res.get('metadata', {}).get('document_id', '')
|
||
# # dify_result.append({
|
||
# # "metadata": {
|
||
# # "_source": "konwledge",
|
||
# # "dataset_id":"111111",
|
||
# # "dataset_name": "NVIDIA_GPU性能参数-RAG-V1.xlsx",
|
||
# # "document_id": document_id,
|
||
# # "document_name": "test.docx",
|
||
# # "data_source_type": "upload_file",
|
||
# # "segment_id": "7b391707-93bc-4654-80ae-7989f393b045",
|
||
# # "retriever_from": "workflow",
|
||
# # "score": score,
|
||
# # "segment_hit_count": 7,
|
||
# # "segment_word_count": 275,
|
||
# # "segment_position": 5,
|
||
# # "segment_index_node_hash": "1cd60b478221c9d4831a0b2af3e8b8581d94ecb53e8ffd46af687e8fc3077b73",
|
||
# # "doc_metadata": None,
|
||
# # "position":1
|
||
# # },
|
||
# # "title": title,
|
||
# # "content": content
|
||
# # })
|
||
# # return {"result": dify_result}
|
||
#
|
||
# except Exception as e:
|
||
# error(f"融合搜索失败: {str(e)}, 堆栈: {traceback.format_exc()}")
|
||
# return {"results": [], "timing": timing_stats}
|