import os import streamlit as st from io import BytesIO from multiprocessing.dummy import Pool import base64 from PIL import Image, ImageOps import torch import numpy as np from torchvision import transforms from streamlit_drawable_canvas import st_canvas from src.model_LN_prompt import Model from html import escape import pickle as pkl from huggingface_hub import hf_hub_download, login from datasets import load_dataset if 'initialized' not in st.session_state: st.session_state.initialized = False device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu") HEIGHT = 200 N_RESULTS = 20 color = st.get_option("theme.primaryColor") if color is None: color = (0, 0, 255) else: color = tuple(int(color.lstrip("#")[i: i + 2], 16) for i in (0, 2, 4)) @st.cache_resource def initialize_huggingface(): token = os.getenv("HUGGINGFACE_TOKEN") if token: login(token=token) else: st.error("HUGGINGFACE_TOKEN not found in environment variables") @st.cache_resource def load_model_and_data(): print("Loading everything...") dataset = load_dataset("CHSTR/ecommerce") path_images = "./data/" + "" #"/".join(dataset['validation']['image'][0].filename.split("/")[:-3]) + "/" # Download model path_model = hf_hub_download( repo_id="CHSTR/Ecommerce", filename="dinov2_ecommerce.ckpt") # Load model model = Model().to(device) model_checkpoint = torch.load(path_model, map_location=device) model.load_state_dict(model_checkpoint['state_dict']) model.eval() # Download and load embeddings embeddings_file = hf_hub_download( repo_id="CHSTR/Ecommerce", filename="ecommerce_demo.pkl") embeddings = { 0: pkl.load(open(embeddings_file, "rb")), 1: pkl.load(open(embeddings_file, "rb")) } # Update image paths for corpus_id in [0, 1]: embeddings[corpus_id] = [ (emb[0], path_images + "/".join(emb[1].split("/")[-3:])) for emb in embeddings[corpus_id] ] return model, path_images, embeddings def compute_sketch(_sketch, model): with torch.no_grad(): sketch_feat = model(_sketch.to(device), dtype='sketch') return sketch_feat def image_search(_query, corpus, model, embeddings, n_results=N_RESULTS): query_embedding = compute_sketch(_query, model) corpus_id = 0 if corpus == "Unsplash" else 1 image_features = torch.from_numpy( np.array([item[0] for item in embeddings[corpus_id]]) ).to(device) dot_product = (image_features @ query_embedding.T)[:, 0] _, max_indices = torch.topk( dot_product, n_results, dim=0, largest=True, sorted=True) path_to_label = {path: idx for idx, (_, path) in enumerate(embeddings[corpus_id])} label_to_path = {idx: path for path, idx in path_to_label.items()} label_of_images = torch.tensor( [path_to_label[item[1]] for item in embeddings[corpus_id]]).to(device) return [ (label_to_path[i],) for i in label_of_images[max_indices].cpu().numpy().tolist() ], dot_product[max_indices] @st.cache_data def make_square(img_path, fill_color=(255, 255, 255)): img = Image.open(img_path) x, y = img.size size = max(x, y) new_img = Image.new("RGB", (x, y), fill_color) new_img.paste(img) return new_img, x, y @st.cache_data def get_images(paths): processed = [make_square(path) for path in paths] imgs, xs, ys = zip(*processed) return list(imgs), list(xs), list(ys) @st.cache_data def convert_pil_to_base64(image): img_buffer = BytesIO() image.save(img_buffer, format="JPEG") byte_data = img_buffer.getvalue() base64_str = base64.b64encode(byte_data) return base64_str def get_html(url_list, encoded_images): html = "