diff --git a/rag/ragapi.py b/rag/ragapi.py index dca0925..cfbd795 100644 --- a/rag/ragapi.py +++ b/rag/ragapi.py @@ -7,6 +7,7 @@ import json import math from rag.service_opts import get_service_params, sor_get_service_params from rag.rag_operations import RagOperations +from rag.aslmapi import MemoryManager helptext = """kyrag API: @@ -31,8 +32,25 @@ data: { 3、docs文档 path: /v1/docs -4、longmemory存储 +4. 添加用户消息到记忆: +path: /v1/add_user_messages +headers: { + "Content-Type": "application/json" +} +data: { + "orgid": "用户组织ID", + "messages": [{"role": "user", "content": "消息内容"}, ...] +} +5. 获取用户所有记忆: +path: /v1/get_user_memories +headers: { + "Content-Type": "application/json" +} +data: { + "orgid": "用户组织ID", + "limit": 10 # 可选,默认为 10 +} """ async def docs(request, params_kw, *params, **kw): @@ -72,8 +90,9 @@ async def get_kdbs(request, params_kw, *params, **kw): }) return result -async def fusedsearch(request, params_kw, *params, **kw): +async def fusedsearch(request, params_kw, *params): """融合搜索,调用服务化端点""" + kw = request._run_ns f = kw.get('get_userorgid') orgid = await f() debug(f"orgid: {orgid},{f=}") @@ -155,6 +174,7 @@ async def fusedsearch(request, params_kw, *params, **kw): if use_rerank and search_results: final_results = await rag_ops.rerank_results(request, combined_text, search_results, limit, service_params, userid, timings) + debug(f"final_results: {final_results}") else: final_results = [{k: v for k, v in r.items() if k != 'rerank_score'} for r in search_results] @@ -175,6 +195,87 @@ async def fusedsearch(request, params_kw, *params, **kw): } +async def add_user_messages(request, params_kw, *params, **kw): + """添加用户消息到记忆库 + + Args: + request: HTTP 请求对象 + params_kw: 请求参数字典,包含 orgid 和 messages + *params: 任意位置参数 + **kw: 任意关键字参数(如 get_module_dbname, get_userorgid) + + Returns: + str: JSON 格式的响应,包含 status 和 result 或错误信息 + """ + debug(f"Received request: path={request.path}, params_kw={params_kw}, kw={kw}") + orgid = params_kw.get('orgid') + messages = params_kw.get('messages') + + if not orgid or not isinstance(orgid, str): + error("orgid 参数无效,必须为非空字符串") + return json.dumps({"status": "error", "message": "orgid 参数无效,必须为非空字符串"}) + + if not messages or not isinstance(messages, list): + error("messages 参数无效,必须为非空列表") + return json.dumps({"status": "error", "message": "messages 参数无效,必须为非空列表"}) + + try: + manager = MemoryManager() + result = await manager.add_messages_to_memory(messages, orgid) + debug(f"用户消息添加结果: {result}") + return json.dumps(result, ensure_ascii=False) + except Exception as e: + error(f"添加用户消息失败: {str(e)}") + return json.dumps({"status": "error", "message": str(e)}) + +async def get_user_memories(request, params_kw, *params, **kw): + """获取用户的所有记忆记录 + + Args: + request: HTTP 请求对象 + params_kw: 请求参数字典,包含 orgid 和可选的 limit + *params: 任意位置参数 + **kw: 任意关键字参数(如 get_module_dbname, get_userorgid) + + Returns: + str: JSON 格式的响应,包含 status 和 result 或错误信息 + """ + debug(f"Received request: path={request.path}, params_kw={params_kw}, kw={kw}") + orgid = params_kw.get('orgid') + limit = params_kw.get('limit', 10) + + if not orgid or not isinstance(orgid, str): + error("orgid 参数无效,必须为非空字符串") + return json.dumps({"status": "error", "message": "orgid 参数无效,必须为非空字符串"}) + + try: + limit = int(limit) if isinstance(limit, (str, int, float)) else 10 + manager = MemoryManager() + result = await manager.get_all_memories(user_id=orgid, limit=limit) + debug(f"用户 {orgid} 的记忆检索结果: {result}") + return json.dumps(result, ensure_ascii=False) + except Exception as e: + error(f"获取用户 {orgid} 的记忆失败: {str(e)}") + return json.dumps({"status": "error", "message": str(e)}) + +async def _validate_fiids_orgid(fiids, orgid, kw): + """验证 fiids 的 orgid 与当前用户 orgid 是否一致""" + if fiids: + db = DBPools() + dbname = kw.get('get_module_dbname')('rag') + sql_opts = """SELECT orgid FROM kdb WHERE id = ${id}$""" + try: + async with db.sqlorContext(dbname) as sor: + result = await sor.sqlExe(sql_opts, {"id": fiids[0]}) + if not result: + raise ValueError(f"未找到 fiid={fiids[0]} 的记录") + kdb_orgid = result[0].get('orgid') + if kdb_orgid != orgid: + raise ValueError(f"orgid 不一致: kdb.orgid={kdb_orgid}, user orgid={orgid}") + except Exception as e: + error(f"orgid 验证失败: {str(e)}") + raise + async def _validate_fiids_orgid(fiids, orgid, kw): """验证 fiids 的 orgid 与当前用户 orgid 是否一致""" if fiids: @@ -213,208 +314,4 @@ def _combine_query_with_triplets(query, triplets): debug(f"拼接文本: {combined_text[:200]}... (总长度: {len(combined_text)}, 三元组数量: {len(triplet_texts)})") return combined_text - # api_service = APIService() - # start_time = time.time() - # timing_stats = {} - # try: - # info( - # f"开始融合搜索: query={query}, userid={orgid}, knowledge_base_ids={fiids}") - # - # if not query or not orgid or not fiids: - # raise ValueError("query、orgid 和 knowledge_base_ids 不能为空") - # - # # 提取实体 - # entity_extract_start = time.time() - # query_entities = await api_service.extract_entities( - # request=request, - # query=query, - # upappid=service_params['entities'], - # apiname="LTP/small", - # user=userid - # ) - # timing_stats["entity_extraction"] = time.time() - entity_extract_start - # debug(f"提取实体: {query_entities}, 耗时: {timing_stats['entity_extraction']:.3f} 秒") - # - # # 调用 Neo4j 服务进行三元组匹配 - # all_triplets = [] - # triplet_match_start = time.time() - # for kb_id in fiids: - # debug(f"调用 Neo4j 三元组匹配: knowledge_base_id={kb_id}") - # try: - # neo4j_result = await api_service.neo4j_match_triplets( - # request=request, - # query=query, - # query_entities=query_entities, - # userid=orgid, - # knowledge_base_id=kb_id, - # upappid=service_params['gdb'], - # apiname="neo4j/matchtriplets", - # user=userid - # ) - # if neo4j_result.get("status") == "success": - # triplets = neo4j_result.get("triplets", []) - # all_triplets.extend(triplets) - # debug(f"知识库 {kb_id} 匹配到 {len(triplets)} 个三元组: {triplets[:5]}") - # else: - # error( - # f"Neo4j 三元组匹配失败: knowledge_base_id={kb_id}, 错误: {neo4j_result.get('message', '未知错误')}") - # except Exception as e: - # error(f"Neo4j 三元组匹配失败: knowledge_base_id={kb_id}, 错误: {str(e)}") - # continue - # timing_stats["triplet_matching"] = time.time() - triplet_match_start - # debug(f"三元组匹配总耗时: {timing_stats['triplet_matching']:.3f} 秒") - # - # # 拼接三元组文本 - # triplet_text_start = time.time() - # triplet_texts = [] - # for triplet in all_triplets: - # head = triplet.get('head', '') - # type_ = triplet.get('type', '') - # tail = triplet.get('tail', '') - # if head and type_ and tail: - # triplet_texts.append(f"{head} {type_} {tail}") - # else: - # debug(f"无效三元组: {triplet}") - # combined_text = query - # if triplet_texts: - # combined_text += "".join(triplet_texts) - # debug( - # f"拼接文本: {combined_text[:200]}... (总长度: {len(combined_text)}, 三元组数量: {len(triplet_texts)})") - # timing_stats["triplet_text_combine"] = time.time() - triplet_text_start - # debug(f"拼接三元组文本耗时: {timing_stats['triplet_text_combine']:.3f} 秒") - # - # # 将拼接文本转换为向量 - # vector_start = time.time() - # query_vector = await api_service.get_embeddings( - # request=request, - # texts=[combined_text], - # upappid=service_params['embedding'], - # apiname="BAAI/bge-m3", - # user=userid - # ) - # if not query_vector or not all(len(vec) == 1024 for vec in query_vector): - # raise ValueError("查询向量必须是长度为 1024 的浮点数列表") - # query_vector = query_vector[0] # 取第一个向量 - # timing_stats["vector_generation"] = time.time() - vector_start - # debug(f"生成查询向量耗时: {timing_stats['vector_generation']:.3f} 秒") - # - # # 调用搜索端点 - # sum = limit + 5 - # search_start = time.time() - # debug(f"orgid: {orgid}") - # result = await api_service.milvus_search_query( - # request=request, - # query_vector=query_vector, - # userid=orgid, - # knowledge_base_ids=fiids, - # limit=sum, - # offset=0, - # upappid=service_params['vdb'], - # apiname="mlvus/searchquery", - # user=userid - # ) - # timing_stats["vector_search"] = time.time() - search_start - # debug(f"向量搜索耗时: {timing_stats['vector_search']:.3f} 秒") - # - # if result.get("status") != "success": - # error(f"融合搜索失败: {result.get('message', '未知错误')}") - # return {"results": [], "timing": timing_stats} - # - # unique_results = result.get("results", []) - # sum = len(unique_results) - # debug(f"从向量数据中搜索到{sum}条数据") - # use_rerank = True - # if use_rerank and unique_results: - # rerank_start = time.time() - # debug("开始重排序") - # unique_results = await api_service.rerank_results( - # request=request, - # query=combined_text, - # results=unique_results, - # top_n=limit, - # upappid=service_params['reranker'], - # apiname="BAAI/bge-reranker-v2-m3", - # user=userid - # ) - # unique_results = sorted(unique_results, key=lambda x: x.get('rerank_score', 0), reverse=True) - # timing_stats["reranking"] = time.time() - rerank_start - # debug(f"重排序耗时: {timing_stats['reranking']:.3f} 秒") - # debug(f"重排序分数分布: {[round(r.get('rerank_score', 0), 3) for r in unique_results]}") - # else: - # unique_results = [{k: v for k, v in r.items() if k != 'rerank_score'} for r in unique_results] - # - # timing_stats["total_time"] = time.time() - start_time - # info(f"融合搜索完成,返回 {len(unique_results)} 条结果,总耗时: {timing_stats['total_time']:.3f} 秒") - # - # # dify_result = [] - # # for res in unique_results[:limit]: - # # content = res.get('text', '') - # # title = res.get('metadata', {}).get('filename', 'Untitled') - # # document_id = res.get('metadata', {}).get('document_id', '') - # # dify_result.append({ - # # 'metadata': {'document_id': document_id}, - # # 'title': title, - # # 'content': content - # # }) - # # info(f"融合搜索完成,返回 {len(dify_result)} 条结果,总耗时: {(time.time() - start_time):.3f} 秒") - # # debug(f"result: {dify_result}") - # # return dify_result - # - # dify_records = [] - # dify_result = [] - # for res in unique_results[:limit]: - # rerank_score = res.get('rerank_score', 0) - # score = 1 / (1 + math.exp(-rerank_score)) if rerank_score is not None else 1 - res.get('distance', 0) - # score = max(0.0, min(1.0, score)) - # content = res.get('text', '') - # title = res.get('metadata', {}).get('filename', 'Untitled') - # document_id = res.get('metadata', {}).get('document_id', '') - # dify_records.append({ - # "content": content, - # "title": title, - # "metadata": {"document_id": document_id, "score": score}, - # }) - # dify_result.append({ - # "content": content, - # "title": title, - # "metadata": {"document_id": document_id, "score": score}, - # }) - # info(f"融合搜索完成,返回 {len(dify_records)} 条结果,总耗时: {(time.time() - start_time):.3f} 秒") - # debug(f"records: {dify_records}, result: {dify_result}") - # # return {"records": dify_records, "result": dify_result,"own": {"results": unique_results[:limit], "timing": timing_stats}} - # return {"records": dify_records} - # - # # dify_result = [] - # # for res in unique_results[:limit]: - # # rerank_score = res.get('rerank_score', 0) - # # score = 1 / (1 + math.exp(-rerank_score)) if rerank_score is not None else 1 - res.get('distance', 0) - # # score = max(0.0, min(1.0, score)) - # # content = res.get('text', '') - # # title = res.get('metadata', {}).get('filename', 'Untitled') - # # document_id = res.get('metadata', {}).get('document_id', '') - # # dify_result.append({ - # # "metadata": { - # # "_source": "konwledge", - # # "dataset_id":"111111", - # # "dataset_name": "NVIDIA_GPU性能参数-RAG-V1.xlsx", - # # "document_id": document_id, - # # "document_name": "test.docx", - # # "data_source_type": "upload_file", - # # "segment_id": "7b391707-93bc-4654-80ae-7989f393b045", - # # "retriever_from": "workflow", - # # "score": score, - # # "segment_hit_count": 7, - # # "segment_word_count": 275, - # # "segment_position": 5, - # # "segment_index_node_hash": "1cd60b478221c9d4831a0b2af3e8b8581d94ecb53e8ffd46af687e8fc3077b73", - # # "doc_metadata": None, - # # "position":1 - # # }, - # # "title": title, - # # "content": content - # # }) - # # return {"result": dify_result} - # - # except Exception as e: - # error(f"融合搜索失败: {str(e)}, 堆栈: {traceback.format_exc()}") - # return {"results": [], "timing": timing_stats} +