rag
This commit is contained in:
parent
0c6c8d9a3f
commit
ace848f996
@ -22,5 +22,5 @@ class BaseM2M:
|
||||
def __init__(self, model_id, **kw):
|
||||
self.model_id = model_id
|
||||
|
||||
def m2m(self, texts: str, src_lang: str, tgt_lang: str) -> str:
|
||||
def translate(self, texts: str, src_lang: str, tgt_lang: str) -> str:
|
||||
raise NotImplementedError
|
||||
@ -15,7 +15,7 @@ class M2M100Translator(BaseM2M):
|
||||
self.model.eval()
|
||||
self.model_name = model_id.split('/')[-1]
|
||||
|
||||
def m2m(self, text: str, src_lang: str, tgt_lang: str) -> str:
|
||||
def translate(self, text: str, src_lang: str, tgt_lang: str) -> str:
|
||||
"""翻译一段话"""
|
||||
self.tokenizer.src_lang = src_lang
|
||||
encoded = self.tokenizer(
|
||||
|
||||
@ -13,7 +13,7 @@ from .base_m2m import get_llm_class
|
||||
|
||||
|
||||
helptext = """M2M100 翻译 API:
|
||||
POST /v1/m2m
|
||||
POST /v1/translate
|
||||
Headers:
|
||||
Content-Type: application/json
|
||||
|
||||
@ -39,7 +39,7 @@ Response:
|
||||
|
||||
def init():
|
||||
rf = RegisterFunction()
|
||||
rf.register('m2m', m2m)
|
||||
rf.register('translate', translate)
|
||||
rf.register('docs', docs)
|
||||
|
||||
|
||||
@ -47,7 +47,7 @@ async def docs(request, params_kw, *params, **kw):
|
||||
return helptext
|
||||
|
||||
|
||||
async def m2m(request, params_kw, *params, **kw):
|
||||
async def translate(request, params_kw, *params, **kw):
|
||||
debug(f'{params_kw=}')
|
||||
se = ServerEnv()
|
||||
engine = se.engine
|
||||
@ -60,7 +60,7 @@ async def m2m(request, params_kw, *params, **kw):
|
||||
if not text or not isinstance(text, str):
|
||||
raise Exception("`text` must be a non-empty string")
|
||||
|
||||
f = awaitify(engine.m2m)
|
||||
f = awaitify(engine.translate)
|
||||
translation = await f(text, src_lang, tgt_lang)
|
||||
|
||||
ret = {
|
||||
@ -88,6 +88,19 @@ def main():
|
||||
se = ServerEnv()
|
||||
se.engine = Klass(args.model_path)
|
||||
|
||||
if torch.cuda.is_available():
|
||||
logical_id = torch.cuda.current_device()
|
||||
physical_id = int(os.environ.get("CUDA_VISIBLE_DEVICES", "0").split(",")[0])
|
||||
gpu_name = torch.cuda.get_device_name(logical_id)
|
||||
mem_used = torch.cuda.memory_allocated(logical_id) / 1024 ** 3
|
||||
mem_total = torch.cuda.get_device_properties(logical_id).total_memory / 1024 ** 3
|
||||
|
||||
debug(f"\n我正在使用 物理 GPU {physical_id} → 逻辑 GPU {logical_id}")
|
||||
debug(f"显卡型号: {gpu_name}")
|
||||
debug(f"显存占用: {mem_used:.1f} GB / {mem_total:.1f} GB\n")
|
||||
else:
|
||||
debug("\n我在 CPU 上跑\n")
|
||||
|
||||
debug(f"Starting M2M100 service on port {args.port}")
|
||||
webserver(init, args.workdir, args.port)
|
||||
|
||||
|
||||
File diff suppressed because it is too large
Load Diff
@ -1,81 +1,84 @@
|
||||
# embed_all_unified.py
|
||||
"""
|
||||
Unified multimodal embedder (text, image, video, audio)
|
||||
Features:
|
||||
- All modalities mapped to the same embedding space (CLIP or CLAP)
|
||||
- GPU/CPU/MPS auto detection
|
||||
- FP16 autocast for speed
|
||||
- Batch processing
|
||||
- Video frame sampling + average pooling
|
||||
- Audio resampling + CLAP embedding
|
||||
- L2 normalized output for similarity search
|
||||
|
||||
model_name='/data/ymq/models/laion/CLIP-ViT-B-32-laion2B-s34B-b79K'
|
||||
|
||||
impput:
|
||||
|
||||
text:
|
||||
{
|
||||
"type":"text,
|
||||
"text":"...."
|
||||
}
|
||||
image:
|
||||
"""
|
||||
|
||||
|
||||
import os
|
||||
from pathlib import Path
|
||||
import argparse
|
||||
import numpy as np
|
||||
import torch
|
||||
from PIL import Image
|
||||
import av
|
||||
import librosa
|
||||
from concurrent.futures import ThreadPoolExecutor
|
||||
from pathlib import Path
|
||||
from math import ceil
|
||||
from appPublic.jsonConfig import getConfig
|
||||
from appPublic.worker import awaitify
|
||||
from ahserver.webapp import webapp
|
||||
from ahserver.serverenv import ServerEnv
|
||||
from traceback import format_exc
|
||||
|
||||
from appPublic.jsonConfig import getConfig
|
||||
from appPublic.registerfunction import RegisterFunction
|
||||
from appPublic.log import debug, exception
|
||||
|
||||
from ahserver.serverenv import ServerEnv
|
||||
from ahserver.webapp import webserver
|
||||
|
||||
from transformers import CLIPProcessor, CLIPModel
|
||||
from sklearn.preprocessing import normalize
|
||||
from sklearn.cluster import DBSCAN
|
||||
import base64
|
||||
import traceback
|
||||
import uuid
|
||||
|
||||
helptext = """CLIP 多模态统一嵌入服务
|
||||
API 地址:
|
||||
POST http://localhost:8883/v1/embed
|
||||
功能:
|
||||
将 文本 / 图片 / 视频 / 音频 统一转为 1024 维 L2 归一化向量
|
||||
支持人脸检测 + 去重(基于 DBSCAN 余弦距离)
|
||||
|
||||
输入格式 (JSON):
|
||||
{
|
||||
"inputs": [
|
||||
"文本字符串",
|
||||
"/path/to/image.jpg",
|
||||
"/path/to/video.mp4",
|
||||
"/path/to/audio.wav"
|
||||
]
|
||||
}
|
||||
|
||||
输出格式:
|
||||
{
|
||||
"data": { ... },
|
||||
"object": "embedding.result",
|
||||
"model": "CLIP-ViT-H-14-laion2B-s32B-b79K"
|
||||
}
|
||||
|
||||
特性:
|
||||
- 自动识别文件类型(文本/图像/音频/视频)
|
||||
- 视频抽帧 + 人脸去重
|
||||
- 文件不存在 → 自动降级为文本嵌入
|
||||
- 所有向量 L2 归一化,可直接余弦相似度
|
||||
- 人脸向量支持聚类去重(eps=0.4)
|
||||
|
||||
文档查看:
|
||||
GET http://localhost:8883/docs
|
||||
"""
|
||||
try:
|
||||
import face_recognition
|
||||
FACE_LIB_AVAILABLE = True
|
||||
except Exception:
|
||||
debug('人脸识别库导入失败')
|
||||
FACE_LIB_AVAILABLE = False
|
||||
|
||||
# ------------------- Configuration -------------------
|
||||
# ------------------- 配置 -------------------
|
||||
DEVICE = "cuda" if torch.cuda.is_available() else "mps" if getattr(torch.backends, "mps", None) and torch.backends.mps.is_available() else "cpu"
|
||||
USE_FP16 = DEVICE == "cuda"
|
||||
|
||||
def choose_device():
|
||||
if torch.cuda.is_available():
|
||||
return "cuda"
|
||||
if getattr(torch.backends, "mps", None) and torch.backends.mps.is_available():
|
||||
return "mps"
|
||||
return "cpu"
|
||||
|
||||
# Unified model for all modalities
|
||||
CLIP_MODEL_NAME = "openai/clip-vit-large-patch14"
|
||||
FRAME_SAMPLE_RATE = 1.0 # fps for video
|
||||
CLIP_MODEL_NAME = "/data/ymq/models/laion/CLIP-ViT-H-14-laion2B-s32B-b79K"
|
||||
FRAME_SAMPLE_RATE = 1.0
|
||||
FRAME_LIMIT = 64
|
||||
AUDIO_SR = 16000 # resample audio
|
||||
AUDIO_SR = 16000
|
||||
IMAGE_DIR = Path("/share/wangmeihua/data/mmembedding/image")
|
||||
AUDIO_DIR = Path("/share/wangmeihua/data/mmembedding/audio")
|
||||
VIDEO_DIR = Path("/share/wangmeihua/data/mmembedding/video")
|
||||
|
||||
# ------------------- Load model -------------------
|
||||
from transformers import CLIPProcessor, CLIPModel
|
||||
|
||||
# ------------------- Utils -------------------
|
||||
|
||||
def deduplicate_faces(face_embeddings, eps=0.4, min_samples=2):
|
||||
emb_norm = normalize(face_embeddings)
|
||||
clustering = DBSCAN(eps=eps, min_samples=min_samples, metric="cosine").fit(emb_norm)
|
||||
unique_faces = []
|
||||
for label in set(clustering.labels_):
|
||||
if label == -1: # 噪声
|
||||
continue
|
||||
cluster_embs = emb_norm[clustering.labels_ == label]
|
||||
unique_faces.append(np.mean(cluster_embs, axis=0))
|
||||
return np.array(unique_faces)
|
||||
for d in [IMAGE_DIR, AUDIO_DIR, VIDEO_DIR]:
|
||||
d.mkdir(parents=True, exist_ok=True)
|
||||
|
||||
# ------------------- 工具函数 -------------------
|
||||
def l2_normalize(v):
|
||||
norm = np.linalg.norm(v)
|
||||
return v / norm if norm > 1e-10 else v
|
||||
@ -84,192 +87,275 @@ def chunked(lst, n):
|
||||
for i in range(0, len(lst), n):
|
||||
yield lst[i:i + n]
|
||||
|
||||
def deduplicate_faces(face_embeddings, eps=0.4, min_samples=2):
|
||||
if len(face_embeddings) == 0:
|
||||
return np.array([])
|
||||
emb_norm = normalize(face_embeddings)
|
||||
clustering = DBSCAN(eps=eps, min_samples=min_samples, metric="cosine").fit(emb_norm)
|
||||
unique = []
|
||||
for label in set(clustering.labels_):
|
||||
if label == -1:
|
||||
continue
|
||||
cluster = emb_norm[clustering.labels_ == label]
|
||||
unique.append(np.mean(cluster, axis=0))
|
||||
return np.array(unique)
|
||||
|
||||
# ------------------- 主模型类 -------------------
|
||||
class MM_Embedding:
|
||||
def __init__(self, model_name):
|
||||
self.model = CLIPModel.from_pretrained(model_name).to(DEVICE)
|
||||
debug(f"Loading CLIP model: {model_name}")
|
||||
self.model_name = Path(model_name).name
|
||||
self.model = CLIPModel.from_pretrained(
|
||||
model_name,
|
||||
torch_dtype=torch.float16 if USE_FP16 else torch.float32,
|
||||
device_map="auto" if USE_FP16 else None
|
||||
).to(DEVICE).eval()
|
||||
self.processor = CLIPProcessor.from_pretrained(model_name)
|
||||
if USE_FP16:
|
||||
self.model.half()
|
||||
|
||||
def detect_faces(self, img):
|
||||
faces = self.extract_faces(img)
|
||||
face_vecs = self.embed_faces(img)
|
||||
return face_vecs, faces
|
||||
def embed_batch(self, inputs):
|
||||
if not isinstance(inputs, list):
|
||||
raise ValueError("inputs must be a list")
|
||||
if len(inputs) == 0:
|
||||
return {}
|
||||
|
||||
# ------------------- Image -------------------
|
||||
def embed_images(self, paths, batch_size=16):
|
||||
groups = {"image": [], "video": [], "audio": [], "text": []}
|
||||
results = {}
|
||||
for batch in chunked(paths, batch_size):
|
||||
imgs = [Image.open(p).convert("RGB") for p in batch]
|
||||
inputs = self.processor(images=imgs, return_tensors="pt", padding=True).to(DEVICE)
|
||||
with torch.no_grad():
|
||||
if USE_FP16:
|
||||
with torch.cuda.amp.autocast():
|
||||
feats = self.model.get_image_features(**inputs)
|
||||
else:
|
||||
feats = self.model.get_image_features(**inputs)
|
||||
feats = feats.cpu().numpy()
|
||||
faces_list = []
|
||||
for img in imgs:
|
||||
faces_list.append(self.detect_faces(img))
|
||||
for item in inputs:
|
||||
# ------------------- 新格式:字典 -------------------
|
||||
if isinstance(item, dict):
|
||||
typ = item.get("type")
|
||||
data_uri = item.get("data")
|
||||
|
||||
for p, v, fs in zip(batch, feats, faces_list):
|
||||
results[p] = {
|
||||
'type':'image',
|
||||
'path': p,
|
||||
'faces': fs[1],
|
||||
'face_vecs': fs[0],
|
||||
'face_count':len(fs[0]),
|
||||
'vector': l2_normalize(v)
|
||||
if typ == "text":
|
||||
content = item.get("content", "")
|
||||
if content:
|
||||
groups["text"].append(content)
|
||||
continue
|
||||
|
||||
if typ in {"image", "video", "audio"} and data_uri:
|
||||
try:
|
||||
# 1. 提取 base64
|
||||
try:
|
||||
header, b64 = data_uri.split(",", 1)
|
||||
debug(f"header: {header},b64: {b64}")
|
||||
binary = base64.b64decode(b64)
|
||||
except Exception as e:
|
||||
error(f"解码失败: {str(e)}, 堆栈: {traceback.format_exc()}")
|
||||
|
||||
# 2. 确定扩展名
|
||||
mime_to_ext = {
|
||||
"image/jpeg": ".jpg",
|
||||
"image/png": ".png",
|
||||
"image/webp": ".webp",
|
||||
"video/mp4": ".mp4",
|
||||
"video/webm": ".webm",
|
||||
"audio/mpeg": ".mp3",
|
||||
"audio/wav": ".wav",
|
||||
"audio/ogg": ".ogg",
|
||||
}
|
||||
mime = header.split(";")[0].split(":")[1]
|
||||
ext = mime_to_ext.get(mime, ".bin")
|
||||
|
||||
# 3. 生成唯一文件名 + 存储
|
||||
uid = uuid.uuid4().hex[:12]
|
||||
if typ == "image":
|
||||
save_dir = IMAGE_DIR
|
||||
fake_path = save_dir / f"{uid}{ext}"
|
||||
fake_path = str(fake_path)
|
||||
elif typ == "video":
|
||||
save_dir = VIDEO_DIR
|
||||
fake_path = save_dir / f"{uid}{ext}"
|
||||
fake_path = str(fake_path)
|
||||
elif typ == "audio":
|
||||
save_dir = AUDIO_DIR
|
||||
fake_path = save_dir / f"{uid}{ext}"
|
||||
fake_path = str(fake_path)
|
||||
|
||||
Path(fake_path).write_bytes(binary)
|
||||
debug(f"保存多媒体文件: {fake_path} ({len(binary) / 1024 / 1024:.2f}MB)")
|
||||
|
||||
# 4. 放入对应 group(CLIP 直接用路径)
|
||||
if typ == "image":
|
||||
groups["image"].append(fake_path)
|
||||
elif typ == "video":
|
||||
groups["video"].append(fake_path)
|
||||
elif typ == "audio":
|
||||
groups["audio"].append(fake_path)
|
||||
|
||||
# 记录原始来源(可选)
|
||||
results[fake_path] = {"type": typ, "source": "data_uri", "original_mime": mime}
|
||||
continue
|
||||
|
||||
except Exception as e:
|
||||
results[id(item)] = {"type": "error", "error": f"data URI 解码失败: {e}"}
|
||||
continue
|
||||
|
||||
if groups["image"]:
|
||||
results.update(self._embed_images(groups["image"]))
|
||||
if groups["video"]:
|
||||
results.update(self._embed_videos(groups["video"]))
|
||||
if groups["audio"]:
|
||||
results.update(self._embed_audios(groups["audio"]))
|
||||
if groups["text"]:
|
||||
results.update(self._embed_texts(groups["text"]))
|
||||
debug(f"最终返回结果是:{results}")
|
||||
return results
|
||||
|
||||
# ------------------- Text -------------------
|
||||
def embed_texts(self, texts, batch_size=64):
|
||||
def _embed_texts(self, texts):
|
||||
results = {}
|
||||
for batch in chunked(texts, batch_size):
|
||||
for batch in chunked(texts, 64):
|
||||
inputs = self.processor(text=batch, return_tensors="pt", padding=True, truncation=True).to(DEVICE)
|
||||
with torch.no_grad():
|
||||
if USE_FP16:
|
||||
with torch.cuda.amp.autocast():
|
||||
feats = self.model.get_text_features(**inputs)
|
||||
else:
|
||||
with torch.amp.autocast('cuda', enabled=USE_FP16):
|
||||
feats = self.model.get_text_features(**inputs)
|
||||
feats = feats.cpu().numpy()
|
||||
for t, v in zip(batch, feats):
|
||||
results[t] = {
|
||||
"type": "text",
|
||||
"vector": l2_normalize(v)
|
||||
results[t] = {"type": "text", "vector": l2_normalize(v).tolist()}
|
||||
return results
|
||||
|
||||
def _embed_images(self, paths):
|
||||
results = {}
|
||||
for batch in chunked(paths, 16):
|
||||
imgs = [Image.open(p).convert("RGB") for p in batch]
|
||||
inputs = self.processor(images=imgs, return_tensors="pt", padding=True).to(DEVICE)
|
||||
with torch.no_grad():
|
||||
with torch.amp.autocast('cuda', enabled=USE_FP16):
|
||||
feats = self.model.get_image_features(**inputs)
|
||||
feats = feats.cpu().numpy()
|
||||
for p, v, img in zip(batch, feats, imgs):
|
||||
face_vecs, _ = self._detect_faces(img)
|
||||
results[p] = {
|
||||
"type": "image",
|
||||
"path": p,
|
||||
"vector": l2_normalize(v).tolist(),
|
||||
"face_count": len(face_vecs),
|
||||
"face_vecs": [vec.tolist() for vec in face_vecs]
|
||||
}
|
||||
return results
|
||||
|
||||
# ------------------- Video -------------------
|
||||
def embed_videos(self, paths, frame_rate=FRAME_SAMPLE_RATE, frame_limit=FRAME_LIMIT):
|
||||
def _embed_videos(self, paths):
|
||||
results = {}
|
||||
for p in paths:
|
||||
try:
|
||||
container = av.open(p)
|
||||
frames = []
|
||||
fps = float(container.streams.video[0].average_rate) if container.streams.video else 30.0
|
||||
step = max(1, int(fps / max(1, frame_rate)))
|
||||
count = 0
|
||||
fps = float(container.streams.video[0].average_rate) or 30.0
|
||||
step = max(1, int(fps / FRAME_SAMPLE_RATE))
|
||||
for i, frame in enumerate(container.decode(video=0)):
|
||||
if i % step == 0:
|
||||
frames.append(frame.to_image().convert("RGB"))
|
||||
count += 1
|
||||
if count >= frame_limit:
|
||||
if len(frames) >= FRAME_LIMIT:
|
||||
break
|
||||
container.close()
|
||||
|
||||
if not frames:
|
||||
results[p] = None
|
||||
continue
|
||||
# batch embed
|
||||
|
||||
emb_list = []
|
||||
faces_list = []
|
||||
all_faces = []
|
||||
for batch in chunked(frames, 16):
|
||||
inputs = self.processor(images=batch, return_tensors="pt", padding=True).to(DEVICE)
|
||||
with torch.no_grad():
|
||||
if USE_FP16:
|
||||
with torch.cuda.amp.autocast():
|
||||
feats = self.model.get_image_features(**inputs)
|
||||
else:
|
||||
with torch.amp.autocast('cuda', enabled=USE_FP16):
|
||||
feats = self.model.get_image_features(**inputs)
|
||||
for img in batch:
|
||||
faces_list += self.detect_faces(img)[0]
|
||||
fv, _ = self._detect_faces(img)
|
||||
all_faces.extend(fv)
|
||||
emb_list.append(feats.cpu().numpy())
|
||||
face_vecs = deduplicate_faces(faces_list)
|
||||
emb_array = np.vstack(emb_list)
|
||||
video_vec = l2_normalize(emb_array.mean(axis=0))
|
||||
# face_vecs =
|
||||
|
||||
face_vecs = deduplicate_faces(all_faces)
|
||||
video_vec = l2_normalize(np.vstack(emb_list).mean(axis=0))
|
||||
|
||||
results[p] = {
|
||||
"type": "video",
|
||||
"path": p,
|
||||
"vector": video_vec,
|
||||
"vector": video_vec.tolist(),
|
||||
"face_count": len(face_vecs),
|
||||
"face_vecs": face_vecs
|
||||
"face_vecs": [vec.tolist() for vec in face_vecs]
|
||||
}
|
||||
except Exception as e:
|
||||
exception(f"Video {p} failed: {e}")
|
||||
results[p] = None
|
||||
return results
|
||||
|
||||
# ------------------- Audio -------------------
|
||||
def embed_audios(self, paths, batch_size=4):
|
||||
def _embed_audios(self, paths):
|
||||
results = {}
|
||||
for p in paths:
|
||||
try:
|
||||
y, sr = librosa.load(p, sr=AUDIO_SR, mono=True)
|
||||
# convert to mel spectrogram image
|
||||
S = librosa.feature.melspectrogram(y=y, sr=sr, n_mels=224)
|
||||
S_db = librosa.power_to_db(S, ref=np.max)
|
||||
img = Image.fromarray(np.uint8((S_db - S_db.min())/(S_db.max()-S_db.min()+1e-9)*255)).convert("RGB").resize((224,224))
|
||||
norm_val = (S_db - S_db.min()) / (S_db.max() - S_db.min() + 1e-9)
|
||||
img = Image.fromarray(np.uint8(norm_val * 255)).convert("RGB").resize((224, 224))
|
||||
inputs = self.processor(images=img, return_tensors="pt").to(DEVICE)
|
||||
with torch.no_grad():
|
||||
if USE_FP16:
|
||||
with torch.cuda.amp.autocast():
|
||||
feat = self.model.get_image_features(**inputs)
|
||||
else:
|
||||
feat = self.model.get_image_features(**inputs)
|
||||
results[p] = l2_normalize(feat.cpu().numpy()[0])
|
||||
with torch.amp.autocast('cuda', enabled=USE_FP16):
|
||||
feats = self.model.get_image_features(**inputs)
|
||||
results[p] = {"type": "audio", "vector": l2_normalize(feats.cpu().numpy()[0]).tolist()}
|
||||
except Exception as e:
|
||||
exception(f"Audio {p} failed: {e}")
|
||||
results[p] = None
|
||||
return results
|
||||
|
||||
def extract_faces(self, img: Image.Image):
|
||||
"""返回裁剪后的人脸区域列表"""
|
||||
def _detect_faces(self, img):
|
||||
arr = np.array(img)
|
||||
face_locs = face_recognition.face_locations(arr)
|
||||
faces = []
|
||||
for (top, right, bottom, left) in face_locs:
|
||||
face = arr[top:bottom, left:right]
|
||||
faces.append(Image.fromarray(face))
|
||||
return faces
|
||||
locs = face_recognition.face_locations(arr)
|
||||
debug(f'图片的人脸位置信息:{locs}')
|
||||
encodings = face_recognition.face_encodings(arr, known_face_locations=locs)
|
||||
vecs = [l2_normalize(np.array(e)) for e in encodings] if encodings else []
|
||||
debug(f'图片的人脸向量是:{vecs}')
|
||||
return vecs, []
|
||||
|
||||
def embed_faces(self, img: Image.Image):
|
||||
"""提取人脸向量(face_recognition + CLIP)"""
|
||||
arr = np.array(img)
|
||||
encodings = face_recognition.face_encodings(arr)
|
||||
if not encodings:
|
||||
return []
|
||||
return [l2_normalize(np.array(e)) for e in encodings]
|
||||
# ------------------- API 路由(完全模仿 m2m) -------------------
|
||||
async def embed(request, params_kw, *params, **kw):
|
||||
debug(f'{params_kw=}')
|
||||
|
||||
# ------------------- Dispatcher -------------------
|
||||
def embed_batch(self, inputs):
|
||||
groups = {"image":[], "video":[], "audio":[], "text":[]}
|
||||
for item in inputs:
|
||||
p = Path(item)
|
||||
ext = item.lower()
|
||||
if p.exists():
|
||||
if any(ext.endswith(e) for e in [".jpg",".jpeg",".png",".bmp",".webp",".heic"]):
|
||||
groups["image"].append(item)
|
||||
elif any(ext.endswith(e) for e in [".mp4",".mov",".avi",".mkv"]):
|
||||
groups["video"].append(item)
|
||||
elif any(ext.endswith(e) for e in [".mp3",".wav",".flac"]):
|
||||
groups["audio"].append(item)
|
||||
else:
|
||||
groups["text"].append(item)
|
||||
else:
|
||||
groups["text"].append(item)
|
||||
outputs = {}
|
||||
if groups["image"]:
|
||||
outputs.update(embed_images(groups["image"]))
|
||||
if groups["video"]:
|
||||
outputs.update(embed_videos(groups["video"]))
|
||||
if groups["audio"]:
|
||||
outputs.update(embed_audios(groups["audio"]))
|
||||
if groups["text"]:
|
||||
outputs.update(embed_texts(groups["text"]))
|
||||
return outputs
|
||||
se = ServerEnv()
|
||||
engine = se.engine
|
||||
|
||||
# 从 params_kw 获取参数
|
||||
inputs = getattr(params_kw, 'inputs', None)
|
||||
if not inputs or not isinstance(inputs, list):
|
||||
raise Exception("`inputs` must be a non-empty list")
|
||||
|
||||
# 调用嵌入
|
||||
raw_result = engine.embed_batch(inputs)
|
||||
|
||||
# 构建标准响应
|
||||
ret = {
|
||||
"data": raw_result,
|
||||
"object": "embedding.result",
|
||||
"model": engine.model_name
|
||||
}
|
||||
return ret
|
||||
|
||||
async def docs(request, *args, **kw):
|
||||
return helptext
|
||||
|
||||
# ------------------- 服务初始化 -------------------
|
||||
def init():
|
||||
env = ServerEnv()
|
||||
config = getConfig()
|
||||
env.mm_model = MM_Embedding(config.model_name)
|
||||
env.embeded_batch = awaitify(env.mm_model.embeded_batch)
|
||||
# ------------------- CLI -------------------
|
||||
if __name__ == "__main__":
|
||||
import argparse
|
||||
parser = argparse.ArgumentParser()
|
||||
parser.add_argument("inputs", nargs="+", help="file paths or text strings")
|
||||
parser.add_argument("--out", default="embeddings.npy")
|
||||
rf = RegisterFunction()
|
||||
rf.register('embed', embed)
|
||||
rf.register('docs', docs)
|
||||
debug("Registered: POST /v1/embed")
|
||||
|
||||
# ------------------- 服务启动 -------------------
|
||||
def main():
|
||||
parser = argparse.ArgumentParser(prog="CLIP Embedding Service")
|
||||
parser.add_argument('model_path', nargs='?', help="CLIP model path")
|
||||
parser.add_argument('-p', '--port', type=int, default=8883)
|
||||
parser.add_argument('-w', '--workdir', default=os.getcwd())
|
||||
args = parser.parse_args()
|
||||
|
||||
embeddings = embed_batch(args.inputs)
|
||||
# save dict of name->vector
|
||||
out_dict = {k:v.tolist() for k,v in embeddings.items()}
|
||||
np.save(args.out, out_dict)
|
||||
print(f"Saved embeddings to {args.out}")
|
||||
config = getConfig()
|
||||
model_name = args.model_path or config.get("model_name") or CLIP_MODEL_NAME
|
||||
|
||||
se = ServerEnv()
|
||||
se.engine = MM_Embedding(model_name)
|
||||
|
||||
debug(f"Starting embedding service on port {args.port}")
|
||||
webserver(init, args.workdir, args.port)
|
||||
|
||||
if __name__ == '__main__':
|
||||
main()
|
||||
Loading…
x
Reference in New Issue
Block a user