This commit is contained in:
wangmeihua 2025-11-28 16:20:20 +08:00
parent 985c5a998a
commit b2088aec49
6 changed files with 488 additions and 88 deletions

Binary file not shown.

View File

@ -2,6 +2,7 @@ import os
import re import re
import time import time
import math import math
import numpy as np
from datetime import datetime from datetime import datetime
from typing import List, Dict, Any, Optional from typing import List, Dict, Any, Optional
from langchain_core.documents import Document from langchain_core.documents import Document
@ -12,7 +13,10 @@ from filetxt.loader import fileloader, File2Text
from rag.uapi_service import APIService from rag.uapi_service import APIService
from rag.service_opts import get_service_params from rag.service_opts import get_service_params
from rag.transaction_manager import TransactionManager, OperationType from rag.transaction_manager import TransactionManager, OperationType
from pdf2image import convert_from_path
import pytesseract
import base64
from pathlib import Path
class RagOperations: class RagOperations:
"""RAG 操作类,提供所有通用的 RAG 操作""" """RAG 操作类,提供所有通用的 RAG 操作"""
@ -33,14 +37,39 @@ class RagOperations:
if ext not in supported_formats: if ext not in supported_formats:
raise ValueError(f"不支持的文件格式: {ext}, 支持的格式: {', '.join(supported_formats)}") raise ValueError(f"不支持的文件格式: {ext}, 支持的格式: {', '.join(supported_formats)}")
# 加载文件内容
text = fileloader(realpath) text = fileloader(realpath)
text = re.sub(r'[^\u4e00-\u9fa5a-zA-Z0-9\s.;,\n/]', '', text) if ext == 'pdf':
debug(f"pdf原生提取结果是{text}")
if not text or len(text.strip()) == 0: # 更严格的空值检查
debug(f"pdf原生提取失败尝试扫描件提取")
ocr_text = self.pdf_to_text(realpath)
debug(f"pdf扫描件抽取的文本内容是{ocr_text}")
text = ocr_text # 只在原生提取失败时使用OCR结果
# 只在有文本内容时进行清洗
if text and len(text.strip()) > 0:
# 或者保留更多有用字符
text = re.sub(r'[^\u4e00-\u9fa5a-zA-Z0-9\s.;,\n/]', '', text)
else:
error(f"文件 {realpath} 无法提取任何文本内容")
text = "" # 确保为空字符串
timings["load_file"] = time.time() - start_load timings["load_file"] = time.time() - start_load
debug(f"加载文件耗时: {timings['load_file']:.2f} 秒, 文本长度: {len(text)}") debug(f"加载文件耗时: {timings['load_file']:.2f} 秒, 文本长度: {len(text)}")
if not text or not text.strip(): # # 加载文件内容
raise ValueError(f"文件 {realpath} 加载为空") # text = fileloader(realpath)
# debug(f"pdf原生提取结果是{text}")
# if len(text) == 0:
# debug(f"pdf原生提取失败尝试扫描件提取")
# text = self.pdf_to_text(realpath)
# debug(f"pdf扫描件抽取的文本内容是{text}")
# text = re.sub(r'[^\u4e00-\u9fa5a-zA-Z0-9\s.;,\n/]', '', text)
# timings["load_file"] = time.time() - start_load
# debug(f"加载文件耗时: {timings['load_file']:.2f} 秒, 文本长度: {len(text)}")
#
# if not text or not text.strip():
# raise ValueError(f"文件 {realpath} 加载为空")
# 分片处理 # 分片处理
document = Document(page_content=text) document = Document(page_content=text)
@ -67,6 +96,40 @@ class RagOperations:
return chunks return chunks
def pdf_to_text(
self,
pdf_path: str,
output_txt: Optional[str] = None,
dpi: int = 300,
lang: str = 'chi_sim+chi_tra+eng'
) -> str:
"""
将扫描版 PDF 转为文字你原来的代码一行调用版
参数:
pdf_path: PDF 文件路径字符串
output_txt: 如果提供会自动保存到这个 txt 文件可选
dpi: 图片分辨率默认 300越高越清晰
lang: 语言包默认中文简体+繁体+英文
返回:
提取出的完整文字字符串
"""
# PDF 转图片
images = convert_from_path(pdf_path, dpi=dpi)
# OCR 识别
text = ''
for img in images:
text += pytesseract.image_to_string(img, lang=lang) + '\n'
# 可选:自动保存到文件
if output_txt:
with open(output_txt, 'w', encoding='utf-8') as f:
f.write(text)
return text
async def generate_embeddings(self, request, chunks: List[Document], service_params: Dict, async def generate_embeddings(self, request, chunks: List[Document], service_params: Dict,
userid: str, timings: Dict, userid: str, timings: Dict,
transaction_mgr: TransactionManager = None) -> List[List[float]]: transaction_mgr: TransactionManager = None) -> List[List[float]]:
@ -149,6 +212,14 @@ class RagOperations:
return result return result
# async def force_l2_normalize(self, vector: List[float]) -> List[float]:
# """万无一失的 L2 归一化"""
# arr = np.array(vector, dtype=np.float32)
# norm = np.linalg.norm(arr)
# if norm == 0:
# return vector # 全零向量无法归一化
# return (arr / norm).tolist()
# 统一插入向量库 # 统一插入向量库
async def insert_all_vectors( async def insert_all_vectors(
self, self,
@ -200,7 +271,10 @@ class RagOperations:
# 遍历 multi_results # 遍历 multi_results
for raw_key, info in multi_results.items(): for raw_key, info in multi_results.items():
typ = info["type"] typ = info["type"]
# vector = info["vector"]
# debug(f"从后端传回来的向量数据是:{vector}")
# emb = await self.force_l2_normalize(info["vector"])
# debug(f"归一化后的向量数据是:{emb}")
# --- 文本 --- # --- 文本 ---
if typ == "text": if typ == "text":
# raw_key 就是原文 # raw_key 就是原文
@ -253,7 +327,7 @@ class RagOperations:
# "upload_time": upload_time, # "upload_time": upload_time,
# "file_type": "face", # "file_type": "face",
# }) # })
# continue continue
# --- 视频 --- # --- 视频 ---
if typ == "video": if typ == "video":
@ -289,7 +363,7 @@ class RagOperations:
# "upload_time": upload_time, # "upload_time": upload_time,
# "file_type": "face", # "file_type": "face",
# }) # })
# continue continue
# --- 音频 --- # --- 音频 ---
if typ == "audio": if typ == "audio":
@ -776,25 +850,150 @@ class RagOperations:
debug(f"三元组匹配总耗时: {timings['triplet_matching']:.3f}") debug(f"三元组匹配总耗时: {timings['triplet_matching']:.3f}")
return all_triplets return all_triplets
async def generate_query_vector(self, request, text: str, service_params: Dict, async def generate_query_vector(
userid: str, timings: Dict) -> List[float]: self,
"""生成查询向量""" request,
debug(f"生成查询向量: {text[:200]}...") text: str,
service_params: Dict,
userid: str,
timings: Dict,
embedding_mode: int = 0
) -> List[float]:
"""生成查询向量(支持文本/多模态)"""
debug(f"生成查询向量: mode={embedding_mode}, text='{text[:100]}...'")
start_vector = time.time() start_vector = time.time()
query_vector = await self.api_service.get_embeddings(
request=request, if embedding_mode == 0:
texts=[text], # === 模式 0纯文本嵌入BAAI/bge-m3===
upappid=service_params['embedding'], debug("使用 BAAI/bge-m3 文本嵌入")
apiname="BAAI/bge-m3", vectors = await self.api_service.get_embeddings(
user=userid request=request,
) texts=[text],
if not query_vector or not all(len(vec) == 1024 for vec in query_vector): upappid=service_params['embedding'],
raise ValueError("查询向量必须是长度为 1024 的浮点数列表") apiname="BAAI/bge-m3",
query_vector = query_vector[0] user=userid
)
if not vectors or not isinstance(vectors, list) or len(vectors) == 0:
raise ValueError("bge-m3 返回空结果")
query_vector = vectors[0]
if len(query_vector) != 1024:
raise ValueError(f"bge-m3 返回向量维度错误: {len(query_vector)}")
elif embedding_mode == 1:
# === 模式 1多模态嵌入black/clip===
debug("使用 black/clip 多模态嵌入")
inputs = [{"type": "text", "content": text}]
result = await self.api_service.get_multi_embeddings(
request=request,
inputs=inputs,
upappid=service_params['embedding'],
apiname="black/clip",
user=userid
)
query_vector = None
for key, info in result.items():
if info.get("type") == "error":
debug(f"CLIP 返回错误跳过: {info['error']}")
continue
if "vector" in info and isinstance(info["vector"], list) and len(info["vector"]) == 1024:
query_vector = info["vector"]
debug(f"成功获取 CLIP 向量(来自 {info['type']}")
break
if query_vector is None:
raise ValueError("black/clip 未返回任何有效 1024 维向量")
else:
raise ValueError(f"不支持的 embedding_mode: {embedding_mode}")
# 最终统一校验
if not isinstance(query_vector, list) or len(query_vector) != 1024:
raise ValueError(f"查询向量必须是长度为 1024 的浮点数列表,实际: {len(query_vector)}")
timings["vector_generation"] = time.time() - start_vector timings["vector_generation"] = time.time() - start_vector
debug(f"生成查询向量耗时: {timings['vector_generation']:.3f}") debug(f"生成查询向量成功,耗时: {timings['vector_generation']:.3f},模式: {embedding_mode}")
return query_vector return query_vector
async def generate_image_vector(
self,
request,
img_path: str,
service_params: Dict,
userid: str,
timings: Dict,
embedding_mode: int = 0
) -> List[float]:
"""生成查询向量(支持文本/多模态)"""
debug(f"生成查询向量: mode={embedding_mode}, image={img_path}")
start_vector = time.time()
if embedding_mode == 0:
raise ValueError(f"纯文本没有这个功能,请重新选择服务")
elif embedding_mode == 1:
# === 模式 1多模态嵌入black/clip===
debug("使用 black/clip 多模态嵌入")
inputs = []
try:
ext = Path(img_path).suffix.lower()
if ext not in {".png", ".jpg", ".jpeg", ".webp", ".bmp"}:
ext = ".jpg"
mime_map = {
".png": "image/png",
".jpg": "image/jpeg",
".jpeg": "image/jpeg",
".webp": "image/webp",
".bmp": "image/bmp"
}
mime_type = mime_map.get(ext, "image/jpeg")
with open(img_path, "rb") as f:
b64 = base64.b64encode(f.read()).decode("utf-8")
data_uri = f"data:{mime_type};base64,{b64}"
inputs.append({
"type": "image",
"data": data_uri
})
debug(f"已添加图像({mime_type}, {len(b64) / 1024:.1f}KB: {Path(img_path).name}")
except Exception as e:
debug(f"图像处理失败,跳过: {img_path}{e}")
result = await self.api_service.get_multi_embeddings(
request=request,
inputs=inputs,
upappid=service_params['embedding'],
apiname="black/clip",
user=userid
)
image_vector = None
for key, info in result.items():
if info.get("type") == "error":
debug(f"CLIP 返回错误跳过: {info['error']}")
continue
if "vector" in info and isinstance(info["vector"], list) and len(info["vector"]) == 1024:
image_vector = info["vector"]
debug(f"成功获取 CLIP 向量(来自 {info['type']}")
break
if image_vector is None:
raise ValueError("black/clip 未返回任何有效 1024 维向量")
else:
raise ValueError(f"不支持的 embedding_mode: {embedding_mode}")
# 最终统一校验
if not isinstance(image_vector, list) or len(image_vector) != 1024:
raise ValueError(f"查询向量必须是长度为 1024 的浮点数列表,实际: {len(image_vector)}")
timings["vector_generation"] = time.time() - start_vector
debug(f"生成查询向量成功,耗时: {timings['vector_generation']:.3f} 秒,模式: {embedding_mode}")
return image_vector
async def vector_search(self, request, query_vector: List[float], orgid: str, async def vector_search(self, request, query_vector: List[float], orgid: str,
fiids: List[str], limit: int, service_params: Dict, userid: str, fiids: List[str], limit: int, service_params: Dict, userid: str,
timings: Dict) -> List[Dict]: timings: Dict) -> List[Dict]:
@ -866,34 +1065,49 @@ class RagOperations:
return unique_triples return unique_triples
def format_search_results(self, results: List[Dict], limit: int) -> List[Dict]: def format_search_results(self, results: List[Dict], limit: int) -> List[Dict]:
"""格式化搜索结果为统一格式""" formatted = []
formatted_results = []
# for res in results[:limit]:
# score = res.get('rerank_score', res.get('distance', 0))
#
# content = res.get('text', '')
# title = res.get('metadata', {}).get('filename', 'Untitled')
# document_id = res.get('metadata', {}).get('document_id', '')
#
# formatted_results.append({
# "content": content,
# "title": title,
# "metadata": {"document_id": document_id, "score": score},
# })
#得分归一化
for res in results[:limit]: for res in results[:limit]:
rerank_score = res.get('rerank_score', 0) # # 优先 rerank其次用向量相似度直接用不要反
score = 1 / (1 + math.exp(-rerank_score)) if rerank_score is not None else 1 - res.get('distance', 0) # if res.get('rerank_score') is not None:
score = max(0.0, min(1.0, score)) # score = res.get('rerank_score')
# else:
content = res.get('text', '') # score = res.get('distance', 0.0)
title = res.get('metadata', {}).get('filename', 'Untitled') distance = res.get('distance', 0.0)
document_id = res.get('metadata', {}).get('document_id', '') rerank_score = res.get('rerank_score', 0.0)
formatted.append({
formatted_results.append({ "content": res.get('text', ''),
"content": content, "title": res.get('metadata', {}).get('filename', 'Untitled'),
"title": title, "metadata": {
"metadata": {"document_id": document_id, "score": score}, "document_id": res.get('metadata', {}).get('document_id', ''),
"distance": distance,
"rerank_score": rerank_score,
}
}) })
return formatted
return formatted_results # async def save_uploaded_photo(self, image_file: FileStorage, orgid: str) -> str:
# """
# 把前端上传的图片保存到 /home/wangmeihua/kyrag/data/photo 目录下
# 返回保存后的绝对路径(字符串),供 generate_img_vector 使用
# """
# if not image_file or not hasattr(image_file, "filename"):
# raise ValueError("无效的图片上传对象")
#
# # 为了安全,按 orgid 分目录存放(避免不同公司文件混在一起)
# org_dir = UPLOAD_PHOTO_DIR / orgid
# org_dir.mkdir(parents=True, exist_ok=True)
#
# # 生成唯一文件名,保留原始后缀
# suffix = Path(image_file.filename).suffix.lower()
# if not suffix or suffix not in {".jpg", ".jpeg", ".png", ".webp", ".bmp", ".gif"}:
# suffix = ".jpg"
#
# unique_name = f"{uuid.uuid4().hex}{suffix}"
# save_path = org_dir / unique_name
#
# # 真正落盘
# image_file.save(str(save_path))
# debug(f"图片已保存: {save_path} (原始名: {image_file.filename})")
#
# # 返回字符串路径generate_img_vector 直接收 str 就行
# return str(save_path)

View File

@ -6,10 +6,13 @@ import traceback
import json import json
import math import math
import uuid import uuid
from rag.service_opts import get_service_params, sor_get_service_params import os
from rag.service_opts import get_service_params, sor_get_service_params, sor_get_embedding_mode, get_embedding_mode
from rag.rag_operations import RagOperations from rag.rag_operations import RagOperations
from langchain_core.documents import Document from langchain_core.documents import Document
REAL_PHOTO_ROOT = "/home/wangmeihua/kyrag/files"
helptext = """kyrag API: helptext = """kyrag API:
1. 得到kdb表: 1. 得到kdb表:
@ -134,7 +137,16 @@ async def fusedsearch(request, params_kw, *params):
debug(f"params_kw: {params_kw}") debug(f"params_kw: {params_kw}")
# orgid = "04J6VbxLqB_9RPMcgOv_8" # orgid = "04J6VbxLqB_9RPMcgOv_8"
# userid = "04J6VbxLqB_9RPMcgOv_8" # userid = "04J6VbxLqB_9RPMcgOv_8"
query = params_kw.get('query', '') query = params_kw.get('query', '').strip()
img_path = params_kw.get('image')
if isinstance(img_path, str):
img_path = img_path.strip()
relative_part = img_path.lstrip("/")
real_img_path = os.path.join(REAL_PHOTO_ROOT, relative_part)
if not os.path.exists(real_img_path):
raise FileNotFoundError(f"图片不存在: {real_img_path}")
img_path = real_img_path
debug(f"自动修复图片路径成功: {img_path}")
# 统一模式处理 limit 参数,为了对接dify和coze # 统一模式处理 limit 参数,为了对接dify和coze
raw_limit = params_kw.get('limit') or ( raw_limit = params_kw.get('limit') or (
params_kw.get('retrieval_setting', {}).get('top_k') params_kw.get('retrieval_setting', {}).get('top_k')
@ -189,43 +201,211 @@ async def fusedsearch(request, params_kw, *params):
service_params = await get_service_params(orgid) service_params = await get_service_params(orgid)
if not service_params: if not service_params:
raise ValueError("无法获取服务参数") raise ValueError("无法获取服务参数")
# 获取嵌入模式
embedding_mode = await get_embedding_mode(orgid)
debug(f"检测到 embedding_mode = {embedding_mode}0=文本, 1=多模态)")
try: # 情况1query 和 image 都为空 → 报错
timings = {} if not query and not img_path:
start_time = time.time() raise ValueError("查询文本和图片不能同时为空")
rag_ops = RagOperations()
query_entities = await rag_ops.extract_entities(request, query, service_params, userid, timings) # 情况2query 和 image 都存在 → 报错(你当前业务不允许同时传)
all_triplets = await rag_ops.match_triplets(request, query, query_entities, orgid, fiids, service_params, if query and img_path:
userid, timings) raise ValueError("查询文本和图片只能二选一,不能同时提交")
combined_text = _combine_query_with_triplets(query, all_triplets)
query_vector = await rag_ops.generate_query_vector(request, combined_text, service_params, userid, timings)
search_results = await rag_ops.vector_search(request, query_vector, orgid, fiids, limit + 5, service_params,
userid, timings)
use_rerank = True # 3. 只有图片 → 以图搜图 走纯多模态分支
if use_rerank and search_results: if img_path and not query:
final_results = await rag_ops.rerank_results(request, combined_text, search_results, limit, service_params, try:
debug("检测到纯图片查询,执行以图搜图")
rag_ops = RagOperations()
timings = {}
start_time = time.time()
# 直接生成图片向量
img_vector = await rag_ops.generate_image_vector(
request, img_path, service_params, userid, timings, embedding_mode
)
# 向量搜索(多取 50 条再截断,和文本分支保持一致)
search_results = await rag_ops.vector_search(
request, img_vector, orgid, fiids, limit + 50, service_params, userid, timings
)
timings["total_time"] = time.time() - start_time
# 可选:搜索完后删除图片,省磁盘(看你需求)
# try:
# os.remove(img_path)
# except:
# pass
final_results = []
for item in search_results[:limit]:
final_results.append({
"text": item["text"],
"distance": item["distance"]
})
return {
"results": final_results,
"timings": timings
}
except Exception as e:
error(f"融合搜索失败: {str(e)}, 堆栈: {traceback.format_exc()}")
return {
"records": [],
"timings": {"total_time": time.time() - start_time if 'start_time' in locals() else 0},
"error": str(e)
}
if not img_path and query:
try:
timings = {}
start_time = time.time()
rag_ops = RagOperations()
query_entities = await rag_ops.extract_entities(request, query, service_params, userid, timings)
all_triplets = await rag_ops.match_triplets(request, query, query_entities, orgid, fiids, service_params,
userid, timings)
combined_text = _combine_query_with_triplets(query, all_triplets)
query_vector = await rag_ops.generate_query_vector(request, combined_text, service_params, userid, timings, embedding_mode)
search_results = await rag_ops.vector_search(request, query_vector, orgid, fiids, limit + 50, service_params,
userid, timings) userid, timings)
debug(f"final_results: {final_results}")
else:
final_results = [{k: v for k, v in r.items() if k != 'rerank_score'} for r in search_results]
formatted_results = rag_ops.format_search_results(final_results, limit) use_rerank = True
timings["total_time"] = time.time() - start_time if use_rerank and search_results:
info(f"融合搜索完成,返回 {len(formatted_results)} 条结果,总耗时: {timings['total_time']:.3f}") final_results = await rag_ops.rerank_results(request, combined_text, search_results, limit, service_params,
userid, timings)
debug(f"final_results: {final_results}")
else:
final_results = [{k: v for k, v in r.items() if k != 'rerank_score'} for r in search_results]
return { formatted_results = rag_ops.format_search_results(final_results, limit)
"records": formatted_results, timings["total_time"] = time.time() - start_time
"timings": timings debug(f"融合搜索完成,返回 {len(formatted_results)} 条结果,总耗时: {timings['total_time']:.3f}")
}
except Exception as e: return {
error(f"融合搜索失败: {str(e)}, 堆栈: {traceback.format_exc()}") "records": formatted_results,
return { "timings": timings
"records": [], }
"timings": {"total_time": time.time() - start_time if 'start_time' in locals() else 0}, except Exception as e:
"error": str(e) error(f"融合搜索失败: {str(e)}, 堆栈: {traceback.format_exc()}")
} return {
"records": [],
"timings": {"total_time": time.time() - start_time if 'start_time' in locals() else 0},
"error": str(e)
}
# async def fusedsearch(request, params_kw, *params):
# """
# 融合搜索,调用服务化端点
#
# """
# kw = request._run_ns
# f = kw.get('get_userorgid')
# orgid = await f()
# debug(f"orgid: {orgid}{f=}")
# f = kw.get('get_user')
# userid = await f()
# debug(f"params_kw: {params_kw}")
# # orgid = "04J6VbxLqB_9RPMcgOv_8"
# # userid = "04J6VbxLqB_9RPMcgOv_8"
# query = params_kw.get('query', '')
# # 统一模式处理 limit 参数,为了对接dify和coze
# raw_limit = params_kw.get('limit') or (
# params_kw.get('retrieval_setting', {}).get('top_k')
# if isinstance(params_kw.get('retrieval_setting'), dict)
# else None
# )
#
# # 标准化为整数值
# if raw_limit is None:
# limit = 5 # 两个来源都不存在时使用默认值
# elif isinstance(raw_limit, (int, float)):
# limit = int(raw_limit) # 数值类型直接转换
# elif isinstance(raw_limit, str):
# try:
# # 字符串转换为整数
# limit = int(raw_limit)
# except (TypeError, ValueError):
# limit = 5 # 转换失败使用默认值
# else:
# limit = 5 # 其他意外类型使用默认值
# debug(f"limit: {limit}")
# raw_fiids = params_kw.get('fiids') or params_kw.get('knowledge_id') #
#
# # 标准化为列表格式
# if raw_fiids is None:
# fiids = [] # 两个参数都不存在
# elif isinstance(raw_fiids, list):
# fiids = [str(item).strip() for item in raw_fiids] # 已经是列表
# elif isinstance(raw_fiids, str):
# # fiids = [f.strip() for f in raw_fiids.split(',') if f.strip()]
# try:
# # 尝试解析 JSON 字符串
# parsed = json.loads(raw_fiids)
# if isinstance(parsed, list):
# fiids = [str(item).strip() for item in parsed] # JSON 数组转为字符串列表
# else:
# # 处理逗号分隔的字符串或单个 ID 字符串
# fiids = [f.strip() for f in raw_fiids.split(',') if f.strip()]
# except json.JSONDecodeError:
# # 如果不是合法 JSON按逗号分隔
# fiids = [f.strip() for f in raw_fiids.split(',') if f.strip()]
# elif isinstance(raw_fiids, (int, float)):
# fiids = [str(int(raw_fiids))] # 数值类型转为字符串列表
# else:
# fiids = [] # 其他意外类型
#
# debug(f"fiids: {fiids}")
#
# # 验证 fiids的orgid与orgid = await f()是否一致
# await _validate_fiids_orgid(fiids, orgid, kw)
#
# service_params = await get_service_params(orgid)
# if not service_params:
# raise ValueError("无法获取服务参数")
# # 获取嵌入模式
# embedding_mode = await get_embedding_mode(orgid)
# debug(f"检测到 embedding_mode = {embedding_mode}0=文本, 1=多模态)")
#
# try:
# timings = {}
# start_time = time.time()
# rag_ops = RagOperations()
#
# query_entities = await rag_ops.extract_entities(request, query, service_params, userid, timings)
# all_triplets = await rag_ops.match_triplets(request, query, query_entities, orgid, fiids, service_params,
# userid, timings)
# combined_text = _combine_query_with_triplets(query, all_triplets)
# query_vector = await rag_ops.generate_query_vector(request, combined_text, service_params, userid, timings, embedding_mode)
# search_results = await rag_ops.vector_search(request, query_vector, orgid, fiids, limit + 50, service_params,
# userid, timings)
#
# use_rerank = False
# if use_rerank and search_results:
# final_results = await rag_ops.rerank_results(request, combined_text, search_results, limit, service_params,
# userid, timings)
# debug(f"final_results: {final_results}")
# else:
# final_results = [{k: v for k, v in r.items() if k != 'rerank_score'} for r in search_results]
#
# formatted_results = rag_ops.format_search_results(final_results, limit)
# timings["total_time"] = time.time() - start_time
# debug(f"融合搜索完成,返回 {len(formatted_results)} 条结果,总耗时: {timings['total_time']:.3f} 秒")
#
# return {
# "records": formatted_results,
# "timings": timings
# }
# except Exception as e:
# error(f"融合搜索失败: {str(e)}, 堆栈: {traceback.format_exc()}")
# return {
# "records": [],
# "timings": {"total_time": time.time() - start_time if 'start_time' in locals() else 0},
# "error": str(e)
# }
# async def text_insert(text: str, fiid: str, orgid: str, db_type: str): # async def text_insert(text: str, fiid: str, orgid: str, db_type: str):
async def textinsert(request, params_kw, *params): async def textinsert(request, params_kw, *params):

View File

@ -94,7 +94,7 @@ async def sor_get_embedding_mode(sor, orgid) -> int:
async def get_embedding_mode(orgid): async def get_embedding_mode(orgid):
db = DBPools() db = DBPools()
debug(f"传入的orgid是{orgid}") # debug(f"传入的orgid是{orgid}")
dbname = get_serverenv('get_module_dbname')('rag') dbname = get_serverenv('get_module_dbname')('rag')
async with db.sqlorContext(dbname) as sor: async with db.sqlorContext(dbname) as sor:
return await sor_get_embedding_mode(sor, orgid) return await sor_get_embedding_mode(sor, orgid)

View File

@ -18,6 +18,11 @@
"editable": true, "editable": true,
"rows": 5 "rows": 5
}, },
{
"uitype": "image",
"name": "image",
"label": "上传查询图片(可选)"
},
{ {
"name": "fiids", "name": "fiids",
"uitype": "checkbox", "uitype": "checkbox",

View File

@ -7,14 +7,15 @@ if not orgid:
message='请先登录' message='请先登录'
) )
fiids = params_kw.fiids
query = params_kw.query query = params_kw.query
image = params_kw.image
fiids = params_kw.fiids
limit = params_kw.limit limit = params_kw.limit
if not query or not fiids or not limit: if (not query and not image) or not fiids or not limit:
return UiError( return UiError(
title='无效输入', title='无效输入',
message='请输入查询文本并选择至少一个知识库' message='请输入查询文本或上传image并选择至少一个知识库和填写返回条数'
) )
try: try: