Spaces:
Running
Running
from html import escape | |
import requests | |
from io import BytesIO | |
import base64 | |
from multiprocessing.dummy import Pool | |
from PIL import Image, ImageDraw | |
import streamlit as st | |
import pandas as pd, numpy as np | |
import torch | |
from transformers import CLIPProcessor, CLIPModel | |
from transformers import OwlViTProcessor, OwlViTForObjectDetection | |
from transformers.image_utils import ImageFeatureExtractionMixin | |
import tokenizers | |
DEBUG = False | |
if DEBUG: | |
MODEL = "vit-base-patch32" | |
OWL_MODEL = f"google/owlvit-base-patch32" | |
else: | |
MODEL = "vit-large-patch14-336" | |
OWL_MODEL = f"google/owlvit-large-patch14" | |
CLIP_MODEL = f"openai/clip-{MODEL}" | |
if not DEBUG and torch.cuda.is_available(): | |
device = torch.device("cuda") | |
else: | |
device = torch.device("cpu") | |
HEIGHT = 200 | |
N_RESULTS = 6 | |
color = st.get_option("theme.primaryColor") | |
if color is None: | |
color = (255, 196, 35) | |
else: | |
color = tuple(int(color.lstrip("#")[i : i + 2], 16) for i in (0, 2, 4)) | |
def load(): | |
df = {0: pd.read_csv("data.csv"), 1: pd.read_csv("data2.csv")} | |
clip_model = CLIPModel.from_pretrained(CLIP_MODEL) | |
clip_model.to(device) | |
clip_model.eval() | |
clip_processor = CLIPProcessor.from_pretrained(CLIP_MODEL) | |
owl_model = OwlViTForObjectDetection.from_pretrained(OWL_MODEL) | |
owl_model.to(device) | |
owl_model.eval() | |
owl_processor = OwlViTProcessor.from_pretrained(OWL_MODEL) | |
embeddings = { | |
0: np.load(f"embeddings-{MODEL}.npy"), | |
1: np.load(f"embeddings2-{MODEL}.npy"), | |
} | |
for k in [0, 1]: | |
embeddings[k] = embeddings[k] / np.linalg.norm( | |
embeddings[k], axis=1, keepdims=True | |
) | |
return clip_model, clip_processor, owl_model, owl_processor, df, embeddings | |
clip_model, clip_processor, owl_model, owl_processor, df, embeddings = load() | |
mixin = ImageFeatureExtractionMixin() | |
source = {0: "\nSource: Unsplash", 1: "\nSource: The Movie Database (TMDB)"} | |
def compute_text_embeddings(list_of_strings): | |
inputs = clip_processor(text=list_of_strings, return_tensors="pt", padding=True).to( | |
device | |
) | |
with torch.no_grad(): | |
result = clip_model.get_text_features(**inputs).detach().cpu().numpy() | |
return result / np.linalg.norm(result, axis=1, keepdims=True) | |
def image_search(query, corpus, n_results=N_RESULTS): | |
query_embedding = compute_text_embeddings([query]) | |
corpus_id = 0 if corpus == "Unsplash" else 1 | |
dot_product = (embeddings[corpus_id] @ query_embedding.T)[:, 0] | |
results = np.argsort(dot_product)[-1 : -n_results - 1 : -1] | |
return [ | |
( | |
df[corpus_id].iloc[i].path, | |
df[corpus_id].iloc[i].tooltip + source[corpus_id], | |
df[corpus_id].iloc[i].link, | |
) | |
for i in results | |
] | |
def make_square(img, fill_color=(255, 255, 255)): | |
x, y = img.size | |
size = max(x, y) | |
new_img = Image.new("RGB", (size, size), fill_color) | |
new_img.paste(img, (int((size - x) / 2), int((size - y) / 2))) | |
return new_img, x, y | |
def get_images(paths): | |
def process_image(path): | |
return make_square(Image.open(BytesIO(requests.get(path).content))) | |
processed = Pool(N_RESULTS).map(process_image, paths) | |
imgs, xs, ys = [], [], [] | |
for img, x, y in processed: | |
imgs.append(img) | |
xs.append(x) | |
ys.append(y) | |
return imgs, xs, ys | |
def apply_owl_model(owl_queries, images): | |
inputs = owl_processor(text=owl_queries, images=images, return_tensors="pt").to( | |
device | |
) | |
with torch.no_grad(): | |
results = owl_model(**inputs) | |
target_sizes = torch.Tensor([img.size[::-1] for img in images]).to(device) | |
return owl_processor.post_process(outputs=results, target_sizes=target_sizes) | |
def keep_best_boxes(boxes, scores, score_threshold=0.1, max_iou=0.8): | |
candidates = [] | |
for box, score in zip(boxes, scores): | |
box = [round(i, 0) for i in box.tolist()] | |
if score >= score_threshold: | |
candidates.append((box, float(score))) | |
to_ignore = set() | |
for i in range(len(candidates) - 1): | |
if i in to_ignore: | |
continue | |
for j in range(i + 1, len(candidates)): | |
if j in to_ignore: | |
continue | |
xmin1, ymin1, xmax1, ymax1 = candidates[i][0] | |
xmin2, ymin2, xmax2, ymax2 = candidates[j][0] | |
if xmax1 < xmin2 or xmax2 < xmin1 or ymax1 < ymin2 or ymax2 < ymin1: | |
continue | |
else: | |
xmin_inter, xmax_inter = sorted([xmin1, xmax1, xmin2, xmax2])[1:3] | |
ymin_inter, ymax_inter = sorted([ymin1, ymax1, ymin2, ymax2])[1:3] | |
area_inter = (xmax_inter - xmin_inter) * (ymax_inter - ymin_inter) | |
area1 = (xmax1 - xmin1) * (ymax1 - ymin1) | |
area2 = (xmax2 - xmin2) * (ymax2 - ymin2) | |
iou = area_inter / (area1 + area2 - area_inter) | |
if iou > max_iou: | |
if candidates[i][1] > candidates[j][1]: | |
to_ignore.add(j) | |
else: | |
to_ignore.add(i) | |
break | |
else: | |
if area_inter / area1 > 0.9: | |
if candidates[i][1] < 1.1 * candidates[j][1]: | |
to_ignore.add(i) | |
if area_inter / area2 > 0.9: | |
if 1.1 * candidates[i][1] > candidates[j][1]: | |
to_ignore.add(j) | |
return [candidates[i][0] for i in range(len(candidates)) if i not in to_ignore] | |
def convert_pil_to_base64(image): | |
img_buffer = BytesIO() | |
image.save(img_buffer, format="JPEG") | |
byte_data = img_buffer.getvalue() | |
base64_str = base64.b64encode(byte_data) | |
return base64_str | |
def draw_reshape_encode(img, boxes, x, y): | |
image = img.copy() | |
draw = ImageDraw.Draw(image) | |
new_x, new_y = int(x * HEIGHT / y), HEIGHT | |
for box in boxes: | |
draw.rectangle( | |
(tuple(box[:2]), tuple(box[2:])), outline=color, width=2 * int(y / HEIGHT) | |
) | |
if x > y: | |
image = image.crop((0, (x - y) / 2, x, x - (x - y) / 2)) | |
else: | |
image = image.crop(((y - x) / 2, 0, y - (y - x) / 2, y)) | |
return convert_pil_to_base64(image.resize((new_x, new_y))) | |
def get_html(url_list, encoded_images): | |
html = "<div style='margin-top: 20px; max-width: 1200px; display: flex; flex-wrap: wrap; justify-content: space-evenly'>" | |
for i in range(len(url_list)): | |
title, link, encoded = url_list[i][1], url_list[i][2], encoded_images[i] | |
html2 = f"<img title='{escape(title)}' style='height: {HEIGHT}px; margin: 5px' src='data:image/jpeg;base64,{encoded.decode()}'>" | |
if len(link) > 0: | |
html2 = f"<a href='{escape(link)}' target='_blank'>" + html2 + "</a>" | |
html = html + html2 | |
html += "</div>" | |
return html | |
description = """ | |
# Search and Detect | |
This demo illustrates how you can both retrieve images containing certain objects and locate these objects with a simple natural language query. | |
**Enter your query and hit enter** | |
**Tip 1**: if your query includes "/", the part left (resp. right) of "/" will be used to retrieve images (resp. locate objects). For example, if you want to retrieve pictures with several cats but locate individual cats, you can type "cats / cat". | |
**Tip 2**: change the score threshold below to adjust the sensitivity of the object detection. | |
*Built with OpenAI's [CLIP](https://openai.com/blog/clip/) model and Google's [OWL-ViT](https://arxiv.org/abs/2205.06230) model, 🤗 Hugging Face's [transformers library](https://huggingface.co/transformers/), [Streamlit](https://streamlit.io/), 25k images from [Unsplash](https://unsplash.com/) and 8k images from [The Movie Database (TMDB)](https://www.themoviedb.org/)* | |
""" | |
div_style = { | |
"display": "flex", | |
"justify-content": "center", | |
"flex-wrap": "wrap", | |
} | |
def main(): | |
st.markdown( | |
""" | |
<style> | |
.block-container{ | |
max-width: 1200px; | |
} | |
div.row-widget.stRadio > div{ | |
flex-direction:row; | |
display: flex; | |
justify-content: center; | |
} | |
div.row-widget.stRadio > div > label{ | |
margin-left: 5px; | |
margin-right: 5px; | |
} | |
.row-widget { | |
margin-top: -25px; | |
} | |
section>div:first-child { | |
padding-top: 30px; | |
} | |
div.reportview-container > section:first-child{ | |
max-width: 320px; | |
} | |
#MainMenu { | |
visibility: hidden; | |
} | |
footer { | |
visibility: hidden; | |
} | |
</style>""", | |
unsafe_allow_html=True, | |
) | |
st.sidebar.markdown(description) | |
score_threshold = st.sidebar.slider( | |
"Score threshold", min_value=0.01, max_value=0.3, value=0.1, step=0.01 | |
) | |
_, c, _ = st.columns((1, 3, 1)) | |
query = c.text_input("", value="koala") | |
corpus = st.radio("", ["Unsplash", "Movies"]) | |
if len(query) > 0: | |
if "/" in query: | |
queries = query.split("/") | |
clip_query, owl_query = ("/").join(queries[:-1]).strip(), queries[ | |
-1 | |
].strip() | |
else: | |
clip_query, owl_query = query, query | |
retrieved = image_search(clip_query, corpus) | |
imgs, xs, ys = get_images([x[0] for x in retrieved]) | |
results = apply_owl_model([[owl_query]] * len(imgs), imgs) | |
encoded_images = [] | |
for image_idx in range(len(imgs)): | |
img0, x, y = imgs[image_idx], xs[image_idx], ys[image_idx] | |
boxes = keep_best_boxes( | |
results[image_idx]["boxes"], | |
results[image_idx]["scores"], | |
score_threshold=score_threshold, | |
) | |
encoded_images.append(draw_reshape_encode(img0, boxes, x, y)) | |
st.markdown(get_html(retrieved, encoded_images), unsafe_allow_html=True) | |
if __name__ == "__main__": | |
main() | |