SliderSpace / utils /clip_util.py
RohitGandikota's picture
adding utils for sliders
4cbd4f2 verified
raw
history blame
4.89 kB
from typing import List, Optional
import math, random, os
import pandas as pd
import numpy as np
import torch
from tqdm.auto import tqdm
from sklearn.decomposition import PCA
def extract_clip_features(clip, image, encoder):
"""
Extracts feature embeddings from an image using either CLIP or DINOv2 models.
Args:
clip (torch.nn.Module): The feature extraction model (either CLIP or DINOv2)
image (torch.Tensor): Input image tensor normalized according to model requirements
encoder (str): Type of encoder to use ('dinov2-small' or 'clip')
Returns:
torch.Tensor: Feature embeddings extracted from the image
Note:
- For DINOv2 models, uses the pooled output features
- For CLIP models, uses the image features from the vision encoder
- The input image should already be properly resized and normalized
"""
# Handle DINOv2 models
if 'dino' in encoder:
denoised = clip(image)
denoised = denoised.pooler_output
# Handle CLIP models
else:
denoised = clip.get_image_features(image)
return denoised
@torch.no_grad()
def compute_clip_pca(
diverse_prompts: List[str],
pipe,
clip_model,
clip_processor,
device,
guidance_scale,
params,
total_samples = 5000,
num_pca_components = 100,
batch_size = 10
) -> torch.Tensor:
"""
Extract CLIP features from generated images based on prompts.
Args:
diverse_prompts: List of prompts to generate images from
model_components: Various model components needed for generation
args: Training arguments
Returns:
Tensor of CLIP principle components
"""
# Calculate how many total batches we need
num_batches = math.ceil(total_samples / batch_size)
# Randomly sample prompts (with replacement if needed)
sampled_prompts_clip = random.choices(diverse_prompts, k=num_batches)
clip_features_path = f"{params['savepath_training_images']}/clip_principle_directions.pt"
if os.path.exists(clip_features_path):
df = pd.read_csv(f"{params['savepath_training_images']}/training_data.csv")
prompts_training = list(df.prompt)
image_paths = list(df.image_path)
return torch.load(clip_features_path).to(device), prompts_training, image_paths
os.makedirs(params['savepath_training_images'], exist_ok=True)
# Generate images and extract features
img_idx = 0
clip_features = []
image_paths = []
prompts_training = []
print('Calculating Semantic PCA')
for prompt in tqdm(sampled_prompts_clip):
if 'max_sequence_length' in params:
images = pipe(prompt,
num_images_per_prompt = batch_size,
num_inference_steps = params['max_denoising_steps'],
guidance_scale=guidance_scale,
max_sequence_length = params['max_sequence_length'],
height = params['height'],
width = params['width'],
).images
else:
images = pipe(prompt,
num_images_per_prompt = batch_size,
num_inference_steps = params['max_denoising_steps'],
guidance_scale=guidance_scale,
height = params['height'],
width = params['width'],
).images
# Process images
clip_inputs = clip_processor(images=images, return_tensors="pt", padding=True)
pixel_values = clip_inputs['pixel_values'].to(device)
# Get image embeddings
with torch.no_grad():
image_features = clip_model.get_image_features(pixel_values)
# Normalize embeddings
clip_feats = image_features / image_features.norm(dim=1, keepdim=True)
clip_features.append(clip_feats)
for im in images:
image_path = f"{params['savepath_training_images']}/{img_idx}.png"
im.save(image_path)
image_paths.append(image_path)
prompts_training.append(prompt)
img_idx += 1
clip_features = torch.cat(clip_features)
# Calculate principle components
pca = PCA(n_components=num_pca_components)
clip_embeds_np = clip_features.float().cpu().numpy()
pca.fit(clip_embeds_np)
clip_principles = torch.from_numpy(pca.components_).to(device, dtype=pipe.vae.dtype)
# Save results
torch.save(clip_principles, clip_features_path)
pd.DataFrame({
'prompt': prompts_training,
'image_path': image_paths
}).to_csv(f"{params['savepath_training_images']}/training_data.csv", index=False)
return clip_principles, prompts_training, image_paths