path filename
Browse files
app.py
CHANGED
@@ -40,6 +40,7 @@ def initialize_huggingface():
|
|
40 |
def load_model_and_data():
|
41 |
print("Loading everything...")
|
42 |
dataset = load_dataset("CHSTR/ecommerce")
|
|
|
43 |
path_images = "/".join(dataset['validation']
|
44 |
['image'][0].filename.split("/")[:-3]) + "/"
|
45 |
|
@@ -81,7 +82,7 @@ def compute_sketch(_sketch, model):
|
|
81 |
def image_search(_query, corpus, model, embeddings, n_results=N_RESULTS):
|
82 |
query_embedding = compute_sketch(_query, model)
|
83 |
corpus_id = 0 if corpus == "Unsplash" else 1
|
84 |
-
image_features = torch.
|
85 |
[item[0] for item in embeddings[corpus_id]]).to(device)
|
86 |
|
87 |
dot_product = (image_features @ query_embedding.T)[:, 0]
|
|
|
40 |
def load_model_and_data():
|
41 |
print("Loading everything...")
|
42 |
dataset = load_dataset("CHSTR/ecommerce")
|
43 |
+
print(dataset['validation']['image'][0].filename)
|
44 |
path_images = "/".join(dataset['validation']
|
45 |
['image'][0].filename.split("/")[:-3]) + "/"
|
46 |
|
|
|
82 |
def image_search(_query, corpus, model, embeddings, n_results=N_RESULTS):
|
83 |
query_embedding = compute_sketch(_query, model)
|
84 |
corpus_id = 0 if corpus == "Unsplash" else 1
|
85 |
+
image_features = torch.from_numpy(
|
86 |
[item[0] for item in embeddings[corpus_id]]).to(device)
|
87 |
|
88 |
dot_product = (image_features @ query_embedding.T)[:, 0]
|