Spaces:
Runtime error
Runtime error
File size: 4,604 Bytes
5aa316f 20d1a10 804efd3 5aa316f 4dad5d2 523ffdf 804efd3 5aa316f eaaab0d 5aa316f 804efd3 eaaab0d 804efd3 eaaab0d 4dad5d2 eaaab0d 534fa09 5aa316f e6c364b 523ffdf eaaab0d 5aa316f eaaab0d 5aa316f 534fa09 804efd3 5aa316f 804efd3 5aa316f eaaab0d 4dad5d2 eaaab0d 534fa09 eaaab0d 534fa09 eaaab0d |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 |
import matplotlib.pyplot as plt
import torch
import torchvision.transforms as T
from PIL import Image
import gradio as gr
from featup.util import norm, unnorm, pca, remove_axes
from pytorch_lightning import seed_everything
import os
import requests
import os
import csv
def plot_feats(image, lr, hr):
assert len(image.shape) == len(lr.shape) == len(hr.shape) == 3
seed_everything(0)
[lr_feats_pca, hr_feats_pca], _ = pca([lr.unsqueeze(0), hr.unsqueeze(0)], dim=9)
fig, ax = plt.subplots(3, 3, figsize=(15, 15))
ax[0, 0].imshow(image.permute(1, 2, 0).detach().cpu())
ax[1, 0].imshow(image.permute(1, 2, 0).detach().cpu())
ax[2, 0].imshow(image.permute(1, 2, 0).detach().cpu())
ax[0, 0].set_title("Image", fontsize=22)
ax[0, 1].set_title("Original", fontsize=22)
ax[0, 2].set_title("Upsampled Features", fontsize=22)
ax[0, 1].imshow(lr_feats_pca[0, :3].permute(1, 2, 0).detach().cpu())
ax[0, 0].set_ylabel("PCA Components 1-3", fontsize=22)
ax[0, 2].imshow(hr_feats_pca[0, :3].permute(1, 2, 0).detach().cpu())
ax[1, 1].imshow(lr_feats_pca[0, 3:6].permute(1, 2, 0).detach().cpu())
ax[1, 0].set_ylabel("PCA Components 4-6", fontsize=22)
ax[1, 2].imshow(hr_feats_pca[0, 3:6].permute(1, 2, 0).detach().cpu())
ax[2, 1].imshow(lr_feats_pca[0, 6:9].permute(1, 2, 0).detach().cpu())
ax[2, 0].set_ylabel("PCA Components 7-9", fontsize=22)
ax[2, 2].imshow(hr_feats_pca[0, 6:9].permute(1, 2, 0).detach().cpu())
remove_axes(ax)
plt.tight_layout()
plt.close(fig) # Close plt to avoid additional empty plots
return fig
if __name__ == "__main__":
def download_image(url, save_path):
response = requests.get(url)
with open(save_path, 'wb') as file:
file.write(response.content)
base_url = "https://marhamilresearch4.blob.core.windows.net/feature-upsampling-public/sample_images/"
sample_images_urls = {
"skate.jpg": base_url + "skate.jpg",
"car.jpg": base_url + "car.jpg",
"plant.png": base_url + "plant.png",
}
sample_images_dir = "/tmp/sample_images"
# Ensure the directory for sample images exists
os.makedirs(sample_images_dir, exist_ok=True)
# Download each sample image
for filename, url in sample_images_urls.items():
save_path = os.path.join(sample_images_dir, filename)
# Download the image if it doesn't already exist
if not os.path.exists(save_path):
print(f"Downloading {filename}...")
download_image(url, save_path)
else:
print(f"{filename} already exists. Skipping download.")
os.environ['TORCH_HOME'] = '/tmp/.cache'
os.environ['GRADIO_EXAMPLES_CACHE'] = '/tmp/gradio_cache'
csv.field_size_limit(100000000)
options = ['dino16', 'vit', 'dinov2', 'clip', 'resnet50']
image_input = gr.Image(label="Choose an image to featurize",
height=480,
type="pil",
image_mode='RGB',
sources=['upload', 'webcam', 'clipboard']
)
model_option = gr.Radio(options, value="dino16", label='Choose a backbone to upsample')
models = {o: torch.hub.load("mhamilton723/FeatUp", o) for o in options}
def upsample_features(image, model_option):
# Image preprocessing
input_size = 224
transform = T.Compose([
T.Resize(input_size),
T.CenterCrop((input_size, input_size)),
T.ToTensor(),
norm
])
image_tensor = transform(image).unsqueeze(0).cuda()
# Load the selected model
upsampler = models[model_option].cuda()
hr_feats = upsampler(image_tensor)
lr_feats = upsampler.model(image_tensor)
upsampler.cpu()
return plot_feats(unnorm(image_tensor)[0], lr_feats[0], hr_feats[0])
demo = gr.Interface(fn=upsample_features,
inputs=[image_input, model_option],
outputs="plot",
title="Feature Upsampling Demo",
description="This demo allows you to upsample features of an image using selected models.",
examples=[
["/tmp/sample_images/skate.jpg", "dino16"],
["/tmp/sample_images/car.jpg", "dinov2"],
["/tmp/sample_images/plant.png", "dino16"],
]
)
demo.launch(server_name="0.0.0.0", server_port=7860, debug=True)
|