import os |
import streamlit as st |
from io import BytesIO |
from multiprocessing.dummy import Pool |
import base64 |
from PIL import Image, ImageOps |
import torch |
import numpy as np |
from torchvision import transforms |
from streamlit_drawable_canvas import st_canvas |
from src.model_LN_prompt import Model |
from html import escape |
import pickle as pkl |
from huggingface_hub import hf_hub_download, login |
from datasets import load_dataset |
if 'initialized' not in st.session_state: |
st.session_state.initialized = False |
device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu") |
HEIGHT = 200 |
N_RESULTS = 20 |
color = st.get_option("theme.primaryColor") |
if color is None: |
color = (0, 0, 255) |
else: |
color = tuple(int(color.lstrip("#")[i: i + 2], 16) for i in (0, 2, 4)) |
@st.cache_resource |
def initialize_huggingface(): |
token = os.getenv("HUGGINGFACE_TOKEN") |
if token: |
login(token=token) |
else: |
st.error("HUGGINGFACE_TOKEN not found in environment variables") |
@st.cache_resource |
def load_model_and_data(): |
print("Loading everything...") |
dataset = load_dataset("CHSTR/ecommerce") |
path_images = "./data/" + "" |
path_model = hf_hub_download( |
repo_id="CHSTR/Ecommerce", filename="dinov2_ecommerce.ckpt") |
model = Model().to(device) |
model_checkpoint = torch.load(path_model, map_location=device) |
model.load_state_dict(model_checkpoint['state_dict']) |
model.eval() |
embeddings_file = hf_hub_download( |
repo_id="CHSTR/Ecommerce", filename="ecommerce_demo.pkl") |
embeddings = { |
0: pkl.load(open(embeddings_file, "rb")), |
1: pkl.load(open(embeddings_file, "rb")) |
} |
for corpus_id in [0, 1]: |
embeddings[corpus_id] = [ |
(emb[0], path_images + "/".join(emb[1].split("/")[-3:])) |
for emb in embeddings[corpus_id] |
] |
return model, path_images, embeddings |
def compute_sketch(_sketch, model): |
with torch.no_grad(): |
sketch_feat = model(_sketch.to(device), dtype='sketch') |
return sketch_feat |
def image_search(_query, corpus, model, embeddings, n_results=N_RESULTS): |
query_embedding = compute_sketch(_query, model) |
corpus_id = 0 if corpus == "Unsplash" else 1 |
image_features = torch.from_numpy( |
np.array([item[0] for item in embeddings[corpus_id]]) |
).to(device) |
dot_product = (image_features @ query_embedding.T)[:, 0] |
_, max_indices = torch.topk( |
dot_product, n_results, dim=0, largest=True, sorted=True) |
path_to_label = {path: idx for idx, |
(_, path) in enumerate(embeddings[corpus_id])} |
label_to_path = {idx: path for path, idx in path_to_label.items()} |
label_of_images = torch.tensor( |
[path_to_label[item[1]] for item in embeddings[corpus_id]]).to(device) |
return [ |
(label_to_path[i],) |
for i in label_of_images[max_indices].cpu().numpy().tolist() |
], dot_product[max_indices] |
@st.cache_data |
def make_square(img_path, fill_color=(255, 255, 255)): |
img = Image.open(img_path) |
x, y = img.size |
size = max(x, y) |
new_img = Image.new("RGB", (x, y), fill_color) |
new_img.paste(img) |
return new_img, x, y |
@st.cache_data |
def get_images(paths): |
processed = [make_square(path) for path in paths] |
imgs, xs, ys = zip(*processed) |
return list(imgs), list(xs), list(ys) |
@st.cache_data |
def convert_pil_to_base64(image): |
img_buffer = BytesIO() |
image.save(img_buffer, format="JPEG") |
byte_data = img_buffer.getvalue() |
base64_str = base64.b64encode(byte_data) |
return base64_str |
def get_html(url_list, encoded_images): |
html = "<div style='margin-top: 20px; max-width: 1200px; display: flex; flex-wrap: wrap; justify-content: space-evenly'>" |
for i in range(len(url_list)): |
title, encoded = url_list[i][0], encoded_images[i] |
html = ( |
html |
+ f"<img title='{escape(title)}' style='height: {HEIGHT}px; margin: 5px' src='data:image/jpeg;base64,{encoded.decode()}'>" |
) |
html += "</div>" |
return html |
def main(): |
if not st.session_state.initialized: |
initialize_huggingface() |
st.session_state.model, st.session_state.path_images, st.session_state.embeddings = load_model_and_data() |
st.session_state.initialized = True |
description = """ |
# Self-Supervised Sketch-based Image Retrieval (S3BIR) |
Our approaches, S3BIR-CLIP and S3BIR-DINOv2, can produce a bimodal sketch-photo feature space from unpaired data without explicit sketch-photo pairs. Our experiments perform outstandingly in three diverse public datasets where the models are trained without real sketches. |
""" |
st.sidebar.markdown(description) |
stroke_width = st.sidebar.slider("Stroke width: ", 1, 25, 5) |
st.markdown( |
""" |
<style> |
.block-container{ max-width: 1200px; } |
div.row-widget > div{ flex-direction: row; display: flex; justify-content: center; color: white; } |
div.row-widget.stRadio > div > label{ margin-left: 5px; margin-right: 5px; } |
.row-widget { margin-top: -25px; } |
section > div:first-child { padding-top: 30px; } |
div.appview-container > section:first-child{ max-width: 320px; } |
#MainMenu { visibility: hidden; } |
.stMarkdown { display: grid; place-items: center; } |
</style> |
""", |
unsafe_allow_html=True, |
) |
st.title("S3BIR App") |
_, col, _ = st.columns((1, 1, 1)) |
with col: |
canvas_result = st_canvas( |
background_color="#eee", |
stroke_width=stroke_width, |
update_streamlit=True, |
height=300, |
width=300, |
key="color_annotation_app", |
) |
corpus = ["Ecommerce"] |
st.columns((1, 3, 1)) |
if canvas_result.image_data is not None: |
draw = Image.fromarray(canvas_result.image_data.astype("uint8")) |
draw = ImageOps.pad(draw.convert("RGB"), size=(224, 224)) |
draw_tensor = transforms.ToTensor()(draw) |
draw_tensor = transforms.Resize((224, 224))(draw_tensor) |
draw_tensor = transforms.Normalize( |
mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225] |
)(draw_tensor) |
draw_tensor = draw_tensor.unsqueeze(0) |
retrieved, _ = image_search( |
draw_tensor, corpus[0], st.session_state.model, st.session_state.embeddings) |
imgs, xs, ys = get_images([x[0] for x in retrieved]) |
encoded_images = [] |
for image_idx in range(len(imgs)): |
img0, x, y = imgs[image_idx], xs[image_idx], ys[image_idx] |
new_x, new_y = int(x * HEIGHT / y), HEIGHT |
encoded_images.append(convert_pil_to_base64( |
img0.resize((new_x, new_y)))) |
st.markdown(get_html(retrieved, encoded_images), |
unsafe_allow_html=True) |
if __name__ == "__main__": |
main() |