# -*- coding:utf-8 -*- """CLIP ViT-H/14 lazy-loading wrapper.""" import os import torch import numpy as np from PIL import Image from io import BytesIO import base64 import urllib.request MODEL_PATH = '/data/ymq/models/laion/CLIP-ViT-H-14-laion2B-s32B-b79K' _model = None _processor = None _device = None def _load(): global _model, _processor, _device if _model is not None: return # CUDA_VISIBLE_DEVICES is set in start.sh, so GPU 0 in visible devices is our target _device = torch.device('cuda:0' if torch.cuda.is_available() else 'cpu') from transformers import CLIPModel, CLIPProcessor _processor = CLIPProcessor.from_pretrained(MODEL_PATH) _model = CLIPModel.from_pretrained(MODEL_PATH, torch_dtype=torch.float16) _model = _model.to(_device) _model.eval() print(f'[CLIP] Model loaded on {_device}, dtype=float16') def embed_texts(texts): _load() inputs = _processor(text=texts, return_tensors='pt', padding=True, truncation=True, max_length=77) inputs = {k: v.to(_device) for k, v in inputs.items()} with torch.no_grad(): outputs = _model.get_text_features(**inputs) outputs = outputs / outputs.norm(dim=-1, keepdim=True) return outputs.cpu().float().numpy().tolist() def _load_image(src): if src.startswith('data:'): _, b64 = src.split(',', 1) return Image.open(BytesIO(base64.b64decode(b64))).convert('RGB') elif src.startswith('http://') or src.startswith('https://'): with urllib.request.urlopen(src, timeout=30) as resp: return Image.open(BytesIO(resp.read())).convert('RGB') else: return Image.open(src).convert('RGB') def embed_images(sources): _load() images = [_load_image(s) for s in sources] inputs = _processor(images=images, return_tensors='pt') inputs = {k: v.to(_device) for k, v in inputs.items()} with torch.no_grad(): outputs = _model.get_image_features(**inputs) outputs = outputs / outputs.norm(dim=-1, keepdim=True) return outputs.cpu().float().numpy().tolist()