llmengine
This commit is contained in:
parent
ace848f996
commit
66dec8faab
@ -22,7 +22,7 @@ class BgeReranker(BaseReranker):
|
|||||||
|
|
||||||
def process_inputs(self, pairs):
|
def process_inputs(self, pairs):
|
||||||
inputs = self.tokenizer(pairs, padding=True,
|
inputs = self.tokenizer(pairs, padding=True,
|
||||||
truncation=True, return_tensors='pt', max_length=512)
|
truncation=True, return_tensors='pt', max_length=8096)
|
||||||
if torch.cuda.is_available():
|
if torch.cuda.is_available():
|
||||||
inputs = {k: v.to('cuda') for k, v in inputs.items()}
|
inputs = {k: v.to('cuda') for k, v in inputs.items()}
|
||||||
return inputs
|
return inputs
|
||||||
|
|||||||
@ -16,10 +16,25 @@ path: /v1/createcollection
|
|||||||
method: POST
|
method: POST
|
||||||
headers: {"Content-Type": "application/json"}
|
headers: {"Content-Type": "application/json"}
|
||||||
data: {
|
data: {
|
||||||
"db_type": "textdb" // 可选,若不提供则使用默认集合 ragdb
|
"db_type": "textdb", // 可选,若不提供则使用默认集合 ragdb
|
||||||
|
"fields": [ // 可选,动态字段定义,格式为 FieldSchema 列表
|
||||||
|
{
|
||||||
|
"name": "pk",
|
||||||
|
"dtype": "VARCHAR",
|
||||||
|
"is_primary": true,
|
||||||
|
"max_length": 36,
|
||||||
|
"auto_id": true
|
||||||
|
},
|
||||||
|
{
|
||||||
|
"name": "vector",
|
||||||
|
"dtype": "FLOAT_VECTOR",
|
||||||
|
"dim": 1024
|
||||||
|
},
|
||||||
|
...
|
||||||
|
]
|
||||||
}
|
}
|
||||||
response:
|
response:
|
||||||
- Success: HTTP 200, {"status": "success", "collection_name": "ragdb" or "ragdb_textdb", "message": "集合 ragdb 或 ragdb_textdb 创建成功"}
|
- Success: HTTP 200, {"status": "success", "collection_name": "ragdb" or "ragdb_textdb" or "ragdb_textdb_<hash>", "message": "集合 <collection_name> 创建成功或已存在"}
|
||||||
- Error: HTTP 400, {"status": "error", "collection_name": "ragdb" or "ragdb_textdb", "message": "<error message>"}
|
- Error: HTTP 400, {"status": "error", "collection_name": "ragdb" or "ragdb_textdb", "message": "<error message>"}
|
||||||
|
|
||||||
2. Delete Collection Endpoint:
|
2. Delete Collection Endpoint:
|
||||||
@ -216,12 +231,24 @@ response:
|
|||||||
|
|
||||||
10. Docs Endpoint:
|
10. Docs Endpoint:
|
||||||
path: /docs
|
path: /docs
|
||||||
method: GET
|
method: POST
|
||||||
response: This help text
|
response: This help text
|
||||||
|
|
||||||
|
11. Initialize Connection Endpoint:
|
||||||
|
path: /v1/initialize
|
||||||
|
method: POST
|
||||||
|
headers: {"Content-Type": "application/json"}
|
||||||
|
data: {
|
||||||
|
"db_type": "textdb" // 可选,若不提供则使用默认集合 ragdb
|
||||||
|
}
|
||||||
|
response:
|
||||||
|
- Success: HTTP 200, {"status": "success", "collection_name": "ragdb" or "ragdb_textdb", "message": "Milvus 连接已初始化,路径: <db_path>"}
|
||||||
|
- Error: HTTP 400, {"status": "error", "collection_name": "ragdb" or "ragdb_textdb", "message": "<error message>", "status_code": 400}
|
||||||
"""
|
"""
|
||||||
|
|
||||||
def init():
|
def init():
|
||||||
rf = RegisterFunction()
|
rf = RegisterFunction()
|
||||||
|
rf.register('initialize', initialize)
|
||||||
rf.register('createcollection', create_collection)
|
rf.register('createcollection', create_collection)
|
||||||
rf.register('deletecollection', delete_collection)
|
rf.register('deletecollection', delete_collection)
|
||||||
rf.register('insertdocument', insert_document)
|
rf.register('insertdocument', insert_document)
|
||||||
@ -236,22 +263,45 @@ def init():
|
|||||||
async def docs(request, params_kw, *params, **kw):
|
async def docs(request, params_kw, *params, **kw):
|
||||||
return web.Response(text=helptext, content_type='text/plain')
|
return web.Response(text=helptext, content_type='text/plain')
|
||||||
|
|
||||||
async def create_collection(request, params_kw, *params, **kw):
|
async def initialize(request, params_kw, *params, **kw):
|
||||||
debug(f'{params_kw=}')
|
debug(f'Received initialize params: {params_kw=}')
|
||||||
se = ServerEnv()
|
se = ServerEnv()
|
||||||
engine = se.engine
|
engine = se.engine
|
||||||
db_type = params_kw.get('db_type', '')
|
db_type = params_kw.get('db_type', '')
|
||||||
collection_name = "ragdb" if not db_type else f"ragdb_{db_type}"
|
collection_name = "ragdb" if not db_type else f"ragdb_{db_type}"
|
||||||
try:
|
try:
|
||||||
result = await engine.handle_connection("create_collection", {"db_type": db_type})
|
result = await engine.handle_connection("initialize", {"db_type": db_type})
|
||||||
|
debug(f'Initialize result: {result=}')
|
||||||
|
status = 200 if result.get("status") == "success" else 400
|
||||||
|
return web.json_response(result, dumps=lambda obj: json.dumps(obj, ensure_ascii=False), status=status)
|
||||||
|
except Exception as e:
|
||||||
|
error(f'初始化连接失败: {str(e)}')
|
||||||
|
return web.json_response({
|
||||||
|
"status": "error",
|
||||||
|
"collection_name": collection_name,
|
||||||
|
"message": str(e),
|
||||||
|
"status_code": 400
|
||||||
|
}, dumps=lambda obj: json.dumps(obj, ensure_ascii=False), status=400)
|
||||||
|
|
||||||
|
async def create_collection(request, params_kw, *params, **kw):
|
||||||
|
debug(f'{params_kw=}')
|
||||||
|
se = ServerEnv()
|
||||||
|
engine = se.engine
|
||||||
|
db_type = params_kw.get('db_type', '')
|
||||||
|
fields = params_kw.get('fields', None)
|
||||||
|
collection_name = "ragdb" if not db_type else f"ragdb_{db_type}"
|
||||||
|
try:
|
||||||
|
result = await engine.handle_connection("create_collection", {"db_type": db_type, "fields": fields})
|
||||||
debug(f'{result=}')
|
debug(f'{result=}')
|
||||||
return web.json_response(result, dumps=lambda obj: json.dumps(obj, ensure_ascii=False))
|
status = 200 if result.get("status") == "success" else 400
|
||||||
|
return web.json_response(result, dumps=lambda obj: json.dumps(obj, ensure_ascii=False), status=status)
|
||||||
except Exception as e:
|
except Exception as e:
|
||||||
error(f'创建集合失败: {str(e)}')
|
error(f'创建集合失败: {str(e)}')
|
||||||
return web.json_response({
|
return web.json_response({
|
||||||
"status": "error",
|
"status": "error",
|
||||||
"collection_name": collection_name,
|
"collection_name": collection_name,
|
||||||
"message": str(e)
|
"message": str(e),
|
||||||
|
"status_code": 400
|
||||||
}, dumps=lambda obj: json.dumps(obj, ensure_ascii=False), status=400)
|
}, dumps=lambda obj: json.dumps(obj, ensure_ascii=False), status=400)
|
||||||
|
|
||||||
async def delete_collection(request, params_kw, *params, **kw):
|
async def delete_collection(request, params_kw, *params, **kw):
|
||||||
@ -313,7 +363,7 @@ async def delete_document(request, params_kw, *params, **kw):
|
|||||||
db_type = params_kw.get('db_type', '')
|
db_type = params_kw.get('db_type', '')
|
||||||
collection_name = "ragdb" if not db_type else f"ragdb_{db_type}"
|
collection_name = "ragdb" if not db_type else f"ragdb_{db_type}"
|
||||||
try:
|
try:
|
||||||
required_fields = ['userid', 'file_path', 'knowledge_base_id', 'document_id']
|
required_fields = ['userid', 'knowledge_base_id', 'document_id']
|
||||||
missing_fields = [field for field in required_fields if field not in params_kw or not params_kw[field]]
|
missing_fields = [field for field in required_fields if field not in params_kw or not params_kw[field]]
|
||||||
if missing_fields:
|
if missing_fields:
|
||||||
raise ValueError(f"缺少必填字段: {', '.join(missing_fields)}")
|
raise ValueError(f"缺少必填字段: {', '.join(missing_fields)}")
|
||||||
@ -382,11 +432,11 @@ async def search_query(request, params_kw, *params, **kw):
|
|||||||
db_type = params_kw.get('db_type', '')
|
db_type = params_kw.get('db_type', '')
|
||||||
collection_name = "ragdb" if not db_type else f"ragdb_{db_type}"
|
collection_name = "ragdb" if not db_type else f"ragdb_{db_type}"
|
||||||
try:
|
try:
|
||||||
if not query_vector or not userid or not knowledge_base_ids:
|
if not query_vector or not userid:
|
||||||
debug(f'query_vector, userid 或 knowledge_base_ids 未提供')
|
debug(f'query_vector或userid 未提供')
|
||||||
return web.json_response({
|
return web.json_response({
|
||||||
"status": "error",
|
"status": "error",
|
||||||
"message": "query_vector, userid 或 knowledge_base_ids 未提供",
|
"message": "query_vector或userid未提供",
|
||||||
"collection_name": collection_name
|
"collection_name": collection_name
|
||||||
}, dumps=lambda obj: json.dumps(obj, ensure_ascii=False), status=400)
|
}, dumps=lambda obj: json.dumps(obj, ensure_ascii=False), status=400)
|
||||||
result = await engine.handle_connection("search_query", {
|
result = await engine.handle_connection("search_query", {
|
||||||
@ -419,7 +469,8 @@ async def list_user_files(request, params_kw, *params, **kw):
|
|||||||
engine = se.engine
|
engine = se.engine
|
||||||
userid = params_kw.get('userid', '')
|
userid = params_kw.get('userid', '')
|
||||||
db_type = params_kw.get('db_type', '')
|
db_type = params_kw.get('db_type', '')
|
||||||
collection_name = "ragdb" if not db_type else f"ragdb_{db_type}"
|
collection_name = "ragdb" if not db_type or db_type == "ragdb" else f"ragdb_{db_type}"
|
||||||
|
|
||||||
try:
|
try:
|
||||||
if not userid:
|
if not userid:
|
||||||
debug(f'userid 未提供')
|
debug(f'userid 未提供')
|
||||||
@ -428,17 +479,13 @@ async def list_user_files(request, params_kw, *params, **kw):
|
|||||||
"message": "userid 未提供",
|
"message": "userid 未提供",
|
||||||
"collection_name": collection_name
|
"collection_name": collection_name
|
||||||
}, dumps=lambda obj: json.dumps(obj, ensure_ascii=False), status=400)
|
}, dumps=lambda obj: json.dumps(obj, ensure_ascii=False), status=400)
|
||||||
|
# 直接调用底层逻辑
|
||||||
result = await engine.handle_connection("list_user_files", {
|
result = await engine.handle_connection("list_user_files", {
|
||||||
"userid": userid,
|
"userid": userid,
|
||||||
"db_type": db_type
|
"db_type": db_type
|
||||||
})
|
})
|
||||||
debug(f'{result=}')
|
debug(f'底层返回: {result=}')
|
||||||
response = {
|
return web.json_response(result, dumps=lambda obj: json.dumps(obj, ensure_ascii=False))
|
||||||
"status": "success",
|
|
||||||
"files_by_knowledge_base": result,
|
|
||||||
"collection_name": collection_name
|
|
||||||
}
|
|
||||||
return web.json_response(response, dumps=lambda obj: json.dumps(obj, ensure_ascii=False))
|
|
||||||
except Exception as e:
|
except Exception as e:
|
||||||
error(f'列出用户文件失败: {str(e)}')
|
error(f'列出用户文件失败: {str(e)}')
|
||||||
return web.json_response({
|
return web.json_response({
|
||||||
@ -452,7 +499,6 @@ async def list_all_knowledge_bases(request, params_kw, *params, **kw):
|
|||||||
se = ServerEnv()
|
se = ServerEnv()
|
||||||
engine = se.engine
|
engine = se.engine
|
||||||
db_type = params_kw.get('db_type', '')
|
db_type = params_kw.get('db_type', '')
|
||||||
collection_name = "ragdb" if not db_type else f"ragdb_{db_type}"
|
|
||||||
try:
|
try:
|
||||||
result = await engine.handle_connection("list_all_knowledge_bases", {
|
result = await engine.handle_connection("list_all_knowledge_bases", {
|
||||||
"db_type": db_type
|
"db_type": db_type
|
||||||
@ -460,18 +506,18 @@ async def list_all_knowledge_bases(request, params_kw, *params, **kw):
|
|||||||
debug(f'{result=}')
|
debug(f'{result=}')
|
||||||
response = {
|
response = {
|
||||||
"status": result.get("status", "success"),
|
"status": result.get("status", "success"),
|
||||||
"users_knowledge_bases": result.get("users_knowledge_bases", {}),
|
"collections": result.get("collections", {}),
|
||||||
"collection_name": collection_name,
|
"collection_names": result.get("collection_names", "none"),
|
||||||
"message": result.get("message", ""),
|
"message": result.get("message", ""),
|
||||||
"status_code": result.get("status_code", 200)
|
"status_code": result.get("status_code", 200)
|
||||||
}
|
}
|
||||||
return web.json_response(response, dumps=lambda obj: json.dumps(obj, ensure_ascii=False), status=response["status_code"])
|
return web.json_response(response, dumps=lambda obj: json.dumps(obj, ensure_ascii=False), status=response["status_code"])
|
||||||
except Exception as e:
|
except Exception as e:
|
||||||
error(f'列出所有用户知识库失败: {str(e)}')
|
error(f'列出所有知识库失败: {str(e)}')
|
||||||
return web.json_response({
|
return web.json_response({
|
||||||
"status": "error",
|
"status": "error",
|
||||||
"users_knowledge_bases": {},
|
"collections": {},
|
||||||
"collection_name": collection_name,
|
"collection_names": "none",
|
||||||
"message": str(e),
|
"message": str(e),
|
||||||
"status_code": 400
|
"status_code": 400
|
||||||
}, dumps=lambda obj: json.dumps(obj, ensure_ascii=False), status=400)
|
}, dumps=lambda obj: json.dumps(obj, ensure_ascii=False), status=400)
|
||||||
|
|||||||
@ -10,7 +10,7 @@ from llmengine.base_db import connection_register, BaseDBConnection
|
|||||||
import time
|
import time
|
||||||
import traceback
|
import traceback
|
||||||
import hashlib
|
import hashlib
|
||||||
|
import numpy as np
|
||||||
class MilvusDBConnection(BaseDBConnection):
|
class MilvusDBConnection(BaseDBConnection):
|
||||||
_instance = None
|
_instance = None
|
||||||
_lock = Lock()
|
_lock = Lock()
|
||||||
@ -606,6 +606,21 @@ class MilvusDBConnection(BaseDBConnection):
|
|||||||
async def _search_query(self, query_vector: List[float], userid: str,
|
async def _search_query(self, query_vector: List[float], userid: str,
|
||||||
knowledge_base_ids: List[str] | None, limit: int = 5, offset: int = 0, db_type: str = "") -> Dict[str, Any]:
|
knowledge_base_ids: List[str] | None, limit: int = 5, offset: int = 0, db_type: str = "") -> Dict[str, Any]:
|
||||||
"""基于向量搜索 Milvus 集合(knowledge_base_ids 可为空)"""
|
"""基于向量搜索 Milvus 集合(knowledge_base_ids 可为空)"""
|
||||||
|
# query_vector = np.array(query_vector, dtype=np.float32)
|
||||||
|
# original_norm = np.linalg.norm(query_vector)
|
||||||
|
# debug(f"API 返回向量范数: {original_norm:.10f}")
|
||||||
|
#
|
||||||
|
# if original_norm == 0:
|
||||||
|
# raise ValueError("zero vector")
|
||||||
|
#
|
||||||
|
# # 强制归一化
|
||||||
|
# query_vector = query_vector / original_norm
|
||||||
|
#
|
||||||
|
# # 重新计算范数
|
||||||
|
# final_norm = np.linalg.norm(query_vector)
|
||||||
|
# debug(f"强制归一化后范数: {final_norm:.10f}")
|
||||||
|
#
|
||||||
|
# query_vector = query_vector.tolist()
|
||||||
timing_stats = {}
|
timing_stats = {}
|
||||||
start_time = time.time()
|
start_time = time.time()
|
||||||
try:
|
try:
|
||||||
@ -649,7 +664,7 @@ class MilvusDBConnection(BaseDBConnection):
|
|||||||
is_ragdb = collection_name == "ragdb"
|
is_ragdb = collection_name == "ragdb"
|
||||||
output_fields = (
|
output_fields = (
|
||||||
["text", "userid", "document_id", "filename", "file_path", "upload_time", "file_type",
|
["text", "userid", "document_id", "filename", "file_path", "upload_time", "file_type",
|
||||||
"knowledge_base_id"]
|
"knowledge_base_id", "vector"]
|
||||||
if is_ragdb else
|
if is_ragdb else
|
||||||
["document_id", "userid", "knowledge_base_id", "text"]
|
["document_id", "userid", "knowledge_base_id", "text"]
|
||||||
)
|
)
|
||||||
@ -685,6 +700,7 @@ class MilvusDBConnection(BaseDBConnection):
|
|||||||
output_fields=output_fields, # 动态
|
output_fields=output_fields, # 动态
|
||||||
offset=offset # 保留
|
offset=offset # 保留
|
||||||
)
|
)
|
||||||
|
# debug(f"使用的召回索引是:{collection.index().describe()}")
|
||||||
except Exception as e:
|
except Exception as e:
|
||||||
error(f"搜索失败: {str(e)}")
|
error(f"搜索失败: {str(e)}")
|
||||||
return {"results": [], "timing": timing_stats}
|
return {"results": [], "timing": timing_stats}
|
||||||
@ -703,11 +719,12 @@ class MilvusDBConnection(BaseDBConnection):
|
|||||||
"file_path": hit.entity.get("file_path"),
|
"file_path": hit.entity.get("file_path"),
|
||||||
"upload_time": hit.entity.get("upload_time"),
|
"upload_time": hit.entity.get("upload_time"),
|
||||||
"file_type": hit.entity.get("file_type"),
|
"file_type": hit.entity.get("file_type"),
|
||||||
"knowledge_base_id": hit.entity.get("knowledge_base_id")
|
"knowledge_base_id": hit.entity.get("knowledge_base_id"),
|
||||||
|
"vector": hit.entity.get("vector")[:10],
|
||||||
}
|
}
|
||||||
raw_results.append({
|
raw_results.append({
|
||||||
"text": hit.entity.get("text"),
|
"text": hit.entity.get("text"),
|
||||||
"distance": 1.0 - hit.distance,
|
"distance": hit.distance,
|
||||||
"metadata": metadata
|
"metadata": metadata
|
||||||
})
|
})
|
||||||
else:
|
else:
|
||||||
@ -718,7 +735,7 @@ class MilvusDBConnection(BaseDBConnection):
|
|||||||
}
|
}
|
||||||
raw_results.append({
|
raw_results.append({
|
||||||
"text": hit.entity.get("text"),
|
"text": hit.entity.get("text"),
|
||||||
"distance": 1.0 - hit.distance,
|
"distance": hit.distance,
|
||||||
"metadata": metadata
|
"metadata": metadata
|
||||||
})
|
})
|
||||||
|
|
||||||
@ -921,7 +938,7 @@ class MilvusDBConnection(BaseDBConnection):
|
|||||||
"file_path": result.get("file_path", ""),
|
"file_path": result.get("file_path", ""),
|
||||||
"upload_time": result.get("upload_time", ""),
|
"upload_time": result.get("upload_time", ""),
|
||||||
"file_type": result.get("file_type", ""),
|
"file_type": result.get("file_type", ""),
|
||||||
"knowledge_base_id": kb_id
|
"knowledge_base_id": kb_id,
|
||||||
}
|
}
|
||||||
files_by_knowledge_base.setdefault(kb_id, []).append(file_info)
|
files_by_knowledge_base.setdefault(kb_id, []).append(file_info)
|
||||||
else:
|
else:
|
||||||
@ -935,7 +952,7 @@ class MilvusDBConnection(BaseDBConnection):
|
|||||||
|
|
||||||
debug(f"找到文件: document_id={document_id}, kb_id={kb_id}, file_info={file_info}")
|
debug(f"找到文件: document_id={document_id}, kb_id={kb_id}, file_info={file_info}")
|
||||||
|
|
||||||
info(f"找到 {len(seen_document_ids)} 个文件")
|
debug(f"找到 {len(seen_document_ids)} 个文件")
|
||||||
return {
|
return {
|
||||||
"status": "success",
|
"status": "success",
|
||||||
"files_by_knowledge_base": files_by_knowledge_base,
|
"files_by_knowledge_base": files_by_knowledge_base,
|
||||||
|
|||||||
@ -71,7 +71,7 @@ CLIP_MODEL_NAME = "/data/ymq/models/laion/CLIP-ViT-H-14-laion2B-s32B-b79K"
|
|||||||
FRAME_SAMPLE_RATE = 1.0
|
FRAME_SAMPLE_RATE = 1.0
|
||||||
FRAME_LIMIT = 64
|
FRAME_LIMIT = 64
|
||||||
AUDIO_SR = 16000
|
AUDIO_SR = 16000
|
||||||
IMAGE_DIR = Path("/share/wangmeihua/data/mmembedding/image")
|
IMAGE_DIR = Path("/share/wangmeihua/data/mmembedding/image/photo")
|
||||||
AUDIO_DIR = Path("/share/wangmeihua/data/mmembedding/audio")
|
AUDIO_DIR = Path("/share/wangmeihua/data/mmembedding/audio")
|
||||||
VIDEO_DIR = Path("/share/wangmeihua/data/mmembedding/video")
|
VIDEO_DIR = Path("/share/wangmeihua/data/mmembedding/video")
|
||||||
|
|
||||||
@ -80,8 +80,11 @@ for d in [IMAGE_DIR, AUDIO_DIR, VIDEO_DIR]:
|
|||||||
|
|
||||||
# ------------------- 工具函数 -------------------
|
# ------------------- 工具函数 -------------------
|
||||||
def l2_normalize(v):
|
def l2_normalize(v):
|
||||||
|
v = np.array(v, dtype=np.float32)
|
||||||
norm = np.linalg.norm(v)
|
norm = np.linalg.norm(v)
|
||||||
return v / norm if norm > 1e-10 else v
|
if norm == 0:
|
||||||
|
return v.tolist()
|
||||||
|
return (v / norm).tolist()
|
||||||
|
|
||||||
def chunked(lst, n):
|
def chunked(lst, n):
|
||||||
for i in range(0, len(lst), n):
|
for i in range(0, len(lst), n):
|
||||||
@ -137,7 +140,7 @@ class MM_Embedding:
|
|||||||
# 1. 提取 base64
|
# 1. 提取 base64
|
||||||
try:
|
try:
|
||||||
header, b64 = data_uri.split(",", 1)
|
header, b64 = data_uri.split(",", 1)
|
||||||
debug(f"header: {header},b64: {b64}")
|
debug(f"header: {header}, b64: {b64[:100]}...({len(b64)} chars total)")
|
||||||
binary = base64.b64decode(b64)
|
binary = base64.b64decode(b64)
|
||||||
except Exception as e:
|
except Exception as e:
|
||||||
error(f"解码失败: {str(e)}, 堆栈: {traceback.format_exc()}")
|
error(f"解码失败: {str(e)}, 堆栈: {traceback.format_exc()}")
|
||||||
@ -210,7 +213,7 @@ class MM_Embedding:
|
|||||||
feats = self.model.get_text_features(**inputs)
|
feats = self.model.get_text_features(**inputs)
|
||||||
feats = feats.cpu().numpy()
|
feats = feats.cpu().numpy()
|
||||||
for t, v in zip(batch, feats):
|
for t, v in zip(batch, feats):
|
||||||
results[t] = {"type": "text", "vector": l2_normalize(v).tolist()}
|
results[t] = {"type": "text", "vector": l2_normalize(v)}
|
||||||
return results
|
return results
|
||||||
|
|
||||||
def _embed_images(self, paths):
|
def _embed_images(self, paths):
|
||||||
@ -227,9 +230,9 @@ class MM_Embedding:
|
|||||||
results[p] = {
|
results[p] = {
|
||||||
"type": "image",
|
"type": "image",
|
||||||
"path": p,
|
"path": p,
|
||||||
"vector": l2_normalize(v).tolist(),
|
"vector": l2_normalize(v),
|
||||||
"face_count": len(face_vecs),
|
"face_count": len(face_vecs),
|
||||||
"face_vecs": [vec.tolist() for vec in face_vecs]
|
"face_vecs": [vec for vec in face_vecs]
|
||||||
}
|
}
|
||||||
return results
|
return results
|
||||||
|
|
||||||
@ -270,9 +273,9 @@ class MM_Embedding:
|
|||||||
results[p] = {
|
results[p] = {
|
||||||
"type": "video",
|
"type": "video",
|
||||||
"path": p,
|
"path": p,
|
||||||
"vector": video_vec.tolist(),
|
"vector": video_vec,
|
||||||
"face_count": len(face_vecs),
|
"face_count": len(face_vecs),
|
||||||
"face_vecs": [vec.tolist() for vec in face_vecs]
|
"face_vecs": [vec for vec in face_vecs]
|
||||||
}
|
}
|
||||||
except Exception as e:
|
except Exception as e:
|
||||||
exception(f"Video {p} failed: {e}")
|
exception(f"Video {p} failed: {e}")
|
||||||
@ -292,7 +295,7 @@ class MM_Embedding:
|
|||||||
with torch.no_grad():
|
with torch.no_grad():
|
||||||
with torch.amp.autocast('cuda', enabled=USE_FP16):
|
with torch.amp.autocast('cuda', enabled=USE_FP16):
|
||||||
feats = self.model.get_image_features(**inputs)
|
feats = self.model.get_image_features(**inputs)
|
||||||
results[p] = {"type": "audio", "vector": l2_normalize(feats.cpu().numpy()[0]).tolist()}
|
results[p] = {"type": "audio", "vector": l2_normalize(feats.cpu().numpy()[0])}
|
||||||
except Exception as e:
|
except Exception as e:
|
||||||
exception(f"Audio {p} failed: {e}")
|
exception(f"Audio {p} failed: {e}")
|
||||||
results[p] = None
|
results[p] = None
|
||||||
|
|||||||
@ -22,8 +22,8 @@
|
|||||||
"leading":"/idfile",
|
"leading":"/idfile",
|
||||||
"registerfunction":"idfile"
|
"registerfunction":"idfile"
|
||||||
},{
|
},{
|
||||||
"leading": "/v1/m2m",
|
"leading": "/v1/translate",
|
||||||
"registerfunction": "m2m"
|
"registerfunction": "translate"
|
||||||
},{
|
},{
|
||||||
"leading": "/docs",
|
"leading": "/docs",
|
||||||
"registerfunction": "docs"
|
"registerfunction": "docs"
|
||||||
|
|||||||
Loading…
x
Reference in New Issue
Block a user