This commit is contained in:
wangmeihua 2025-09-09 13:17:22 +08:00
parent ea1a9a084f
commit 516edb5b6a
3 changed files with 481 additions and 428 deletions

View File

@ -20,463 +20,461 @@ from typing import List, Dict, Any
import json import json
class RagFileMgr(FileMgr): class RagFileMgr(FileMgr):
async def get_folder_ownerid(self, sor): async def get_folder_ownerid(self, sor):
fiid = self.fiid fiid = self.fiid
recs = await sor.R('kdb', {'id': self.fiid}) recs = await sor.R('kdb', {'id': self.fiid})
if len(recs) > 0: if len(recs) > 0:
return recs[0].orgid return recs[0].orgid
return None return None
async def get_organization_quota(self, sor, orgid): async def get_organization_quota(self, sor, orgid):
sql = """select a.* from ragquota a, kdb b sql = """select a.* from ragquota a, kdb b
where a.orgid = b.orgid where a.orgid = b.orgid
and b.id = ${id}$ and b.id = ${id}$
and ${today}$ >= a.enabled_date and ${today}$ >= a.enabled_date
and ${today}$ < a.expired_date and ${today}$ < a.expired_date
""" """
recs = await sor.sqlExe(sql, { recs = await sor.sqlExe(sql, {
'id': self.fiid, 'id': self.fiid,
'today': curDateString() 'today': curDateString()
}) })
if len(recs) > 0: if len(recs) > 0:
r = recs[0] r = recs[0]
return r.quota, r.expired_date return r.quota, r.expired_date
return None, None return None, None
async def get_service_params(self,orgid): async def get_service_params(self,orgid):
""" 根据 orgid 从数据库获取服务参数 (仅 upappid),假设 service_opts 表返回单条记录。 """ """ 根据 orgid 从数据库获取服务参数 (仅 upappid),假设 service_opts 表返回单条记录。 """
db = DBPools() db = DBPools()
dbname = "kyrag" dbname = "kyrag"
sql_opts = """ sql_opts = """
SELECT embedding_id, vdb_id, reranker_id, triples_id, gdb_id, entities_id SELECT embedding_id, vdb_id, reranker_id, triples_id, gdb_id, entities_id
FROM service_opts FROM service_opts
WHERE orgid = ${orgid}$ WHERE orgid = ${orgid}$
""" """
try: try:
async with db.sqlorContext(dbname) as sor: async with db.sqlorContext(dbname) as sor:
opts_result = await sor.sqlExe(sql_opts, {"orgid": orgid}) opts_result = await sor.sqlExe(sql_opts, {"orgid": orgid})
if not opts_result: if not opts_result:
error(f"未找到 orgid={orgid} 的服务配置") error(f"未找到 orgid={orgid} 的服务配置")
return None return None
opts = opts_result[0] opts = opts_result[0]
except Exception as e: except Exception as e:
error(f"查询 service_opts 失败: {str(e)}, 堆栈: {traceback.format_exc()}") error(f"查询 service_opts 失败: {str(e)}, 堆栈: {traceback.format_exc()}")
return None return None
# 收集服务 ID # 收集服务 ID
service_ids = set() service_ids = set()
for key in ['embedding_id', 'vdb_id', 'reranker_id', 'triples_id', 'gdb_id', 'entities_id']: for key in ['embedding_id', 'vdb_id', 'reranker_id', 'triples_id', 'gdb_id', 'entities_id']:
if opts[key]: if opts[key]:
service_ids.add(opts[key]) service_ids.add(opts[key])
# 检查 service_ids 是否为空 # 检查 service_ids 是否为空
if not service_ids: if not service_ids:
error(f"未找到任何服务 ID for orgid={orgid}") error(f"未找到任何服务 ID for orgid={orgid}")
return None return None
# 手动构造 IN 子句的 ID 列表 # 手动构造 IN 子句的 ID 列表
id_list = ','.join([f"'{id}'" for id in service_ids]) # 确保每个 ID 被单引号包裹 id_list = ','.join([f"'{id}'" for id in service_ids]) # 确保每个 ID 被单引号包裹
sql_services = f""" sql_services = f"""
SELECT id, name, upappid SELECT id, name, upappid
FROM ragservices FROM ragservices
WHERE id IN ({id_list}) WHERE id IN ({id_list})
""" """
try: try:
async with db.sqlorContext(dbname) as sor: async with db.sqlorContext(dbname) as sor:
services_result = await sor.sqlExe(sql_services, {}) services_result = await sor.sqlExe(sql_services, {})
if not services_result: if not services_result:
error(f"未找到服务 ID {service_ids} 的 ragservices 配置") error(f"未找到服务 ID {service_ids} 的 ragservices 配置")
return None return None
# 构建服务参数字典,基于 name 字段匹配,仅存储 upappid # 构建服务参数字典,基于 name 字段匹配,仅存储 upappid
service_params = { service_params = {
'embedding': None, 'embedding': None,
'vdb': None, 'vdb': None,
'reranker': None, 'reranker': None,
'triples': None, 'triples': None,
'gdb': None, 'gdb': None,
'entities': None 'entities': None
} }
for service in services_result: for service in services_result:
name = service['name'] name = service['name']
if name == 'bgem3嵌入': if name == 'bgem3嵌入':
service_params['embedding'] = service['upappid'] service_params['embedding'] = service['upappid']
elif name == 'milvus向量检索': elif name == 'milvus向量检索':
service_params['vdb'] = service['upappid'] service_params['vdb'] = service['upappid']
elif name == 'bgem2v3重排': elif name == 'bgem2v3重排':
service_params['reranker'] = service['upappid'] service_params['reranker'] = service['upappid']
elif name == 'mrebel三元组抽取': elif name == 'mrebel三元组抽取':
service_params['triples'] = service['upappid'] service_params['triples'] = service['upappid']
elif name == 'neo4j删除知识库': elif name == 'neo4j删除知识库':
service_params['gdb'] = service['upappid'] service_params['gdb'] = service['upappid']
elif name == 'small实体抽取': elif name == 'small实体抽取':
service_params['entities'] = service['upappid'] service_params['entities'] = service['upappid']
# 检查是否所有服务参数都已填充 # 检查是否所有服务参数都已填充
missing_services = [k for k, v in service_params.items() if v is None] missing_services = [k for k, v in service_params.items() if v is None]
if missing_services: if missing_services:
error(f"未找到以下服务的配置: {missing_services}") error(f"未找到以下服务的配置: {missing_services}")
return None return None
return service_params return service_params
except Exception as e: except Exception as e:
error(f"查询 ragservices 失败: {str(e)}, 堆栈: {traceback.format_exc()}") error(f"查询 ragservices 失败: {str(e)}, 堆栈: {traceback.format_exc()}")
return None return None
async def file_uploaded(self, request, ns, userid): async def file_uploaded(self, request, ns, userid):
"""将文档插入 Milvus 并抽取三元组到 Neo4j""" """将文档插入 Milvus 并抽取三元组到 Neo4j"""
debug(f'Received ns: {ns=}') debug(f'Received ns: {ns=}')
realpath = ns.get('realpath', '') realpath = ns.get('realpath', '')
fiid = ns.get('fiid', '') fiid = ns.get('fiid', '')
id = ns.get('id', '') id = ns.get('id', '')
orgid = ns.get('ownerid', '') orgid = ns.get('ownerid', '')
hashvalue = ns.get('hashvalue', '') hashvalue = ns.get('hashvalue', '')
db_type = '' db_type = ''
api_service = APIService() api_service = APIService()
debug( debug(
f'Inserting document: file_path={realpath}, userid={orgid}, db_type={db_type}, knowledge_base_id={fiid}, document_id={id}') f'Inserting document: file_path={realpath}, userid={orgid}, db_type={db_type}, knowledge_base_id={fiid}, document_id={id}')
timings = {} timings = {}
start_total = time.time() start_total = time.time()
try: try:
if not orgid or not fiid or not id: if not orgid or not fiid or not id:
raise ValueError("orgid、fiid 和 id 不能为空") raise ValueError("orgid、fiid 和 id 不能为空")
if len(orgid) > 32 or len(fiid) > 255: if len(orgid) > 32 or len(fiid) > 255:
raise ValueError("orgid 或 fiid 的长度超出限制") raise ValueError("orgid 或 fiid 的长度超出限制")
if not os.path.exists(realpath): if not os.path.exists(realpath):
raise ValueError(f"文件 {realpath} 不存在") raise ValueError(f"文件 {realpath} 不存在")
# 检查 hashvalue 是否已存在 # 检查 hashvalue 是否已存在
db = DBPools() db = DBPools()
dbname = "kyrag" dbname = "kyrag"
sql_check_hash = """ sql_check_hash = """
SELECT hashvalue SELECT hashvalue
FROM file FROM file
WHERE hashvalue = ${hashvalue}$ WHERE hashvalue = ${hashvalue}$
""" """
async with db.sqlorContext(dbname) as sor: async with db.sqlorContext(dbname) as sor:
hash_result = await sor.sqlExe(sql_check_hash, {"hashvalue": hashvalue}) hash_result = await sor.sqlExe(sql_check_hash, {"hashvalue": hashvalue})
if hash_result: if hash_result:
debug(f"文件已存在: hashvalue={hashvalue}") debug(f"文件已存在: hashvalue={hashvalue}")
timings["total"] = time.time() - start_total timings["total"] = time.time() - start_total
return { return {
"status": "error", "status": "error",
"document_id": id, "document_id": id,
"collection_name": "ragdb", "collection_name": "ragdb",
"timings": timings, "timings": timings,
"message": f"文件已存在: hashvalue={hashvalue}", "message": f"文件已存在: hashvalue={hashvalue}",
"status_code": 400 "status_code": 400
} }
# 获取服务参数 # 获取服务参数
service_params = await self.get_service_params(orgid) service_params = await self.get_service_params(orgid)
if not service_params: if not service_params:
raise ValueError("无法获取服务参数") raise ValueError("无法获取服务参数")
supported_formats = {'pdf', 'docx', 'xlsx', 'pptx', 'csv', 'txt'} supported_formats = {'pdf', 'docx', 'xlsx', 'pptx', 'csv', 'txt'}
ext = realpath.rsplit('.', 1)[1].lower() if '.' in realpath else '' ext = realpath.rsplit('.', 1)[1].lower() if '.' in realpath else ''
if ext not in supported_formats: if ext not in supported_formats:
raise ValueError(f"不支持的文件格式: {ext}, 支持的格式: {', '.join(supported_formats)}") raise ValueError(f"不支持的文件格式: {ext}, 支持的格式: {', '.join(supported_formats)}")
debug(f"加载文件: {realpath}") debug(f"加载文件: {realpath}")
start_load = time.time() start_load = time.time()
text = fileloader(realpath) text = fileloader(realpath)
# debug(f"处理后的文件内容是:{text=}") # debug(f"处理后的文件内容是:{text=}")
text = re.sub(r'[^\u4e00-\u9fa5a-zA-Z0-9\s.;,\n]', '', text) text = re.sub(r'[^\u4e00-\u9fa5a-zA-Z0-9\s.;,\n/]', '', text)
timings["load_file"] = time.time() - start_load timings["load_file"] = time.time() - start_load
debug(f"加载文件耗时: {timings['load_file']:.2f} 秒, 文本长度: {len(text)}") debug(f"加载文件耗时: {timings['load_file']:.2f} 秒, 文本长度: {len(text)}")
if not text or not text.strip(): if not text or not text.strip():
raise ValueError(f"文件 {realpath} 加载为空") raise ValueError(f"文件 {realpath} 加载为空")
document = Document(page_content=text) document = Document(page_content=text)
text_splitter = RecursiveCharacterTextSplitter( text_splitter = RecursiveCharacterTextSplitter(
chunk_size=500, chunk_size=500,
chunk_overlap=100, chunk_overlap=100,
length_function=len) length_function=len)
debug("开始分片文件内容") debug("开始分片文件内容")
start_split = time.time() start_split = time.time()
chunks = text_splitter.split_documents([document]) chunks = text_splitter.split_documents([document])
timings["split_text"] = time.time() - start_split timings["split_text"] = time.time() - start_split
debug( debug(
f"文本分片耗时: {timings['split_text']:.2f} 秒, 分片数量: {len(chunks)}, 分片内容: {[chunk.page_content[:50] for chunk in chunks[:5]]}") f"文本分片耗时: {timings['split_text']:.2f} 秒, 分片数量: {len(chunks)}, 分片内容: {[chunk.page_content[:50] for chunk in chunks[:5]]}")
if not chunks: debug(f"分片内容: {[chunk.page_content[:100] + '...' for chunk in chunks]}")
raise ValueError(f"文件 {realpath} 未生成任何文档块") if not chunks:
raise ValueError(f"文件 {realpath} 未生成任何文档块")
filename = os.path.basename(realpath).rsplit('.', 1)[0] filename = os.path.basename(realpath).rsplit('.', 1)[0]
upload_time = datetime.now().isoformat() upload_time = datetime.now().isoformat()
debug("调用嵌入服务生成向量") debug("调用嵌入服务生成向量")
start_embedding = time.time() start_embedding = time.time()
texts = [chunk.page_content for chunk in chunks] texts = [chunk.page_content for chunk in chunks]
embeddings = [] embeddings = []
for i in range(0, len(texts), 10): # 每次处理 10 个文本块 for i in range(0, len(texts), 10): # 每次处理 10 个文本块
batch_texts = texts[i:i + 10] batch_texts = texts[i:i + 10]
batch_embeddings = await api_service.get_embeddings( batch_embeddings = await api_service.get_embeddings(
request=request, request=request,
texts=batch_texts, texts=batch_texts,
upappid=service_params['embedding'], upappid=service_params['embedding'],
apiname="BAAI/bge-m3", apiname="BAAI/bge-m3",
user=userid user=userid
) )
embeddings.extend(batch_embeddings) embeddings.extend(batch_embeddings)
if not embeddings or not all(len(vec) == 1024 for vec in embeddings): if not embeddings or not all(len(vec) == 1024 for vec in embeddings):
raise ValueError("所有嵌入向量必须是长度为 1024 的浮点数列表") raise ValueError("所有嵌入向量必须是长度为 1024 的浮点数列表")
timings["generate_embeddings"] = time.time() - start_embedding timings["generate_embeddings"] = time.time() - start_embedding
debug(f"生成嵌入向量耗时: {timings['generate_embeddings']:.2f} 秒, 嵌入数量: {len(embeddings)}") debug(f"生成嵌入向量耗时: {timings['generate_embeddings']:.2f} 秒, 嵌入数量: {len(embeddings)}")
chunks_data = [] chunks_data = []
for i, chunk in enumerate(chunks): for i, chunk in enumerate(chunks):
chunks_data.append({ chunks_data.append({
"userid": orgid, "userid": orgid,
"knowledge_base_id": fiid, "knowledge_base_id": fiid,
"text": chunk.page_content, "text": chunk.page_content,
"vector": embeddings[i], "vector": embeddings[i],
"document_id": id, "document_id": id,
"filename": filename + '.' + ext, "filename": filename + '.' + ext,
"file_path": realpath, "file_path": realpath,
"upload_time": upload_time, "upload_time": upload_time,
"file_type": ext, "file_type": ext,
}) })
debug(f"调用插入文件端点: {realpath}") debug(f"调用插入文件端点: {realpath}")
start_milvus = time.time() start_milvus = time.time()
for i in range(0, len(chunks_data), 10): # 每次处理 10 条数据 for i in range(0, len(chunks_data), 10): # 每次处理 10 条数据
batch_chunks = chunks_data[i:i + 10] batch_chunks = chunks_data[i:i + 10]
result = await api_service.milvus_insert_document( result = await api_service.milvus_insert_document(
request=request, request=request,
chunks=batch_chunks, chunks=batch_chunks,
db_type=db_type, db_type=db_type,
upappid=service_params['vdb'], upappid=service_params['vdb'],
apiname="milvus/insertdocument", apiname="milvus/insertdocument",
user=userid user=userid
) )
if result.get("status") != "success": if result.get("status") != "success":
raise ValueError(result.get("message", "Milvus 插入失败")) raise ValueError(result.get("message", "Milvus 插入失败"))
timings["insert_milvus"] = time.time() - start_milvus timings["insert_milvus"] = time.time() - start_milvus
debug(f"Milvus 插入耗时: {timings['insert_milvus']:.2f}") debug(f"Milvus 插入耗时: {timings['insert_milvus']:.2f}")
if result.get("status") != "success": if result.get("status") != "success":
timings["total"] = time.time() - start_total timings["total"] = time.time() - start_total
return {"status": "error", "document_id": id, "timings": timings, return {"status": "error", "document_id": id, "timings": timings,
"message": result.get("message", "未知错误"), "status_code": 400} "message": result.get("message", "未知错误"), "status_code": 400}
debug("调用三元组抽取服务") debug("调用三元组抽取服务")
start_triples = time.time() start_triples = time.time()
unique_triples = [] unique_triples = []
try: try:
chunk_texts = [doc.page_content for doc in chunks] chunk_texts = [doc.page_content for doc in chunks]
debug(f"处理 {len(chunk_texts)} 个分片进行三元组抽取") debug(f"处理 {len(chunk_texts)} 个分片进行三元组抽取")
tasks = [ triples = []
api_service.extract_triples( for i, chunk in enumerate(chunk_texts):
request=request, result = await api_service.extract_triples(
text=chunk, request=request,
upappid=service_params['triples'], text=chunk,
apiname="Babelscape/mrebel-large", upappid=service_params['triples'],
user=userid apiname="Babelscape/mrebel-large",
) for chunk in chunk_texts user=userid
] )
results = await asyncio.gather(*tasks, return_exceptions=True) if isinstance(result, list):
triples = [] triples.extend(result)
for i, result in enumerate(results): debug(f"分片 {i + 1} 抽取到 {len(result)} 个三元组: {result[:5]}")
if isinstance(result, list): else:
triples.extend(result) error(f"分片 {i + 1} 处理失败: {str(result)}")
debug(f"分片 {i + 1} 抽取到 {len(result)} 个三元组: {result[:5]}")
else:
error(f"分片 {i + 1} 处理失败: {str(result)}")
seen = set() seen = set()
for t in triples: for t in triples:
identifier = (t['head'].lower(), t['tail'].lower(), t['type'].lower()) identifier = (t['head'].lower(), t['tail'].lower(), t['type'].lower())
if identifier not in seen: if identifier not in seen:
seen.add(identifier) seen.add(identifier)
unique_triples.append(t) unique_triples.append(t)
else: else:
for existing in unique_triples: for existing in unique_triples:
if (existing['head'].lower() == t['head'].lower() and if (existing['head'].lower() == t['head'].lower() and
existing['tail'].lower() == t['tail'].lower() and existing['tail'].lower() == t['tail'].lower() and
len(t['type']) > len(existing['type'])): len(t['type']) > len(existing['type'])):
unique_triples.remove(existing) unique_triples.remove(existing)
unique_triples.append(t) unique_triples.append(t)
debug(f"替换三元组为更具体类型: {t}") debug(f"替换三元组为更具体类型: {t}")
break break
timings["extract_triples"] = time.time() - start_triples timings["extract_triples"] = time.time() - start_triples
debug( debug(
f"三元组抽取耗时: {timings['extract_triples']:.2f} 秒, 抽取到 {len(unique_triples)} 个三元组: {unique_triples[:5]}") f"三元组抽取耗时: {timings['extract_triples']:.2f} 秒, 抽取到 {len(unique_triples)} 个三元组: {unique_triples[:5]}")
if unique_triples: if unique_triples:
debug(f"抽取到 {len(unique_triples)} 个三元组,调用 Neo4j 服务插入") debug(f"抽取到 {len(unique_triples)} 个三元组,调用 Neo4j 服务插入")
start_neo4j = time.time() start_neo4j = time.time()
for i in range(0, len(unique_triples), 30): # 每次插入 30 个三元组 for i in range(0, len(unique_triples), 30): # 每次插入 30 个三元组
batch_triples = unique_triples[i:i + 30] batch_triples = unique_triples[i:i + 30]
neo4j_result = await api_service.neo4j_insert_triples( neo4j_result = await api_service.neo4j_insert_triples(
request=request, request=request,
triples=batch_triples, triples=batch_triples,
document_id=id, document_id=id,
knowledge_base_id=fiid, knowledge_base_id=fiid,
userid=orgid, userid=orgid,
upappid=service_params['gdb'], upappid=service_params['gdb'],
apiname="neo4j/inserttriples", apiname="neo4j/inserttriples",
user=userid user=userid
) )
debug(f"Neo4j 服务响应: {neo4j_result}") debug(f"Neo4j 服务响应: {neo4j_result}")
if neo4j_result.get("status") != "success": if neo4j_result.get("status") != "success":
timings["insert_neo4j"] = time.time() - start_neo4j timings["insert_neo4j"] = time.time() - start_neo4j
timings["total"] = time.time() - start_total timings["total"] = time.time() - start_total
return { return {
"status": "error", "status": "error",
"document_id": id, "document_id": id,
"collection_name": "ragdb", "collection_name": "ragdb",
"timings": timings, "timings": timings,
"message": f"Neo4j 三元组插入失败: {neo4j_result.get('message', '未知错误')}", "message": f"Neo4j 三元组插入失败: {neo4j_result.get('message', '未知错误')}",
"status_code": 400 "status_code": 400
} }
info(f"文件 {realpath} 三元组成功插入 Neo4j: {neo4j_result.get('message')}") info(f"文件 {realpath} 三元组成功插入 Neo4j: {neo4j_result.get('message')}")
timings["insert_neo4j"] = time.time() - start_neo4j timings["insert_neo4j"] = time.time() - start_neo4j
debug(f"Neo4j 插入耗时: {timings['insert_neo4j']:.2f}") debug(f"Neo4j 插入耗时: {timings['insert_neo4j']:.2f}")
else: else:
debug(f"文件 {realpath} 未抽取到三元组") debug(f"文件 {realpath} 未抽取到三元组")
timings["insert_neo4j"] = 0.0 timings["insert_neo4j"] = 0.0
except Exception as e: except Exception as e:
timings["extract_triples"] = time.time() - start_triples if "extract_triples" not in timings else \ timings["extract_triples"] = time.time() - start_triples if "extract_triples" not in timings else \
timings["extract_triples"] timings["extract_triples"]
timings["insert_neo4j"] = time.time() - start_neo4j if "insert_neo4j" not in timings else timings[ timings["insert_neo4j"] = time.time() - start_neo4j if "insert_neo4j" not in timings else timings[
"insert_neo4j"] "insert_neo4j"]
debug(f"处理三元组或 Neo4j 插入失败: {str(e)}, 堆栈: {traceback.format_exc()}") debug(f"处理三元组或 Neo4j 插入失败: {str(e)}, 堆栈: {traceback.format_exc()}")
timings["total"] = time.time() - start_total timings["total"] = time.time() - start_total
return { return {
"status": "success", "status": "success",
"document_id": id, "document_id": id,
"collection_name": "ragdb", "collection_name": "ragdb",
"timings": timings, "timings": timings,
"unique_triples": unique_triples, "unique_triples": unique_triples,
"message": f"文件 {realpath} 成功嵌入,但三元组处理或 Neo4j 插入失败: {str(e)}", "message": f"文件 {realpath} 成功嵌入,但三元组处理或 Neo4j 插入失败: {str(e)}",
"status_code": 200 "status_code": 200
} }
timings["total"] = time.time() - start_total timings["total"] = time.time() - start_total
debug(f"总耗时: {timings['total']:.2f}") debug(f"总耗时: {timings['total']:.2f}")
return { return {
"status": "success", "status": "success",
"userid": orgid, "userid": orgid,
"document_id": id, "document_id": id,
"collection_name": "ragdb", "collection_name": "ragdb",
"timings": timings, "timings": timings,
"unique_triples": unique_triples, "unique_triples": unique_triples,
"message": f"文件 {realpath} 成功嵌入并处理三元组", "message": f"文件 {realpath} 成功嵌入并处理三元组",
"status_code": 200 "status_code": 200
} }
except Exception as e: except Exception as e:
error(f"插入文档失败: {str(e)}, 堆栈: {traceback.format_exc()}") error(f"插入文档失败: {str(e)}, 堆栈: {traceback.format_exc()}")
timings["total"] = time.time() - start_total timings["total"] = time.time() - start_total
return { return {
"status": "error", "status": "error",
"document_id": id, "document_id": id,
"collection_name": "ragdb", "collection_name": "ragdb",
"timings": timings, "timings": timings,
"message": f"插入文档失败: {str(e)}", "message": f"插入文档失败: {str(e)}",
"status_code": 400 "status_code": 400
} }
async def file_deleted(self, request, recs, userid): async def file_deleted(self, request, recs, userid):
"""删除用户指定文件数据,包括 Milvus 和 Neo4j 中的记录""" """删除用户指定文件数据,包括 Milvus 和 Neo4j 中的记录"""
if not isinstance(recs, list): if not isinstance(recs, list):
recs = [recs] # 确保 recs 是列表,即使传入单个记录 recs = [recs] # 确保 recs 是列表,即使传入单个记录
results = [] results = []
api_service = APIService() api_service = APIService()
total_nodes_deleted = 0 total_nodes_deleted = 0
total_rels_deleted = 0 total_rels_deleted = 0
for rec in recs: for rec in recs:
id = rec.get('id', '') id = rec.get('id', '')
realpath = rec.get('realpath', '') realpath = rec.get('realpath', '')
fiid = rec.get('fiid', '') fiid = rec.get('fiid', '')
orgid = rec.get('ownerid', '') orgid = rec.get('ownerid', '')
db_type = '' db_type = ''
collection_name = "ragdb" if not db_type else f"ragdb_{db_type}" collection_name = "ragdb" if not db_type else f"ragdb_{db_type}"
try: try:
required_fields = ['id', 'realpath', 'fiid', 'ownerid'] required_fields = ['id', 'realpath', 'fiid', 'ownerid']
missing_fields = [field for field in required_fields if not rec.get(field, '')] missing_fields = [field for field in required_fields if not rec.get(field, '')]
if missing_fields: if missing_fields:
raise ValueError(f"缺少必填字段: {', '.join(missing_fields)}") raise ValueError(f"缺少必填字段: {', '.join(missing_fields)}")
# 获取服务参数 # 获取服务参数
service_params = await self.get_service_params(orgid) service_params = await self.get_service_params(orgid)
if not service_params: if not service_params:
raise ValueError("无法获取服务参数") raise ValueError("无法获取服务参数")
debug( debug(
f"调用删除文件端点: userid={orgid}, file_path={realpath}, knowledge_base_id={fiid}, document_id={id}") f"调用删除文件端点: userid={orgid}, file_path={realpath}, knowledge_base_id={fiid}, document_id={id}")
milvus_result = await api_service.milvus_delete_document( milvus_result = await api_service.milvus_delete_document(
request=request, request=request,
userid=orgid, userid=orgid,
file_path=realpath, file_path=realpath,
knowledge_base_id=fiid, knowledge_base_id=fiid,
document_id=id, document_id=id,
db_type=db_type, db_type=db_type,
upappid=service_params['vdb'], upappid=service_params['vdb'],
apiname="milvus/deletedocument", apiname="milvus/deletedocument",
user=userid user=userid
) )
if milvus_result.get("status") != "success": if milvus_result.get("status") != "success":
raise ValueError(milvus_result.get("message", "Milvus 删除失败")) raise ValueError(milvus_result.get("message", "Milvus 删除失败"))
neo4j_deleted_nodes = 0 neo4j_deleted_nodes = 0
neo4j_deleted_rels = 0 neo4j_deleted_rels = 0
try: try:
debug(f"调用 Neo4j 删除文档端点: document_id={id}") debug(f"调用 Neo4j 删除文档端点: document_id={id}")
neo4j_result = await api_service.neo4j_delete_document( neo4j_result = await api_service.neo4j_delete_document(
request=request, request=request,
document_id=id, document_id=id,
upappid=service_params['gdb'], upappid=service_params['gdb'],
apiname="neo4j/deletedocument", apiname="neo4j/deletedocument",
user=userid user=userid
) )
if neo4j_result.get("status") != "success": if neo4j_result.get("status") != "success":
raise ValueError(neo4j_result.get("message", "Neo4j 删除失败")) raise ValueError(neo4j_result.get("message", "Neo4j 删除失败"))
nodes_deleted = neo4j_result.get("nodes_deleted", 0) nodes_deleted = neo4j_result.get("nodes_deleted", 0)
rels_deleted = neo4j_result.get("rels_deleted", 0) rels_deleted = neo4j_result.get("rels_deleted", 0)
neo4j_deleted_nodes += nodes_deleted neo4j_deleted_nodes += nodes_deleted
neo4j_deleted_rels += rels_deleted neo4j_deleted_rels += rels_deleted
total_nodes_deleted += nodes_deleted total_nodes_deleted += nodes_deleted
total_rels_deleted += rels_deleted total_rels_deleted += rels_deleted
info(f"成功删除 document_id={id}{nodes_deleted} 个 Neo4j 节点和 {rels_deleted} 个关系") info(f"成功删除 document_id={id}{nodes_deleted} 个 Neo4j 节点和 {rels_deleted} 个关系")
except Exception as e: except Exception as e:
error(f"删除 document_id={id} 的 Neo4j 数据失败: {str(e)}") error(f"删除 document_id={id} 的 Neo4j 数据失败: {str(e)}")
results.append({ results.append({
"status": "success", "status": "success",
"collection_name": collection_name, "collection_name": collection_name,
"document_id": id, "document_id": id,
"message": f"成功删除文件 {realpath} 的 Milvus 记录,{neo4j_deleted_nodes} 个 Neo4j 节点,{neo4j_deleted_rels} 个 Neo4j 关系", "message": f"成功删除文件 {realpath} 的 Milvus 记录,{neo4j_deleted_nodes} 个 Neo4j 节点,{neo4j_deleted_rels} 个 Neo4j 关系",
"status_code": 200 "status_code": 200
}) })
except Exception as e: except Exception as e:
error(f"删除文档 {realpath} 失败: {str(e)}, 堆栈: {traceback.format_exc()}") error(f"删除文档 {realpath} 失败: {str(e)}, 堆栈: {traceback.format_exc()}")
results.append({ results.append({
"status": "error", "status": "error",
"collection_name": collection_name, "collection_name": collection_name,
"document_id": id, "document_id": id,
"message": f"删除文档 {realpath} 失败: {str(e)}", "message": f"删除文档 {realpath} 失败: {str(e)}",
"status_code": 400 "status_code": 400
}) })
return { return {
"status": "success" if all(r["status"] == "success" for r in results) else "partial", "status": "success" if all(r["status"] == "success" for r in results) else "partial",
"results": results, "results": results,
"total_nodes_deleted": total_nodes_deleted, "total_nodes_deleted": total_nodes_deleted,
"total_rels_deleted": total_rels_deleted, "total_rels_deleted": total_rels_deleted,
"message": f"处理 {len(recs)} 个文件,成功删除 {sum(1 for r in results if r['status'] == 'success')}", "message": f"处理 {len(recs)} 个文件,成功删除 {sum(1 for r in results if r['status'] == 'success')}",
"status_code": 200 if all(r["status"] == "success" for r in results) else 207 "status_code": 200 if all(r["status"] == "success" for r in results) else 207
} }
async def test_ragfilemgr(): async def test_ragfilemgr():
"""测试 RagFileMgr 类的 get_service_params""" """测试 RagFileMgr 类的 get_service_params"""
@ -505,11 +503,4 @@ async def test_ragfilemgr():
if __name__ == "__main__": if __name__ == "__main__":
asyncio.run(test_ragfilemgr()) asyncio.run(test_ragfilemgr())
## usage
# mgr = RagFileMgr(fiid)
# await mgr.add_file(request, params_kw)
# await mgr.delete_file(request, file_id)
##

View File

@ -5,6 +5,7 @@ from appPublic.log import debug, error, info
import time import time
import traceback import traceback
import json import json
import math
helptext = """kyrag API: helptext = """kyrag API:
@ -81,11 +82,44 @@ async def fusedsearch(request, params_kw, *params, **kw):
# orgid = "04J6VbxLqB_9RPMcgOv_8" # orgid = "04J6VbxLqB_9RPMcgOv_8"
# userid = "04J6VbxLqB_9RPMcgOv_8" # userid = "04J6VbxLqB_9RPMcgOv_8"
query = params_kw.get('query', '') query = params_kw.get('query', '')
fiids = params_kw.get('fiids', []) # 统一模式处理 limit 参数
limit = int(params_kw.get('limit', 5)) raw_limit = params_kw.get('limit') or (
params_kw.get('retrieval_setting', {}).get('top_k')
if isinstance(params_kw.get('retrieval_setting'), dict)
else None
)
# 标准化为整数值
if raw_limit is None:
limit = 5 # 两个来源都不存在时使用默认值
elif isinstance(raw_limit, (int, float)):
limit = int(raw_limit) # 数值类型直接转换
elif isinstance(raw_limit, str):
try:
# 字符串转换为整数
limit = int(raw_limit)
except (TypeError, ValueError):
limit = 5 # 转换失败使用默认值
else:
limit = 5 # 其他意外类型使用默认值
debug(f"limit: {limit}")
raw_fiids = params_kw.get('fiids') or params_kw.get('knowledge_id')
# 标准化为列表格式
if raw_fiids is None:
fiids = [] # 两个参数都不存在
elif isinstance(raw_fiids, list):
fiids = [str(item).strip() for item in raw_fiids] # 已经是列表
elif isinstance(raw_fiids, str):
# 处理逗号分隔的字符串或单个ID字符串
fiids = [f.strip() for f in raw_fiids.split(',') if f.strip()]
elif isinstance(raw_fiids, (int, float)):
fiids = [str(int(raw_fiids))] # 数值类型转为字符串列表
else:
fiids = [] # 其他意外类型
debug(f"fiids: {fiids}") debug(f"fiids: {fiids}")
if isinstance(fiids, str):
fiids = [f.strip() for f in fiids.split(',') if f.strip()]
# 验证 fiids的orgid与orgid = await f()是否一致 # 验证 fiids的orgid与orgid = await f()是否一致
if fiids: if fiids:
db = DBPools() db = DBPools()
@ -197,6 +231,7 @@ async def fusedsearch(request, params_kw, *params, **kw):
# 调用搜索端点 # 调用搜索端点
sum = limit + 5 sum = limit + 5
search_start = time.time() search_start = time.time()
debug(f"orgid: {orgid}")
result = await api_service.milvus_search_query( result = await api_service.milvus_search_query(
request=request, request=request,
query_vector=query_vector, query_vector=query_vector,
@ -240,8 +275,34 @@ async def fusedsearch(request, params_kw, *params, **kw):
timing_stats["total_time"] = time.time() - start_time timing_stats["total_time"] = time.time() - start_time
info(f"融合搜索完成,返回 {len(unique_results)} 条结果,总耗时: {timing_stats['total_time']:.3f}") info(f"融合搜索完成,返回 {len(unique_results)} 条结果,总耗时: {timing_stats['total_time']:.3f}")
return {"results": unique_results[:limit], "timing": timing_stats}
# debug(f"results: {unique_results[:limit]},timing: {timing_stats}")
# return {"results": unique_results[:limit], "timing": timing_stats}
dify_records = []
dify_result = []
for res in unique_results[:limit]:
rerank_score = res.get('rerank_score', 0)
score = 1 / (1 + math.exp(-rerank_score)) if rerank_score is not None else 1 - res.get('distance', 0)
score = max(0.0, min(1.0, score))
content = res.get('text', '')
title = res.get('metadata', {}).get('filename', 'Untitled')
document_id = res.get('metadata', {}).get('document_id', '')
dify_records.append({
"content": content,
"score": score,
"title": title
})
dify_result.append({
"content": content,
"title": title,
"metadata": {"document_id": document_id}
})
info(f"融合搜索完成,返回 {len(dify_records)} 条结果,总耗时: {(time.time() - start_time):.3f}")
debug(f"records: {dify_records}, result: {dify_result}")
return {"records": dify_records, "result": dify_result, "own":{"results": unique_results[:limit], "timing": timing_stats}}
except Exception as e: except Exception as e:
error(f"融合搜索失败: {str(e)}, 堆栈: {traceback.format_exc()}") error(f"融合搜索失败: {str(e)}, 堆栈: {traceback.format_exc()}")
return {"results": [], "timing": timing_stats} return {"results": [], "timing": timing_stats}

View File

@ -321,6 +321,7 @@ class APIService:
async def milvus_search_query(self, request, query_vector: List[float], userid: str, knowledge_base_ids: list, limit: int, offset: int, upappid: str, apiname: str, user: str) -> Dict[str, Any]: async def milvus_search_query(self, request, query_vector: List[float], userid: str, knowledge_base_ids: list, limit: int, offset: int, upappid: str, apiname: str, user: str) -> Dict[str, Any]:
"""根据用户知识库检索 Milvus""" """根据用户知识库检索 Milvus"""
request_id = str(uuid.uuid4()) request_id = str(uuid.uuid4())
debug(f"userid:{userid}")
debug(f"Request #{request_id} started for Milvus search") debug(f"Request #{request_id} started for Milvus search")
try: try:
uapi = UAPI(request, DictObject(**globals())) uapi = UAPI(request, DictObject(**globals()))