subiendo la demo a hf
Browse files- app.py +4 -2
- src/__pycache__/model_LN_prompt.cpython-310.pyc +0 -0
- src/model_LN_prompt.py +3 -6
app.py
CHANGED
@@ -5,6 +5,7 @@ from multiprocessing.dummy import Pool
|
|
5 |
import base64
|
6 |
from PIL import Image, ImageOps
|
7 |
import torch
|
|
|
8 |
from torchvision import transforms
|
9 |
from streamlit_drawable_canvas import st_canvas
|
10 |
from src.model_LN_prompt import Model
|
@@ -83,8 +84,9 @@ def compute_sketch(_sketch, model):
|
|
83 |
def image_search(_query, corpus, model, embeddings, n_results=N_RESULTS):
|
84 |
query_embedding = compute_sketch(_query, model)
|
85 |
corpus_id = 0 if corpus == "Unsplash" else 1
|
86 |
-
image_features = torch.
|
87 |
-
|
|
|
88 |
|
89 |
dot_product = (image_features @ query_embedding.T)[:, 0]
|
90 |
_, max_indices = torch.topk(
|
|
|
5 |
import base64
|
6 |
from PIL import Image, ImageOps
|
7 |
import torch
|
8 |
+
import numpy as np
|
9 |
from torchvision import transforms
|
10 |
from streamlit_drawable_canvas import st_canvas
|
11 |
from src.model_LN_prompt import Model
|
|
|
84 |
def image_search(_query, corpus, model, embeddings, n_results=N_RESULTS):
|
85 |
query_embedding = compute_sketch(_query, model)
|
86 |
corpus_id = 0 if corpus == "Unsplash" else 1
|
87 |
+
image_features = torch.from_numpy(
|
88 |
+
np.array([item[0] for item in embeddings[corpus_id]])
|
89 |
+
).to(device)
|
90 |
|
91 |
dot_product = (image_features @ query_embedding.T)[:, 0]
|
92 |
_, max_indices = torch.topk(
|
src/__pycache__/model_LN_prompt.cpython-310.pyc
CHANGED
Binary files a/src/__pycache__/model_LN_prompt.cpython-310.pyc and b/src/__pycache__/model_LN_prompt.cpython-310.pyc differ
|
|
src/model_LN_prompt.py
CHANGED
@@ -32,14 +32,11 @@ class Model(pl.LightningModule):
|
|
32 |
|
33 |
|
34 |
def configure_optimizers(self):
|
35 |
-
|
36 |
-
model_params = list(self.dino.parameters())
|
37 |
-
else:
|
38 |
-
model_params = list(self.dino.parameters()) + list(self.clip_sk.parameters())
|
39 |
|
40 |
optimizer = torch.optim.Adam([
|
41 |
-
{'params': model_params, 'lr': self.opts.clip_LN_lr}
|
42 |
-
|
43 |
return optimizer
|
44 |
|
45 |
def forward(self, data, dtype='image'):
|
|
|
32 |
|
33 |
|
34 |
def configure_optimizers(self):
|
35 |
+
model_params = list(self.dino.parameters())
|
|
|
|
|
|
|
36 |
|
37 |
optimizer = torch.optim.Adam([
|
38 |
+
{'params': model_params, 'lr': self.opts.clip_LN_lr}]
|
39 |
+
)
|
40 |
return optimizer
|
41 |
|
42 |
def forward(self, data, dtype='image'):
|