CHSTR commited on
Commit
04923de
·
1 Parent(s): 09892bf

path filename

Browse files
Files changed (1) hide show
  1. app.py +2 -1
app.py CHANGED
@@ -40,6 +40,7 @@ def initialize_huggingface():
40
  def load_model_and_data():
41
  print("Loading everything...")
42
  dataset = load_dataset("CHSTR/ecommerce")
 
43
  path_images = "/".join(dataset['validation']
44
  ['image'][0].filename.split("/")[:-3]) + "/"
45
 
@@ -81,7 +82,7 @@ def compute_sketch(_sketch, model):
81
  def image_search(_query, corpus, model, embeddings, n_results=N_RESULTS):
82
  query_embedding = compute_sketch(_query, model)
83
  corpus_id = 0 if corpus == "Unsplash" else 1
84
- image_features = torch.tensor(
85
  [item[0] for item in embeddings[corpus_id]]).to(device)
86
 
87
  dot_product = (image_features @ query_embedding.T)[:, 0]
 
40
  def load_model_and_data():
41
  print("Loading everything...")
42
  dataset = load_dataset("CHSTR/ecommerce")
43
+ print(dataset['validation']['image'][0].filename)
44
  path_images = "/".join(dataset['validation']
45
  ['image'][0].filename.split("/")[:-3]) + "/"
46
 
 
82
  def image_search(_query, corpus, model, embeddings, n_results=N_RESULTS):
83
  query_embedding = compute_sketch(_query, model)
84
  corpus_id = 0 if corpus == "Unsplash" else 1
85
+ image_features = torch.from_numpy(
86
  [item[0] for item in embeddings[corpus_id]]).to(device)
87
 
88
  dot_product = (image_features @ query_embedding.T)[:, 0]