Create README.md
Browse files
README.md
ADDED
@@ -0,0 +1,61 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
# Don’t Judge Before You CLIP: Memorability Prediction Model
|
2 |
+
|
3 |
+
This model is part of our paper:
|
4 |
+
*"Don’t Judge Before You CLIP: A Unified Approach for Perceptual Tasks"*
|
5 |
+
It was trained on the *LaMem dataset* to predict image memorability scores.
|
6 |
+
|
7 |
+
## Model Overview
|
8 |
+
|
9 |
+
Visual perceptual tasks, such as image memorability prediction, aim to estimate how humans perceive and interpret images. Unlike objective tasks (e.g., object recognition), these tasks rely on subjective human judgment, making labeled data scarce.
|
10 |
+
|
11 |
+
Our approach leverages *CLIP* as a prior for perceptual tasks, inspired by cognitive research suggesting that CLIP implicitly captures human biases, emotions, and preferences. We fine-tune CLIP minimally using *LoRA* to adapt it to each specific task.
|
12 |
+
|
13 |
+
## Training Details
|
14 |
+
|
15 |
+
- *Dataset*: [LaMem](http://memorability.csail.mit.edu/download.html) (Large-Scale Image Memorability)
|
16 |
+
- *Architecture*: CLIP Vision Encoder (ViT-L/14) with *LoRA adaptation*
|
17 |
+
- *Loss Function*: Mean Squared Error (MSE) Loss for memorability prediction
|
18 |
+
- *Optimizer*: AdamW
|
19 |
+
- *Learning Rate*: 0.00005
|
20 |
+
- *Batch Size*: 32
|
21 |
+
|
22 |
+
## Performance
|
23 |
+
|
24 |
+
The model was trained on the *LaMem dataset* and exhibits *state-of-the-art generalization* to the *THINGS memorability dataset*.
|
25 |
+
For more models and results on the *five common splits* of LaMem, please refer to our paper.
|
26 |
+
|
27 |
+
## Usage
|
28 |
+
|
29 |
+
To use the model for inference:
|
30 |
+
|
31 |
+
```python
|
32 |
+
from torchvision import transforms
|
33 |
+
import torch
|
34 |
+
from PIL import Image
|
35 |
+
|
36 |
+
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
|
37 |
+
|
38 |
+
# Load model
|
39 |
+
model = torch.load("lamem_all_clip_Lora_16.0R_8.0alphaLora_32_batch_0.00005_lossmse_headmlp.pth").to(device).eval()
|
40 |
+
|
41 |
+
# Load an image
|
42 |
+
image = Image.open("image_path.jpg").convert("RGB")
|
43 |
+
|
44 |
+
# Preprocess and predict
|
45 |
+
def Mem_augmentations():
|
46 |
+
transform = transforms.Compose([
|
47 |
+
transforms.Resize(224),
|
48 |
+
transforms.CenterCrop(size=(224,224)),
|
49 |
+
transforms.ToTensor(),
|
50 |
+
# Note: The model normalizes the image inside the forward pass
|
51 |
+
# using mean = (0.48145466, 0.4578275, 0.40821073) and
|
52 |
+
# std = (0.26862954, 0.26130258, 0.27577711)
|
53 |
+
])
|
54 |
+
return transform
|
55 |
+
|
56 |
+
image = Mem_augmentations()(image).unsqueeze(0).to(device)
|
57 |
+
|
58 |
+
with torch.no_grad():
|
59 |
+
mem_score = model(image).item()
|
60 |
+
|
61 |
+
print(f"Predicted Memorability Score: {mem_score:.4f}")
|