llmengine
This commit is contained in:
parent
ace848f996
commit
66dec8faab
@ -22,7 +22,7 @@ class BgeReranker(BaseReranker):
|
||||
|
||||
def process_inputs(self, pairs):
|
||||
inputs = self.tokenizer(pairs, padding=True,
|
||||
truncation=True, return_tensors='pt', max_length=512)
|
||||
truncation=True, return_tensors='pt', max_length=8096)
|
||||
if torch.cuda.is_available():
|
||||
inputs = {k: v.to('cuda') for k, v in inputs.items()}
|
||||
return inputs
|
||||
|
||||
@ -16,10 +16,25 @@ path: /v1/createcollection
|
||||
method: POST
|
||||
headers: {"Content-Type": "application/json"}
|
||||
data: {
|
||||
"db_type": "textdb" // 可选,若不提供则使用默认集合 ragdb
|
||||
"db_type": "textdb", // 可选,若不提供则使用默认集合 ragdb
|
||||
"fields": [ // 可选,动态字段定义,格式为 FieldSchema 列表
|
||||
{
|
||||
"name": "pk",
|
||||
"dtype": "VARCHAR",
|
||||
"is_primary": true,
|
||||
"max_length": 36,
|
||||
"auto_id": true
|
||||
},
|
||||
{
|
||||
"name": "vector",
|
||||
"dtype": "FLOAT_VECTOR",
|
||||
"dim": 1024
|
||||
},
|
||||
...
|
||||
]
|
||||
}
|
||||
response:
|
||||
- Success: HTTP 200, {"status": "success", "collection_name": "ragdb" or "ragdb_textdb", "message": "集合 ragdb 或 ragdb_textdb 创建成功"}
|
||||
- Success: HTTP 200, {"status": "success", "collection_name": "ragdb" or "ragdb_textdb" or "ragdb_textdb_<hash>", "message": "集合 <collection_name> 创建成功或已存在"}
|
||||
- Error: HTTP 400, {"status": "error", "collection_name": "ragdb" or "ragdb_textdb", "message": "<error message>"}
|
||||
|
||||
2. Delete Collection Endpoint:
|
||||
@ -216,12 +231,24 @@ response:
|
||||
|
||||
10. Docs Endpoint:
|
||||
path: /docs
|
||||
method: GET
|
||||
method: POST
|
||||
response: This help text
|
||||
|
||||
11. Initialize Connection Endpoint:
|
||||
path: /v1/initialize
|
||||
method: POST
|
||||
headers: {"Content-Type": "application/json"}
|
||||
data: {
|
||||
"db_type": "textdb" // 可选,若不提供则使用默认集合 ragdb
|
||||
}
|
||||
response:
|
||||
- Success: HTTP 200, {"status": "success", "collection_name": "ragdb" or "ragdb_textdb", "message": "Milvus 连接已初始化,路径: <db_path>"}
|
||||
- Error: HTTP 400, {"status": "error", "collection_name": "ragdb" or "ragdb_textdb", "message": "<error message>", "status_code": 400}
|
||||
"""
|
||||
|
||||
def init():
|
||||
rf = RegisterFunction()
|
||||
rf.register('initialize', initialize)
|
||||
rf.register('createcollection', create_collection)
|
||||
rf.register('deletecollection', delete_collection)
|
||||
rf.register('insertdocument', insert_document)
|
||||
@ -236,22 +263,45 @@ def init():
|
||||
async def docs(request, params_kw, *params, **kw):
|
||||
return web.Response(text=helptext, content_type='text/plain')
|
||||
|
||||
async def create_collection(request, params_kw, *params, **kw):
|
||||
debug(f'{params_kw=}')
|
||||
async def initialize(request, params_kw, *params, **kw):
|
||||
debug(f'Received initialize params: {params_kw=}')
|
||||
se = ServerEnv()
|
||||
engine = se.engine
|
||||
db_type = params_kw.get('db_type', '')
|
||||
collection_name = "ragdb" if not db_type else f"ragdb_{db_type}"
|
||||
try:
|
||||
result = await engine.handle_connection("create_collection", {"db_type": db_type})
|
||||
result = await engine.handle_connection("initialize", {"db_type": db_type})
|
||||
debug(f'Initialize result: {result=}')
|
||||
status = 200 if result.get("status") == "success" else 400
|
||||
return web.json_response(result, dumps=lambda obj: json.dumps(obj, ensure_ascii=False), status=status)
|
||||
except Exception as e:
|
||||
error(f'初始化连接失败: {str(e)}')
|
||||
return web.json_response({
|
||||
"status": "error",
|
||||
"collection_name": collection_name,
|
||||
"message": str(e),
|
||||
"status_code": 400
|
||||
}, dumps=lambda obj: json.dumps(obj, ensure_ascii=False), status=400)
|
||||
|
||||
async def create_collection(request, params_kw, *params, **kw):
|
||||
debug(f'{params_kw=}')
|
||||
se = ServerEnv()
|
||||
engine = se.engine
|
||||
db_type = params_kw.get('db_type', '')
|
||||
fields = params_kw.get('fields', None)
|
||||
collection_name = "ragdb" if not db_type else f"ragdb_{db_type}"
|
||||
try:
|
||||
result = await engine.handle_connection("create_collection", {"db_type": db_type, "fields": fields})
|
||||
debug(f'{result=}')
|
||||
return web.json_response(result, dumps=lambda obj: json.dumps(obj, ensure_ascii=False))
|
||||
status = 200 if result.get("status") == "success" else 400
|
||||
return web.json_response(result, dumps=lambda obj: json.dumps(obj, ensure_ascii=False), status=status)
|
||||
except Exception as e:
|
||||
error(f'创建集合失败: {str(e)}')
|
||||
return web.json_response({
|
||||
"status": "error",
|
||||
"collection_name": collection_name,
|
||||
"message": str(e)
|
||||
"message": str(e),
|
||||
"status_code": 400
|
||||
}, dumps=lambda obj: json.dumps(obj, ensure_ascii=False), status=400)
|
||||
|
||||
async def delete_collection(request, params_kw, *params, **kw):
|
||||
@ -313,7 +363,7 @@ async def delete_document(request, params_kw, *params, **kw):
|
||||
db_type = params_kw.get('db_type', '')
|
||||
collection_name = "ragdb" if not db_type else f"ragdb_{db_type}"
|
||||
try:
|
||||
required_fields = ['userid', 'file_path', 'knowledge_base_id', 'document_id']
|
||||
required_fields = ['userid', 'knowledge_base_id', 'document_id']
|
||||
missing_fields = [field for field in required_fields if field not in params_kw or not params_kw[field]]
|
||||
if missing_fields:
|
||||
raise ValueError(f"缺少必填字段: {', '.join(missing_fields)}")
|
||||
@ -382,11 +432,11 @@ async def search_query(request, params_kw, *params, **kw):
|
||||
db_type = params_kw.get('db_type', '')
|
||||
collection_name = "ragdb" if not db_type else f"ragdb_{db_type}"
|
||||
try:
|
||||
if not query_vector or not userid or not knowledge_base_ids:
|
||||
debug(f'query_vector, userid 或 knowledge_base_ids 未提供')
|
||||
if not query_vector or not userid:
|
||||
debug(f'query_vector或userid 未提供')
|
||||
return web.json_response({
|
||||
"status": "error",
|
||||
"message": "query_vector, userid 或 knowledge_base_ids 未提供",
|
||||
"message": "query_vector或userid未提供",
|
||||
"collection_name": collection_name
|
||||
}, dumps=lambda obj: json.dumps(obj, ensure_ascii=False), status=400)
|
||||
result = await engine.handle_connection("search_query", {
|
||||
@ -419,7 +469,8 @@ async def list_user_files(request, params_kw, *params, **kw):
|
||||
engine = se.engine
|
||||
userid = params_kw.get('userid', '')
|
||||
db_type = params_kw.get('db_type', '')
|
||||
collection_name = "ragdb" if not db_type else f"ragdb_{db_type}"
|
||||
collection_name = "ragdb" if not db_type or db_type == "ragdb" else f"ragdb_{db_type}"
|
||||
|
||||
try:
|
||||
if not userid:
|
||||
debug(f'userid 未提供')
|
||||
@ -428,17 +479,13 @@ async def list_user_files(request, params_kw, *params, **kw):
|
||||
"message": "userid 未提供",
|
||||
"collection_name": collection_name
|
||||
}, dumps=lambda obj: json.dumps(obj, ensure_ascii=False), status=400)
|
||||
# 直接调用底层逻辑
|
||||
result = await engine.handle_connection("list_user_files", {
|
||||
"userid": userid,
|
||||
"db_type": db_type
|
||||
})
|
||||
debug(f'{result=}')
|
||||
response = {
|
||||
"status": "success",
|
||||
"files_by_knowledge_base": result,
|
||||
"collection_name": collection_name
|
||||
}
|
||||
return web.json_response(response, dumps=lambda obj: json.dumps(obj, ensure_ascii=False))
|
||||
debug(f'底层返回: {result=}')
|
||||
return web.json_response(result, dumps=lambda obj: json.dumps(obj, ensure_ascii=False))
|
||||
except Exception as e:
|
||||
error(f'列出用户文件失败: {str(e)}')
|
||||
return web.json_response({
|
||||
@ -452,7 +499,6 @@ async def list_all_knowledge_bases(request, params_kw, *params, **kw):
|
||||
se = ServerEnv()
|
||||
engine = se.engine
|
||||
db_type = params_kw.get('db_type', '')
|
||||
collection_name = "ragdb" if not db_type else f"ragdb_{db_type}"
|
||||
try:
|
||||
result = await engine.handle_connection("list_all_knowledge_bases", {
|
||||
"db_type": db_type
|
||||
@ -460,18 +506,18 @@ async def list_all_knowledge_bases(request, params_kw, *params, **kw):
|
||||
debug(f'{result=}')
|
||||
response = {
|
||||
"status": result.get("status", "success"),
|
||||
"users_knowledge_bases": result.get("users_knowledge_bases", {}),
|
||||
"collection_name": collection_name,
|
||||
"collections": result.get("collections", {}),
|
||||
"collection_names": result.get("collection_names", "none"),
|
||||
"message": result.get("message", ""),
|
||||
"status_code": result.get("status_code", 200)
|
||||
}
|
||||
return web.json_response(response, dumps=lambda obj: json.dumps(obj, ensure_ascii=False), status=response["status_code"])
|
||||
except Exception as e:
|
||||
error(f'列出所有用户知识库失败: {str(e)}')
|
||||
error(f'列出所有知识库失败: {str(e)}')
|
||||
return web.json_response({
|
||||
"status": "error",
|
||||
"users_knowledge_bases": {},
|
||||
"collection_name": collection_name,
|
||||
"collections": {},
|
||||
"collection_names": "none",
|
||||
"message": str(e),
|
||||
"status_code": 400
|
||||
}, dumps=lambda obj: json.dumps(obj, ensure_ascii=False), status=400)
|
||||
|
||||
@ -10,7 +10,7 @@ from llmengine.base_db import connection_register, BaseDBConnection
|
||||
import time
|
||||
import traceback
|
||||
import hashlib
|
||||
|
||||
import numpy as np
|
||||
class MilvusDBConnection(BaseDBConnection):
|
||||
_instance = None
|
||||
_lock = Lock()
|
||||
@ -606,6 +606,21 @@ class MilvusDBConnection(BaseDBConnection):
|
||||
async def _search_query(self, query_vector: List[float], userid: str,
|
||||
knowledge_base_ids: List[str] | None, limit: int = 5, offset: int = 0, db_type: str = "") -> Dict[str, Any]:
|
||||
"""基于向量搜索 Milvus 集合(knowledge_base_ids 可为空)"""
|
||||
# query_vector = np.array(query_vector, dtype=np.float32)
|
||||
# original_norm = np.linalg.norm(query_vector)
|
||||
# debug(f"API 返回向量范数: {original_norm:.10f}")
|
||||
#
|
||||
# if original_norm == 0:
|
||||
# raise ValueError("zero vector")
|
||||
#
|
||||
# # 强制归一化
|
||||
# query_vector = query_vector / original_norm
|
||||
#
|
||||
# # 重新计算范数
|
||||
# final_norm = np.linalg.norm(query_vector)
|
||||
# debug(f"强制归一化后范数: {final_norm:.10f}")
|
||||
#
|
||||
# query_vector = query_vector.tolist()
|
||||
timing_stats = {}
|
||||
start_time = time.time()
|
||||
try:
|
||||
@ -649,7 +664,7 @@ class MilvusDBConnection(BaseDBConnection):
|
||||
is_ragdb = collection_name == "ragdb"
|
||||
output_fields = (
|
||||
["text", "userid", "document_id", "filename", "file_path", "upload_time", "file_type",
|
||||
"knowledge_base_id"]
|
||||
"knowledge_base_id", "vector"]
|
||||
if is_ragdb else
|
||||
["document_id", "userid", "knowledge_base_id", "text"]
|
||||
)
|
||||
@ -685,6 +700,7 @@ class MilvusDBConnection(BaseDBConnection):
|
||||
output_fields=output_fields, # 动态
|
||||
offset=offset # 保留
|
||||
)
|
||||
# debug(f"使用的召回索引是:{collection.index().describe()}")
|
||||
except Exception as e:
|
||||
error(f"搜索失败: {str(e)}")
|
||||
return {"results": [], "timing": timing_stats}
|
||||
@ -703,11 +719,12 @@ class MilvusDBConnection(BaseDBConnection):
|
||||
"file_path": hit.entity.get("file_path"),
|
||||
"upload_time": hit.entity.get("upload_time"),
|
||||
"file_type": hit.entity.get("file_type"),
|
||||
"knowledge_base_id": hit.entity.get("knowledge_base_id")
|
||||
"knowledge_base_id": hit.entity.get("knowledge_base_id"),
|
||||
"vector": hit.entity.get("vector")[:10],
|
||||
}
|
||||
raw_results.append({
|
||||
"text": hit.entity.get("text"),
|
||||
"distance": 1.0 - hit.distance,
|
||||
"distance": hit.distance,
|
||||
"metadata": metadata
|
||||
})
|
||||
else:
|
||||
@ -718,7 +735,7 @@ class MilvusDBConnection(BaseDBConnection):
|
||||
}
|
||||
raw_results.append({
|
||||
"text": hit.entity.get("text"),
|
||||
"distance": 1.0 - hit.distance,
|
||||
"distance": hit.distance,
|
||||
"metadata": metadata
|
||||
})
|
||||
|
||||
@ -921,7 +938,7 @@ class MilvusDBConnection(BaseDBConnection):
|
||||
"file_path": result.get("file_path", ""),
|
||||
"upload_time": result.get("upload_time", ""),
|
||||
"file_type": result.get("file_type", ""),
|
||||
"knowledge_base_id": kb_id
|
||||
"knowledge_base_id": kb_id,
|
||||
}
|
||||
files_by_knowledge_base.setdefault(kb_id, []).append(file_info)
|
||||
else:
|
||||
@ -935,7 +952,7 @@ class MilvusDBConnection(BaseDBConnection):
|
||||
|
||||
debug(f"找到文件: document_id={document_id}, kb_id={kb_id}, file_info={file_info}")
|
||||
|
||||
info(f"找到 {len(seen_document_ids)} 个文件")
|
||||
debug(f"找到 {len(seen_document_ids)} 个文件")
|
||||
return {
|
||||
"status": "success",
|
||||
"files_by_knowledge_base": files_by_knowledge_base,
|
||||
|
||||
@ -71,7 +71,7 @@ CLIP_MODEL_NAME = "/data/ymq/models/laion/CLIP-ViT-H-14-laion2B-s32B-b79K"
|
||||
FRAME_SAMPLE_RATE = 1.0
|
||||
FRAME_LIMIT = 64
|
||||
AUDIO_SR = 16000
|
||||
IMAGE_DIR = Path("/share/wangmeihua/data/mmembedding/image")
|
||||
IMAGE_DIR = Path("/share/wangmeihua/data/mmembedding/image/photo")
|
||||
AUDIO_DIR = Path("/share/wangmeihua/data/mmembedding/audio")
|
||||
VIDEO_DIR = Path("/share/wangmeihua/data/mmembedding/video")
|
||||
|
||||
@ -80,8 +80,11 @@ for d in [IMAGE_DIR, AUDIO_DIR, VIDEO_DIR]:
|
||||
|
||||
# ------------------- 工具函数 -------------------
|
||||
def l2_normalize(v):
|
||||
v = np.array(v, dtype=np.float32)
|
||||
norm = np.linalg.norm(v)
|
||||
return v / norm if norm > 1e-10 else v
|
||||
if norm == 0:
|
||||
return v.tolist()
|
||||
return (v / norm).tolist()
|
||||
|
||||
def chunked(lst, n):
|
||||
for i in range(0, len(lst), n):
|
||||
@ -137,7 +140,7 @@ class MM_Embedding:
|
||||
# 1. 提取 base64
|
||||
try:
|
||||
header, b64 = data_uri.split(",", 1)
|
||||
debug(f"header: {header},b64: {b64}")
|
||||
debug(f"header: {header}, b64: {b64[:100]}...({len(b64)} chars total)")
|
||||
binary = base64.b64decode(b64)
|
||||
except Exception as e:
|
||||
error(f"解码失败: {str(e)}, 堆栈: {traceback.format_exc()}")
|
||||
@ -210,7 +213,7 @@ class MM_Embedding:
|
||||
feats = self.model.get_text_features(**inputs)
|
||||
feats = feats.cpu().numpy()
|
||||
for t, v in zip(batch, feats):
|
||||
results[t] = {"type": "text", "vector": l2_normalize(v).tolist()}
|
||||
results[t] = {"type": "text", "vector": l2_normalize(v)}
|
||||
return results
|
||||
|
||||
def _embed_images(self, paths):
|
||||
@ -227,9 +230,9 @@ class MM_Embedding:
|
||||
results[p] = {
|
||||
"type": "image",
|
||||
"path": p,
|
||||
"vector": l2_normalize(v).tolist(),
|
||||
"vector": l2_normalize(v),
|
||||
"face_count": len(face_vecs),
|
||||
"face_vecs": [vec.tolist() for vec in face_vecs]
|
||||
"face_vecs": [vec for vec in face_vecs]
|
||||
}
|
||||
return results
|
||||
|
||||
@ -270,9 +273,9 @@ class MM_Embedding:
|
||||
results[p] = {
|
||||
"type": "video",
|
||||
"path": p,
|
||||
"vector": video_vec.tolist(),
|
||||
"vector": video_vec,
|
||||
"face_count": len(face_vecs),
|
||||
"face_vecs": [vec.tolist() for vec in face_vecs]
|
||||
"face_vecs": [vec for vec in face_vecs]
|
||||
}
|
||||
except Exception as e:
|
||||
exception(f"Video {p} failed: {e}")
|
||||
@ -292,7 +295,7 @@ class MM_Embedding:
|
||||
with torch.no_grad():
|
||||
with torch.amp.autocast('cuda', enabled=USE_FP16):
|
||||
feats = self.model.get_image_features(**inputs)
|
||||
results[p] = {"type": "audio", "vector": l2_normalize(feats.cpu().numpy()[0]).tolist()}
|
||||
results[p] = {"type": "audio", "vector": l2_normalize(feats.cpu().numpy()[0])}
|
||||
except Exception as e:
|
||||
exception(f"Audio {p} failed: {e}")
|
||||
results[p] = None
|
||||
|
||||
@ -22,8 +22,8 @@
|
||||
"leading":"/idfile",
|
||||
"registerfunction":"idfile"
|
||||
},{
|
||||
"leading": "/v1/m2m",
|
||||
"registerfunction": "m2m"
|
||||
"leading": "/v1/translate",
|
||||
"registerfunction": "translate"
|
||||
},{
|
||||
"leading": "/docs",
|
||||
"registerfunction": "docs"
|
||||
|
||||
Loading…
x
Reference in New Issue
Block a user