From 97c9e0f1fa55efe643f68af1d635a942024767f4 Mon Sep 17 00:00:00 2001 From: wangmeihua <13383952685@163.com> Date: Fri, 12 Sep 2025 17:58:21 +0800 Subject: [PATCH] rag --- rag/folderinfo.py | 548 ++++++++++++++++++++++++++---------------- rag/rag_operations.py | 450 ++++++++++++++++++++++++++++++++++ rag/ragapi.py | 517 +++++++++++++++++++++++---------------- 3 files changed, 1102 insertions(+), 413 deletions(-) create mode 100644 rag/rag_operations.py diff --git a/rag/folderinfo.py b/rag/folderinfo.py index ec66736..d858441 100644 --- a/rag/folderinfo.py +++ b/rag/folderinfo.py @@ -18,9 +18,18 @@ from filetxt.loader import fileloader,File2Text from ahserver.serverenv import get_serverenv from typing import List, Dict, Any from rag.service_opts import get_service_params, sor_get_service_params +from rag.rag_operations import RagOperations import json +from dataclasses import dataclass +from enum import Enum + + class RagFileMgr(FileMgr): + def __init__(self, fiid): + super().__init__(fiid) + self.rag_ops = RagOperations() + async def get_folder_ownerid(self, sor): fiid = self.fiid recs = await sor.R('kdb', {'id': self.fiid}) @@ -44,205 +53,6 @@ where a.orgid = b.orgid return r.quota, r.expired_date return None, None - async def get_doucment_chunks(self, realpath, timings): - """加载文件并进行文本分片""" - debug(f"加载文件: {realpath}") - start_load = time.time() - supported_formats = File2Text.supported_types() - debug(f"支持的文件格式:{supported_formats}") - ext = realpath.rsplit('.', 1)[1].lower() if '.' in realpath else '' - if ext not in supported_formats: - raise ValueError(f"不支持的文件格式: {ext}, 支持的格式: {', '.join(supported_formats)}") - text = fileloader(realpath) - text = re.sub(r'[^\u4e00-\u9fa5a-zA-Z0-9\s.;,\n/]', '', text) - timings["load_file"] = time.time() - start_load - debug(f"加载文件耗时: {timings['load_file']:.2f} 秒, 文本长度: {len(text)}") - - if not text or not text.strip(): - raise ValueError(f"文件 {realpath} 加载为空") - - document = Document(page_content=text) - text_splitter = RecursiveCharacterTextSplitter( - chunk_size=500, - chunk_overlap=100, - length_function=len - ) - debug("开始分片文件内容") - start_split = time.time() - chunks = text_splitter.split_documents([document]) - timings["split_text"] = time.time() - start_split - debug(f"文本分片耗时: {timings['split_text']:.2f} 秒, 分片数量: {len(chunks)}") - - if not chunks: - raise ValueError(f"文件 {realpath} 未生成任何文档块") - - return chunks - - async def docs_embedding(self, request, chunks, service_params, userid, timings): - """调用嵌入服务生成向量""" - debug("调用嵌入服务生成向量") - start_embedding = time.time() - texts = [chunk.page_content for chunk in chunks] - embeddings = [] - for i in range(0, len(texts), 10): - batch_texts = texts[i:i + 10] - batch_embeddings = await APIService().get_embeddings( - request=request, - texts=batch_texts, - upappid=service_params['embedding'], - apiname="BAAI/bge-m3", - user=userid - ) - embeddings.extend(batch_embeddings) - - if not embeddings or not all(len(vec) == 1024 for vec in embeddings): - raise ValueError("所有嵌入向量必须是长度为 1024 的浮点数列表") - - timings["generate_embeddings"] = time.time() - start_embedding - debug(f"生成嵌入向量耗时: {timings['generate_embeddings']:.2f} 秒, 嵌入数量: {len(embeddings)}") - return embeddings - - async def embedding_2_vdb(self, request, chunks, embeddings, realpath, orgid, fiid, id, service_params, userid, - db_type, timings): - """准备数据并插入 Milvus""" - debug(f"准备数据并调用插入文件端点: {realpath}") - filename = os.path.basename(realpath).rsplit('.', 1)[0] - ext = realpath.rsplit('.', 1)[1].lower() if '.' in realpath else '' - upload_time = datetime.now().isoformat() - - chunks_data = [ - { - "userid": orgid, - "knowledge_base_id": fiid, - "text": chunk.page_content, - "vector": embeddings[i], - "document_id": id, - "filename": filename + '.' + ext, - "file_path": realpath, - "upload_time": upload_time, - "file_type": ext, - } - for i, chunk in enumerate(chunks) - ] - - start_milvus = time.time() - for i in range(0, len(chunks_data), 10): - batch_chunks = chunks_data[i:i + 10] - result = await APIService().milvus_insert_document( - request=request, - chunks=batch_chunks, - db_type=db_type, - upappid=service_params['vdb'], - apiname="milvus/insertdocument", - user=userid - ) - if result.get("status") != "success": - raise ValueError(result.get("message", "Milvus 插入失败")) - - timings["insert_milvus"] = time.time() - start_milvus - debug(f"Milvus 插入耗时: {timings['insert_milvus']:.2f} 秒") - return chunks_data - - async def get_triples(self, request, chunks, service_params, userid, timings): - """调用三元组抽取服务""" - debug("调用三元组抽取服务") - start_triples = time.time() - chunk_texts = [doc.page_content for doc in chunks] - triples = [] - for i, chunk in enumerate(chunk_texts): - result = await APIService().extract_triples( - request=request, - text=chunk, - upappid=service_params['triples'], - apiname="Babelscape/mrebel-large", - user=userid - ) - if isinstance(result, list): - triples.extend(result) - debug(f"分片 {i + 1} 抽取到 {len(result)} 个三元组") - else: - error(f"分片 {i + 1} 处理失败: {str(result)}") - - unique_triples = [] - seen = set() - for t in triples: - identifier = (t['head'].lower(), t['tail'].lower(), t['type'].lower()) - if identifier not in seen: - seen.add(identifier) - unique_triples.append(t) - else: - for existing in unique_triples: - if (existing['head'].lower() == t['head'].lower() and - existing['tail'].lower() == t['tail'].lower() and - len(t['type']) > len(existing['type'])): - unique_triples.remove(existing) - unique_triples.append(t) - debug(f"替换三元组为更具体类型: {t}") - break - - timings["extract_triples"] = time.time() - start_triples - debug(f"三元组抽取耗时: {timings['extract_triples']:.2f} 秒, 抽取到 {len(unique_triples)} 个三元组") - return unique_triples - - async def triple2graphdb(self, request, unique_triples, id, fiid, orgid, service_params, userid, timings): - """调用 Neo4j 插入三元组""" - debug(f"插入 {len(unique_triples)} 个三元组到 Neo4j") - start_neo4j = time.time() - if unique_triples: - for i in range(0, len(unique_triples), 30): - batch_triples = unique_triples[i:i + 30] - neo4j_result = await APIService().neo4j_insert_triples( - request=request, - triples=batch_triples, - document_id=id, - knowledge_base_id=fiid, - userid=orgid, - upappid=service_params['gdb'], - apiname="neo4j/inserttriples", - user=userid - ) - if neo4j_result.get("status") != "success": - raise ValueError(f"Neo4j 三元组插入失败: {neo4j_result.get('message', '未知错误')}") - info(f"文件三元组成功插入 Neo4j: {neo4j_result.get('message')}") - timings["insert_neo4j"] = time.time() - start_neo4j - debug(f"Neo4j 插入耗时: {timings['insert_neo4j']:.2f} 秒") - else: - debug("未抽取到三元组") - timings["insert_neo4j"] = 0.0 - - async def delete_from_milvus(self, request, orgid, realpath, fiid, id, service_params, userid, db_type): - """调用 Milvus 删除文档""" - debug(f"调用删除文件端点: userid={orgid}, file_path={realpath}, knowledge_base_id={fiid}, document_id={id}") - milvus_result = await APIService().milvus_delete_document( - request=request, - userid=orgid, - file_path=realpath, - knowledge_base_id=fiid, - document_id=id, - db_type=db_type, - upappid=service_params['vdb'], - apiname="milvus/deletedocument", - user=userid - ) - if milvus_result.get("status") != "success": - raise ValueError(milvus_result.get("message", "Milvus 删除失败")) - - async def delete_from_neo4j(self, request, id, service_params, userid): - """调用 Neo4j 删除文档""" - debug(f"调用 Neo4j 删除文档端点: document_id={id}") - neo4j_result = await APIService().neo4j_delete_document( - request=request, - document_id=id, - upappid=service_params['gdb'], - apiname="neo4j/deletedocument", - user=userid - ) - if neo4j_result.get("status") != "success": - raise ValueError(neo4j_result.get("message", "Neo4j 删除失败")) - nodes_deleted = neo4j_result.get("nodes_deleted", 0) - rels_deleted = neo4j_result.get("rels_deleted", 0) - info(f"成功删除 document_id={id} 的 {nodes_deleted} 个 Neo4j 节点和 {rels_deleted} 个关系") - return nodes_deleted, rels_deleted async def file_uploaded(self, request, ns, userid): """将文档插入 Milvus 并抽取三元组到 Neo4j""" @@ -272,11 +82,11 @@ where a.orgid = b.orgid if not service_params: raise ValueError("无法获取服务参数") - chunks = await self.get_doucment_chunks(realpath, timings) - embeddings = await self.docs_embedding(request, chunks, service_params, userid, timings) - await self.embedding_2_vdb(request, chunks, embeddings, realpath, orgid, fiid, id, service_params, userid,db_type, timings) - triples = await self.get_triples(request, chunks, service_params, userid, timings) - await self.triple2graphdb(request, triples, id, fiid, orgid, service_params, userid, timings) + chunks = await self.rag_ops.load_and_chunk_document(realpath, timings) + embeddings = await self.rag_ops.generate_embeddings(request, chunks, service_params, userid, timings) + await self.rag_ops.insert_to_vector_db(request, chunks, embeddings, realpath, orgid, fiid, id, service_params, userid,db_type, timings) + triples = await self.rag_ops.extract_triples(request, chunks, service_params, userid, timings) + await self.rag_ops.insert_to_graph_db(request, triples, id, fiid, orgid, service_params, userid, timings) timings["total"] = time.time() - start_total debug(f"总耗时: {timings['total']:.2f} 秒") @@ -329,13 +139,13 @@ where a.orgid = b.orgid raise ValueError("无法获取服务参数") # 调用 Milvus 删除 - await self.delete_from_milvus(request, orgid, realpath, fiid, id, service_params, userid, db_type) + await self.rag_ops.delete_from_vector_db(request, orgid, realpath, fiid, id, service_params, userid, db_type) # 调用 Neo4j 删除 neo4j_deleted_nodes = 0 neo4j_deleted_rels = 0 try: - nodes_deleted, rels_deleted = await self.delete_from_neo4j(request, id, service_params, userid) + nodes_deleted, rels_deleted = await self.rag_ops.delete_from_graph_db(request, id, service_params, userid) neo4j_deleted_nodes += nodes_deleted neo4j_deleted_rels += rels_deleted total_nodes_deleted += nodes_deleted @@ -369,6 +179,332 @@ where a.orgid = b.orgid "message": f"处理 {len(recs)} 个文件,成功删除 {sum(1 for r in results if r['status'] == 'success')} 个", "status_code": 200 if all(r["status"] == "success" for r in results) else 207 } + # async def get_doucment_chunks(self, realpath, timings): + # """加载文件并进行文本分片""" + # debug(f"加载文件: {realpath}") + # start_load = time.time() + # supported_formats = File2Text.supported_types() + # debug(f"支持的文件格式:{supported_formats}") + # ext = realpath.rsplit('.', 1)[1].lower() if '.' in realpath else '' + # if ext not in supported_formats: + # raise ValueError(f"不支持的文件格式: {ext}, 支持的格式: {', '.join(supported_formats)}") + # text = fileloader(realpath) + # text = re.sub(r'[^\u4e00-\u9fa5a-zA-Z0-9\s.;,\n/]', '', text) + # timings["load_file"] = time.time() - start_load + # debug(f"加载文件耗时: {timings['load_file']:.2f} 秒, 文本长度: {len(text)}") + # + # if not text or not text.strip(): + # raise ValueError(f"文件 {realpath} 加载为空") + # + # document = Document(page_content=text) + # text_splitter = RecursiveCharacterTextSplitter( + # chunk_size=500, + # chunk_overlap=100, + # length_function=len + # ) + # debug("开始分片文件内容") + # start_split = time.time() + # chunks = text_splitter.split_documents([document]) + # timings["split_text"] = time.time() - start_split + # debug(f"文本分片耗时: {timings['split_text']:.2f} 秒, 分片数量: {len(chunks)}") + # + # if not chunks: + # raise ValueError(f"文件 {realpath} 未生成任何文档块") + # + # return chunks + # + # async def docs_embedding(self, request, chunks, service_params, userid, timings): + # """调用嵌入服务生成向量""" + # debug("调用嵌入服务生成向量") + # start_embedding = time.time() + # texts = [chunk.page_content for chunk in chunks] + # embeddings = [] + # for i in range(0, len(texts), 10): + # batch_texts = texts[i:i + 10] + # batch_embeddings = await APIService().get_embeddings( + # request=request, + # texts=batch_texts, + # upappid=service_params['embedding'], + # apiname="BAAI/bge-m3", + # user=userid + # ) + # embeddings.extend(batch_embeddings) + # + # if not embeddings or not all(len(vec) == 1024 for vec in embeddings): + # raise ValueError("所有嵌入向量必须是长度为 1024 的浮点数列表") + # + # timings["generate_embeddings"] = time.time() - start_embedding + # debug(f"生成嵌入向量耗时: {timings['generate_embeddings']:.2f} 秒, 嵌入数量: {len(embeddings)}") + # return embeddings + # + # async def embedding_2_vdb(self, request, chunks, embeddings, realpath, orgid, fiid, id, service_params, userid, + # db_type, timings): + # """准备数据并插入 Milvus""" + # debug(f"准备数据并调用插入文件端点: {realpath}") + # filename = os.path.basename(realpath).rsplit('.', 1)[0] + # ext = realpath.rsplit('.', 1)[1].lower() if '.' in realpath else '' + # upload_time = datetime.now().isoformat() + # + # chunks_data = [ + # { + # "userid": orgid, + # "knowledge_base_id": fiid, + # "text": chunk.page_content, + # "vector": embeddings[i], + # "document_id": id, + # "filename": filename + '.' + ext, + # "file_path": realpath, + # "upload_time": upload_time, + # "file_type": ext, + # } + # for i, chunk in enumerate(chunks) + # ] + # + # start_milvus = time.time() + # for i in range(0, len(chunks_data), 10): + # batch_chunks = chunks_data[i:i + 10] + # result = await APIService().milvus_insert_document( + # request=request, + # chunks=batch_chunks, + # db_type=db_type, + # upappid=service_params['vdb'], + # apiname="milvus/insertdocument", + # user=userid + # ) + # if result.get("status") != "success": + # raise ValueError(result.get("message", "Milvus 插入失败")) + # + # timings["insert_milvus"] = time.time() - start_milvus + # debug(f"Milvus 插入耗时: {timings['insert_milvus']:.2f} 秒") + # return chunks_data + # + # async def get_triples(self, request, chunks, service_params, userid, timings): + # """调用三元组抽取服务""" + # debug("调用三元组抽取服务") + # start_triples = time.time() + # chunk_texts = [doc.page_content for doc in chunks] + # triples = [] + # for i, chunk in enumerate(chunk_texts): + # result = await APIService().extract_triples( + # request=request, + # text=chunk, + # upappid=service_params['triples'], + # apiname="Babelscape/mrebel-large", + # user=userid + # ) + # if isinstance(result, list): + # triples.extend(result) + # debug(f"分片 {i + 1} 抽取到 {len(result)} 个三元组") + # else: + # error(f"分片 {i + 1} 处理失败: {str(result)}") + # + # unique_triples = [] + # seen = set() + # for t in triples: + # identifier = (t['head'].lower(), t['tail'].lower(), t['type'].lower()) + # if identifier not in seen: + # seen.add(identifier) + # unique_triples.append(t) + # else: + # for existing in unique_triples: + # if (existing['head'].lower() == t['head'].lower() and + # existing['tail'].lower() == t['tail'].lower() and + # len(t['type']) > len(existing['type'])): + # unique_triples.remove(existing) + # unique_triples.append(t) + # debug(f"替换三元组为更具体类型: {t}") + # break + # + # timings["extract_triples"] = time.time() - start_triples + # debug(f"三元组抽取耗时: {timings['extract_triples']:.2f} 秒, 抽取到 {len(unique_triples)} 个三元组") + # return unique_triples + # + # async def triple2graphdb(self, request, unique_triples, id, fiid, orgid, service_params, userid, timings): + # """调用 Neo4j 插入三元组""" + # debug(f"插入 {len(unique_triples)} 个三元组到 Neo4j") + # start_neo4j = time.time() + # if unique_triples: + # for i in range(0, len(unique_triples), 30): + # batch_triples = unique_triples[i:i + 30] + # neo4j_result = await APIService().neo4j_insert_triples( + # request=request, + # triples=batch_triples, + # document_id=id, + # knowledge_base_id=fiid, + # userid=orgid, + # upappid=service_params['gdb'], + # apiname="neo4j/inserttriples", + # user=userid + # ) + # if neo4j_result.get("status") != "success": + # raise ValueError(f"Neo4j 三元组插入失败: {neo4j_result.get('message', '未知错误')}") + # info(f"文件三元组成功插入 Neo4j: {neo4j_result.get('message')}") + # timings["insert_neo4j"] = time.time() - start_neo4j + # debug(f"Neo4j 插入耗时: {timings['insert_neo4j']:.2f} 秒") + # else: + # debug("未抽取到三元组") + # timings["insert_neo4j"] = 0.0 + # + # async def delete_from_milvus(self, request, orgid, realpath, fiid, id, service_params, userid, db_type): + # """调用 Milvus 删除文档""" + # debug(f"调用删除文件端点: userid={orgid}, file_path={realpath}, knowledge_base_id={fiid}, document_id={id}") + # milvus_result = await APIService().milvus_delete_document( + # request=request, + # userid=orgid, + # file_path=realpath, + # knowledge_base_id=fiid, + # document_id=id, + # db_type=db_type, + # upappid=service_params['vdb'], + # apiname="milvus/deletedocument", + # user=userid + # ) + # if milvus_result.get("status") != "success": + # raise ValueError(milvus_result.get("message", "Milvus 删除失败")) + # + # async def delete_from_neo4j(self, request, id, service_params, userid): + # """调用 Neo4j 删除文档""" + # debug(f"调用 Neo4j 删除文档端点: document_id={id}") + # neo4j_result = await APIService().neo4j_delete_document( + # request=request, + # document_id=id, + # upappid=service_params['gdb'], + # apiname="neo4j/deletedocument", + # user=userid + # ) + # if neo4j_result.get("status") != "success": + # raise ValueError(neo4j_result.get("message", "Neo4j 删除失败")) + # nodes_deleted = neo4j_result.get("nodes_deleted", 0) + # rels_deleted = neo4j_result.get("rels_deleted", 0) + # info(f"成功删除 document_id={id} 的 {nodes_deleted} 个 Neo4j 节点和 {rels_deleted} 个关系") + # return nodes_deleted, rels_deleted + + # async def file_uploaded(self, request, ns, userid): + # """将文档插入 Milvus 并抽取三元组到 Neo4j""" + # debug(f'Received ns: {ns=}') + # env = request._run_ns + # realpath = ns.get('realpath', '') + # fiid = ns.get('fiid', '') + # id = ns.get('id', '') + # orgid = ns.get('ownerid', '') + # db_type = '' + # + # debug(f'Inserting document: file_path={realpath}, userid={orgid}, db_type={db_type}, knowledge_base_id={fiid}, document_id={id}') + # + # timings = {} + # start_total = time.time() + # + # try: + # if not orgid or not fiid or not id: + # raise ValueError("orgid、fiid 和 id 不能为空") + # if len(orgid) > 32 or len(fiid) > 255: + # raise ValueError("orgid 或 fiid 的长度超出限制") + # if not os.path.exists(realpath): + # raise ValueError(f"文件 {realpath} 不存在") + # + # # 获取服务参数 + # service_params = await get_service_params(orgid) + # if not service_params: + # raise ValueError("无法获取服务参数") + # + # chunks = await self.get_doucment_chunks(realpath, timings) + # embeddings = await self.docs_embedding(request, chunks, service_params, userid, timings) + # await self.embedding_2_vdb(request, chunks, embeddings, realpath, orgid, fiid, id, service_params, userid,db_type, timings) + # triples = await self.get_triples(request, chunks, service_params, userid, timings) + # await self.triple2graphdb(request, triples, id, fiid, orgid, service_params, userid, timings) + # + # timings["total"] = time.time() - start_total + # debug(f"总耗时: {timings['total']:.2f} 秒") + # return { + # "status": "success", + # "userid": orgid, + # "document_id": id, + # "collection_name": "ragdb", + # "timings": timings, + # "unique_triples": triples, + # "message": f"文件 {realpath} 成功嵌入并处理三元组", + # "status_code": 200 + # } + # except Exception as e: + # error(f"插入文档失败: {str(e)}, 堆栈: {traceback.format_exc()}") + # timings["total"] = time.time() - start_total + # return { + # "status": "error", + # "document_id": id, + # "collection_name": "ragdb", + # "timings": timings, + # "message": f"插入文档失败: {str(e)}", + # "status_code": 400 + # } + # + # async def file_deleted(self, request, recs, userid): + # """删除用户指定文件数据,包括 Milvus 和 Neo4j 中的记录""" + # if not isinstance(recs, list): + # recs = [recs] + # results = [] + # total_nodes_deleted = 0 + # total_rels_deleted = 0 + # + # for rec in recs: + # id = rec.get('id', '') + # realpath = rec.get('realpath', '') + # fiid = rec.get('fiid', '') + # orgid = rec.get('ownerid', '') + # db_type = '' + # collection_name = "ragdb" if not db_type else f"ragdb_{db_type}" + # + # try: + # required_fields = ['id', 'realpath', 'fiid', 'ownerid'] + # missing_fields = [field for field in required_fields if not rec.get(field, '')] + # if missing_fields: + # raise ValueError(f"缺少必填字段: {', '.join(missing_fields)}") + # + # service_params = await get_service_params(orgid) + # if not service_params: + # raise ValueError("无法获取服务参数") + # + # # 调用 Milvus 删除 + # await self.delete_from_milvus(request, orgid, realpath, fiid, id, service_params, userid, db_type) + # + # # 调用 Neo4j 删除 + # neo4j_deleted_nodes = 0 + # neo4j_deleted_rels = 0 + # try: + # nodes_deleted, rels_deleted = await self.delete_from_neo4j(request, id, service_params, userid) + # neo4j_deleted_nodes += nodes_deleted + # neo4j_deleted_rels += rels_deleted + # total_nodes_deleted += nodes_deleted + # total_rels_deleted += rels_deleted + # except Exception as e: + # error(f"删除 document_id={id} 的 Neo4j 数据失败: {str(e)}") + # + # results.append({ + # "status": "success", + # "collection_name": collection_name, + # "document_id": id, + # "message": f"成功删除文件 {realpath} 的 Milvus 记录,{neo4j_deleted_nodes} 个 Neo4j 节点,{neo4j_deleted_rels} 个 Neo4j 关系", + # "status_code": 200 + # }) + # + # except Exception as e: + # error(f"删除文档 {realpath} 失败: {str(e)}, 堆栈: {traceback.format_exc()}") + # results.append({ + # "status": "error", + # "collection_name": collection_name, + # "document_id": id, + # "message": f"删除文档 {realpath} 失败: {str(e)}", + # "status_code": 400 + # }) + # + # return { + # "status": "success" if all(r["status"] == "success" for r in results) else "partial", + # "results": results, + # "total_nodes_deleted": total_nodes_deleted, + # "total_rels_deleted": total_rels_deleted, + # "message": f"处理 {len(recs)} 个文件,成功删除 {sum(1 for r in results if r['status'] == 'success')} 个", + # "status_code": 200 if all(r["status"] == "success" for r in results) else 207 + # } + # async def test_ragfilemgr(): # """测试 RagFileMgr 类的 get_service_params""" diff --git a/rag/rag_operations.py b/rag/rag_operations.py new file mode 100644 index 0000000..493134a --- /dev/null +++ b/rag/rag_operations.py @@ -0,0 +1,450 @@ +""" +RAG 操作的通用函数库 +包含文档处理、搜索、嵌入等通用操作,供 folderinfo.py 和 ragapi.py 共同使用 +""" + +import os +import re +import time +import math +from datetime import datetime +from typing import List, Dict, Any, Optional +from langchain_core.documents import Document +from langchain_text_splitters import RecursiveCharacterTextSplitter + +from appPublic.log import debug, error, info +from filetxt.loader import fileloader, File2Text +from rag.uapi_service import APIService +from rag.service_opts import get_service_params +from rag.transaction_manager import TransactionManager, OperationType + + +class RagOperations: + """RAG 操作类,提供所有通用的 RAG 操作""" + + def __init__(self): + self.api_service = APIService() + + async def load_and_chunk_document(self, realpath: str, timings: Dict, + transaction_mgr: TransactionManager = None) -> List[Document]: + """加载文件并进行文本分片""" + debug(f"加载文件: {realpath}") + start_load = time.time() + + # 检查文件格式支持 + supported_formats = File2Text.supported_types() + debug(f"支持的文件格式:{supported_formats}") + ext = realpath.rsplit('.', 1)[1].lower() if '.' in realpath else '' + if ext not in supported_formats: + raise ValueError(f"不支持的文件格式: {ext}, 支持的格式: {', '.join(supported_formats)}") + + # 加载文件内容 + text = fileloader(realpath) + text = re.sub(r'[^\u4e00-\u9fa5a-zA-Z0-9\s.;,\n/]', '', text) + timings["load_file"] = time.time() - start_load + debug(f"加载文件耗时: {timings['load_file']:.2f} 秒, 文本长度: {len(text)}") + + if not text or not text.strip(): + raise ValueError(f"文件 {realpath} 加载为空") + + # 分片处理 + document = Document(page_content=text) + text_splitter = RecursiveCharacterTextSplitter( + chunk_size=500, + chunk_overlap=100, + length_function=len + ) + debug("开始分片文件内容") + start_split = time.time() + chunks = text_splitter.split_documents([document]) + timings["split_text"] = time.time() - start_split + debug(f"文本分片耗时: {timings['split_text']:.2f} 秒, 分片数量: {len(chunks)}") + + if not chunks: + raise ValueError(f"文件 {realpath} 未生成任何文档块") + + # 记录事务操作 + if transaction_mgr: + transaction_mgr.add_operation( + OperationType.FILE_LOAD, + {'realpath': realpath, 'chunks_count': len(chunks)} + ) + + return chunks + + async def generate_embeddings(self, request, chunks: List[Document], service_params: Dict, + userid: str, timings: Dict, + transaction_mgr: TransactionManager = None) -> List[List[float]]: + """生成嵌入向量""" + debug("调用嵌入服务生成向量") + start_embedding = time.time() + texts = [chunk.page_content for chunk in chunks] + embeddings = [] + + # 批量处理嵌入 + for i in range(0, len(texts), 10): + batch_texts = texts[i:i + 10] + batch_embeddings = await self.api_service.get_embeddings( + request=request, + texts=batch_texts, + upappid=service_params['embedding'], + apiname="BAAI/bge-m3", + user=userid + ) + embeddings.extend(batch_embeddings) + + if not embeddings or not all(len(vec) == 1024 for vec in embeddings): + raise ValueError("所有嵌入向量必须是长度为 1024 的浮点数列表") + + timings["generate_embeddings"] = time.time() - start_embedding + debug(f"生成嵌入向量耗时: {timings['generate_embeddings']:.2f} 秒, 嵌入数量: {len(embeddings)}") + + # 记录事务操作 + if transaction_mgr: + transaction_mgr.add_operation( + OperationType.EMBEDDING, + {'embeddings_count': len(embeddings)} + ) + + return embeddings + + async def insert_to_vector_db(self, request, chunks: List[Document], embeddings: List[List[float]], + realpath: str, orgid: str, fiid: str, id: str, service_params: Dict, + userid: str, db_type: str, timings: Dict, + transaction_mgr: TransactionManager = None): + """插入向量数据库""" + debug(f"准备数据并调用插入文件端点: {realpath}") + filename = os.path.basename(realpath).rsplit('.', 1)[0] + ext = realpath.rsplit('.', 1)[1].lower() if '.' in realpath else '' + upload_time = datetime.now().isoformat() + + chunks_data = [ + { + "userid": orgid, + "knowledge_base_id": fiid, + "text": chunk.page_content, + "vector": embeddings[i], + "document_id": id, + "filename": filename + '.' + ext, + "file_path": realpath, + "upload_time": upload_time, + "file_type": ext, + } + for i, chunk in enumerate(chunks) + ] + + start_milvus = time.time() + for i in range(0, len(chunks_data), 10): + batch_chunks = chunks_data[i:i + 10] + result = await self.api_service.milvus_insert_document( + request=request, + chunks=batch_chunks, + db_type=db_type, + upappid=service_params['vdb'], + apiname="milvus/insertdocument", + user=userid + ) + if result.get("status") != "success": + raise ValueError(result.get("message", "Milvus 插入失败")) + + timings["insert_milvus"] = time.time() - start_milvus + debug(f"Milvus 插入耗时: {timings['insert_milvus']:.2f} 秒") + + # 记录事务操作,包含回滚函数 + if transaction_mgr: + async def rollback_vdb_insert(data, context): + await self.delete_from_vector_db( + context['request'], data['orgid'], data['realpath'], + data['fiid'], data['id'], context['service_params'], + context['userid'], data['db_type'] + ) + return f"已回滚向量数据库插入: {data['id']}" + + transaction_mgr.add_operation( + OperationType.VDB_INSERT, + { + 'orgid': orgid, 'realpath': realpath, 'fiid': fiid, + 'id': id, 'db_type': db_type + }, + rollback_func=rollback_vdb_insert + ) + + return chunks_data + + async def extract_triples(self, request, chunks: List[Document], service_params: Dict, + userid: str, timings: Dict, + transaction_mgr: TransactionManager = None) -> List[Dict]: + """抽取三元组""" + debug("调用三元组抽取服务") + start_triples = time.time() + chunk_texts = [doc.page_content for doc in chunks] + triples = [] + + for i, chunk in enumerate(chunk_texts): + result = await self.api_service.extract_triples( + request=request, + text=chunk, + upappid=service_params['triples'], + apiname="Babelscape/mrebel-large", + user=userid + ) + if isinstance(result, list): + triples.extend(result) + debug(f"分片 {i + 1} 抽取到 {len(result)} 个三元组") + else: + error(f"分片 {i + 1} 处理失败: {str(result)}") + + # 去重和优化三元组 + unique_triples = self._deduplicate_triples(triples) + + timings["extract_triples"] = time.time() - start_triples + debug(f"三元组抽取耗时: {timings['extract_triples']:.2f} 秒, 抽取到 {len(unique_triples)} 个三元组") + + # 记录事务操作 + if transaction_mgr: + transaction_mgr.add_operation( + OperationType.TRIPLES_EXTRACT, + {'triples_count': len(unique_triples)} + ) + + return unique_triples + + async def insert_to_graph_db(self, request, triples: List[Dict], id: str, fiid: str, + orgid: str, service_params: Dict, userid: str, timings: Dict, + transaction_mgr: TransactionManager = None): + """插入图数据库""" + debug(f"插入 {len(triples)} 个三元组到 Neo4j") + start_neo4j = time.time() + + if triples: + for i in range(0, len(triples), 30): + batch_triples = triples[i:i + 30] + neo4j_result = await self.api_service.neo4j_insert_triples( + request=request, + triples=batch_triples, + document_id=id, + knowledge_base_id=fiid, + userid=orgid, + upappid=service_params['gdb'], + apiname="neo4j/inserttriples", + user=userid + ) + if neo4j_result.get("status") != "success": + raise ValueError(f"Neo4j 三元组插入失败: {neo4j_result.get('message', '未知错误')}") + info(f"文件三元组成功插入 Neo4j: {neo4j_result.get('message')}") + + timings["insert_neo4j"] = time.time() - start_neo4j + debug(f"Neo4j 插入耗时: {timings['insert_neo4j']:.2f} 秒") + else: + debug("未抽取到三元组") + timings["insert_neo4j"] = 0.0 + + # 记录事务操作,包含回滚函数 + if transaction_mgr: + async def rollback_gdb_insert(data, context): + await self.delete_from_graph_db( + context['request'], data['id'], + context['service_params'], context['userid'] + ) + return f"已回滚图数据库插入: {data['id']}" + + transaction_mgr.add_operation( + OperationType.GDB_INSERT, + {'id': id, 'triples_count': len(triples)}, + rollback_func=rollback_gdb_insert + ) + + async def delete_from_vector_db(self, request, orgid: str, realpath: str, fiid: str, + id: str, service_params: Dict, userid: str, db_type: str): + """从向量数据库删除文档""" + debug(f"调用删除文件端点: userid={orgid}, file_path={realpath}, knowledge_base_id={fiid}, document_id={id}") + milvus_result = await self.api_service.milvus_delete_document( + request=request, + userid=orgid, + file_path=realpath, + knowledge_base_id=fiid, + document_id=id, + db_type=db_type, + upappid=service_params['vdb'], + apiname="milvus/deletedocument", + user=userid + ) + if milvus_result.get("status") != "success": + raise ValueError(milvus_result.get("message", "Milvus 删除失败")) + + async def delete_from_graph_db(self, request, id: str, service_params: Dict, userid: str): + """从图数据库删除文档""" + debug(f"调用 Neo4j 删除文档端点: document_id={id}") + neo4j_result = await self.api_service.neo4j_delete_document( + request=request, + document_id=id, + upappid=service_params['gdb'], + apiname="neo4j/deletedocument", + user=userid + ) + if neo4j_result.get("status") != "success": + raise ValueError(neo4j_result.get("message", "Neo4j 删除失败")) + nodes_deleted = neo4j_result.get("nodes_deleted", 0) + rels_deleted = neo4j_result.get("rels_deleted", 0) + info(f"成功删除 document_id={id} 的 {nodes_deleted} 个 Neo4j 节点和 {rels_deleted} 个关系") + return nodes_deleted, rels_deleted + + async def extract_entities(self, request, query: str, service_params: Dict, userid: str, + transaction_mgr: TransactionManager = None) -> List[str]: + """提取实体""" + debug(f"提取查询实体: {query}") + entities = await self.api_service.extract_entities( + request=request, + query=query, + upappid=service_params['entities'], + apiname="LTP/small", + user=userid + ) + + # 记录事务操作 + if transaction_mgr: + transaction_mgr.add_operation( + OperationType.ENTITY_EXTRACT, + {'query': query, 'entities_count': len(entities)} + ) + + return entities + + async def match_triplets(self, request, query: str, entities: List[str], orgid: str, + fiids: List[str], service_params: Dict, userid: str, + transaction_mgr: TransactionManager = None) -> List[Dict]: + """匹配三元组""" + debug("开始三元组匹配") + all_triplets = [] + + for kb_id in fiids: + debug(f"调用 Neo4j 三元组匹配: knowledge_base_id={kb_id}") + try: + neo4j_result = await self.api_service.neo4j_match_triplets( + request=request, + query=query, + query_entities=entities, + userid=orgid, + knowledge_base_id=kb_id, + upappid=service_params['gdb'], + apiname="neo4j/matchtriplets", + user=userid + ) + if neo4j_result.get("status") == "success": + triplets = neo4j_result.get("triplets", []) + all_triplets.extend(triplets) + debug(f"知识库 {kb_id} 匹配到 {len(triplets)} 个三元组") + else: + error( + f"Neo4j 三元组匹配失败: knowledge_base_id={kb_id}, 错误: {neo4j_result.get('message', '未知错误')}") + except Exception as e: + error(f"Neo4j 三元组匹配失败: knowledge_base_id={kb_id}, 错误: {str(e)}") + continue + + # 记录事务操作 + if transaction_mgr: + transaction_mgr.add_operation( + OperationType.TRIPLET_MATCH, + {'query': query, 'triplets_count': len(all_triplets)} + ) + + return all_triplets + + async def vector_search(self, request, query_vector: List[float], orgid: str, + fiids: List[str], limit: int, service_params: Dict, userid: str, + transaction_mgr: TransactionManager = None) -> List[Dict]: + """向量搜索""" + debug("开始向量搜索") + result = await self.api_service.milvus_search_query( + request=request, + query_vector=query_vector, + userid=orgid, + knowledge_base_ids=fiids, + limit=limit, + offset=0, + upappid=service_params['vdb'], + apiname="mlvus/searchquery", + user=userid + ) + + if result.get("status") != "success": + raise ValueError(f"向量搜索失败: {result.get('message', '未知错误')}") + + search_results = result.get("results", []) + + # 记录事务操作 + if transaction_mgr: + transaction_mgr.add_operation( + OperationType.VECTOR_SEARCH, + {'results_count': len(search_results)} + ) + + return search_results + + async def rerank_results(self, request, query: str, results: List[Dict], top_n: int, + service_params: Dict, userid: str, + transaction_mgr: TransactionManager = None) -> List[Dict]: + """重排序结果""" + debug("开始重排序") + reranked_results = await self.api_service.rerank_results( + request=request, + query=query, + results=results, + top_n=top_n, + upappid=service_params['reranker'], + apiname="BAAI/bge-reranker-v2-m3", + user=userid + ) + + # 记录事务操作 + if transaction_mgr: + transaction_mgr.add_operation( + OperationType.RERANK, + {'input_count': len(results), 'output_count': len(reranked_results)} + ) + + return reranked_results + + def _deduplicate_triples(self, triples: List[Dict]) -> List[Dict]: + """去重和优化三元组""" + unique_triples = [] + seen = set() + + for t in triples: + identifier = (t['head'].lower(), t['tail'].lower(), t['type'].lower()) + if identifier not in seen: + seen.add(identifier) + unique_triples.append(t) + else: + # 如果发现更具体的类型,则替换 + for existing in unique_triples: + if (existing['head'].lower() == t['head'].lower() and + existing['tail'].lower() == t['tail'].lower() and + len(t['type']) > len(existing['type'])): + unique_triples.remove(existing) + unique_triples.append(t) + debug(f"替换三元组为更具体类型: {t}") + break + + return unique_triples + + def format_search_results(self, results: List[Dict], limit: int) -> List[Dict]: + """格式化搜索结果为统一格式""" + formatted_results = [] + + for res in results[:limit]: + rerank_score = res.get('rerank_score', 0) + score = 1 / (1 + math.exp(-rerank_score)) if rerank_score is not None else 1 - res.get('distance', 0) + score = max(0.0, min(1.0, score)) + + content = res.get('text', '') + title = res.get('metadata', {}).get('filename', 'Untitled') + document_id = res.get('metadata', {}).get('document_id', '') + + formatted_results.append({ + "content": content, + "title": title, + "metadata": {"document_id": document_id, "score": score}, + }) + + return formatted_results \ No newline at end of file diff --git a/rag/ragapi.py b/rag/ragapi.py index e4520de..bbc8084 100644 --- a/rag/ragapi.py +++ b/rag/ragapi.py @@ -6,6 +6,7 @@ import traceback import json import math from rag.service_opts import get_service_params, sor_get_service_params +from rag.rag_operations import RagOperations helptext = """kyrag API: @@ -131,14 +132,97 @@ async def fusedsearch(request, params_kw, *params, **kw): debug(f"fiids: {fiids}") # 验证 fiids的orgid与orgid = await f()是否一致 + await _validate_fiids_orgid(fiids, orgid, kw) + + service_params = await get_service_params(orgid) + if not service_params: + raise ValueError("无法获取服务参数") + + try: + timing_stats = {} + start_time = time.time() + rag_ops = RagOperations() + + entity_extract_start = time.time() + query_entities = await rag_ops.extract_entities(request, query, service_params, userid) + timing_stats["entity_extraction"] = time.time() - entity_extract_start + debug(f"提取实体: {query_entities}, 耗时: {timing_stats['entity_extraction']:.3f} 秒") + + triplet_match_start = time.time() + all_triplets = await rag_ops.match_triplets(request, query, query_entities, orgid, fiids, service_params, userid) + timing_stats["triplet_matching"] = time.time() - triplet_match_start + debug(f"三元组匹配总耗时: {timing_stats['triplet_matching']:.3f} 秒") + + triplet_text_start = time.time() + combined_text = _combine_query_with_triplets(query, all_triplets) + timing_stats["triplet_text_combine"] = time.time() - triplet_text_start + debug(f"拼接三元组文本耗时: {timing_stats['triplet_text_combine']:.3f} 秒") + + vector_start = time.time() + query_vector = await rag_ops.api_service.get_embeddings( + request=request, + texts=[combined_text], + upappid=service_params['embedding'], + apiname="BAAI/bge-m3", + user=userid + ) + if not query_vector or not all(len(vec) == 1024 for vec in query_vector): + raise ValueError("查询向量必须是长度为 1024 的浮点数列表") + query_vector = query_vector[0] + timing_stats["vector_generation"] = time.time() - vector_start + debug(f"生成查询向量耗时: {timing_stats['vector_generation']:.3f} 秒") + + search_start = time.time() + search_limit = limit + 5 + search_results = await rag_ops.vector_search( + request, query_vector, orgid, fiids, search_limit, service_params, userid + ) + timing_stats["vector_search"] = time.time() - search_start + debug(f"向量搜索耗时: {timing_stats['vector_search']:.3f} 秒") + debug(f"从向量数据中搜索到{len(search_results)}条数据") + + # 步骤6: 重排序(可选) + use_rerank = True + if use_rerank and search_results: + rerank_start = time.time() + debug("开始重排序") + reranked_results = await rag_ops.rerank_results( + request, combined_text, search_results, limit, service_params, userid + ) + reranked_results = sorted(reranked_results, key=lambda x: x.get('rerank_score', 0), reverse=True) + timing_stats["reranking"] = time.time() - rerank_start + debug(f"重排序耗时: {timing_stats['reranking']:.3f} 秒") + debug(f"重排序分数分布: {[round(r.get('rerank_score', 0), 3) for r in reranked_results]}") + final_results = reranked_results + else: + final_results = [{k: v for k, v in r.items() if k != 'rerank_score'} for r in search_results] + + timing_stats["total_time"] = time.time() - start_time + info(f"融合搜索完成,返回 {len(final_results)} 条结果,总耗时: {timing_stats['total_time']:.3f} 秒") + + formatted_results = rag_ops.format_search_results(final_results, limit) + info(f"融合搜索完成,返回 {len(formatted_results)} 条结果") + + return { + "records": formatted_results + } + + except Exception as e: + error(f"融合搜索失败: {str(e)}, 堆栈: {traceback.format_exc()}") + # 事务管理器会自动执行回滚 + return { + "records": [], + "timing": {"total_time": time.time() - start_time if 'start_time' in locals() else 0}, + "error": str(e) + } + + +async def _validate_fiids_orgid(fiids, orgid, kw): + """验证 fiids 的 orgid 与当前用户 orgid 是否一致""" if fiids: db = DBPools() dbname = kw.get('get_module_dbname')('rag') - sql_opts = """ - SELECT orgid - FROM kdb - WHERE id = ${id}$ - """ + sql_opts = """SELECT orgid FROM kdb WHERE id = ${id}$""" try: async with db.sqlorContext(dbname) as sor: result = await sor.sqlExe(sql_opts, {"id": fiids[0]}) @@ -149,211 +233,230 @@ async def fusedsearch(request, params_kw, *params, **kw): raise ValueError(f"orgid 不一致: kdb.orgid={kdb_orgid}, user orgid={orgid}") except Exception as e: error(f"orgid 验证失败: {str(e)}") - return json.dumps({"status": "error", "message": str(e)}) - service_params = await get_service_params(orgid) + raise - api_service = APIService() - start_time = time.time() - timing_stats = {} - try: - info( - f"开始融合搜索: query={query}, userid={orgid}, knowledge_base_ids={fiids}") - if not query or not orgid or not fiids: - raise ValueError("query、orgid 和 knowledge_base_ids 不能为空") - - # 提取实体 - entity_extract_start = time.time() - query_entities = await api_service.extract_entities( - request=request, - query=query, - upappid=service_params['entities'], - apiname="LTP/small", - user=userid - ) - timing_stats["entity_extraction"] = time.time() - entity_extract_start - debug(f"提取实体: {query_entities}, 耗时: {timing_stats['entity_extraction']:.3f} 秒") - - # 调用 Neo4j 服务进行三元组匹配 - all_triplets = [] - triplet_match_start = time.time() - for kb_id in fiids: - debug(f"调用 Neo4j 三元组匹配: knowledge_base_id={kb_id}") - try: - neo4j_result = await api_service.neo4j_match_triplets( - request=request, - query=query, - query_entities=query_entities, - userid=orgid, - knowledge_base_id=kb_id, - upappid=service_params['gdb'], - apiname="neo4j/matchtriplets", - user=userid - ) - if neo4j_result.get("status") == "success": - triplets = neo4j_result.get("triplets", []) - all_triplets.extend(triplets) - debug(f"知识库 {kb_id} 匹配到 {len(triplets)} 个三元组: {triplets[:5]}") - else: - error( - f"Neo4j 三元组匹配失败: knowledge_base_id={kb_id}, 错误: {neo4j_result.get('message', '未知错误')}") - except Exception as e: - error(f"Neo4j 三元组匹配失败: knowledge_base_id={kb_id}, 错误: {str(e)}") - continue - timing_stats["triplet_matching"] = time.time() - triplet_match_start - debug(f"三元组匹配总耗时: {timing_stats['triplet_matching']:.3f} 秒") - - # 拼接三元组文本 - triplet_text_start = time.time() - triplet_texts = [] - for triplet in all_triplets: - head = triplet.get('head', '') - type_ = triplet.get('type', '') - tail = triplet.get('tail', '') - if head and type_ and tail: - triplet_texts.append(f"{head} {type_} {tail}") - else: - debug(f"无效三元组: {triplet}") - combined_text = query - if triplet_texts: - combined_text += "".join(triplet_texts) - debug( - f"拼接文本: {combined_text[:200]}... (总长度: {len(combined_text)}, 三元组数量: {len(triplet_texts)})") - timing_stats["triplet_text_combine"] = time.time() - triplet_text_start - debug(f"拼接三元组文本耗时: {timing_stats['triplet_text_combine']:.3f} 秒") - - # 将拼接文本转换为向量 - vector_start = time.time() - query_vector = await api_service.get_embeddings( - request=request, - texts=[combined_text], - upappid=service_params['embedding'], - apiname="BAAI/bge-m3", - user=userid - ) - if not query_vector or not all(len(vec) == 1024 for vec in query_vector): - raise ValueError("查询向量必须是长度为 1024 的浮点数列表") - query_vector = query_vector[0] # 取第一个向量 - timing_stats["vector_generation"] = time.time() - vector_start - debug(f"生成查询向量耗时: {timing_stats['vector_generation']:.3f} 秒") - - # 调用搜索端点 - sum = limit + 5 - search_start = time.time() - debug(f"orgid: {orgid}") - result = await api_service.milvus_search_query( - request=request, - query_vector=query_vector, - userid=orgid, - knowledge_base_ids=fiids, - limit=sum, - offset=0, - upappid=service_params['vdb'], - apiname="mlvus/searchquery", - user=userid - ) - timing_stats["vector_search"] = time.time() - search_start - debug(f"向量搜索耗时: {timing_stats['vector_search']:.3f} 秒") - - if result.get("status") != "success": - error(f"融合搜索失败: {result.get('message', '未知错误')}") - return {"results": [], "timing": timing_stats} - - unique_results = result.get("results", []) - sum = len(unique_results) - debug(f"从向量数据中搜索到{sum}条数据") - use_rerank = True - if use_rerank and unique_results: - rerank_start = time.time() - debug("开始重排序") - unique_results = await api_service.rerank_results( - request=request, - query=combined_text, - results=unique_results, - top_n=limit, - upappid=service_params['reranker'], - apiname="BAAI/bge-reranker-v2-m3", - user=userid - ) - unique_results = sorted(unique_results, key=lambda x: x.get('rerank_score', 0), reverse=True) - timing_stats["reranking"] = time.time() - rerank_start - debug(f"重排序耗时: {timing_stats['reranking']:.3f} 秒") - debug(f"重排序分数分布: {[round(r.get('rerank_score', 0), 3) for r in unique_results]}") +def _combine_query_with_triplets(query, triplets): + """拼接查询文本和三元组文本""" + triplet_texts = [] + for triplet in triplets: + head = triplet.get('head', '') + type_ = triplet.get('type', '') + tail = triplet.get('tail', '') + if head and type_ and tail: + triplet_texts.append(f"{head} {type_} {tail}") else: - unique_results = [{k: v for k, v in r.items() if k != 'rerank_score'} for r in unique_results] + debug(f"无效三元组: {triplet}") - timing_stats["total_time"] = time.time() - start_time - info(f"融合搜索完成,返回 {len(unique_results)} 条结果,总耗时: {timing_stats['total_time']:.3f} 秒") + combined_text = query + if triplet_texts: + combined_text += "".join(triplet_texts) - # dify_result = [] - # for res in unique_results[:limit]: - # content = res.get('text', '') - # title = res.get('metadata', {}).get('filename', 'Untitled') - # document_id = res.get('metadata', {}).get('document_id', '') - # dify_result.append({ - # 'metadata': {'document_id': document_id}, - # 'title': title, - # 'content': content - # }) - # info(f"融合搜索完成,返回 {len(dify_result)} 条结果,总耗时: {(time.time() - start_time):.3f} 秒") - # debug(f"result: {dify_result}") - # return dify_result + debug(f"拼接文本: {combined_text[:200]}... (总长度: {len(combined_text)}, 三元组数量: {len(triplet_texts)})") + return combined_text - dify_records = [] - dify_result = [] - for res in unique_results[:limit]: - rerank_score = res.get('rerank_score', 0) - score = 1 / (1 + math.exp(-rerank_score)) if rerank_score is not None else 1 - res.get('distance', 0) - score = max(0.0, min(1.0, score)) - content = res.get('text', '') - title = res.get('metadata', {}).get('filename', 'Untitled') - document_id = res.get('metadata', {}).get('document_id', '') - dify_records.append({ - "content": content, - "title": title, - "metadata": {"document_id": document_id, "score": score}, - }) - dify_result.append({ - "content": content, - "title": title, - "metadata": {"document_id": document_id, "score": score}, - }) - info(f"融合搜索完成,返回 {len(dify_records)} 条结果,总耗时: {(time.time() - start_time):.3f} 秒") - debug(f"records: {dify_records}, result: {dify_result}") - # return {"records": dify_records, "result": dify_result,"own": {"results": unique_results[:limit], "timing": timing_stats}} - return {"records": dify_records} - - # dify_result = [] - # for res in unique_results[:limit]: - # rerank_score = res.get('rerank_score', 0) - # score = 1 / (1 + math.exp(-rerank_score)) if rerank_score is not None else 1 - res.get('distance', 0) - # score = max(0.0, min(1.0, score)) - # content = res.get('text', '') - # title = res.get('metadata', {}).get('filename', 'Untitled') - # document_id = res.get('metadata', {}).get('document_id', '') - # dify_result.append({ - # "metadata": { - # "_source": "konwledge", - # "dataset_id":"111111", - # "dataset_name": "NVIDIA_GPU性能参数-RAG-V1.xlsx", - # "document_id": document_id, - # "document_name": "test.docx", - # "data_source_type": "upload_file", - # "segment_id": "7b391707-93bc-4654-80ae-7989f393b045", - # "retriever_from": "workflow", - # "score": score, - # "segment_hit_count": 7, - # "segment_word_count": 275, - # "segment_position": 5, - # "segment_index_node_hash": "1cd60b478221c9d4831a0b2af3e8b8581d94ecb53e8ffd46af687e8fc3077b73", - # "doc_metadata": None, - # "position":1 - # }, - # "title": title, - # "content": content - # }) - # return {"result": dify_result} - - except Exception as e: - error(f"融合搜索失败: {str(e)}, 堆栈: {traceback.format_exc()}") - return {"results": [], "timing": timing_stats} + # api_service = APIService() + # start_time = time.time() + # timing_stats = {} + # try: + # info( + # f"开始融合搜索: query={query}, userid={orgid}, knowledge_base_ids={fiids}") + # + # if not query or not orgid or not fiids: + # raise ValueError("query、orgid 和 knowledge_base_ids 不能为空") + # + # # 提取实体 + # entity_extract_start = time.time() + # query_entities = await api_service.extract_entities( + # request=request, + # query=query, + # upappid=service_params['entities'], + # apiname="LTP/small", + # user=userid + # ) + # timing_stats["entity_extraction"] = time.time() - entity_extract_start + # debug(f"提取实体: {query_entities}, 耗时: {timing_stats['entity_extraction']:.3f} 秒") + # + # # 调用 Neo4j 服务进行三元组匹配 + # all_triplets = [] + # triplet_match_start = time.time() + # for kb_id in fiids: + # debug(f"调用 Neo4j 三元组匹配: knowledge_base_id={kb_id}") + # try: + # neo4j_result = await api_service.neo4j_match_triplets( + # request=request, + # query=query, + # query_entities=query_entities, + # userid=orgid, + # knowledge_base_id=kb_id, + # upappid=service_params['gdb'], + # apiname="neo4j/matchtriplets", + # user=userid + # ) + # if neo4j_result.get("status") == "success": + # triplets = neo4j_result.get("triplets", []) + # all_triplets.extend(triplets) + # debug(f"知识库 {kb_id} 匹配到 {len(triplets)} 个三元组: {triplets[:5]}") + # else: + # error( + # f"Neo4j 三元组匹配失败: knowledge_base_id={kb_id}, 错误: {neo4j_result.get('message', '未知错误')}") + # except Exception as e: + # error(f"Neo4j 三元组匹配失败: knowledge_base_id={kb_id}, 错误: {str(e)}") + # continue + # timing_stats["triplet_matching"] = time.time() - triplet_match_start + # debug(f"三元组匹配总耗时: {timing_stats['triplet_matching']:.3f} 秒") + # + # # 拼接三元组文本 + # triplet_text_start = time.time() + # triplet_texts = [] + # for triplet in all_triplets: + # head = triplet.get('head', '') + # type_ = triplet.get('type', '') + # tail = triplet.get('tail', '') + # if head and type_ and tail: + # triplet_texts.append(f"{head} {type_} {tail}") + # else: + # debug(f"无效三元组: {triplet}") + # combined_text = query + # if triplet_texts: + # combined_text += "".join(triplet_texts) + # debug( + # f"拼接文本: {combined_text[:200]}... (总长度: {len(combined_text)}, 三元组数量: {len(triplet_texts)})") + # timing_stats["triplet_text_combine"] = time.time() - triplet_text_start + # debug(f"拼接三元组文本耗时: {timing_stats['triplet_text_combine']:.3f} 秒") + # + # # 将拼接文本转换为向量 + # vector_start = time.time() + # query_vector = await api_service.get_embeddings( + # request=request, + # texts=[combined_text], + # upappid=service_params['embedding'], + # apiname="BAAI/bge-m3", + # user=userid + # ) + # if not query_vector or not all(len(vec) == 1024 for vec in query_vector): + # raise ValueError("查询向量必须是长度为 1024 的浮点数列表") + # query_vector = query_vector[0] # 取第一个向量 + # timing_stats["vector_generation"] = time.time() - vector_start + # debug(f"生成查询向量耗时: {timing_stats['vector_generation']:.3f} 秒") + # + # # 调用搜索端点 + # sum = limit + 5 + # search_start = time.time() + # debug(f"orgid: {orgid}") + # result = await api_service.milvus_search_query( + # request=request, + # query_vector=query_vector, + # userid=orgid, + # knowledge_base_ids=fiids, + # limit=sum, + # offset=0, + # upappid=service_params['vdb'], + # apiname="mlvus/searchquery", + # user=userid + # ) + # timing_stats["vector_search"] = time.time() - search_start + # debug(f"向量搜索耗时: {timing_stats['vector_search']:.3f} 秒") + # + # if result.get("status") != "success": + # error(f"融合搜索失败: {result.get('message', '未知错误')}") + # return {"results": [], "timing": timing_stats} + # + # unique_results = result.get("results", []) + # sum = len(unique_results) + # debug(f"从向量数据中搜索到{sum}条数据") + # use_rerank = True + # if use_rerank and unique_results: + # rerank_start = time.time() + # debug("开始重排序") + # unique_results = await api_service.rerank_results( + # request=request, + # query=combined_text, + # results=unique_results, + # top_n=limit, + # upappid=service_params['reranker'], + # apiname="BAAI/bge-reranker-v2-m3", + # user=userid + # ) + # unique_results = sorted(unique_results, key=lambda x: x.get('rerank_score', 0), reverse=True) + # timing_stats["reranking"] = time.time() - rerank_start + # debug(f"重排序耗时: {timing_stats['reranking']:.3f} 秒") + # debug(f"重排序分数分布: {[round(r.get('rerank_score', 0), 3) for r in unique_results]}") + # else: + # unique_results = [{k: v for k, v in r.items() if k != 'rerank_score'} for r in unique_results] + # + # timing_stats["total_time"] = time.time() - start_time + # info(f"融合搜索完成,返回 {len(unique_results)} 条结果,总耗时: {timing_stats['total_time']:.3f} 秒") + # + # # dify_result = [] + # # for res in unique_results[:limit]: + # # content = res.get('text', '') + # # title = res.get('metadata', {}).get('filename', 'Untitled') + # # document_id = res.get('metadata', {}).get('document_id', '') + # # dify_result.append({ + # # 'metadata': {'document_id': document_id}, + # # 'title': title, + # # 'content': content + # # }) + # # info(f"融合搜索完成,返回 {len(dify_result)} 条结果,总耗时: {(time.time() - start_time):.3f} 秒") + # # debug(f"result: {dify_result}") + # # return dify_result + # + # dify_records = [] + # dify_result = [] + # for res in unique_results[:limit]: + # rerank_score = res.get('rerank_score', 0) + # score = 1 / (1 + math.exp(-rerank_score)) if rerank_score is not None else 1 - res.get('distance', 0) + # score = max(0.0, min(1.0, score)) + # content = res.get('text', '') + # title = res.get('metadata', {}).get('filename', 'Untitled') + # document_id = res.get('metadata', {}).get('document_id', '') + # dify_records.append({ + # "content": content, + # "title": title, + # "metadata": {"document_id": document_id, "score": score}, + # }) + # dify_result.append({ + # "content": content, + # "title": title, + # "metadata": {"document_id": document_id, "score": score}, + # }) + # info(f"融合搜索完成,返回 {len(dify_records)} 条结果,总耗时: {(time.time() - start_time):.3f} 秒") + # debug(f"records: {dify_records}, result: {dify_result}") + # # return {"records": dify_records, "result": dify_result,"own": {"results": unique_results[:limit], "timing": timing_stats}} + # return {"records": dify_records} + # + # # dify_result = [] + # # for res in unique_results[:limit]: + # # rerank_score = res.get('rerank_score', 0) + # # score = 1 / (1 + math.exp(-rerank_score)) if rerank_score is not None else 1 - res.get('distance', 0) + # # score = max(0.0, min(1.0, score)) + # # content = res.get('text', '') + # # title = res.get('metadata', {}).get('filename', 'Untitled') + # # document_id = res.get('metadata', {}).get('document_id', '') + # # dify_result.append({ + # # "metadata": { + # # "_source": "konwledge", + # # "dataset_id":"111111", + # # "dataset_name": "NVIDIA_GPU性能参数-RAG-V1.xlsx", + # # "document_id": document_id, + # # "document_name": "test.docx", + # # "data_source_type": "upload_file", + # # "segment_id": "7b391707-93bc-4654-80ae-7989f393b045", + # # "retriever_from": "workflow", + # # "score": score, + # # "segment_hit_count": 7, + # # "segment_word_count": 275, + # # "segment_position": 5, + # # "segment_index_node_hash": "1cd60b478221c9d4831a0b2af3e8b8581d94ecb53e8ffd46af687e8fc3077b73", + # # "doc_metadata": None, + # # "position":1 + # # }, + # # "title": title, + # # "content": content + # # }) + # # return {"result": dify_result} + # + # except Exception as e: + # error(f"融合搜索失败: {str(e)}, 堆栈: {traceback.format_exc()}") + # return {"results": [], "timing": timing_stats}