File size: 4,897 Bytes
d099e9f |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 |
import argparse
import json
import logging
import os
import pathlib
import pickle
import sys
from typing import Dict, Optional, Tuple
import numpy as np
import torch
import torchvision
import tqdm
import webdataset
from ocl import visualizations
from ocl.cli import train
from omegaconf import OmegaConf
from PIL import Image
from torch.utils.data import DataLoader, Dataset
from torchvision import transforms
from transformers import (AutoImageProcessor, AutoTokenizer, CLIPModel,
CLIPProcessor, Dinov2Model, T5EncoderModel)
import handlers
from llm2vec import LLM2Vec
logging.getLogger().setLevel(logging.INFO)
# TODO: Use CVPR submission checkpoints --- these checkpoints are recent I suppose
CHECKPOINTS = {
"checkpoint": "/home/mila/r/rabiul.awal/scratch/oclf/checkpoints/ctrlo_rebuttal/checkpoints/epoch=1-step=65000.ckpt",
"config": "/home/mila/r/rabiul.awal/scratch/oclf/checkpoints/ctrlo_rebuttal/config/config.yaml",
}
l2v = LLM2Vec.from_pretrained(
"McGill-NLP/LLM2Vec-Meta-Llama-3-8B-Instruct-mntp",
peft_model_name_or_path="McGill-NLP/LLM2Vec-Meta-Llama-3-8B-Instruct-mntp-unsup-simcse",
device_map="cuda" if torch.cuda.is_available() else "cpu",
torch_dtype=torch.bfloat16,
)
def get_shard_pattern(path: str):
base_pattern: str = "shard-%06d.tar"
return os.path.join(path, base_pattern)
class FeatureExtractor:
"""Handles feature extraction for multiple vision models."""
def __init__(self, device="cuda", batch_size=32):
self.device = device
self._init_models()
self._init_transforms()
def _init_models(self):
# Initialize Ctrlo
config_path = CHECKPOINTS[
"config"
] # "/home/mila/a/aniket.didolkar/scratch/lang_oclf_github/language_conditioned_oclf/outputs/projects/prompting/vg/prompt_vg_small14_dinov2_mapping_lang_point_pred_sep/2024-10-10_15-05-41/config/config.yaml"
encoder_checkpoint_path = CHECKPOINTS[
"checkpoint"
] # "/home/mila/a/aniket.didolkar/scratch/lang_oclf_github/language_conditioned_oclf/outputs/projects/prompting/vg/prompt_vg_small14_dinov2_mapping_lang_point_pred_sep/2024-10-10_15-05-41/checkpoints/epoch=0-step=243000.ckpt"
oclf_config = OmegaConf.load(config_path)
self.model = train.build_model_from_config(
oclf_config, encoder_checkpoint_path
).to(self.device)
self.model.eval()
def _init_transforms(self):
self.base_transform = transforms.Compose(
[
transforms.Resize(
(224, 224),
interpolation=torchvision.transforms.InterpolationMode.BILINEAR,
),
transforms.ToTensor(),
transforms.Normalize(
mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225]
),
]
)
@torch.no_grad()
def extract_features_batch(self, images, prompts):
"""Extract features for a batch of images."""
images = torch.stack([self.base_transform(image) for image in images]).to(
self.device
) # TODO: check if this is correct
name_embeddings = torch.stack([l2v.encode(prompt) for prompt in prompts]).to(self.device)
bsz = images.shape[0]
inputs = {
"image": images,
"bbox_centroids": torch.tensor([[-1, -1]] * 7, dtype=torch.float32)
.unsqueeze(0)
.to(self.device),
# "contrastive_loss_mask": torch.zeros(7).repeat(bsz, 1).to(self.device),
# "name_embedding": torch.randn(bsz, 7, 512).to(
# self.device
# ), # TODO: replace with actual name embeddings
"contrastive_loss_mask": torch.stack([torch.tensor([int(p != "other") for p in prompt]) for prompt in prompts]).to(self.device),
"name_embedding": name_embeddings,
"instance_bbox": torch.tensor(
[[-1, -1, -1, -1]] * 7, dtype=torch.float32
)
.repeat(bsz, 1, 1)
.to(self.device), # TODO: this field was not before
"batch_size": bsz,
}
outputs = self.model(inputs)
features = outputs["perceptual_grouping"].objects
# make sure feature shape makes sense
return features.cpu().numpy()
# you can specify upto 7 regions or objects phrases, the rest will be "other"
prompts = [
["The orange bag on the skier's back.", "The brown pants worn by the skier.", "Beautiful green trees in ice", "other", "other", "other", "other"],
["Man going down ski slope", "The snow is white", "other", "other", "other", "other", "other"],
]
images = [
"/path/to/image1.jpg",
"/path/to/image2.jpg",
]
feature_extractor = FeatureExtractor()
features = feature_extractor.extract_features_batch(images, prompts)
print(features)
|