rag/rag/file.py
2025-08-27 14:09:39 +08:00

381 lines
17 KiB
Python

from rag.api_service import APIService
from appPublic.registerfunction import RegisterFunction
from appPublic.log import debug, error, info
from sqlor.dbpools import DBPools
import asyncio
import aiohttp
from langchain_core.documents import Document
from langchain_text_splitters import RecursiveCharacterTextSplitter
import os
import re
import time
import uuid
from datetime import datetime
import traceback
from filetxt.loader import fileloader
from ahserver.serverenv import get_serverenv
from typing import List, Dict, Any
api_service = APIService()
async def get_orgid_by_id(kdb_id):
"""
根据 kdb 的 id 查询对应的 orgid。
"""
db = DBPools()
# f = get_serverenv("get_module_dbname")
# dbname = f("rag")
dbname = "kyrag"
sql = "SELECT orgid FROM kdb WHERE id = ${id}$"
try:
async with db.sqlorContext(dbname) as sor:
result = await sor.sqlExe(sql,{"id":kdb_id})
print(result)
if result and len(result) > 0:
return result[0].get('orgid')
return None
except Exception as e:
error(f"查询 orgid 失败: {str(e)}, 堆栈: {traceback.format_exc()}")
return None
async def file_uploaded(params_kw):
"""将文档插入 Milvus 并抽取三元组到 Neo4j"""
debug(f'Received params: {params_kw=}')
realpath = params_kw.get('realpath', '')
fiid = params_kw.get('fiid', '')
id = params_kw.get('id', '')
orgid = await get_orgid_by_id(fiid)
db_type = ''
debug(f'Inserting document: file_path={realpath}, userid={orgid}, db_type={db_type}, knowledge_base_id={fiid}, document_id={id}')
timings = {}
start_total = time.time()
try:
if not orgid or not fiid or not id:
raise ValueError("orgid、fiid 和 id 不能为空")
debug(f'orgid、fiid 和 id 不能为空')
if len(orgid) > 32 or len(fiid) > 255:
raise ValueError("orgid 或 fiid 的长度超出限制")
if not os.path.exists(realpath):
raise ValueError(f"文件 {realpath} 不存在")
supported_formats = {'pdf', 'docx', 'xlsx', 'pptx', 'csv', 'txt'}
ext = realpath.rsplit('.', 1)[1].lower() if '.' in realpath else ''
if ext not in supported_formats:
raise ValueError(f"不支持的文件格式: {ext}, 支持的格式: {', '.join(supported_formats)}")
debug(f"加载文件: {realpath}")
start_load = time.time()
text = fileloader(realpath)
text = re.sub(r'[^\u4e00-\u9fa5a-zA-Z0-9\s.;,\n]', '', text)
timings["load_file"] = time.time() - start_load
debug(f"加载文件耗时: {timings['load_file']:.2f} 秒, 文本长度: {len(text)}")
if not text or not text.strip():
raise ValueError(f"文件 {realpath} 加载为空")
document = Document(page_content=text)
text_splitter = RecursiveCharacterTextSplitter(
chunk_size=500,
chunk_overlap=100,
length_function=len)
debug("开始分片文件内容")
start_split = time.time()
chunks = text_splitter.split_documents([document])
timings["split_text"] = time.time() - start_split
debug(f"文本分片耗时: {timings['split_text']:.2f} 秒, 分片数量: {len(chunks)}, 分片内容: {[chunk.page_content[:50] for chunk in chunks[:5]]}")
if not chunks:
raise ValueError(f"文件 {realpath} 未生成任何文档块")
filename = os.path.basename(realpath).rsplit('.', 1)[0]
upload_time = datetime.now().isoformat()
debug("调用嵌入服务生成向量")
start_embedding = time.time()
texts = [chunk.page_content for chunk in chunks]
embeddings = []
for i in range(0, len(texts), 10): # 每次处理 10 个文本块
batch_texts = texts[i:i + 10]
batch_embeddings = await api_service.get_embeddings(batch_texts)
embeddings.extend(batch_embeddings)
if not embeddings or not all(len(vec) == 1024 for vec in embeddings):
raise ValueError("所有嵌入向量必须是长度为 1024 的浮点数列表")
timings["generate_embeddings"] = time.time() - start_embedding
debug(f"生成嵌入向量耗时: {timings['generate_embeddings']:.2f} 秒, 嵌入数量: {len(embeddings)}")
chunks_data = []
for i, chunk in enumerate(chunks):
chunks_data.append({
"userid": orgid,
"knowledge_base_id": fiid,
"text": chunk.page_content,
"vector": embeddings[i],
"document_id": id,
"filename": filename + '.' + ext,
"file_path": realpath,
"upload_time": upload_time,
"file_type": ext,
})
debug(f"调用插入文件端点: {realpath}")
start_milvus = time.time()
for i in range(0, len(chunks_data), 10): # 每次处理 10 条数据
batch_chunks = chunks_data[i:i + 10]
result = await api_service.milvus_insert_document(batch_chunks, db_type)
if result.get("status") != "success":
raise ValueError(result.get("message", "Milvus 插入失败"))
timings["insert_milvus"] = time.time() - start_milvus
debug(f"Milvus 插入耗时: {timings['insert_milvus']:.2f}")
if result.get("status") != "success":
timings["total"] = time.time() - start_total
return {"status": "error", "document_id": id, "timings": timings, "message": result.get("message", "未知错误"), "status_code": 400}
debug("调用三元组抽取服务")
start_triples = time.time()
try:
chunk_texts = [doc.page_content for doc in chunks]
debug(f"处理 {len(chunk_texts)} 个分片进行三元组抽取")
tasks = [api_service.extract_triples(chunk) for chunk in chunk_texts]
results = await asyncio.gather(*tasks, return_exceptions=True)
triples = []
for i, result in enumerate(results):
if isinstance(result, list):
triples.extend(result)
debug(f"分片 {i + 1} 抽取到 {len(result)} 个三元组: {result[:5]}")
else:
error(f"分片 {i + 1} 处理失败: {str(result)}")
unique_triples = []
seen = set()
for t in triples:
identifier = (t['head'].lower(), t['tail'].lower(), t['type'].lower())
if identifier not in seen:
seen.add(identifier)
unique_triples.append(t)
else:
for existing in unique_triples:
if (existing['head'].lower() == t['head'].lower() and
existing['tail'].lower() == t['tail'].lower() and
len(t['type']) > len(existing['type'])):
unique_triples.remove(existing)
unique_triples.append(t)
debug(f"替换三元组为更具体类型: {t}")
break
timings["extract_triples"] = time.time() - start_triples
debug(f"三元组抽取耗时: {timings['extract_triples']:.2f} 秒, 抽取到 {len(unique_triples)} 个三元组: {unique_triples[:5]}")
debug(f"抽取到 {len(unique_triples)} 个三元组,调用 Neo4j 服务插入")
start_neo4j = time.time()
for i in range(0, len(unique_triples), 30): # 每次插入 30 个三元组
batch_triples = unique_triples[i:i + 30]
neo4j_result = await api_service.neo4j_insert_triples(batch_triples, id, fiid, orgid)
debug(f"Neo4j 服务响应: {neo4j_result}")
if neo4j_result.get("status") != "success":
timings["insert_neo4j"] = time.time() - start_neo4j
timings["total"] = time.time() - start_total
return {"status": "error", "document_id": id, "collection_name": "ragdb", "timings": timings,
"message": f"Neo4j 三元组插入失败: {neo4j_result.get('message', '未知错误')}", "status_code": 400}
info(f"文件 {realpath} 三元组成功插入 Neo4j: {neo4j_result.get('message')}")
else:
debug(f"文件 {realpath} 未抽取到三元组")
timings["insert_neo4j"] = time.time() - start_neo4j
debug(f"Neo4j 插入耗时: {timings['insert_neo4j']:.2f}")
except Exception as e:
timings["extract_triples"] = time.time() - start_triples if "extract_triples" not in timings else timings["extract_triples"]
timings["insert_neo4j"] = time.time() - start_neo4j
debug(f"处理三元组或 Neo4j 插入失败: {str(e)}, 堆栈: {traceback.format_exc()}")
timings["total"] = time.time() - start_total
return {"status": "success", "document_id": id, "collection_name": "ragdb", "timings": timings,
"unique_triples": unique_triples,
"message": f"文件 {realpath} 成功嵌入,但三元组处理或 Neo4j 插入失败: {str(e)}", "status_code": 200}
timings["total"] = time.time() - start_total
debug(f"总耗时: {timings['total']:.2f}")
return {"status": "success", "userid": orgid, "document_id": id, "collection_name": "ragdb", "timings": timings,
"unique_triples": unique_triples, "message": f"文件 {realpath} 成功嵌入并处理三元组", "status_code": 200}
except Exception as e:
error(f"插入文档失败: {str(e)}, 堆栈: {traceback.format_exc()}")
timings["total"] = time.time() - start_total
return {"status": "error", "document_id": id, "collection_name": "ragdb", "timings": timings,
"message": f"插入文档失败: {str(e)}", "status_code": 400}
async def file_deleted(params_kw):
"""删除用户指定文件数据,包括 Milvus 和 Neo4j 中的记录"""
id = params_kw.get('id', '')
realpath = params_kw.get('realpath', '')
fiid = params_kw.get('fiid', '')
orgid = await get_orgid_by_id(fiid)
db_type = ''
collection_name = "ragdb" if not db_type else f"ragdb_{db_type}"
try:
required_fields = ['id', 'fiid', 'realpath']
missing_fields = [field for field in required_fields if field not in params_kw or not params_kw[field]]
if missing_fields:
raise ValueError(f"缺少必填字段: {', '.join(missing_fields)}")
debug(f"调用删除文件端点: userid={orgid}, file_path={realpath}, knowledge_base_id={fiid}")
milvus_result = await api_service.milvus_delete_document(orgid, realpath, fiid, id, db_type)
if milvus_result.get("status") != "success":
raise ValueError(milvus_result.get("message", "Milvus 删除失败"))
neo4j_deleted_nodes = 0
neo4j_deleted_rels = 0
try:
debug(f"调用 Neo4j 删除文档端点: document_id={id}")
neo4j_result = await api_service.neo4j_delete_document(id)
if neo4j_result.get("status") != "success":
raise ValueError(neo4j_result.get("message", "Neo4j 删除失败"))
nodes_deleted = neo4j_result.get("nodes_deleted", 0)
rels_deleted = neo4j_result.get("rels_deleted", 0)
neo4j_deleted_nodes += nodes_deleted
neo4j_deleted_rels += rels_deleted
info(f"成功删除 document_id={id}{nodes_deleted} 个 Neo4j 节点和 {rels_deleted} 个关系")
except Exception as e:
error(f"删除 document_id={id} 的 Neo4j 数据失败: {str(e)}")
return {
"status": "success",
"collection_name": collection_name,
"document_id": id,
"message": f"成功删除 Milvus 记录,{neo4j_deleted_nodes} 个 Neo4j 节点,{neo4j_deleted_rels} 个 Neo4j 关系",
"status_code": 200
}
except Exception as e:
error(f"删除文档失败: {str(e)}, 堆栈: {traceback.format_exc()}")
return {
"status": "error",
"collection_name": collection_name,
"document_id": id,
"message": f"删除文档失败: {str(e)}",
"status_code": 400
}
async def _search_query(query: str, userid: str, knowledge_base_ids: List[str], limit: int = 5,
offset: int = 0, use_rerank: bool = True, db_type: str = "") -> Dict[str, Any]:
"""纯向量搜索,调用服务化端点"""
start_time = time.time()
collection_name = "ragdb" if not db_type else f"ragdb_{db_type}"
timing_stats = {}
try:
info(
f"开始纯向量搜索: query={query}, userid={userid}, db_type={db_type}, knowledge_base_ids={knowledge_base_ids}, limit={limit}, offset={offset}, use_rerank={use_rerank}")
if not query:
raise ValueError("查询文本不能为空")
if not userid:
raise ValueError("userid 不能为空")
if limit <= 0 or limit > 16384:
raise ValueError("limit 必须在 1 到 16384 之间")
if offset < 0:
raise ValueError("offset 不能为负数")
if limit + offset > 16384:
raise ValueError("limit + offset 不能超过 16384")
if not knowledge_base_ids:
raise ValueError("knowledge_base_ids 不能为空")
for kb_id in knowledge_base_ids:
if not isinstance(kb_id, str):
raise ValueError(f"knowledge_base_id 必须是字符串: {kb_id}")
if len(kb_id) > 100:
raise ValueError(f"knowledge_base_id 长度超出 100 个字符: {kb_id}")
# 将查询文本转换为向量
vector_start = time.time()
query_vector = await api_service.get_embeddings([query])
if not query_vector or not all(len(vec) == 1024 for vec in query_vector):
raise ValueError("查询向量必须是长度为 1024 的浮点数列表")
query_vector = query_vector[0] # 取第一个向量
timing_stats["vector_generation"] = time.time() - vector_start
debug(f"生成查询向量耗时: {timing_stats['vector_generation']:.3f}")
# 调用纯向量搜索端点
search_start = time.time()
result = await api_service.milvus_search_query(query_vector, userid, knowledge_base_ids, limit, offset)
timing_stats["vector_search"] = time.time() - search_start
debug(f"向量搜索耗时: {timing_stats['vector_search']:.3f}")
if result.get("status") != "success":
error(f"纯向量搜索失败: {result.get('message', '未知错误')}")
return {"results": [], "timing": timing_stats}
unique_results = result.get("results", [])
if use_rerank and unique_results:
rerank_start = time.time()
debug("开始重排序")
unique_results = await api_service.rerank_results(query, unique_results, limit)
unique_results = sorted(unique_results, key=lambda x: x.get('rerank_score', 0), reverse=True)
timing_stats["reranking"] = time.time() - rerank_start
debug(f"重排序耗时: {timing_stats['reranking']:.3f}")
debug(f"重排序分数分布: {[round(r.get('rerank_score', 0), 3) for r in unique_results]}")
else:
unique_results = [{k: v for k, v in r.items() if k != 'rerank_score'} for r in unique_results]
timing_stats["total_time"] = time.time() - start_time
info(f"纯向量搜索完成,返回 {len(unique_results)} 条结果,总耗时: {timing_stats['total_time']:.3f}")
return {"results": unique_results[:limit], "timing": timing_stats}
except Exception as e:
error(f"纯向量搜索失败: {str(e)}, 堆栈: {traceback.format_exc()}")
return {"results": [], "timing": timing_stats}
async def main():
dbs = {
"kyrag":{
"driver":"aiomysql",
"async_mode":True,
"coding":"utf8",
"maxconn":100,
"dbname":"kyrag",
"kwargs":{
"user":"test",
"db":"kyrag",
"password":"QUZVcXg5V1p1STMybG5Ia6mX9D0v7+g=",
"host":"db"
}
}
}
DBPools(dbs)
# 测试 file_uploaded
print("测试 file_uploaded...")
test_file_path = "/home/wangmeihua/data/kg.txt"
test_params_upload = {
"realpath": test_file_path,
"fiid": "1",
"id": "doc1"
}
upload_result = await file_uploaded(test_params_upload)
print(f"file_uploaded 结果: {upload_result}")
# # 测试 file_deleted
# test_file_path = "/home/wangmeihua/data/kg.txt"
# print("测试 file_deleted...")
# test_params_delete = {
# "realpath": test_file_path,
# "fiid": "1",
# "id": "doc1"
# }
# delete_result = await file_deleted(test_params_delete)
# print(f"file_deleted 结果: {delete_result}")
# # 测试 _search_query
# print("测试 _search_query...")
# test_params_query = {
# "query": "什么是关系抽取",
# "userid": "04J6VbxLqB_9RPMcgOv_8",
# "knowledge_base_ids": ["1"],
# "limit": 5,
# "offset": 0,
# "use_rerank": True
# }
# query_result = await _search_query(query="什么是知识融合?", userid="testuser1", knowledge_base_ids=["kb1", "kb2"], limit=5, offset=0, use_rerank=True, db_type="")
# print(f"file_uploaded 结果: {query_result}")
if __name__ == "__main__":
asyncio.run(main())