This commit is contained in:
yumoqing 2025-08-14 21:49:05 +08:00
parent 5093e345b1
commit 5ed451e47d

View File

@ -4,24 +4,24 @@ from PIL import Image
import torch import torch
class CLIPEmbedder: class CLIPEmbedder:
def __init__(self): def __init__(self):
# model_id="laion/CLIP-ViT-H-14-laion2B-s32B-b79K" # model_id="laion/CLIP-ViT-H-14-laion2B-s32B-b79K"
self.config = getConfig() self.config = getConfig()
model_path = self.config.clip_model_path model_path = self.config.clip_model_path
self.device = "cuda" if torch.cuda.is_available() else "cpu" self.device = "cuda" if torch.cuda.is_available() else "cpu"
self.model = CLIPModel.from_pretrained(model_path).to(self.device) self.model = CLIPModel.from_pretrained(model_path).to(self.device)
self.processor = CLIPProcessor.from_pretrained(model_path) self.processor = CLIPProcessor.from_pretrained(model_path)
def embed_image(self, image_path): def embed_image(self, image_path):
image = Image.open(image_path).convert("RGB") image = Image.open(image_path).convert("RGB")
inputs = self.processor(images=image, return_tensors="pt").to(self.device) inputs = self.processor(images=image, return_tensors="pt").to(self.device)
with torch.no_grad(): with torch.no_grad():
embedding = self.model.get_image_features(**inputs) embedding = self.model.get_image_features(**inputs)
return embedding[0].cpu().numpy() return embedding[0].cpu().numpy()
def embed_text(self, text): def embed_text(self, text):
inputs = self.processor(text=text, return_tensors="pt").to(self.device) inputs = self.processor(text=text, return_tensors="pt").to(self.device)
with torch.no_grad(): with torch.no_grad():
embedding = self.model.get_text_features(**inputs) embedding = self.model.get_text_features(**inputs)
return embedding[0].cpu().numpy() return embedding[0].cpu().numpy()