taesiri commited on
Commit
b35ee4e
1 Parent(s): 006cc92

update the UI

Browse files
Files changed (3) hide show
  1. app.py +84 -56
  2. requirements.txt +1 -1
  3. visualization.py +9 -9
app.py CHANGED
@@ -1,3 +1,4 @@
 
1
  import csv
2
  import sys
3
  import pickle
@@ -7,6 +8,7 @@ import gradio as gr
7
  import gdown
8
  import torchvision
9
  from torchvision.datasets import ImageFolder
 
10
 
11
  from SimSearch import FaissCosineNeighbors, SearchableTrainingSet
12
  from ExtractEmbedding import QueryToEmbedding
@@ -17,42 +19,42 @@ csv.field_size_limit(sys.maxsize)
17
 
18
  concat = lambda x: np.concatenate(x, axis=0)
19
 
20
- # Embeddings
21
- gdown.cached_download(
22
- url="https://static.taesiri.com/chm-corr/embeddings.pickle",
23
- path="./embeddings.pickle",
24
- quiet=False,
25
- md5="002b2a7f5c80d910b9cc740c2265f058",
26
- )
27
-
28
- # embeddings
29
- # gdown.download(id="116CiA_cXciGSl72tbAUDoN-f1B9Frp89")
30
-
31
- # labels
32
- gdown.download(id="1SDtq6ap7LPPpYfLbAxaMGGmj0EAV_m_e")
33
-
34
- # CUB training set
35
- gdown.cached_download(
36
- url="https://static.taesiri.com/chm-corr/CUB_train.zip",
37
- path="./CUB_train.zip",
38
- quiet=False,
39
- md5="1bd99e73b2fea8e4c2ebcb0e7722f1b1",
40
- )
41
-
42
- # EXTRACT training set
43
- torchvision.datasets.utils.extract_archive(
44
- from_path="CUB_train.zip",
45
- to_path="data/",
46
- remove_finished=False,
47
- )
48
-
49
- # CHM Weights
50
- gdown.cached_download(
51
- url="https://static.taesiri.com/chm-corr/pas_psi.pt",
52
- path="pas_psi.pt",
53
- quiet=False,
54
- md5="6b7b4d7bad7f89600fac340d6aa7708b",
55
- )
56
 
57
 
58
  # Caluclate Accuracy
@@ -72,7 +74,7 @@ id_to_bird_name = {
72
  }
73
 
74
 
75
- def search(query_image, searcher=searcher):
76
  query_embedding = QueryToEmbedding(query_image)
77
  scores, indices, labels = searcher.search(query_embedding, k=50)
78
 
@@ -99,27 +101,53 @@ def search(query_image, searcher=searcher):
99
  query_image, kNN_results, support, training_folder
100
  )
101
 
102
- viz_plot = plot_from_reranker_output(chm_output, draw_arcs=False)
103
 
104
- return viz_plot, predicted_labels, gallery_images
105
 
 
 
 
 
 
 
106
 
107
- demo = gr.Interface(
108
- search,
109
- gr.Image(type="filepath"),
110
- ["plot", "label", "gallery"],
111
- examples=[
112
- ["./examples/bird.jpg"],
113
- ["./examples/Red_Winged_Blackbird_0012_6015.jpg"],
114
- ["./examples/Red_Winged_Blackbird_0025_5342.jpg"],
115
- ["./examples/sample1.jpeg"],
116
- ["./examples/sample2.jpeg"],
117
- ["./examples/Yellow_Headed_Blackbird_0020_8549.jpg"],
118
- ["./examples/Yellow_Headed_Blackbird_0026_8545.jpg"],
119
- ],
120
- description="WIP - kNN on CUB dataset",
121
- title="Work in Progress - CHM-Corr",
122
- )
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
123
 
124
  if __name__ == "__main__":
125
- demo.launch()
 
 
 
 
1
+ import io
2
  import csv
3
  import sys
4
  import pickle
 
8
  import gdown
9
  import torchvision
10
  from torchvision.datasets import ImageFolder
11
+ from PIL import Image
12
 
13
  from SimSearch import FaissCosineNeighbors, SearchableTrainingSet
14
  from ExtractEmbedding import QueryToEmbedding
 
19
 
20
  concat = lambda x: np.concatenate(x, axis=0)
21
 
22
+ # # Embeddings
23
+ # gdown.cached_download(
24
+ # url="https://static.taesiri.com/chm-corr/embeddings.pickle",
25
+ # path="./embeddings.pickle",
26
+ # quiet=False,
27
+ # md5="002b2a7f5c80d910b9cc740c2265f058",
28
+ # )
29
+
30
+ # # embeddings
31
+ # # gdown.download(id="116CiA_cXciGSl72tbAUDoN-f1B9Frp89")
32
+
33
+ # # labels
34
+ # gdown.download(id="1SDtq6ap7LPPpYfLbAxaMGGmj0EAV_m_e")
35
+
36
+ # # CUB training set
37
+ # gdown.cached_download(
38
+ # url="https://static.taesiri.com/chm-corr/CUB_train.zip",
39
+ # path="./CUB_train.zip",
40
+ # quiet=False,
41
+ # md5="1bd99e73b2fea8e4c2ebcb0e7722f1b1",
42
+ # )
43
+
44
+ # # EXTRACT training set
45
+ # torchvision.datasets.utils.extract_archive(
46
+ # from_path="CUB_train.zip",
47
+ # to_path="data/",
48
+ # remove_finished=False,
49
+ # )
50
+
51
+ # # CHM Weights
52
+ # gdown.cached_download(
53
+ # url="https://static.taesiri.com/chm-corr/pas_psi.pt",
54
+ # path="pas_psi.pt",
55
+ # quiet=False,
56
+ # md5="6b7b4d7bad7f89600fac340d6aa7708b",
57
+ # )
58
 
59
 
60
  # Caluclate Accuracy
 
74
  }
75
 
76
 
77
+ def search(query_image, draw_arcs, searcher=searcher):
78
  query_embedding = QueryToEmbedding(query_image)
79
  scores, indices, labels = searcher.search(query_embedding, k=50)
80
 
 
101
  query_image, kNN_results, support, training_folder
102
  )
103
 
104
+ fig = plot_from_reranker_output(chm_output, draw_arcs=draw_arcs)
105
 
106
+ # Resize the output
107
 
108
+ img_buf = io.BytesIO()
109
+ fig.savefig(img_buf, format="jpg")
110
+ image = Image.open(img_buf)
111
+ width, height = image.size
112
+ new_width = width
113
+ new_height = height
114
 
115
+ left = (width - new_width) / 2
116
+ top = (height - new_height) / 2
117
+ right = (width + new_width) / 2
118
+ bottom = (height + new_height) / 2
119
+
120
+ viz_image = image.crop((left + 540, top + 40, right - 492, bottom - 100))
121
+
122
+ return viz_image, predicted_labels
123
+
124
+
125
+ blocks = gr.Blocks()
126
+
127
+ with blocks:
128
+ gr.Markdown(""" # CHM-Corr DEMO""")
129
+ gr.Markdown(""" ### Parameters: N=50, k=20 - Using ResNet50 features""")
130
+
131
+ # with gr.Row():
132
+ input_image = gr.Image(type="filepath")
133
+ with gr.Column():
134
+ arcs_checkbox = gr.Checkbox(label="Draw Arcs")
135
+ run_btn = gr.Button("Classify")
136
+
137
+ # with gr.Column():
138
+ gr.Markdown(""" ### CHM-Corr Output """)
139
+ viz_plot = gr.Image(type="pil")
140
+ gr.Markdown(""" ### kNN Predicted Labels """)
141
+ predicted_labels = gr.Label(label="kNN Prediction")
142
+
143
+ run_btn.click(
144
+ search,
145
+ inputs=[input_image, arcs_checkbox],
146
+ outputs=[viz_plot, predicted_labels],
147
+ )
148
 
149
  if __name__ == "__main__":
150
+ blocks.launch(
151
+ debug=True,
152
+ enable_queue=True,
153
+ )
requirements.txt CHANGED
@@ -1,4 +1,3 @@
1
- faiss-cpu==1.7.2
2
  gdown
3
  gradio
4
  numpy
@@ -8,3 +7,4 @@ torchvision
8
  tqdm
9
  tensorboardX==2.5
10
  matplotlib
 
 
 
1
  gdown
2
  gradio
3
  numpy
 
7
  tqdm
8
  tensorboardX==2.5
9
  matplotlib
10
+ faiss-cpu==1.7.2
visualization.py CHANGED
@@ -261,14 +261,14 @@ def plot_from_reranker_output(reranker_output, draw_box=True, draw_arcs=True):
261
  color="black",
262
  fontsize=22,
263
  )
264
- fig.text(
265
- 0.8,
266
- 0.95,
267
- f"KNN: {reranker_output['knn-prediction']}",
268
- ha="right",
269
- va="bottom",
270
- color="black",
271
- fontsize=22,
272
- )
273
 
274
  return fig
 
261
  color="black",
262
  fontsize=22,
263
  )
264
+ # fig.text(
265
+ # 0.8,
266
+ # 0.95,
267
+ # f"KNN: {reranker_output['knn-prediction']}",
268
+ # ha="right",
269
+ # va="bottom",
270
+ # color="black",
271
+ # fontsize=22,
272
+ # )
273
 
274
  return fig