删除
This commit is contained in:
parent
c0a62bb5e7
commit
e4cdcc1f5a
@ -1,2 +0,0 @@
|
||||
from .version import __version__
|
||||
|
||||
@ -1,138 +0,0 @@
|
||||
import logging
|
||||
import yaml
|
||||
import os
|
||||
from pymilvus import connections, Collection, utility
|
||||
from vector import initialize_milvus_connection
|
||||
|
||||
# 加载配置文件
|
||||
CONFIG_PATH = os.getenv('CONFIG_PATH', '/share/wangmeihua/rag/conf/milvusconfig.yaml')
|
||||
try:
|
||||
with open(CONFIG_PATH, 'r', encoding='utf-8') as f:
|
||||
config = yaml.safe_load(f)
|
||||
MILVUS_DB_PATH = config['database']['milvus_db_path']
|
||||
TEXT_EMBEDDING_MODEL = config['models']['text_embedding_model']
|
||||
except Exception as e:
|
||||
print(f"加载配置文件 {CONFIG_PATH} 失败: {str(e)}")
|
||||
raise RuntimeError(f"无法加载配置文件: {str(e)}")
|
||||
|
||||
# 配置日志
|
||||
logger = logging.getLogger(config['logging']['name'])
|
||||
logger.setLevel(getattr(logging, config['logging']['level'], logging.DEBUG))
|
||||
os.makedirs(os.path.dirname(config['logging']['file']), exist_ok=True)
|
||||
formatter = logging.Formatter('%(asctime)s - %(levelname)s - %(message)s')
|
||||
for handler in (logging.FileHandler(config['logging']['file'], encoding='utf-8'), logging.StreamHandler()):
|
||||
handler.setFormatter(formatter)
|
||||
logger.addHandler(handler)
|
||||
|
||||
def delete_document(db_type: str, userid: str, filename: str) -> bool:
|
||||
"""
|
||||
根据 db_type、userid 和 filename 删除用户的指定文件数据。
|
||||
|
||||
参数:
|
||||
db_type (str): 数据库类型(如 'textdb', 'pptdb')
|
||||
userid (str): 用户 ID
|
||||
filename (str): 文件名(如 'test.docx')
|
||||
|
||||
返回:
|
||||
bool: 删除是否成功
|
||||
|
||||
异常:
|
||||
ValueError: 参数无效
|
||||
RuntimeError: 数据库操作失败
|
||||
"""
|
||||
try:
|
||||
# 参数验证
|
||||
if not db_type or "_" in db_type:
|
||||
raise ValueError("db_type 不能为空且不能包含下划线")
|
||||
if not userid or "_" in userid:
|
||||
raise ValueError("userid 不能为空且不能包含下划线")
|
||||
if not filename:
|
||||
raise ValueError("filename 不能为空")
|
||||
if len(db_type) > 100 or len(userid) > 100 or len(filename) > 255:
|
||||
raise ValueError("db_type、userid 或 filename 的长度超出限制")
|
||||
|
||||
# 初始化 Milvus 连接
|
||||
initialize_milvus_connection()
|
||||
logger.debug(f"已连接到 Milvus Lite,路径: {MILVUS_DB_PATH}")
|
||||
|
||||
# 检查集合是否存在
|
||||
collection_name = f"ragdb_{db_type}"
|
||||
if not utility.has_collection(collection_name):
|
||||
logger.warning(f"集合 {collection_name} 不存在")
|
||||
return False
|
||||
|
||||
# 加载集合
|
||||
try:
|
||||
collection = Collection(collection_name)
|
||||
collection.load()
|
||||
logger.debug(f"加载集合: {collection_name}")
|
||||
except Exception as e:
|
||||
logger.error(f"加载集合 {collection_name} 失败: {str(e)}")
|
||||
raise RuntimeError(f"加载集合失败: {str(e)}")
|
||||
|
||||
# 查询匹配的 document_id
|
||||
expr = f"userid == '{userid}' and filename == '{filename}'"
|
||||
logger.debug(f"查询表达式: {expr}")
|
||||
try:
|
||||
results = collection.query(
|
||||
expr=expr,
|
||||
output_fields=["document_id"],
|
||||
limit=1000
|
||||
)
|
||||
if not results:
|
||||
logger.warning(f"没有找到 userid={userid}, filename={filename} 的记录")
|
||||
return False
|
||||
document_ids = list(set(result["document_id"] for result in results if "document_id" in result))
|
||||
logger.debug(f"找到 {len(document_ids)} 个 document_id: {document_ids}")
|
||||
except Exception as e:
|
||||
logger.error(f"查询 document_id 失败: {str(e)}")
|
||||
raise RuntimeError(f"查询失败: {str(e)}")
|
||||
|
||||
# 执行删除
|
||||
total_deleted = 0
|
||||
for doc_id in document_ids:
|
||||
try:
|
||||
delete_expr = f"userid == '{userid}' and document_id == '{doc_id}'"
|
||||
logger.debug(f"删除表达式: {delete_expr}")
|
||||
delete_result = collection.delete(delete_expr)
|
||||
deleted_count = delete_result.delete_count
|
||||
total_deleted += deleted_count
|
||||
logger.info(f"成功删除 document_id={doc_id} 的 {deleted_count} 条记录")
|
||||
except Exception as e:
|
||||
logger.error(f"删除 document_id={doc_id} 失败: {str(e)}")
|
||||
continue
|
||||
|
||||
if total_deleted == 0:
|
||||
logger.warning(f"没有删除任何记录,userid={userid}, filename={filename}")
|
||||
return False
|
||||
|
||||
logger.info(f"总计删除 {total_deleted} 条记录,userid={userid}, filename={filename}")
|
||||
return True
|
||||
|
||||
except ValueError as ve:
|
||||
logger.error(f"参数验证失败: {str(ve)}")
|
||||
return False
|
||||
except RuntimeError as re:
|
||||
logger.error(f"数据库操作失败: {str(re)}")
|
||||
return False
|
||||
except Exception as e:
|
||||
logger.error(f"删除文件失败: {str(e)}")
|
||||
import traceback
|
||||
logger.debug(traceback.format_exc())
|
||||
return False
|
||||
finally:
|
||||
try:
|
||||
connections.disconnect("default")
|
||||
logger.debug("已断开 Milvus 连接")
|
||||
except Exception as e:
|
||||
logger.warning(f"断开 Milvus 连接失败: {str(e)}")
|
||||
|
||||
if __name__ == "__main__":
|
||||
# 测试用例
|
||||
db_type = "textdb"
|
||||
userid = "testuser4"
|
||||
filename = "聚类结果1.xlsx"
|
||||
|
||||
logger.info(f"测试:删除 userid={userid}, filename={filename} 的文件")
|
||||
result = delete_document(db_type, userid, filename)
|
||||
print(f"删除结果: {result}")
|
||||
@ -1,178 +0,0 @@
|
||||
import os
|
||||
import uuid
|
||||
import yaml
|
||||
import logging
|
||||
from datetime import datetime
|
||||
from typing import List
|
||||
from langchain_core.documents import Document
|
||||
from langchain_text_splitters import RecursiveCharacterTextSplitter
|
||||
from pymilvus import connections
|
||||
from .vector import get_vector_db
|
||||
from filetxt.loader import fileloader
|
||||
|
||||
# 加载配置文件
|
||||
CONFIG_PATH = os.getenv('CONFIG_PATH', '/share/wangmeihua/rag/conf/milvusconfig.yaml')
|
||||
try:
|
||||
with open(CONFIG_PATH, 'r', encoding='utf-8') as f:
|
||||
config = yaml.safe_load(f)
|
||||
MILVUS_DB_PATH = config['database']['milvus_db_path']
|
||||
except Exception as e:
|
||||
print(f"加载配置文件 {CONFIG_PATH} 失败: {str(e)}")
|
||||
raise RuntimeError(f"无法加载配置文件: {str(e)}")
|
||||
|
||||
# 配置日志
|
||||
logger = logging.getLogger(config['logging']['name'])
|
||||
logger.setLevel(getattr(logging, config['logging']['level'], logging.DEBUG))
|
||||
os.makedirs(os.path.dirname(config['logging']['file']), exist_ok=True)
|
||||
formatter = logging.Formatter('%(asctime)s - %(levelname)s - %(message)s')
|
||||
for handler in (logging.FileHandler(config['logging']['file'], encoding='utf-8'), logging.StreamHandler()):
|
||||
handler.setFormatter(formatter)
|
||||
logger.addHandler(handler)
|
||||
|
||||
def load_and_split_data(file_path: str, userid: str) -> List[Document]:
|
||||
"""
|
||||
加载文件,分片并生成带有元数据的 Document 对象。
|
||||
|
||||
参数:
|
||||
file_path (str): 文件路径
|
||||
userid (str): 用户ID
|
||||
|
||||
返回:
|
||||
List[Document]: 分片后的文档列表
|
||||
|
||||
异常:
|
||||
ValueError: 文件或参数无效
|
||||
"""
|
||||
try:
|
||||
# 验证文件
|
||||
if not os.path.exists(file_path):
|
||||
raise ValueError(f"文件 {file_path} 不存在")
|
||||
if os.path.getsize(file_path) == 0:
|
||||
raise ValueError(f"文件 {file_path} 为空")
|
||||
logger.debug(f"检查文件: {file_path}, 大小: {os.path.getsize(file_path)} 字节")
|
||||
ext = file_path.rsplit('.', 1)[1].lower()
|
||||
logger.debug(f"文件扩展名: {ext}")
|
||||
|
||||
# 使用 fileloader 加载文件内容
|
||||
logger.debug("开始加载文件")
|
||||
text = fileloader(file_path)
|
||||
if not text or not text.strip():
|
||||
raise ValueError(f"文件 {file_path} 加载为空")
|
||||
|
||||
# 创建单个 Document 对象
|
||||
document = Document(page_content=text)
|
||||
logger.debug(f"加载完成,生成 1 个文档")
|
||||
|
||||
# 分割文本
|
||||
text_splitter = RecursiveCharacterTextSplitter(
|
||||
chunk_size=2000,
|
||||
chunk_overlap=200,
|
||||
length_function=len,
|
||||
)
|
||||
chunks = text_splitter.split_documents([document])
|
||||
logger.debug(f"分割完成,生成 {len(chunks)} 个文档块")
|
||||
|
||||
# 为整个文件生成唯一的 document_id
|
||||
document_id = str(uuid.uuid4())
|
||||
|
||||
# 添加元数据,确保包含所有必需字段
|
||||
filename = os.path.basename(file_path)
|
||||
upload_time = datetime.now().isoformat()
|
||||
documents = []
|
||||
for i, chunk in enumerate(chunks):
|
||||
chunk.metadata.update({
|
||||
'userid': userid,
|
||||
'document_id': document_id,
|
||||
'filename': filename,
|
||||
'file_path': file_path,
|
||||
'upload_time': upload_time,
|
||||
'file_type': ext,
|
||||
'chunk_index': i, # 可选,追踪分片顺序
|
||||
'source': file_path, # 可选,追踪文件来源
|
||||
})
|
||||
# 验证元数据完整性
|
||||
required_fields = ['userid', 'document_id', 'filename', 'file_path', 'upload_time', 'file_type']
|
||||
if not all(field in chunk.metadata and chunk.metadata[field] for field in required_fields):
|
||||
raise ValueError(f"文档元数据缺少必需字段或值为空: {chunk.metadata}")
|
||||
documents.append(chunk)
|
||||
logger.debug(f"生成文档块 {i}: metadata={chunk.metadata}")
|
||||
|
||||
logger.debug(f"文件 {file_path} 加载并分割为 {len(documents)} 个文档块,document_id: {document_id}")
|
||||
return documents
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"加载或分割文件 {file_path} 失败: {str(e)}")
|
||||
import traceback
|
||||
logger.debug(traceback.format_exc())
|
||||
raise ValueError(f"加载或分割文件失败: {str(e)}")
|
||||
|
||||
def embed(file_path: str, userid: str, db_type: str) -> bool:
|
||||
"""
|
||||
嵌入文件到 Milvus 向量数据库。
|
||||
|
||||
参数:
|
||||
file_path (str): 文件路径
|
||||
userid (str): 用户ID
|
||||
db_type (str): 数据库类型
|
||||
|
||||
返回:
|
||||
bool: 嵌入是否成功
|
||||
|
||||
异常:
|
||||
ValueError: 参数无效
|
||||
RuntimeError: 数据库操作失败
|
||||
"""
|
||||
try:
|
||||
# 验证输入
|
||||
if not userid or not db_type:
|
||||
raise ValueError("userid 和 db_type 不能为空")
|
||||
if "_" in userid:
|
||||
raise ValueError("userid 不能包含下划线")
|
||||
if "_" in db_type:
|
||||
raise ValueError("db_type 不能包含下划线")
|
||||
if not os.path.exists(file_path):
|
||||
raise ValueError(f"文件 {file_path} 不存在")
|
||||
|
||||
supported_formats = {'pdf', 'doc', 'docx', 'xlsx', 'xls', 'ppt', 'pptx', 'csv', 'txt'}
|
||||
ext = file_path.rsplit('.', 1)[1].lower()
|
||||
if ext not in supported_formats:
|
||||
logger.error(f"文件 {file_path} 格式不支持,支持的格式: {', '.join(supported_formats)}")
|
||||
raise ValueError(f"不支持的文件格式: {ext}, 支持的格式: {', '.join(supported_formats)}")
|
||||
|
||||
# 加载并分割文件
|
||||
logger.info(f"开始处理文件 {file_path},userid: {userid},db_type: {db_type}")
|
||||
chunks = load_and_split_data(file_path, userid)
|
||||
if not chunks:
|
||||
logger.error(f"文件 {file_path} 未生成任何文档块")
|
||||
raise ValueError("未生成任何文档块")
|
||||
|
||||
logger.debug(f"处理文件 {file_path},生成 {len(chunks)} 个文档块")
|
||||
logger.debug(f"第一个文档块: {chunks[0].page_content[:200]}")
|
||||
|
||||
# 插入到 Milvus
|
||||
db = get_vector_db(userid, db_type, documents=chunks)
|
||||
if not db:
|
||||
logger.error(f"无法初始化或插入到向量数据库 ragdb_{db_type}")
|
||||
raise RuntimeError(f"数据库操作失败")
|
||||
|
||||
logger.info(f"文件 {file_path} 成功嵌入到数据库 ragdb_{db_type}")
|
||||
return True
|
||||
|
||||
except ValueError as ve:
|
||||
logger.error(f"嵌入文件 {file_path} 失败: {str(ve)}")
|
||||
return False
|
||||
except RuntimeError as re:
|
||||
logger.error(f"嵌入文件 {file_path} 失败: {str(re)}")
|
||||
return False
|
||||
except Exception as e:
|
||||
logger.error(f"嵌入文件 {file_path} 失败: {str(e)}")
|
||||
import traceback
|
||||
logger.debug(traceback.format_exc())
|
||||
return False
|
||||
|
||||
if __name__ == "__main__":
|
||||
test_file = "/share/wangmeihua/rag/data/test.txt"
|
||||
userid = "testuser4"
|
||||
db_type = "textdb"
|
||||
result = embed(test_file, userid, db_type)
|
||||
print(f"嵌入结果: {result}")
|
||||
@ -1,15 +0,0 @@
|
||||
from appPublic.worker import awaitify
|
||||
from ahserver.serverenv import ServerEnv
|
||||
from .kdb import add_kdb, add_dir, add_doc, get_all_docs
|
||||
from .query import search_query
|
||||
from .embed import embed
|
||||
def load_rag():
|
||||
env = ServerEnv()
|
||||
env.add_kdb = add_kdb
|
||||
env.query = awaitify(search_query)
|
||||
env.embed = awaitify(embed)
|
||||
env.add_dir = add_dir
|
||||
env.add_doc = add_doc
|
||||
env.get_all_docs = get_all_docs
|
||||
|
||||
|
||||
@ -1,81 +0,0 @@
|
||||
|
||||
from traceback import format_exc
|
||||
from appPublic.uniqueID import getID
|
||||
from appPublic.timeUtils import curDateString
|
||||
from appPublic.dictObject import DictObject
|
||||
from sqlor.dbpools import DBPools
|
||||
from ahserver.serverenv import get_serverenv
|
||||
from ahserver.filestorage import FileStorage
|
||||
|
||||
async def add_kdb(kdb:dict) -> None:
|
||||
"""
|
||||
添加知识库
|
||||
"""
|
||||
kdb = DictObject(**kdb)
|
||||
kdb.parentid=None
|
||||
if kdb.id is None:
|
||||
kdb.id = getID()
|
||||
kdb.entity_type = '0'
|
||||
kdb.create_date = curDateString()
|
||||
if kdb.orgid is None:
|
||||
e = Exception(f'Can not add none orgid kdb')
|
||||
exception(f'{e}\n{format_exc()}')
|
||||
raise e
|
||||
|
||||
f = get_serverenv('get_module_dbname')
|
||||
dbname = f('rag')
|
||||
db = DBPools()
|
||||
async with db.sqlorContext(dbname) as sor:
|
||||
await C('kdb', kdb.copy())
|
||||
|
||||
async def add_dir(kdb:dict) -> None:
|
||||
"""
|
||||
添加子目录
|
||||
"""
|
||||
kdb = DictObject(**kdb)
|
||||
if kdb.parentid is None:
|
||||
e = Exception(f'Can not add root folder')
|
||||
exception(f'{e}\n{format_exc()}')
|
||||
raise e
|
||||
if kdb.id is None:
|
||||
kdb.id = getID()
|
||||
kdb.entity_type = '1'
|
||||
kdb.create_date = curDateString()
|
||||
f = get_serverenv('get_module_dbname')
|
||||
dbname = f('rag')
|
||||
db = DBPools()
|
||||
async with db.sqlorContext(dbname) as sor:
|
||||
await C('kdb', kdb.copy())
|
||||
|
||||
async def add_doc(doc:dict) -> None:
|
||||
"""
|
||||
添加文档
|
||||
"""
|
||||
doc = DictObject(**doc)
|
||||
if doc.parentid is None:
|
||||
e = Exception(f'Can not add root document')
|
||||
exception(f'{e}\n{format_exc()}')
|
||||
raise e
|
||||
if doc.id is None:
|
||||
doc.id = getID()
|
||||
fs = FileStorage()
|
||||
doc.realpath = fs.realPath(doc.webpath)
|
||||
doc.create_date = curDateString()
|
||||
f = get_serverenv('get_module_dbname')
|
||||
dbname = f('rag')
|
||||
db = DBPools()
|
||||
async with db.sqlorContext(dbname) as sor:
|
||||
await C('doc', doc.copy())
|
||||
|
||||
async def get_all_docs(sor, kdbid):
|
||||
"""
|
||||
获取所有kdbid下的文档,含子目录的
|
||||
"""
|
||||
docs = await sor.R('doc', {'parentid':kdbid})
|
||||
kdbs = await sor.R('kdb', {'parentid':kdbid})
|
||||
for kdb in kdbs:
|
||||
docs1 = await get_all_docs(kdb.id)
|
||||
docs += docs1
|
||||
return docs
|
||||
|
||||
|
||||
@ -1,180 +0,0 @@
|
||||
import os
|
||||
import yaml
|
||||
import logging
|
||||
from typing import List, Dict
|
||||
from pymilvus import connections, Collection, utility
|
||||
from langchain_huggingface import HuggingFaceEmbeddings
|
||||
from .vector import get_vector_db, initialize_milvus_connection, cleanup_milvus_connection
|
||||
import torch
|
||||
|
||||
# 加载配置文件
|
||||
CONFIG_PATH = os.getenv('CONFIG_PATH', '/share/wangmeihua/rag/conf/milvusconfig.yaml')
|
||||
try:
|
||||
with open(CONFIG_PATH, 'r', encoding='utf-8') as f:
|
||||
config = yaml.safe_load(f)
|
||||
MILVUS_DB_PATH = config['database']['milvus_db_path']
|
||||
TEXT_EMBEDDING_MODEL = config['models']['text_embedding_model']
|
||||
except Exception as e:
|
||||
print(f"加载配置文件 {CONFIG_PATH} 失败: {str(e)}")
|
||||
raise RuntimeError(f"无法加载配置文件: {str(e)}")
|
||||
|
||||
# 配置日志
|
||||
logger = logging.getLogger(config['logging']['name'])
|
||||
logger.setLevel(getattr(logging, config['logging']['level'], logging.DEBUG))
|
||||
os.makedirs(os.path.dirname(config['logging']['file']), exist_ok=True)
|
||||
formatter = logging.Formatter('%(asctime)s - %(levelname)s - %(message)s')
|
||||
for handler in (logging.FileHandler(config['logging']['file'], encoding='utf-8'), logging.StreamHandler()):
|
||||
handler.setFormatter(formatter)
|
||||
logger.addHandler(handler)
|
||||
|
||||
def search_query(query: str, userid: str, db_type: str, limit: int = 10, offset: int = 0) -> List[Dict]:
|
||||
"""
|
||||
根据用户输入的查询文本,在指定 db_type 的知识库中搜索与 userid 相关的文档。
|
||||
|
||||
参数:
|
||||
query (str): 用户输入的查询文本
|
||||
userid (str): 用户ID,用于过滤
|
||||
db_type (str): 数据库类型(例如 'textdb')
|
||||
limit (int): 返回的最大结果数,默认为 10
|
||||
offset (int): 偏移量,用于分页,默认为 0
|
||||
|
||||
返回:
|
||||
List[Dict]: 搜索结果,每个元素为包含 text 和 metadata 的字典
|
||||
|
||||
异常:
|
||||
ValueError: 参数无效
|
||||
RuntimeError: 模型加载或 Milvus 操作失败
|
||||
"""
|
||||
try:
|
||||
# 参数验证
|
||||
if not query:
|
||||
raise ValueError("查询文本不能为空")
|
||||
if not userid or not db_type:
|
||||
raise ValueError("userid 和 db_type 不能为空")
|
||||
if "_" in userid or "_" in db_type:
|
||||
raise ValueError("userid 和 db_type 不能包含下划线")
|
||||
if len(userid) > 100 or len(db_type) > 100:
|
||||
raise ValueError("userid 和 db_type 的长度应小于 100")
|
||||
if limit <= 0 or limit > 16384:
|
||||
raise ValueError("limit 必须在 1 到 16384 之间")
|
||||
if offset < 0:
|
||||
raise ValueError("offset 不能为负数")
|
||||
if limit + offset > 16384:
|
||||
raise ValueError("limit + offset 不能超过 16384")
|
||||
|
||||
# 初始化嵌入模型
|
||||
model_path = TEXT_EMBEDDING_MODEL
|
||||
if not os.path.exists(model_path):
|
||||
raise ValueError(f"模型路径 {model_path} 不存在")
|
||||
|
||||
embedding = HuggingFaceEmbeddings(
|
||||
model_name=model_path,
|
||||
model_kwargs={'device': 'cuda' if torch.cuda.is_available() else 'cpu'},
|
||||
encode_kwargs={'normalize_embeddings': True}
|
||||
)
|
||||
try:
|
||||
test_vector = embedding.embed_query("test")
|
||||
if len(test_vector) != 1024:
|
||||
raise ValueError(f"嵌入模型输出维度 {len(test_vector)} 不匹配预期 1024")
|
||||
logger.debug("嵌入模型加载成功")
|
||||
except Exception as e:
|
||||
logger.error(f"嵌入模型加载失败: {str(e)}")
|
||||
raise RuntimeError(f"嵌入模型加载失败: {str(e)}")
|
||||
|
||||
# 将查询转换为向量
|
||||
query_vector = embedding.embed_query(query)
|
||||
logger.debug(f"查询向量维度: {len(query_vector)}")
|
||||
|
||||
# 连接到 Milvus
|
||||
initialize_milvus_connection()
|
||||
|
||||
# 检查集合是否存在
|
||||
collection_name = f"ragdb_{db_type}"
|
||||
if not utility.has_collection(collection_name):
|
||||
logger.warning(f"集合 {collection_name} 不存在")
|
||||
return []
|
||||
|
||||
# 加载集合
|
||||
try:
|
||||
collection = Collection(collection_name)
|
||||
collection.load()
|
||||
logger.debug(f"加载集合: {collection_name}")
|
||||
except Exception as e:
|
||||
logger.error(f"加载集合 {collection_name} 失败: {str(e)}")
|
||||
raise RuntimeError(f"加载集合失败: {str(e)}")
|
||||
|
||||
# 构造搜索参数
|
||||
search_params = {
|
||||
"metric_type": "COSINE", # 与 vector.py 一致
|
||||
"params": {"nprobe": 10} # 优化搜索性能
|
||||
}
|
||||
|
||||
# 构造过滤表达式
|
||||
expr = f"userid == '{userid}'"
|
||||
logger.debug(f"搜索参数: {search_params}, 表达式: {expr}, limit: {limit}, offset: {offset}")
|
||||
|
||||
# 执行搜索
|
||||
try:
|
||||
results = collection.search(
|
||||
data=[query_vector],
|
||||
anns_field="vector",
|
||||
param=search_params,
|
||||
limit=limit,
|
||||
expr=expr,
|
||||
output_fields=["text", "userid", "document_id", "filename", "file_path", "upload_time", "file_type"],
|
||||
offset=offset
|
||||
)
|
||||
except Exception as e:
|
||||
logger.error(f"搜索失败: {str(e)}")
|
||||
raise RuntimeError(f"搜索失败: {str(e)}")
|
||||
|
||||
# 处理搜索结果
|
||||
search_results = []
|
||||
for hits in results:
|
||||
for hit in hits:
|
||||
metadata = {
|
||||
"userid": hit.entity.get("userid"),
|
||||
"document_id": hit.entity.get("document_id"),
|
||||
"filename": hit.entity.get("filename"),
|
||||
"file_path": hit.entity.get("file_path"),
|
||||
"upload_time": hit.entity.get("upload_time"),
|
||||
"file_type": hit.entity.get("file_type")
|
||||
}
|
||||
result = {
|
||||
"text": hit.entity.get("text"),
|
||||
"distance": hit.distance,
|
||||
"metadata": metadata
|
||||
}
|
||||
search_results.append(result)
|
||||
logger.debug(f"命中: text: {result['text'][:200]}..., 距离: {hit.distance}, 元数据: {metadata}")
|
||||
|
||||
logger.debug(f"搜索完成,返回 {len(search_results)} 条结果")
|
||||
return search_results
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"搜索失败: {str(e)}")
|
||||
import traceback
|
||||
logger.debug(traceback.format_exc())
|
||||
raise
|
||||
finally:
|
||||
cleanup_milvus_connection()
|
||||
|
||||
if __name__ == "__main__":
|
||||
# 测试代码
|
||||
query = "知识图谱的知识融合是什么?"
|
||||
userid = "testuser1"
|
||||
db_type = "textdb"
|
||||
limit = 2
|
||||
offset = 0
|
||||
|
||||
try:
|
||||
results = search_query(query, userid, db_type, limit, offset)
|
||||
print(f"搜索结果 ({len(results)} 条):")
|
||||
for idx, result in enumerate(results, 1):
|
||||
print(f"结果 {idx}:")
|
||||
print(f"内容: {result['text'][:200]}...")
|
||||
print(f"距离: {result['distance']}")
|
||||
print(f"元数据: {result['metadata']}")
|
||||
print("-" * 50)
|
||||
except Exception as e:
|
||||
print(f"搜索失败: {str(e)}")
|
||||
@ -1,53 +0,0 @@
|
||||
from ahserver.serverenv import ServerEnv
|
||||
from ahserver.configuredServer import ConfiguredServer
|
||||
from ahserver.webapp import webapp
|
||||
from appPublic.worker import awaitify
|
||||
from query import search_query
|
||||
from embed import embed
|
||||
from deletefile import delete_document
|
||||
import logging
|
||||
import os
|
||||
|
||||
# 配置日志
|
||||
logging.basicConfig(
|
||||
level=logging.DEBUG,
|
||||
format='%(asctime)s - %(levelname)s - %(message)s',
|
||||
handlers=[
|
||||
logging.FileHandler(os.path.expanduser('~/rag/logs/rag.log'), encoding='utf-8'),
|
||||
logging.StreamHandler()
|
||||
]
|
||||
)
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
def get_module_dbname(name):
|
||||
"""
|
||||
获取默认数据库名称,优先使用环境变量 RAG_DB_TYPE,未设置时返回 'textdb'。
|
||||
"""
|
||||
return os.getenv('RAG_DB_TYPE', 'textdb')
|
||||
|
||||
|
||||
def init():
|
||||
"""
|
||||
初始化 RAG 系统,绑定核心函数到 ServerEnv。
|
||||
"""
|
||||
try:
|
||||
logger.info("初始化 RAG 系统")
|
||||
g = ServerEnv()
|
||||
|
||||
# 绑定核心函数为异步版本
|
||||
g.embed = awaitify(embed)
|
||||
g.query = awaitify(search_query)
|
||||
g.delete = awaitify(delete_document)
|
||||
g.get_module_dbname = get_module_dbname
|
||||
|
||||
logger.info("RAG 系统初始化完成")
|
||||
return g
|
||||
except Exception as e:
|
||||
logger.error(f"初始化失败: {str(e)}")
|
||||
raise
|
||||
|
||||
|
||||
if __name__ == '__main__':
|
||||
logger.info("启动 RAG Web 服务器")
|
||||
webapp(init)
|
||||
@ -1,53 +0,0 @@
|
||||
from ahserver.serverenv import ServerEnv
|
||||
from ahserver.configuredServer import ConfiguredServer
|
||||
from ahserver.webapp import webapp
|
||||
from appPublic.worker import awaitify
|
||||
from query import search_query
|
||||
from embed import embed
|
||||
from deletefile import delete_document
|
||||
import logging
|
||||
import os
|
||||
|
||||
# 配置日志
|
||||
logging.basicConfig(
|
||||
level=logging.DEBUG,
|
||||
format='%(asctime)s - %(levelname)s - %(message)s',
|
||||
handlers=[
|
||||
logging.FileHandler(os.path.expanduser('~/rag/logs/rag.log'), encoding='utf-8'),
|
||||
logging.StreamHandler()
|
||||
]
|
||||
)
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
def get_module_dbname(name):
|
||||
"""
|
||||
获取默认数据库名称,优先使用环境变量 RAG_DB_TYPE,未设置时返回 'textdb'。
|
||||
"""
|
||||
return os.getenv('RAG_DB_TYPE', 'textdb')
|
||||
|
||||
|
||||
def init():
|
||||
"""
|
||||
初始化 RAG 系统,绑定核心函数到 ServerEnv。
|
||||
"""
|
||||
try:
|
||||
logger.info("初始化 RAG 系统")
|
||||
g = ServerEnv()
|
||||
|
||||
# 绑定核心函数为异步版本
|
||||
g.embed = awaitify(embed)
|
||||
g.query = awaitify(search_query)
|
||||
g.delete = awaitify(delete_document)
|
||||
g.get_module_dbname = get_module_dbname
|
||||
|
||||
logger.info("RAG 系统初始化完成")
|
||||
return g
|
||||
except Exception as e:
|
||||
logger.error(f"初始化失败: {str(e)}")
|
||||
raise
|
||||
|
||||
|
||||
if __name__ == '__main__':
|
||||
logger.info("启动 RAG Web 服务器")
|
||||
webapp(init)
|
||||
@ -1,539 +0,0 @@
|
||||
import os
|
||||
import uuid
|
||||
import json
|
||||
import yaml
|
||||
from datetime import datetime
|
||||
from typing import List, Dict, Optional
|
||||
from pymilvus import connections, utility, Collection, CollectionSchema, FieldSchema, DataType
|
||||
from langchain_milvus import Milvus
|
||||
from langchain_huggingface import HuggingFaceEmbeddings
|
||||
from langchain_core.documents import Document
|
||||
import torch
|
||||
import logging
|
||||
import time
|
||||
|
||||
# 加载配置文件
|
||||
CONFIG_PATH = os.getenv('CONFIG_PATH', '/share/wangmeihua/rag/conf/milvusconfig.yaml')
|
||||
try:
|
||||
with open(CONFIG_PATH, 'r', encoding='utf-8') as f:
|
||||
config = yaml.safe_load(f)
|
||||
MILVUS_DB_PATH = config['database']['milvus_db_path']
|
||||
TEXT_EMBEDDING_MODEL = config['models']['text_embedding_model']
|
||||
except Exception as e:
|
||||
print(f"加载配置文件 {CONFIG_PATH} 失败: {str(e)}")
|
||||
raise RuntimeError(f"无法加载配置文件: {str(e)}")
|
||||
|
||||
# 配置日志
|
||||
logger = logging.getLogger(config['logging']['name'])
|
||||
logger.setLevel(getattr(logging, config['logging']['level'], logging.DEBUG))
|
||||
os.makedirs(os.path.dirname(config['logging']['file']), exist_ok=True)
|
||||
formatter = logging.Formatter('%(asctime)s - %(levelname)s - %(message)s')
|
||||
for handler in (logging.FileHandler(config['logging']['file'], encoding='utf-8'), logging.StreamHandler()):
|
||||
handler.setFormatter(formatter)
|
||||
logger.addHandler(handler)
|
||||
|
||||
def ensure_milvus_directory() -> None:
|
||||
"""确保 Milvus 数据库目录存在"""
|
||||
db_dir = os.path.dirname(MILVUS_DB_PATH)
|
||||
if not os.path.exists(db_dir):
|
||||
os.makedirs(db_dir, exist_ok=True)
|
||||
logger.debug(f"创建 Milvus 目录: {db_dir}")
|
||||
if not os.access(db_dir, os.W_OK):
|
||||
raise RuntimeError(f"Milvus 目录 {db_dir} 不可写")
|
||||
|
||||
def initialize_milvus_connection() -> None:
|
||||
"""初始化 Milvus 连接,确保单一连接"""
|
||||
try:
|
||||
if not connections.has_connection("default"):
|
||||
connections.connect("default", uri=MILVUS_DB_PATH)
|
||||
logger.debug(f"已连接到 Milvus Lite,路径: {MILVUS_DB_PATH}")
|
||||
else:
|
||||
logger.debug("已存在 Milvus 连接,跳过重复连接")
|
||||
except Exception as e:
|
||||
logger.error(f"连接 Milvus 失败: {str(e)}")
|
||||
raise RuntimeError(f"连接 Milvus 失败: {str(e)}")
|
||||
|
||||
def cleanup_milvus_connection() -> None:
|
||||
"""清理 Milvus 连接,确保资源释放"""
|
||||
try:
|
||||
if connections.has_connection("default"):
|
||||
connections.disconnect("default")
|
||||
logger.debug("已断开 Milvus 连接")
|
||||
time.sleep(3)
|
||||
except Exception as e:
|
||||
logger.warning(f"断开 Milvus 连接失败: {str(e)}")
|
||||
|
||||
def get_vector_db(userid: str, db_type: str, documents: List[Document]) -> Milvus:
|
||||
"""
|
||||
初始化或访问 Milvus Lite 向量数据库集合,按 db_type 组织,利用 userid 区分用户,document_id 区分文档,并插入文档。
|
||||
"""
|
||||
try:
|
||||
# 参数验证
|
||||
if not userid or not db_type:
|
||||
raise ValueError("userid 和 db_type 不能为空")
|
||||
if "_" in userid or "_" in db_type:
|
||||
raise ValueError("userid 和 db_type 不能包含下划线")
|
||||
if len(userid) > 100 or len(db_type) > 100:
|
||||
raise ValueError("userid 和 db_type 的长度应小于 100")
|
||||
if not documents or not all(isinstance(doc, Document) for doc in documents):
|
||||
raise ValueError("documents 不能为空且必须是 Document 对象列表")
|
||||
required_fields = ["userid", "document_id", "filename", "file_path", "upload_time", "file_type"]
|
||||
for doc in documents:
|
||||
if not all(field in doc.metadata and doc.metadata[field] for field in required_fields):
|
||||
raise ValueError(f"文档元数据缺少必需字段或字段值为空: {doc.metadata}")
|
||||
if doc.metadata["userid"] != userid:
|
||||
raise ValueError(f"文档元数据的 userid {doc.metadata['userid']} 与输入 userid {userid} 不一致")
|
||||
|
||||
ensure_milvus_directory()
|
||||
initialize_milvus_connection()
|
||||
|
||||
# 初始化嵌入模型
|
||||
model_path = TEXT_EMBEDDING_MODEL
|
||||
if not os.path.exists(model_path):
|
||||
raise ValueError(f"模型路径 {model_path} 不存在")
|
||||
|
||||
embedding = HuggingFaceEmbeddings(
|
||||
model_name=model_path,
|
||||
model_kwargs={'device': 'cuda' if torch.cuda.is_available() else 'cpu'},
|
||||
encode_kwargs={'normalize_embeddings': True}
|
||||
)
|
||||
try:
|
||||
test_vector = embedding.embed_query("test")
|
||||
if len(test_vector) != 1024:
|
||||
raise ValueError(f"嵌入模型输出维度 {len(test_vector)} 不匹配预期 1024")
|
||||
logger.debug(f"嵌入模型加载成功,输出维度: {len(test_vector)}")
|
||||
except Exception as e:
|
||||
logger.error(f"嵌入模型加载失败: {str(e)}")
|
||||
raise RuntimeError(f"加载模型失败: {str(e)}")
|
||||
|
||||
# 集合名称
|
||||
collection_name = f"ragdb_{db_type}"
|
||||
if len(collection_name) > 255:
|
||||
raise ValueError(f"集合名称 {collection_name} 超过 255 个字符")
|
||||
logger.debug(f"集合名称: {collection_name}")
|
||||
|
||||
# 定义 schema,包含所有固定字段
|
||||
fields = [
|
||||
FieldSchema(name="pk", dtype=DataType.VARCHAR, is_primary=True, max_length=36, auto_id=True),
|
||||
FieldSchema(name="userid", dtype=DataType.VARCHAR, max_length=100),
|
||||
FieldSchema(name="document_id", dtype=DataType.VARCHAR, max_length=36),
|
||||
FieldSchema(name="text", dtype=DataType.VARCHAR, max_length=65535),
|
||||
FieldSchema(name="vector", dtype=DataType.FLOAT_VECTOR, dim=1024),
|
||||
FieldSchema(name="filename", dtype=DataType.VARCHAR, max_length=255),
|
||||
FieldSchema(name="file_path", dtype=DataType.VARCHAR, max_length=1024),
|
||||
FieldSchema(name="upload_time", dtype=DataType.VARCHAR, max_length=64),
|
||||
FieldSchema(name="file_type", dtype=DataType.VARCHAR, max_length=64),
|
||||
]
|
||||
schema = CollectionSchema(
|
||||
fields=fields,
|
||||
description=f"{db_type} 数据集合,跨用户使用,包含 document_id 和元数据字段",
|
||||
auto_id=True,
|
||||
primary_field="pk",
|
||||
)
|
||||
|
||||
# 检查集合是否存在
|
||||
if utility.has_collection(collection_name):
|
||||
try:
|
||||
collection = Collection(collection_name)
|
||||
existing_schema = collection.schema
|
||||
expected_fields = {f.name for f in fields}
|
||||
actual_fields = {f.name for f in existing_schema.fields}
|
||||
vector_field = next((f for f in existing_schema.fields if f.name == "vector"), None)
|
||||
|
||||
schema_compatible = False
|
||||
if expected_fields == actual_fields and vector_field is not None and vector_field.dtype == DataType.FLOAT_VECTOR:
|
||||
dim = vector_field.params.get('dim', None) if hasattr(vector_field, 'params') and vector_field.params else None
|
||||
schema_compatible = dim == 1024
|
||||
logger.debug(f"检查集合 {collection_name} 的 schema: 字段匹配={expected_fields == actual_fields}, "
|
||||
f"vector_field存在={vector_field is not None}, dtype={vector_field.dtype if vector_field else '无'}, "
|
||||
f"dim={dim if dim is not None else '未定义'}")
|
||||
if not schema_compatible:
|
||||
logger.warning(f"集合 {collection_name} 的 schema 不兼容,原因: "
|
||||
f"字段不匹配: {expected_fields.symmetric_difference(actual_fields) or '无'}, "
|
||||
f"vector_field: {vector_field is not None}, "
|
||||
f"dtype: {vector_field.dtype if vector_field else '无'}, "
|
||||
f"dim: {vector_field.params.get('dim', '未定义') if vector_field and hasattr(vector_field, 'params') and vector_field.params else '未定义'}")
|
||||
utility.drop_collection(collection_name)
|
||||
else:
|
||||
collection.load()
|
||||
logger.debug(f"集合 {collection_name} 已存在并加载成功")
|
||||
except Exception as e:
|
||||
logger.error(f"加载集合 {collection_name} 失败: {str(e)}")
|
||||
raise RuntimeError(f"加载集合失败: {str(e)}")
|
||||
|
||||
# 创建新集合
|
||||
if not utility.has_collection(collection_name):
|
||||
try:
|
||||
collection = Collection(collection_name, schema)
|
||||
collection.create_index(
|
||||
field_name="vector",
|
||||
index_params={"index_type": "AUTOINDEX", "metric_type": "COSINE"}
|
||||
)
|
||||
collection.create_index(
|
||||
field_name="userid",
|
||||
index_params={"index_type": "INVERTED"}
|
||||
)
|
||||
collection.create_index(
|
||||
field_name="document_id",
|
||||
index_params={"index_type": "INVERTED"}
|
||||
)
|
||||
collection.create_index(
|
||||
field_name="filename",
|
||||
index_params={"index_type": "INVERTED"}
|
||||
)
|
||||
collection.create_index(
|
||||
field_name="file_path",
|
||||
index_params={"index_type": "INVERTED"}
|
||||
)
|
||||
collection.create_index(
|
||||
field_name="upload_time",
|
||||
index_params={"index_type": "INVERTED"}
|
||||
)
|
||||
collection.create_index(
|
||||
field_name="file_type",
|
||||
index_params={"index_type": "INVERTED"}
|
||||
)
|
||||
collection.load()
|
||||
logger.debug(f"成功创建并加载集合: {collection_name}")
|
||||
except Exception as e:
|
||||
logger.error(f"创建集合 {collection_name} 失败: {str(e)}")
|
||||
raise RuntimeError(f"创建集合失败: {str(e)}")
|
||||
|
||||
# 初始化 Milvus 向量存储
|
||||
try:
|
||||
vector_store = Milvus(
|
||||
embedding_function=embedding,
|
||||
collection_name=collection_name,
|
||||
connection_args={"uri": MILVUS_DB_PATH},
|
||||
drop_old=False,
|
||||
auto_id=True,
|
||||
primary_field="pk",
|
||||
)
|
||||
logger.debug(f"成功初始化 Milvus 向量存储: {collection_name}")
|
||||
except Exception as e:
|
||||
logger.error(f"初始化 Milvus 向量存储失败: {str(e)}")
|
||||
raise RuntimeError(f"初始化向量存储失败: {str(e)}")
|
||||
|
||||
# 插入文档
|
||||
try:
|
||||
logger.debug(f"正在为 userid {userid} 插入 {len(documents)} 个文档到 {collection_name}")
|
||||
for doc in documents:
|
||||
logger.debug(f"插入文档元数据: {doc.metadata}")
|
||||
vector_store.add_documents(documents=documents)
|
||||
logger.debug(f"成功插入 {len(documents)} 个文档")
|
||||
|
||||
# 立即查询验证
|
||||
collection = Collection(collection_name)
|
||||
collection.load()
|
||||
results = collection.query(
|
||||
expr=f"userid == '{userid}'",
|
||||
output_fields=["pk", "text", "document_id", "filename", "file_path", "upload_time", "file_type"],
|
||||
limit=10
|
||||
)
|
||||
for result in results:
|
||||
logger.debug(f"插入后查询结果: pk={result['pk']}, document_id={result['document_id']}, "
|
||||
f"metadata={{'filename': '{result['filename']}', 'file_path': '{result['file_path']}', "
|
||||
f"'upload_time': '{result['upload_time']}', 'file_type': '{result['file_type']}'}}")
|
||||
except Exception as e:
|
||||
logger.error(f"插入文档失败: {str(e)}")
|
||||
raise RuntimeError(f"插入文档失败: {str(e)}")
|
||||
|
||||
return vector_store
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"初始化 Milvus 向量存储失败: {str(e)}")
|
||||
raise
|
||||
finally:
|
||||
cleanup_milvus_connection()
|
||||
|
||||
def get_document_mapping(userid: str, db_type: str) -> Dict[str, Dict]:
|
||||
"""
|
||||
获取指定 userid 和 db_type 下的 document_id 与元数据的映射。
|
||||
"""
|
||||
try:
|
||||
if not userid or "_" in userid:
|
||||
raise ValueError("userid 不能为空且不能包含下划线")
|
||||
if not db_type or "_" in db_type:
|
||||
raise ValueError("db_type 不能为空且不能包含下划线")
|
||||
|
||||
initialize_milvus_connection()
|
||||
collection_name = f"ragdb_{db_type}"
|
||||
if not utility.has_collection(collection_name):
|
||||
logger.warning(f"集合 {collection_name} 不存在")
|
||||
return {}
|
||||
|
||||
collection = Collection(collection_name)
|
||||
collection.load()
|
||||
|
||||
results = collection.query(
|
||||
expr=f"userid == '{userid}'",
|
||||
output_fields=["userid", "document_id", "filename", "file_path", "upload_time", "file_type"],
|
||||
limit=100
|
||||
)
|
||||
mapping = {}
|
||||
for result in results:
|
||||
doc_id = result.get("document_id")
|
||||
if doc_id:
|
||||
mapping[doc_id] = {
|
||||
"userid": result.get("userid", ""),
|
||||
"filename": result.get("filename", ""),
|
||||
"file_path": result.get("file_path", ""),
|
||||
"upload_time": result.get("upload_time", ""),
|
||||
"file_type": result.get("file_type", "")
|
||||
}
|
||||
logger.debug(f"document_id: {doc_id}, metadata: {mapping[doc_id]}")
|
||||
|
||||
logger.debug(f"找到 {len(mapping)} 个文档的映射")
|
||||
return mapping
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"获取文档映射失败: {str(e)}")
|
||||
raise RuntimeError(f"获取文档映射失败: {str(e)}")
|
||||
|
||||
def list_user_collections() -> Dict[str, Dict]:
|
||||
"""
|
||||
列出所有数据库类型(db_type)及其包含的用户(userid)与对应的文档(document_id)映射。
|
||||
"""
|
||||
try:
|
||||
ensure_milvus_directory()
|
||||
initialize_milvus_connection()
|
||||
collections = utility.list_collections()
|
||||
|
||||
db_types_with_data = {}
|
||||
for col in collections:
|
||||
if col.startswith("ragdb_"):
|
||||
db_type = col[len("ragdb_"):]
|
||||
logger.debug(f"处理集合: {col} (db_type: {db_type})")
|
||||
|
||||
collection = Collection(col)
|
||||
collection.load()
|
||||
|
||||
batch_size = 1000
|
||||
offset = 0
|
||||
user_document_map = {}
|
||||
while True:
|
||||
try:
|
||||
results = collection.query(
|
||||
expr="",
|
||||
output_fields=["userid", "document_id"],
|
||||
limit=batch_size,
|
||||
offset=offset
|
||||
)
|
||||
if not results:
|
||||
break
|
||||
for result in results:
|
||||
userid = result.get("userid")
|
||||
doc_id = result.get("document_id")
|
||||
if userid and doc_id:
|
||||
if userid not in user_document_map:
|
||||
user_document_map[userid] = set()
|
||||
user_document_map[userid].add(doc_id)
|
||||
offset += batch_size
|
||||
except Exception as e:
|
||||
logger.error(f"查询集合 {col} 的用户和文档失败: {str(e)}")
|
||||
break
|
||||
|
||||
# 转换为列表以便序列化
|
||||
user_document_map = {uid: list(doc_ids) for uid, doc_ids in user_document_map.items()}
|
||||
logger.debug(f"集合 {col} 中找到用户和文档映射: {user_document_map}")
|
||||
|
||||
db_types_with_data[db_type] = {
|
||||
"userids": user_document_map
|
||||
}
|
||||
|
||||
logger.debug(f"可用 db_types 和数据: {db_types_with_data}")
|
||||
return db_types_with_data
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"列出集合和用户失败: {str(e)}")
|
||||
raise
|
||||
|
||||
def view_collection_details(userid: str) -> None:
|
||||
"""
|
||||
查看特定 userid 在所有集合中的内容和容量,包含 document_id 和元数据。
|
||||
"""
|
||||
try:
|
||||
if not userid or "_" in userid:
|
||||
raise ValueError("userid 不能为空且不能包含下划线")
|
||||
|
||||
logger.debug(f"正在查看 userid {userid} 的集合")
|
||||
ensure_milvus_directory()
|
||||
initialize_milvus_connection()
|
||||
collections = utility.list_collections()
|
||||
db_types = [col[len("ragdb_"):] for col in collections if col.startswith("ragdb_")]
|
||||
|
||||
if not db_types:
|
||||
logger.debug(f"未找到任何集合")
|
||||
return
|
||||
|
||||
for db_type in db_types:
|
||||
collection_name = f"ragdb_{db_type}"
|
||||
if not utility.has_collection(collection_name):
|
||||
logger.warning(f"集合 {collection_name} 不存在")
|
||||
continue
|
||||
|
||||
collection = Collection(collection_name)
|
||||
collection.load()
|
||||
|
||||
try:
|
||||
all_pks = collection.query(expr=f"userid == '{userid}'", output_fields=["pk"], limit=10000)
|
||||
num_entities = len(all_pks)
|
||||
results = collection.query(
|
||||
expr=f"userid == '{userid}'",
|
||||
output_fields=["userid","text", "document_id", "filename", "file_path", "upload_time", "file_type"],
|
||||
limit=10
|
||||
)
|
||||
logger.debug(f"集合 {collection_name} 中 userid {userid} 的文档数: {num_entities}")
|
||||
|
||||
if num_entities == 0:
|
||||
logger.debug(f"集合 {collection_name} 中 userid {userid} 无文档")
|
||||
continue
|
||||
|
||||
logger.debug(f"集合 {collection_name} 中 userid {userid} 的内容:")
|
||||
for idx, doc in enumerate(results, 1):
|
||||
metadata = {
|
||||
"userid": doc.get("userid", ""),
|
||||
"filename": doc.get("filename", ""),
|
||||
"file_path": doc.get("file_path", ""),
|
||||
"upload_time": doc.get("upload_time", ""),
|
||||
"file_type": doc.get("file_type", "")
|
||||
}
|
||||
logger.debug(f"文档 {idx}: 内容: {doc.get('text', '')[:200]}..., 元数据: {metadata}")
|
||||
except Exception as e:
|
||||
logger.error(f"查询集合 {collection_name} 的文档失败: {str(e)}")
|
||||
continue
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"无法查看 userid {userid} 的集合详情: {str(e)}")
|
||||
raise
|
||||
|
||||
def view_vector_data(db_type: str, userid: Optional[str] = None, document_id: Optional[str] = None, limit: int = 100) -> Dict[str, Dict]:
|
||||
"""
|
||||
查看指定 db_type 中的向量数据,可选按 userid 和 document_id 过滤,包含完整元数据和向量。
|
||||
"""
|
||||
try:
|
||||
if not db_type or "_" in db_type:
|
||||
raise ValueError("db_type 不能为空且不能包含下划线")
|
||||
if limit <= 0 or limit > 16384:
|
||||
raise ValueError("limit 必须在 1 到 16384 之间")
|
||||
if userid and "_" in userid:
|
||||
raise ValueError("userid 不能包含下划线")
|
||||
if document_id and "_" in document_id:
|
||||
raise ValueError("document_id 不能包含下划线")
|
||||
|
||||
initialize_milvus_connection()
|
||||
collection_name = f"ragdb_{db_type}"
|
||||
if not utility.has_collection(collection_name):
|
||||
logger.warning(f"集合 {collection_name} 不存在")
|
||||
return {}
|
||||
|
||||
collection = Collection(collection_name)
|
||||
collection.load()
|
||||
logger.debug(f"加载集合: {collection_name}")
|
||||
|
||||
expr = []
|
||||
if userid:
|
||||
expr.append(f"userid == '{userid}'")
|
||||
if document_id:
|
||||
expr.append(f"document_id == '{document_id}'")
|
||||
expr = " && ".join(expr) if expr else ""
|
||||
|
||||
results = collection.query(
|
||||
expr=expr,
|
||||
output_fields=["pk", "text", "document_id", "vector", "filename", "file_path", "upload_time", "file_type"],
|
||||
limit=limit
|
||||
)
|
||||
|
||||
vector_data = {}
|
||||
for doc in results:
|
||||
pk = doc.get("pk", str(uuid.uuid4()))
|
||||
text = doc.get("text", "")
|
||||
doc_id = doc.get("document_id", "")
|
||||
vector = doc.get("vector", [])
|
||||
metadata = {
|
||||
"filename": doc.get("filename", ""),
|
||||
"file_path": doc.get("file_path", ""),
|
||||
"upload_time": doc.get("upload_time", ""),
|
||||
"file_type": doc.get("file_type", "")
|
||||
}
|
||||
vector_data[pk] = {
|
||||
"text": text,
|
||||
"vector": vector,
|
||||
"document_id": doc_id,
|
||||
"metadata": metadata
|
||||
}
|
||||
logger.debug(f"pk: {pk}, text: {text[:200]}..., document_id: {doc_id}, vector_length: {len(vector)}, metadata: {metadata}")
|
||||
|
||||
logger.debug(f"共找到 {len(vector_data)} 条向量数据")
|
||||
return vector_data
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"查看向量数据失败: {str(e)}")
|
||||
raise
|
||||
|
||||
def main():
|
||||
userid = "testuser1"
|
||||
db_type = "textdb"
|
||||
|
||||
# logger.info("\n测试 1:带文档初始化")
|
||||
# documents = [
|
||||
# Document(
|
||||
# page_content="深度学习是基于深层神经网络的机器学习子集。",
|
||||
# metadata={
|
||||
# "userid": userid,
|
||||
# "document_id": str(uuid.uuid4()),
|
||||
# "filename": "test_doc1.txt",
|
||||
# "file_path": "/path/to/test_doc1.txt",
|
||||
# "upload_time": datetime.now().isoformat(),
|
||||
# "file_type": "txt"
|
||||
# }
|
||||
# ),
|
||||
# Document(
|
||||
# page_content="知识图谱是一个结构化的语义知识库。",
|
||||
# metadata={
|
||||
# "userid": userid,
|
||||
# "document_id": str(uuid.uuid4()),
|
||||
# "filename": "test_doc2.txt",
|
||||
# "file_path": "/path/to/test_doc2.txt",
|
||||
# "upload_time": datetime.now().isoformat(),
|
||||
# "file_type": "txt"
|
||||
# }
|
||||
# ),
|
||||
# ]
|
||||
#
|
||||
# try:
|
||||
# vector_store = get_vector_db(userid, db_type, documents=documents)
|
||||
# logger.info(f"集合: ragdb_{db_type}")
|
||||
# logger.info(f"成功为 userid {userid} 在 {db_type} 中插入文档")
|
||||
# except Exception as e:
|
||||
# logger.error(f"失败: {str(e)}")
|
||||
|
||||
logger.info("\n测试 2:列出所有 db_types 和文档映射")
|
||||
try:
|
||||
db_types = list_user_collections()
|
||||
logger.info(f"可用 db_types 和文档: {db_types}")
|
||||
except Exception as e:
|
||||
logger.error(f"失败: {str(e)}")
|
||||
|
||||
logger.info(f"\n测试 3:查看 userid {userid} 的所有集合")
|
||||
try:
|
||||
view_collection_details(userid)
|
||||
except Exception as e:
|
||||
logger.error(f"失败: {str(e)}")
|
||||
|
||||
# logger.info(f"\n测试 4:查看向量数据")
|
||||
# try:
|
||||
# vector_data = view_vector_data(db_type, userid=userid)
|
||||
# logger.info(f"向量数据: {vector_data}")
|
||||
# except Exception as e:
|
||||
# logger.error(f"失败: {str(e)}")
|
||||
|
||||
logger.info(f"\n测试 5:获取 userid {userid} 在{db_type}数据库的文档映射")
|
||||
try:
|
||||
mapping = get_document_mapping(userid, db_type)
|
||||
logger.info(f"文档映射: {mapping}")
|
||||
except Exception as e:
|
||||
logger.error(f"失败: {str(e)}")
|
||||
|
||||
if __name__ == "__main__":
|
||||
main()
|
||||
@ -1 +0,0 @@
|
||||
__version__ = '0.0.1'
|
||||
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
@ -1 +0,0 @@
|
||||
网的 c0f603f3-1bbe-45dc-bb4a-1005e26bf1ac
|
||||
@ -1,3 +0,0 @@
|
||||
Relation Extraction c0f603f3-1bbe-45dc-bb4a-1005e26bf1ac
|
||||
Web c0f603f3-1bbe-45dc-bb4a-1005e26bf1ac
|
||||
的知识 c0f603f3-1bbe-45dc-bb4a-1005e26bf1ac
|
||||
Loading…
x
Reference in New Issue
Block a user