增加添加/删除文档功能
This commit is contained in:
parent
d52440560d
commit
64e42705bf
282
rag/api_service.py
Normal file
282
rag/api_service.py
Normal file
@ -0,0 +1,282 @@
|
|||||||
|
from appPublic.log import debug, error
|
||||||
|
from typing import Dict, Any, List
|
||||||
|
import aiohttp
|
||||||
|
from aiohttp import ClientSession, ClientTimeout
|
||||||
|
from tenacity import retry, stop_after_attempt, wait_exponential, retry_if_exception_type
|
||||||
|
import traceback
|
||||||
|
import uuid
|
||||||
|
import json
|
||||||
|
|
||||||
|
class APIService:
|
||||||
|
"""处理 API 请求的服务类"""
|
||||||
|
|
||||||
|
@retry(
|
||||||
|
stop=stop_after_attempt(3),
|
||||||
|
wait=wait_exponential(multiplier=1, min=1, max=10),
|
||||||
|
retry=retry_if_exception_type((aiohttp.ClientError, RuntimeError)),
|
||||||
|
before_sleep=lambda retry_state: debug(f"重试 API 请求,第 {retry_state.attempt_number} 次")
|
||||||
|
)
|
||||||
|
async def _make_request(self, url: str, action: str, params: Dict[str, Any]) -> Dict[str, Any]:
|
||||||
|
"""通用 API 请求函数"""
|
||||||
|
debug(f"开始 API 请求: action={action}, params={params}, url={url}")
|
||||||
|
try:
|
||||||
|
async with ClientSession(timeout=ClientTimeout(total=300)) as session:
|
||||||
|
async with session.post(
|
||||||
|
url,
|
||||||
|
headers={"Content-Type": "application/json"},
|
||||||
|
json=params
|
||||||
|
) as response:
|
||||||
|
debug(f"收到响应: status={response.status}, headers={response.headers}")
|
||||||
|
response_text = await response.text()
|
||||||
|
debug(f"响应内容: {response_text}")
|
||||||
|
result = await response.json()
|
||||||
|
debug(f"API 响应内容: {result}")
|
||||||
|
if response.status == 400:
|
||||||
|
debug(f"客户端错误,状态码: {response.status}, 返回响应: {result}")
|
||||||
|
return result
|
||||||
|
if response.status != 200:
|
||||||
|
error(f"API 调用失败,动作: {action}, 状态码: {response.status}, 响应: {response_text}")
|
||||||
|
raise RuntimeError(f"API 调用失败: {response.status}")
|
||||||
|
return result
|
||||||
|
except Exception as e:
|
||||||
|
error(f"API 调用失败: {action}, 错误: {str(e)}, 堆栈: {traceback.format_exc()}")
|
||||||
|
raise RuntimeError(f"API 调用失败: {str(e)}")
|
||||||
|
|
||||||
|
# 嵌入服务 (BAAI/bge-m3)
|
||||||
|
async def get_embeddings(self, texts: list) -> list:
|
||||||
|
"""调用嵌入服务获取文本向量"""
|
||||||
|
try:
|
||||||
|
async with ClientSession() as session:
|
||||||
|
async with session.post(
|
||||||
|
"http://localhost:9998/v1/embeddings",
|
||||||
|
headers={"Content-Type": "application/json"},
|
||||||
|
json={"input": texts if isinstance(texts, list) else [texts]}
|
||||||
|
) as response:
|
||||||
|
if response.status != 200:
|
||||||
|
error(f"嵌入服务调用失败,状态码: {response.status}")
|
||||||
|
raise RuntimeError(f"嵌入服务调用失败: {response.status}")
|
||||||
|
result = await response.json()
|
||||||
|
if result.get("object") != "list" or not result.get("data"):
|
||||||
|
error(f"嵌入服务响应格式错误: {result}")
|
||||||
|
raise RuntimeError("嵌入服务响应格式错误")
|
||||||
|
embeddings = [item["embedding"] for item in result["data"]]
|
||||||
|
debug(f"成功获取 {len(embeddings)} 个嵌入向量")
|
||||||
|
return embeddings
|
||||||
|
except Exception as e:
|
||||||
|
error(f"嵌入服务调用失败: {str(e)}")
|
||||||
|
raise RuntimeError(f"嵌入服务调用失败: {str(e)}")
|
||||||
|
|
||||||
|
# 实体提取服务 (LTP/small)
|
||||||
|
async def extract_entities(self, query: str) -> list:
|
||||||
|
"""调用实体识别服务"""
|
||||||
|
try:
|
||||||
|
if not query:
|
||||||
|
raise ValueError("查询文本不能为空")
|
||||||
|
async with ClientSession() as session:
|
||||||
|
async with session.post(
|
||||||
|
"http://localhost:9990/v1/entities",
|
||||||
|
headers={"Content-Type": "application/json"},
|
||||||
|
json={"query": query}
|
||||||
|
) as response:
|
||||||
|
if response.status != 200:
|
||||||
|
error(f"实体识别服务调用失败,状态码: {response.status}")
|
||||||
|
raise RuntimeError(f"实体识别服务调用失败: {response.status}")
|
||||||
|
result = await response.json()
|
||||||
|
if result.get("object") != "list" or not result.get("data"):
|
||||||
|
error(f"实体识别服务响应格式错误: {result}")
|
||||||
|
raise RuntimeError("实体识别服务响应格式错误")
|
||||||
|
entities = result["data"]
|
||||||
|
unique_entities = list(dict.fromkeys(entities))
|
||||||
|
debug(f"成功提取 {len(unique_entities)} 个唯一实体")
|
||||||
|
return unique_entities
|
||||||
|
except Exception as e:
|
||||||
|
error(f"实体识别服务调用失败: {str(e)}")
|
||||||
|
return []
|
||||||
|
|
||||||
|
# 三元组抽取服务 (Babelscape/mrebel-large)
|
||||||
|
async def extract_triples(self, text: str) -> list:
|
||||||
|
"""调用三元组抽取服务"""
|
||||||
|
request_id = str(uuid.uuid4())
|
||||||
|
debug(f"Request #{request_id} started for triples extraction")
|
||||||
|
try:
|
||||||
|
async with ClientSession(
|
||||||
|
connector=aiohttp.TCPConnector(limit=30),
|
||||||
|
timeout=ClientTimeout(total=None)
|
||||||
|
) as session:
|
||||||
|
async with session.post(
|
||||||
|
"http://localhost:9991/v1/triples",
|
||||||
|
headers={"Content-Type": "application/json; charset=utf-8"},
|
||||||
|
json={"text": text}
|
||||||
|
) as response:
|
||||||
|
if response.status != 200:
|
||||||
|
error_text = await response.text()
|
||||||
|
error(f"Request #{request_id} failed, status: {response.status}, response: {error_text}")
|
||||||
|
raise RuntimeError(f"三元组抽取服务调用失败: {response.status}, {error_text}")
|
||||||
|
result = await response.json()
|
||||||
|
if result.get("object") != "list" or not result.get("data"):
|
||||||
|
error(f"Request #{request_id} invalid response format: {result}")
|
||||||
|
raise RuntimeError("三元组抽取服务响应格式错误")
|
||||||
|
triples = result["data"]
|
||||||
|
debug(f"Request #{request_id} extracted {len(triples)} triples")
|
||||||
|
return triples
|
||||||
|
except Exception as e:
|
||||||
|
error(f"Request #{request_id} failed to extract triples: {str(e)}")
|
||||||
|
raise RuntimeError(f"三元组抽取服务调用失败: {str(e)}")
|
||||||
|
|
||||||
|
# 重排序服务 (BAAI/bge-reranker-v2-m3)
|
||||||
|
async def rerank_results(self, query: str, results: list, top_n: int) -> list:
|
||||||
|
"""调用重排序服务"""
|
||||||
|
try:
|
||||||
|
if not results:
|
||||||
|
debug("无结果需要重排序")
|
||||||
|
return results
|
||||||
|
|
||||||
|
if not isinstance(top_n, int) or top_n < 1:
|
||||||
|
debug(f"无效的 top_n 参数: {top_n}, 使用 len(results)={len(results)}")
|
||||||
|
top_n = len(results)
|
||||||
|
else:
|
||||||
|
top_n = min(top_n, len(results))
|
||||||
|
|
||||||
|
documents = [result.get("text", str(result)) for result in results]
|
||||||
|
async with ClientSession() as session:
|
||||||
|
async with session.post(
|
||||||
|
"http://localhost:9997/v1/rerank",
|
||||||
|
headers={"Content-Type": "application/json"},
|
||||||
|
json={
|
||||||
|
"model": "rerank-001",
|
||||||
|
"query": query,
|
||||||
|
"documents": documents,
|
||||||
|
"top_n": top_n
|
||||||
|
}
|
||||||
|
) as response:
|
||||||
|
if response.status != 200:
|
||||||
|
error(f"重排序服务调用失败,状态码: {response.status}")
|
||||||
|
raise RuntimeError(f"重排序服务调用失败: {response.status}")
|
||||||
|
result = await response.json()
|
||||||
|
if result.get("object") != "rerank.result" or not result.get("data"):
|
||||||
|
error(f"重排序服务响应格式错误: {result}")
|
||||||
|
raise RuntimeError("重排序服务响应格式错误")
|
||||||
|
rerank_data = result["data"]
|
||||||
|
reranked_results = []
|
||||||
|
for item in rerank_data:
|
||||||
|
index = item["index"]
|
||||||
|
if index < len(results):
|
||||||
|
results[index]["rerank_score"] = item["relevance_score"]
|
||||||
|
reranked_results.append(results[index])
|
||||||
|
debug(f"成功重排序 {len(reranked_results)} 条结果")
|
||||||
|
return reranked_results[:top_n]
|
||||||
|
except Exception as e:
|
||||||
|
error(f"重排序服务调用失败: {str(e)}")
|
||||||
|
return results
|
||||||
|
|
||||||
|
# Neo4j 服务
|
||||||
|
async def neo4j_docs(self) -> Dict[str, Any]:
|
||||||
|
"""获取 Neo4j 文档"""
|
||||||
|
return await self._make_request("http://localhost:8885/docs", "docs", {})
|
||||||
|
|
||||||
|
async def neo4j_initialize(self) -> Dict[str, Any]:
|
||||||
|
"""初始化 Neo4j 服务"""
|
||||||
|
return await self._make_request("http://localhost:8885/v1/initialize", "initialize", {})
|
||||||
|
|
||||||
|
async def neo4j_insert_triples(self, triples: list, document_id: str, knowledge_base_id: str, userid: str) -> Dict[str, Any]:
|
||||||
|
"""插入三元组到 Neo4j"""
|
||||||
|
params = {
|
||||||
|
"triples": triples,
|
||||||
|
"document_id": document_id,
|
||||||
|
"knowledge_base_id": knowledge_base_id,
|
||||||
|
"userid": userid
|
||||||
|
}
|
||||||
|
return await self._make_request("http://localhost:8885/v1/inserttriples", "inserttriples", params)
|
||||||
|
|
||||||
|
async def neo4j_delete_document(self, document_id: str) -> Dict[str, Any]:
|
||||||
|
"""删除指定文档"""
|
||||||
|
return await self._make_request("http://localhost:8885/v1/deletedocument", "deletedocument", {"document_id": document_id})
|
||||||
|
|
||||||
|
async def neo4j_delete_knowledgebase(self, userid: str, knowledge_base_id: str) -> Dict[str, Any]:
|
||||||
|
"""删除用户知识库"""
|
||||||
|
return await self._make_request("http://localhost:8885/v1/deleteknowledgebase", "deleteknowledgebase",
|
||||||
|
{"userid": userid, "knowledge_base_id": knowledge_base_id})
|
||||||
|
|
||||||
|
async def neo4j_match_triplets(self, query: str, query_entities: list, userid: str, knowledge_base_id: str) -> Dict[str, Any]:
|
||||||
|
"""根据实体匹配相关三元组"""
|
||||||
|
params = {
|
||||||
|
"query": query,
|
||||||
|
"query_entities": query_entities,
|
||||||
|
"userid": userid,
|
||||||
|
"knowledge_base_id": knowledge_base_id
|
||||||
|
}
|
||||||
|
return await self._make_request("http://localhost:8885/v1/matchtriplets", "matchtriplets", params)
|
||||||
|
|
||||||
|
# RAG 服务
|
||||||
|
async def rag_create_collection(self, db_type: str = "") -> Dict[str, Any]:
|
||||||
|
"""创建集合"""
|
||||||
|
params = {"db_type": db_type} if db_type else {}
|
||||||
|
return await self._make_request("http://localhost:8888/v1/createcollection", "createcollection", params)
|
||||||
|
|
||||||
|
async def rag_delete_collection(self, db_type: str = "") -> Dict[str, Any]:
|
||||||
|
"""删除集合"""
|
||||||
|
params = {"db_type": db_type} if db_type else {}
|
||||||
|
return await self._make_request("http://localhost:8888/v1/deletecollection", "deletecollection", params)
|
||||||
|
|
||||||
|
async def rag_insert_file(self, file_path: str, userid: str, knowledge_base_id: str, document_id: str) -> Dict[str, Any]:
|
||||||
|
"""添加记录"""
|
||||||
|
params = {
|
||||||
|
"file_path": file_path,
|
||||||
|
"userid": userid,
|
||||||
|
"knowledge_base_id": knowledge_base_id,
|
||||||
|
"document_id": document_id
|
||||||
|
}
|
||||||
|
return await self._make_request("http://localhost:8888/v1/insertfile", "insertfile", params)
|
||||||
|
|
||||||
|
async def rag_delete_file(self, userid: str, file_path: str, knowledge_base_id: str, document_id: str) -> Dict[str, Any]:
|
||||||
|
"""删除记录"""
|
||||||
|
params = {
|
||||||
|
"userid": userid,
|
||||||
|
"file_path": file_path,
|
||||||
|
"knowledge_base_id": knowledge_base_id,
|
||||||
|
"document_id": document_id
|
||||||
|
}
|
||||||
|
return await self._make_request("http://localhost:8888/v1/deletefile", "deletefile", params)
|
||||||
|
|
||||||
|
async def rag_delete_knowledgebase(self, userid: str, knowledge_base_id: str) -> Dict[str, Any]:
|
||||||
|
"""删除知识库"""
|
||||||
|
return await self._make_request("http://localhost:8888/v1/deleteknowledgebase", "deleteknowledgebase",
|
||||||
|
{"userid": userid, "knowledge_base_id": knowledge_base_id})
|
||||||
|
|
||||||
|
async def rag_search_query(self, query: str, userid: str, knowledge_base_ids: list, limit: int, offset: int,
|
||||||
|
use_rerank: bool) -> Dict[str, Any]:
|
||||||
|
"""根据用户知识库检索"""
|
||||||
|
params = {
|
||||||
|
"query": query,
|
||||||
|
"userid": userid,
|
||||||
|
"knowledge_base_ids": knowledge_base_ids,
|
||||||
|
"limit": limit,
|
||||||
|
"offset": offset,
|
||||||
|
"use_rerank": use_rerank
|
||||||
|
}
|
||||||
|
return await self._make_request("http://localhost:8888/v1/searchquery", "searchquery", params)
|
||||||
|
|
||||||
|
async def rag_fused_search_query(self, query: str, userid: str, knowledge_base_ids: list, limit: int, offset: int,
|
||||||
|
use_rerank: bool) -> Dict[str, Any]:
|
||||||
|
"""根据用户知识库+知识图谱检索"""
|
||||||
|
params = {
|
||||||
|
"query": query,
|
||||||
|
"userid": userid,
|
||||||
|
"knowledge_base_ids": knowledge_base_ids,
|
||||||
|
"limit": limit,
|
||||||
|
"offset": offset,
|
||||||
|
"use_rerank": use_rerank
|
||||||
|
}
|
||||||
|
return await self._make_request("http://localhost:8888/v1/fusedsearchquery", "fusedsearchquery", params)
|
||||||
|
|
||||||
|
async def rag_list_user_files(self, userid: str) -> Dict[str, Any]:
|
||||||
|
"""列出用户知识库列表"""
|
||||||
|
return await self._make_request("http://localhost:8888/v1/listuserfiles", "listuserfiles", {"userid": userid})
|
||||||
|
|
||||||
|
async def rag_list_all_knowledgebases(self) -> Dict[str, Any]:
|
||||||
|
"""列出数据库中所有数据"""
|
||||||
|
return await self._make_request("http://localhost:8888/v1/listallknowledgebases", "listallknowledgebases", {})
|
||||||
|
|
||||||
|
async def rag_docs(self) -> Dict[str, Any]:
|
||||||
|
"""获取 RAG 帮助文档"""
|
||||||
|
return await self._make_request("http://localhost:8888/v1/docs", "docs", {})
|
||||||
283
rag/file.py
Normal file
283
rag/file.py
Normal file
@ -0,0 +1,283 @@
|
|||||||
|
from api_service import APIService
|
||||||
|
from appPublic.registerfunction import RegisterFunction
|
||||||
|
from appPublic.log import debug, error, info
|
||||||
|
from sqlor.dbpools import DBPools
|
||||||
|
import asyncio
|
||||||
|
import aiohttp
|
||||||
|
from langchain_core.documents import Document
|
||||||
|
from langchain_text_splitters import RecursiveCharacterTextSplitter
|
||||||
|
import os
|
||||||
|
import re
|
||||||
|
import time
|
||||||
|
import uuid
|
||||||
|
from datetime import datetime
|
||||||
|
import traceback
|
||||||
|
from filetxt.loader import fileloader
|
||||||
|
|
||||||
|
def init():
|
||||||
|
rf = RegisterFunction()
|
||||||
|
rf.register('fileuploaded', file_uploaded)
|
||||||
|
rf.register('filedeleted', file_deleted)
|
||||||
|
|
||||||
|
async def get_orgid_by_id(kdb_id):
|
||||||
|
"""
|
||||||
|
根据 kdb 的 id 查询对应的 orgid。
|
||||||
|
"""
|
||||||
|
dbs = {
|
||||||
|
"cfae": {
|
||||||
|
"driver": "mysql.connector",
|
||||||
|
"coding": "utf8",
|
||||||
|
"dbname": "cfae",
|
||||||
|
"kwargs": {
|
||||||
|
"user": "test",
|
||||||
|
"db": "cfae",
|
||||||
|
"password": "test123",
|
||||||
|
"host": "localhost"
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
loop = asyncio.get_event_loop()
|
||||||
|
pool = DBPools(dbs, loop)
|
||||||
|
db = DBPools()
|
||||||
|
dbname = get_module_dbname('rag')
|
||||||
|
sql = "SELECT orgid FROM kdb WHERE id = %s"
|
||||||
|
try:
|
||||||
|
async with db.sqlorContext(dbname) as sor:
|
||||||
|
result = await sor.sql(sql, (kdb_id,))
|
||||||
|
if result and result.rows:
|
||||||
|
return result.rows[0][0]
|
||||||
|
return None
|
||||||
|
except Exception as e:
|
||||||
|
error(f"查询 orgid 失败: {str(e)}, 堆栈: {traceback.format_exc()}")
|
||||||
|
return None
|
||||||
|
|
||||||
|
async def file_uploaded(params_kw):
|
||||||
|
"""将文档插入 Milvus 并抽取三元组到 Neo4j"""
|
||||||
|
debug(f'Received params: {params_kw=}')
|
||||||
|
api_service = APIService()
|
||||||
|
realpath = params_kw.get('realpath', '')
|
||||||
|
fiid = params_kw.get('fiid', '')
|
||||||
|
id = params_kw.get('id', '')
|
||||||
|
orgid = await get_orgid_by_id(id)
|
||||||
|
db_type = ''
|
||||||
|
debug(f'Inserting document: file_path={realpath}, userid={orgid}, db_type={db_type}, knowledge_base_id={fiid}, document_id={id}')
|
||||||
|
|
||||||
|
timings = {}
|
||||||
|
start_total = time.time()
|
||||||
|
|
||||||
|
try:
|
||||||
|
if not orgid or not fiid or not id:
|
||||||
|
raise ValueError("orgid、fiid 和 id 不能为空")
|
||||||
|
if len(orgid) > 32 or len(fiid) > 255:
|
||||||
|
raise ValueError("orgid 或 fiid 的长度超出限制")
|
||||||
|
if not os.path.exists(realpath):
|
||||||
|
raise ValueError(f"文件 {realpath} 不存在")
|
||||||
|
|
||||||
|
supported_formats = {'pdf', 'docx', 'xlsx', 'pptx', 'csv', 'txt'}
|
||||||
|
ext = realpath.rsplit('.', 1)[1].lower() if '.' in realpath else ''
|
||||||
|
if ext not in supported_formats:
|
||||||
|
raise ValueError(f"不支持的文件格式: {ext}, 支持的格式: {', '.join(supported_formats)}")
|
||||||
|
|
||||||
|
debug(f"加载文件: {realpath}")
|
||||||
|
start_load = time.time()
|
||||||
|
text = fileloader(realpath)
|
||||||
|
text = re.sub(r'[^\u4e00-\u9fa5a-zA-Z0-9\s.;,\n]', '', text)
|
||||||
|
timings["load_file"] = time.time() - start_load
|
||||||
|
debug(f"加载文件耗时: {timings['load_file']:.2f} 秒, 文本长度: {len(text)}")
|
||||||
|
if not text or not text.strip():
|
||||||
|
raise ValueError(f"文件 {realpath} 加载为空")
|
||||||
|
|
||||||
|
document = Document(page_content=text)
|
||||||
|
text_splitter = RecursiveCharacterTextSplitter(
|
||||||
|
chunk_size=500,
|
||||||
|
chunk_overlap=100,
|
||||||
|
length_function=len)
|
||||||
|
debug("开始分片文件内容")
|
||||||
|
start_split = time.time()
|
||||||
|
chunks = text_splitter.split_documents([document])
|
||||||
|
timings["split_text"] = time.time() - start_split
|
||||||
|
debug(f"文本分片耗时: {timings['split_text']:.2f} 秒, 分片数量: {len(chunks)}, 分片内容: {[chunk.page_content[:50] for chunk in chunks[:5]]}")
|
||||||
|
if not chunks:
|
||||||
|
raise ValueError(f"文件 {realpath} 未生成任何文档块")
|
||||||
|
|
||||||
|
filename = os.path.basename(realpath).rsplit('.', 1)[0]
|
||||||
|
upload_time = datetime.now().isoformat()
|
||||||
|
|
||||||
|
debug("调用嵌入服务生成向量")
|
||||||
|
start_embedding = time.time()
|
||||||
|
texts = [chunk.page_content for chunk in chunks]
|
||||||
|
embeddings = await api_service.get_embeddings(texts)
|
||||||
|
if not embeddings or not all(len(vec) == 1024 for vec in embeddings):
|
||||||
|
raise ValueError("所有嵌入向量必须是长度为 1024 的浮点数列表")
|
||||||
|
timings["generate_embeddings"] = time.time() - start_embedding
|
||||||
|
debug(f"生成嵌入向量耗时: {timings['generate_embeddings']:.2f} 秒, 嵌入数量: {len(embeddings)}")
|
||||||
|
|
||||||
|
chunks_data = []
|
||||||
|
for i, chunk in enumerate(chunks):
|
||||||
|
chunks_data.append({
|
||||||
|
"userid": orgid,
|
||||||
|
"knowledge_base_id": fiid,
|
||||||
|
"text": chunk.page_content,
|
||||||
|
"vector": embeddings[i],
|
||||||
|
"document_id": id,
|
||||||
|
"filename": filename + '.' + ext,
|
||||||
|
"file_path": realpath,
|
||||||
|
"upload_time": upload_time,
|
||||||
|
"file_type": ext,
|
||||||
|
})
|
||||||
|
|
||||||
|
debug(f"调用插入文件端点: {realpath}")
|
||||||
|
start_milvus = time.time()
|
||||||
|
result = await api_service.make_milvus_request("insertdocument", {"chunks": chunks_data, "db_type": db_type})
|
||||||
|
timings["insert_milvus"] = time.time() - start_milvus
|
||||||
|
debug(f"Milvus 插入耗时: {timings['insert_milvus']:.2f} 秒")
|
||||||
|
|
||||||
|
if result.get("status") != "success":
|
||||||
|
timings["total"] = time.time() - start_total
|
||||||
|
return {"status": "error", "document_id": id, "timings": timings, "message": result.get("message", "未知错误"), "status_code": 400}
|
||||||
|
|
||||||
|
debug("调用三元组抽取服务")
|
||||||
|
start_triples = time.time()
|
||||||
|
try:
|
||||||
|
chunk_texts = [doc.page_content for doc in chunks]
|
||||||
|
debug(f"处理 {len(chunk_texts)} 个分片进行三元组抽取")
|
||||||
|
tasks = [api_service.extract_triples(chunk) for chunk in chunk_texts]
|
||||||
|
results = await asyncio.gather(*tasks, return_exceptions=True)
|
||||||
|
|
||||||
|
triples = []
|
||||||
|
for i, result in enumerate(results):
|
||||||
|
if isinstance(result, list):
|
||||||
|
triples.extend(result)
|
||||||
|
debug(f"分片 {i + 1} 抽取到 {len(result)} 个三元组: {result[:5]}")
|
||||||
|
else:
|
||||||
|
error(f"分片 {i + 1} 处理失败: {str(result)}")
|
||||||
|
|
||||||
|
unique_triples = []
|
||||||
|
seen = set()
|
||||||
|
for t in triples:
|
||||||
|
identifier = (t['head'].lower(), t['tail'].lower(), t['type'].lower())
|
||||||
|
if identifier not in seen:
|
||||||
|
seen.add(identifier)
|
||||||
|
unique_triples.append(t)
|
||||||
|
else:
|
||||||
|
for existing in unique_triples:
|
||||||
|
if (existing['head'].lower() == t['head'].lower() and
|
||||||
|
existing['tail'].lower() == t['tail'].lower() and
|
||||||
|
len(t['type']) > len(existing['type'])):
|
||||||
|
unique_triples.remove(existing)
|
||||||
|
unique_triples.append(t)
|
||||||
|
debug(f"替换三元组为更具体类型: {t}")
|
||||||
|
break
|
||||||
|
|
||||||
|
timings["extract_triples"] = time.time() - start_triples
|
||||||
|
debug(f"三元组抽取耗时: {timings['extract_triples']:.2f} 秒, 抽取到 {len(unique_triples)} 个三元组: {unique_triples[:5]}")
|
||||||
|
|
||||||
|
debug(f"抽取到 {len(unique_triples)} 个三元组,调用 Neo4j 服务插入")
|
||||||
|
start_neo4j = time.time()
|
||||||
|
if unique_triples:
|
||||||
|
neo4j_result = await api_service.make_neo4j_request("inserttriples", {
|
||||||
|
"triples": unique_triples, "document_id": id, "knowledge_base_id": fiid, "userid": orgid
|
||||||
|
})
|
||||||
|
debug(f"Neo4j 服务响应: {neo4j_result}")
|
||||||
|
if neo4j_result.get("status") != "success":
|
||||||
|
timings["insert_neo4j"] = time.time() - start_neo4j
|
||||||
|
timings["total"] = time.time() - start_total
|
||||||
|
return {"status": "error", "document_id": id, "collection_name": "ragdb", "timings": timings,
|
||||||
|
"message": f"Neo4j 三元组插入失败: {neo4j_result.get('message', '未知错误')}", "status_code": 400}
|
||||||
|
info(f"文件 {realpath} 三元组成功插入 Neo4j: {neo4j_result.get('message')}")
|
||||||
|
else:
|
||||||
|
debug(f"文件 {realpath} 未抽取到三元组")
|
||||||
|
timings["insert_neo4j"] = time.time() - start_neo4j
|
||||||
|
debug(f"Neo4j 插入耗时: {timings['insert_neo4j']:.2f} 秒")
|
||||||
|
|
||||||
|
except Exception as e:
|
||||||
|
timings["extract_triples"] = time.time() - start_triples if "extract_triples" not in timings else timings["extract_triples"]
|
||||||
|
timings["insert_neo4j"] = time.time() - start_neo4j
|
||||||
|
debug(f"处理三元组或 Neo4j 插入失败: {str(e)}, 堆栈: {traceback.format_exc()}")
|
||||||
|
timings["total"] = time.time() - start_total
|
||||||
|
return {"status": "success", "document_id": id, "collection_name": "ragdb", "timings": timings,
|
||||||
|
"unique_triples": unique_triples,
|
||||||
|
"message": f"文件 {realpath} 成功嵌入,但三元组处理或 Neo4j 插入失败: {str(e)}", "status_code": 200}
|
||||||
|
|
||||||
|
timings["total"] = time.time() - start_total
|
||||||
|
debug(f"总耗时: {timings['total']:.2f} 秒")
|
||||||
|
return {"status": "success", "document_id": id, "collection_name": "ragdb", "timings": timings,
|
||||||
|
"unique_triples": unique_triples, "message": f"文件 {realpath} 成功嵌入并处理三元组", "status_code": 200}
|
||||||
|
|
||||||
|
except Exception as e:
|
||||||
|
error(f"插入文档失败: {str(e)}, 堆栈: {traceback.format_exc()}")
|
||||||
|
timings["total"] = time.time() - start_total
|
||||||
|
return {"status": "error", "document_id": id, "collection_name": "ragdb", "timings": timings,
|
||||||
|
"message": f"插入文档失败: {str(e)}", "status_code": 400}
|
||||||
|
|
||||||
|
async def file_deleted(params_kw):
|
||||||
|
"""删除用户指定文件数据,包括 Milvus 和 Neo4j 中的记录"""
|
||||||
|
api_service = APIService()
|
||||||
|
id = params_kw.get('id', '')
|
||||||
|
realpath = params_kw.get('realpath', '')
|
||||||
|
fiid = params_kw.get('fiid', '')
|
||||||
|
orgid = await get_orgid_by_id(id)
|
||||||
|
db_type = ''
|
||||||
|
collection_name = "ragdb" if not db_type else f"ragdb_{db_type}"
|
||||||
|
try:
|
||||||
|
required_fields = ['id', 'fiid', 'realpath']
|
||||||
|
missing_fields = [field for field in required_fields if field not in params_kw or not params_kw[field]]
|
||||||
|
if missing_fields:
|
||||||
|
raise ValueError(f"缺少必填字段: {', '.join(missing_fields)}")
|
||||||
|
|
||||||
|
debug(f"调用删除文件端点: userid={orgid}, file_path={realpath}, knowledge_base_id={fiid}")
|
||||||
|
document_id = id # 备用,使用 id 作为 document_id
|
||||||
|
milvus_result = await api_service.make_milvus_request("deletedocument", {
|
||||||
|
"userid": orgid,
|
||||||
|
"file_path": realpath,
|
||||||
|
"knowledge_base_id": fiid,
|
||||||
|
"document_id": document_id,
|
||||||
|
"db_type": db_type
|
||||||
|
})
|
||||||
|
|
||||||
|
if milvus_result.get("status") != "success":
|
||||||
|
raise ValueError(milvus_result.get("message", "Milvus 删除失败"))
|
||||||
|
|
||||||
|
neo4j_deleted_nodes = 0
|
||||||
|
neo4j_deleted_rels = 0
|
||||||
|
try:
|
||||||
|
debug(f"调用 Neo4j 删除文档端点: document_id={document_id}")
|
||||||
|
neo4j_result = await api_service.make_neo4j_request("deletedocument", {"document_id": document_id})
|
||||||
|
if neo4j_result.get("status") != "success":
|
||||||
|
raise ValueError(neo4j_result.get("message", "Neo4j 删除失败"))
|
||||||
|
nodes_deleted = neo4j_result.get("nodes_deleted", 0)
|
||||||
|
rels_deleted = neo4j_result.get("rels_deleted", 0)
|
||||||
|
neo4j_deleted_nodes += nodes_deleted
|
||||||
|
neo4j_deleted_rels += rels_deleted
|
||||||
|
info(f"成功删除 document_id={document_id} 的 {nodes_deleted} 个 Neo4j 节点和 {rels_deleted} 个关系")
|
||||||
|
except Exception as e:
|
||||||
|
error(f"删除 document_id={document_id} 的 Neo4j 数据失败: {str(e)}")
|
||||||
|
|
||||||
|
return {
|
||||||
|
"status": "success",
|
||||||
|
"collection_name": collection_name,
|
||||||
|
"document_id": document_id,
|
||||||
|
"message": f"成功删除 Milvus 记录,{neo4j_deleted_nodes} 个 Neo4j 节点,{neo4j_deleted_rels} 个 Neo4j 关系",
|
||||||
|
"status_code": 200
|
||||||
|
}
|
||||||
|
|
||||||
|
except Exception as e:
|
||||||
|
error(f"删除文档失败: {str(e)}, 堆栈: {traceback.format_exc()}")
|
||||||
|
return {
|
||||||
|
"status": "error",
|
||||||
|
"collection_name": collection_name,
|
||||||
|
"document_id": document_id,
|
||||||
|
"message": f"删除文档失败: {str(e)}",
|
||||||
|
"status_code": 400
|
||||||
|
}
|
||||||
|
|
||||||
|
async def main():
|
||||||
|
kdb_id = "textdb"
|
||||||
|
orgid = await get_orgid_by_id(kdb_id)
|
||||||
|
if orgid:
|
||||||
|
print(f"找到的 orgid: {orgid}")
|
||||||
|
else:
|
||||||
|
print("未找到对应的 orgid")
|
||||||
|
|
||||||
|
if __name__ == "__main__":
|
||||||
|
asyncio.run(main())
|
||||||
Loading…
x
Reference in New Issue
Block a user