uapi化
This commit is contained in:
parent
fcb477d7ea
commit
ea205fbb7c
@ -1,5 +1,25 @@
|
||||
from urllib.request import Request
|
||||
|
||||
from appPublic.timeUtils import curDateString
|
||||
form filemgr.filemgr import FileMgr
|
||||
from filemgr.filemgr import FileMgr
|
||||
from rag.uapi_service import APIService
|
||||
from appPublic.registerfunction import RegisterFunction
|
||||
from appPublic.log import debug, error, info
|
||||
from sqlor.dbpools import DBPools
|
||||
import asyncio
|
||||
import aiohttp
|
||||
from langchain_core.documents import Document
|
||||
from langchain_text_splitters import RecursiveCharacterTextSplitter
|
||||
import os
|
||||
import re
|
||||
import time
|
||||
import uuid
|
||||
from datetime import datetime
|
||||
import traceback
|
||||
from filetxt.loader import fileloader
|
||||
from ahserver.serverenv import get_serverenv
|
||||
from typing import List, Dict, Any
|
||||
import json
|
||||
|
||||
class RagFileMgr(FileMgr):
|
||||
async def get_folder_ownerid(self, sor):
|
||||
@ -23,14 +43,425 @@ where a.orgid = b.orgid
|
||||
if len(recs) > 0:
|
||||
r = recs[0]
|
||||
return r.quota, r.expired_date
|
||||
|
||||
|
||||
async def get_service_params(self,orgid):
|
||||
""" 根据 orgid 从数据库获取服务参数 (仅 upappid),假设 service_opts 表返回单条记录。 """
|
||||
db = DBPools()
|
||||
dbname = "kyrag"
|
||||
|
||||
sql_opts = """
|
||||
SELECT embedding_id, vdb_id, reranker_id, triples_id, gdb_id, entities_id
|
||||
FROM service_opts
|
||||
WHERE orgid = ${orgid}$
|
||||
"""
|
||||
try:
|
||||
async with db.sqlorContext(dbname) as sor:
|
||||
opts_result = await sor.sqlExe(sql_opts, {"orgid": orgid})
|
||||
if not opts_result:
|
||||
error(f"未找到 orgid={orgid} 的服务配置")
|
||||
return None
|
||||
opts = opts_result[0]
|
||||
except Exception as e:
|
||||
error(f"查询 service_opts 失败: {str(e)}, 堆栈: {traceback.format_exc()}")
|
||||
return None
|
||||
|
||||
# 收集服务 ID
|
||||
service_ids = set()
|
||||
for key in ['embedding_id', 'vdb_id', 'reranker_id', 'triples_id', 'gdb_id', 'entities_id']:
|
||||
if opts[key]:
|
||||
service_ids.add(opts[key])
|
||||
|
||||
# 检查 service_ids 是否为空
|
||||
if not service_ids:
|
||||
error(f"未找到任何服务 ID for orgid={orgid}")
|
||||
return None
|
||||
|
||||
# 手动构造 IN 子句的 ID 列表
|
||||
id_list = ','.join([f"'{id}'" for id in service_ids]) # 确保每个 ID 被单引号包裹
|
||||
sql_services = f"""
|
||||
SELECT id, name, upappid
|
||||
FROM ragservices
|
||||
WHERE id IN ({id_list})
|
||||
"""
|
||||
try:
|
||||
async with db.sqlorContext(dbname) as sor:
|
||||
services_result = await sor.sqlExe(sql_services, {})
|
||||
if not services_result:
|
||||
error(f"未找到服务 ID {service_ids} 的 ragservices 配置")
|
||||
return None
|
||||
|
||||
# 构建服务参数字典,基于 name 字段匹配,仅存储 upappid
|
||||
service_params = {
|
||||
'embedding': None,
|
||||
'vdb': None,
|
||||
'reranker': None,
|
||||
'triples': None,
|
||||
'gdb': None,
|
||||
'entities': None
|
||||
}
|
||||
for service in services_result:
|
||||
name = service['name']
|
||||
if name == 'bgem3嵌入':
|
||||
service_params['embedding'] = service['upappid']
|
||||
elif name == 'milvus向量检索':
|
||||
service_params['vdb'] = service['upappid']
|
||||
elif name == 'bgem2v3重排':
|
||||
service_params['reranker'] = service['upappid']
|
||||
elif name == 'mrebel三元组抽取':
|
||||
service_params['triples'] = service['upappid']
|
||||
elif name == 'neo4j删除知识库':
|
||||
service_params['gdb'] = service['upappid']
|
||||
elif name == 'small实体抽取':
|
||||
service_params['entities'] = service['upappid']
|
||||
|
||||
# 检查是否所有服务参数都已填充
|
||||
missing_services = [k for k, v in service_params.items() if v is None]
|
||||
if missing_services:
|
||||
error(f"未找到以下服务的配置: {missing_services}")
|
||||
return None
|
||||
|
||||
return service_params
|
||||
except Exception as e:
|
||||
error(f"查询 ragservices 失败: {str(e)}, 堆栈: {traceback.format_exc()}")
|
||||
return None
|
||||
|
||||
async def file_uploaded(self, request, ns, userid):
|
||||
pass
|
||||
"""将文档插入 Milvus 并抽取三元组到 Neo4j"""
|
||||
debug(f'Received ns: {ns=}')
|
||||
realpath = ns.get('realpath', '')
|
||||
fiid = ns.get('fiid', '')
|
||||
id = ns.get('id', '')
|
||||
orgid = ns.get('orgid', '')
|
||||
db_type = ''
|
||||
api_service = APIService()
|
||||
debug(
|
||||
f'Inserting document: file_path={realpath}, userid={orgid}, db_type={db_type}, knowledge_base_id={fiid}, document_id={id}')
|
||||
|
||||
timings = {}
|
||||
start_total = time.time()
|
||||
|
||||
try:
|
||||
if not orgid or not fiid or not id:
|
||||
raise ValueError("orgid、fiid 和 id 不能为空")
|
||||
if len(orgid) > 32 or len(fiid) > 255:
|
||||
raise ValueError("orgid 或 fiid 的长度超出限制")
|
||||
if not os.path.exists(realpath):
|
||||
raise ValueError(f"文件 {realpath} 不存在")
|
||||
|
||||
# 获取服务参数
|
||||
service_params = await get_service_params(orgid)
|
||||
if not service_params:
|
||||
raise ValueError("无法获取服务参数")
|
||||
|
||||
supported_formats = {'pdf', 'docx', 'xlsx', 'pptx', 'csv', 'txt'}
|
||||
ext = realpath.rsplit('.', 1)[1].lower() if '.' in realpath else ''
|
||||
if ext not in supported_formats:
|
||||
raise ValueError(f"不支持的文件格式: {ext}, 支持的格式: {', '.join(supported_formats)}")
|
||||
|
||||
debug(f"加载文件: {realpath}")
|
||||
start_load = time.time()
|
||||
text = fileloader(realpath)
|
||||
text = re.sub(r'[^\u4e00-\u9fa5a-zA-Z0-9\s.;,\n]', '', text)
|
||||
timings["load_file"] = time.time() - start_load
|
||||
debug(f"加载文件耗时: {timings['load_file']:.2f} 秒, 文本长度: {len(text)}")
|
||||
if not text or not text.strip():
|
||||
raise ValueError(f"文件 {realpath} 加载为空")
|
||||
|
||||
document = Document(page_content=text)
|
||||
text_splitter = RecursiveCharacterTextSplitter(
|
||||
chunk_size=500,
|
||||
chunk_overlap=100,
|
||||
length_function=len)
|
||||
debug("开始分片文件内容")
|
||||
start_split = time.time()
|
||||
chunks = text_splitter.split_documents([document])
|
||||
timings["split_text"] = time.time() - start_split
|
||||
debug(
|
||||
f"文本分片耗时: {timings['split_text']:.2f} 秒, 分片数量: {len(chunks)}, 分片内容: {[chunk.page_content[:50] for chunk in chunks[:5]]}")
|
||||
if not chunks:
|
||||
raise ValueError(f"文件 {realpath} 未生成任何文档块")
|
||||
|
||||
filename = os.path.basename(realpath).rsplit('.', 1)[0]
|
||||
upload_time = datetime.now().isoformat()
|
||||
|
||||
debug("调用嵌入服务生成向量")
|
||||
start_embedding = time.time()
|
||||
texts = [chunk.page_content for chunk in chunks]
|
||||
embeddings = []
|
||||
for i in range(0, len(texts), 10): # 每次处理 10 个文本块
|
||||
batch_texts = texts[i:i + 10]
|
||||
batch_embeddings = await api_service.get_embeddings(
|
||||
request=Request,
|
||||
texts=batch_texts,
|
||||
upappid=service_params['embedding'],
|
||||
apiname="BAAI/bge-m3",
|
||||
user=userid
|
||||
)
|
||||
embeddings.extend(batch_embeddings)
|
||||
if not embeddings or not all(len(vec) == 1024 for vec in embeddings):
|
||||
raise ValueError("所有嵌入向量必须是长度为 1024 的浮点数列表")
|
||||
timings["generate_embeddings"] = time.time() - start_embedding
|
||||
debug(f"生成嵌入向量耗时: {timings['generate_embeddings']:.2f} 秒, 嵌入数量: {len(embeddings)}")
|
||||
|
||||
chunks_data = []
|
||||
for i, chunk in enumerate(chunks):
|
||||
chunks_data.append({
|
||||
"userid": orgid,
|
||||
"knowledge_base_id": fiid,
|
||||
"text": chunk.page_content,
|
||||
"vector": embeddings[i],
|
||||
"document_id": id,
|
||||
"filename": filename + '.' + ext,
|
||||
"file_path": realpath,
|
||||
"upload_time": upload_time,
|
||||
"file_type": ext,
|
||||
})
|
||||
|
||||
debug(f"调用插入文件端点: {realpath}")
|
||||
start_milvus = time.time()
|
||||
for i in range(0, len(chunks_data), 10): # 每次处理 10 条数据
|
||||
batch_chunks = chunks_data[i:i + 10]
|
||||
result = await api_service.milvus_insert_document(
|
||||
request=request,
|
||||
chunks=batch_chunks,
|
||||
db_type=db_type,
|
||||
upappid=service_params['vdb'],
|
||||
apiname="milvus/insertdocument", # 固定 apiname
|
||||
user=userid
|
||||
)
|
||||
if result.get("status") != "success":
|
||||
raise ValueError(result.get("message", "Milvus 插入失败"))
|
||||
timings["insert_milvus"] = time.time() - start_milvus
|
||||
debug(f"Milvus 插入耗时: {timings['insert_milvus']:.2f} 秒")
|
||||
|
||||
if result.get("status") != "success":
|
||||
timings["total"] = time.time() - start_total
|
||||
return {"status": "error", "document_id": id, "timings": timings,
|
||||
"message": result.get("message", "未知错误"), "status_code": 400}
|
||||
|
||||
debug("调用三元组抽取服务")
|
||||
start_triples = time.time()
|
||||
try:
|
||||
chunk_texts = [doc.page_content for doc in chunks]
|
||||
debug(f"处理 {len(chunk_texts)} 个分片进行三元组抽取")
|
||||
tasks = [
|
||||
api_service.extract_triples(
|
||||
request=Request,
|
||||
text=chunk,
|
||||
upappid=service_params['triples'],
|
||||
apiname="Babelscape/mrebel-large", # 固定 apiname
|
||||
user=userid
|
||||
) for chunk in chunk_texts
|
||||
]
|
||||
results = await asyncio.gather(*tasks, return_exceptions=True)
|
||||
|
||||
triples = []
|
||||
for i, result in enumerate(results):
|
||||
if isinstance(result, list):
|
||||
triples.extend(result)
|
||||
debug(f"分片 {i + 1} 抽取到 {len(result)} 个三元组: {result[:5]}")
|
||||
else:
|
||||
error(f"分片 {i + 1} 处理失败: {str(result)}")
|
||||
|
||||
unique_triples = []
|
||||
seen = set()
|
||||
for t in triples:
|
||||
identifier = (t['head'].lower(), t['tail'].lower(), t['type'].lower())
|
||||
if identifier not in seen:
|
||||
seen.add(identifier)
|
||||
unique_triples.append(t)
|
||||
else:
|
||||
for existing in unique_triples:
|
||||
if (existing['head'].lower() == t['head'].lower() and
|
||||
existing['tail'].lower() == t['tail'].lower() and
|
||||
len(t['type']) > len(existing['type'])):
|
||||
unique_triples.remove(existing)
|
||||
unique_triples.append(t)
|
||||
debug(f"替换三元组为更具体类型: {t}")
|
||||
break
|
||||
|
||||
timings["extract_triples"] = time.time() - start_triples
|
||||
debug(
|
||||
f"三元组抽取耗时: {timings['extract_triples']:.2f} 秒, 抽取到 {len(unique_triples)} 个三元组: {unique_triples[:5]}")
|
||||
|
||||
debug(f"抽取到 {len(unique_triples)} 个三元组,调用 Neo4j 服务插入")
|
||||
start_neo4j = time.time()
|
||||
for i in range(0, len(unique_triples), 30): # 每次插入 30 个三元组
|
||||
batch_triples = unique_triples[i:i + 30]
|
||||
neo4j_result = await api_service.neo4j_insert_triples(
|
||||
request=Request,
|
||||
triples=batch_triples,
|
||||
document_id=id,
|
||||
knowledge_base_id=fiid,
|
||||
userid=orgid,
|
||||
upappid=service_params['gdb'],
|
||||
apiname="neo4j/inserttriples", # 固定 apiname
|
||||
user=userid
|
||||
)
|
||||
debug(f"Neo4j 服务响应: {neo4j_result}")
|
||||
if neo4j_result.get("status") != "success":
|
||||
timings["insert_neo4j"] = time.time() - start_neo4j
|
||||
timings["total"] = time.time() - start_total
|
||||
return {"status": "error", "document_id": id, "collection_name": "ragdb",
|
||||
"timings": timings,
|
||||
"message": f"Neo4j 三元组插入失败: {neo4j_result.get('message', '未知错误')}",
|
||||
"status_code": 400}
|
||||
info(f"文件 {realpath} 三元组成功插入 Neo4j: {neo4j_result.get('message')}")
|
||||
else:
|
||||
debug(f"文件 {realpath} 未抽取到三元组")
|
||||
timings["insert_neo4j"] = time.time() - start_neo4j
|
||||
debug(f"Neo4j 插入耗时: {timings['insert_neo4j']:.2f} 秒")
|
||||
|
||||
except Exception as e:
|
||||
timings["extract_triples"] = time.time() - start_triples if "extract_triples" not in timings else \
|
||||
timings[
|
||||
"extract_triples"]
|
||||
timings["insert_neo4j"] = time.time() - start_neo4j
|
||||
debug(f"处理三元组或 Neo4j 插入失败: {str(e)}, 堆栈: {traceback.format_exc()}")
|
||||
timings["total"] = time.time() - start_total
|
||||
return {"status": "success", "document_id": id, "collection_name": "ragdb", "timings": timings,
|
||||
"unique_triples": unique_triples,
|
||||
"message": f"文件 {realpath} 成功嵌入,但三元组处理或 Neo4j 插入失败: {str(e)}",
|
||||
"status_code": 200}
|
||||
|
||||
timings["total"] = time.time() - start_total
|
||||
debug(f"总耗时: {timings['total']:.2f} 秒")
|
||||
return {"status": "success", "userid": orgid, "document_id": id, "collection_name": "ragdb",
|
||||
"timings": timings,
|
||||
"unique_triples": unique_triples, "message": f"文件 {realpath} 成功嵌入并处理三元组",
|
||||
"status_code": 200}
|
||||
|
||||
except Exception as e:
|
||||
error(f"插入文档失败: {str(e)}, 堆栈: {traceback.format_exc()}")
|
||||
timings["total"] = time.time() - start_total
|
||||
return {"status": "error", "document_id": id, "collection_name": "ragdb", "timings": timings,
|
||||
"message": f"插入文档失败: {str(e)}", "status_code": 400}
|
||||
|
||||
async def file_deleted(self, request, recs, userid):
|
||||
pass
|
||||
|
||||
"""删除用户指定文件数据,包括 Milvus 和 Neo4j 中的记录"""
|
||||
if not isinstance(recs, list):
|
||||
recs = [recs] # 确保 recs 是列表,即使传入单个记录
|
||||
results = []
|
||||
api_service = APIService()
|
||||
total_nodes_deleted = 0
|
||||
total_rels_deleted = 0
|
||||
|
||||
for rec in recs:
|
||||
id = rec.get('id', '')
|
||||
realpath = rec.get('realpath', '')
|
||||
fiid = rec.get('fiid', '')
|
||||
orgid = rec.get('orgid', '')
|
||||
db_type = ''
|
||||
collection_name = "ragdb" if not db_type else f"ragdb_{db_type}"
|
||||
|
||||
try:
|
||||
required_fields = ['id', 'realpath', 'fiid', 'orgid']
|
||||
missing_fields = [field for field in required_fields if not rec.get(field, '')]
|
||||
if missing_fields:
|
||||
raise ValueError(f"缺少必填字段: {', '.join(missing_fields)}")
|
||||
|
||||
# 获取服务参数
|
||||
service_params = await self.get_service_params(orgid)
|
||||
if not service_params:
|
||||
raise ValueError("无法获取服务参数")
|
||||
|
||||
debug(
|
||||
f"调用删除文件端点: userid={orgid}, file_path={realpath}, knowledge_base_id={fiid}, document_id={id}")
|
||||
milvus_result = await api_service.milvus_delete_document(
|
||||
request=request,
|
||||
userid=orgid,
|
||||
file_path=realpath,
|
||||
knowledge_base_id=fiid,
|
||||
document_id=id,
|
||||
db_type=db_type,
|
||||
upappid=service_params['vdb'],
|
||||
apiname="milvus/deletedocument",
|
||||
user=userid
|
||||
)
|
||||
|
||||
if milvus_result.get("status") != "success":
|
||||
raise ValueError(milvus_result.get("message", "Milvus 删除失败"))
|
||||
|
||||
neo4j_deleted_nodes = 0
|
||||
neo4j_deleted_rels = 0
|
||||
try:
|
||||
debug(f"调用 Neo4j 删除文档端点: document_id={id}")
|
||||
neo4j_result = await api_service.neo4j_delete_document(
|
||||
request=request,
|
||||
document_id=id,
|
||||
upappid=service_params['gdb'],
|
||||
apiname="neo4j/deletedocument",
|
||||
user=userid
|
||||
)
|
||||
if neo4j_result.get("status") != "success":
|
||||
raise ValueError(neo4j_result.get("message", "Neo4j 删除失败"))
|
||||
nodes_deleted = neo4j_result.get("nodes_deleted", 0)
|
||||
rels_deleted = neo4j_result.get("rels_deleted", 0)
|
||||
neo4j_deleted_nodes += nodes_deleted
|
||||
neo4j_deleted_rels += rels_deleted
|
||||
total_nodes_deleted += nodes_deleted
|
||||
total_rels_deleted += rels_deleted
|
||||
info(f"成功删除 document_id={id} 的 {nodes_deleted} 个 Neo4j 节点和 {rels_deleted} 个关系")
|
||||
except Exception as e:
|
||||
error(f"删除 document_id={id} 的 Neo4j 数据失败: {str(e)}")
|
||||
|
||||
results.append({
|
||||
"status": "success",
|
||||
"collection_name": collection_name,
|
||||
"document_id": id,
|
||||
"message": f"成功删除文件 {realpath} 的 Milvus 记录,{neo4j_deleted_nodes} 个 Neo4j 节点,{neo4j_deleted_rels} 个 Neo4j 关系",
|
||||
"status_code": 200
|
||||
})
|
||||
|
||||
except Exception as e:
|
||||
error(f"删除文档 {realpath} 失败: {str(e)}, 堆栈: {traceback.format_exc()}")
|
||||
results.append({
|
||||
"status": "error",
|
||||
"collection_name": collection_name,
|
||||
"document_id": id,
|
||||
"message": f"删除文档 {realpath} 失败: {str(e)}",
|
||||
"status_code": 400
|
||||
})
|
||||
|
||||
return {
|
||||
"status": "success" if all(r["status"] == "success" for r in results) else "partial",
|
||||
"results": results,
|
||||
"total_nodes_deleted": total_nodes_deleted,
|
||||
"total_rels_deleted": total_rels_deleted,
|
||||
"message": f"处理 {len(recs)} 个文件,成功删除 {sum(1 for r in results if r['status'] == 'success')} 个",
|
||||
"status_code": 200 if all(r["status"] == "success" for r in results) else 207
|
||||
}
|
||||
|
||||
async def test_ragfilemgr():
|
||||
"""测试 RagFileMgr 类的 get_service_params"""
|
||||
print("初始化数据库连接池...")
|
||||
dbs = {
|
||||
"kyrag": {
|
||||
"driver": "aiomysql",
|
||||
"async_mode": True,
|
||||
"coding": "utf8",
|
||||
"maxconn": 100,
|
||||
"dbname": "kyrag",
|
||||
"kwargs": {
|
||||
"user": "test",
|
||||
"db": "kyrag",
|
||||
"password": "QUZVcXg5V1p1STMybG5Ia6mX9D0v7+g=",
|
||||
"host": "db"
|
||||
}
|
||||
}
|
||||
}
|
||||
DBPools(dbs)
|
||||
|
||||
ragfilemgr = RagFileMgr()
|
||||
orgid = "04J6VbxLqB_9RPMcgOv_8"
|
||||
result = await ragfilemgr.get_service_params(orgid)
|
||||
print(f"get_service_params 结果: {result}")
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
asyncio.run(test_ragfilemgr())
|
||||
|
||||
|
||||
## usage
|
||||
# mgr = RagFileMgr(fiid)
|
||||
# await mgr.add_file(request, params_kw)
|
||||
|
||||
315
rag/uapi_service.py
Normal file
315
rag/uapi_service.py
Normal file
@ -0,0 +1,315 @@
|
||||
from appPublic.log import debug, error
|
||||
from typing import Dict, Any, List
|
||||
import uuid
|
||||
from ahserver.serverenv import ServerEnv
|
||||
from uapi.init import load_uapi
|
||||
|
||||
load_uapi()
|
||||
|
||||
class APIService:
|
||||
"""处理 API 请求的服务类"""
|
||||
|
||||
# 嵌入服务 (BAAI/bge-m3)
|
||||
async def get_embeddings(self, request, texts: list, upappid: str, apiname: str, user: str) -> list:
|
||||
"""调用嵌入服务获取文本向量"""
|
||||
try:
|
||||
uapi = UAPI(request, DictObject(**globals()))
|
||||
params_kw = {"input": texts}
|
||||
d = await uapi.request(upappid, apiname, user, params_kw)
|
||||
if d.get("object") != "list" or not d.get("data"):
|
||||
error(f"嵌入服务响应格式错误: {d}")
|
||||
raise RuntimeError("嵌入服务响应格式错误")
|
||||
embeddings = [item["embedding"] for item in d["data"]]
|
||||
debug(f"成功获取 {len(embeddings)} 个嵌入向量")
|
||||
return embeddings
|
||||
except Exception as e:
|
||||
error(f"嵌入服务调用失败: {str(e)}")
|
||||
raise RuntimeError(f"嵌入服务调用失败: {str(e)}")
|
||||
|
||||
# 实体提取服务 (LTP/small)
|
||||
async def extract_entities(self, request, query: str, upappid: str, apiname: str, user: str) -> list:
|
||||
"""调用实体识别服务"""
|
||||
try:
|
||||
if not query:
|
||||
raise ValueError("查询文本不能为空")
|
||||
uapi = UAPI(request, DictObject(**globals()))
|
||||
params_kw = {"query": query}
|
||||
d = await uapi.request(upappid, apiname, user, params_kw)
|
||||
if d.get("object") != "list" or not d.get("data"):
|
||||
error(f"实体识别服务响应格式错误: {d}")
|
||||
raise RuntimeError("实体识别服务响应格式错误")
|
||||
entities = d["data"]
|
||||
unique_entities = list(dict.fromkeys(entities))
|
||||
debug(f"成功提取 {len(unique_entities)} 个唯一实体")
|
||||
return unique_entities
|
||||
except Exception as e:
|
||||
error(f"实体识别服务调用失败: {str(e)}")
|
||||
return []
|
||||
|
||||
# 三元组抽取服务 (Babelscape/mrebel-large)
|
||||
async def extract_triples(self, request, text: str, upappid: str, apiname: str, user: str) -> list:
|
||||
"""调用三元组抽取服务"""
|
||||
request_id = str(uuid.uuid4())
|
||||
debug(f"Request #{request_id} started for triples extraction")
|
||||
try:
|
||||
uapi = UAPI(request, DictObject(**globals()))
|
||||
params_kw = {"text": text}
|
||||
d = await uapi.request(upappid, apiname, user, params_kw)
|
||||
if d.get("object") != "list":
|
||||
error(f"Request #{request_id} invalid response format: {d}")
|
||||
raise RuntimeError("三元组抽取服务响应格式错误")
|
||||
triples = d["data"]
|
||||
debug(f"Request #{request_id} extracted {len(triples)} triples")
|
||||
return triples
|
||||
except Exception as e:
|
||||
error(f"Request #{request_id} failed to extract triples: {str(e)}")
|
||||
raise RuntimeError(f"三元组抽取服务调用失败: {str(e)}")
|
||||
|
||||
# 重排序服务 (BAAI/bge-reranker-v2-m3)
|
||||
async def rerank_results(self, request, query: str, results: list, top_n: int, upappid: str, apiname: str, user: str) -> list:
|
||||
"""调用重排序服务"""
|
||||
try:
|
||||
if not results:
|
||||
debug("无结果需要重排序")
|
||||
return results
|
||||
|
||||
if not isinstance(top_n, int) or top_n < 1:
|
||||
debug(f"无效的 top_n 参数: {top_n}, 使用 len(results)={len(results)}")
|
||||
top_n = len(results)
|
||||
else:
|
||||
top_n = min(top_n, len(results))
|
||||
|
||||
documents = [result.get("text", str(result)) for result in results]
|
||||
uapi = UAPI(request, DictObject(**globals()))
|
||||
params_kw = {
|
||||
"model": "rerank-001",
|
||||
"query": query,
|
||||
"documents": documents,
|
||||
"top_n": top_n
|
||||
}
|
||||
d = await uapi.request(upappid, apiname, user, params_kw)
|
||||
if d.get("object") != "rerank.result" or not d.get("data"):
|
||||
error(f"重排序服务响应格式错误: {d}")
|
||||
raise RuntimeError("重排序服务响应格式错误")
|
||||
rerank_data = d["data"]
|
||||
reranked_results = []
|
||||
for item in rerank_data:
|
||||
index = item["index"]
|
||||
if index < len(results):
|
||||
results[index]["rerank_score"] = item["relevance_score"]
|
||||
reranked_results.append(results[index])
|
||||
debug(f"成功重排序 {len(reranked_results)} 条结果")
|
||||
return reranked_results[:top_n]
|
||||
except Exception as e:
|
||||
error(f"重排序服务调用失败: {str(e)}")
|
||||
return results
|
||||
|
||||
# Neo4j 服务
|
||||
async def neo4j_docs(self, request, upappid: str, apiname: str, user: str) -> str:
|
||||
"""获取 Neo4j 文档(返回文本格式)"""
|
||||
try:
|
||||
uapi = UAPI(request, DictObject(**globals()))
|
||||
params_kw = {}
|
||||
d = await uapi.request(upappid, apiname, user, params_kw)
|
||||
if d.get("status") != 200:
|
||||
error(f"Neo4j 文档请求失败,状态码: {d.get('status')}")
|
||||
raise RuntimeError(f"Neo4j 文档请求失败: {d.get('status')}")
|
||||
text = d.get("text")
|
||||
debug(f"Neo4j 文档内容: {text}")
|
||||
return text
|
||||
except Exception as e:
|
||||
error(f"Neo4j 文档请求失败: {str(e)}")
|
||||
raise RuntimeError(f"Neo4j 文档请求失败: {str(e)}")
|
||||
|
||||
async def neo4j_initialize(self, request, upappid: str, apiname: str, user: str) -> Dict[str, Any]:
|
||||
"""初始化 Neo4j 服务"""
|
||||
uapi = UAPI(request, DictObject(**globals()))
|
||||
params_kw = {}
|
||||
return await uapi.request(upappid, apiname, user, params_kw)
|
||||
|
||||
async def neo4j_insert_triples(self, request, triples: list, document_id: str, knowledge_base_id: str, userid: str, upappid: str, apiname: str, user: str) -> Dict[str, Any]:
|
||||
"""插入三元组到 Neo4j"""
|
||||
uapi = UAPI(request, DictObject(**globals()))
|
||||
params_kw = {
|
||||
"triples": triples,
|
||||
"document_id": document_id,
|
||||
"knowledge_base_id": knowledge_base_id,
|
||||
"userid": userid
|
||||
}
|
||||
return await uapi.request(upappid, apiname, user, params_kw)
|
||||
|
||||
async def neo4j_delete_document(self, request, document_id: str, upappid: str, apiname: str, user: str) -> Dict[str, Any]:
|
||||
"""删除指定文档"""
|
||||
uapi = UAPI(request, DictObject(**globals()))
|
||||
params_kw = {"document_id": document_id}
|
||||
return await uapi.request(upappid, apiname, user, params_kw)
|
||||
|
||||
async def neo4j_delete_knowledgebase(self, request, userid: str, knowledge_base_id: str, upappid: str, apiname: str, user: str) -> Dict[str, Any]:
|
||||
"""删除用户知识库"""
|
||||
uapi = UAPI(request, DictObject(**globals()))
|
||||
params_kw = {"userid": userid, "knowledge_base_id": knowledge_base_id}
|
||||
return await uapi.request(upappid, apiname, user, params_kw)
|
||||
|
||||
async def neo4j_match_triplets(self, request, query: str, query_entities: list, userid: str, knowledge_base_id: str, upappid: str, apiname: str, user: str) -> Dict[str, Any]:
|
||||
"""根据实体匹配相关三元组"""
|
||||
uapi = UAPI(request, DictObject(**globals()))
|
||||
params_kw = {
|
||||
"query": query,
|
||||
"query_entities": query_entities,
|
||||
"userid": userid,
|
||||
"knowledge_base_id": knowledge_base_id
|
||||
}
|
||||
return await uapi.request(upappid, apiname, user, params_kw)
|
||||
|
||||
# Milvus 服务
|
||||
async def milvus_create_collection(self, request, upappid: str, apiname: str, user: str, db_type: str = "") -> Dict[str, Any]:
|
||||
"""创建 Milvus 集合"""
|
||||
uapi = UAPI(request, DictObject(**globals()))
|
||||
params_kw = {"db_type": db_type}
|
||||
return await uapi.request(upappid, apiname, user, params_kw)
|
||||
|
||||
async def milvus_delete_collection(self, request, upappid: str, apiname: str, user: str, db_type: str = "") -> Dict[str, Any]:
|
||||
"""删除 Milvus 集合"""
|
||||
uapi = UAPI(request, DictObject(**globals()))
|
||||
params_kw = {"db_type": db_type}
|
||||
return await uapi.request(upappid, apiname, user, params_kw)
|
||||
|
||||
async def milvus_insert_document(self, request, chunks: List[Dict], upappid: str, apiname: str, user: str, db_type: str = "") -> Dict[str, Any]:
|
||||
"""添加 Milvus 记录"""
|
||||
uapi = UAPI(request, DictObject(**globals()))
|
||||
params_kw = {
|
||||
"chunks": chunks,
|
||||
"dbtype": db_type
|
||||
}
|
||||
payload = json.dumps(params_kw) # 转换为 JSON 字符串
|
||||
payload_bytes = payload.encode() # 编码为字节
|
||||
payload_size = len(payload_bytes) # 获取字节数
|
||||
debug(f"Request payload size for insertdocument: {payload_size} bytes")
|
||||
return await uapi.request(upappid, apiname, user, params_kw)
|
||||
|
||||
async def milvus_delete_document(self, request, userid: str, file_path: str, knowledge_base_id: str, document_id: str, upappid: str, apiname: str, user: str, db_type: str = "") -> Dict[str, Any]:
|
||||
"""删除 Milvus 记录"""
|
||||
uapi = UAPI(request, DictObject(**globals()))
|
||||
params_kw = {
|
||||
"userid": userid,
|
||||
"file_path": file_path,
|
||||
"knowledge_base_id": knowledge_base_id,
|
||||
"document_id": document_id,
|
||||
"dbtype": db_type
|
||||
}
|
||||
return await uapi.request(upappid, apiname, user, params_kw)
|
||||
|
||||
async def milvus_delete_knowledgebase(self, request, userid: str, knowledge_base_id: str, upappid: str, apiname: str, user: str) -> Dict[str, Any]:
|
||||
"""删除 Milvus 知识库"""
|
||||
uapi = UAPI(request, DictObject(**globals()))
|
||||
params_kw = {"userid": userid, "knowledge_base_id": knowledge_base_id}
|
||||
return await uapi.request(upappid, apiname, user, params_kw)
|
||||
|
||||
async def milvus_search_query(self, request, query_vector: List[float], userid: str, knowledge_base_ids: list, limit: int, offset: int, upappid: str, apiname: str, user: str) -> Dict[str, Any]:
|
||||
"""根据用户知识库检索 Milvus"""
|
||||
uapi = UAPI(request, DictObject(**globals()))
|
||||
params_kw = {
|
||||
"query_vector": query_vector,
|
||||
"userid": userid,
|
||||
"knowledge_base_ids": knowledge_base_ids,
|
||||
"limit": limit,
|
||||
"offset": offset
|
||||
}
|
||||
return await uapi.request(upappid, apiname, user, params_kw)
|
||||
|
||||
async def milvus_list_user_files(self, request, userid: str, upappid: str, apiname: str, user: str) -> Dict[str, Any]:
|
||||
"""列出 Milvus 用户知识库列表"""
|
||||
uapi = UAPI(request, DictObject(**globals()))
|
||||
params_kw = {"userid": userid}
|
||||
return await uapi.request(upappid, apiname, user, params_kw)
|
||||
|
||||
async def milvus_list_all_knowledgebases(self, request, upappid: str, apiname: str, user: str) -> Dict[str, Any]:
|
||||
"""列出 Milvus 数据库中所有数据"""
|
||||
uapi = UAPI(request, DictObject(**globals()))
|
||||
params_kw = {}
|
||||
return await uapi.request(upappid, apiname, user, params_kw)
|
||||
|
||||
# RAG 服务
|
||||
async def rag_create_collection(self, request, upappid: str, apiname: str, user: str, db_type: str = "") -> Dict[str, Any]:
|
||||
"""创建 RAG 集合"""
|
||||
uapi = UAPI(request, DictObject(**globals()))
|
||||
params_kw = {"db_type": db_type}
|
||||
return await uapi.request(upappid, apiname, user, params_kw)
|
||||
|
||||
async def rag_delete_collection(self, request, upappid: str, apiname: str, user: str, db_type: str = "") -> Dict[str, Any]:
|
||||
"""删除 RAG 集合"""
|
||||
uapi = UAPI(request, DictObject(**globals()))
|
||||
params_kw = {"db_type": db_type}
|
||||
return await uapi.request(upappid, apiname, user, params_kw)
|
||||
|
||||
async def rag_insert_file(self, request, file_path: str, userid: str, knowledge_base_id: str, document_id: str, upappid: str, apiname: str, user: str) -> Dict[str, Any]:
|
||||
"""添加 RAG 记录"""
|
||||
uapi = UAPI(request, DictObject(**globals()))
|
||||
params_kw = {
|
||||
"file_path": file_path,
|
||||
"userid": userid,
|
||||
"knowledge_base_id": knowledge_base_id,
|
||||
"document_id": document_id
|
||||
}
|
||||
return await uapi.request(upappid, apiname, user, params_kw)
|
||||
|
||||
async def rag_delete_file(self, request, userid: str, file_path: str, knowledge_base_id: str, document_id: str, upappid: str, apiname: str, user: str) -> Dict[str, Any]:
|
||||
"""删除 RAG 记录"""
|
||||
uapi = UAPI(request, DictObject(**globals()))
|
||||
params_kw = {
|
||||
"userid": userid,
|
||||
"file_path": file_path,
|
||||
"knowledge_base_id": knowledge_base_id,
|
||||
"document_id": document_id
|
||||
}
|
||||
return await uapi.request(upappid, apiname, user, params_kw)
|
||||
|
||||
async def rag_delete_knowledgebase(self, request, userid: str, knowledge_base_id: str, upappid: str, apiname: str, user: str) -> Dict[str, Any]:
|
||||
"""删除 RAG 知识库"""
|
||||
uapi = UAPI(request, DictObject(**globals()))
|
||||
params_kw = {"userid": userid, "knowledge_base_id": knowledge_base_id}
|
||||
return await uapi.request(upappid, apiname, user, params_kw)
|
||||
|
||||
async def rag_search_query(self, request, query: str, userid: str, knowledge_base_ids: list, limit: int, offset: int, use_rerank: bool, upappid: str, apiname: str, user: str) -> Dict[str, Any]:
|
||||
"""根据用户知识库检索 RAG"""
|
||||
uapi = UAPI(request, DictObject(**globals()))
|
||||
params_kw = {
|
||||
"query": query,
|
||||
"userid": userid,
|
||||
"knowledge_base_ids": knowledge_base_ids,
|
||||
"limit": limit,
|
||||
"offset": offset,
|
||||
"use_rerank": use_rerank
|
||||
}
|
||||
return await uapi.request(upappid, apiname, user, params_kw)
|
||||
|
||||
async def rag_fused_search_query(self, request, query: str, userid: str, knowledge_base_ids: list, limit: int, offset: int, use_rerank: bool, upappid: str, apiname: str, user: str) -> Dict[str, Any]:
|
||||
"""根据用户知识库+知识图谱检索 RAG"""
|
||||
uapi = UAPI(request, DictObject(**globals()))
|
||||
params_kw = {
|
||||
"query": query,
|
||||
"userid": userid,
|
||||
"knowledge_base_ids": knowledge_base_ids,
|
||||
"limit": limit,
|
||||
"offset": offset,
|
||||
"use_rerank": use_rerank
|
||||
}
|
||||
return await uapi.request(upappid, apiname, user, params_kw)
|
||||
|
||||
async def rag_list_user_files(self, request, userid: str, upappid: str, apiname: str, user: str) -> Dict[str, Any]:
|
||||
"""列出 RAG 用户知识库列表"""
|
||||
uapi = UAPI(request, DictObject(**globals()))
|
||||
params_kw = {"userid": userid}
|
||||
return await uapi.request(upappid, apiname, user, params_kw)
|
||||
|
||||
async def rag_list_all_knowledgebases(self, request, upappid: str, apiname: str, user: str) -> Dict[str, Any]:
|
||||
"""列出 RAG 数据库中所有数据"""
|
||||
uapi = UAPI(request, DictObject(**globals()))
|
||||
params_kw = {}
|
||||
return await uapi.request(upappid, apiname, user, params_kw)
|
||||
|
||||
async def rag_docs(self, request, upappid: str, apiname: str, user: str) -> Dict[str, Any]:
|
||||
"""获取 RAG 帮助文档"""
|
||||
uapi = UAPI(request, DictObject(**globals()))
|
||||
params_kw = {}
|
||||
return await uapi.request(upappid, apiname, user, params_kw)
|
||||
Loading…
x
Reference in New Issue
Block a user