diff --git a/rag/ragapi.py b/rag/ragapi.py new file mode 100644 index 0000000..e32df8b --- /dev/null +++ b/rag/ragapi.py @@ -0,0 +1,211 @@ +from rag.uapi_service import APIService +from rag.folderinfo import RagFileMgr + +helptext = """kyrag API: + +1. 得到kdb表: +path: /v1/get_kdbs +headers: { + "Content-Type": "application/json" +} +response: + +2. 向量检索文本块: +path: /v1/fusedsearch +headers: { + "Content-Type": "application/json" +} +data: { + "query": "什么是知识抽取。" + "fiids":["1"] +} +3、docs文档 +""" + +async def docs(request, params_kw, *params, **kw): + return helptext + +async def get_kdbs(request, params_kw, *params, **kw): + """返回 kdb 表的全部内容,返回 JSON""" + print("初始化数据库连接池...") + db = DBPools() + dbname = kw.get('get_module_dbname')('rag') + sql_opts = """ + SELECT id, name, description + FROM kdb + """ + try: + async with db.sqlorContext(dbname) as sor: + opts_result = await sor.sqlExe(sql_opts, {}) + if not opts_result: + error("未找到 kdb 表记录") + return json.dumps({"status": "error", "message": "未找到记录"}) + return json.dumps(opts_result, ensure_ascii=False) + except Exception as e: + error(f"查询 kdb 表失败: {str(e)}, 堆栈: {traceback.format_exc()}") + return json.dumps({"status": "error", "message": str(e)}) + +async def fusedsearch(request, params_kw, *params, **kw): + """融合搜索,调用服务化端点""" + f = kw.get('get_userorgid') + orgid = await f() + f = kw.get('get_user') + userid = await f() + debug(f"params_kw: {params_kw}") + query = params_kw.get('query', '') + fiids = params_kw.get('fiid', []) + + # 验证 fiids的orgid与orgid = await f()是否一致 + if fiids: + db = DBPools() + dbname = kw.get('get_module_dbname')('rag') + sql_opts = """ + SELECT orgid + FROM kdb + WHERE id = ${fiid}$ + """ + try: + async with db.sqlorContext(dbname) as sor: + result = await sor.sqlExe(sql_opts, {"fiid": fiids[0]}) + if not result: + raise ValueError(f"未找到 fiid={fiids[0]} 的记录") + kdb_orgid = result[0].get('orgid') + if kdb_orgid != orgid: + raise ValueError(f"orgid 不一致: kdb.orgid={kdb_orgid}, user orgid={orgid}") + except Exception as e: + error(f"orgid 验证失败: {str(e)}") + return json.dumps({"status": "error", "message": str(e)}) + ragfilemgr = RagFileMgr("fiids[0]") + service_params = ragfilemgr.get_service_params(orgid) + + api_service = APIService() + start_time = time.time() + timing_stats = {} + try: + info( + f"开始融合搜索: query={query}, userid={orgid}, knowledge_base_ids={fiids}") + + if not query or not orgid or not knowledge_base_ids: + raise ValueError("query、orgid 和 knowledge_base_ids 不能为空") + + # 提取实体 + entity_extract_start = time.time() + query_entities = await api_service.extract_entities( + request=request, + query=query, + upappid=service_params['entities'], + apiname="LTP/small", + user=userid + ) + timing_stats["entity_extraction"] = time.time() - entity_extract_start + debug(f"提取实体: {query_entities}, 耗时: {timing_stats['entity_extraction']:.3f} 秒") + + # 调用 Neo4j 服务进行三元组匹配 + all_triplets = [] + triplet_match_start = time.time() + for kb_id in fiids: + debug(f"调用 Neo4j 三元组匹配: knowledge_base_id={kb_id}") + try: + neo4j_result = await api_service.neo4j_match_triplets( + request=request, + query=query, + query_entities=query_entities, + userid=orgid, + knowledge_base_id=kb_id, + upappid=service_params['gdb'], + apiname="neo4j/matchtriplets", + user=userid + ) + if neo4j_result.get("status") == "success": + triplets = neo4j_result.get("triplets", []) + all_triplets.extend(triplets) + debug(f"知识库 {kb_id} 匹配到 {len(triplets)} 个三元组: {triplets[:5]}") + else: + error( + f"Neo4j 三元组匹配失败: knowledge_base_id={kb_id}, 错误: {neo4j_result.get('message', '未知错误')}") + except Exception as e: + error(f"Neo4j 三元组匹配失败: knowledge_base_id={kb_id}, 错误: {str(e)}") + continue + timing_stats["triplet_matching"] = time.time() - triplet_match_start + debug(f"三元组匹配总耗时: {timing_stats['triplet_matching']:.3f} 秒") + + # 拼接三元组文本 + triplet_text_start = time.time() + triplet_texts = [] + for triplet in all_triplets: + head = triplet.get('head', '') + type_ = triplet.get('type', '') + tail = triplet.get('tail', '') + if head and type_ and tail: + triplet_texts.append(f"{head} {type_} {tail}") + else: + debug(f"无效三元组: {triplet}") + combined_text = query + if triplet_texts: + combined_text += " [三元组] " + "; ".join(triplet_texts) + debug( + f"拼接文本: {combined_text[:200]}... (总长度: {len(combined_text)}, 三元组数量: {len(triplet_texts)})") + timing_stats["triplet_text_combine"] = time.time() - triplet_text_start + debug(f"拼接三元组文本耗时: {timing_stats['triplet_text_combine']:.3f} 秒") + + # 将拼接文本转换为向量 + vector_start = time.time() + query_vector = await api_service.get_embeddings( + request=request, + texts=[combined_text], + upappid=service_params['embedding'], + user=userid + ) + if not query_vector or not all(len(vec) == 1024 for vec in query_vector): + raise ValueError("查询向量必须是长度为 1024 的浮点数列表") + query_vector = query_vector[0] # 取第一个向量 + timing_stats["vector_generation"] = time.time() - vector_start + debug(f"生成查询向量耗时: {timing_stats['vector_generation']:.3f} 秒") + + # 调用搜索端点 + limit = 5 + search_start = time.time() + result = await api_service.milvus_search_query( + request=request, + query_vector=query_vector, + userid=orgid, + knowledge_base_ids=[fiids], + upappid=service_params['vdb'], + apiname="mlvus/searchquery", + user=userid + ) + timing_stats["vector_search"] = time.time() - search_start + debug(f"向量搜索耗时: {timing_stats['vector_search']:.3f} 秒") + + if result.get("status") != "success": + error(f"融合搜索失败: {result.get('message', '未知错误')}") + return {"results": [], "timing": timing_stats} + + unique_results = result.get("results", []) + use_rerank = True + if use_rerank and unique_results: + rerank_start = time.time() + debug("开始重排序") + unique_results = await api_service( + request=request, + query=combined_text, + results=unique_results, + top_n=limit, + upappid=service_params['reranker'], + apiname="BAAI/bge-reranker-v2-m3", + user=userid + ) + unique_results = sorted(unique_results, key=lambda x: x.get('rerank_score', 0), reverse=True) + timing_stats["reranking"] = time.time() - rerank_start + debug(f"重排序耗时: {timing_stats['reranking']:.3f} 秒") + debug(f"重排序分数分布: {[round(r.get('rerank_score', 0), 3) for r in unique_results]}") + else: + unique_results = [{k: v for k, v in r.items() if k != 'rerank_score'} for r in unique_results] + + timing_stats["total_time"] = time.time() - start_time + info(f"融合搜索完成,返回 {len(unique_results)} 条结果,总耗时: {timing_stats['total_time']:.3f} 秒") + return {"results": unique_results[:limit], "timing": timing_stats} + + except Exception as e: + error(f"融合搜索失败: {str(e)}, 堆栈: {traceback.format_exc()}") + return {"results": [], "timing": timing_stats}