atwang commited on
Commit
df89a31
1 Parent(s): af5d6e8

add pickle prep script and hot fix app to not load dna

Browse files
Files changed (2) hide show
  1. app.py +4 -2
  2. prepare_pickle.py +205 -0
app.py CHANGED
@@ -7,6 +7,7 @@ from PIL import Image
7
  import io
8
  import pickle
9
  import random
 
10
 
11
 
12
  def getRandID():
@@ -106,8 +107,9 @@ with gr.Blocks() as demo:
106
  # initialize both possible dicts
107
  with open("big_id_to_image_emb_dict.pickle", "rb") as f:
108
  id_to_image_emb_dict = pickle.load(f)
109
- with open("big_id_to_dna_emb_dict.pickle", "rb") as f:
110
- id_to_dna_emb_dict = pickle.load(f)
 
111
 
112
  with gr.Column():
113
  with gr.Row():
 
7
  import io
8
  import pickle
9
  import random
10
+ import click
11
 
12
 
13
  def getRandID():
 
107
  # initialize both possible dicts
108
  with open("big_id_to_image_emb_dict.pickle", "rb") as f:
109
  id_to_image_emb_dict = pickle.load(f)
110
+ # with open("big_id_to_dna_emb_dict.pickle", "rb") as f:
111
+ # id_to_dna_emb_dict = pickle.load(f)
112
+ id_to_dna_emb_dict = None
113
 
114
  with gr.Column():
115
  with gr.Row():
prepare_pickle.py ADDED
@@ -0,0 +1,205 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import pickle
2
+ from pathlib import Path
3
+
4
+ import numpy as np
5
+ import h5py
6
+ import faiss
7
+ import click
8
+
9
+
10
+ def getFlatIP():
11
+ test_index = faiss.IndexFlatIP(768)
12
+ return test_index
13
+
14
+
15
+ def getFlatL2():
16
+ test_index = faiss.IndexFlatL2(768)
17
+ return test_index
18
+
19
+
20
+ def getIVFFlat(all_keys, seen_test, unseen_test, seen_val, unseen_val):
21
+ quantizer = faiss.IndexFlatIP(768)
22
+ test_index = faiss.IndexIVFFlat(quantizer, 768, 128)
23
+ test_index.train(all_keys)
24
+ test_index.train(seen_test)
25
+ test_index.train(unseen_test)
26
+ test_index.train(seen_val)
27
+ test_index.train(unseen_val)
28
+ return test_index
29
+
30
+
31
+ def getHNSW():
32
+ # 16: connections for each vertex. efSearch: depth of search during search. efConstruction: depth of search during build
33
+ test_index = faiss.IndexHNSWFlat(768, 16)
34
+ test_index.hnsw.efSearch = 32
35
+ test_index.hnsw.efConstruction = 64
36
+ return test_index
37
+
38
+
39
+ def getLSH():
40
+ test_index = faiss.IndexLSH(768, 768 * 2)
41
+ return test_index
42
+
43
+
44
+ def getIdToEmbedding(allid, stid, utid, svalid, uvalid, all_keys, seen_test, unseen_test, seen_val, unseen_val):
45
+ id_to_emb_dict = dict()
46
+ i = 0
47
+ for id in allid:
48
+ id_to_emb_dict[id] = np.array([all_keys[i]])
49
+ i += 1
50
+ for id in stid:
51
+ id_to_emb_dict[id] = np.array([seen_test[i]])
52
+ i += 1
53
+ for id in utid:
54
+ id_to_emb_dict[id] = np.array([unseen_test[i]])
55
+ i += 1
56
+ for id in svalid:
57
+ id_to_emb_dict[id] = np.array([seen_val[i]])
58
+ i += 1
59
+ for id in uvalid:
60
+ id_to_emb_dict[id] = np.array([unseen_val[i]])
61
+ i += 1
62
+
63
+ return id_to_emb_dict
64
+
65
+
66
+ @click.command()
67
+ @click.option(
68
+ "--input",
69
+ type=click.Path(path_type=Path),
70
+ default="bioscan-clip-scripts/extracted_features",
71
+ help="Path to extracted features",
72
+ )
73
+ @click.option(
74
+ "--metadata", type=click.Path(path_type=Path), default="data/BIOSCAN_5M/BIOSCAN_5M.hdf5", help="Path to metadata"
75
+ )
76
+ @click.option(
77
+ "--output", type=click.Path(path_type=Path), default="bioscan-clip-scripts/index", help="Path to save the index"
78
+ )
79
+ def main(input, metadata, output):
80
+ # initialize data
81
+ all_keys = h5py.File(input / "extracted_features_of_all_keys.hdf5", "r", libver="latest")
82
+ all_keys_dna = all_keys["encoded_dna_feature"][:]
83
+ all_keys_im = all_keys["encoded_image_feature"][:]
84
+
85
+ seen_test = h5py.File(input / "extracted_features_of_seen_test.hdf5", "r", libver="latest")
86
+ seen_test_dna = seen_test["encoded_dna_feature"][:]
87
+ seen_test_im = seen_test["encoded_image_feature"][:]
88
+
89
+ unseen_test = h5py.File(input / "extracted_features_of_unseen_test.hdf5", "r", libver="latest")
90
+ unseen_test_dna = unseen_test["encoded_dna_feature"][:]
91
+ unseen_test_im = unseen_test["encoded_image_feature"][:]
92
+
93
+ seen_val = h5py.File(input / "extracted_features_of_seen_val.hdf5", "r", libver="latest")
94
+ seen_val_dna = seen_val["encoded_dna_feature"][:]
95
+ seen_val_im = seen_val["encoded_image_feature"][:]
96
+
97
+ unseen_val = h5py.File(input / "extracted_features_of_unseen_val.hdf5", "r", libver="latest")
98
+ unseen_val_dna = unseen_val["encoded_dna_feature"][:]
99
+ unseen_val_im = unseen_val["encoded_image_feature"][:]
100
+
101
+ dataset = h5py.File(metadata, "r", libver="latest")
102
+ id_field = "sampleid" # "processid"
103
+ allid = [item.decode("utf-8") for item in dataset["all_keys"][id_field][:]]
104
+ stid = [item.decode("utf-8") for item in dataset["test_seen"][id_field][:]]
105
+ utid = [item.decode("utf-8") for item in dataset["test_unseen"][id_field][:]]
106
+ svalid = [item.decode("utf-8") for item in dataset["val_seen"][id_field][:]]
107
+ uvalid = [item.decode("utf-8") for item in dataset["val_unseen"][id_field][:]]
108
+
109
+ all_keys = dataset["all_keys"]
110
+ seen_test = dataset["test_seen"]
111
+ unseen_test = dataset["test_unseen"]
112
+ seen_val = dataset["val_seen"]
113
+ unseen_val = dataset["val_unseen"]
114
+
115
+ # d = getIdToEmbedding(allid, stid, utid, svalid, uvalid, all_keys_dna, seen_test_dna, unseen_test_dna, seen_val_dna, unseen_val_dna)
116
+ # d = getIdToEmbedding(allid, stid, utid, svalid, uvalid, all_keys_im, seen_test_im, unseen_test_im, seen_val_im, unseen_val_im)
117
+
118
+ big_id_to_image_emb_dict = dict()
119
+ i = 0
120
+ for object in allid:
121
+ big_id_to_image_emb_dict[object] = np.array([all_keys_im[i]])
122
+ i += 1
123
+ i = 0
124
+ for object in stid:
125
+ big_id_to_image_emb_dict[object] = np.array([seen_test_im[i]])
126
+ i += 1
127
+ i = 0
128
+ for object in utid:
129
+ big_id_to_image_emb_dict[object] = np.array([unseen_test_im[i]])
130
+ i += 1
131
+ i = 0
132
+ for object in svalid:
133
+ big_id_to_image_emb_dict[object] = np.array([seen_val_im[i]])
134
+ i += 1
135
+ i = 0
136
+ for object in uvalid:
137
+ big_id_to_image_emb_dict[object] = np.array([unseen_val_im[i]])
138
+ i += 1
139
+
140
+ ###
141
+
142
+ big_id_to_dna_emb_dict = dict()
143
+ i = 0
144
+ for object in allid:
145
+ big_id_to_dna_emb_dict[object] = np.array([all_keys_dna[i]])
146
+ i += 1
147
+ i = 0
148
+ for object in stid:
149
+ big_id_to_dna_emb_dict[object] = np.array([seen_test_dna[i]])
150
+ i += 1
151
+ i = 0
152
+ for object in utid:
153
+ big_id_to_dna_emb_dict[object] = np.array([unseen_test_dna[i]])
154
+ i += 1
155
+ i = 0
156
+ for object in svalid:
157
+ big_id_to_dna_emb_dict[object] = np.array([seen_val_dna[i]])
158
+ i += 1
159
+ i = 0
160
+ for object in uvalid:
161
+ big_id_to_dna_emb_dict[object] = np.array([unseen_val_dna[i]])
162
+ i += 1
163
+
164
+ ###
165
+
166
+ processid_to_indx = dict()
167
+ big_indx_to_id_dict = dict()
168
+ i = 0
169
+ for object in allid:
170
+ big_indx_to_id_dict[i] = object
171
+ processid_to_indx[object] = i
172
+ i += 1
173
+
174
+ for object in stid:
175
+ big_indx_to_id_dict[i] = object
176
+ processid_to_indx[object] = i
177
+ i += 1
178
+
179
+ for object in utid:
180
+ big_indx_to_id_dict[i] = object
181
+ processid_to_indx[object] = i
182
+ i += 1
183
+
184
+ for object in svalid:
185
+ big_indx_to_id_dict[i] = object
186
+ processid_to_indx[object] = i
187
+ i += 1
188
+
189
+ for object in uvalid:
190
+ big_indx_to_id_dict[i] = object
191
+ processid_to_indx[object] = i
192
+ i += 1
193
+
194
+ ###
195
+
196
+ with open(output / "big_id_to_image_emb_dict.pickle", "wb") as f:
197
+ pickle.dump(big_id_to_image_emb_dict, f)
198
+ with open(output / "big_id_to_dna_emb_dict.pickle", "wb") as f:
199
+ pickle.dump(big_id_to_dna_emb_dict, f)
200
+ with open(output / "big_indx_to_id_dict.pickle", "wb") as f:
201
+ pickle.dump(big_indx_to_id_dict, f)
202
+
203
+
204
+ if __name__ == "__main__":
205
+ main()