from html import escape import requests from io import BytesIO import base64 from multiprocessing.dummy import Pool from PIL import Image, ImageDraw import streamlit as st import pandas as pd, numpy as np import torch from transformers import CLIPProcessor, CLIPModel from transformers import OwlViTProcessor, OwlViTForObjectDetection from transformers.image_utils import ImageFeatureExtractionMixin import tokenizers DEBUG = True if DEBUG: MODEL = "vit-base-patch32" OWL_MODEL = f"google/owlvit-base-patch32" else: MODEL = "vit-large-patch14-336" OWL_MODEL = f"google/owlvit-large-path14" CLIP_MODEL = f"openai/clip-{MODEL}" if not DEBUG and torch.cuda.is_available(): device = torch.device("cuda") else: device = torch.device("cpu") HEIGHT = 200 N_RESULTS = 6 color = st.get_option("theme.primaryColor") if color is None: color = (255, 75, 75) else: color = tuple(int(color.lstrip("#")[i : i + 2], 16) for i in (0, 2, 4)) @st.cache(allow_output_mutation=True) def load(): df = {0: pd.read_csv("data.csv"), 1: pd.read_csv("data2.csv")} clip_model = CLIPModel.from_pretrained(CLIP_MODEL) clip_model.to(device) clip_model.eval() clip_processor = CLIPProcessor.from_pretrained(CLIP_MODEL) owl_model = OwlViTForObjectDetection.from_pretrained(OWL_MODEL) owl_model.to(device) owl_model.eval() owl_processor = OwlViTProcessor.from_pretrained(OWL_MODEL) embeddings = { 0: np.load(f"embeddings-{MODEL}.npy"), 1: np.load(f"embeddings2-{MODEL}.npy"), } for k in [0, 1]: embeddings[k] = embeddings[k] / np.linalg.norm( embeddings[k], axis=1, keepdims=True ) return clip_model, clip_processor, owl_model, owl_processor, df, embeddings clip_model, clip_processor, owl_model, owl_processor, df, embeddings = load() mixin = ImageFeatureExtractionMixin() source = {0: "\nSource: Unsplash", 1: "\nSource: The Movie Database (TMDB)"} def compute_text_embeddings(list_of_strings): inputs = clip_processor(text=list_of_strings, return_tensors="pt", padding=True).to( device ) with torch.no_grad(): result = clip_model.get_text_features(**inputs).detach().cpu().numpy() return result / np.linalg.norm(result, axis=1, keepdims=True) def image_search(query, corpus, n_results=N_RESULTS): query_embedding = compute_text_embeddings([query]) corpus_id = 0 if corpus == "Unsplash" else 1 dot_product = (embeddings[corpus_id] @ query_embedding.T)[:, 0] results = np.argsort(dot_product)[-1 : -n_results - 1 : -1] return [ ( df[corpus_id].iloc[i].path, df[corpus_id].iloc[i].tooltip + source[corpus_id], df[corpus_id].iloc[i].link, ) for i in results ] def make_square(img, fill_color=(255, 255, 255)): x, y = img.size size = max(x, y) new_img = Image.new("RGB", (size, size), fill_color) new_img.paste(img, (int((size - x) / 2), int((size - y) / 2))) return new_img, x, y @st.cache(allow_output_mutation=True, show_spinner=False) def get_images(paths): def process_image(path): return make_square(Image.open(BytesIO(requests.get(path).content))) processed = Pool(N_RESULTS).map(process_image, paths) imgs, xs, ys = [], [], [] for img, x, y in processed: imgs.append(img) xs.append(x) ys.append(y) return imgs, xs, ys @st.cache( hash_funcs={ tokenizers.Tokenizer: lambda x: None, tokenizers.AddedToken: lambda x: None, torch.nn.parameter.Parameter: lambda x: None, }, allow_output_mutation=True, show_spinner=False, ) def apply_owl_model(owl_queries, images): inputs = owl_processor(text=owl_queries, images=images, return_tensors="pt").to( device ) with torch.no_grad(): results = owl_model(**inputs) target_sizes = torch.Tensor([img.size[::-1] for img in images]).to(device) return owl_processor.post_process(outputs=results, target_sizes=target_sizes) def keep_best_boxes(boxes, scores, score_threshold=0.1, max_iou=0.8): candidates = [] for box, score in zip(boxes, scores): box = [round(i, 0) for i in box.tolist()] if score >= score_threshold: candidates.append((box, float(score))) to_ignore = set() for i in range(len(candidates) - 1): if i in to_ignore: continue for j in range(i + 1, len(candidates)): if j in to_ignore: continue xmin1, ymin1, xmax1, ymax1 = candidates[i][0] xmin2, ymin2, xmax2, ymax2 = candidates[j][0] if xmax1 < xmin2 or xmax2 < xmin1 or ymax1 < ymin2 or ymax2 < ymin1: continue else: xmin_inter, xmax_inter = sorted([xmin1, xmax1, xmin2, xmax2])[1:3] ymin_inter, ymax_inter = sorted([ymin1, ymax1, ymin2, ymax2])[1:3] area_inter = (xmax_inter - xmin_inter) * (ymax_inter - ymin_inter) area1 = (xmax1 - xmin1) * (ymax1 - ymin1) area2 = (xmax2 - xmin2) * (ymax2 - ymin2) iou = area_inter / (area1 + area2 - area_inter) if iou > max_iou: if candidates[i][1] > candidates[j][1]: to_ignore.add(j) else: to_ignore.add(i) break else: if area_inter / area1 > 0.9: if candidates[i][1] < 1.1 * candidates[j][1]: to_ignore.add(i) if area_inter / area2 > 0.9: if 1.1 * candidates[i][1] > candidates[j][1]: to_ignore.add(j) return [candidates[i][0] for i in range(len(candidates)) if i not in to_ignore] def convert_pil_to_base64(image): img_buffer = BytesIO() image.save(img_buffer, format="JPEG") byte_data = img_buffer.getvalue() base64_str = base64.b64encode(byte_data) return base64_str def draw_reshape_encode(img, boxes, x, y): image = img.copy() draw = ImageDraw.Draw(image) new_x, new_y = int(x * HEIGHT / y), HEIGHT for box in boxes: draw.rectangle( (tuple(box[:2]), tuple(box[2:])), outline=color, width=2 * int(y / HEIGHT) ) if x > y: image = image.crop((0, (x - y) / 2, x, x - (x - y) / 2)) else: image = image.crop(((y - x) / 2, 0, y - (y - x) / 2, y)) return convert_pil_to_base64(image.resize((new_x, new_y))) def get_html(url_list, encoded_images): html = "
" for i in range(len(url_list)): title, link, encoded = url_list[i][1], url_list[i][2], encoded_images[i] html2 = f"" if len(link) > 0: html2 = f"" + html2 + "" html = html + html2 html += "
" return html description = """ # Search and Detect This demo illustrates how you can both retrieve images containing certain objects and locate these objects with a simple natural language query. **Enter your query and hit enter** **Tip 1**: if your query includes "/", the part left (resp. right) of "/" will be used to retrieve images (resp. locate objects). For example, if you want to retrieve pictures with several cats but locate individual cats, you can type "cats / cat". **Tip 2**: change the score threshold below to adjust the sensitivity of the object detection. *Built with OpenAI's [CLIP](https://openai.com/blog/clip/) model and Google's [Owl-ViT](https://arxiv.org/abs/2205.06230) model, 🤗 Hugging Face's [transformers library](https://huggingface.co/transformers/), [Streamlit](https://streamlit.io/), 25k images from [Unsplash](https://unsplash.com/) and 8k images from [The Movie Database (TMDB)](https://www.themoviedb.org/)* """ div_style = { "display": "flex", "justify-content": "center", "flex-wrap": "wrap", } def main(): st.markdown( """ """, unsafe_allow_html=True, ) st.sidebar.markdown(description) score_threshold = st.sidebar.slider( "Score threshold", min_value=0.01, max_value=0.3, value=0.1, step=0.01 ) _, c, _ = st.columns((1, 3, 1)) query = c.text_input("", value="clouds at sunset") corpus = st.radio("", ["Unsplash", "Movies"]) if len(query) > 0: if "/" in query: queries = query.split("/") clip_query, owl_query = ("/").join(queries[:-1]), queries[-1] else: clip_query, owl_query = query, query retrieved = image_search(clip_query, corpus) imgs, xs, ys = get_images([x[0] for x in retrieved]) results = apply_owl_model([[owl_query]] * len(imgs), imgs) encoded_images = [] for image_idx in range(len(imgs)): img0, x, y = imgs[image_idx], xs[image_idx], ys[image_idx] boxes = keep_best_boxes( results[image_idx]["boxes"], results[image_idx]["scores"], score_threshold=score_threshold, ) encoded_images.append(draw_reshape_encode(img0, boxes, x, y)) st.markdown(get_html(retrieved, encoded_images), unsafe_allow_html=True) if __name__ == "__main__": main()