|
import argparse |
|
import json |
|
import logging |
|
import os |
|
import pathlib |
|
import pickle |
|
import sys |
|
from typing import Dict, Optional, Tuple |
|
|
|
import numpy as np |
|
import torch |
|
import torchvision |
|
import tqdm |
|
import webdataset |
|
from ocl import visualizations |
|
from ocl.cli import train |
|
from omegaconf import OmegaConf |
|
from PIL import Image |
|
from torch.utils.data import DataLoader, Dataset |
|
from torchvision import transforms |
|
from transformers import (AutoImageProcessor, AutoTokenizer, CLIPModel, |
|
CLIPProcessor, Dinov2Model, T5EncoderModel) |
|
|
|
import handlers |
|
from llm2vec import LLM2Vec |
|
|
|
logging.getLogger().setLevel(logging.INFO) |
|
|
|
|
|
|
|
CHECKPOINTS = { |
|
"checkpoint": "/home/mila/r/rabiul.awal/scratch/oclf/checkpoints/ctrlo_rebuttal/checkpoints/epoch=1-step=65000.ckpt", |
|
"config": "/home/mila/r/rabiul.awal/scratch/oclf/checkpoints/ctrlo_rebuttal/config/config.yaml", |
|
} |
|
|
|
|
|
l2v = LLM2Vec.from_pretrained( |
|
"McGill-NLP/LLM2Vec-Meta-Llama-3-8B-Instruct-mntp", |
|
peft_model_name_or_path="McGill-NLP/LLM2Vec-Meta-Llama-3-8B-Instruct-mntp-unsup-simcse", |
|
device_map="cuda" if torch.cuda.is_available() else "cpu", |
|
torch_dtype=torch.bfloat16, |
|
) |
|
|
|
|
|
def get_shard_pattern(path: str): |
|
base_pattern: str = "shard-%06d.tar" |
|
return os.path.join(path, base_pattern) |
|
|
|
class FeatureExtractor: |
|
"""Handles feature extraction for multiple vision models.""" |
|
|
|
def __init__(self, device="cuda", batch_size=32): |
|
self.device = device |
|
self._init_models() |
|
self._init_transforms() |
|
|
|
def _init_models(self): |
|
|
|
config_path = CHECKPOINTS[ |
|
"config" |
|
] |
|
encoder_checkpoint_path = CHECKPOINTS[ |
|
"checkpoint" |
|
] |
|
oclf_config = OmegaConf.load(config_path) |
|
self.model = train.build_model_from_config( |
|
oclf_config, encoder_checkpoint_path |
|
).to(self.device) |
|
self.model.eval() |
|
|
|
def _init_transforms(self): |
|
self.base_transform = transforms.Compose( |
|
[ |
|
transforms.Resize( |
|
(224, 224), |
|
interpolation=torchvision.transforms.InterpolationMode.BILINEAR, |
|
), |
|
transforms.ToTensor(), |
|
transforms.Normalize( |
|
mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225] |
|
), |
|
] |
|
) |
|
|
|
@torch.no_grad() |
|
def extract_features_batch(self, images, prompts): |
|
"""Extract features for a batch of images.""" |
|
images = torch.stack([self.base_transform(image) for image in images]).to( |
|
self.device |
|
) |
|
name_embeddings = torch.stack([l2v.encode(prompt) for prompt in prompts]).to(self.device) |
|
bsz = images.shape[0] |
|
inputs = { |
|
"image": images, |
|
"bbox_centroids": torch.tensor([[-1, -1]] * 7, dtype=torch.float32) |
|
.unsqueeze(0) |
|
.to(self.device), |
|
|
|
|
|
|
|
|
|
"contrastive_loss_mask": torch.stack([torch.tensor([int(p != "other") for p in prompt]) for prompt in prompts]).to(self.device), |
|
"name_embedding": name_embeddings, |
|
"instance_bbox": torch.tensor( |
|
[[-1, -1, -1, -1]] * 7, dtype=torch.float32 |
|
) |
|
.repeat(bsz, 1, 1) |
|
.to(self.device), |
|
"batch_size": bsz, |
|
} |
|
outputs = self.model(inputs) |
|
features = outputs["perceptual_grouping"].objects |
|
|
|
|
|
return features.cpu().numpy() |
|
|
|
|
|
|
|
prompts = [ |
|
["The orange bag on the skier's back.", "The brown pants worn by the skier.", "Beautiful green trees in ice", "other", "other", "other", "other"], |
|
["Man going down ski slope", "The snow is white", "other", "other", "other", "other", "other"], |
|
] |
|
images = [ |
|
"/path/to/image1.jpg", |
|
"/path/to/image2.jpg", |
|
] |
|
|
|
feature_extractor = FeatureExtractor() |
|
features = feature_extractor.extract_features_batch(images, prompts) |
|
print(features) |
|
|
|
|
|
|