S3BIR / app.py
CHSTR's picture
se modifica el path de lectura
9fb1bd4
import os
import streamlit as st
from io import BytesIO
from multiprocessing.dummy import Pool
import base64
from PIL import Image, ImageOps
import torch
import numpy as np
from torchvision import transforms
from streamlit_drawable_canvas import st_canvas
from src.model_LN_prompt import Model
from html import escape
import pickle as pkl
from huggingface_hub import hf_hub_download, login
from datasets import load_dataset
if 'initialized' not in st.session_state:
st.session_state.initialized = False
device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")
HEIGHT = 200
N_RESULTS = 20
color = st.get_option("theme.primaryColor")
if color is None:
color = (0, 0, 255)
else:
color = tuple(int(color.lstrip("#")[i: i + 2], 16) for i in (0, 2, 4))
@st.cache_resource
def initialize_huggingface():
token = os.getenv("HUGGINGFACE_TOKEN")
if token:
login(token=token)
else:
st.error("HUGGINGFACE_TOKEN not found in environment variables")
@st.cache_resource
def load_model_and_data():
print("Loading everything...")
dataset = load_dataset("CHSTR/ecommerce")
path_images = "./data/" + "" #"/".join(dataset['validation']['image'][0].filename.split("/")[:-3]) + "/"
# Download model
path_model = hf_hub_download(
repo_id="CHSTR/Ecommerce", filename="dinov2_ecommerce.ckpt")
# Load model
model = Model().to(device)
model_checkpoint = torch.load(path_model, map_location=device)
model.load_state_dict(model_checkpoint['state_dict'])
model.eval()
# Download and load embeddings
embeddings_file = hf_hub_download(
repo_id="CHSTR/Ecommerce", filename="ecommerce_demo.pkl")
embeddings = {
0: pkl.load(open(embeddings_file, "rb")),
1: pkl.load(open(embeddings_file, "rb"))
}
# Update image paths
for corpus_id in [0, 1]:
embeddings[corpus_id] = [
(emb[0], path_images + "/".join(emb[1].split("/")[-3:]))
for emb in embeddings[corpus_id]
]
return model, path_images, embeddings
def compute_sketch(_sketch, model):
with torch.no_grad():
sketch_feat = model(_sketch.to(device), dtype='sketch')
return sketch_feat
def image_search(_query, corpus, model, embeddings, n_results=N_RESULTS):
query_embedding = compute_sketch(_query, model)
corpus_id = 0 if corpus == "Unsplash" else 1
image_features = torch.from_numpy(
np.array([item[0] for item in embeddings[corpus_id]])
).to(device)
dot_product = (image_features @ query_embedding.T)[:, 0]
_, max_indices = torch.topk(
dot_product, n_results, dim=0, largest=True, sorted=True)
path_to_label = {path: idx for idx,
(_, path) in enumerate(embeddings[corpus_id])}
label_to_path = {idx: path for path, idx in path_to_label.items()}
label_of_images = torch.tensor(
[path_to_label[item[1]] for item in embeddings[corpus_id]]).to(device)
return [
(label_to_path[i],)
for i in label_of_images[max_indices].cpu().numpy().tolist()
], dot_product[max_indices]
@st.cache_data
def make_square(img_path, fill_color=(255, 255, 255)):
img = Image.open(img_path)
x, y = img.size
size = max(x, y)
new_img = Image.new("RGB", (x, y), fill_color)
new_img.paste(img)
return new_img, x, y
@st.cache_data
def get_images(paths):
processed = [make_square(path) for path in paths]
imgs, xs, ys = zip(*processed)
return list(imgs), list(xs), list(ys)
@st.cache_data
def convert_pil_to_base64(image):
img_buffer = BytesIO()
image.save(img_buffer, format="JPEG")
byte_data = img_buffer.getvalue()
base64_str = base64.b64encode(byte_data)
return base64_str
def get_html(url_list, encoded_images):
html = "<div style='margin-top: 20px; max-width: 1200px; display: flex; flex-wrap: wrap; justify-content: space-evenly'>"
for i in range(len(url_list)):
title, encoded = url_list[i][0], encoded_images[i]
html = (
html
+ f"<img title='{escape(title)}' style='height: {HEIGHT}px; margin: 5px' src='data:image/jpeg;base64,{encoded.decode()}'>"
)
html += "</div>"
return html
def main():
if not st.session_state.initialized:
initialize_huggingface()
st.session_state.model, st.session_state.path_images, st.session_state.embeddings = load_model_and_data()
st.session_state.initialized = True
description = """
# Self-Supervised Sketch-based Image Retrieval (S3BIR)
Our approaches, S3BIR-CLIP and S3BIR-DINOv2, can produce a bimodal sketch-photo feature space from unpaired data without explicit sketch-photo pairs. Our experiments perform outstandingly in three diverse public datasets where the models are trained without real sketches.
"""
st.sidebar.markdown(description)
stroke_width = st.sidebar.slider("Stroke width: ", 1, 25, 5)
# styles
st.markdown(
"""
<style>
.block-container{ max-width: 1200px; }
div.row-widget > div{ flex-direction: row; display: flex; justify-content: center; color: white; }
div.row-widget.stRadio > div > label{ margin-left: 5px; margin-right: 5px; }
.row-widget { margin-top: -25px; }
section > div:first-child { padding-top: 30px; }
div.appview-container > section:first-child{ max-width: 320px; }
#MainMenu { visibility: hidden; }
.stMarkdown { display: grid; place-items: center; }
</style>
""",
unsafe_allow_html=True,
)
st.title("S3BIR App")
_, col, _ = st.columns((1, 1, 1))
with col:
canvas_result = st_canvas(
background_color="#eee",
stroke_width=stroke_width,
update_streamlit=True,
height=300,
width=300,
key="color_annotation_app",
)
corpus = ["Ecommerce"]
st.columns((1, 3, 1))
if canvas_result.image_data is not None:
draw = Image.fromarray(canvas_result.image_data.astype("uint8"))
draw = ImageOps.pad(draw.convert("RGB"), size=(224, 224))
draw_tensor = transforms.ToTensor()(draw)
draw_tensor = transforms.Resize((224, 224))(draw_tensor)
draw_tensor = transforms.Normalize(
mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225]
)(draw_tensor)
draw_tensor = draw_tensor.unsqueeze(0)
retrieved, _ = image_search(
draw_tensor, corpus[0], st.session_state.model, st.session_state.embeddings)
imgs, xs, ys = get_images([x[0] for x in retrieved])
encoded_images = []
for image_idx in range(len(imgs)):
img0, x, y = imgs[image_idx], xs[image_idx], ys[image_idx]
new_x, new_y = int(x * HEIGHT / y), HEIGHT
encoded_images.append(convert_pil_to_base64(
img0.resize((new_x, new_y))))
st.markdown(get_html(retrieved, encoded_images),
unsafe_allow_html=True)
if __name__ == "__main__":
main()