import os import streamlit as st from io import BytesIO from multiprocessing.dummy import Pool import base64 from PIL import Image, ImageOps import torch import numpy as np from torchvision import transforms from streamlit_drawable_canvas import st_canvas from src.model_LN_prompt import Model from html import escape import pickle as pkl from huggingface_hub import hf_hub_download, login from datasets import load_dataset if 'initialized' not in st.session_state: st.session_state.initialized = False device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu") HEIGHT = 200 N_RESULTS = 20 color = st.get_option("theme.primaryColor") if color is None: color = (0, 0, 255) else: color = tuple(int(color.lstrip("#")[i: i + 2], 16) for i in (0, 2, 4)) @st.cache_resource def initialize_huggingface(): token = os.getenv("HUGGINGFACE_TOKEN") if token: login(token=token) else: st.error("HUGGINGFACE_TOKEN not found in environment variables") @st.cache_resource def load_model_and_data(): print("Loading everything...") dataset = load_dataset("CHSTR/ecommerce") path_images = "./data/" + "" #"/".join(dataset['validation']['image'][0].filename.split("/")[:-3]) + "/" # Download model path_model = hf_hub_download( repo_id="CHSTR/Ecommerce", filename="dinov2_ecommerce.ckpt") # Load model model = Model().to(device) model_checkpoint = torch.load(path_model, map_location=device) model.load_state_dict(model_checkpoint['state_dict']) model.eval() # Download and load embeddings embeddings_file = hf_hub_download( repo_id="CHSTR/Ecommerce", filename="ecommerce_demo.pkl") embeddings = { 0: pkl.load(open(embeddings_file, "rb")), 1: pkl.load(open(embeddings_file, "rb")) } # Update image paths for corpus_id in [0, 1]: embeddings[corpus_id] = [ (emb[0], path_images + "/".join(emb[1].split("/")[-3:])) for emb in embeddings[corpus_id] ] return model, path_images, embeddings def compute_sketch(_sketch, model): with torch.no_grad(): sketch_feat = model(_sketch.to(device), dtype='sketch') return sketch_feat def image_search(_query, corpus, model, embeddings, n_results=N_RESULTS): query_embedding = compute_sketch(_query, model) corpus_id = 0 if corpus == "Unsplash" else 1 image_features = torch.from_numpy( np.array([item[0] for item in embeddings[corpus_id]]) ).to(device) dot_product = (image_features @ query_embedding.T)[:, 0] _, max_indices = torch.topk( dot_product, n_results, dim=0, largest=True, sorted=True) path_to_label = {path: idx for idx, (_, path) in enumerate(embeddings[corpus_id])} label_to_path = {idx: path for path, idx in path_to_label.items()} label_of_images = torch.tensor( [path_to_label[item[1]] for item in embeddings[corpus_id]]).to(device) return [ (label_to_path[i],) for i in label_of_images[max_indices].cpu().numpy().tolist() ], dot_product[max_indices] @st.cache_data def make_square(img_path, fill_color=(255, 255, 255)): img = Image.open(img_path) x, y = img.size size = max(x, y) new_img = Image.new("RGB", (x, y), fill_color) new_img.paste(img) return new_img, x, y @st.cache_data def get_images(paths): processed = [make_square(path) for path in paths] imgs, xs, ys = zip(*processed) return list(imgs), list(xs), list(ys) @st.cache_data def convert_pil_to_base64(image): img_buffer = BytesIO() image.save(img_buffer, format="JPEG") byte_data = img_buffer.getvalue() base64_str = base64.b64encode(byte_data) return base64_str def get_html(url_list, encoded_images): html = "
" for i in range(len(url_list)): title, encoded = url_list[i][0], encoded_images[i] html = ( html + f"" ) html += "
" return html def main(): if not st.session_state.initialized: initialize_huggingface() st.session_state.model, st.session_state.path_images, st.session_state.embeddings = load_model_and_data() st.session_state.initialized = True description = """ # Self-Supervised Sketch-based Image Retrieval (S3BIR) Our approaches, S3BIR-CLIP and S3BIR-DINOv2, can produce a bimodal sketch-photo feature space from unpaired data without explicit sketch-photo pairs. Our experiments perform outstandingly in three diverse public datasets where the models are trained without real sketches. """ st.sidebar.markdown(description) stroke_width = st.sidebar.slider("Stroke width: ", 1, 25, 5) # styles st.markdown( """ """, unsafe_allow_html=True, ) st.title("S3BIR App") _, col, _ = st.columns((1, 1, 1)) with col: canvas_result = st_canvas( background_color="#eee", stroke_width=stroke_width, update_streamlit=True, height=300, width=300, key="color_annotation_app", ) corpus = ["Ecommerce"] st.columns((1, 3, 1)) if canvas_result.image_data is not None: draw = Image.fromarray(canvas_result.image_data.astype("uint8")) draw = ImageOps.pad(draw.convert("RGB"), size=(224, 224)) draw_tensor = transforms.ToTensor()(draw) draw_tensor = transforms.Resize((224, 224))(draw_tensor) draw_tensor = transforms.Normalize( mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225] )(draw_tensor) draw_tensor = draw_tensor.unsqueeze(0) retrieved, _ = image_search( draw_tensor, corpus[0], st.session_state.model, st.session_state.embeddings) imgs, xs, ys = get_images([x[0] for x in retrieved]) encoded_images = [] for image_idx in range(len(imgs)): img0, x, y = imgs[image_idx], xs[image_idx], ys[image_idx] new_x, new_y = int(x * HEIGHT / y), HEIGHT encoded_images.append(convert_pil_to_base64( img0.resize((new_x, new_y)))) st.markdown(get_html(retrieved, encoded_images), unsafe_allow_html=True) if __name__ == "__main__": main()