Spaces:
Running
Running
Revert "update code to download dataset files from separate repo"
Browse filesThis reverts commit 07356cd41d62a22354471068093aa528e5724930.
- .gitignore +0 -7
- app.py +57 -23
- big_indx_to_id_dict.pickle +3 -0
- bioscan_5m_dna_IndexFlatIP.index +3 -0
- bioscan_5m_image_IndexFlatIP.index +3 -0
- data.py +0 -34
- prepare_index.py +33 -18
.gitignore
DELETED
@@ -1,7 +0,0 @@
|
|
1 |
-
.build
|
2 |
-
.data
|
3 |
-
.singularity
|
4 |
-
slurm/
|
5 |
-
.env
|
6 |
-
*.sif
|
7 |
-
__pycache__/
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
app.py
CHANGED
@@ -1,10 +1,13 @@
|
|
1 |
-
import pickle
|
2 |
-
import random
|
3 |
-
|
4 |
import gradio as gr
|
|
|
5 |
import numpy as np
|
6 |
-
|
7 |
-
|
|
|
|
|
|
|
|
|
|
|
8 |
|
9 |
|
10 |
def getRandID():
|
@@ -13,17 +16,37 @@ def getRandID():
|
|
13 |
|
14 |
|
15 |
def get_image_index(indexType):
|
16 |
-
|
17 |
-
return
|
18 |
-
|
19 |
-
raise
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
20 |
|
21 |
|
22 |
def get_dna_index(indexType):
|
23 |
-
|
24 |
-
return
|
25 |
-
|
26 |
-
raise
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
27 |
|
28 |
|
29 |
def searchEmbeddings(id, key_type, query_type, index_type, num_results: int = 10):
|
@@ -63,13 +86,24 @@ with gr.Blocks() as demo:
|
|
63 |
# for hf: change all file paths, indx_to_id_dict as well
|
64 |
|
65 |
# load indexes
|
66 |
-
|
67 |
-
|
68 |
-
)
|
69 |
-
|
70 |
-
|
71 |
-
|
72 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
73 |
id_to_index_dict = {v: k for k, v in index_to_id_dict.items()}
|
74 |
|
75 |
with gr.Column():
|
@@ -79,8 +113,8 @@ with gr.Blocks() as demo:
|
|
79 |
rand_id_indx = gr.Textbox(label="Index:")
|
80 |
id_btn = gr.Button("Get Random ID")
|
81 |
with gr.Column():
|
82 |
-
|
83 |
-
|
84 |
|
85 |
index_type = gr.Radio(
|
86 |
choices=["FlatIP(default)", "FlatL2", "HNSWFlat", "IVFFlat", "LSH"], label="Index:", value="FlatIP(default)"
|
@@ -88,7 +122,7 @@ with gr.Blocks() as demo:
|
|
88 |
num_results = gr.Number(label="Number of Results:", value=10, precision=0)
|
89 |
|
90 |
process_id = gr.Textbox(label="ID:", info="Enter a sample ID to search for")
|
91 |
-
process_id_list = gr.Textbox(label="Closest matches:")
|
92 |
search_btn = gr.Button("Search")
|
93 |
id_btn.click(fn=getRandID, inputs=[], outputs=[rand_id, rand_id_indx])
|
94 |
|
|
|
|
|
|
|
|
|
1 |
import gradio as gr
|
2 |
+
import torch
|
3 |
import numpy as np
|
4 |
+
import h5py
|
5 |
+
import faiss
|
6 |
+
from PIL import Image
|
7 |
+
import io
|
8 |
+
import pickle
|
9 |
+
import random
|
10 |
+
import click
|
11 |
|
12 |
|
13 |
def getRandID():
|
|
|
16 |
|
17 |
|
18 |
def get_image_index(indexType):
|
19 |
+
if indexType == "FlatIP(default)":
|
20 |
+
return image_index_IP
|
21 |
+
elif indexType == "FlatL2":
|
22 |
+
raise NotImplementedError
|
23 |
+
return image_index_L2
|
24 |
+
elif indexType == "HNSWFlat":
|
25 |
+
raise NotImplementedError
|
26 |
+
return image_index_HNSW
|
27 |
+
elif indexType == "IVFFlat":
|
28 |
+
raise NotImplementedError
|
29 |
+
return image_index_IVF
|
30 |
+
elif indexType == "LSH":
|
31 |
+
raise NotImplementedError
|
32 |
+
return image_index_LSH
|
33 |
|
34 |
|
35 |
def get_dna_index(indexType):
|
36 |
+
if indexType == "FlatIP(default)":
|
37 |
+
return dna_index_IP
|
38 |
+
elif indexType == "FlatL2":
|
39 |
+
raise NotImplementedError
|
40 |
+
return dna_index_L2
|
41 |
+
elif indexType == "HNSWFlat":
|
42 |
+
raise NotImplementedError
|
43 |
+
return dna_index_HNSW
|
44 |
+
elif indexType == "IVFFlat":
|
45 |
+
raise NotImplementedError
|
46 |
+
return dna_index_IVF
|
47 |
+
elif indexType == "LSH":
|
48 |
+
raise NotImplementedError
|
49 |
+
return dna_index_LSH
|
50 |
|
51 |
|
52 |
def searchEmbeddings(id, key_type, query_type, index_type, num_results: int = 10):
|
|
|
86 |
# for hf: change all file paths, indx_to_id_dict as well
|
87 |
|
88 |
# load indexes
|
89 |
+
image_index_IP = faiss.read_index("bioscan_5m_image_IndexFlatIP.index")
|
90 |
+
# image_index_L2 = faiss.read_index("big_image_index_FlatL2.index")
|
91 |
+
# image_index_HNSW = faiss.read_index("big_image_index_HNSWFlat.index")
|
92 |
+
# image_index_IVF = faiss.read_index("big_image_index_IVFFlat.index")
|
93 |
+
# image_index_LSH = faiss.read_index("big_image_index_LSH.index")
|
94 |
+
|
95 |
+
dna_index_IP = faiss.read_index("bioscan_5m_dna_IndexFlatIP.index")
|
96 |
+
# dna_index_L2 = faiss.read_index("big_dna_index_FlatL2.index")
|
97 |
+
# dna_index_HNSW = faiss.read_index("big_dna_index_HNSWFlat.index")
|
98 |
+
# dna_index_IVF = faiss.read_index("big_dna_index_IVFFlat.index")
|
99 |
+
# dna_index_LSH = faiss.read_index("big_dna_index_LSH.index")
|
100 |
+
|
101 |
+
# with open("dataset_processid_list.pickle", "rb") as f:
|
102 |
+
# dataset_processid_list = pickle.load(f)
|
103 |
+
# with open("processid_to_index.pickle", "rb") as f:
|
104 |
+
# processid_to_index = pickle.load(f)
|
105 |
+
with open("big_indx_to_id_dict.pickle", "rb") as f:
|
106 |
+
index_to_id_dict = pickle.load(f)
|
107 |
id_to_index_dict = {v: k for k, v in index_to_id_dict.items()}
|
108 |
|
109 |
with gr.Column():
|
|
|
113 |
rand_id_indx = gr.Textbox(label="Index:")
|
114 |
id_btn = gr.Button("Get Random ID")
|
115 |
with gr.Column():
|
116 |
+
key_type = gr.Radio(choices=["Image", "DNA"], label="Search From:", value="Image")
|
117 |
+
query_type = gr.Radio(choices=["Image", "DNA"], label="Search To:", value="Image")
|
118 |
|
119 |
index_type = gr.Radio(
|
120 |
choices=["FlatIP(default)", "FlatL2", "HNSWFlat", "IVFFlat", "LSH"], label="Index:", value="FlatIP(default)"
|
|
|
122 |
num_results = gr.Number(label="Number of Results:", value=10, precision=0)
|
123 |
|
124 |
process_id = gr.Textbox(label="ID:", info="Enter a sample ID to search for")
|
125 |
+
process_id_list = gr.Textbox(label="Closest 10 matches:")
|
126 |
search_btn = gr.Button("Search")
|
127 |
id_btn.click(fn=getRandID, inputs=[], outputs=[rand_id, rand_id_indx])
|
128 |
|
big_indx_to_id_dict.pickle
ADDED
@@ -0,0 +1,3 @@
|
|
|
|
|
|
|
|
|
1 |
+
version https://git-lfs.github.com/spec/v1
|
2 |
+
oid sha256:ee0a9044e054f640b704247a2fa2e74219180b78ded6ba07f551bfc222657fc5
|
3 |
+
size 885457
|
bioscan_5m_dna_IndexFlatIP.index
ADDED
@@ -0,0 +1,3 @@
|
|
|
|
|
|
|
|
|
1 |
+
version https://git-lfs.github.com/spec/v1
|
2 |
+
oid sha256:83fe6599724756652689b76ef942ffaca2f8d5863ff3dd7fe7ac655199e0968d
|
3 |
+
size 136009773
|
bioscan_5m_image_IndexFlatIP.index
ADDED
@@ -0,0 +1,3 @@
|
|
|
|
|
|
|
|
|
1 |
+
version https://git-lfs.github.com/spec/v1
|
2 |
+
oid sha256:01e5d74fd5194551e2b8e43aba8e41153efeb29589fa82a7839791d2e057c21d
|
3 |
+
size 136009773
|
data.py
DELETED
@@ -1,34 +0,0 @@
|
|
1 |
-
import os
|
2 |
-
import pickle
|
3 |
-
from typing import Any
|
4 |
-
|
5 |
-
import faiss
|
6 |
-
from huggingface_hub import hf_hub_download
|
7 |
-
|
8 |
-
|
9 |
-
def load_indexes_local(index_files: dict[str, str], *, data_folder: str, **kw) -> dict[str, Any]:
|
10 |
-
indexes = {}
|
11 |
-
for index_type, index_file in index_files.items():
|
12 |
-
indexes[index_type] = faiss.read_index(os.path.join(data_folder, index_file))
|
13 |
-
|
14 |
-
return indexes
|
15 |
-
|
16 |
-
|
17 |
-
def load_indexes_hf(index_files: dict[str, str], *, repo_name: str, **kw) -> dict[str, Any]:
|
18 |
-
indexes = {}
|
19 |
-
for index_type, index_file in index_files.items():
|
20 |
-
indexes[index_type] = faiss.read_index(
|
21 |
-
hf_hub_download(repo_id=repo_name, filename=index_file, repo_type="dataset")
|
22 |
-
)
|
23 |
-
|
24 |
-
return indexes
|
25 |
-
|
26 |
-
|
27 |
-
def load_index_pickle(index_file: str, repo_name: str) -> Any:
|
28 |
-
index_to_id_dict_file = hf_hub_download(
|
29 |
-
repo_id=repo_name,
|
30 |
-
filename=index_file,
|
31 |
-
repo_type="dataset",
|
32 |
-
)
|
33 |
-
with open(index_to_id_dict_file, "rb") as f:
|
34 |
-
return pickle.load(f)
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
prepare_index.py
CHANGED
@@ -1,4 +1,3 @@
|
|
1 |
-
import pickle
|
2 |
from pathlib import Path
|
3 |
|
4 |
import click
|
@@ -10,33 +9,55 @@ ALL_INDEX_TYPES = ["IndexFlatIP", "IndexFlatL2", "IndexIVFFlat", "IndexHNSWFlat"
|
|
10 |
EMBEDDING_SIZE = 768
|
11 |
|
12 |
|
13 |
-
def process(
|
14 |
# load embeddings
|
15 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
16 |
|
17 |
# FlatIP and FlatL2
|
18 |
if index_type == "IndexFlatIP":
|
19 |
-
test_index = faiss.IndexFlatIP(
|
20 |
elif index_type == "IndexFlatL2":
|
21 |
-
test_index = faiss.IndexFlatL2(
|
22 |
elif index_type == "IndexIVFFlat":
|
23 |
# IVFFlat
|
24 |
-
quantizer = faiss.IndexFlatIP(
|
25 |
-
test_index = faiss.IndexIVFFlat(quantizer,
|
26 |
-
test_index.train(
|
|
|
|
|
|
|
|
|
27 |
elif index_type == "IndexHNSWFlat":
|
28 |
# HNSW
|
29 |
# 16: connections for each vertex. efSearch: depth of search during search. efConstruction: depth of search during build
|
30 |
-
test_index = faiss.IndexHNSWFlat(
|
31 |
test_index.hnsw.efSearch = 32
|
32 |
test_index.hnsw.efConstruction = 64
|
33 |
elif index_type == "IndexLSH":
|
34 |
# LSH
|
35 |
-
test_index = faiss.IndexLSH(
|
36 |
else:
|
37 |
raise ValueError(f"Index type {index_type} is not supported")
|
38 |
|
39 |
-
test_index.add(
|
|
|
|
|
|
|
|
|
40 |
|
41 |
faiss.write_index(test_index, str(output / f"bioscan_5m_{key_type}_{index_type}.index"))
|
42 |
print("Saved index to", output / f"bioscan_5m_{key_type}_{index_type}.index")
|
@@ -75,15 +96,9 @@ def main(input, output, key_type, index_type):
|
|
75 |
else:
|
76 |
index_types = [index_type]
|
77 |
|
78 |
-
embedding_data = h5py.File(input / "extracted_features_for_all_5m_data.hdf5", "r", libver="latest")
|
79 |
for key_type in key_types:
|
80 |
for index_type in index_types:
|
81 |
-
process(
|
82 |
-
|
83 |
-
sample_ids = [raw_id.decode("utf-8") for raw_id in embedding_data["file_name_list"][:]]
|
84 |
-
index_to_id = {index: id for index, id in enumerate(sample_ids)}
|
85 |
-
with open(output / "big_indx_to_id_dict.pickle", "wb") as f:
|
86 |
-
pickle.dump(index_to_id, f)
|
87 |
|
88 |
|
89 |
if __name__ == "__main__":
|
|
|
|
|
1 |
from pathlib import Path
|
2 |
|
3 |
import click
|
|
|
9 |
EMBEDDING_SIZE = 768
|
10 |
|
11 |
|
12 |
+
def process(input: Path, output: Path, key_type: str, index_type: str):
|
13 |
# load embeddings
|
14 |
+
all_keys = h5py.File(input / "extracted_features_of_all_keys.hdf5", "r", libver="latest")[
|
15 |
+
f"encoded_{key_type}_feature"
|
16 |
+
][:]
|
17 |
+
seen_test = h5py.File(input / "extracted_features_of_seen_test.hdf5", "r", libver="latest")[
|
18 |
+
f"encoded_{key_type}_feature"
|
19 |
+
][:]
|
20 |
+
unseen_test = h5py.File(input / "extracted_features_of_unseen_test.hdf5", "r", libver="latest")[
|
21 |
+
f"encoded_{key_type}_feature"
|
22 |
+
][:]
|
23 |
+
seen_val = h5py.File(input / "extracted_features_of_seen_val.hdf5", "r", libver="latest")[
|
24 |
+
f"encoded_{key_type}_feature"
|
25 |
+
][:]
|
26 |
+
unseen_val = h5py.File(input / "extracted_features_of_unseen_val.hdf5", "r", libver="latest")[
|
27 |
+
f"encoded_{key_type}_feature"
|
28 |
+
][:]
|
29 |
|
30 |
# FlatIP and FlatL2
|
31 |
if index_type == "IndexFlatIP":
|
32 |
+
test_index = faiss.IndexFlatIP(EMBEDDING_SIZE)
|
33 |
elif index_type == "IndexFlatL2":
|
34 |
+
test_index = faiss.IndexFlatL2(EMBEDDING_SIZE)
|
35 |
elif index_type == "IndexIVFFlat":
|
36 |
# IVFFlat
|
37 |
+
quantizer = faiss.IndexFlatIP(EMBEDDING_SIZE)
|
38 |
+
test_index = faiss.IndexIVFFlat(quantizer, EMBEDDING_SIZE, 128)
|
39 |
+
test_index.train(all_keys)
|
40 |
+
test_index.train(seen_test)
|
41 |
+
test_index.train(unseen_test)
|
42 |
+
test_index.train(seen_val)
|
43 |
+
test_index.train(unseen_val)
|
44 |
elif index_type == "IndexHNSWFlat":
|
45 |
# HNSW
|
46 |
# 16: connections for each vertex. efSearch: depth of search during search. efConstruction: depth of search during build
|
47 |
+
test_index = faiss.IndexHNSWFlat(EMBEDDING_SIZE, 16)
|
48 |
test_index.hnsw.efSearch = 32
|
49 |
test_index.hnsw.efConstruction = 64
|
50 |
elif index_type == "IndexLSH":
|
51 |
# LSH
|
52 |
+
test_index = faiss.IndexLSH(EMBEDDING_SIZE, EMBEDDING_SIZE * 2)
|
53 |
else:
|
54 |
raise ValueError(f"Index type {index_type} is not supported")
|
55 |
|
56 |
+
test_index.add(all_keys)
|
57 |
+
test_index.add(seen_test)
|
58 |
+
test_index.add(unseen_test)
|
59 |
+
test_index.add(seen_val)
|
60 |
+
test_index.add(unseen_val)
|
61 |
|
62 |
faiss.write_index(test_index, str(output / f"bioscan_5m_{key_type}_{index_type}.index"))
|
63 |
print("Saved index to", output / f"bioscan_5m_{key_type}_{index_type}.index")
|
|
|
96 |
else:
|
97 |
index_types = [index_type]
|
98 |
|
|
|
99 |
for key_type in key_types:
|
100 |
for index_type in index_types:
|
101 |
+
process(input, output, key_type, index_type)
|
|
|
|
|
|
|
|
|
|
|
102 |
|
103 |
|
104 |
if __name__ == "__main__":
|