Fangrui Liu commited on
Commit
98667f6
·
1 Parent(s): 1439326

update compute gradient with sql

Browse files
Files changed (5) hide show
  1. app.py +12 -6
  2. box_utils.py +61 -32
  3. card_model.py +2 -1
  4. classifier.py +46 -37
  5. query_model.py +2 -2
app.py CHANGED
@@ -192,7 +192,7 @@ def submit(meta):
192
  zip(
193
  *(
194
  (
195
- v[-1],
196
  st.session_state.text_prompts.index(st.session_state[f"label-{i}"]),
197
  )
198
  for i, v in matches.items()
@@ -329,7 +329,7 @@ try:
329
  matches = st.session_state.matches
330
  # initialize classifier
331
  if "clf" not in st.session_state:
332
- st.session_state.clf = Classifier(st.session_state.xq)
333
  st.session_state.step = 0
334
  if qtime > 0:
335
  st.info(
@@ -344,11 +344,13 @@ try:
344
  ),
345
  )
346
  )
 
 
347
 
348
  # export the model into executable ONNX
349
  st.session_state.dnld_model = BytesIO()
350
  torch.onnx.export(
351
- torch.nn.Sequential(st.session_state.clf.model, SplitLayer()),
352
  torch.zeros([1, len(st.session_state.xq[0])]),
353
  st.session_state.dnld_model,
354
  input_names=["input"],
@@ -370,7 +372,9 @@ try:
370
  with st.expander("Top-K Images"):
371
  with st.container():
372
  boxes_w_img, _ = postprocess(
373
- o_matches, st.session_state.text_prompts, None
 
 
374
  )
375
  boxes_w_img = sorted(boxes_w_img, key=lambda x: x[4], reverse=True)
376
  for img_id, img_url, img_w, img_h, img_score, boxes in boxes_w_img:
@@ -428,7 +432,9 @@ try:
428
 
429
  # Post processing boxes regarding to their score, intersection
430
  boxes_w_img, meta = postprocess(
431
- matches, st.session_state.text_prompts, img_matches
 
 
432
  )
433
 
434
  # Sort the result according to their relavancy
@@ -452,7 +458,7 @@ try:
452
  img_row[0].write(card(*args), unsafe_allow_html=True)
453
  # crop objects out of the original image
454
  for b in boxes:
455
- _id, cx, cy, w, h, label, logit, is_selected, _ = b
456
  with img_row[1 + ind_b % 3].container():
457
  st.write("{:s}: {:.4f}".format(label, logit))
458
  # quite hacky: with streamlit components API
 
192
  zip(
193
  *(
194
  (
195
+ v[0],
196
  st.session_state.text_prompts.index(st.session_state[f"label-{i}"]),
197
  )
198
  for i, v in matches.items()
 
329
  matches = st.session_state.matches
330
  # initialize classifier
331
  if "clf" not in st.session_state:
332
+ st.session_state.clf = Classifier(st.session_state.index, OBJ_DB_NAME, st.session_state.xq)
333
  st.session_state.step = 0
334
  if qtime > 0:
335
  st.info(
 
344
  ),
345
  )
346
  )
347
+ lnprob = torch.nn.Linear(st.session_state.xq.shape[1], st.session_state.xq.shape[0], bias=False)
348
+ lnprob.weight = torch.nn.Parameter(st.session_state.clf.weight)
349
 
350
  # export the model into executable ONNX
351
  st.session_state.dnld_model = BytesIO()
352
  torch.onnx.export(
353
+ torch.nn.Sequential(lnprob, SplitLayer()),
354
  torch.zeros([1, len(st.session_state.xq[0])]),
355
  st.session_state.dnld_model,
356
  input_names=["input"],
 
372
  with st.expander("Top-K Images"):
373
  with st.container():
374
  boxes_w_img, _ = postprocess(
375
+ o_matches, st.session_state.text_prompts, o_matches,
376
+ agnostic_ratio=1-0.6**(st.session_state.step+1),
377
+ class_ratio=1-0.2**(st.session_state.step+1)
378
  )
379
  boxes_w_img = sorted(boxes_w_img, key=lambda x: x[4], reverse=True)
380
  for img_id, img_url, img_w, img_h, img_score, boxes in boxes_w_img:
 
432
 
433
  # Post processing boxes regarding to their score, intersection
434
  boxes_w_img, meta = postprocess(
435
+ matches, st.session_state.text_prompts, img_matches,
436
+ agnostic_ratio=1-0.6**(st.session_state.step+1),
437
+ class_ratio=1-0.2**(st.session_state.step+1)
438
  )
439
 
440
  # Sort the result according to their relavancy
 
458
  img_row[0].write(card(*args), unsafe_allow_html=True)
459
  # crop objects out of the original image
460
  for b in boxes:
461
+ _id, cx, cy, w, h, label, logit, is_selected = b[:8]
462
  with img_row[1 + ind_b % 3].container():
463
  st.write("{:s}: {:.4f}".format(label, logit))
464
  # quite hacky: with streamlit components API
box_utils.py CHANGED
@@ -2,16 +2,14 @@ import numpy as np
2
 
3
 
4
  def cxywh2xywh(cx, cy, w, h):
5
- """ CxCyWH format to XYWH format conversion
6
- """
7
  x = cx - w / 2
8
  y = cy - h / 2
9
  return x, y, w, h
10
 
11
 
12
  def cxywh2ltrb(cx, cy, w, h):
13
- """CxCyWH format to LeftRightTopBottom format
14
- """
15
  l = cx - w / 2
16
  t = cy - h / 2
17
  r = cx + w / 2
@@ -61,9 +59,16 @@ def nms(cx, cy, w, h, s, iou_thresh=0.3):
61
  i = sort_ind[0]
62
  res.append(i)
63
 
64
- _iou = iou((l[i], t[i], r[i], b[i], areas[i]),
65
- (l[sort_ind[1:]], t[sort_ind[1:]],
66
- r[sort_ind[1:]], b[sort_ind[1:]], areas[sort_ind[1:]]))
 
 
 
 
 
 
 
67
  sel_ind = np.where(_iou <= iou_thresh)[0]
68
  sort_ind = sort_ind[sel_ind + 1]
69
  return res
@@ -77,43 +82,64 @@ def filter_nonpos(boxes, agnostic_ratio=0.5, class_ratio=0.7):
77
  """
78
  ret = []
79
  labelwise = {}
80
- for _id, cx, cy, w, h, label, logit, is_selected, _ in boxes:
 
81
  if label not in labelwise:
82
  labelwise[label] = []
83
  labelwise[label].append(logit)
84
  labelwise = {l: max(s) for l, s in labelwise.items()}
85
  agnostic = max([v for _, v in labelwise.items()])
86
  for b in boxes:
87
- _id, cx, cy, w, h, label, logit, is_selected, _ = b
88
- if logit > class_ratio * labelwise[label] \
89
- and logit > agnostic_ratio * agnostic:
90
  ret.append(b)
91
  return ret
92
 
93
 
94
- def postprocess(matches, prompt_labels, img_matches=None):
95
  meta = []
96
  boxes_w_img = []
97
- matches_ = {m['img_id']: m for m in matches}
98
  if img_matches is not None:
99
- img_matches_ = {m['img_id']: m for m in img_matches}
100
  for k in matches_.keys():
101
  m = matches_[k]
102
  boxes = []
103
- boxes += list(map(list, zip(m['box_id'], m['cx'], m['cy'], m['w'], m['h'],
104
- [prompt_labels[int(l)]
105
- for l in m['label']],
106
- m['logit'], [1] *
107
- len(m['box_id']),
108
- list(np.array(m['cls_emb'])))))
 
 
 
 
 
 
 
 
 
109
  if img_matches is not None and k in img_matches_:
110
  img_m = img_matches_[k]
111
  # and also those non-TopK hits and those non-topk are not anticipating training
112
- boxes += [i for i in map(list, zip(img_m['box_id'], img_m['cx'], img_m['cy'], img_m['w'], img_m['h'],
113
- [prompt_labels[int(
114
- l)] for l in img_m['label']], img_m['logit'],
115
- [0] * len(img_m['box_id']), list(np.array(img_m['cls_emb']))))
116
- if i[0] not in [b[0] for b in boxes]]
 
 
 
 
 
 
 
 
 
 
 
 
117
  else:
118
  img_m = None
119
  # update record metadata after query
@@ -121,16 +147,19 @@ def postprocess(matches, prompt_labels, img_matches=None):
121
  meta.append(b[0])
122
 
123
  # remove some non-significant boxes
124
- boxes = filter_nonpos(
125
- boxes, agnostic_ratio=0.4, class_ratio=0.7)
126
 
127
  # doing non-maximum suppression
128
- cx, cy, w, h, s = list(map(lambda x: np.array(x),
129
- list(zip(*[(*b[1:5], b[6]) for b in boxes]))))
 
130
  ind = nms(cx, cy, w, h, s, 0.3)
131
  boxes = [boxes[i] for i in ind]
132
  if img_m is not None:
133
- img_score = img_m['img_score'] if img_matches is not None else m['img_score']
 
 
134
  boxes_w_img.append(
135
- (m["img_id"], m["img_url"], m["img_w"], m["img_h"], img_score, boxes))
136
- return boxes_w_img, meta
 
 
2
 
3
 
4
  def cxywh2xywh(cx, cy, w, h):
5
+ """CxCyWH format to XYWH format conversion"""
 
6
  x = cx - w / 2
7
  y = cy - h / 2
8
  return x, y, w, h
9
 
10
 
11
  def cxywh2ltrb(cx, cy, w, h):
12
+ """CxCyWH format to LeftRightTopBottom format"""
 
13
  l = cx - w / 2
14
  t = cy - h / 2
15
  r = cx + w / 2
 
59
  i = sort_ind[0]
60
  res.append(i)
61
 
62
+ _iou = iou(
63
+ (l[i], t[i], r[i], b[i], areas[i]),
64
+ (
65
+ l[sort_ind[1:]],
66
+ t[sort_ind[1:]],
67
+ r[sort_ind[1:]],
68
+ b[sort_ind[1:]],
69
+ areas[sort_ind[1:]],
70
+ ),
71
+ )
72
  sel_ind = np.where(_iou <= iou_thresh)[0]
73
  sort_ind = sort_ind[sel_ind + 1]
74
  return res
 
82
  """
83
  ret = []
84
  labelwise = {}
85
+ for b in boxes:
86
+ _id, cx, cy, w, h, label, logit, is_selected = b[:8]
87
  if label not in labelwise:
88
  labelwise[label] = []
89
  labelwise[label].append(logit)
90
  labelwise = {l: max(s) for l, s in labelwise.items()}
91
  agnostic = max([v for _, v in labelwise.items()])
92
  for b in boxes:
93
+ _id, cx, cy, w, h, label, logit, is_selected = b[:8]
94
+ if logit > class_ratio * labelwise[label] and logit > agnostic_ratio * agnostic:
 
95
  ret.append(b)
96
  return ret
97
 
98
 
99
+ def postprocess(matches, prompt_labels, img_matches=None, agnostic_ratio=0.4, class_ratio=0.7):
100
  meta = []
101
  boxes_w_img = []
102
+ matches_ = {m["img_id"]: m for m in matches}
103
  if img_matches is not None:
104
+ img_matches_ = {m["img_id"]: m for m in img_matches}
105
  for k in matches_.keys():
106
  m = matches_[k]
107
  boxes = []
108
+ boxes += list(
109
+ map(
110
+ list,
111
+ zip(
112
+ m["box_id"],
113
+ m["cx"],
114
+ m["cy"],
115
+ m["w"],
116
+ m["h"],
117
+ [prompt_labels[int(l)] for l in m["label"]],
118
+ m["logit"],
119
+ [1] * len(m["box_id"]),
120
+ ),
121
+ )
122
+ )
123
  if img_matches is not None and k in img_matches_:
124
  img_m = img_matches_[k]
125
  # and also those non-TopK hits and those non-topk are not anticipating training
126
+ boxes += [
127
+ i
128
+ for i in map(
129
+ list,
130
+ zip(
131
+ img_m["box_id"],
132
+ img_m["cx"],
133
+ img_m["cy"],
134
+ img_m["w"],
135
+ img_m["h"],
136
+ [prompt_labels[int(l)] for l in img_m["label"]],
137
+ img_m["logit"],
138
+ [0] * len(img_m["box_id"]),
139
+ ),
140
+ )
141
+ if i[0] not in [b[0] for b in boxes]
142
+ ]
143
  else:
144
  img_m = None
145
  # update record metadata after query
 
147
  meta.append(b[0])
148
 
149
  # remove some non-significant boxes
150
+ boxes = filter_nonpos(boxes, agnostic_ratio=agnostic_ratio, class_ratio=class_ratio)
 
151
 
152
  # doing non-maximum suppression
153
+ cx, cy, w, h, s = list(
154
+ map(lambda x: np.array(x), list(zip(*[(*b[1:5], b[6]) for b in boxes])))
155
+ )
156
  ind = nms(cx, cy, w, h, s, 0.3)
157
  boxes = [boxes[i] for i in ind]
158
  if img_m is not None:
159
+ img_score = (
160
+ img_m["img_score"] if img_matches is not None else m["img_score"]
161
+ )
162
  boxes_w_img.append(
163
+ (m["img_id"], m["img_url"], m["img_w"], m["img_h"], img_score, boxes)
164
+ )
165
+ return boxes_w_img, meta
card_model.py CHANGED
@@ -47,7 +47,8 @@ def card(img_url, img_w, img_h, boxes):
47
  """
48
  _boxes = ""
49
  img_url = convert_img_url(img_url)
50
- for _id, cx, cy, w, h, label, logit, is_selected, _ in boxes:
 
51
  x, y, w, h = cxywh2xywh(cx, cy, w, h)
52
  x = round(img_w * x)
53
  y = round(img_h * y)
 
47
  """
48
  _boxes = ""
49
  img_url = convert_img_url(img_url)
50
+ for b in boxes:
51
+ _id, cx, cy, w, h, label, logit, is_selected = b[:8]
52
  x, y, w, h = cxywh2xywh(cx, cy, w, h)
53
  x = round(img_w * x)
54
  y = round(img_h * y)
classifier.py CHANGED
@@ -1,7 +1,7 @@
1
  import torch
2
 
3
 
4
- def extract_text_feature(prompt, model, processor, device='cpu'):
5
  """Extract text features
6
 
7
  Args:
@@ -10,12 +10,11 @@ def extract_text_feature(prompt, model, processor, device='cpu'):
10
  processor: OwlViT processor
11
  device (str, optional): device to run. Defaults to 'cpu'.
12
  """
13
- device = 'cpu'
14
  if torch.cuda.is_available():
15
- device = 'cuda'
16
  with torch.no_grad():
17
- input_ids = torch.as_tensor(processor(text=prompt)[
18
- 'input_ids']).to(device)
19
  print(input_ids.device)
20
  text_outputs = model.owlvit.text_model(
21
  input_ids=input_ids,
@@ -32,7 +31,7 @@ def extract_text_feature(prompt, model, processor, device='cpu'):
32
 
33
 
34
  def prompt2vec(prompt: str, model, processor):
35
- """ Convert prompt into a computational vector
36
 
37
  Args:
38
  prompt (str): Text to be tokenized
@@ -49,7 +48,7 @@ def prompt2vec(prompt: str, model, processor):
49
 
50
 
51
  def tune(clf, X, y, iters=2):
52
- """ Train the Zero-shot Classifier
53
 
54
  Args:
55
  X (numpy.ndarray): Input vectors (retreived vectors)
@@ -62,60 +61,70 @@ def tune(clf, X, y, iters=2):
62
  # extract new vector
63
  return clf.get_weights()
64
 
65
-
66
  class Classifier:
67
  """Multi-Class Zero-shot Classifier
68
  This Classifier provides proxy regarding to the user's reaction to the probed images.
69
  The proxy will replace the original query vector generated by prompted vector and finally
70
  give the user a satisfying retrieval result.
71
 
72
- This can be commonly seen in a recommendation system. The classifier will recommend more
73
  precise result as it accumulating user's activity.
74
-
75
  This is a multiclass classifier. For N queries it will set the all queries to the first-N classes
76
  and the last one takes the negative one.
77
  """
78
 
79
- def __init__(self, xq: list):
80
  init_weight = torch.Tensor(xq)
81
  self.num_class = xq.shape[0]
82
- DIMS = xq.shape[1]
83
- # note that the bias is ignored, as we only focus on the inner product result
84
- self.model = torch.nn.Linear(DIMS, self.num_class, bias=False)
85
  # convert initial query `xq` to tensor parameter to init weights
86
- self.model.weight = torch.nn.Parameter(init_weight)
87
- # init loss and optimizer
88
- self.loss = torch.nn.BCEWithLogitsLoss()
89
- self.optimizer = torch.optim.SGD(self.model.parameters(), lr=0.1)
90
 
91
  def fit(self, X: list, y: list, iters: int = 5):
92
  # convert X and y to tensor
93
- X = torch.Tensor(X)
94
- X /= torch.norm(X, p=2, dim=-1, keepdim=True)
95
- y = torch.Tensor(y).long()
96
- # Generate labels for binary classification and ignore outbound labels
97
- non_ind = y > self.num_class
98
- y = torch.nn.functional.one_hot(y % self.num_class, num_classes=self.num_class).float()
99
- y[non_ind] = 0
100
- for i in range(iters):
101
  # zero gradients
102
- self.optimizer.zero_grad()
103
  # Normalize the weight before inference
104
  # This will constrain the gradient or you will have an explosion on query vector
105
- self.model.weight.data /= torch.norm(self.model.weight.data, p=2, dim=-1, keepdim=True)
106
- # forward pass
107
- out = self.model(X)
108
- # compute loss
109
- loss = self.loss(out, y)
110
- # backward pass
111
- loss.backward()
 
 
 
 
 
 
 
 
 
 
 
 
 
112
  # update weights
113
- self.optimizer.step()
 
114
 
115
  def get_weights(self):
116
- xq = self.model.weight.detach().numpy()
117
  return xq
118
-
 
 
119
  class SplitLayer(torch.nn.Module):
120
  def forward(self, x):
121
  return torch.split(x, 1, dim=-1)
 
1
  import torch
2
 
3
 
4
+ def extract_text_feature(prompt, model, processor, device="cpu"):
5
  """Extract text features
6
 
7
  Args:
 
10
  processor: OwlViT processor
11
  device (str, optional): device to run. Defaults to 'cpu'.
12
  """
13
+ device = "cpu"
14
  if torch.cuda.is_available():
15
+ device = "cuda"
16
  with torch.no_grad():
17
+ input_ids = torch.as_tensor(processor(text=prompt)["input_ids"]).to(device)
 
18
  print(input_ids.device)
19
  text_outputs = model.owlvit.text_model(
20
  input_ids=input_ids,
 
31
 
32
 
33
  def prompt2vec(prompt: str, model, processor):
34
+ """Convert prompt into a computational vector
35
 
36
  Args:
37
  prompt (str): Text to be tokenized
 
48
 
49
 
50
  def tune(clf, X, y, iters=2):
51
+ """Train the Zero-shot Classifier
52
 
53
  Args:
54
  X (numpy.ndarray): Input vectors (retreived vectors)
 
61
  # extract new vector
62
  return clf.get_weights()
63
 
 
64
  class Classifier:
65
  """Multi-Class Zero-shot Classifier
66
  This Classifier provides proxy regarding to the user's reaction to the probed images.
67
  The proxy will replace the original query vector generated by prompted vector and finally
68
  give the user a satisfying retrieval result.
69
 
70
+ This can be commonly seen in a recommendation system. The classifier will recommend more
71
  precise result as it accumulating user's activity.
72
+
73
  This is a multiclass classifier. For N queries it will set the all queries to the first-N classes
74
  and the last one takes the negative one.
75
  """
76
 
77
+ def __init__(self, client, obj_db:str, xq: list):
78
  init_weight = torch.Tensor(xq)
79
  self.num_class = xq.shape[0]
80
+ self.DIMS = xq.shape[1]
 
 
81
  # convert initial query `xq` to tensor parameter to init weights
82
+ self.weight = init_weight
83
+ self.client = client
84
+ self.obj_db = obj_db
 
85
 
86
  def fit(self, X: list, y: list, iters: int = 5):
87
  # convert X and y to tensor
88
+ xq_s = [
89
+ f"[{', '.join([str(float(fnum)) for fnum in _xq + [1]])}]"
90
+ for _xq in self.get_weights().tolist()
91
+ ]
92
+
93
+ for _ in range(iters):
 
 
94
  # zero gradients
95
+ grad = []
96
  # Normalize the weight before inference
97
  # This will constrain the gradient or you will have an explosion on query vector
98
+ self.weight.data /= torch.norm(
99
+ self.weight.data, p=2, dim=-1, keepdim=True
100
+ )
101
+ for n in range(self.num_class):
102
+ # select all training sample and create labels
103
+ labels, objs = list(map(list, zip(*[[1 if y[i]==n else 0, x] for i, x in enumerate(X) if y[i] in [n, self.num_class+1]])))
104
+
105
+ # NOTE from @fangruil
106
+ # Use SQL to calculate the gradient
107
+ # For binary cross entropy we have
108
+ # g = (1/(1+\exp(-XW))-Y)^TX
109
+ # To simplify the query, we separated
110
+ # the calculation into class numbers
111
+ grad_q_str = f"""
112
+ SELECT sumForEachArray(arrayMap((x,y,gt)->arrayMap(i->i*(y-gt), x), X, Y, GT)) AS grad
113
+ FROM (
114
+ SELECT groupArray(arrayPopBack(prelogit)) AS X,
115
+ groupArray(1/(1+exp(-arraySum(arrayMap((x,y)->x*y, prelogit, {xq_s[n]}))))) AS Y, {labels} AS GT
116
+ FROM {self.obj_db} WHERE obj_id IN {objs})"""
117
+ grad.append(torch.as_tensor(self.client.fetch(grad_q_str)[0]['grad']))
118
  # update weights
119
+ grad = torch.stack(grad, dim=0)
120
+ self.weight -= 0.1 * grad
121
 
122
  def get_weights(self):
123
+ xq = self.weight.detach().numpy()
124
  return xq
125
+
126
+
127
+
128
  class SplitLayer(torch.nn.Module):
129
  def forward(self, x):
130
  return torch.split(x, 1, dim=-1)
query_model.py CHANGED
@@ -32,7 +32,7 @@ def topk_obj_query(client, xq, IMG_DB_NAME, OBJ_DB_NAME,
32
  q_str = f"""
33
  SELECT img_id, img_url, img_w, img_h, groupArray(obj_id) AS box_id,
34
  groupArray(box_cx) AS cx, groupArray(box_cy) AS cy, groupArray(box_w) AS w, groupArray(box_h) AS h,
35
- groupArray(pred_logit) AS logit, groupArray(l) as label, groupArray(class_embedding) AS cls_emb,
36
  {_img_score_q}
37
  FROM
38
  ({_subq_str})
@@ -68,7 +68,7 @@ def rev_query(client, xq, img_ids, IMG_DB_NAME, OBJ_DB_NAME, thresh=0.08):
68
  q_str = f"""
69
  SELECT img_id, groupArray(obj_id) AS box_id, img_url, img_w, img_h,
70
  groupArray(box_cx) AS cx, groupArray(box_cy) AS cy, groupArray(box_w) AS w, groupArray(box_h) AS h,
71
- groupArray(pred_logit) AS logit, groupArray(l) as label, groupArray(class_embedding) AS cls_emb,
72
  {_img_score_q}
73
  FROM
74
  ({_subq_str})
 
32
  q_str = f"""
33
  SELECT img_id, img_url, img_w, img_h, groupArray(obj_id) AS box_id,
34
  groupArray(box_cx) AS cx, groupArray(box_cy) AS cy, groupArray(box_w) AS w, groupArray(box_h) AS h,
35
+ groupArray(pred_logit) AS logit, groupArray(l) as label,
36
  {_img_score_q}
37
  FROM
38
  ({_subq_str})
 
68
  q_str = f"""
69
  SELECT img_id, groupArray(obj_id) AS box_id, img_url, img_w, img_h,
70
  groupArray(box_cx) AS cx, groupArray(box_cy) AS cy, groupArray(box_w) AS w, groupArray(box_h) AS h,
71
+ groupArray(pred_logit) AS logit, groupArray(l) as label,
72
  {_img_score_q}
73
  FROM
74
  ({_subq_str})