|
import os |
|
import streamlit as st |
|
from io import BytesIO |
|
from multiprocessing.dummy import Pool |
|
import base64 |
|
from PIL import Image, ImageOps |
|
import torch |
|
import numpy as np |
|
from torchvision import transforms |
|
from streamlit_drawable_canvas import st_canvas |
|
from src.model_LN_prompt import Model |
|
from html import escape |
|
import pickle as pkl |
|
from huggingface_hub import hf_hub_download, login |
|
from datasets import load_dataset |
|
|
|
|
|
if 'initialized' not in st.session_state: |
|
st.session_state.initialized = False |
|
|
|
device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu") |
|
HEIGHT = 200 |
|
N_RESULTS = 20 |
|
color = st.get_option("theme.primaryColor") |
|
if color is None: |
|
color = (0, 0, 255) |
|
else: |
|
color = tuple(int(color.lstrip("#")[i: i + 2], 16) for i in (0, 2, 4)) |
|
|
|
|
|
@st.cache_resource |
|
def initialize_huggingface(): |
|
token = os.getenv("HUGGINGFACE_TOKEN") |
|
if token: |
|
login(token=token) |
|
else: |
|
st.error("HUGGINGFACE_TOKEN not found in environment variables") |
|
|
|
|
|
@st.cache_resource |
|
def load_model_and_data(): |
|
print("Loading everything...") |
|
dataset = load_dataset("CHSTR/ecommerce") |
|
path_images = "./data/" + "" |
|
|
|
|
|
path_model = hf_hub_download( |
|
repo_id="CHSTR/Ecommerce", filename="dinov2_ecommerce.ckpt") |
|
|
|
|
|
model = Model().to(device) |
|
model_checkpoint = torch.load(path_model, map_location=device) |
|
model.load_state_dict(model_checkpoint['state_dict']) |
|
model.eval() |
|
|
|
|
|
embeddings_file = hf_hub_download( |
|
repo_id="CHSTR/Ecommerce", filename="ecommerce_demo.pkl") |
|
|
|
embeddings = { |
|
0: pkl.load(open(embeddings_file, "rb")), |
|
1: pkl.load(open(embeddings_file, "rb")) |
|
} |
|
|
|
|
|
for corpus_id in [0, 1]: |
|
embeddings[corpus_id] = [ |
|
(emb[0], path_images + "/".join(emb[1].split("/")[-3:])) |
|
for emb in embeddings[corpus_id] |
|
] |
|
|
|
return model, path_images, embeddings |
|
|
|
|
|
def compute_sketch(_sketch, model): |
|
with torch.no_grad(): |
|
sketch_feat = model(_sketch.to(device), dtype='sketch') |
|
return sketch_feat |
|
|
|
|
|
def image_search(_query, corpus, model, embeddings, n_results=N_RESULTS): |
|
query_embedding = compute_sketch(_query, model) |
|
corpus_id = 0 if corpus == "Unsplash" else 1 |
|
image_features = torch.from_numpy( |
|
np.array([item[0] for item in embeddings[corpus_id]]) |
|
).to(device) |
|
|
|
dot_product = (image_features @ query_embedding.T)[:, 0] |
|
_, max_indices = torch.topk( |
|
dot_product, n_results, dim=0, largest=True, sorted=True) |
|
|
|
path_to_label = {path: idx for idx, |
|
(_, path) in enumerate(embeddings[corpus_id])} |
|
label_to_path = {idx: path for path, idx in path_to_label.items()} |
|
label_of_images = torch.tensor( |
|
[path_to_label[item[1]] for item in embeddings[corpus_id]]).to(device) |
|
|
|
return [ |
|
(label_to_path[i],) |
|
for i in label_of_images[max_indices].cpu().numpy().tolist() |
|
], dot_product[max_indices] |
|
|
|
|
|
@st.cache_data |
|
def make_square(img_path, fill_color=(255, 255, 255)): |
|
img = Image.open(img_path) |
|
x, y = img.size |
|
size = max(x, y) |
|
new_img = Image.new("RGB", (x, y), fill_color) |
|
new_img.paste(img) |
|
return new_img, x, y |
|
|
|
|
|
@st.cache_data |
|
def get_images(paths): |
|
processed = [make_square(path) for path in paths] |
|
imgs, xs, ys = zip(*processed) |
|
return list(imgs), list(xs), list(ys) |
|
|
|
|
|
@st.cache_data |
|
def convert_pil_to_base64(image): |
|
img_buffer = BytesIO() |
|
image.save(img_buffer, format="JPEG") |
|
byte_data = img_buffer.getvalue() |
|
base64_str = base64.b64encode(byte_data) |
|
return base64_str |
|
|
|
|
|
def get_html(url_list, encoded_images): |
|
html = "<div style='margin-top: 20px; max-width: 1200px; display: flex; flex-wrap: wrap; justify-content: space-evenly'>" |
|
for i in range(len(url_list)): |
|
title, encoded = url_list[i][0], encoded_images[i] |
|
html = ( |
|
html |
|
+ f"<img title='{escape(title)}' style='height: {HEIGHT}px; margin: 5px' src='data:image/jpeg;base64,{encoded.decode()}'>" |
|
) |
|
html += "</div>" |
|
return html |
|
|
|
|
|
def main(): |
|
if not st.session_state.initialized: |
|
initialize_huggingface() |
|
st.session_state.model, st.session_state.path_images, st.session_state.embeddings = load_model_and_data() |
|
st.session_state.initialized = True |
|
|
|
description = """ |
|
# Self-Supervised Sketch-based Image Retrieval (S3BIR) |
|
|
|
Our approaches, S3BIR-CLIP and S3BIR-DINOv2, can produce a bimodal sketch-photo feature space from unpaired data without explicit sketch-photo pairs. Our experiments perform outstandingly in three diverse public datasets where the models are trained without real sketches. |
|
|
|
""" |
|
|
|
st.sidebar.markdown(description) |
|
stroke_width = st.sidebar.slider("Stroke width: ", 1, 25, 5) |
|
|
|
|
|
st.markdown( |
|
""" |
|
<style> |
|
.block-container{ max-width: 1200px; } |
|
div.row-widget > div{ flex-direction: row; display: flex; justify-content: center; color: white; } |
|
div.row-widget.stRadio > div > label{ margin-left: 5px; margin-right: 5px; } |
|
.row-widget { margin-top: -25px; } |
|
section > div:first-child { padding-top: 30px; } |
|
div.appview-container > section:first-child{ max-width: 320px; } |
|
#MainMenu { visibility: hidden; } |
|
.stMarkdown { display: grid; place-items: center; } |
|
</style> |
|
""", |
|
unsafe_allow_html=True, |
|
) |
|
|
|
st.title("S3BIR App") |
|
_, col, _ = st.columns((1, 1, 1)) |
|
with col: |
|
canvas_result = st_canvas( |
|
background_color="#eee", |
|
stroke_width=stroke_width, |
|
update_streamlit=True, |
|
height=300, |
|
width=300, |
|
key="color_annotation_app", |
|
) |
|
|
|
corpus = ["Ecommerce"] |
|
st.columns((1, 3, 1)) |
|
|
|
if canvas_result.image_data is not None: |
|
draw = Image.fromarray(canvas_result.image_data.astype("uint8")) |
|
draw = ImageOps.pad(draw.convert("RGB"), size=(224, 224)) |
|
|
|
draw_tensor = transforms.ToTensor()(draw) |
|
draw_tensor = transforms.Resize((224, 224))(draw_tensor) |
|
draw_tensor = transforms.Normalize( |
|
mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225] |
|
)(draw_tensor) |
|
draw_tensor = draw_tensor.unsqueeze(0) |
|
|
|
retrieved, _ = image_search( |
|
draw_tensor, corpus[0], st.session_state.model, st.session_state.embeddings) |
|
imgs, xs, ys = get_images([x[0] for x in retrieved]) |
|
|
|
encoded_images = [] |
|
for image_idx in range(len(imgs)): |
|
img0, x, y = imgs[image_idx], xs[image_idx], ys[image_idx] |
|
new_x, new_y = int(x * HEIGHT / y), HEIGHT |
|
encoded_images.append(convert_pil_to_base64( |
|
img0.resize((new_x, new_y)))) |
|
|
|
st.markdown(get_html(retrieved, encoded_images), |
|
unsafe_allow_html=True) |
|
|
|
|
|
if __name__ == "__main__": |
|
main() |
|
|