from html import escape import requests from io import BytesIO import base64 from multiprocessing.dummy import Pool from PIL import Image, ImageDraw import streamlit as st import pandas as pd, numpy as np import torch from transformers import CLIPProcessor, CLIPModel from transformers import OwlViTProcessor, OwlViTForObjectDetection from transformers.image_utils import ImageFeatureExtractionMixin import tokenizers DEBUG = False if DEBUG: MODEL = "vit-base-patch32" OWL_MODEL = f"google/owlvit-base-patch32" else: MODEL = "vit-large-patch14-336" OWL_MODEL = f"google/owlvit-large-patch14" CLIP_MODEL = f"openai/clip-{MODEL}" if not DEBUG and torch.cuda.is_available(): device = torch.device("cuda") else: device = torch.device("cpu") HEIGHT = 200 N_RESULTS = 6 color = st.get_option("theme.primaryColor") if color is None: color = (255, 196, 35) else: color = tuple(int(color.lstrip("#")[i : i + 2], 16) for i in (0, 2, 4)) @st.cache(allow_output_mutation=True) def load(): df = {0: pd.read_csv("data.csv"), 1: pd.read_csv("data2.csv")} clip_model = CLIPModel.from_pretrained(CLIP_MODEL) clip_model.to(device) clip_model.eval() clip_processor = CLIPProcessor.from_pretrained(CLIP_MODEL) owl_model = OwlViTForObjectDetection.from_pretrained(OWL_MODEL) owl_model.to(device) owl_model.eval() owl_processor = OwlViTProcessor.from_pretrained(OWL_MODEL) embeddings = { 0: np.load(f"embeddings-{MODEL}.npy"), 1: np.load(f"embeddings2-{MODEL}.npy"), } for k in [0, 1]: embeddings[k] = embeddings[k] / np.linalg.norm( embeddings[k], axis=1, keepdims=True ) return clip_model, clip_processor, owl_model, owl_processor, df, embeddings clip_model, clip_processor, owl_model, owl_processor, df, embeddings = load() mixin = ImageFeatureExtractionMixin() source = {0: "\nSource: Unsplash", 1: "\nSource: The Movie Database (TMDB)"} def compute_text_embeddings(list_of_strings): inputs = clip_processor(text=list_of_strings, return_tensors="pt", padding=True).to( device ) with torch.no_grad(): result = clip_model.get_text_features(**inputs).detach().cpu().numpy() return result / np.linalg.norm(result, axis=1, keepdims=True) def image_search(query, corpus, n_results=N_RESULTS): query_embedding = compute_text_embeddings([query]) corpus_id = 0 if corpus == "Unsplash" else 1 dot_product = (embeddings[corpus_id] @ query_embedding.T)[:, 0] results = np.argsort(dot_product)[-1 : -n_results - 1 : -1] return [ ( df[corpus_id].iloc[i].path, df[corpus_id].iloc[i].tooltip + source[corpus_id], df[corpus_id].iloc[i].link, ) for i in results ] def make_square(img, fill_color=(255, 255, 255)): x, y = img.size size = max(x, y) new_img = Image.new("RGB", (size, size), fill_color) new_img.paste(img, (int((size - x) / 2), int((size - y) / 2))) return new_img, x, y @st.cache(allow_output_mutation=True, show_spinner=False) def get_images(paths): def process_image(path): return make_square(Image.open(BytesIO(requests.get(path).content))) processed = Pool(N_RESULTS).map(process_image, paths) imgs, xs, ys = [], [], [] for img, x, y in processed: imgs.append(img) xs.append(x) ys.append(y) return imgs, xs, ys @st.cache( hash_funcs={ tokenizers.Tokenizer: lambda x: None, tokenizers.AddedToken: lambda x: None, torch.nn.parameter.Parameter: lambda x: None, }, allow_output_mutation=True, show_spinner=False, ) def apply_owl_model(owl_queries, images): inputs = owl_processor(text=owl_queries, images=images, return_tensors="pt").to( device ) with torch.no_grad(): results = owl_model(**inputs) target_sizes = torch.Tensor([img.size[::-1] for img in images]).to(device) return owl_processor.post_process(outputs=results, target_sizes=target_sizes) def keep_best_boxes(boxes, scores, score_threshold=0.1, max_iou=0.8): candidates = [] for box, score in zip(boxes, scores): box = [round(i, 0) for i in box.tolist()] if score >= score_threshold: candidates.append((box, float(score))) to_ignore = set() for i in range(len(candidates) - 1): if i in to_ignore: continue for j in range(i + 1, len(candidates)): if j in to_ignore: continue xmin1, ymin1, xmax1, ymax1 = candidates[i][0] xmin2, ymin2, xmax2, ymax2 = candidates[j][0] if xmax1 < xmin2 or xmax2 < xmin1 or ymax1 < ymin2 or ymax2 < ymin1: continue else: xmin_inter, xmax_inter = sorted([xmin1, xmax1, xmin2, xmax2])[1:3] ymin_inter, ymax_inter = sorted([ymin1, ymax1, ymin2, ymax2])[1:3] area_inter = (xmax_inter - xmin_inter) * (ymax_inter - ymin_inter) area1 = (xmax1 - xmin1) * (ymax1 - ymin1) area2 = (xmax2 - xmin2) * (ymax2 - ymin2) iou = area_inter / (area1 + area2 - area_inter) if iou > max_iou: if candidates[i][1] > candidates[j][1]: to_ignore.add(j) else: to_ignore.add(i) break else: if area_inter / area1 > 0.9: if candidates[i][1] < 1.1 * candidates[j][1]: to_ignore.add(i) if area_inter / area2 > 0.9: if 1.1 * candidates[i][1] > candidates[j][1]: to_ignore.add(j) return [candidates[i][0] for i in range(len(candidates)) if i not in to_ignore] def convert_pil_to_base64(image): img_buffer = BytesIO() image.save(img_buffer, format="JPEG") byte_data = img_buffer.getvalue() base64_str = base64.b64encode(byte_data) return base64_str def draw_reshape_encode(img, boxes, x, y): image = img.copy() draw = ImageDraw.Draw(image) new_x, new_y = int(x * HEIGHT / y), HEIGHT for box in boxes: draw.rectangle( (tuple(box[:2]), tuple(box[2:])), outline=color, width=2 * int(y / HEIGHT) ) if x > y: image = image.crop((0, (x - y) / 2, x, x - (x - y) / 2)) else: image = image.crop(((y - x) / 2, 0, y - (y - x) / 2, y)) return convert_pil_to_base64(image.resize((new_x, new_y))) def get_html(url_list, encoded_images): html = "