899 lines
38 KiB
Python
899 lines
38 KiB
Python
import os
|
||
import re
|
||
import time
|
||
import math
|
||
from datetime import datetime
|
||
from typing import List, Dict, Any, Optional
|
||
from langchain_core.documents import Document
|
||
from langchain_text_splitters import RecursiveCharacterTextSplitter
|
||
|
||
from appPublic.log import debug, error, info
|
||
from filetxt.loader import fileloader, File2Text
|
||
from rag.uapi_service import APIService
|
||
from rag.service_opts import get_service_params
|
||
from rag.transaction_manager import TransactionManager, OperationType
|
||
|
||
|
||
class RagOperations:
|
||
"""RAG 操作类,提供所有通用的 RAG 操作"""
|
||
|
||
def __init__(self):
|
||
self.api_service = APIService()
|
||
|
||
async def load_and_chunk_document(self, realpath: str, timings: Dict,
|
||
transaction_mgr: TransactionManager = None) -> List[Document]:
|
||
"""加载文件并进行文本分片"""
|
||
debug(f"加载文件: {realpath}")
|
||
start_load = time.time()
|
||
|
||
# 检查文件格式支持
|
||
supported_formats = File2Text.supported_types()
|
||
debug(f"支持的文件格式:{supported_formats}")
|
||
ext = realpath.rsplit('.', 1)[1].lower() if '.' in realpath else ''
|
||
if ext not in supported_formats:
|
||
raise ValueError(f"不支持的文件格式: {ext}, 支持的格式: {', '.join(supported_formats)}")
|
||
|
||
# 加载文件内容
|
||
text = fileloader(realpath)
|
||
text = re.sub(r'[^\u4e00-\u9fa5a-zA-Z0-9\s.;,\n/]', '', text)
|
||
timings["load_file"] = time.time() - start_load
|
||
debug(f"加载文件耗时: {timings['load_file']:.2f} 秒, 文本长度: {len(text)}")
|
||
|
||
if not text or not text.strip():
|
||
raise ValueError(f"文件 {realpath} 加载为空")
|
||
|
||
# 分片处理
|
||
document = Document(page_content=text)
|
||
text_splitter = RecursiveCharacterTextSplitter(
|
||
chunk_size=500,
|
||
chunk_overlap=100,
|
||
length_function=len
|
||
)
|
||
debug("开始分片文件内容")
|
||
start_split = time.time()
|
||
chunks = text_splitter.split_documents([document])
|
||
timings["split_text"] = time.time() - start_split
|
||
debug(f"文本分片耗时: {timings['split_text']:.2f} 秒, 分片数量: {len(chunks)}")
|
||
|
||
if not chunks:
|
||
raise ValueError(f"文件 {realpath} 未生成任何文档块")
|
||
|
||
# 记录事务操作
|
||
if transaction_mgr:
|
||
transaction_mgr.add_operation(
|
||
OperationType.FILE_LOAD,
|
||
{'realpath': realpath, 'chunks_count': len(chunks)}
|
||
)
|
||
|
||
return chunks
|
||
|
||
async def generate_embeddings(self, request, chunks: List[Document], service_params: Dict,
|
||
userid: str, timings: Dict,
|
||
transaction_mgr: TransactionManager = None) -> List[List[float]]:
|
||
"""生成嵌入向量"""
|
||
debug("调用嵌入服务生成向量")
|
||
start_embedding = time.time()
|
||
texts = [chunk.page_content for chunk in chunks]
|
||
embeddings = []
|
||
|
||
# 批量处理嵌入
|
||
for i in range(0, len(texts), 10):
|
||
batch_texts = texts[i:i + 10]
|
||
batch_embeddings = await self.api_service.get_embeddings(
|
||
request=request,
|
||
texts=batch_texts,
|
||
upappid=service_params['embedding'],
|
||
apiname="BAAI/bge-m3",
|
||
user=userid
|
||
)
|
||
embeddings.extend(batch_embeddings)
|
||
|
||
if not embeddings or not all(len(vec) == 1024 for vec in embeddings):
|
||
raise ValueError("所有嵌入向量必须是长度为 1024 的浮点数列表")
|
||
|
||
timings["generate_embeddings"] = time.time() - start_embedding
|
||
debug(f"生成嵌入向量耗时: {timings['generate_embeddings']:.2f} 秒, 嵌入数量: {len(embeddings)}")
|
||
|
||
# 记录事务操作
|
||
if transaction_mgr:
|
||
transaction_mgr.add_operation(
|
||
OperationType.EMBEDDING,
|
||
{'embeddings_count': len(embeddings)}
|
||
)
|
||
|
||
return embeddings
|
||
|
||
async def generate_multi_embeddings(self, request, inputs: List[Dict], service_params: Dict,
|
||
userid: str, timings: Dict,
|
||
transaction_mgr: TransactionManager = None) -> Dict[str, Dict]:
|
||
"""调用多模态嵌入服务(CLIP)"""
|
||
debug("调用多模态嵌入服务")
|
||
start = time.time()
|
||
|
||
result = await self.api_service.get_multi_embeddings(
|
||
request=request,
|
||
inputs=inputs,
|
||
upappid=service_params['embedding'],
|
||
apiname="black/clip",
|
||
user=userid
|
||
)
|
||
debug(f"多模态返回结果是{result}")
|
||
timings["multi_embedding"] = time.time() - start
|
||
debug(f"多模态嵌入耗时: {timings['multi_embedding']:.2f}秒,处理 {len(result)} 条")
|
||
|
||
# ==================== 新增:错误检查 + 过滤 ====================
|
||
valid_results = {}
|
||
error_count = 0
|
||
error_examples = []
|
||
|
||
for key, info in result.items():
|
||
if info.get("type") == "error":
|
||
error_count += 1
|
||
if len(error_examples) < 3: # 只记录前3个
|
||
error_examples.append(f"{key} → {info['error']}")
|
||
# 直接丢弃错误条目
|
||
continue
|
||
valid_results[key] = info
|
||
|
||
if error_count > 0:
|
||
error(f"多模态嵌入失败 {error_count} 条!示例:{'; '.join(error_examples)}")
|
||
raise RuntimeError(f"多模态嵌入有{error_count} 条失败")
|
||
else:
|
||
debug("多模态嵌入全部成功!")
|
||
|
||
if transaction_mgr:
|
||
transaction_mgr.add_operation(
|
||
OperationType.EMBEDDING,
|
||
{'count': len(result)}
|
||
)
|
||
|
||
return result
|
||
|
||
# 统一插入向量库
|
||
async def insert_all_vectors(
|
||
self,
|
||
request,
|
||
text_chunks: List[Document],
|
||
realpath: str,
|
||
orgid: str,
|
||
fiid: str,
|
||
document_id: str,
|
||
service_params: Dict,
|
||
userid: str,
|
||
db_type: str,
|
||
timings: Dict,
|
||
img_paths: List[str] = None,
|
||
text_embeddings: List[List[float]] = None,
|
||
multi_results: Dict = None,
|
||
transaction_mgr: TransactionManager = None
|
||
) -> Dict[str, int]:
|
||
"""
|
||
统一插入函数:支持两种模式
|
||
1. 纯文本模式:text_embeddings 有值
|
||
2. 多模态模式:multi_results 有值(来自 generate_multi_embeddings)
|
||
"""
|
||
img_paths = img_paths or []
|
||
all_chunks = []
|
||
start = time.time()
|
||
filename = os.path.basename(realpath)
|
||
upload_time = datetime.now().isoformat()
|
||
|
||
# ==================== 1. 纯文本模式(BGE) ====================
|
||
if text_embeddings is not None:
|
||
debug(f"【纯文本模式】插入 {len(text_embeddings)} 条文本向量")
|
||
for i, chunk in enumerate(text_chunks):
|
||
all_chunks.append({
|
||
"userid": orgid,
|
||
"knowledge_base_id": fiid,
|
||
"text": chunk.page_content,
|
||
"vector": text_embeddings[i],
|
||
"document_id": document_id,
|
||
"filename": filename,
|
||
"file_path": realpath,
|
||
"upload_time": upload_time,
|
||
"file_type": "text",
|
||
})
|
||
|
||
# ==================== 2. 多模态模式(CLIP 混排) ====================
|
||
if multi_results is not None:
|
||
debug(f"【多模态模式】解析 {len(multi_results)} 条 CLIP 结果")
|
||
# 遍历 multi_results
|
||
for raw_key, info in multi_results.items():
|
||
typ = info["type"]
|
||
|
||
# --- 文本 ---
|
||
if typ == "text":
|
||
# raw_key 就是原文
|
||
all_chunks.append({
|
||
"userid": orgid,
|
||
"knowledge_base_id": fiid,
|
||
"text": raw_key,
|
||
"vector": info["vector"],
|
||
"document_id": document_id,
|
||
"filename": filename,
|
||
"file_path": realpath,
|
||
"upload_time": upload_time,
|
||
"file_type": "text",
|
||
})
|
||
continue
|
||
|
||
# --- 图像 ---
|
||
if typ == "image":
|
||
img_path = info.get("path") or raw_key
|
||
img_name = os.path.basename(img_path)
|
||
|
||
# 整图向量
|
||
if "vector" in info:
|
||
all_chunks.append({
|
||
"userid": orgid,
|
||
"knowledge_base_id": fiid,
|
||
"text": f"[Image: {img_path}]图片来源于文件{realpath}",
|
||
"vector": info["vector"],
|
||
"document_id": document_id,
|
||
"filename": img_name,
|
||
"file_path": realpath,
|
||
"upload_time": upload_time,
|
||
"file_type": "image",
|
||
})
|
||
|
||
# 人脸向量
|
||
face_vecs = info.get("face_vecs", [])
|
||
face_count = len(face_vecs)
|
||
# if face_count > 0:
|
||
# for f_idx, fvec in enumerate(face_vecs):
|
||
# debug(f"人脸向量维度是:{len(fvec)}")
|
||
# all_chunks.append({
|
||
# "userid": orgid,
|
||
# "knowledge_base_id": fiid,
|
||
# "text": f"[Face {f_idx + 1}/{face_count} in {img_name}]人脸来源于{realpath}的{img_path}图片",
|
||
# "vector": fvec,
|
||
# "document_id": document_id,
|
||
# "filename": img_name,
|
||
# "file_path": realpath,
|
||
# "upload_time": upload_time,
|
||
# "file_type": "face",
|
||
# })
|
||
# continue
|
||
|
||
# --- 视频 ---
|
||
if typ == "video":
|
||
video_path = info.get("path") or raw_key
|
||
video_name = os.path.basename(video_path)
|
||
|
||
if "vector" in info:
|
||
all_chunks.append({
|
||
"userid": orgid,
|
||
"knowledge_base_id": fiid,
|
||
"text": f"[Video: {video_name}]",
|
||
"vector": info["vector"],
|
||
"document_id": document_id,
|
||
"filename": video_path,
|
||
"file_path": realpath,
|
||
"upload_time": upload_time,
|
||
"file_type": "video",
|
||
})
|
||
|
||
# 视频人脸
|
||
face_vecs = info.get("face_vecs", [])
|
||
face_count = len(face_vecs)
|
||
# if face_count > 0 :
|
||
# for f_idx, fvec in enumerate(face_vecs):
|
||
# all_chunks.append({
|
||
# "userid": orgid,
|
||
# "knowledge_base_id": fiid,
|
||
# "text": f"[Face {f_idx + 1}/{face_count} in video {video_name}]来源于{video_path}",
|
||
# "vector": fvec,
|
||
# "document_id": document_id,
|
||
# "filename": video_path,
|
||
# "file_path": realpath,
|
||
# "upload_time": upload_time,
|
||
# "file_type": "face",
|
||
# })
|
||
# continue
|
||
|
||
# --- 音频 ---
|
||
if typ == "audio":
|
||
audio_path = info.get("path") or raw_key
|
||
audio_name = os.path.basename(audio_path)
|
||
|
||
if "vector" in info:
|
||
all_chunks.append({
|
||
"userid": orgid,
|
||
"knowledge_base_id": fiid,
|
||
"text": f"[Audio: {audio_name}]",
|
||
"vector": info["vector"],
|
||
"document_id": document_id,
|
||
"filename": audio_path,
|
||
"file_path": realpath,
|
||
"upload_time": upload_time,
|
||
"file_type": "audio",
|
||
})
|
||
continue
|
||
|
||
# --- 未知类型 ---
|
||
debug(f"未知类型跳过: {typ} → {raw_key}")
|
||
|
||
# ==================== 3. 批量插入 Milvus ====================
|
||
if not all_chunks:
|
||
debug("无向量需要插入")
|
||
return {"text": 0, "image": 0, "face": 0}
|
||
|
||
for i in range(0, len(all_chunks), 10):
|
||
batch = all_chunks[i:i + 10]
|
||
result = await self.api_service.milvus_insert_document(
|
||
request=request,
|
||
chunks=batch,
|
||
upappid=service_params['vdb'],
|
||
apiname="milvus/insertdocument",
|
||
user=userid,
|
||
db_type=db_type
|
||
)
|
||
if result.get("status") != "success":
|
||
raise ValueError(f"Milvus 插入失败: {result.get('message')}")
|
||
|
||
# ==================== 4. 统一回滚(只登记一次) ====================
|
||
if transaction_mgr and all_chunks:
|
||
async def rollback_all(data, context):
|
||
try:
|
||
await self.delete_from_vector_db(
|
||
request=context['request'],
|
||
orgid=data['orgid'],
|
||
realpath=data['realpath'],
|
||
fiid=data['fiid'],
|
||
id=data['document_id'],
|
||
service_params=context['service_params'],
|
||
userid=context['userid'],
|
||
db_type=data['db_type']
|
||
)
|
||
return f"已回滚 document_id={data['document_id']} 的所有向量"
|
||
except Exception as e:
|
||
error(f"统一回滚失败: {e}")
|
||
raise
|
||
|
||
transaction_mgr.add_operation(
|
||
OperationType.VDB_INSERT,
|
||
{
|
||
'orgid': orgid,
|
||
'realpath': realpath,
|
||
'fiid': fiid,
|
||
'id': document_id,
|
||
'db_type': db_type
|
||
},
|
||
rollback_func=rollback_all
|
||
)
|
||
|
||
# ==================== 5. 统计返回 ====================
|
||
stats = {
|
||
"text": len([c for c in all_chunks if c["file_type"] == "text"]),
|
||
"image": len([c for c in all_chunks if c["file_type"] == "image"]),
|
||
"face": len([c for c in all_chunks if c["file_type"] == "face"])
|
||
}
|
||
|
||
timings["insert_all"] = time.time() - start
|
||
debug(
|
||
f"统一插入完成: 文本 {stats['text']}, 图像 {stats['image']}, 人脸 {stats['face']}, 耗时 {timings['insert_all']:.2f}s")
|
||
return stats
|
||
# async def insert_to_vector_db(self, request, chunks: List[Document], embeddings: List[List[float]],
|
||
# realpath: str, orgid: str, fiid: str, id: str, service_params: Dict,
|
||
# userid: str, db_type: str, timings: Dict,
|
||
# transaction_mgr: TransactionManager = None):
|
||
# """插入向量数据库"""
|
||
# debug(f"准备数据并调用插入文件端点: {realpath}")
|
||
# filename = os.path.basename(realpath).rsplit('.', 1)[0]
|
||
# ext = realpath.rsplit('.', 1)[1].lower() if '.' in realpath else ''
|
||
# upload_time = datetime.now().isoformat()
|
||
#
|
||
# chunks_data = [
|
||
# {
|
||
# "userid": orgid,
|
||
# "knowledge_base_id": fiid,
|
||
# "text": chunk.page_content,
|
||
# "vector": embeddings[i],
|
||
# "document_id": id,
|
||
# "filename": filename + '.' + ext,
|
||
# "file_path": realpath,
|
||
# "upload_time": upload_time,
|
||
# "file_type": ext,
|
||
# }
|
||
# for i, chunk in enumerate(chunks)
|
||
# ]
|
||
#
|
||
# start_milvus = time.time()
|
||
# for i in range(0, len(chunks_data), 10):
|
||
# batch_chunks = chunks_data[i:i + 10]
|
||
# debug(f"传入的数据是:{batch_chunks}")
|
||
# result = await self.api_service.milvus_insert_document(
|
||
# request=request,
|
||
# chunks=batch_chunks,
|
||
# db_type=db_type,
|
||
# upappid=service_params['vdb'],
|
||
# apiname="milvus/insertdocument",
|
||
# user=userid
|
||
# )
|
||
# if result.get("status") != "success":
|
||
# raise ValueError(result.get("message", "Milvus 插入失败"))
|
||
#
|
||
# timings["insert_milvus"] = time.time() - start_milvus
|
||
# debug(f"Milvus 插入耗时: {timings['insert_milvus']:.2f} 秒")
|
||
#
|
||
# # 记录事务操作,包含回滚函数
|
||
# if transaction_mgr:
|
||
# async def rollback_vdb_insert(data, context):
|
||
# try:
|
||
# # 防御性检查
|
||
# required_context = ['request', 'service_params', 'userid']
|
||
# missing_context = [k for k in required_context if k not in context or context[k] is None]
|
||
# if missing_context:
|
||
# raise ValueError(f"回滚上下文缺少字段: {', '.join(missing_context)}")
|
||
#
|
||
# required_data = ['orgid', 'realpath', 'fiid', 'id', 'db_type']
|
||
# missing_data = [k for k in required_data if k not in data or data[k] is None]
|
||
# if missing_data:
|
||
# raise ValueError(f"VDB_INSERT 数据缺少字段: {', '.join(missing_data)}")
|
||
#
|
||
# await self.delete_from_vector_db(
|
||
# context['request'], data['orgid'], data['realpath'],
|
||
# data['fiid'], data['id'], context['service_params'],
|
||
# context['userid'], data['db_type']
|
||
# )
|
||
# return f"已回滚向量数据库插入: {data['id']}"
|
||
# except Exception as e:
|
||
# error(f"回滚向量数据库失败: document_id={data.get('id', '未知')}, 错误: {str(e)}")
|
||
# raise
|
||
#
|
||
# transaction_mgr.add_operation(
|
||
# OperationType.VDB_INSERT,
|
||
# {
|
||
# 'orgid': orgid, 'realpath': realpath, 'fiid': fiid,
|
||
# 'id': id, 'db_type': db_type
|
||
# },
|
||
# rollback_func=rollback_vdb_insert
|
||
# )
|
||
#
|
||
# return chunks_data
|
||
#
|
||
# async def insert_image_vectors(
|
||
# self,
|
||
# request,
|
||
# multi_results: Dict[str, Dict],
|
||
# realpath: str,
|
||
# orgid: str,
|
||
# fiid: str,
|
||
# document_id: str,
|
||
# service_params: Dict,
|
||
# userid: str,
|
||
# db_type: str,
|
||
# timings: Dict,
|
||
# transaction_mgr: TransactionManager = None
|
||
# ) -> tuple[int, int]:
|
||
#
|
||
# start = time.time()
|
||
# image_chunks = []
|
||
# face_chunks = []
|
||
#
|
||
# for img_path, info in multi_results.items():
|
||
# # img_name = os.path.basename(img_path)
|
||
#
|
||
# # 1. 插入整张图
|
||
# if info.get("type") in ["image", "video"] and "vector" in info:
|
||
# image_chunks.append({
|
||
# "userid": orgid,
|
||
# "knowledge_base_id": fiid,
|
||
# "text": f"[Image: {img_path}]",
|
||
# "vector": info["vector"],
|
||
# "document_id": document_id,
|
||
# "filename": os.path.basename(realpath),
|
||
# "file_path": realpath,
|
||
# "upload_time": datetime.now().isoformat(),
|
||
# "file_type": "image"
|
||
# })
|
||
#
|
||
# # 2. 插入每张人脸
|
||
# face_vecs = info.get("face_vecs")
|
||
# face_count = info.get("face_count", 0)
|
||
#
|
||
# if face_count > 0 and face_vecs and len(face_vecs) == face_count:
|
||
# for idx, face_vec in enumerate(face_vecs):
|
||
# face_chunks.append({
|
||
# "userid": orgid,
|
||
# "knowledge_base_id": fiid,
|
||
# "text": f"[Face {idx + 1}/{face_count} in {img_path}]",
|
||
# "vector": face_vec,
|
||
# "document_id": document_id,
|
||
# "filename": os.path.basename(realpath),
|
||
# "file_path": realpath,
|
||
# "upload_time": datetime.now().isoformat(),
|
||
# "file_type": "face",
|
||
# })
|
||
#
|
||
# if image_chunks:
|
||
# for i in range(0, len(image_chunks), 10):
|
||
# await self.api_service.milvus_insert_document(
|
||
# request=request,
|
||
# chunks=image_chunks[i:i + 10],
|
||
# upappid=service_params['vdb'],
|
||
# apiname="milvus/insertdocument",
|
||
# user=userid,
|
||
# db_type=db_type
|
||
# )
|
||
#
|
||
# if face_chunks:
|
||
# for i in range(0, len(face_chunks), 10):
|
||
# await self.api_service.milvus_insert_document(
|
||
# request=request,
|
||
# chunks=face_chunks[i:i + 10],
|
||
# upappid=service_params['vdb'],
|
||
# apiname="milvus/insertdocument",
|
||
# user=userid,
|
||
# db_type=db_type
|
||
# )
|
||
# timings["insert_images"] = time.time() - start
|
||
# image_count = len(image_chunks)
|
||
# face_count = len(face_chunks)
|
||
#
|
||
# debug(f"多模态插入完成: 图像 {image_count} 条, 人脸 {face_count} 条")
|
||
#
|
||
# if transaction_mgr and (image_count + face_count > 0):
|
||
# transaction_mgr.add_operation(
|
||
# OperationType.IMAGE_VECTORS_INSERT,
|
||
# {"images": image_count, "faces": face_count, "document_id": document_id}
|
||
# )
|
||
#
|
||
# # 记录事务操作,包含回滚函数
|
||
# if transaction_mgr:
|
||
# async def rollback_multimodal(data, context):
|
||
# try:
|
||
# # 防御性检查
|
||
# required_context = ['request', 'service_params', 'userid']
|
||
# missing_context = [k for k in required_context if k not in context or context[k] is None]
|
||
# if missing_context:
|
||
# raise ValueError(f"回滚上下文缺少字段: {', '.join(missing_context)}")
|
||
#
|
||
# required_data = ['orgid', 'realpath', 'fiid', 'id', 'db_type']
|
||
# missing_data = [k for k in required_data if k not in data or data[k] is None]
|
||
# if missing_data:
|
||
# raise ValueError(f"多模态回滚数据缺少字段: {', '.join(missing_data)}")
|
||
#
|
||
# await self.delete_from_vector_db(
|
||
# context['request'], data['orgid'], data['realpath'],
|
||
# data['fiid'], data['id'], context['service_params'],
|
||
# context['userid'], data['db_type']
|
||
# )
|
||
# return f"已回滚多模态向量: {data['id']}"
|
||
# except Exception as e:
|
||
# error(f"多模态回滚向量数据库失败: document_id={data.get('id', '未知')}, 错误: {str(e)}")
|
||
# raise
|
||
#
|
||
# transaction_mgr.add_operation(
|
||
# OperationType.VDB_INSERT,
|
||
# {
|
||
# 'orgid': orgid, 'realpath': realpath, 'fiid': fiid,
|
||
# 'id': id, 'db_type': db_type
|
||
# },
|
||
# rollback_func=rollback_multimodal
|
||
# )
|
||
#
|
||
# return image_count, face_count
|
||
|
||
async def insert_to_vector_text(self, request,
|
||
db_type: str, fields: Dict, service_params: Dict, userid: str, timings: Dict) -> List[Dict]:
|
||
"""插入单一纯文本到向量数据库,支持动态 schema"""
|
||
|
||
chunk_data = {}
|
||
debug("准备单一纯文本数据并调用插入端点")
|
||
start = time.time()
|
||
for key, value in fields.items():
|
||
chunk_data[key] = value
|
||
|
||
chunks_data = [chunk_data]
|
||
debug(f"向量库插入传入的数据是:{chunks_data}")
|
||
|
||
# 调用 Milvus 插入
|
||
result = await self.api_service.milvus_insert_document(
|
||
request=request,
|
||
chunks=chunks_data,
|
||
upappid=service_params['vdb'],
|
||
apiname="milvus/insertdocument",
|
||
user=userid,
|
||
db_type=db_type
|
||
)
|
||
if result.get("status") != "success":
|
||
raise ValueError(result.get("message", "Milvus 插入失败"))
|
||
|
||
debug(f"成功插入纯文本到集合 {result.get('collection_name')}")
|
||
timings["textinsert"] = time.time() - start
|
||
debug(f"插入纯文本耗时: {timings['textinsert']:.2f} 秒")
|
||
|
||
return chunks_data
|
||
|
||
async def extract_triples(self, request, chunks: List[Document], service_params: Dict,
|
||
userid: str, timings: Dict,
|
||
transaction_mgr: TransactionManager = None) -> List[Dict]:
|
||
"""抽取三元组"""
|
||
debug("调用三元组抽取服务")
|
||
start_triples = time.time()
|
||
chunk_texts = [doc.page_content for doc in chunks]
|
||
triples = []
|
||
|
||
for i, chunk in enumerate(chunk_texts):
|
||
result = await self.api_service.extract_triples(
|
||
request=request,
|
||
text=chunk,
|
||
upappid=service_params['triples'],
|
||
apiname="Babelscape/mrebel-large",
|
||
user=userid
|
||
)
|
||
if isinstance(result, list):
|
||
triples.extend(result)
|
||
debug(f"分片 {i + 1} 抽取到 {len(result)} 个三元组")
|
||
else:
|
||
error(f"分片 {i + 1} 处理失败: {str(result)}")
|
||
|
||
# 去重和优化三元组
|
||
unique_triples = self._deduplicate_triples(triples)
|
||
|
||
timings["extract_triples"] = time.time() - start_triples
|
||
debug(f"三元组抽取耗时: {timings['extract_triples']:.2f} 秒, 抽取到 {len(unique_triples)} 个三元组")
|
||
|
||
# 记录事务操作
|
||
if transaction_mgr:
|
||
transaction_mgr.add_operation(
|
||
OperationType.TRIPLES_EXTRACT,
|
||
{'triples_count': len(unique_triples)}
|
||
)
|
||
|
||
return unique_triples
|
||
|
||
async def insert_to_graph_db(self, request, triples: List[Dict], id: str, fiid: str,
|
||
orgid: str, service_params: Dict, userid: str, timings: Dict,
|
||
transaction_mgr: TransactionManager = None):
|
||
|
||
"""插入图数据库"""
|
||
debug(f"插入 {len(triples)} 个三元组到 Neo4j")
|
||
start_neo4j = time.time()
|
||
|
||
if triples:
|
||
for i in range(0, len(triples), 30):
|
||
batch_triples = triples[i:i + 30]
|
||
neo4j_result = await self.api_service.neo4j_insert_triples(
|
||
request=request,
|
||
triples=batch_triples,
|
||
document_id=id,
|
||
knowledge_base_id=fiid,
|
||
userid=orgid,
|
||
upappid=service_params['gdb'],
|
||
apiname="neo4j/inserttriples",
|
||
user=userid
|
||
)
|
||
if neo4j_result.get("status") != "success":
|
||
raise ValueError(f"Neo4j 三元组插入失败: {neo4j_result.get('message', '未知错误')}")
|
||
info(f"文件三元组成功插入 Neo4j: {neo4j_result.get('message')}")
|
||
|
||
timings["insert_neo4j"] = time.time() - start_neo4j
|
||
debug(f"Neo4j 插入耗时: {timings['insert_neo4j']:.2f} 秒")
|
||
else:
|
||
debug("未抽取到三元组")
|
||
timings["insert_neo4j"] = 0.0
|
||
|
||
# 记录事务操作,包含回滚函数
|
||
if transaction_mgr:
|
||
async def rollback_gdb_insert(data, context):
|
||
await self.delete_from_graph_db(
|
||
context['request'], data['id'],
|
||
context['service_params'], context['userid']
|
||
)
|
||
return f"已回滚图数据库插入: {data['id']}"
|
||
|
||
transaction_mgr.add_operation(
|
||
OperationType.GDB_INSERT,
|
||
{'id': id, 'triples_count': len(triples)},
|
||
rollback_func=rollback_gdb_insert
|
||
)
|
||
|
||
async def delete_from_vector_db(self, request, orgid: str, realpath: str, fiid: str,
|
||
id: str, service_params: Dict, userid: str, db_type: str):
|
||
"""从向量数据库删除文档"""
|
||
debug(f"调用删除文件端点: userid={orgid}, file_path={realpath}, knowledge_base_id={fiid}, document_id={id}")
|
||
milvus_result = await self.api_service.milvus_delete_document(
|
||
request=request,
|
||
userid=orgid,
|
||
file_path=realpath,
|
||
knowledge_base_id=fiid,
|
||
document_id=id,
|
||
db_type=db_type,
|
||
upappid=service_params['vdb'],
|
||
apiname="milvus/deletedocument",
|
||
user=userid
|
||
)
|
||
if milvus_result.get("status") != "success":
|
||
raise ValueError(milvus_result.get("message", "Milvus 删除失败"))
|
||
|
||
async def delete_from_graph_db(self, request, id: str, service_params: Dict, userid: str):
|
||
"""从图数据库删除文档"""
|
||
debug(f"调用 Neo4j 删除文档端点: document_id={id}")
|
||
neo4j_result = await self.api_service.neo4j_delete_document(
|
||
request=request,
|
||
document_id=id,
|
||
upappid=service_params['gdb'],
|
||
apiname="neo4j/deletedocument",
|
||
user=userid
|
||
)
|
||
if neo4j_result.get("status") != "success":
|
||
raise ValueError(neo4j_result.get("message", "Neo4j 删除失败"))
|
||
nodes_deleted = neo4j_result.get("nodes_deleted", 0)
|
||
rels_deleted = neo4j_result.get("rels_deleted", 0)
|
||
info(f"成功删除 document_id={id} 的 {nodes_deleted} 个 Neo4j 节点和 {rels_deleted} 个关系")
|
||
return nodes_deleted, rels_deleted
|
||
|
||
async def extract_entities(self, request, query: str, service_params: Dict, userid: str,
|
||
timings: Dict) -> List[str]:
|
||
"""提取实体"""
|
||
debug(f"提取查询实体: {query}")
|
||
start_extract = time.time()
|
||
entities = await self.api_service.extract_entities(
|
||
request=request,
|
||
query=query,
|
||
upappid=service_params['entities'],
|
||
apiname="LTP/small",
|
||
user=userid
|
||
)
|
||
timings["entity_extraction"] = time.time() - start_extract
|
||
debug(f"提取实体: {entities}, 耗时: {timings['entity_extraction']:.3f} 秒")
|
||
return entities
|
||
|
||
async def match_triplets(self, request, query: str, entities: List[str], orgid: str,
|
||
fiids: List[str], service_params: Dict, userid: str,
|
||
timings: Dict) -> List[Dict]:
|
||
"""匹配三元组"""
|
||
debug("开始三元组匹配")
|
||
start_triplet = time.time()
|
||
all_triplets = []
|
||
|
||
for kb_id in fiids:
|
||
debug(f"调用 Neo4j 三元组匹配: knowledge_base_id={kb_id}")
|
||
try:
|
||
neo4j_result = await self.api_service.neo4j_match_triplets(
|
||
request=request,
|
||
query=query,
|
||
query_entities=entities,
|
||
userid=orgid,
|
||
knowledge_base_id=kb_id,
|
||
upappid=service_params['gdb'],
|
||
apiname="neo4j/matchtriplets",
|
||
user=userid
|
||
)
|
||
if neo4j_result.get("status") == "success":
|
||
triplets = neo4j_result.get("triplets", [])
|
||
all_triplets.extend(triplets)
|
||
debug(f"知识库 {kb_id} 匹配到 {len(triplets)} 个三元组")
|
||
else:
|
||
error(f"Neo4j 三元组匹配失败: knowledge_base_id={kb_id}, 错误: {neo4j_result.get('message', '未知错误')}")
|
||
except Exception as e:
|
||
error(f"Neo4j 三元组匹配失败: knowledge_base_id={kb_id}, 错误: {str(e)}")
|
||
continue
|
||
|
||
timings["triplet_matching"] = time.time() - start_triplet
|
||
debug(f"三元组匹配总耗时: {timings['triplet_matching']:.3f} 秒")
|
||
return all_triplets
|
||
|
||
async def generate_query_vector(self, request, text: str, service_params: Dict,
|
||
userid: str, timings: Dict) -> List[float]:
|
||
"""生成查询向量"""
|
||
debug(f"生成查询向量: {text[:200]}...")
|
||
start_vector = time.time()
|
||
query_vector = await self.api_service.get_embeddings(
|
||
request=request,
|
||
texts=[text],
|
||
upappid=service_params['embedding'],
|
||
apiname="BAAI/bge-m3",
|
||
user=userid
|
||
)
|
||
if not query_vector or not all(len(vec) == 1024 for vec in query_vector):
|
||
raise ValueError("查询向量必须是长度为 1024 的浮点数列表")
|
||
query_vector = query_vector[0]
|
||
timings["vector_generation"] = time.time() - start_vector
|
||
debug(f"生成查询向量耗时: {timings['vector_generation']:.3f} 秒")
|
||
return query_vector
|
||
|
||
async def vector_search(self, request, query_vector: List[float], orgid: str,
|
||
fiids: List[str], limit: int, service_params: Dict, userid: str,
|
||
timings: Dict) -> List[Dict]:
|
||
"""向量搜索"""
|
||
debug("开始向量搜索")
|
||
start_search = time.time()
|
||
result = await self.api_service.milvus_search_query(
|
||
request=request,
|
||
query_vector=query_vector,
|
||
userid=orgid,
|
||
knowledge_base_ids=fiids,
|
||
limit=limit,
|
||
offset=0,
|
||
upappid=service_params['vdb'],
|
||
apiname="mlvus/searchquery",
|
||
user=userid
|
||
)
|
||
|
||
if result.get("status") != "success":
|
||
raise ValueError(f"向量搜索失败: {result.get('message', '未知错误')}")
|
||
|
||
search_results = result.get("results", [])
|
||
timings["vector_search"] = time.time() - start_search
|
||
debug(f"向量搜索耗时: {timings['vector_search']:.3f} 秒")
|
||
debug(f"从向量数据中搜索到{len(search_results)}条数据")
|
||
return search_results
|
||
|
||
async def rerank_results(self, request, query: str, results: List[Dict], top_n: int,
|
||
service_params: Dict, userid: str, timings: Dict) -> List[Dict]:
|
||
"""重排序结果"""
|
||
debug("开始重排序")
|
||
start_rerank = time.time()
|
||
reranked_results = await self.api_service.rerank_results(
|
||
request=request,
|
||
query=query,
|
||
results=results,
|
||
top_n=top_n,
|
||
upappid=service_params['reranker'],
|
||
apiname="BAAI/bge-reranker-v2-m3",
|
||
user=userid
|
||
)
|
||
reranked_results = sorted(reranked_results, key=lambda x: x.get('rerank_score', 0), reverse=True)
|
||
timings["reranking"] = time.time() - start_rerank
|
||
debug(f"重排序耗时: {timings['reranking']:.3f} 秒")
|
||
debug(f"重排序分数分布: {[round(r.get('rerank_score', 0), 3) for r in reranked_results]}")
|
||
return reranked_results
|
||
|
||
def _deduplicate_triples(self, triples: List[Dict]) -> List[Dict]:
|
||
"""去重和优化三元组"""
|
||
unique_triples = []
|
||
seen = set()
|
||
|
||
for t in triples:
|
||
identifier = (t['head'].lower(), t['tail'].lower(), t['type'].lower())
|
||
if identifier not in seen:
|
||
seen.add(identifier)
|
||
unique_triples.append(t)
|
||
else:
|
||
# 如果发现更具体的类型,则替换
|
||
for existing in unique_triples:
|
||
if (existing['head'].lower() == t['head'].lower() and
|
||
existing['tail'].lower() == t['tail'].lower() and
|
||
len(t['type']) > len(existing['type'])):
|
||
unique_triples.remove(existing)
|
||
unique_triples.append(t)
|
||
debug(f"替换三元组为更具体类型: {t}")
|
||
break
|
||
|
||
return unique_triples
|
||
|
||
def format_search_results(self, results: List[Dict], limit: int) -> List[Dict]:
|
||
"""格式化搜索结果为统一格式"""
|
||
formatted_results = []
|
||
# for res in results[:limit]:
|
||
# score = res.get('rerank_score', res.get('distance', 0))
|
||
#
|
||
# content = res.get('text', '')
|
||
# title = res.get('metadata', {}).get('filename', 'Untitled')
|
||
# document_id = res.get('metadata', {}).get('document_id', '')
|
||
#
|
||
# formatted_results.append({
|
||
# "content": content,
|
||
# "title": title,
|
||
# "metadata": {"document_id": document_id, "score": score},
|
||
# })
|
||
#得分归一化
|
||
for res in results[:limit]:
|
||
rerank_score = res.get('rerank_score', 0)
|
||
score = 1 / (1 + math.exp(-rerank_score)) if rerank_score is not None else 1 - res.get('distance', 0)
|
||
score = max(0.0, min(1.0, score))
|
||
|
||
content = res.get('text', '')
|
||
title = res.get('metadata', {}).get('filename', 'Untitled')
|
||
document_id = res.get('metadata', {}).get('document_id', '')
|
||
|
||
formatted_results.append({
|
||
"content": content,
|
||
"title": title,
|
||
"metadata": {"document_id": document_id, "score": score},
|
||
})
|
||
|
||
return formatted_results |