CHSTR commited on
Commit
0f23307
·
1 Parent(s): 3bfc386

subiendo la demo a hf

Browse files
app.py CHANGED
@@ -5,6 +5,7 @@ from multiprocessing.dummy import Pool
5
  import base64
6
  from PIL import Image, ImageOps
7
  import torch
 
8
  from torchvision import transforms
9
  from streamlit_drawable_canvas import st_canvas
10
  from src.model_LN_prompt import Model
@@ -83,8 +84,9 @@ def compute_sketch(_sketch, model):
83
  def image_search(_query, corpus, model, embeddings, n_results=N_RESULTS):
84
  query_embedding = compute_sketch(_query, model)
85
  corpus_id = 0 if corpus == "Unsplash" else 1
86
- image_features = torch.tensor(
87
- list([item[0] for item in embeddings[corpus_id]])).to(device)
 
88
 
89
  dot_product = (image_features @ query_embedding.T)[:, 0]
90
  _, max_indices = torch.topk(
 
5
  import base64
6
  from PIL import Image, ImageOps
7
  import torch
8
+ import numpy as np
9
  from torchvision import transforms
10
  from streamlit_drawable_canvas import st_canvas
11
  from src.model_LN_prompt import Model
 
84
  def image_search(_query, corpus, model, embeddings, n_results=N_RESULTS):
85
  query_embedding = compute_sketch(_query, model)
86
  corpus_id = 0 if corpus == "Unsplash" else 1
87
+ image_features = torch.from_numpy(
88
+ np.array([item[0] for item in embeddings[corpus_id]])
89
+ ).to(device)
90
 
91
  dot_product = (image_features @ query_embedding.T)[:, 0]
92
  _, max_indices = torch.topk(
src/__pycache__/model_LN_prompt.cpython-310.pyc CHANGED
Binary files a/src/__pycache__/model_LN_prompt.cpython-310.pyc and b/src/__pycache__/model_LN_prompt.cpython-310.pyc differ
 
src/model_LN_prompt.py CHANGED
@@ -32,14 +32,11 @@ class Model(pl.LightningModule):
32
 
33
 
34
  def configure_optimizers(self):
35
- if self.opts.model_type == 'one_encoder':
36
- model_params = list(self.dino.parameters())
37
- else:
38
- model_params = list(self.dino.parameters()) + list(self.clip_sk.parameters())
39
 
40
  optimizer = torch.optim.Adam([
41
- {'params': model_params, 'lr': self.opts.clip_LN_lr},
42
- {'params': [self.sk_prompt] + [self.img_prompt], 'lr': self.opts.prompt_lr}])
43
  return optimizer
44
 
45
  def forward(self, data, dtype='image'):
 
32
 
33
 
34
  def configure_optimizers(self):
35
+ model_params = list(self.dino.parameters())
 
 
 
36
 
37
  optimizer = torch.optim.Adam([
38
+ {'params': model_params, 'lr': self.opts.clip_LN_lr}]
39
+ )
40
  return optimizer
41
 
42
  def forward(self, data, dtype='image'):