Merge branch 'main' of git.opencomputing.cn:yumoqing/rag
This commit is contained in:
commit
6155ea176b
16
json/embeddingmode.json
Normal file
16
json/embeddingmode.json
Normal file
@ -0,0 +1,16 @@
|
||||
{
|
||||
"tblname": "embedding_mode",
|
||||
"title":"嵌入模式",
|
||||
"params":{
|
||||
"browserfields":{
|
||||
"exclouded":["id"],
|
||||
"alters":{}
|
||||
},
|
||||
"editexclouded":["id"],
|
||||
"toolbar":{
|
||||
},
|
||||
"binds":[
|
||||
]
|
||||
}
|
||||
}
|
||||
|
||||
BIN
models/embeddingmode.xlsx
Normal file
BIN
models/embeddingmode.xlsx
Normal file
Binary file not shown.
341
rag/fileprocess.py
Normal file
341
rag/fileprocess.py
Normal file
@ -0,0 +1,341 @@
|
||||
import numpy as np
|
||||
import os
|
||||
import re
|
||||
from pdf2image import convert_from_path
|
||||
from appPublic.log import debug, error, info
|
||||
from pathlib import Path
|
||||
import zipfile
|
||||
import xml.etree.ElementTree as ET
|
||||
from PIL import Image
|
||||
from typing import List
|
||||
|
||||
# ==================== 新增:路径安全化函数 ====================
|
||||
def safe_filename(name: str) -> str:
|
||||
"""
|
||||
安全化文件名/目录名
|
||||
- 去除首尾空格
|
||||
- 多个空格 → 单空格
|
||||
- 非法字符 → 下划线
|
||||
- 空格 → 下划线(推荐!永不炸)
|
||||
"""
|
||||
name = name.strip()
|
||||
name = re.sub(r'\s+', ' ', name) # 多个空格合并
|
||||
name = re.sub(r'[<>:"/\\|?*]', '_', name) # 非法字符
|
||||
name = name.replace(' ', '_') # 空格 → 下划线(关键!)
|
||||
return name
|
||||
|
||||
def render_pdf_to_images(pdf_path, base_output_dir, dpi=200, image_format="PNG")-> List[str]:
|
||||
"""
|
||||
将PDF文件的每一页渲染为图片
|
||||
|
||||
参数:
|
||||
pdf_path (str): PDF文件路径
|
||||
page_output_dir (str): 输出图片的目录
|
||||
dpi (int): 图片分辨率,默认200
|
||||
image_format (str): 图片格式,默认PNG
|
||||
|
||||
返回:
|
||||
int: 成功渲染的页面数量
|
||||
"""
|
||||
pdf_filename = safe_filename(Path(pdf_path).stem)
|
||||
page_output_dir = os.path.join(base_output_dir, pdf_filename)
|
||||
|
||||
# 创建输出目录(如果不存在)
|
||||
if not os.path.exists(page_output_dir):
|
||||
os.makedirs(page_output_dir, exist_ok=True)
|
||||
debug(f"创建输出目录: {page_output_dir}")
|
||||
|
||||
try:
|
||||
# 检查PDF文件是否存在
|
||||
if not os.path.exists(pdf_path):
|
||||
error(f"PDF文件不存在: {pdf_path}")
|
||||
return []
|
||||
|
||||
debug(f"开始渲染PDF: {pdf_path}")
|
||||
debug(f"输出目录: {page_output_dir}")
|
||||
debug(f"分辨率: {dpi} DPI, 格式: {image_format}")
|
||||
|
||||
# 渲染PDF页面为图片
|
||||
pages = convert_from_path(pdf_path, dpi=dpi)
|
||||
|
||||
debug(f"PDF总页数: {len(pages)}")
|
||||
debug("📄 正在渲染 PDF 页面...")
|
||||
|
||||
img_paths = []
|
||||
for i, page in enumerate(pages, start=1):
|
||||
try:
|
||||
# 生成图片文件路径
|
||||
img_path = os.path.join(page_output_dir, f"page_{i:03d}.{image_format.lower()}")
|
||||
img_paths.append(img_path)
|
||||
# 保存图片
|
||||
page.save(img_path, image_format)
|
||||
debug(f"✅ 已保存 {img_path}")
|
||||
|
||||
except Exception as e:
|
||||
error(f"保存第 {i} 页失败: {e}")
|
||||
continue
|
||||
|
||||
debug(f"渲染完成: 成功保存{len(pages)} 页")
|
||||
return img_paths
|
||||
|
||||
except Exception as e:
|
||||
error(f"渲染PDF失败: {e}")
|
||||
return []
|
||||
|
||||
|
||||
def extract_images_from_word(doc_path, base_output_dir) -> List[str]:
|
||||
"""
|
||||
从Word文档中提取所有图像
|
||||
|
||||
参数:
|
||||
doc_path (str): Word文档路径(.docx格式)
|
||||
base_output_dir (str): 基础输出目录,会在此目录下创建以文档名命名的子文件夹
|
||||
|
||||
返回:
|
||||
int: 成功提取的图像数量
|
||||
"""
|
||||
# 检查文件是否为.docx格式
|
||||
if not doc_path.lower().endswith('.docx'):
|
||||
error(f"仅支持.docx格式的Word文档: {doc_path}")
|
||||
return []
|
||||
|
||||
# 从文档路径提取文件名(不含扩展名)
|
||||
doc_filename = safe_filename(Path(doc_path).stem)
|
||||
|
||||
# 创建以文档名命名的子文件夹
|
||||
image_output_dir = os.path.join(base_output_dir, doc_filename)
|
||||
|
||||
# 创建输出目录(如果不存在)
|
||||
if not os.path.exists(image_output_dir):
|
||||
os.makedirs(image_output_dir, exist_ok=True)
|
||||
debug(f"创建输出目录: {image_output_dir}")
|
||||
|
||||
try:
|
||||
# 检查文档是否存在
|
||||
if not os.path.exists(doc_path):
|
||||
error(f"Word文档不存在: {doc_path}")
|
||||
return []
|
||||
|
||||
debug(f"开始从Word文档提取图像: {doc_path}")
|
||||
debug(f"输出目录: {image_output_dir}")
|
||||
|
||||
# 将.docx文件视为zip文件处理
|
||||
with zipfile.ZipFile(doc_path, 'r') as docx:
|
||||
# 获取所有文件列表
|
||||
file_list = docx.namelist()
|
||||
|
||||
# 筛选出图像文件(通常位于word/media/目录下)
|
||||
image_files = [f for f in file_list if f.startswith('word/media/') and not f.endswith('/') and os.path.basename(f)]
|
||||
|
||||
debug(f"找到 {len(image_files)} 个图像文件")
|
||||
|
||||
img_paths = []
|
||||
for i, image_path in enumerate(image_files):
|
||||
try:
|
||||
# 提取图像文件名
|
||||
image_name = os.path.basename(image_path)
|
||||
|
||||
# 确保文件名有效
|
||||
if not image_name or image_name == "media":
|
||||
# 从路径中提取有意义的文件名
|
||||
parts = image_path.split('/')
|
||||
for part in reversed(parts):
|
||||
if part and part != "media":
|
||||
image_name = part
|
||||
break
|
||||
else:
|
||||
image_name = f"image_{i + 1}.png"
|
||||
|
||||
# 添加文件扩展名如果缺失
|
||||
if not Path(image_name).suffix:
|
||||
# 尝试从文件内容检测格式,否则使用默认png
|
||||
image_name += ".png"
|
||||
|
||||
# 生成输出文件路径
|
||||
output_path = os.path.join(image_output_dir, f"image_{i + 1:03d}_{image_name}")
|
||||
img_paths.append(output_path)
|
||||
# 提取并保存图像
|
||||
with docx.open(image_path) as image_file:
|
||||
image_data = image_file.read()
|
||||
|
||||
# 保存图像数据
|
||||
with open(output_path, 'wb') as f:
|
||||
f.write(image_data)
|
||||
|
||||
debug(f"✅ 已提取图像: {output_path}")
|
||||
|
||||
except Exception as e:
|
||||
error(f"提取图像 {image_path} 失败: {e}")
|
||||
continue
|
||||
|
||||
debug(f"Word文档图像提取完成: 成功提取 {len(image_files)} 个图像")
|
||||
return img_paths
|
||||
|
||||
except Exception as e:
|
||||
error(f"提取Word文档图像失败: {e}")
|
||||
return []
|
||||
|
||||
|
||||
def extract_images_from_ppt(ppt_path, base_output_dir) -> List[str]:
|
||||
"""
|
||||
从PowerPoint演示文稿中提取所有图像
|
||||
|
||||
参数:
|
||||
ppt_path (str): PowerPoint文件路径(.pptx格式)
|
||||
base_output_dir (str): 基础输出目录,会在此目录下创建以PPT名命名的子文件夹
|
||||
|
||||
返回:
|
||||
int: 成功提取的图像数量
|
||||
"""
|
||||
# 检查文件是否为.pptx格式
|
||||
if not ppt_path.lower().endswith('.pptx'):
|
||||
error(f"仅支持.pptx格式的PowerPoint文档: {ppt_path}")
|
||||
return []
|
||||
|
||||
# 从PPT路径提取文件名(不含扩展名)
|
||||
ppt_filename = safe_filename(Path(ppt_path).stem)
|
||||
|
||||
# 创建以PPT名命名的子文件夹
|
||||
image_output_dir = os.path.join(base_output_dir, ppt_filename)
|
||||
|
||||
# 创建输出目录(如果不存在)
|
||||
if not os.path.exists(image_output_dir):
|
||||
os.makedirs(image_output_dir, exist_ok=True)
|
||||
debug(f"创建输出目录: {image_output_dir}")
|
||||
|
||||
try:
|
||||
# 检查PPT文件是否存在
|
||||
if not os.path.exists(ppt_path):
|
||||
error(f"PowerPoint文档不存在: {ppt_path}")
|
||||
return []
|
||||
|
||||
debug(f"开始从PowerPoint文档提取图像: {ppt_path}")
|
||||
debug(f"输出目录: {image_output_dir}")
|
||||
|
||||
# 将.pptx文件视为zip文件处理
|
||||
with zipfile.ZipFile(ppt_path, 'r') as pptx:
|
||||
# 获取所有文件列表
|
||||
file_list = pptx.namelist()
|
||||
|
||||
# 筛选出图像文件(通常位于ppt/media/目录下)
|
||||
image_files = [f for f in file_list if f.startswith('ppt/media/') and not f.endswith('/') and os.path.basename(f)]
|
||||
|
||||
debug(f"找到 {len(image_files)} 个图像文件")
|
||||
|
||||
img_paths = []
|
||||
for i, image_path in enumerate(image_files):
|
||||
try:
|
||||
# 提取图像文件名
|
||||
image_name = Path(image_path).name
|
||||
|
||||
# 验证文件名有效性
|
||||
if not image_name or image_name == "media":
|
||||
parts = image_path.split('/')
|
||||
for part in reversed(parts):
|
||||
if part and part != "media":
|
||||
image_name = part
|
||||
break
|
||||
else:
|
||||
image_name = f"image_{i + 1}.png"
|
||||
|
||||
# 确保有文件扩展名
|
||||
if not Path(image_name).suffix:
|
||||
image_name += ".png"
|
||||
|
||||
# 生成输出文件路径
|
||||
output_path = os.path.join(image_output_dir, f"image_{i + 1:03d}_{image_name}")
|
||||
img_paths.append(output_path)
|
||||
# 提取并保存图像
|
||||
with pptx.open(image_path) as image_file:
|
||||
image_data = image_file.read()
|
||||
|
||||
# 保存图像数据
|
||||
with open(output_path, 'wb') as f:
|
||||
f.write(image_data)
|
||||
|
||||
debug(f"✅ 已提取图像: {output_path}")
|
||||
|
||||
except Exception as e:
|
||||
error(f"提取图像 {image_path} 失败: {e}")
|
||||
continue
|
||||
|
||||
debug(f"PowerPoint文档图像提取完成: 成功提取{len(image_files)} 个图像")
|
||||
return img_paths
|
||||
|
||||
except Exception as e:
|
||||
error(f"提取PowerPoint文档图像失败: {e}")
|
||||
return []
|
||||
|
||||
|
||||
def extract_images_from_file(file_path, base_output_dir="/home/wangmeihua/kyrag/data/extracted_images", file_type=None):
|
||||
"""
|
||||
通用函数:根据文件类型自动选择提取方法
|
||||
|
||||
参数:
|
||||
file_path (str): 文件路径
|
||||
base_output_dir (str): 基础输出目录
|
||||
file_type (str): 文件类型(可选,自动检测)
|
||||
|
||||
返回:
|
||||
int: 成功提取的图像/页面数量
|
||||
"""
|
||||
# 如果没有指定文件类型,根据扩展名自动检测
|
||||
if file_type is None:
|
||||
ext = Path(file_path).suffix.lower()
|
||||
if ext == '.pdf':
|
||||
file_type = 'pdf'
|
||||
elif ext == '.docx':
|
||||
file_type = 'word'
|
||||
elif ext == '.pptx':
|
||||
file_type = 'ppt'
|
||||
else:
|
||||
error(f"不支持的文件类型: {ext}")
|
||||
return []
|
||||
|
||||
# 根据文件类型调用相应的函数
|
||||
if file_type == 'pdf':
|
||||
return render_pdf_to_images(file_path, base_output_dir)
|
||||
elif file_type == 'word':
|
||||
return extract_images_from_word(file_path, base_output_dir)
|
||||
elif file_type == 'ppt':
|
||||
return extract_images_from_ppt(file_path, base_output_dir)
|
||||
else:
|
||||
error(f"不支持的文件类型: {file_type}")
|
||||
return []
|
||||
|
||||
|
||||
# 使用示例
|
||||
if __name__ == "__main__":
|
||||
base_output_dir = "/home/wangmeihua/kyrag/data/extracted_images"
|
||||
|
||||
# PDF文件处理
|
||||
pdf_path = "/home/wangmeihua/kyrag/22-zh-review.pdf"
|
||||
pdf_imgs = extract_images_from_file(pdf_path, base_output_dir, 'pdf')
|
||||
debug(f"pdf_imgs: {pdf_imgs}")
|
||||
if len(pdf_imgs) > 0:
|
||||
debug(f"成功处理PDF: {len(pdf_imgs)} 页")
|
||||
else:
|
||||
error("PDF处理失败")
|
||||
|
||||
# Word文档处理
|
||||
doc_path = "/home/wangmeihua/kyrag/test.docx"
|
||||
if os.path.exists(doc_path):
|
||||
doc_imgs = extract_images_from_file(doc_path, base_output_dir, 'word')
|
||||
debug(f"doc_imgs: {doc_imgs}")
|
||||
if len(doc_imgs) > 0:
|
||||
debug(f"成功处理Word文档: {len(doc_imgs)} 个图像")
|
||||
else:
|
||||
error("Word文档处理失败")
|
||||
else:
|
||||
debug(f"Word文档不存在: {doc_path}")
|
||||
|
||||
# PowerPoint处理
|
||||
ppt_path = "/home/wangmeihua/kyrag/提示学习-王美华.pptx"
|
||||
if os.path.exists(ppt_path):
|
||||
ppt_imgs = extract_images_from_file(ppt_path, base_output_dir, 'ppt')
|
||||
if len(ppt_imgs) > 0:
|
||||
debug(f"成功处理PowerPoint: {len(ppt_imgs)} 个图像")
|
||||
else:
|
||||
error("PowerPoint处理失败")
|
||||
else:
|
||||
debug(f"PowerPoint文档不存在: {ppt_path}")
|
||||
@ -17,13 +17,15 @@ import traceback
|
||||
from filetxt.loader import fileloader,File2Text
|
||||
from ahserver.serverenv import get_serverenv
|
||||
from typing import List, Dict, Any
|
||||
from rag.service_opts import get_service_params, sor_get_service_params
|
||||
from rag.service_opts import get_service_params, sor_get_service_params, sor_get_embedding_mode, get_embedding_mode
|
||||
from rag.fileprocess import extract_images_from_file
|
||||
from rag.rag_operations import RagOperations
|
||||
import json
|
||||
from rag.transaction_manager import TransactionContext
|
||||
from dataclasses import dataclass
|
||||
from enum import Enum
|
||||
|
||||
import base64
|
||||
from pathlib import Path
|
||||
|
||||
class RagFileMgr(FileMgr):
|
||||
def __init__(self, fiid):
|
||||
@ -53,6 +55,10 @@ where a.orgid = b.orgid
|
||||
return r.quota, r.expired_date
|
||||
return None, None
|
||||
|
||||
async def file_to_base64(self,path: str) -> str:
|
||||
with open(path, "rb") as f:
|
||||
return base64.b64encode(f.read()).decode("utf-8")
|
||||
|
||||
async def file_uploaded(self, request, ns, userid):
|
||||
"""将文档插入 Milvus 并抽取三元组到 Neo4j"""
|
||||
debug(f'Received ns: {ns=}')
|
||||
@ -104,21 +110,107 @@ where a.orgid = b.orgid
|
||||
raise ValueError("无法获取服务参数")
|
||||
rollback_context["service_params"] = service_params
|
||||
|
||||
#获取嵌入模式
|
||||
embedding_mode = await get_embedding_mode(orgid)
|
||||
debug(f"检测到 embedding_mode = {embedding_mode}(0=文本, 1=多模态)")
|
||||
|
||||
# 加载和分片文档
|
||||
chunks = await self.rag_ops.load_and_chunk_document(
|
||||
realpath, timings, transaction_mgr=transaction_mgr
|
||||
)
|
||||
|
||||
text_embeddings = None
|
||||
multi_results = None
|
||||
image_paths = []
|
||||
|
||||
if embedding_mode == 1:
|
||||
inputs = []
|
||||
# 文本
|
||||
for chunk in chunks:
|
||||
inputs.append({"type": "text", "content": chunk.page_content})
|
||||
|
||||
debug("开始多模态图像抽取与嵌入")
|
||||
image_paths = extract_images_from_file(realpath)
|
||||
debug(f"从文档中抽取 {len(image_paths)} 张图像")
|
||||
|
||||
if image_paths:
|
||||
for img_path in image_paths:
|
||||
try:
|
||||
# 1. 自动识别真实格式
|
||||
ext = Path(img_path).suffix.lower()
|
||||
if ext not in {".png", ".jpg", ".jpeg", ".webp", ".bmp"}:
|
||||
ext = ".jpg"
|
||||
|
||||
mime_map = {
|
||||
".png": "image/png",
|
||||
".jpg": "image/jpeg",
|
||||
".jpeg": "image/jpeg",
|
||||
".webp": "image/webp",
|
||||
".bmp": "image/bmp"
|
||||
}
|
||||
mime_type = mime_map.get(ext, "image/jpeg")
|
||||
|
||||
# # 2. 智能压缩(>1MB 才压缩,节省 70% 流量)
|
||||
# img = Image.open(img_path).convert("RGB")
|
||||
# if os.path.getsize(img_path) > 1024 * 1024: # >1MB
|
||||
# buffer = BytesIO()
|
||||
# img.save(buffer, format="JPEG", quality=85, optimize=True)
|
||||
# b64 = base64.b64encode(buffer.getvalue()).decode()
|
||||
# data_uri = f"data:image/jpeg;base64,{b64}"
|
||||
# else:
|
||||
b64 = await self.file_to_base64(img_path)
|
||||
data_uri = f"data:{mime_type};base64,{b64}"
|
||||
|
||||
inputs.append({
|
||||
"type": "image",
|
||||
"data": data_uri
|
||||
})
|
||||
debug(f"已添加图像({mime_type}, {len(b64) / 1024:.1f}KB): {Path(img_path).name}")
|
||||
|
||||
except Exception as e:
|
||||
debug(f"图像处理失败,跳过: {img_path} → {e}")
|
||||
# 即使失败也加个占位,防止顺序错乱
|
||||
inputs.append({
|
||||
"type": "image",
|
||||
"data": ""
|
||||
})
|
||||
|
||||
debug(f"混排输入总数: {len(inputs)}(文本 {len(chunks)} + 图像 {len(image_paths)})")
|
||||
|
||||
multi_results = await self.rag_ops.generate_multi_embeddings(
|
||||
request=request,
|
||||
inputs=inputs,
|
||||
service_params=service_params,
|
||||
userid=userid,
|
||||
timings=timings,
|
||||
transaction_mgr=transaction_mgr
|
||||
)
|
||||
debug(f"多模态嵌入成功,返回 {len(multi_results)} 条结果")
|
||||
else:
|
||||
# 生成嵌入向量
|
||||
embeddings = await self.rag_ops.generate_embeddings(
|
||||
debug("【纯文本模式】使用 BGE 嵌入")
|
||||
text_embeddings = await self.rag_ops.generate_embeddings(
|
||||
request, chunks, service_params, userid, timings, transaction_mgr=transaction_mgr
|
||||
)
|
||||
debug(f"BGE 嵌入完成: {len(text_embeddings)} 条")
|
||||
|
||||
# 插入 Milvus
|
||||
chunks_data = await self.rag_ops.insert_to_vector_db(
|
||||
request, chunks, embeddings, realpath, orgid, fiid, id,
|
||||
service_params, userid, db_type, timings, transaction_mgr=transaction_mgr
|
||||
inserted = await self.rag_ops.insert_all_vectors(
|
||||
request=request,
|
||||
text_chunks=chunks,
|
||||
realpath=realpath,
|
||||
orgid=orgid,
|
||||
fiid=fiid,
|
||||
document_id=id,
|
||||
service_params=service_params,
|
||||
userid=userid,
|
||||
db_type=db_type,
|
||||
timings=timings,
|
||||
img_paths=image_paths,
|
||||
text_embeddings=text_embeddings,
|
||||
multi_results=multi_results,
|
||||
transaction_mgr=transaction_mgr
|
||||
)
|
||||
debug(f"统一插入: 文本 {inserted['text']}, 图像 {inserted['image']}, 人脸 {inserted['face']}")
|
||||
|
||||
# 抽取三元组
|
||||
triples = await self.rag_ops.extract_triples(
|
||||
|
||||
@ -3,7 +3,6 @@ from ahserver.serverenv import ServerEnv
|
||||
import aiohttp
|
||||
from aiohttp import ClientSession, ClientTimeout
|
||||
import json
|
||||
from .file import file_uploaded, file_deleted
|
||||
from .folderinfo import RagFileMgr
|
||||
from .ragprogram import set_program, get_rag_programs
|
||||
from .ragllm_utils import get_ragllms_by_catelog
|
||||
|
||||
@ -2,6 +2,7 @@ import os
|
||||
import re
|
||||
import time
|
||||
import math
|
||||
import numpy as np
|
||||
from datetime import datetime
|
||||
from typing import List, Dict, Any, Optional
|
||||
from langchain_core.documents import Document
|
||||
@ -12,7 +13,10 @@ from filetxt.loader import fileloader, File2Text
|
||||
from rag.uapi_service import APIService
|
||||
from rag.service_opts import get_service_params
|
||||
from rag.transaction_manager import TransactionManager, OperationType
|
||||
|
||||
from pdf2image import convert_from_path
|
||||
import pytesseract
|
||||
import base64
|
||||
from pathlib import Path
|
||||
|
||||
class RagOperations:
|
||||
"""RAG 操作类,提供所有通用的 RAG 操作"""
|
||||
@ -33,14 +37,39 @@ class RagOperations:
|
||||
if ext not in supported_formats:
|
||||
raise ValueError(f"不支持的文件格式: {ext}, 支持的格式: {', '.join(supported_formats)}")
|
||||
|
||||
# 加载文件内容
|
||||
text = fileloader(realpath)
|
||||
if ext == 'pdf':
|
||||
debug(f"pdf原生提取结果是:{text}")
|
||||
if not text or len(text.strip()) == 0: # 更严格的空值检查
|
||||
debug(f"pdf原生提取失败,尝试扫描件提取")
|
||||
ocr_text = self.pdf_to_text(realpath)
|
||||
debug(f"pdf扫描件抽取的文本内容是:{ocr_text}")
|
||||
text = ocr_text # 只在原生提取失败时使用OCR结果
|
||||
|
||||
# 只在有文本内容时进行清洗
|
||||
if text and len(text.strip()) > 0:
|
||||
# 或者保留更多有用字符
|
||||
text = re.sub(r'[^\u4e00-\u9fa5a-zA-Z0-9\s.;,\n/]', '', text)
|
||||
else:
|
||||
error(f"文件 {realpath} 无法提取任何文本内容")
|
||||
text = "" # 确保为空字符串
|
||||
|
||||
timings["load_file"] = time.time() - start_load
|
||||
debug(f"加载文件耗时: {timings['load_file']:.2f} 秒, 文本长度: {len(text)}")
|
||||
|
||||
if not text or not text.strip():
|
||||
raise ValueError(f"文件 {realpath} 加载为空")
|
||||
# # 加载文件内容
|
||||
# text = fileloader(realpath)
|
||||
# debug(f"pdf原生提取结果是:{text}")
|
||||
# if len(text) == 0:
|
||||
# debug(f"pdf原生提取失败,尝试扫描件提取")
|
||||
# text = self.pdf_to_text(realpath)
|
||||
# debug(f"pdf扫描件抽取的文本内容是:{text}")
|
||||
# text = re.sub(r'[^\u4e00-\u9fa5a-zA-Z0-9\s.;,\n/]', '', text)
|
||||
# timings["load_file"] = time.time() - start_load
|
||||
# debug(f"加载文件耗时: {timings['load_file']:.2f} 秒, 文本长度: {len(text)}")
|
||||
#
|
||||
# if not text or not text.strip():
|
||||
# raise ValueError(f"文件 {realpath} 加载为空")
|
||||
|
||||
# 分片处理
|
||||
document = Document(page_content=text)
|
||||
@ -67,6 +96,40 @@ class RagOperations:
|
||||
|
||||
return chunks
|
||||
|
||||
def pdf_to_text(
|
||||
self,
|
||||
pdf_path: str,
|
||||
output_txt: Optional[str] = None,
|
||||
dpi: int = 300,
|
||||
lang: str = 'chi_sim+chi_tra+eng'
|
||||
) -> str:
|
||||
"""
|
||||
将扫描版 PDF 转为文字(你原来的代码,一行调用版)
|
||||
|
||||
参数:
|
||||
pdf_path: PDF 文件路径(字符串)
|
||||
output_txt: 如果提供,会自动保存到这个 txt 文件(可选)
|
||||
dpi: 图片分辨率,默认 300(越高越清晰)
|
||||
lang: 语言包,默认中文简体+繁体+英文
|
||||
|
||||
返回:
|
||||
提取出的完整文字(字符串)
|
||||
"""
|
||||
# PDF 转图片
|
||||
images = convert_from_path(pdf_path, dpi=dpi)
|
||||
|
||||
# OCR 识别
|
||||
text = ''
|
||||
for img in images:
|
||||
text += pytesseract.image_to_string(img, lang=lang) + '\n'
|
||||
|
||||
# 可选:自动保存到文件
|
||||
if output_txt:
|
||||
with open(output_txt, 'w', encoding='utf-8') as f:
|
||||
f.write(text)
|
||||
|
||||
return text
|
||||
|
||||
async def generate_embeddings(self, request, chunks: List[Document], service_params: Dict,
|
||||
userid: str, timings: Dict,
|
||||
transaction_mgr: TransactionManager = None) -> List[List[float]]:
|
||||
@ -103,84 +166,488 @@ class RagOperations:
|
||||
|
||||
return embeddings
|
||||
|
||||
async def insert_to_vector_db(self, request, chunks: List[Document], embeddings: List[List[float]],
|
||||
realpath: str, orgid: str, fiid: str, id: str, service_params: Dict,
|
||||
userid: str, db_type: str, timings: Dict,
|
||||
transaction_mgr: TransactionManager = None):
|
||||
"""插入向量数据库"""
|
||||
debug(f"准备数据并调用插入文件端点: {realpath}")
|
||||
filename = os.path.basename(realpath).rsplit('.', 1)[0]
|
||||
ext = realpath.rsplit('.', 1)[1].lower() if '.' in realpath else ''
|
||||
async def generate_multi_embeddings(self, request, inputs: List[Dict], service_params: Dict,
|
||||
userid: str, timings: Dict,
|
||||
transaction_mgr: TransactionManager = None) -> Dict[str, Dict]:
|
||||
"""调用多模态嵌入服务(CLIP)"""
|
||||
debug("调用多模态嵌入服务")
|
||||
start = time.time()
|
||||
|
||||
result = await self.api_service.get_multi_embeddings(
|
||||
request=request,
|
||||
inputs=inputs,
|
||||
upappid=service_params['embedding'],
|
||||
apiname="black/clip",
|
||||
user=userid
|
||||
)
|
||||
debug(f"多模态返回结果是{result}")
|
||||
timings["multi_embedding"] = time.time() - start
|
||||
debug(f"多模态嵌入耗时: {timings['multi_embedding']:.2f}秒,处理 {len(result)} 条")
|
||||
|
||||
# ==================== 新增:错误检查 + 过滤 ====================
|
||||
valid_results = {}
|
||||
error_count = 0
|
||||
error_examples = []
|
||||
|
||||
for key, info in result.items():
|
||||
if info.get("type") == "error":
|
||||
error_count += 1
|
||||
if len(error_examples) < 3: # 只记录前3个
|
||||
error_examples.append(f"{key} → {info['error']}")
|
||||
# 直接丢弃错误条目
|
||||
continue
|
||||
valid_results[key] = info
|
||||
|
||||
if error_count > 0:
|
||||
error(f"多模态嵌入失败 {error_count} 条!示例:{'; '.join(error_examples)}")
|
||||
raise RuntimeError(f"多模态嵌入有{error_count} 条失败")
|
||||
else:
|
||||
debug("多模态嵌入全部成功!")
|
||||
|
||||
if transaction_mgr:
|
||||
transaction_mgr.add_operation(
|
||||
OperationType.EMBEDDING,
|
||||
{'count': len(result)}
|
||||
)
|
||||
|
||||
return result
|
||||
|
||||
# async def force_l2_normalize(self, vector: List[float]) -> List[float]:
|
||||
# """万无一失的 L2 归一化"""
|
||||
# arr = np.array(vector, dtype=np.float32)
|
||||
# norm = np.linalg.norm(arr)
|
||||
# if norm == 0:
|
||||
# return vector # 全零向量无法归一化
|
||||
# return (arr / norm).tolist()
|
||||
|
||||
# 统一插入向量库
|
||||
async def insert_all_vectors(
|
||||
self,
|
||||
request,
|
||||
text_chunks: List[Document],
|
||||
realpath: str,
|
||||
orgid: str,
|
||||
fiid: str,
|
||||
document_id: str,
|
||||
service_params: Dict,
|
||||
userid: str,
|
||||
db_type: str,
|
||||
timings: Dict,
|
||||
img_paths: List[str] = None,
|
||||
text_embeddings: List[List[float]] = None,
|
||||
multi_results: Dict = None,
|
||||
transaction_mgr: TransactionManager = None
|
||||
) -> Dict[str, int]:
|
||||
"""
|
||||
统一插入函数:支持两种模式
|
||||
1. 纯文本模式:text_embeddings 有值
|
||||
2. 多模态模式:multi_results 有值(来自 generate_multi_embeddings)
|
||||
"""
|
||||
img_paths = img_paths or []
|
||||
all_chunks = []
|
||||
start = time.time()
|
||||
filename = os.path.basename(realpath)
|
||||
upload_time = datetime.now().isoformat()
|
||||
|
||||
chunks_data = [
|
||||
{
|
||||
# ==================== 1. 纯文本模式(BGE) ====================
|
||||
if text_embeddings is not None:
|
||||
debug(f"【纯文本模式】插入 {len(text_embeddings)} 条文本向量")
|
||||
for i, chunk in enumerate(text_chunks):
|
||||
all_chunks.append({
|
||||
"userid": orgid,
|
||||
"knowledge_base_id": fiid,
|
||||
"text": chunk.page_content,
|
||||
"vector": embeddings[i],
|
||||
"document_id": id,
|
||||
"filename": filename + '.' + ext,
|
||||
"vector": text_embeddings[i],
|
||||
"document_id": document_id,
|
||||
"filename": filename,
|
||||
"file_path": realpath,
|
||||
"upload_time": upload_time,
|
||||
"file_type": ext,
|
||||
}
|
||||
for i, chunk in enumerate(chunks)
|
||||
]
|
||||
"file_type": "text",
|
||||
})
|
||||
|
||||
start_milvus = time.time()
|
||||
for i in range(0, len(chunks_data), 10):
|
||||
batch_chunks = chunks_data[i:i + 10]
|
||||
debug(f"传入的数据是:{batch_chunks}")
|
||||
# ==================== 2. 多模态模式(CLIP 混排) ====================
|
||||
if multi_results is not None:
|
||||
debug(f"【多模态模式】解析 {len(multi_results)} 条 CLIP 结果")
|
||||
# 遍历 multi_results
|
||||
for raw_key, info in multi_results.items():
|
||||
typ = info["type"]
|
||||
# vector = info["vector"]
|
||||
# debug(f"从后端传回来的向量数据是:{vector}")
|
||||
# emb = await self.force_l2_normalize(info["vector"])
|
||||
# debug(f"归一化后的向量数据是:{emb}")
|
||||
# --- 文本 ---
|
||||
if typ == "text":
|
||||
# raw_key 就是原文
|
||||
all_chunks.append({
|
||||
"userid": orgid,
|
||||
"knowledge_base_id": fiid,
|
||||
"text": raw_key,
|
||||
"vector": info["vector"],
|
||||
"document_id": document_id,
|
||||
"filename": filename,
|
||||
"file_path": realpath,
|
||||
"upload_time": upload_time,
|
||||
"file_type": "text",
|
||||
})
|
||||
continue
|
||||
|
||||
# --- 图像 ---
|
||||
if typ == "image":
|
||||
img_path = info.get("path") or raw_key
|
||||
img_name = os.path.basename(img_path)
|
||||
|
||||
# 整图向量
|
||||
if "vector" in info:
|
||||
all_chunks.append({
|
||||
"userid": orgid,
|
||||
"knowledge_base_id": fiid,
|
||||
"text": f"[Image: {img_path}]图片来源于文件{realpath}",
|
||||
"vector": info["vector"],
|
||||
"document_id": document_id,
|
||||
"filename": img_name,
|
||||
"file_path": realpath,
|
||||
"upload_time": upload_time,
|
||||
"file_type": "image",
|
||||
})
|
||||
|
||||
# 人脸向量
|
||||
face_vecs = info.get("face_vecs", [])
|
||||
face_count = len(face_vecs)
|
||||
# if face_count > 0:
|
||||
# for f_idx, fvec in enumerate(face_vecs):
|
||||
# debug(f"人脸向量维度是:{len(fvec)}")
|
||||
# all_chunks.append({
|
||||
# "userid": orgid,
|
||||
# "knowledge_base_id": fiid,
|
||||
# "text": f"[Face {f_idx + 1}/{face_count} in {img_name}]人脸来源于{realpath}的{img_path}图片",
|
||||
# "vector": fvec,
|
||||
# "document_id": document_id,
|
||||
# "filename": img_name,
|
||||
# "file_path": realpath,
|
||||
# "upload_time": upload_time,
|
||||
# "file_type": "face",
|
||||
# })
|
||||
continue
|
||||
|
||||
# --- 视频 ---
|
||||
if typ == "video":
|
||||
video_path = info.get("path") or raw_key
|
||||
video_name = os.path.basename(video_path)
|
||||
|
||||
if "vector" in info:
|
||||
all_chunks.append({
|
||||
"userid": orgid,
|
||||
"knowledge_base_id": fiid,
|
||||
"text": f"[Video: {video_name}]",
|
||||
"vector": info["vector"],
|
||||
"document_id": document_id,
|
||||
"filename": video_path,
|
||||
"file_path": realpath,
|
||||
"upload_time": upload_time,
|
||||
"file_type": "video",
|
||||
})
|
||||
|
||||
# 视频人脸
|
||||
face_vecs = info.get("face_vecs", [])
|
||||
face_count = len(face_vecs)
|
||||
# if face_count > 0 :
|
||||
# for f_idx, fvec in enumerate(face_vecs):
|
||||
# all_chunks.append({
|
||||
# "userid": orgid,
|
||||
# "knowledge_base_id": fiid,
|
||||
# "text": f"[Face {f_idx + 1}/{face_count} in video {video_name}]来源于{video_path}",
|
||||
# "vector": fvec,
|
||||
# "document_id": document_id,
|
||||
# "filename": video_path,
|
||||
# "file_path": realpath,
|
||||
# "upload_time": upload_time,
|
||||
# "file_type": "face",
|
||||
# })
|
||||
continue
|
||||
|
||||
# --- 音频 ---
|
||||
if typ == "audio":
|
||||
audio_path = info.get("path") or raw_key
|
||||
audio_name = os.path.basename(audio_path)
|
||||
|
||||
if "vector" in info:
|
||||
all_chunks.append({
|
||||
"userid": orgid,
|
||||
"knowledge_base_id": fiid,
|
||||
"text": f"[Audio: {audio_name}]",
|
||||
"vector": info["vector"],
|
||||
"document_id": document_id,
|
||||
"filename": audio_path,
|
||||
"file_path": realpath,
|
||||
"upload_time": upload_time,
|
||||
"file_type": "audio",
|
||||
})
|
||||
continue
|
||||
|
||||
# --- 未知类型 ---
|
||||
debug(f"未知类型跳过: {typ} → {raw_key}")
|
||||
|
||||
# ==================== 3. 批量插入 Milvus ====================
|
||||
if not all_chunks:
|
||||
debug("无向量需要插入")
|
||||
return {"text": 0, "image": 0, "face": 0}
|
||||
|
||||
for i in range(0, len(all_chunks), 10):
|
||||
batch = all_chunks[i:i + 10]
|
||||
result = await self.api_service.milvus_insert_document(
|
||||
request=request,
|
||||
chunks=batch_chunks,
|
||||
db_type=db_type,
|
||||
chunks=batch,
|
||||
upappid=service_params['vdb'],
|
||||
apiname="milvus/insertdocument",
|
||||
user=userid
|
||||
user=userid,
|
||||
db_type=db_type
|
||||
)
|
||||
if result.get("status") != "success":
|
||||
raise ValueError(result.get("message", "Milvus 插入失败"))
|
||||
raise ValueError(f"Milvus 插入失败: {result.get('message')}")
|
||||
|
||||
timings["insert_milvus"] = time.time() - start_milvus
|
||||
debug(f"Milvus 插入耗时: {timings['insert_milvus']:.2f} 秒")
|
||||
|
||||
# 记录事务操作,包含回滚函数
|
||||
if transaction_mgr:
|
||||
async def rollback_vdb_insert(data, context):
|
||||
# ==================== 4. 统一回滚(只登记一次) ====================
|
||||
if transaction_mgr and all_chunks:
|
||||
async def rollback_all(data, context):
|
||||
try:
|
||||
# 防御性检查
|
||||
required_context = ['request', 'service_params', 'userid']
|
||||
missing_context = [k for k in required_context if k not in context or context[k] is None]
|
||||
if missing_context:
|
||||
raise ValueError(f"回滚上下文缺少字段: {', '.join(missing_context)}")
|
||||
|
||||
required_data = ['orgid', 'realpath', 'fiid', 'id', 'db_type']
|
||||
missing_data = [k for k in required_data if k not in data or data[k] is None]
|
||||
if missing_data:
|
||||
raise ValueError(f"VDB_INSERT 数据缺少字段: {', '.join(missing_data)}")
|
||||
|
||||
await self.delete_from_vector_db(
|
||||
context['request'], data['orgid'], data['realpath'],
|
||||
data['fiid'], data['id'], context['service_params'],
|
||||
context['userid'], data['db_type']
|
||||
request=context['request'],
|
||||
orgid=data['orgid'],
|
||||
realpath=data['realpath'],
|
||||
fiid=data['fiid'],
|
||||
id=data['document_id'],
|
||||
service_params=context['service_params'],
|
||||
userid=context['userid'],
|
||||
db_type=data['db_type']
|
||||
)
|
||||
return f"已回滚向量数据库插入: {data['id']}"
|
||||
return f"已回滚 document_id={data['document_id']} 的所有向量"
|
||||
except Exception as e:
|
||||
error(f"回滚向量数据库失败: document_id={data.get('id', '未知')}, 错误: {str(e)}")
|
||||
error(f"统一回滚失败: {e}")
|
||||
raise
|
||||
|
||||
transaction_mgr.add_operation(
|
||||
OperationType.VDB_INSERT,
|
||||
{
|
||||
'orgid': orgid, 'realpath': realpath, 'fiid': fiid,
|
||||
'id': id, 'db_type': db_type
|
||||
'orgid': orgid,
|
||||
'realpath': realpath,
|
||||
'fiid': fiid,
|
||||
'id': document_id,
|
||||
'db_type': db_type
|
||||
},
|
||||
rollback_func=rollback_vdb_insert
|
||||
rollback_func=rollback_all
|
||||
)
|
||||
|
||||
return chunks_data
|
||||
# ==================== 5. 统计返回 ====================
|
||||
stats = {
|
||||
"text": len([c for c in all_chunks if c["file_type"] == "text"]),
|
||||
"image": len([c for c in all_chunks if c["file_type"] == "image"]),
|
||||
"face": len([c for c in all_chunks if c["file_type"] == "face"])
|
||||
}
|
||||
|
||||
timings["insert_all"] = time.time() - start
|
||||
debug(
|
||||
f"统一插入完成: 文本 {stats['text']}, 图像 {stats['image']}, 人脸 {stats['face']}, 耗时 {timings['insert_all']:.2f}s")
|
||||
return stats
|
||||
# async def insert_to_vector_db(self, request, chunks: List[Document], embeddings: List[List[float]],
|
||||
# realpath: str, orgid: str, fiid: str, id: str, service_params: Dict,
|
||||
# userid: str, db_type: str, timings: Dict,
|
||||
# transaction_mgr: TransactionManager = None):
|
||||
# """插入向量数据库"""
|
||||
# debug(f"准备数据并调用插入文件端点: {realpath}")
|
||||
# filename = os.path.basename(realpath).rsplit('.', 1)[0]
|
||||
# ext = realpath.rsplit('.', 1)[1].lower() if '.' in realpath else ''
|
||||
# upload_time = datetime.now().isoformat()
|
||||
#
|
||||
# chunks_data = [
|
||||
# {
|
||||
# "userid": orgid,
|
||||
# "knowledge_base_id": fiid,
|
||||
# "text": chunk.page_content,
|
||||
# "vector": embeddings[i],
|
||||
# "document_id": id,
|
||||
# "filename": filename + '.' + ext,
|
||||
# "file_path": realpath,
|
||||
# "upload_time": upload_time,
|
||||
# "file_type": ext,
|
||||
# }
|
||||
# for i, chunk in enumerate(chunks)
|
||||
# ]
|
||||
#
|
||||
# start_milvus = time.time()
|
||||
# for i in range(0, len(chunks_data), 10):
|
||||
# batch_chunks = chunks_data[i:i + 10]
|
||||
# debug(f"传入的数据是:{batch_chunks}")
|
||||
# result = await self.api_service.milvus_insert_document(
|
||||
# request=request,
|
||||
# chunks=batch_chunks,
|
||||
# db_type=db_type,
|
||||
# upappid=service_params['vdb'],
|
||||
# apiname="milvus/insertdocument",
|
||||
# user=userid
|
||||
# )
|
||||
# if result.get("status") != "success":
|
||||
# raise ValueError(result.get("message", "Milvus 插入失败"))
|
||||
#
|
||||
# timings["insert_milvus"] = time.time() - start_milvus
|
||||
# debug(f"Milvus 插入耗时: {timings['insert_milvus']:.2f} 秒")
|
||||
#
|
||||
# # 记录事务操作,包含回滚函数
|
||||
# if transaction_mgr:
|
||||
# async def rollback_vdb_insert(data, context):
|
||||
# try:
|
||||
# # 防御性检查
|
||||
# required_context = ['request', 'service_params', 'userid']
|
||||
# missing_context = [k for k in required_context if k not in context or context[k] is None]
|
||||
# if missing_context:
|
||||
# raise ValueError(f"回滚上下文缺少字段: {', '.join(missing_context)}")
|
||||
#
|
||||
# required_data = ['orgid', 'realpath', 'fiid', 'id', 'db_type']
|
||||
# missing_data = [k for k in required_data if k not in data or data[k] is None]
|
||||
# if missing_data:
|
||||
# raise ValueError(f"VDB_INSERT 数据缺少字段: {', '.join(missing_data)}")
|
||||
#
|
||||
# await self.delete_from_vector_db(
|
||||
# context['request'], data['orgid'], data['realpath'],
|
||||
# data['fiid'], data['id'], context['service_params'],
|
||||
# context['userid'], data['db_type']
|
||||
# )
|
||||
# return f"已回滚向量数据库插入: {data['id']}"
|
||||
# except Exception as e:
|
||||
# error(f"回滚向量数据库失败: document_id={data.get('id', '未知')}, 错误: {str(e)}")
|
||||
# raise
|
||||
#
|
||||
# transaction_mgr.add_operation(
|
||||
# OperationType.VDB_INSERT,
|
||||
# {
|
||||
# 'orgid': orgid, 'realpath': realpath, 'fiid': fiid,
|
||||
# 'id': id, 'db_type': db_type
|
||||
# },
|
||||
# rollback_func=rollback_vdb_insert
|
||||
# )
|
||||
#
|
||||
# return chunks_data
|
||||
#
|
||||
# async def insert_image_vectors(
|
||||
# self,
|
||||
# request,
|
||||
# multi_results: Dict[str, Dict],
|
||||
# realpath: str,
|
||||
# orgid: str,
|
||||
# fiid: str,
|
||||
# document_id: str,
|
||||
# service_params: Dict,
|
||||
# userid: str,
|
||||
# db_type: str,
|
||||
# timings: Dict,
|
||||
# transaction_mgr: TransactionManager = None
|
||||
# ) -> tuple[int, int]:
|
||||
#
|
||||
# start = time.time()
|
||||
# image_chunks = []
|
||||
# face_chunks = []
|
||||
#
|
||||
# for img_path, info in multi_results.items():
|
||||
# # img_name = os.path.basename(img_path)
|
||||
#
|
||||
# # 1. 插入整张图
|
||||
# if info.get("type") in ["image", "video"] and "vector" in info:
|
||||
# image_chunks.append({
|
||||
# "userid": orgid,
|
||||
# "knowledge_base_id": fiid,
|
||||
# "text": f"[Image: {img_path}]",
|
||||
# "vector": info["vector"],
|
||||
# "document_id": document_id,
|
||||
# "filename": os.path.basename(realpath),
|
||||
# "file_path": realpath,
|
||||
# "upload_time": datetime.now().isoformat(),
|
||||
# "file_type": "image"
|
||||
# })
|
||||
#
|
||||
# # 2. 插入每张人脸
|
||||
# face_vecs = info.get("face_vecs")
|
||||
# face_count = info.get("face_count", 0)
|
||||
#
|
||||
# if face_count > 0 and face_vecs and len(face_vecs) == face_count:
|
||||
# for idx, face_vec in enumerate(face_vecs):
|
||||
# face_chunks.append({
|
||||
# "userid": orgid,
|
||||
# "knowledge_base_id": fiid,
|
||||
# "text": f"[Face {idx + 1}/{face_count} in {img_path}]",
|
||||
# "vector": face_vec,
|
||||
# "document_id": document_id,
|
||||
# "filename": os.path.basename(realpath),
|
||||
# "file_path": realpath,
|
||||
# "upload_time": datetime.now().isoformat(),
|
||||
# "file_type": "face",
|
||||
# })
|
||||
#
|
||||
# if image_chunks:
|
||||
# for i in range(0, len(image_chunks), 10):
|
||||
# await self.api_service.milvus_insert_document(
|
||||
# request=request,
|
||||
# chunks=image_chunks[i:i + 10],
|
||||
# upappid=service_params['vdb'],
|
||||
# apiname="milvus/insertdocument",
|
||||
# user=userid,
|
||||
# db_type=db_type
|
||||
# )
|
||||
#
|
||||
# if face_chunks:
|
||||
# for i in range(0, len(face_chunks), 10):
|
||||
# await self.api_service.milvus_insert_document(
|
||||
# request=request,
|
||||
# chunks=face_chunks[i:i + 10],
|
||||
# upappid=service_params['vdb'],
|
||||
# apiname="milvus/insertdocument",
|
||||
# user=userid,
|
||||
# db_type=db_type
|
||||
# )
|
||||
# timings["insert_images"] = time.time() - start
|
||||
# image_count = len(image_chunks)
|
||||
# face_count = len(face_chunks)
|
||||
#
|
||||
# debug(f"多模态插入完成: 图像 {image_count} 条, 人脸 {face_count} 条")
|
||||
#
|
||||
# if transaction_mgr and (image_count + face_count > 0):
|
||||
# transaction_mgr.add_operation(
|
||||
# OperationType.IMAGE_VECTORS_INSERT,
|
||||
# {"images": image_count, "faces": face_count, "document_id": document_id}
|
||||
# )
|
||||
#
|
||||
# # 记录事务操作,包含回滚函数
|
||||
# if transaction_mgr:
|
||||
# async def rollback_multimodal(data, context):
|
||||
# try:
|
||||
# # 防御性检查
|
||||
# required_context = ['request', 'service_params', 'userid']
|
||||
# missing_context = [k for k in required_context if k not in context or context[k] is None]
|
||||
# if missing_context:
|
||||
# raise ValueError(f"回滚上下文缺少字段: {', '.join(missing_context)}")
|
||||
#
|
||||
# required_data = ['orgid', 'realpath', 'fiid', 'id', 'db_type']
|
||||
# missing_data = [k for k in required_data if k not in data or data[k] is None]
|
||||
# if missing_data:
|
||||
# raise ValueError(f"多模态回滚数据缺少字段: {', '.join(missing_data)}")
|
||||
#
|
||||
# await self.delete_from_vector_db(
|
||||
# context['request'], data['orgid'], data['realpath'],
|
||||
# data['fiid'], data['id'], context['service_params'],
|
||||
# context['userid'], data['db_type']
|
||||
# )
|
||||
# return f"已回滚多模态向量: {data['id']}"
|
||||
# except Exception as e:
|
||||
# error(f"多模态回滚向量数据库失败: document_id={data.get('id', '未知')}, 错误: {str(e)}")
|
||||
# raise
|
||||
#
|
||||
# transaction_mgr.add_operation(
|
||||
# OperationType.VDB_INSERT,
|
||||
# {
|
||||
# 'orgid': orgid, 'realpath': realpath, 'fiid': fiid,
|
||||
# 'id': id, 'db_type': db_type
|
||||
# },
|
||||
# rollback_func=rollback_multimodal
|
||||
# )
|
||||
#
|
||||
# return image_count, face_count
|
||||
|
||||
async def insert_to_vector_text(self, request,
|
||||
db_type: str, fields: Dict, service_params: Dict, userid: str, timings: Dict) -> List[Dict]:
|
||||
@ -383,25 +850,150 @@ class RagOperations:
|
||||
debug(f"三元组匹配总耗时: {timings['triplet_matching']:.3f} 秒")
|
||||
return all_triplets
|
||||
|
||||
async def generate_query_vector(self, request, text: str, service_params: Dict,
|
||||
userid: str, timings: Dict) -> List[float]:
|
||||
"""生成查询向量"""
|
||||
debug(f"生成查询向量: {text[:200]}...")
|
||||
async def generate_query_vector(
|
||||
self,
|
||||
request,
|
||||
text: str,
|
||||
service_params: Dict,
|
||||
userid: str,
|
||||
timings: Dict,
|
||||
embedding_mode: int = 0
|
||||
) -> List[float]:
|
||||
"""生成查询向量(支持文本/多模态)"""
|
||||
debug(f"生成查询向量: mode={embedding_mode}, text='{text[:100]}...'")
|
||||
start_vector = time.time()
|
||||
query_vector = await self.api_service.get_embeddings(
|
||||
|
||||
if embedding_mode == 0:
|
||||
# === 模式 0:纯文本嵌入(BAAI/bge-m3)===
|
||||
debug("使用 BAAI/bge-m3 文本嵌入")
|
||||
vectors = await self.api_service.get_embeddings(
|
||||
request=request,
|
||||
texts=[text],
|
||||
upappid=service_params['embedding'],
|
||||
apiname="BAAI/bge-m3",
|
||||
user=userid
|
||||
)
|
||||
if not query_vector or not all(len(vec) == 1024 for vec in query_vector):
|
||||
raise ValueError("查询向量必须是长度为 1024 的浮点数列表")
|
||||
query_vector = query_vector[0]
|
||||
if not vectors or not isinstance(vectors, list) or len(vectors) == 0:
|
||||
raise ValueError("bge-m3 返回空结果")
|
||||
query_vector = vectors[0]
|
||||
if len(query_vector) != 1024:
|
||||
raise ValueError(f"bge-m3 返回向量维度错误: {len(query_vector)}")
|
||||
|
||||
elif embedding_mode == 1:
|
||||
# === 模式 1:多模态嵌入(black/clip)===
|
||||
debug("使用 black/clip 多模态嵌入")
|
||||
inputs = [{"type": "text", "content": text}]
|
||||
|
||||
result = await self.api_service.get_multi_embeddings(
|
||||
request=request,
|
||||
inputs=inputs,
|
||||
upappid=service_params['embedding'],
|
||||
apiname="black/clip",
|
||||
user=userid
|
||||
)
|
||||
|
||||
query_vector = None
|
||||
for key, info in result.items():
|
||||
if info.get("type") == "error":
|
||||
debug(f"CLIP 返回错误跳过: {info['error']}")
|
||||
continue
|
||||
if "vector" in info and isinstance(info["vector"], list) and len(info["vector"]) == 1024:
|
||||
query_vector = info["vector"]
|
||||
debug(f"成功获取 CLIP 向量(来自 {info['type']})")
|
||||
break
|
||||
|
||||
if query_vector is None:
|
||||
raise ValueError("black/clip 未返回任何有效 1024 维向量")
|
||||
|
||||
else:
|
||||
raise ValueError(f"不支持的 embedding_mode: {embedding_mode}")
|
||||
|
||||
# 最终统一校验
|
||||
if not isinstance(query_vector, list) or len(query_vector) != 1024:
|
||||
raise ValueError(f"查询向量必须是长度为 1024 的浮点数列表,实际: {len(query_vector)}")
|
||||
|
||||
timings["vector_generation"] = time.time() - start_vector
|
||||
debug(f"生成查询向量耗时: {timings['vector_generation']:.3f} 秒")
|
||||
debug(f"生成查询向量成功,耗时: {timings['vector_generation']:.3f} 秒,模式: {embedding_mode}")
|
||||
return query_vector
|
||||
|
||||
async def generate_image_vector(
|
||||
self,
|
||||
request,
|
||||
img_path: str,
|
||||
service_params: Dict,
|
||||
userid: str,
|
||||
timings: Dict,
|
||||
embedding_mode: int = 0
|
||||
) -> List[float]:
|
||||
"""生成查询向量(支持文本/多模态)"""
|
||||
debug(f"生成查询向量: mode={embedding_mode}, image={img_path}")
|
||||
start_vector = time.time()
|
||||
|
||||
if embedding_mode == 0:
|
||||
raise ValueError(f"纯文本没有这个功能,请重新选择服务")
|
||||
|
||||
elif embedding_mode == 1:
|
||||
# === 模式 1:多模态嵌入(black/clip)===
|
||||
debug("使用 black/clip 多模态嵌入")
|
||||
inputs = []
|
||||
try:
|
||||
ext = Path(img_path).suffix.lower()
|
||||
if ext not in {".png", ".jpg", ".jpeg", ".webp", ".bmp"}:
|
||||
ext = ".jpg"
|
||||
|
||||
mime_map = {
|
||||
".png": "image/png",
|
||||
".jpg": "image/jpeg",
|
||||
".jpeg": "image/jpeg",
|
||||
".webp": "image/webp",
|
||||
".bmp": "image/bmp"
|
||||
}
|
||||
mime_type = mime_map.get(ext, "image/jpeg")
|
||||
with open(img_path, "rb") as f:
|
||||
b64 = base64.b64encode(f.read()).decode("utf-8")
|
||||
data_uri = f"data:{mime_type};base64,{b64}"
|
||||
|
||||
inputs.append({
|
||||
"type": "image",
|
||||
"data": data_uri
|
||||
})
|
||||
debug(f"已添加图像({mime_type}, {len(b64) / 1024:.1f}KB): {Path(img_path).name}")
|
||||
|
||||
except Exception as e:
|
||||
debug(f"图像处理失败,跳过: {img_path} → {e}")
|
||||
|
||||
result = await self.api_service.get_multi_embeddings(
|
||||
request=request,
|
||||
inputs=inputs,
|
||||
upappid=service_params['embedding'],
|
||||
apiname="black/clip",
|
||||
user=userid
|
||||
)
|
||||
|
||||
image_vector = None
|
||||
for key, info in result.items():
|
||||
if info.get("type") == "error":
|
||||
debug(f"CLIP 返回错误跳过: {info['error']}")
|
||||
continue
|
||||
if "vector" in info and isinstance(info["vector"], list) and len(info["vector"]) == 1024:
|
||||
image_vector = info["vector"]
|
||||
debug(f"成功获取 CLIP 向量(来自 {info['type']})")
|
||||
break
|
||||
|
||||
if image_vector is None:
|
||||
raise ValueError("black/clip 未返回任何有效 1024 维向量")
|
||||
|
||||
else:
|
||||
raise ValueError(f"不支持的 embedding_mode: {embedding_mode}")
|
||||
|
||||
# 最终统一校验
|
||||
if not isinstance(image_vector, list) or len(image_vector) != 1024:
|
||||
raise ValueError(f"查询向量必须是长度为 1024 的浮点数列表,实际: {len(image_vector)}")
|
||||
|
||||
timings["vector_generation"] = time.time() - start_vector
|
||||
debug(f"生成查询向量成功,耗时: {timings['vector_generation']:.3f} 秒,模式: {embedding_mode}")
|
||||
return image_vector
|
||||
|
||||
async def vector_search(self, request, query_vector: List[float], orgid: str,
|
||||
fiids: List[str], limit: int, service_params: Dict, userid: str,
|
||||
timings: Dict) -> List[Dict]:
|
||||
@ -473,34 +1065,49 @@ class RagOperations:
|
||||
return unique_triples
|
||||
|
||||
def format_search_results(self, results: List[Dict], limit: int) -> List[Dict]:
|
||||
"""格式化搜索结果为统一格式"""
|
||||
formatted_results = []
|
||||
# for res in results[:limit]:
|
||||
# score = res.get('rerank_score', res.get('distance', 0))
|
||||
#
|
||||
# content = res.get('text', '')
|
||||
# title = res.get('metadata', {}).get('filename', 'Untitled')
|
||||
# document_id = res.get('metadata', {}).get('document_id', '')
|
||||
#
|
||||
# formatted_results.append({
|
||||
# "content": content,
|
||||
# "title": title,
|
||||
# "metadata": {"document_id": document_id, "score": score},
|
||||
# })
|
||||
#得分归一化
|
||||
formatted = []
|
||||
for res in results[:limit]:
|
||||
rerank_score = res.get('rerank_score', 0)
|
||||
score = 1 / (1 + math.exp(-rerank_score)) if rerank_score is not None else 1 - res.get('distance', 0)
|
||||
score = max(0.0, min(1.0, score))
|
||||
|
||||
content = res.get('text', '')
|
||||
title = res.get('metadata', {}).get('filename', 'Untitled')
|
||||
document_id = res.get('metadata', {}).get('document_id', '')
|
||||
|
||||
formatted_results.append({
|
||||
"content": content,
|
||||
"title": title,
|
||||
"metadata": {"document_id": document_id, "score": score},
|
||||
# # 优先 rerank,其次用向量相似度(直接用,不要反)
|
||||
# if res.get('rerank_score') is not None:
|
||||
# score = res.get('rerank_score')
|
||||
# else:
|
||||
# score = res.get('distance', 0.0)
|
||||
distance = res.get('distance', 0.0)
|
||||
rerank_score = res.get('rerank_score', 0.0)
|
||||
formatted.append({
|
||||
"content": res.get('text', ''),
|
||||
"title": res.get('metadata', {}).get('filename', 'Untitled'),
|
||||
"metadata": {
|
||||
"document_id": res.get('metadata', {}).get('document_id', ''),
|
||||
"distance": distance,
|
||||
"rerank_score": rerank_score,
|
||||
}
|
||||
})
|
||||
return formatted
|
||||
|
||||
return formatted_results
|
||||
# async def save_uploaded_photo(self, image_file: FileStorage, orgid: str) -> str:
|
||||
# """
|
||||
# 把前端上传的图片保存到 /home/wangmeihua/kyrag/data/photo 目录下
|
||||
# 返回保存后的绝对路径(字符串),供 generate_img_vector 使用
|
||||
# """
|
||||
# if not image_file or not hasattr(image_file, "filename"):
|
||||
# raise ValueError("无效的图片上传对象")
|
||||
#
|
||||
# # 为了安全,按 orgid 分目录存放(避免不同公司文件混在一起)
|
||||
# org_dir = UPLOAD_PHOTO_DIR / orgid
|
||||
# org_dir.mkdir(parents=True, exist_ok=True)
|
||||
#
|
||||
# # 生成唯一文件名,保留原始后缀
|
||||
# suffix = Path(image_file.filename).suffix.lower()
|
||||
# if not suffix or suffix not in {".jpg", ".jpeg", ".png", ".webp", ".bmp", ".gif"}:
|
||||
# suffix = ".jpg"
|
||||
#
|
||||
# unique_name = f"{uuid.uuid4().hex}{suffix}"
|
||||
# save_path = org_dir / unique_name
|
||||
#
|
||||
# # 真正落盘
|
||||
# image_file.save(str(save_path))
|
||||
# debug(f"图片已保存: {save_path} (原始名: {image_file.filename})")
|
||||
#
|
||||
# # 返回字符串路径,generate_img_vector 直接收 str 就行
|
||||
# return str(save_path)
|
||||
199
rag/ragapi.py
199
rag/ragapi.py
@ -6,8 +6,12 @@ import traceback
|
||||
import json
|
||||
import math
|
||||
import uuid
|
||||
from rag.service_opts import get_service_params, sor_get_service_params
|
||||
import os
|
||||
from rag.service_opts import get_service_params, sor_get_service_params, sor_get_embedding_mode, get_embedding_mode
|
||||
from rag.rag_operations import RagOperations
|
||||
from langchain_core.documents import Document
|
||||
|
||||
REAL_PHOTO_ROOT = "/home/wangmeihua/kyrag/files"
|
||||
|
||||
helptext = """kyrag API:
|
||||
|
||||
@ -133,7 +137,16 @@ async def fusedsearch(request, params_kw, *params):
|
||||
debug(f"params_kw: {params_kw}")
|
||||
# orgid = "04J6VbxLqB_9RPMcgOv_8"
|
||||
# userid = "04J6VbxLqB_9RPMcgOv_8"
|
||||
query = params_kw.get('query', '')
|
||||
query = params_kw.get('query', '').strip()
|
||||
img_path = params_kw.get('image')
|
||||
if isinstance(img_path, str):
|
||||
img_path = img_path.strip()
|
||||
relative_part = img_path.lstrip("/")
|
||||
real_img_path = os.path.join(REAL_PHOTO_ROOT, relative_part)
|
||||
if not os.path.exists(real_img_path):
|
||||
raise FileNotFoundError(f"图片不存在: {real_img_path}")
|
||||
img_path = real_img_path
|
||||
debug(f"自动修复图片路径成功: {img_path}")
|
||||
# 统一模式处理 limit 参数,为了对接dify和coze
|
||||
raw_limit = params_kw.get('limit') or (
|
||||
params_kw.get('retrieval_setting', {}).get('top_k')
|
||||
@ -188,7 +201,65 @@ async def fusedsearch(request, params_kw, *params):
|
||||
service_params = await get_service_params(orgid)
|
||||
if not service_params:
|
||||
raise ValueError("无法获取服务参数")
|
||||
# 获取嵌入模式
|
||||
embedding_mode = await get_embedding_mode(orgid)
|
||||
debug(f"检测到 embedding_mode = {embedding_mode}(0=文本, 1=多模态)")
|
||||
|
||||
# 情况1:query 和 image 都为空 → 报错
|
||||
if not query and not img_path:
|
||||
raise ValueError("查询文本和图片不能同时为空")
|
||||
|
||||
# 情况2:query 和 image 都存在 → 报错(你当前业务不允许同时传)
|
||||
if query and img_path:
|
||||
raise ValueError("查询文本和图片只能二选一,不能同时提交")
|
||||
|
||||
# 3. 只有图片 → 以图搜图 走纯多模态分支
|
||||
if img_path and not query:
|
||||
try:
|
||||
debug("检测到纯图片查询,执行以图搜图")
|
||||
rag_ops = RagOperations()
|
||||
|
||||
timings = {}
|
||||
start_time = time.time()
|
||||
|
||||
# 直接生成图片向量
|
||||
img_vector = await rag_ops.generate_image_vector(
|
||||
request, img_path, service_params, userid, timings, embedding_mode
|
||||
)
|
||||
|
||||
# 向量搜索(多取 50 条再截断,和文本分支保持一致)
|
||||
search_results = await rag_ops.vector_search(
|
||||
request, img_vector, orgid, fiids, limit + 50, service_params, userid, timings
|
||||
)
|
||||
|
||||
timings["total_time"] = time.time() - start_time
|
||||
|
||||
# 可选:搜索完后删除图片,省磁盘(看你需求)
|
||||
# try:
|
||||
# os.remove(img_path)
|
||||
# except:
|
||||
# pass
|
||||
|
||||
final_results = []
|
||||
for item in search_results[:limit]:
|
||||
final_results.append({
|
||||
"text": item["text"],
|
||||
"distance": item["distance"]
|
||||
})
|
||||
|
||||
return {
|
||||
"results": final_results,
|
||||
"timings": timings
|
||||
}
|
||||
except Exception as e:
|
||||
error(f"融合搜索失败: {str(e)}, 堆栈: {traceback.format_exc()}")
|
||||
return {
|
||||
"records": [],
|
||||
"timings": {"total_time": time.time() - start_time if 'start_time' in locals() else 0},
|
||||
"error": str(e)
|
||||
}
|
||||
|
||||
if not img_path and query:
|
||||
try:
|
||||
timings = {}
|
||||
start_time = time.time()
|
||||
@ -198,8 +269,8 @@ async def fusedsearch(request, params_kw, *params):
|
||||
all_triplets = await rag_ops.match_triplets(request, query, query_entities, orgid, fiids, service_params,
|
||||
userid, timings)
|
||||
combined_text = _combine_query_with_triplets(query, all_triplets)
|
||||
query_vector = await rag_ops.generate_query_vector(request, combined_text, service_params, userid, timings)
|
||||
search_results = await rag_ops.vector_search(request, query_vector, orgid, fiids, limit + 5, service_params,
|
||||
query_vector = await rag_ops.generate_query_vector(request, combined_text, service_params, userid, timings, embedding_mode)
|
||||
search_results = await rag_ops.vector_search(request, query_vector, orgid, fiids, limit + 50, service_params,
|
||||
userid, timings)
|
||||
|
||||
use_rerank = True
|
||||
@ -212,7 +283,7 @@ async def fusedsearch(request, params_kw, *params):
|
||||
|
||||
formatted_results = rag_ops.format_search_results(final_results, limit)
|
||||
timings["total_time"] = time.time() - start_time
|
||||
info(f"融合搜索完成,返回 {len(formatted_results)} 条结果,总耗时: {timings['total_time']:.3f} 秒")
|
||||
debug(f"融合搜索完成,返回 {len(formatted_results)} 条结果,总耗时: {timings['total_time']:.3f} 秒")
|
||||
|
||||
return {
|
||||
"records": formatted_results,
|
||||
@ -226,6 +297,116 @@ async def fusedsearch(request, params_kw, *params):
|
||||
"error": str(e)
|
||||
}
|
||||
|
||||
# async def fusedsearch(request, params_kw, *params):
|
||||
# """
|
||||
# 融合搜索,调用服务化端点
|
||||
#
|
||||
# """
|
||||
# kw = request._run_ns
|
||||
# f = kw.get('get_userorgid')
|
||||
# orgid = await f()
|
||||
# debug(f"orgid: {orgid},{f=}")
|
||||
# f = kw.get('get_user')
|
||||
# userid = await f()
|
||||
# debug(f"params_kw: {params_kw}")
|
||||
# # orgid = "04J6VbxLqB_9RPMcgOv_8"
|
||||
# # userid = "04J6VbxLqB_9RPMcgOv_8"
|
||||
# query = params_kw.get('query', '')
|
||||
# # 统一模式处理 limit 参数,为了对接dify和coze
|
||||
# raw_limit = params_kw.get('limit') or (
|
||||
# params_kw.get('retrieval_setting', {}).get('top_k')
|
||||
# if isinstance(params_kw.get('retrieval_setting'), dict)
|
||||
# else None
|
||||
# )
|
||||
#
|
||||
# # 标准化为整数值
|
||||
# if raw_limit is None:
|
||||
# limit = 5 # 两个来源都不存在时使用默认值
|
||||
# elif isinstance(raw_limit, (int, float)):
|
||||
# limit = int(raw_limit) # 数值类型直接转换
|
||||
# elif isinstance(raw_limit, str):
|
||||
# try:
|
||||
# # 字符串转换为整数
|
||||
# limit = int(raw_limit)
|
||||
# except (TypeError, ValueError):
|
||||
# limit = 5 # 转换失败使用默认值
|
||||
# else:
|
||||
# limit = 5 # 其他意外类型使用默认值
|
||||
# debug(f"limit: {limit}")
|
||||
# raw_fiids = params_kw.get('fiids') or params_kw.get('knowledge_id') #
|
||||
#
|
||||
# # 标准化为列表格式
|
||||
# if raw_fiids is None:
|
||||
# fiids = [] # 两个参数都不存在
|
||||
# elif isinstance(raw_fiids, list):
|
||||
# fiids = [str(item).strip() for item in raw_fiids] # 已经是列表
|
||||
# elif isinstance(raw_fiids, str):
|
||||
# # fiids = [f.strip() for f in raw_fiids.split(',') if f.strip()]
|
||||
# try:
|
||||
# # 尝试解析 JSON 字符串
|
||||
# parsed = json.loads(raw_fiids)
|
||||
# if isinstance(parsed, list):
|
||||
# fiids = [str(item).strip() for item in parsed] # JSON 数组转为字符串列表
|
||||
# else:
|
||||
# # 处理逗号分隔的字符串或单个 ID 字符串
|
||||
# fiids = [f.strip() for f in raw_fiids.split(',') if f.strip()]
|
||||
# except json.JSONDecodeError:
|
||||
# # 如果不是合法 JSON,按逗号分隔
|
||||
# fiids = [f.strip() for f in raw_fiids.split(',') if f.strip()]
|
||||
# elif isinstance(raw_fiids, (int, float)):
|
||||
# fiids = [str(int(raw_fiids))] # 数值类型转为字符串列表
|
||||
# else:
|
||||
# fiids = [] # 其他意外类型
|
||||
#
|
||||
# debug(f"fiids: {fiids}")
|
||||
#
|
||||
# # 验证 fiids的orgid与orgid = await f()是否一致
|
||||
# await _validate_fiids_orgid(fiids, orgid, kw)
|
||||
#
|
||||
# service_params = await get_service_params(orgid)
|
||||
# if not service_params:
|
||||
# raise ValueError("无法获取服务参数")
|
||||
# # 获取嵌入模式
|
||||
# embedding_mode = await get_embedding_mode(orgid)
|
||||
# debug(f"检测到 embedding_mode = {embedding_mode}(0=文本, 1=多模态)")
|
||||
#
|
||||
# try:
|
||||
# timings = {}
|
||||
# start_time = time.time()
|
||||
# rag_ops = RagOperations()
|
||||
#
|
||||
# query_entities = await rag_ops.extract_entities(request, query, service_params, userid, timings)
|
||||
# all_triplets = await rag_ops.match_triplets(request, query, query_entities, orgid, fiids, service_params,
|
||||
# userid, timings)
|
||||
# combined_text = _combine_query_with_triplets(query, all_triplets)
|
||||
# query_vector = await rag_ops.generate_query_vector(request, combined_text, service_params, userid, timings, embedding_mode)
|
||||
# search_results = await rag_ops.vector_search(request, query_vector, orgid, fiids, limit + 50, service_params,
|
||||
# userid, timings)
|
||||
#
|
||||
# use_rerank = False
|
||||
# if use_rerank and search_results:
|
||||
# final_results = await rag_ops.rerank_results(request, combined_text, search_results, limit, service_params,
|
||||
# userid, timings)
|
||||
# debug(f"final_results: {final_results}")
|
||||
# else:
|
||||
# final_results = [{k: v for k, v in r.items() if k != 'rerank_score'} for r in search_results]
|
||||
#
|
||||
# formatted_results = rag_ops.format_search_results(final_results, limit)
|
||||
# timings["total_time"] = time.time() - start_time
|
||||
# debug(f"融合搜索完成,返回 {len(formatted_results)} 条结果,总耗时: {timings['total_time']:.3f} 秒")
|
||||
#
|
||||
# return {
|
||||
# "records": formatted_results,
|
||||
# "timings": timings
|
||||
# }
|
||||
# except Exception as e:
|
||||
# error(f"融合搜索失败: {str(e)}, 堆栈: {traceback.format_exc()}")
|
||||
# return {
|
||||
# "records": [],
|
||||
# "timings": {"total_time": time.time() - start_time if 'start_time' in locals() else 0},
|
||||
# "error": str(e)
|
||||
# }
|
||||
|
||||
# async def text_insert(text: str, fiid: str, orgid: str, db_type: str):
|
||||
async def textinsert(request, params_kw, *params):
|
||||
kw = request._run_ns
|
||||
@ -254,7 +435,7 @@ async def textinsert(request, params_kw, *params):
|
||||
result = {
|
||||
"status": "error",
|
||||
"userid": orgid,
|
||||
"collection_name": "ragdb_{dbtype}",
|
||||
"collection_name": f"ragdb_{db_type}",
|
||||
"message": "",
|
||||
"status_code": 400
|
||||
}
|
||||
@ -277,10 +458,10 @@ async def textinsert(request, params_kw, *params):
|
||||
# 插入 Milvus
|
||||
fields = {
|
||||
"text": text,
|
||||
"fiid": fiid,
|
||||
"orgid": orgid,
|
||||
"knowledge_base_id": fiid,
|
||||
"userid": orgid,
|
||||
"vector": embedding,
|
||||
"id": id
|
||||
"document_id": id
|
||||
}
|
||||
chunks_data = await rag_ops.insert_to_vector_text(request, db_type, fields, service_params, userid, timings)
|
||||
|
||||
|
||||
@ -57,11 +57,12 @@ async def sor_get_service_params(sor, orgid):
|
||||
service_params['reranker'] = service['upappid']
|
||||
elif name == 'mrebel三元组抽取':
|
||||
service_params['triples'] = service['upappid']
|
||||
elif name == 'neo4j删除知识库':
|
||||
elif name == 'neo4j图知识库':
|
||||
service_params['gdb'] = service['upappid']
|
||||
elif name == 'small实体抽取':
|
||||
service_params['entities'] = service['upappid']
|
||||
|
||||
elif name == 'clip多模态嵌入服务':
|
||||
service_params['embedding'] = service['upappid']
|
||||
# 检查是否所有服务参数都已填充
|
||||
missing_services = [k for k, v in service_params.items() if v is None]
|
||||
if missing_services:
|
||||
@ -76,3 +77,25 @@ async def get_service_params(orgid):
|
||||
async with db.sqlorContext(dbname) as sor:
|
||||
return await sor_get_service_params(sor, orgid)
|
||||
return None
|
||||
|
||||
async def sor_get_embedding_mode(sor, orgid) -> int:
|
||||
"""根据 orgid 获取嵌入模式:0=纯文本,1=多模态"""
|
||||
sql = """
|
||||
SELECT em.mode
|
||||
FROM service_opts so
|
||||
JOIN embedding_mode em ON so.embedding_id = em.embeddingid
|
||||
WHERE so.orgid = ${orgid}$
|
||||
"""
|
||||
rows = await sor.sqlExe(sql, {"orgid": orgid})
|
||||
if not rows:
|
||||
debug(f"orgid={orgid} 未配置 embedding_mode,默认为 0(纯文本)")
|
||||
return 0
|
||||
return int(rows[0].mode)
|
||||
|
||||
async def get_embedding_mode(orgid):
|
||||
db = DBPools()
|
||||
# debug(f"传入的orgid是:{orgid}")
|
||||
dbname = get_serverenv('get_module_dbname')('rag')
|
||||
async with db.sqlorContext(dbname) as sor:
|
||||
return await sor_get_embedding_mode(sor, orgid)
|
||||
return None
|
||||
@ -20,7 +20,6 @@ class OperationType(Enum):
|
||||
VECTOR_SEARCH = "vector_search"
|
||||
RERANK = "rerank"
|
||||
|
||||
|
||||
@dataclass
|
||||
class RollbackOperation:
|
||||
"""回滚操作记录"""
|
||||
|
||||
@ -45,6 +45,43 @@ class APIService:
|
||||
error(f"request #{request_id} 嵌入服务调用失败: {str(e)}, upappid={upappid}, apiname={apiname}")
|
||||
raise RuntimeError(f"嵌入服务调用失败: {str(e)}")
|
||||
|
||||
#多模态嵌入服务
|
||||
async def get_multi_embeddings(
|
||||
self,
|
||||
request,
|
||||
inputs: List[Dict],
|
||||
upappid: str,
|
||||
apiname: str,
|
||||
user: str
|
||||
) -> Dict[str, Dict]:
|
||||
"""
|
||||
多模态统一嵌入(支持文本、图片、音频、视频)
|
||||
返回原始输入字符串为 key 的完整结果,含 type / vector / 人脸信息
|
||||
"""
|
||||
request_id = str(uuid.uuid4())
|
||||
debug(f"Request #{request_id} 多模态嵌入开始,共{len(inputs)}项")
|
||||
|
||||
if not inputs or not isinstance(inputs, list):
|
||||
raise ValueError("inputs 必须为非空列表")
|
||||
|
||||
try:
|
||||
uapi = UAPI(request, DictObject(**globals()))
|
||||
params_kw = {"inputs": inputs}
|
||||
b = await uapi.call(upappid, apiname, user, params_kw)
|
||||
d = await self.handle_uapi_response(b, upappid, apiname, "多模态嵌入服务", request_id)
|
||||
|
||||
if d.get("object") != "embedding.result" or "data" not in d:
|
||||
error(f"request #{request_id} 返回格式错误: {d}")
|
||||
raise RuntimeError("多模态嵌入返回格式错误")
|
||||
|
||||
result = d["data"] # 直接返回 {input_str: {type, vector, ...}}
|
||||
debug(f"request #{request_id} 成功获取 {len(result)} 条多模态向量")
|
||||
return result
|
||||
|
||||
except Exception as e:
|
||||
error(f"request #{request_id} 多模态嵌入失败: {str(e)}")
|
||||
raise RuntimeError(f"多模态嵌入失败: {str(e)}")
|
||||
|
||||
# 实体提取服务 (LTP/small)
|
||||
async def extract_entities(self, request, query: str, upappid: str, apiname: str, user: str) -> list:
|
||||
"""调用实体识别服务"""
|
||||
|
||||
@ -18,6 +18,11 @@
|
||||
"editable": true,
|
||||
"rows": 5
|
||||
},
|
||||
{
|
||||
"uitype": "image",
|
||||
"name": "image",
|
||||
"label": "上传查询图片(可选)"
|
||||
},
|
||||
{
|
||||
"name": "fiids",
|
||||
"uitype": "checkbox",
|
||||
|
||||
@ -7,14 +7,15 @@ if not orgid:
|
||||
message='请先登录'
|
||||
)
|
||||
|
||||
fiids = params_kw.fiids
|
||||
query = params_kw.query
|
||||
image = params_kw.image
|
||||
fiids = params_kw.fiids
|
||||
limit = params_kw.limit
|
||||
|
||||
if not query or not fiids or not limit:
|
||||
if (not query and not image) or not fiids or not limit:
|
||||
return UiError(
|
||||
title='无效输入',
|
||||
message='请输入查询文本并选择至少一个知识库'
|
||||
message='请输入查询文本或上传image并选择至少一个知识库和填写返回条数'
|
||||
)
|
||||
|
||||
try:
|
||||
|
||||
9
wwwroot/test_textinsert.dspy
Normal file
9
wwwroot/test_textinsert.dspy
Normal file
@ -0,0 +1,9 @@
|
||||
debug(f'{params_kw=}')
|
||||
text = params_kw.text
|
||||
fiid = params_kw.fiid
|
||||
db_type = params_kw.db_type
|
||||
env = DictObject(**globals())
|
||||
keys = [k for k in env.keys()]
|
||||
debug(f'{keys=}')
|
||||
x = await rfexe('textinsert', request, params_kw)
|
||||
return x
|
||||
Loading…
x
Reference in New Issue
Block a user