From 66dec8faabee99c7c2d1ba2cfe0dbe763ad82146 Mon Sep 17 00:00:00 2001 From: wangmeihua <13383952685@163.com> Date: Fri, 28 Nov 2025 16:25:40 +0800 Subject: [PATCH] llmengine --- llmengine/bge_reranker.py | 2 +- llmengine/db_service.py | 98 ++++++++++++++++++++++++++++----------- llmengine/milvus_db.py | 31 ++++++++++--- llmengine/mm_embedding.py | 21 +++++---- test/m2m/conf/config.json | 4 +- 5 files changed, 111 insertions(+), 45 deletions(-) diff --git a/llmengine/bge_reranker.py b/llmengine/bge_reranker.py index 3c1bd0b..6d7549d 100644 --- a/llmengine/bge_reranker.py +++ b/llmengine/bge_reranker.py @@ -22,7 +22,7 @@ class BgeReranker(BaseReranker): def process_inputs(self, pairs): inputs = self.tokenizer(pairs, padding=True, - truncation=True, return_tensors='pt', max_length=512) + truncation=True, return_tensors='pt', max_length=8096) if torch.cuda.is_available(): inputs = {k: v.to('cuda') for k, v in inputs.items()} return inputs diff --git a/llmengine/db_service.py b/llmengine/db_service.py index 33400fd..97a6457 100755 --- a/llmengine/db_service.py +++ b/llmengine/db_service.py @@ -16,10 +16,25 @@ path: /v1/createcollection method: POST headers: {"Content-Type": "application/json"} data: { - "db_type": "textdb" // 可选,若不提供则使用默认集合 ragdb + "db_type": "textdb", // 可选,若不提供则使用默认集合 ragdb + "fields": [ // 可选,动态字段定义,格式为 FieldSchema 列表 + { + "name": "pk", + "dtype": "VARCHAR", + "is_primary": true, + "max_length": 36, + "auto_id": true + }, + { + "name": "vector", + "dtype": "FLOAT_VECTOR", + "dim": 1024 + }, + ... + ] } response: -- Success: HTTP 200, {"status": "success", "collection_name": "ragdb" or "ragdb_textdb", "message": "集合 ragdb 或 ragdb_textdb 创建成功"} +- Success: HTTP 200, {"status": "success", "collection_name": "ragdb" or "ragdb_textdb" or "ragdb_textdb_", "message": "集合 创建成功或已存在"} - Error: HTTP 400, {"status": "error", "collection_name": "ragdb" or "ragdb_textdb", "message": ""} 2. Delete Collection Endpoint: @@ -216,12 +231,24 @@ response: 10. Docs Endpoint: path: /docs -method: GET +method: POST response: This help text + +11. Initialize Connection Endpoint: +path: /v1/initialize +method: POST +headers: {"Content-Type": "application/json"} +data: { + "db_type": "textdb" // 可选,若不提供则使用默认集合 ragdb +} +response: +- Success: HTTP 200, {"status": "success", "collection_name": "ragdb" or "ragdb_textdb", "message": "Milvus 连接已初始化,路径: "} +- Error: HTTP 400, {"status": "error", "collection_name": "ragdb" or "ragdb_textdb", "message": "", "status_code": 400} """ def init(): rf = RegisterFunction() + rf.register('initialize', initialize) rf.register('createcollection', create_collection) rf.register('deletecollection', delete_collection) rf.register('insertdocument', insert_document) @@ -236,22 +263,45 @@ def init(): async def docs(request, params_kw, *params, **kw): return web.Response(text=helptext, content_type='text/plain') -async def create_collection(request, params_kw, *params, **kw): - debug(f'{params_kw=}') +async def initialize(request, params_kw, *params, **kw): + debug(f'Received initialize params: {params_kw=}') se = ServerEnv() engine = se.engine db_type = params_kw.get('db_type', '') collection_name = "ragdb" if not db_type else f"ragdb_{db_type}" try: - result = await engine.handle_connection("create_collection", {"db_type": db_type}) + result = await engine.handle_connection("initialize", {"db_type": db_type}) + debug(f'Initialize result: {result=}') + status = 200 if result.get("status") == "success" else 400 + return web.json_response(result, dumps=lambda obj: json.dumps(obj, ensure_ascii=False), status=status) + except Exception as e: + error(f'初始化连接失败: {str(e)}') + return web.json_response({ + "status": "error", + "collection_name": collection_name, + "message": str(e), + "status_code": 400 + }, dumps=lambda obj: json.dumps(obj, ensure_ascii=False), status=400) + +async def create_collection(request, params_kw, *params, **kw): + debug(f'{params_kw=}') + se = ServerEnv() + engine = se.engine + db_type = params_kw.get('db_type', '') + fields = params_kw.get('fields', None) + collection_name = "ragdb" if not db_type else f"ragdb_{db_type}" + try: + result = await engine.handle_connection("create_collection", {"db_type": db_type, "fields": fields}) debug(f'{result=}') - return web.json_response(result, dumps=lambda obj: json.dumps(obj, ensure_ascii=False)) + status = 200 if result.get("status") == "success" else 400 + return web.json_response(result, dumps=lambda obj: json.dumps(obj, ensure_ascii=False), status=status) except Exception as e: error(f'创建集合失败: {str(e)}') return web.json_response({ "status": "error", "collection_name": collection_name, - "message": str(e) + "message": str(e), + "status_code": 400 }, dumps=lambda obj: json.dumps(obj, ensure_ascii=False), status=400) async def delete_collection(request, params_kw, *params, **kw): @@ -313,7 +363,7 @@ async def delete_document(request, params_kw, *params, **kw): db_type = params_kw.get('db_type', '') collection_name = "ragdb" if not db_type else f"ragdb_{db_type}" try: - required_fields = ['userid', 'file_path', 'knowledge_base_id', 'document_id'] + required_fields = ['userid', 'knowledge_base_id', 'document_id'] missing_fields = [field for field in required_fields if field not in params_kw or not params_kw[field]] if missing_fields: raise ValueError(f"缺少必填字段: {', '.join(missing_fields)}") @@ -382,11 +432,11 @@ async def search_query(request, params_kw, *params, **kw): db_type = params_kw.get('db_type', '') collection_name = "ragdb" if not db_type else f"ragdb_{db_type}" try: - if not query_vector or not userid or not knowledge_base_ids: - debug(f'query_vector, userid 或 knowledge_base_ids 未提供') + if not query_vector or not userid: + debug(f'query_vector或userid 未提供') return web.json_response({ "status": "error", - "message": "query_vector, userid 或 knowledge_base_ids 未提供", + "message": "query_vector或userid未提供", "collection_name": collection_name }, dumps=lambda obj: json.dumps(obj, ensure_ascii=False), status=400) result = await engine.handle_connection("search_query", { @@ -419,7 +469,8 @@ async def list_user_files(request, params_kw, *params, **kw): engine = se.engine userid = params_kw.get('userid', '') db_type = params_kw.get('db_type', '') - collection_name = "ragdb" if not db_type else f"ragdb_{db_type}" + collection_name = "ragdb" if not db_type or db_type == "ragdb" else f"ragdb_{db_type}" + try: if not userid: debug(f'userid 未提供') @@ -428,17 +479,13 @@ async def list_user_files(request, params_kw, *params, **kw): "message": "userid 未提供", "collection_name": collection_name }, dumps=lambda obj: json.dumps(obj, ensure_ascii=False), status=400) + # 直接调用底层逻辑 result = await engine.handle_connection("list_user_files", { "userid": userid, "db_type": db_type }) - debug(f'{result=}') - response = { - "status": "success", - "files_by_knowledge_base": result, - "collection_name": collection_name - } - return web.json_response(response, dumps=lambda obj: json.dumps(obj, ensure_ascii=False)) + debug(f'底层返回: {result=}') + return web.json_response(result, dumps=lambda obj: json.dumps(obj, ensure_ascii=False)) except Exception as e: error(f'列出用户文件失败: {str(e)}') return web.json_response({ @@ -452,7 +499,6 @@ async def list_all_knowledge_bases(request, params_kw, *params, **kw): se = ServerEnv() engine = se.engine db_type = params_kw.get('db_type', '') - collection_name = "ragdb" if not db_type else f"ragdb_{db_type}" try: result = await engine.handle_connection("list_all_knowledge_bases", { "db_type": db_type @@ -460,18 +506,18 @@ async def list_all_knowledge_bases(request, params_kw, *params, **kw): debug(f'{result=}') response = { "status": result.get("status", "success"), - "users_knowledge_bases": result.get("users_knowledge_bases", {}), - "collection_name": collection_name, + "collections": result.get("collections", {}), + "collection_names": result.get("collection_names", "none"), "message": result.get("message", ""), "status_code": result.get("status_code", 200) } return web.json_response(response, dumps=lambda obj: json.dumps(obj, ensure_ascii=False), status=response["status_code"]) except Exception as e: - error(f'列出所有用户知识库失败: {str(e)}') + error(f'列出所有知识库失败: {str(e)}') return web.json_response({ "status": "error", - "users_knowledge_bases": {}, - "collection_name": collection_name, + "collections": {}, + "collection_names": "none", "message": str(e), "status_code": 400 }, dumps=lambda obj: json.dumps(obj, ensure_ascii=False), status=400) diff --git a/llmengine/milvus_db.py b/llmengine/milvus_db.py index 04c132d..3f51459 100755 --- a/llmengine/milvus_db.py +++ b/llmengine/milvus_db.py @@ -10,7 +10,7 @@ from llmengine.base_db import connection_register, BaseDBConnection import time import traceback import hashlib - +import numpy as np class MilvusDBConnection(BaseDBConnection): _instance = None _lock = Lock() @@ -606,6 +606,21 @@ class MilvusDBConnection(BaseDBConnection): async def _search_query(self, query_vector: List[float], userid: str, knowledge_base_ids: List[str] | None, limit: int = 5, offset: int = 0, db_type: str = "") -> Dict[str, Any]: """基于向量搜索 Milvus 集合(knowledge_base_ids 可为空)""" + # query_vector = np.array(query_vector, dtype=np.float32) + # original_norm = np.linalg.norm(query_vector) + # debug(f"API 返回向量范数: {original_norm:.10f}") + # + # if original_norm == 0: + # raise ValueError("zero vector") + # + # # 强制归一化 + # query_vector = query_vector / original_norm + # + # # 重新计算范数 + # final_norm = np.linalg.norm(query_vector) + # debug(f"强制归一化后范数: {final_norm:.10f}") + # + # query_vector = query_vector.tolist() timing_stats = {} start_time = time.time() try: @@ -649,7 +664,7 @@ class MilvusDBConnection(BaseDBConnection): is_ragdb = collection_name == "ragdb" output_fields = ( ["text", "userid", "document_id", "filename", "file_path", "upload_time", "file_type", - "knowledge_base_id"] + "knowledge_base_id", "vector"] if is_ragdb else ["document_id", "userid", "knowledge_base_id", "text"] ) @@ -685,6 +700,7 @@ class MilvusDBConnection(BaseDBConnection): output_fields=output_fields, # 动态 offset=offset # 保留 ) + # debug(f"使用的召回索引是:{collection.index().describe()}") except Exception as e: error(f"搜索失败: {str(e)}") return {"results": [], "timing": timing_stats} @@ -703,11 +719,12 @@ class MilvusDBConnection(BaseDBConnection): "file_path": hit.entity.get("file_path"), "upload_time": hit.entity.get("upload_time"), "file_type": hit.entity.get("file_type"), - "knowledge_base_id": hit.entity.get("knowledge_base_id") + "knowledge_base_id": hit.entity.get("knowledge_base_id"), + "vector": hit.entity.get("vector")[:10], } raw_results.append({ "text": hit.entity.get("text"), - "distance": 1.0 - hit.distance, + "distance": hit.distance, "metadata": metadata }) else: @@ -718,7 +735,7 @@ class MilvusDBConnection(BaseDBConnection): } raw_results.append({ "text": hit.entity.get("text"), - "distance": 1.0 - hit.distance, + "distance": hit.distance, "metadata": metadata }) @@ -921,7 +938,7 @@ class MilvusDBConnection(BaseDBConnection): "file_path": result.get("file_path", ""), "upload_time": result.get("upload_time", ""), "file_type": result.get("file_type", ""), - "knowledge_base_id": kb_id + "knowledge_base_id": kb_id, } files_by_knowledge_base.setdefault(kb_id, []).append(file_info) else: @@ -935,7 +952,7 @@ class MilvusDBConnection(BaseDBConnection): debug(f"找到文件: document_id={document_id}, kb_id={kb_id}, file_info={file_info}") - info(f"找到 {len(seen_document_ids)} 个文件") + debug(f"找到 {len(seen_document_ids)} 个文件") return { "status": "success", "files_by_knowledge_base": files_by_knowledge_base, diff --git a/llmengine/mm_embedding.py b/llmengine/mm_embedding.py index 0d9cd33..be0b2af 100644 --- a/llmengine/mm_embedding.py +++ b/llmengine/mm_embedding.py @@ -71,7 +71,7 @@ CLIP_MODEL_NAME = "/data/ymq/models/laion/CLIP-ViT-H-14-laion2B-s32B-b79K" FRAME_SAMPLE_RATE = 1.0 FRAME_LIMIT = 64 AUDIO_SR = 16000 -IMAGE_DIR = Path("/share/wangmeihua/data/mmembedding/image") +IMAGE_DIR = Path("/share/wangmeihua/data/mmembedding/image/photo") AUDIO_DIR = Path("/share/wangmeihua/data/mmembedding/audio") VIDEO_DIR = Path("/share/wangmeihua/data/mmembedding/video") @@ -80,8 +80,11 @@ for d in [IMAGE_DIR, AUDIO_DIR, VIDEO_DIR]: # ------------------- 工具函数 ------------------- def l2_normalize(v): + v = np.array(v, dtype=np.float32) norm = np.linalg.norm(v) - return v / norm if norm > 1e-10 else v + if norm == 0: + return v.tolist() + return (v / norm).tolist() def chunked(lst, n): for i in range(0, len(lst), n): @@ -137,7 +140,7 @@ class MM_Embedding: # 1. 提取 base64 try: header, b64 = data_uri.split(",", 1) - debug(f"header: {header},b64: {b64}") + debug(f"header: {header}, b64: {b64[:100]}...({len(b64)} chars total)") binary = base64.b64decode(b64) except Exception as e: error(f"解码失败: {str(e)}, 堆栈: {traceback.format_exc()}") @@ -210,7 +213,7 @@ class MM_Embedding: feats = self.model.get_text_features(**inputs) feats = feats.cpu().numpy() for t, v in zip(batch, feats): - results[t] = {"type": "text", "vector": l2_normalize(v).tolist()} + results[t] = {"type": "text", "vector": l2_normalize(v)} return results def _embed_images(self, paths): @@ -227,9 +230,9 @@ class MM_Embedding: results[p] = { "type": "image", "path": p, - "vector": l2_normalize(v).tolist(), + "vector": l2_normalize(v), "face_count": len(face_vecs), - "face_vecs": [vec.tolist() for vec in face_vecs] + "face_vecs": [vec for vec in face_vecs] } return results @@ -270,9 +273,9 @@ class MM_Embedding: results[p] = { "type": "video", "path": p, - "vector": video_vec.tolist(), + "vector": video_vec, "face_count": len(face_vecs), - "face_vecs": [vec.tolist() for vec in face_vecs] + "face_vecs": [vec for vec in face_vecs] } except Exception as e: exception(f"Video {p} failed: {e}") @@ -292,7 +295,7 @@ class MM_Embedding: with torch.no_grad(): with torch.amp.autocast('cuda', enabled=USE_FP16): feats = self.model.get_image_features(**inputs) - results[p] = {"type": "audio", "vector": l2_normalize(feats.cpu().numpy()[0]).tolist()} + results[p] = {"type": "audio", "vector": l2_normalize(feats.cpu().numpy()[0])} except Exception as e: exception(f"Audio {p} failed: {e}") results[p] = None diff --git a/test/m2m/conf/config.json b/test/m2m/conf/config.json index 09a49ac..85c17ca 100644 --- a/test/m2m/conf/config.json +++ b/test/m2m/conf/config.json @@ -22,8 +22,8 @@ "leading":"/idfile", "registerfunction":"idfile" },{ - "leading": "/v1/m2m", - "registerfunction": "m2m" + "leading": "/v1/translate", + "registerfunction": "translate" },{ "leading": "/docs", "registerfunction": "docs"