rag/rag/init.py
2025-10-09 17:39:59 +08:00

176 lines
5.7 KiB
Python
Raw Blame History

This file contains ambiguous Unicode characters

This file contains Unicode characters that might be confused with other characters. If you think that this is intentional, you can safely ignore this warning. Use the Escape button to reveal them.

#
from appPublic.log import debug, error, info
from ahserver.serverenv import ServerEnv
import aiohttp
from aiohttp import ClientSession, ClientTimeout
import json
from .file import file_uploaded, file_deleted
from .folderinfo import RagFileMgr
from .ragprogram import set_program, get_rag_programs
from .ragllm_utils import get_ragllms_by_catelog
from appPublic.registerfunction import RegisterFunction
from sqlor.dbpools import DBPools
async def _make_connection_request(action: str, params: dict = None) -> dict:
"""
通用函数,调用 MilvusConnection 的服务化接口。
参数:
action (str): 操作类型,例如 'initialize''insert_document'
params (dict): 操作参数,默认为 None。
返回:
dict: 服务响应,包含 status、message、collection_name 等字段。
"""
if params is None:
params = {}
url = f"http://localhost:8888/v1/{action}"
try:
debug(f"发起 {action} 请求: params={params}")
async with ClientSession(timeout=ClientTimeout(total=10)) as session:
async with session.post(
url,
headers={"Content-Type": "application/json"},
json=params
) as response:
response_text = await response.text()
debug(f"收到 {action} 响应: status={response.status}, content={response_text}")
if response.status != 200:
error(f"{action} 请求失败: 状态码={response.status}, 响应={response_text}")
return {
"status": "error",
"message": f"请求失败: 状态码 {response.status}",
"collection_name": params.get("db_type", "ragdb"),
"document_id": "",
"status_code": response.status
}
result = await response.json()
info(f"{action} 请求成功: 结果={result}")
return result
except Exception as e:
error(f"{action} 请求异常: 错误={str(e)}")
return {
"status": "error",
"message": f"服务器错误: {str(e)}",
"collection_name": params.get("db_type", "ragdb"),
"document_id": "",
"status_code": 500
}
async def create_collection(db_type: str = "") -> dict:
"""创建 Milvus 集合"""
return await _make_connection_request("create_collection", {"db_type": db_type})
async def delete_collection(db_type: str = "") -> dict:
"""删除 Milvus 集合"""
return await _make_connection_request("delete_collection", {"db_type": db_type})
async def insert_document(file_path: str, userid: str, knowledge_base_id: str, db_type: str = "") -> dict:
"""插入文档到 Milvus 并抽取三元组到 Neo4j"""
params = {
"file_path": file_path,
"userid": userid,
"knowledge_base_id": knowledge_base_id,
"db_type": db_type
}
return await _make_connection_request("insert_document", params)
async def delete_document(userid: str, filename: str, knowledge_base_id: str, db_type: str = "") -> dict:
"""删除指定文档的 Milvus 和 Neo4j 记录"""
params = {
"userid": userid,
"filename": filename,
"knowledge_base_id": knowledge_base_id,
"db_type": db_type
}
return await _make_connection_request("delete_document", params)
async def delete_knowledge_base(userid: str, knowledge_base_id: str, db_type: str = "") -> dict:
"""删除整个知识库的 Milvus 和 Neo4j 记录"""
params = {
"userid": userid,
"knowledge_base_id": knowledge_base_id,
"db_type": db_type
}
return await _make_connection_request("delete_knowledge_base", params)
async def search_query(query: str, userid: str, knowledge_base_ids: list, limit: int = 5, offset: int = 0,
use_rerank: bool = True, db_type: str = "") -> dict:
"""执行纯向量搜索"""
params = {
"query": query,
"userid": userid,
"knowledge_base_ids": knowledge_base_ids,
"limit": limit,
"offset": offset,
"use_rerank": use_rerank,
"db_type": db_type
}
return await _make_connection_request("search_query", params)
async def fused_search(query: str, userid: str, knowledge_base_ids: list, limit: int = 5, offset: int = 0,
use_rerank: bool = True, db_type: str = "") -> dict:
"""执行融合搜索(向量 + 三元组)"""
params = {
"query": query,
"userid": userid,
"knowledge_base_ids": knowledge_base_ids,
"limit": limit,
"offset": offset,
"use_rerank": use_rerank,
"db_type": db_type
}
return await _make_connection_request("fused_search", params)
async def list_user_files(userid: str, db_type: str = "") -> dict:
"""列出用户的所有知识库及其文件"""
params = {
"userid": userid,
"db_type": db_type
}
return await _make_connection_request("list_user_files", params)
async def list_all_knowledge_bases(db_type: str = "") -> dict:
"""列出所有用户的知识库及其文件"""
return await _make_connection_request("list_all_knowledge_bases", {"db_type": db_type})
async def docs() -> dict:
"""列出所有用户的知识库及其文件"""
return await _make_connection_request("docs", {})
async def get_user_kdbs(request):
env = request._run_ns
db = env.DBPools()
dbname = env.get_module_dbname('rag')
userorgid = await env.get_userorgid()
async with db.sqlorContext(dbname) as sor:
sql = "select * from kdb where ownerid = ${orgid}$"
recs = await sor.sqlExe(sql, {'orgid': userorgid})
return recs
return recs
def load_rag():
"""
初始化 ServerEnv绑定 MilvusConnection 的所有功能。
"""
env = ServerEnv()
env.create_collection = create_collection
env.delete_collection = delete_collection
env.insert_document = insert_document
env.delete_document = delete_document
env.delete_knowledge_base = delete_knowledge_base
env.search_query = search_query
env.fused_search = fused_search
env.list_user_files = list_user_files
env.list_all_knowledge_bases = list_all_knowledge_bases
env.docs = docs
env.RagFileMgr = RagFileMgr
env.set_program = set_program
env.get_rag_programs = get_rag_programs
env.get_ragllms_by_catelog = get_ragllms_by_catelog
env.get_user_kdbs = get_user_kdbs