Se utiliza el dataset desde hugginface
Browse files- __pycache__/utils.cpython-310.pyc +0 -0
- app.py +68 -120
- src/__pycache__/model_LN_prompt.cpython-310.pyc +0 -0
- src/__pycache__/options.cpython-310.pyc +0 -0
- src/model_LN_prompt.py +0 -18
- src/options.py +4 -5
__pycache__/utils.cpython-310.pyc
ADDED
|
Binary file (3.75 kB). View file
|
|
|
app.py
CHANGED
|
@@ -1,90 +1,85 @@
|
|
| 1 |
import os
|
| 2 |
-
|
| 3 |
import streamlit as st
|
| 4 |
from io import BytesIO
|
| 5 |
-
import base64
|
| 6 |
from multiprocessing.dummy import Pool
|
| 7 |
-
|
| 8 |
-
|
| 9 |
import torch
|
| 10 |
from torchvision import transforms
|
| 11 |
-
|
| 12 |
-
# sketches
|
| 13 |
from streamlit_drawable_canvas import st_canvas
|
| 14 |
from src.model_LN_prompt import Model
|
| 15 |
-
|
| 16 |
-
|
| 17 |
-
import pickle as pkl
|
| 18 |
from html import escape
|
|
|
|
| 19 |
from huggingface_hub import hf_hub_download, login
|
| 20 |
from datasets import load_dataset
|
| 21 |
|
| 22 |
-
token = os.getenv("HUGGINGFACE_TOKEN")
|
| 23 |
|
| 24 |
-
|
| 25 |
-
|
| 26 |
|
| 27 |
-
# Variables
|
| 28 |
device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")
|
| 29 |
-
print(f"Device: {device}")
|
| 30 |
HEIGHT = 200
|
| 31 |
-
N_RESULTS =
|
| 32 |
color = st.get_option("theme.primaryColor")
|
| 33 |
if color is None:
|
| 34 |
color = (0, 0, 255)
|
| 35 |
else:
|
| 36 |
color = tuple(int(color.lstrip("#")[i: i + 2], 16) for i in (0, 2, 4))
|
| 37 |
|
|
|
|
| 38 |
@st.cache_resource
|
| 39 |
-
def
|
| 40 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 41 |
dataset = load_dataset("CHSTR/ecommerce")
|
| 42 |
path_images = "/".join(dataset['validation']
|
| 43 |
['image'][0].filename.split("/")[:-3]) + "/"
|
| 44 |
-
print(f"Directorio de imágenes: {path_images}")
|
| 45 |
|
| 46 |
-
#
|
| 47 |
path_model = hf_hub_download(
|
| 48 |
repo_id="CHSTR/Ecommerce", filename="dinov2_ecommerce.ckpt")
|
| 49 |
-
print(f"Archivo del modelo descargado en: {path_model}")
|
| 50 |
|
| 51 |
-
#
|
| 52 |
-
model = Model()
|
| 53 |
model_checkpoint = torch.load(path_model, map_location=device)
|
| 54 |
model.load_state_dict(model_checkpoint['state_dict'])
|
| 55 |
model.eval()
|
| 56 |
-
# model.to(device)
|
| 57 |
-
print("Modelo cargado exitosamente")
|
| 58 |
|
| 59 |
-
#
|
| 60 |
embeddings_file = hf_hub_download(
|
| 61 |
repo_id="CHSTR/Ecommerce", filename="ecommerce_demo.pkl")
|
| 62 |
-
print(f"Archivo de embeddings descargado en: {embeddings_file}")
|
| 63 |
|
| 64 |
embeddings = {
|
| 65 |
0: pkl.load(open(embeddings_file, "rb")),
|
| 66 |
1: pkl.load(open(embeddings_file, "rb"))
|
| 67 |
}
|
| 68 |
|
| 69 |
-
#
|
| 70 |
-
for
|
| 71 |
-
embeddings[
|
| 72 |
-
|
| 73 |
-
|
| 74 |
-
|
| 75 |
-
for i in range(len(embeddings[1])):
|
| 76 |
-
embeddings[1][i] = (embeddings[1][i][0], path_images +
|
| 77 |
-
"/".join(embeddings[1][i][1].split("/")[-3:]))
|
| 78 |
|
| 79 |
return model, path_images, embeddings
|
| 80 |
|
| 81 |
-
|
|
|
|
| 82 |
with torch.no_grad():
|
| 83 |
-
sketch_feat = model(
|
| 84 |
return sketch_feat
|
| 85 |
|
| 86 |
-
|
| 87 |
-
|
|
|
|
| 88 |
corpus_id = 0 if corpus == "Unsplash" else 1
|
| 89 |
image_features = torch.tensor(
|
| 90 |
[item[0] for item in embeddings[corpus_id]]).to(device)
|
|
@@ -93,7 +88,6 @@ def image_search(query, corpus, n_results=N_RESULTS):
|
|
| 93 |
_, max_indices = torch.topk(
|
| 94 |
dot_product, n_results, dim=0, largest=True, sorted=True)
|
| 95 |
|
| 96 |
-
# Diccionario para mapear los paths a labels
|
| 97 |
path_to_label = {path: idx for idx,
|
| 98 |
(_, path) in enumerate(embeddings[corpus_id])}
|
| 99 |
label_to_path = {idx: path for path, idx in path_to_label.items()}
|
|
@@ -101,14 +95,14 @@ def image_search(query, corpus, n_results=N_RESULTS):
|
|
| 101 |
[path_to_label[item[1]] for item in embeddings[corpus_id]]).to(device)
|
| 102 |
|
| 103 |
return [
|
| 104 |
-
(
|
| 105 |
-
label_to_path[i],
|
| 106 |
-
)
|
| 107 |
for i in label_of_images[max_indices].cpu().numpy().tolist()
|
| 108 |
-
], dot_product[max_indices]
|
| 109 |
|
| 110 |
|
| 111 |
-
|
|
|
|
|
|
|
| 112 |
x, y = img.size
|
| 113 |
size = max(x, y)
|
| 114 |
new_img = Image.new("RGB", (x, y), fill_color)
|
|
@@ -118,18 +112,12 @@ def make_square(img, fill_color=(255, 255, 255)):
|
|
| 118 |
|
| 119 |
@st.cache_data
|
| 120 |
def get_images(paths):
|
| 121 |
-
|
| 122 |
-
|
| 123 |
-
|
| 124 |
-
processed = Pool(N_RESULTS).map(process_image, paths)
|
| 125 |
-
imgs, xs, ys = [], [], []
|
| 126 |
-
for img, x, y in processed:
|
| 127 |
-
imgs.append(img)
|
| 128 |
-
xs.append(x)
|
| 129 |
-
ys.append(y)
|
| 130 |
-
return imgs, xs, ys
|
| 131 |
|
| 132 |
|
|
|
|
| 133 |
def convert_pil_to_base64(image):
|
| 134 |
img_buffer = BytesIO()
|
| 135 |
image.save(img_buffer, format="JPEG")
|
|
@@ -138,21 +126,6 @@ def convert_pil_to_base64(image):
|
|
| 138 |
return base64_str
|
| 139 |
|
| 140 |
|
| 141 |
-
def draw_reshape_encode(img, boxes, x, y):
|
| 142 |
-
boxes = [boxes.tolist()]
|
| 143 |
-
image = img.copy()
|
| 144 |
-
draw = ImageDraw.Draw(image)
|
| 145 |
-
new_x, new_y = int(x * HEIGHT / y), HEIGHT
|
| 146 |
-
for box in boxes:
|
| 147 |
-
print("box:", box)
|
| 148 |
-
draw.rectangle(
|
| 149 |
-
# (x_min, y_min, x_max, y_max)
|
| 150 |
-
[(box[0], box[1]), (box[2], box[3])],
|
| 151 |
-
outline=color, # Box color
|
| 152 |
-
width=7 # Box width
|
| 153 |
-
)
|
| 154 |
-
|
| 155 |
-
|
| 156 |
def get_html(url_list, encoded_images):
|
| 157 |
html = "<div style='margin-top: 20px; max-width: 1200px; display: flex; flex-wrap: wrap; justify-content: space-evenly'>"
|
| 158 |
for i in range(len(url_list)):
|
|
@@ -165,63 +138,40 @@ def get_html(url_list, encoded_images):
|
|
| 165 |
return html
|
| 166 |
|
| 167 |
|
| 168 |
-
|
| 169 |
-
|
| 170 |
-
|
| 171 |
-
|
| 172 |
-
|
| 173 |
-
"display": "flex",
|
| 174 |
-
"justify-content": "center",
|
| 175 |
-
"flex-wrap": "wrap",
|
| 176 |
-
}
|
| 177 |
-
|
| 178 |
|
| 179 |
-
|
|
|
|
| 180 |
|
|
|
|
| 181 |
|
| 182 |
-
|
| 183 |
|
| 184 |
-
|
| 185 |
stroke_width = st.sidebar.slider("Stroke width: ", 1, 25, 5)
|
| 186 |
|
|
|
|
| 187 |
st.markdown(
|
| 188 |
"""
|
| 189 |
<style>
|
| 190 |
-
.block-container{
|
| 191 |
-
|
| 192 |
-
}
|
| 193 |
-
|
| 194 |
-
|
| 195 |
-
|
| 196 |
-
|
| 197 |
-
}
|
| 198 |
-
div.row-widget.stRadio > div > label{
|
| 199 |
-
margin-left: 5px;
|
| 200 |
-
margin-right: 5px;
|
| 201 |
-
}
|
| 202 |
-
.row-widget {
|
| 203 |
-
margin-top: -25px;
|
| 204 |
-
}
|
| 205 |
-
section > div:first-child {
|
| 206 |
-
padding-top: 30px;
|
| 207 |
-
}
|
| 208 |
-
div.appview-container > section:first-child{
|
| 209 |
-
max-width: 320px;
|
| 210 |
-
}
|
| 211 |
-
#MainMenu {
|
| 212 |
-
visibility: hidden;
|
| 213 |
-
}
|
| 214 |
-
.stMarkdown {
|
| 215 |
-
display: grid;
|
| 216 |
-
place-items: center;
|
| 217 |
-
}
|
| 218 |
</style>
|
| 219 |
""",
|
| 220 |
unsafe_allow_html=True,
|
| 221 |
)
|
| 222 |
-
st.sidebar.markdown(description)
|
| 223 |
|
| 224 |
-
st.title("
|
| 225 |
_, col, _ = st.columns((1, 1, 1))
|
| 226 |
with col:
|
| 227 |
canvas_result = st_canvas(
|
|
@@ -233,13 +183,12 @@ def main():
|
|
| 233 |
key="color_annotation_app",
|
| 234 |
)
|
| 235 |
|
| 236 |
-
st.columns((1, 3, 1))
|
| 237 |
corpus = ["Ecommerce"]
|
|
|
|
| 238 |
|
| 239 |
if canvas_result.image_data is not None:
|
| 240 |
draw = Image.fromarray(canvas_result.image_data.astype("uint8"))
|
| 241 |
draw = ImageOps.pad(draw.convert("RGB"), size=(224, 224))
|
| 242 |
-
draw.save("draw.jpg")
|
| 243 |
|
| 244 |
draw_tensor = transforms.ToTensor()(draw)
|
| 245 |
draw_tensor = transforms.Resize((224, 224))(draw_tensor)
|
|
@@ -248,20 +197,19 @@ def main():
|
|
| 248 |
)(draw_tensor)
|
| 249 |
draw_tensor = draw_tensor.unsqueeze(0)
|
| 250 |
|
| 251 |
-
retrieved, _ = image_search(
|
|
|
|
| 252 |
imgs, xs, ys = get_images([x[0] for x in retrieved])
|
|
|
|
| 253 |
encoded_images = []
|
| 254 |
for image_idx in range(len(imgs)):
|
| 255 |
img0, x, y = imgs[image_idx], xs[image_idx], ys[image_idx]
|
| 256 |
-
|
| 257 |
new_x, new_y = int(x * HEIGHT / y), HEIGHT
|
| 258 |
-
|
| 259 |
encoded_images.append(convert_pil_to_base64(
|
| 260 |
img0.resize((new_x, new_y))))
|
|
|
|
| 261 |
st.markdown(get_html(retrieved, encoded_images),
|
| 262 |
unsafe_allow_html=True)
|
| 263 |
-
else:
|
| 264 |
-
return
|
| 265 |
|
| 266 |
|
| 267 |
if __name__ == "__main__":
|
|
|
|
| 1 |
import os
|
|
|
|
| 2 |
import streamlit as st
|
| 3 |
from io import BytesIO
|
|
|
|
| 4 |
from multiprocessing.dummy import Pool
|
| 5 |
+
import base64
|
| 6 |
+
from PIL import Image, ImageOps
|
| 7 |
import torch
|
| 8 |
from torchvision import transforms
|
|
|
|
|
|
|
| 9 |
from streamlit_drawable_canvas import st_canvas
|
| 10 |
from src.model_LN_prompt import Model
|
|
|
|
|
|
|
|
|
|
| 11 |
from html import escape
|
| 12 |
+
import pickle as pkl
|
| 13 |
from huggingface_hub import hf_hub_download, login
|
| 14 |
from datasets import load_dataset
|
| 15 |
|
|
|
|
| 16 |
|
| 17 |
+
if 'initialized' not in st.session_state:
|
| 18 |
+
st.session_state.initialized = False
|
| 19 |
|
|
|
|
| 20 |
device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")
|
|
|
|
| 21 |
HEIGHT = 200
|
| 22 |
+
N_RESULTS = 20
|
| 23 |
color = st.get_option("theme.primaryColor")
|
| 24 |
if color is None:
|
| 25 |
color = (0, 0, 255)
|
| 26 |
else:
|
| 27 |
color = tuple(int(color.lstrip("#")[i: i + 2], 16) for i in (0, 2, 4))
|
| 28 |
|
| 29 |
+
|
| 30 |
@st.cache_resource
|
| 31 |
+
def initialize_huggingface():
|
| 32 |
+
token = os.getenv("HUGGINGFACE_TOKEN")
|
| 33 |
+
if token:
|
| 34 |
+
login(token=token)
|
| 35 |
+
else:
|
| 36 |
+
st.error("HUGGINGFACE_TOKEN not found in environment variables")
|
| 37 |
+
|
| 38 |
+
|
| 39 |
+
@st.cache_resource
|
| 40 |
+
def load_model_and_data():
|
| 41 |
+
print("Loading everything...")
|
| 42 |
dataset = load_dataset("CHSTR/ecommerce")
|
| 43 |
path_images = "/".join(dataset['validation']
|
| 44 |
['image'][0].filename.split("/")[:-3]) + "/"
|
|
|
|
| 45 |
|
| 46 |
+
# Download model
|
| 47 |
path_model = hf_hub_download(
|
| 48 |
repo_id="CHSTR/Ecommerce", filename="dinov2_ecommerce.ckpt")
|
|
|
|
| 49 |
|
| 50 |
+
# Load model
|
| 51 |
+
model = Model().to(device)
|
| 52 |
model_checkpoint = torch.load(path_model, map_location=device)
|
| 53 |
model.load_state_dict(model_checkpoint['state_dict'])
|
| 54 |
model.eval()
|
|
|
|
|
|
|
| 55 |
|
| 56 |
+
# Download and load embeddings
|
| 57 |
embeddings_file = hf_hub_download(
|
| 58 |
repo_id="CHSTR/Ecommerce", filename="ecommerce_demo.pkl")
|
|
|
|
| 59 |
|
| 60 |
embeddings = {
|
| 61 |
0: pkl.load(open(embeddings_file, "rb")),
|
| 62 |
1: pkl.load(open(embeddings_file, "rb"))
|
| 63 |
}
|
| 64 |
|
| 65 |
+
# Update image paths
|
| 66 |
+
for corpus_id in [0, 1]:
|
| 67 |
+
embeddings[corpus_id] = [
|
| 68 |
+
(emb[0], path_images + "/".join(emb[1].split("/")[-3:]))
|
| 69 |
+
for emb in embeddings[corpus_id]
|
| 70 |
+
]
|
|
|
|
|
|
|
|
|
|
| 71 |
|
| 72 |
return model, path_images, embeddings
|
| 73 |
|
| 74 |
+
|
| 75 |
+
def compute_sketch(_sketch, model):
|
| 76 |
with torch.no_grad():
|
| 77 |
+
sketch_feat = model(_sketch.to(device), dtype='sketch')
|
| 78 |
return sketch_feat
|
| 79 |
|
| 80 |
+
|
| 81 |
+
def image_search(_query, corpus, model, embeddings, n_results=N_RESULTS):
|
| 82 |
+
query_embedding = compute_sketch(_query, model)
|
| 83 |
corpus_id = 0 if corpus == "Unsplash" else 1
|
| 84 |
image_features = torch.tensor(
|
| 85 |
[item[0] for item in embeddings[corpus_id]]).to(device)
|
|
|
|
| 88 |
_, max_indices = torch.topk(
|
| 89 |
dot_product, n_results, dim=0, largest=True, sorted=True)
|
| 90 |
|
|
|
|
| 91 |
path_to_label = {path: idx for idx,
|
| 92 |
(_, path) in enumerate(embeddings[corpus_id])}
|
| 93 |
label_to_path = {idx: path for path, idx in path_to_label.items()}
|
|
|
|
| 95 |
[path_to_label[item[1]] for item in embeddings[corpus_id]]).to(device)
|
| 96 |
|
| 97 |
return [
|
| 98 |
+
(label_to_path[i],)
|
|
|
|
|
|
|
| 99 |
for i in label_of_images[max_indices].cpu().numpy().tolist()
|
| 100 |
+
], dot_product[max_indices]
|
| 101 |
|
| 102 |
|
| 103 |
+
@st.cache_data
|
| 104 |
+
def make_square(img_path, fill_color=(255, 255, 255)):
|
| 105 |
+
img = Image.open(img_path)
|
| 106 |
x, y = img.size
|
| 107 |
size = max(x, y)
|
| 108 |
new_img = Image.new("RGB", (x, y), fill_color)
|
|
|
|
| 112 |
|
| 113 |
@st.cache_data
|
| 114 |
def get_images(paths):
|
| 115 |
+
processed = [make_square(path) for path in paths]
|
| 116 |
+
imgs, xs, ys = zip(*processed)
|
| 117 |
+
return list(imgs), list(xs), list(ys)
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 118 |
|
| 119 |
|
| 120 |
+
@st.cache_data
|
| 121 |
def convert_pil_to_base64(image):
|
| 122 |
img_buffer = BytesIO()
|
| 123 |
image.save(img_buffer, format="JPEG")
|
|
|
|
| 126 |
return base64_str
|
| 127 |
|
| 128 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 129 |
def get_html(url_list, encoded_images):
|
| 130 |
html = "<div style='margin-top: 20px; max-width: 1200px; display: flex; flex-wrap: wrap; justify-content: space-evenly'>"
|
| 131 |
for i in range(len(url_list)):
|
|
|
|
| 138 |
return html
|
| 139 |
|
| 140 |
|
| 141 |
+
def main():
|
| 142 |
+
if not st.session_state.initialized:
|
| 143 |
+
initialize_huggingface()
|
| 144 |
+
st.session_state.model, st.session_state.path_images, st.session_state.embeddings = load_model_and_data()
|
| 145 |
+
st.session_state.initialized = True
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 146 |
|
| 147 |
+
description = """
|
| 148 |
+
# Self-Supervised Sketch-based Image Retrieval (S3BIR)
|
| 149 |
|
| 150 |
+
Our approaches, S3BIR-CLIP and S3BIR-DINOv2, can produce a bimodal sketch-photo feature space from unpaired data without explicit sketch-photo pairs. Our experiments perform outstandingly in three diverse public datasets where the models are trained without real sketches.
|
| 151 |
|
| 152 |
+
"""
|
| 153 |
|
| 154 |
+
st.sidebar.markdown(description)
|
| 155 |
stroke_width = st.sidebar.slider("Stroke width: ", 1, 25, 5)
|
| 156 |
|
| 157 |
+
# styles
|
| 158 |
st.markdown(
|
| 159 |
"""
|
| 160 |
<style>
|
| 161 |
+
.block-container{ max-width: 1200px; }
|
| 162 |
+
div.row-widget > div{ flex-direction: row; display: flex; justify-content: center; color: white; }
|
| 163 |
+
div.row-widget.stRadio > div > label{ margin-left: 5px; margin-right: 5px; }
|
| 164 |
+
.row-widget { margin-top: -25px; }
|
| 165 |
+
section > div:first-child { padding-top: 30px; }
|
| 166 |
+
div.appview-container > section:first-child{ max-width: 320px; }
|
| 167 |
+
#MainMenu { visibility: hidden; }
|
| 168 |
+
.stMarkdown { display: grid; place-items: center; }
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 169 |
</style>
|
| 170 |
""",
|
| 171 |
unsafe_allow_html=True,
|
| 172 |
)
|
|
|
|
| 173 |
|
| 174 |
+
st.title("S3BIR App")
|
| 175 |
_, col, _ = st.columns((1, 1, 1))
|
| 176 |
with col:
|
| 177 |
canvas_result = st_canvas(
|
|
|
|
| 183 |
key="color_annotation_app",
|
| 184 |
)
|
| 185 |
|
|
|
|
| 186 |
corpus = ["Ecommerce"]
|
| 187 |
+
st.columns((1, 3, 1))
|
| 188 |
|
| 189 |
if canvas_result.image_data is not None:
|
| 190 |
draw = Image.fromarray(canvas_result.image_data.astype("uint8"))
|
| 191 |
draw = ImageOps.pad(draw.convert("RGB"), size=(224, 224))
|
|
|
|
| 192 |
|
| 193 |
draw_tensor = transforms.ToTensor()(draw)
|
| 194 |
draw_tensor = transforms.Resize((224, 224))(draw_tensor)
|
|
|
|
| 197 |
)(draw_tensor)
|
| 198 |
draw_tensor = draw_tensor.unsqueeze(0)
|
| 199 |
|
| 200 |
+
retrieved, _ = image_search(
|
| 201 |
+
draw_tensor, corpus[0], st.session_state.model, st.session_state.embeddings)
|
| 202 |
imgs, xs, ys = get_images([x[0] for x in retrieved])
|
| 203 |
+
|
| 204 |
encoded_images = []
|
| 205 |
for image_idx in range(len(imgs)):
|
| 206 |
img0, x, y = imgs[image_idx], xs[image_idx], ys[image_idx]
|
|
|
|
| 207 |
new_x, new_y = int(x * HEIGHT / y), HEIGHT
|
|
|
|
| 208 |
encoded_images.append(convert_pil_to_base64(
|
| 209 |
img0.resize((new_x, new_y))))
|
| 210 |
+
|
| 211 |
st.markdown(get_html(retrieved, encoded_images),
|
| 212 |
unsafe_allow_html=True)
|
|
|
|
|
|
|
| 213 |
|
| 214 |
|
| 215 |
if __name__ == "__main__":
|
src/__pycache__/model_LN_prompt.cpython-310.pyc
CHANGED
|
Binary files a/src/__pycache__/model_LN_prompt.cpython-310.pyc and b/src/__pycache__/model_LN_prompt.cpython-310.pyc differ
|
|
|
src/__pycache__/options.cpython-310.pyc
CHANGED
|
Binary files a/src/__pycache__/options.cpython-310.pyc and b/src/__pycache__/options.cpython-310.pyc differ
|
|
|
src/model_LN_prompt.py
CHANGED
|
@@ -1,15 +1,9 @@
|
|
| 1 |
-
import numpy as np
|
| 2 |
import torch
|
| 3 |
import torch.nn as nn
|
| 4 |
import torch.nn.functional as F
|
| 5 |
-
from torchmetrics.functional import retrieval_average_precision
|
| 6 |
import pytorch_lightning as pl
|
| 7 |
|
| 8 |
from src.dinov2.models.vision_transformer import vit_base
|
| 9 |
-
|
| 10 |
-
from functools import partial
|
| 11 |
-
|
| 12 |
-
# from src.clip import clip
|
| 13 |
from src.options import opts
|
| 14 |
|
| 15 |
def freeze_model(m):
|
|
@@ -31,23 +25,11 @@ class Model(pl.LightningModule):
|
|
| 31 |
self.opts = opts
|
| 32 |
|
| 33 |
self.dino = vit_base(patch_size=14, block_chunks=0, init_values=1.0)
|
| 34 |
-
print("self.dino", self.dino)
|
| 35 |
|
| 36 |
# Prompt Engineering
|
| 37 |
self.sk_prompt = nn.Parameter(torch.randn(self.opts.n_prompts, self.opts.prompt_dim))
|
| 38 |
self.img_prompt = nn.Parameter(torch.randn(self.opts.n_prompts, self.opts.prompt_dim))
|
| 39 |
|
| 40 |
-
self.distance_fn = lambda x, y: 1.0 - F.cosine_similarity(x, y)
|
| 41 |
-
self.loss_fn_triplet = nn.TripletMarginWithDistanceLoss(
|
| 42 |
-
distance_function=self.distance_fn, margin=0.2)
|
| 43 |
-
|
| 44 |
-
self.emb_cos_loss = nn.CosineEmbeddingLoss(margin=0.2)
|
| 45 |
-
|
| 46 |
-
self.loss_kl = nn.KLDivLoss(reduction="batchmean", log_target=True)
|
| 47 |
-
|
| 48 |
-
self.best_metric = -1e3
|
| 49 |
-
# normalization layer for the representations z1 and z2
|
| 50 |
-
# self.bn = nn.BatchNorm1d(self.opts.prompt_dim, affine=False)
|
| 51 |
|
| 52 |
def configure_optimizers(self):
|
| 53 |
if self.opts.model_type == 'one_encoder':
|
|
|
|
|
|
|
| 1 |
import torch
|
| 2 |
import torch.nn as nn
|
| 3 |
import torch.nn.functional as F
|
|
|
|
| 4 |
import pytorch_lightning as pl
|
| 5 |
|
| 6 |
from src.dinov2.models.vision_transformer import vit_base
|
|
|
|
|
|
|
|
|
|
|
|
|
| 7 |
from src.options import opts
|
| 8 |
|
| 9 |
def freeze_model(m):
|
|
|
|
| 25 |
self.opts = opts
|
| 26 |
|
| 27 |
self.dino = vit_base(patch_size=14, block_chunks=0, init_values=1.0)
|
|
|
|
| 28 |
|
| 29 |
# Prompt Engineering
|
| 30 |
self.sk_prompt = nn.Parameter(torch.randn(self.opts.n_prompts, self.opts.prompt_dim))
|
| 31 |
self.img_prompt = nn.Parameter(torch.randn(self.opts.n_prompts, self.opts.prompt_dim))
|
| 32 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 33 |
|
| 34 |
def configure_optimizers(self):
|
| 35 |
if self.opts.model_type == 'one_encoder':
|
src/options.py
CHANGED
|
@@ -1,18 +1,17 @@
|
|
| 1 |
import argparse
|
| 2 |
|
| 3 |
-
parser = argparse.ArgumentParser(description='
|
| 4 |
|
| 5 |
-
parser.add_argument('--exp_name', type=str, default='
|
| 6 |
|
| 7 |
# ----------------------
|
| 8 |
# Training Params
|
| 9 |
# ----------------------
|
| 10 |
|
| 11 |
-
parser.add_argument('--
|
| 12 |
-
parser.add_argument('--
|
| 13 |
parser.add_argument('--prompt_lr', type=float, default=1e-4)
|
| 14 |
parser.add_argument('--linear_lr', type=float, default=1e-4)
|
| 15 |
-
parser.add_argument('--model_type', type=str, default='one_encoder', choices=['one_encoder', 'two_encoder'])
|
| 16 |
|
| 17 |
# ----------------------
|
| 18 |
# ViT Prompt Parameters
|
|
|
|
| 1 |
import argparse
|
| 2 |
|
| 3 |
+
parser = argparse.ArgumentParser(description='S3BIR')
|
| 4 |
|
| 5 |
+
parser.add_argument('--exp_name', type=str, default='DINOv2_prompt')
|
| 6 |
|
| 7 |
# ----------------------
|
| 8 |
# Training Params
|
| 9 |
# ----------------------
|
| 10 |
|
| 11 |
+
parser.add_argument('--dinov2_lr', type=float, default=1e-4)
|
| 12 |
+
parser.add_argument('--dinov2_LN_lr', type=float, default=1e-6)
|
| 13 |
parser.add_argument('--prompt_lr', type=float, default=1e-4)
|
| 14 |
parser.add_argument('--linear_lr', type=float, default=1e-4)
|
|
|
|
| 15 |
|
| 16 |
# ----------------------
|
| 17 |
# ViT Prompt Parameters
|