From 35d0eb22340c864e3c6fc5428aa7295b521f789b Mon Sep 17 00:00:00 2001 From: wangmeihua <13383952685@163.com> Date: Tue, 14 Oct 2025 13:23:57 +0800 Subject: [PATCH] rag --- rag/folderinfo.py | 1 + rag/init.py | 3 +- rag/rag_operations.py | 32 ++++++++++ rag/ragapi.py | 129 ++++++++++++++++++++++++++++++++++++++- rag/service_opts.py | 2 + rag/uapi_service.py | 4 +- wwwroot/upload_file.dspy | 10 ++- 7 files changed, 172 insertions(+), 9 deletions(-) diff --git a/rag/folderinfo.py b/rag/folderinfo.py index 8cf3d99..89215eb 100644 --- a/rag/folderinfo.py +++ b/rag/folderinfo.py @@ -99,6 +99,7 @@ where a.orgid = b.orgid # 获取服务参数 service_params = await get_service_params(orgid) + debug(f"服务参数是:{service_params}") if not service_params: raise ValueError("无法获取服务参数") rollback_context["service_params"] = service_params diff --git a/rag/init.py b/rag/init.py index 78ef64f..56bf3eb 100644 --- a/rag/init.py +++ b/rag/init.py @@ -10,7 +10,7 @@ from .ragllm_utils import get_ragllms_by_catelog from appPublic.registerfunction import RegisterFunction from sqlor.dbpools import DBPools from appPublic.registerfunction import RegisterFunction -from rag.ragapi import docs, get_kdbs, fusedsearch +from rag.ragapi import docs, get_kdbs, fusedsearch, textinsert async def get_user_kdbs(request): env = request._run_ns @@ -40,6 +40,7 @@ def load_rag(): rf.register('docs', docs) rf.register('get_kdbs', get_kdbs) rf.register('fusedsearch', fusedsearch) + rf.register('textinsert', textinsert) diff --git a/rag/rag_operations.py b/rag/rag_operations.py index 871f743..0280397 100644 --- a/rag/rag_operations.py +++ b/rag/rag_operations.py @@ -131,6 +131,7 @@ class RagOperations: start_milvus = time.time() for i in range(0, len(chunks_data), 10): batch_chunks = chunks_data[i:i + 10] + debug(f"传入的数据是:{batch_chunks}") result = await self.api_service.milvus_insert_document( request=request, chunks=batch_chunks, @@ -181,6 +182,37 @@ class RagOperations: return chunks_data + async def insert_to_vector_text(self, request, + db_type: str, fields: Dict, service_params: Dict, userid: str, timings: Dict) -> List[Dict]: + """插入单一纯文本到向量数据库,支持动态 schema""" + + chunk_data = {} + debug("准备单一纯文本数据并调用插入端点") + start = time.time() + for key, value in fields.items(): + chunk_data[key] = value + + chunks_data = [chunk_data] + debug(f"向量库插入传入的数据是:{chunks_data}") + + # 调用 Milvus 插入 + result = await self.api_service.milvus_insert_document( + request=request, + chunks=chunks_data, + upappid=service_params['vdb'], + apiname="milvus/insertdocument", + user=userid, + db_type=db_type + ) + if result.get("status") != "success": + raise ValueError(result.get("message", "Milvus 插入失败")) + + debug(f"成功插入纯文本到集合 {result.get('collection_name')}") + timings["textinsert"] = time.time() - start + debug(f"插入纯文本耗时: {timings['textinsert']:.2f} 秒") + + return chunks_data + async def extract_triples(self, request, chunks: List[Document], service_params: Dict, userid: str, timings: Dict, transaction_mgr: TransactionManager = None) -> List[Dict]: diff --git a/rag/ragapi.py b/rag/ragapi.py index 7e3d28e..9174b00 100644 --- a/rag/ragapi.py +++ b/rag/ragapi.py @@ -5,6 +5,7 @@ import time import traceback import json import math +import uuid from rag.service_opts import get_service_params, sor_get_service_params from rag.rag_operations import RagOperations @@ -24,13 +25,42 @@ headers: { "Content-Type": "application/json" } data: { - "query": "什么是知识抽取。" - "fiids":["1"] + "query": "什么是知识抽取。", + "fiids":["1"], + "limit":5 } 3、docs文档 path: /v1/docs +4. 纯文本插入接口: +path: /v1/textinsert +headers: { + "Content-Type": "application/json" +} +data: { + "text": "要插入的纯文本内容", + "fiid": "知识库ID", + "db_type": "数据库类型(如 textdb)" +} +response: { + "status": "success", + "userid": "用户组织ID", + "collection_name": "ragdb_{dbtype}", + "message": "文本成功嵌入并处理三元组", + "status_code": 200, + "unique_triples": ["提取的三元组列表"], + "timings": {"total": 0.123} +} +error response: { + "status": "error", + "userid": "用户组织ID", + "collection_name": "ragdb_{dbtype}", + "message": "错误信息", + "status_code": 400, + "timings": {"total": 0.123} +} + 4. 添加用户消息到记忆: path: /v1/add_user_messages headers: { @@ -90,7 +120,10 @@ async def get_kdbs(request, params_kw, *params, **kw): return result async def fusedsearch(request, params_kw, *params): - """融合搜索,调用服务化端点""" + """ + 融合搜索,调用服务化端点 + + """ kw = request._run_ns f = kw.get('get_userorgid') orgid = await f() @@ -193,6 +226,96 @@ async def fusedsearch(request, params_kw, *params): "error": str(e) } +# async def text_insert(text: str, fiid: str, orgid: str, db_type: str): +async def textinsert(request, params_kw, *params): + kw = request._run_ns + f = kw.get('get_userorgid') + orgid = await f() + debug(f"orgid: {orgid},{f=}") + f = kw.get('get_user') + userid = await f() + text = params_kw.get('text', '') + fiid = params_kw.get('fiid') + db_type = params_kw.get('db_type') + id = str(uuid.uuid4()) + debug(f"params_kw: {params_kw}") + # orgid = "04J6VbxLqB_9RPMcgOv_8" + # userid = "04J6VbxLqB_9RPMcgOv_8" + # text = text + # fiid = fiid + # orgid = orgid + # db_type = db_type + # id = str(uuid.uuid4()) + # debug(f'Inserting document: text={text},userid={orgid}, db_type={db_type}, knowledge_base_id={fiid}') + + timings = {} + start_total = time.time() + + result = { + "status": "error", + "userid": orgid, + "collection_name": "ragdb_{dbtype}", + "message": "", + "status_code": 400 + } + + try: + # 验证必填字段 + if not orgid or not fiid or not text or not db_type: + raise ValueError("orgid、fiid、db_type 和 text 不能为空") + if len(orgid) > 32 or len(fiid) > 255: + raise ValueError("orgid 或 fiid 的长度超出限制") + + # 获取服务参数 + service_params = await get_service_params(orgid) + if not service_params: + raise ValueError("无法获取服务参数") + rag_ops = RagOperations() + # 生成嵌入向量 + embedding = await rag_ops.generate_query_vector(request, text, service_params, userid, timings) + + # 插入 Milvus + fields = { + "text": text, + "fiid": fiid, + "orgid": orgid, + "vector": embedding, + "id": id + } + chunks_data = await rag_ops.insert_to_vector_text(request, db_type, fields, service_params, userid, timings) + + # 抽取三元组 + document = Document(page_content=text) + chunks = [document] + triples = await rag_ops.extract_triples( + request, chunks, service_params, userid, timings + ) + + # 插入 Neo4j + await rag_ops.insert_to_graph_db( + request, triples, id, fiid, orgid, service_params, userid, timings + ) + + timings["total"] = time.time() - start_total + result.update({ + "status": "success", + "unique_triples": triples, + "message": f"文本成功嵌入并处理三元组", + "status_code": 200 + }) + debug(f"总耗时: {timings['total']:.2f} 秒") + + except Exception as e: + error(f"插入文档失败: {str(e)}, 堆栈: {traceback.format_exc()}") + timings["total"] = time.time() - start_total + result.update({ + "message": f"插入文档失败: {str(e)}", + "timings": timings + }) + raise ValueError(str(e)) from e + debug(f"最终结果是:{result}") + return result + async def _validate_fiids_orgid(fiids, orgid, kw): """验证 fiids 的 orgid 与当前用户 orgid 是否一致""" if fiids: diff --git a/rag/service_opts.py b/rag/service_opts.py index 6e07d3b..37f3b77 100644 --- a/rag/service_opts.py +++ b/rag/service_opts.py @@ -1,5 +1,6 @@ from ahserver.serverenv import get_serverenv from sqlor.dbpools import DBPools +from appPublic.log import debug, error, info async def sor_get_service_params(sor, orgid): """ 根据 orgid 从数据库获取服务参数 (仅 upappid),假设 service_opts 表返回单条记录。 """ @@ -70,6 +71,7 @@ async def sor_get_service_params(sor, orgid): async def get_service_params(orgid): db = DBPools() + debug(f"传入的orgid是:{orgid}") dbname = get_serverenv('get_module_dbname')('rag') async with db.sqlorContext(dbname) as sor: return await sor_get_service_params(sor, orgid) diff --git a/rag/uapi_service.py b/rag/uapi_service.py index bd99273..c429dc5 100644 --- a/rag/uapi_service.py +++ b/rag/uapi_service.py @@ -275,7 +275,7 @@ class APIService: uapi = UAPI(request, DictObject(**globals())) params_kw = { "chunks": chunks, - "dbtype": db_type + "db_type": db_type } b = await uapi.call(upappid, apiname, user, params_kw) d = await self.handle_uapi_response(b, upappid, apiname, "Milvus 文档插入服务", request_id) @@ -296,7 +296,7 @@ class APIService: "file_path": file_path, "knowledge_base_id": knowledge_base_id, "document_id": document_id, - "dbtype": db_type + "db_type": db_type } b = await uapi.call(upappid, apiname, user, params_kw) d = await self.handle_uapi_response(b, upappid, apiname, "Milvus 文档删除服务", request_id) diff --git a/wwwroot/upload_file.dspy b/wwwroot/upload_file.dspy index 0b90e61..5ceaed6 100644 --- a/wwwroot/upload_file.dspy +++ b/wwwroot/upload_file.dspy @@ -1,6 +1,10 @@ fmgr = RagFileMgr(params_kw.fiid) try: - stat = await fmgr.add_file(request, params_kw) - return UiMessage(title='Add file', message='file add success') + stat = await fmgr.add_file(request, params_kw) + if stat.get('status_code') == 200: + return UiMessage(title='Add file', message='file add success') + else: + error_message = stat.get('message', 'Unknown error occurred') + return UiError(title='Add file', message=f'file add failed: {error_message}') except Exception as e: - return UiError(title='Add file', message=f'file add failed({e})') + return UiError(title='Add file', message=f'file add failed({e})')