rag
This commit is contained in:
parent
985c5a998a
commit
b2088aec49
Binary file not shown.
@ -2,6 +2,7 @@ import os
|
|||||||
import re
|
import re
|
||||||
import time
|
import time
|
||||||
import math
|
import math
|
||||||
|
import numpy as np
|
||||||
from datetime import datetime
|
from datetime import datetime
|
||||||
from typing import List, Dict, Any, Optional
|
from typing import List, Dict, Any, Optional
|
||||||
from langchain_core.documents import Document
|
from langchain_core.documents import Document
|
||||||
@ -12,7 +13,10 @@ from filetxt.loader import fileloader, File2Text
|
|||||||
from rag.uapi_service import APIService
|
from rag.uapi_service import APIService
|
||||||
from rag.service_opts import get_service_params
|
from rag.service_opts import get_service_params
|
||||||
from rag.transaction_manager import TransactionManager, OperationType
|
from rag.transaction_manager import TransactionManager, OperationType
|
||||||
|
from pdf2image import convert_from_path
|
||||||
|
import pytesseract
|
||||||
|
import base64
|
||||||
|
from pathlib import Path
|
||||||
|
|
||||||
class RagOperations:
|
class RagOperations:
|
||||||
"""RAG 操作类,提供所有通用的 RAG 操作"""
|
"""RAG 操作类,提供所有通用的 RAG 操作"""
|
||||||
@ -33,14 +37,39 @@ class RagOperations:
|
|||||||
if ext not in supported_formats:
|
if ext not in supported_formats:
|
||||||
raise ValueError(f"不支持的文件格式: {ext}, 支持的格式: {', '.join(supported_formats)}")
|
raise ValueError(f"不支持的文件格式: {ext}, 支持的格式: {', '.join(supported_formats)}")
|
||||||
|
|
||||||
# 加载文件内容
|
|
||||||
text = fileloader(realpath)
|
text = fileloader(realpath)
|
||||||
text = re.sub(r'[^\u4e00-\u9fa5a-zA-Z0-9\s.;,\n/]', '', text)
|
if ext == 'pdf':
|
||||||
|
debug(f"pdf原生提取结果是:{text}")
|
||||||
|
if not text or len(text.strip()) == 0: # 更严格的空值检查
|
||||||
|
debug(f"pdf原生提取失败,尝试扫描件提取")
|
||||||
|
ocr_text = self.pdf_to_text(realpath)
|
||||||
|
debug(f"pdf扫描件抽取的文本内容是:{ocr_text}")
|
||||||
|
text = ocr_text # 只在原生提取失败时使用OCR结果
|
||||||
|
|
||||||
|
# 只在有文本内容时进行清洗
|
||||||
|
if text and len(text.strip()) > 0:
|
||||||
|
# 或者保留更多有用字符
|
||||||
|
text = re.sub(r'[^\u4e00-\u9fa5a-zA-Z0-9\s.;,\n/]', '', text)
|
||||||
|
else:
|
||||||
|
error(f"文件 {realpath} 无法提取任何文本内容")
|
||||||
|
text = "" # 确保为空字符串
|
||||||
|
|
||||||
timings["load_file"] = time.time() - start_load
|
timings["load_file"] = time.time() - start_load
|
||||||
debug(f"加载文件耗时: {timings['load_file']:.2f} 秒, 文本长度: {len(text)}")
|
debug(f"加载文件耗时: {timings['load_file']:.2f} 秒, 文本长度: {len(text)}")
|
||||||
|
|
||||||
if not text or not text.strip():
|
# # 加载文件内容
|
||||||
raise ValueError(f"文件 {realpath} 加载为空")
|
# text = fileloader(realpath)
|
||||||
|
# debug(f"pdf原生提取结果是:{text}")
|
||||||
|
# if len(text) == 0:
|
||||||
|
# debug(f"pdf原生提取失败,尝试扫描件提取")
|
||||||
|
# text = self.pdf_to_text(realpath)
|
||||||
|
# debug(f"pdf扫描件抽取的文本内容是:{text}")
|
||||||
|
# text = re.sub(r'[^\u4e00-\u9fa5a-zA-Z0-9\s.;,\n/]', '', text)
|
||||||
|
# timings["load_file"] = time.time() - start_load
|
||||||
|
# debug(f"加载文件耗时: {timings['load_file']:.2f} 秒, 文本长度: {len(text)}")
|
||||||
|
#
|
||||||
|
# if not text or not text.strip():
|
||||||
|
# raise ValueError(f"文件 {realpath} 加载为空")
|
||||||
|
|
||||||
# 分片处理
|
# 分片处理
|
||||||
document = Document(page_content=text)
|
document = Document(page_content=text)
|
||||||
@ -67,6 +96,40 @@ class RagOperations:
|
|||||||
|
|
||||||
return chunks
|
return chunks
|
||||||
|
|
||||||
|
def pdf_to_text(
|
||||||
|
self,
|
||||||
|
pdf_path: str,
|
||||||
|
output_txt: Optional[str] = None,
|
||||||
|
dpi: int = 300,
|
||||||
|
lang: str = 'chi_sim+chi_tra+eng'
|
||||||
|
) -> str:
|
||||||
|
"""
|
||||||
|
将扫描版 PDF 转为文字(你原来的代码,一行调用版)
|
||||||
|
|
||||||
|
参数:
|
||||||
|
pdf_path: PDF 文件路径(字符串)
|
||||||
|
output_txt: 如果提供,会自动保存到这个 txt 文件(可选)
|
||||||
|
dpi: 图片分辨率,默认 300(越高越清晰)
|
||||||
|
lang: 语言包,默认中文简体+繁体+英文
|
||||||
|
|
||||||
|
返回:
|
||||||
|
提取出的完整文字(字符串)
|
||||||
|
"""
|
||||||
|
# PDF 转图片
|
||||||
|
images = convert_from_path(pdf_path, dpi=dpi)
|
||||||
|
|
||||||
|
# OCR 识别
|
||||||
|
text = ''
|
||||||
|
for img in images:
|
||||||
|
text += pytesseract.image_to_string(img, lang=lang) + '\n'
|
||||||
|
|
||||||
|
# 可选:自动保存到文件
|
||||||
|
if output_txt:
|
||||||
|
with open(output_txt, 'w', encoding='utf-8') as f:
|
||||||
|
f.write(text)
|
||||||
|
|
||||||
|
return text
|
||||||
|
|
||||||
async def generate_embeddings(self, request, chunks: List[Document], service_params: Dict,
|
async def generate_embeddings(self, request, chunks: List[Document], service_params: Dict,
|
||||||
userid: str, timings: Dict,
|
userid: str, timings: Dict,
|
||||||
transaction_mgr: TransactionManager = None) -> List[List[float]]:
|
transaction_mgr: TransactionManager = None) -> List[List[float]]:
|
||||||
@ -149,6 +212,14 @@ class RagOperations:
|
|||||||
|
|
||||||
return result
|
return result
|
||||||
|
|
||||||
|
# async def force_l2_normalize(self, vector: List[float]) -> List[float]:
|
||||||
|
# """万无一失的 L2 归一化"""
|
||||||
|
# arr = np.array(vector, dtype=np.float32)
|
||||||
|
# norm = np.linalg.norm(arr)
|
||||||
|
# if norm == 0:
|
||||||
|
# return vector # 全零向量无法归一化
|
||||||
|
# return (arr / norm).tolist()
|
||||||
|
|
||||||
# 统一插入向量库
|
# 统一插入向量库
|
||||||
async def insert_all_vectors(
|
async def insert_all_vectors(
|
||||||
self,
|
self,
|
||||||
@ -200,7 +271,10 @@ class RagOperations:
|
|||||||
# 遍历 multi_results
|
# 遍历 multi_results
|
||||||
for raw_key, info in multi_results.items():
|
for raw_key, info in multi_results.items():
|
||||||
typ = info["type"]
|
typ = info["type"]
|
||||||
|
# vector = info["vector"]
|
||||||
|
# debug(f"从后端传回来的向量数据是:{vector}")
|
||||||
|
# emb = await self.force_l2_normalize(info["vector"])
|
||||||
|
# debug(f"归一化后的向量数据是:{emb}")
|
||||||
# --- 文本 ---
|
# --- 文本 ---
|
||||||
if typ == "text":
|
if typ == "text":
|
||||||
# raw_key 就是原文
|
# raw_key 就是原文
|
||||||
@ -253,7 +327,7 @@ class RagOperations:
|
|||||||
# "upload_time": upload_time,
|
# "upload_time": upload_time,
|
||||||
# "file_type": "face",
|
# "file_type": "face",
|
||||||
# })
|
# })
|
||||||
# continue
|
continue
|
||||||
|
|
||||||
# --- 视频 ---
|
# --- 视频 ---
|
||||||
if typ == "video":
|
if typ == "video":
|
||||||
@ -289,7 +363,7 @@ class RagOperations:
|
|||||||
# "upload_time": upload_time,
|
# "upload_time": upload_time,
|
||||||
# "file_type": "face",
|
# "file_type": "face",
|
||||||
# })
|
# })
|
||||||
# continue
|
continue
|
||||||
|
|
||||||
# --- 音频 ---
|
# --- 音频 ---
|
||||||
if typ == "audio":
|
if typ == "audio":
|
||||||
@ -776,25 +850,150 @@ class RagOperations:
|
|||||||
debug(f"三元组匹配总耗时: {timings['triplet_matching']:.3f} 秒")
|
debug(f"三元组匹配总耗时: {timings['triplet_matching']:.3f} 秒")
|
||||||
return all_triplets
|
return all_triplets
|
||||||
|
|
||||||
async def generate_query_vector(self, request, text: str, service_params: Dict,
|
async def generate_query_vector(
|
||||||
userid: str, timings: Dict) -> List[float]:
|
self,
|
||||||
"""生成查询向量"""
|
request,
|
||||||
debug(f"生成查询向量: {text[:200]}...")
|
text: str,
|
||||||
|
service_params: Dict,
|
||||||
|
userid: str,
|
||||||
|
timings: Dict,
|
||||||
|
embedding_mode: int = 0
|
||||||
|
) -> List[float]:
|
||||||
|
"""生成查询向量(支持文本/多模态)"""
|
||||||
|
debug(f"生成查询向量: mode={embedding_mode}, text='{text[:100]}...'")
|
||||||
start_vector = time.time()
|
start_vector = time.time()
|
||||||
query_vector = await self.api_service.get_embeddings(
|
|
||||||
request=request,
|
if embedding_mode == 0:
|
||||||
texts=[text],
|
# === 模式 0:纯文本嵌入(BAAI/bge-m3)===
|
||||||
upappid=service_params['embedding'],
|
debug("使用 BAAI/bge-m3 文本嵌入")
|
||||||
apiname="BAAI/bge-m3",
|
vectors = await self.api_service.get_embeddings(
|
||||||
user=userid
|
request=request,
|
||||||
)
|
texts=[text],
|
||||||
if not query_vector or not all(len(vec) == 1024 for vec in query_vector):
|
upappid=service_params['embedding'],
|
||||||
raise ValueError("查询向量必须是长度为 1024 的浮点数列表")
|
apiname="BAAI/bge-m3",
|
||||||
query_vector = query_vector[0]
|
user=userid
|
||||||
|
)
|
||||||
|
if not vectors or not isinstance(vectors, list) or len(vectors) == 0:
|
||||||
|
raise ValueError("bge-m3 返回空结果")
|
||||||
|
query_vector = vectors[0]
|
||||||
|
if len(query_vector) != 1024:
|
||||||
|
raise ValueError(f"bge-m3 返回向量维度错误: {len(query_vector)}")
|
||||||
|
|
||||||
|
elif embedding_mode == 1:
|
||||||
|
# === 模式 1:多模态嵌入(black/clip)===
|
||||||
|
debug("使用 black/clip 多模态嵌入")
|
||||||
|
inputs = [{"type": "text", "content": text}]
|
||||||
|
|
||||||
|
result = await self.api_service.get_multi_embeddings(
|
||||||
|
request=request,
|
||||||
|
inputs=inputs,
|
||||||
|
upappid=service_params['embedding'],
|
||||||
|
apiname="black/clip",
|
||||||
|
user=userid
|
||||||
|
)
|
||||||
|
|
||||||
|
query_vector = None
|
||||||
|
for key, info in result.items():
|
||||||
|
if info.get("type") == "error":
|
||||||
|
debug(f"CLIP 返回错误跳过: {info['error']}")
|
||||||
|
continue
|
||||||
|
if "vector" in info and isinstance(info["vector"], list) and len(info["vector"]) == 1024:
|
||||||
|
query_vector = info["vector"]
|
||||||
|
debug(f"成功获取 CLIP 向量(来自 {info['type']})")
|
||||||
|
break
|
||||||
|
|
||||||
|
if query_vector is None:
|
||||||
|
raise ValueError("black/clip 未返回任何有效 1024 维向量")
|
||||||
|
|
||||||
|
else:
|
||||||
|
raise ValueError(f"不支持的 embedding_mode: {embedding_mode}")
|
||||||
|
|
||||||
|
# 最终统一校验
|
||||||
|
if not isinstance(query_vector, list) or len(query_vector) != 1024:
|
||||||
|
raise ValueError(f"查询向量必须是长度为 1024 的浮点数列表,实际: {len(query_vector)}")
|
||||||
|
|
||||||
timings["vector_generation"] = time.time() - start_vector
|
timings["vector_generation"] = time.time() - start_vector
|
||||||
debug(f"生成查询向量耗时: {timings['vector_generation']:.3f} 秒")
|
debug(f"生成查询向量成功,耗时: {timings['vector_generation']:.3f} 秒,模式: {embedding_mode}")
|
||||||
return query_vector
|
return query_vector
|
||||||
|
|
||||||
|
async def generate_image_vector(
|
||||||
|
self,
|
||||||
|
request,
|
||||||
|
img_path: str,
|
||||||
|
service_params: Dict,
|
||||||
|
userid: str,
|
||||||
|
timings: Dict,
|
||||||
|
embedding_mode: int = 0
|
||||||
|
) -> List[float]:
|
||||||
|
"""生成查询向量(支持文本/多模态)"""
|
||||||
|
debug(f"生成查询向量: mode={embedding_mode}, image={img_path}")
|
||||||
|
start_vector = time.time()
|
||||||
|
|
||||||
|
if embedding_mode == 0:
|
||||||
|
raise ValueError(f"纯文本没有这个功能,请重新选择服务")
|
||||||
|
|
||||||
|
elif embedding_mode == 1:
|
||||||
|
# === 模式 1:多模态嵌入(black/clip)===
|
||||||
|
debug("使用 black/clip 多模态嵌入")
|
||||||
|
inputs = []
|
||||||
|
try:
|
||||||
|
ext = Path(img_path).suffix.lower()
|
||||||
|
if ext not in {".png", ".jpg", ".jpeg", ".webp", ".bmp"}:
|
||||||
|
ext = ".jpg"
|
||||||
|
|
||||||
|
mime_map = {
|
||||||
|
".png": "image/png",
|
||||||
|
".jpg": "image/jpeg",
|
||||||
|
".jpeg": "image/jpeg",
|
||||||
|
".webp": "image/webp",
|
||||||
|
".bmp": "image/bmp"
|
||||||
|
}
|
||||||
|
mime_type = mime_map.get(ext, "image/jpeg")
|
||||||
|
with open(img_path, "rb") as f:
|
||||||
|
b64 = base64.b64encode(f.read()).decode("utf-8")
|
||||||
|
data_uri = f"data:{mime_type};base64,{b64}"
|
||||||
|
|
||||||
|
inputs.append({
|
||||||
|
"type": "image",
|
||||||
|
"data": data_uri
|
||||||
|
})
|
||||||
|
debug(f"已添加图像({mime_type}, {len(b64) / 1024:.1f}KB): {Path(img_path).name}")
|
||||||
|
|
||||||
|
except Exception as e:
|
||||||
|
debug(f"图像处理失败,跳过: {img_path} → {e}")
|
||||||
|
|
||||||
|
result = await self.api_service.get_multi_embeddings(
|
||||||
|
request=request,
|
||||||
|
inputs=inputs,
|
||||||
|
upappid=service_params['embedding'],
|
||||||
|
apiname="black/clip",
|
||||||
|
user=userid
|
||||||
|
)
|
||||||
|
|
||||||
|
image_vector = None
|
||||||
|
for key, info in result.items():
|
||||||
|
if info.get("type") == "error":
|
||||||
|
debug(f"CLIP 返回错误跳过: {info['error']}")
|
||||||
|
continue
|
||||||
|
if "vector" in info and isinstance(info["vector"], list) and len(info["vector"]) == 1024:
|
||||||
|
image_vector = info["vector"]
|
||||||
|
debug(f"成功获取 CLIP 向量(来自 {info['type']})")
|
||||||
|
break
|
||||||
|
|
||||||
|
if image_vector is None:
|
||||||
|
raise ValueError("black/clip 未返回任何有效 1024 维向量")
|
||||||
|
|
||||||
|
else:
|
||||||
|
raise ValueError(f"不支持的 embedding_mode: {embedding_mode}")
|
||||||
|
|
||||||
|
# 最终统一校验
|
||||||
|
if not isinstance(image_vector, list) or len(image_vector) != 1024:
|
||||||
|
raise ValueError(f"查询向量必须是长度为 1024 的浮点数列表,实际: {len(image_vector)}")
|
||||||
|
|
||||||
|
timings["vector_generation"] = time.time() - start_vector
|
||||||
|
debug(f"生成查询向量成功,耗时: {timings['vector_generation']:.3f} 秒,模式: {embedding_mode}")
|
||||||
|
return image_vector
|
||||||
|
|
||||||
async def vector_search(self, request, query_vector: List[float], orgid: str,
|
async def vector_search(self, request, query_vector: List[float], orgid: str,
|
||||||
fiids: List[str], limit: int, service_params: Dict, userid: str,
|
fiids: List[str], limit: int, service_params: Dict, userid: str,
|
||||||
timings: Dict) -> List[Dict]:
|
timings: Dict) -> List[Dict]:
|
||||||
@ -866,34 +1065,49 @@ class RagOperations:
|
|||||||
return unique_triples
|
return unique_triples
|
||||||
|
|
||||||
def format_search_results(self, results: List[Dict], limit: int) -> List[Dict]:
|
def format_search_results(self, results: List[Dict], limit: int) -> List[Dict]:
|
||||||
"""格式化搜索结果为统一格式"""
|
formatted = []
|
||||||
formatted_results = []
|
|
||||||
# for res in results[:limit]:
|
|
||||||
# score = res.get('rerank_score', res.get('distance', 0))
|
|
||||||
#
|
|
||||||
# content = res.get('text', '')
|
|
||||||
# title = res.get('metadata', {}).get('filename', 'Untitled')
|
|
||||||
# document_id = res.get('metadata', {}).get('document_id', '')
|
|
||||||
#
|
|
||||||
# formatted_results.append({
|
|
||||||
# "content": content,
|
|
||||||
# "title": title,
|
|
||||||
# "metadata": {"document_id": document_id, "score": score},
|
|
||||||
# })
|
|
||||||
#得分归一化
|
|
||||||
for res in results[:limit]:
|
for res in results[:limit]:
|
||||||
rerank_score = res.get('rerank_score', 0)
|
# # 优先 rerank,其次用向量相似度(直接用,不要反)
|
||||||
score = 1 / (1 + math.exp(-rerank_score)) if rerank_score is not None else 1 - res.get('distance', 0)
|
# if res.get('rerank_score') is not None:
|
||||||
score = max(0.0, min(1.0, score))
|
# score = res.get('rerank_score')
|
||||||
|
# else:
|
||||||
content = res.get('text', '')
|
# score = res.get('distance', 0.0)
|
||||||
title = res.get('metadata', {}).get('filename', 'Untitled')
|
distance = res.get('distance', 0.0)
|
||||||
document_id = res.get('metadata', {}).get('document_id', '')
|
rerank_score = res.get('rerank_score', 0.0)
|
||||||
|
formatted.append({
|
||||||
formatted_results.append({
|
"content": res.get('text', ''),
|
||||||
"content": content,
|
"title": res.get('metadata', {}).get('filename', 'Untitled'),
|
||||||
"title": title,
|
"metadata": {
|
||||||
"metadata": {"document_id": document_id, "score": score},
|
"document_id": res.get('metadata', {}).get('document_id', ''),
|
||||||
|
"distance": distance,
|
||||||
|
"rerank_score": rerank_score,
|
||||||
|
}
|
||||||
})
|
})
|
||||||
|
return formatted
|
||||||
|
|
||||||
return formatted_results
|
# async def save_uploaded_photo(self, image_file: FileStorage, orgid: str) -> str:
|
||||||
|
# """
|
||||||
|
# 把前端上传的图片保存到 /home/wangmeihua/kyrag/data/photo 目录下
|
||||||
|
# 返回保存后的绝对路径(字符串),供 generate_img_vector 使用
|
||||||
|
# """
|
||||||
|
# if not image_file or not hasattr(image_file, "filename"):
|
||||||
|
# raise ValueError("无效的图片上传对象")
|
||||||
|
#
|
||||||
|
# # 为了安全,按 orgid 分目录存放(避免不同公司文件混在一起)
|
||||||
|
# org_dir = UPLOAD_PHOTO_DIR / orgid
|
||||||
|
# org_dir.mkdir(parents=True, exist_ok=True)
|
||||||
|
#
|
||||||
|
# # 生成唯一文件名,保留原始后缀
|
||||||
|
# suffix = Path(image_file.filename).suffix.lower()
|
||||||
|
# if not suffix or suffix not in {".jpg", ".jpeg", ".png", ".webp", ".bmp", ".gif"}:
|
||||||
|
# suffix = ".jpg"
|
||||||
|
#
|
||||||
|
# unique_name = f"{uuid.uuid4().hex}{suffix}"
|
||||||
|
# save_path = org_dir / unique_name
|
||||||
|
#
|
||||||
|
# # 真正落盘
|
||||||
|
# image_file.save(str(save_path))
|
||||||
|
# debug(f"图片已保存: {save_path} (原始名: {image_file.filename})")
|
||||||
|
#
|
||||||
|
# # 返回字符串路径,generate_img_vector 直接收 str 就行
|
||||||
|
# return str(save_path)
|
||||||
246
rag/ragapi.py
246
rag/ragapi.py
@ -6,10 +6,13 @@ import traceback
|
|||||||
import json
|
import json
|
||||||
import math
|
import math
|
||||||
import uuid
|
import uuid
|
||||||
from rag.service_opts import get_service_params, sor_get_service_params
|
import os
|
||||||
|
from rag.service_opts import get_service_params, sor_get_service_params, sor_get_embedding_mode, get_embedding_mode
|
||||||
from rag.rag_operations import RagOperations
|
from rag.rag_operations import RagOperations
|
||||||
from langchain_core.documents import Document
|
from langchain_core.documents import Document
|
||||||
|
|
||||||
|
REAL_PHOTO_ROOT = "/home/wangmeihua/kyrag/files"
|
||||||
|
|
||||||
helptext = """kyrag API:
|
helptext = """kyrag API:
|
||||||
|
|
||||||
1. 得到kdb表:
|
1. 得到kdb表:
|
||||||
@ -134,7 +137,16 @@ async def fusedsearch(request, params_kw, *params):
|
|||||||
debug(f"params_kw: {params_kw}")
|
debug(f"params_kw: {params_kw}")
|
||||||
# orgid = "04J6VbxLqB_9RPMcgOv_8"
|
# orgid = "04J6VbxLqB_9RPMcgOv_8"
|
||||||
# userid = "04J6VbxLqB_9RPMcgOv_8"
|
# userid = "04J6VbxLqB_9RPMcgOv_8"
|
||||||
query = params_kw.get('query', '')
|
query = params_kw.get('query', '').strip()
|
||||||
|
img_path = params_kw.get('image')
|
||||||
|
if isinstance(img_path, str):
|
||||||
|
img_path = img_path.strip()
|
||||||
|
relative_part = img_path.lstrip("/")
|
||||||
|
real_img_path = os.path.join(REAL_PHOTO_ROOT, relative_part)
|
||||||
|
if not os.path.exists(real_img_path):
|
||||||
|
raise FileNotFoundError(f"图片不存在: {real_img_path}")
|
||||||
|
img_path = real_img_path
|
||||||
|
debug(f"自动修复图片路径成功: {img_path}")
|
||||||
# 统一模式处理 limit 参数,为了对接dify和coze
|
# 统一模式处理 limit 参数,为了对接dify和coze
|
||||||
raw_limit = params_kw.get('limit') or (
|
raw_limit = params_kw.get('limit') or (
|
||||||
params_kw.get('retrieval_setting', {}).get('top_k')
|
params_kw.get('retrieval_setting', {}).get('top_k')
|
||||||
@ -189,43 +201,211 @@ async def fusedsearch(request, params_kw, *params):
|
|||||||
service_params = await get_service_params(orgid)
|
service_params = await get_service_params(orgid)
|
||||||
if not service_params:
|
if not service_params:
|
||||||
raise ValueError("无法获取服务参数")
|
raise ValueError("无法获取服务参数")
|
||||||
|
# 获取嵌入模式
|
||||||
|
embedding_mode = await get_embedding_mode(orgid)
|
||||||
|
debug(f"检测到 embedding_mode = {embedding_mode}(0=文本, 1=多模态)")
|
||||||
|
|
||||||
try:
|
# 情况1:query 和 image 都为空 → 报错
|
||||||
timings = {}
|
if not query and not img_path:
|
||||||
start_time = time.time()
|
raise ValueError("查询文本和图片不能同时为空")
|
||||||
rag_ops = RagOperations()
|
|
||||||
|
|
||||||
query_entities = await rag_ops.extract_entities(request, query, service_params, userid, timings)
|
# 情况2:query 和 image 都存在 → 报错(你当前业务不允许同时传)
|
||||||
all_triplets = await rag_ops.match_triplets(request, query, query_entities, orgid, fiids, service_params,
|
if query and img_path:
|
||||||
userid, timings)
|
raise ValueError("查询文本和图片只能二选一,不能同时提交")
|
||||||
combined_text = _combine_query_with_triplets(query, all_triplets)
|
|
||||||
query_vector = await rag_ops.generate_query_vector(request, combined_text, service_params, userid, timings)
|
|
||||||
search_results = await rag_ops.vector_search(request, query_vector, orgid, fiids, limit + 5, service_params,
|
|
||||||
userid, timings)
|
|
||||||
|
|
||||||
use_rerank = True
|
# 3. 只有图片 → 以图搜图 走纯多模态分支
|
||||||
if use_rerank and search_results:
|
if img_path and not query:
|
||||||
final_results = await rag_ops.rerank_results(request, combined_text, search_results, limit, service_params,
|
try:
|
||||||
|
debug("检测到纯图片查询,执行以图搜图")
|
||||||
|
rag_ops = RagOperations()
|
||||||
|
|
||||||
|
timings = {}
|
||||||
|
start_time = time.time()
|
||||||
|
|
||||||
|
# 直接生成图片向量
|
||||||
|
img_vector = await rag_ops.generate_image_vector(
|
||||||
|
request, img_path, service_params, userid, timings, embedding_mode
|
||||||
|
)
|
||||||
|
|
||||||
|
# 向量搜索(多取 50 条再截断,和文本分支保持一致)
|
||||||
|
search_results = await rag_ops.vector_search(
|
||||||
|
request, img_vector, orgid, fiids, limit + 50, service_params, userid, timings
|
||||||
|
)
|
||||||
|
|
||||||
|
timings["total_time"] = time.time() - start_time
|
||||||
|
|
||||||
|
# 可选:搜索完后删除图片,省磁盘(看你需求)
|
||||||
|
# try:
|
||||||
|
# os.remove(img_path)
|
||||||
|
# except:
|
||||||
|
# pass
|
||||||
|
|
||||||
|
final_results = []
|
||||||
|
for item in search_results[:limit]:
|
||||||
|
final_results.append({
|
||||||
|
"text": item["text"],
|
||||||
|
"distance": item["distance"]
|
||||||
|
})
|
||||||
|
|
||||||
|
return {
|
||||||
|
"results": final_results,
|
||||||
|
"timings": timings
|
||||||
|
}
|
||||||
|
except Exception as e:
|
||||||
|
error(f"融合搜索失败: {str(e)}, 堆栈: {traceback.format_exc()}")
|
||||||
|
return {
|
||||||
|
"records": [],
|
||||||
|
"timings": {"total_time": time.time() - start_time if 'start_time' in locals() else 0},
|
||||||
|
"error": str(e)
|
||||||
|
}
|
||||||
|
|
||||||
|
if not img_path and query:
|
||||||
|
try:
|
||||||
|
timings = {}
|
||||||
|
start_time = time.time()
|
||||||
|
rag_ops = RagOperations()
|
||||||
|
|
||||||
|
query_entities = await rag_ops.extract_entities(request, query, service_params, userid, timings)
|
||||||
|
all_triplets = await rag_ops.match_triplets(request, query, query_entities, orgid, fiids, service_params,
|
||||||
|
userid, timings)
|
||||||
|
combined_text = _combine_query_with_triplets(query, all_triplets)
|
||||||
|
query_vector = await rag_ops.generate_query_vector(request, combined_text, service_params, userid, timings, embedding_mode)
|
||||||
|
search_results = await rag_ops.vector_search(request, query_vector, orgid, fiids, limit + 50, service_params,
|
||||||
userid, timings)
|
userid, timings)
|
||||||
debug(f"final_results: {final_results}")
|
|
||||||
else:
|
|
||||||
final_results = [{k: v for k, v in r.items() if k != 'rerank_score'} for r in search_results]
|
|
||||||
|
|
||||||
formatted_results = rag_ops.format_search_results(final_results, limit)
|
use_rerank = True
|
||||||
timings["total_time"] = time.time() - start_time
|
if use_rerank and search_results:
|
||||||
info(f"融合搜索完成,返回 {len(formatted_results)} 条结果,总耗时: {timings['total_time']:.3f} 秒")
|
final_results = await rag_ops.rerank_results(request, combined_text, search_results, limit, service_params,
|
||||||
|
userid, timings)
|
||||||
|
debug(f"final_results: {final_results}")
|
||||||
|
else:
|
||||||
|
final_results = [{k: v for k, v in r.items() if k != 'rerank_score'} for r in search_results]
|
||||||
|
|
||||||
return {
|
formatted_results = rag_ops.format_search_results(final_results, limit)
|
||||||
"records": formatted_results,
|
timings["total_time"] = time.time() - start_time
|
||||||
"timings": timings
|
debug(f"融合搜索完成,返回 {len(formatted_results)} 条结果,总耗时: {timings['total_time']:.3f} 秒")
|
||||||
}
|
|
||||||
except Exception as e:
|
return {
|
||||||
error(f"融合搜索失败: {str(e)}, 堆栈: {traceback.format_exc()}")
|
"records": formatted_results,
|
||||||
return {
|
"timings": timings
|
||||||
"records": [],
|
}
|
||||||
"timings": {"total_time": time.time() - start_time if 'start_time' in locals() else 0},
|
except Exception as e:
|
||||||
"error": str(e)
|
error(f"融合搜索失败: {str(e)}, 堆栈: {traceback.format_exc()}")
|
||||||
}
|
return {
|
||||||
|
"records": [],
|
||||||
|
"timings": {"total_time": time.time() - start_time if 'start_time' in locals() else 0},
|
||||||
|
"error": str(e)
|
||||||
|
}
|
||||||
|
|
||||||
|
# async def fusedsearch(request, params_kw, *params):
|
||||||
|
# """
|
||||||
|
# 融合搜索,调用服务化端点
|
||||||
|
#
|
||||||
|
# """
|
||||||
|
# kw = request._run_ns
|
||||||
|
# f = kw.get('get_userorgid')
|
||||||
|
# orgid = await f()
|
||||||
|
# debug(f"orgid: {orgid},{f=}")
|
||||||
|
# f = kw.get('get_user')
|
||||||
|
# userid = await f()
|
||||||
|
# debug(f"params_kw: {params_kw}")
|
||||||
|
# # orgid = "04J6VbxLqB_9RPMcgOv_8"
|
||||||
|
# # userid = "04J6VbxLqB_9RPMcgOv_8"
|
||||||
|
# query = params_kw.get('query', '')
|
||||||
|
# # 统一模式处理 limit 参数,为了对接dify和coze
|
||||||
|
# raw_limit = params_kw.get('limit') or (
|
||||||
|
# params_kw.get('retrieval_setting', {}).get('top_k')
|
||||||
|
# if isinstance(params_kw.get('retrieval_setting'), dict)
|
||||||
|
# else None
|
||||||
|
# )
|
||||||
|
#
|
||||||
|
# # 标准化为整数值
|
||||||
|
# if raw_limit is None:
|
||||||
|
# limit = 5 # 两个来源都不存在时使用默认值
|
||||||
|
# elif isinstance(raw_limit, (int, float)):
|
||||||
|
# limit = int(raw_limit) # 数值类型直接转换
|
||||||
|
# elif isinstance(raw_limit, str):
|
||||||
|
# try:
|
||||||
|
# # 字符串转换为整数
|
||||||
|
# limit = int(raw_limit)
|
||||||
|
# except (TypeError, ValueError):
|
||||||
|
# limit = 5 # 转换失败使用默认值
|
||||||
|
# else:
|
||||||
|
# limit = 5 # 其他意外类型使用默认值
|
||||||
|
# debug(f"limit: {limit}")
|
||||||
|
# raw_fiids = params_kw.get('fiids') or params_kw.get('knowledge_id') #
|
||||||
|
#
|
||||||
|
# # 标准化为列表格式
|
||||||
|
# if raw_fiids is None:
|
||||||
|
# fiids = [] # 两个参数都不存在
|
||||||
|
# elif isinstance(raw_fiids, list):
|
||||||
|
# fiids = [str(item).strip() for item in raw_fiids] # 已经是列表
|
||||||
|
# elif isinstance(raw_fiids, str):
|
||||||
|
# # fiids = [f.strip() for f in raw_fiids.split(',') if f.strip()]
|
||||||
|
# try:
|
||||||
|
# # 尝试解析 JSON 字符串
|
||||||
|
# parsed = json.loads(raw_fiids)
|
||||||
|
# if isinstance(parsed, list):
|
||||||
|
# fiids = [str(item).strip() for item in parsed] # JSON 数组转为字符串列表
|
||||||
|
# else:
|
||||||
|
# # 处理逗号分隔的字符串或单个 ID 字符串
|
||||||
|
# fiids = [f.strip() for f in raw_fiids.split(',') if f.strip()]
|
||||||
|
# except json.JSONDecodeError:
|
||||||
|
# # 如果不是合法 JSON,按逗号分隔
|
||||||
|
# fiids = [f.strip() for f in raw_fiids.split(',') if f.strip()]
|
||||||
|
# elif isinstance(raw_fiids, (int, float)):
|
||||||
|
# fiids = [str(int(raw_fiids))] # 数值类型转为字符串列表
|
||||||
|
# else:
|
||||||
|
# fiids = [] # 其他意外类型
|
||||||
|
#
|
||||||
|
# debug(f"fiids: {fiids}")
|
||||||
|
#
|
||||||
|
# # 验证 fiids的orgid与orgid = await f()是否一致
|
||||||
|
# await _validate_fiids_orgid(fiids, orgid, kw)
|
||||||
|
#
|
||||||
|
# service_params = await get_service_params(orgid)
|
||||||
|
# if not service_params:
|
||||||
|
# raise ValueError("无法获取服务参数")
|
||||||
|
# # 获取嵌入模式
|
||||||
|
# embedding_mode = await get_embedding_mode(orgid)
|
||||||
|
# debug(f"检测到 embedding_mode = {embedding_mode}(0=文本, 1=多模态)")
|
||||||
|
#
|
||||||
|
# try:
|
||||||
|
# timings = {}
|
||||||
|
# start_time = time.time()
|
||||||
|
# rag_ops = RagOperations()
|
||||||
|
#
|
||||||
|
# query_entities = await rag_ops.extract_entities(request, query, service_params, userid, timings)
|
||||||
|
# all_triplets = await rag_ops.match_triplets(request, query, query_entities, orgid, fiids, service_params,
|
||||||
|
# userid, timings)
|
||||||
|
# combined_text = _combine_query_with_triplets(query, all_triplets)
|
||||||
|
# query_vector = await rag_ops.generate_query_vector(request, combined_text, service_params, userid, timings, embedding_mode)
|
||||||
|
# search_results = await rag_ops.vector_search(request, query_vector, orgid, fiids, limit + 50, service_params,
|
||||||
|
# userid, timings)
|
||||||
|
#
|
||||||
|
# use_rerank = False
|
||||||
|
# if use_rerank and search_results:
|
||||||
|
# final_results = await rag_ops.rerank_results(request, combined_text, search_results, limit, service_params,
|
||||||
|
# userid, timings)
|
||||||
|
# debug(f"final_results: {final_results}")
|
||||||
|
# else:
|
||||||
|
# final_results = [{k: v for k, v in r.items() if k != 'rerank_score'} for r in search_results]
|
||||||
|
#
|
||||||
|
# formatted_results = rag_ops.format_search_results(final_results, limit)
|
||||||
|
# timings["total_time"] = time.time() - start_time
|
||||||
|
# debug(f"融合搜索完成,返回 {len(formatted_results)} 条结果,总耗时: {timings['total_time']:.3f} 秒")
|
||||||
|
#
|
||||||
|
# return {
|
||||||
|
# "records": formatted_results,
|
||||||
|
# "timings": timings
|
||||||
|
# }
|
||||||
|
# except Exception as e:
|
||||||
|
# error(f"融合搜索失败: {str(e)}, 堆栈: {traceback.format_exc()}")
|
||||||
|
# return {
|
||||||
|
# "records": [],
|
||||||
|
# "timings": {"total_time": time.time() - start_time if 'start_time' in locals() else 0},
|
||||||
|
# "error": str(e)
|
||||||
|
# }
|
||||||
|
|
||||||
# async def text_insert(text: str, fiid: str, orgid: str, db_type: str):
|
# async def text_insert(text: str, fiid: str, orgid: str, db_type: str):
|
||||||
async def textinsert(request, params_kw, *params):
|
async def textinsert(request, params_kw, *params):
|
||||||
|
|||||||
@ -94,7 +94,7 @@ async def sor_get_embedding_mode(sor, orgid) -> int:
|
|||||||
|
|
||||||
async def get_embedding_mode(orgid):
|
async def get_embedding_mode(orgid):
|
||||||
db = DBPools()
|
db = DBPools()
|
||||||
debug(f"传入的orgid是:{orgid}")
|
# debug(f"传入的orgid是:{orgid}")
|
||||||
dbname = get_serverenv('get_module_dbname')('rag')
|
dbname = get_serverenv('get_module_dbname')('rag')
|
||||||
async with db.sqlorContext(dbname) as sor:
|
async with db.sqlorContext(dbname) as sor:
|
||||||
return await sor_get_embedding_mode(sor, orgid)
|
return await sor_get_embedding_mode(sor, orgid)
|
||||||
|
|||||||
@ -18,6 +18,11 @@
|
|||||||
"editable": true,
|
"editable": true,
|
||||||
"rows": 5
|
"rows": 5
|
||||||
},
|
},
|
||||||
|
{
|
||||||
|
"uitype": "image",
|
||||||
|
"name": "image",
|
||||||
|
"label": "上传查询图片(可选)"
|
||||||
|
},
|
||||||
{
|
{
|
||||||
"name": "fiids",
|
"name": "fiids",
|
||||||
"uitype": "checkbox",
|
"uitype": "checkbox",
|
||||||
|
|||||||
@ -7,14 +7,15 @@ if not orgid:
|
|||||||
message='请先登录'
|
message='请先登录'
|
||||||
)
|
)
|
||||||
|
|
||||||
fiids = params_kw.fiids
|
|
||||||
query = params_kw.query
|
query = params_kw.query
|
||||||
|
image = params_kw.image
|
||||||
|
fiids = params_kw.fiids
|
||||||
limit = params_kw.limit
|
limit = params_kw.limit
|
||||||
|
|
||||||
if not query or not fiids or not limit:
|
if (not query and not image) or not fiids or not limit:
|
||||||
return UiError(
|
return UiError(
|
||||||
title='无效输入',
|
title='无效输入',
|
||||||
message='请输入查询文本并选择至少一个知识库'
|
message='请输入查询文本或上传image并选择至少一个知识库和填写返回条数'
|
||||||
)
|
)
|
||||||
|
|
||||||
try:
|
try:
|
||||||
|
|||||||
Loading…
x
Reference in New Issue
Block a user