rag/build/lib/rag/embed.py
2025-07-16 15:06:59 +08:00

179 lines
6.8 KiB
Python
Raw Blame History

This file contains ambiguous Unicode characters

This file contains Unicode characters that might be confused with other characters. If you think that this is intentional, you can safely ignore this warning. Use the Escape button to reveal them.

import os
import uuid
import yaml
import logging
from datetime import datetime
from typing import List
from langchain_core.documents import Document
from langchain_text_splitters import RecursiveCharacterTextSplitter
from pymilvus import connections
from .vector import get_vector_db
from filetxt.loader import fileloader
# 加载配置文件
CONFIG_PATH = os.getenv('CONFIG_PATH', '/share/wangmeihua/rag/conf/milvusconfig.yaml')
try:
with open(CONFIG_PATH, 'r', encoding='utf-8') as f:
config = yaml.safe_load(f)
MILVUS_DB_PATH = config['database']['milvus_db_path']
except Exception as e:
print(f"加载配置文件 {CONFIG_PATH} 失败: {str(e)}")
raise RuntimeError(f"无法加载配置文件: {str(e)}")
# 配置日志
logger = logging.getLogger(config['logging']['name'])
logger.setLevel(getattr(logging, config['logging']['level'], logging.DEBUG))
os.makedirs(os.path.dirname(config['logging']['file']), exist_ok=True)
formatter = logging.Formatter('%(asctime)s - %(levelname)s - %(message)s')
for handler in (logging.FileHandler(config['logging']['file'], encoding='utf-8'), logging.StreamHandler()):
handler.setFormatter(formatter)
logger.addHandler(handler)
def load_and_split_data(file_path: str, userid: str) -> List[Document]:
"""
加载文件,分片并生成带有元数据的 Document 对象。
参数:
file_path (str): 文件路径
userid (str): 用户ID
返回:
List[Document]: 分片后的文档列表
异常:
ValueError: 文件或参数无效
"""
try:
# 验证文件
if not os.path.exists(file_path):
raise ValueError(f"文件 {file_path} 不存在")
if os.path.getsize(file_path) == 0:
raise ValueError(f"文件 {file_path} 为空")
logger.debug(f"检查文件: {file_path}, 大小: {os.path.getsize(file_path)} 字节")
ext = file_path.rsplit('.', 1)[1].lower()
logger.debug(f"文件扩展名: {ext}")
# 使用 fileloader 加载文件内容
logger.debug("开始加载文件")
text = fileloader(file_path)
if not text or not text.strip():
raise ValueError(f"文件 {file_path} 加载为空")
# 创建单个 Document 对象
document = Document(page_content=text)
logger.debug(f"加载完成,生成 1 个文档")
# 分割文本
text_splitter = RecursiveCharacterTextSplitter(
chunk_size=2000,
chunk_overlap=200,
length_function=len,
)
chunks = text_splitter.split_documents([document])
logger.debug(f"分割完成,生成 {len(chunks)} 个文档块")
# 为整个文件生成唯一的 document_id
document_id = str(uuid.uuid4())
# 添加元数据,确保包含所有必需字段
filename = os.path.basename(file_path)
upload_time = datetime.now().isoformat()
documents = []
for i, chunk in enumerate(chunks):
chunk.metadata.update({
'userid': userid,
'document_id': document_id,
'filename': filename,
'file_path': file_path,
'upload_time': upload_time,
'file_type': ext,
'chunk_index': i, # 可选,追踪分片顺序
'source': file_path, # 可选,追踪文件来源
})
# 验证元数据完整性
required_fields = ['userid', 'document_id', 'filename', 'file_path', 'upload_time', 'file_type']
if not all(field in chunk.metadata and chunk.metadata[field] for field in required_fields):
raise ValueError(f"文档元数据缺少必需字段或值为空: {chunk.metadata}")
documents.append(chunk)
logger.debug(f"生成文档块 {i}: metadata={chunk.metadata}")
logger.debug(f"文件 {file_path} 加载并分割为 {len(documents)} 个文档块document_id: {document_id}")
return documents
except Exception as e:
logger.error(f"加载或分割文件 {file_path} 失败: {str(e)}")
import traceback
logger.debug(traceback.format_exc())
raise ValueError(f"加载或分割文件失败: {str(e)}")
def embed(file_path: str, userid: str, db_type: str) -> bool:
"""
嵌入文件到 Milvus 向量数据库。
参数:
file_path (str): 文件路径
userid (str): 用户ID
db_type (str): 数据库类型
返回:
bool: 嵌入是否成功
异常:
ValueError: 参数无效
RuntimeError: 数据库操作失败
"""
try:
# 验证输入
if not userid or not db_type:
raise ValueError("userid 和 db_type 不能为空")
if "_" in userid:
raise ValueError("userid 不能包含下划线")
if "_" in db_type:
raise ValueError("db_type 不能包含下划线")
if not os.path.exists(file_path):
raise ValueError(f"文件 {file_path} 不存在")
supported_formats = {'pdf', 'doc', 'docx', 'xlsx', 'xls', 'ppt', 'pptx', 'csv', 'txt'}
ext = file_path.rsplit('.', 1)[1].lower()
if ext not in supported_formats:
logger.error(f"文件 {file_path} 格式不支持,支持的格式: {', '.join(supported_formats)}")
raise ValueError(f"不支持的文件格式: {ext}, 支持的格式: {', '.join(supported_formats)}")
# 加载并分割文件
logger.info(f"开始处理文件 {file_path}userid: {userid}db_type: {db_type}")
chunks = load_and_split_data(file_path, userid)
if not chunks:
logger.error(f"文件 {file_path} 未生成任何文档块")
raise ValueError("未生成任何文档块")
logger.debug(f"处理文件 {file_path},生成 {len(chunks)} 个文档块")
logger.debug(f"第一个文档块: {chunks[0].page_content[:200]}")
# 插入到 Milvus
db = get_vector_db(userid, db_type, documents=chunks)
if not db:
logger.error(f"无法初始化或插入到向量数据库 ragdb_{db_type}")
raise RuntimeError(f"数据库操作失败")
logger.info(f"文件 {file_path} 成功嵌入到数据库 ragdb_{db_type}")
return True
except ValueError as ve:
logger.error(f"嵌入文件 {file_path} 失败: {str(ve)}")
return False
except RuntimeError as re:
logger.error(f"嵌入文件 {file_path} 失败: {str(re)}")
return False
except Exception as e:
logger.error(f"嵌入文件 {file_path} 失败: {str(e)}")
import traceback
logger.debug(traceback.format_exc())
return False
if __name__ == "__main__":
test_file = "/share/wangmeihua/rag/data/test.txt"
userid = "testuser4"
db_type = "textdb"
result = embed(test_file, userid, db_type)
print(f"嵌入结果: {result}")