File size: 4,897 Bytes
d099e9f
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
import argparse
import json
import logging
import os
import pathlib
import pickle
import sys
from typing import Dict, Optional, Tuple

import numpy as np
import torch
import torchvision
import tqdm
import webdataset
from ocl import visualizations
from ocl.cli import train
from omegaconf import OmegaConf
from PIL import Image
from torch.utils.data import DataLoader, Dataset
from torchvision import transforms
from transformers import (AutoImageProcessor, AutoTokenizer, CLIPModel,
                          CLIPProcessor, Dinov2Model, T5EncoderModel)

import handlers
from llm2vec import LLM2Vec

logging.getLogger().setLevel(logging.INFO)


# TODO: Use CVPR submission checkpoints --- these checkpoints are recent I suppose
CHECKPOINTS = {
    "checkpoint": "/home/mila/r/rabiul.awal/scratch/oclf/checkpoints/ctrlo_rebuttal/checkpoints/epoch=1-step=65000.ckpt",
    "config": "/home/mila/r/rabiul.awal/scratch/oclf/checkpoints/ctrlo_rebuttal/config/config.yaml",
}


l2v = LLM2Vec.from_pretrained(
    "McGill-NLP/LLM2Vec-Meta-Llama-3-8B-Instruct-mntp",
    peft_model_name_or_path="McGill-NLP/LLM2Vec-Meta-Llama-3-8B-Instruct-mntp-unsup-simcse",
    device_map="cuda" if torch.cuda.is_available() else "cpu",
    torch_dtype=torch.bfloat16,
)


def get_shard_pattern(path: str):
    base_pattern: str = "shard-%06d.tar"
    return os.path.join(path, base_pattern)

class FeatureExtractor:
    """Handles feature extraction for multiple vision models."""

    def __init__(self, device="cuda", batch_size=32):
        self.device = device
        self._init_models()
        self._init_transforms()

    def _init_models(self):
        # Initialize Ctrlo
        config_path = CHECKPOINTS[
            "config"
        ]  # "/home/mila/a/aniket.didolkar/scratch/lang_oclf_github/language_conditioned_oclf/outputs/projects/prompting/vg/prompt_vg_small14_dinov2_mapping_lang_point_pred_sep/2024-10-10_15-05-41/config/config.yaml"
        encoder_checkpoint_path = CHECKPOINTS[
            "checkpoint"
        ]  # "/home/mila/a/aniket.didolkar/scratch/lang_oclf_github/language_conditioned_oclf/outputs/projects/prompting/vg/prompt_vg_small14_dinov2_mapping_lang_point_pred_sep/2024-10-10_15-05-41/checkpoints/epoch=0-step=243000.ckpt"
        oclf_config = OmegaConf.load(config_path)
        self.model = train.build_model_from_config(
            oclf_config, encoder_checkpoint_path
        ).to(self.device)
        self.model.eval()

    def _init_transforms(self):
        self.base_transform = transforms.Compose(
            [
                transforms.Resize(
                    (224, 224),
                    interpolation=torchvision.transforms.InterpolationMode.BILINEAR,
                ),
                transforms.ToTensor(),
                transforms.Normalize(
                    mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225]
                ),
            ]
        )

    @torch.no_grad()
    def extract_features_batch(self, images, prompts):
        """Extract features for a batch of images."""
        images = torch.stack([self.base_transform(image) for image in images]).to(
            self.device
        )  # TODO: check if this is correct
        name_embeddings = torch.stack([l2v.encode(prompt) for prompt in prompts]).to(self.device)
        bsz = images.shape[0]
        inputs = {
            "image": images,
            "bbox_centroids": torch.tensor([[-1, -1]] * 7, dtype=torch.float32)
            .unsqueeze(0)
            .to(self.device),
            # "contrastive_loss_mask": torch.zeros(7).repeat(bsz, 1).to(self.device),
            # "name_embedding": torch.randn(bsz, 7, 512).to(
            #     self.device
            # ),  # TODO: replace with actual name embeddings
            "contrastive_loss_mask": torch.stack([torch.tensor([int(p != "other") for p in prompt]) for prompt in prompts]).to(self.device),
            "name_embedding": name_embeddings,
            "instance_bbox": torch.tensor(
                [[-1, -1, -1, -1]] * 7, dtype=torch.float32
            )
            .repeat(bsz, 1, 1)
            .to(self.device),  # TODO: this field was not before
            "batch_size": bsz,
        }
        outputs = self.model(inputs)
        features = outputs["perceptual_grouping"].objects

        # make sure feature shape makes sense
        return features.cpu().numpy()


# you can specify upto 7 regions or objects phrases, the rest will be "other"
prompts = [
    ["The orange bag on the skier's back.", "The brown pants worn by the skier.", "Beautiful green trees in ice", "other", "other", "other", "other"],
    ["Man going down ski slope", "The snow is white", "other", "other", "other", "other", "other"],
]
images = [
    "/path/to/image1.jpg",
    "/path/to/image2.jpg",
]

feature_extractor = FeatureExtractor()
features = feature_extractor.extract_features_batch(images, prompts)
print(features)