rag/rag/ragapi.py
2025-09-12 15:34:36 +08:00

360 lines
15 KiB
Python
Raw Blame History

This file contains ambiguous Unicode characters

This file contains Unicode characters that might be confused with other characters. If you think that this is intentional, you can safely ignore this warning. Use the Escape button to reveal them.

from rag.uapi_service import APIService
from sqlor.dbpools import DBPools
from appPublic.log import debug, error, info
import time
import traceback
import json
import math
from rag.service_opts import get_service_params, sor_get_service_params
helptext = """kyrag API:
1. 得到kdb表:
path: /v1/get_kdbs
headers: {
"Content-Type": "application/json"
}
response:
[{"id": "1", "name": "textdb", "description": "文本数据库"}, {"id": "testkdb", "name": "testkdb", "description": ""}, {"id": "Vdtbt3qBfocteit1HIxVH", "name": "trm", "description": ""}]
2. 向量检索文本块:
path: /v1/fusedsearch
headers: {
"Content-Type": "application/json"
}
data: {
"query": "什么是知识抽取。"
"fiids":["1"]
}
3、docs文档
path: /v1/docs
4、longmemory存储
"""
async def docs(request, params_kw, *params, **kw):
return helptext
async def get_kdbs(request, params_kw, *params, **kw):
"""返回 kdb 表的全部内容,返回 JSON"""
f = kw.get('get_userorgid')
orgid = await f()
debug(f"orgid: {orgid}{f=}")
debug(f"params_kw: {params_kw}")
db = DBPools()
dbname = kw.get('get_module_dbname')('rag')
sql_opts = """
SELECT id, name, description
FROM kdb
WHERE orgid = ${orgid}$
"""
try:
async with db.sqlorContext(dbname) as sor:
opts_result = await sor.sqlExe(sql_opts, {"orgid": orgid})
if not opts_result:
error("未找到 kdb 表记录")
return json.dumps({"status": "error", "message": "未找到记录"})
return json.dumps(opts_result, ensure_ascii=False)
except Exception as e:
error(f"查询 kdb 表失败: {str(e)}, 堆栈: {traceback.format_exc()}")
return json.dumps({"status": "error", "message": str(e)})
except Exception as e:
error(f"列出用户文件失败: {str(e)}, 堆栈: {traceback.format_exc()}")
result.update({
"status": "error",
"files_by_knowledge_base": {},
"message": f"列出用户文件失败: {str(e)}",
"status_code": 400
})
return result
async def fusedsearch(request, params_kw, *params, **kw):
"""融合搜索,调用服务化端点"""
f = kw.get('get_userorgid')
orgid = await f()
debug(f"orgid: {orgid}{f=}")
f = kw.get('get_user')
userid = await f()
debug(f"params_kw: {params_kw}")
# orgid = "04J6VbxLqB_9RPMcgOv_8"
# userid = "04J6VbxLqB_9RPMcgOv_8"
query = params_kw.get('query', '')
# 统一模式处理 limit 参数,为了对接dify和coze
raw_limit = params_kw.get('limit') or (
params_kw.get('retrieval_setting', {}).get('top_k')
if isinstance(params_kw.get('retrieval_setting'), dict)
else None
)
# 标准化为整数值
if raw_limit is None:
limit = 5 # 两个来源都不存在时使用默认值
elif isinstance(raw_limit, (int, float)):
limit = int(raw_limit) # 数值类型直接转换
elif isinstance(raw_limit, str):
try:
# 字符串转换为整数
limit = int(raw_limit)
except (TypeError, ValueError):
limit = 5 # 转换失败使用默认值
else:
limit = 5 # 其他意外类型使用默认值
debug(f"limit: {limit}")
raw_fiids = params_kw.get('fiids') or params_kw.get('knowledge_id') #
# 标准化为列表格式
if raw_fiids is None:
fiids = [] # 两个参数都不存在
elif isinstance(raw_fiids, list):
fiids = [str(item).strip() for item in raw_fiids] # 已经是列表
elif isinstance(raw_fiids, str):
# fiids = [f.strip() for f in raw_fiids.split(',') if f.strip()]
try:
# 尝试解析 JSON 字符串
parsed = json.loads(raw_fiids)
if isinstance(parsed, list):
fiids = [str(item).strip() for item in parsed] # JSON 数组转为字符串列表
else:
# 处理逗号分隔的字符串或单个 ID 字符串
fiids = [f.strip() for f in raw_fiids.split(',') if f.strip()]
except json.JSONDecodeError:
# 如果不是合法 JSON按逗号分隔
fiids = [f.strip() for f in raw_fiids.split(',') if f.strip()]
elif isinstance(raw_fiids, (int, float)):
fiids = [str(int(raw_fiids))] # 数值类型转为字符串列表
else:
fiids = [] # 其他意外类型
debug(f"fiids: {fiids}")
# 验证 fiids的orgid与orgid = await f()是否一致
if fiids:
db = DBPools()
dbname = kw.get('get_module_dbname')('rag')
sql_opts = """
SELECT orgid
FROM kdb
WHERE id = ${id}$
"""
try:
async with db.sqlorContext(dbname) as sor:
result = await sor.sqlExe(sql_opts, {"id": fiids[0]})
if not result:
raise ValueError(f"未找到 fiid={fiids[0]} 的记录")
kdb_orgid = result[0].get('orgid')
if kdb_orgid != orgid:
raise ValueError(f"orgid 不一致: kdb.orgid={kdb_orgid}, user orgid={orgid}")
except Exception as e:
error(f"orgid 验证失败: {str(e)}")
return json.dumps({"status": "error", "message": str(e)})
service_params = await get_service_params(orgid)
api_service = APIService()
start_time = time.time()
timing_stats = {}
try:
info(
f"开始融合搜索: query={query}, userid={orgid}, knowledge_base_ids={fiids}")
if not query or not orgid or not fiids:
raise ValueError("query、orgid 和 knowledge_base_ids 不能为空")
# 提取实体
entity_extract_start = time.time()
query_entities = await api_service.extract_entities(
request=request,
query=query,
upappid=service_params['entities'],
apiname="LTP/small",
user=userid
)
timing_stats["entity_extraction"] = time.time() - entity_extract_start
debug(f"提取实体: {query_entities}, 耗时: {timing_stats['entity_extraction']:.3f}")
# 调用 Neo4j 服务进行三元组匹配
all_triplets = []
triplet_match_start = time.time()
for kb_id in fiids:
debug(f"调用 Neo4j 三元组匹配: knowledge_base_id={kb_id}")
try:
neo4j_result = await api_service.neo4j_match_triplets(
request=request,
query=query,
query_entities=query_entities,
userid=orgid,
knowledge_base_id=kb_id,
upappid=service_params['gdb'],
apiname="neo4j/matchtriplets",
user=userid
)
if neo4j_result.get("status") == "success":
triplets = neo4j_result.get("triplets", [])
all_triplets.extend(triplets)
debug(f"知识库 {kb_id} 匹配到 {len(triplets)} 个三元组: {triplets[:5]}")
else:
error(
f"Neo4j 三元组匹配失败: knowledge_base_id={kb_id}, 错误: {neo4j_result.get('message', '未知错误')}")
except Exception as e:
error(f"Neo4j 三元组匹配失败: knowledge_base_id={kb_id}, 错误: {str(e)}")
continue
timing_stats["triplet_matching"] = time.time() - triplet_match_start
debug(f"三元组匹配总耗时: {timing_stats['triplet_matching']:.3f}")
# 拼接三元组文本
triplet_text_start = time.time()
triplet_texts = []
for triplet in all_triplets:
head = triplet.get('head', '')
type_ = triplet.get('type', '')
tail = triplet.get('tail', '')
if head and type_ and tail:
triplet_texts.append(f"{head} {type_} {tail}")
else:
debug(f"无效三元组: {triplet}")
combined_text = query
if triplet_texts:
combined_text += "".join(triplet_texts)
debug(
f"拼接文本: {combined_text[:200]}... (总长度: {len(combined_text)}, 三元组数量: {len(triplet_texts)})")
timing_stats["triplet_text_combine"] = time.time() - triplet_text_start
debug(f"拼接三元组文本耗时: {timing_stats['triplet_text_combine']:.3f}")
# 将拼接文本转换为向量
vector_start = time.time()
query_vector = await api_service.get_embeddings(
request=request,
texts=[combined_text],
upappid=service_params['embedding'],
apiname="BAAI/bge-m3",
user=userid
)
if not query_vector or not all(len(vec) == 1024 for vec in query_vector):
raise ValueError("查询向量必须是长度为 1024 的浮点数列表")
query_vector = query_vector[0] # 取第一个向量
timing_stats["vector_generation"] = time.time() - vector_start
debug(f"生成查询向量耗时: {timing_stats['vector_generation']:.3f}")
# 调用搜索端点
sum = limit + 5
search_start = time.time()
debug(f"orgid: {orgid}")
result = await api_service.milvus_search_query(
request=request,
query_vector=query_vector,
userid=orgid,
knowledge_base_ids=fiids,
limit=sum,
offset=0,
upappid=service_params['vdb'],
apiname="mlvus/searchquery",
user=userid
)
timing_stats["vector_search"] = time.time() - search_start
debug(f"向量搜索耗时: {timing_stats['vector_search']:.3f}")
if result.get("status") != "success":
error(f"融合搜索失败: {result.get('message', '未知错误')}")
return {"results": [], "timing": timing_stats}
unique_results = result.get("results", [])
sum = len(unique_results)
debug(f"从向量数据中搜索到{sum}条数据")
use_rerank = True
if use_rerank and unique_results:
rerank_start = time.time()
debug("开始重排序")
unique_results = await api_service.rerank_results(
request=request,
query=combined_text,
results=unique_results,
top_n=limit,
upappid=service_params['reranker'],
apiname="BAAI/bge-reranker-v2-m3",
user=userid
)
unique_results = sorted(unique_results, key=lambda x: x.get('rerank_score', 0), reverse=True)
timing_stats["reranking"] = time.time() - rerank_start
debug(f"重排序耗时: {timing_stats['reranking']:.3f}")
debug(f"重排序分数分布: {[round(r.get('rerank_score', 0), 3) for r in unique_results]}")
else:
unique_results = [{k: v for k, v in r.items() if k != 'rerank_score'} for r in unique_results]
timing_stats["total_time"] = time.time() - start_time
info(f"融合搜索完成,返回 {len(unique_results)} 条结果,总耗时: {timing_stats['total_time']:.3f}")
# dify_result = []
# for res in unique_results[:limit]:
# content = res.get('text', '')
# title = res.get('metadata', {}).get('filename', 'Untitled')
# document_id = res.get('metadata', {}).get('document_id', '')
# dify_result.append({
# 'metadata': {'document_id': document_id},
# 'title': title,
# 'content': content
# })
# info(f"融合搜索完成,返回 {len(dify_result)} 条结果,总耗时: {(time.time() - start_time):.3f} 秒")
# debug(f"result: {dify_result}")
# return dify_result
dify_records = []
dify_result = []
for res in unique_results[:limit]:
rerank_score = res.get('rerank_score', 0)
score = 1 / (1 + math.exp(-rerank_score)) if rerank_score is not None else 1 - res.get('distance', 0)
score = max(0.0, min(1.0, score))
content = res.get('text', '')
title = res.get('metadata', {}).get('filename', 'Untitled')
document_id = res.get('metadata', {}).get('document_id', '')
dify_records.append({
"content": content,
"title": title,
"metadata": {"document_id": document_id, "score": score},
})
dify_result.append({
"content": content,
"title": title,
"metadata": {"document_id": document_id, "score": score},
})
info(f"融合搜索完成,返回 {len(dify_records)} 条结果,总耗时: {(time.time() - start_time):.3f}")
debug(f"records: {dify_records}, result: {dify_result}")
# return {"records": dify_records, "result": dify_result,"own": {"results": unique_results[:limit], "timing": timing_stats}}
return {"records": dify_records}
# dify_result = []
# for res in unique_results[:limit]:
# rerank_score = res.get('rerank_score', 0)
# score = 1 / (1 + math.exp(-rerank_score)) if rerank_score is not None else 1 - res.get('distance', 0)
# score = max(0.0, min(1.0, score))
# content = res.get('text', '')
# title = res.get('metadata', {}).get('filename', 'Untitled')
# document_id = res.get('metadata', {}).get('document_id', '')
# dify_result.append({
# "metadata": {
# "_source": "konwledge",
# "dataset_id":"111111",
# "dataset_name": "NVIDIA_GPU性能参数-RAG-V1.xlsx",
# "document_id": document_id,
# "document_name": "test.docx",
# "data_source_type": "upload_file",
# "segment_id": "7b391707-93bc-4654-80ae-7989f393b045",
# "retriever_from": "workflow",
# "score": score,
# "segment_hit_count": 7,
# "segment_word_count": 275,
# "segment_position": 5,
# "segment_index_node_hash": "1cd60b478221c9d4831a0b2af3e8b8581d94ecb53e8ffd46af687e8fc3077b73",
# "doc_metadata": None,
# "position":1
# },
# "title": title,
# "content": content
# })
# return {"result": dify_result}
except Exception as e:
error(f"融合搜索失败: {str(e)}, 堆栈: {traceback.format_exc()}")
return {"results": [], "timing": timing_stats}