rag/rag/rag_operations.py
2025-11-12 15:15:52 +08:00

899 lines
38 KiB
Python
Raw Blame History

This file contains ambiguous Unicode characters

This file contains Unicode characters that might be confused with other characters. If you think that this is intentional, you can safely ignore this warning. Use the Escape button to reveal them.

import os
import re
import time
import math
from datetime import datetime
from typing import List, Dict, Any, Optional
from langchain_core.documents import Document
from langchain_text_splitters import RecursiveCharacterTextSplitter
from appPublic.log import debug, error, info
from filetxt.loader import fileloader, File2Text
from rag.uapi_service import APIService
from rag.service_opts import get_service_params
from rag.transaction_manager import TransactionManager, OperationType
class RagOperations:
"""RAG 操作类,提供所有通用的 RAG 操作"""
def __init__(self):
self.api_service = APIService()
async def load_and_chunk_document(self, realpath: str, timings: Dict,
transaction_mgr: TransactionManager = None) -> List[Document]:
"""加载文件并进行文本分片"""
debug(f"加载文件: {realpath}")
start_load = time.time()
# 检查文件格式支持
supported_formats = File2Text.supported_types()
debug(f"支持的文件格式:{supported_formats}")
ext = realpath.rsplit('.', 1)[1].lower() if '.' in realpath else ''
if ext not in supported_formats:
raise ValueError(f"不支持的文件格式: {ext}, 支持的格式: {', '.join(supported_formats)}")
# 加载文件内容
text = fileloader(realpath)
text = re.sub(r'[^\u4e00-\u9fa5a-zA-Z0-9\s.;,\n/]', '', text)
timings["load_file"] = time.time() - start_load
debug(f"加载文件耗时: {timings['load_file']:.2f} 秒, 文本长度: {len(text)}")
if not text or not text.strip():
raise ValueError(f"文件 {realpath} 加载为空")
# 分片处理
document = Document(page_content=text)
text_splitter = RecursiveCharacterTextSplitter(
chunk_size=500,
chunk_overlap=100,
length_function=len
)
debug("开始分片文件内容")
start_split = time.time()
chunks = text_splitter.split_documents([document])
timings["split_text"] = time.time() - start_split
debug(f"文本分片耗时: {timings['split_text']:.2f} 秒, 分片数量: {len(chunks)}")
if not chunks:
raise ValueError(f"文件 {realpath} 未生成任何文档块")
# 记录事务操作
if transaction_mgr:
transaction_mgr.add_operation(
OperationType.FILE_LOAD,
{'realpath': realpath, 'chunks_count': len(chunks)}
)
return chunks
async def generate_embeddings(self, request, chunks: List[Document], service_params: Dict,
userid: str, timings: Dict,
transaction_mgr: TransactionManager = None) -> List[List[float]]:
"""生成嵌入向量"""
debug("调用嵌入服务生成向量")
start_embedding = time.time()
texts = [chunk.page_content for chunk in chunks]
embeddings = []
# 批量处理嵌入
for i in range(0, len(texts), 10):
batch_texts = texts[i:i + 10]
batch_embeddings = await self.api_service.get_embeddings(
request=request,
texts=batch_texts,
upappid=service_params['embedding'],
apiname="BAAI/bge-m3",
user=userid
)
embeddings.extend(batch_embeddings)
if not embeddings or not all(len(vec) == 1024 for vec in embeddings):
raise ValueError("所有嵌入向量必须是长度为 1024 的浮点数列表")
timings["generate_embeddings"] = time.time() - start_embedding
debug(f"生成嵌入向量耗时: {timings['generate_embeddings']:.2f} 秒, 嵌入数量: {len(embeddings)}")
# 记录事务操作
if transaction_mgr:
transaction_mgr.add_operation(
OperationType.EMBEDDING,
{'embeddings_count': len(embeddings)}
)
return embeddings
async def generate_multi_embeddings(self, request, inputs: List[Dict], service_params: Dict,
userid: str, timings: Dict,
transaction_mgr: TransactionManager = None) -> Dict[str, Dict]:
"""调用多模态嵌入服务CLIP"""
debug("调用多模态嵌入服务")
start = time.time()
result = await self.api_service.get_multi_embeddings(
request=request,
inputs=inputs,
upappid=service_params['embedding'],
apiname="black/clip",
user=userid
)
debug(f"多模态返回结果是{result}")
timings["multi_embedding"] = time.time() - start
debug(f"多模态嵌入耗时: {timings['multi_embedding']:.2f}秒,处理 {len(result)}")
# ==================== 新增:错误检查 + 过滤 ====================
valid_results = {}
error_count = 0
error_examples = []
for key, info in result.items():
if info.get("type") == "error":
error_count += 1
if len(error_examples) < 3: # 只记录前3个
error_examples.append(f"{key}{info['error']}")
# 直接丢弃错误条目
continue
valid_results[key] = info
if error_count > 0:
error(f"多模态嵌入失败 {error_count} 条!示例:{'; '.join(error_examples)}")
raise RuntimeError(f"多模态嵌入有{error_count} 条失败")
else:
debug("多模态嵌入全部成功!")
if transaction_mgr:
transaction_mgr.add_operation(
OperationType.EMBEDDING,
{'count': len(result)}
)
return result
# 统一插入向量库
async def insert_all_vectors(
self,
request,
text_chunks: List[Document],
realpath: str,
orgid: str,
fiid: str,
document_id: str,
service_params: Dict,
userid: str,
db_type: str,
timings: Dict,
img_paths: List[str] = None,
text_embeddings: List[List[float]] = None,
multi_results: Dict = None,
transaction_mgr: TransactionManager = None
) -> Dict[str, int]:
"""
统一插入函数:支持两种模式
1. 纯文本模式text_embeddings 有值
2. 多模态模式multi_results 有值(来自 generate_multi_embeddings
"""
img_paths = img_paths or []
all_chunks = []
start = time.time()
filename = os.path.basename(realpath)
upload_time = datetime.now().isoformat()
# ==================== 1. 纯文本模式BGE ====================
if text_embeddings is not None:
debug(f"【纯文本模式】插入 {len(text_embeddings)} 条文本向量")
for i, chunk in enumerate(text_chunks):
all_chunks.append({
"userid": orgid,
"knowledge_base_id": fiid,
"text": chunk.page_content,
"vector": text_embeddings[i],
"document_id": document_id,
"filename": filename,
"file_path": realpath,
"upload_time": upload_time,
"file_type": "text",
})
# ==================== 2. 多模态模式CLIP 混排) ====================
if multi_results is not None:
debug(f"【多模态模式】解析 {len(multi_results)} 条 CLIP 结果")
# 遍历 multi_results
for raw_key, info in multi_results.items():
typ = info["type"]
# --- 文本 ---
if typ == "text":
# raw_key 就是原文
all_chunks.append({
"userid": orgid,
"knowledge_base_id": fiid,
"text": raw_key,
"vector": info["vector"],
"document_id": document_id,
"filename": filename,
"file_path": realpath,
"upload_time": upload_time,
"file_type": "text",
})
continue
# --- 图像 ---
if typ == "image":
img_path = info.get("path") or raw_key
img_name = os.path.basename(img_path)
# 整图向量
if "vector" in info:
all_chunks.append({
"userid": orgid,
"knowledge_base_id": fiid,
"text": f"[Image: {img_path}]图片来源于文件{realpath}",
"vector": info["vector"],
"document_id": document_id,
"filename": img_name,
"file_path": realpath,
"upload_time": upload_time,
"file_type": "image",
})
# 人脸向量
face_vecs = info.get("face_vecs", [])
face_count = len(face_vecs)
# if face_count > 0:
# for f_idx, fvec in enumerate(face_vecs):
# debug(f"人脸向量维度是:{len(fvec)}")
# all_chunks.append({
# "userid": orgid,
# "knowledge_base_id": fiid,
# "text": f"[Face {f_idx + 1}/{face_count} in {img_name}]人脸来源于{realpath}的{img_path}图片",
# "vector": fvec,
# "document_id": document_id,
# "filename": img_name,
# "file_path": realpath,
# "upload_time": upload_time,
# "file_type": "face",
# })
# continue
# --- 视频 ---
if typ == "video":
video_path = info.get("path") or raw_key
video_name = os.path.basename(video_path)
if "vector" in info:
all_chunks.append({
"userid": orgid,
"knowledge_base_id": fiid,
"text": f"[Video: {video_name}]",
"vector": info["vector"],
"document_id": document_id,
"filename": video_path,
"file_path": realpath,
"upload_time": upload_time,
"file_type": "video",
})
# 视频人脸
face_vecs = info.get("face_vecs", [])
face_count = len(face_vecs)
# if face_count > 0 :
# for f_idx, fvec in enumerate(face_vecs):
# all_chunks.append({
# "userid": orgid,
# "knowledge_base_id": fiid,
# "text": f"[Face {f_idx + 1}/{face_count} in video {video_name}]来源于{video_path}",
# "vector": fvec,
# "document_id": document_id,
# "filename": video_path,
# "file_path": realpath,
# "upload_time": upload_time,
# "file_type": "face",
# })
# continue
# --- 音频 ---
if typ == "audio":
audio_path = info.get("path") or raw_key
audio_name = os.path.basename(audio_path)
if "vector" in info:
all_chunks.append({
"userid": orgid,
"knowledge_base_id": fiid,
"text": f"[Audio: {audio_name}]",
"vector": info["vector"],
"document_id": document_id,
"filename": audio_path,
"file_path": realpath,
"upload_time": upload_time,
"file_type": "audio",
})
continue
# --- 未知类型 ---
debug(f"未知类型跳过: {typ}{raw_key}")
# ==================== 3. 批量插入 Milvus ====================
if not all_chunks:
debug("无向量需要插入")
return {"text": 0, "image": 0, "face": 0}
for i in range(0, len(all_chunks), 10):
batch = all_chunks[i:i + 10]
result = await self.api_service.milvus_insert_document(
request=request,
chunks=batch,
upappid=service_params['vdb'],
apiname="milvus/insertdocument",
user=userid,
db_type=db_type
)
if result.get("status") != "success":
raise ValueError(f"Milvus 插入失败: {result.get('message')}")
# ==================== 4. 统一回滚(只登记一次) ====================
if transaction_mgr and all_chunks:
async def rollback_all(data, context):
try:
await self.delete_from_vector_db(
request=context['request'],
orgid=data['orgid'],
realpath=data['realpath'],
fiid=data['fiid'],
id=data['document_id'],
service_params=context['service_params'],
userid=context['userid'],
db_type=data['db_type']
)
return f"已回滚 document_id={data['document_id']} 的所有向量"
except Exception as e:
error(f"统一回滚失败: {e}")
raise
transaction_mgr.add_operation(
OperationType.VDB_INSERT,
{
'orgid': orgid,
'realpath': realpath,
'fiid': fiid,
'id': document_id,
'db_type': db_type
},
rollback_func=rollback_all
)
# ==================== 5. 统计返回 ====================
stats = {
"text": len([c for c in all_chunks if c["file_type"] == "text"]),
"image": len([c for c in all_chunks if c["file_type"] == "image"]),
"face": len([c for c in all_chunks if c["file_type"] == "face"])
}
timings["insert_all"] = time.time() - start
debug(
f"统一插入完成: 文本 {stats['text']}, 图像 {stats['image']}, 人脸 {stats['face']}, 耗时 {timings['insert_all']:.2f}s")
return stats
# async def insert_to_vector_db(self, request, chunks: List[Document], embeddings: List[List[float]],
# realpath: str, orgid: str, fiid: str, id: str, service_params: Dict,
# userid: str, db_type: str, timings: Dict,
# transaction_mgr: TransactionManager = None):
# """插入向量数据库"""
# debug(f"准备数据并调用插入文件端点: {realpath}")
# filename = os.path.basename(realpath).rsplit('.', 1)[0]
# ext = realpath.rsplit('.', 1)[1].lower() if '.' in realpath else ''
# upload_time = datetime.now().isoformat()
#
# chunks_data = [
# {
# "userid": orgid,
# "knowledge_base_id": fiid,
# "text": chunk.page_content,
# "vector": embeddings[i],
# "document_id": id,
# "filename": filename + '.' + ext,
# "file_path": realpath,
# "upload_time": upload_time,
# "file_type": ext,
# }
# for i, chunk in enumerate(chunks)
# ]
#
# start_milvus = time.time()
# for i in range(0, len(chunks_data), 10):
# batch_chunks = chunks_data[i:i + 10]
# debug(f"传入的数据是:{batch_chunks}")
# result = await self.api_service.milvus_insert_document(
# request=request,
# chunks=batch_chunks,
# db_type=db_type,
# upappid=service_params['vdb'],
# apiname="milvus/insertdocument",
# user=userid
# )
# if result.get("status") != "success":
# raise ValueError(result.get("message", "Milvus 插入失败"))
#
# timings["insert_milvus"] = time.time() - start_milvus
# debug(f"Milvus 插入耗时: {timings['insert_milvus']:.2f} 秒")
#
# # 记录事务操作,包含回滚函数
# if transaction_mgr:
# async def rollback_vdb_insert(data, context):
# try:
# # 防御性检查
# required_context = ['request', 'service_params', 'userid']
# missing_context = [k for k in required_context if k not in context or context[k] is None]
# if missing_context:
# raise ValueError(f"回滚上下文缺少字段: {', '.join(missing_context)}")
#
# required_data = ['orgid', 'realpath', 'fiid', 'id', 'db_type']
# missing_data = [k for k in required_data if k not in data or data[k] is None]
# if missing_data:
# raise ValueError(f"VDB_INSERT 数据缺少字段: {', '.join(missing_data)}")
#
# await self.delete_from_vector_db(
# context['request'], data['orgid'], data['realpath'],
# data['fiid'], data['id'], context['service_params'],
# context['userid'], data['db_type']
# )
# return f"已回滚向量数据库插入: {data['id']}"
# except Exception as e:
# error(f"回滚向量数据库失败: document_id={data.get('id', '未知')}, 错误: {str(e)}")
# raise
#
# transaction_mgr.add_operation(
# OperationType.VDB_INSERT,
# {
# 'orgid': orgid, 'realpath': realpath, 'fiid': fiid,
# 'id': id, 'db_type': db_type
# },
# rollback_func=rollback_vdb_insert
# )
#
# return chunks_data
#
# async def insert_image_vectors(
# self,
# request,
# multi_results: Dict[str, Dict],
# realpath: str,
# orgid: str,
# fiid: str,
# document_id: str,
# service_params: Dict,
# userid: str,
# db_type: str,
# timings: Dict,
# transaction_mgr: TransactionManager = None
# ) -> tuple[int, int]:
#
# start = time.time()
# image_chunks = []
# face_chunks = []
#
# for img_path, info in multi_results.items():
# # img_name = os.path.basename(img_path)
#
# # 1. 插入整张图
# if info.get("type") in ["image", "video"] and "vector" in info:
# image_chunks.append({
# "userid": orgid,
# "knowledge_base_id": fiid,
# "text": f"[Image: {img_path}]",
# "vector": info["vector"],
# "document_id": document_id,
# "filename": os.path.basename(realpath),
# "file_path": realpath,
# "upload_time": datetime.now().isoformat(),
# "file_type": "image"
# })
#
# # 2. 插入每张人脸
# face_vecs = info.get("face_vecs")
# face_count = info.get("face_count", 0)
#
# if face_count > 0 and face_vecs and len(face_vecs) == face_count:
# for idx, face_vec in enumerate(face_vecs):
# face_chunks.append({
# "userid": orgid,
# "knowledge_base_id": fiid,
# "text": f"[Face {idx + 1}/{face_count} in {img_path}]",
# "vector": face_vec,
# "document_id": document_id,
# "filename": os.path.basename(realpath),
# "file_path": realpath,
# "upload_time": datetime.now().isoformat(),
# "file_type": "face",
# })
#
# if image_chunks:
# for i in range(0, len(image_chunks), 10):
# await self.api_service.milvus_insert_document(
# request=request,
# chunks=image_chunks[i:i + 10],
# upappid=service_params['vdb'],
# apiname="milvus/insertdocument",
# user=userid,
# db_type=db_type
# )
#
# if face_chunks:
# for i in range(0, len(face_chunks), 10):
# await self.api_service.milvus_insert_document(
# request=request,
# chunks=face_chunks[i:i + 10],
# upappid=service_params['vdb'],
# apiname="milvus/insertdocument",
# user=userid,
# db_type=db_type
# )
# timings["insert_images"] = time.time() - start
# image_count = len(image_chunks)
# face_count = len(face_chunks)
#
# debug(f"多模态插入完成: 图像 {image_count} 条, 人脸 {face_count} 条")
#
# if transaction_mgr and (image_count + face_count > 0):
# transaction_mgr.add_operation(
# OperationType.IMAGE_VECTORS_INSERT,
# {"images": image_count, "faces": face_count, "document_id": document_id}
# )
#
# # 记录事务操作,包含回滚函数
# if transaction_mgr:
# async def rollback_multimodal(data, context):
# try:
# # 防御性检查
# required_context = ['request', 'service_params', 'userid']
# missing_context = [k for k in required_context if k not in context or context[k] is None]
# if missing_context:
# raise ValueError(f"回滚上下文缺少字段: {', '.join(missing_context)}")
#
# required_data = ['orgid', 'realpath', 'fiid', 'id', 'db_type']
# missing_data = [k for k in required_data if k not in data or data[k] is None]
# if missing_data:
# raise ValueError(f"多模态回滚数据缺少字段: {', '.join(missing_data)}")
#
# await self.delete_from_vector_db(
# context['request'], data['orgid'], data['realpath'],
# data['fiid'], data['id'], context['service_params'],
# context['userid'], data['db_type']
# )
# return f"已回滚多模态向量: {data['id']}"
# except Exception as e:
# error(f"多模态回滚向量数据库失败: document_id={data.get('id', '未知')}, 错误: {str(e)}")
# raise
#
# transaction_mgr.add_operation(
# OperationType.VDB_INSERT,
# {
# 'orgid': orgid, 'realpath': realpath, 'fiid': fiid,
# 'id': id, 'db_type': db_type
# },
# rollback_func=rollback_multimodal
# )
#
# return image_count, face_count
async def insert_to_vector_text(self, request,
db_type: str, fields: Dict, service_params: Dict, userid: str, timings: Dict) -> List[Dict]:
"""插入单一纯文本到向量数据库,支持动态 schema"""
chunk_data = {}
debug("准备单一纯文本数据并调用插入端点")
start = time.time()
for key, value in fields.items():
chunk_data[key] = value
chunks_data = [chunk_data]
debug(f"向量库插入传入的数据是:{chunks_data}")
# 调用 Milvus 插入
result = await self.api_service.milvus_insert_document(
request=request,
chunks=chunks_data,
upappid=service_params['vdb'],
apiname="milvus/insertdocument",
user=userid,
db_type=db_type
)
if result.get("status") != "success":
raise ValueError(result.get("message", "Milvus 插入失败"))
debug(f"成功插入纯文本到集合 {result.get('collection_name')}")
timings["textinsert"] = time.time() - start
debug(f"插入纯文本耗时: {timings['textinsert']:.2f}")
return chunks_data
async def extract_triples(self, request, chunks: List[Document], service_params: Dict,
userid: str, timings: Dict,
transaction_mgr: TransactionManager = None) -> List[Dict]:
"""抽取三元组"""
debug("调用三元组抽取服务")
start_triples = time.time()
chunk_texts = [doc.page_content for doc in chunks]
triples = []
for i, chunk in enumerate(chunk_texts):
result = await self.api_service.extract_triples(
request=request,
text=chunk,
upappid=service_params['triples'],
apiname="Babelscape/mrebel-large",
user=userid
)
if isinstance(result, list):
triples.extend(result)
debug(f"分片 {i + 1} 抽取到 {len(result)} 个三元组")
else:
error(f"分片 {i + 1} 处理失败: {str(result)}")
# 去重和优化三元组
unique_triples = self._deduplicate_triples(triples)
timings["extract_triples"] = time.time() - start_triples
debug(f"三元组抽取耗时: {timings['extract_triples']:.2f} 秒, 抽取到 {len(unique_triples)} 个三元组")
# 记录事务操作
if transaction_mgr:
transaction_mgr.add_operation(
OperationType.TRIPLES_EXTRACT,
{'triples_count': len(unique_triples)}
)
return unique_triples
async def insert_to_graph_db(self, request, triples: List[Dict], id: str, fiid: str,
orgid: str, service_params: Dict, userid: str, timings: Dict,
transaction_mgr: TransactionManager = None):
"""插入图数据库"""
debug(f"插入 {len(triples)} 个三元组到 Neo4j")
start_neo4j = time.time()
if triples:
for i in range(0, len(triples), 30):
batch_triples = triples[i:i + 30]
neo4j_result = await self.api_service.neo4j_insert_triples(
request=request,
triples=batch_triples,
document_id=id,
knowledge_base_id=fiid,
userid=orgid,
upappid=service_params['gdb'],
apiname="neo4j/inserttriples",
user=userid
)
if neo4j_result.get("status") != "success":
raise ValueError(f"Neo4j 三元组插入失败: {neo4j_result.get('message', '未知错误')}")
info(f"文件三元组成功插入 Neo4j: {neo4j_result.get('message')}")
timings["insert_neo4j"] = time.time() - start_neo4j
debug(f"Neo4j 插入耗时: {timings['insert_neo4j']:.2f}")
else:
debug("未抽取到三元组")
timings["insert_neo4j"] = 0.0
# 记录事务操作,包含回滚函数
if transaction_mgr:
async def rollback_gdb_insert(data, context):
await self.delete_from_graph_db(
context['request'], data['id'],
context['service_params'], context['userid']
)
return f"已回滚图数据库插入: {data['id']}"
transaction_mgr.add_operation(
OperationType.GDB_INSERT,
{'id': id, 'triples_count': len(triples)},
rollback_func=rollback_gdb_insert
)
async def delete_from_vector_db(self, request, orgid: str, realpath: str, fiid: str,
id: str, service_params: Dict, userid: str, db_type: str):
"""从向量数据库删除文档"""
debug(f"调用删除文件端点: userid={orgid}, file_path={realpath}, knowledge_base_id={fiid}, document_id={id}")
milvus_result = await self.api_service.milvus_delete_document(
request=request,
userid=orgid,
file_path=realpath,
knowledge_base_id=fiid,
document_id=id,
db_type=db_type,
upappid=service_params['vdb'],
apiname="milvus/deletedocument",
user=userid
)
if milvus_result.get("status") != "success":
raise ValueError(milvus_result.get("message", "Milvus 删除失败"))
async def delete_from_graph_db(self, request, id: str, service_params: Dict, userid: str):
"""从图数据库删除文档"""
debug(f"调用 Neo4j 删除文档端点: document_id={id}")
neo4j_result = await self.api_service.neo4j_delete_document(
request=request,
document_id=id,
upappid=service_params['gdb'],
apiname="neo4j/deletedocument",
user=userid
)
if neo4j_result.get("status") != "success":
raise ValueError(neo4j_result.get("message", "Neo4j 删除失败"))
nodes_deleted = neo4j_result.get("nodes_deleted", 0)
rels_deleted = neo4j_result.get("rels_deleted", 0)
info(f"成功删除 document_id={id}{nodes_deleted} 个 Neo4j 节点和 {rels_deleted} 个关系")
return nodes_deleted, rels_deleted
async def extract_entities(self, request, query: str, service_params: Dict, userid: str,
timings: Dict) -> List[str]:
"""提取实体"""
debug(f"提取查询实体: {query}")
start_extract = time.time()
entities = await self.api_service.extract_entities(
request=request,
query=query,
upappid=service_params['entities'],
apiname="LTP/small",
user=userid
)
timings["entity_extraction"] = time.time() - start_extract
debug(f"提取实体: {entities}, 耗时: {timings['entity_extraction']:.3f}")
return entities
async def match_triplets(self, request, query: str, entities: List[str], orgid: str,
fiids: List[str], service_params: Dict, userid: str,
timings: Dict) -> List[Dict]:
"""匹配三元组"""
debug("开始三元组匹配")
start_triplet = time.time()
all_triplets = []
for kb_id in fiids:
debug(f"调用 Neo4j 三元组匹配: knowledge_base_id={kb_id}")
try:
neo4j_result = await self.api_service.neo4j_match_triplets(
request=request,
query=query,
query_entities=entities,
userid=orgid,
knowledge_base_id=kb_id,
upappid=service_params['gdb'],
apiname="neo4j/matchtriplets",
user=userid
)
if neo4j_result.get("status") == "success":
triplets = neo4j_result.get("triplets", [])
all_triplets.extend(triplets)
debug(f"知识库 {kb_id} 匹配到 {len(triplets)} 个三元组")
else:
error(f"Neo4j 三元组匹配失败: knowledge_base_id={kb_id}, 错误: {neo4j_result.get('message', '未知错误')}")
except Exception as e:
error(f"Neo4j 三元组匹配失败: knowledge_base_id={kb_id}, 错误: {str(e)}")
continue
timings["triplet_matching"] = time.time() - start_triplet
debug(f"三元组匹配总耗时: {timings['triplet_matching']:.3f}")
return all_triplets
async def generate_query_vector(self, request, text: str, service_params: Dict,
userid: str, timings: Dict) -> List[float]:
"""生成查询向量"""
debug(f"生成查询向量: {text[:200]}...")
start_vector = time.time()
query_vector = await self.api_service.get_embeddings(
request=request,
texts=[text],
upappid=service_params['embedding'],
apiname="BAAI/bge-m3",
user=userid
)
if not query_vector or not all(len(vec) == 1024 for vec in query_vector):
raise ValueError("查询向量必须是长度为 1024 的浮点数列表")
query_vector = query_vector[0]
timings["vector_generation"] = time.time() - start_vector
debug(f"生成查询向量耗时: {timings['vector_generation']:.3f}")
return query_vector
async def vector_search(self, request, query_vector: List[float], orgid: str,
fiids: List[str], limit: int, service_params: Dict, userid: str,
timings: Dict) -> List[Dict]:
"""向量搜索"""
debug("开始向量搜索")
start_search = time.time()
result = await self.api_service.milvus_search_query(
request=request,
query_vector=query_vector,
userid=orgid,
knowledge_base_ids=fiids,
limit=limit,
offset=0,
upappid=service_params['vdb'],
apiname="mlvus/searchquery",
user=userid
)
if result.get("status") != "success":
raise ValueError(f"向量搜索失败: {result.get('message', '未知错误')}")
search_results = result.get("results", [])
timings["vector_search"] = time.time() - start_search
debug(f"向量搜索耗时: {timings['vector_search']:.3f}")
debug(f"从向量数据中搜索到{len(search_results)}条数据")
return search_results
async def rerank_results(self, request, query: str, results: List[Dict], top_n: int,
service_params: Dict, userid: str, timings: Dict) -> List[Dict]:
"""重排序结果"""
debug("开始重排序")
start_rerank = time.time()
reranked_results = await self.api_service.rerank_results(
request=request,
query=query,
results=results,
top_n=top_n,
upappid=service_params['reranker'],
apiname="BAAI/bge-reranker-v2-m3",
user=userid
)
reranked_results = sorted(reranked_results, key=lambda x: x.get('rerank_score', 0), reverse=True)
timings["reranking"] = time.time() - start_rerank
debug(f"重排序耗时: {timings['reranking']:.3f}")
debug(f"重排序分数分布: {[round(r.get('rerank_score', 0), 3) for r in reranked_results]}")
return reranked_results
def _deduplicate_triples(self, triples: List[Dict]) -> List[Dict]:
"""去重和优化三元组"""
unique_triples = []
seen = set()
for t in triples:
identifier = (t['head'].lower(), t['tail'].lower(), t['type'].lower())
if identifier not in seen:
seen.add(identifier)
unique_triples.append(t)
else:
# 如果发现更具体的类型,则替换
for existing in unique_triples:
if (existing['head'].lower() == t['head'].lower() and
existing['tail'].lower() == t['tail'].lower() and
len(t['type']) > len(existing['type'])):
unique_triples.remove(existing)
unique_triples.append(t)
debug(f"替换三元组为更具体类型: {t}")
break
return unique_triples
def format_search_results(self, results: List[Dict], limit: int) -> List[Dict]:
"""格式化搜索结果为统一格式"""
formatted_results = []
# for res in results[:limit]:
# score = res.get('rerank_score', res.get('distance', 0))
#
# content = res.get('text', '')
# title = res.get('metadata', {}).get('filename', 'Untitled')
# document_id = res.get('metadata', {}).get('document_id', '')
#
# formatted_results.append({
# "content": content,
# "title": title,
# "metadata": {"document_id": document_id, "score": score},
# })
#得分归一化
for res in results[:limit]:
rerank_score = res.get('rerank_score', 0)
score = 1 / (1 + math.exp(-rerank_score)) if rerank_score is not None else 1 - res.get('distance', 0)
score = max(0.0, min(1.0, score))
content = res.get('text', '')
title = res.get('metadata', {}).get('filename', 'Untitled')
document_id = res.get('metadata', {}).get('document_id', '')
formatted_results.append({
"content": content,
"title": title,
"metadata": {"document_id": document_id, "score": score},
})
return formatted_results