Fangrui Liu commited on
Commit
3f1124e
1 Parent(s): e66a418
Files changed (7) hide show
  1. README.md +1 -1
  2. TestSet.py +120 -0
  3. app.py +393 -0
  4. box_utils.py +133 -0
  5. card_model.py +94 -0
  6. classifier.py +121 -0
  7. query_model.py +108 -0
README.md CHANGED
@@ -6,7 +6,7 @@ colorTo: purple
6
  sdk: streamlit
7
  sdk_version: 1.10.0
8
  app_file: app.py
9
- pinned: false
10
  license: lgpl-3.0
11
  ---
12
 
 
6
  sdk: streamlit
7
  sdk_version: 1.10.0
8
  app_file: app.py
9
+ pinned: true
10
  license: lgpl-3.0
11
  ---
12
 
TestSet.py ADDED
@@ -0,0 +1,120 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import time
2
+ import requests
3
+ from io import BytesIO
4
+ from os import path
5
+ from torch.utils.data import Dataset
6
+ from PIL import Image
7
+
8
+ class TestImageSetOnline(Dataset):
9
+ """ Test Image set with hugging face CLIP preprocess interface
10
+
11
+ Args:
12
+ Dataset (torch.utils.data.Dataset):
13
+ """
14
+ def __init__(self, processor, image_list, timeout_base=0.5, timeout_mul=2):
15
+ """
16
+ Args:
17
+ processor (CLIP preprocessor): process data to a CLIP digestable format
18
+ image_list (pandas.DataFrame): pandas.DataFrame that contains image metadata
19
+ timeout_base (float, optional): initial timeout parameter. Defaults to 0.5.
20
+ timeout_mul (int, optional): multiplier on timeout every time reqeust fails. Defaults to 2.
21
+ """
22
+ self.image_list = image_list
23
+ self.processor = processor
24
+ self.timeout_base = timeout_base
25
+ self.timeout = self.timeout_base
26
+ self.timeout_mul = timeout_mul
27
+
28
+ def __getitem__(self, index):
29
+ row = self.image_list[index]
30
+ url = str(row['coco_url'])
31
+ _id = str(row['id'])
32
+ txt, img = None, None
33
+ flag = True
34
+ while flag:
35
+ try:
36
+ # Get images online
37
+ response = requests.get(url)
38
+ img = Image.open(BytesIO(response.content))
39
+ img_s = img.size
40
+ if img.mode in ['L', 'CMYK', 'RGBA']:
41
+ # L is grayscale, CMYK uses alternative color channels
42
+ img = img.convert('RGB')
43
+ # Preprocess image
44
+ ret = self.processor(text=txt, images=img, return_tensor='pt')
45
+ img = ret['pixel_values'][0]
46
+ # If success, then there will be no need to run this again
47
+ flag = False
48
+ # Relief the timeout param
49
+ if self.timeout > self.timeout_base:
50
+ self.timeout /= self.timeout_mul
51
+ except Exception as e:
52
+ print(f"{_id} {url}: {str(e)}")
53
+ if type(e) is KeyboardInterrupt:
54
+ raise e
55
+ time.sleep(self.timeout)
56
+ # Tension the timeout param and turn into a new request
57
+ self.timeout *= self.timeout_mul
58
+ return _id, url, img, img_s
59
+
60
+ def get(self, url):
61
+ _id = url
62
+ txt, img = None, None
63
+ flag = True
64
+ while flag:
65
+ try:
66
+ # Get images online
67
+ response = requests.get(url)
68
+ img = Image.open(BytesIO(response.content))
69
+ img_s = img.size
70
+ if img.mode in ['L', 'CMYK', 'RGBA']:
71
+ # L is grayscale, CMYK uses alternative color channels
72
+ img = img.convert('RGB')
73
+ # Preprocess image
74
+ ret = self.processor(text=txt, images=img, return_tensor='pt')
75
+ img = ret['pixel_values'][0]
76
+ # If success, then there will be no need to run this again
77
+ flag = False
78
+ # Relief the timeout param
79
+ if self.timeout > self.timeout_base:
80
+ self.timeout /= self.timeout_mul
81
+ except Exception as e:
82
+ print(f"{_id} {url}: {str(e)}")
83
+ if type(e) is KeyboardInterrupt:
84
+ raise e
85
+ time.sleep(self.timeout)
86
+ # Tension the timeout param and turn into a new request
87
+ self.timeout *= self.timeout_mul
88
+ return _id, url, img, img_s
89
+
90
+
91
+ def __len__(self,):
92
+ return len(self.image_list)
93
+
94
+ def __add__(self, other):
95
+ self.image_list += other.image_list
96
+ return self
97
+
98
+ class TestImageSet(TestImageSetOnline):
99
+ def __init__(self, droot, processor, image_list, timeout_base=0.5, timeout_mul=2):
100
+ super().__init__(processor, image_list, timeout_base, timeout_mul)
101
+ self.droot = droot
102
+
103
+ def __getitem__(self, index):
104
+ row = self.image_list[index]
105
+ url = str(row['coco_url'])
106
+ _id = '_'.join([url.split('/')[-2], str(row['id'])])
107
+ txt, img = None, None
108
+ # Get images online
109
+ img = Image.open(path.join(self.droot,
110
+ url.split('http://images.cocodataset.org/')[1]))
111
+ img_s = img.size
112
+ if img.mode in ['L', 'CMYK', 'RGBA']:
113
+ # L is grayscale, CMYK uses alternative color channels
114
+ img = img.convert('RGB')
115
+ # Preprocess image
116
+ ret = self.processor(text=txt, images=img, return_tensor='pt')
117
+ img = ret['pixel_values'][0]
118
+ # If success, then there will be no need to run this again
119
+ return _id, url, img, img_s
120
+
app.py ADDED
@@ -0,0 +1,393 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from time import time
2
+ import aiohttp
3
+ from io import BytesIO
4
+ import torch
5
+ import streamlit as st
6
+ import streamlit.components.v1 as components
7
+ import numpy as np
8
+ import torch
9
+ import logging
10
+ from os import environ
11
+ from transformers import OwlViTProcessor, OwlViTForObjectDetection
12
+
13
+ from myscaledb import Client
14
+ from classifier import Classifier, prompt2vec, tune, SplitLayer
15
+ from query_model import simple_query, topk_obj_query, rev_query
16
+ from card_model import card, obj_card, style
17
+ from box_utils import postprocess
18
+
19
+ environ['TOKENIZERS_PARALLELISM'] = 'true'
20
+
21
+ OBJ_DB_NAME = "mqdb_demo.coco_owl_vit_b_32_objects"
22
+ IMG_DB_NAME = "mqdb_demo.coco_owl_vit_b_32_images"
23
+ MODEL_ID = 'google/owlvit-base-patch32'
24
+ DIMS = 512
25
+
26
+ qtime = 0
27
+
28
+
29
+ def build_model(name="google/owlvit-base-patch32"):
30
+ """Model builder function
31
+
32
+ Args:
33
+ name (str, optional): Name for HuggingFace OwlViT model. Defaults to "google/owlvit-base-patch32".
34
+
35
+ Returns:
36
+ (model, processor): OwlViT model and its processor for both image and text
37
+ """
38
+ device = 'cpu'
39
+ if torch.cuda.is_available():
40
+ device = 'cuda'
41
+ model = OwlViTForObjectDetection.from_pretrained(name).to(device)
42
+ processor = OwlViTProcessor.from_pretrained(name)
43
+ return model, processor
44
+
45
+
46
+ @st.experimental_singleton(show_spinner=False)
47
+ def init_owlvit():
48
+ """ Initialize OwlViT Model
49
+
50
+ Returns:
51
+ model, processor
52
+ """
53
+ model, processor = build_model(MODEL_ID)
54
+ return model, processor
55
+
56
+
57
+ @st.experimental_singleton(show_spinner=False)
58
+ def init_db():
59
+ """ Initialize the Database Connection
60
+
61
+ Returns:
62
+ meta_field: Meta field that records if an image is viewed or not
63
+ client: Database connection object
64
+ """
65
+ meta = []
66
+ client = Client(
67
+ url=st.secrets["DB_URL"], user=st.secrets["USER"], password=st.secrets["PASSWD"])
68
+ # We can check if the connection is alive
69
+ assert client.is_alive()
70
+ return meta, client
71
+
72
+
73
+ def refresh_index():
74
+ """ Clean the session
75
+ """
76
+ del st.session_state["meta"]
77
+ st.session_state.meta = []
78
+ st.session_state.query_num = 0
79
+ logging.info(f"Refresh for '{st.session_state.meta}'")
80
+ # Need to clear singleton function with streamlit API
81
+ init_db.clear()
82
+ # refresh session states
83
+ st.session_state.meta, st.session_state.index = init_db()
84
+ if 'clf' in st.session_state:
85
+ del st.session_state.clf
86
+ if 'xq' in st.session_state:
87
+ del st.session_state.xq
88
+ if 'topk_img_id' in st.session_state:
89
+ del st.session_state.topk_img_id
90
+
91
+
92
+ def query(xq, exclude_list=None):
93
+ """ Query matched w.r.t a given vector
94
+
95
+ In this part, we will retrieve A LOT OF data from the server,
96
+ including TopK boxes and their embeddings, the counterpart of non-TopK boxes in TopK images.
97
+
98
+ Args:
99
+ xq (numpy.ndarray or list of floats): Query vector
100
+
101
+ Returns:
102
+ matches: list of Records object. Keys referrring to selected columns group by images.
103
+ Exclude the user's viewlist.
104
+ img_matches: list of Records object. Containing other non-TopK but hit objects among TopK images.
105
+ side_matches: list of Records object. Containing REAL TopK objects disregard the user's view history
106
+ """
107
+ attempt = 0
108
+ xq = xq
109
+ xq = xq / np.linalg.norm(xq, axis=-1, ord=2, keepdims=True)
110
+ status_bar = [st.empty(), st.empty()]
111
+ status_bar[0].write("Retrieving Another TopK Images...")
112
+ pbar = status_bar[1].progress(0)
113
+ while attempt < 3:
114
+ try:
115
+ matches = topk_obj_query(
116
+ st.session_state.index, xq, IMG_DB_NAME, OBJ_DB_NAME,
117
+ exclude_list=exclude_list, topk=5000)
118
+ img_ids = [r['img_id'] for r in matches]
119
+ if 'topk_img_id' not in st.session_state:
120
+ st.session_state.topk_img_id = img_ids
121
+ status_bar[0].write("Retrieving TopK Images...")
122
+ pbar.progress(25)
123
+ o_matches = rev_query(
124
+ st.session_state.index, xq, st.session_state.topk_img_id,
125
+ IMG_DB_NAME, OBJ_DB_NAME, thresh=0.1)
126
+ status_bar[0].write("Retrieving TopKs Objects...")
127
+ pbar.progress(50)
128
+ side_matches = simple_query(st.session_state.index, xq, IMG_DB_NAME, OBJ_DB_NAME,
129
+ thresh=-1, topk=5000)
130
+ status_bar[0].write(
131
+ "Retrieving Non-TopK in Another TopK Images...")
132
+ pbar.progress(75)
133
+ if len(img_ids) > 0:
134
+ img_matches = rev_query(
135
+ st.session_state.index, xq, img_ids, IMG_DB_NAME, OBJ_DB_NAME,
136
+ thresh=0.1)
137
+ else:
138
+ img_matches = []
139
+ status_bar[0].write("DONE!")
140
+ pbar.progress(100)
141
+ break
142
+ except Exception as e:
143
+ # force reload if we have trouble on connections or something else
144
+ logging.warning(str(e))
145
+ st.session_state.meta, st.session_state.index = init_db()
146
+ attempt += 1
147
+ matches = []
148
+ _ = [s.empty() for s in status_bar]
149
+ if len(matches) == 0:
150
+ logging.error(f"No matches found for '{OBJ_DB_NAME}'")
151
+ return matches, img_matches, side_matches, o_matches
152
+
153
+
154
+ @st.experimental_singleton(show_spinner=False)
155
+ def init_random_query():
156
+ """Initialize a random query vector
157
+
158
+ Returns:
159
+ xq: a random vector
160
+ """
161
+ xq = np.random.rand(1, DIMS)
162
+ xq /= np.linalg.norm(xq, keepdims=True, axis=-1)
163
+ return xq
164
+
165
+
166
+ def submit(meta):
167
+ """ Tune the model w.r.t given score from user.
168
+ """
169
+ # Only updating the meta if the train button is pressed
170
+ st.session_state.meta.extend(meta)
171
+ st.session_state.step += 1
172
+ matches = st.session_state.matched_boxes
173
+ X, y = list(zip(*((v[-1],
174
+ st.session_state.text_prompts.index(
175
+ st.session_state[f"label-{i}"])) for i, v in matches.items())))
176
+ st.session_state.xq = tune(st.session_state.clf,
177
+ X, y, iters=int(st.session_state.iters))
178
+ st.session_state.matches, \
179
+ st.session_state.img_matches, \
180
+ st.session_state.side_matches, \
181
+ st.session_state.o_matches = query(
182
+ st.session_state.xq, st.session_state.meta)
183
+
184
+
185
+ # st.set_page_config(layout="wide")
186
+ # To hack the streamlit style we define our own style.
187
+ # Boxes are drawn in SVGs.
188
+ st.write(style(), unsafe_allow_html=True)
189
+
190
+ with st.spinner("Connecting DB..."):
191
+ st.session_state.meta, st.session_state.index = init_db()
192
+
193
+ with st.spinner("Loading Models..."):
194
+ # Initialize model
195
+ model, tokenizer = init_owlvit()
196
+
197
+ # If its a fresh start... (query not set)
198
+ if 'xq' not in st.session_state:
199
+ with st.container():
200
+ st.title('Object Detection Safari')
201
+ start = [st.empty() for _ in range(8)]
202
+ start[0].info("""
203
+ We extracted boxes from **287,104** images in COCO Dataset, including its train / val / test /
204
+ unlabeled images, collecting **165,371,904 boxes** which are then filtered with common prompts.
205
+ You can search with almost any words or phrases you can think of. Please enjoy your journey of
206
+ an adventure to COCO.
207
+ """)
208
+ prompt = start[1].text_input(
209
+ "Prompt:", value="", placeholder="Examples: football, billboard, stop sign, watermark ...",)
210
+ with start[2].container():
211
+ st.write(
212
+ 'You can search with multiple keywords. Plese separate with commas but with no space.')
213
+ st.write('For example: `cat,dog,tree`')
214
+ st.markdown('''
215
+ <p style="color:gray;"> Don\'t know what to search? Try <b>Random</b>!</p>
216
+ ''',
217
+ unsafe_allow_html=True)
218
+
219
+ upld_model = start[4].file_uploader(
220
+ "Or you can upload your previous run!", type='onnx')
221
+ upld_btn = start[5].button(
222
+ "Use Loaded Weights", disabled=upld_model is None, on_click=refresh_index)
223
+
224
+ with start[3]:
225
+ col = st.columns(8)
226
+ has_no_prompt = (len(prompt) == 0 and upld_model is None)
227
+ prompt_xq = col[6].button("Prompt", disabled=len(
228
+ prompt) == 0, on_click=refresh_index)
229
+ random_xq = col[7].button(
230
+ "Random", disabled=not has_no_prompt, on_click=refresh_index)
231
+ matches = []
232
+ img_matches = []
233
+ if random_xq:
234
+ xq = init_random_query()
235
+ st.session_state.xq = xq
236
+ prompt = 'unknown'
237
+ st.session_state.text_prompts = prompt.split(',') + ['none']
238
+ _ = [elem.empty() for elem in start]
239
+ t0 = time()
240
+ matches, img_matches, side_matches, o_matches = query(
241
+ st.session_state.xq, st.session_state.meta)
242
+ t1 = time()
243
+ qtime = (t1-t0) * 1000
244
+ elif prompt_xq or upld_btn:
245
+ if upld_model is not None:
246
+ import onnx
247
+ from onnx import numpy_helper
248
+ _model = onnx.load(upld_model)
249
+ st.session_state.text_prompts = [
250
+ node.name for node in _model.graph.output] + ['none']
251
+ weights = _model.graph.initializer
252
+ xq = numpy_helper.to_array(weights[0]).T
253
+ assert xq.shape[0] == len(
254
+ st.session_state.text_prompts)-1 and xq.shape[1] == DIMS
255
+ st.session_state.xq = xq
256
+ _ = [elem.empty() for elem in start]
257
+ else:
258
+ logging.info(f"Input prompt is {prompt}")
259
+ st.session_state.text_prompts = prompt.split(',') + ['none']
260
+ input_ids, xq = prompt2vec(
261
+ st.session_state.text_prompts[:-1], model, tokenizer)
262
+ st.session_state.xq = xq
263
+ _ = [elem.empty() for elem in start]
264
+ t0 = time()
265
+ st.session_state.matches, \
266
+ st.session_state.img_matches, \
267
+ st.session_state.side_matches, \
268
+ st.session_state.o_matches = query(
269
+ st.session_state.xq, st.session_state.meta)
270
+ t1 = time()
271
+ qtime = (t1-t0) * 1000
272
+
273
+ # If its not a fresh start (query is set)
274
+ if 'xq' in st.session_state:
275
+ o_matches = st.session_state.o_matches
276
+ side_matches = st.session_state.side_matches
277
+ img_matches = st.session_state.img_matches
278
+ matches = st.session_state.matches
279
+ # initialize classifier
280
+ if 'clf' not in st.session_state:
281
+ st.session_state.clf = Classifier(st.session_state.xq)
282
+ st.session_state.step = 0
283
+ if qtime > 0:
284
+ st.info("Query done in {0:.2f} ms and returned {1:d} images with {2:d} boxes".format(
285
+ qtime, len(matches), sum([len(m["box_id"]) + len(im["box_id"]) for m, im in zip(matches, img_matches)])))
286
+
287
+ # export the model into executable ONNX
288
+ st.session_state.dnld_model = BytesIO()
289
+ torch.onnx.export(torch.nn.Sequential(st.session_state.clf.model, SplitLayer()),
290
+ torch.zeros([1, len(st.session_state.xq[0])]),
291
+ st.session_state.dnld_model,
292
+ input_names=['input'],
293
+ output_names=st.session_state.text_prompts[:-1])
294
+
295
+ dnld_nam = st.text_input('Download Name:',
296
+ f'{("_".join([i.replace(" ", "-") for i in st.session_state.text_prompts[:-1]]) if "text_prompts" in st.session_state else "model")}.onnx',
297
+ max_chars=50)
298
+ dnld_btn = st.download_button('Download your classifier!',
299
+ st.session_state.dnld_model,
300
+ dnld_nam)
301
+ # build up a sidebar to display REAL TopK in DB
302
+ # this will change during user's finetune. But sometime it would lead to bad results
303
+ side_bar_len = min(240 // len(st.session_state.text_prompts), 120)
304
+ with st.sidebar:
305
+ with st.expander("Top-K Images"):
306
+ with st.container():
307
+ boxes_w_img, _ = postprocess(o_matches, st.session_state.text_prompts,
308
+ None)
309
+ boxes_w_img = sorted(
310
+ boxes_w_img, key=lambda x: x[4], reverse=True)
311
+ for img_id, img_url, img_w, img_h, img_score, boxes in boxes_w_img:
312
+ args = img_url, img_w, img_h, boxes
313
+ st.write(card(*args), unsafe_allow_html=True)
314
+
315
+ with st.expander("Top-K Objects", expanded=True):
316
+ side_cols = st.columns(
317
+ len(st.session_state.text_prompts[:-1]))
318
+ for _cols, m in zip(side_cols, side_matches):
319
+ with _cols.container():
320
+ for cx, cy, w, h, logit, img_url, img_w, img_h \
321
+ in zip(m['cx'], m['cy'], m['w'], m['h'], m['logit'],
322
+ m['img_url'], m['img_w'], m['img_h']):
323
+ st.write("{:s}: {:.4f}".format(
324
+ st.session_state.text_prompts[m['label']], logit))
325
+ _html = obj_card(
326
+ img_url, img_w, img_h, cx, cy, w, h, dst_len=side_bar_len)
327
+ components.html(
328
+ _html, side_bar_len, side_bar_len)
329
+ with st.container():
330
+ # Here let the user interact with batch labeling
331
+ with st.form("batch", clear_on_submit=False):
332
+ col = st.columns([1, 9])
333
+
334
+ # If there is nothing to show about
335
+ if len(matches) <= 0:
336
+ st.warning(
337
+ 'Oops! We didn\'t find anything relevant to your query! Pleas try another one :/')
338
+ else:
339
+ st.session_state.iters = st.slider(
340
+ "Number of Iterations to Update", min_value=0, max_value=10, step=1, value=2)
341
+ # No matter what happened the user wants a way back
342
+ col[1].form_submit_button(
343
+ "Choose a new prompt", on_click=refresh_index)
344
+
345
+ # If there are things to show
346
+ if len(matches) > 0:
347
+ with st.container():
348
+ prompt_labels = st.session_state.text_prompts
349
+
350
+ # Post processing boxes regarding to their score, intersection
351
+ boxes_w_img, meta = postprocess(matches, st.session_state.text_prompts,
352
+ img_matches)
353
+
354
+ # Sort the result according to their relavancy
355
+ boxes_w_img = sorted(
356
+ boxes_w_img, key=lambda x: x[4], reverse=True)
357
+
358
+ st.session_state.matched_boxes = {}
359
+ # For each images in the retrieved images, DISPLAY
360
+ for img_id, img_url, img_w, img_h, img_score, boxes in boxes_w_img:
361
+
362
+ # prepare inputs for training
363
+ st.session_state.matched_boxes.update(
364
+ {b[0]: b for b in boxes})
365
+ args = img_url, img_w, img_h, boxes
366
+
367
+ # display boxes
368
+ with st.expander("{:s}: {:.4f}".format(img_id, img_score), expanded=True):
369
+ ind_b = 0
370
+ # 4 columns: (img, obj, obj, obj)
371
+ img_row = st.columns([4, 2, 2, 2])
372
+ img_row[0].write(
373
+ card(*args), unsafe_allow_html=True)
374
+ # crop objects out of the original image
375
+ for b in boxes:
376
+ _id, cx, cy, w, h, label, logit, is_selected, _ = b
377
+ with img_row[1 + ind_b % 3].container():
378
+ st.write(
379
+ "{:s}: {:.4f}".format(label, logit))
380
+ # quite hacky: with streamlit components API
381
+ _html = \
382
+ obj_card(img_url, img_w, img_h,
383
+ *b[1:5], dst_len=120)
384
+ components.html(_html, 120, 120)
385
+ # the user will choose the right label of the given object
386
+ st.selectbox(
387
+ "Class",
388
+ prompt_labels,
389
+ index=prompt_labels.index(label),
390
+ key=f"label-{_id}")
391
+ ind_b += 1
392
+ col[0].form_submit_button(
393
+ "Train!", on_click=lambda: submit(meta))
box_utils.py ADDED
@@ -0,0 +1,133 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import numpy as np
2
+
3
+
4
+ def cxywh2xywh(cx, cy, w, h):
5
+ """ CxCyWH format to XYWH format conversion
6
+ """
7
+ x = cx - w / 2
8
+ y = cy - h / 2
9
+ return x, y, w, h
10
+
11
+
12
+ def cxywh2ltrb(cx, cy, w, h):
13
+ """CxCyWH format to LeftRightTopBottom format
14
+ """
15
+ l = cx - w / 2
16
+ t = cy - h / 2
17
+ r = cx + w / 2
18
+ b = cy + h / 2
19
+ return l, t, r, b
20
+
21
+
22
+ def iou(ba, bb):
23
+ """Calculate Intersection-Over-Union
24
+
25
+ Args:
26
+ ba (tuple): CxCyWH format with score
27
+ bb (tuple): CxCyWH format with score
28
+
29
+ Returns:
30
+ IoU with size of length of given box
31
+ """
32
+ a_l, a_t, a_r, a_b, sa = ba
33
+ b_l, b_t, b_r, b_b, sb = bb
34
+
35
+ x1 = np.maximum(a_l, b_l)
36
+ y1 = np.maximum(a_t, b_t)
37
+ x2 = np.minimum(a_r, b_r)
38
+ y2 = np.minimum(a_b, b_b)
39
+ w = np.maximum(0, x2 - x1)
40
+ h = np.maximum(0, y2 - y1)
41
+ intersec = w * h
42
+ iou = (intersec) / (sa + sb - intersec)
43
+ return iou.squeeze()
44
+
45
+
46
+ def nms(cx, cy, w, h, s, iou_thresh=0.3):
47
+ """Bounding box Non-maximum Suppression
48
+
49
+ Args:
50
+ cx, cy, w, h, s: CxCyWH Format with score boxes
51
+ iou_thresh (float, optional): IoU threshold. Defaults to 0.3.
52
+
53
+ Returns:
54
+ res: indexes of the selected boxes
55
+ """
56
+ l, t, r, b = cxywh2ltrb(cx, cy, w, h)
57
+ areas = w * h
58
+ res = []
59
+ sort_ind = np.argsort(s, axis=-1)[::-1]
60
+ while sort_ind.shape[0] > 0:
61
+ i = sort_ind[0]
62
+ res.append(i)
63
+
64
+ _iou = iou((l[i], t[i], r[i], b[i], areas[i]),
65
+ (l[sort_ind[1:]], t[sort_ind[1:]],
66
+ r[sort_ind[1:]], b[sort_ind[1:]], areas[sort_ind[1:]]))
67
+ sel_ind = np.where(_iou <= iou_thresh)[0]
68
+ sort_ind = sort_ind[sel_ind + 1]
69
+ return res
70
+
71
+
72
+ def filter_nonpos(boxes, agnostic_ratio=0.5, class_ratio=0.7):
73
+ """filter out insignificant boxes
74
+
75
+ Args:
76
+ boxes (list of records): returned query to be filtered
77
+ """
78
+ ret = []
79
+ labelwise = {}
80
+ for _id, cx, cy, w, h, label, logit, is_selected, _ in boxes:
81
+ if label not in labelwise:
82
+ labelwise[label] = []
83
+ labelwise[label].append(logit)
84
+ labelwise = {l: max(s) for l, s in labelwise.items()}
85
+ agnostic = max([v for _, v in labelwise.items()])
86
+ for b in boxes:
87
+ _id, cx, cy, w, h, label, logit, is_selected, _ = b
88
+ if logit > class_ratio * labelwise[label] \
89
+ and logit > agnostic_ratio * agnostic:
90
+ ret.append(b)
91
+ return ret
92
+
93
+
94
+ def postprocess(matches, prompt_labels, img_matches=None):
95
+ meta = []
96
+ boxes_w_img = []
97
+ matches_ = {m['img_id']: m for m in matches}
98
+ if img_matches is not None:
99
+ img_matches_ = {m['img_id']: m for m in img_matches}
100
+ for k in matches_.keys():
101
+ m = matches_[k]
102
+ boxes = []
103
+ boxes += list(map(list, zip(m['box_id'], m['cx'], m['cy'], m['w'], m['h'],
104
+ [prompt_labels[int(l)]
105
+ for l in m['label']],
106
+ m['logit'], [1] *
107
+ len(m['box_id']),
108
+ list(np.array(m['cls_emb'])))))
109
+ if img_matches is not None:
110
+ img_m = img_matches_[k]
111
+ # and also those non-TopK hits and those non-topk are not anticipating training
112
+ boxes += [i for i in map(list, zip(img_m['box_id'], img_m['cx'], img_m['cy'], img_m['w'], img_m['h'],
113
+ [prompt_labels[int(
114
+ l)] for l in img_m['label']], img_m['logit'],
115
+ [0] * len(img_m['box_id']), list(np.array(img_m['cls_emb']))))
116
+ if i[0] not in [b[0] for b in boxes]]
117
+ # update record metadata after query
118
+ for b in boxes:
119
+ meta.append(b[0])
120
+
121
+ # remove some non-significant boxes
122
+ boxes = filter_nonpos(
123
+ boxes, agnostic_ratio=0.4, class_ratio=0.7)
124
+
125
+ # doing non-maximum suppression
126
+ cx, cy, w, h, s = list(map(lambda x: np.array(x),
127
+ list(zip(*[(*b[1:5], b[6]) for b in boxes]))))
128
+ ind = nms(cx, cy, w, h, s, 0.3)
129
+ boxes = [boxes[i] for i in ind]
130
+ img_score = img_m['img_score'] if img_matches is not None else m['img_score']
131
+ boxes_w_img.append(
132
+ (m["img_id"], m["img_url"], m["img_w"], m["img_h"], img_score, boxes))
133
+ return boxes_w_img, meta
card_model.py ADDED
@@ -0,0 +1,94 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import base64
2
+ from box_utils import cxywh2ltrb, cxywh2xywh
3
+
4
+
5
+ def style():
6
+ """ Style string for card models
7
+ """
8
+ return """
9
+ <link
10
+ rel="stylesheet"
11
+ href="https://fonts.googleapis.com/css?family=Roboto:300,400,500,700&display=swap"
12
+ />
13
+ <style>
14
+ .img-overlay-wrap {
15
+ position: relative;
16
+ display: inline-block;
17
+ }
18
+ .img-overlay-wrap {
19
+ position: relative;
20
+ display: inline-block;
21
+ /* <= shrinks container to image size */
22
+ transition: transform 150ms ease-in-out;
23
+ }
24
+ .img-overlay-wrap img {
25
+ /* <= optional, for responsiveness */
26
+ display: block;
27
+ max-width: 100%;
28
+ height: auto;
29
+ }
30
+ .img-overlay-wrap svg {
31
+ position: absolute;
32
+ top: 0;
33
+ left: 0;
34
+ }
35
+ </style>
36
+ """
37
+
38
+
39
+ def card(img_url, img_w, img_h, boxes):
40
+ """ This is a hack to streamlit
41
+ Solution thanks to: https://discuss.streamlit.io/t/display-svg/172/5
42
+ Converting SVG to Base64 and display with <img> tag.
43
+ Also we used the
44
+ """
45
+ _boxes = ""
46
+ for _id, cx, cy, w, h, label, logit, is_selected, _ in boxes:
47
+ x, y, w, h = cxywh2xywh(cx, cy, w, h)
48
+ x = round(img_w * x)
49
+ y = round(img_h * y)
50
+ w = round(img_w * w)
51
+ h = round(img_h * h)
52
+ logit = "%.3f" % logit
53
+ _boxes += f'''
54
+ <text fill="white" font-size="20" x="{x}" y="{y}" style="fill:white;opacity:0.7">{label}: {logit}</text>
55
+ <rect x="{x}" y="{y}" width="{w}" height="{h}" style="fill:none;stroke:{"red" if is_selected else "green"};
56
+ stroke-width:4;opacity:0.5" />
57
+ '''
58
+ _svg = f'''
59
+ <svg xmlns="http://www.w3.org/2000/svg" viewBox="0 0 {img_w} {img_h}">
60
+ {_boxes}
61
+ </svg>
62
+ '''
63
+ _svg = r'<img style="position:absolute;top:0;left:0;" src="data:image/svg+xml;base64,%s"/>' % \
64
+ base64.b64encode(_svg.encode('utf-8')).decode('utf-8')
65
+ _img_d = f'''
66
+ <div class="img-overlay-wrap" width="{img_w}" height="{img_h}">
67
+ <img width="{img_w}" height="{img_h}" src="{img_url}">
68
+ {_svg}
69
+ </div>
70
+ '''
71
+ return _img_d
72
+
73
+
74
+ def obj_card(img_url, img_w, img_h, cx, cy, w, h, *args, dst_len=100):
75
+ """object card for displaying cropped object
76
+
77
+ Args:
78
+ Retrieved image and object info
79
+
80
+ Returns:
81
+ _obj_html: html string to display object
82
+ """
83
+ w = img_w * w
84
+ h = img_h * h
85
+ s = max(w, h)
86
+ x = round(img_w * cx - s / 2)
87
+ y = round(img_h * cy - s / 2)
88
+ scale = dst_len / s
89
+ _obj_html = f'''
90
+ <div style="transform-origin:0 0;transform:scale({scale});">
91
+ <img src="{img_url}" style="margin:{-y}px 0px 0px {-x}px;">
92
+ </div>
93
+ '''
94
+ return _obj_html
classifier.py ADDED
@@ -0,0 +1,121 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import torch
2
+
3
+
4
+ def extract_text_feature(prompt, model, processor, device='cpu'):
5
+ """Extract text features
6
+
7
+ Args:
8
+ prompt: a single text query
9
+ model: OwlViT model
10
+ processor: OwlViT processor
11
+ device (str, optional): device to run. Defaults to 'cpu'.
12
+ """
13
+ device = 'cpu'
14
+ if torch.cuda.is_available():
15
+ device = 'cuda'
16
+ with torch.no_grad():
17
+ input_ids = torch.as_tensor(processor(text=prompt)[
18
+ 'input_ids']).to(device)
19
+ print(input_ids.device)
20
+ text_outputs = model.owlvit.text_model(
21
+ input_ids=input_ids,
22
+ attention_mask=None,
23
+ output_attentions=None,
24
+ output_hidden_states=None,
25
+ return_dict=None,
26
+ )
27
+ text_embeds = text_outputs[1]
28
+ text_embeds = model.owlvit.text_projection(text_embeds)
29
+ text_embeds /= text_embeds.norm(p=2, dim=-1, keepdim=True) + 1e-6
30
+ query_embeds = text_embeds
31
+ return input_ids, query_embeds
32
+
33
+
34
+ def prompt2vec(prompt: str, model, processor):
35
+ """ Convert prompt into a computational vector
36
+
37
+ Args:
38
+ prompt (str): Text to be tokenized
39
+
40
+ Returns:
41
+ xq: vector from the tokenizer, representing the original prompt
42
+ """
43
+ # inputs = tokenizer(prompt, return_tensors='pt')
44
+ # out = clip.get_text_features(**inputs)
45
+ input_ids, xq = extract_text_feature(prompt, model, processor)
46
+ input_ids = input_ids.detach().cpu().numpy()
47
+ xq = xq.detach().cpu().numpy()
48
+ return input_ids, xq
49
+
50
+
51
+ def tune(clf, X, y, iters=2):
52
+ """ Train the Zero-shot Classifier
53
+
54
+ Args:
55
+ X (numpy.ndarray): Input vectors (retreived vectors)
56
+ y (list of floats or numpy.ndarray): Scores given by user
57
+ iters (int, optional): iterations of updates to be run
58
+ """
59
+ assert len(X) == len(y)
60
+ # train the classifier
61
+ clf.fit(X, y, iters=iters)
62
+ # extract new vector
63
+ return clf.get_weights()
64
+
65
+
66
+ class Classifier:
67
+ """Multi-Class Zero-shot Classifier
68
+ This Classifier provides proxy regarding to the user's reaction to the probed images.
69
+ The proxy will replace the original query vector generated by prompted vector and finally
70
+ give the user a satisfying retrieval result.
71
+
72
+ This can be commonly seen in a recommendation system. The classifier will recommend more
73
+ precise result as it accumulating user's activity.
74
+
75
+ This is a multiclass classifier. For N queries it will set the all queries to the first-N classes
76
+ and the last one takes the negative one.
77
+ """
78
+
79
+ def __init__(self, xq: list):
80
+ init_weight = torch.Tensor(xq)
81
+ self.num_class = xq.shape[0]
82
+ DIMS = xq.shape[1]
83
+ # note that the bias is ignored, as we only focus on the inner product result
84
+ self.model = torch.nn.Linear(DIMS, self.num_class, bias=False)
85
+ # convert initial query `xq` to tensor parameter to init weights
86
+ self.model.weight = torch.nn.Parameter(init_weight)
87
+ # init loss and optimizer
88
+ self.loss = torch.nn.BCEWithLogitsLoss()
89
+ self.optimizer = torch.optim.SGD(self.model.parameters(), lr=0.1)
90
+
91
+ def fit(self, X: list, y: list, iters: int = 5):
92
+ # convert X and y to tensor
93
+ X = torch.Tensor(X)
94
+ X /= torch.norm(X, p=2, dim=-1, keepdim=True)
95
+ y = torch.Tensor(y).long()
96
+ # Generate labels for binary classification and ignore outbound labels
97
+ non_ind = y > self.num_class
98
+ y = torch.nn.functional.one_hot(y % self.num_class, num_classes=self.num_class).float()
99
+ y[non_ind] = 0
100
+ for i in range(iters):
101
+ # zero gradients
102
+ self.optimizer.zero_grad()
103
+ # Normalize the weight before inference
104
+ # This will constrain the gradient or you will have an explosion on query vector
105
+ self.model.weight.data /= torch.norm(self.model.weight.data, p=2, dim=-1, keepdim=True)
106
+ # forward pass
107
+ out = self.model(X)
108
+ # compute loss
109
+ loss = self.loss(out, y)
110
+ # backward pass
111
+ loss.backward()
112
+ # update weights
113
+ self.optimizer.step()
114
+
115
+ def get_weights(self):
116
+ xq = self.model.weight.detach().numpy()
117
+ return xq
118
+
119
+ class SplitLayer(torch.nn.Module):
120
+ def forward(self, x):
121
+ return torch.split(x, 1, dim=-1)
query_model.py ADDED
@@ -0,0 +1,108 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import logging
2
+
3
+
4
+ def topk_obj_query(client, xq, IMG_DB_NAME, OBJ_DB_NAME,
5
+ exclude_list=[], topk=10):
6
+ xq_s = [
7
+ f"[{', '.join([str(float(fnum)) for fnum in _xq.tolist() + [1]])}]" for _xq in xq]
8
+ exclude_list_str = ','.join([f'\'{i}\'' for i in exclude_list])
9
+ _cond = (f"WHERE obj_id NOT IN ({exclude_list_str})" if len(
10
+ exclude_list) > 0 else "")
11
+ _subq_str = []
12
+ _img_score_subq = []
13
+ for _l, _xq in enumerate(xq_s):
14
+ _img_score_subq.append(
15
+ f"arrayReduce('maxIf', logit, arrayMap(x->x={_l}, label))")
16
+ _subq_str.append(f"""
17
+ SELECT img_id, img_url, img_w, img_h, 1/(1+exp(-arraySum(arrayMap((x,y)->x*y, prelogit, {_xq})))) AS pred_logit,
18
+ obj_id, box_cx, box_cy, box_w, box_h, class_embedding, {_l} AS l
19
+ FROM {OBJ_DB_NAME}
20
+ JOIN {IMG_DB_NAME}
21
+ ON {IMG_DB_NAME}.img_id = {OBJ_DB_NAME}.img_id
22
+ PREWHERE obj_id IN (
23
+ SELECT obj_id FROM (
24
+ SELECT obj_id, distance('topK={topk}', 'nprobe=32')(prelogit, {_xq}) AS dist FROM {OBJ_DB_NAME}
25
+ ORDER BY dist DESC
26
+ ) {_cond} LIMIT 10
27
+ )
28
+ """)
29
+ _subq_str = ' UNION ALL '.join(_subq_str)
30
+ _img_score_q = ','.join(_img_score_subq)
31
+ _img_score_q = f"arraySum(arrayFilter(x->NOT isNaN(x), array({_img_score_q}))) AS img_score"
32
+ q_str = f"""
33
+ SELECT img_id, img_url, img_w, img_h, groupArray(obj_id) AS box_id,
34
+ groupArray(box_cx) AS cx, groupArray(box_cy) AS cy, groupArray(box_w) AS w, groupArray(box_h) AS h,
35
+ groupArray(pred_logit) AS logit, groupArray(l) as label, groupArray(class_embedding) AS cls_emb,
36
+ {_img_score_q}
37
+ FROM
38
+ ({_subq_str})
39
+ GROUP BY img_id, img_url, img_w, img_h ORDER BY img_score DESC
40
+ """
41
+ xc = client.fetch(q_str)
42
+ return xc
43
+
44
+
45
+ def rev_query(client, xq, img_ids, IMG_DB_NAME, OBJ_DB_NAME, thresh=0.08):
46
+ xq_s = [
47
+ f"[{', '.join([str(float(fnum)) for fnum in _xq.tolist() + [1]])}]" for _xq in xq]
48
+ image_list = ','.join([f'\'{i}\'' for i in img_ids])
49
+ _thresh = f"WHERE pred_logit > {thresh}" if thresh > 0 else ""
50
+ _subq_str = []
51
+ _img_score_subq = []
52
+ for _l, _xq in enumerate(xq_s):
53
+ _img_score_subq.append(
54
+ f"arrayReduce('maxIf', logit, arrayMap(x->x={_l}, label))")
55
+ _subq_str.append(f"""
56
+ SELECT {OBJ_DB_NAME}.img_id AS img_id, img_url, img_w, img_h,
57
+ (1 / (1 + exp(-(arraySum(arrayMap((x,y)->x*y, prelogit, {_xq})))))) AS pred_logit,
58
+ obj_id, box_cx, box_cy, box_w, box_h, class_embedding, {_l} AS l
59
+ FROM {OBJ_DB_NAME}
60
+ JOIN {IMG_DB_NAME}
61
+ ON {IMG_DB_NAME}.img_id = {OBJ_DB_NAME}.img_id
62
+ PREWHERE img_id IN ({image_list})
63
+ {_thresh}
64
+ """)
65
+ _subq_str = ' UNION ALL '.join(_subq_str)
66
+ _img_score_q = ','.join(_img_score_subq)
67
+ _img_score_q = f"arraySum(arrayFilter(x->NOT isNaN(x), array({_img_score_q}))) AS img_score"
68
+ q_str = f"""
69
+ SELECT img_id, groupArray(obj_id) AS box_id, img_url, img_w, img_h,
70
+ groupArray(box_cx) AS cx, groupArray(box_cy) AS cy, groupArray(box_w) AS w, groupArray(box_h) AS h,
71
+ groupArray(pred_logit) AS logit, groupArray(l) as label, groupArray(class_embedding) AS cls_emb,
72
+ {_img_score_q}
73
+ FROM
74
+ ({_subq_str})
75
+ GROUP BY img_id, img_url, img_w, img_h ORDER BY img_score DESC
76
+ """
77
+ xc = client.fetch(q_str)
78
+ return xc
79
+
80
+
81
+ def simple_query(client, xq, IMG_DB_NAME, OBJ_DB_NAME, thresh=0.08, topk=10):
82
+ xq_s = [
83
+ f"[{', '.join([str(float(fnum)) for fnum in _xq.tolist() + [1]])}]" for _xq in xq]
84
+ res = []
85
+ subq_str = []
86
+ _thresh = f"WHERE pred_logit > {thresh}" if thresh > 0 else ""
87
+ for _l, _xq in enumerate(xq_s):
88
+ subq_str.append(
89
+ f"""
90
+ SELECT {OBJ_DB_NAME}.img_id AS img_id, img_url, img_w, img_h, prelogit,
91
+ obj_id, box_cx, box_cy, box_w, box_h, {_l} AS l, distance('topK={topk}', 'nprobe=32')(prelogit, {_xq}) AS dist
92
+ FROM {OBJ_DB_NAME}
93
+ JOIN {IMG_DB_NAME}
94
+ ON {IMG_DB_NAME}.img_id = {OBJ_DB_NAME}.img_id
95
+ {_thresh} LIMIT 10
96
+ """)
97
+ subq_str = " UNION ALL ".join(subq_str)
98
+ q_str = f"""
99
+ SELECT groupArray(img_url) AS img_url, groupArray(img_w) AS img_w, groupArray(img_h) AS img_h,
100
+ groupArray(box_cx) AS cx, groupArray(box_cy) AS cy, groupArray(box_w) AS w, groupArray(box_h) AS h,
101
+ l AS label, groupArray(dist) as d,
102
+ groupArray(1 / (1 + exp(-dist))) AS logit FROM (
103
+ {subq_str}
104
+ )
105
+ GROUP BY l
106
+ """
107
+ res = client.fetch(q_str)
108
+ return res