diff --git a/rag/init.py b/rag/init.py index becdec1..88e8254 100644 --- a/rag/init.py +++ b/rag/init.py @@ -8,6 +8,7 @@ from .folderinfo import RagFileMgr from .ragprogram import set_program, get_rag_programs from ragllm_utils import get_ragllms from appPublic.registerfunction import RegisterFunction +from sqlor.dbpools import DBPools async def _make_connection_request(action: str, params: dict = None) -> dict: """ @@ -138,6 +139,17 @@ async def docs() -> dict: """列出所有用户的知识库及其文件""" return await _make_connection_request("docs", {}) +async def get_user_kdbs(request): + env = request._run_ns + db = env.DBPools() + dbnme = env.get_module_dbname('rag') + userorgid = await env.get_userorgid() + async with db.sqlorContext(dbname) as sor: + sql = "select * from kdb where ownerid = ${orgid}$" + recs = await sor.sqlExe(sql, {'orgid': userorgid}) + return recs + return recs + def load_rag(): """ 初始化 ServerEnv,绑定 MilvusConnection 的所有功能。 @@ -157,4 +169,6 @@ def load_rag(): env.set_program = set_program env.get_rag_programs = get_rag_programs env.get_ragllms = get_ragllms + env.get_user_kdbs = get_user_kdbs + diff --git a/rag/ragllm_utils.py b/rag/ragllm_utils.py index c6467cc..57cdbc2 100644 --- a/rag/ragllm_utils.py +++ b/rag/ragllm_utils.py @@ -1,7 +1,7 @@ from sqlor.dbpools import DBPools async def get_ragllms_by_catelog(request, **params): - catelogid = params.get('catelogid') + catelogid = params.get('id') if not catelogid: raise 'need applies catelogid'