Spaces:
Runtime error
Runtime error
Samuel Schmidt
commited on
Commit
·
cb03df9
1
Parent(s):
49a4ce1
Bugfix, comma
Browse files- src/app.py +5 -5
src/app.py
CHANGED
|
@@ -49,14 +49,14 @@ def check_index(ds):
|
|
| 49 |
|
| 50 |
else:
|
| 51 |
return index_dataset(ds)
|
| 52 |
-
|
| 53 |
|
| 54 |
dataset_with_embeddings = check_index(candidate_subset)
|
| 55 |
|
| 56 |
# Main function, to find similar images
|
| 57 |
# TODO: implement different distance measures
|
| 58 |
|
| 59 |
-
def get_neighbors(query_image, selected_descriptor, selected_distance top_k=5):
|
| 60 |
"""Returns the top k nearest examples to the query image.
|
| 61 |
|
| 62 |
Args:
|
|
@@ -75,10 +75,10 @@ def get_neighbors(query_image, selected_descriptor, selected_distance top_k=5):
|
|
| 75 |
'color_embeddings', qi_np, k=top_k)
|
| 76 |
elif selected_distance == "Chi-squared":
|
| 77 |
tmp_dataset = dataset_with_embeddings.map(lambda row: {'distance': chi2_distance(histA=query_vector, histB=row['color_embeddings'])})
|
| 78 |
-
retrieved_examples = tmp_dataset.sort("distance")
|
| 79 |
-
else:
|
| 80 |
tmp_dataset = dataset_with_embeddings.map(lambda row: {'distance': euclidian_distance(histA=query_vector, histB=row['color_embeddings'])})
|
| 81 |
-
retrieved_examples = tmp_dataset.sort("distance")
|
| 82 |
images = retrieved_examples['image'] #retrieved images is a dict, with images and embeddings
|
| 83 |
return images
|
| 84 |
if "CLIP" == selected_descriptor:
|
|
|
|
| 49 |
|
| 50 |
else:
|
| 51 |
return index_dataset(ds)
|
| 52 |
+
|
| 53 |
|
| 54 |
dataset_with_embeddings = check_index(candidate_subset)
|
| 55 |
|
| 56 |
# Main function, to find similar images
|
| 57 |
# TODO: implement different distance measures
|
| 58 |
|
| 59 |
+
def get_neighbors(query_image, selected_descriptor, selected_distance, top_k=5):
|
| 60 |
"""Returns the top k nearest examples to the query image.
|
| 61 |
|
| 62 |
Args:
|
|
|
|
| 75 |
'color_embeddings', qi_np, k=top_k)
|
| 76 |
elif selected_distance == "Chi-squared":
|
| 77 |
tmp_dataset = dataset_with_embeddings.map(lambda row: {'distance': chi2_distance(histA=query_vector, histB=row['color_embeddings'])})
|
| 78 |
+
retrieved_examples = tmp_dataset.sort("distance")[:5]
|
| 79 |
+
else:
|
| 80 |
tmp_dataset = dataset_with_embeddings.map(lambda row: {'distance': euclidian_distance(histA=query_vector, histB=row['color_embeddings'])})
|
| 81 |
+
retrieved_examples = tmp_dataset.sort("distance")[:5]
|
| 82 |
images = retrieved_examples['image'] #retrieved images is a dict, with images and embeddings
|
| 83 |
return images
|
| 84 |
if "CLIP" == selected_descriptor:
|