rag/rag/vector.py
2025-07-16 15:06:59 +08:00

539 lines
23 KiB
Python
Raw Blame History

This file contains ambiguous Unicode characters

This file contains Unicode characters that might be confused with other characters. If you think that this is intentional, you can safely ignore this warning. Use the Escape button to reveal them.

import os
import uuid
import json
import yaml
from datetime import datetime
from typing import List, Dict, Optional
from pymilvus import connections, utility, Collection, CollectionSchema, FieldSchema, DataType
from langchain_milvus import Milvus
from langchain_huggingface import HuggingFaceEmbeddings
from langchain_core.documents import Document
import torch
import logging
import time
# 加载配置文件
CONFIG_PATH = os.getenv('CONFIG_PATH', '/share/wangmeihua/rag/conf/milvusconfig.yaml')
try:
with open(CONFIG_PATH, 'r', encoding='utf-8') as f:
config = yaml.safe_load(f)
MILVUS_DB_PATH = config['database']['milvus_db_path']
TEXT_EMBEDDING_MODEL = config['models']['text_embedding_model']
except Exception as e:
print(f"加载配置文件 {CONFIG_PATH} 失败: {str(e)}")
raise RuntimeError(f"无法加载配置文件: {str(e)}")
# 配置日志
logger = logging.getLogger(config['logging']['name'])
logger.setLevel(getattr(logging, config['logging']['level'], logging.DEBUG))
os.makedirs(os.path.dirname(config['logging']['file']), exist_ok=True)
formatter = logging.Formatter('%(asctime)s - %(levelname)s - %(message)s')
for handler in (logging.FileHandler(config['logging']['file'], encoding='utf-8'), logging.StreamHandler()):
handler.setFormatter(formatter)
logger.addHandler(handler)
def ensure_milvus_directory() -> None:
"""确保 Milvus 数据库目录存在"""
db_dir = os.path.dirname(MILVUS_DB_PATH)
if not os.path.exists(db_dir):
os.makedirs(db_dir, exist_ok=True)
logger.debug(f"创建 Milvus 目录: {db_dir}")
if not os.access(db_dir, os.W_OK):
raise RuntimeError(f"Milvus 目录 {db_dir} 不可写")
def initialize_milvus_connection() -> None:
"""初始化 Milvus 连接,确保单一连接"""
try:
if not connections.has_connection("default"):
connections.connect("default", uri=MILVUS_DB_PATH)
logger.debug(f"已连接到 Milvus Lite路径: {MILVUS_DB_PATH}")
else:
logger.debug("已存在 Milvus 连接,跳过重复连接")
except Exception as e:
logger.error(f"连接 Milvus 失败: {str(e)}")
raise RuntimeError(f"连接 Milvus 失败: {str(e)}")
def cleanup_milvus_connection() -> None:
"""清理 Milvus 连接,确保资源释放"""
try:
if connections.has_connection("default"):
connections.disconnect("default")
logger.debug("已断开 Milvus 连接")
time.sleep(3)
except Exception as e:
logger.warning(f"断开 Milvus 连接失败: {str(e)}")
def get_vector_db(userid: str, db_type: str, documents: List[Document]) -> Milvus:
"""
初始化或访问 Milvus Lite 向量数据库集合,按 db_type 组织,利用 userid 区分用户document_id 区分文档,并插入文档。
"""
try:
# 参数验证
if not userid or not db_type:
raise ValueError("userid 和 db_type 不能为空")
if "_" in userid or "_" in db_type:
raise ValueError("userid 和 db_type 不能包含下划线")
if len(userid) > 100 or len(db_type) > 100:
raise ValueError("userid 和 db_type 的长度应小于 100")
if not documents or not all(isinstance(doc, Document) for doc in documents):
raise ValueError("documents 不能为空且必须是 Document 对象列表")
required_fields = ["userid", "document_id", "filename", "file_path", "upload_time", "file_type"]
for doc in documents:
if not all(field in doc.metadata and doc.metadata[field] for field in required_fields):
raise ValueError(f"文档元数据缺少必需字段或字段值为空: {doc.metadata}")
if doc.metadata["userid"] != userid:
raise ValueError(f"文档元数据的 userid {doc.metadata['userid']} 与输入 userid {userid} 不一致")
ensure_milvus_directory()
initialize_milvus_connection()
# 初始化嵌入模型
model_path = TEXT_EMBEDDING_MODEL
if not os.path.exists(model_path):
raise ValueError(f"模型路径 {model_path} 不存在")
embedding = HuggingFaceEmbeddings(
model_name=model_path,
model_kwargs={'device': 'cuda' if torch.cuda.is_available() else 'cpu'},
encode_kwargs={'normalize_embeddings': True}
)
try:
test_vector = embedding.embed_query("test")
if len(test_vector) != 1024:
raise ValueError(f"嵌入模型输出维度 {len(test_vector)} 不匹配预期 1024")
logger.debug(f"嵌入模型加载成功,输出维度: {len(test_vector)}")
except Exception as e:
logger.error(f"嵌入模型加载失败: {str(e)}")
raise RuntimeError(f"加载模型失败: {str(e)}")
# 集合名称
collection_name = f"ragdb_{db_type}"
if len(collection_name) > 255:
raise ValueError(f"集合名称 {collection_name} 超过 255 个字符")
logger.debug(f"集合名称: {collection_name}")
# 定义 schema包含所有固定字段
fields = [
FieldSchema(name="pk", dtype=DataType.VARCHAR, is_primary=True, max_length=36, auto_id=True),
FieldSchema(name="userid", dtype=DataType.VARCHAR, max_length=100),
FieldSchema(name="document_id", dtype=DataType.VARCHAR, max_length=36),
FieldSchema(name="text", dtype=DataType.VARCHAR, max_length=65535),
FieldSchema(name="vector", dtype=DataType.FLOAT_VECTOR, dim=1024),
FieldSchema(name="filename", dtype=DataType.VARCHAR, max_length=255),
FieldSchema(name="file_path", dtype=DataType.VARCHAR, max_length=1024),
FieldSchema(name="upload_time", dtype=DataType.VARCHAR, max_length=64),
FieldSchema(name="file_type", dtype=DataType.VARCHAR, max_length=64),
]
schema = CollectionSchema(
fields=fields,
description=f"{db_type} 数据集合,跨用户使用,包含 document_id 和元数据字段",
auto_id=True,
primary_field="pk",
)
# 检查集合是否存在
if utility.has_collection(collection_name):
try:
collection = Collection(collection_name)
existing_schema = collection.schema
expected_fields = {f.name for f in fields}
actual_fields = {f.name for f in existing_schema.fields}
vector_field = next((f for f in existing_schema.fields if f.name == "vector"), None)
schema_compatible = False
if expected_fields == actual_fields and vector_field is not None and vector_field.dtype == DataType.FLOAT_VECTOR:
dim = vector_field.params.get('dim', None) if hasattr(vector_field, 'params') and vector_field.params else None
schema_compatible = dim == 1024
logger.debug(f"检查集合 {collection_name} 的 schema: 字段匹配={expected_fields == actual_fields}, "
f"vector_field存在={vector_field is not None}, dtype={vector_field.dtype if vector_field else ''}, "
f"dim={dim if dim is not None else '未定义'}")
if not schema_compatible:
logger.warning(f"集合 {collection_name} 的 schema 不兼容,原因: "
f"字段不匹配: {expected_fields.symmetric_difference(actual_fields) or ''}, "
f"vector_field: {vector_field is not None}, "
f"dtype: {vector_field.dtype if vector_field else ''}, "
f"dim: {vector_field.params.get('dim', '未定义') if vector_field and hasattr(vector_field, 'params') and vector_field.params else '未定义'}")
utility.drop_collection(collection_name)
else:
collection.load()
logger.debug(f"集合 {collection_name} 已存在并加载成功")
except Exception as e:
logger.error(f"加载集合 {collection_name} 失败: {str(e)}")
raise RuntimeError(f"加载集合失败: {str(e)}")
# 创建新集合
if not utility.has_collection(collection_name):
try:
collection = Collection(collection_name, schema)
collection.create_index(
field_name="vector",
index_params={"index_type": "AUTOINDEX", "metric_type": "COSINE"}
)
collection.create_index(
field_name="userid",
index_params={"index_type": "INVERTED"}
)
collection.create_index(
field_name="document_id",
index_params={"index_type": "INVERTED"}
)
collection.create_index(
field_name="filename",
index_params={"index_type": "INVERTED"}
)
collection.create_index(
field_name="file_path",
index_params={"index_type": "INVERTED"}
)
collection.create_index(
field_name="upload_time",
index_params={"index_type": "INVERTED"}
)
collection.create_index(
field_name="file_type",
index_params={"index_type": "INVERTED"}
)
collection.load()
logger.debug(f"成功创建并加载集合: {collection_name}")
except Exception as e:
logger.error(f"创建集合 {collection_name} 失败: {str(e)}")
raise RuntimeError(f"创建集合失败: {str(e)}")
# 初始化 Milvus 向量存储
try:
vector_store = Milvus(
embedding_function=embedding,
collection_name=collection_name,
connection_args={"uri": MILVUS_DB_PATH},
drop_old=False,
auto_id=True,
primary_field="pk",
)
logger.debug(f"成功初始化 Milvus 向量存储: {collection_name}")
except Exception as e:
logger.error(f"初始化 Milvus 向量存储失败: {str(e)}")
raise RuntimeError(f"初始化向量存储失败: {str(e)}")
# 插入文档
try:
logger.debug(f"正在为 userid {userid} 插入 {len(documents)} 个文档到 {collection_name}")
for doc in documents:
logger.debug(f"插入文档元数据: {doc.metadata}")
vector_store.add_documents(documents=documents)
logger.debug(f"成功插入 {len(documents)} 个文档")
# 立即查询验证
collection = Collection(collection_name)
collection.load()
results = collection.query(
expr=f"userid == '{userid}'",
output_fields=["pk", "text", "document_id", "filename", "file_path", "upload_time", "file_type"],
limit=10
)
for result in results:
logger.debug(f"插入后查询结果: pk={result['pk']}, document_id={result['document_id']}, "
f"metadata={{'filename': '{result['filename']}', 'file_path': '{result['file_path']}', "
f"'upload_time': '{result['upload_time']}', 'file_type': '{result['file_type']}'}}")
except Exception as e:
logger.error(f"插入文档失败: {str(e)}")
raise RuntimeError(f"插入文档失败: {str(e)}")
return vector_store
except Exception as e:
logger.error(f"初始化 Milvus 向量存储失败: {str(e)}")
raise
finally:
cleanup_milvus_connection()
def get_document_mapping(userid: str, db_type: str) -> Dict[str, Dict]:
"""
获取指定 userid 和 db_type 下的 document_id 与元数据的映射。
"""
try:
if not userid or "_" in userid:
raise ValueError("userid 不能为空且不能包含下划线")
if not db_type or "_" in db_type:
raise ValueError("db_type 不能为空且不能包含下划线")
initialize_milvus_connection()
collection_name = f"ragdb_{db_type}"
if not utility.has_collection(collection_name):
logger.warning(f"集合 {collection_name} 不存在")
return {}
collection = Collection(collection_name)
collection.load()
results = collection.query(
expr=f"userid == '{userid}'",
output_fields=["userid", "document_id", "filename", "file_path", "upload_time", "file_type"],
limit=100
)
mapping = {}
for result in results:
doc_id = result.get("document_id")
if doc_id:
mapping[doc_id] = {
"userid": result.get("userid", ""),
"filename": result.get("filename", ""),
"file_path": result.get("file_path", ""),
"upload_time": result.get("upload_time", ""),
"file_type": result.get("file_type", "")
}
logger.debug(f"document_id: {doc_id}, metadata: {mapping[doc_id]}")
logger.debug(f"找到 {len(mapping)} 个文档的映射")
return mapping
except Exception as e:
logger.error(f"获取文档映射失败: {str(e)}")
raise RuntimeError(f"获取文档映射失败: {str(e)}")
def list_user_collections() -> Dict[str, Dict]:
"""
列出所有数据库类型db_type及其包含的用户userid与对应的文档document_id映射。
"""
try:
ensure_milvus_directory()
initialize_milvus_connection()
collections = utility.list_collections()
db_types_with_data = {}
for col in collections:
if col.startswith("ragdb_"):
db_type = col[len("ragdb_"):]
logger.debug(f"处理集合: {col} (db_type: {db_type})")
collection = Collection(col)
collection.load()
batch_size = 1000
offset = 0
user_document_map = {}
while True:
try:
results = collection.query(
expr="",
output_fields=["userid", "document_id"],
limit=batch_size,
offset=offset
)
if not results:
break
for result in results:
userid = result.get("userid")
doc_id = result.get("document_id")
if userid and doc_id:
if userid not in user_document_map:
user_document_map[userid] = set()
user_document_map[userid].add(doc_id)
offset += batch_size
except Exception as e:
logger.error(f"查询集合 {col} 的用户和文档失败: {str(e)}")
break
# 转换为列表以便序列化
user_document_map = {uid: list(doc_ids) for uid, doc_ids in user_document_map.items()}
logger.debug(f"集合 {col} 中找到用户和文档映射: {user_document_map}")
db_types_with_data[db_type] = {
"userids": user_document_map
}
logger.debug(f"可用 db_types 和数据: {db_types_with_data}")
return db_types_with_data
except Exception as e:
logger.error(f"列出集合和用户失败: {str(e)}")
raise
def view_collection_details(userid: str) -> None:
"""
查看特定 userid 在所有集合中的内容和容量,包含 document_id 和元数据。
"""
try:
if not userid or "_" in userid:
raise ValueError("userid 不能为空且不能包含下划线")
logger.debug(f"正在查看 userid {userid} 的集合")
ensure_milvus_directory()
initialize_milvus_connection()
collections = utility.list_collections()
db_types = [col[len("ragdb_"):] for col in collections if col.startswith("ragdb_")]
if not db_types:
logger.debug(f"未找到任何集合")
return
for db_type in db_types:
collection_name = f"ragdb_{db_type}"
if not utility.has_collection(collection_name):
logger.warning(f"集合 {collection_name} 不存在")
continue
collection = Collection(collection_name)
collection.load()
try:
all_pks = collection.query(expr=f"userid == '{userid}'", output_fields=["pk"], limit=10000)
num_entities = len(all_pks)
results = collection.query(
expr=f"userid == '{userid}'",
output_fields=["userid","text", "document_id", "filename", "file_path", "upload_time", "file_type"],
limit=10
)
logger.debug(f"集合 {collection_name} 中 userid {userid} 的文档数: {num_entities}")
if num_entities == 0:
logger.debug(f"集合 {collection_name} 中 userid {userid} 无文档")
continue
logger.debug(f"集合 {collection_name} 中 userid {userid} 的内容:")
for idx, doc in enumerate(results, 1):
metadata = {
"userid": doc.get("userid", ""),
"filename": doc.get("filename", ""),
"file_path": doc.get("file_path", ""),
"upload_time": doc.get("upload_time", ""),
"file_type": doc.get("file_type", "")
}
logger.debug(f"文档 {idx}: 内容: {doc.get('text', '')[:200]}..., 元数据: {metadata}")
except Exception as e:
logger.error(f"查询集合 {collection_name} 的文档失败: {str(e)}")
continue
except Exception as e:
logger.error(f"无法查看 userid {userid} 的集合详情: {str(e)}")
raise
def view_vector_data(db_type: str, userid: Optional[str] = None, document_id: Optional[str] = None, limit: int = 100) -> Dict[str, Dict]:
"""
查看指定 db_type 中的向量数据,可选按 userid 和 document_id 过滤,包含完整元数据和向量。
"""
try:
if not db_type or "_" in db_type:
raise ValueError("db_type 不能为空且不能包含下划线")
if limit <= 0 or limit > 16384:
raise ValueError("limit 必须在 1 到 16384 之间")
if userid and "_" in userid:
raise ValueError("userid 不能包含下划线")
if document_id and "_" in document_id:
raise ValueError("document_id 不能包含下划线")
initialize_milvus_connection()
collection_name = f"ragdb_{db_type}"
if not utility.has_collection(collection_name):
logger.warning(f"集合 {collection_name} 不存在")
return {}
collection = Collection(collection_name)
collection.load()
logger.debug(f"加载集合: {collection_name}")
expr = []
if userid:
expr.append(f"userid == '{userid}'")
if document_id:
expr.append(f"document_id == '{document_id}'")
expr = " && ".join(expr) if expr else ""
results = collection.query(
expr=expr,
output_fields=["pk", "text", "document_id", "vector", "filename", "file_path", "upload_time", "file_type"],
limit=limit
)
vector_data = {}
for doc in results:
pk = doc.get("pk", str(uuid.uuid4()))
text = doc.get("text", "")
doc_id = doc.get("document_id", "")
vector = doc.get("vector", [])
metadata = {
"filename": doc.get("filename", ""),
"file_path": doc.get("file_path", ""),
"upload_time": doc.get("upload_time", ""),
"file_type": doc.get("file_type", "")
}
vector_data[pk] = {
"text": text,
"vector": vector,
"document_id": doc_id,
"metadata": metadata
}
logger.debug(f"pk: {pk}, text: {text[:200]}..., document_id: {doc_id}, vector_length: {len(vector)}, metadata: {metadata}")
logger.debug(f"共找到 {len(vector_data)} 条向量数据")
return vector_data
except Exception as e:
logger.error(f"查看向量数据失败: {str(e)}")
raise
def main():
userid = "testuser1"
db_type = "textdb"
# logger.info("\n测试 1带文档初始化")
# documents = [
# Document(
# page_content="深度学习是基于深层神经网络的机器学习子集。",
# metadata={
# "userid": userid,
# "document_id": str(uuid.uuid4()),
# "filename": "test_doc1.txt",
# "file_path": "/path/to/test_doc1.txt",
# "upload_time": datetime.now().isoformat(),
# "file_type": "txt"
# }
# ),
# Document(
# page_content="知识图谱是一个结构化的语义知识库。",
# metadata={
# "userid": userid,
# "document_id": str(uuid.uuid4()),
# "filename": "test_doc2.txt",
# "file_path": "/path/to/test_doc2.txt",
# "upload_time": datetime.now().isoformat(),
# "file_type": "txt"
# }
# ),
# ]
#
# try:
# vector_store = get_vector_db(userid, db_type, documents=documents)
# logger.info(f"集合: ragdb_{db_type}")
# logger.info(f"成功为 userid {userid} 在 {db_type} 中插入文档")
# except Exception as e:
# logger.error(f"失败: {str(e)}")
logger.info("\n测试 2列出所有 db_types 和文档映射")
try:
db_types = list_user_collections()
logger.info(f"可用 db_types 和文档: {db_types}")
except Exception as e:
logger.error(f"失败: {str(e)}")
logger.info(f"\n测试 3查看 userid {userid} 的所有集合")
try:
view_collection_details(userid)
except Exception as e:
logger.error(f"失败: {str(e)}")
# logger.info(f"\n测试 4查看向量数据")
# try:
# vector_data = view_vector_data(db_type, userid=userid)
# logger.info(f"向量数据: {vector_data}")
# except Exception as e:
# logger.error(f"失败: {str(e)}")
logger.info(f"\n测试 5获取 userid {userid}{db_type}数据库的文档映射")
try:
mapping = get_document_mapping(userid, db_type)
logger.info(f"文档映射: {mapping}")
except Exception as e:
logger.error(f"失败: {str(e)}")
if __name__ == "__main__":
main()