From 22ad6e48fd5f54f5d3507035fd118c29cf3b4b68 Mon Sep 17 00:00:00 2001 From: wangmeihua <13383952685@163.com> Date: Mon, 28 Jul 2025 10:55:08 +0800 Subject: [PATCH] =?UTF-8?q?=E5=88=A0=E9=99=A4=E7=9B=B8=E5=85=B3=E6=96=87?= =?UTF-8?q?=E4=BB=B6?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit --- rag/allfusedsearch.py | 290 ----------------------- rag/combinedsearch.py | 190 --------------- rag/deletefile.py | 138 ----------- rag/embed.py | 183 -------------- rag/extract.py | 225 ------------------ rag/fusedsearch.py | 290 ----------------------- rag/kdb.py | 81 ------- rag/kgc.py | 194 --------------- rag/query.py | 201 ---------------- rag/rerank.py | 80 ------- rag/searchquery.py | 363 ---------------------------- rag/test.py | 9 - rag/vector.py | 539 ------------------------------------------ rag/version.py | 1 - setup.py | 52 ---- 15 files changed, 2836 deletions(-) delete mode 100644 rag/allfusedsearch.py delete mode 100644 rag/combinedsearch.py delete mode 100644 rag/deletefile.py delete mode 100644 rag/embed.py delete mode 100644 rag/extract.py delete mode 100644 rag/fusedsearch.py delete mode 100644 rag/kdb.py delete mode 100644 rag/kgc.py delete mode 100644 rag/query.py delete mode 100644 rag/rerank.py delete mode 100644 rag/searchquery.py delete mode 100644 rag/test.py delete mode 100644 rag/vector.py delete mode 100644 rag/version.py delete mode 100755 setup.py diff --git a/rag/allfusedsearch.py b/rag/allfusedsearch.py deleted file mode 100644 index 6c515f2..0000000 --- a/rag/allfusedsearch.py +++ /dev/null @@ -1,290 +0,0 @@ -import os -import logging -import yaml -import numpy as np -from typing import List, Dict, Any -from pymilvus import Collection, utility -from langchain_huggingface import HuggingFaceEmbeddings -from vector import initialize_milvus_connection -from searchquery import extract_entities, match_triplets -from rerank import rerank_results -import torch -import time - -# 加载配置文件 -CONFIG_PATH = os.getenv('CONFIG_PATH', '/share/wangmeihua/rag/conf/milvusconfig.yaml') -try: - with open(CONFIG_PATH, 'r', encoding='utf-8') as f: - config = yaml.safe_load(f) - TEXT_EMBEDDING_MODEL = config['models']['text_embedding_model'] -except Exception as e: - raise RuntimeError(f"无法加载配置文件: {str(e)}") - -# 配置日志 -logger = logging.getLogger(config['logging']['name']) -logger.setLevel(getattr(logging, config['logging']['level'], logging.DEBUG)) -logger.handlers.clear() -logger.propagate = False -os.makedirs(os.path.dirname(config['logging']['file']), exist_ok=True) -try: - with open(config['logging']['file'], 'a', encoding='utf-8') as f: - pass -except Exception as e: - raise RuntimeError(f"日志文件 {config['logging']['file']} 不可写: {str(e)}") -formatter = logging.Formatter('%(asctime)s - %(levelname)s - %(message)s') -file_handler = logging.FileHandler(config['logging']['file'], encoding='utf-8') -file_handler.setFormatter(formatter) -stream_handler = logging.StreamHandler() -stream_handler.setFormatter(formatter) -logger.addHandler(file_handler) -logger.addHandler(stream_handler) - -# 初始化嵌入模型 -embedding = HuggingFaceEmbeddings( - model_name=TEXT_EMBEDDING_MODEL, - model_kwargs={'device': 'cuda' if torch.cuda.is_available() else 'cpu'}, - encode_kwargs={'normalize_embeddings': True} -) -try: - test_vector = embedding.embed_query("test") - if len(test_vector) != 1024: - raise ValueError(f"嵌入模型输出维度 {len(test_vector)} 不匹配预期 1024") - logger.debug("嵌入模型加载成功") -except Exception as e: - logger.error(f"嵌入模型加载失败: {str(e)}") - raise RuntimeError(f"嵌入模型加载失败: {str(e)}") - -# 缓存三元组 -TRIPLET_CACHE = {} - - -def load_triplets_to_cache(userid: str, document_id: str) -> List[Dict]: - """加载三元组到缓存""" - cache_key = f"{document_id}_{userid}" - if cache_key in TRIPLET_CACHE: - logger.debug(f"从缓存加载三元组: {cache_key}") - return TRIPLET_CACHE[cache_key] - - triplet_file = f"/share/wangmeihua/rag/triples/{document_id}_{userid}.txt" - triplets = [] - try: - with open(triplet_file, 'r', encoding='utf-8') as f: - for line in f: - parts = line.strip().split('\t') - if len(parts) < 3: - continue - head, type_, tail = parts[:3] - triplets.append({'head': head, 'type': type_, 'tail': tail}) - TRIPLET_CACHE[cache_key] = triplets - logger.debug(f"加载三元组文件: {triplet_file}, 数量: {len(triplets)}") - return triplets - except Exception as e: - logger.error(f"加载三元组失败: {triplet_file}, 错误: {str(e)}") - return [] - - -def fused_search( - query: str, - userid: str, - db_type: str, - file_paths: List[str], - limit: int = 5, - offset: int = 0, - use_rerank: bool = True -) -> List[Dict[str, Any]]: - """ - 融合 RAG 和三元组召回文本块: - - 收集所有输入文件的三元组,拼接为融合文本,向量化后在所有文件中搜索。 - - 结果去重并按 rerank_score 或 distance 排序,重排序使用融合文本。 - - 参数: - query (str): 查询文本 - userid (str): 用户 ID - db_type (str): 数据库类型 (e.g., 'textdb') - file_paths (List[str]): 文件路径列表 - limit (int): 返回结果数量 - offset (int): 偏移量 - use_rerank (bool): 是否使用重排序 - - 返回: - List[Dict[str, Any]]: 召回结果,包含 text、distance、source、metadata、rerank_score - """ - try: - logger.info(f"开始融合搜索: query={query}, userid={userid}, db_type={db_type}") - start_time = time.time() - - # 参数验证 - if not query or not userid or not db_type or not file_paths: - raise ValueError("query、userid、db_type 和 file_paths 不能为空") - if "_" in userid or "_" in db_type: - raise ValueError("userid 和 db_type 不能包含下划线") - - # 初始化 Milvus 连接 - connections = initialize_milvus_connection() - collection_name = f"ragdb_{db_type}" - if not utility.has_collection(collection_name): - logger.warning(f"集合 {collection_name} 不存在") - return [] - collection = Collection(collection_name) - collection.load() - logger.debug(f"加载 Milvus 集合: {collection_name}") - - # 提取实体 - entity_start = time.time() - query_entities = extract_entities(query) - logger.debug(f"提取实体: {query_entities}, 耗时: {time.time() - entity_start:.3f}s") - - # 收集所有文件的 document_id 和三元组 - doc_id_map = {} - filenames = [] - all_triplets = [] - for file_path in file_paths: - filename = os.path.basename(file_path) - filenames.append(filename) - logger.debug(f"处理文件: {filename}") - - # 获取 document_id - results_query = collection.query( - expr=f"userid == '{userid}' and filename == '{filename}'", - output_fields=["document_id"], - limit=1 - ) - if not results_query: - logger.warning(f"未找到 userid {userid} 和 filename {filename} 对应的文档") - continue - document_id = results_query[0]["document_id"] - doc_id_map[filename] = document_id - load_triplets_to_cache(userid, document_id) - - # 获取匹配的三元组 - triplet_start = time.time() - matched_triplets = match_triplets(query, query_entities, userid, document_id) - logger.debug( - f"文件 {filename} 匹配三元组: {len(matched_triplets)} 条, 耗时: {time.time() - triplet_start:.3f}s") - all_triplets.extend(matched_triplets) - - if not doc_id_map: - logger.warning("未找到任何有效文档") - return [] - - # 拼接融合文本 - triplet_texts = [] - for triplet in all_triplets: - head = triplet['head'] - type_ = triplet['type'] - tail = triplet['tail'] - if not head or not type_ or not tail: - logger.debug(f"无效三元组: {triplet}") - continue - triplet_texts.append(f"{head} {type_} {tail}") - - # 定义融合文本 - fused_text = query if not triplet_texts else f"{query} {' '.join(triplet_texts)}" - logger.debug(f"融合文本: {fused_text}, 三元组数量: {len(triplet_texts)}") - - # 向量化 - embed_start = time.time() - query_vector = embedding.embed_query(fused_text) - query_vector = np.array(query_vector) / np.linalg.norm(query_vector) - logger.debug(f"生成融合向量,维度: {len(query_vector)}, 耗时: {time.time() - embed_start:.3f}s") - - # Milvus 搜索 - expr = f"userid == '{userid}' and filename in {filenames}" - search_params = {"metric_type": "COSINE", "params": {"nprobe": 10}} - milvus_start = time.time() - milvus_results = collection.search( - data=[query_vector], - anns_field="vector", - param=search_params, - limit=100, - expr=expr, - output_fields=["text", "userid", "document_id", "filename", "file_path", "upload_time", "file_type"], - offset=offset - ) - logger.debug(f"Milvus 搜索耗时: {time.time() - milvus_start:.3f}s") - - results = [] - for hits in milvus_results: - for hit in hits: - result = { - "text": hit.entity.get("text"), - "distance": hit.distance, - "source": "fused_query" if not triplet_texts else f"fused_triplets_{len(triplet_texts)}", - "metadata": { - "userid": hit.entity.get("userid"), - "document_id": hit.entity.get("document_id"), - "filename": hit.entity.get("filename"), - "file_path": hit.entity.get("file_path"), - "upload_time": hit.entity.get("upload_time"), - "file_type": hit.entity.get("file_type") - } - } - results.append(result) - logger.debug( - f"召回: text={result['text'][:100]}..., distance={result['distance']}, filename={result['metadata']['filename']}") - - # 去重 - unique_results = [] - seen_texts = set() - for result in results: - text = result['text'] - if not text: - logger.warning(f"发现空文本结果: {result['metadata']}") - continue - if text in seen_texts: - logger.debug(f"移除重复文本: text={text[:100]}..., filename={result['metadata']['filename']}") - continue - seen_texts.add(text) - unique_results.append(result) - logger.info(f"去重后结果数量: {len(unique_results)} (原始数量: {len(results)})") - - # 可选:重排序 - if use_rerank and unique_results: - logger.debug("开始重排序") - logger.debug(f"重排序查询: {fused_text}") - rerank_start = time.time() - reranked_results = rerank_results(fused_text, unique_results) - reranked_results = sorted(reranked_results, key=lambda x: x.get('rerank_score', 0), reverse=True) - logger.debug(f"重排序分数分布: {[round(r.get('rerank_score', 0), 3) for r in reranked_results]}") - logger.debug(f"重排序耗时: {time.time() - rerank_start:.3f}s") - for i, result in enumerate(reranked_results): - logger.debug( - f"排序结果 {i + 1}: text={result['text'][:100]}..., distance={result['distance']}, rerank_score={result.get('rerank_score', 'N/A')}") - logger.info(f"总耗时: {time.time() - start_time:.3f}s") - return reranked_results[:limit] - - # 按 distance 降序排序 - sorted_results = sorted(unique_results, key=lambda x: x['distance'], reverse=True) - for i, result in enumerate(sorted_results): - logger.debug(f"排序结果 {i + 1}: text={result['text'][:100]}..., distance={result['distance']}") - logger.info(f"总耗时: {time.time() - start_time:.3f}s") - return sorted_results[:limit] - - except Exception as e: - logger.error(f"融合搜索失败: {str(e)}") - import traceback - logger.error(traceback.format_exc()) - return [] - - -if __name__ == "__main__": - query = "什么是知识抽取?" - userid = "testuser1" - db_type = "textdb" - file_paths = [ - "/share/wangmeihua/rag/data/test.docx", - "/share/wangmeihua/rag/data/zongshu.pdf", - "/share/wangmeihua/rag/data/qianru.pdf", - ] - try: - results = fused_search(query, userid, db_type, file_paths, limit=10, offset=0) - for i, result in enumerate(results): - print(f"Result {i + 1}:") - print(f"Text: {result['text'][:200]}...") - print(f"Distance: {result['distance']:.3f}") - print( - f"Rerank Score: {result.get('rerank_score', 'N/A') if isinstance(result.get('rerank_score'), (int, float)) else 'N/A':.3f}") - print(f"Source: {result['source']}") - print(f"Metadata: {result['metadata']}\n") - except Exception as e: - print(f"搜索失败: {str(e)}") \ No newline at end of file diff --git a/rag/combinedsearch.py b/rag/combinedsearch.py deleted file mode 100644 index 12230d8..0000000 --- a/rag/combinedsearch.py +++ /dev/null @@ -1,190 +0,0 @@ -import os -import yaml -import logging -from typing import List, Dict -from pymilvus import connections, Collection, utility -from langchain_huggingface import HuggingFaceEmbeddings -from query import search_query -from searchquery import searchquery -from rerank import rerank_results -from vector import initialize_milvus_connection, cleanup_milvus_connection -import torch -from functools import lru_cache - -# 加载配置文件 -CONFIG_PATH = os.getenv('CONFIG_PATH', '/share/wangmeihua/rag/conf/milvusconfig.yaml') -try: - with open(CONFIG_PATH, 'r', encoding='utf-8') as f: - config = yaml.safe_load(f) - MILVUS_DB_PATH = config['database']['milvus_db_path'] - TEXT_EMBEDDING_MODEL = config['models']['text_embedding_model'] -except Exception as e: - print(f"加载配置文件 {CONFIG_PATH} 失败: {str(e)}") - raise RuntimeError(f"加载配置文件: {str(e)}") - -# 配置日志 -logger = logging.getLogger(config['logging']['name']) -logger.setLevel(getattr(logging, config['logging']['level'], logging.DEBUG)) -logger.handlers.clear() # 清除现有处理器 -logger.propagate = False # 禁用传播 -os.makedirs(os.path.dirname(config['logging']['file']), exist_ok=True) -formatter = logging.Formatter('%(asctime)s - %(levelname)s - %(message)s') -file_handler = logging.FileHandler(config['logging']['file'], encoding='utf-8') -file_handler.setFormatter(formatter) -stream_handler = logging.StreamHandler() -stream_handler.setFormatter(formatter) -logger.addHandler(file_handler) -logger.addHandler(stream_handler) - -# 初始化嵌入模型(缓存) -@lru_cache(maxsize=1000) -def get_embedding(text: str) -> List[float]: - embedding = HuggingFaceEmbeddings( - model_name=TEXT_EMBEDDING_MODEL, - model_kwargs={'device': 'cuda' if torch.cuda.is_available() else 'cpu'}, - encode_kwargs={'normalize_embeddings': True} - ) - vector = embedding.embed_query(text) - if len(vector) != 1024: - raise ValueError(f"嵌入模型输出维度 {len(vector)} 不匹配预期 1024") - return vector - -def combined_search(query: str, userid: str, db_type: str, file_paths: List[str], limit: int = 10, offset: int = 0) -> List[Dict]: - """ - 结合 RAG 和三元组检索,召回相关文本块,使用 BGE Reranker 重排序。 - - 参数: - query (str): 查询文本 - userid (str): 用户ID - db_type (str): 数据库类型 - file_paths (List[str]): 文档路径列表 - limit (int): 返回的最大结果数,默认为 10 - offset (int): 偏移量,默认为 0 - - 返回: - List[Dict]: 包含 text、distance、source、metadata 和 rerank_score 的结果列表 - """ - try: - # 参数验证 - if not query or not userid or not db_type or not file_paths: - raise ValueError("query、userid、db_type 和 file_paths 不能为空") - if "_" in userid or "_" in db_type: - raise ValueError("userid 和 db_type 不能包含下划线") - if len(userid) > 100 or len(db_type) > 100: - raise ValueError("userid 或 db_type 的长度超出限制") - if limit <= 0 or limit > 16384: - raise ValueError("limit 必须在 1 到 16384 之间") - if offset < 0: - raise ValueError("offset 不能为负数") - if limit + offset > 16384: - raise ValueError("limit + offset 不能超过 16384") - - for file_path in file_paths: - if not isinstance(file_path, str): - raise ValueError(f"file_path 必须是字符串: {file_path}") - if len(os.path.basename(file_path)) > 255: - raise ValueError(f"文件名长度超出 255 个字符: {file_path}") - - # 初始化 Milvus 连接 - initialize_milvus_connection() - collection_name = f"ragdb_{db_type}" - if not utility.has_collection(collection_name): - logger.warning(f"集合 {collection_name} 不存在") - return [] - - # RAG 检索,使用默认 limit=3 - rag_results = search_query(query, userid, db_type, file_paths, offset=offset) - for result in rag_results: - result['source'] = 'rag' - logger.info(f"RAG 检索返回 {len(rag_results)} 条结果") - - # 三元组检索,使用默认 limit=3 - triplet_results = searchquery(query, userid, db_type, file_paths, offset=offset) - for result in triplet_results: - result['source'] = 'triplet' - logger.info(f"三元组检索返回 {len(triplet_results)} 条结果") - - # 记录三元组检索结果详情 - for idx, result in enumerate(triplet_results, 1): - logger.debug(f"三元组结果 {idx}: text={result['text'][:200]}..., distance={result['distance']:.4f}, metadata={result['metadata']}") - - # 合并结果 - all_results = rag_results + triplet_results - if not all_results: - logger.warning("RAG 和三元组检索均无结果") - return [] - - # 记录合并前的结果 - logger.debug("合并前结果:") - for idx, result in enumerate(all_results, 1): - logger.debug(f"结果 {idx} ({result['source']}): text={result['text'][:200]}..., distance={result['distance']:.4f}, metadata={result['metadata']}") - - # 使用 BGE Reranker 重排序 - reranked_results = rerank_results(query, all_results, top_k=len(all_results)) - - # 按 rerank_score 排序(不去重) - sorted_results = sorted(reranked_results, key=lambda x: x['rerank_score'], reverse=True) - - # 记录排序后的结果 - logger.debug("重排序后结果:") - for idx, result in enumerate(sorted_results, 1): - logger.debug(f"排序结果 {idx} ({result['source']}): text={result['text'][:200]}..., distance={result['distance']:.4f}, rerank_score={result['rerank_score']:.6f}, metadata={result['metadata']}") - - # 去重(基于 text,保留 rerank_score 最大的记录) - unique_results = [] - text_to_result = {} - for result in sorted_results: - text = result['text'] - if text not in text_to_result or result['rerank_score'] > text_to_result[text]['rerank_score']: - text_to_result[text] = result - unique_results = list(text_to_result.values()) - - # 记录去重后的结果 - logger.debug("去重后结果:") - for idx, result in enumerate(unique_results, 1): - logger.debug(f"去重结果 {idx} ({result['source']}): text={result['text'][:200]}..., distance={result['distance']:.4f}, rerank_score={result['rerank_score']:.6f}, metadata={result['metadata']}") - - # 限制结果数量 - final_results = unique_results[:limit] - logger.info(f"合并后返回 {len(final_results)} 条唯一结果") - - # 移除 weighted_score 字段(若存在),保留 rerank_score 和 source - for result in final_results: - result.pop('weighted_score', None) - - return final_results - - except Exception as e: - logger.error(f"合并搜索失败: {str(e)}") - import traceback - logger.debug(traceback.format_exc()) - return [] - finally: - cleanup_milvus_connection() - -if __name__ == "__main__": - # 测试代码 - query = "知识图谱构建需要什么技术?" - userid = "testuser1" - db_type = "textdb" - file_paths = [ - "/share/wangmeihua/rag/data/test.docx", - "/share/wangmeihua/rag/data/zongshu.pdf", - "/share/wangmeihua/rag/data/qianru.pdf" - ] - limit = 10 - offset = 0 - - try: - results = combined_search(query, userid, db_type, file_paths, limit, offset) - print(f"搜索结果 ({len(results)} 条):") - for idx, result in enumerate(results, 1): - print(f"结果 {idx}:") - print(f"内容: {result['text'][:200]}...") - print(f"距离: {result['distance']}") - print(f"来源: {result['source']}") - print(f"重排序分数: {result['rerank_score']}") - print(f"元数据: {result['metadata']}") - print("-" * 50) - except Exception as e: - print(f"搜索失败: {str(e)}") \ No newline at end of file diff --git a/rag/deletefile.py b/rag/deletefile.py deleted file mode 100644 index d6c4e4e..0000000 --- a/rag/deletefile.py +++ /dev/null @@ -1,138 +0,0 @@ -import logging -import yaml -import os -from pymilvus import connections, Collection, utility -from vector import initialize_milvus_connection - -# 加载配置文件 -CONFIG_PATH = os.getenv('CONFIG_PATH', '/share/wangmeihua/rag/conf/milvusconfig.yaml') -try: - with open(CONFIG_PATH, 'r', encoding='utf-8') as f: - config = yaml.safe_load(f) - MILVUS_DB_PATH = config['database']['milvus_db_path'] - TEXT_EMBEDDING_MODEL = config['models']['text_embedding_model'] -except Exception as e: - print(f"加载配置文件 {CONFIG_PATH} 失败: {str(e)}") - raise RuntimeError(f"无法加载配置文件: {str(e)}") - -# 配置日志 -logger = logging.getLogger(config['logging']['name']) -logger.setLevel(getattr(logging, config['logging']['level'], logging.DEBUG)) -os.makedirs(os.path.dirname(config['logging']['file']), exist_ok=True) -formatter = logging.Formatter('%(asctime)s - %(levelname)s - %(message)s') -for handler in (logging.FileHandler(config['logging']['file'], encoding='utf-8'), logging.StreamHandler()): - handler.setFormatter(formatter) - logger.addHandler(handler) - -def delete_document(db_type: str, userid: str, filename: str) -> bool: - """ - 根据 db_type、userid 和 filename 删除用户的指定文件数据。 - - 参数: - db_type (str): 数据库类型(如 'textdb', 'pptdb') - userid (str): 用户 ID - filename (str): 文件名(如 'test.docx') - - 返回: - bool: 删除是否成功 - - 异常: - ValueError: 参数无效 - RuntimeError: 数据库操作失败 - """ - try: - # 参数验证 - if not db_type or "_" in db_type: - raise ValueError("db_type 不能为空且不能包含下划线") - if not userid or "_" in userid: - raise ValueError("userid 不能为空且不能包含下划线") - if not filename: - raise ValueError("filename 不能为空") - if len(db_type) > 100 or len(userid) > 100 or len(filename) > 255: - raise ValueError("db_type、userid 或 filename 的长度超出限制") - - # 初始化 Milvus 连接 - initialize_milvus_connection() - logger.debug(f"已连接到 Milvus Lite,路径: {MILVUS_DB_PATH}") - - # 检查集合是否存在 - collection_name = f"ragdb_{db_type}" - if not utility.has_collection(collection_name): - logger.warning(f"集合 {collection_name} 不存在") - return False - - # 加载集合 - try: - collection = Collection(collection_name) - collection.load() - logger.debug(f"加载集合: {collection_name}") - except Exception as e: - logger.error(f"加载集合 {collection_name} 失败: {str(e)}") - raise RuntimeError(f"加载集合失败: {str(e)}") - - # 查询匹配的 document_id - expr = f"userid == '{userid}' and filename == '{filename}'" - logger.debug(f"查询表达式: {expr}") - try: - results = collection.query( - expr=expr, - output_fields=["document_id"], - limit=1000 - ) - if not results: - logger.warning(f"没有找到 userid={userid}, filename={filename} 的记录") - return False - document_ids = list(set(result["document_id"] for result in results if "document_id" in result)) - logger.debug(f"找到 {len(document_ids)} 个 document_id: {document_ids}") - except Exception as e: - logger.error(f"查询 document_id 失败: {str(e)}") - raise RuntimeError(f"查询失败: {str(e)}") - - # 执行删除 - total_deleted = 0 - for doc_id in document_ids: - try: - delete_expr = f"userid == '{userid}' and document_id == '{doc_id}'" - logger.debug(f"删除表达式: {delete_expr}") - delete_result = collection.delete(delete_expr) - deleted_count = delete_result.delete_count - total_deleted += deleted_count - logger.info(f"成功删除 document_id={doc_id} 的 {deleted_count} 条记录") - except Exception as e: - logger.error(f"删除 document_id={doc_id} 失败: {str(e)}") - continue - - if total_deleted == 0: - logger.warning(f"没有删除任何记录,userid={userid}, filename={filename}") - return False - - logger.info(f"总计删除 {total_deleted} 条记录,userid={userid}, filename={filename}") - return True - - except ValueError as ve: - logger.error(f"参数验证失败: {str(ve)}") - return False - except RuntimeError as re: - logger.error(f"数据库操作失败: {str(re)}") - return False - except Exception as e: - logger.error(f"删除文件失败: {str(e)}") - import traceback - logger.debug(traceback.format_exc()) - return False - finally: - try: - connections.disconnect("default") - logger.debug("已断开 Milvus 连接") - except Exception as e: - logger.warning(f"断开 Milvus 连接失败: {str(e)}") - -if __name__ == "__main__": - # 测试用例 - db_type = "textdb" - userid = "testuser2" - filename = "test.docx" - - logger.info(f"测试:删除 userid={userid}, filename={filename} 的文件") - result = delete_document(db_type, userid, filename) - print(f"删除结果: {result}") \ No newline at end of file diff --git a/rag/embed.py b/rag/embed.py deleted file mode 100644 index f7a8a1c..0000000 --- a/rag/embed.py +++ /dev/null @@ -1,183 +0,0 @@ -import os -import uuid -import yaml -import logging -from datetime import datetime -from typing import List -from langchain_core.documents import Document -from langchain_text_splitters import RecursiveCharacterTextSplitter -from pymilvus import connections -from vector import get_vector_db -from filetxt.loader import fileloader -from extract import extract_and_save_triplets -from kgc import KnowledgeGraph - -# 加载配置文件 -CONFIG_PATH = os.getenv('CONFIG_PATH', '/share/wangmeihua/rag/conf/milvusconfig.yaml') -try: - with open(CONFIG_PATH, 'r', encoding='utf-8') as f: - config = yaml.safe_load(f) - MILVUS_DB_PATH = config['database']['milvus_db_path'] -except Exception as e: - logger.error(f"加载配置文件 {CONFIG_PATH} 失败: {str(e)}") - raise RuntimeError(f"无法加载配置文件: {str(e)}") - -# 配置日志 -logger = logging.getLogger(config['logging']['name']) -logger.setLevel(getattr(logging, config['logging']['level'], logging.DEBUG)) -logger.handlers.clear() -logger.propagate = False -os.makedirs(os.path.dirname(config['logging']['file']), exist_ok=True) -formatter = logging.Formatter('%(asctime)s - %(levelname)s - %(message)s') -file_handler = logging.FileHandler(config['logging']['file'], encoding='utf-8') -file_handler.setFormatter(formatter) -stream_handler = logging.StreamHandler() -stream_handler.setFormatter(formatter) -logger.addHandler(file_handler) -logger.addHandler(stream_handler) - -def generate_document_id() -> str: - """为文件生成唯一的 document_id""" - return str(uuid.uuid4()) - -def load_and_split_data(file_path: str, userid: str, document_id: str) -> List[Document]: - """ - 加载文件,分片并生成带有元数据的 Document 对象。 - """ - try: - if not os.path.exists(file_path): - raise ValueError(f"文件 {file_path} 不存在") - if os.path.getsize(file_path) == 0: - raise ValueError(f"文件 {file_path} 为空") - logger.debug(f"检查文件: {file_path}, 大小: {os.path.getsize(file_path)} 字节") - ext = file_path.rsplit('.', 1)[1].lower() - logger.debug(f"文件扩展名: {ext}") - - logger.debug("开始加载文件") - text = fileloader(file_path) - if not text or not text.strip(): - raise ValueError(f"文件 {file_path} 加载为空") - - document = Document(page_content=text) - logger.debug(f"加载完成,生成 1 个文档") - - text_splitter = RecursiveCharacterTextSplitter( - chunk_size=2000, - chunk_overlap=200, - length_function=len, - ) - chunks = text_splitter.split_documents([document]) - logger.debug(f"分割完成,生成 {len(chunks)} 个文档块") - - filename = os.path.basename(file_path) - upload_time = datetime.now().isoformat() - documents = [] - for i, chunk in enumerate(chunks): - chunk.metadata.update({ - 'userid': userid, - 'document_id': document_id, - 'filename': filename, - 'file_path': file_path, - 'upload_time': upload_time, - 'file_type': ext, - 'chunk_index': i, - 'source': file_path, - }) - required_fields = ['userid', 'document_id', 'filename', 'file_path', 'upload_time', 'file_type'] - if not all(field in chunk.metadata and chunk.metadata[field] for field in required_fields): - raise ValueError(f"文档元数据缺少必需字段或值为空: {chunk.metadata}") - documents.append(chunk) - logger.debug(f"生成文档块 {i}: metadata={chunk.metadata}") - - logger.debug(f"文件 {file_path} 加载并分割为 {len(documents)} 个文档块,document_id: {document_id}") - return documents - - except Exception as e: - logger.error(f"加载或分割文件 {file_path} 失败: {str(e)}") - import traceback - logger.debug(traceback.format_exc()) - raise ValueError(f"加载或分割文件失败: {str(e)}") - -def embed(file_path: str, userid: str, db_type: str) -> bool: - """ - 嵌入文件到 Milvus 向量数据库,抽取三元组保存到指定路径,并将三元组存储到 Neo4j。 - """ - try: - if not userid or not db_type: - raise ValueError("userid 和 db_type 不能为空") - if "_" in userid: - raise ValueError("userid 不能包含下划线") - if "_" in db_type: - raise ValueError("db_type 不能包含下划线") - if not os.path.exists(file_path): - raise ValueError(f"文件 {file_path} 不存在") - - supported_formats = {'pdf', 'doc', 'docx', 'xlsx', 'xls', 'ppt', 'pptx', 'csv', 'txt'} - ext = file_path.rsplit('.', 1)[1].lower() - if ext not in supported_formats: - logger.error(f"文件 {file_path} 格式不支持,支持的格式: {', '.join(supported_formats)}") - raise ValueError(f"不支持的文件格式: {ext}, 支持的格式: {', '.join(supported_formats)}") - - document_id = generate_document_id() - logger.info(f"生成 document_id: {document_id} for file: {file_path}") - - logger.info(f"开始处理文件 {file_path},userid: {userid},db_type: {db_type}") - chunks = load_and_split_data(file_path, userid, document_id) - if not chunks: - logger.error(f"文件 {file_path} 未生成任何文档块") - raise ValueError("未生成任何文档块") - - logger.debug(f"处理文件 {file_path},生成 {len(chunks)} 个文档块") - logger.debug(f"第一个文档块: {chunks[0].page_content[:200]}") - - db = get_vector_db(userid, db_type, documents=chunks) - if not db: - logger.error(f"无法初始化或插入到向量数据库 ragdb_{db_type}") - raise RuntimeError(f"数据库操作失败") - - try: - full_text = fileloader(file_path) - if full_text and full_text.strip(): - success = extract_and_save_triplets(full_text, document_id, userid) - triplet_file_path = f"/share/wangmeihua/rag/triples/{document_id}_{userid}.txt" - if success and os.path.exists(triplet_file_path): - logger.info(f"文件 {file_path} 三元组保存到: {triplet_file_path}") - try: - kg = KnowledgeGraph(data_path=triplet_file_path, document_id=document_id) - logger.info(f"Step 1: 导入图谱节点到 Neo4j,document_id: {document_id}") - kg.create_graphnodes() - logger.info(f"Step 2: 导入图谱边到 Neo4j,document_id: {document_id}") - kg.create_graphrels() - logger.info(f"Step 3: 导出 Neo4j 节点数据,document_id: {document_id}") - kg.export_data() - logger.info(f"文件 {file_path} 三元组成功插入 Neo4j") - except Exception as e: - logger.warning(f"将三元组插入 Neo4j 失败: {str(e)},但不影响 Milvus 嵌入") - else: - logger.warning(f"文件 {file_path} 的三元组抽取失败或文件不存在: {triplet_file_path}") - else: - logger.warning(f"文件 {file_path} 内容为空,无法抽取三元组") - except Exception as e: - logger.error(f"文件 {file_path} 三元组抽取失败: {str(e)},但不影响向量化") - - logger.info(f"文件 {file_path} 成功嵌入到数据库 ragdb_{db_type}") - return True - - except ValueError as ve: - logger.error(f"嵌入文件 {file_path} 失败: {str(ve)}") - return False - except RuntimeError as re: - logger.error(f"嵌入文件 {file_path} 失败: {str(re)}") - return False - except Exception as e: - logger.error(f"嵌入文件 {file_path} 失败: {str(e)}") - import traceback - logger.debug(traceback.format_exc()) - return False - -if __name__ == "__main__": - test_file = "/share/wangmeihua/rag/data/test.docx" - userid = "testuser1" - db_type = "textdb" - result = embed(test_file, userid, db_type) - print(f"嵌入结果: {result}") \ No newline at end of file diff --git a/rag/extract.py b/rag/extract.py deleted file mode 100644 index d2f75ee..0000000 --- a/rag/extract.py +++ /dev/null @@ -1,225 +0,0 @@ -import os -import torch -import re -from transformers import AutoModelForSeq2SeqLM, AutoTokenizer -import logging -import yaml -import time - -# 加载配置文件 -CONFIG_PATH = os.getenv('CONFIG_PATH', '/share/wangmeihua/rag/conf/milvusconfig.yaml') -try: - with open(CONFIG_PATH, 'r', encoding='utf-8') as f: - config = yaml.safe_load(f) -except Exception as e: - print(f"加载配置文件 {CONFIG_PATH} 失败: {str(e)}") - raise RuntimeError(f"无法加载配置文件: {str(e)}") - -# 配置日志 -logger = logging.getLogger(config['logging']['name']) -logger.setLevel(getattr(logging, config['logging']['level'], logging.DEBUG)) -os.makedirs(os.path.dirname(config['logging']['file']), exist_ok=True) -formatter = logging.Formatter('%(asctime)s - %(levelname)s - %(message)s') -for handler in (logging.FileHandler(config['logging']['file'], encoding='utf-8'), logging.StreamHandler()): - handler.setFormatter(formatter) - logger.addHandler(handler) - -# 三元组保存路径 -TRIPLES_OUTPUT_DIR = "/share/wangmeihua/rag/triples" -os.makedirs(TRIPLES_OUTPUT_DIR, exist_ok=True) - -# 加载 mREBEL 模型和分词器 -local_path = "/share/models/Babelscape/mrebel-large" -try: - tokenizer = AutoTokenizer.from_pretrained(local_path, src_lang="zh_CN", tgt_lang="tp_XX") - model = AutoModelForSeq2SeqLM.from_pretrained(local_path) - device = torch.device("cuda" if torch.cuda.is_available() else "cpu") - model = model.to(device) - triplet_id = tokenizer.convert_tokens_to_ids("") - logger.debug(f"成功加载 mREBEL 模型,分词器 triplet_id: {triplet_id}") -except Exception as e: - logger.error(f"加载 mREBEL 模型失败: {str(e)}") - raise RuntimeError(f"加载 mREBEL 模型失败: {str(e)}") - -# 优化生成参数 -gen_kwargs = { - "max_length": 512, - "min_length": 10, - "length_penalty": 0.5, - "num_beams": 3, - "num_return_sequences": 1, - "no_repeat_ngram_size": 2, - "early_stopping": True, - "decoder_start_token_id": triplet_id, -} - -def split_document(text: str, max_chunk_size: int = 150) -> list: - """分割文档为语义完整的块""" - sentences = re.split(r'(?<=[。!?;\n])', text) - chunks = [] - current_chunk = "" - - for sentence in sentences: - if len(current_chunk) + len(sentence) <= max_chunk_size: - current_chunk += sentence - else: - if current_chunk: - chunks.append(current_chunk) - current_chunk = sentence - - if current_chunk: - chunks.append(current_chunk) - - return chunks - -def extract_triplets_typed(text: str) -> list: - """解析 mREBEL 生成文本,匹配 格式""" - triplets = [] - logger.debug(f"原始生成文本: {text}") - - # 分割标记 - tokens = [] - in_tag = False - buffer = "" - for char in text: - if char == '<': - in_tag = True - if buffer: - tokens.append(buffer.strip()) - buffer = "" - buffer += char - elif char == '>': - in_tag = False - buffer += char - tokens.append(buffer.strip()) - buffer = "" - else: - buffer += char - if buffer: - tokens.append(buffer.strip()) - - # 过滤特殊标记 - special_tokens = ["", "", "", "tp_XX", "__en__", "__zh__", "zh_CN"] - tokens = [t for t in tokens if t not in special_tokens and t] - - logger.debug(f"处理后标记: {tokens}") - - # 解析三元组 - i = 0 - while i < len(tokens): - if tokens[i] == "" and i + 5 < len(tokens): - entity1 = tokens[i + 1] - type1 = tokens[i + 2][1:-1] if tokens[i + 2].startswith("<") and tokens[i + 2].endswith(">") else "" - entity2 = tokens[i + 3] - type2 = tokens[i + 4][1:-1] if tokens[i + 4].startswith("<") and tokens[i + 4].endswith(">") else "" - relation = tokens[i + 5] - - if entity1 and type1 and entity2 and type2 and relation: - triplets.append({ - 'head': entity1.strip(), - 'head_type': type1, - 'type': relation.strip(), - 'tail': entity2.strip(), - 'tail_type': type2 - }) - logger.debug(f"添加三元组: {entity1}({type1}) - {relation} - {entity2}({type2})") - i += 6 - else: - i += 1 - - return triplets - -def extract_and_save_triplets(text: str, document_id: str, userid: str) -> bool: - """ - 从文本中抽取三元组并保存到指定路径。 - - 参数: - text (str): 输入文本 - document_id (str): 文档ID - userid (str): 用户ID - - 返回: - bool: 三元组抽取和保存是否成功 - """ - try: - if not text or not document_id or not userid: - raise ValueError("text、document_id 和 userid 不能为空") - if "_" in document_id or "_" in userid: - raise ValueError("document_id 和 userid 不能包含下划线") - - start_time = time.time() - logger.info(f"开始抽取文档 {document_id} 的三元组,userid: {userid}") - - # 分割文本为语义块 - text_chunks = split_document(text, max_chunk_size=150) - logger.debug(f"分割为 {len(text_chunks)} 个文本块") - - # 处理所有文本块 - all_triplets = [] - for i, chunk in enumerate(text_chunks): - logger.debug(f"处理块 {i + 1}/{len(text_chunks)}: {chunk[:50]}...") - - # 分词 - model_inputs = tokenizer( - chunk, - max_length=256, - padding=True, - truncation=True, - return_tensors="pt" - ).to(device) - - # 生成 - try: - generated_tokens = model.generate( - model_inputs["input_ids"], - attention_mask=model_inputs["attention_mask"], - **gen_kwargs, - ) - decoded_preds = tokenizer.batch_decode(generated_tokens, skip_special_tokens=False) - for idx, sentence in enumerate(decoded_preds): - logger.debug(f"块 {i + 1} 生成文本: {sentence}") - triplets = extract_triplets_typed(sentence) - if triplets: - logger.debug(f"块 {i + 1} 提取到 {len(triplets)} 个三元组") - all_triplets.extend(triplets) - except Exception as e: - logger.warning(f"处理块 {i + 1} 时出错: {str(e)}") - continue - - # 去重 - unique_triplets = [] - seen = set() - for t in all_triplets: - identifier = (t['head'].lower(), t['type'].lower(), t['tail'].lower()) - if identifier not in seen: - seen.add(identifier) - unique_triplets.append(t) - - # 保存结果 - output_file = os.path.join(TRIPLES_OUTPUT_DIR, f"{document_id}_{userid}.txt") - try: - with open(output_file, "w", encoding="utf-8") as f: - for t in unique_triplets: - f.write(f"{t['head']}\t{t['type']}\t{t['tail']}\t{t['head_type']}\t{t['tail_type']}\n") - logger.info(f"文档 {document_id} 的 {len(unique_triplets)} 个三元组已保存到: {output_file}") - except Exception as e: - logger.error(f"保存文档 {document_id} 的三元组失败: {str(e)}") - return False - - end_time = time.time() - logger.info(f"文档 {document_id} 三元组抽取完成,耗时: {end_time - start_time:.2f} 秒") - return True - - except Exception as e: - logger.error(f"抽取或保存三元组失败: {str(e)}") - import traceback - logger.debug(traceback.format_exc()) - return False - -if __name__ == "__main__": - # 测试用例 - test_text = "知识图谱是一个结构化的语义知识库。深度学习是基于深层神经网络的机器学习子集。" - document_id = "testdoc123" - userid = "testuser1" - result = extract_and_save_triplets(test_text, document_id, userid) - print(f"抽取结果: {result}") \ No newline at end of file diff --git a/rag/fusedsearch.py b/rag/fusedsearch.py deleted file mode 100644 index a2f9ad1..0000000 --- a/rag/fusedsearch.py +++ /dev/null @@ -1,290 +0,0 @@ -import os -import logging -import yaml -import numpy as np -from typing import List, Dict, Any -from pymilvus import Collection, utility -from langchain_huggingface import HuggingFaceEmbeddings -from vector import initialize_milvus_connection -from searchquery import extract_entities, match_triplets -from rerank import rerank_results -import torch - -# 加载配置文件 -CONFIG_PATH = os.getenv('CONFIG_PATH', '/share/wangmeihua/rag/conf/milvusconfig.yaml') -try: - with open(CONFIG_PATH, 'r', encoding='utf-8') as f: - config = yaml.safe_load(f) - TEXT_EMBEDDING_MODEL = config['models']['text_embedding_model'] -except Exception as e: - raise RuntimeError(f"无法加载配置文件: {str(e)}") - -# 配置日志 -logger = logging.getLogger(config['logging']['name']) -logger.setLevel(getattr(logging, config['logging']['level'], logging.DEBUG)) -logger.handlers.clear() -logger.propagate = False -os.makedirs(os.path.dirname(config['logging']['file']), exist_ok=True) -try: - with open(config['logging']['file'], 'a', encoding='utf-8') as f: - pass -except Exception as e: - raise RuntimeError(f"日志文件 {config['logging']['file']} 不可写: {str(e)}") -formatter = logging.Formatter('%(asctime)s - %(levelname)s - %(message)s') -file_handler = logging.FileHandler(config['logging']['file'], encoding='utf-8') -file_handler.setFormatter(formatter) -stream_handler = logging.StreamHandler() -stream_handler.setFormatter(formatter) -logger.addHandler(file_handler) -logger.addHandler(stream_handler) - -# 初始化嵌入模型 -embedding = HuggingFaceEmbeddings( - model_name=TEXT_EMBEDDING_MODEL, - model_kwargs={'device': 'cuda' if torch.cuda.is_available() else 'cpu'}, - encode_kwargs={'normalize_embeddings': True} -) -try: - test_vector = embedding.embed_query("test") - if len(test_vector) != 1024: - raise ValueError(f"嵌入模型输出维度 {len(test_vector)} 不匹配预期 1024") - logger.debug("嵌入模型加载成功") -except Exception as e: - logger.error(f"嵌入模型加载失败: {str(e)}") - raise RuntimeError(f"嵌入模型加载失败: {str(e)}") - -def fused_search( - query: str, - userid: str, - db_type: str, - file_paths: List[str], - limit: int = 10, - offset: int = 0, - use_rerank: bool = True -) -> List[Dict[str, Any]]: - """ - 融合 RAG 和三元组召回文本块: - - 调用 searchquery.py 的 extract_entities 和 match_triplets 获取三元组。 - - 将所有匹配三元组拼接为融合文本,向量化后在 Milvus 中搜索。 - - 参数: - query (str): 查询文本 - userid (str): 用户 ID - db_type (str): 数据库类型 (e.g., 'textdb') - file_paths (List[str]): 文件路径列表 - limit (int): 返回结果数量 - offset (int): 偏移量 - use_rerank (bool): 是否使用重排序 - - 返回: - List[Dict[str, Any]]: 召回结果,包含 text、distance、metadata - """ - try: - logger.info(f"开始融合搜索: query={query}, userid={userid}, db_type={db_type}") - - # 参数验证 - if not query or not userid or not db_type or not file_paths: - raise ValueError("query、userid、db_type 和 file_paths 不能为空") - if "_" in userid or "_" in db_type: - raise ValueError("userid 和 db_type 不能包含下划线") - - # 初始化 Milvus 连接 - connections = initialize_milvus_connection() - collection_name = f"ragdb_{db_type}" - if not utility.has_collection(collection_name): - logger.warning(f"集合 {collection_name} 不存在") - return [] - collection = Collection(collection_name) - collection.load() - logger.debug(f"加载 Milvus 集合: {collection_name}") - - # 提取实体 - query_entities = extract_entities(query) - logger.debug(f"提取实体: {query_entities}") - - # 收集所有结果 - results = [] - search_params = {"metric_type": "COSINE", "params": {"nprobe": 10}} - - for file_path in file_paths: - filename = os.path.basename(file_path) - logger.debug(f"处理文件: {filename}") - - # 获取 document_id - results_query = collection.query( - expr=f"userid == '{userid}' and filename == '{filename}'", - output_fields=["document_id"], - limit=1 - ) - if not results_query: - logger.warning(f"未找到 userid {userid} 和 filename {filename} 对应的文档") - continue - document_id = results_query[0]["document_id"] - logger.debug(f"找到 document_id: {document_id}") - - # 获取匹配的三元组 - matched_triplets = match_triplets(query, query_entities, userid, document_id) - logger.debug(f"匹配三元组: {matched_triplets}") - - # 若无三元组,使用原查询向量化 - if not matched_triplets: - logger.debug(f"无匹配三元组,使用原查询: {query}") - query_vector = embedding.embed_query(query) - expr = f"userid == '{userid}' and filename == '{filename}'" - milvus_results = collection.search( - data=[query_vector], - anns_field="vector", - param=search_params, - limit=limit, - expr=expr, - output_fields=["text", "userid", "document_id", "filename", "file_path", "upload_time", "file_type"], - offset=offset - ) - for hits in milvus_results: - for hit in hits: - result = { - "text": hit.entity.get("text"), - "distance": hit.distance, - "source": "fused_query", - "metadata": { - "userid": hit.entity.get("userid"), - "document_id": hit.entity.get("document_id"), - "filename": hit.entity.get("filename"), - "file_path": hit.entity.get("file_path"), - "upload_time": hit.entity.get("upload_time"), - "file_type": hit.entity.get("file_type") - } - } - results.append(result) - logger.debug(f"召回: text={result['text'][:100]}..., distance={result['distance']}") - continue - - # 拼接所有三元组 - triplet_texts = [] - for triplet in matched_triplets: - head = triplet['head'] - type = triplet['type'] - tail = triplet['tail'] - if not head or not type or not tail: - logger.debug(f"无效三元组: {triplet}") - continue - triplet_texts.append(f"{head} {type} {tail}") - if not triplet_texts: - logger.debug(f"无有效三元组,使用原查询: {query}") - query_vector = embedding.embed_query(query) - expr = f"userid == '{userid}' and filename == '{filename}'" - milvus_results = collection.search( - data=[query_vector], - anns_field="vector", - param=search_params, - limit=5, - expr=expr, - output_fields=["text", "userid", "document_id", "filename", "file_path", "upload_time", "file_type"], - offset=offset - ) - for hits in milvus_results: - for hit in hits: - result = { - "text": hit.entity.get("text"), - "distance": hit.distance, - "source": "fused_query", - "metadata": { - "userid": hit.entity.get("userid"), - "document_id": hit.entity.get("document_id"), - "filename": hit.entity.get("filename"), - "file_path": hit.entity.get("file_path"), - "upload_time": hit.entity.get("upload_time"), - "file_type": hit.entity.get("file_type") - } - } - results.append(result) - logger.debug(f"召回: text={result['text'][:100]}..., distance={result['distance']}") - continue - - # 生成融合文本 - fused_text = f"{query} {' '.join(triplet_texts)}" - logger.debug(f"融合文本: {fused_text}") - - # 向量化 - fused_vector = embedding.embed_query(fused_text) - fused_vector = np.array(fused_vector) / np.linalg.norm(fused_vector) - logger.debug(f"生成融合向量,维度: {len(fused_vector)}") - - # Milvus 搜索 - expr = f"userid == '{userid}' and filename == '{filename}'" - milvus_results = collection.search( - data=[fused_vector], - anns_field="vector", - param=search_params, - limit=5, - expr=expr, - output_fields=["text", "userid", "document_id", "filename", "file_path", "upload_time", "file_type"], - offset=offset - ) - - for hits in milvus_results: - for hit in hits: - result = { - "text": hit.entity.get("text"), - "distance": hit.distance, - "source": f"fused_triplets_{len(triplet_texts)}", - "metadata": { - "userid": hit.entity.get("userid"), - "document_id": hit.entity.get("document_id"), - "filename": hit.entity.get("filename"), - "file_path": hit.entity.get("file_path"), - "upload_time": hit.entity.get("upload_time"), - "file_type": hit.entity.get("file_type") - } - } - results.append(result) - logger.debug(f"召回: text={result['text'][:100]}..., distance={result['distance']}") - - # 去重 - unique_results = [] - seen_texts = set() - for result in results: - text = result['text'] - if text not in seen_texts: - seen_texts.add(text) - unique_results.append(result) - logger.debug(f"去重后结果数量: {len(unique_results)}") - - # 可选:重排序 - if use_rerank and unique_results: - logger.debug("开始重排序") - reranked_results = rerank_results(query, unique_results) - # 按 rerank_score 降序排序 - reranked_results = sorted(reranked_results, key=lambda x: x['rerank_score'], reverse=True) - for i, result in enumerate(reranked_results): - logger.debug(f"排序结果 {i+1}: text={result['text'][:100]}..., distance={result['distance']}, rerank_score={result['rerank_score']}") - return reranked_results[:limit] - - # 按 distance 降序排序 - sorted_results = sorted(unique_results, key=lambda x: x['distance'], reverse=True) - for i, result in enumerate(sorted_results): - logger.debug(f"排序结果 {i+1}: text={result['text'][:100]}..., distance={result['distance']}") - return sorted_results[:limit] - - except Exception as e: - logger.error(f"融合搜索失败: {str(e)}") - import traceback - logger.debug(traceback.format_exc()) - return [] - -if __name__ == "__main__": - query = "知识图谱构建需要什么技术?" - userid = "testuser1" - db_type = "textdb" - file_paths = [ - "/share/wangmeihua/rag/data/test.docx", - "/share/wangmeihua/rag/data/zongshu.pdf", - "/share/wangmeihua/rag/data/qianru.pdf" - ] - results = fused_search(query, userid, db_type, file_paths, limit=10, offset=0) - for i, result in enumerate(results): - print(f"Result {i+1}:") - print(f"Text: {result['text'][:200]}...") - print(f"Distance: {result['distance']}") - print(f"Source: {result['source']}") - print(f"Metadata: {result['metadata']}\n") \ No newline at end of file diff --git a/rag/kdb.py b/rag/kdb.py deleted file mode 100644 index 47ef805..0000000 --- a/rag/kdb.py +++ /dev/null @@ -1,81 +0,0 @@ - -from traceback import format_exc -from appPublic.uniqueID import getID -from appPublic.timeUtils import curDateString -from appPublic.dictObject import DictObject -from sqlor.dbpools import DBPools -from ahserver.serverenv import get_serverenv -from ahserver.filestorage import FileStorage - -async def add_kdb(kdb:dict) -> None: - """ - 添加知识库 - """ - kdb = DictObject(**kdb) - kdb.parentid=None - if kdb.id is None: - kdb.id = getID() - kdb.entity_type = '0' - kdb.create_date = curDateString() - if kdb.orgid is None: - e = Exception(f'Can not add none orgid kdb') - exception(f'{e}\n{format_exc()}') - raise e - - f = get_serverenv('get_module_dbname') - dbname = f('rag') - db = DBPools() - async with db.sqlorContext(dbname) as sor: - await C('kdb', kdb.copy()) - -async def add_dir(kdb:dict) -> None: - """ - 添加子目录 - """ - kdb = DictObject(**kdb) - if kdb.parentid is None: - e = Exception(f'Can not add root folder') - exception(f'{e}\n{format_exc()}') - raise e - if kdb.id is None: - kdb.id = getID() - kdb.entity_type = '1' - kdb.create_date = curDateString() - f = get_serverenv('get_module_dbname') - dbname = f('rag') - db = DBPools() - async with db.sqlorContext(dbname) as sor: - await C('kdb', kdb.copy()) - -async def add_doc(doc:dict) -> None: - """ - 添加文档 - """ - doc = DictObject(**doc) - if doc.parentid is None: - e = Exception(f'Can not add root document') - exception(f'{e}\n{format_exc()}') - raise e - if doc.id is None: - doc.id = getID() - fs = FileStorage() - doc.realpath = fs.realPath(doc.webpath) - doc.create_date = curDateString() - f = get_serverenv('get_module_dbname') - dbname = f('rag') - db = DBPools() - async with db.sqlorContext(dbname) as sor: - await C('doc', doc.copy()) - -async def get_all_docs(sor, kdbid): - """ - 获取所有kdbid下的文档,含子目录的 - """ - docs = await sor.R('doc', {'parentid':kdbid}) - kdbs = await sor.R('kdb', {'parentid':kdbid}) - for kdb in kdbs: - docs1 = await get_all_docs(kdb.id) - docs += docs1 - return docs - - diff --git a/rag/kgc.py b/rag/kgc.py deleted file mode 100644 index ef989df..0000000 --- a/rag/kgc.py +++ /dev/null @@ -1,194 +0,0 @@ -import os -import logging -import re -from py2neo import Graph, Node, Relationship -from typing import Set, List, Dict, Tuple - -from ufw.common import share_dir - -# 配置日志 -logging.basicConfig(level=logging.INFO, format='%(asctime)s - %(levelname)s - %(message)s') -logger = logging.getLogger(__name__) - -class KnowledgeGraph: - def __init__(self, data_path: str, document_id: str = None): - self.data_path = data_path - self.document_id = document_id or os.path.basename(data_path).split('_')[0] - self.g = Graph("bolt://10.18.34.18:7687", auth=('neo4j', '261229..wmh')) - logger.info(f"开始构建知识图谱,data_path: {self.data_path}, document_id: {self.document_id}") - # 验证 data_path 是否有效 - if not os.path.exists(self.data_path): - logger.error(f"数据路径 {self.data_path} 不存在") - raise ValueError(f"数据路径 {self.data_path} 不存在") - - def _normalize_label(self, entity_type: str) -> str: - """规范化实体类型为 Neo4j 标签""" - if not entity_type or not entity_type.strip(): - return 'Entity' - entity_type = re.sub(r'[^\w\s]', '', entity_type.strip()) - words = entity_type.split() - label = '_'.join(word.capitalize() for word in words if word) - return label or 'Entity' - - def _clean_relation(self, relation: str) -> Tuple[str, str]: - """清洗关系,返回 (rel_type, rel_name)""" - relation = relation.strip() - if not relation: - return 'RELATED_TO', '相关' - if relation.startswith('<') and relation.endswith('>'): - cleaned_relation = relation[1:-1] - rel_name = cleaned_relation - rel_type = re.sub(r'[^\w\s]', '', cleaned_relation).replace(' ', '_').upper() - else: - rel_name = relation - rel_type = re.sub(r'[^\w\s]', '', relation).replace(' ', '_').upper() - if 'instance of' in relation.lower(): - rel_type = 'INSTANCE_OF' - rel_name = '实例' - elif 'subclass of' in relation.lower(): - rel_type = 'SUBCLASS_OF' - rel_name = '子类' - elif 'part of' in relation.lower(): - rel_type = 'PART_OF' - rel_name = '部分' - logger.debug(f"处理关系: {relation} -> {rel_type} ({rel_name})") - return rel_type, rel_name - - def read_nodes(self) -> Tuple[Dict[str, Set], Dict[str, List], List[Dict]]: - """读取三元组数据,返回节点和关系""" - nodes_by_label = {} - relations_by_type = {} - triples = [] - - try: - logger.debug(f"尝试读取文件: {self.data_path}") - with open(self.data_path, 'r', encoding='utf-8') as f: - for line in f: - line = line.strip() - if not line or line.startswith('#'): - continue - parts = line.split('\t') - if len(parts) != 5: - logger.warning(f"无效行: {line}") - continue - head, relation, tail, head_type, tail_type = parts - head_label = self._normalize_label(head_type) - tail_label = self._normalize_label(tail_type) - logger.debug(f"实体类型: {head_type} -> {head_label}, {tail_type} -> {tail_label}") - - if head_label not in nodes_by_label: - nodes_by_label[head_label] = set() - if tail_label not in nodes_by_label: - nodes_by_label[tail_label] = set() - nodes_by_label[head_label].add(head) - nodes_by_label[tail_label].add(tail) - - rel_type, rel_name = self._clean_relation(relation) - if rel_type not in relations_by_type: - relations_by_type[rel_type] = [] - relations_by_type[rel_type].append({ - 'head': head, - 'tail': tail, - 'head_label': head_label, - 'tail_label': tail_label, - 'rel_name': rel_name - }) - - triples.append({ - 'head': head, - 'relation': relation, - 'tail': tail, - 'head_type': head_type, - 'tail_type': tail_type - }) - - logger.info(f"读取节点: {sum(len(nodes) for nodes in nodes_by_label.values())} 个") - logger.info(f"读取关系: {sum(len(rels) for rels in relations_by_type.values())} 条") - return nodes_by_label, relations_by_type, triples - - except Exception as e: - logger.error(f"读取数据失败: {str(e)},data_path: {self.data_path}") - raise RuntimeError(f"读取数据失败: {str(e)}") - - def create_node(self, label: str, nodes: Set[str]): - """创建节点,包含 document_id 属性""" - count = 0 - for node_name in nodes: - query = f"MATCH (n:{label} {{name: '{node_name}', document_id: '{self.document_id}'}}) RETURN n" - try: - if self.g.run(query).data(): - continue - node = Node(label, name=node_name, document_id=self.document_id) - self.g.create(node) - count += 1 - logger.debug(f"创建节点: {label} - {node_name} (document_id: {self.document_id})") - except Exception as e: - logger.error(f"创建节点失败: {label} - {node_name}, 错误: {str(e)}") - logger.info(f"创建 {label} 节点: {count}/{len(nodes)} 个") - return count - - def create_relationship(self, rel_type: str, relations: List[Dict]): - """创建关系""" - count = 0 - total = len(relations) - seen_edges = set() - for rel in relations: - head, tail, head_label, tail_label, rel_name = ( - rel['head'], rel['tail'], rel['head_label'], rel['tail_label'], rel['rel_name'] - ) - edge_key = f"{head_label}:{head}###{tail_label}:{tail}###{rel_type}" - if edge_key in seen_edges: - continue - seen_edges.add(edge_key) - - query = ( - f"MATCH (p:{head_label} {{name: '{head}', document_id: '{self.document_id}'}}), " - f"(q:{tail_label} {{name: '{tail}', document_id: '{self.document_id}'}}) " - f"CREATE (p)-[r:{rel_type} {{name: '{rel_name}'}}]->(q)" - ) - try: - self.g.run(query) - count += 1 - logger.debug(f"创建关系: {head} -[{rel_type}]-> {tail} (document_id: {self.document_id})") - except Exception as e: - logger.error(f"创建关系失败: {query}, 错误: {str(e)}") - logger.info(f"创建 {rel_type} 关系: {count}/{total} 条") - return count - - def create_graphnodes(self): - """创建所有节点""" - nodes_by_label, _, _ = self.read_nodes() - total = 0 - for label, nodes in nodes_by_label.items(): - total += self.create_node(label, nodes) - logger.info(f"总计创建节点: {total} 个") - return total - - def create_graphrels(self): - """创建所有关系""" - _, relations_by_type, _ = self.read_nodes() - total = 0 - for rel_type, relations in relations_by_type.items(): - total += self.create_relationship(rel_type, relations) - logger.info(f"总计创建关系: {total} 条") - return total - - def export_data(self): - """导出节点到文件,包含 document_id""" - nodes_by_label, _, _ = self.read_nodes() - os.makedirs('dict', exist_ok=True) - for label, nodes in nodes_by_label.items(): - with open(f'dict/{label.lower()}.txt', 'w', encoding='utf-8') as f: - f.write('\n'.join(f"{name}\t{self.document_id}" for name in sorted(nodes))) - logger.info(f"导出 {label} 节点到 dict/{label.lower()}.txt: {len(nodes)} 个") - return - -if __name__ == '__main__': - data_path = '/share/wangmeihua/rag/triples/26911c68-9107-4bb4-8f31-ff776991a119_testuser2.txt' - handler = KnowledgeGraph(data_path) - logger.info("Step 1: 导入图谱节点中") - handler.create_graphnodes() - logger.info("Step 2: 导入图谱边中") - handler.create_graphrels() - logger.info("Step 3: 导出数据") - handler.export_data() \ No newline at end of file diff --git a/rag/query.py b/rag/query.py deleted file mode 100644 index fb44cb4..0000000 --- a/rag/query.py +++ /dev/null @@ -1,201 +0,0 @@ -import os -import yaml -import logging -from typing import List, Dict -from pymilvus import connections, Collection, utility -from langchain_huggingface import HuggingFaceEmbeddings -from vector import initialize_milvus_connection, cleanup_milvus_connection -import torch - -# 加载配置文件 -CONFIG_PATH = os.getenv('CONFIG_PATH', '/share/wangmeihua/rag/conf/milvusconfig.yaml') -try: - with open(CONFIG_PATH, 'r', encoding='utf-8') as f: - config = yaml.safe_load(f) - MILVUS_DB_PATH = config['database']['milvus_db_path'] - TEXT_EMBEDDING_MODEL = config['models']['text_embedding_model'] -except Exception as e: - print(f"加载配置文件 {CONFIG_PATH} 失败: {str(e)}") - raise RuntimeError(f"无法加载配置文件: {str(e)}") - -# 配置日志 -logger = logging.getLogger(config['logging']['name']) -logger.setLevel(getattr(logging, config['logging']['level'], logging.DEBUG)) -logger.handlers.clear() # 清除现有处理器,避免重复 -logger.propagate = False # 禁用传播到父级 -os.makedirs(os.path.dirname(config['logging']['file']), exist_ok=True) -formatter = logging.Formatter('%(asctime)s - %(levelname)s - %(message)s') -file_handler = logging.FileHandler(config['logging']['file'], encoding='utf-8') -file_handler.setFormatter(formatter) -stream_handler = logging.StreamHandler() -stream_handler.setFormatter(formatter) -logger.addHandler(file_handler) -logger.addHandler(stream_handler) - -def search_query(query: str, userid: str, db_type: str, file_paths: List[str], limit: int = 5, offset: int = 0) -> List[Dict]: - """ - 根据用户输入的查询文本,在指定 db_type 的知识库中搜索与 userid 相关的指定文档。 - - 参数: - query (str): 用户输入的查询文本 - userid (str): 用户ID,用于过滤 - db_type (str): 数据库类型(例如 'textdb') - file_paths (List[str]): 文档路径列表(支持1到多个文件) - limit (int): 返回的最大结果数,默认为 10 - offset (int): 偏移量,用于分页,默认为 0 - - 返回: - List[Dict]: 搜索结果,每个元素为包含 text、distance 和 metadata 的字典 - - 异常: - ValueError: 参数无效 - RuntimeError: 模型加载或 Milvus 操作失败 - """ - try: - # 参数验证 - if not query: - raise ValueError("查询文本不能为空") - if not userid or not db_type: - raise ValueError("userid 和 db_type 不能为空") - if "_" in userid or "_" in db_type: - raise ValueError("userid 和 db_type 不能包含下划线") - if len(userid) > 100 or len(db_type) > 100: - raise ValueError("userid 或 db_type 的长度超出限制") - if limit <= 0 or limit > 16384: - raise ValueError("limit 必须在 1 到 16384 之间") - if offset < 0: - raise ValueError("offset 不能为负数") - if limit + offset > 16384: - raise ValueError("limit + offset 不能超过 16384") - if not file_paths: - raise ValueError("file_paths 不能为空") - for file_path in file_paths: - if not isinstance(file_path, str): - raise ValueError(f"file_path 必须是字符串: {file_path}") - if len(os.path.basename(file_path)) > 255: - raise ValueError(f"文件名长度超出 255 个字符: {file_path}") - if "_" in os.path.basename(file_path): - raise ValueError(f"文件名 {file_path} 不能包含下划线") - - # 初始化嵌入模型 - model_path = TEXT_EMBEDDING_MODEL - if not os.path.exists(model_path): - raise ValueError(f"模型路径 {model_path} 不存在") - - embedding = HuggingFaceEmbeddings( - model_name=model_path, - model_kwargs={'device': 'cuda' if torch.cuda.is_available() else 'cpu'}, - encode_kwargs={'normalize_embeddings': True} - ) - try: - test_vector = embedding.embed_query("test") - if len(test_vector) != 1024: - raise ValueError(f"嵌入模型输出维度 {len(test_vector)} 不匹配预期 1024") - logger.debug("嵌入模型加载成功") - except Exception as e: - logger.error(f"嵌入模型加载失败: {str(e)}") - raise RuntimeError(f"嵌入模型加载失败: {str(e)}") - - # 将查询转换为向量 - query_vector = embedding.embed_query(query) - logger.debug(f"查询向量维度: {len(query_vector)}") - - # 连接到 Milvus - initialize_milvus_connection() - - # 检查集合是否存在 - collection_name = f"ragdb_{db_type}" - if not utility.has_collection(collection_name): - logger.warning(f"集合 {collection_name} 不存在") - return [] - - # 加载集合 - try: - collection = Collection(collection_name) - collection.load() - logger.debug(f"加载集合: {collection_name}") - except Exception as e: - logger.error(f"加载集合 {collection_name} 失败: {str(e)}") - raise RuntimeError(f"加载集合失败: {str(e)}") - - # 构造搜索参数 - search_params = { - "metric_type": "COSINE", # 与 vector.py 一致 - "params": {"nprobe": 10} # 优化搜索性能 - } - - # 构造过滤表达式,限制在指定文件 - filenames = [os.path.basename(file_path) for file_path in file_paths] - filename_expr = " or ".join([f"filename == '{filename}'" for filename in filenames]) - expr = f"userid == '{userid}' and ({filename_expr})" - logger.debug(f"搜索参数: {search_params}, 表达式: {expr}, limit: {limit}, offset: {offset}") - - # 执行搜索 - try: - results = collection.search( - data=[query_vector], - anns_field="vector", - param=search_params, - limit=limit, - expr=expr, - output_fields=["text", "userid", "document_id", "filename", "file_path", "upload_time", "file_type"], - offset=offset - ) - except Exception as e: - logger.error(f"搜索失败: {str(e)}") - raise RuntimeError(f"搜索失败: {str(e)}") - - # 处理搜索结果 - search_results = [] - for hits in results: - for hit in hits: - metadata = { - "userid": hit.entity.get("userid"), - "document_id": hit.entity.get("document_id"), - "filename": hit.entity.get("filename"), - "file_path": hit.entity.get("file_path"), - "upload_time": hit.entity.get("upload_time"), - "file_type": hit.entity.get("file_type") - } - result = { - "text": hit.entity.get("text"), - "distance": hit.distance, - "metadata": metadata - } - search_results.append(result) - logger.debug(f"命中: text: {result['text'][:200]}..., 距离: {hit.distance}, 元数据: {metadata}") - - logger.debug(f"搜索完成,返回 {len(search_results)} 条结果") - return search_results - - except Exception as e: - logger.error(f"搜索失败: {str(e)}") - import traceback - logger.debug(traceback.format_exc()) - raise - finally: - cleanup_milvus_connection() - -if __name__ == "__main__": - # 测试代码 - query = "知识图谱的知识融合是什么?" - userid = "testuser2" - db_type = "textdb" - file_paths = [ - "/share/wangmeihua/rag/data/test.docx", - "/share/wangmeihua/rag/data/test.txt" - ] - limit = 5 - offset = 0 - - try: - results = search_query(query, userid, db_type, file_paths, limit, offset) - print(f"搜索结果 ({len(results)} 条):") - for idx, result in enumerate(results, 1): - print(f"结果 {idx}:") - print(f"内容: {result['text'][:200]}...") - print(f"距离: {result['distance']}") - print(f"元数据: {result['metadata']}") - print("-" * 50) - except Exception as e: - print(f"搜索失败: {str(e)}") \ No newline at end of file diff --git a/rag/rerank.py b/rag/rerank.py deleted file mode 100644 index cc1ff69..0000000 --- a/rag/rerank.py +++ /dev/null @@ -1,80 +0,0 @@ -import os -import yaml -import logging -from typing import List, Dict -from pymilvus.model.reranker import BGERerankFunction -import torch - -# 加载配置文件 -CONFIG_PATH = os.getenv('CONFIG_PATH', '/share/wangmeihua/rag/conf/milvusconfig.yaml') -try: - with open(CONFIG_PATH, 'r', encoding='utf-8') as f: - config = yaml.safe_load(f) -except Exception as e: - print(f"加载配置文件 {CONFIG_PATH} 失败: {str(e)}") - raise RuntimeError(f"加载配置文件: {str(e)}") - -# 配置日志 -logger = logging.getLogger(config['logging']['name']) -logger.setLevel(getattr(logging, config['logging']['level'], logging.DEBUG)) -logger.handlers.clear() # 清除现有处理器 -logger.propagate = False # 禁用传播 -os.makedirs(os.path.dirname(config['logging']['file']), exist_ok=True) -formatter = logging.Formatter('%(asctime)s - %(levelname)s - %(message)s') -file_handler = logging.FileHandler(config['logging']['file'], encoding='utf-8') -file_handler.setFormatter(formatter) -stream_handler = logging.StreamHandler() -stream_handler.setFormatter(formatter) -logger.addHandler(file_handler) -logger.addHandler(stream_handler) - -def rerank_results(query: str, results: List[Dict], top_k: int = 10) -> List[Dict]: - """ - 使用 BGE Reranker 模型对查询和文本块进行重排序。 - - 参数: - query (str): 查询文本 - results (List[Dict]): 包含 text、distance、source 和 metadata 的结果列表 - top_k (int): 返回的最大结果数,默认为 10 - - 返回: - List[Dict]: 重排序后的结果列表,包含 text、distance、source、metadata 和 rerank_score - """ - try: - # 初始化 BGE Reranker - bge_rf = BGERerankFunction( - model_name="/share/models/BAAI/bge-reranker-v2-m3", - device="cuda:0" if torch.cuda.is_available() else "cpu" - ) - logger.debug(f"BGE Reranker 初始化成功,模型路径: /share/models/BAAI/bge-reranker-v2-m3, 设备: {'cuda:0' if torch.cuda.is_available() else 'cpu'}") - - # 提取文本块 - documents = [result['text'] for result in results] - if not documents: - logger.warning("无文本块可重排序") - return results - - # 重排序 - rerank_results = bge_rf( - query=query, - documents=documents, - top_k=min(top_k, len(documents)) - ) - - # 构建重排序结果 - reranked = [] - for result in rerank_results: - original_result = results[result.index].copy() - original_result['rerank_score'] = result.score - reranked.append(original_result) - logger.debug(f"重排序结果: text={result.text[:200]}..., rerank_score={result.score:.6f}, source={original_result['source']}") - - logger.info(f"重排序返回 {len(reranked)} 条结果") - return reranked - - except Exception as e: - logger.error(f"重排序失败: {str(e)}") - import traceback - logger.debug(traceback.format_exc()) - # 回退到原始结果 - return results \ No newline at end of file diff --git a/rag/searchquery.py b/rag/searchquery.py deleted file mode 100644 index 54073b8..0000000 --- a/rag/searchquery.py +++ /dev/null @@ -1,363 +0,0 @@ -import os -import yaml -import logging -from typing import List, Dict -from pymilvus import connections, Collection, utility -from langchain_huggingface import HuggingFaceEmbeddings -import numpy as np -from scipy.spatial.distance import cosine -from ltp import LTP -from vector import initialize_milvus_connection, cleanup_milvus_connection -import torch - -# 加载配置文件 -CONFIG_PATH = os.getenv('CONFIG_PATH', '/share/wangmeihua/rag/conf/milvusconfig.yaml') -try: - with open(CONFIG_PATH, 'r', encoding='utf-8') as f: - config = yaml.safe_load(f) - MILVUS_DB_PATH = config['database']['milvus_db_path'] - TEXT_EMBEDDING_MODEL = config['models']['text_embedding_model'] -except Exception as e: - print(f"加载配置文件 {CONFIG_PATH} 失败: {str(e)}") - raise RuntimeError(f"加载配置文件失败: {str(e)}") - -# 配置日志 -logger = logging.getLogger(config['logging']['name']) -logger.setLevel(getattr(logging, config['logging']['level'], logging.DEBUG)) -logger.handlers.clear() # 清理现有处理器,避免重复 -logger.propagate = False # 禁用传播到父级 -os.makedirs(os.path.dirname(config['logging']['file']), exist_ok=True) -formatter = logging.Formatter('%(asctime)s - %(levelname)s - %(message)s') -file_handler = logging.FileHandler(config['logging']['file'], encoding='utf-8') -file_handler.setFormatter(formatter) -stream_handler = logging.StreamHandler() -stream_handler.setFormatter(formatter) -logger.addHandler(file_handler) -logger.addHandler(stream_handler) - -# 三元组保存路径 -TRIPLES_OUTPUT_DIR = '/share/wangmeihua/rag/triples' - -# 初始化嵌入模型 -embedding = HuggingFaceEmbeddings( - model_name=TEXT_EMBEDDING_MODEL, - model_kwargs={'device': 'cuda' if torch.cuda.is_available() else 'cpu'}, - encode_kwargs={'normalize_embeddings': True} -) -try: - test_vector = embedding.embed_query("test") - if len(test_vector) != 1024: - raise ValueError(f"嵌入模型输出维度 {len(test_vector)} 不匹配预期 1024") - logger.debug("嵌入模型加载成功") -except Exception as e: - logger.error(f"嵌入模型加载失败: {str(e)}") - raise RuntimeError(f"嵌入模型加载失败: {str(e)}") - -# 初始化 LTP 模型 -try: - model_path = "/share/models/LTP/small" - if not os.path.isdir(model_path): - logger.warning(f"本地模型路径 {model_path} 不存在,尝试使用 Hugging Face 模型 'hit-scir/ltp-small'") - model_path = "hit-scir/ltp-small" - ltp = LTP(pretrained_model_name_or_path=model_path) - if torch.cuda.is_available(): - ltp.to("cuda") - logger.debug("LTP 模型加载成功") -except Exception as e: - logger.error(f"加载 LTP 模型失败: {str(e)}") - raise RuntimeError(f"加载 LTP 模型失败: {str(e)}") - -def extract_entities(query: str) -> List[str]: - """ - 从查询文本中抽取实体,包括: - - LTP NER 识别的实体(所有类型)。 - - LTP POS 标注为名词('n')的词。 - - LTP POS 标注为动词('v')的词。 - - 连续名词合并(如 '苹果 公司' -> '苹果公司'),移除子词。 - """ - try: - if not query: - raise ValueError("查询文本不能为空") - - # 使用 LTP pipeline 获取分词、词性、NER 结果 - result = ltp.pipeline([query], tasks=["cws", "pos", "ner"]) - words = result.cws[0] - pos_list = result.pos[0] - ner = result.ner[0] - - entities = [] - subword_set = set() # 记录连续名词的子词 - - # 提取 1:NER 实体(所有类型) - logger.debug(f"NER 结果: {ner}") - for entity_type, entity, start, end in ner: - entities.append(entity) - - # 提取 2:合并连续名词 - combined = "" - combined_words = [] # 记录当前连续名词的单词 - for i in range(len(words)): - if pos_list[i] == 'n': - combined += words[i] - combined_words.append(words[i]) - if i + 1 < len(words) and pos_list[i + 1] == 'n': - continue - if combined: - entities.append(combined) - subword_set.update(combined_words) - logger.debug(f"合并连续名词: {combined}, 子词: {combined_words}") - combined = "" - combined_words = [] - else: - combined = "" - combined_words = [] - logger.debug(f"连续名词子词集合: {subword_set}") - - # 提取 3:POS 名词('n'),排除子词 - for word, pos in zip(words, pos_list): - if pos == 'n' and word not in subword_set: - entities.append(word) - - # 提取 4:POS 动词('v') - for word, pos in zip(words, pos_list): - if pos == 'v': - entities.append(word) - - # 去重 - unique_entities = list(dict.fromkeys(entities)) - logger.info(f"从查询中提取到 {len(unique_entities)} 个唯一实体: {unique_entities}") - return unique_entities - - except Exception as e: - logger.error(f"实体抽取失败: {str(e)}") - return [] - -def load_triplets_from_file(triplet_file: str) -> List[Dict]: - """从三元组文件中加载""" - triplets = [] - try: - if not os.path.exists(triplet_file): - logger.warning(f"三元组文件 {triplet_file} 不存在") - return [] - - with open(triplet_file, 'r', encoding='utf-8') as f: - for line in f: - if line.strip(): - parts = line.strip().split('\t') - if len(parts) >= 5: - head, relation, tail, head_type, tail_type = parts[:5] - triplets.append({ - 'head': head, - 'head_type': head_type, - 'type': relation, - 'tail': tail, - 'tail_type': tail_type - }) - logger.debug(f"从 {triplet_file} 加载 {len(triplets)} 个三元组") - return triplets - except Exception as e: - logger.error(f"加载三元组文件 {triplet_file} 失败: {str(e)}") - return [] - -def match_triplets(query: str, query_entities: List[str], userid: str, document_id: str) -> List[Dict]: - """ - 匹配查询实体与文档三元组,使用语义嵌入: - - 初始匹配:实体与 head 或 tail 相似度 ≥ 0.8。 - - 返回匹配的三元组。 - """ - matched_triplets = [] - ENTITY_SIMILARITY_THRESHOLD = 0.8 # 实体与 head/tail 相似度阈值 - - try: - # 加载三元组 - triplet_file = os.path.join(TRIPLES_OUTPUT_DIR, f"{document_id}_{userid}.txt") - doc_triplets = load_triplets_from_file(triplet_file) - if not doc_triplets: - logger.debug(f"文档 document_id={document_id} 无三元组") - return [] - - # 缓存查询实体嵌入 - entity_vectors = {entity: embedding.embed_query(entity) for entity in query_entities} - - # 初始匹配 - for entity in query_entities: - entity_vec = entity_vectors[entity] - for d_triplet in doc_triplets: - d_head_vec = embedding.embed_query(d_triplet['head']) - d_tail_vec = embedding.embed_query(d_triplet['tail']) - head_similarity = 1 - cosine(entity_vec, d_head_vec) - tail_similarity = 1 - cosine(entity_vec, d_tail_vec) - - if head_similarity >= ENTITY_SIMILARITY_THRESHOLD or tail_similarity >= ENTITY_SIMILARITY_THRESHOLD: - matched_triplets.append(d_triplet) - logger.debug(f"匹配三元组: {d_triplet['head']} - {d_triplet['type']} - {d_triplet['tail']} " - f"(entity={entity}, head_sim={head_similarity:.2f}, tail_sim={tail_similarity:.2f})") - - # 去重 - unique_matched = [] - seen = set() - for t in matched_triplets: - identifier = (t['head'].lower(), t['type'].lower(), t['tail'].lower()) - if identifier not in seen: - seen.add(identifier) - unique_matched.append(t) - - logger.info(f"找到 {len(unique_matched)} 个匹配的三元组") - return unique_matched - - except Exception as e: - logger.error(f"匹配三元组失败: {str(e)}") - return [] - -def searchquery(query: str, userid: str, db_type: str, file_paths: List[str], limit: int = 5, offset: int = 0) -> List[Dict]: - """ - 根据查询抽取实体,匹配指定文档的三元组,并在 Milvus 中搜索相关文档片段。 - """ - try: - if not query or not userid or not db_type or not file_paths: - raise ValueError("query、userid、db_type 和 file_paths 不能为空") - if "_" in userid or "_" in db_type: - raise ValueError("userid 和 db_type 不能包含下划线") - if len(userid) > 100 or len(db_type) > 100: - raise ValueError("userid 或 db_type 的长度超出限制") - if limit <= 0 or limit > 16384: - raise ValueError("limit 必须在 1 到 16384 之间") - if offset < 0: - raise ValueError("offset 不能为负数") - if limit + offset > 16384: - raise ValueError("limit + offset 不能超过 16384") - - initialize_milvus_connection() - collection_name = f"ragdb_{db_type}" - if not utility.has_collection(collection_name): - logger.warning(f"集合 {collection_name} 不存在") - return [] - - collection = Collection(collection_name) - collection.load() - - documents = [] - for file_path in file_paths: - filename = os.path.basename(file_path) - results = collection.query( - expr=f"userid == '{userid}' and filename == '{filename}'", - output_fields=["document_id", "filename"], - limit=1 - ) - if not results: - logger.warning(f"未找到 userid {userid} 和 filename {filename} 对应的文档") - continue - documents.append(results[0]) - - if not documents: - logger.warning("没有找到任何有效文档") - return [] - - logger.info(f"找到 {len(documents)} 个文档: {[doc['filename'] for doc in documents]}") - - query_entities = extract_entities(query) - if not query_entities: - logger.warning("未从查询中提取到实体") - return [] - - search_results = [] - for doc in documents: - document_id = doc["document_id"] - filename = doc["filename"] - logger.debug(f"处理文档: document_id={document_id}, filename={filename}") - - matched_triplets = match_triplets(query, query_entities, userid, document_id) - if not matched_triplets: - logger.debug(f"文档 document_id={document_id} 未找到匹配的三元组") - continue - - for triplet in matched_triplets: - head = triplet['head'] - type = triplet['type'] - tail = triplet['tail'] - if not head or not type or not tail: - logger.debug(f"无效三元组: head={head}, type={type}, tail={tail}") - continue - - triplet_text = f"{head} {type} {tail}" - logger.debug(f"搜索三元组: {triplet_text} (文档: {filename})") - try: - search_params = {"metric_type": "COSINE", "params": {"nprobe": 10}} - query_vector = embedding.embed_query(triplet_text) - expr = f"userid == '{userid}' and filename == '{filename}' and text like '%{head}%{tail}%'" - logger.debug(f"搜索表达式: {expr}") - - results = collection.search( - data=[query_vector], - anns_field="vector", - param=search_params, - limit=limit, - expr=expr, - output_fields=["text", "userid", "document_id", "filename", "file_path", "upload_time", "file_type"], - offset=offset - ) - - for hits in results: - for hit in hits: - metadata = { - "userid": hit.entity.get("userid"), - "document_id": hit.entity.get("document_id"), - "filename": hit.entity.get("filename"), - "file_path": hit.entity.get("file_path"), - "upload_time": hit.entity.get("upload_time"), - "file_type": hit.entity.get("file_type") - } - result = { - "text": hit.entity.get("text"), - "distance": hit.distance, - "metadata": metadata - } - search_results.append(result) - logger.debug(f"命中: text: {result['text'][:200]}..., 距离: {hit.distance}, 元数据: {metadata}") - except Exception as e: - logger.warning(f"三元组 {triplet_text} 在文档 {filename} 搜索失败: {str(e)}") - continue - - unique_results = [] - seen_texts = set() - for result in sorted(search_results, key=lambda x: x['distance'], reverse=True): - if result['text'] not in seen_texts: - unique_results.append(result) - seen_texts.add(result['text']) - if len(unique_results) >= limit: - break - - logger.info(f"返回 {len(unique_results)} 条唯一结果") - return unique_results - - except Exception as e: - logger.error(f"搜索失败: {str(e)}") - import traceback - logger.debug(traceback.format_exc()) - return [] - finally: - cleanup_milvus_connection() - -if __name__ == "__main__": - query = "什么是知识图谱的知识抽取?" - userid = "testuser1" - db_type = "textdb" - file_paths = [ - "/share/wangmeihua/rag/data/test.docx", - "/share/wangmeihua/rag/data/zongshu.pdf", - "/share/wangmeihua/rag/data/qianru.pdf" - ] - limit = 5 - offset = 0 - - try: - results = searchquery(query, userid, db_type, file_paths, limit, offset) - print(f"搜索结果 ({len(results)} 条):") - for idx, result in enumerate(results, 1): - print(f"结果 {idx}:") - print(f"内容: {result['text'][:200]}...") - print(f"距离: {result['distance']}") - print(f"元数据: {result['metadata']}") - print("-" * 50) - except Exception as e: - print(f"搜索失败: {str(e)}") \ No newline at end of file diff --git a/rag/test.py b/rag/test.py deleted file mode 100644 index 85b1895..0000000 --- a/rag/test.py +++ /dev/null @@ -1,9 +0,0 @@ -from py2neo import Graph,Node,Relationship,NodeMatcher - -username = 'neo4j' -password = '261229..wmh' -auth = (username, password) -graph=Graph("bolt://10.18.34.18:7687", auth = auth) - -book_node=Node('经名',name='十三经') -graph.create(book_node) \ No newline at end of file diff --git a/rag/vector.py b/rag/vector.py deleted file mode 100644 index fc2a3fd..0000000 --- a/rag/vector.py +++ /dev/null @@ -1,539 +0,0 @@ -import os -import uuid -import json -import yaml -from datetime import datetime -from typing import List, Dict, Optional -from pymilvus import connections, utility, Collection, CollectionSchema, FieldSchema, DataType -from langchain_milvus import Milvus -from langchain_huggingface import HuggingFaceEmbeddings -from langchain_core.documents import Document -import torch -import logging -import time - -# 加载配置文件 -CONFIG_PATH = os.getenv('CONFIG_PATH', '/share/wangmeihua/rag/conf/milvusconfig.yaml') -try: - with open(CONFIG_PATH, 'r', encoding='utf-8') as f: - config = yaml.safe_load(f) - MILVUS_DB_PATH = config['database']['milvus_db_path'] - TEXT_EMBEDDING_MODEL = config['models']['text_embedding_model'] -except Exception as e: - print(f"加载配置文件 {CONFIG_PATH} 失败: {str(e)}") - raise RuntimeError(f"无法加载配置文件: {str(e)}") - -# 配置日志 -logger = logging.getLogger(config['logging']['name']) -logger.setLevel(getattr(logging, config['logging']['level'], logging.DEBUG)) -os.makedirs(os.path.dirname(config['logging']['file']), exist_ok=True) -formatter = logging.Formatter('%(asctime)s - %(levelname)s - %(message)s') -for handler in (logging.FileHandler(config['logging']['file'], encoding='utf-8'), logging.StreamHandler()): - handler.setFormatter(formatter) - logger.addHandler(handler) - -def ensure_milvus_directory() -> None: - """确保 Milvus 数据库目录存在""" - db_dir = os.path.dirname(MILVUS_DB_PATH) - if not os.path.exists(db_dir): - os.makedirs(db_dir, exist_ok=True) - logger.debug(f"创建 Milvus 目录: {db_dir}") - if not os.access(db_dir, os.W_OK): - raise RuntimeError(f"Milvus 目录 {db_dir} 不可写") - -def initialize_milvus_connection() -> None: - """初始化 Milvus 连接,确保单一连接""" - try: - if not connections.has_connection("default"): - connections.connect("default", uri=MILVUS_DB_PATH) - logger.debug(f"已连接到 Milvus Lite,路径: {MILVUS_DB_PATH}") - else: - logger.debug("已存在 Milvus 连接,跳过重复连接") - except Exception as e: - logger.error(f"连接 Milvus 失败: {str(e)}") - raise RuntimeError(f"连接 Milvus 失败: {str(e)}") - -def cleanup_milvus_connection() -> None: - """清理 Milvus 连接,确保资源释放""" - try: - if connections.has_connection("default"): - connections.disconnect("default") - logger.debug("已断开 Milvus 连接") - time.sleep(3) - except Exception as e: - logger.warning(f"断开 Milvus 连接失败: {str(e)}") - -def get_vector_db(userid: str, db_type: str, documents: List[Document]) -> Milvus: - """ - 初始化或访问 Milvus Lite 向量数据库集合,按 db_type 组织,利用 userid 区分用户,document_id 区分文档,并插入文档。 - """ - try: - # 参数验证 - if not userid or not db_type: - raise ValueError("userid 和 db_type 不能为空") - if "_" in userid or "_" in db_type: - raise ValueError("userid 和 db_type 不能包含下划线") - if len(userid) > 100 or len(db_type) > 100: - raise ValueError("userid 和 db_type 的长度应小于 100") - if not documents or not all(isinstance(doc, Document) for doc in documents): - raise ValueError("documents 不能为空且必须是 Document 对象列表") - required_fields = ["userid", "document_id", "filename", "file_path", "upload_time", "file_type"] - for doc in documents: - if not all(field in doc.metadata and doc.metadata[field] for field in required_fields): - raise ValueError(f"文档元数据缺少必需字段或字段值为空: {doc.metadata}") - if doc.metadata["userid"] != userid: - raise ValueError(f"文档元数据的 userid {doc.metadata['userid']} 与输入 userid {userid} 不一致") - - ensure_milvus_directory() - initialize_milvus_connection() - - # 初始化嵌入模型 - model_path = TEXT_EMBEDDING_MODEL - if not os.path.exists(model_path): - raise ValueError(f"模型路径 {model_path} 不存在") - - embedding = HuggingFaceEmbeddings( - model_name=model_path, - model_kwargs={'device': 'cuda' if torch.cuda.is_available() else 'cpu'}, - encode_kwargs={'normalize_embeddings': True} - ) - try: - test_vector = embedding.embed_query("test") - if len(test_vector) != 1024: - raise ValueError(f"嵌入模型输出维度 {len(test_vector)} 不匹配预期 1024") - logger.debug(f"嵌入模型加载成功,输出维度: {len(test_vector)}") - except Exception as e: - logger.error(f"嵌入模型加载失败: {str(e)}") - raise RuntimeError(f"加载模型失败: {str(e)}") - - # 集合名称 - collection_name = f"ragdb_{db_type}" - if len(collection_name) > 255: - raise ValueError(f"集合名称 {collection_name} 超过 255 个字符") - logger.debug(f"集合名称: {collection_name}") - - # 定义 schema,包含所有固定字段 - fields = [ - FieldSchema(name="pk", dtype=DataType.VARCHAR, is_primary=True, max_length=36, auto_id=True), - FieldSchema(name="userid", dtype=DataType.VARCHAR, max_length=100), - FieldSchema(name="document_id", dtype=DataType.VARCHAR, max_length=36), - FieldSchema(name="text", dtype=DataType.VARCHAR, max_length=65535), - FieldSchema(name="vector", dtype=DataType.FLOAT_VECTOR, dim=1024), - FieldSchema(name="filename", dtype=DataType.VARCHAR, max_length=255), - FieldSchema(name="file_path", dtype=DataType.VARCHAR, max_length=1024), - FieldSchema(name="upload_time", dtype=DataType.VARCHAR, max_length=64), - FieldSchema(name="file_type", dtype=DataType.VARCHAR, max_length=64), - ] - schema = CollectionSchema( - fields=fields, - description=f"{db_type} 数据集合,跨用户使用,包含 document_id 和元数据字段", - auto_id=True, - primary_field="pk", - ) - - # 检查集合是否存在 - if utility.has_collection(collection_name): - try: - collection = Collection(collection_name) - existing_schema = collection.schema - expected_fields = {f.name for f in fields} - actual_fields = {f.name for f in existing_schema.fields} - vector_field = next((f for f in existing_schema.fields if f.name == "vector"), None) - - schema_compatible = False - if expected_fields == actual_fields and vector_field is not None and vector_field.dtype == DataType.FLOAT_VECTOR: - dim = vector_field.params.get('dim', None) if hasattr(vector_field, 'params') and vector_field.params else None - schema_compatible = dim == 1024 - logger.debug(f"检查集合 {collection_name} 的 schema: 字段匹配={expected_fields == actual_fields}, " - f"vector_field存在={vector_field is not None}, dtype={vector_field.dtype if vector_field else '无'}, " - f"dim={dim if dim is not None else '未定义'}") - if not schema_compatible: - logger.warning(f"集合 {collection_name} 的 schema 不兼容,原因: " - f"字段不匹配: {expected_fields.symmetric_difference(actual_fields) or '无'}, " - f"vector_field: {vector_field is not None}, " - f"dtype: {vector_field.dtype if vector_field else '无'}, " - f"dim: {vector_field.params.get('dim', '未定义') if vector_field and hasattr(vector_field, 'params') and vector_field.params else '未定义'}") - utility.drop_collection(collection_name) - else: - collection.load() - logger.debug(f"集合 {collection_name} 已存在并加载成功") - except Exception as e: - logger.error(f"加载集合 {collection_name} 失败: {str(e)}") - raise RuntimeError(f"加载集合失败: {str(e)}") - - # 创建新集合 - if not utility.has_collection(collection_name): - try: - collection = Collection(collection_name, schema) - collection.create_index( - field_name="vector", - index_params={"index_type": "AUTOINDEX", "metric_type": "COSINE"} - ) - collection.create_index( - field_name="userid", - index_params={"index_type": "INVERTED"} - ) - collection.create_index( - field_name="document_id", - index_params={"index_type": "INVERTED"} - ) - collection.create_index( - field_name="filename", - index_params={"index_type": "INVERTED"} - ) - collection.create_index( - field_name="file_path", - index_params={"index_type": "INVERTED"} - ) - collection.create_index( - field_name="upload_time", - index_params={"index_type": "INVERTED"} - ) - collection.create_index( - field_name="file_type", - index_params={"index_type": "INVERTED"} - ) - collection.load() - logger.debug(f"成功创建并加载集合: {collection_name}") - except Exception as e: - logger.error(f"创建集合 {collection_name} 失败: {str(e)}") - raise RuntimeError(f"创建集合失败: {str(e)}") - - # 初始化 Milvus 向量存储 - try: - vector_store = Milvus( - embedding_function=embedding, - collection_name=collection_name, - connection_args={"uri": MILVUS_DB_PATH}, - drop_old=False, - auto_id=True, - primary_field="pk", - ) - logger.debug(f"成功初始化 Milvus 向量存储: {collection_name}") - except Exception as e: - logger.error(f"初始化 Milvus 向量存储失败: {str(e)}") - raise RuntimeError(f"初始化向量存储失败: {str(e)}") - - # 插入文档 - try: - logger.debug(f"正在为 userid {userid} 插入 {len(documents)} 个文档到 {collection_name}") - for doc in documents: - logger.debug(f"插入文档元数据: {doc.metadata}") - vector_store.add_documents(documents=documents) - logger.debug(f"成功插入 {len(documents)} 个文档") - - # 立即查询验证 - collection = Collection(collection_name) - collection.load() - results = collection.query( - expr=f"userid == '{userid}'", - output_fields=["pk", "text", "document_id", "filename", "file_path", "upload_time", "file_type"], - limit=10 - ) - for result in results: - logger.debug(f"插入后查询结果: pk={result['pk']}, document_id={result['document_id']}, " - f"metadata={{'filename': '{result['filename']}', 'file_path': '{result['file_path']}', " - f"'upload_time': '{result['upload_time']}', 'file_type': '{result['file_type']}'}}") - except Exception as e: - logger.error(f"插入文档失败: {str(e)}") - raise RuntimeError(f"插入文档失败: {str(e)}") - - return vector_store - - except Exception as e: - logger.error(f"初始化 Milvus 向量存储失败: {str(e)}") - raise - finally: - cleanup_milvus_connection() - -def get_document_mapping(userid: str, db_type: str) -> Dict[str, Dict]: - """ - 获取指定 userid 和 db_type 下的 document_id 与元数据的映射。 - """ - try: - if not userid or "_" in userid: - raise ValueError("userid 不能为空且不能包含下划线") - if not db_type or "_" in db_type: - raise ValueError("db_type 不能为空且不能包含下划线") - - initialize_milvus_connection() - collection_name = f"ragdb_{db_type}" - if not utility.has_collection(collection_name): - logger.warning(f"集合 {collection_name} 不存在") - return {} - - collection = Collection(collection_name) - collection.load() - - results = collection.query( - expr=f"userid == '{userid}'", - output_fields=["userid", "document_id", "filename", "file_path", "upload_time", "file_type"], - limit=100 - ) - mapping = {} - for result in results: - doc_id = result.get("document_id") - if doc_id: - mapping[doc_id] = { - "userid": result.get("userid", ""), - "filename": result.get("filename", ""), - "file_path": result.get("file_path", ""), - "upload_time": result.get("upload_time", ""), - "file_type": result.get("file_type", "") - } - logger.debug(f"document_id: {doc_id}, metadata: {mapping[doc_id]}") - - logger.debug(f"找到 {len(mapping)} 个文档的映射") - return mapping - - except Exception as e: - logger.error(f"获取文档映射失败: {str(e)}") - raise RuntimeError(f"获取文档映射失败: {str(e)}") - -def list_user_collections() -> Dict[str, Dict]: - """ - 列出所有数据库类型(db_type)及其包含的用户(userid)与对应的文档(document_id)映射。 - """ - try: - ensure_milvus_directory() - initialize_milvus_connection() - collections = utility.list_collections() - - db_types_with_data = {} - for col in collections: - if col.startswith("ragdb_"): - db_type = col[len("ragdb_"):] - logger.debug(f"处理集合: {col} (db_type: {db_type})") - - collection = Collection(col) - collection.load() - - batch_size = 1000 - offset = 0 - user_document_map = {} - while True: - try: - results = collection.query( - expr="", - output_fields=["userid", "document_id"], - limit=batch_size, - offset=offset - ) - if not results: - break - for result in results: - userid = result.get("userid") - doc_id = result.get("document_id") - if userid and doc_id: - if userid not in user_document_map: - user_document_map[userid] = set() - user_document_map[userid].add(doc_id) - offset += batch_size - except Exception as e: - logger.error(f"查询集合 {col} 的用户和文档失败: {str(e)}") - break - - # 转换为列表以便序列化 - user_document_map = {uid: list(doc_ids) for uid, doc_ids in user_document_map.items()} - logger.debug(f"集合 {col} 中找到用户和文档映射: {user_document_map}") - - db_types_with_data[db_type] = { - "userids": user_document_map - } - - logger.debug(f"可用 db_types 和数据: {db_types_with_data}") - return db_types_with_data - - except Exception as e: - logger.error(f"列出集合和用户失败: {str(e)}") - raise - -def view_collection_details(userid: str) -> None: - """ - 查看特定 userid 在所有集合中的内容和容量,包含 document_id 和元数据。 - """ - try: - if not userid or "_" in userid: - raise ValueError("userid 不能为空且不能包含下划线") - - logger.debug(f"正在查看 userid {userid} 的集合") - ensure_milvus_directory() - initialize_milvus_connection() - collections = utility.list_collections() - db_types = [col[len("ragdb_"):] for col in collections if col.startswith("ragdb_")] - - if not db_types: - logger.debug(f"未找到任何集合") - return - - for db_type in db_types: - collection_name = f"ragdb_{db_type}" - if not utility.has_collection(collection_name): - logger.warning(f"集合 {collection_name} 不存在") - continue - - collection = Collection(collection_name) - collection.load() - - try: - all_pks = collection.query(expr=f"userid == '{userid}'", output_fields=["pk"], limit=10000) - num_entities = len(all_pks) - results = collection.query( - expr=f"userid == '{userid}'", - output_fields=["userid","text", "document_id", "filename", "file_path", "upload_time", "file_type"], - limit=10 - ) - logger.debug(f"集合 {collection_name} 中 userid {userid} 的文档数: {num_entities}") - - if num_entities == 0: - logger.debug(f"集合 {collection_name} 中 userid {userid} 无文档") - continue - - logger.debug(f"集合 {collection_name} 中 userid {userid} 的内容:") - for idx, doc in enumerate(results, 1): - metadata = { - "userid": doc.get("userid", ""), - "filename": doc.get("filename", ""), - "file_path": doc.get("file_path", ""), - "upload_time": doc.get("upload_time", ""), - "file_type": doc.get("file_type", "") - } - logger.debug(f"文档 {idx}: 内容: {doc.get('text', '')[:200]}..., 元数据: {metadata}") - except Exception as e: - logger.error(f"查询集合 {collection_name} 的文档失败: {str(e)}") - continue - - except Exception as e: - logger.error(f"无法查看 userid {userid} 的集合详情: {str(e)}") - raise - -def view_vector_data(db_type: str, userid: Optional[str] = None, document_id: Optional[str] = None, limit: int = 100) -> Dict[str, Dict]: - """ - 查看指定 db_type 中的向量数据,可选按 userid 和 document_id 过滤,包含完整元数据和向量。 - """ - try: - if not db_type or "_" in db_type: - raise ValueError("db_type 不能为空且不能包含下划线") - if limit <= 0 or limit > 16384: - raise ValueError("limit 必须在 1 到 16384 之间") - if userid and "_" in userid: - raise ValueError("userid 不能包含下划线") - if document_id and "_" in document_id: - raise ValueError("document_id 不能包含下划线") - - initialize_milvus_connection() - collection_name = f"ragdb_{db_type}" - if not utility.has_collection(collection_name): - logger.warning(f"集合 {collection_name} 不存在") - return {} - - collection = Collection(collection_name) - collection.load() - logger.debug(f"加载集合: {collection_name}") - - expr = [] - if userid: - expr.append(f"userid == '{userid}'") - if document_id: - expr.append(f"document_id == '{document_id}'") - expr = " && ".join(expr) if expr else "" - - results = collection.query( - expr=expr, - output_fields=["pk", "text", "document_id", "vector", "filename", "file_path", "upload_time", "file_type"], - limit=limit - ) - - vector_data = {} - for doc in results: - pk = doc.get("pk", str(uuid.uuid4())) - text = doc.get("text", "") - doc_id = doc.get("document_id", "") - vector = doc.get("vector", []) - metadata = { - "filename": doc.get("filename", ""), - "file_path": doc.get("file_path", ""), - "upload_time": doc.get("upload_time", ""), - "file_type": doc.get("file_type", "") - } - vector_data[pk] = { - "text": text, - "vector": vector, - "document_id": doc_id, - "metadata": metadata - } - logger.debug(f"pk: {pk}, text: {text[:200]}..., document_id: {doc_id}, vector_length: {len(vector)}, metadata: {metadata}") - - logger.debug(f"共找到 {len(vector_data)} 条向量数据") - return vector_data - - except Exception as e: - logger.error(f"查看向量数据失败: {str(e)}") - raise - -def main(): - userid = "testuser1" - db_type = "textdb" - - # logger.info("\n测试 1:带文档初始化") - # documents = [ - # Document( - # page_content="深度学习是基于深层神经网络的机器学习子集。", - # metadata={ - # "userid": userid, - # "document_id": str(uuid.uuid4()), - # "filename": "test_doc1.txt", - # "file_path": "/path/to/test_doc1.txt", - # "upload_time": datetime.now().isoformat(), - # "file_type": "txt" - # } - # ), - # Document( - # page_content="知识图谱是一个结构化的语义知识库。", - # metadata={ - # "userid": userid, - # "document_id": str(uuid.uuid4()), - # "filename": "test_doc2.txt", - # "file_path": "/path/to/test_doc2.txt", - # "upload_time": datetime.now().isoformat(), - # "file_type": "txt" - # } - # ), - # ] - # - # try: - # vector_store = get_vector_db(userid, db_type, documents=documents) - # logger.info(f"集合: ragdb_{db_type}") - # logger.info(f"成功为 userid {userid} 在 {db_type} 中插入文档") - # except Exception as e: - # logger.error(f"失败: {str(e)}") - - logger.info("\n测试 2:列出所有 db_types 和文档映射") - try: - db_types = list_user_collections() - logger.info(f"可用 db_types 和文档: {db_types}") - except Exception as e: - logger.error(f"失败: {str(e)}") - - logger.info(f"\n测试 3:查看 userid {userid} 的所有集合") - try: - view_collection_details(userid) - except Exception as e: - logger.error(f"失败: {str(e)}") - - # logger.info(f"\n测试 4:查看向量数据") - # try: - # vector_data = view_vector_data(db_type, userid=userid) - # logger.info(f"向量数据: {vector_data}") - # except Exception as e: - # logger.error(f"失败: {str(e)}") - - logger.info(f"\n测试 5:获取 userid {userid} 在{db_type}数据库的文档映射") - try: - mapping = get_document_mapping(userid, db_type) - logger.info(f"文档映射: {mapping}") - except Exception as e: - logger.error(f"失败: {str(e)}") - -if __name__ == "__main__": - main() \ No newline at end of file diff --git a/rag/version.py b/rag/version.py deleted file mode 100644 index b8023d8..0000000 --- a/rag/version.py +++ /dev/null @@ -1 +0,0 @@ -__version__ = '0.0.1' diff --git a/setup.py b/setup.py deleted file mode 100755 index 86b9fff..0000000 --- a/setup.py +++ /dev/null @@ -1,52 +0,0 @@ -# -*- coding: utf-8 -*- - -from rag.version import __version__ -try: - from setuptools import setup -except ImportError: - from distutils.core import setup -required = [] -with open('requirements.txt', 'r') as f: - ls = f.read() - required = ls.split('\n') - -with open('rag/version.py', 'r') as f: - x = f.read() - y = x[x.index("'")+1:] - z = y[:y.index("'")] - version = z -with open("README.md", "r") as fh: - long_description = fh.read() - -name = "rag" -description = "rag" -author = "yumoqing" -email = "yumoqing@gmail.com" - -package_data = {} - -setup( - name="rag", - version=version, - - # uncomment the following lines if you fill them out in release.py - description=description, - author=author, - author_email=email, - platforms='any', - install_requires=required , - packages=[ - "rag" - ], - package_data=package_data, - keywords = [ - ], - url="https://github.com/yumoqing/rag", - long_description=long_description, - long_description_content_type="text/markdown", - classifiers = [ - 'Operating System :: OS Independent', - 'Programming Language :: Python :: 3', - 'License :: OSI Approved :: MIT License', - ], -)