From e590c1084fca5f3e9645dfce8134dcbe323d82cf Mon Sep 17 00:00:00 2001 From: yumoqing Date: Thu, 11 Sep 2025 14:33:23 +0800 Subject: [PATCH] bugfix --- rag/folderinfo.py | 261 +++++++++++++++----------------------------- rag/service_opts.py | 75 +++++++++++++ 2 files changed, 163 insertions(+), 173 deletions(-) create mode 100644 rag/service_opts.py diff --git a/rag/folderinfo.py b/rag/folderinfo.py index e5434cb..ae60e95 100644 --- a/rag/folderinfo.py +++ b/rag/folderinfo.py @@ -17,6 +17,7 @@ import traceback from filetxt.loader import fileloader from ahserver.serverenv import get_serverenv from typing import List, Dict, Any +from rag.service_opts import get_service_params, sor_get_service_params import json class RagFileMgr(FileMgr): @@ -43,73 +44,6 @@ where a.orgid = b.orgid return r.quota, r.expired_date return None, None - async def get_service_params(self,sor, orgid): - """ 根据 orgid 从数据库获取服务参数 (仅 upappid),假设 service_opts 表返回单条记录。 """ - sql_opts = """ - SELECT embedding_id, vdb_id, reranker_id, triples_id, gdb_id, entities_id - FROM service_opts - WHERE orgid = ${orgid}$ - """ - opts_result = await sor.sqlExe(sql_opts, {"orgid": orgid}) - if not opts_result: - error(f"未找到 orgid={orgid} 的服务配置") - return None - opts = opts_result[0] - - # 收集服务 ID - service_ids = set() - for key in ['embedding_id', 'vdb_id', 'reranker_id', 'triples_id', 'gdb_id', 'entities_id']: - if opts[key]: - service_ids.add(opts[key]) - - # 检查 service_ids 是否为空 - if not service_ids: - error(f"未找到任何服务 ID for orgid={orgid}") - return None - - # 手动构造 IN 子句的 ID 列表 - id_list = ','.join([f"'{id}'" for id in service_ids]) # 确保每个 ID 被单引号包裹 - sql_services = """ - SELECT id, name, upappid - FROM ragservices - WHERE id IN ${id_list}$ - """ - services_result = await sor.sqlExe(sql_services, {'id_list': id_list}) - if not services_result: - error(f"未找到服务 ID {service_ids} 的 ragservices 配置") - return None - - # 构建服务参数字典,基于 name 字段匹配,仅存储 upappid - service_params = { - 'embedding': None, - 'vdb': None, - 'reranker': None, - 'triples': None, - 'gdb': None, - 'entities': None - } - for service in services_result: - name = service['name'] - if name == 'bgem3嵌入': - service_params['embedding'] = service['upappid'] - elif name == 'milvus向量检索': - service_params['vdb'] = service['upappid'] - elif name == 'bgem2v3重排': - service_params['reranker'] = service['upappid'] - elif name == 'mrebel三元组抽取': - service_params['triples'] = service['upappid'] - elif name == 'neo4j删除知识库': - service_params['gdb'] = service['upappid'] - elif name == 'small实体抽取': - service_params['entities'] = service['upappid'] - - # 检查是否所有服务参数都已填充 - missing_services = [k for k, v in service_params.items() if v is None] - if missing_services: - error(f"未找到以下服务的配置: {missing_services}") - return None - return service_params - async def file_uploaded(self, request, ns, userid): """将文档插入 Milvus 并抽取三元组到 Neo4j""" debug(f'Received ns: {ns=}') @@ -128,45 +62,30 @@ where a.orgid = b.orgid timings = {} start_total = time.time() + service_params = await get_service_params(orgid) + chunks = await self.get_doucment_chunks(realpath) + embeddings = await self.docs_embedding(chunks) + await self.embedding_2_vdb(id, fiid, orgid, realpath, embedding) + triples = await self.get_triples(chunks) + await self.triple2graphdb(id, fiid, orgid, realpath, triples) + return try: if not orgid or not fiid or not id: - raise ValueError("orgid、fiid 和 id 不能为空") + raise ValueError("orgid、fiid 和 id 不能为空") if len(orgid) > 32 or len(fiid) > 255: - raise ValueError("orgid 或 fiid 的长度超出限制") + raise ValueError("orgid 或 fiid 的长度超出限制") if not os.path.exists(realpath): - raise ValueError(f"文件 {realpath} 不存在") - - # 检查 hashvalue 是否已存在 - db = DBPools() - dbname = env.get_module_dbname('rag') - sql_check_hash = """ - SELECT hashvalue - FROM file - WHERE hashvalue = ${hashvalue}$ - """ - async with db.sqlorContext(dbname) as sor: - hash_result = await sor.sqlExe(sql_check_hash, {"hashvalue": hashvalue}) - if hash_result: - debug(f"文件已存在: hashvalue={hashvalue}") - timings["total"] = time.time() - start_total - return { - "status": "error", - "document_id": id, - "collection_name": "ragdb", - "timings": timings, - "message": f"文件已存在: hashvalue={hashvalue}", - "status_code": 400 - } + raise ValueError(f"文件 {realpath} 不存在") # 获取服务参数 - service_params = await self.get_service_params(sor, orgid) + service_params = await get_service_params(orgid) if not service_params: - raise ValueError("无法获取服务参数") + raise ValueError("无法获取服务参数") supported_formats = {'pdf', 'docx', 'xlsx', 'pptx', 'csv', 'txt'} ext = realpath.rsplit('.', 1)[1].lower() if '.' in realpath else '' if ext not in supported_formats: - raise ValueError(f"不支持的文件格式: {ext}, 支持的格式: {', '.join(supported_formats)}") + raise ValueError(f"不支持的文件格式: {ext}, 支持的格式: {', '.join(supported_formats)}") debug(f"加载文件: {realpath}") start_load = time.time() @@ -176,22 +95,22 @@ where a.orgid = b.orgid timings["load_file"] = time.time() - start_load debug(f"加载文件耗时: {timings['load_file']:.2f} 秒, 文本长度: {len(text)}") if not text or not text.strip(): - raise ValueError(f"文件 {realpath} 加载为空") + raise ValueError(f"文件 {realpath} 加载为空") document = Document(page_content=text) text_splitter = RecursiveCharacterTextSplitter( - chunk_size=500, - chunk_overlap=100, - length_function=len) + chunk_size=500, + chunk_overlap=100, + length_function=len) debug("开始分片文件内容") start_split = time.time() chunks = text_splitter.split_documents([document]) timings["split_text"] = time.time() - start_split debug( - f"文本分片耗时: {timings['split_text']:.2f} 秒, 分片数量: {len(chunks)}, 分片内容: {[chunk.page_content[:50] for chunk in chunks[:5]]}") + f"文本分片耗时: {timings['split_text']:.2f} 秒, 分片数量: {len(chunks)}, 分片内容: {[chunk.page_content[:50] for chunk in chunks[:5]]}") debug(f"分片内容: {[chunk.page_content[:100] + '...' for chunk in chunks]}") if not chunks: - raise ValueError(f"文件 {realpath} 未生成任何文档块") + raise ValueError(f"文件 {realpath} 未生成任何文档块") filename = os.path.basename(realpath).rsplit('.', 1)[0] upload_time = datetime.now().isoformat() @@ -201,23 +120,23 @@ where a.orgid = b.orgid texts = [chunk.page_content for chunk in chunks] embeddings = [] for i in range(0, len(texts), 10): # 每次处理 10 个文本块 - batch_texts = texts[i:i + 10] - batch_embeddings = await api_service.get_embeddings( + batch_texts = texts[i:i + 10] + batch_embeddings = await api_service.get_embeddings( request=request, texts=batch_texts, upappid=service_params['embedding'], apiname="BAAI/bge-m3", user=userid - ) - embeddings.extend(batch_embeddings) + ) + embeddings.extend(batch_embeddings) if not embeddings or not all(len(vec) == 1024 for vec in embeddings): - raise ValueError("所有嵌入向量必须是长度为 1024 的浮点数列表") + raise ValueError("所有嵌入向量必须是长度为 1024 的浮点数列表") timings["generate_embeddings"] = time.time() - start_embedding debug(f"生成嵌入向量耗时: {timings['generate_embeddings']:.2f} 秒, 嵌入数量: {len(embeddings)}") chunks_data = [] for i, chunk in enumerate(chunks): - chunks_data.append({ + chunks_data.append({ "userid": orgid, "knowledge_base_id": fiid, "text": chunk.page_content, @@ -227,38 +146,38 @@ where a.orgid = b.orgid "file_path": realpath, "upload_time": upload_time, "file_type": ext, - }) + }) debug(f"调用插入文件端点: {realpath}") start_milvus = time.time() for i in range(0, len(chunks_data), 10): # 每次处理 10 条数据 - batch_chunks = chunks_data[i:i + 10] - result = await api_service.milvus_insert_document( - request=request, - chunks=batch_chunks, - db_type=db_type, - upappid=service_params['vdb'], - apiname="milvus/insertdocument", - user=userid - ) - if result.get("status") != "success": - raise ValueError(result.get("message", "Milvus 插入失败")) + batch_chunks = chunks_data[i:i + 10] + result = await api_service.milvus_insert_document( + request=request, + chunks=batch_chunks, + db_type=db_type, + upappid=service_params['vdb'], + apiname="milvus/insertdocument", + user=userid + ) + if result.get("status") != "success": + raise ValueError(result.get("message", "Milvus 插入失败")) timings["insert_milvus"] = time.time() - start_milvus debug(f"Milvus 插入耗时: {timings['insert_milvus']:.2f} 秒") if result.get("status") != "success": - timings["total"] = time.time() - start_total - return {"status": "error", "document_id": id, "timings": timings, + timings["total"] = time.time() - start_total + return {"status": "error", "document_id": id, "timings": timings, "message": result.get("message", "未知错误"), "status_code": 400} debug("调用三元组抽取服务") start_triples = time.time() unique_triples = [] try: - chunk_texts = [doc.page_content for doc in chunks] - debug(f"处理 {len(chunk_texts)} 个分片进行三元组抽取") - triples = [] - for i, chunk in enumerate(chunk_texts): + chunk_texts = [doc.page_content for doc in chunks] + debug(f"处理 {len(chunk_texts)} 个分片进行三元组抽取") + triples = [] + for i, chunk in enumerate(chunk_texts): result = await api_service.extract_triples( request=request, text=chunk, @@ -283,16 +202,16 @@ where a.orgid = b.orgid if (existing['head'].lower() == t['head'].lower() and existing['tail'].lower() == t['tail'].lower() and len(t['type']) > len(existing['type'])): - unique_triples.remove(existing) - unique_triples.append(t) - debug(f"替换三元组为更具体类型: {t}") - break + unique_triples.remove(existing) + unique_triples.append(t) + debug(f"替换三元组为更具体类型: {t}") + break - timings["extract_triples"] = time.time() - start_triples - debug( - f"三元组抽取耗时: {timings['extract_triples']:.2f} 秒, 抽取到 {len(unique_triples)} 个三元组: {unique_triples[:5]}") + timings["extract_triples"] = time.time() - start_triples + debug( + f"三元组抽取耗时: {timings['extract_triples']:.2f} 秒, 抽取到 {len(unique_triples)} 个三元组: {unique_triples[:5]}") - if unique_triples: + if unique_triples: debug(f"抽取到 {len(unique_triples)} 个三元组,调用 Neo4j 服务插入") start_neo4j = time.time() for i in range(0, len(unique_triples), 30): # 每次插入 30 个三元组 @@ -322,11 +241,11 @@ where a.orgid = b.orgid info(f"文件 {realpath} 三元组成功插入 Neo4j: {neo4j_result.get('message')}") timings["insert_neo4j"] = time.time() - start_neo4j debug(f"Neo4j 插入耗时: {timings['insert_neo4j']:.2f} 秒") - else: + else: debug(f"文件 {realpath} 未抽取到三元组") timings["insert_neo4j"] = 0.0 - except Exception as e: + except Exception as e: timings["extract_triples"] = time.time() - start_triples if "extract_triples" not in timings else \ timings["extract_triples"] timings["insert_neo4j"] = time.time() - start_neo4j if "insert_neo4j" not in timings else timings[ @@ -386,36 +305,34 @@ where a.orgid = b.orgid collection_name = "ragdb" if not db_type else f"ragdb_{db_type}" try: - required_fields = ['id', 'realpath', 'fiid', 'ownerid'] - missing_fields = [field for field in required_fields if not rec.get(field, '')] - if missing_fields: - raise ValueError(f"缺少必填字段: {', '.join(missing_fields)}") + required_fields = ['id', 'realpath', 'fiid', 'ownerid'] + missing_fields = [field for field in required_fields if not rec.get(field, '')] + if missing_fields: + raise ValueError(f"缺少必填字段: {', '.join(missing_fields)}") - # 获取服务参数 - service_params = await self.get_service_params(sor, orgid) - if not service_params: - raise ValueError("无法获取服务参数") + # 获取服务参数 + service_params = await get_service_params(orgid) + if not service_params: + raise ValueError("无法获取服务参数") - debug( - f"调用删除文件端点: userid={orgid}, file_path={realpath}, knowledge_base_id={fiid}, document_id={id}") - milvus_result = await api_service.milvus_delete_document( - request=request, - userid=orgid, - file_path=realpath, - knowledge_base_id=fiid, - document_id=id, - db_type=db_type, - upappid=service_params['vdb'], - apiname="milvus/deletedocument", - user=userid - ) + debug(f"调用删除文件端点: userid={orgid}, file_path={realpath}, knowledge_base_id={fiid}, document_id={id}") + milvus_result = await api_service.milvus_delete_document( + request=request, + userid=orgid, + file_path=realpath, + knowledge_base_id=fiid, + document_id=id, + db_type=db_type, + upappid=service_params['vdb'], + apiname="milvus/deletedocument", + user=userid + ) - if milvus_result.get("status") != "success": - raise ValueError(milvus_result.get("message", "Milvus 删除失败")) + if milvus_result.get("status") != "success": + raise ValueError(milvus_result.get("message", "Milvus 删除失败")) - neo4j_deleted_nodes = 0 - neo4j_deleted_rels = 0 - try: + neo4j_deleted_nodes = 0 + neo4j_deleted_rels = 0 debug(f"调用 Neo4j 删除文档端点: document_id={id}") neo4j_result = await api_service.neo4j_delete_document( request=request, @@ -433,26 +350,24 @@ where a.orgid = b.orgid total_nodes_deleted += nodes_deleted total_rels_deleted += rels_deleted info(f"成功删除 document_id={id} 的 {nodes_deleted} 个 Neo4j 节点和 {rels_deleted} 个关系") - except Exception as e: - error(f"删除 document_id={id} 的 Neo4j 数据失败: {str(e)}") - results.append({ - "status": "success", - "collection_name": collection_name, - "document_id": id, - "message": f"成功删除文件 {realpath} 的 Milvus 记录,{neo4j_deleted_nodes} 个 Neo4j 节点,{neo4j_deleted_rels} 个 Neo4j 关系", - "status_code": 200 - }) + results.append({ + "status": "success", + "collection_name": collection_name, + "document_id": id, + "message": f"成功删除文件 {realpath} 的 Milvus 记录,{neo4j_deleted_nodes} 个 Neo4j 节点,{neo4j_deleted_rels} 个 Neo4j 关系", + "status_code": 200 + }) except Exception as e: - error(f"删除文档 {realpath} 失败: {str(e)}, 堆栈: {traceback.format_exc()}") - results.append({ + error(f"删除文档 {realpath} 失败: {str(e)}, 堆栈: {traceback.format_exc()}") + results.append({ "status": "error", "collection_name": collection_name, "document_id": id, "message": f"删除文档 {realpath} 失败: {str(e)}", "status_code": 400 - }) + }) return { "status": "success" if all(r["status"] == "success" for r in results) else "partial", diff --git a/rag/service_opts.py b/rag/service_opts.py new file mode 100644 index 0000000..72c1833 --- /dev/null +++ b/rag/service_opts.py @@ -0,0 +1,75 @@ +from ahserver.serverenv import get_serverenv + +async def sor_get_service_params(sor, orgid): + """ 根据 orgid 从数据库获取服务参数 (仅 upappid),假设 service_opts 表返回单条记录。 """ + sql_opts = """ + SELECT embedding_id, vdb_id, reranker_id, triples_id, gdb_id, entities_id + FROM service_opts + WHERE orgid = ${orgid}$ + """ + opts_result = await sor.sqlExe(sql_opts, {"orgid": orgid}) + if not opts_result: + error(f"未找到 orgid={orgid} 的服务配置") + return None + opts = opts_result[0] + + # 收集服务 ID + service_ids = set() + for key in ['embedding_id', 'vdb_id', 'reranker_id', 'triples_id', 'gdb_id', 'entities_id']: + if opts[key]: + service_ids.add(opts[key]) + + # 检查 service_ids 是否为空 + if not service_ids: + error(f"未找到任何服务 ID for orgid={orgid}") + return None + + # 手动构造 IN 子句的 ID 列表 + id_list = ','.join([f"'{id}'" for id in service_ids]) # 确保每个 ID 被单引号包裹 + sql_services = """ + SELECT id, name, upappid + FROM ragservices + WHERE id IN ${id_list}$ + """ + services_result = await sor.sqlExe(sql_services, {'id_list': id_list}) + if not services_result: + error(f"未找到服务 ID {service_ids} 的 ragservices 配置") + return None + + # 构建服务参数字典,基于 name 字段匹配,仅存储 upappid + service_params = { + 'embedding': None, + 'vdb': None, + 'reranker': None, + 'triples': None, + 'gdb': None, + 'entities': None + } + for service in services_result: + name = service['name'] + if name == 'bgem3嵌入': + service_params['embedding'] = service['upappid'] + elif name == 'milvus向量检索': + service_params['vdb'] = service['upappid'] + elif name == 'bgem2v3重排': + service_params['reranker'] = service['upappid'] + elif name == 'mrebel三元组抽取': + service_params['triples'] = service['upappid'] + elif name == 'neo4j删除知识库': + service_params['gdb'] = service['upappid'] + elif name == 'small实体抽取': + service_params['entities'] = service['upappid'] + + # 检查是否所有服务参数都已填充 + missing_services = [k for k, v in service_params.items() if v is None] + if missing_services: + error(f"未找到以下服务的配置: {missing_services}") + return None + return service_params + +async def get_service_params(orgid): + db = DBPools() + dbname = get_serverenv('get_module_dbname')('rag') + async with db.sqlorContext(dbname) as sor: + return await sor_get_server_params(sor, orgid) + return None