Merge branch 'main' of git.opencomputing.cn:yumoqing/rag

This commit is contained in:
yumoqing 2025-12-25 15:28:49 +08:00
commit 6155ea176b
13 changed files with 1468 additions and 158 deletions

16
json/embeddingmode.json Normal file
View File

@ -0,0 +1,16 @@
{
"tblname": "embedding_mode",
"title":"嵌入模式",
"params":{
"browserfields":{
"exclouded":["id"],
"alters":{}
},
"editexclouded":["id"],
"toolbar":{
},
"binds":[
]
}
}

BIN
models/embeddingmode.xlsx Normal file

Binary file not shown.

341
rag/fileprocess.py Normal file
View File

@ -0,0 +1,341 @@
import numpy as np
import os
import re
from pdf2image import convert_from_path
from appPublic.log import debug, error, info
from pathlib import Path
import zipfile
import xml.etree.ElementTree as ET
from PIL import Image
from typing import List
# ==================== 新增:路径安全化函数 ====================
def safe_filename(name: str) -> str:
"""
安全化文件名/目录名
- 去除首尾空格
- 多个空格 单空格
- 非法字符 下划线
- 空格 下划线推荐永不炸
"""
name = name.strip()
name = re.sub(r'\s+', ' ', name) # 多个空格合并
name = re.sub(r'[<>:"/\\|?*]', '_', name) # 非法字符
name = name.replace(' ', '_') # 空格 → 下划线(关键!)
return name
def render_pdf_to_images(pdf_path, base_output_dir, dpi=200, image_format="PNG")-> List[str]:
"""
将PDF文件的每一页渲染为图片
参数:
pdf_path (str): PDF文件路径
page_output_dir (str): 输出图片的目录
dpi (int): 图片分辨率默认200
image_format (str): 图片格式默认PNG
返回:
int: 成功渲染的页面数量
"""
pdf_filename = safe_filename(Path(pdf_path).stem)
page_output_dir = os.path.join(base_output_dir, pdf_filename)
# 创建输出目录(如果不存在)
if not os.path.exists(page_output_dir):
os.makedirs(page_output_dir, exist_ok=True)
debug(f"创建输出目录: {page_output_dir}")
try:
# 检查PDF文件是否存在
if not os.path.exists(pdf_path):
error(f"PDF文件不存在: {pdf_path}")
return []
debug(f"开始渲染PDF: {pdf_path}")
debug(f"输出目录: {page_output_dir}")
debug(f"分辨率: {dpi} DPI, 格式: {image_format}")
# 渲染PDF页面为图片
pages = convert_from_path(pdf_path, dpi=dpi)
debug(f"PDF总页数: {len(pages)}")
debug("📄 正在渲染 PDF 页面...")
img_paths = []
for i, page in enumerate(pages, start=1):
try:
# 生成图片文件路径
img_path = os.path.join(page_output_dir, f"page_{i:03d}.{image_format.lower()}")
img_paths.append(img_path)
# 保存图片
page.save(img_path, image_format)
debug(f"✅ 已保存 {img_path}")
except Exception as e:
error(f"保存第 {i} 页失败: {e}")
continue
debug(f"渲染完成: 成功保存{len(pages)}")
return img_paths
except Exception as e:
error(f"渲染PDF失败: {e}")
return []
def extract_images_from_word(doc_path, base_output_dir) -> List[str]:
"""
从Word文档中提取所有图像
参数:
doc_path (str): Word文档路径.docx格式
base_output_dir (str): 基础输出目录会在此目录下创建以文档名命名的子文件夹
返回:
int: 成功提取的图像数量
"""
# 检查文件是否为.docx格式
if not doc_path.lower().endswith('.docx'):
error(f"仅支持.docx格式的Word文档: {doc_path}")
return []
# 从文档路径提取文件名(不含扩展名)
doc_filename = safe_filename(Path(doc_path).stem)
# 创建以文档名命名的子文件夹
image_output_dir = os.path.join(base_output_dir, doc_filename)
# 创建输出目录(如果不存在)
if not os.path.exists(image_output_dir):
os.makedirs(image_output_dir, exist_ok=True)
debug(f"创建输出目录: {image_output_dir}")
try:
# 检查文档是否存在
if not os.path.exists(doc_path):
error(f"Word文档不存在: {doc_path}")
return []
debug(f"开始从Word文档提取图像: {doc_path}")
debug(f"输出目录: {image_output_dir}")
# 将.docx文件视为zip文件处理
with zipfile.ZipFile(doc_path, 'r') as docx:
# 获取所有文件列表
file_list = docx.namelist()
# 筛选出图像文件通常位于word/media/目录下)
image_files = [f for f in file_list if f.startswith('word/media/') and not f.endswith('/') and os.path.basename(f)]
debug(f"找到 {len(image_files)} 个图像文件")
img_paths = []
for i, image_path in enumerate(image_files):
try:
# 提取图像文件名
image_name = os.path.basename(image_path)
# 确保文件名有效
if not image_name or image_name == "media":
# 从路径中提取有意义的文件名
parts = image_path.split('/')
for part in reversed(parts):
if part and part != "media":
image_name = part
break
else:
image_name = f"image_{i + 1}.png"
# 添加文件扩展名如果缺失
if not Path(image_name).suffix:
# 尝试从文件内容检测格式否则使用默认png
image_name += ".png"
# 生成输出文件路径
output_path = os.path.join(image_output_dir, f"image_{i + 1:03d}_{image_name}")
img_paths.append(output_path)
# 提取并保存图像
with docx.open(image_path) as image_file:
image_data = image_file.read()
# 保存图像数据
with open(output_path, 'wb') as f:
f.write(image_data)
debug(f"✅ 已提取图像: {output_path}")
except Exception as e:
error(f"提取图像 {image_path} 失败: {e}")
continue
debug(f"Word文档图像提取完成: 成功提取 {len(image_files)} 个图像")
return img_paths
except Exception as e:
error(f"提取Word文档图像失败: {e}")
return []
def extract_images_from_ppt(ppt_path, base_output_dir) -> List[str]:
"""
从PowerPoint演示文稿中提取所有图像
参数:
ppt_path (str): PowerPoint文件路径.pptx格式
base_output_dir (str): 基础输出目录会在此目录下创建以PPT名命名的子文件夹
返回:
int: 成功提取的图像数量
"""
# 检查文件是否为.pptx格式
if not ppt_path.lower().endswith('.pptx'):
error(f"仅支持.pptx格式的PowerPoint文档: {ppt_path}")
return []
# 从PPT路径提取文件名不含扩展名
ppt_filename = safe_filename(Path(ppt_path).stem)
# 创建以PPT名命名的子文件夹
image_output_dir = os.path.join(base_output_dir, ppt_filename)
# 创建输出目录(如果不存在)
if not os.path.exists(image_output_dir):
os.makedirs(image_output_dir, exist_ok=True)
debug(f"创建输出目录: {image_output_dir}")
try:
# 检查PPT文件是否存在
if not os.path.exists(ppt_path):
error(f"PowerPoint文档不存在: {ppt_path}")
return []
debug(f"开始从PowerPoint文档提取图像: {ppt_path}")
debug(f"输出目录: {image_output_dir}")
# 将.pptx文件视为zip文件处理
with zipfile.ZipFile(ppt_path, 'r') as pptx:
# 获取所有文件列表
file_list = pptx.namelist()
# 筛选出图像文件通常位于ppt/media/目录下)
image_files = [f for f in file_list if f.startswith('ppt/media/') and not f.endswith('/') and os.path.basename(f)]
debug(f"找到 {len(image_files)} 个图像文件")
img_paths = []
for i, image_path in enumerate(image_files):
try:
# 提取图像文件名
image_name = Path(image_path).name
# 验证文件名有效性
if not image_name or image_name == "media":
parts = image_path.split('/')
for part in reversed(parts):
if part and part != "media":
image_name = part
break
else:
image_name = f"image_{i + 1}.png"
# 确保有文件扩展名
if not Path(image_name).suffix:
image_name += ".png"
# 生成输出文件路径
output_path = os.path.join(image_output_dir, f"image_{i + 1:03d}_{image_name}")
img_paths.append(output_path)
# 提取并保存图像
with pptx.open(image_path) as image_file:
image_data = image_file.read()
# 保存图像数据
with open(output_path, 'wb') as f:
f.write(image_data)
debug(f"✅ 已提取图像: {output_path}")
except Exception as e:
error(f"提取图像 {image_path} 失败: {e}")
continue
debug(f"PowerPoint文档图像提取完成: 成功提取{len(image_files)} 个图像")
return img_paths
except Exception as e:
error(f"提取PowerPoint文档图像失败: {e}")
return []
def extract_images_from_file(file_path, base_output_dir="/home/wangmeihua/kyrag/data/extracted_images", file_type=None):
"""
通用函数根据文件类型自动选择提取方法
参数:
file_path (str): 文件路径
base_output_dir (str): 基础输出目录
file_type (str): 文件类型可选自动检测
返回:
int: 成功提取的图像/页面数量
"""
# 如果没有指定文件类型,根据扩展名自动检测
if file_type is None:
ext = Path(file_path).suffix.lower()
if ext == '.pdf':
file_type = 'pdf'
elif ext == '.docx':
file_type = 'word'
elif ext == '.pptx':
file_type = 'ppt'
else:
error(f"不支持的文件类型: {ext}")
return []
# 根据文件类型调用相应的函数
if file_type == 'pdf':
return render_pdf_to_images(file_path, base_output_dir)
elif file_type == 'word':
return extract_images_from_word(file_path, base_output_dir)
elif file_type == 'ppt':
return extract_images_from_ppt(file_path, base_output_dir)
else:
error(f"不支持的文件类型: {file_type}")
return []
# 使用示例
if __name__ == "__main__":
base_output_dir = "/home/wangmeihua/kyrag/data/extracted_images"
# PDF文件处理
pdf_path = "/home/wangmeihua/kyrag/22-zh-review.pdf"
pdf_imgs = extract_images_from_file(pdf_path, base_output_dir, 'pdf')
debug(f"pdf_imgs: {pdf_imgs}")
if len(pdf_imgs) > 0:
debug(f"成功处理PDF: {len(pdf_imgs)}")
else:
error("PDF处理失败")
# Word文档处理
doc_path = "/home/wangmeihua/kyrag/test.docx"
if os.path.exists(doc_path):
doc_imgs = extract_images_from_file(doc_path, base_output_dir, 'word')
debug(f"doc_imgs: {doc_imgs}")
if len(doc_imgs) > 0:
debug(f"成功处理Word文档: {len(doc_imgs)} 个图像")
else:
error("Word文档处理失败")
else:
debug(f"Word文档不存在: {doc_path}")
# PowerPoint处理
ppt_path = "/home/wangmeihua/kyrag/提示学习-王美华.pptx"
if os.path.exists(ppt_path):
ppt_imgs = extract_images_from_file(ppt_path, base_output_dir, 'ppt')
if len(ppt_imgs) > 0:
debug(f"成功处理PowerPoint: {len(ppt_imgs)} 个图像")
else:
error("PowerPoint处理失败")
else:
debug(f"PowerPoint文档不存在: {ppt_path}")

View File

@ -17,13 +17,15 @@ import traceback
from filetxt.loader import fileloader,File2Text from filetxt.loader import fileloader,File2Text
from ahserver.serverenv import get_serverenv from ahserver.serverenv import get_serverenv
from typing import List, Dict, Any from typing import List, Dict, Any
from rag.service_opts import get_service_params, sor_get_service_params from rag.service_opts import get_service_params, sor_get_service_params, sor_get_embedding_mode, get_embedding_mode
from rag.fileprocess import extract_images_from_file
from rag.rag_operations import RagOperations from rag.rag_operations import RagOperations
import json import json
from rag.transaction_manager import TransactionContext from rag.transaction_manager import TransactionContext
from dataclasses import dataclass from dataclasses import dataclass
from enum import Enum from enum import Enum
import base64
from pathlib import Path
class RagFileMgr(FileMgr): class RagFileMgr(FileMgr):
def __init__(self, fiid): def __init__(self, fiid):
@ -53,6 +55,10 @@ where a.orgid = b.orgid
return r.quota, r.expired_date return r.quota, r.expired_date
return None, None return None, None
async def file_to_base64(self,path: str) -> str:
with open(path, "rb") as f:
return base64.b64encode(f.read()).decode("utf-8")
async def file_uploaded(self, request, ns, userid): async def file_uploaded(self, request, ns, userid):
"""将文档插入 Milvus 并抽取三元组到 Neo4j""" """将文档插入 Milvus 并抽取三元组到 Neo4j"""
debug(f'Received ns: {ns=}') debug(f'Received ns: {ns=}')
@ -104,21 +110,107 @@ where a.orgid = b.orgid
raise ValueError("无法获取服务参数") raise ValueError("无法获取服务参数")
rollback_context["service_params"] = service_params rollback_context["service_params"] = service_params
#获取嵌入模式
embedding_mode = await get_embedding_mode(orgid)
debug(f"检测到 embedding_mode = {embedding_mode}0=文本, 1=多模态)")
# 加载和分片文档 # 加载和分片文档
chunks = await self.rag_ops.load_and_chunk_document( chunks = await self.rag_ops.load_and_chunk_document(
realpath, timings, transaction_mgr=transaction_mgr realpath, timings, transaction_mgr=transaction_mgr
) )
text_embeddings = None
multi_results = None
image_paths = []
if embedding_mode == 1:
inputs = []
# 文本
for chunk in chunks:
inputs.append({"type": "text", "content": chunk.page_content})
debug("开始多模态图像抽取与嵌入")
image_paths = extract_images_from_file(realpath)
debug(f"从文档中抽取 {len(image_paths)} 张图像")
if image_paths:
for img_path in image_paths:
try:
# 1. 自动识别真实格式
ext = Path(img_path).suffix.lower()
if ext not in {".png", ".jpg", ".jpeg", ".webp", ".bmp"}:
ext = ".jpg"
mime_map = {
".png": "image/png",
".jpg": "image/jpeg",
".jpeg": "image/jpeg",
".webp": "image/webp",
".bmp": "image/bmp"
}
mime_type = mime_map.get(ext, "image/jpeg")
# # 2. 智能压缩(>1MB 才压缩,节省 70% 流量)
# img = Image.open(img_path).convert("RGB")
# if os.path.getsize(img_path) > 1024 * 1024: # >1MB
# buffer = BytesIO()
# img.save(buffer, format="JPEG", quality=85, optimize=True)
# b64 = base64.b64encode(buffer.getvalue()).decode()
# data_uri = f"data:image/jpeg;base64,{b64}"
# else:
b64 = await self.file_to_base64(img_path)
data_uri = f"data:{mime_type};base64,{b64}"
inputs.append({
"type": "image",
"data": data_uri
})
debug(f"已添加图像({mime_type}, {len(b64) / 1024:.1f}KB: {Path(img_path).name}")
except Exception as e:
debug(f"图像处理失败,跳过: {img_path}{e}")
# 即使失败也加个占位,防止顺序错乱
inputs.append({
"type": "image",
"data": ""
})
debug(f"混排输入总数: {len(inputs)}(文本 {len(chunks)} + 图像 {len(image_paths)}")
multi_results = await self.rag_ops.generate_multi_embeddings(
request=request,
inputs=inputs,
service_params=service_params,
userid=userid,
timings=timings,
transaction_mgr=transaction_mgr
)
debug(f"多模态嵌入成功,返回 {len(multi_results)} 条结果")
else:
# 生成嵌入向量 # 生成嵌入向量
embeddings = await self.rag_ops.generate_embeddings( debug("【纯文本模式】使用 BGE 嵌入")
text_embeddings = await self.rag_ops.generate_embeddings(
request, chunks, service_params, userid, timings, transaction_mgr=transaction_mgr request, chunks, service_params, userid, timings, transaction_mgr=transaction_mgr
) )
debug(f"BGE 嵌入完成: {len(text_embeddings)}")
# 插入 Milvus inserted = await self.rag_ops.insert_all_vectors(
chunks_data = await self.rag_ops.insert_to_vector_db( request=request,
request, chunks, embeddings, realpath, orgid, fiid, id, text_chunks=chunks,
service_params, userid, db_type, timings, transaction_mgr=transaction_mgr realpath=realpath,
orgid=orgid,
fiid=fiid,
document_id=id,
service_params=service_params,
userid=userid,
db_type=db_type,
timings=timings,
img_paths=image_paths,
text_embeddings=text_embeddings,
multi_results=multi_results,
transaction_mgr=transaction_mgr
) )
debug(f"统一插入: 文本 {inserted['text']}, 图像 {inserted['image']}, 人脸 {inserted['face']}")
# 抽取三元组 # 抽取三元组
triples = await self.rag_ops.extract_triples( triples = await self.rag_ops.extract_triples(

View File

@ -3,7 +3,6 @@ from ahserver.serverenv import ServerEnv
import aiohttp import aiohttp
from aiohttp import ClientSession, ClientTimeout from aiohttp import ClientSession, ClientTimeout
import json import json
from .file import file_uploaded, file_deleted
from .folderinfo import RagFileMgr from .folderinfo import RagFileMgr
from .ragprogram import set_program, get_rag_programs from .ragprogram import set_program, get_rag_programs
from .ragllm_utils import get_ragllms_by_catelog from .ragllm_utils import get_ragllms_by_catelog

View File

@ -2,6 +2,7 @@ import os
import re import re
import time import time
import math import math
import numpy as np
from datetime import datetime from datetime import datetime
from typing import List, Dict, Any, Optional from typing import List, Dict, Any, Optional
from langchain_core.documents import Document from langchain_core.documents import Document
@ -12,7 +13,10 @@ from filetxt.loader import fileloader, File2Text
from rag.uapi_service import APIService from rag.uapi_service import APIService
from rag.service_opts import get_service_params from rag.service_opts import get_service_params
from rag.transaction_manager import TransactionManager, OperationType from rag.transaction_manager import TransactionManager, OperationType
from pdf2image import convert_from_path
import pytesseract
import base64
from pathlib import Path
class RagOperations: class RagOperations:
"""RAG 操作类,提供所有通用的 RAG 操作""" """RAG 操作类,提供所有通用的 RAG 操作"""
@ -33,14 +37,39 @@ class RagOperations:
if ext not in supported_formats: if ext not in supported_formats:
raise ValueError(f"不支持的文件格式: {ext}, 支持的格式: {', '.join(supported_formats)}") raise ValueError(f"不支持的文件格式: {ext}, 支持的格式: {', '.join(supported_formats)}")
# 加载文件内容
text = fileloader(realpath) text = fileloader(realpath)
if ext == 'pdf':
debug(f"pdf原生提取结果是{text}")
if not text or len(text.strip()) == 0: # 更严格的空值检查
debug(f"pdf原生提取失败尝试扫描件提取")
ocr_text = self.pdf_to_text(realpath)
debug(f"pdf扫描件抽取的文本内容是{ocr_text}")
text = ocr_text # 只在原生提取失败时使用OCR结果
# 只在有文本内容时进行清洗
if text and len(text.strip()) > 0:
# 或者保留更多有用字符
text = re.sub(r'[^\u4e00-\u9fa5a-zA-Z0-9\s.;,\n/]', '', text) text = re.sub(r'[^\u4e00-\u9fa5a-zA-Z0-9\s.;,\n/]', '', text)
else:
error(f"文件 {realpath} 无法提取任何文本内容")
text = "" # 确保为空字符串
timings["load_file"] = time.time() - start_load timings["load_file"] = time.time() - start_load
debug(f"加载文件耗时: {timings['load_file']:.2f} 秒, 文本长度: {len(text)}") debug(f"加载文件耗时: {timings['load_file']:.2f} 秒, 文本长度: {len(text)}")
if not text or not text.strip(): # # 加载文件内容
raise ValueError(f"文件 {realpath} 加载为空") # text = fileloader(realpath)
# debug(f"pdf原生提取结果是{text}")
# if len(text) == 0:
# debug(f"pdf原生提取失败尝试扫描件提取")
# text = self.pdf_to_text(realpath)
# debug(f"pdf扫描件抽取的文本内容是{text}")
# text = re.sub(r'[^\u4e00-\u9fa5a-zA-Z0-9\s.;,\n/]', '', text)
# timings["load_file"] = time.time() - start_load
# debug(f"加载文件耗时: {timings['load_file']:.2f} 秒, 文本长度: {len(text)}")
#
# if not text or not text.strip():
# raise ValueError(f"文件 {realpath} 加载为空")
# 分片处理 # 分片处理
document = Document(page_content=text) document = Document(page_content=text)
@ -67,6 +96,40 @@ class RagOperations:
return chunks return chunks
def pdf_to_text(
self,
pdf_path: str,
output_txt: Optional[str] = None,
dpi: int = 300,
lang: str = 'chi_sim+chi_tra+eng'
) -> str:
"""
将扫描版 PDF 转为文字你原来的代码一行调用版
参数:
pdf_path: PDF 文件路径字符串
output_txt: 如果提供会自动保存到这个 txt 文件可选
dpi: 图片分辨率默认 300越高越清晰
lang: 语言包默认中文简体+繁体+英文
返回:
提取出的完整文字字符串
"""
# PDF 转图片
images = convert_from_path(pdf_path, dpi=dpi)
# OCR 识别
text = ''
for img in images:
text += pytesseract.image_to_string(img, lang=lang) + '\n'
# 可选:自动保存到文件
if output_txt:
with open(output_txt, 'w', encoding='utf-8') as f:
f.write(text)
return text
async def generate_embeddings(self, request, chunks: List[Document], service_params: Dict, async def generate_embeddings(self, request, chunks: List[Document], service_params: Dict,
userid: str, timings: Dict, userid: str, timings: Dict,
transaction_mgr: TransactionManager = None) -> List[List[float]]: transaction_mgr: TransactionManager = None) -> List[List[float]]:
@ -103,84 +166,488 @@ class RagOperations:
return embeddings return embeddings
async def insert_to_vector_db(self, request, chunks: List[Document], embeddings: List[List[float]], async def generate_multi_embeddings(self, request, inputs: List[Dict], service_params: Dict,
realpath: str, orgid: str, fiid: str, id: str, service_params: Dict, userid: str, timings: Dict,
userid: str, db_type: str, timings: Dict, transaction_mgr: TransactionManager = None) -> Dict[str, Dict]:
transaction_mgr: TransactionManager = None): """调用多模态嵌入服务CLIP"""
"""插入向量数据库""" debug("调用多模态嵌入服务")
debug(f"准备数据并调用插入文件端点: {realpath}") start = time.time()
filename = os.path.basename(realpath).rsplit('.', 1)[0]
ext = realpath.rsplit('.', 1)[1].lower() if '.' in realpath else '' result = await self.api_service.get_multi_embeddings(
request=request,
inputs=inputs,
upappid=service_params['embedding'],
apiname="black/clip",
user=userid
)
debug(f"多模态返回结果是{result}")
timings["multi_embedding"] = time.time() - start
debug(f"多模态嵌入耗时: {timings['multi_embedding']:.2f}秒,处理 {len(result)}")
# ==================== 新增:错误检查 + 过滤 ====================
valid_results = {}
error_count = 0
error_examples = []
for key, info in result.items():
if info.get("type") == "error":
error_count += 1
if len(error_examples) < 3: # 只记录前3个
error_examples.append(f"{key}{info['error']}")
# 直接丢弃错误条目
continue
valid_results[key] = info
if error_count > 0:
error(f"多模态嵌入失败 {error_count} 条!示例:{'; '.join(error_examples)}")
raise RuntimeError(f"多模态嵌入有{error_count} 条失败")
else:
debug("多模态嵌入全部成功!")
if transaction_mgr:
transaction_mgr.add_operation(
OperationType.EMBEDDING,
{'count': len(result)}
)
return result
# async def force_l2_normalize(self, vector: List[float]) -> List[float]:
# """万无一失的 L2 归一化"""
# arr = np.array(vector, dtype=np.float32)
# norm = np.linalg.norm(arr)
# if norm == 0:
# return vector # 全零向量无法归一化
# return (arr / norm).tolist()
# 统一插入向量库
async def insert_all_vectors(
self,
request,
text_chunks: List[Document],
realpath: str,
orgid: str,
fiid: str,
document_id: str,
service_params: Dict,
userid: str,
db_type: str,
timings: Dict,
img_paths: List[str] = None,
text_embeddings: List[List[float]] = None,
multi_results: Dict = None,
transaction_mgr: TransactionManager = None
) -> Dict[str, int]:
"""
统一插入函数支持两种模式
1. 纯文本模式text_embeddings 有值
2. 多模态模式multi_results 有值来自 generate_multi_embeddings
"""
img_paths = img_paths or []
all_chunks = []
start = time.time()
filename = os.path.basename(realpath)
upload_time = datetime.now().isoformat() upload_time = datetime.now().isoformat()
chunks_data = [ # ==================== 1. 纯文本模式BGE ====================
{ if text_embeddings is not None:
debug(f"【纯文本模式】插入 {len(text_embeddings)} 条文本向量")
for i, chunk in enumerate(text_chunks):
all_chunks.append({
"userid": orgid, "userid": orgid,
"knowledge_base_id": fiid, "knowledge_base_id": fiid,
"text": chunk.page_content, "text": chunk.page_content,
"vector": embeddings[i], "vector": text_embeddings[i],
"document_id": id, "document_id": document_id,
"filename": filename + '.' + ext, "filename": filename,
"file_path": realpath, "file_path": realpath,
"upload_time": upload_time, "upload_time": upload_time,
"file_type": ext, "file_type": "text",
} })
for i, chunk in enumerate(chunks)
]
start_milvus = time.time() # ==================== 2. 多模态模式CLIP 混排) ====================
for i in range(0, len(chunks_data), 10): if multi_results is not None:
batch_chunks = chunks_data[i:i + 10] debug(f"【多模态模式】解析 {len(multi_results)} 条 CLIP 结果")
debug(f"传入的数据是:{batch_chunks}") # 遍历 multi_results
for raw_key, info in multi_results.items():
typ = info["type"]
# vector = info["vector"]
# debug(f"从后端传回来的向量数据是:{vector}")
# emb = await self.force_l2_normalize(info["vector"])
# debug(f"归一化后的向量数据是:{emb}")
# --- 文本 ---
if typ == "text":
# raw_key 就是原文
all_chunks.append({
"userid": orgid,
"knowledge_base_id": fiid,
"text": raw_key,
"vector": info["vector"],
"document_id": document_id,
"filename": filename,
"file_path": realpath,
"upload_time": upload_time,
"file_type": "text",
})
continue
# --- 图像 ---
if typ == "image":
img_path = info.get("path") or raw_key
img_name = os.path.basename(img_path)
# 整图向量
if "vector" in info:
all_chunks.append({
"userid": orgid,
"knowledge_base_id": fiid,
"text": f"[Image: {img_path}]图片来源于文件{realpath}",
"vector": info["vector"],
"document_id": document_id,
"filename": img_name,
"file_path": realpath,
"upload_time": upload_time,
"file_type": "image",
})
# 人脸向量
face_vecs = info.get("face_vecs", [])
face_count = len(face_vecs)
# if face_count > 0:
# for f_idx, fvec in enumerate(face_vecs):
# debug(f"人脸向量维度是:{len(fvec)}")
# all_chunks.append({
# "userid": orgid,
# "knowledge_base_id": fiid,
# "text": f"[Face {f_idx + 1}/{face_count} in {img_name}]人脸来源于{realpath}的{img_path}图片",
# "vector": fvec,
# "document_id": document_id,
# "filename": img_name,
# "file_path": realpath,
# "upload_time": upload_time,
# "file_type": "face",
# })
continue
# --- 视频 ---
if typ == "video":
video_path = info.get("path") or raw_key
video_name = os.path.basename(video_path)
if "vector" in info:
all_chunks.append({
"userid": orgid,
"knowledge_base_id": fiid,
"text": f"[Video: {video_name}]",
"vector": info["vector"],
"document_id": document_id,
"filename": video_path,
"file_path": realpath,
"upload_time": upload_time,
"file_type": "video",
})
# 视频人脸
face_vecs = info.get("face_vecs", [])
face_count = len(face_vecs)
# if face_count > 0 :
# for f_idx, fvec in enumerate(face_vecs):
# all_chunks.append({
# "userid": orgid,
# "knowledge_base_id": fiid,
# "text": f"[Face {f_idx + 1}/{face_count} in video {video_name}]来源于{video_path}",
# "vector": fvec,
# "document_id": document_id,
# "filename": video_path,
# "file_path": realpath,
# "upload_time": upload_time,
# "file_type": "face",
# })
continue
# --- 音频 ---
if typ == "audio":
audio_path = info.get("path") or raw_key
audio_name = os.path.basename(audio_path)
if "vector" in info:
all_chunks.append({
"userid": orgid,
"knowledge_base_id": fiid,
"text": f"[Audio: {audio_name}]",
"vector": info["vector"],
"document_id": document_id,
"filename": audio_path,
"file_path": realpath,
"upload_time": upload_time,
"file_type": "audio",
})
continue
# --- 未知类型 ---
debug(f"未知类型跳过: {typ}{raw_key}")
# ==================== 3. 批量插入 Milvus ====================
if not all_chunks:
debug("无向量需要插入")
return {"text": 0, "image": 0, "face": 0}
for i in range(0, len(all_chunks), 10):
batch = all_chunks[i:i + 10]
result = await self.api_service.milvus_insert_document( result = await self.api_service.milvus_insert_document(
request=request, request=request,
chunks=batch_chunks, chunks=batch,
db_type=db_type,
upappid=service_params['vdb'], upappid=service_params['vdb'],
apiname="milvus/insertdocument", apiname="milvus/insertdocument",
user=userid user=userid,
db_type=db_type
) )
if result.get("status") != "success": if result.get("status") != "success":
raise ValueError(result.get("message", "Milvus 插入失败")) raise ValueError(f"Milvus 插入失败: {result.get('message')}")
timings["insert_milvus"] = time.time() - start_milvus # ==================== 4. 统一回滚(只登记一次) ====================
debug(f"Milvus 插入耗时: {timings['insert_milvus']:.2f}") if transaction_mgr and all_chunks:
async def rollback_all(data, context):
# 记录事务操作,包含回滚函数
if transaction_mgr:
async def rollback_vdb_insert(data, context):
try: try:
# 防御性检查
required_context = ['request', 'service_params', 'userid']
missing_context = [k for k in required_context if k not in context or context[k] is None]
if missing_context:
raise ValueError(f"回滚上下文缺少字段: {', '.join(missing_context)}")
required_data = ['orgid', 'realpath', 'fiid', 'id', 'db_type']
missing_data = [k for k in required_data if k not in data or data[k] is None]
if missing_data:
raise ValueError(f"VDB_INSERT 数据缺少字段: {', '.join(missing_data)}")
await self.delete_from_vector_db( await self.delete_from_vector_db(
context['request'], data['orgid'], data['realpath'], request=context['request'],
data['fiid'], data['id'], context['service_params'], orgid=data['orgid'],
context['userid'], data['db_type'] realpath=data['realpath'],
fiid=data['fiid'],
id=data['document_id'],
service_params=context['service_params'],
userid=context['userid'],
db_type=data['db_type']
) )
return f"已回滚向量数据库插入: {data['id']}" return f"已回滚 document_id={data['document_id']} 的所有向量"
except Exception as e: except Exception as e:
error(f"回滚向量数据库失败: document_id={data.get('id', '未知')}, 错误: {str(e)}") error(f"统一回滚失败: {e}")
raise raise
transaction_mgr.add_operation( transaction_mgr.add_operation(
OperationType.VDB_INSERT, OperationType.VDB_INSERT,
{ {
'orgid': orgid, 'realpath': realpath, 'fiid': fiid, 'orgid': orgid,
'id': id, 'db_type': db_type 'realpath': realpath,
'fiid': fiid,
'id': document_id,
'db_type': db_type
}, },
rollback_func=rollback_vdb_insert rollback_func=rollback_all
) )
return chunks_data # ==================== 5. 统计返回 ====================
stats = {
"text": len([c for c in all_chunks if c["file_type"] == "text"]),
"image": len([c for c in all_chunks if c["file_type"] == "image"]),
"face": len([c for c in all_chunks if c["file_type"] == "face"])
}
timings["insert_all"] = time.time() - start
debug(
f"统一插入完成: 文本 {stats['text']}, 图像 {stats['image']}, 人脸 {stats['face']}, 耗时 {timings['insert_all']:.2f}s")
return stats
# async def insert_to_vector_db(self, request, chunks: List[Document], embeddings: List[List[float]],
# realpath: str, orgid: str, fiid: str, id: str, service_params: Dict,
# userid: str, db_type: str, timings: Dict,
# transaction_mgr: TransactionManager = None):
# """插入向量数据库"""
# debug(f"准备数据并调用插入文件端点: {realpath}")
# filename = os.path.basename(realpath).rsplit('.', 1)[0]
# ext = realpath.rsplit('.', 1)[1].lower() if '.' in realpath else ''
# upload_time = datetime.now().isoformat()
#
# chunks_data = [
# {
# "userid": orgid,
# "knowledge_base_id": fiid,
# "text": chunk.page_content,
# "vector": embeddings[i],
# "document_id": id,
# "filename": filename + '.' + ext,
# "file_path": realpath,
# "upload_time": upload_time,
# "file_type": ext,
# }
# for i, chunk in enumerate(chunks)
# ]
#
# start_milvus = time.time()
# for i in range(0, len(chunks_data), 10):
# batch_chunks = chunks_data[i:i + 10]
# debug(f"传入的数据是:{batch_chunks}")
# result = await self.api_service.milvus_insert_document(
# request=request,
# chunks=batch_chunks,
# db_type=db_type,
# upappid=service_params['vdb'],
# apiname="milvus/insertdocument",
# user=userid
# )
# if result.get("status") != "success":
# raise ValueError(result.get("message", "Milvus 插入失败"))
#
# timings["insert_milvus"] = time.time() - start_milvus
# debug(f"Milvus 插入耗时: {timings['insert_milvus']:.2f} 秒")
#
# # 记录事务操作,包含回滚函数
# if transaction_mgr:
# async def rollback_vdb_insert(data, context):
# try:
# # 防御性检查
# required_context = ['request', 'service_params', 'userid']
# missing_context = [k for k in required_context if k not in context or context[k] is None]
# if missing_context:
# raise ValueError(f"回滚上下文缺少字段: {', '.join(missing_context)}")
#
# required_data = ['orgid', 'realpath', 'fiid', 'id', 'db_type']
# missing_data = [k for k in required_data if k not in data or data[k] is None]
# if missing_data:
# raise ValueError(f"VDB_INSERT 数据缺少字段: {', '.join(missing_data)}")
#
# await self.delete_from_vector_db(
# context['request'], data['orgid'], data['realpath'],
# data['fiid'], data['id'], context['service_params'],
# context['userid'], data['db_type']
# )
# return f"已回滚向量数据库插入: {data['id']}"
# except Exception as e:
# error(f"回滚向量数据库失败: document_id={data.get('id', '未知')}, 错误: {str(e)}")
# raise
#
# transaction_mgr.add_operation(
# OperationType.VDB_INSERT,
# {
# 'orgid': orgid, 'realpath': realpath, 'fiid': fiid,
# 'id': id, 'db_type': db_type
# },
# rollback_func=rollback_vdb_insert
# )
#
# return chunks_data
#
# async def insert_image_vectors(
# self,
# request,
# multi_results: Dict[str, Dict],
# realpath: str,
# orgid: str,
# fiid: str,
# document_id: str,
# service_params: Dict,
# userid: str,
# db_type: str,
# timings: Dict,
# transaction_mgr: TransactionManager = None
# ) -> tuple[int, int]:
#
# start = time.time()
# image_chunks = []
# face_chunks = []
#
# for img_path, info in multi_results.items():
# # img_name = os.path.basename(img_path)
#
# # 1. 插入整张图
# if info.get("type") in ["image", "video"] and "vector" in info:
# image_chunks.append({
# "userid": orgid,
# "knowledge_base_id": fiid,
# "text": f"[Image: {img_path}]",
# "vector": info["vector"],
# "document_id": document_id,
# "filename": os.path.basename(realpath),
# "file_path": realpath,
# "upload_time": datetime.now().isoformat(),
# "file_type": "image"
# })
#
# # 2. 插入每张人脸
# face_vecs = info.get("face_vecs")
# face_count = info.get("face_count", 0)
#
# if face_count > 0 and face_vecs and len(face_vecs) == face_count:
# for idx, face_vec in enumerate(face_vecs):
# face_chunks.append({
# "userid": orgid,
# "knowledge_base_id": fiid,
# "text": f"[Face {idx + 1}/{face_count} in {img_path}]",
# "vector": face_vec,
# "document_id": document_id,
# "filename": os.path.basename(realpath),
# "file_path": realpath,
# "upload_time": datetime.now().isoformat(),
# "file_type": "face",
# })
#
# if image_chunks:
# for i in range(0, len(image_chunks), 10):
# await self.api_service.milvus_insert_document(
# request=request,
# chunks=image_chunks[i:i + 10],
# upappid=service_params['vdb'],
# apiname="milvus/insertdocument",
# user=userid,
# db_type=db_type
# )
#
# if face_chunks:
# for i in range(0, len(face_chunks), 10):
# await self.api_service.milvus_insert_document(
# request=request,
# chunks=face_chunks[i:i + 10],
# upappid=service_params['vdb'],
# apiname="milvus/insertdocument",
# user=userid,
# db_type=db_type
# )
# timings["insert_images"] = time.time() - start
# image_count = len(image_chunks)
# face_count = len(face_chunks)
#
# debug(f"多模态插入完成: 图像 {image_count} 条, 人脸 {face_count} 条")
#
# if transaction_mgr and (image_count + face_count > 0):
# transaction_mgr.add_operation(
# OperationType.IMAGE_VECTORS_INSERT,
# {"images": image_count, "faces": face_count, "document_id": document_id}
# )
#
# # 记录事务操作,包含回滚函数
# if transaction_mgr:
# async def rollback_multimodal(data, context):
# try:
# # 防御性检查
# required_context = ['request', 'service_params', 'userid']
# missing_context = [k for k in required_context if k not in context or context[k] is None]
# if missing_context:
# raise ValueError(f"回滚上下文缺少字段: {', '.join(missing_context)}")
#
# required_data = ['orgid', 'realpath', 'fiid', 'id', 'db_type']
# missing_data = [k for k in required_data if k not in data or data[k] is None]
# if missing_data:
# raise ValueError(f"多模态回滚数据缺少字段: {', '.join(missing_data)}")
#
# await self.delete_from_vector_db(
# context['request'], data['orgid'], data['realpath'],
# data['fiid'], data['id'], context['service_params'],
# context['userid'], data['db_type']
# )
# return f"已回滚多模态向量: {data['id']}"
# except Exception as e:
# error(f"多模态回滚向量数据库失败: document_id={data.get('id', '未知')}, 错误: {str(e)}")
# raise
#
# transaction_mgr.add_operation(
# OperationType.VDB_INSERT,
# {
# 'orgid': orgid, 'realpath': realpath, 'fiid': fiid,
# 'id': id, 'db_type': db_type
# },
# rollback_func=rollback_multimodal
# )
#
# return image_count, face_count
async def insert_to_vector_text(self, request, async def insert_to_vector_text(self, request,
db_type: str, fields: Dict, service_params: Dict, userid: str, timings: Dict) -> List[Dict]: db_type: str, fields: Dict, service_params: Dict, userid: str, timings: Dict) -> List[Dict]:
@ -383,25 +850,150 @@ class RagOperations:
debug(f"三元组匹配总耗时: {timings['triplet_matching']:.3f}") debug(f"三元组匹配总耗时: {timings['triplet_matching']:.3f}")
return all_triplets return all_triplets
async def generate_query_vector(self, request, text: str, service_params: Dict, async def generate_query_vector(
userid: str, timings: Dict) -> List[float]: self,
"""生成查询向量""" request,
debug(f"生成查询向量: {text[:200]}...") text: str,
service_params: Dict,
userid: str,
timings: Dict,
embedding_mode: int = 0
) -> List[float]:
"""生成查询向量(支持文本/多模态)"""
debug(f"生成查询向量: mode={embedding_mode}, text='{text[:100]}...'")
start_vector = time.time() start_vector = time.time()
query_vector = await self.api_service.get_embeddings(
if embedding_mode == 0:
# === 模式 0纯文本嵌入BAAI/bge-m3===
debug("使用 BAAI/bge-m3 文本嵌入")
vectors = await self.api_service.get_embeddings(
request=request, request=request,
texts=[text], texts=[text],
upappid=service_params['embedding'], upappid=service_params['embedding'],
apiname="BAAI/bge-m3", apiname="BAAI/bge-m3",
user=userid user=userid
) )
if not query_vector or not all(len(vec) == 1024 for vec in query_vector): if not vectors or not isinstance(vectors, list) or len(vectors) == 0:
raise ValueError("查询向量必须是长度为 1024 的浮点数列表") raise ValueError("bge-m3 返回空结果")
query_vector = query_vector[0] query_vector = vectors[0]
if len(query_vector) != 1024:
raise ValueError(f"bge-m3 返回向量维度错误: {len(query_vector)}")
elif embedding_mode == 1:
# === 模式 1多模态嵌入black/clip===
debug("使用 black/clip 多模态嵌入")
inputs = [{"type": "text", "content": text}]
result = await self.api_service.get_multi_embeddings(
request=request,
inputs=inputs,
upappid=service_params['embedding'],
apiname="black/clip",
user=userid
)
query_vector = None
for key, info in result.items():
if info.get("type") == "error":
debug(f"CLIP 返回错误跳过: {info['error']}")
continue
if "vector" in info and isinstance(info["vector"], list) and len(info["vector"]) == 1024:
query_vector = info["vector"]
debug(f"成功获取 CLIP 向量(来自 {info['type']}")
break
if query_vector is None:
raise ValueError("black/clip 未返回任何有效 1024 维向量")
else:
raise ValueError(f"不支持的 embedding_mode: {embedding_mode}")
# 最终统一校验
if not isinstance(query_vector, list) or len(query_vector) != 1024:
raise ValueError(f"查询向量必须是长度为 1024 的浮点数列表,实际: {len(query_vector)}")
timings["vector_generation"] = time.time() - start_vector timings["vector_generation"] = time.time() - start_vector
debug(f"生成查询向量耗时: {timings['vector_generation']:.3f}") debug(f"生成查询向量成功,耗时: {timings['vector_generation']:.3f},模式: {embedding_mode}")
return query_vector return query_vector
async def generate_image_vector(
self,
request,
img_path: str,
service_params: Dict,
userid: str,
timings: Dict,
embedding_mode: int = 0
) -> List[float]:
"""生成查询向量(支持文本/多模态)"""
debug(f"生成查询向量: mode={embedding_mode}, image={img_path}")
start_vector = time.time()
if embedding_mode == 0:
raise ValueError(f"纯文本没有这个功能,请重新选择服务")
elif embedding_mode == 1:
# === 模式 1多模态嵌入black/clip===
debug("使用 black/clip 多模态嵌入")
inputs = []
try:
ext = Path(img_path).suffix.lower()
if ext not in {".png", ".jpg", ".jpeg", ".webp", ".bmp"}:
ext = ".jpg"
mime_map = {
".png": "image/png",
".jpg": "image/jpeg",
".jpeg": "image/jpeg",
".webp": "image/webp",
".bmp": "image/bmp"
}
mime_type = mime_map.get(ext, "image/jpeg")
with open(img_path, "rb") as f:
b64 = base64.b64encode(f.read()).decode("utf-8")
data_uri = f"data:{mime_type};base64,{b64}"
inputs.append({
"type": "image",
"data": data_uri
})
debug(f"已添加图像({mime_type}, {len(b64) / 1024:.1f}KB: {Path(img_path).name}")
except Exception as e:
debug(f"图像处理失败,跳过: {img_path}{e}")
result = await self.api_service.get_multi_embeddings(
request=request,
inputs=inputs,
upappid=service_params['embedding'],
apiname="black/clip",
user=userid
)
image_vector = None
for key, info in result.items():
if info.get("type") == "error":
debug(f"CLIP 返回错误跳过: {info['error']}")
continue
if "vector" in info and isinstance(info["vector"], list) and len(info["vector"]) == 1024:
image_vector = info["vector"]
debug(f"成功获取 CLIP 向量(来自 {info['type']}")
break
if image_vector is None:
raise ValueError("black/clip 未返回任何有效 1024 维向量")
else:
raise ValueError(f"不支持的 embedding_mode: {embedding_mode}")
# 最终统一校验
if not isinstance(image_vector, list) or len(image_vector) != 1024:
raise ValueError(f"查询向量必须是长度为 1024 的浮点数列表,实际: {len(image_vector)}")
timings["vector_generation"] = time.time() - start_vector
debug(f"生成查询向量成功,耗时: {timings['vector_generation']:.3f} 秒,模式: {embedding_mode}")
return image_vector
async def vector_search(self, request, query_vector: List[float], orgid: str, async def vector_search(self, request, query_vector: List[float], orgid: str,
fiids: List[str], limit: int, service_params: Dict, userid: str, fiids: List[str], limit: int, service_params: Dict, userid: str,
timings: Dict) -> List[Dict]: timings: Dict) -> List[Dict]:
@ -473,34 +1065,49 @@ class RagOperations:
return unique_triples return unique_triples
def format_search_results(self, results: List[Dict], limit: int) -> List[Dict]: def format_search_results(self, results: List[Dict], limit: int) -> List[Dict]:
"""格式化搜索结果为统一格式""" formatted = []
formatted_results = []
# for res in results[:limit]:
# score = res.get('rerank_score', res.get('distance', 0))
#
# content = res.get('text', '')
# title = res.get('metadata', {}).get('filename', 'Untitled')
# document_id = res.get('metadata', {}).get('document_id', '')
#
# formatted_results.append({
# "content": content,
# "title": title,
# "metadata": {"document_id": document_id, "score": score},
# })
#得分归一化
for res in results[:limit]: for res in results[:limit]:
rerank_score = res.get('rerank_score', 0) # # 优先 rerank其次用向量相似度直接用不要反
score = 1 / (1 + math.exp(-rerank_score)) if rerank_score is not None else 1 - res.get('distance', 0) # if res.get('rerank_score') is not None:
score = max(0.0, min(1.0, score)) # score = res.get('rerank_score')
# else:
content = res.get('text', '') # score = res.get('distance', 0.0)
title = res.get('metadata', {}).get('filename', 'Untitled') distance = res.get('distance', 0.0)
document_id = res.get('metadata', {}).get('document_id', '') rerank_score = res.get('rerank_score', 0.0)
formatted.append({
formatted_results.append({ "content": res.get('text', ''),
"content": content, "title": res.get('metadata', {}).get('filename', 'Untitled'),
"title": title, "metadata": {
"metadata": {"document_id": document_id, "score": score}, "document_id": res.get('metadata', {}).get('document_id', ''),
"distance": distance,
"rerank_score": rerank_score,
}
}) })
return formatted
return formatted_results # async def save_uploaded_photo(self, image_file: FileStorage, orgid: str) -> str:
# """
# 把前端上传的图片保存到 /home/wangmeihua/kyrag/data/photo 目录下
# 返回保存后的绝对路径(字符串),供 generate_img_vector 使用
# """
# if not image_file or not hasattr(image_file, "filename"):
# raise ValueError("无效的图片上传对象")
#
# # 为了安全,按 orgid 分目录存放(避免不同公司文件混在一起)
# org_dir = UPLOAD_PHOTO_DIR / orgid
# org_dir.mkdir(parents=True, exist_ok=True)
#
# # 生成唯一文件名,保留原始后缀
# suffix = Path(image_file.filename).suffix.lower()
# if not suffix or suffix not in {".jpg", ".jpeg", ".png", ".webp", ".bmp", ".gif"}:
# suffix = ".jpg"
#
# unique_name = f"{uuid.uuid4().hex}{suffix}"
# save_path = org_dir / unique_name
#
# # 真正落盘
# image_file.save(str(save_path))
# debug(f"图片已保存: {save_path} (原始名: {image_file.filename})")
#
# # 返回字符串路径generate_img_vector 直接收 str 就行
# return str(save_path)

View File

@ -6,8 +6,12 @@ import traceback
import json import json
import math import math
import uuid import uuid
from rag.service_opts import get_service_params, sor_get_service_params import os
from rag.service_opts import get_service_params, sor_get_service_params, sor_get_embedding_mode, get_embedding_mode
from rag.rag_operations import RagOperations from rag.rag_operations import RagOperations
from langchain_core.documents import Document
REAL_PHOTO_ROOT = "/home/wangmeihua/kyrag/files"
helptext = """kyrag API: helptext = """kyrag API:
@ -133,7 +137,16 @@ async def fusedsearch(request, params_kw, *params):
debug(f"params_kw: {params_kw}") debug(f"params_kw: {params_kw}")
# orgid = "04J6VbxLqB_9RPMcgOv_8" # orgid = "04J6VbxLqB_9RPMcgOv_8"
# userid = "04J6VbxLqB_9RPMcgOv_8" # userid = "04J6VbxLqB_9RPMcgOv_8"
query = params_kw.get('query', '') query = params_kw.get('query', '').strip()
img_path = params_kw.get('image')
if isinstance(img_path, str):
img_path = img_path.strip()
relative_part = img_path.lstrip("/")
real_img_path = os.path.join(REAL_PHOTO_ROOT, relative_part)
if not os.path.exists(real_img_path):
raise FileNotFoundError(f"图片不存在: {real_img_path}")
img_path = real_img_path
debug(f"自动修复图片路径成功: {img_path}")
# 统一模式处理 limit 参数,为了对接dify和coze # 统一模式处理 limit 参数,为了对接dify和coze
raw_limit = params_kw.get('limit') or ( raw_limit = params_kw.get('limit') or (
params_kw.get('retrieval_setting', {}).get('top_k') params_kw.get('retrieval_setting', {}).get('top_k')
@ -188,7 +201,65 @@ async def fusedsearch(request, params_kw, *params):
service_params = await get_service_params(orgid) service_params = await get_service_params(orgid)
if not service_params: if not service_params:
raise ValueError("无法获取服务参数") raise ValueError("无法获取服务参数")
# 获取嵌入模式
embedding_mode = await get_embedding_mode(orgid)
debug(f"检测到 embedding_mode = {embedding_mode}0=文本, 1=多模态)")
# 情况1query 和 image 都为空 → 报错
if not query and not img_path:
raise ValueError("查询文本和图片不能同时为空")
# 情况2query 和 image 都存在 → 报错(你当前业务不允许同时传)
if query and img_path:
raise ValueError("查询文本和图片只能二选一,不能同时提交")
# 3. 只有图片 → 以图搜图 走纯多模态分支
if img_path and not query:
try:
debug("检测到纯图片查询,执行以图搜图")
rag_ops = RagOperations()
timings = {}
start_time = time.time()
# 直接生成图片向量
img_vector = await rag_ops.generate_image_vector(
request, img_path, service_params, userid, timings, embedding_mode
)
# 向量搜索(多取 50 条再截断,和文本分支保持一致)
search_results = await rag_ops.vector_search(
request, img_vector, orgid, fiids, limit + 50, service_params, userid, timings
)
timings["total_time"] = time.time() - start_time
# 可选:搜索完后删除图片,省磁盘(看你需求)
# try:
# os.remove(img_path)
# except:
# pass
final_results = []
for item in search_results[:limit]:
final_results.append({
"text": item["text"],
"distance": item["distance"]
})
return {
"results": final_results,
"timings": timings
}
except Exception as e:
error(f"融合搜索失败: {str(e)}, 堆栈: {traceback.format_exc()}")
return {
"records": [],
"timings": {"total_time": time.time() - start_time if 'start_time' in locals() else 0},
"error": str(e)
}
if not img_path and query:
try: try:
timings = {} timings = {}
start_time = time.time() start_time = time.time()
@ -198,8 +269,8 @@ async def fusedsearch(request, params_kw, *params):
all_triplets = await rag_ops.match_triplets(request, query, query_entities, orgid, fiids, service_params, all_triplets = await rag_ops.match_triplets(request, query, query_entities, orgid, fiids, service_params,
userid, timings) userid, timings)
combined_text = _combine_query_with_triplets(query, all_triplets) combined_text = _combine_query_with_triplets(query, all_triplets)
query_vector = await rag_ops.generate_query_vector(request, combined_text, service_params, userid, timings) query_vector = await rag_ops.generate_query_vector(request, combined_text, service_params, userid, timings, embedding_mode)
search_results = await rag_ops.vector_search(request, query_vector, orgid, fiids, limit + 5, service_params, search_results = await rag_ops.vector_search(request, query_vector, orgid, fiids, limit + 50, service_params,
userid, timings) userid, timings)
use_rerank = True use_rerank = True
@ -212,7 +283,7 @@ async def fusedsearch(request, params_kw, *params):
formatted_results = rag_ops.format_search_results(final_results, limit) formatted_results = rag_ops.format_search_results(final_results, limit)
timings["total_time"] = time.time() - start_time timings["total_time"] = time.time() - start_time
info(f"融合搜索完成,返回 {len(formatted_results)} 条结果,总耗时: {timings['total_time']:.3f}") debug(f"融合搜索完成,返回 {len(formatted_results)} 条结果,总耗时: {timings['total_time']:.3f}")
return { return {
"records": formatted_results, "records": formatted_results,
@ -226,6 +297,116 @@ async def fusedsearch(request, params_kw, *params):
"error": str(e) "error": str(e)
} }
# async def fusedsearch(request, params_kw, *params):
# """
# 融合搜索,调用服务化端点
#
# """
# kw = request._run_ns
# f = kw.get('get_userorgid')
# orgid = await f()
# debug(f"orgid: {orgid}{f=}")
# f = kw.get('get_user')
# userid = await f()
# debug(f"params_kw: {params_kw}")
# # orgid = "04J6VbxLqB_9RPMcgOv_8"
# # userid = "04J6VbxLqB_9RPMcgOv_8"
# query = params_kw.get('query', '')
# # 统一模式处理 limit 参数,为了对接dify和coze
# raw_limit = params_kw.get('limit') or (
# params_kw.get('retrieval_setting', {}).get('top_k')
# if isinstance(params_kw.get('retrieval_setting'), dict)
# else None
# )
#
# # 标准化为整数值
# if raw_limit is None:
# limit = 5 # 两个来源都不存在时使用默认值
# elif isinstance(raw_limit, (int, float)):
# limit = int(raw_limit) # 数值类型直接转换
# elif isinstance(raw_limit, str):
# try:
# # 字符串转换为整数
# limit = int(raw_limit)
# except (TypeError, ValueError):
# limit = 5 # 转换失败使用默认值
# else:
# limit = 5 # 其他意外类型使用默认值
# debug(f"limit: {limit}")
# raw_fiids = params_kw.get('fiids') or params_kw.get('knowledge_id') #
#
# # 标准化为列表格式
# if raw_fiids is None:
# fiids = [] # 两个参数都不存在
# elif isinstance(raw_fiids, list):
# fiids = [str(item).strip() for item in raw_fiids] # 已经是列表
# elif isinstance(raw_fiids, str):
# # fiids = [f.strip() for f in raw_fiids.split(',') if f.strip()]
# try:
# # 尝试解析 JSON 字符串
# parsed = json.loads(raw_fiids)
# if isinstance(parsed, list):
# fiids = [str(item).strip() for item in parsed] # JSON 数组转为字符串列表
# else:
# # 处理逗号分隔的字符串或单个 ID 字符串
# fiids = [f.strip() for f in raw_fiids.split(',') if f.strip()]
# except json.JSONDecodeError:
# # 如果不是合法 JSON按逗号分隔
# fiids = [f.strip() for f in raw_fiids.split(',') if f.strip()]
# elif isinstance(raw_fiids, (int, float)):
# fiids = [str(int(raw_fiids))] # 数值类型转为字符串列表
# else:
# fiids = [] # 其他意外类型
#
# debug(f"fiids: {fiids}")
#
# # 验证 fiids的orgid与orgid = await f()是否一致
# await _validate_fiids_orgid(fiids, orgid, kw)
#
# service_params = await get_service_params(orgid)
# if not service_params:
# raise ValueError("无法获取服务参数")
# # 获取嵌入模式
# embedding_mode = await get_embedding_mode(orgid)
# debug(f"检测到 embedding_mode = {embedding_mode}0=文本, 1=多模态)")
#
# try:
# timings = {}
# start_time = time.time()
# rag_ops = RagOperations()
#
# query_entities = await rag_ops.extract_entities(request, query, service_params, userid, timings)
# all_triplets = await rag_ops.match_triplets(request, query, query_entities, orgid, fiids, service_params,
# userid, timings)
# combined_text = _combine_query_with_triplets(query, all_triplets)
# query_vector = await rag_ops.generate_query_vector(request, combined_text, service_params, userid, timings, embedding_mode)
# search_results = await rag_ops.vector_search(request, query_vector, orgid, fiids, limit + 50, service_params,
# userid, timings)
#
# use_rerank = False
# if use_rerank and search_results:
# final_results = await rag_ops.rerank_results(request, combined_text, search_results, limit, service_params,
# userid, timings)
# debug(f"final_results: {final_results}")
# else:
# final_results = [{k: v for k, v in r.items() if k != 'rerank_score'} for r in search_results]
#
# formatted_results = rag_ops.format_search_results(final_results, limit)
# timings["total_time"] = time.time() - start_time
# debug(f"融合搜索完成,返回 {len(formatted_results)} 条结果,总耗时: {timings['total_time']:.3f} 秒")
#
# return {
# "records": formatted_results,
# "timings": timings
# }
# except Exception as e:
# error(f"融合搜索失败: {str(e)}, 堆栈: {traceback.format_exc()}")
# return {
# "records": [],
# "timings": {"total_time": time.time() - start_time if 'start_time' in locals() else 0},
# "error": str(e)
# }
# async def text_insert(text: str, fiid: str, orgid: str, db_type: str): # async def text_insert(text: str, fiid: str, orgid: str, db_type: str):
async def textinsert(request, params_kw, *params): async def textinsert(request, params_kw, *params):
kw = request._run_ns kw = request._run_ns
@ -254,7 +435,7 @@ async def textinsert(request, params_kw, *params):
result = { result = {
"status": "error", "status": "error",
"userid": orgid, "userid": orgid,
"collection_name": "ragdb_{dbtype}", "collection_name": f"ragdb_{db_type}",
"message": "", "message": "",
"status_code": 400 "status_code": 400
} }
@ -277,10 +458,10 @@ async def textinsert(request, params_kw, *params):
# 插入 Milvus # 插入 Milvus
fields = { fields = {
"text": text, "text": text,
"fiid": fiid, "knowledge_base_id": fiid,
"orgid": orgid, "userid": orgid,
"vector": embedding, "vector": embedding,
"id": id "document_id": id
} }
chunks_data = await rag_ops.insert_to_vector_text(request, db_type, fields, service_params, userid, timings) chunks_data = await rag_ops.insert_to_vector_text(request, db_type, fields, service_params, userid, timings)

View File

@ -57,11 +57,12 @@ async def sor_get_service_params(sor, orgid):
service_params['reranker'] = service['upappid'] service_params['reranker'] = service['upappid']
elif name == 'mrebel三元组抽取': elif name == 'mrebel三元组抽取':
service_params['triples'] = service['upappid'] service_params['triples'] = service['upappid']
elif name == 'neo4j删除知识库': elif name == 'neo4j知识库':
service_params['gdb'] = service['upappid'] service_params['gdb'] = service['upappid']
elif name == 'small实体抽取': elif name == 'small实体抽取':
service_params['entities'] = service['upappid'] service_params['entities'] = service['upappid']
elif name == 'clip多模态嵌入服务':
service_params['embedding'] = service['upappid']
# 检查是否所有服务参数都已填充 # 检查是否所有服务参数都已填充
missing_services = [k for k, v in service_params.items() if v is None] missing_services = [k for k, v in service_params.items() if v is None]
if missing_services: if missing_services:
@ -76,3 +77,25 @@ async def get_service_params(orgid):
async with db.sqlorContext(dbname) as sor: async with db.sqlorContext(dbname) as sor:
return await sor_get_service_params(sor, orgid) return await sor_get_service_params(sor, orgid)
return None return None
async def sor_get_embedding_mode(sor, orgid) -> int:
"""根据 orgid 获取嵌入模式0=纯文本1=多模态"""
sql = """
SELECT em.mode
FROM service_opts so
JOIN embedding_mode em ON so.embedding_id = em.embeddingid
WHERE so.orgid = ${orgid}$
"""
rows = await sor.sqlExe(sql, {"orgid": orgid})
if not rows:
debug(f"orgid={orgid} 未配置 embedding_mode默认为 0纯文本")
return 0
return int(rows[0].mode)
async def get_embedding_mode(orgid):
db = DBPools()
# debug(f"传入的orgid是{orgid}")
dbname = get_serverenv('get_module_dbname')('rag')
async with db.sqlorContext(dbname) as sor:
return await sor_get_embedding_mode(sor, orgid)
return None

View File

@ -20,7 +20,6 @@ class OperationType(Enum):
VECTOR_SEARCH = "vector_search" VECTOR_SEARCH = "vector_search"
RERANK = "rerank" RERANK = "rerank"
@dataclass @dataclass
class RollbackOperation: class RollbackOperation:
"""回滚操作记录""" """回滚操作记录"""

View File

@ -45,6 +45,43 @@ class APIService:
error(f"request #{request_id} 嵌入服务调用失败: {str(e)}, upappid={upappid}, apiname={apiname}") error(f"request #{request_id} 嵌入服务调用失败: {str(e)}, upappid={upappid}, apiname={apiname}")
raise RuntimeError(f"嵌入服务调用失败: {str(e)}") raise RuntimeError(f"嵌入服务调用失败: {str(e)}")
#多模态嵌入服务
async def get_multi_embeddings(
self,
request,
inputs: List[Dict],
upappid: str,
apiname: str,
user: str
) -> Dict[str, Dict]:
"""
多模态统一嵌入支持文本图片音频视频
返回原始输入字符串为 key 的完整结果 type / vector / 人脸信息
"""
request_id = str(uuid.uuid4())
debug(f"Request #{request_id} 多模态嵌入开始,共{len(inputs)}")
if not inputs or not isinstance(inputs, list):
raise ValueError("inputs 必须为非空列表")
try:
uapi = UAPI(request, DictObject(**globals()))
params_kw = {"inputs": inputs}
b = await uapi.call(upappid, apiname, user, params_kw)
d = await self.handle_uapi_response(b, upappid, apiname, "多模态嵌入服务", request_id)
if d.get("object") != "embedding.result" or "data" not in d:
error(f"request #{request_id} 返回格式错误: {d}")
raise RuntimeError("多模态嵌入返回格式错误")
result = d["data"] # 直接返回 {input_str: {type, vector, ...}}
debug(f"request #{request_id} 成功获取 {len(result)} 条多模态向量")
return result
except Exception as e:
error(f"request #{request_id} 多模态嵌入失败: {str(e)}")
raise RuntimeError(f"多模态嵌入失败: {str(e)}")
# 实体提取服务 (LTP/small) # 实体提取服务 (LTP/small)
async def extract_entities(self, request, query: str, upappid: str, apiname: str, user: str) -> list: async def extract_entities(self, request, query: str, upappid: str, apiname: str, user: str) -> list:
"""调用实体识别服务""" """调用实体识别服务"""

View File

@ -18,6 +18,11 @@
"editable": true, "editable": true,
"rows": 5 "rows": 5
}, },
{
"uitype": "image",
"name": "image",
"label": "上传查询图片(可选)"
},
{ {
"name": "fiids", "name": "fiids",
"uitype": "checkbox", "uitype": "checkbox",

View File

@ -7,14 +7,15 @@ if not orgid:
message='请先登录' message='请先登录'
) )
fiids = params_kw.fiids
query = params_kw.query query = params_kw.query
image = params_kw.image
fiids = params_kw.fiids
limit = params_kw.limit limit = params_kw.limit
if not query or not fiids or not limit: if (not query and not image) or not fiids or not limit:
return UiError( return UiError(
title='无效输入', title='无效输入',
message='请输入查询文本并选择至少一个知识库' message='请输入查询文本或上传image并选择至少一个知识库和填写返回条数'
) )
try: try:

View File

@ -0,0 +1,9 @@
debug(f'{params_kw=}')
text = params_kw.text
fiid = params_kw.fiid
db_type = params_kw.db_type
env = DictObject(**globals())
keys = [k for k in env.keys()]
debug(f'{keys=}')
x = await rfexe('textinsert', request, params_kw)
return x