This commit is contained in:
ymq1 2025-10-15 16:25:31 +08:00
parent 86589ec2e8
commit 04905889bf

View File

@ -91,6 +91,11 @@ class MM_Embedding:
if USE_FP16: if USE_FP16:
self.model.half() self.model.half()
def detect_faces(self, img):
faces = self.extract_faces(img)
face_vecs = self.embed_faces(img)
return face_vecs, faces
# ------------------- Image ------------------- # ------------------- Image -------------------
def embed_images(self, paths, batch_size=16): def embed_images(self, paths, batch_size=16):
results = {} results = {}
@ -106,16 +111,14 @@ class MM_Embedding:
feats = feats.cpu().numpy() feats = feats.cpu().numpy()
faces_list = [] faces_list = []
for img in imgs: for img in imgs:
faces = self.extract_faces(img) faces_list.append(self.detect_faces(img))
face_vecs = self.embed_faces(img)
faces_list.append([faces, face_vecs])
for p, v, fs in zip(batch, feats, faces_list): for p, v, fs in zip(batch, feats, faces_list):
results[p] = { results[p] = {
'type':'image', 'type':'image',
'path': p, 'path': p,
'faces': fs[0], 'faces': fs[1],
'face_vecs': fs[1], 'face_vecs': fs[0],
'face_count':len(fs[0]), 'face_count':len(fs[0]),
'vector': l2_normalize(v) 'vector': l2_normalize(v)
} }
@ -134,7 +137,10 @@ class MM_Embedding:
feats = self.model.get_text_features(**inputs) feats = self.model.get_text_features(**inputs)
feats = feats.cpu().numpy() feats = feats.cpu().numpy()
for t, v in zip(batch, feats): for t, v in zip(batch, feats):
results[t] = l2_normalize(v) results[t] = {
"type": "text",
"vector": l2_normalize(v)
}
return results return results
# ------------------- Video ------------------- # ------------------- Video -------------------
@ -158,6 +164,7 @@ class MM_Embedding:
continue continue
# batch embed # batch embed
emb_list = [] emb_list = []
faces_list = []
for batch in chunked(frames, 16): for batch in chunked(frames, 16):
inputs = self.processor(images=batch, return_tensors="pt", padding=True).to(DEVICE) inputs = self.processor(images=batch, return_tensors="pt", padding=True).to(DEVICE)
with torch.no_grad(): with torch.no_grad():
@ -166,11 +173,20 @@ class MM_Embedding:
feats = self.model.get_image_features(**inputs) feats = self.model.get_image_features(**inputs)
else: else:
feats = self.model.get_image_features(**inputs) feats = self.model.get_image_features(**inputs)
for img in batch:
faces_list += self.detect_faces(img)[0]
emb_list.append(feats.cpu().numpy()) emb_list.append(feats.cpu().numpy())
face_vecs = deduplicate_faces(faces_list)
emb_array = np.vstack(emb_list) emb_array = np.vstack(emb_list)
video_vec = l2_normalize(emb_array.mean(axis=0)) video_vec = l2_normalize(emb_array.mean(axis=0))
face_vecs = # face_vecs =
results[p] = video_vec results[p] = {
"type": "video",
"path": p,
"vector": video_vec,
"face_count": len(face_vecs),
"face_vecs": face_vecs
}
return results return results
# ------------------- Audio ------------------- # ------------------- Audio -------------------