import argparse import json import logging import os import pathlib import pickle import sys from typing import Dict, Optional, Tuple import numpy as np import torch import torchvision import tqdm import webdataset from ocl import visualizations from ocl.cli import train from omegaconf import OmegaConf from PIL import Image from torch.utils.data import DataLoader, Dataset from torchvision import transforms from transformers import (AutoImageProcessor, AutoTokenizer, CLIPModel, CLIPProcessor, Dinov2Model, T5EncoderModel) import handlers from llm2vec import LLM2Vec logging.getLogger().setLevel(logging.INFO) # TODO: Use CVPR submission checkpoints --- these checkpoints are recent I suppose CHECKPOINTS = { "checkpoint": "/home/mila/r/rabiul.awal/scratch/oclf/checkpoints/ctrlo_rebuttal/checkpoints/epoch=1-step=65000.ckpt", "config": "/home/mila/r/rabiul.awal/scratch/oclf/checkpoints/ctrlo_rebuttal/config/config.yaml", } l2v = LLM2Vec.from_pretrained( "McGill-NLP/LLM2Vec-Meta-Llama-3-8B-Instruct-mntp", peft_model_name_or_path="McGill-NLP/LLM2Vec-Meta-Llama-3-8B-Instruct-mntp-unsup-simcse", device_map="cuda" if torch.cuda.is_available() else "cpu", torch_dtype=torch.bfloat16, ) def get_shard_pattern(path: str): base_pattern: str = "shard-%06d.tar" return os.path.join(path, base_pattern) class FeatureExtractor: """Handles feature extraction for multiple vision models.""" def __init__(self, device="cuda", batch_size=32): self.device = device self._init_models() self._init_transforms() def _init_models(self): # Initialize Ctrlo config_path = CHECKPOINTS[ "config" ] # "/home/mila/a/aniket.didolkar/scratch/lang_oclf_github/language_conditioned_oclf/outputs/projects/prompting/vg/prompt_vg_small14_dinov2_mapping_lang_point_pred_sep/2024-10-10_15-05-41/config/config.yaml" encoder_checkpoint_path = CHECKPOINTS[ "checkpoint" ] # "/home/mila/a/aniket.didolkar/scratch/lang_oclf_github/language_conditioned_oclf/outputs/projects/prompting/vg/prompt_vg_small14_dinov2_mapping_lang_point_pred_sep/2024-10-10_15-05-41/checkpoints/epoch=0-step=243000.ckpt" oclf_config = OmegaConf.load(config_path) self.model = train.build_model_from_config( oclf_config, encoder_checkpoint_path ).to(self.device) self.model.eval() def _init_transforms(self): self.base_transform = transforms.Compose( [ transforms.Resize( (224, 224), interpolation=torchvision.transforms.InterpolationMode.BILINEAR, ), transforms.ToTensor(), transforms.Normalize( mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225] ), ] ) @torch.no_grad() def extract_features_batch(self, images, prompts): """Extract features for a batch of images.""" images = torch.stack([self.base_transform(image) for image in images]).to( self.device ) # TODO: check if this is correct name_embeddings = torch.stack([l2v.encode(prompt) for prompt in prompts]).to(self.device) bsz = images.shape[0] inputs = { "image": images, "bbox_centroids": torch.tensor([[-1, -1]] * 7, dtype=torch.float32) .unsqueeze(0) .to(self.device), # "contrastive_loss_mask": torch.zeros(7).repeat(bsz, 1).to(self.device), # "name_embedding": torch.randn(bsz, 7, 512).to( # self.device # ), # TODO: replace with actual name embeddings "contrastive_loss_mask": torch.stack([torch.tensor([int(p != "other") for p in prompt]) for prompt in prompts]).to(self.device), "name_embedding": name_embeddings, "instance_bbox": torch.tensor( [[-1, -1, -1, -1]] * 7, dtype=torch.float32 ) .repeat(bsz, 1, 1) .to(self.device), # TODO: this field was not before "batch_size": bsz, } outputs = self.model(inputs) features = outputs["perceptual_grouping"].objects # make sure feature shape makes sense return features.cpu().numpy() # you can specify upto 7 regions or objects phrases, the rest will be "other" prompts = [ ["The orange bag on the skier's back.", "The brown pants worn by the skier.", "Beautiful green trees in ice", "other", "other", "other", "other"], ["Man going down ski slope", "The snow is white", "other", "other", "other", "other", "other"], ] images = [ "/path/to/image1.jpg", "/path/to/image2.jpg", ] feature_extractor = FeatureExtractor() features = feature_extractor.extract_features_batch(images, prompts) print(features)