rag/rag/embed.py
2025-07-16 15:06:59 +08:00

183 lines
8.1 KiB
Python
Raw Blame History

This file contains ambiguous Unicode characters

This file contains Unicode characters that might be confused with other characters. If you think that this is intentional, you can safely ignore this warning. Use the Escape button to reveal them.

import os
import uuid
import yaml
import logging
from datetime import datetime
from typing import List
from langchain_core.documents import Document
from langchain_text_splitters import RecursiveCharacterTextSplitter
from pymilvus import connections
from vector import get_vector_db
from filetxt.loader import fileloader
from extract import extract_and_save_triplets
from kgc import KnowledgeGraph
# 加载配置文件
CONFIG_PATH = os.getenv('CONFIG_PATH', '/share/wangmeihua/rag/conf/milvusconfig.yaml')
try:
with open(CONFIG_PATH, 'r', encoding='utf-8') as f:
config = yaml.safe_load(f)
MILVUS_DB_PATH = config['database']['milvus_db_path']
except Exception as e:
logger.error(f"加载配置文件 {CONFIG_PATH} 失败: {str(e)}")
raise RuntimeError(f"无法加载配置文件: {str(e)}")
# 配置日志
logger = logging.getLogger(config['logging']['name'])
logger.setLevel(getattr(logging, config['logging']['level'], logging.DEBUG))
logger.handlers.clear()
logger.propagate = False
os.makedirs(os.path.dirname(config['logging']['file']), exist_ok=True)
formatter = logging.Formatter('%(asctime)s - %(levelname)s - %(message)s')
file_handler = logging.FileHandler(config['logging']['file'], encoding='utf-8')
file_handler.setFormatter(formatter)
stream_handler = logging.StreamHandler()
stream_handler.setFormatter(formatter)
logger.addHandler(file_handler)
logger.addHandler(stream_handler)
def generate_document_id() -> str:
"""为文件生成唯一的 document_id"""
return str(uuid.uuid4())
def load_and_split_data(file_path: str, userid: str, document_id: str) -> List[Document]:
"""
加载文件,分片并生成带有元数据的 Document 对象。
"""
try:
if not os.path.exists(file_path):
raise ValueError(f"文件 {file_path} 不存在")
if os.path.getsize(file_path) == 0:
raise ValueError(f"文件 {file_path} 为空")
logger.debug(f"检查文件: {file_path}, 大小: {os.path.getsize(file_path)} 字节")
ext = file_path.rsplit('.', 1)[1].lower()
logger.debug(f"文件扩展名: {ext}")
logger.debug("开始加载文件")
text = fileloader(file_path)
if not text or not text.strip():
raise ValueError(f"文件 {file_path} 加载为空")
document = Document(page_content=text)
logger.debug(f"加载完成,生成 1 个文档")
text_splitter = RecursiveCharacterTextSplitter(
chunk_size=2000,
chunk_overlap=200,
length_function=len,
)
chunks = text_splitter.split_documents([document])
logger.debug(f"分割完成,生成 {len(chunks)} 个文档块")
filename = os.path.basename(file_path)
upload_time = datetime.now().isoformat()
documents = []
for i, chunk in enumerate(chunks):
chunk.metadata.update({
'userid': userid,
'document_id': document_id,
'filename': filename,
'file_path': file_path,
'upload_time': upload_time,
'file_type': ext,
'chunk_index': i,
'source': file_path,
})
required_fields = ['userid', 'document_id', 'filename', 'file_path', 'upload_time', 'file_type']
if not all(field in chunk.metadata and chunk.metadata[field] for field in required_fields):
raise ValueError(f"文档元数据缺少必需字段或值为空: {chunk.metadata}")
documents.append(chunk)
logger.debug(f"生成文档块 {i}: metadata={chunk.metadata}")
logger.debug(f"文件 {file_path} 加载并分割为 {len(documents)} 个文档块document_id: {document_id}")
return documents
except Exception as e:
logger.error(f"加载或分割文件 {file_path} 失败: {str(e)}")
import traceback
logger.debug(traceback.format_exc())
raise ValueError(f"加载或分割文件失败: {str(e)}")
def embed(file_path: str, userid: str, db_type: str) -> bool:
"""
嵌入文件到 Milvus 向量数据库,抽取三元组保存到指定路径,并将三元组存储到 Neo4j。
"""
try:
if not userid or not db_type:
raise ValueError("userid 和 db_type 不能为空")
if "_" in userid:
raise ValueError("userid 不能包含下划线")
if "_" in db_type:
raise ValueError("db_type 不能包含下划线")
if not os.path.exists(file_path):
raise ValueError(f"文件 {file_path} 不存在")
supported_formats = {'pdf', 'doc', 'docx', 'xlsx', 'xls', 'ppt', 'pptx', 'csv', 'txt'}
ext = file_path.rsplit('.', 1)[1].lower()
if ext not in supported_formats:
logger.error(f"文件 {file_path} 格式不支持,支持的格式: {', '.join(supported_formats)}")
raise ValueError(f"不支持的文件格式: {ext}, 支持的格式: {', '.join(supported_formats)}")
document_id = generate_document_id()
logger.info(f"生成 document_id: {document_id} for file: {file_path}")
logger.info(f"开始处理文件 {file_path}userid: {userid}db_type: {db_type}")
chunks = load_and_split_data(file_path, userid, document_id)
if not chunks:
logger.error(f"文件 {file_path} 未生成任何文档块")
raise ValueError("未生成任何文档块")
logger.debug(f"处理文件 {file_path},生成 {len(chunks)} 个文档块")
logger.debug(f"第一个文档块: {chunks[0].page_content[:200]}")
db = get_vector_db(userid, db_type, documents=chunks)
if not db:
logger.error(f"无法初始化或插入到向量数据库 ragdb_{db_type}")
raise RuntimeError(f"数据库操作失败")
try:
full_text = fileloader(file_path)
if full_text and full_text.strip():
success = extract_and_save_triplets(full_text, document_id, userid)
triplet_file_path = f"/share/wangmeihua/rag/triples/{document_id}_{userid}.txt"
if success and os.path.exists(triplet_file_path):
logger.info(f"文件 {file_path} 三元组保存到: {triplet_file_path}")
try:
kg = KnowledgeGraph(data_path=triplet_file_path, document_id=document_id)
logger.info(f"Step 1: 导入图谱节点到 Neo4jdocument_id: {document_id}")
kg.create_graphnodes()
logger.info(f"Step 2: 导入图谱边到 Neo4jdocument_id: {document_id}")
kg.create_graphrels()
logger.info(f"Step 3: 导出 Neo4j 节点数据document_id: {document_id}")
kg.export_data()
logger.info(f"文件 {file_path} 三元组成功插入 Neo4j")
except Exception as e:
logger.warning(f"将三元组插入 Neo4j 失败: {str(e)},但不影响 Milvus 嵌入")
else:
logger.warning(f"文件 {file_path} 的三元组抽取失败或文件不存在: {triplet_file_path}")
else:
logger.warning(f"文件 {file_path} 内容为空,无法抽取三元组")
except Exception as e:
logger.error(f"文件 {file_path} 三元组抽取失败: {str(e)},但不影响向量化")
logger.info(f"文件 {file_path} 成功嵌入到数据库 ragdb_{db_type}")
return True
except ValueError as ve:
logger.error(f"嵌入文件 {file_path} 失败: {str(ve)}")
return False
except RuntimeError as re:
logger.error(f"嵌入文件 {file_path} 失败: {str(re)}")
return False
except Exception as e:
logger.error(f"嵌入文件 {file_path} 失败: {str(e)}")
import traceback
logger.debug(traceback.format_exc())
return False
if __name__ == "__main__":
test_file = "/share/wangmeihua/rag/data/test.docx"
userid = "testuser1"
db_type = "textdb"
result = embed(test_file, userid, db_type)
print(f"嵌入结果: {result}")