This commit is contained in:
wangmeihua 2025-09-15 15:32:24 +08:00
parent 043fb80ed4
commit 68a9c43390
5 changed files with 0 additions and 2388 deletions

View File

@ -1,350 +0,0 @@
from appPublic.log import debug, error
from typing import Dict, Any, List
import aiohttp
from aiohttp import ClientSession, ClientTimeout
from tenacity import retry, stop_after_attempt, wait_exponential, retry_if_exception_type
import traceback
import uuid
import json
class APIService:
"""处理 API 请求的服务类"""
@retry(
stop=stop_after_attempt(3),
wait=wait_exponential(multiplier=1, min=1, max=10),
retry=retry_if_exception_type((aiohttp.ClientError, RuntimeError)),
before_sleep=lambda retry_state: debug(f"重试 API 请求,第 {retry_state.attempt_number}")
)
async def _make_request(self, url: str, action: str, params: Dict[str, Any]) -> Dict[str, Any]:
"""通用 API 请求函数"""
debug(f"开始 API 请求: action={action}, params={params}, url={url}")
try:
async with ClientSession(timeout=ClientTimeout(total=300)) as session:
async with session.post(
url,
headers={"Content-Type": "application/json"},
json=params
) as response:
debug(f"收到响应: status={response.status}, headers={response.headers}")
response_text = await response.text()
debug(f"响应内容: {response_text}")
result = await response.json()
debug(f"API 响应内容: {result}")
if response.status == 400:
debug(f"客户端错误,状态码: {response.status}, 返回响应: {result}")
return result
if response.status != 200:
error(f"API 调用失败,动作: {action}, 状态码: {response.status}, 响应: {response_text}")
raise RuntimeError(f"API 调用失败: {response.status}")
return result
except Exception as e:
error(f"API 调用失败: {action}, 错误: {str(e)}, 堆栈: {traceback.format_exc()}")
raise RuntimeError(f"API 调用失败: {str(e)}")
# 嵌入服务 (BAAI/bge-m3)
async def get_embeddings(self, texts: list) -> list:
"""调用嵌入服务获取文本向量"""
try:
async with ClientSession() as session:
async with session.post(
"https://embedding.opencomputing.net:10443/v1/embeddings", # 使用外网地址
headers={"Content-Type": "application/json"},
json={"input": texts if isinstance(texts, list) else [texts]}
) as response:
if response.status != 200:
error(f"嵌入服务调用失败,状态码: {response.status}")
raise RuntimeError(f"嵌入服务调用失败: {response.status}")
result = await response.json()
if result.get("object") != "list" or not result.get("data"):
error(f"嵌入服务响应格式错误: {result}")
raise RuntimeError("嵌入服务响应格式错误")
embeddings = [item["embedding"] for item in result["data"]]
debug(f"成功获取 {len(embeddings)} 个嵌入向量")
return embeddings
except Exception as e:
error(f"嵌入服务调用失败: {str(e)}")
raise RuntimeError(f"嵌入服务调用失败: {str(e)}")
# 实体提取服务 (LTP/small)
async def extract_entities(self, query: str) -> list:
"""调用实体识别服务"""
try:
if not query:
raise ValueError("查询文本不能为空")
async with ClientSession() as session:
async with session.post(
"https://entities.opencomputing.net:10443/v1/entities", # 使用外网地址
headers={"Content-Type": "application/json"},
json={"query": query}
) as response:
if response.status != 200:
error(f"实体识别服务调用失败,状态码: {response.status}")
raise RuntimeError(f"实体识别服务调用失败: {response.status}")
result = await response.json()
if result.get("object") != "list" or not result.get("data"):
error(f"实体识别服务响应格式错误: {result}")
raise RuntimeError("实体识别服务响应格式错误")
entities = result["data"]
unique_entities = list(dict.fromkeys(entities))
debug(f"成功提取 {len(unique_entities)} 个唯一实体")
return unique_entities
except Exception as e:
error(f"实体识别服务调用失败: {str(e)}")
return []
# 三元组抽取服务 (Babelscape/mrebel-large)
async def extract_triples(self, text: str) -> list:
"""调用三元组抽取服务"""
request_id = str(uuid.uuid4())
debug(f"Request #{request_id} started for triples extraction")
try:
async with ClientSession(
connector=aiohttp.TCPConnector(limit=30),
timeout=ClientTimeout(total=None)
) as session:
async with session.post(
"https://triples.opencomputing.net:10443/v1/triples", # 使用外网地址
headers={"Content-Type": "application/json; charset=utf-8"},
json={"text": text}
) as response:
if response.status != 200:
error_text = await response.text()
error(f"Request #{request_id} failed, status: {response.status}, response: {error_text}")
raise RuntimeError(f"三元组抽取服务调用失败: {response.status}, {error_text}")
result = await response.json()
if result.get("object") != "list":
error(f"Request #{request_id} invalid response format: {result}")
raise RuntimeError("三元组抽取服务响应格式错误")
triples = result["data"]
debug(f"Request #{request_id} extracted {len(triples)} triples")
return triples
except Exception as e:
error(f"Request #{request_id} failed to extract triples: {str(e)}")
raise RuntimeError(f"三元组抽取服务调用失败: {str(e)}")
# 重排序服务 (BAAI/bge-reranker-v2-m3)
async def rerank_results(self, query: str, results: list, top_n: int) -> list:
"""调用重排序服务"""
try:
if not results:
debug("无结果需要重排序")
return results
if not isinstance(top_n, int) or top_n < 1:
debug(f"无效的 top_n 参数: {top_n}, 使用 len(results)={len(results)}")
top_n = len(results)
else:
top_n = min(top_n, len(results))
documents = [result.get("text", str(result)) for result in results]
async with ClientSession() as session:
async with session.post(
"https://reranker.opencomputing.net:10443/v1/rerank", # 使用外网地址
headers={"Content-Type": "application/json"},
json={
"model": "rerank-001",
"query": query,
"documents": documents,
"top_n": top_n
}
) as response:
if response.status != 200:
error(f"重排序服务调用失败,状态码: {response.status}")
raise RuntimeError(f"重排序服务调用失败: {response.status}")
result = await response.json()
if result.get("object") != "rerank.result" or not result.get("data"):
error(f"重排序服务响应格式错误: {result}")
raise RuntimeError("重排序服务响应格式错误")
rerank_data = result["data"]
reranked_results = []
for item in rerank_data:
index = item["index"]
if index < len(results):
results[index]["rerank_score"] = item["relevance_score"]
reranked_results.append(results[index])
debug(f"成功重排序 {len(reranked_results)} 条结果")
return reranked_results[:top_n]
except Exception as e:
error(f"重排序服务调用失败: {str(e)}")
return results
# Neo4j 服务
async def neo4j_docs(self) -> str:
"""获取 Neo4j 文档(返回文本格式)"""
async with ClientSession(timeout=ClientTimeout(total=300)) as session:
async with session.get("https://graphdb.opencomputing.net:10443/docs") as response:
if response.status != 200:
error(f"Neo4j 文档请求失败,状态码: {response.status}")
raise RuntimeError(f"Neo4j 文档请求失败: {response.status}")
text = await response.text() # 获取纯文本内容
debug(f"Neo4j 文档内容: {text}")
return text
async def neo4j_initialize(self) -> Dict[str, Any]:
"""初始化 Neo4j 服务"""
return await self._make_request("https://graphdb.opencomputing.net:10443/v1/initialize", "initialize", {})
async def neo4j_insert_triples(self, triples: list, document_id: str, knowledge_base_id: str, userid: str) -> Dict[str, Any]:
"""插入三元组到 Neo4j"""
params = {
"triples": triples,
"document_id": document_id,
"knowledge_base_id": knowledge_base_id,
"userid": userid
}
return await self._make_request("https://graphdb.opencomputing.net:10443/v1/inserttriples", "inserttriples", params)
async def neo4j_delete_document(self, document_id: str) -> Dict[str, Any]:
"""删除指定文档"""
return await self._make_request("https://graphdb.opencomputing.net:10443/v1/deletedocument", "deletedocument", {"document_id": document_id})
async def neo4j_delete_knowledgebase(self, userid: str, knowledge_base_id: str) -> Dict[str, Any]:
"""删除用户知识库"""
return await self._make_request("https://graphdb.opencomputing.net:10443/v1/deleteknowledgebase", "deleteknowledgebase",
{"userid": userid, "knowledge_base_id": knowledge_base_id})
async def neo4j_match_triplets(self, query: str, query_entities: list, userid: str, knowledge_base_id: str) -> Dict[str, Any]:
"""根据实体匹配相关三元组"""
params = {
"query": query,
"query_entities": query_entities,
"userid": userid,
"knowledge_base_id": knowledge_base_id
}
return await self._make_request("https://graphdb.opencomputing.net:10443/v1/matchtriplets", "matchtriplets", params)
# Milvus 服务
async def milvus_create_collection(self, db_type: str = "") -> Dict[str, Any]:
"""创建 Milvus 集合"""
params = {"db_type": db_type} if db_type else {}
return await self._make_request("https://vectordb.opencomputing.net:10443/v1/createcollection", "createcollection", params)
async def milvus_delete_collection(self, db_type: str = "") -> Dict[str, Any]:
"""删除 Milvus 集合"""
params = {"db_type": db_type} if db_type else {}
return await self._make_request("https://vectordb.opencomputing.net:10443/v1/deletecollection", "deletecollection", params)
async def milvus_insert_document(self, chunks: List[Dict], db_type: str = "") -> Dict[str, Any]:
"""添加 Milvus 记录"""
params = {
"chunks": chunks,
"dbtype": db_type
}
# 计算请求体大小
payload = json.dumps(params) # 转换为 JSON 字符串
payload_bytes = payload.encode() # 编码为字节
payload_size = len(payload_bytes) # 获取字节数
debug(f"Request payload size for insertdocument: {payload_size} bytes")
return await self._make_request("https://vectordb.opencomputing.net:10443/v1/insertdocument", "insertdocument", params)
async def milvus_delete_document(self, userid: str, file_path: str, knowledge_base_id: str, document_id:str, db_type: str = "") -> Dict[str, Any]:
"""删除 Milvus 记录"""
params = {
"userid": userid,
"file_path": file_path,
"knowledge_base_id": knowledge_base_id,
"document_id": document_id,
"dbtype": db_type
}
return await self._make_request("https://vectordb.opencomputing.net:10443/v1/deletedocument", "deletedocument", params)
async def milvus_delete_knowledgebase(self, userid: str, knowledge_base_id: str) -> Dict[str, Any]:
"""删除 Milvus 知识库"""
return await self._make_request("https://vectordb.opencomputing.net:10443/v1/deleteknowledgebase", "deleteknowledgebase",
{"userid": userid, "knowledge_base_id": knowledge_base_id})
async def milvus_search_query(self, query_vector: List[float], userid: str, knowledge_base_ids: list, limit: int, offset: int) -> Dict[str, Any]:
"""根据用户知识库检索 Milvus"""
params = {
"query_vector": query_vector,
"userid": userid,
"knowledge_base_ids": knowledge_base_ids,
"limit": limit,
"offset": offset
}
return await self._make_request("https://vectordb.opencomputing.net:10443/v1/searchquery", "searchquery", params)
async def milvus_list_user_files(self, userid: str) -> Dict[str, Any]:
"""列出 Milvus 用户知识库列表"""
return await self._make_request("https://vectordb.opencomputing.net:10443/v1/listuserfiles", "listuserfiles", {"userid": userid})
async def milvus_list_all_knowledgebases(self) -> Dict[str, Any]:
"""列出 Milvus 数据库中所有数据"""
return await self._make_request("https://vectordb.opencomputing.net:10443/v1/listallknowledgebases", "listallknowledgebases", {})
# RAG 服务
async def rag_create_collection(self, db_type: str = "") -> Dict[str, Any]:
"""创建 RAG 集合"""
params = {"db_type": db_type} if db_type else {}
return await self._make_request("https://rag.opencomputing.net:10443/v1/createcollection", "createcollection", params)
async def rag_delete_collection(self, db_type: str = "") -> Dict[str, Any]:
"""删除 RAG 集合"""
params = {"db_type": db_type} if db_type else {}
return await self._make_request("https://rag.opencomputing.net:10443/v1/deletecollection", "deletecollection", params)
async def rag_insert_file(self, file_path: str, userid: str, knowledge_base_id: str, document_id: str) -> Dict[str, Any]:
"""添加 RAG 记录"""
params = {
"file_path": file_path,
"userid": userid,
"knowledge_base_id": knowledge_base_id,
"document_id": document_id
}
return await self._make_request("https://rag.opencomputing.net:10443/v1/insertfile", "insertfile", params)
async def rag_delete_file(self, userid: str, file_path: str, knowledge_base_id: str, document_id: str) -> Dict[str, Any]:
"""删除 RAG 记录"""
params = {
"userid": userid,
"file_path": file_path,
"knowledge_base_id": knowledge_base_id,
"document_id": document_id
}
return await self._make_request("https://rag.opencomputing.net:10443/v1/deletefile", "deletefile", params)
async def rag_delete_knowledgebase(self, userid: str, knowledge_base_id: str) -> Dict[str, Any]:
"""删除 RAG 知识库"""
return await self._make_request("https://rag.opencomputing.net:10443/v1/deleteknowledgebase", "deleteknowledgebase",
{"userid": userid, "knowledge_base_id": knowledge_base_id})
async def rag_search_query(self, query: str, userid: str, knowledge_base_ids: list, limit: int, offset: int,
use_rerank: bool) -> Dict[str, Any]:
"""根据用户知识库检索 RAG"""
params = {
"query": query,
"userid": userid,
"knowledge_base_ids": knowledge_base_ids,
"limit": limit,
"offset": offset,
"use_rerank": use_rerank
}
return await self._make_request("https://rag.opencomputing.net:10443/v1/searchquery", "searchquery", params)
async def rag_fused_search_query(self, query: str, userid: str, knowledge_base_ids: list, limit: int, offset: int,
use_rerank: bool) -> Dict[str, Any]:
"""根据用户知识库+知识图谱检索 RAG"""
params = {
"query": query,
"userid": userid,
"knowledge_base_ids": knowledge_base_ids,
"limit": limit,
"offset": offset,
"use_rerank": use_rerank
}
return await self._make_request("https://rag.opencomputing.net:10443/v1/fusedsearchquery", "fusedsearchquery", params)
async def rag_list_user_files(self, userid: str) -> Dict[str, Any]:
"""列出 RAG 用户知识库列表"""
return await self._make_request("https://rag.opencomputing.net:10443/v1/listuserfiles", "listuserfiles", {"userid": userid})
async def rag_list_all_knowledgebases(self) -> Dict[str, Any]:
"""列出 RAG 数据库中所有数据"""
return await self._make_request("https://rag.opencomputing.net:10443/v1/listallknowledgebases", "listallknowledgebases", {})
async def rag_docs(self) -> Dict[str, Any]:
"""获取 RAG 帮助文档"""
return await self._make_request("https://rag.opencomputing.net:10443/v1/docs", "docs", {})

View File

@ -1,27 +0,0 @@
from abc import ABC, abstractmethod
from typing import Dict
from appPublic.log import debug, error, info, exception
connection_pathMap = {}
def connection_register(connection_key, Klass):
"""为给定的连接键注册一个连接类"""
global connection_pathMap
connection_pathMap[connection_key] = Klass
info(f"Registered {connection_key} with class {Klass}")
def get_connection_class(connection_path):
"""根据连接路径查找对应的连接类"""
global connection_pathMap
debug(f"connection_pathMap: {connection_pathMap}")
klass = connection_pathMap.get(connection_path)
if klass is None:
error(f"{connection_path} has not mapping to a connection class")
raise Exception(f"{connection_path} has not mapping to a connection class")
return klass
class BaseConnection(ABC):
@abstractmethod
async def handle_connection(self, action: str, params: Dict = None) -> Dict:
"""处理数据库操作,根据 action 执行创建集合等"""
pass

View File

@ -1,653 +0,0 @@
import llmengine.milvus_connection
from traceback import format_exc
import argparse
from aiohttp import web
from llmengine.base_connection import get_connection_class
from appPublic.registerfunction import RegisterFunction
from appPublic.log import debug, error, info
from ahserver.serverenv import ServerEnv
from ahserver.webapp import webserver
import os
import json
helptext = """Milvus Connection Service API (using pymilvus Collection API):
1. Create Collection Endpoint:
path: /v1/createcollection
method: POST
headers: {"Content-Type": "application/json"}
data: {
"db_type": "textdb" // 可选若不提供则使用默认集合 ragdb
}
response:
- Success: HTTP 200, {"status": "success", "collection_name": "ragdb" or "ragdb_textdb", "message": "集合 ragdb 或 ragdb_textdb 创建成功"}
- Error: HTTP 400, {"status": "error", "collection_name": "ragdb" or "ragdb_textdb", "message": "<error message>"}
2. Delete Collection Endpoint:
path: /v1/deletecollection
method: POST
headers: {"Content-Type": "application/json"}
data: {
"db_type": "textdb" // 可选若不提供则删除默认集合 ragdb
}
response:
- Success: HTTP 200, {"status": "success", "collection_name": "ragdb" or "ragdb_textdb", "message": "集合 ragdb 或 ragdb_textdb 删除成功"}
- Success (collection does not exist): HTTP 200, {"status": "success", "collection_name": "ragdb" or "ragdb_textdb", "message": "集合 ragdb 或 ragdb_textdb 不存在,无需删除"}
- Error: HTTP 400, {"status": "error", "collection_name": "ragdb" or "ragdb_textdb", "message": "<error message>"}
3. Insert File Endpoint:
path: /v1/insertfile
method: POST
headers: {"Content-Type": "application/json"}
data: {
"file_path": "/path/to/file.txt", // 必填文件路径
"userid": "user123", // 必填用户 ID
"db_type": "textdb", // 可选若不提供则使用默认集合 ragdb
"knowledge_base_id": "kb123" // 必填知识库 ID
}
response:
- Success: HTTP 200, {"status": "success", "document_id": "<uuid>", "collection_name": "ragdb" or "ragdb_textdb", "message": "文件 <file_path> 成功嵌入并处理三元组", "status_code": 200}
- Success (triples failed): HTTP 200, {"status": "success", "document_id": "<uuid>", "collection_name": "ragdb" or "ragdb_textdb", "message": "文件 <file_path> 成功嵌入,但三元组处理失败: <error>", "status_code": 200}
- Error: HTTP 400, {"status": "error", "document_id": "", "collection_name": "ragdb" or "ragdb_textdb", "message": "<error message>", "status_code": 400}
4. Delete Document Endpoint:
path: /v1/deletefile
method: POST
headers: {"Content-Type": "application/json"}
data: {
"userid": "user123", // 必填用户 ID
"filename": "file.txt", // 必填文件名
"db_type": "textdb", // 可选若不提供则使用默认集合 ragdb
"knowledge_base_id": "kb123" // 必填知识库 ID
}
response:
- Success: HTTP 200, {"status": "success", "document_id": "<uuid1,uuid2>", "collection_name": "ragdb" or "ragdb_textdb", "message": "成功删除 <count> 条 Milvus 记录,<nodes> 个 Neo4j 节点,<rels> 个 Neo4j 关系userid=<userid>, filename=<filename>", "status_code": 200}
- Success (no records): HTTP 200, {"status": "success", "document_id": "", "collection_name": "ragdb" or "ragdb_textdb", "message": "没有找到 userid=<userid>, filename=<filename>, knowledge_base_id=<knowledge_base_id> 的记录,无需删除", "status_code": 200}
- Success (collection missing): HTTP 200, {"status": "success", "document_id": "", "collection_name": "ragdb" or "ragdb_textdb", "message": "集合 <collection_name> 不存在,无需删除", "status_code": 200}
- Error: HTTP 400, {"status": "error", "document_id": "", "collection_name": "ragdb" or "ragdb_textdb", "message": "<error message>", "status_code": 400}
5. Fused Search Query Endpoint:
path: /v1/fusedsearchquery
method: POST
headers: {"Content-Type": "application/json"}
data: {
"query": "苹果公司在北京开设新店",
"userid": "user1",
"db_type": "textdb", // 可选若不提供则使用默认集合 ragdb
"knowledge_base_ids": ["kb123"],
"limit": 5,
"offset": 0,
"use_rerank": true
}
response:
- Success: HTTP 200, {
"status": "success",
"results": [
{
"text": "<完整文本内容>",
"distance": 0.95,
"source": "fused_query_with_triplets",
"rerank_score": 0.92, // use_rerank=true
"metadata": {
"userid": "user1",
"document_id": "<uuid>",
"filename": "file.txt",
"file_path": "/path/to/file.txt",
"upload_time": "<iso_timestamp>",
"file_type": "txt"
}
},
...
],
"timing": {
"collection_load": <float>, // 集合加载耗时
"entity_extraction": <float>, // 实体提取耗时
"triplet_matching": <float>, // 三元组匹配耗时
"triplet_text_combine": <float>, // 拼接三元组文本耗时
"embedding_generation": <float>, // 嵌入向量生成耗时
"vector_search": <float>, // 向量搜索耗时
"deduplication": <float>, // 去重耗时
"reranking": <float>, // 重排序耗时 use_rerank=true
"total_time": <float> // 总耗时
},
"collection_name": "ragdb" or "ragdb_textdb"
}
- Error: HTTP 400, {
"status": "error",
"message": "<error message>",
"collection_name": "ragdb" or "ragdb_textdb"
}
6. Search Query Endpoint:
path: /v1/searchquery
method: POST
headers: {"Content-Type": "application/json"}
data: {
"query": "知识图谱的知识融合是什么?",
"userid": "user1",
"db_type": "textdb", // 可选若不提供则使用默认集合 ragdb
"knowledge_base_ids": ["kb123"],
"limit": 5,
"offset": 0,
"use_rerank": true
}
response:
- Success: HTTP 200, {
"status": "success",
"results": [
{
"text": "<完整文本内容>",
"distance": 0.95,
"source": "vector_query",
"rerank_score": 0.92, // use_rerank=true
"metadata": {
"userid": "user1",
"document_id": "<uuid>",
"filename": "file.txt",
"file_path": "/path/to/file.txt",
"upload_time": "<iso_timestamp>",
"file_type": "txt"
}
},
...
],
"timing": {
"collection_load": <float>, // 集合加载耗时
"embedding_generation": <float>, // 嵌入向量生成耗时
"vector_search": <float>, // 向量搜索耗时
"deduplication": <float>, // 去重耗时
"reranking": <float>, // 重排序耗时 use_rerank=true
"total_time": <float> // 总耗时
},
"collection_name": "ragdb" or "ragdb_textdb"
}
- Error: HTTP 400, {
"status": "error",
"message": "<error message>",
"collection_name": "ragdb" or "ragdb_textdb"
}
7. List User Files Endpoint:
path: /v1/listuserfiles
method: POST
headers: {"Content-Type": "application/json"}
data: {
"userid": "user1",
"db_type": "textdb" // 可选若不提供则使用默认集合 ragdb
}
response:
- Success: HTTP 200, {
"status": "success",
"files_by_knowledge_base": {
"kb123": [
{
"document_id": "<uuid>",
"filename": "file1.txt",
"file_path": "/path/to/file1.txt",
"upload_time": "<iso_timestamp>",
"file_type": "txt",
"knowledge_base_id": "kb123"
},
...
],
"kb456": [
{
"document_id": "<uuid>",
"filename": "file2.pdf",
"file_path": "/path/to/file2.pdf",
"upload_time": "<iso_timestamp>",
"file_type": "pdf",
"knowledge_base_id": "kb456"
},
...
]
},
"collection_name": "ragdb" or "ragdb_textdb"
}
- Error: HTTP 400, {
"status": "error",
"message": "<error message>",
"collection_name": "ragdb" or "ragdb_textdb"
}
8. Connection Endpoint (for compatibility):
path: /v1/connection
method: POST
headers: {"Content-Type": "application/json"}
data: {
"action": "<initialize|get_params|create_collection|delete_collection|insert_document|delete_document|fused_search|search_query|list_user_files>",
"params": {...}
}
response:
- Success: HTTP 200, {"status": "success", ...}
- Error: HTTP 400, {"status": "error", "message": "<error message>"}
9. Docs Endpoint:
path: /docs
method: GET
response: This help text
10. Delete Knowledge Base Endpoint:
path: /v1/deleteknowledgebase
method: POST
headers: {"Content-Type": "application/json"}
data: {
"userid": "user123", // 必填用户 ID
"knowledge_base_id": "kb123",// 必填知识库 ID
"db_type": "textdb" // 可选若不提供则使用默认集合 ragdb
}
response:
- Success: HTTP 200, {"status": "success", "document_id": "<uuid1,uuid2>", "filename": "<filename1,filename2>", "collection_name": "ragdb" or "ragdb_textdb", "message": "成功删除 <count> 条 Milvus 记录,<nodes> 个 Neo4j 节点,<rels> 个 Neo4j 关系userid=<userid>, knowledge_base_id=<knowledge_base_id>", "status_code": 200}
- Success (no records): HTTP 200, {"status": "success", "document_id": "", "filename": "", "collection_name": "ragdb" or "ragdb_textdb", "message": "没有找到 userid=<userid>, knowledge_base_id=<knowledge_base_id> 的记录,无需删除", "status_code": 200}
- Success (collection missing): HTTP 200, {"status": "success", "document_id": "", "filename": "", "collection_name": "ragdb" or "ragdb_textdb", "message": "集合 <collection_name> 不存在,无需删除", "status_code": 200}
- Error: HTTP 400, {"status": "error", "document_id": "", "filename": "", "collection_name": "ragdb" or "ragdb_textdb", "message": "<error message>", "status_code": 400}
10. List All Knowledge Bases Endpoint:
path: /v1/listallknowledgebases
method: POST
headers: {"Content-Type": "application/json"}
data: {
"db_type": "textdb" // 可选若不提供则使用默认集合 ragdb
}
response:
- Success: HTTP 200, {
"status": "success",
"users_knowledge_bases": {
"user1": {
"kb123": [
{
"document_id": "<uuid>",
"filename": "file1.txt",
"file_path": "/path/to/file1.txt",
"upload_time": "<iso_timestamp>",
"file_type": "txt",
"knowledge_base_id": "kb123"
},
...
],
"kb456": [
{
"document_id": "<uuid>",
"filename": "file2.pdf",
"file_path": "/path/to/file2.pdf",
"upload_time": "<iso_timestamp>",
"file_type": "pdf",
"knowledge_base_id": "kb456"
},
...
]
},
"user2": {...}
},
"collection_name": "ragdb" or "ragdb_textdb",
"message": "成功列出 <count> 个用户的知识库和文件",
"status_code": 200
}
- Error: HTTP 400, {
"status": "error",
"users_knowledge_bases": {},
"collection_name": "ragdb" or "ragdb_textdb",
"message": "<error message>",
"status_code": 400
}
"""
def init():
rf = RegisterFunction()
rf.register('createcollection', create_collection)
rf.register('deletecollection', delete_collection)
rf.register('insertfile', insert_file)
rf.register('deletefile', delete_file)
rf.register('deleteknowledgebase', delete_knowledge_base)
rf.register('fusedsearchquery', fused_search_query)
rf.register('searchquery', search_query)
rf.register('listuserfiles', list_user_files)
rf.register('listallknowledgebases', list_all_knowledge_bases)
rf.register('connection', handle_connection)
rf.register('docs', docs)
async def docs(request, params_kw, *params, **kw):
return web.Response(text=helptext, content_type='text/plain')
async def not_implemented(request, params_kw, *params, **kw):
return web.json_response({
"status": "error",
"message": "功能尚未实现"
}, dumps=lambda obj: json.dumps(obj, ensure_ascii=False), status=501)
async def create_collection(request, params_kw, *params, **kw):
debug(f'{params_kw=}')
se = ServerEnv()
engine = se.engine
db_type = params_kw.get('db_type', '')
collection_name = "ragdb" if not db_type else f"ragdb_{db_type}"
try:
result = await engine.handle_connection("create_collection", {"db_type": db_type})
debug(f'{result=}')
return web.json_response(result, dumps=lambda obj: json.dumps(obj, ensure_ascii=False))
except Exception as e:
error(f'创建集合失败: {str(e)}')
return web.json_response({
"status": "error",
"collection_name": collection_name,
"message": str(e)
}, dumps=lambda obj: json.dumps(obj, ensure_ascii=False), status=400)
async def delete_collection(request, params_kw, *params, **kw):
debug(f'{params_kw=}')
se = ServerEnv()
engine = se.engine
db_type = params_kw.get('db_type', '')
collection_name = "ragdb" if not db_type else f"ragdb_{db_type}"
try:
result = await engine.handle_connection("delete_collection", {"db_type": db_type})
debug(f'{result=}')
return web.json_response(result, dumps=lambda obj: json.dumps(obj, ensure_ascii=False))
except Exception as e:
error(f'删除集合失败: {str(e)}')
return web.json_response({
"status": "error",
"collection_name": collection_name,
"message": str(e)
}, dumps=lambda obj: json.dumps(obj, ensure_ascii=False), status=400)
async def insert_file(request, params_kw, *params, **kw):
debug(f'Received params: {params_kw=}')
se = ServerEnv()
engine = se.engine
file_path = params_kw.get('file_path', '')
userid = params_kw.get('userid', '')
db_type = params_kw.get('db_type', '')
knowledge_base_id = params_kw.get('knowledge_base_id', '')
document_id = params_kw.get('document_id', '')
collection_name = "ragdb" if not db_type else f"ragdb_{db_type}"
try:
required_fields = ['file_path', 'userid', 'knowledge_base_id', 'document_id']
missing_fields = [field for field in required_fields if field not in params_kw or not params_kw[field]]
if missing_fields:
raise ValueError(f"缺少必填字段: {', '.join(missing_fields)}")
debug(
f'Calling insert_document with: file_path={file_path}, userid={userid}, db_type={db_type}, knowledge_base_id={knowledge_base_id}, document_id={document_id}')
result = await engine.handle_connection("insert_document", {
"file_path": file_path,
"userid": userid,
"db_type": db_type,
"knowledge_base_id": knowledge_base_id,
"document_id": document_id
})
debug(f'Insert result: {result=}')
status = 200 if result.get("status") == "success" else 400
return web.json_response(result, dumps=lambda obj: json.dumps(obj, ensure_ascii=False), status=status)
except Exception as e:
error(f'插入文件失败: {str(e)}')
return web.json_response({
"status": "error",
"collection_name": collection_name,
"document_id": document_id,
"message": str(e)
}, dumps=lambda obj: json.dumps(obj, ensure_ascii=False), status=400)
async def delete_file(request, params_kw, *params, **kw):
debug(f'Received delete_file params: {params_kw=}')
se = ServerEnv()
engine = se.engine
userid = params_kw.get('userid', '')
file_path = params_kw.get('file_path', '')
db_type = params_kw.get('db_type', '')
knowledge_base_id = params_kw.get('knowledge_base_id', '')
document_id = params_kw.get('document_id', '')
collection_name = "ragdb" if not db_type else f"ragdb_{db_type}"
try:
required_fields = ['userid', 'file_path', 'knowledge_base_id', 'document_id']
missing_fields = [field for field in required_fields if field not in params_kw or not params_kw[field]]
if missing_fields:
raise ValueError(f"缺少必填字段: {', '.join(missing_fields)}")
debug(f'Calling delete_document with: userid={userid}, file_path={file_path}, db_type={db_type}, knowledge_base_id={knowledge_base_id}, document_id={document_id}')
result = await engine.handle_connection("delete_document", {
"userid": userid,
"file_path": file_path,
"knowledge_base_id": knowledge_base_id,
"document_id": document_id,
"db_type": db_type
})
debug(f'Delete result: {result=}')
status = 200 if result.get("status") == "success" else 400
return web.json_response(result, dumps=lambda obj: json.dumps(obj, ensure_ascii=False), status=status)
except Exception as e:
error(f'删除文件失败: {str(e)}')
return web.json_response({
"status": "error",
"collection_name": collection_name,
"document_id": document_id,
"message": str(e),
"status_code": 400
}, dumps=lambda obj: json.dumps(obj, ensure_ascii=False), status=400)
async def delete_knowledge_base(request, params_kw, *params, **kw):
debug(f'Received delete_knowledge_base params: {params_kw=}')
se = ServerEnv()
engine = se.engine
userid = params_kw.get('userid', '')
knowledge_base_id = params_kw.get('knowledge_base_id', '')
db_type = params_kw.get('db_type', '')
collection_name = "ragdb" if not db_type else f"ragdb_{db_type}"
try:
required_fields = ['userid', 'knowledge_base_id']
missing_fields = [field for field in required_fields if field not in params_kw or not params_kw[field]]
if missing_fields:
raise ValueError(f"缺少必填字段: {', '.join(missing_fields)}")
debug(
f'Calling delete_knowledge_base with: userid={userid}, knowledge_base_id={knowledge_base_id}, db_type={db_type}')
result = await engine.handle_connection("delete_knowledge_base", {
"userid": userid,
"knowledge_base_id": knowledge_base_id,
"db_type": db_type
})
debug(f'Delete knowledge base result: {result=}')
status = 200 if result.get("status") == "success" else 400
return web.json_response(result, dumps=lambda obj: json.dumps(obj, ensure_ascii=False), status=status)
except Exception as e:
error(f'删除知识库失败: {str(e)}')
return web.json_response({
"status": "error",
"collection_name": collection_name,
"document_id": "",
"filename": "",
"message": str(e),
"status_code": 400
}, dumps=lambda obj: json.dumps(obj, ensure_ascii=False), status=400)
async def search_query(request, params_kw, *params, **kw):
debug(f'Received search_query params: {params_kw=}')
se = ServerEnv()
engine = se.engine
query = params_kw.get('query')
userid = params_kw.get('userid')
db_type = params_kw.get('db_type', '')
knowledge_base_ids = params_kw.get('knowledge_base_ids')
limit = params_kw.get('limit', 5)
offset = params_kw.get('offset', 0)
use_rerank = params_kw.get('use_rerank', True)
collection_name = "ragdb" if not db_type else f"ragdb_{db_type}"
try:
if not all([query, userid, knowledge_base_ids]):
debug(f'query, userid 或 knowledge_base_ids 未提供')
return web.json_response({
"status": "error",
"message": "query, userid 或 knowledge_base_ids 未提供",
"collection_name": collection_name
}, dumps=lambda obj: json.dumps(obj, ensure_ascii=False), status=400)
result = await engine.handle_connection("search_query", {
"query": query,
"userid": userid,
"knowledge_base_ids": knowledge_base_ids,
"limit": limit,
"offset": offset,
"use_rerank": use_rerank,
"db_type": db_type
})
debug(f'Search result: {result=}')
response = {
"status": "success",
"results": result.get("results", []),
"timing": result.get("timing", {}),
"collection_name": collection_name
}
return web.json_response(response, dumps=lambda obj: json.dumps(obj, ensure_ascii=False))
except Exception as e:
error(f'纯向量搜索失败: {str(e)}')
return web.json_response({
"status": "error",
"message": str(e),
"collection_name": collection_name
}, dumps=lambda obj: json.dumps(obj, ensure_ascii=False), status=400)
async def fused_search_query(request, params_kw, *params, **kw):
debug(f'Received fused_search_query params: {params_kw=}')
se = ServerEnv()
engine = se.engine
query = params_kw.get('query')
userid = params_kw.get('userid')
db_type = params_kw.get('db_type', '')
knowledge_base_ids = params_kw.get('knowledge_base_ids')
limit = params_kw.get('limit', 5)
offset = params_kw.get('offset', 0)
use_rerank = params_kw.get('use_rerank', True)
collection_name = "ragdb" if not db_type else f"ragdb_{db_type}"
try:
if not all([query, userid, knowledge_base_ids]):
debug(f'query, userid 或 knowledge_base_ids 未提供')
return web.json_response({
"status": "error",
"message": "query, userid 或 knowledge_base_ids 未提供",
"collection_name": collection_name
}, dumps=lambda obj: json.dumps(obj, ensure_ascii=False), status=400)
result = await engine.handle_connection("fused_search", {
"query": query,
"userid": userid,
"knowledge_base_ids": knowledge_base_ids,
"limit": limit,
"offset": offset,
"use_rerank": use_rerank,
"db_type": db_type
})
debug(f'Fused search result: {result=}')
response = {
"status": "success",
"results": result.get("results", []),
"timing": result.get("timing", {}),
"collection_name": collection_name
}
return web.json_response(response, dumps=lambda obj: json.dumps(obj, ensure_ascii=False))
except Exception as e:
error(f'融合搜索失败: {str(e)}')
return web.json_response({
"status": "error",
"message": str(e),
"collection_name": collection_name
}, dumps=lambda obj: json.dumps(obj, ensure_ascii=False), status=400)
async def list_user_files(request, params_kw, *params, **kw):
debug(f'{params_kw=}')
se = ServerEnv()
engine = se.engine
userid = params_kw.get('userid')
db_type = params_kw.get('db_type', '')
collection_name = "ragdb" if not db_type else f"ragdb_{db_type}"
try:
if not userid:
debug(f'userid 未提供')
return web.json_response({
"status": "error",
"message": "userid 未提供",
"collection_name": collection_name
}, dumps=lambda obj: json.dumps(obj, ensure_ascii=False), status=400)
result = await engine.handle_connection("list_user_files", {
"userid": userid,
"db_type": db_type
})
debug(f'{result=}')
response = {
"status": "success",
"files_by_knowledge_base": result,
"collection_name": collection_name
}
return web.json_response(response, dumps=lambda obj: json.dumps(obj, ensure_ascii=False))
except Exception as e:
error(f'列出用户文件失败: {str(e)}')
return web.json_response({
"status": "error",
"message": str(e),
"collection_name": collection_name
}, dumps=lambda obj: json.dumps(obj, ensure_ascii=False), status=400)
async def list_all_knowledge_bases(request, params_kw, *params, **kw):
debug(f'{params_kw=}')
se = ServerEnv()
engine = se.engine
db_type = params_kw.get('db_type', '')
collection_name = "ragdb" if not db_type else f"ragdb_{db_type}"
try:
result = await engine.handle_connection("list_all_knowledge_bases", {
"db_type": db_type
})
debug(f'{result=}')
response = {
"status": result.get("status", "success"),
"users_knowledge_bases": result.get("users_knowledge_bases", {}),
"collection_name": collection_name,
"message": result.get("message", ""),
"status_code": result.get("status_code", 200)
}
return web.json_response(response, dumps=lambda obj: json.dumps(obj, ensure_ascii=False), status=response["status_code"])
except Exception as e:
error(f'列出所有用户知识库失败: {str(e)}')
return web.json_response({
"status": "error",
"users_knowledge_bases": {},
"collection_name": collection_name,
"message": str(e),
"status_code": 400
}, dumps=lambda obj: json.dumps(obj, ensure_ascii=False), status=400)
async def handle_connection(request, params_kw, *params, **kw):
debug(f'{params_kw=}')
se = ServerEnv()
engine = se.engine
try:
data = await request.json()
action = data.get('action')
if not action:
debug(f'action 未提供')
return web.json_response({
"status": "error",
"message": "action 参数未提供"
}, dumps=lambda obj: json.dumps(obj, ensure_ascii=False), status=400)
result = await engine.handle_connection(action, data.get('params', {}))
debug(f'{result=}')
return web.json_response(result, dumps=lambda obj: json.dumps(obj, ensure_ascii=False))
except Exception as e:
error(f'处理连接操作失败: {str(e)}')
return web.json_response({
"status": "error",
"message": str(e)
}, dumps=lambda obj: json.dumps(obj, ensure_ascii=False), status=400)
def main():
parser = argparse.ArgumentParser(prog="Milvus Connection Service")
parser.add_argument('-w', '--workdir')
parser.add_argument('-p', '--port', default='8888')
parser.add_argument('connection_path')
args = parser.parse_args()
debug(f"Arguments: {args}")
Klass = get_connection_class(args.connection_path)
se = ServerEnv()
se.engine = Klass()
workdir = args.workdir or os.getcwd()
port = args.port
debug(f'{args=}')
webserver(init, workdir, port)
if __name__ == '__main__':
main()

View File

@ -1,380 +0,0 @@
from rag.api_service import APIService
from appPublic.registerfunction import RegisterFunction
from appPublic.log import debug, error, info
from sqlor.dbpools import DBPools
import asyncio
import aiohttp
from langchain_core.documents import Document
from langchain_text_splitters import RecursiveCharacterTextSplitter
import os
import re
import time
import uuid
from datetime import datetime
import traceback
from filetxt.loader import fileloader
from ahserver.serverenv import get_serverenv
from typing import List, Dict, Any
api_service = APIService()
async def get_orgid_by_id(kdb_id):
"""
根据 kdb id 查询对应的 orgid
"""
db = DBPools()
# f = get_serverenv("get_module_dbname")
# dbname = f("rag")
dbname = "kyrag"
sql = "SELECT orgid FROM kdb WHERE id = ${id}$"
try:
async with db.sqlorContext(dbname) as sor:
result = await sor.sqlExe(sql,{"id":kdb_id})
print(result)
if result and len(result) > 0:
return result[0].get('orgid')
return None
except Exception as e:
error(f"查询 orgid 失败: {str(e)}, 堆栈: {traceback.format_exc()}")
return None
async def file_uploaded(params_kw):
"""将文档插入 Milvus 并抽取三元组到 Neo4j"""
debug(f'Received params: {params_kw=}')
realpath = params_kw.get('realpath', '')
fiid = params_kw.get('fiid', '')
id = params_kw.get('id', '')
orgid = await get_orgid_by_id(fiid)
db_type = ''
debug(f'Inserting document: file_path={realpath}, userid={orgid}, db_type={db_type}, knowledge_base_id={fiid}, document_id={id}')
timings = {}
start_total = time.time()
try:
if not orgid or not fiid or not id:
raise ValueError("orgid、fiid 和 id 不能为空")
debug(f'orgid、fiid 和 id 不能为空')
if len(orgid) > 32 or len(fiid) > 255:
raise ValueError("orgid 或 fiid 的长度超出限制")
if not os.path.exists(realpath):
raise ValueError(f"文件 {realpath} 不存在")
supported_formats = {'pdf', 'docx', 'xlsx', 'pptx', 'csv', 'txt'}
ext = realpath.rsplit('.', 1)[1].lower() if '.' in realpath else ''
if ext not in supported_formats:
raise ValueError(f"不支持的文件格式: {ext}, 支持的格式: {', '.join(supported_formats)}")
debug(f"加载文件: {realpath}")
start_load = time.time()
text = fileloader(realpath)
text = re.sub(r'[^\u4e00-\u9fa5a-zA-Z0-9\s.;,\n]', '', text)
timings["load_file"] = time.time() - start_load
debug(f"加载文件耗时: {timings['load_file']:.2f} 秒, 文本长度: {len(text)}")
if not text or not text.strip():
raise ValueError(f"文件 {realpath} 加载为空")
document = Document(page_content=text)
text_splitter = RecursiveCharacterTextSplitter(
chunk_size=500,
chunk_overlap=100,
length_function=len)
debug("开始分片文件内容")
start_split = time.time()
chunks = text_splitter.split_documents([document])
timings["split_text"] = time.time() - start_split
debug(f"文本分片耗时: {timings['split_text']:.2f} 秒, 分片数量: {len(chunks)}, 分片内容: {[chunk.page_content[:50] for chunk in chunks[:5]]}")
if not chunks:
raise ValueError(f"文件 {realpath} 未生成任何文档块")
filename = os.path.basename(realpath).rsplit('.', 1)[0]
upload_time = datetime.now().isoformat()
debug("调用嵌入服务生成向量")
start_embedding = time.time()
texts = [chunk.page_content for chunk in chunks]
embeddings = []
for i in range(0, len(texts), 10): # 每次处理 10 个文本块
batch_texts = texts[i:i + 10]
batch_embeddings = await api_service.get_embeddings(batch_texts)
embeddings.extend(batch_embeddings)
if not embeddings or not all(len(vec) == 1024 for vec in embeddings):
raise ValueError("所有嵌入向量必须是长度为 1024 的浮点数列表")
timings["generate_embeddings"] = time.time() - start_embedding
debug(f"生成嵌入向量耗时: {timings['generate_embeddings']:.2f} 秒, 嵌入数量: {len(embeddings)}")
chunks_data = []
for i, chunk in enumerate(chunks):
chunks_data.append({
"userid": orgid,
"knowledge_base_id": fiid,
"text": chunk.page_content,
"vector": embeddings[i],
"document_id": id,
"filename": filename + '.' + ext,
"file_path": realpath,
"upload_time": upload_time,
"file_type": ext,
})
debug(f"调用插入文件端点: {realpath}")
start_milvus = time.time()
for i in range(0, len(chunks_data), 10): # 每次处理 10 条数据
batch_chunks = chunks_data[i:i + 10]
result = await api_service.milvus_insert_document(batch_chunks, db_type)
if result.get("status") != "success":
raise ValueError(result.get("message", "Milvus 插入失败"))
timings["insert_milvus"] = time.time() - start_milvus
debug(f"Milvus 插入耗时: {timings['insert_milvus']:.2f}")
if result.get("status") != "success":
timings["total"] = time.time() - start_total
return {"status": "error", "document_id": id, "timings": timings, "message": result.get("message", "未知错误"), "status_code": 400}
debug("调用三元组抽取服务")
start_triples = time.time()
try:
chunk_texts = [doc.page_content for doc in chunks]
debug(f"处理 {len(chunk_texts)} 个分片进行三元组抽取")
tasks = [api_service.extract_triples(chunk) for chunk in chunk_texts]
results = await asyncio.gather(*tasks, return_exceptions=True)
triples = []
for i, result in enumerate(results):
if isinstance(result, list):
triples.extend(result)
debug(f"分片 {i + 1} 抽取到 {len(result)} 个三元组: {result[:5]}")
else:
error(f"分片 {i + 1} 处理失败: {str(result)}")
unique_triples = []
seen = set()
for t in triples:
identifier = (t['head'].lower(), t['tail'].lower(), t['type'].lower())
if identifier not in seen:
seen.add(identifier)
unique_triples.append(t)
else:
for existing in unique_triples:
if (existing['head'].lower() == t['head'].lower() and
existing['tail'].lower() == t['tail'].lower() and
len(t['type']) > len(existing['type'])):
unique_triples.remove(existing)
unique_triples.append(t)
debug(f"替换三元组为更具体类型: {t}")
break
timings["extract_triples"] = time.time() - start_triples
debug(f"三元组抽取耗时: {timings['extract_triples']:.2f} 秒, 抽取到 {len(unique_triples)} 个三元组: {unique_triples[:5]}")
debug(f"抽取到 {len(unique_triples)} 个三元组,调用 Neo4j 服务插入")
start_neo4j = time.time()
for i in range(0, len(unique_triples), 30): # 每次插入 30 个三元组
batch_triples = unique_triples[i:i + 30]
neo4j_result = await api_service.neo4j_insert_triples(batch_triples, id, fiid, orgid)
debug(f"Neo4j 服务响应: {neo4j_result}")
if neo4j_result.get("status") != "success":
timings["insert_neo4j"] = time.time() - start_neo4j
timings["total"] = time.time() - start_total
return {"status": "error", "document_id": id, "collection_name": "ragdb", "timings": timings,
"message": f"Neo4j 三元组插入失败: {neo4j_result.get('message', '未知错误')}", "status_code": 400}
info(f"文件 {realpath} 三元组成功插入 Neo4j: {neo4j_result.get('message')}")
else:
debug(f"文件 {realpath} 未抽取到三元组")
timings["insert_neo4j"] = time.time() - start_neo4j
debug(f"Neo4j 插入耗时: {timings['insert_neo4j']:.2f}")
except Exception as e:
timings["extract_triples"] = time.time() - start_triples if "extract_triples" not in timings else timings["extract_triples"]
timings["insert_neo4j"] = time.time() - start_neo4j
debug(f"处理三元组或 Neo4j 插入失败: {str(e)}, 堆栈: {traceback.format_exc()}")
timings["total"] = time.time() - start_total
return {"status": "success", "document_id": id, "collection_name": "ragdb", "timings": timings,
"unique_triples": unique_triples,
"message": f"文件 {realpath} 成功嵌入,但三元组处理或 Neo4j 插入失败: {str(e)}", "status_code": 200}
timings["total"] = time.time() - start_total
debug(f"总耗时: {timings['total']:.2f}")
return {"status": "success", "userid": orgid, "document_id": id, "collection_name": "ragdb", "timings": timings,
"unique_triples": unique_triples, "message": f"文件 {realpath} 成功嵌入并处理三元组", "status_code": 200}
except Exception as e:
error(f"插入文档失败: {str(e)}, 堆栈: {traceback.format_exc()}")
timings["total"] = time.time() - start_total
return {"status": "error", "document_id": id, "collection_name": "ragdb", "timings": timings,
"message": f"插入文档失败: {str(e)}", "status_code": 400}
async def file_deleted(params_kw):
"""删除用户指定文件数据,包括 Milvus 和 Neo4j 中的记录"""
id = params_kw.get('id', '')
realpath = params_kw.get('realpath', '')
fiid = params_kw.get('fiid', '')
orgid = await get_orgid_by_id(fiid)
db_type = ''
collection_name = "ragdb" if not db_type else f"ragdb_{db_type}"
try:
required_fields = ['id', 'fiid', 'realpath']
missing_fields = [field for field in required_fields if field not in params_kw or not params_kw[field]]
if missing_fields:
raise ValueError(f"缺少必填字段: {', '.join(missing_fields)}")
debug(f"调用删除文件端点: userid={orgid}, file_path={realpath}, knowledge_base_id={fiid}")
milvus_result = await api_service.milvus_delete_document(orgid, realpath, fiid, id, db_type)
if milvus_result.get("status") != "success":
raise ValueError(milvus_result.get("message", "Milvus 删除失败"))
neo4j_deleted_nodes = 0
neo4j_deleted_rels = 0
try:
debug(f"调用 Neo4j 删除文档端点: document_id={id}")
neo4j_result = await api_service.neo4j_delete_document(id)
if neo4j_result.get("status") != "success":
raise ValueError(neo4j_result.get("message", "Neo4j 删除失败"))
nodes_deleted = neo4j_result.get("nodes_deleted", 0)
rels_deleted = neo4j_result.get("rels_deleted", 0)
neo4j_deleted_nodes += nodes_deleted
neo4j_deleted_rels += rels_deleted
info(f"成功删除 document_id={id}{nodes_deleted} 个 Neo4j 节点和 {rels_deleted} 个关系")
except Exception as e:
error(f"删除 document_id={id} 的 Neo4j 数据失败: {str(e)}")
return {
"status": "success",
"collection_name": collection_name,
"document_id": id,
"message": f"成功删除 Milvus 记录,{neo4j_deleted_nodes} 个 Neo4j 节点,{neo4j_deleted_rels} 个 Neo4j 关系",
"status_code": 200
}
except Exception as e:
error(f"删除文档失败: {str(e)}, 堆栈: {traceback.format_exc()}")
return {
"status": "error",
"collection_name": collection_name,
"document_id": id,
"message": f"删除文档失败: {str(e)}",
"status_code": 400
}
async def _search_query(query: str, userid: str, knowledge_base_ids: List[str], limit: int = 5,
offset: int = 0, use_rerank: bool = True, db_type: str = "") -> Dict[str, Any]:
"""纯向量搜索,调用服务化端点"""
start_time = time.time()
collection_name = "ragdb" if not db_type else f"ragdb_{db_type}"
timing_stats = {}
try:
info(
f"开始纯向量搜索: query={query}, userid={userid}, db_type={db_type}, knowledge_base_ids={knowledge_base_ids}, limit={limit}, offset={offset}, use_rerank={use_rerank}")
if not query:
raise ValueError("查询文本不能为空")
if not userid:
raise ValueError("userid 不能为空")
if limit <= 0 or limit > 16384:
raise ValueError("limit 必须在 1 到 16384 之间")
if offset < 0:
raise ValueError("offset 不能为负数")
if limit + offset > 16384:
raise ValueError("limit + offset 不能超过 16384")
if not knowledge_base_ids:
raise ValueError("knowledge_base_ids 不能为空")
for kb_id in knowledge_base_ids:
if not isinstance(kb_id, str):
raise ValueError(f"knowledge_base_id 必须是字符串: {kb_id}")
if len(kb_id) > 100:
raise ValueError(f"knowledge_base_id 长度超出 100 个字符: {kb_id}")
# 将查询文本转换为向量
vector_start = time.time()
query_vector = await api_service.get_embeddings([query])
if not query_vector or not all(len(vec) == 1024 for vec in query_vector):
raise ValueError("查询向量必须是长度为 1024 的浮点数列表")
query_vector = query_vector[0] # 取第一个向量
timing_stats["vector_generation"] = time.time() - vector_start
debug(f"生成查询向量耗时: {timing_stats['vector_generation']:.3f}")
# 调用纯向量搜索端点
search_start = time.time()
result = await api_service.milvus_search_query(query_vector, userid, knowledge_base_ids, limit, offset)
timing_stats["vector_search"] = time.time() - search_start
debug(f"向量搜索耗时: {timing_stats['vector_search']:.3f}")
if result.get("status") != "success":
error(f"纯向量搜索失败: {result.get('message', '未知错误')}")
return {"results": [], "timing": timing_stats}
unique_results = result.get("results", [])
if use_rerank and unique_results:
rerank_start = time.time()
debug("开始重排序")
unique_results = await api_service.rerank_results(query, unique_results, limit)
unique_results = sorted(unique_results, key=lambda x: x.get('rerank_score', 0), reverse=True)
timing_stats["reranking"] = time.time() - rerank_start
debug(f"重排序耗时: {timing_stats['reranking']:.3f}")
debug(f"重排序分数分布: {[round(r.get('rerank_score', 0), 3) for r in unique_results]}")
else:
unique_results = [{k: v for k, v in r.items() if k != 'rerank_score'} for r in unique_results]
timing_stats["total_time"] = time.time() - start_time
info(f"纯向量搜索完成,返回 {len(unique_results)} 条结果,总耗时: {timing_stats['total_time']:.3f}")
return {"results": unique_results[:limit], "timing": timing_stats}
except Exception as e:
error(f"纯向量搜索失败: {str(e)}, 堆栈: {traceback.format_exc()}")
return {"results": [], "timing": timing_stats}
async def main():
dbs = {
"kyrag":{
"driver":"aiomysql",
"async_mode":True,
"coding":"utf8",
"maxconn":100,
"dbname":"kyrag",
"kwargs":{
"user":"test",
"db":"kyrag",
"password":"QUZVcXg5V1p1STMybG5Ia6mX9D0v7+g=",
"host":"db"
}
}
}
DBPools(dbs)
# 测试 file_uploaded
print("测试 file_uploaded...")
test_file_path = "/home/wangmeihua/data/kg.txt"
test_params_upload = {
"realpath": test_file_path,
"fiid": "1",
"id": "doc1"
}
upload_result = await file_uploaded(test_params_upload)
print(f"file_uploaded 结果: {upload_result}")
# # 测试 file_deleted
# test_file_path = "/home/wangmeihua/data/kg.txt"
# print("测试 file_deleted...")
# test_params_delete = {
# "realpath": test_file_path,
# "fiid": "1",
# "id": "doc1"
# }
# delete_result = await file_deleted(test_params_delete)
# print(f"file_deleted 结果: {delete_result}")
# # 测试 _search_query
# print("测试 _search_query...")
# test_params_query = {
# "query": "什么是关系抽取",
# "userid": "04J6VbxLqB_9RPMcgOv_8",
# "knowledge_base_ids": ["1"],
# "limit": 5,
# "offset": 0,
# "use_rerank": True
# }
# query_result = await _search_query(query="什么是知识融合?", userid="testuser1", knowledge_base_ids=["kb1", "kb2"], limit=5, offset=0, use_rerank=True, db_type="")
# print(f"file_uploaded 结果: {query_result}")
if __name__ == "__main__":
asyncio.run(main())

View File

@ -1,978 +0,0 @@
import os
from appPublic.log import debug, error, info
from base_connection import connection_register
from typing import Dict, List, Any
import numpy as np
import aiohttp
from aiohttp import ClientSession, ClientTimeout
from langchain_core.documents import Document
from langchain_text_splitters import RecursiveCharacterTextSplitter
import uuid
from datetime import datetime
from filetxt.loader import fileloader
import time
from tenacity import retry, stop_after_attempt, wait_exponential, retry_if_exception_type
import traceback
import asyncio
import re
# 嵌入缓存
EMBED_CACHE = {}
class MilvusConnection:
def __init__(self):
pass
@retry(stop = stop_after_attempt(3))
async def _make_neo4japi_request(self, action: str, params: Dict[str, Any]) -> Dict[str, Any]:
debug(f"开始API请求action={action}, params={params}")
try:
async with ClientSession(timeout=ClientTimeout(total=300)) as session:
url = f"http://localhost:8885/v1/{action}"
debug(f"发起POST请求{url}")
async with session.post(
url,
headers={'Content-Type': 'application/json'},
json=params
) as response:
debug(f"收到相应: status={response.status}, headers={response.headers}")
respose_text = await response.text()
debug(f"响应内容: {respose_text}")
result = await response.json()
debug(f"API响应内容: {result}")
if response.status == 400:
debug(f"客户端错误,状态码: {response.status},返回响应: {result}")
return result
if response.status != 200:
error(f"API 调用失败,动作: {action}, 状态码: {response.status}, 响应: {response_text}")
raise RuntimeError(f"API 调用失败: {response.status}")
debug(f"API 调用成功: {action}, 响应: {result}")
return result
except Exception as e:
error(f"API 调用失败: {action}, 错误: {str(e)}, 堆栈: {traceback.format_exc()}")
raise RuntimeError(f"API 调用失败: {str(e)}")
@retry(stop=stop_after_attempt(3))
async def _make_api_request(self, action: str, params: Dict[str, Any]) -> Dict[str, Any]:
debug(f"开始 API 请求: action={action}, params={params}")
try:
async with ClientSession(timeout=ClientTimeout(total=300)) as session:
url = f"http://localhost:8886/v1/{action}"
debug(f"发起 POST 请求: {url}")
async with session.post(
url,
headers={"Content-Type": "application/json"},
json=params
) as response:
debug(f"收到响应: status={response.status}, headers={response.headers}")
response_text = await response.text()
debug(f"响应内容: {response_text}")
result = await response.json()
debug(f"API 响应内容: {result}")
if response.status == 400: # 客户端错误,直接返回
debug(f"客户端错误,状态码: {response.status}, 返回响应: {result}")
return result
if response.status != 200:
error(f"API 调用失败,动作: {action}, 状态码: {response.status}, 响应: {response_text}")
raise RuntimeError(f"API 调用失败: {response.status}")
debug(f"API 调用成功: {action}, 响应: {result}")
return result
except Exception as e:
error(f"API 调用失败: {action}, 错误: {str(e)}, 堆栈: {traceback.format_exc()}")
raise RuntimeError(f"API 调用失败: {str(e)}")
async def handle_connection(self, action: str, params: Dict = None) -> Dict:
"""处理数据库操作"""
try:
debug(f"处理操作: action={action}, params={params}")
if not params:
params = {}
# 通用 db_type 验证
db_type = params.get("db_type", "")
collection_name = "ragdb" if not db_type else f"ragdb_{db_type}"
if db_type and "_" in db_type:
return {"status": "error", "message": "db_type 不能包含下划线", "collection_name": collection_name,
"document_id": "", "status_code": 400}
if db_type and len(db_type) > 100:
return {"status": "error", "message": "db_type 的长度应小于 100", "collection_name": collection_name,
"document_id": "", "status_code": 400}
if action == "initialize":
return {"status": "success", "message": "Milvus 服务已初始化"}
elif action == "get_params":
return {"status": "success", "params": {}}
elif action == "create_collection":
return await self._create_collection(db_type)
elif action == "delete_collection":
return await self._delete_collection(db_type)
elif action == "insert_document":
file_path = params.get("file_path", "")
userid = params.get("userid", "")
knowledge_base_id = params.get("knowledge_base_id", "")
document_id = params.get("document_id", "")
if not file_path or not userid or not knowledge_base_id or not document_id:
return {"status": "error", "message": "file_path、userid document_id和 knowledge_base_id 不能为空",
"collection_name": collection_name, "document_id": "", "status_code": 400}
if "_" in userid or "_" in knowledge_base_id:
return {"status": "error", "message": "userid 和 knowledge_base_id 不能包含下划线",
"collection_name": collection_name, "document_id": document_id, "status_code": 400}
if len(knowledge_base_id) > 100:
return {"status": "error", "message": "knowledge_base_id 的长度应小于 100",
"collection_name": collection_name, "document_id": "", "status_code": 400}
return await self._insert_document(file_path, userid, knowledge_base_id, document_id, db_type)
elif action == "delete_document":
userid = params.get("userid", "")
file_path = params.get("file_path", "")
knowledge_base_id = params.get("knowledge_base_id", "")
document_id = params.get("document_id", "")
if not userid or not file_path or not knowledge_base_id or not document_id:
return {"status": "error", "message": "userid、file_path document_id和 knowledge_base_id 不能为空",
"collection_name": collection_name, "document_id": "", "status_code": 400}
if "_" in userid or "_" in knowledge_base_id:
return {"status": "error", "message": "userid 和 knowledge_base_id 不能包含下划线",
"collection_name": collection_name, "document_id": "", "status_code": 400}
if len(userid) > 100 or len(file_path) > 255 or len(knowledge_base_id) > 100:
return {"status": "error", "message": "userid、file_path 或 knowledge_base_id 的长度超出限制",
"collection_name": collection_name, "document_id": "", "status_code": 400}
return await self._delete_document(userid, file_path, knowledge_base_id, document_id, db_type)
elif action == "delete_knowledge_base":
userid = params.get("userid", "")
knowledge_base_id = params.get("knowledge_base_id", "")
if not userid or not knowledge_base_id:
return {"status": "error", "message": "userid 和 knowledge_base_id 不能为空",
"collection_name": collection_name, "document_id": "", "status_code": 400}
if "_" in userid or "_" in knowledge_base_id:
return {"status": "error", "message": "userid 和 knowledge_base_id 不能包含下划线",
"collection_name": collection_name, "document_id": "", "status_code": 400}
if len(userid) > 100 or len(knowledge_base_id) > 100:
return {"status": "error", "message": "userid 或 knowledge_base_id 的长度超出限制",
"collection_name": collection_name, "document_id": "", "status_code": 400}
return await self._delete_knowledge_base(db_type, userid, knowledge_base_id)
elif action == "search_query":
query = params.get("query", "")
userid = params.get("userid", "")
knowledge_base_ids = params.get("knowledge_base_ids", [])
limit = params.get("limit", 5)
offset = params.get("offset", 0)
db_type = params.get("db_type", "")
use_rerank = params.get("use_rerank", True)
if not query or not userid or not knowledge_base_ids:
return {"status": "error", "message": "query、userid 或 knowledge_base_ids 不能为空",
"collection_name": collection_name, "document_id": "", "status_code": 400}
return await self._search_query(query, userid, knowledge_base_ids, limit, offset, use_rerank, db_type)
elif action == "fused_search":
query = params.get("query", "")
userid = params.get("userid", "")
knowledge_base_ids = params.get("knowledge_base_ids", [])
limit = params.get("limit", 5)
offset = params.get("offset", 0)
db_type = params.get("db_type", "")
use_rerank = params.get("use_rerank", True)
if not query or not userid or not knowledge_base_ids:
return {"status": "error", "message": "query、userid 或 knowledge_base_ids 不能为空",
"collection_name": collection_name, "document_id": "", "status_code": 400}
return await self._fused_search(query, userid, knowledge_base_ids, limit, offset, use_rerank, db_type)
elif action == "list_user_files":
userid = params.get("userid", "")
if not userid:
return {"status": "error", "message": "userid 不能为空", "collection_name": collection_name,
"document_id": "", "status_code": 400}
return await self._list_user_files(userid, db_type)
elif action == "list_all_knowledge_bases":
return await self._list_all_knowledge_bases(db_type)
else:
return {"status": "error", "message": f"未知的 action: {action}", "collection_name": collection_name,
"document_id": "", "status_code": 400}
except Exception as e:
error(f"处理操作失败: action={action}, 错误: {str(e)}")
return {
"status": "error",
"message": f"服务器错误: {str(e)}",
"collection_name": collection_name,
"document_id": "",
"status_code": 400
}
async def _create_collection(self, db_type: str = "") -> Dict[str, Any]:
"""创建 Milvus 集合"""
collection_name = "ragdb" if not db_type else f"ragdb_{db_type}"
try:
if len(collection_name) > 255:
raise ValueError(f"集合名称 {collection_name} 超过 255 个字符")
if len(db_type) > 100:
raise ValueError("db_type 的长度应小于 100")
debug(f"调用创建集合端点: {collection_name}, 参数: {{'db_type': '{db_type}'}}")
result = await self._make_api_request("createcollection", {"db_type": db_type})
return result
except Exception as e:
error(f"创建集合失败: {str(e)}, 堆栈: {traceback.format_exc()}")
return {
"status": "error",
"collection_name": collection_name,
"message": f"创建集合失败: {str(e)}",
"status_code": 400
}
async def _delete_collection(self, db_type: str = "") -> Dict:
"""删除 Milvus 集合通过服务化端点"""
try:
collection_name = "ragdb" if not db_type else f"ragdb_{db_type}"
if len(collection_name) > 255:
raise ValueError(f"集合名称 {collection_name} 超过 255 个字符")
if db_type and "_" in db_type:
raise ValueError("db_type 不能包含下划线")
if db_type and len(db_type) > 100:
raise ValueError("db_type 的长度应小于 100")
debug(f"调用删除集合端点: {collection_name}")
result = await self._make_api_request("deletecollection", {"db_type": db_type})
return result
except Exception as e:
error(f"删除集合失败: {str(e)}")
return {
"status": "error",
"collection_name": collection_name,
"message": str(e),
"status_code": 400
}
async def _insert_document(self, file_path: str, userid: str, knowledge_base_id: str, document_id:str, db_type: str = "") -> Dict[
str, Any]:
"""将文档插入 Milvus 并抽取三元组到 Neo4j"""
collection_name = "ragdb" if not db_type else f"ragdb_{db_type}"
debug(
f'Inserting document: file_path={file_path}, userid={userid}, db_type={db_type}, knowledge_base_id={knowledge_base_id}, document_id={document_id}')
timings = {}
start_total = time.time()
try:
# 验证参数
if not userid or not knowledge_base_id:
raise ValueError("userid 和 knowledge_base_id 不能为空")
if "_" in userid or "_" in knowledge_base_id:
raise ValueError("userid 和 knowledge_base_id 不能包含下划线")
if len(userid) > 100 or len(knowledge_base_id) > 100:
raise ValueError("userid 或 knowledge_base_id 的长度超出限制")
if not os.path.exists(file_path):
raise ValueError(f"文件 {file_path} 不存在")
supported_formats = {'pdf', 'docx', 'xlsx', 'pptx', 'csv', 'txt'}
ext = file_path.rsplit('.', 1)[1].lower() if '.' in file_path else ''
if ext not in supported_formats:
raise ValueError(f"不支持的文件格式: {ext}, 支持的格式: {', '.join(supported_formats)}")
info(f"生成 document_id: {document_id} for file: {file_path}")
# 文件加载
debug(f"加载文件: {file_path}")
start_load = time.time()
text = fileloader(file_path)
text = re.sub(r'[^\u4e00-\u9fa5a-zA-Z0-9\s.;,\n]', '', text)
timings["load_file"] = time.time() - start_load
debug(f"加载文件耗时: {timings['load_file']:.2f} 秒, 文本长度: {len(text)}")
if not text or not text.strip():
raise ValueError(f"文件 {file_path} 加载为空")
# 文本分片
document = Document(page_content=text)
text_splitter = RecursiveCharacterTextSplitter(
chunk_size=500,
chunk_overlap=100,
length_function=len,
)
debug("开始分片文件内容")
start_split = time.time()
chunks = text_splitter.split_documents([document])
timings["split_text"] = time.time() - start_split
debug(
f"文本分片耗时: {timings['split_text']:.2f} 秒, 分片数量: {len(chunks)}, 分片内容: {[chunk.page_content[:50] for chunk in chunks[:5]]}")
if not chunks:
raise ValueError(f"文件 {file_path} 未生成任何文档块")
filename = os.path.basename(file_path).rsplit('.', 1)[0]
upload_time = datetime.now().isoformat()
# 生成嵌入向量
debug("调用嵌入服务生成向量")
start_embedding = time.time()
texts = [chunk.page_content for chunk in chunks]
embeddings = await self._get_embeddings(texts)
if not embeddings or not all(len(vec) == 1024 for vec in embeddings):
raise ValueError("所有嵌入向量必须是长度为 1024 的浮点数列表")
timings["generate_embeddings"] = time.time() - start_embedding
debug(f"生成嵌入向量耗时: {timings['generate_embeddings']:.2f} 秒, 嵌入数量: {len(embeddings)}")
# 构造 chunks 参数(展平结构)
chunks_data = []
for i, chunk in enumerate(chunks):
chunks_data.append({
"userid": userid,
"knowledge_base_id": knowledge_base_id,
"text": chunk.page_content,
"vector": embeddings[i].tolist(),
"document_id": document_id,
"filename": filename + '.' + ext,
"file_path": file_path,
"upload_time": upload_time,
"file_type": ext,
})
# 调用 Milvus 插入端点
debug(f"调用插入文件端点: {file_path}")
start_milvus = time.time()
result = await self._make_api_request("insertdocument", {
"chunks": chunks_data,
"db_type": db_type,
})
timings["insert_milvus"] = time.time() - start_milvus
debug(f"Milvus 插入耗时: {timings['insert_milvus']:.2f}")
if result.get("status") != "success":
timings["total"] = time.time() - start_total
return {
"status": "error",
"document_id": document_id,
"collection_name": collection_name,
"timings": timings,
"message": result.get("message", "未知错误"),
"status_code": 400
}
# 三元组抽取
debug("调用三元组抽取服务")
start_triples = time.time()
try:
chunk_texts = [doc.page_content for doc in chunks]
debug(f"处理 {len(chunk_texts)} 个分片进行三元组抽取")
tasks = [self._extract_triples(chunk) for chunk in chunk_texts]
results = await asyncio.gather(*tasks, return_exceptions=True)
triples = []
for i, result in enumerate(results):
if isinstance(result, list):
triples.extend(result)
debug(f"分片 {i + 1} 抽取到 {len(result)} 个三元组: {result[:5]}")
else:
error(f"分片 {i + 1} 处理失败: {str(result)}")
# 去重
unique_triples = []
seen = set()
for t in triples:
identifier = (t['head'].lower(), t['tail'].lower(), t['type'].lower())
if identifier not in seen:
seen.add(identifier)
unique_triples.append(t)
else:
for existing in unique_triples:
if (existing['head'].lower() == t['head'].lower() and
existing['tail'].lower() == t['tail'].lower() and
len(t['type']) > len(existing['type'])):
unique_triples.remove(existing)
unique_triples.append(t)
debug(f"替换三元组为更具体类型: {t}")
break
timings["extract_triples"] = time.time() - start_triples
debug(
f"三元组抽取耗时: {timings['extract_triples']:.2f} 秒, 抽取到 {len(unique_triples)} 个三元组: {unique_triples[:5]}")
# Neo4j 插入
debug(f"抽取到 {len(unique_triples)} 个三元组调用Neo4j服务插入")
start_neo4j = time.time()
if unique_triples:
neo4j_result = await self._make_neo4japi_request("inserttriples", {
"triples":unique_triples,
"document_id": document_id,
"knowledge_base_id": knowledge_base_id,
"userid": userid
})
debug(f"Neo4j服务响应: {neo4j_result}")
if neo4j_result.get("status") != "success":
timings["insert_neo4j"] = time.time() - start_neo4j
timings["total"] = time.time() - start_total
return{
"status": "error",
"document_id": document_id,
"collection_name": collection_name,
"timings": timings,
"message": f"Neo4j 三元组插入失败: {neo4j_result.get('message', '未知错误')}",
"status_code": 400
}
info(f"文件 {file_path} 三元组成功插入 Neo4j: {neo4j_result.get('message')}")
else:
debug(f"文件 {file_path} 未抽取到三元组")
timings["insert_neo4j"] = time.time() - start_neo4j
debug(f"Neo4j 插入耗时: {timings['insert_neo4j']:.2f}")
except Exception as e:
timings["extract_triples"] = time.time() - start_triples if "extract_triples" not in timings else \
timings["extract_triples"]
timings["insert_neo4j"] = time.time() - start_neo4j
debug(f"处理三元组或 Neo4j 插入失败: {str(e)}, 堆栈: {traceback.format_exc()}")
timings["total"] = time.time() - start_total
return {
"status": "success",
"document_id": document_id,
"collection_name": collection_name,
"timings": timings,
"unique_triples": unique_triples,
"message": f"文件 {file_path} 成功嵌入,但三元组处理或 Neo4j 插入失败: {str(e)}",
"status_code": 200
}
timings["total"] = time.time() - start_total
debug(f"总耗时: {timings['total']:.2f}")
return {
"status": "success",
"document_id": document_id,
"collection_name": collection_name,
"timings": timings,
"unique_triples": unique_triples,
"message": f"文件 {file_path} 成功嵌入并处理三元组",
"status_code": 200
}
except Exception as e:
error(f"插入文档失败: {str(e)}, 堆栈: {traceback.format_exc()}")
timings["total"] = time.time() - start_total
return {
"status": "error",
"document_id": document_id,
"collection_name": collection_name,
"timings": timings,
"message": f"插入文档失败: {str(e)}",
"status_code": 400
}
@retry(
stop=stop_after_attempt(3),
wait=wait_exponential(multiplier=1, min=1, max=10),
retry=retry_if_exception_type((aiohttp.ClientError, RuntimeError)),
before_sleep=lambda retry_state: debug(f"重试三元组抽取服务,第 {retry_state.attempt_number}")
)
async def _get_embeddings(self, texts: List[str]) -> List[List[float]]:
"""调用嵌入服务获取文本的向量,带缓存"""
try:
# 检查缓存
uncached_texts = [text for text in texts if text not in EMBED_CACHE]
if uncached_texts:
async with aiohttp.ClientSession() as session:
async with session.post(
"http://localhost:9998/v1/embeddings",
headers={"Content-Type": "application/json"},
json={"input": uncached_texts}
) as response:
if response.status != 200:
error(f"嵌入服务调用失败,状态码: {response.status}")
raise RuntimeError(f"嵌入服务调用失败: {response.status}")
result = await response.json()
if result.get("object") != "list" or not result.get("data"):
error(f"嵌入服务响应格式错误: {result}")
raise RuntimeError("嵌入服务响应格式错误")
embeddings = [item["embedding"] for item in result["data"]]
for text, embedding in zip(uncached_texts, embeddings):
EMBED_CACHE[text] = np.array(embedding) / np.linalg.norm(embedding)
debug(f"成功获取 {len(embeddings)} 个新嵌入向量,缓存大小: {len(EMBED_CACHE)}")
# 返回缓存中的嵌入
return [EMBED_CACHE[text] for text in texts]
except Exception as e:
error(f"嵌入服务调用失败: {str(e)}")
raise RuntimeError(f"嵌入服务调用失败: {str(e)}")
async def _extract_triples(self, text: str) -> List[Dict]:
"""调用三元组抽取服务,无超时限制"""
request_id = str(uuid.uuid4()) # 为每个请求生成唯一 ID
start_time = time.time()
debug(f"Request #{request_id} started for triples extraction")
try:
async with aiohttp.ClientSession(
connector=aiohttp.TCPConnector(limit=30),
timeout=aiohttp.ClientTimeout(total=None) # 无限等待
) as session:
async with session.post(
"http://localhost:9991/v1/triples",
headers={"Content-Type": "application/json; charset=utf-8"},
json={"text": text}
) as response:
elapsed_time = time.time() - start_time
debug(f"Request #{request_id} received response, status: {response.status}, took {elapsed_time:.2f} seconds")
if response.status != 200:
error_text = await response.text()
error(f"Request #{request_id} failed, status: {response.status}, response: {error_text}")
raise RuntimeError(f"三元组抽取服务调用失败: {response.status}, {error_text}")
result = await response.json()
if result.get("object") != "list" or not result.get("data"):
error(f"Request #{request_id} invalid response format: {result}")
raise RuntimeError("三元组抽取服务响应格式错误")
triples = result["data"]
debug(f"Request #{request_id} extracted {len(triples)} triples, total time: {elapsed_time:.2f} seconds")
return triples
except Exception as e:
elapsed_time = time.time() - start_time
error(f"Request #{request_id} failed to extract triples: {str(e)}, took {elapsed_time:.2f} seconds")
debug(f"Request #{request_id} traceback: {traceback.format_exc()}")
raise RuntimeError(f"三元组抽取服务调用失败: {str(e)}")
async def _delete_document(self, userid: str, file_path: str, knowledge_base_id: str, document_id:str, db_type: str = "") -> Dict[str, Any]:
"""删除用户指定文件数据,包括 Milvus 和 Neo4j 中的记录"""
collection_name = "ragdb" if not db_type else f"ragdb_{db_type}"
try:
# 调用 Milvus 删除文件端点
debug(f"调用删除文件端点: userid={userid}, file_path={file_path}, knowledge_base_id={knowledge_base_id}, document_id={document_id}")
milvus_result = await self._make_api_request("deletedocument", {
"userid": userid,
"file_path": file_path,
"knowledge_base_id": knowledge_base_id,
"document_id": document_id,
"db_type": db_type
})
if milvus_result.get("status") != "success":
error(f"Milvus 删除文件失败: {milvus_result.get('message', '未知错误')}")
return milvus_result
# 调用 Neo4j 删除端点
neo4j_deleted_nodes = 0
neo4j_deleted_rels = 0
try:
debug(f"调用 Neo4j 删除文档端点: document_id={document_id}")
neo4j_result = await self._make_neo4japi_request("deletedocument", {
"document_id": document_id
})
if neo4j_result.get("status") != "success":
error(
f"Neo4j 删除文档失败: document_id={document_id}, 错误: {neo4j_result.get('message', '未知错误')}")
nodes_deleted = neo4j_result.get("nodes_deleted", 0)
rels_deleted = neo4j_result.get("rels_deleted", 0)
neo4j_deleted_nodes += nodes_deleted
neo4j_deleted_rels += rels_deleted
info(f"成功删除 document_id={document_id}{nodes_deleted} 个 Neo4j 节点和 {rels_deleted} 个关系")
except Exception as e:
error(f"删除 document_id={document_id} 的 Neo4j 数据失败: {str(e)}")
return {
"status": "success",
"collection_name": collection_name,
"document_id": document_id,
"message": f"成功删除 Milvus 记录,{neo4j_deleted_nodes} 个 Neo4j 节点,{neo4j_deleted_rels} 个 Neo4j 关系",
"status_code": 200
}
except Exception as e:
error(f"删除文档失败: {str(e)}")
return {
"status": "error",
"collection_name": collection_name,
"document_id": document_id,
"message": f"删除文档失败: {str(e)}",
"status_code": 400
}
async def _delete_knowledge_base(self, db_type: str, userid: str, knowledge_base_id: str) -> Dict[str, Any]:
"""删除用户的整个知识库,包括 Milvus 和 Neo4j 中的记录"""
collection_name = "ragdb" if not db_type else f"ragdb_{db_type}"
try:
# 调用 Milvus 删除知识库端点
debug(f"调用删除知识库端点: userid={userid}, knowledge_base_id={knowledge_base_id}")
milvus_result = await self._make_api_request("deleteknowledgebase", {
"userid": userid,
"knowledge_base_id": knowledge_base_id,
"db_type": db_type
})
if milvus_result.get("status") != "success":
error(f"Milvus 删除知识库失败: {milvus_result.get('message', '未知错误')}")
return milvus_result
deleted_files = milvus_result.get("deleted_files", [])
# 新增:调用 Neo4j 删除知识库端点
neo4j_deleted_nodes = 0
neo4j_deleted_rels = 0
try:
debug(f"调用 Neo4j 删除知识库端点: userid={userid}, knowledge_base_id={knowledge_base_id}")
neo4j_result = await self._make_neo4japi_request("deleteknowledgebase", {
"userid": userid,
"knowledge_base_id": knowledge_base_id
})
if neo4j_result.get("status") == "success":
neo4j_deleted_nodes = neo4j_result.get("nodes_deleted", 0)
neo4j_deleted_rels = neo4j_result.get("rels_deleted", 0)
info(f"成功删除 {neo4j_deleted_nodes} 个 Neo4j 节点和 {neo4j_deleted_rels} 个关系")
else:
error(f"Neo4j 删除知识库失败: {neo4j_result.get('message', '未知错误')}")
return {
"status": "success",
"collection_name": collection_name,
"deleted_files": deleted_files,
"message": f"成功删除 Milvus 知识库,{neo4j_deleted_nodes} 个 Neo4j 节点,{neo4j_deleted_rels} 个 Neo4j 关系,但 Neo4j 删除失败: {neo4j_result.get('message')}",
"status_code": 200
}
except Exception as e:
error(f"Neo4j 删除知识库失败: {str(e)}")
return {
"status": "success",
"collection_name": collection_name,
"deleted_files": deleted_files,
"message": f"成功删除 Milvus 知识库,{neo4j_deleted_nodes} 个 Neo4j 节点,{neo4j_deleted_rels} 个 Neo4j 关系,但 Neo4j 删除失败: {str(e)}",
"status_code": 200
}
if not deleted_files and neo4j_deleted_nodes == 0 and neo4j_deleted_rels == 0:
debug(f"没有删除任何记录userid={userid}, knowledge_base_id={knowledge_base_id}")
return {
"status": "success",
"collection_name": collection_name,
"deleted_files": [],
"message": f"没有找到 userid={userid}, knowledge_base_id={knowledge_base_id} 的记录,无需删除",
"status_code": 200
}
info(
f"总计删除 Milvus 记录,{neo4j_deleted_nodes} 个 Neo4j 节点,{neo4j_deleted_rels} 个 Neo4j 关系,删除文件: {deleted_files}")
return {
"status": "success",
"collection_name": collection_name,
"deleted_files": deleted_files,
"message": f"成功删除 Milvus 记录,{neo4j_deleted_nodes} 个 Neo4j 节点,{neo4j_deleted_rels} 个 Neo4j 关系,删除文件: {deleted_files}",
"status_code": 200
}
except Exception as e:
error(f"删除知识库失败: {str(e)}")
return {
"status": "error",
"collection_name": collection_name,
"deleted_files": [],
"message": f"删除知识库失败: {str(e)}",
"status_code": 400
}
async def _extract_entities(self, query: str) -> List[str]:
"""调用实体识别服务"""
try:
if not query:
raise ValueError("查询文本不能为空")
async with aiohttp.ClientSession() as session:
async with session.post(
"http://localhost:9990/v1/entities",
headers={"Content-Type": "application/json"},
json={"query": query}
) as response:
if response.status != 200:
error(f"实体识别服务调用失败,状态码: {response.status}")
raise RuntimeError(f"实体识别服务调用失败: {response.status}")
result = await response.json()
if result.get("object") != "list" or not result.get("data"):
error(f"实体识别服务响应格式错误: {result}")
raise RuntimeError("实体识别服务响应格式错误")
entities = result["data"]
unique_entities = list(dict.fromkeys(entities)) # 去重
debug(f"成功提取 {len(unique_entities)} 个唯一实体: {unique_entities}")
return unique_entities
except Exception as e:
error(f"实体识别服务调用失败: {str(e)}")
return []
async def _rerank_results(self, query: str, results: List[Dict], top_n: int) -> List[Dict]:
"""调用重排序服务"""
try:
if not results:
debug("无结果需要重排序")
return results
if not isinstance(top_n, int) or top_n < 1:
debug(f"无效的 top_n 参数: {top_n},使用 len(results)={len(results)}")
top_n = len(results)
else:
top_n = min(top_n, len(results))
debug(f"重排序 top_n={top_n}, 原始结果数={len(results)}")
documents = [result["text"] for result in results]
async with aiohttp.ClientSession() as session:
async with session.post(
"http://localhost:9997/v1/rerank",
headers={"Content-Type": "application/json"},
json={
"model": "rerank-001",
"query": query,
"documents": documents,
"top_n": top_n
}
) as response:
if response.status != 200:
error(f"重排序服务调用失败,状态码: {response.status}")
raise RuntimeError(f"重排序服务调用失败: {response.status}")
result = await response.json()
if result.get("object") != "rerank.result" or not result.get("data"):
error(f"重排序服务响应格式错误: {result}")
raise RuntimeError("重排序服务响应格式错误")
rerank_data = result["data"]
reranked_results = []
for item in rerank_data:
index = item["index"]
if index < len(results):
results[index]["rerank_score"] = item["relevance_score"]
reranked_results.append(results[index])
debug(f"成功重排序 {len(reranked_results)} 条结果")
return reranked_results[:top_n]
except Exception as e:
error(f"重排序服务调用失败: {str(e)}")
return results
async def _search_query(self, query: str, userid: str, knowledge_base_ids: List[str], limit: int = 5,
offset: int = 0, use_rerank: bool = True, db_type: str = "") -> Dict[str, Any]:
"""纯向量搜索,调用服务化端点"""
start_time = time.time()
collection_name = "ragdb" if not db_type else f"ragdb_{db_type}"
timing_stats = {}
try:
info(
f"开始纯向量搜索: query={query}, userid={userid}, db_type={db_type}, knowledge_base_ids={knowledge_base_ids}, limit={limit}, offset={offset}, use_rerank={use_rerank}")
if not query:
raise ValueError("查询文本不能为空")
if not userid:
raise ValueError("userid 不能为空")
if "_" in userid or (db_type and "_" in db_type):
raise ValueError("userid 和 db_type 不能包含下划线")
if (db_type and len(db_type) > 100) or len(userid) > 100:
raise ValueError("userid 或 db_type 的长度超出限制")
if limit <= 0 or limit > 16384:
raise ValueError("limit 必须在 1 到 16384 之间")
if offset < 0:
raise ValueError("offset 不能为负数")
if limit + offset > 16384:
raise ValueError("limit + offset 不能超过 16384")
if not knowledge_base_ids:
raise ValueError("knowledge_base_ids 不能为空")
for kb_id in knowledge_base_ids:
if not isinstance(kb_id, str):
raise ValueError(f"knowledge_base_id 必须是字符串: {kb_id}")
if len(kb_id) > 100:
raise ValueError(f"knowledge_base_id 长度超出 100 个字符: {kb_id}")
if "_" in kb_id:
raise ValueError(f"knowledge_base_id 不能包含下划线: {kb_id}")
# 将查询文本转换为向量
vector_start = time.time()
query_vector = await self._get_embeddings([query])
if not query_vector or not all(len(vec) == 1024 for vec in query_vector):
raise ValueError("查询向量必须是长度为 1024 的浮点数列表")
query_vector = query_vector[0] # 取第一个向量
timing_stats["vector_generation"] = time.time() - vector_start
debug(f"生成查询向量耗时: {timing_stats['vector_generation']:.3f}")
# 调用纯向量搜索端点
search_start = time.time()
result = await self._make_api_request("searchquery", {
"query_vector": query_vector.tolist(),
"userid": userid,
"knowledge_base_ids": knowledge_base_ids,
"limit": limit,
"offset": offset,
"db_type": db_type
})
timing_stats["vector_search"] = time.time() - search_start
debug(f"向量搜索耗时: {timing_stats['vector_search']:.3f}")
if result.get("status") != "success":
error(f"纯向量搜索失败: {result.get('message', '未知错误')}")
return {"results": [], "timing": timing_stats}
unique_results = result.get("results", [])
if use_rerank and unique_results:
rerank_start = time.time()
debug("开始重排序")
unique_results = await self._rerank_results(query, unique_results, limit)
unique_results = sorted(unique_results, key=lambda x: x.get('rerank_score', 0), reverse=True)
timing_stats["reranking"] = time.time() - rerank_start
debug(f"重排序耗时: {timing_stats['reranking']:.3f}")
debug(f"重排序分数分布: {[round(r.get('rerank_score', 0), 3) for r in unique_results]}")
else:
unique_results = [{k: v for k, v in r.items() if k != 'rerank_score'} for r in unique_results]
timing_stats["total_time"] = time.time() - start_time
info(f"纯向量搜索完成,返回 {len(unique_results)} 条结果,总耗时: {timing_stats['total_time']:.3f}")
return {"results": unique_results[:limit], "timing": timing_stats}
except Exception as e:
error(f"纯向量搜索失败: {str(e)}, 堆栈: {traceback.format_exc()}")
return {"results": [], "timing": timing_stats}
async def _fused_search(self, query: str, userid: str, knowledge_base_ids: List[str], limit: int = 5,
offset: int = 0, use_rerank: bool = True, db_type: str = "") -> Dict[str, Any]:
"""融合搜索,调用服务化端点"""
start_time = time.time()
collection_name = "ragdb" if not db_type else f"ragdb_{db_type}"
timing_stats = {}
try:
info(
f"开始融合搜索: query={query}, userid={userid}, db_type={db_type}, knowledge_base_ids={knowledge_base_ids}, limit={limit}, offset={offset}, use_rerank={use_rerank}")
if not query or not userid or not knowledge_base_ids:
raise ValueError("query、userid 和 knowledge_base_ids 不能为空")
if "_" in userid or (db_type and "_" in db_type):
raise ValueError("userid 和 db_type 不能包含下划线")
if (db_type and len(db_type) > 100) or len(userid) > 100:
raise ValueError("db_type 或 userid 的长度超出限制")
if limit < 1 or limit > 16384 or offset < 0:
raise ValueError("limit 必须在 1 到 16384 之间offset 必须大于或等于 0")
# 提取实体
entity_extract_start = time.time()
query_entities = await self._extract_entities(query)
timing_stats["entity_extraction"] = time.time() - entity_extract_start
debug(f"提取实体: {query_entities}, 耗时: {timing_stats['entity_extraction']:.3f}")
# 调用 Neo4j 服务进行三元组匹配
all_triplets = []
triplet_match_start = time.time()
for kb_id in knowledge_base_ids:
debug(f"调用 Neo4j 三元组匹配: knowledge_base_id={kb_id}")
try:
neo4j_result = await self._make_neo4japi_request("matchtriplets", {
"query": query,
"query_entities": query_entities,
"userid": userid,
"knowledge_base_id": kb_id
})
if neo4j_result.get("status") == "success":
triplets = neo4j_result.get("triplets", [])
all_triplets.extend(triplets)
debug(f"知识库 {kb_id} 匹配到 {len(triplets)} 个三元组: {triplets[:5]}")
else:
error(
f"Neo4j 三元组匹配失败: knowledge_base_id={kb_id}, 错误: {neo4j_result.get('message', '未知错误')}")
except Exception as e:
error(f"Neo4j 三元组匹配失败: knowledge_base_id={kb_id}, 错误: {str(e)}")
continue
timing_stats["triplet_matching"] = time.time() - triplet_match_start
debug(f"三元组匹配总耗时: {timing_stats['triplet_matching']:.3f}")
# 拼接三元组文本
triplet_text_start = time.time()
triplet_texts = []
for triplet in all_triplets:
head = triplet.get('head', '')
type_ = triplet.get('type', '')
tail = triplet.get('tail', '')
if head and type_ and tail:
triplet_texts.append(f"{head} {type_} {tail}")
else:
debug(f"无效三元组: {triplet}")
combined_text = query
if triplet_texts:
combined_text += " [三元组] " + "; ".join(triplet_texts)
debug(
f"拼接文本: {combined_text[:200]}... (总长度: {len(combined_text)}, 三元组数量: {len(triplet_texts)})")
timing_stats["triplet_text_combine"] = time.time() - triplet_text_start
debug(f"拼接三元组文本耗时: {timing_stats['triplet_text_combine']:.3f}")
# 将拼接文本转换为向量
vector_start = time.time()
query_vector = await self._get_embeddings([combined_text])
if not query_vector or not all(len(vec) == 1024 for vec in query_vector):
raise ValueError("查询向量必须是长度为 1024 的浮点数列表")
query_vector = query_vector[0] # 取第一个向量
timing_stats["vector_generation"] = time.time() - vector_start
debug(f"生成查询向量耗时: {timing_stats['vector_generation']:.3f}")
# 调用融合搜索端点
search_start = time.time()
result = await self._make_api_request("searchquery", {
"query_vector": query_vector.tolist(),
"userid": userid,
"knowledge_base_ids": knowledge_base_ids,
"limit": limit,
"offset": offset,
"db_type": db_type
})
timing_stats["vector_search"] = time.time() - search_start
debug(f"向量搜索耗时: {timing_stats['vector_search']:.3f}")
if result.get("status") != "success":
error(f"融合搜索失败: {result.get('message', '未知错误')}")
return {"results": [], "timing": timing_stats}
unique_results = result.get("results", [])
if use_rerank and unique_results:
rerank_start = time.time()
debug("开始重排序")
unique_results = await self._rerank_results(combined_text, unique_results, limit)
unique_results = sorted(unique_results, key=lambda x: x.get('rerank_score', 0), reverse=True)
timing_stats["reranking"] = time.time() - rerank_start
debug(f"重排序耗时: {timing_stats['reranking']:.3f}")
debug(f"重排序分数分布: {[round(r.get('rerank_score', 0), 3) for r in unique_results]}")
else:
unique_results = [{k: v for k, v in r.items() if k != 'rerank_score'} for r in unique_results]
timing_stats["total_time"] = time.time() - start_time
info(f"融合搜索完成,返回 {len(unique_results)} 条结果,总耗时: {timing_stats['total_time']:.3f}")
return {"results": unique_results[:limit], "timing": timing_stats}
except Exception as e:
error(f"融合搜索失败: {str(e)}, 堆栈: {traceback.format_exc()}")
return {"results": [], "timing": timing_stats}
async def _list_user_files(self, userid: str, db_type: str = "") -> Dict[str, List[Dict]]:
"""列出用户的所有知识库及其文件,按 knowledge_base_id 分组"""
collection_name = "ragdb" if not db_type else f"ragdb_{db_type}"
try:
info(f"列出用户文件: userid={userid}, db_type={db_type}")
if not userid:
raise ValueError("userid 不能为空")
if "_" in userid or (db_type and "_" in db_type):
raise ValueError("userid 和 db_type 不能包含下划线")
if (db_type and len(db_type) > 100) or len(userid) > 100:
raise ValueError("userid 或 db_type 的长度超出限制")
# 调用列出用户文件端点
result = await self._make_api_request("listuserfiles", {
"userid": userid,
"db_type": db_type
})
if result.get("status") != "success":
error(f"列出用户文件失败: {result.get('message', '未知错误')}")
return {}
return result.get("files_by_knowledge_base", {})
except Exception as e:
error(f"列出用户文件失败: {str(e)}")
return {}
async def _list_all_knowledge_bases(self, db_type: str = "") -> Dict[str, Any]:
"""列出数据库中所有用户的知识库及其文件,按用户分组"""
collection_name = "ragdb" if not db_type else f"ragdb_{db_type}"
try:
info(f"列出所有用户的知识库: db_type={db_type}")
if db_type and "_" in db_type:
raise ValueError("db_type 不能包含下划线")
if db_type and len(db_type) > 100:
raise ValueError("db_type 的长度应小于 100")
# 调用列出所有知识库端点
result = await self._make_api_request("listallknowledgebases", {
"db_type": db_type
})
return result
except Exception as e:
error(f"列出所有用户知识库失败: {str(e)}")
return {
"status": "error",
"users_knowledge_bases": {},
"collection_name": collection_name,
"message": f"列出所有用户知识库失败: {str(e)}",
"status_code": 400
}
connection_register('Rag', MilvusConnection)
info("MilvusConnection registered")