from rag.uapi_service import APIService from sqlor.dbpools import DBPools from appPublic.log import debug, error, info import time import traceback import json import math from rag.service_opts import get_service_params, sor_get_service_params from rag.rag_operations import RagOperations helptext = """kyrag API: 1. 得到kdb表: path: /v1/get_kdbs headers: { "Content-Type": "application/json" } response: [{"id": "1", "name": "textdb", "description": "文本数据库"}, {"id": "testkdb", "name": "testkdb", "description": ""}, {"id": "Vdtbt3qBfocteit1HIxVH", "name": "trm", "description": ""}] 2. 向量检索文本块: path: /v1/fusedsearch headers: { "Content-Type": "application/json" } data: { "query": "什么是知识抽取。" "fiids":["1"] } 3、docs文档 path: /v1/docs 4、longmemory存储 """ async def docs(request, params_kw, *params, **kw): return helptext async def get_kdbs(request, params_kw, *params, **kw): """返回 kdb 表的全部内容,返回 JSON""" f = kw.get('get_userorgid') orgid = await f() debug(f"orgid: {orgid},{f=}") debug(f"params_kw: {params_kw}") db = DBPools() dbname = kw.get('get_module_dbname')('rag') sql_opts = """ SELECT id, name, description FROM kdb WHERE orgid = ${orgid}$ """ try: async with db.sqlorContext(dbname) as sor: opts_result = await sor.sqlExe(sql_opts, {"orgid": orgid}) if not opts_result: error("未找到 kdb 表记录") return json.dumps({"status": "error", "message": "未找到记录"}) return json.dumps(opts_result, ensure_ascii=False) except Exception as e: error(f"查询 kdb 表失败: {str(e)}, 堆栈: {traceback.format_exc()}") return json.dumps({"status": "error", "message": str(e)}) except Exception as e: error(f"列出用户文件失败: {str(e)}, 堆栈: {traceback.format_exc()}") result.update({ "status": "error", "files_by_knowledge_base": {}, "message": f"列出用户文件失败: {str(e)}", "status_code": 400 }) return result async def fusedsearch(request, params_kw, *params, **kw): """融合搜索,调用服务化端点""" f = kw.get('get_userorgid') orgid = await f() debug(f"orgid: {orgid},{f=}") f = kw.get('get_user') userid = await f() debug(f"params_kw: {params_kw}") # orgid = "04J6VbxLqB_9RPMcgOv_8" # userid = "04J6VbxLqB_9RPMcgOv_8" query = params_kw.get('query', '') # 统一模式处理 limit 参数,为了对接dify和coze raw_limit = params_kw.get('limit') or ( params_kw.get('retrieval_setting', {}).get('top_k') if isinstance(params_kw.get('retrieval_setting'), dict) else None ) # 标准化为整数值 if raw_limit is None: limit = 5 # 两个来源都不存在时使用默认值 elif isinstance(raw_limit, (int, float)): limit = int(raw_limit) # 数值类型直接转换 elif isinstance(raw_limit, str): try: # 字符串转换为整数 limit = int(raw_limit) except (TypeError, ValueError): limit = 5 # 转换失败使用默认值 else: limit = 5 # 其他意外类型使用默认值 debug(f"limit: {limit}") raw_fiids = params_kw.get('fiids') or params_kw.get('knowledge_id') # # 标准化为列表格式 if raw_fiids is None: fiids = [] # 两个参数都不存在 elif isinstance(raw_fiids, list): fiids = [str(item).strip() for item in raw_fiids] # 已经是列表 elif isinstance(raw_fiids, str): # fiids = [f.strip() for f in raw_fiids.split(',') if f.strip()] try: # 尝试解析 JSON 字符串 parsed = json.loads(raw_fiids) if isinstance(parsed, list): fiids = [str(item).strip() for item in parsed] # JSON 数组转为字符串列表 else: # 处理逗号分隔的字符串或单个 ID 字符串 fiids = [f.strip() for f in raw_fiids.split(',') if f.strip()] except json.JSONDecodeError: # 如果不是合法 JSON,按逗号分隔 fiids = [f.strip() for f in raw_fiids.split(',') if f.strip()] elif isinstance(raw_fiids, (int, float)): fiids = [str(int(raw_fiids))] # 数值类型转为字符串列表 else: fiids = [] # 其他意外类型 debug(f"fiids: {fiids}") # 验证 fiids的orgid与orgid = await f()是否一致 await _validate_fiids_orgid(fiids, orgid, kw) service_params = await get_service_params(orgid) if not service_params: raise ValueError("无法获取服务参数") try: timing_stats = {} start_time = time.time() rag_ops = RagOperations() entity_extract_start = time.time() query_entities = await rag_ops.extract_entities(request, query, service_params, userid) timing_stats["entity_extraction"] = time.time() - entity_extract_start debug(f"提取实体: {query_entities}, 耗时: {timing_stats['entity_extraction']:.3f} 秒") triplet_match_start = time.time() all_triplets = await rag_ops.match_triplets(request, query, query_entities, orgid, fiids, service_params, userid) timing_stats["triplet_matching"] = time.time() - triplet_match_start debug(f"三元组匹配总耗时: {timing_stats['triplet_matching']:.3f} 秒") triplet_text_start = time.time() combined_text = _combine_query_with_triplets(query, all_triplets) timing_stats["triplet_text_combine"] = time.time() - triplet_text_start debug(f"拼接三元组文本耗时: {timing_stats['triplet_text_combine']:.3f} 秒") vector_start = time.time() query_vector = await rag_ops.api_service.get_embeddings( request=request, texts=[combined_text], upappid=service_params['embedding'], apiname="BAAI/bge-m3", user=userid ) if not query_vector or not all(len(vec) == 1024 for vec in query_vector): raise ValueError("查询向量必须是长度为 1024 的浮点数列表") query_vector = query_vector[0] timing_stats["vector_generation"] = time.time() - vector_start debug(f"生成查询向量耗时: {timing_stats['vector_generation']:.3f} 秒") search_start = time.time() search_limit = limit + 5 search_results = await rag_ops.vector_search( request, query_vector, orgid, fiids, search_limit, service_params, userid ) timing_stats["vector_search"] = time.time() - search_start debug(f"向量搜索耗时: {timing_stats['vector_search']:.3f} 秒") debug(f"从向量数据中搜索到{len(search_results)}条数据") # 步骤6: 重排序(可选) use_rerank = True if use_rerank and search_results: rerank_start = time.time() debug("开始重排序") reranked_results = await rag_ops.rerank_results( request, combined_text, search_results, limit, service_params, userid ) reranked_results = sorted(reranked_results, key=lambda x: x.get('rerank_score', 0), reverse=True) timing_stats["reranking"] = time.time() - rerank_start debug(f"重排序耗时: {timing_stats['reranking']:.3f} 秒") debug(f"重排序分数分布: {[round(r.get('rerank_score', 0), 3) for r in reranked_results]}") final_results = reranked_results else: final_results = [{k: v for k, v in r.items() if k != 'rerank_score'} for r in search_results] timing_stats["total_time"] = time.time() - start_time info(f"融合搜索完成,返回 {len(final_results)} 条结果,总耗时: {timing_stats['total_time']:.3f} 秒") formatted_results = rag_ops.format_search_results(final_results, limit) info(f"融合搜索完成,返回 {len(formatted_results)} 条结果") return { "records": formatted_results } except Exception as e: error(f"融合搜索失败: {str(e)}, 堆栈: {traceback.format_exc()}") # 事务管理器会自动执行回滚 return { "records": [], "timing": {"total_time": time.time() - start_time if 'start_time' in locals() else 0}, "error": str(e) } async def _validate_fiids_orgid(fiids, orgid, kw): """验证 fiids 的 orgid 与当前用户 orgid 是否一致""" if fiids: db = DBPools() dbname = kw.get('get_module_dbname')('rag') sql_opts = """SELECT orgid FROM kdb WHERE id = ${id}$""" try: async with db.sqlorContext(dbname) as sor: result = await sor.sqlExe(sql_opts, {"id": fiids[0]}) if not result: raise ValueError(f"未找到 fiid={fiids[0]} 的记录") kdb_orgid = result[0].get('orgid') if kdb_orgid != orgid: raise ValueError(f"orgid 不一致: kdb.orgid={kdb_orgid}, user orgid={orgid}") except Exception as e: error(f"orgid 验证失败: {str(e)}") raise def _combine_query_with_triplets(query, triplets): """拼接查询文本和三元组文本""" triplet_texts = [] for triplet in triplets: head = triplet.get('head', '') type_ = triplet.get('type', '') tail = triplet.get('tail', '') if head and type_ and tail: triplet_texts.append(f"{head} {type_} {tail}") else: debug(f"无效三元组: {triplet}") combined_text = query if triplet_texts: combined_text += "".join(triplet_texts) debug(f"拼接文本: {combined_text[:200]}... (总长度: {len(combined_text)}, 三元组数量: {len(triplet_texts)})") return combined_text # api_service = APIService() # start_time = time.time() # timing_stats = {} # try: # info( # f"开始融合搜索: query={query}, userid={orgid}, knowledge_base_ids={fiids}") # # if not query or not orgid or not fiids: # raise ValueError("query、orgid 和 knowledge_base_ids 不能为空") # # # 提取实体 # entity_extract_start = time.time() # query_entities = await api_service.extract_entities( # request=request, # query=query, # upappid=service_params['entities'], # apiname="LTP/small", # user=userid # ) # timing_stats["entity_extraction"] = time.time() - entity_extract_start # debug(f"提取实体: {query_entities}, 耗时: {timing_stats['entity_extraction']:.3f} 秒") # # # 调用 Neo4j 服务进行三元组匹配 # all_triplets = [] # triplet_match_start = time.time() # for kb_id in fiids: # debug(f"调用 Neo4j 三元组匹配: knowledge_base_id={kb_id}") # try: # neo4j_result = await api_service.neo4j_match_triplets( # request=request, # query=query, # query_entities=query_entities, # userid=orgid, # knowledge_base_id=kb_id, # upappid=service_params['gdb'], # apiname="neo4j/matchtriplets", # user=userid # ) # if neo4j_result.get("status") == "success": # triplets = neo4j_result.get("triplets", []) # all_triplets.extend(triplets) # debug(f"知识库 {kb_id} 匹配到 {len(triplets)} 个三元组: {triplets[:5]}") # else: # error( # f"Neo4j 三元组匹配失败: knowledge_base_id={kb_id}, 错误: {neo4j_result.get('message', '未知错误')}") # except Exception as e: # error(f"Neo4j 三元组匹配失败: knowledge_base_id={kb_id}, 错误: {str(e)}") # continue # timing_stats["triplet_matching"] = time.time() - triplet_match_start # debug(f"三元组匹配总耗时: {timing_stats['triplet_matching']:.3f} 秒") # # # 拼接三元组文本 # triplet_text_start = time.time() # triplet_texts = [] # for triplet in all_triplets: # head = triplet.get('head', '') # type_ = triplet.get('type', '') # tail = triplet.get('tail', '') # if head and type_ and tail: # triplet_texts.append(f"{head} {type_} {tail}") # else: # debug(f"无效三元组: {triplet}") # combined_text = query # if triplet_texts: # combined_text += "".join(triplet_texts) # debug( # f"拼接文本: {combined_text[:200]}... (总长度: {len(combined_text)}, 三元组数量: {len(triplet_texts)})") # timing_stats["triplet_text_combine"] = time.time() - triplet_text_start # debug(f"拼接三元组文本耗时: {timing_stats['triplet_text_combine']:.3f} 秒") # # # 将拼接文本转换为向量 # vector_start = time.time() # query_vector = await api_service.get_embeddings( # request=request, # texts=[combined_text], # upappid=service_params['embedding'], # apiname="BAAI/bge-m3", # user=userid # ) # if not query_vector or not all(len(vec) == 1024 for vec in query_vector): # raise ValueError("查询向量必须是长度为 1024 的浮点数列表") # query_vector = query_vector[0] # 取第一个向量 # timing_stats["vector_generation"] = time.time() - vector_start # debug(f"生成查询向量耗时: {timing_stats['vector_generation']:.3f} 秒") # # # 调用搜索端点 # sum = limit + 5 # search_start = time.time() # debug(f"orgid: {orgid}") # result = await api_service.milvus_search_query( # request=request, # query_vector=query_vector, # userid=orgid, # knowledge_base_ids=fiids, # limit=sum, # offset=0, # upappid=service_params['vdb'], # apiname="mlvus/searchquery", # user=userid # ) # timing_stats["vector_search"] = time.time() - search_start # debug(f"向量搜索耗时: {timing_stats['vector_search']:.3f} 秒") # # if result.get("status") != "success": # error(f"融合搜索失败: {result.get('message', '未知错误')}") # return {"results": [], "timing": timing_stats} # # unique_results = result.get("results", []) # sum = len(unique_results) # debug(f"从向量数据中搜索到{sum}条数据") # use_rerank = True # if use_rerank and unique_results: # rerank_start = time.time() # debug("开始重排序") # unique_results = await api_service.rerank_results( # request=request, # query=combined_text, # results=unique_results, # top_n=limit, # upappid=service_params['reranker'], # apiname="BAAI/bge-reranker-v2-m3", # user=userid # ) # unique_results = sorted(unique_results, key=lambda x: x.get('rerank_score', 0), reverse=True) # timing_stats["reranking"] = time.time() - rerank_start # debug(f"重排序耗时: {timing_stats['reranking']:.3f} 秒") # debug(f"重排序分数分布: {[round(r.get('rerank_score', 0), 3) for r in unique_results]}") # else: # unique_results = [{k: v for k, v in r.items() if k != 'rerank_score'} for r in unique_results] # # timing_stats["total_time"] = time.time() - start_time # info(f"融合搜索完成,返回 {len(unique_results)} 条结果,总耗时: {timing_stats['total_time']:.3f} 秒") # # # dify_result = [] # # for res in unique_results[:limit]: # # content = res.get('text', '') # # title = res.get('metadata', {}).get('filename', 'Untitled') # # document_id = res.get('metadata', {}).get('document_id', '') # # dify_result.append({ # # 'metadata': {'document_id': document_id}, # # 'title': title, # # 'content': content # # }) # # info(f"融合搜索完成,返回 {len(dify_result)} 条结果,总耗时: {(time.time() - start_time):.3f} 秒") # # debug(f"result: {dify_result}") # # return dify_result # # dify_records = [] # dify_result = [] # for res in unique_results[:limit]: # rerank_score = res.get('rerank_score', 0) # score = 1 / (1 + math.exp(-rerank_score)) if rerank_score is not None else 1 - res.get('distance', 0) # score = max(0.0, min(1.0, score)) # content = res.get('text', '') # title = res.get('metadata', {}).get('filename', 'Untitled') # document_id = res.get('metadata', {}).get('document_id', '') # dify_records.append({ # "content": content, # "title": title, # "metadata": {"document_id": document_id, "score": score}, # }) # dify_result.append({ # "content": content, # "title": title, # "metadata": {"document_id": document_id, "score": score}, # }) # info(f"融合搜索完成,返回 {len(dify_records)} 条结果,总耗时: {(time.time() - start_time):.3f} 秒") # debug(f"records: {dify_records}, result: {dify_result}") # # return {"records": dify_records, "result": dify_result,"own": {"results": unique_results[:limit], "timing": timing_stats}} # return {"records": dify_records} # # # dify_result = [] # # for res in unique_results[:limit]: # # rerank_score = res.get('rerank_score', 0) # # score = 1 / (1 + math.exp(-rerank_score)) if rerank_score is not None else 1 - res.get('distance', 0) # # score = max(0.0, min(1.0, score)) # # content = res.get('text', '') # # title = res.get('metadata', {}).get('filename', 'Untitled') # # document_id = res.get('metadata', {}).get('document_id', '') # # dify_result.append({ # # "metadata": { # # "_source": "konwledge", # # "dataset_id":"111111", # # "dataset_name": "NVIDIA_GPU性能参数-RAG-V1.xlsx", # # "document_id": document_id, # # "document_name": "test.docx", # # "data_source_type": "upload_file", # # "segment_id": "7b391707-93bc-4654-80ae-7989f393b045", # # "retriever_from": "workflow", # # "score": score, # # "segment_hit_count": 7, # # "segment_word_count": 275, # # "segment_position": 5, # # "segment_index_node_hash": "1cd60b478221c9d4831a0b2af3e8b8581d94ecb53e8ffd46af687e8fc3077b73", # # "doc_metadata": None, # # "position":1 # # }, # # "title": title, # # "content": content # # }) # # return {"result": dify_result} # # except Exception as e: # error(f"融合搜索失败: {str(e)}, 堆栈: {traceback.format_exc()}") # return {"results": [], "timing": timing_stats}