diff --git a/models/embeddingmode.xlsx b/models/embeddingmode.xlsx index 55a8a76..af4ba5d 100644 Binary files a/models/embeddingmode.xlsx and b/models/embeddingmode.xlsx differ diff --git a/rag/rag_operations.py b/rag/rag_operations.py index 06e20fd..ec04e76 100644 --- a/rag/rag_operations.py +++ b/rag/rag_operations.py @@ -2,6 +2,7 @@ import os import re import time import math +import numpy as np from datetime import datetime from typing import List, Dict, Any, Optional from langchain_core.documents import Document @@ -12,7 +13,10 @@ from filetxt.loader import fileloader, File2Text from rag.uapi_service import APIService from rag.service_opts import get_service_params from rag.transaction_manager import TransactionManager, OperationType - +from pdf2image import convert_from_path +import pytesseract +import base64 +from pathlib import Path class RagOperations: """RAG 操作类,提供所有通用的 RAG 操作""" @@ -33,14 +37,39 @@ class RagOperations: if ext not in supported_formats: raise ValueError(f"不支持的文件格式: {ext}, 支持的格式: {', '.join(supported_formats)}") - # 加载文件内容 text = fileloader(realpath) - text = re.sub(r'[^\u4e00-\u9fa5a-zA-Z0-9\s.;,\n/]', '', text) + if ext == 'pdf': + debug(f"pdf原生提取结果是:{text}") + if not text or len(text.strip()) == 0: # 更严格的空值检查 + debug(f"pdf原生提取失败,尝试扫描件提取") + ocr_text = self.pdf_to_text(realpath) + debug(f"pdf扫描件抽取的文本内容是:{ocr_text}") + text = ocr_text # 只在原生提取失败时使用OCR结果 + + # 只在有文本内容时进行清洗 + if text and len(text.strip()) > 0: + # 或者保留更多有用字符 + text = re.sub(r'[^\u4e00-\u9fa5a-zA-Z0-9\s.;,\n/]', '', text) + else: + error(f"文件 {realpath} 无法提取任何文本内容") + text = "" # 确保为空字符串 + timings["load_file"] = time.time() - start_load debug(f"加载文件耗时: {timings['load_file']:.2f} 秒, 文本长度: {len(text)}") - if not text or not text.strip(): - raise ValueError(f"文件 {realpath} 加载为空") + # # 加载文件内容 + # text = fileloader(realpath) + # debug(f"pdf原生提取结果是:{text}") + # if len(text) == 0: + # debug(f"pdf原生提取失败,尝试扫描件提取") + # text = self.pdf_to_text(realpath) + # debug(f"pdf扫描件抽取的文本内容是:{text}") + # text = re.sub(r'[^\u4e00-\u9fa5a-zA-Z0-9\s.;,\n/]', '', text) + # timings["load_file"] = time.time() - start_load + # debug(f"加载文件耗时: {timings['load_file']:.2f} 秒, 文本长度: {len(text)}") + # + # if not text or not text.strip(): + # raise ValueError(f"文件 {realpath} 加载为空") # 分片处理 document = Document(page_content=text) @@ -67,6 +96,40 @@ class RagOperations: return chunks + def pdf_to_text( + self, + pdf_path: str, + output_txt: Optional[str] = None, + dpi: int = 300, + lang: str = 'chi_sim+chi_tra+eng' + ) -> str: + """ + 将扫描版 PDF 转为文字(你原来的代码,一行调用版) + + 参数: + pdf_path: PDF 文件路径(字符串) + output_txt: 如果提供,会自动保存到这个 txt 文件(可选) + dpi: 图片分辨率,默认 300(越高越清晰) + lang: 语言包,默认中文简体+繁体+英文 + + 返回: + 提取出的完整文字(字符串) + """ + # PDF 转图片 + images = convert_from_path(pdf_path, dpi=dpi) + + # OCR 识别 + text = '' + for img in images: + text += pytesseract.image_to_string(img, lang=lang) + '\n' + + # 可选:自动保存到文件 + if output_txt: + with open(output_txt, 'w', encoding='utf-8') as f: + f.write(text) + + return text + async def generate_embeddings(self, request, chunks: List[Document], service_params: Dict, userid: str, timings: Dict, transaction_mgr: TransactionManager = None) -> List[List[float]]: @@ -149,6 +212,14 @@ class RagOperations: return result + # async def force_l2_normalize(self, vector: List[float]) -> List[float]: + # """万无一失的 L2 归一化""" + # arr = np.array(vector, dtype=np.float32) + # norm = np.linalg.norm(arr) + # if norm == 0: + # return vector # 全零向量无法归一化 + # return (arr / norm).tolist() + # 统一插入向量库 async def insert_all_vectors( self, @@ -200,7 +271,10 @@ class RagOperations: # 遍历 multi_results for raw_key, info in multi_results.items(): typ = info["type"] - + # vector = info["vector"] + # debug(f"从后端传回来的向量数据是:{vector}") + # emb = await self.force_l2_normalize(info["vector"]) + # debug(f"归一化后的向量数据是:{emb}") # --- 文本 --- if typ == "text": # raw_key 就是原文 @@ -253,7 +327,7 @@ class RagOperations: # "upload_time": upload_time, # "file_type": "face", # }) - # continue + continue # --- 视频 --- if typ == "video": @@ -289,7 +363,7 @@ class RagOperations: # "upload_time": upload_time, # "file_type": "face", # }) - # continue + continue # --- 音频 --- if typ == "audio": @@ -776,25 +850,150 @@ class RagOperations: debug(f"三元组匹配总耗时: {timings['triplet_matching']:.3f} 秒") return all_triplets - async def generate_query_vector(self, request, text: str, service_params: Dict, - userid: str, timings: Dict) -> List[float]: - """生成查询向量""" - debug(f"生成查询向量: {text[:200]}...") + async def generate_query_vector( + self, + request, + text: str, + service_params: Dict, + userid: str, + timings: Dict, + embedding_mode: int = 0 + ) -> List[float]: + """生成查询向量(支持文本/多模态)""" + debug(f"生成查询向量: mode={embedding_mode}, text='{text[:100]}...'") start_vector = time.time() - query_vector = await self.api_service.get_embeddings( - request=request, - texts=[text], - upappid=service_params['embedding'], - apiname="BAAI/bge-m3", - user=userid - ) - if not query_vector or not all(len(vec) == 1024 for vec in query_vector): - raise ValueError("查询向量必须是长度为 1024 的浮点数列表") - query_vector = query_vector[0] + + if embedding_mode == 0: + # === 模式 0:纯文本嵌入(BAAI/bge-m3)=== + debug("使用 BAAI/bge-m3 文本嵌入") + vectors = await self.api_service.get_embeddings( + request=request, + texts=[text], + upappid=service_params['embedding'], + apiname="BAAI/bge-m3", + user=userid + ) + if not vectors or not isinstance(vectors, list) or len(vectors) == 0: + raise ValueError("bge-m3 返回空结果") + query_vector = vectors[0] + if len(query_vector) != 1024: + raise ValueError(f"bge-m3 返回向量维度错误: {len(query_vector)}") + + elif embedding_mode == 1: + # === 模式 1:多模态嵌入(black/clip)=== + debug("使用 black/clip 多模态嵌入") + inputs = [{"type": "text", "content": text}] + + result = await self.api_service.get_multi_embeddings( + request=request, + inputs=inputs, + upappid=service_params['embedding'], + apiname="black/clip", + user=userid + ) + + query_vector = None + for key, info in result.items(): + if info.get("type") == "error": + debug(f"CLIP 返回错误跳过: {info['error']}") + continue + if "vector" in info and isinstance(info["vector"], list) and len(info["vector"]) == 1024: + query_vector = info["vector"] + debug(f"成功获取 CLIP 向量(来自 {info['type']})") + break + + if query_vector is None: + raise ValueError("black/clip 未返回任何有效 1024 维向量") + + else: + raise ValueError(f"不支持的 embedding_mode: {embedding_mode}") + + # 最终统一校验 + if not isinstance(query_vector, list) or len(query_vector) != 1024: + raise ValueError(f"查询向量必须是长度为 1024 的浮点数列表,实际: {len(query_vector)}") + timings["vector_generation"] = time.time() - start_vector - debug(f"生成查询向量耗时: {timings['vector_generation']:.3f} 秒") + debug(f"生成查询向量成功,耗时: {timings['vector_generation']:.3f} 秒,模式: {embedding_mode}") return query_vector + async def generate_image_vector( + self, + request, + img_path: str, + service_params: Dict, + userid: str, + timings: Dict, + embedding_mode: int = 0 + ) -> List[float]: + """生成查询向量(支持文本/多模态)""" + debug(f"生成查询向量: mode={embedding_mode}, image={img_path}") + start_vector = time.time() + + if embedding_mode == 0: + raise ValueError(f"纯文本没有这个功能,请重新选择服务") + + elif embedding_mode == 1: + # === 模式 1:多模态嵌入(black/clip)=== + debug("使用 black/clip 多模态嵌入") + inputs = [] + try: + ext = Path(img_path).suffix.lower() + if ext not in {".png", ".jpg", ".jpeg", ".webp", ".bmp"}: + ext = ".jpg" + + mime_map = { + ".png": "image/png", + ".jpg": "image/jpeg", + ".jpeg": "image/jpeg", + ".webp": "image/webp", + ".bmp": "image/bmp" + } + mime_type = mime_map.get(ext, "image/jpeg") + with open(img_path, "rb") as f: + b64 = base64.b64encode(f.read()).decode("utf-8") + data_uri = f"data:{mime_type};base64,{b64}" + + inputs.append({ + "type": "image", + "data": data_uri + }) + debug(f"已添加图像({mime_type}, {len(b64) / 1024:.1f}KB): {Path(img_path).name}") + + except Exception as e: + debug(f"图像处理失败,跳过: {img_path} → {e}") + + result = await self.api_service.get_multi_embeddings( + request=request, + inputs=inputs, + upappid=service_params['embedding'], + apiname="black/clip", + user=userid + ) + + image_vector = None + for key, info in result.items(): + if info.get("type") == "error": + debug(f"CLIP 返回错误跳过: {info['error']}") + continue + if "vector" in info and isinstance(info["vector"], list) and len(info["vector"]) == 1024: + image_vector = info["vector"] + debug(f"成功获取 CLIP 向量(来自 {info['type']})") + break + + if image_vector is None: + raise ValueError("black/clip 未返回任何有效 1024 维向量") + + else: + raise ValueError(f"不支持的 embedding_mode: {embedding_mode}") + + # 最终统一校验 + if not isinstance(image_vector, list) or len(image_vector) != 1024: + raise ValueError(f"查询向量必须是长度为 1024 的浮点数列表,实际: {len(image_vector)}") + + timings["vector_generation"] = time.time() - start_vector + debug(f"生成查询向量成功,耗时: {timings['vector_generation']:.3f} 秒,模式: {embedding_mode}") + return image_vector + async def vector_search(self, request, query_vector: List[float], orgid: str, fiids: List[str], limit: int, service_params: Dict, userid: str, timings: Dict) -> List[Dict]: @@ -866,34 +1065,49 @@ class RagOperations: return unique_triples def format_search_results(self, results: List[Dict], limit: int) -> List[Dict]: - """格式化搜索结果为统一格式""" - formatted_results = [] - # for res in results[:limit]: - # score = res.get('rerank_score', res.get('distance', 0)) - # - # content = res.get('text', '') - # title = res.get('metadata', {}).get('filename', 'Untitled') - # document_id = res.get('metadata', {}).get('document_id', '') - # - # formatted_results.append({ - # "content": content, - # "title": title, - # "metadata": {"document_id": document_id, "score": score}, - # }) - #得分归一化 + formatted = [] for res in results[:limit]: - rerank_score = res.get('rerank_score', 0) - score = 1 / (1 + math.exp(-rerank_score)) if rerank_score is not None else 1 - res.get('distance', 0) - score = max(0.0, min(1.0, score)) - - content = res.get('text', '') - title = res.get('metadata', {}).get('filename', 'Untitled') - document_id = res.get('metadata', {}).get('document_id', '') - - formatted_results.append({ - "content": content, - "title": title, - "metadata": {"document_id": document_id, "score": score}, + # # 优先 rerank,其次用向量相似度(直接用,不要反) + # if res.get('rerank_score') is not None: + # score = res.get('rerank_score') + # else: + # score = res.get('distance', 0.0) + distance = res.get('distance', 0.0) + rerank_score = res.get('rerank_score', 0.0) + formatted.append({ + "content": res.get('text', ''), + "title": res.get('metadata', {}).get('filename', 'Untitled'), + "metadata": { + "document_id": res.get('metadata', {}).get('document_id', ''), + "distance": distance, + "rerank_score": rerank_score, + } }) + return formatted - return formatted_results \ No newline at end of file + # async def save_uploaded_photo(self, image_file: FileStorage, orgid: str) -> str: + # """ + # 把前端上传的图片保存到 /home/wangmeihua/kyrag/data/photo 目录下 + # 返回保存后的绝对路径(字符串),供 generate_img_vector 使用 + # """ + # if not image_file or not hasattr(image_file, "filename"): + # raise ValueError("无效的图片上传对象") + # + # # 为了安全,按 orgid 分目录存放(避免不同公司文件混在一起) + # org_dir = UPLOAD_PHOTO_DIR / orgid + # org_dir.mkdir(parents=True, exist_ok=True) + # + # # 生成唯一文件名,保留原始后缀 + # suffix = Path(image_file.filename).suffix.lower() + # if not suffix or suffix not in {".jpg", ".jpeg", ".png", ".webp", ".bmp", ".gif"}: + # suffix = ".jpg" + # + # unique_name = f"{uuid.uuid4().hex}{suffix}" + # save_path = org_dir / unique_name + # + # # 真正落盘 + # image_file.save(str(save_path)) + # debug(f"图片已保存: {save_path} (原始名: {image_file.filename})") + # + # # 返回字符串路径,generate_img_vector 直接收 str 就行 + # return str(save_path) \ No newline at end of file diff --git a/rag/ragapi.py b/rag/ragapi.py index 355c20f..96b783c 100644 --- a/rag/ragapi.py +++ b/rag/ragapi.py @@ -6,10 +6,13 @@ import traceback import json import math import uuid -from rag.service_opts import get_service_params, sor_get_service_params +import os +from rag.service_opts import get_service_params, sor_get_service_params, sor_get_embedding_mode, get_embedding_mode from rag.rag_operations import RagOperations from langchain_core.documents import Document +REAL_PHOTO_ROOT = "/home/wangmeihua/kyrag/files" + helptext = """kyrag API: 1. 得到kdb表: @@ -134,7 +137,16 @@ async def fusedsearch(request, params_kw, *params): debug(f"params_kw: {params_kw}") # orgid = "04J6VbxLqB_9RPMcgOv_8" # userid = "04J6VbxLqB_9RPMcgOv_8" - query = params_kw.get('query', '') + query = params_kw.get('query', '').strip() + img_path = params_kw.get('image') + if isinstance(img_path, str): + img_path = img_path.strip() + relative_part = img_path.lstrip("/") + real_img_path = os.path.join(REAL_PHOTO_ROOT, relative_part) + if not os.path.exists(real_img_path): + raise FileNotFoundError(f"图片不存在: {real_img_path}") + img_path = real_img_path + debug(f"自动修复图片路径成功: {img_path}") # 统一模式处理 limit 参数,为了对接dify和coze raw_limit = params_kw.get('limit') or ( params_kw.get('retrieval_setting', {}).get('top_k') @@ -189,43 +201,211 @@ async def fusedsearch(request, params_kw, *params): service_params = await get_service_params(orgid) if not service_params: raise ValueError("无法获取服务参数") + # 获取嵌入模式 + embedding_mode = await get_embedding_mode(orgid) + debug(f"检测到 embedding_mode = {embedding_mode}(0=文本, 1=多模态)") - try: - timings = {} - start_time = time.time() - rag_ops = RagOperations() + # 情况1:query 和 image 都为空 → 报错 + if not query and not img_path: + raise ValueError("查询文本和图片不能同时为空") - query_entities = await rag_ops.extract_entities(request, query, service_params, userid, timings) - all_triplets = await rag_ops.match_triplets(request, query, query_entities, orgid, fiids, service_params, - userid, timings) - combined_text = _combine_query_with_triplets(query, all_triplets) - query_vector = await rag_ops.generate_query_vector(request, combined_text, service_params, userid, timings) - search_results = await rag_ops.vector_search(request, query_vector, orgid, fiids, limit + 5, service_params, - userid, timings) + # 情况2:query 和 image 都存在 → 报错(你当前业务不允许同时传) + if query and img_path: + raise ValueError("查询文本和图片只能二选一,不能同时提交") - use_rerank = True - if use_rerank and search_results: - final_results = await rag_ops.rerank_results(request, combined_text, search_results, limit, service_params, + # 3. 只有图片 → 以图搜图 走纯多模态分支 + if img_path and not query: + try: + debug("检测到纯图片查询,执行以图搜图") + rag_ops = RagOperations() + + timings = {} + start_time = time.time() + + # 直接生成图片向量 + img_vector = await rag_ops.generate_image_vector( + request, img_path, service_params, userid, timings, embedding_mode + ) + + # 向量搜索(多取 50 条再截断,和文本分支保持一致) + search_results = await rag_ops.vector_search( + request, img_vector, orgid, fiids, limit + 50, service_params, userid, timings + ) + + timings["total_time"] = time.time() - start_time + + # 可选:搜索完后删除图片,省磁盘(看你需求) + # try: + # os.remove(img_path) + # except: + # pass + + final_results = [] + for item in search_results[:limit]: + final_results.append({ + "text": item["text"], + "distance": item["distance"] + }) + + return { + "results": final_results, + "timings": timings + } + except Exception as e: + error(f"融合搜索失败: {str(e)}, 堆栈: {traceback.format_exc()}") + return { + "records": [], + "timings": {"total_time": time.time() - start_time if 'start_time' in locals() else 0}, + "error": str(e) + } + + if not img_path and query: + try: + timings = {} + start_time = time.time() + rag_ops = RagOperations() + + query_entities = await rag_ops.extract_entities(request, query, service_params, userid, timings) + all_triplets = await rag_ops.match_triplets(request, query, query_entities, orgid, fiids, service_params, + userid, timings) + combined_text = _combine_query_with_triplets(query, all_triplets) + query_vector = await rag_ops.generate_query_vector(request, combined_text, service_params, userid, timings, embedding_mode) + search_results = await rag_ops.vector_search(request, query_vector, orgid, fiids, limit + 50, service_params, userid, timings) - debug(f"final_results: {final_results}") - else: - final_results = [{k: v for k, v in r.items() if k != 'rerank_score'} for r in search_results] - formatted_results = rag_ops.format_search_results(final_results, limit) - timings["total_time"] = time.time() - start_time - info(f"融合搜索完成,返回 {len(formatted_results)} 条结果,总耗时: {timings['total_time']:.3f} 秒") + use_rerank = True + if use_rerank and search_results: + final_results = await rag_ops.rerank_results(request, combined_text, search_results, limit, service_params, + userid, timings) + debug(f"final_results: {final_results}") + else: + final_results = [{k: v for k, v in r.items() if k != 'rerank_score'} for r in search_results] - return { - "records": formatted_results, - "timings": timings - } - except Exception as e: - error(f"融合搜索失败: {str(e)}, 堆栈: {traceback.format_exc()}") - return { - "records": [], - "timings": {"total_time": time.time() - start_time if 'start_time' in locals() else 0}, - "error": str(e) - } + formatted_results = rag_ops.format_search_results(final_results, limit) + timings["total_time"] = time.time() - start_time + debug(f"融合搜索完成,返回 {len(formatted_results)} 条结果,总耗时: {timings['total_time']:.3f} 秒") + + return { + "records": formatted_results, + "timings": timings + } + except Exception as e: + error(f"融合搜索失败: {str(e)}, 堆栈: {traceback.format_exc()}") + return { + "records": [], + "timings": {"total_time": time.time() - start_time if 'start_time' in locals() else 0}, + "error": str(e) + } + +# async def fusedsearch(request, params_kw, *params): +# """ +# 融合搜索,调用服务化端点 +# +# """ +# kw = request._run_ns +# f = kw.get('get_userorgid') +# orgid = await f() +# debug(f"orgid: {orgid},{f=}") +# f = kw.get('get_user') +# userid = await f() +# debug(f"params_kw: {params_kw}") +# # orgid = "04J6VbxLqB_9RPMcgOv_8" +# # userid = "04J6VbxLqB_9RPMcgOv_8" +# query = params_kw.get('query', '') +# # 统一模式处理 limit 参数,为了对接dify和coze +# raw_limit = params_kw.get('limit') or ( +# params_kw.get('retrieval_setting', {}).get('top_k') +# if isinstance(params_kw.get('retrieval_setting'), dict) +# else None +# ) +# +# # 标准化为整数值 +# if raw_limit is None: +# limit = 5 # 两个来源都不存在时使用默认值 +# elif isinstance(raw_limit, (int, float)): +# limit = int(raw_limit) # 数值类型直接转换 +# elif isinstance(raw_limit, str): +# try: +# # 字符串转换为整数 +# limit = int(raw_limit) +# except (TypeError, ValueError): +# limit = 5 # 转换失败使用默认值 +# else: +# limit = 5 # 其他意外类型使用默认值 +# debug(f"limit: {limit}") +# raw_fiids = params_kw.get('fiids') or params_kw.get('knowledge_id') # +# +# # 标准化为列表格式 +# if raw_fiids is None: +# fiids = [] # 两个参数都不存在 +# elif isinstance(raw_fiids, list): +# fiids = [str(item).strip() for item in raw_fiids] # 已经是列表 +# elif isinstance(raw_fiids, str): +# # fiids = [f.strip() for f in raw_fiids.split(',') if f.strip()] +# try: +# # 尝试解析 JSON 字符串 +# parsed = json.loads(raw_fiids) +# if isinstance(parsed, list): +# fiids = [str(item).strip() for item in parsed] # JSON 数组转为字符串列表 +# else: +# # 处理逗号分隔的字符串或单个 ID 字符串 +# fiids = [f.strip() for f in raw_fiids.split(',') if f.strip()] +# except json.JSONDecodeError: +# # 如果不是合法 JSON,按逗号分隔 +# fiids = [f.strip() for f in raw_fiids.split(',') if f.strip()] +# elif isinstance(raw_fiids, (int, float)): +# fiids = [str(int(raw_fiids))] # 数值类型转为字符串列表 +# else: +# fiids = [] # 其他意外类型 +# +# debug(f"fiids: {fiids}") +# +# # 验证 fiids的orgid与orgid = await f()是否一致 +# await _validate_fiids_orgid(fiids, orgid, kw) +# +# service_params = await get_service_params(orgid) +# if not service_params: +# raise ValueError("无法获取服务参数") +# # 获取嵌入模式 +# embedding_mode = await get_embedding_mode(orgid) +# debug(f"检测到 embedding_mode = {embedding_mode}(0=文本, 1=多模态)") +# +# try: +# timings = {} +# start_time = time.time() +# rag_ops = RagOperations() +# +# query_entities = await rag_ops.extract_entities(request, query, service_params, userid, timings) +# all_triplets = await rag_ops.match_triplets(request, query, query_entities, orgid, fiids, service_params, +# userid, timings) +# combined_text = _combine_query_with_triplets(query, all_triplets) +# query_vector = await rag_ops.generate_query_vector(request, combined_text, service_params, userid, timings, embedding_mode) +# search_results = await rag_ops.vector_search(request, query_vector, orgid, fiids, limit + 50, service_params, +# userid, timings) +# +# use_rerank = False +# if use_rerank and search_results: +# final_results = await rag_ops.rerank_results(request, combined_text, search_results, limit, service_params, +# userid, timings) +# debug(f"final_results: {final_results}") +# else: +# final_results = [{k: v for k, v in r.items() if k != 'rerank_score'} for r in search_results] +# +# formatted_results = rag_ops.format_search_results(final_results, limit) +# timings["total_time"] = time.time() - start_time +# debug(f"融合搜索完成,返回 {len(formatted_results)} 条结果,总耗时: {timings['total_time']:.3f} 秒") +# +# return { +# "records": formatted_results, +# "timings": timings +# } +# except Exception as e: +# error(f"融合搜索失败: {str(e)}, 堆栈: {traceback.format_exc()}") +# return { +# "records": [], +# "timings": {"total_time": time.time() - start_time if 'start_time' in locals() else 0}, +# "error": str(e) +# } # async def text_insert(text: str, fiid: str, orgid: str, db_type: str): async def textinsert(request, params_kw, *params): diff --git a/rag/service_opts.py b/rag/service_opts.py index 9cb7248..559ade5 100644 --- a/rag/service_opts.py +++ b/rag/service_opts.py @@ -94,7 +94,7 @@ async def sor_get_embedding_mode(sor, orgid) -> int: async def get_embedding_mode(orgid): db = DBPools() - debug(f"传入的orgid是:{orgid}") + # debug(f"传入的orgid是:{orgid}") dbname = get_serverenv('get_module_dbname')('rag') async with db.sqlorContext(dbname) as sor: return await sor_get_embedding_mode(sor, orgid) diff --git a/wwwroot/test.ui b/wwwroot/test.ui index 622ccf2..1a5caa5 100644 --- a/wwwroot/test.ui +++ b/wwwroot/test.ui @@ -18,6 +18,11 @@ "editable": true, "rows": 5 }, + { + "uitype": "image", + "name": "image", + "label": "上传查询图片(可选)" + }, { "name": "fiids", "uitype": "checkbox", diff --git a/wwwroot/test_query.dspy b/wwwroot/test_query.dspy index 381290a..6e49ffa 100644 --- a/wwwroot/test_query.dspy +++ b/wwwroot/test_query.dspy @@ -7,14 +7,15 @@ if not orgid: message='请先登录' ) -fiids = params_kw.fiids query = params_kw.query +image = params_kw.image +fiids = params_kw.fiids limit = params_kw.limit -if not query or not fiids or not limit: +if (not query and not image) or not fiids or not limit: return UiError( title='无效输入', - message='请输入查询文本并选择至少一个知识库' + message='请输入查询文本或上传image并选择至少一个知识库和填写返回条数' ) try: