183 lines
8.1 KiB
Python
183 lines
8.1 KiB
Python
import os
|
||
import uuid
|
||
import yaml
|
||
import logging
|
||
from datetime import datetime
|
||
from typing import List
|
||
from langchain_core.documents import Document
|
||
from langchain_text_splitters import RecursiveCharacterTextSplitter
|
||
from pymilvus import connections
|
||
from vector import get_vector_db
|
||
from filetxt.loader import fileloader
|
||
from extract import extract_and_save_triplets
|
||
from kgc import KnowledgeGraph
|
||
|
||
# 加载配置文件
|
||
CONFIG_PATH = os.getenv('CONFIG_PATH', '/share/wangmeihua/rag/conf/milvusconfig.yaml')
|
||
try:
|
||
with open(CONFIG_PATH, 'r', encoding='utf-8') as f:
|
||
config = yaml.safe_load(f)
|
||
MILVUS_DB_PATH = config['database']['milvus_db_path']
|
||
except Exception as e:
|
||
logger.error(f"加载配置文件 {CONFIG_PATH} 失败: {str(e)}")
|
||
raise RuntimeError(f"无法加载配置文件: {str(e)}")
|
||
|
||
# 配置日志
|
||
logger = logging.getLogger(config['logging']['name'])
|
||
logger.setLevel(getattr(logging, config['logging']['level'], logging.DEBUG))
|
||
logger.handlers.clear()
|
||
logger.propagate = False
|
||
os.makedirs(os.path.dirname(config['logging']['file']), exist_ok=True)
|
||
formatter = logging.Formatter('%(asctime)s - %(levelname)s - %(message)s')
|
||
file_handler = logging.FileHandler(config['logging']['file'], encoding='utf-8')
|
||
file_handler.setFormatter(formatter)
|
||
stream_handler = logging.StreamHandler()
|
||
stream_handler.setFormatter(formatter)
|
||
logger.addHandler(file_handler)
|
||
logger.addHandler(stream_handler)
|
||
|
||
def generate_document_id() -> str:
|
||
"""为文件生成唯一的 document_id"""
|
||
return str(uuid.uuid4())
|
||
|
||
def load_and_split_data(file_path: str, userid: str, document_id: str) -> List[Document]:
|
||
"""
|
||
加载文件,分片并生成带有元数据的 Document 对象。
|
||
"""
|
||
try:
|
||
if not os.path.exists(file_path):
|
||
raise ValueError(f"文件 {file_path} 不存在")
|
||
if os.path.getsize(file_path) == 0:
|
||
raise ValueError(f"文件 {file_path} 为空")
|
||
logger.debug(f"检查文件: {file_path}, 大小: {os.path.getsize(file_path)} 字节")
|
||
ext = file_path.rsplit('.', 1)[1].lower()
|
||
logger.debug(f"文件扩展名: {ext}")
|
||
|
||
logger.debug("开始加载文件")
|
||
text = fileloader(file_path)
|
||
if not text or not text.strip():
|
||
raise ValueError(f"文件 {file_path} 加载为空")
|
||
|
||
document = Document(page_content=text)
|
||
logger.debug(f"加载完成,生成 1 个文档")
|
||
|
||
text_splitter = RecursiveCharacterTextSplitter(
|
||
chunk_size=2000,
|
||
chunk_overlap=200,
|
||
length_function=len,
|
||
)
|
||
chunks = text_splitter.split_documents([document])
|
||
logger.debug(f"分割完成,生成 {len(chunks)} 个文档块")
|
||
|
||
filename = os.path.basename(file_path)
|
||
upload_time = datetime.now().isoformat()
|
||
documents = []
|
||
for i, chunk in enumerate(chunks):
|
||
chunk.metadata.update({
|
||
'userid': userid,
|
||
'document_id': document_id,
|
||
'filename': filename,
|
||
'file_path': file_path,
|
||
'upload_time': upload_time,
|
||
'file_type': ext,
|
||
'chunk_index': i,
|
||
'source': file_path,
|
||
})
|
||
required_fields = ['userid', 'document_id', 'filename', 'file_path', 'upload_time', 'file_type']
|
||
if not all(field in chunk.metadata and chunk.metadata[field] for field in required_fields):
|
||
raise ValueError(f"文档元数据缺少必需字段或值为空: {chunk.metadata}")
|
||
documents.append(chunk)
|
||
logger.debug(f"生成文档块 {i}: metadata={chunk.metadata}")
|
||
|
||
logger.debug(f"文件 {file_path} 加载并分割为 {len(documents)} 个文档块,document_id: {document_id}")
|
||
return documents
|
||
|
||
except Exception as e:
|
||
logger.error(f"加载或分割文件 {file_path} 失败: {str(e)}")
|
||
import traceback
|
||
logger.debug(traceback.format_exc())
|
||
raise ValueError(f"加载或分割文件失败: {str(e)}")
|
||
|
||
def embed(file_path: str, userid: str, db_type: str) -> bool:
|
||
"""
|
||
嵌入文件到 Milvus 向量数据库,抽取三元组保存到指定路径,并将三元组存储到 Neo4j。
|
||
"""
|
||
try:
|
||
if not userid or not db_type:
|
||
raise ValueError("userid 和 db_type 不能为空")
|
||
if "_" in userid:
|
||
raise ValueError("userid 不能包含下划线")
|
||
if "_" in db_type:
|
||
raise ValueError("db_type 不能包含下划线")
|
||
if not os.path.exists(file_path):
|
||
raise ValueError(f"文件 {file_path} 不存在")
|
||
|
||
supported_formats = {'pdf', 'doc', 'docx', 'xlsx', 'xls', 'ppt', 'pptx', 'csv', 'txt'}
|
||
ext = file_path.rsplit('.', 1)[1].lower()
|
||
if ext not in supported_formats:
|
||
logger.error(f"文件 {file_path} 格式不支持,支持的格式: {', '.join(supported_formats)}")
|
||
raise ValueError(f"不支持的文件格式: {ext}, 支持的格式: {', '.join(supported_formats)}")
|
||
|
||
document_id = generate_document_id()
|
||
logger.info(f"生成 document_id: {document_id} for file: {file_path}")
|
||
|
||
logger.info(f"开始处理文件 {file_path},userid: {userid},db_type: {db_type}")
|
||
chunks = load_and_split_data(file_path, userid, document_id)
|
||
if not chunks:
|
||
logger.error(f"文件 {file_path} 未生成任何文档块")
|
||
raise ValueError("未生成任何文档块")
|
||
|
||
logger.debug(f"处理文件 {file_path},生成 {len(chunks)} 个文档块")
|
||
logger.debug(f"第一个文档块: {chunks[0].page_content[:200]}")
|
||
|
||
db = get_vector_db(userid, db_type, documents=chunks)
|
||
if not db:
|
||
logger.error(f"无法初始化或插入到向量数据库 ragdb_{db_type}")
|
||
raise RuntimeError(f"数据库操作失败")
|
||
|
||
try:
|
||
full_text = fileloader(file_path)
|
||
if full_text and full_text.strip():
|
||
success = extract_and_save_triplets(full_text, document_id, userid)
|
||
triplet_file_path = f"/share/wangmeihua/rag/triples/{document_id}_{userid}.txt"
|
||
if success and os.path.exists(triplet_file_path):
|
||
logger.info(f"文件 {file_path} 三元组保存到: {triplet_file_path}")
|
||
try:
|
||
kg = KnowledgeGraph(data_path=triplet_file_path, document_id=document_id)
|
||
logger.info(f"Step 1: 导入图谱节点到 Neo4j,document_id: {document_id}")
|
||
kg.create_graphnodes()
|
||
logger.info(f"Step 2: 导入图谱边到 Neo4j,document_id: {document_id}")
|
||
kg.create_graphrels()
|
||
logger.info(f"Step 3: 导出 Neo4j 节点数据,document_id: {document_id}")
|
||
kg.export_data()
|
||
logger.info(f"文件 {file_path} 三元组成功插入 Neo4j")
|
||
except Exception as e:
|
||
logger.warning(f"将三元组插入 Neo4j 失败: {str(e)},但不影响 Milvus 嵌入")
|
||
else:
|
||
logger.warning(f"文件 {file_path} 的三元组抽取失败或文件不存在: {triplet_file_path}")
|
||
else:
|
||
logger.warning(f"文件 {file_path} 内容为空,无法抽取三元组")
|
||
except Exception as e:
|
||
logger.error(f"文件 {file_path} 三元组抽取失败: {str(e)},但不影响向量化")
|
||
|
||
logger.info(f"文件 {file_path} 成功嵌入到数据库 ragdb_{db_type}")
|
||
return True
|
||
|
||
except ValueError as ve:
|
||
logger.error(f"嵌入文件 {file_path} 失败: {str(ve)}")
|
||
return False
|
||
except RuntimeError as re:
|
||
logger.error(f"嵌入文件 {file_path} 失败: {str(re)}")
|
||
return False
|
||
except Exception as e:
|
||
logger.error(f"嵌入文件 {file_path} 失败: {str(e)}")
|
||
import traceback
|
||
logger.debug(traceback.format_exc())
|
||
return False
|
||
|
||
if __name__ == "__main__":
|
||
test_file = "/share/wangmeihua/rag/data/test.docx"
|
||
userid = "testuser1"
|
||
db_type = "textdb"
|
||
result = embed(test_file, userid, db_type)
|
||
print(f"嵌入结果: {result}") |