From 516edb5b6ab9b64318beb4728ac0eb9f63521fc7 Mon Sep 17 00:00:00 2001 From: wangmeihua <13383952685@163.com> Date: Tue, 9 Sep 2025 13:17:22 +0800 Subject: [PATCH] ragapi --- rag/folderinfo.py | 837 ++++++++++++++++++++++---------------------- rag/ragapi.py | 71 +++- rag/uapi_service.py | 1 + 3 files changed, 481 insertions(+), 428 deletions(-) diff --git a/rag/folderinfo.py b/rag/folderinfo.py index 07d4fa9..f4c2aaa 100644 --- a/rag/folderinfo.py +++ b/rag/folderinfo.py @@ -20,463 +20,461 @@ from typing import List, Dict, Any import json class RagFileMgr(FileMgr): - async def get_folder_ownerid(self, sor): - fiid = self.fiid - recs = await sor.R('kdb', {'id': self.fiid}) - if len(recs) > 0: - return recs[0].orgid - return None + async def get_folder_ownerid(self, sor): + fiid = self.fiid + recs = await sor.R('kdb', {'id': self.fiid}) + if len(recs) > 0: + return recs[0].orgid + return None - async def get_organization_quota(self, sor, orgid): - sql = """select a.* from ragquota a, kdb b + async def get_organization_quota(self, sor, orgid): + sql = """select a.* from ragquota a, kdb b where a.orgid = b.orgid - and b.id = ${id}$ - and ${today}$ >= a.enabled_date - and ${today}$ < a.expired_date + and b.id = ${id}$ + and ${today}$ >= a.enabled_date + and ${today}$ < a.expired_date """ - recs = await sor.sqlExe(sql, { - 'id': self.fiid, - 'today': curDateString() - }) - if len(recs) > 0: - r = recs[0] - return r.quota, r.expired_date - return None, None + recs = await sor.sqlExe(sql, { + 'id': self.fiid, + 'today': curDateString() + }) + if len(recs) > 0: + r = recs[0] + return r.quota, r.expired_date + return None, None - async def get_service_params(self,orgid): - """ 根据 orgid 从数据库获取服务参数 (仅 upappid),假设 service_opts 表返回单条记录。 """ - db = DBPools() - dbname = "kyrag" + async def get_service_params(self,orgid): + """ 根据 orgid 从数据库获取服务参数 (仅 upappid),假设 service_opts 表返回单条记录。 """ + db = DBPools() + dbname = "kyrag" - sql_opts = """ - SELECT embedding_id, vdb_id, reranker_id, triples_id, gdb_id, entities_id - FROM service_opts - WHERE orgid = ${orgid}$ - """ - try: - async with db.sqlorContext(dbname) as sor: - opts_result = await sor.sqlExe(sql_opts, {"orgid": orgid}) - if not opts_result: - error(f"未找到 orgid={orgid} 的服务配置") - return None - opts = opts_result[0] - except Exception as e: - error(f"查询 service_opts 失败: {str(e)}, 堆栈: {traceback.format_exc()}") - return None + sql_opts = """ + SELECT embedding_id, vdb_id, reranker_id, triples_id, gdb_id, entities_id + FROM service_opts + WHERE orgid = ${orgid}$ + """ + try: + async with db.sqlorContext(dbname) as sor: + opts_result = await sor.sqlExe(sql_opts, {"orgid": orgid}) + if not opts_result: + error(f"未找到 orgid={orgid} 的服务配置") + return None + opts = opts_result[0] + except Exception as e: + error(f"查询 service_opts 失败: {str(e)}, 堆栈: {traceback.format_exc()}") + return None - # 收集服务 ID - service_ids = set() - for key in ['embedding_id', 'vdb_id', 'reranker_id', 'triples_id', 'gdb_id', 'entities_id']: - if opts[key]: - service_ids.add(opts[key]) + # 收集服务 ID + service_ids = set() + for key in ['embedding_id', 'vdb_id', 'reranker_id', 'triples_id', 'gdb_id', 'entities_id']: + if opts[key]: + service_ids.add(opts[key]) - # 检查 service_ids 是否为空 - if not service_ids: - error(f"未找到任何服务 ID for orgid={orgid}") - return None + # 检查 service_ids 是否为空 + if not service_ids: + error(f"未找到任何服务 ID for orgid={orgid}") + return None - # 手动构造 IN 子句的 ID 列表 - id_list = ','.join([f"'{id}'" for id in service_ids]) # 确保每个 ID 被单引号包裹 - sql_services = f""" - SELECT id, name, upappid - FROM ragservices - WHERE id IN ({id_list}) - """ - try: - async with db.sqlorContext(dbname) as sor: - services_result = await sor.sqlExe(sql_services, {}) - if not services_result: - error(f"未找到服务 ID {service_ids} 的 ragservices 配置") - return None + # 手动构造 IN 子句的 ID 列表 + id_list = ','.join([f"'{id}'" for id in service_ids]) # 确保每个 ID 被单引号包裹 + sql_services = f""" + SELECT id, name, upappid + FROM ragservices + WHERE id IN ({id_list}) + """ + try: + async with db.sqlorContext(dbname) as sor: + services_result = await sor.sqlExe(sql_services, {}) + if not services_result: + error(f"未找到服务 ID {service_ids} 的 ragservices 配置") + return None - # 构建服务参数字典,基于 name 字段匹配,仅存储 upappid - service_params = { - 'embedding': None, - 'vdb': None, - 'reranker': None, - 'triples': None, - 'gdb': None, - 'entities': None - } - for service in services_result: - name = service['name'] - if name == 'bgem3嵌入': - service_params['embedding'] = service['upappid'] - elif name == 'milvus向量检索': - service_params['vdb'] = service['upappid'] - elif name == 'bgem2v3重排': - service_params['reranker'] = service['upappid'] - elif name == 'mrebel三元组抽取': - service_params['triples'] = service['upappid'] - elif name == 'neo4j删除知识库': - service_params['gdb'] = service['upappid'] - elif name == 'small实体抽取': - service_params['entities'] = service['upappid'] + # 构建服务参数字典,基于 name 字段匹配,仅存储 upappid + service_params = { + 'embedding': None, + 'vdb': None, + 'reranker': None, + 'triples': None, + 'gdb': None, + 'entities': None + } + for service in services_result: + name = service['name'] + if name == 'bgem3嵌入': + service_params['embedding'] = service['upappid'] + elif name == 'milvus向量检索': + service_params['vdb'] = service['upappid'] + elif name == 'bgem2v3重排': + service_params['reranker'] = service['upappid'] + elif name == 'mrebel三元组抽取': + service_params['triples'] = service['upappid'] + elif name == 'neo4j删除知识库': + service_params['gdb'] = service['upappid'] + elif name == 'small实体抽取': + service_params['entities'] = service['upappid'] - # 检查是否所有服务参数都已填充 - missing_services = [k for k, v in service_params.items() if v is None] - if missing_services: - error(f"未找到以下服务的配置: {missing_services}") - return None + # 检查是否所有服务参数都已填充 + missing_services = [k for k, v in service_params.items() if v is None] + if missing_services: + error(f"未找到以下服务的配置: {missing_services}") + return None - return service_params - except Exception as e: - error(f"查询 ragservices 失败: {str(e)}, 堆栈: {traceback.format_exc()}") - return None + return service_params + except Exception as e: + error(f"查询 ragservices 失败: {str(e)}, 堆栈: {traceback.format_exc()}") + return None - async def file_uploaded(self, request, ns, userid): - """将文档插入 Milvus 并抽取三元组到 Neo4j""" - debug(f'Received ns: {ns=}') - realpath = ns.get('realpath', '') - fiid = ns.get('fiid', '') - id = ns.get('id', '') - orgid = ns.get('ownerid', '') - hashvalue = ns.get('hashvalue', '') - db_type = '' + async def file_uploaded(self, request, ns, userid): + """将文档插入 Milvus 并抽取三元组到 Neo4j""" + debug(f'Received ns: {ns=}') + realpath = ns.get('realpath', '') + fiid = ns.get('fiid', '') + id = ns.get('id', '') + orgid = ns.get('ownerid', '') + hashvalue = ns.get('hashvalue', '') + db_type = '' - api_service = APIService() - debug( - f'Inserting document: file_path={realpath}, userid={orgid}, db_type={db_type}, knowledge_base_id={fiid}, document_id={id}') + api_service = APIService() + debug( + f'Inserting document: file_path={realpath}, userid={orgid}, db_type={db_type}, knowledge_base_id={fiid}, document_id={id}') - timings = {} - start_total = time.time() + timings = {} + start_total = time.time() - try: - if not orgid or not fiid or not id: - raise ValueError("orgid、fiid 和 id 不能为空") - if len(orgid) > 32 or len(fiid) > 255: - raise ValueError("orgid 或 fiid 的长度超出限制") - if not os.path.exists(realpath): - raise ValueError(f"文件 {realpath} 不存在") + try: + if not orgid or not fiid or not id: + raise ValueError("orgid、fiid 和 id 不能为空") + if len(orgid) > 32 or len(fiid) > 255: + raise ValueError("orgid 或 fiid 的长度超出限制") + if not os.path.exists(realpath): + raise ValueError(f"文件 {realpath} 不存在") - # 检查 hashvalue 是否已存在 - db = DBPools() - dbname = "kyrag" - sql_check_hash = """ - SELECT hashvalue - FROM file - WHERE hashvalue = ${hashvalue}$ - """ - async with db.sqlorContext(dbname) as sor: - hash_result = await sor.sqlExe(sql_check_hash, {"hashvalue": hashvalue}) - if hash_result: - debug(f"文件已存在: hashvalue={hashvalue}") - timings["total"] = time.time() - start_total - return { - "status": "error", - "document_id": id, - "collection_name": "ragdb", - "timings": timings, - "message": f"文件已存在: hashvalue={hashvalue}", - "status_code": 400 - } + # 检查 hashvalue 是否已存在 + db = DBPools() + dbname = "kyrag" + sql_check_hash = """ + SELECT hashvalue + FROM file + WHERE hashvalue = ${hashvalue}$ + """ + async with db.sqlorContext(dbname) as sor: + hash_result = await sor.sqlExe(sql_check_hash, {"hashvalue": hashvalue}) + if hash_result: + debug(f"文件已存在: hashvalue={hashvalue}") + timings["total"] = time.time() - start_total + return { + "status": "error", + "document_id": id, + "collection_name": "ragdb", + "timings": timings, + "message": f"文件已存在: hashvalue={hashvalue}", + "status_code": 400 + } - # 获取服务参数 - service_params = await self.get_service_params(orgid) - if not service_params: - raise ValueError("无法获取服务参数") + # 获取服务参数 + service_params = await self.get_service_params(orgid) + if not service_params: + raise ValueError("无法获取服务参数") - supported_formats = {'pdf', 'docx', 'xlsx', 'pptx', 'csv', 'txt'} - ext = realpath.rsplit('.', 1)[1].lower() if '.' in realpath else '' - if ext not in supported_formats: - raise ValueError(f"不支持的文件格式: {ext}, 支持的格式: {', '.join(supported_formats)}") + supported_formats = {'pdf', 'docx', 'xlsx', 'pptx', 'csv', 'txt'} + ext = realpath.rsplit('.', 1)[1].lower() if '.' in realpath else '' + if ext not in supported_formats: + raise ValueError(f"不支持的文件格式: {ext}, 支持的格式: {', '.join(supported_formats)}") - debug(f"加载文件: {realpath}") - start_load = time.time() - text = fileloader(realpath) - # debug(f"处理后的文件内容是:{text=}") - text = re.sub(r'[^\u4e00-\u9fa5a-zA-Z0-9\s.;,\n]', '', text) - timings["load_file"] = time.time() - start_load - debug(f"加载文件耗时: {timings['load_file']:.2f} 秒, 文本长度: {len(text)}") - if not text or not text.strip(): - raise ValueError(f"文件 {realpath} 加载为空") + debug(f"加载文件: {realpath}") + start_load = time.time() + text = fileloader(realpath) + # debug(f"处理后的文件内容是:{text=}") + text = re.sub(r'[^\u4e00-\u9fa5a-zA-Z0-9\s.;,\n/]', '', text) + timings["load_file"] = time.time() - start_load + debug(f"加载文件耗时: {timings['load_file']:.2f} 秒, 文本长度: {len(text)}") + if not text or not text.strip(): + raise ValueError(f"文件 {realpath} 加载为空") - document = Document(page_content=text) - text_splitter = RecursiveCharacterTextSplitter( - chunk_size=500, - chunk_overlap=100, - length_function=len) - debug("开始分片文件内容") - start_split = time.time() - chunks = text_splitter.split_documents([document]) - timings["split_text"] = time.time() - start_split - debug( - f"文本分片耗时: {timings['split_text']:.2f} 秒, 分片数量: {len(chunks)}, 分片内容: {[chunk.page_content[:50] for chunk in chunks[:5]]}") - if not chunks: - raise ValueError(f"文件 {realpath} 未生成任何文档块") + document = Document(page_content=text) + text_splitter = RecursiveCharacterTextSplitter( + chunk_size=500, + chunk_overlap=100, + length_function=len) + debug("开始分片文件内容") + start_split = time.time() + chunks = text_splitter.split_documents([document]) + timings["split_text"] = time.time() - start_split + debug( + f"文本分片耗时: {timings['split_text']:.2f} 秒, 分片数量: {len(chunks)}, 分片内容: {[chunk.page_content[:50] for chunk in chunks[:5]]}") + debug(f"分片内容: {[chunk.page_content[:100] + '...' for chunk in chunks]}") + if not chunks: + raise ValueError(f"文件 {realpath} 未生成任何文档块") - filename = os.path.basename(realpath).rsplit('.', 1)[0] - upload_time = datetime.now().isoformat() + filename = os.path.basename(realpath).rsplit('.', 1)[0] + upload_time = datetime.now().isoformat() - debug("调用嵌入服务生成向量") - start_embedding = time.time() - texts = [chunk.page_content for chunk in chunks] - embeddings = [] - for i in range(0, len(texts), 10): # 每次处理 10 个文本块 - batch_texts = texts[i:i + 10] - batch_embeddings = await api_service.get_embeddings( - request=request, - texts=batch_texts, - upappid=service_params['embedding'], - apiname="BAAI/bge-m3", - user=userid - ) - embeddings.extend(batch_embeddings) - if not embeddings or not all(len(vec) == 1024 for vec in embeddings): - raise ValueError("所有嵌入向量必须是长度为 1024 的浮点数列表") - timings["generate_embeddings"] = time.time() - start_embedding - debug(f"生成嵌入向量耗时: {timings['generate_embeddings']:.2f} 秒, 嵌入数量: {len(embeddings)}") + debug("调用嵌入服务生成向量") + start_embedding = time.time() + texts = [chunk.page_content for chunk in chunks] + embeddings = [] + for i in range(0, len(texts), 10): # 每次处理 10 个文本块 + batch_texts = texts[i:i + 10] + batch_embeddings = await api_service.get_embeddings( + request=request, + texts=batch_texts, + upappid=service_params['embedding'], + apiname="BAAI/bge-m3", + user=userid + ) + embeddings.extend(batch_embeddings) + if not embeddings or not all(len(vec) == 1024 for vec in embeddings): + raise ValueError("所有嵌入向量必须是长度为 1024 的浮点数列表") + timings["generate_embeddings"] = time.time() - start_embedding + debug(f"生成嵌入向量耗时: {timings['generate_embeddings']:.2f} 秒, 嵌入数量: {len(embeddings)}") - chunks_data = [] - for i, chunk in enumerate(chunks): - chunks_data.append({ - "userid": orgid, - "knowledge_base_id": fiid, - "text": chunk.page_content, - "vector": embeddings[i], - "document_id": id, - "filename": filename + '.' + ext, - "file_path": realpath, - "upload_time": upload_time, - "file_type": ext, - }) + chunks_data = [] + for i, chunk in enumerate(chunks): + chunks_data.append({ + "userid": orgid, + "knowledge_base_id": fiid, + "text": chunk.page_content, + "vector": embeddings[i], + "document_id": id, + "filename": filename + '.' + ext, + "file_path": realpath, + "upload_time": upload_time, + "file_type": ext, + }) - debug(f"调用插入文件端点: {realpath}") - start_milvus = time.time() - for i in range(0, len(chunks_data), 10): # 每次处理 10 条数据 - batch_chunks = chunks_data[i:i + 10] - result = await api_service.milvus_insert_document( - request=request, - chunks=batch_chunks, - db_type=db_type, - upappid=service_params['vdb'], - apiname="milvus/insertdocument", - user=userid - ) - if result.get("status") != "success": - raise ValueError(result.get("message", "Milvus 插入失败")) - timings["insert_milvus"] = time.time() - start_milvus - debug(f"Milvus 插入耗时: {timings['insert_milvus']:.2f} 秒") + debug(f"调用插入文件端点: {realpath}") + start_milvus = time.time() + for i in range(0, len(chunks_data), 10): # 每次处理 10 条数据 + batch_chunks = chunks_data[i:i + 10] + result = await api_service.milvus_insert_document( + request=request, + chunks=batch_chunks, + db_type=db_type, + upappid=service_params['vdb'], + apiname="milvus/insertdocument", + user=userid + ) + if result.get("status") != "success": + raise ValueError(result.get("message", "Milvus 插入失败")) + timings["insert_milvus"] = time.time() - start_milvus + debug(f"Milvus 插入耗时: {timings['insert_milvus']:.2f} 秒") - if result.get("status") != "success": - timings["total"] = time.time() - start_total - return {"status": "error", "document_id": id, "timings": timings, - "message": result.get("message", "未知错误"), "status_code": 400} + if result.get("status") != "success": + timings["total"] = time.time() - start_total + return {"status": "error", "document_id": id, "timings": timings, + "message": result.get("message", "未知错误"), "status_code": 400} - debug("调用三元组抽取服务") - start_triples = time.time() - unique_triples = [] - try: - chunk_texts = [doc.page_content for doc in chunks] - debug(f"处理 {len(chunk_texts)} 个分片进行三元组抽取") - tasks = [ - api_service.extract_triples( - request=request, - text=chunk, - upappid=service_params['triples'], - apiname="Babelscape/mrebel-large", - user=userid - ) for chunk in chunk_texts - ] - results = await asyncio.gather(*tasks, return_exceptions=True) - triples = [] - for i, result in enumerate(results): - if isinstance(result, list): - triples.extend(result) - debug(f"分片 {i + 1} 抽取到 {len(result)} 个三元组: {result[:5]}") - else: - error(f"分片 {i + 1} 处理失败: {str(result)}") + debug("调用三元组抽取服务") + start_triples = time.time() + unique_triples = [] + try: + chunk_texts = [doc.page_content for doc in chunks] + debug(f"处理 {len(chunk_texts)} 个分片进行三元组抽取") + triples = [] + for i, chunk in enumerate(chunk_texts): + result = await api_service.extract_triples( + request=request, + text=chunk, + upappid=service_params['triples'], + apiname="Babelscape/mrebel-large", + user=userid + ) + if isinstance(result, list): + triples.extend(result) + debug(f"分片 {i + 1} 抽取到 {len(result)} 个三元组: {result[:5]}") + else: + error(f"分片 {i + 1} 处理失败: {str(result)}") - seen = set() - for t in triples: - identifier = (t['head'].lower(), t['tail'].lower(), t['type'].lower()) - if identifier not in seen: - seen.add(identifier) - unique_triples.append(t) - else: - for existing in unique_triples: - if (existing['head'].lower() == t['head'].lower() and - existing['tail'].lower() == t['tail'].lower() and - len(t['type']) > len(existing['type'])): - unique_triples.remove(existing) - unique_triples.append(t) - debug(f"替换三元组为更具体类型: {t}") - break + seen = set() + for t in triples: + identifier = (t['head'].lower(), t['tail'].lower(), t['type'].lower()) + if identifier not in seen: + seen.add(identifier) + unique_triples.append(t) + else: + for existing in unique_triples: + if (existing['head'].lower() == t['head'].lower() and + existing['tail'].lower() == t['tail'].lower() and + len(t['type']) > len(existing['type'])): + unique_triples.remove(existing) + unique_triples.append(t) + debug(f"替换三元组为更具体类型: {t}") + break - timings["extract_triples"] = time.time() - start_triples - debug( - f"三元组抽取耗时: {timings['extract_triples']:.2f} 秒, 抽取到 {len(unique_triples)} 个三元组: {unique_triples[:5]}") + timings["extract_triples"] = time.time() - start_triples + debug( + f"三元组抽取耗时: {timings['extract_triples']:.2f} 秒, 抽取到 {len(unique_triples)} 个三元组: {unique_triples[:5]}") - if unique_triples: - debug(f"抽取到 {len(unique_triples)} 个三元组,调用 Neo4j 服务插入") - start_neo4j = time.time() - for i in range(0, len(unique_triples), 30): # 每次插入 30 个三元组 - batch_triples = unique_triples[i:i + 30] - neo4j_result = await api_service.neo4j_insert_triples( - request=request, - triples=batch_triples, - document_id=id, - knowledge_base_id=fiid, - userid=orgid, - upappid=service_params['gdb'], - apiname="neo4j/inserttriples", - user=userid - ) - debug(f"Neo4j 服务响应: {neo4j_result}") - if neo4j_result.get("status") != "success": - timings["insert_neo4j"] = time.time() - start_neo4j - timings["total"] = time.time() - start_total - return { - "status": "error", - "document_id": id, - "collection_name": "ragdb", - "timings": timings, - "message": f"Neo4j 三元组插入失败: {neo4j_result.get('message', '未知错误')}", - "status_code": 400 - } - info(f"文件 {realpath} 三元组成功插入 Neo4j: {neo4j_result.get('message')}") - timings["insert_neo4j"] = time.time() - start_neo4j - debug(f"Neo4j 插入耗时: {timings['insert_neo4j']:.2f} 秒") - else: - debug(f"文件 {realpath} 未抽取到三元组") - timings["insert_neo4j"] = 0.0 + if unique_triples: + debug(f"抽取到 {len(unique_triples)} 个三元组,调用 Neo4j 服务插入") + start_neo4j = time.time() + for i in range(0, len(unique_triples), 30): # 每次插入 30 个三元组 + batch_triples = unique_triples[i:i + 30] + neo4j_result = await api_service.neo4j_insert_triples( + request=request, + triples=batch_triples, + document_id=id, + knowledge_base_id=fiid, + userid=orgid, + upappid=service_params['gdb'], + apiname="neo4j/inserttriples", + user=userid + ) + debug(f"Neo4j 服务响应: {neo4j_result}") + if neo4j_result.get("status") != "success": + timings["insert_neo4j"] = time.time() - start_neo4j + timings["total"] = time.time() - start_total + return { + "status": "error", + "document_id": id, + "collection_name": "ragdb", + "timings": timings, + "message": f"Neo4j 三元组插入失败: {neo4j_result.get('message', '未知错误')}", + "status_code": 400 + } + info(f"文件 {realpath} 三元组成功插入 Neo4j: {neo4j_result.get('message')}") + timings["insert_neo4j"] = time.time() - start_neo4j + debug(f"Neo4j 插入耗时: {timings['insert_neo4j']:.2f} 秒") + else: + debug(f"文件 {realpath} 未抽取到三元组") + timings["insert_neo4j"] = 0.0 - except Exception as e: - timings["extract_triples"] = time.time() - start_triples if "extract_triples" not in timings else \ - timings["extract_triples"] - timings["insert_neo4j"] = time.time() - start_neo4j if "insert_neo4j" not in timings else timings[ - "insert_neo4j"] - debug(f"处理三元组或 Neo4j 插入失败: {str(e)}, 堆栈: {traceback.format_exc()}") - timings["total"] = time.time() - start_total - return { - "status": "success", - "document_id": id, - "collection_name": "ragdb", - "timings": timings, - "unique_triples": unique_triples, - "message": f"文件 {realpath} 成功嵌入,但三元组处理或 Neo4j 插入失败: {str(e)}", - "status_code": 200 - } + except Exception as e: + timings["extract_triples"] = time.time() - start_triples if "extract_triples" not in timings else \ + timings["extract_triples"] + timings["insert_neo4j"] = time.time() - start_neo4j if "insert_neo4j" not in timings else timings[ + "insert_neo4j"] + debug(f"处理三元组或 Neo4j 插入失败: {str(e)}, 堆栈: {traceback.format_exc()}") + timings["total"] = time.time() - start_total + return { + "status": "success", + "document_id": id, + "collection_name": "ragdb", + "timings": timings, + "unique_triples": unique_triples, + "message": f"文件 {realpath} 成功嵌入,但三元组处理或 Neo4j 插入失败: {str(e)}", + "status_code": 200 + } - timings["total"] = time.time() - start_total - debug(f"总耗时: {timings['total']:.2f} 秒") - return { - "status": "success", - "userid": orgid, - "document_id": id, - "collection_name": "ragdb", - "timings": timings, - "unique_triples": unique_triples, - "message": f"文件 {realpath} 成功嵌入并处理三元组", - "status_code": 200 - } + timings["total"] = time.time() - start_total + debug(f"总耗时: {timings['total']:.2f} 秒") + return { + "status": "success", + "userid": orgid, + "document_id": id, + "collection_name": "ragdb", + "timings": timings, + "unique_triples": unique_triples, + "message": f"文件 {realpath} 成功嵌入并处理三元组", + "status_code": 200 + } - except Exception as e: - error(f"插入文档失败: {str(e)}, 堆栈: {traceback.format_exc()}") - timings["total"] = time.time() - start_total - return { - "status": "error", - "document_id": id, - "collection_name": "ragdb", - "timings": timings, - "message": f"插入文档失败: {str(e)}", - "status_code": 400 - } + except Exception as e: + error(f"插入文档失败: {str(e)}, 堆栈: {traceback.format_exc()}") + timings["total"] = time.time() - start_total + return { + "status": "error", + "document_id": id, + "collection_name": "ragdb", + "timings": timings, + "message": f"插入文档失败: {str(e)}", + "status_code": 400 + } - async def file_deleted(self, request, recs, userid): - """删除用户指定文件数据,包括 Milvus 和 Neo4j 中的记录""" - if not isinstance(recs, list): - recs = [recs] # 确保 recs 是列表,即使传入单个记录 - results = [] - api_service = APIService() - total_nodes_deleted = 0 - total_rels_deleted = 0 + async def file_deleted(self, request, recs, userid): + """删除用户指定文件数据,包括 Milvus 和 Neo4j 中的记录""" + if not isinstance(recs, list): + recs = [recs] # 确保 recs 是列表,即使传入单个记录 + results = [] + api_service = APIService() + total_nodes_deleted = 0 + total_rels_deleted = 0 - for rec in recs: - id = rec.get('id', '') - realpath = rec.get('realpath', '') - fiid = rec.get('fiid', '') - orgid = rec.get('ownerid', '') - db_type = '' - collection_name = "ragdb" if not db_type else f"ragdb_{db_type}" + for rec in recs: + id = rec.get('id', '') + realpath = rec.get('realpath', '') + fiid = rec.get('fiid', '') + orgid = rec.get('ownerid', '') + db_type = '' + collection_name = "ragdb" if not db_type else f"ragdb_{db_type}" - try: - required_fields = ['id', 'realpath', 'fiid', 'ownerid'] - missing_fields = [field for field in required_fields if not rec.get(field, '')] - if missing_fields: - raise ValueError(f"缺少必填字段: {', '.join(missing_fields)}") + try: + required_fields = ['id', 'realpath', 'fiid', 'ownerid'] + missing_fields = [field for field in required_fields if not rec.get(field, '')] + if missing_fields: + raise ValueError(f"缺少必填字段: {', '.join(missing_fields)}") - # 获取服务参数 - service_params = await self.get_service_params(orgid) - if not service_params: - raise ValueError("无法获取服务参数") + # 获取服务参数 + service_params = await self.get_service_params(orgid) + if not service_params: + raise ValueError("无法获取服务参数") - debug( - f"调用删除文件端点: userid={orgid}, file_path={realpath}, knowledge_base_id={fiid}, document_id={id}") - milvus_result = await api_service.milvus_delete_document( - request=request, - userid=orgid, - file_path=realpath, - knowledge_base_id=fiid, - document_id=id, - db_type=db_type, - upappid=service_params['vdb'], - apiname="milvus/deletedocument", - user=userid - ) + debug( + f"调用删除文件端点: userid={orgid}, file_path={realpath}, knowledge_base_id={fiid}, document_id={id}") + milvus_result = await api_service.milvus_delete_document( + request=request, + userid=orgid, + file_path=realpath, + knowledge_base_id=fiid, + document_id=id, + db_type=db_type, + upappid=service_params['vdb'], + apiname="milvus/deletedocument", + user=userid + ) - if milvus_result.get("status") != "success": - raise ValueError(milvus_result.get("message", "Milvus 删除失败")) + if milvus_result.get("status") != "success": + raise ValueError(milvus_result.get("message", "Milvus 删除失败")) - neo4j_deleted_nodes = 0 - neo4j_deleted_rels = 0 - try: - debug(f"调用 Neo4j 删除文档端点: document_id={id}") - neo4j_result = await api_service.neo4j_delete_document( - request=request, - document_id=id, - upappid=service_params['gdb'], - apiname="neo4j/deletedocument", - user=userid - ) - if neo4j_result.get("status") != "success": - raise ValueError(neo4j_result.get("message", "Neo4j 删除失败")) - nodes_deleted = neo4j_result.get("nodes_deleted", 0) - rels_deleted = neo4j_result.get("rels_deleted", 0) - neo4j_deleted_nodes += nodes_deleted - neo4j_deleted_rels += rels_deleted - total_nodes_deleted += nodes_deleted - total_rels_deleted += rels_deleted - info(f"成功删除 document_id={id} 的 {nodes_deleted} 个 Neo4j 节点和 {rels_deleted} 个关系") - except Exception as e: - error(f"删除 document_id={id} 的 Neo4j 数据失败: {str(e)}") + neo4j_deleted_nodes = 0 + neo4j_deleted_rels = 0 + try: + debug(f"调用 Neo4j 删除文档端点: document_id={id}") + neo4j_result = await api_service.neo4j_delete_document( + request=request, + document_id=id, + upappid=service_params['gdb'], + apiname="neo4j/deletedocument", + user=userid + ) + if neo4j_result.get("status") != "success": + raise ValueError(neo4j_result.get("message", "Neo4j 删除失败")) + nodes_deleted = neo4j_result.get("nodes_deleted", 0) + rels_deleted = neo4j_result.get("rels_deleted", 0) + neo4j_deleted_nodes += nodes_deleted + neo4j_deleted_rels += rels_deleted + total_nodes_deleted += nodes_deleted + total_rels_deleted += rels_deleted + info(f"成功删除 document_id={id} 的 {nodes_deleted} 个 Neo4j 节点和 {rels_deleted} 个关系") + except Exception as e: + error(f"删除 document_id={id} 的 Neo4j 数据失败: {str(e)}") - results.append({ - "status": "success", - "collection_name": collection_name, - "document_id": id, - "message": f"成功删除文件 {realpath} 的 Milvus 记录,{neo4j_deleted_nodes} 个 Neo4j 节点,{neo4j_deleted_rels} 个 Neo4j 关系", - "status_code": 200 - }) + results.append({ + "status": "success", + "collection_name": collection_name, + "document_id": id, + "message": f"成功删除文件 {realpath} 的 Milvus 记录,{neo4j_deleted_nodes} 个 Neo4j 节点,{neo4j_deleted_rels} 个 Neo4j 关系", + "status_code": 200 + }) - except Exception as e: - error(f"删除文档 {realpath} 失败: {str(e)}, 堆栈: {traceback.format_exc()}") - results.append({ - "status": "error", - "collection_name": collection_name, - "document_id": id, - "message": f"删除文档 {realpath} 失败: {str(e)}", - "status_code": 400 - }) + except Exception as e: + error(f"删除文档 {realpath} 失败: {str(e)}, 堆栈: {traceback.format_exc()}") + results.append({ + "status": "error", + "collection_name": collection_name, + "document_id": id, + "message": f"删除文档 {realpath} 失败: {str(e)}", + "status_code": 400 + }) - return { - "status": "success" if all(r["status"] == "success" for r in results) else "partial", - "results": results, - "total_nodes_deleted": total_nodes_deleted, - "total_rels_deleted": total_rels_deleted, - "message": f"处理 {len(recs)} 个文件,成功删除 {sum(1 for r in results if r['status'] == 'success')} 个", - "status_code": 200 if all(r["status"] == "success" for r in results) else 207 - } + return { + "status": "success" if all(r["status"] == "success" for r in results) else "partial", + "results": results, + "total_nodes_deleted": total_nodes_deleted, + "total_rels_deleted": total_rels_deleted, + "message": f"处理 {len(recs)} 个文件,成功删除 {sum(1 for r in results if r['status'] == 'success')} 个", + "status_code": 200 if all(r["status"] == "success" for r in results) else 207 + } async def test_ragfilemgr(): """测试 RagFileMgr 类的 get_service_params""" @@ -505,11 +503,4 @@ async def test_ragfilemgr(): if __name__ == "__main__": - asyncio.run(test_ragfilemgr()) - - -## usage -# mgr = RagFileMgr(fiid) -# await mgr.add_file(request, params_kw) -# await mgr.delete_file(request, file_id) -## + asyncio.run(test_ragfilemgr()) \ No newline at end of file diff --git a/rag/ragapi.py b/rag/ragapi.py index b143d24..5c3c4da 100644 --- a/rag/ragapi.py +++ b/rag/ragapi.py @@ -5,6 +5,7 @@ from appPublic.log import debug, error, info import time import traceback import json +import math helptext = """kyrag API: @@ -81,11 +82,44 @@ async def fusedsearch(request, params_kw, *params, **kw): # orgid = "04J6VbxLqB_9RPMcgOv_8" # userid = "04J6VbxLqB_9RPMcgOv_8" query = params_kw.get('query', '') - fiids = params_kw.get('fiids', []) - limit = int(params_kw.get('limit', 5)) + # 统一模式处理 limit 参数 + raw_limit = params_kw.get('limit') or ( + params_kw.get('retrieval_setting', {}).get('top_k') + if isinstance(params_kw.get('retrieval_setting'), dict) + else None + ) + + # 标准化为整数值 + if raw_limit is None: + limit = 5 # 两个来源都不存在时使用默认值 + elif isinstance(raw_limit, (int, float)): + limit = int(raw_limit) # 数值类型直接转换 + elif isinstance(raw_limit, str): + try: + # 字符串转换为整数 + limit = int(raw_limit) + except (TypeError, ValueError): + limit = 5 # 转换失败使用默认值 + else: + limit = 5 # 其他意外类型使用默认值 + debug(f"limit: {limit}") + raw_fiids = params_kw.get('fiids') or params_kw.get('knowledge_id') + + # 标准化为列表格式 + if raw_fiids is None: + fiids = [] # 两个参数都不存在 + elif isinstance(raw_fiids, list): + fiids = [str(item).strip() for item in raw_fiids] # 已经是列表 + elif isinstance(raw_fiids, str): + # 处理逗号分隔的字符串或单个ID字符串 + fiids = [f.strip() for f in raw_fiids.split(',') if f.strip()] + elif isinstance(raw_fiids, (int, float)): + fiids = [str(int(raw_fiids))] # 数值类型转为字符串列表 + else: + fiids = [] # 其他意外类型 + debug(f"fiids: {fiids}") - if isinstance(fiids, str): - fiids = [f.strip() for f in fiids.split(',') if f.strip()] + # 验证 fiids的orgid与orgid = await f()是否一致 if fiids: db = DBPools() @@ -197,6 +231,7 @@ async def fusedsearch(request, params_kw, *params, **kw): # 调用搜索端点 sum = limit + 5 search_start = time.time() + debug(f"orgid: {orgid}") result = await api_service.milvus_search_query( request=request, query_vector=query_vector, @@ -240,8 +275,34 @@ async def fusedsearch(request, params_kw, *params, **kw): timing_stats["total_time"] = time.time() - start_time info(f"融合搜索完成,返回 {len(unique_results)} 条结果,总耗时: {timing_stats['total_time']:.3f} 秒") - return {"results": unique_results[:limit], "timing": timing_stats} + # debug(f"results: {unique_results[:limit]},timing: {timing_stats}") + # return {"results": unique_results[:limit], "timing": timing_stats} + + + dify_records = [] + dify_result = [] + for res in unique_results[:limit]: + rerank_score = res.get('rerank_score', 0) + score = 1 / (1 + math.exp(-rerank_score)) if rerank_score is not None else 1 - res.get('distance', 0) + score = max(0.0, min(1.0, score)) + content = res.get('text', '') + title = res.get('metadata', {}).get('filename', 'Untitled') + document_id = res.get('metadata', {}).get('document_id', '') + dify_records.append({ + "content": content, + "score": score, + "title": title + }) + dify_result.append({ + "content": content, + "title": title, + "metadata": {"document_id": document_id} + }) + + info(f"融合搜索完成,返回 {len(dify_records)} 条结果,总耗时: {(time.time() - start_time):.3f} 秒") + debug(f"records: {dify_records}, result: {dify_result}") + return {"records": dify_records, "result": dify_result, "own":{"results": unique_results[:limit], "timing": timing_stats}} except Exception as e: error(f"融合搜索失败: {str(e)}, 堆栈: {traceback.format_exc()}") return {"results": [], "timing": timing_stats} diff --git a/rag/uapi_service.py b/rag/uapi_service.py index 0c5ddcc..0bd2e13 100644 --- a/rag/uapi_service.py +++ b/rag/uapi_service.py @@ -321,6 +321,7 @@ class APIService: async def milvus_search_query(self, request, query_vector: List[float], userid: str, knowledge_base_ids: list, limit: int, offset: int, upappid: str, apiname: str, user: str) -> Dict[str, Any]: """根据用户知识库检索 Milvus""" request_id = str(uuid.uuid4()) + debug(f"userid:{userid}") debug(f"Request #{request_id} started for Milvus search") try: uapi = UAPI(request, DictObject(**globals()))