llmengine

This commit is contained in:
wangmeihua 2025-11-28 16:25:40 +08:00
parent ace848f996
commit 66dec8faab
5 changed files with 111 additions and 45 deletions

View File

@ -22,7 +22,7 @@ class BgeReranker(BaseReranker):
def process_inputs(self, pairs): def process_inputs(self, pairs):
inputs = self.tokenizer(pairs, padding=True, inputs = self.tokenizer(pairs, padding=True,
truncation=True, return_tensors='pt', max_length=512) truncation=True, return_tensors='pt', max_length=8096)
if torch.cuda.is_available(): if torch.cuda.is_available():
inputs = {k: v.to('cuda') for k, v in inputs.items()} inputs = {k: v.to('cuda') for k, v in inputs.items()}
return inputs return inputs

View File

@ -16,10 +16,25 @@ path: /v1/createcollection
method: POST method: POST
headers: {"Content-Type": "application/json"} headers: {"Content-Type": "application/json"}
data: { data: {
"db_type": "textdb" // 可选若不提供则使用默认集合 ragdb "db_type": "textdb", // 可选若不提供则使用默认集合 ragdb
"fields": [ // 可选动态字段定义格式为 FieldSchema 列表
{
"name": "pk",
"dtype": "VARCHAR",
"is_primary": true,
"max_length": 36,
"auto_id": true
},
{
"name": "vector",
"dtype": "FLOAT_VECTOR",
"dim": 1024
},
...
]
} }
response: response:
- Success: HTTP 200, {"status": "success", "collection_name": "ragdb" or "ragdb_textdb", "message": "集合 ragdb 或 ragdb_textdb 创建成功"} - Success: HTTP 200, {"status": "success", "collection_name": "ragdb" or "ragdb_textdb" or "ragdb_textdb_<hash>", "message": "集合 <collection_name> 创建成功或已存在"}
- Error: HTTP 400, {"status": "error", "collection_name": "ragdb" or "ragdb_textdb", "message": "<error message>"} - Error: HTTP 400, {"status": "error", "collection_name": "ragdb" or "ragdb_textdb", "message": "<error message>"}
2. Delete Collection Endpoint: 2. Delete Collection Endpoint:
@ -216,12 +231,24 @@ response:
10. Docs Endpoint: 10. Docs Endpoint:
path: /docs path: /docs
method: GET method: POST
response: This help text response: This help text
11. Initialize Connection Endpoint:
path: /v1/initialize
method: POST
headers: {"Content-Type": "application/json"}
data: {
"db_type": "textdb" // 可选若不提供则使用默认集合 ragdb
}
response:
- Success: HTTP 200, {"status": "success", "collection_name": "ragdb" or "ragdb_textdb", "message": "Milvus 连接已初始化,路径: <db_path>"}
- Error: HTTP 400, {"status": "error", "collection_name": "ragdb" or "ragdb_textdb", "message": "<error message>", "status_code": 400}
""" """
def init(): def init():
rf = RegisterFunction() rf = RegisterFunction()
rf.register('initialize', initialize)
rf.register('createcollection', create_collection) rf.register('createcollection', create_collection)
rf.register('deletecollection', delete_collection) rf.register('deletecollection', delete_collection)
rf.register('insertdocument', insert_document) rf.register('insertdocument', insert_document)
@ -236,22 +263,45 @@ def init():
async def docs(request, params_kw, *params, **kw): async def docs(request, params_kw, *params, **kw):
return web.Response(text=helptext, content_type='text/plain') return web.Response(text=helptext, content_type='text/plain')
async def create_collection(request, params_kw, *params, **kw): async def initialize(request, params_kw, *params, **kw):
debug(f'{params_kw=}') debug(f'Received initialize params: {params_kw=}')
se = ServerEnv() se = ServerEnv()
engine = se.engine engine = se.engine
db_type = params_kw.get('db_type', '') db_type = params_kw.get('db_type', '')
collection_name = "ragdb" if not db_type else f"ragdb_{db_type}" collection_name = "ragdb" if not db_type else f"ragdb_{db_type}"
try: try:
result = await engine.handle_connection("create_collection", {"db_type": db_type}) result = await engine.handle_connection("initialize", {"db_type": db_type})
debug(f'Initialize result: {result=}')
status = 200 if result.get("status") == "success" else 400
return web.json_response(result, dumps=lambda obj: json.dumps(obj, ensure_ascii=False), status=status)
except Exception as e:
error(f'初始化连接失败: {str(e)}')
return web.json_response({
"status": "error",
"collection_name": collection_name,
"message": str(e),
"status_code": 400
}, dumps=lambda obj: json.dumps(obj, ensure_ascii=False), status=400)
async def create_collection(request, params_kw, *params, **kw):
debug(f'{params_kw=}')
se = ServerEnv()
engine = se.engine
db_type = params_kw.get('db_type', '')
fields = params_kw.get('fields', None)
collection_name = "ragdb" if not db_type else f"ragdb_{db_type}"
try:
result = await engine.handle_connection("create_collection", {"db_type": db_type, "fields": fields})
debug(f'{result=}') debug(f'{result=}')
return web.json_response(result, dumps=lambda obj: json.dumps(obj, ensure_ascii=False)) status = 200 if result.get("status") == "success" else 400
return web.json_response(result, dumps=lambda obj: json.dumps(obj, ensure_ascii=False), status=status)
except Exception as e: except Exception as e:
error(f'创建集合失败: {str(e)}') error(f'创建集合失败: {str(e)}')
return web.json_response({ return web.json_response({
"status": "error", "status": "error",
"collection_name": collection_name, "collection_name": collection_name,
"message": str(e) "message": str(e),
"status_code": 400
}, dumps=lambda obj: json.dumps(obj, ensure_ascii=False), status=400) }, dumps=lambda obj: json.dumps(obj, ensure_ascii=False), status=400)
async def delete_collection(request, params_kw, *params, **kw): async def delete_collection(request, params_kw, *params, **kw):
@ -313,7 +363,7 @@ async def delete_document(request, params_kw, *params, **kw):
db_type = params_kw.get('db_type', '') db_type = params_kw.get('db_type', '')
collection_name = "ragdb" if not db_type else f"ragdb_{db_type}" collection_name = "ragdb" if not db_type else f"ragdb_{db_type}"
try: try:
required_fields = ['userid', 'file_path', 'knowledge_base_id', 'document_id'] required_fields = ['userid', 'knowledge_base_id', 'document_id']
missing_fields = [field for field in required_fields if field not in params_kw or not params_kw[field]] missing_fields = [field for field in required_fields if field not in params_kw or not params_kw[field]]
if missing_fields: if missing_fields:
raise ValueError(f"缺少必填字段: {', '.join(missing_fields)}") raise ValueError(f"缺少必填字段: {', '.join(missing_fields)}")
@ -382,11 +432,11 @@ async def search_query(request, params_kw, *params, **kw):
db_type = params_kw.get('db_type', '') db_type = params_kw.get('db_type', '')
collection_name = "ragdb" if not db_type else f"ragdb_{db_type}" collection_name = "ragdb" if not db_type else f"ragdb_{db_type}"
try: try:
if not query_vector or not userid or not knowledge_base_ids: if not query_vector or not userid:
debug(f'query_vector, userid 或 knowledge_base_ids 未提供') debug(f'query_vector或userid 未提供')
return web.json_response({ return web.json_response({
"status": "error", "status": "error",
"message": "query_vector, userid 或 knowledge_base_ids 未提供", "message": "query_vector或userid未提供",
"collection_name": collection_name "collection_name": collection_name
}, dumps=lambda obj: json.dumps(obj, ensure_ascii=False), status=400) }, dumps=lambda obj: json.dumps(obj, ensure_ascii=False), status=400)
result = await engine.handle_connection("search_query", { result = await engine.handle_connection("search_query", {
@ -419,7 +469,8 @@ async def list_user_files(request, params_kw, *params, **kw):
engine = se.engine engine = se.engine
userid = params_kw.get('userid', '') userid = params_kw.get('userid', '')
db_type = params_kw.get('db_type', '') db_type = params_kw.get('db_type', '')
collection_name = "ragdb" if not db_type else f"ragdb_{db_type}" collection_name = "ragdb" if not db_type or db_type == "ragdb" else f"ragdb_{db_type}"
try: try:
if not userid: if not userid:
debug(f'userid 未提供') debug(f'userid 未提供')
@ -428,17 +479,13 @@ async def list_user_files(request, params_kw, *params, **kw):
"message": "userid 未提供", "message": "userid 未提供",
"collection_name": collection_name "collection_name": collection_name
}, dumps=lambda obj: json.dumps(obj, ensure_ascii=False), status=400) }, dumps=lambda obj: json.dumps(obj, ensure_ascii=False), status=400)
# 直接调用底层逻辑
result = await engine.handle_connection("list_user_files", { result = await engine.handle_connection("list_user_files", {
"userid": userid, "userid": userid,
"db_type": db_type "db_type": db_type
}) })
debug(f'{result=}') debug(f'底层返回: {result=}')
response = { return web.json_response(result, dumps=lambda obj: json.dumps(obj, ensure_ascii=False))
"status": "success",
"files_by_knowledge_base": result,
"collection_name": collection_name
}
return web.json_response(response, dumps=lambda obj: json.dumps(obj, ensure_ascii=False))
except Exception as e: except Exception as e:
error(f'列出用户文件失败: {str(e)}') error(f'列出用户文件失败: {str(e)}')
return web.json_response({ return web.json_response({
@ -452,7 +499,6 @@ async def list_all_knowledge_bases(request, params_kw, *params, **kw):
se = ServerEnv() se = ServerEnv()
engine = se.engine engine = se.engine
db_type = params_kw.get('db_type', '') db_type = params_kw.get('db_type', '')
collection_name = "ragdb" if not db_type else f"ragdb_{db_type}"
try: try:
result = await engine.handle_connection("list_all_knowledge_bases", { result = await engine.handle_connection("list_all_knowledge_bases", {
"db_type": db_type "db_type": db_type
@ -460,18 +506,18 @@ async def list_all_knowledge_bases(request, params_kw, *params, **kw):
debug(f'{result=}') debug(f'{result=}')
response = { response = {
"status": result.get("status", "success"), "status": result.get("status", "success"),
"users_knowledge_bases": result.get("users_knowledge_bases", {}), "collections": result.get("collections", {}),
"collection_name": collection_name, "collection_names": result.get("collection_names", "none"),
"message": result.get("message", ""), "message": result.get("message", ""),
"status_code": result.get("status_code", 200) "status_code": result.get("status_code", 200)
} }
return web.json_response(response, dumps=lambda obj: json.dumps(obj, ensure_ascii=False), status=response["status_code"]) return web.json_response(response, dumps=lambda obj: json.dumps(obj, ensure_ascii=False), status=response["status_code"])
except Exception as e: except Exception as e:
error(f'列出所有用户知识库失败: {str(e)}') error(f'列出所有知识库失败: {str(e)}')
return web.json_response({ return web.json_response({
"status": "error", "status": "error",
"users_knowledge_bases": {}, "collections": {},
"collection_name": collection_name, "collection_names": "none",
"message": str(e), "message": str(e),
"status_code": 400 "status_code": 400
}, dumps=lambda obj: json.dumps(obj, ensure_ascii=False), status=400) }, dumps=lambda obj: json.dumps(obj, ensure_ascii=False), status=400)

View File

@ -10,7 +10,7 @@ from llmengine.base_db import connection_register, BaseDBConnection
import time import time
import traceback import traceback
import hashlib import hashlib
import numpy as np
class MilvusDBConnection(BaseDBConnection): class MilvusDBConnection(BaseDBConnection):
_instance = None _instance = None
_lock = Lock() _lock = Lock()
@ -606,6 +606,21 @@ class MilvusDBConnection(BaseDBConnection):
async def _search_query(self, query_vector: List[float], userid: str, async def _search_query(self, query_vector: List[float], userid: str,
knowledge_base_ids: List[str] | None, limit: int = 5, offset: int = 0, db_type: str = "") -> Dict[str, Any]: knowledge_base_ids: List[str] | None, limit: int = 5, offset: int = 0, db_type: str = "") -> Dict[str, Any]:
"""基于向量搜索 Milvus 集合knowledge_base_ids 可为空)""" """基于向量搜索 Milvus 集合knowledge_base_ids 可为空)"""
# query_vector = np.array(query_vector, dtype=np.float32)
# original_norm = np.linalg.norm(query_vector)
# debug(f"API 返回向量范数: {original_norm:.10f}")
#
# if original_norm == 0:
# raise ValueError("zero vector")
#
# # 强制归一化
# query_vector = query_vector / original_norm
#
# # 重新计算范数
# final_norm = np.linalg.norm(query_vector)
# debug(f"强制归一化后范数: {final_norm:.10f}")
#
# query_vector = query_vector.tolist()
timing_stats = {} timing_stats = {}
start_time = time.time() start_time = time.time()
try: try:
@ -649,7 +664,7 @@ class MilvusDBConnection(BaseDBConnection):
is_ragdb = collection_name == "ragdb" is_ragdb = collection_name == "ragdb"
output_fields = ( output_fields = (
["text", "userid", "document_id", "filename", "file_path", "upload_time", "file_type", ["text", "userid", "document_id", "filename", "file_path", "upload_time", "file_type",
"knowledge_base_id"] "knowledge_base_id", "vector"]
if is_ragdb else if is_ragdb else
["document_id", "userid", "knowledge_base_id", "text"] ["document_id", "userid", "knowledge_base_id", "text"]
) )
@ -685,6 +700,7 @@ class MilvusDBConnection(BaseDBConnection):
output_fields=output_fields, # 动态 output_fields=output_fields, # 动态
offset=offset # 保留 offset=offset # 保留
) )
# debug(f"使用的召回索引是:{collection.index().describe()}")
except Exception as e: except Exception as e:
error(f"搜索失败: {str(e)}") error(f"搜索失败: {str(e)}")
return {"results": [], "timing": timing_stats} return {"results": [], "timing": timing_stats}
@ -703,11 +719,12 @@ class MilvusDBConnection(BaseDBConnection):
"file_path": hit.entity.get("file_path"), "file_path": hit.entity.get("file_path"),
"upload_time": hit.entity.get("upload_time"), "upload_time": hit.entity.get("upload_time"),
"file_type": hit.entity.get("file_type"), "file_type": hit.entity.get("file_type"),
"knowledge_base_id": hit.entity.get("knowledge_base_id") "knowledge_base_id": hit.entity.get("knowledge_base_id"),
"vector": hit.entity.get("vector")[:10],
} }
raw_results.append({ raw_results.append({
"text": hit.entity.get("text"), "text": hit.entity.get("text"),
"distance": 1.0 - hit.distance, "distance": hit.distance,
"metadata": metadata "metadata": metadata
}) })
else: else:
@ -718,7 +735,7 @@ class MilvusDBConnection(BaseDBConnection):
} }
raw_results.append({ raw_results.append({
"text": hit.entity.get("text"), "text": hit.entity.get("text"),
"distance": 1.0 - hit.distance, "distance": hit.distance,
"metadata": metadata "metadata": metadata
}) })
@ -921,7 +938,7 @@ class MilvusDBConnection(BaseDBConnection):
"file_path": result.get("file_path", ""), "file_path": result.get("file_path", ""),
"upload_time": result.get("upload_time", ""), "upload_time": result.get("upload_time", ""),
"file_type": result.get("file_type", ""), "file_type": result.get("file_type", ""),
"knowledge_base_id": kb_id "knowledge_base_id": kb_id,
} }
files_by_knowledge_base.setdefault(kb_id, []).append(file_info) files_by_knowledge_base.setdefault(kb_id, []).append(file_info)
else: else:
@ -935,7 +952,7 @@ class MilvusDBConnection(BaseDBConnection):
debug(f"找到文件: document_id={document_id}, kb_id={kb_id}, file_info={file_info}") debug(f"找到文件: document_id={document_id}, kb_id={kb_id}, file_info={file_info}")
info(f"找到 {len(seen_document_ids)} 个文件") debug(f"找到 {len(seen_document_ids)} 个文件")
return { return {
"status": "success", "status": "success",
"files_by_knowledge_base": files_by_knowledge_base, "files_by_knowledge_base": files_by_knowledge_base,

View File

@ -71,7 +71,7 @@ CLIP_MODEL_NAME = "/data/ymq/models/laion/CLIP-ViT-H-14-laion2B-s32B-b79K"
FRAME_SAMPLE_RATE = 1.0 FRAME_SAMPLE_RATE = 1.0
FRAME_LIMIT = 64 FRAME_LIMIT = 64
AUDIO_SR = 16000 AUDIO_SR = 16000
IMAGE_DIR = Path("/share/wangmeihua/data/mmembedding/image") IMAGE_DIR = Path("/share/wangmeihua/data/mmembedding/image/photo")
AUDIO_DIR = Path("/share/wangmeihua/data/mmembedding/audio") AUDIO_DIR = Path("/share/wangmeihua/data/mmembedding/audio")
VIDEO_DIR = Path("/share/wangmeihua/data/mmembedding/video") VIDEO_DIR = Path("/share/wangmeihua/data/mmembedding/video")
@ -80,8 +80,11 @@ for d in [IMAGE_DIR, AUDIO_DIR, VIDEO_DIR]:
# ------------------- 工具函数 ------------------- # ------------------- 工具函数 -------------------
def l2_normalize(v): def l2_normalize(v):
v = np.array(v, dtype=np.float32)
norm = np.linalg.norm(v) norm = np.linalg.norm(v)
return v / norm if norm > 1e-10 else v if norm == 0:
return v.tolist()
return (v / norm).tolist()
def chunked(lst, n): def chunked(lst, n):
for i in range(0, len(lst), n): for i in range(0, len(lst), n):
@ -137,7 +140,7 @@ class MM_Embedding:
# 1. 提取 base64 # 1. 提取 base64
try: try:
header, b64 = data_uri.split(",", 1) header, b64 = data_uri.split(",", 1)
debug(f"header: {header},b64: {b64}") debug(f"header: {header}, b64: {b64[:100]}...({len(b64)} chars total)")
binary = base64.b64decode(b64) binary = base64.b64decode(b64)
except Exception as e: except Exception as e:
error(f"解码失败: {str(e)}, 堆栈: {traceback.format_exc()}") error(f"解码失败: {str(e)}, 堆栈: {traceback.format_exc()}")
@ -210,7 +213,7 @@ class MM_Embedding:
feats = self.model.get_text_features(**inputs) feats = self.model.get_text_features(**inputs)
feats = feats.cpu().numpy() feats = feats.cpu().numpy()
for t, v in zip(batch, feats): for t, v in zip(batch, feats):
results[t] = {"type": "text", "vector": l2_normalize(v).tolist()} results[t] = {"type": "text", "vector": l2_normalize(v)}
return results return results
def _embed_images(self, paths): def _embed_images(self, paths):
@ -227,9 +230,9 @@ class MM_Embedding:
results[p] = { results[p] = {
"type": "image", "type": "image",
"path": p, "path": p,
"vector": l2_normalize(v).tolist(), "vector": l2_normalize(v),
"face_count": len(face_vecs), "face_count": len(face_vecs),
"face_vecs": [vec.tolist() for vec in face_vecs] "face_vecs": [vec for vec in face_vecs]
} }
return results return results
@ -270,9 +273,9 @@ class MM_Embedding:
results[p] = { results[p] = {
"type": "video", "type": "video",
"path": p, "path": p,
"vector": video_vec.tolist(), "vector": video_vec,
"face_count": len(face_vecs), "face_count": len(face_vecs),
"face_vecs": [vec.tolist() for vec in face_vecs] "face_vecs": [vec for vec in face_vecs]
} }
except Exception as e: except Exception as e:
exception(f"Video {p} failed: {e}") exception(f"Video {p} failed: {e}")
@ -292,7 +295,7 @@ class MM_Embedding:
with torch.no_grad(): with torch.no_grad():
with torch.amp.autocast('cuda', enabled=USE_FP16): with torch.amp.autocast('cuda', enabled=USE_FP16):
feats = self.model.get_image_features(**inputs) feats = self.model.get_image_features(**inputs)
results[p] = {"type": "audio", "vector": l2_normalize(feats.cpu().numpy()[0]).tolist()} results[p] = {"type": "audio", "vector": l2_normalize(feats.cpu().numpy()[0])}
except Exception as e: except Exception as e:
exception(f"Audio {p} failed: {e}") exception(f"Audio {p} failed: {e}")
results[p] = None results[p] = None

View File

@ -22,8 +22,8 @@
"leading":"/idfile", "leading":"/idfile",
"registerfunction":"idfile" "registerfunction":"idfile"
},{ },{
"leading": "/v1/m2m", "leading": "/v1/translate",
"registerfunction": "m2m" "registerfunction": "translate"
},{ },{
"leading": "/docs", "leading": "/docs",
"registerfunction": "docs" "registerfunction": "docs"