clip-italian-demo / localization.py
4rtemi5's picture
add localization and examples
ea3b7ec
raw
history blame
4.94 kB
import streamlit as st
from text2image import get_model, get_tokenizer, get_image_transform
from utils import text_encoder
from torchvision import transforms
from PIL import Image
from jax import numpy as jnp
import pandas as pd
import numpy as np
import requests
import jax
import gc
preprocess = transforms.Compose([
transforms.ToTensor(),
transforms.Normalize((0.48145466, 0.4578275, 0.40821073), (0.26862954, 0.26130258, 0.27577711)),
])
def pad_to_square(image, size=224):
ratio = float(size) / max(image.size)
new_size = tuple([int(x * ratio) for x in image.size])
image = image.resize(new_size, Image.ANTIALIAS)
new_image = Image.new("RGB", size=(size, size), color=(128, 128, 128))
new_image.paste(image, ((size - new_size[0]) // 2, (size - new_size[1]) // 2))
return new_image
def image_encoder(image, model):
image = np.transpose(image, (0, 2, 3, 1))
features = model.get_image_features(image)
features /= jnp.linalg.norm(features, keepdims=True)
return features
def gen_image_batch(image_url, image_size=224, pixel_size=10):
n_pixels = image_size // pixel_size + 1
image_batch = []
masks = []
image_raw = requests.get(image_url, stream=True).raw
image = Image.open(image_raw).convert("RGB")
image = pad_to_square(image, size=image_size)
gray = np.ones_like(image) * 128
mask = np.ones_like(image)
image_batch.append(image)
masks.append(mask)
for i in range(0, n_pixels):
for j in range(i+1, n_pixels):
m = mask.copy()
m[:min(i*pixel_size, image_size) + 1, :] = 0
m[min(j*pixel_size, image_size) + 1:, :] = 0
neg_m = 1 - m
image_batch.append(image * m + gray * neg_m)
masks.append(m)
for i in range(0, n_pixels+1):
for j in range(i+1, n_pixels+1):
m = mask.copy()
m[:, :min(i*pixel_size + 1, image_size)] = 0
m[:, min(j*pixel_size + 1, image_size):] = 0
neg_m = 1 - m
image_batch.append(image * m + gray * neg_m)
masks.append(m)
return image_batch, masks
def get_heatmap(image_url, text, pixel_size=10, iterations=3):
tokenizer = get_tokenizer()
model = get_model()
image_size = model.config.vision_config.image_size
text_embedding = text_encoder(text, model, tokenizer)
images, masks = gen_image_batch(image_url, image_size=image_size, pixel_size=pixel_size)
input_image = images[0].copy()
images = np.stack([preprocess(image) for image in images], axis=0)
image_embeddings = jnp.asarray(image_encoder(images, model))
sims = []
scores = []
mask_val = jnp.zeros_like(masks[0])
for e, m in zip(image_embeddings, masks):
sim = jnp.matmul(e, text_embedding.T)
sims.append(sim)
if len(sims) > 1:
scores.append(sim * m)
mask_val += 1 - m
score = jnp.mean(jnp.clip(jnp.array(scores) - sims[0], 0, jnp.inf), axis=0)
for i in range(iterations):
score = jnp.clip(score - jnp.mean(score), 0, jnp.inf)
score = (score - jnp.min(score)) / (jnp.max(score) - jnp.min(score))
return np.asarray(score), input_image
def app():
st.title("Zero-Shot Localization")
st.markdown(
"""
### πŸ‘‹ Ciao!
Here you can find an exaple for zero shot localization that will show you where in an image the model sees an object.
🀌 Italian mode on! 🀌
For example, try typing "gatto" (cat) or "cane" (dog) in the space for label and click "locate"!
"""
)
image_url = st.text_input(
"You can input the URL of an image here...",
value="https://www.tuttosuigatti.it/files/styles/full_width/public/images/featured/205/cani-e-gatti.jpg?itok=WAAiTGS6",
)
MAX_ITER = 1
col1, col2 = st.beta_columns([3, 1])
with col2:
pixel_size = st.selectbox(
"Pixel Size", options=range(10, 21, 5), index=0
)
iterations = st.selectbox(
"Refinement Steps", options=range(3, 30, 3), index=0
)
compute = st.button("LOCATE")
with col1:
caption = st.text_input(f"Insert label...")
if compute:
if not caption or not image_url:
st.error("Please choose one image and at least one label")
else:
with st.spinner("Computing..."):
heatmap, image = get_heatmap(image_url, caption, pixel_size, iterations)
with col1:
st.image(image, use_column_width=True)
st.image(heatmap, use_column_width=True)
st.image(np.asarray(image) / 255.0 * heatmap, use_column_width=True)
gc.collect()
elif image_url:
image_raw = requests.get(image_url, stream=True, ).raw
image = Image.open(image_raw).convert("RGB")
with col1:
st.image(image)