179 lines
6.8 KiB
Python
179 lines
6.8 KiB
Python
import os
|
||
import uuid
|
||
import yaml
|
||
import logging
|
||
from datetime import datetime
|
||
from typing import List
|
||
from langchain_core.documents import Document
|
||
from langchain_text_splitters import RecursiveCharacterTextSplitter
|
||
from pymilvus import connections
|
||
from .vector import get_vector_db
|
||
from filetxt.loader import fileloader
|
||
|
||
# 加载配置文件
|
||
CONFIG_PATH = os.getenv('CONFIG_PATH', '/share/wangmeihua/rag/conf/milvusconfig.yaml')
|
||
try:
|
||
with open(CONFIG_PATH, 'r', encoding='utf-8') as f:
|
||
config = yaml.safe_load(f)
|
||
MILVUS_DB_PATH = config['database']['milvus_db_path']
|
||
except Exception as e:
|
||
print(f"加载配置文件 {CONFIG_PATH} 失败: {str(e)}")
|
||
raise RuntimeError(f"无法加载配置文件: {str(e)}")
|
||
|
||
# 配置日志
|
||
logger = logging.getLogger(config['logging']['name'])
|
||
logger.setLevel(getattr(logging, config['logging']['level'], logging.DEBUG))
|
||
os.makedirs(os.path.dirname(config['logging']['file']), exist_ok=True)
|
||
formatter = logging.Formatter('%(asctime)s - %(levelname)s - %(message)s')
|
||
for handler in (logging.FileHandler(config['logging']['file'], encoding='utf-8'), logging.StreamHandler()):
|
||
handler.setFormatter(formatter)
|
||
logger.addHandler(handler)
|
||
|
||
def load_and_split_data(file_path: str, userid: str) -> List[Document]:
|
||
"""
|
||
加载文件,分片并生成带有元数据的 Document 对象。
|
||
|
||
参数:
|
||
file_path (str): 文件路径
|
||
userid (str): 用户ID
|
||
|
||
返回:
|
||
List[Document]: 分片后的文档列表
|
||
|
||
异常:
|
||
ValueError: 文件或参数无效
|
||
"""
|
||
try:
|
||
# 验证文件
|
||
if not os.path.exists(file_path):
|
||
raise ValueError(f"文件 {file_path} 不存在")
|
||
if os.path.getsize(file_path) == 0:
|
||
raise ValueError(f"文件 {file_path} 为空")
|
||
logger.debug(f"检查文件: {file_path}, 大小: {os.path.getsize(file_path)} 字节")
|
||
ext = file_path.rsplit('.', 1)[1].lower()
|
||
logger.debug(f"文件扩展名: {ext}")
|
||
|
||
# 使用 fileloader 加载文件内容
|
||
logger.debug("开始加载文件")
|
||
text = fileloader(file_path)
|
||
if not text or not text.strip():
|
||
raise ValueError(f"文件 {file_path} 加载为空")
|
||
|
||
# 创建单个 Document 对象
|
||
document = Document(page_content=text)
|
||
logger.debug(f"加载完成,生成 1 个文档")
|
||
|
||
# 分割文本
|
||
text_splitter = RecursiveCharacterTextSplitter(
|
||
chunk_size=2000,
|
||
chunk_overlap=200,
|
||
length_function=len,
|
||
)
|
||
chunks = text_splitter.split_documents([document])
|
||
logger.debug(f"分割完成,生成 {len(chunks)} 个文档块")
|
||
|
||
# 为整个文件生成唯一的 document_id
|
||
document_id = str(uuid.uuid4())
|
||
|
||
# 添加元数据,确保包含所有必需字段
|
||
filename = os.path.basename(file_path)
|
||
upload_time = datetime.now().isoformat()
|
||
documents = []
|
||
for i, chunk in enumerate(chunks):
|
||
chunk.metadata.update({
|
||
'userid': userid,
|
||
'document_id': document_id,
|
||
'filename': filename,
|
||
'file_path': file_path,
|
||
'upload_time': upload_time,
|
||
'file_type': ext,
|
||
'chunk_index': i, # 可选,追踪分片顺序
|
||
'source': file_path, # 可选,追踪文件来源
|
||
})
|
||
# 验证元数据完整性
|
||
required_fields = ['userid', 'document_id', 'filename', 'file_path', 'upload_time', 'file_type']
|
||
if not all(field in chunk.metadata and chunk.metadata[field] for field in required_fields):
|
||
raise ValueError(f"文档元数据缺少必需字段或值为空: {chunk.metadata}")
|
||
documents.append(chunk)
|
||
logger.debug(f"生成文档块 {i}: metadata={chunk.metadata}")
|
||
|
||
logger.debug(f"文件 {file_path} 加载并分割为 {len(documents)} 个文档块,document_id: {document_id}")
|
||
return documents
|
||
|
||
except Exception as e:
|
||
logger.error(f"加载或分割文件 {file_path} 失败: {str(e)}")
|
||
import traceback
|
||
logger.debug(traceback.format_exc())
|
||
raise ValueError(f"加载或分割文件失败: {str(e)}")
|
||
|
||
def embed(file_path: str, userid: str, db_type: str) -> bool:
|
||
"""
|
||
嵌入文件到 Milvus 向量数据库。
|
||
|
||
参数:
|
||
file_path (str): 文件路径
|
||
userid (str): 用户ID
|
||
db_type (str): 数据库类型
|
||
|
||
返回:
|
||
bool: 嵌入是否成功
|
||
|
||
异常:
|
||
ValueError: 参数无效
|
||
RuntimeError: 数据库操作失败
|
||
"""
|
||
try:
|
||
# 验证输入
|
||
if not userid or not db_type:
|
||
raise ValueError("userid 和 db_type 不能为空")
|
||
if "_" in userid:
|
||
raise ValueError("userid 不能包含下划线")
|
||
if "_" in db_type:
|
||
raise ValueError("db_type 不能包含下划线")
|
||
if not os.path.exists(file_path):
|
||
raise ValueError(f"文件 {file_path} 不存在")
|
||
|
||
supported_formats = {'pdf', 'doc', 'docx', 'xlsx', 'xls', 'ppt', 'pptx', 'csv', 'txt'}
|
||
ext = file_path.rsplit('.', 1)[1].lower()
|
||
if ext not in supported_formats:
|
||
logger.error(f"文件 {file_path} 格式不支持,支持的格式: {', '.join(supported_formats)}")
|
||
raise ValueError(f"不支持的文件格式: {ext}, 支持的格式: {', '.join(supported_formats)}")
|
||
|
||
# 加载并分割文件
|
||
logger.info(f"开始处理文件 {file_path},userid: {userid},db_type: {db_type}")
|
||
chunks = load_and_split_data(file_path, userid)
|
||
if not chunks:
|
||
logger.error(f"文件 {file_path} 未生成任何文档块")
|
||
raise ValueError("未生成任何文档块")
|
||
|
||
logger.debug(f"处理文件 {file_path},生成 {len(chunks)} 个文档块")
|
||
logger.debug(f"第一个文档块: {chunks[0].page_content[:200]}")
|
||
|
||
# 插入到 Milvus
|
||
db = get_vector_db(userid, db_type, documents=chunks)
|
||
if not db:
|
||
logger.error(f"无法初始化或插入到向量数据库 ragdb_{db_type}")
|
||
raise RuntimeError(f"数据库操作失败")
|
||
|
||
logger.info(f"文件 {file_path} 成功嵌入到数据库 ragdb_{db_type}")
|
||
return True
|
||
|
||
except ValueError as ve:
|
||
logger.error(f"嵌入文件 {file_path} 失败: {str(ve)}")
|
||
return False
|
||
except RuntimeError as re:
|
||
logger.error(f"嵌入文件 {file_path} 失败: {str(re)}")
|
||
return False
|
||
except Exception as e:
|
||
logger.error(f"嵌入文件 {file_path} 失败: {str(e)}")
|
||
import traceback
|
||
logger.debug(traceback.format_exc())
|
||
return False
|
||
|
||
if __name__ == "__main__":
|
||
test_file = "/share/wangmeihua/rag/data/test.txt"
|
||
userid = "testuser4"
|
||
db_type = "textdb"
|
||
result = embed(test_file, userid, db_type)
|
||
print(f"嵌入结果: {result}")
|