bugfic
This commit is contained in:
parent
f2f6390608
commit
86589ec2e8
@ -9,8 +9,20 @@ Features:
|
|||||||
- Video frame sampling + average pooling
|
- Video frame sampling + average pooling
|
||||||
- Audio resampling + CLAP embedding
|
- Audio resampling + CLAP embedding
|
||||||
- L2 normalized output for similarity search
|
- L2 normalized output for similarity search
|
||||||
|
|
||||||
|
model_name='/data/ymq/models/laion/CLIP-ViT-B-32-laion2B-s34B-b79K'
|
||||||
|
|
||||||
|
impput:
|
||||||
|
|
||||||
|
text:
|
||||||
|
{
|
||||||
|
"type":"text,
|
||||||
|
"text":"...."
|
||||||
|
}
|
||||||
|
image:
|
||||||
"""
|
"""
|
||||||
|
|
||||||
|
|
||||||
import os
|
import os
|
||||||
from pathlib import Path
|
from pathlib import Path
|
||||||
import numpy as np
|
import numpy as np
|
||||||
@ -35,6 +47,13 @@ except Exception:
|
|||||||
DEVICE = "cuda" if torch.cuda.is_available() else "mps" if getattr(torch.backends, "mps", None) and torch.backends.mps.is_available() else "cpu"
|
DEVICE = "cuda" if torch.cuda.is_available() else "mps" if getattr(torch.backends, "mps", None) and torch.backends.mps.is_available() else "cpu"
|
||||||
USE_FP16 = DEVICE == "cuda"
|
USE_FP16 = DEVICE == "cuda"
|
||||||
|
|
||||||
|
def choose_device():
|
||||||
|
if torch.cuda.is_available():
|
||||||
|
return "cuda"
|
||||||
|
if getattr(torch.backends, "mps", None) and torch.backends.mps.is_available():
|
||||||
|
return "mps"
|
||||||
|
return "cpu"
|
||||||
|
|
||||||
# Unified model for all modalities
|
# Unified model for all modalities
|
||||||
CLIP_MODEL_NAME = "openai/clip-vit-large-patch14"
|
CLIP_MODEL_NAME = "openai/clip-vit-large-patch14"
|
||||||
FRAME_SAMPLE_RATE = 1.0 # fps for video
|
FRAME_SAMPLE_RATE = 1.0 # fps for video
|
||||||
|
|||||||
Loading…
x
Reference in New Issue
Block a user