CHSTR commited on
Commit
3e0e9f4
·
1 Parent(s): db47434

S3BIR app demo

Browse files
This view is limited to 50 files because it contains too many changes.   See raw diff
Files changed (50) hide show
  1. app.py +271 -0
  2. data/valid/Almohadas_y_cojines/6474eb7beca04255e216354707604caa.jpg +0 -0
  3. data/valid/Almohadas_y_cojines/b6f4b43eb8e47193358b0b3b69b9cd4d.jpg +0 -0
  4. data/valid/Almohadas_y_cojines/b6f4b43eb8e47193358b0b3b69b9cd4d_1.jpg +0 -0
  5. data/valid/Almohadas_y_cojines/b6f4b43eb8e47193358b0b3b69b9cd4d_2.jpg +0 -0
  6. data/valid/Almohadas_y_cojines/b6f4b43eb8e47193358b0b3b69b9cd4d_3.jpg +0 -0
  7. data/valid/Almohadas_y_cojines/b6f4b43eb8e47193358b0b3b69b9cd4d_4.jpg +0 -0
  8. data/valid/Almohadas_y_cojines/b6f4b43eb8e47193358b0b3b69b9cd4d_5.jpg +0 -0
  9. data/valid/Almohadas_y_cojines/d319582ad5976fa0526871af907d75e5.jpg +0 -0
  10. data/valid/Baberos/134673c99a13f9f17bb4a3420aa830bb.jpg +0 -0
  11. data/valid/Baberos/134673c99a13f9f17bb4a3420aa830bb_1.jpg +0 -0
  12. data/valid/Baberos/134673c99a13f9f17bb4a3420aa830bb_2.jpg +0 -0
  13. data/valid/Baberos/134673c99a13f9f17bb4a3420aa830bb_3.jpg +0 -0
  14. data/valid/Baberos/134673c99a13f9f17bb4a3420aa830bb_4.jpg +0 -0
  15. data/valid/Baberos/134673c99a13f9f17bb4a3420aa830bb_5.jpg +0 -0
  16. data/valid/Baberos/80ff9b872165c3f13c72859dbbcbd4a4.jpg +0 -0
  17. data/valid/Baberos/80ff9b872165c3f13c72859dbbcbd4a4_1.jpg +0 -0
  18. data/valid/Baberos/80ff9b872165c3f13c72859dbbcbd4a4_2.jpg +0 -0
  19. data/valid/Baberos/80ff9b872165c3f13c72859dbbcbd4a4_3.jpg +0 -0
  20. data/valid/Baberos/80ff9b872165c3f13c72859dbbcbd4a4_4.jpg +0 -0
  21. data/valid/Baberos/80ff9b872165c3f13c72859dbbcbd4a4_5.jpg +0 -0
  22. data/valid/Baberos/80ff9b872165c3f13c72859dbbcbd4a4_6.jpg +0 -0
  23. data/valid/Baberos/9daca95e1bd0ca6aad1812e44007a2ea.jpg +0 -0
  24. data/valid/Baberos/9daca95e1bd0ca6aad1812e44007a2ea_1.jpg +0 -0
  25. data/valid/Baberos/9daca95e1bd0ca6aad1812e44007a2ea_2.jpg +0 -0
  26. data/valid/Baberos/9daca95e1bd0ca6aad1812e44007a2ea_3.jpg +0 -0
  27. data/valid/Baberos/9daca95e1bd0ca6aad1812e44007a2ea_4.jpg +0 -0
  28. data/valid/Baberos/9daca95e1bd0ca6aad1812e44007a2ea_5.jpg +0 -0
  29. data/valid/Baberos/c4bb79af1cdae49467eea9efca2ee32c.jpg +0 -0
  30. data/valid/Baberos/c4bb79af1cdae49467eea9efca2ee32c_1.jpg +0 -0
  31. data/valid/Baberos/c4bb79af1cdae49467eea9efca2ee32c_2.jpg +0 -0
  32. data/valid/Baberos/c4bb79af1cdae49467eea9efca2ee32c_3.jpg +0 -0
  33. data/valid/Baberos/c4bb79af1cdae49467eea9efca2ee32c_4.jpg +0 -0
  34. data/valid/Baberos/c4bb79af1cdae49467eea9efca2ee32c_5.jpg +0 -0
  35. data/valid/Baberos/c4bb79af1cdae49467eea9efca2ee32c_6.jpg +0 -0
  36. data/valid/Baberos/c939ccd756d45577cb28f93ee5486a59.jpg +0 -0
  37. data/valid/Baberos/c939ccd756d45577cb28f93ee5486a59_1.jpg +0 -0
  38. data/valid/Baberos/c939ccd756d45577cb28f93ee5486a59_2.jpg +0 -0
  39. data/valid/Baberos/c939ccd756d45577cb28f93ee5486a59_3.jpg +0 -0
  40. data/valid/Baberos/c939ccd756d45577cb28f93ee5486a59_4.jpg +0 -0
  41. data/valid/Baberos/c939ccd756d45577cb28f93ee5486a59_5.jpg +0 -0
  42. data/valid/Baberos/cce281a309ee213c364cfa0bd62ba1f2.jpg +0 -0
  43. data/valid/Baberos/cce281a309ee213c364cfa0bd62ba1f2_1.jpg +0 -0
  44. data/valid/Baberos/cce281a309ee213c364cfa0bd62ba1f2_2.jpg +0 -0
  45. data/valid/Baberos/cce281a309ee213c364cfa0bd62ba1f2_3.jpg +0 -0
  46. data/valid/Baberos/cce281a309ee213c364cfa0bd62ba1f2_4.jpg +0 -0
  47. data/valid/Baberos/cce281a309ee213c364cfa0bd62ba1f2_5.jpg +0 -0
  48. data/valid/Baberos/cce281a309ee213c364cfa0bd62ba1f2_6.jpg +0 -0
  49. data/valid/Baberos/e3188f410d687d5e9c939cf9dcc85bc8.jpg +0 -0
  50. data/valid/Baberos/e3188f410d687d5e9c939cf9dcc85bc8_1.jpg +0 -0
app.py ADDED
@@ -0,0 +1,271 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import os
2
+
3
+ import streamlit as st
4
+ from io import BytesIO
5
+ import base64
6
+ from multiprocessing.dummy import Pool
7
+ from PIL import Image, ImageDraw
8
+
9
+ import torch
10
+ from torchvision import transforms
11
+
12
+ # sketches
13
+ from streamlit_drawable_canvas import st_canvas
14
+ from PIL import Image, ImageOps
15
+ from torchvision import transforms
16
+ from src.model_LN_prompt import Model
17
+
18
+
19
+ import pickle as pkl
20
+ from html import escape
21
+ from huggingface_hub import hf_hub_download,login
22
+
23
+ token = os.getenv("HUGGINGFACE_TOKEN")
24
+
25
+ # Autentica usando el token
26
+ login(token=token)
27
+
28
+ # Variables
29
+ device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
30
+ HEIGHT = 200
31
+ N_RESULTS = 15
32
+ color = st.get_option("theme.primaryColor")
33
+ if color is None:
34
+ color = (0, 0, 255)
35
+ else:
36
+ color = tuple(int(color.lstrip("#")[i: i + 2], 16) for i in (0, 2, 4))
37
+
38
+
39
+ @st.cache_resource
40
+ def load():
41
+ device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
42
+ path_images = 'data'
43
+
44
+ # Descargar el modelo desde Hugging Face
45
+ path_model = hf_hub_download(repo_id="CHSTR/Ecommerce", filename="dinov2_ecommerce.ckpt")
46
+ print(f"Archivo del modelo descargado en: {path_model}")
47
+
48
+ # Cargar el modelo
49
+ model = Model().to(device)
50
+ model_checkpoint = torch.load(path_model, map_location=device)
51
+ model.load_state_dict(model_checkpoint['state_dict'])
52
+ model.eval()
53
+ print("Modelo cargado exitosamente")
54
+
55
+ # Descargar y cargar los embeddings desde Hugging Face
56
+ embeddings_file = hf_hub_download(repo_id="CHSTR/Ecommerce", filename="ecommerce_demo.pkl")
57
+ print(f"Archivo de embeddings descargado en: {embeddings_file}")
58
+
59
+ embeddings = {
60
+ 0: pkl.load(open(embeddings_file, "rb")),
61
+ 1: pkl.load(open(embeddings_file, "rb"))
62
+ }
63
+
64
+ # Actualizar los paths de las imágenes en los embeddings
65
+ for i in range(len(embeddings[0])):
66
+ embeddings[0][i] = (embeddings[0][i][0], path_images + embeddings[0][i][1].split("/images")[-1])
67
+
68
+ for i in range(len(embeddings[1])):
69
+ embeddings[1][i] = (embeddings[1][i][0], path_images + embeddings[1][i][1].split("/images")[-1])
70
+
71
+ return model, path_images, embeddings
72
+
73
+
74
+ def compute_text_embeddings(sketch):
75
+ with torch.no_grad():
76
+ sketch_feat = model(sketch.to(device), dtype='sketch')
77
+ return sketch_feat
78
+
79
+
80
+ def image_search(query, corpus, n_results=N_RESULTS):
81
+ query_embedding = compute_text_embeddings(query)
82
+ corpus_id = 0 if corpus == "Unsplash" else 1
83
+ image_features = torch.tensor(
84
+ [item[0] for item in embeddings[corpus_id]]).to(device)
85
+
86
+ dot_product = (image_features @ query_embedding.T)[:, 0]
87
+ _, max_indices = torch.topk(
88
+ dot_product, n_results, dim=0, largest=True, sorted=True)
89
+
90
+ # Diccionario para mapear los paths a labels
91
+ path_to_label = {path: idx for idx,
92
+ (_, path) in enumerate(embeddings[corpus_id])}
93
+ label_to_path = {idx: path for path, idx in path_to_label.items()}
94
+ label_of_images = torch.tensor(
95
+ [path_to_label[item[1]] for item in embeddings[corpus_id]]).to(device)
96
+
97
+ return [
98
+ (
99
+ # path_images + "page" + str(i) + ".jpg", # DocExplore
100
+ label_to_path[i], # DocExplore
101
+ )
102
+ for i in label_of_images[max_indices].cpu().numpy().tolist()
103
+ ], dot_product[max_indices] # bbox_of_images[max_indices], dot_product[max_indices]
104
+
105
+
106
+ def make_square(img, fill_color=(255, 255, 255)):
107
+ x, y = img.size
108
+ size = max(x, y)
109
+ new_img = Image.new("RGB", (x, y), fill_color)
110
+ new_img.paste(img)
111
+ return new_img, x, y
112
+
113
+
114
+ @st.cache_data
115
+ def get_images(paths):
116
+ def process_image(path):
117
+ return make_square(Image.open(path))
118
+
119
+ processed = Pool(N_RESULTS).map(process_image, paths)
120
+ imgs, xs, ys = [], [], []
121
+ for img, x, y in processed:
122
+ imgs.append(img)
123
+ xs.append(x)
124
+ ys.append(y)
125
+ return imgs, xs, ys
126
+
127
+
128
+ def convert_pil_to_base64(image):
129
+ img_buffer = BytesIO()
130
+ image.save(img_buffer, format="JPEG")
131
+ byte_data = img_buffer.getvalue()
132
+ base64_str = base64.b64encode(byte_data)
133
+ return base64_str
134
+
135
+
136
+ def draw_reshape_encode(img, boxes, x, y):
137
+ boxes = [boxes.tolist()]
138
+ image = img.copy()
139
+ draw = ImageDraw.Draw(image)
140
+ new_x, new_y = int(x * HEIGHT / y), HEIGHT
141
+ for box in boxes:
142
+ print("box:", box)
143
+ draw.rectangle(
144
+ # (x_min, y_min, x_max, y_max)
145
+ [(box[0], box[1]), (box[2], box[3])],
146
+ outline=color, # Box color
147
+ width=7 # Box width
148
+ )
149
+
150
+
151
+ def get_html(url_list, encoded_images):
152
+ html = "<div style='margin-top: 20px; max-width: 1200px; display: flex; flex-wrap: wrap; justify-content: space-evenly'>"
153
+ for i in range(len(url_list)):
154
+ title, encoded = url_list[i][0], encoded_images[i]
155
+ html = (
156
+ html
157
+ + f"<img title='{escape(title)}' style='height: {HEIGHT}px; margin: 5px' src='data:image/jpeg;base64,{encoded.decode()}'>"
158
+ )
159
+ html += "</div>"
160
+ return html
161
+
162
+
163
+ description = """
164
+ # Sketch-based Image Retrieval (SBIR)
165
+ """
166
+
167
+ div_style = {
168
+ "display": "flex",
169
+ "justify-content": "center",
170
+ "flex-wrap": "wrap",
171
+ }
172
+
173
+
174
+ print("Cargando modelos...")
175
+ model, path_images, embeddings = load()
176
+ source = {0: "\Ecommerce", 1: "\nNone"}
177
+
178
+ stroke_width = st.sidebar.slider("Stroke width: ", 1, 25, 5)
179
+
180
+ dataset_transforms = transforms.Compose([
181
+ transforms.Resize((224, 224)),
182
+ transforms.ToTensor(),
183
+ transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225])
184
+ ])
185
+
186
+
187
+ def main():
188
+ st.markdown(
189
+ """
190
+ <style>
191
+ .block-container{
192
+ max-width: 1200px;
193
+ }
194
+ div.row-widget > div{
195
+ flex-direction: row;
196
+ display: flex;
197
+ justify-content: center;
198
+ }
199
+ div.row-widget.stRadio > div > label{
200
+ margin-left: 5px;
201
+ margin-right: 5px;
202
+ }
203
+ .row-widget {
204
+ margin-top: -25px;
205
+ }
206
+ section > div:first-child {
207
+ padding-top: 30px;
208
+ }
209
+ div.appview-container > section:first-child{
210
+ max-width: 320px;
211
+ }
212
+ #MainMenu {
213
+ visibility: hidden;
214
+ }
215
+ .stMarkdown {
216
+ display: grid;
217
+ place-items: center;
218
+ }
219
+ </style>
220
+ """,
221
+ unsafe_allow_html=True,
222
+ )
223
+ st.sidebar.markdown(description)
224
+
225
+ st.title("SBIR App")
226
+ _, col, _ = st.columns((1, 1, 1))
227
+ with col:
228
+ canvas_result = st_canvas(
229
+ background_color="#eee",
230
+ stroke_width=stroke_width,
231
+ update_streamlit=True,
232
+ height=300,
233
+ width=300,
234
+ key="color_annotation_app",
235
+ )
236
+
237
+ _, c, _ = st.columns((1, 3, 1))
238
+ query = ["koala"] # c.text_input("", value="koala")
239
+ corpus = c.radio("", ["Ecommerce"])
240
+
241
+ if canvas_result.image_data is not None:
242
+ draw = Image.fromarray(canvas_result.image_data.astype("uint8"))
243
+ draw = ImageOps.pad(draw.convert("RGB"), size=(224, 224))
244
+ draw.save("draw.jpg")
245
+
246
+ draw_tensor = transforms.ToTensor()(draw)
247
+ draw_tensor = transforms.Resize((224, 224))(draw_tensor)
248
+ draw_tensor = transforms.Normalize(
249
+ mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225]
250
+ )(draw_tensor)
251
+ draw_tensor = draw_tensor.unsqueeze(0)
252
+ else:
253
+ return
254
+
255
+ if len(query) > 0:
256
+ retrieved, dot_product = image_search(draw_tensor, corpus)
257
+ imgs, xs, ys = get_images([x[0] for x in retrieved])
258
+ encoded_images = []
259
+ for image_idx in range(len(imgs)):
260
+ img0, x, y = imgs[image_idx], xs[image_idx], ys[image_idx]
261
+
262
+ new_x, new_y = int(x * HEIGHT / y), HEIGHT
263
+
264
+ encoded_images.append(convert_pil_to_base64(
265
+ img0.resize((new_x, new_y))))
266
+ st.markdown(get_html(retrieved, encoded_images),
267
+ unsafe_allow_html=True)
268
+
269
+
270
+ if __name__ == "__main__":
271
+ main()
data/valid/Almohadas_y_cojines/6474eb7beca04255e216354707604caa.jpg ADDED
data/valid/Almohadas_y_cojines/b6f4b43eb8e47193358b0b3b69b9cd4d.jpg ADDED
data/valid/Almohadas_y_cojines/b6f4b43eb8e47193358b0b3b69b9cd4d_1.jpg ADDED
data/valid/Almohadas_y_cojines/b6f4b43eb8e47193358b0b3b69b9cd4d_2.jpg ADDED
data/valid/Almohadas_y_cojines/b6f4b43eb8e47193358b0b3b69b9cd4d_3.jpg ADDED
data/valid/Almohadas_y_cojines/b6f4b43eb8e47193358b0b3b69b9cd4d_4.jpg ADDED
data/valid/Almohadas_y_cojines/b6f4b43eb8e47193358b0b3b69b9cd4d_5.jpg ADDED
data/valid/Almohadas_y_cojines/d319582ad5976fa0526871af907d75e5.jpg ADDED
data/valid/Baberos/134673c99a13f9f17bb4a3420aa830bb.jpg ADDED
data/valid/Baberos/134673c99a13f9f17bb4a3420aa830bb_1.jpg ADDED
data/valid/Baberos/134673c99a13f9f17bb4a3420aa830bb_2.jpg ADDED
data/valid/Baberos/134673c99a13f9f17bb4a3420aa830bb_3.jpg ADDED
data/valid/Baberos/134673c99a13f9f17bb4a3420aa830bb_4.jpg ADDED
data/valid/Baberos/134673c99a13f9f17bb4a3420aa830bb_5.jpg ADDED
data/valid/Baberos/80ff9b872165c3f13c72859dbbcbd4a4.jpg ADDED
data/valid/Baberos/80ff9b872165c3f13c72859dbbcbd4a4_1.jpg ADDED
data/valid/Baberos/80ff9b872165c3f13c72859dbbcbd4a4_2.jpg ADDED
data/valid/Baberos/80ff9b872165c3f13c72859dbbcbd4a4_3.jpg ADDED
data/valid/Baberos/80ff9b872165c3f13c72859dbbcbd4a4_4.jpg ADDED
data/valid/Baberos/80ff9b872165c3f13c72859dbbcbd4a4_5.jpg ADDED
data/valid/Baberos/80ff9b872165c3f13c72859dbbcbd4a4_6.jpg ADDED
data/valid/Baberos/9daca95e1bd0ca6aad1812e44007a2ea.jpg ADDED
data/valid/Baberos/9daca95e1bd0ca6aad1812e44007a2ea_1.jpg ADDED
data/valid/Baberos/9daca95e1bd0ca6aad1812e44007a2ea_2.jpg ADDED
data/valid/Baberos/9daca95e1bd0ca6aad1812e44007a2ea_3.jpg ADDED
data/valid/Baberos/9daca95e1bd0ca6aad1812e44007a2ea_4.jpg ADDED
data/valid/Baberos/9daca95e1bd0ca6aad1812e44007a2ea_5.jpg ADDED
data/valid/Baberos/c4bb79af1cdae49467eea9efca2ee32c.jpg ADDED
data/valid/Baberos/c4bb79af1cdae49467eea9efca2ee32c_1.jpg ADDED
data/valid/Baberos/c4bb79af1cdae49467eea9efca2ee32c_2.jpg ADDED
data/valid/Baberos/c4bb79af1cdae49467eea9efca2ee32c_3.jpg ADDED
data/valid/Baberos/c4bb79af1cdae49467eea9efca2ee32c_4.jpg ADDED
data/valid/Baberos/c4bb79af1cdae49467eea9efca2ee32c_5.jpg ADDED
data/valid/Baberos/c4bb79af1cdae49467eea9efca2ee32c_6.jpg ADDED
data/valid/Baberos/c939ccd756d45577cb28f93ee5486a59.jpg ADDED
data/valid/Baberos/c939ccd756d45577cb28f93ee5486a59_1.jpg ADDED
data/valid/Baberos/c939ccd756d45577cb28f93ee5486a59_2.jpg ADDED
data/valid/Baberos/c939ccd756d45577cb28f93ee5486a59_3.jpg ADDED
data/valid/Baberos/c939ccd756d45577cb28f93ee5486a59_4.jpg ADDED
data/valid/Baberos/c939ccd756d45577cb28f93ee5486a59_5.jpg ADDED
data/valid/Baberos/cce281a309ee213c364cfa0bd62ba1f2.jpg ADDED
data/valid/Baberos/cce281a309ee213c364cfa0bd62ba1f2_1.jpg ADDED
data/valid/Baberos/cce281a309ee213c364cfa0bd62ba1f2_2.jpg ADDED
data/valid/Baberos/cce281a309ee213c364cfa0bd62ba1f2_3.jpg ADDED
data/valid/Baberos/cce281a309ee213c364cfa0bd62ba1f2_4.jpg ADDED
data/valid/Baberos/cce281a309ee213c364cfa0bd62ba1f2_5.jpg ADDED
data/valid/Baberos/cce281a309ee213c364cfa0bd62ba1f2_6.jpg ADDED
data/valid/Baberos/e3188f410d687d5e9c939cf9dcc85bc8.jpg ADDED
data/valid/Baberos/e3188f410d687d5e9c939cf9dcc85bc8_1.jpg ADDED