rag
This commit is contained in:
parent
cddb4733fb
commit
3e69908c25
210
rag/aslmapi.py
Normal file
210
rag/aslmapi.py
Normal file
@ -0,0 +1,210 @@
|
||||
import asyncio
|
||||
from typing import List, Dict, Any, Optional
|
||||
from mem0 import AsyncMemory
|
||||
from appPublic.log import debug, error, info, exception
|
||||
import atexit
|
||||
|
||||
# 配置
|
||||
CONFIG = {
|
||||
"llm": {
|
||||
"provider": "openai",
|
||||
"config": {
|
||||
"model": "Qwen3-14B-FP8",
|
||||
"temperature": 0.3,
|
||||
"max_tokens": 2000,
|
||||
"openai_base_url": "https://t2t.opencomputing.net:10443/qwen3-14b-fp8/v1",
|
||||
"api_key": "any-key"
|
||||
}
|
||||
},
|
||||
"embedder": {
|
||||
"provider": "openai",
|
||||
"config": {
|
||||
"model": "bge-m3",
|
||||
"embedding_dims": 1024,
|
||||
"openai_base_url": "https://embedding.opencomputing.net:10443/v1",
|
||||
"api_key": "any-key"
|
||||
}
|
||||
},
|
||||
"vector_store": {
|
||||
"provider": "milvus",
|
||||
"config": {
|
||||
"collection_name": "mem0",
|
||||
"url": "/home/wangmeihua/mem0/milvus_demo.db",
|
||||
"embedding_model_dims": 1024,
|
||||
"metric_type": "COSINE",
|
||||
"db_name": "milvus"
|
||||
}
|
||||
},
|
||||
"version": "v1.1"
|
||||
}
|
||||
|
||||
|
||||
class MemoryManager:
|
||||
"""管理用户记忆的类,基于 AsyncMemory 提供添加和检索功能"""
|
||||
_instance = None
|
||||
|
||||
def __new__(cls):
|
||||
if cls._instance is None:
|
||||
cls._instance = super(MemoryManager, cls).__new__(cls)
|
||||
return cls._instance
|
||||
|
||||
def __init__(self):
|
||||
if not hasattr(self, '_initialized'):
|
||||
self.memory = None
|
||||
self._initialized = False
|
||||
atexit.register(self._close)
|
||||
|
||||
async def initialize(self):
|
||||
"""异步初始化 AsyncMemory"""
|
||||
if self._initialized:
|
||||
debug(f"AsyncMemory 已初始化,跳过")
|
||||
try:
|
||||
self.memory = await AsyncMemory.from_config(CONFIG)
|
||||
self._initialized = True
|
||||
debug("AsyncMemory 初始化成功")
|
||||
except Exception as e:
|
||||
error(f"AsyncMemory 初始化失败: {e}")
|
||||
raise
|
||||
|
||||
def _close(self):
|
||||
if self.memory and hasattr(self.memory.vector_store, 'client'):
|
||||
try:
|
||||
self.memory.vector_store.client.close()
|
||||
debug(f"Milvus连接已关闭")
|
||||
except Exception as e:
|
||||
error(f"关闭milvus连接失败:{e}")
|
||||
if self.memory and hasattr(self.memory.db, 'connection'):
|
||||
try:
|
||||
self.memory.db.connection.close()
|
||||
debug("SQLite 连接已关闭")
|
||||
except Exception as e:
|
||||
error(f"关闭 SQLite 连接失败: {e}")
|
||||
|
||||
async def _ensure_initialized(self):
|
||||
"""确保 AsyncMemory 已初始化(懒加载)"""
|
||||
if self.memory is None:
|
||||
await self.initialize()
|
||||
|
||||
async def add_messages_to_memory(self, messages: List[Dict[str, str]], user_id: str) -> Dict[str, Any]:
|
||||
await self._ensure_initialized()
|
||||
try:
|
||||
result = await self.memory.add(messages, user_id=user_id)
|
||||
debug(f"用户 {user_id} 的消息添加成功,结果: {result}")
|
||||
return {"status": "success", "result": result}
|
||||
except Exception as e:
|
||||
error(f"用户 {user_id} 的消息添加失败: {e}")
|
||||
return {"status": "error", "message": str(e)}
|
||||
|
||||
async def search_user_memories(self, query: str, user_id: str, limit: int = 5) -> List[Dict[str, Any]]:
|
||||
await self._ensure_initialized()
|
||||
try:
|
||||
result = await self.memory.search(query=query, user_id=user_id, limit=limit)
|
||||
existing_memories = result["results"]
|
||||
debug(f"用户 {user_id} 的记忆检索成功,找到 {len(existing_memories)} 条记录")
|
||||
return existing_memories
|
||||
except Exception as e:
|
||||
error(f"用户 {user_id} 的记忆检索失败: {e}")
|
||||
return []
|
||||
|
||||
async def get_all_memories(self, user_id: Optional[str] = None, agent_id: Optional[str] = None,
|
||||
run_id: Optional[str] = None, filters: Optional[Dict[str, Any]] = None,
|
||||
limit: int = 100) -> Dict[str, Any]:
|
||||
await self._ensure_initialized()
|
||||
try:
|
||||
result = await self.memory.get_all(user_id=user_id, agent_id=agent_id, run_id=run_id,
|
||||
filters=filters, limit=limit)
|
||||
debug(f"检索所有记忆成功,找到 {len(result.get('results', []))} 条记录")
|
||||
return {"status": "success", "result": result}
|
||||
except Exception as e:
|
||||
error(f"检索所有记忆失败: {e}")
|
||||
return {"status": "error", "message": str(e)}
|
||||
|
||||
async def delete_all_memories(self, user_id: Optional[str] = None, agent_id: Optional[str] = None,
|
||||
run_id: Optional[str] = None) -> Dict[str, Any]:
|
||||
await self._ensure_initialized()
|
||||
try:
|
||||
result = await self.memory.delete_all(user_id=user_id, agent_id=agent_id, run_id=run_id)
|
||||
debug(f"删除所有记忆成功: {result}")
|
||||
return {"status": "success", "result": result}
|
||||
except Exception as e:
|
||||
error(f"删除所有记忆失败: {e}")
|
||||
return {"status": "error", "message": str(e)}
|
||||
|
||||
async def get_memory_history(self, memory_id: str) -> Dict[str, Any]:
|
||||
await self._ensure_initialized()
|
||||
try:
|
||||
result = await self.memory.history(memory_id)
|
||||
debug(f"获取记忆 ID {memory_id} 的历史成功,找到 {len(result)} 条记录")
|
||||
return {"status": "success", "result": result}
|
||||
except Exception as e:
|
||||
error(f"获取记忆 ID {memory_id} 的历史失败: {e}")
|
||||
return {"status": "error", "message": str(e)}
|
||||
|
||||
async def reset_memory(self) -> Dict[str, Any]:
|
||||
await self._ensure_initialized()
|
||||
try:
|
||||
await self.memory.reset()
|
||||
debug("记忆存储重置成功")
|
||||
return {"status": "success", "result": {"message": "Memory store reset successfully"}}
|
||||
except Exception as e:
|
||||
error(f"记忆存储重置失败: {e}")
|
||||
return {"status": "error", "message": str(e)}
|
||||
|
||||
|
||||
async def test_memory_functions():
|
||||
"""异步测试函数,验证 MemoryManager 的 add_messages_to_memory 和 search_user_memories"""
|
||||
# 初始化 MemoryManager
|
||||
manager = MemoryManager()
|
||||
#测试获取所有记忆
|
||||
test_user_id = "test_user_123"
|
||||
try:
|
||||
result = await manager.get_all_memories(user_id=test_user_id, limit=10)
|
||||
if result["status"] != "success":
|
||||
error(f"测试获取所有记忆失败: {result['message']}")
|
||||
return
|
||||
debug(f"测试获取所有记忆成功: {result}")
|
||||
except Exception as e:
|
||||
error(f"测试获取所有记忆失败: {e}")
|
||||
return
|
||||
|
||||
# # 测试数据
|
||||
# test_user_id = "test_user_123"
|
||||
# test_messages = [
|
||||
# {"role": "user", "content": "您好,请您为我推荐一款饮料"},
|
||||
# {"role": "assistant", "content": "好的,您可以试试元气森林、冰红茶、水溶C等饮料"},
|
||||
# {"role": "user", "content": "感谢您的推荐,我喜欢喝元气森林。"},
|
||||
# {"role": "assistant", "content": "好的"}
|
||||
# ]
|
||||
#
|
||||
# # 测试添加消息
|
||||
# try:
|
||||
# result = await manager.add_messages_to_memory(test_messages, test_user_id)
|
||||
# if result["status"] != "success":
|
||||
# error(f"测试添加消息失败: {result['message']}")
|
||||
# return
|
||||
# if "result" not in result:
|
||||
# error("测试添加消息失败: 返回结果中缺少 'result' 键")
|
||||
# return
|
||||
# debug(f"测试添加消息成功: {result}")
|
||||
# except Exception as e:
|
||||
# error(f"测试添加消息失败: {e}")
|
||||
# return
|
||||
#
|
||||
# # 测试检索记忆
|
||||
# query = "今天想吃水果"
|
||||
# try:
|
||||
# memories = await manager.search_user_memories(query, test_user_id, limit=3)
|
||||
# if not isinstance(memories, list):
|
||||
# error("测试检索记忆失败: 返回结果不是列表")
|
||||
# return
|
||||
# if memories and ("user_id" not in memories[0] or memories[0]["user_id"] != test_user_id):
|
||||
# error("测试检索记忆失败: 返回结果的 user_id 不匹配")
|
||||
# return
|
||||
# debug(f"测试检索记忆成功,返回 {len(memories)} 条记录")
|
||||
# debug(f"检索结果: {memories}")
|
||||
# except Exception as e:
|
||||
# error(f"测试检索记忆失败: {e}")
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
asyncio.run(test_memory_functions())
|
||||
24
wwwroot/get_user_kdb.dspy
Normal file
24
wwwroot/get_user_kdb.dspy
Normal file
@ -0,0 +1,24 @@
|
||||
debug(params_kw)
|
||||
ns = params_kw.copy()
|
||||
orgid = await get_userorgid()
|
||||
debug(f'orgid={orgid}')
|
||||
|
||||
db = DBPools()
|
||||
sql = f"""select {params_kw.tblvalue} as {params_kw.valueField},
|
||||
{params_kw.tbltext} as {params_kw.textField}
|
||||
from {params_kw.table} where 1=1 """
|
||||
|
||||
if orgid:
|
||||
sql += " and orgid = ${orgid}$ "
|
||||
ns['orgid'] = orgid
|
||||
|
||||
if params_kw.get('cond'):
|
||||
sql += f" and {params_kw.cond} "
|
||||
|
||||
sql += f"order by {params_kw.textField}"
|
||||
|
||||
debug(f'/rag/kdb/get_user_kdb.dspy: {sql=}, ns={ns}')
|
||||
|
||||
async with db.sqlorContext(params_kw.dbname) as sor:
|
||||
rs = await sor.sqlExe(sql, ns)
|
||||
return rs if rs else []
|
||||
67
wwwroot/test.ui
Normal file
67
wwwroot/test.ui
Normal file
@ -0,0 +1,67 @@
|
||||
{
|
||||
"widgettype": "VBox",
|
||||
"options": {
|
||||
"width": "100%",
|
||||
"height": "100%",
|
||||
"bgcolor": "#f5f5f5"
|
||||
},
|
||||
"subwidgets": [
|
||||
{
|
||||
"widgettype": "Form",
|
||||
"id": "test_form",
|
||||
"options": {
|
||||
"fields": [
|
||||
{
|
||||
"name": "query",
|
||||
"uitype": "text",
|
||||
"label": "输入查询文本",
|
||||
"editable": true,
|
||||
"rows": 5
|
||||
},
|
||||
{
|
||||
"name": "fiids",
|
||||
"uitype": "checkbox",
|
||||
"label": "选择知识库",
|
||||
"valueField": "id",
|
||||
"textField": "name",
|
||||
"params": {
|
||||
"dbname": "{{get_module_dbname('rag')}}",
|
||||
"table": "kdb",
|
||||
"tblvalue": "id",
|
||||
"tbltext": "name",
|
||||
"valueField": "id",
|
||||
"textField": "name"
|
||||
},
|
||||
"dataurl": "{{entire_url('/rag/get_user_kdb.dspy')}}",
|
||||
"multicheck": true
|
||||
},
|
||||
{
|
||||
"name": "limit",
|
||||
"uitype": "int",
|
||||
"label": "输入返回条数"
|
||||
}
|
||||
]
|
||||
}
|
||||
},
|
||||
{
|
||||
"id": "result_panel",
|
||||
"widgettype": "VScrollPanel",
|
||||
"options": {
|
||||
"height": "90%",
|
||||
"bgcolor": "#ffffff",
|
||||
"border": "1px solid #cccccc"
|
||||
}
|
||||
}
|
||||
],
|
||||
"binds": [
|
||||
{
|
||||
"wid": "test_form",
|
||||
"event": "submit",
|
||||
"actiontype": "urlwidget",
|
||||
"target": "result_panel",
|
||||
"options": {
|
||||
"url": "{{entire_url('/rag/test_query.dspy')}}"
|
||||
}
|
||||
}
|
||||
]
|
||||
}
|
||||
34
wwwroot/test_query.dspy
Normal file
34
wwwroot/test_query.dspy
Normal file
@ -0,0 +1,34 @@
|
||||
debug(f'{params_kw=}')
|
||||
orgid = await get_userorgid()
|
||||
|
||||
if not orgid:
|
||||
return UiError(
|
||||
title='授权错误',
|
||||
message='请先登录'
|
||||
)
|
||||
|
||||
fiids = params_kw.fiids
|
||||
query = params_kw.query
|
||||
limit = params_kw.limit
|
||||
|
||||
if not query or not fiids or not limit:
|
||||
return UiError(
|
||||
title='无效输入',
|
||||
message='请输入查询文本并选择至少一个知识库'
|
||||
)
|
||||
|
||||
try:
|
||||
env = DictObject(**globals())
|
||||
keys = [k for k in env.keys()]
|
||||
debug(f'{keys=}')
|
||||
result = await rfexe('fusedsearch', request, params_kw)
|
||||
debug(f'fusedsearch result: {result}')
|
||||
return {
|
||||
"widgettype":"MdWidget",
|
||||
"options":{
|
||||
"width": "100%",
|
||||
"mdtext": json.dumps(result, ensure_ascii=False)
|
||||
}
|
||||
}
|
||||
except Exception as e:
|
||||
return UiError(title='search failed', message=f'召回失败,failed({e})')
|
||||
Loading…
x
Reference in New Issue
Block a user