62 lines
2.0 KiB
Python
62 lines
2.0 KiB
Python
# -*- coding:utf-8 -*-
|
|
"""CLIP ViT-H/14 lazy-loading wrapper."""
|
|
import os
|
|
import torch
|
|
import numpy as np
|
|
from PIL import Image
|
|
from io import BytesIO
|
|
import base64
|
|
import urllib.request
|
|
|
|
MODEL_PATH = '/data/ymq/models/laion/CLIP-ViT-H-14-laion2B-s32B-b79K'
|
|
|
|
_model = None
|
|
_processor = None
|
|
_device = None
|
|
|
|
|
|
def _load():
|
|
global _model, _processor, _device
|
|
if _model is not None:
|
|
return
|
|
# CUDA_VISIBLE_DEVICES is set in start.sh, so GPU 0 in visible devices is our target
|
|
_device = torch.device('cuda:0' if torch.cuda.is_available() else 'cpu')
|
|
from transformers import CLIPModel, CLIPProcessor
|
|
_processor = CLIPProcessor.from_pretrained(MODEL_PATH)
|
|
_model = CLIPModel.from_pretrained(MODEL_PATH, torch_dtype=torch.float16)
|
|
_model = _model.to(_device)
|
|
_model.eval()
|
|
print(f'[CLIP] Model loaded on {_device}, dtype=float16')
|
|
|
|
|
|
def embed_texts(texts):
|
|
_load()
|
|
inputs = _processor(text=texts, return_tensors='pt', padding=True, truncation=True, max_length=77)
|
|
inputs = {k: v.to(_device) for k, v in inputs.items()}
|
|
with torch.no_grad():
|
|
outputs = _model.get_text_features(**inputs)
|
|
outputs = outputs / outputs.norm(dim=-1, keepdim=True)
|
|
return outputs.cpu().float().numpy().tolist()
|
|
|
|
|
|
def _load_image(src):
|
|
if src.startswith('data:'):
|
|
_, b64 = src.split(',', 1)
|
|
return Image.open(BytesIO(base64.b64decode(b64))).convert('RGB')
|
|
elif src.startswith('http://') or src.startswith('https://'):
|
|
with urllib.request.urlopen(src, timeout=30) as resp:
|
|
return Image.open(BytesIO(resp.read())).convert('RGB')
|
|
else:
|
|
return Image.open(src).convert('RGB')
|
|
|
|
|
|
def embed_images(sources):
|
|
_load()
|
|
images = [_load_image(s) for s in sources]
|
|
inputs = _processor(images=images, return_tensors='pt')
|
|
inputs = {k: v.to(_device) for k, v in inputs.items()}
|
|
with torch.no_grad():
|
|
outputs = _model.get_image_features(**inputs)
|
|
outputs = outputs / outputs.norm(dim=-1, keepdim=True)
|
|
return outputs.cpu().float().numpy().tolist()
|