cwkuo commited on
Commit
9051af7
·
1 Parent(s): 507ba2d

code clean up

Browse files
.gitattributes CHANGED
@@ -33,8 +33,8 @@ saved_model/**/* filter=lfs diff=lfs merge=lfs -text
33
  *.zip filter=lfs diff=lfs merge=lfs -text
34
  *.zst filter=lfs diff=lfs merge=lfs -text
35
  *tfevents* filter=lfs diff=lfs merge=lfs -text
36
- .bin filter=lfs diff=lfs merge=lfs -text
37
- .pt filter=lfs diff=lfs merge=lfs -text
38
  *.hdf5 filter=lfs diff=lfs merge=lfs -text
39
  *.index filter=lfs diff=lfs merge=lfs -text
40
  *.jpg filter=lfs diff=lfs merge=lfs -text
 
33
  *.zip filter=lfs diff=lfs merge=lfs -text
34
  *.zst filter=lfs diff=lfs merge=lfs -text
35
  *tfevents* filter=lfs diff=lfs merge=lfs -text
36
+ *.bin filter=lfs diff=lfs merge=lfs -text
37
+ *.pt filter=lfs diff=lfs merge=lfs -text
38
  *.hdf5 filter=lfs diff=lfs merge=lfs -text
39
  *.index filter=lfs diff=lfs merge=lfs -text
40
  *.jpg filter=lfs diff=lfs merge=lfs -text
.vscode/settings.json DELETED
@@ -1,6 +0,0 @@
1
- {
2
- "[python]": {
3
- "editor.defaultFormatter": "ms-python.autopep8"
4
- },
5
- "python.formatting.provider": "none"
6
- }
 
 
 
 
 
 
 
app.py CHANGED
@@ -1,6 +1,4 @@
1
  from pathlib import Path
2
- import datetime
3
- import json
4
  import os
5
  import time
6
  import gradio as gr
@@ -258,10 +256,11 @@ The service is a research preview intended for non-commercial use only, subject
258
  def build_demo():
259
  textbox = gr.Textbox(show_label=False, placeholder="Enter text and press ENTER", container=False)
260
  imagebox = gr.Image(type="pil")
261
- state = gr.State()
262
 
263
  with gr.Blocks(title="GPT-K", theme=gr.themes.Base()) as demo:
 
264
  gr.Markdown(title_markdown)
 
265
  with gr.Row():
266
  with gr.Column(scale=3):
267
  gr.Examples(examples=[
@@ -274,10 +273,10 @@ def build_demo():
274
 
275
  imagebox.render()
276
  textbox.render()
277
- with gr.Column():
278
  submit_btn = gr.Button(value="📝 Submit")
279
- regenerate_btn = gr.Button(value="🔄 Regenerate", interactive=False)
280
- clear_btn = gr.Button(value="🗑️ Clear", interactive=False)
281
 
282
  with gr.Accordion("Parameters", open=True):
283
  with gr.Row():
@@ -289,7 +288,7 @@ def build_demo():
289
  max_output_tokens = gr.Slider(minimum=0, maximum=1024, value=512, step=64, interactive=True, label="Max output tokens",)
290
 
291
  with gr.Column(scale=6):
292
- chatbot = gr.Chatbot(elem_id="chatbot", label="LLaVA Chatbot", height=550)
293
 
294
  gr.Markdown("Retrieved Knowledge")
295
  knwl_img, knwl_txt = [], []
@@ -303,7 +302,7 @@ def build_demo():
303
  with gr.Column(scale=7):
304
  knwl_txt.append(gr.Markdown())
305
  knwl_vis = knwl_img + knwl_txt
306
-
307
  gr.Markdown(tos_markdown)
308
  gr.Markdown(learn_more_markdown)
309
 
 
1
  from pathlib import Path
 
 
2
  import os
3
  import time
4
  import gradio as gr
 
256
  def build_demo():
257
  textbox = gr.Textbox(show_label=False, placeholder="Enter text and press ENTER", container=False)
258
  imagebox = gr.Image(type="pil")
 
259
 
260
  with gr.Blocks(title="GPT-K", theme=gr.themes.Base()) as demo:
261
+ state = gr.State()
262
  gr.Markdown(title_markdown)
263
+
264
  with gr.Row():
265
  with gr.Column(scale=3):
266
  gr.Examples(examples=[
 
273
 
274
  imagebox.render()
275
  textbox.render()
276
+ with gr.Row():
277
  submit_btn = gr.Button(value="📝 Submit")
278
+ regenerate_btn = gr.Button(value="🔄 Regenerate", interactive=False)
279
+ clear_btn = gr.Button(value="🗑️ Clear", interactive=False)
280
 
281
  with gr.Accordion("Parameters", open=True):
282
  with gr.Row():
 
288
  max_output_tokens = gr.Slider(minimum=0, maximum=1024, value=512, step=64, interactive=True, label="Max output tokens",)
289
 
290
  with gr.Column(scale=6):
291
+ chatbot = gr.Chatbot(elem_id="chatbot", label="GPT-K Chatbot", height=550)
292
 
293
  gr.Markdown("Retrieved Knowledge")
294
  knwl_img, knwl_txt = [], []
 
302
  with gr.Column(scale=7):
303
  knwl_txt.append(gr.Markdown())
304
  knwl_vis = knwl_img + knwl_txt
305
+
306
  gr.Markdown(tos_markdown)
307
  gr.Markdown(learn_more_markdown)
308
 
conversation.py CHANGED
@@ -1,6 +1,6 @@
1
  import dataclasses
2
  from enum import auto, Enum
3
- from typing import List, Tuple
4
 
5
 
6
  class SeparatorStyle(Enum):
@@ -197,7 +197,6 @@ conv_gptk = Conversation(
197
  sep=""
198
  )
199
 
200
-
201
  conv_vicuna_v0 = Conversation(
202
  system="A chat between a curious human and an artificial intelligence assistant. "
203
  "The assistant gives helpful, detailed, and polite answers to the human's questions.",
 
1
  import dataclasses
2
  from enum import auto, Enum
3
+ from typing import List
4
 
5
 
6
  class SeparatorStyle(Enum):
 
197
  sep=""
198
  )
199
 
 
200
  conv_vicuna_v0 = Conversation(
201
  system="A chat between a curious human and an artificial intelligence assistant. "
202
  "The assistant gives helpful, detailed, and polite answers to the human's questions.",
knowledge/__pycache__/__init__.cpython-37.pyc DELETED
Binary file (254 Bytes)
 
knowledge/__pycache__/__init__.cpython-38.pyc DELETED
Binary file (254 Bytes)
 
knowledge/__pycache__/cluster.cpython-38.pyc DELETED
Binary file (5.12 kB)
 
knowledge/__pycache__/dbscan.cpython-37.pyc DELETED
Binary file (2.29 kB)
 
knowledge/__pycache__/dbscan.cpython-38.pyc DELETED
Binary file (2.32 kB)
 
knowledge/__pycache__/image_crops_idx.cpython-38.pyc DELETED
Binary file (10.8 kB)
 
knowledge/__pycache__/image_tokens_idx.cpython-38.pyc DELETED
Binary file (7.7 kB)
 
knowledge/__pycache__/revive.cpython-38.pyc DELETED
Binary file (2.19 kB)
 
knowledge/__pycache__/sentence_db.cpython-37.pyc DELETED
Binary file (6.01 kB)
 
knowledge/__pycache__/sentence_db.cpython-38.pyc DELETED
Binary file (6.39 kB)
 
knowledge/__pycache__/sentence_idx.cpython-37.pyc DELETED
Binary file (9.12 kB)
 
knowledge/__pycache__/sentence_idx.cpython-38.pyc DELETED
Binary file (9.75 kB)
 
knowledge/__pycache__/text_db.cpython-38.pyc DELETED
Binary file (7.22 kB)
 
knowledge/__pycache__/utils.cpython-37.pyc DELETED
Binary file (3.05 kB)
 
knowledge/__pycache__/utils.cpython-38.pyc DELETED
Binary file (4.1 kB)
 
knowledge/__pycache__/vis_vocab.cpython-37.pyc DELETED
Binary file (8.46 kB)
 
knowledge/__pycache__/wordnet.cpython-37.pyc DELETED
Binary file (2.3 kB)
 
knowledge/cluster.py DELETED
@@ -1,178 +0,0 @@
1
- import argparse
2
- from pathlib import Path
3
- import numpy as np
4
- from tqdm import tqdm
5
- import h5py
6
- import time
7
-
8
- import faiss
9
- import torch
10
- from pytorch_lightning import seed_everything
11
-
12
- import sys
13
- sys.path.append('.')
14
- from knowledge.text_db import TextDB
15
- from knowledge.utils import nn_search, build_faiss_index, refine_cosine
16
-
17
-
18
- UNSEEN = -2
19
- NOISE = -1
20
-
21
-
22
- def dbscan(X, faiss_index, device, eps=0.1, min_points=1, k=2048, bs=512):
23
- neighbors = []
24
- N = (len(X) - 1) // bs + 1
25
- for i in tqdm(range(N), dynamic_ncols=True, desc="Find nearest neighbors", mininterval=1.0):
26
- Xi = X[i*bs: (i+1)*bs]
27
- _, I = faiss_index.search(Xi, k*2)
28
- S, I = refine_cosine(X, Xi, I, device, k)
29
-
30
- for sim, idx in zip(S, I):
31
- dist = 1. - sim
32
- neighbors.append(idx[dist < eps])
33
-
34
- cluster_id = 0
35
- n_points = len(X)
36
- labels = np.array([
37
- NOISE if len(neighbors[i]) < min_points else UNSEEN
38
- for i in range(n_points)
39
- ])
40
-
41
- with tqdm(total=n_points, dynamic_ncols=True, desc="DBSCAN clustering", mininterval=1.0) as pbar:
42
- for i in range(n_points):
43
- if labels[i] == UNSEEN:
44
- seeds = np.array([i, ])
45
- labels[seeds] = cluster_id
46
-
47
- while len(seeds) > 0:
48
- neighbor_seeds = set()
49
- for s in seeds:
50
- n = neighbors[s]
51
- if len(n) > 0:
52
- l = np.array(list(set(labels[n])))
53
- l = l[np.logical_and(l >= 0, l != cluster_id)]
54
- for li in l:
55
- labels[labels == li] = cluster_id
56
-
57
- n = n[labels[n] == UNSEEN]
58
- neighbor_seeds.update(n)
59
-
60
- seeds = np.array(list(neighbor_seeds))
61
- if len(seeds) > 0:
62
- assert np.all(labels[seeds] == UNSEEN)
63
- labels[seeds] = cluster_id
64
-
65
- cluster_id += 1
66
-
67
- pbar.set_postfix(num_clusters=cluster_id)
68
- pbar.update()
69
-
70
- label_set = np.sort(list(set(labels)))
71
- label_set = label_set[label_set >= 0]
72
- labels_mapping = {l1: l2 for l2, l1 in enumerate(label_set)}
73
- labels_mapping[-1] = -1
74
- labels = np.array([labels_mapping[l] for l in labels])
75
-
76
- return labels
77
-
78
-
79
- def extract_clusters(feat, text, labels, faiss_index, device, k=128, bs=8192):
80
- clusters = {}
81
- for i, l in enumerate(tqdm(labels, dynamic_ncols=True, desc="Label each samples", mininterval=1.0)):
82
- if l >= 0:
83
- try:
84
- clusters[l]["feat"] += feat[i].astype(np.float64)
85
- clusters[l]["N"] += 1
86
- except KeyError:
87
- clusters[l] = {"feat": feat[i].astype(np.float64), "N": 1}
88
-
89
- cc = []
90
- for l in tqdm(list(clusters.keys()), dynamic_ncols=True, desc="Compute cluster centers", mininterval=1.0):
91
- c = clusters[l]["feat"]/clusters[l]["N"]
92
- cc.append(c.astype(np.float32))
93
- cc = np.stack(cc)
94
- cc /= np.linalg.norm(cc, keepdims=True, axis=-1)
95
-
96
- idx = []
97
- N = (len(cc) - 1) // bs + 1
98
- for i in tqdm(range(N), dynamic_ncols=True, desc="Find nearest neighbors", mininterval=1.0):
99
- cc_i = cc[i*bs: (i+1)*bs]
100
- _, I = faiss_index.search(cc_i, k)
101
- _, I = refine_cosine(feat, cc_i, I, device, 1)
102
- idx.append(I[:, 0])
103
- idx = np.unique(np.concatenate(idx))
104
- text = [text[i] for i in idx]
105
- feat = np.stack([feat[i] for i in idx])
106
-
107
- return feat, text
108
-
109
-
110
- if __name__ == "__main__":
111
- parser = argparse.ArgumentParser(description="Cluster knowledge database using DBSCAN")
112
- parser.add_argument("--knowledge_db", type=str, required=True)
113
- parser.add_argument("--seed", type=int, default=12345)
114
- parser.add_argument("--eps", type=float, default=0.1)
115
- parser.add_argument("--ms", type=int, default=1)
116
- parser.add_argument("--ratio", type=float, default=None)
117
- parser.add_argument("--device", type=int, default=None)
118
- args = parser.parse_args()
119
-
120
- # parse exp name
121
- args.knowledge_db = Path(args.knowledge_db)
122
- exp_name = args.knowledge_db.parent.name
123
- exp_name += f"(dbscan)(eps-{args.eps})(ms-{args.ms})"
124
- save_root = args.knowledge_db.parent.parent/exp_name
125
- setattr(args, "save_root", save_root)
126
- args.save_root.mkdir(parents=True, exist_ok=True)
127
-
128
- args.device = torch.device("cuda", args.device) \
129
- if args.device is not None else torch.device("cpu")
130
-
131
- seed_everything(args.seed, workers=True)
132
- print(args)
133
-
134
- # load feature, text, and faiss index from knowledge db
135
- knowledge_db = TextDB(args.knowledge_db)
136
- feat = knowledge_db.feature.astype(np.float32)
137
- text = knowledge_db.text
138
- if args.ratio is not None:
139
- N = int(len(feat) * args.ratio)
140
- feat, text = feat[:N], text[:N]
141
- faiss_index = faiss.read_index(str(args.knowledge_db.parent/"faiss.index"))
142
- print("Add data to faiss index...", end="\r")
143
- ts = time.time()
144
- faiss_index.add(feat)
145
- print(f"Add data to faiss index...done in {time.time() - ts:.2f} secs")
146
-
147
- # DBSCAN clustering
148
- labels_file = args.save_root/"labels.npy"
149
- if labels_file.exists():
150
- labels = np.load(labels_file)
151
- else:
152
- labels = dbscan(feat, faiss_index, args.device, args.eps, args.ms)
153
- with open(labels_file, 'wb') as f:
154
- np.save(f, labels)
155
-
156
- # extract clusters
157
- feat, text = extract_clusters(feat, text, labels, faiss_index, args.device)
158
- with h5py.File(args.save_root/f"knowledge_db.hdf5", "w") as f:
159
- bs = 65536
160
- N = (len(feat) - 1) // bs + 1
161
- for i in tqdm(range(N), dynamic_ncols=True, desc="Saving clustered DB", mininterval=1.0):
162
- g = f.create_group(str(i))
163
- g.create_dataset("feature", data=feat[i*bs: (i+1)*bs], compression="gzip")
164
- g.create_dataset("text", data=text[i*bs: (i+1)*bs], compression="gzip")
165
-
166
- # build faiss index for the clustered DB
167
- index = build_faiss_index(feat, gpus=[args.device.index, ])
168
- faiss.write_index(index, str(args.save_root/"faiss.index"))
169
-
170
- # some stats
171
- noise_ratio = np.sum(labels == -1) / len(labels)
172
- n_clusters, n_samples = len(text), len(labels)
173
- msg = f"n_samples = {n_samples:,}; n_clusters = {n_clusters:,}; noise_ratio = {noise_ratio*100:.3f}%\n"
174
- with open(save_root/"info.txt", "w") as f:
175
- f.write(msg)
176
- print(msg)
177
-
178
-
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
knowledge/retrieve.py CHANGED
@@ -1,28 +1,8 @@
1
- import argparse
2
- from pathlib import Path
3
  import h5py
4
- import time
5
- import shutil
6
  import numpy as np
7
- import subprocess
8
- import time
9
  from tqdm import tqdm
10
-
11
- import faiss
12
- import open_clip
13
  import torch
14
- import torch.distributed as dist
15
- from torch.utils.data import DataLoader
16
- from pytorch_lightning import callbacks
17
- from pytorch_lightning import Trainer, LightningModule, seed_everything
18
-
19
- import sys
20
- sys.path.append('.')
21
- from dataset import coco, cc, llava
22
- from knowledge.utils import refine_cosine
23
- from knowledge import text_db
24
  from knowledge import TextDB
25
- from train.utils import ExpName
26
 
27
 
28
  class ImageCropsIdx:
@@ -123,205 +103,3 @@ class KnowAugImageCropsCombined:
123
  }
124
 
125
  return ret
126
-
127
-
128
- class ImageCropsIdxBuilder(LightningModule):
129
- def __init__(self, args, model: open_clip.model.CLIP):
130
- super().__init__()
131
-
132
- self.args = args
133
- self.save_root = args.save_root
134
- self.k = args.k
135
- self.model = model
136
-
137
- def on_validation_epoch_start(self):
138
- if self.global_rank == 0:
139
- knowledge_db = TextDB(self.args.knowledge_db)
140
- self.feature = knowledge_db.feature
141
- self.text = knowledge_db.text
142
-
143
- self.faiss_index = faiss.read_index(
144
- str(Path(self.args.knowledge_db).parent/"faiss.index")
145
- )
146
- print("\nAdd data to faiss index...", end="\r")
147
- ts = time.time()
148
- self.faiss_index.add(self.feature)
149
- print(f"Add data to faiss index...done in {time.time() - ts:.2f} secs")
150
-
151
- with h5py.File(self.save_root/"knowledge_idx.hdf5", "a") as f:
152
- f.attrs["fdim"] = self.feature.shape[-1]
153
- f.attrs["file_hash"] = knowledge_db.file_hash
154
-
155
- self.trainer.strategy.barrier()
156
-
157
- def all_gather_object(self, data):
158
- if self.trainer.world_size > 1:
159
- gathered = [None for _ in range(self.trainer.world_size)]
160
- dist.all_gather_object(gathered, data)
161
- data = gathered
162
- else:
163
- data = [data, ]
164
-
165
- return data
166
-
167
- def broadcast_object(self, data, src_rank=0):
168
- if self.trainer.world_size > 1:
169
- if self.global_rank == src_rank:
170
- data_list = [data, ] * self.trainer.world_size
171
- else:
172
- data_list = [None, ] * self.trainer.world_size
173
-
174
- dist.broadcast_object_list(data_list, src=src_rank)
175
- return data_list[0]
176
- else:
177
- return data
178
-
179
- def search(self, images, topk):
180
- query = self.model.encode_image(images, normalize=True)
181
- query = query.cpu().numpy()
182
- query = self.all_gather_object(query)
183
- query = np.concatenate(query)
184
-
185
- if self.global_rank == 0:
186
- _, I = self.faiss_index.search(query, 4*topk)
187
- S, I = refine_cosine(self.feature, query, I, self.device, topk)
188
- else:
189
- S = I = None
190
-
191
- return S, I, query
192
-
193
- def validation_step(self, batch, batch_idx):
194
- orig_imgs, five_imgs, nine_imgs, ids = batch
195
-
196
- ids = ids.cpu().numpy()
197
- ids = np.concatenate(self.all_gather_object(ids))
198
-
199
- S_w, I_w, Q_w = self.search(orig_imgs, topk=self.k)
200
-
201
- S_f, I_f, Q_f = [], [], []
202
- for i in range(five_imgs.shape[1]):
203
- Si, Ii, Qi = self.search(five_imgs[:, i], topk=self.k)
204
- S_f.append(Si)
205
- I_f.append(Ii)
206
- Q_f.append(Qi)
207
-
208
- S_n, I_n, Q_n = [], [], []
209
- for i in range(nine_imgs.shape[1]):
210
- Si, Ii, Qi = self.search(nine_imgs[:, i], topk=self.k)
211
- S_n.append(Si)
212
- I_n.append(Ii)
213
- Q_n.append(Qi)
214
-
215
- if self.global_rank == 0:
216
- S_w, I_w, Q_w = np.expand_dims(S_w, axis=1), np.expand_dims(I_w, axis=1), np.expand_dims(Q_w, axis=1)
217
- S_f, I_f, Q_f = np.stack(S_f, axis=1), np.stack(I_f, axis=1), np.stack(Q_f, axis=1)
218
- S_n, I_n, Q_n = np.stack(S_n, axis=1), np.stack(I_n, axis=1), np.stack(Q_n, axis=1)
219
-
220
- with h5py.File(self.save_root/"knowledge_idx.hdf5", "a") as f:
221
- g = f.create_group(str(batch_idx))
222
-
223
- g.create_dataset("image_ids", data=ids.astype(np.int32), compression="gzip")
224
-
225
- gw = g.create_group("whole")
226
- gw.create_dataset("index", data=I_w.astype(np.int32), compression="gzip")
227
- gw.create_dataset("score", data=S_w.astype(np.float32), compression="gzip")
228
- gw.create_dataset("query", data=Q_w.astype(np.float32), compression="gzip")
229
-
230
- gf = g.create_group("five")
231
- gf.create_dataset("index", data=I_f.astype(np.int32), compression="gzip")
232
- gf.create_dataset("score", data=S_f.astype(np.float32), compression="gzip")
233
- gf.create_dataset("query", data=Q_f.astype(np.float32), compression="gzip")
234
-
235
- gn = g.create_group("nine")
236
- gn.create_dataset("index", data=I_n.astype(np.int32), compression="gzip")
237
- gn.create_dataset("score", data=S_n.astype(np.float32), compression="gzip")
238
- gn.create_dataset("query", data=Q_n.astype(np.float32), compression="gzip")
239
-
240
- def on_validation_epoch_end(self):
241
- if self.args.azcopy and self.global_rank == 0:
242
- with open("azcopy/sas_output", "r") as f:
243
- sas = f.readline()
244
- sas_base, sas_key = sas.split("?")
245
- sas = f"{sas_base}/knowledge_idx?{sas_key}"
246
-
247
- cmd = ["azcopy/azcopy", "copy", str(self.args.save_root), sas, "--recursive=true"]
248
- print(f"start copying data with command {cmd}")
249
- ts = time.time()
250
- subprocess.run(cmd)
251
- print(f"done copying data in {time.time() - ts:.2f} secs")
252
-
253
-
254
- def main(args):
255
- model, _, trans_img = open_clip.create_model_and_transforms(
256
- args.clip_model, pretrained=text_db.CLIP_MODELS[args.clip_model]
257
- )
258
-
259
- print("load query dataset...")
260
- if "coco" in args.query:
261
- dset = coco.COCOImageCrops(Path(f"data/{args.query}"), trans=trans_img)
262
- collate_crops = coco.collate_coco_crops
263
- elif args.query == "cc3m":
264
- dset = cc.CC3MImageCrops(Path("data/cc3m_instruct"), trans=trans_img)
265
- collate_crops = cc.collate_cc_crops
266
- elif args.query == "llava":
267
- dset = llava.LLaVAImageCrops(Path("data/llava_bench"), trans=trans_img)
268
- collate_crops = llava.collate_llava_crops
269
- else:
270
- raise ValueError
271
- loader = DataLoader(
272
- dset, batch_size=args.bs, shuffle=False, num_workers=args.num_workers,
273
- drop_last=False, collate_fn=collate_crops
274
- )
275
-
276
- print("build model and trainer...")
277
- pl_model = ImageCropsIdxBuilder(args, model)
278
- model_summary = callbacks.RichModelSummary()
279
- progress_bar = callbacks.TQDMProgressBar(args.refresh_rate)
280
- trainer_callbacks = [model_summary, progress_bar]
281
- trainer = Trainer(
282
- sync_batchnorm=True,
283
- precision=16,
284
- accelerator='gpu',
285
- devices=args.devices,
286
- strategy="ddp",
287
- default_root_dir=args.save_root,
288
- callbacks=trainer_callbacks,
289
- limit_val_batches=args.limit_val_batches
290
- )
291
-
292
- print("retrieve knowledge...")
293
- trainer.validate(pl_model, dataloaders=loader)
294
-
295
-
296
- if __name__ == "__main__":
297
- parser = argparse.ArgumentParser(description='Knowledge retrieval using image crops')
298
- parser = Trainer.add_argparse_args(parser)
299
- parser.add_argument('--query', type=str, choices=["coco14", "coco17", "cc3m", "llava"], required=True)
300
- parser.add_argument('--knowledge_db', type=str, required=True)
301
- parser.add_argument('--k', type=int, default=128)
302
- parser.add_argument("--bs", type=int, default=128)
303
- parser.add_argument("--num_workers", type=int, default=7)
304
- parser.add_argument("--seed", type=int, default=12345)
305
- parser.add_argument("--refresh_rate", type=int, default=1)
306
- parser.add_argument("--azcopy", action="store_true")
307
- args = parser.parse_args()
308
-
309
- # parse exp_name
310
- exp_name = ExpName(f"(query-{args.query})")
311
- exp_name += Path(args.knowledge_db).parent.name
312
- if args.azcopy:
313
- setattr(args, "save_root", Path("azcopy")/str(exp_name))
314
- else:
315
- setattr(args, "save_root", Path("output")/"knowledge_idx"/str(exp_name))
316
- shutil.rmtree(args.save_root, ignore_errors=True)
317
- args.save_root.mkdir(parents=True, exist_ok=True)
318
-
319
- # parse model
320
- model = exp_name.get("clip-model")[1:-1]
321
- model = model[len("clip-model-"):]
322
- assert model in text_db.CLIP_MODELS.keys()
323
- setattr(args, "clip_model", model)
324
-
325
- print(args)
326
- seed_everything(args.seed, workers=True)
327
- main(args)
 
 
 
1
  import h5py
 
 
2
  import numpy as np
 
 
3
  from tqdm import tqdm
 
 
 
4
  import torch
 
 
 
 
 
 
 
 
 
 
5
  from knowledge import TextDB
 
6
 
7
 
8
  class ImageCropsIdx:
 
103
  }
104
 
105
  return ret
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
knowledge/text_db.py CHANGED
@@ -1,26 +1,8 @@
1
- import argparse
2
- import itertools
3
- from pathlib import Path
4
- import shutil
5
  import h5py
6
- import time
7
- import subprocess
8
  from tqdm import tqdm
9
  import numpy as np
10
  import codecs
11
-
12
- import open_clip
13
- import faiss
14
- import torch
15
- import torch.distributed as dist
16
- from torch.utils.data import DataLoader
17
- from pytorch_lightning import callbacks
18
- from pytorch_lightning import Trainer, LightningModule, seed_everything
19
-
20
- import sys
21
- sys.path.append("./")
22
- from dataset import cc, words
23
- from knowledge.utils import file_hash, build_faiss_index
24
 
25
 
26
  class TextDB:
@@ -59,139 +41,3 @@ class TextDB:
59
 
60
  return f, t
61
 
62
-
63
- class TextDBBuilder(LightningModule):
64
- def __init__(self, args, model: open_clip.model.CLIP):
65
- super().__init__()
66
- self.args = args
67
- self.model = model
68
-
69
- def validation_step(self, batch, batch_idx):
70
- token, text = batch
71
- feat = self.model.encode_text(token, normalize=True)
72
-
73
- if self.trainer.world_size > 1:
74
- text_gathered = [None for _ in range(self.trainer.world_size)]
75
- dist.all_gather_object(text_gathered, text)
76
- text = list(itertools.chain.from_iterable(text_gathered))
77
-
78
- feat_gathered = [None for _ in range(self.trainer.world_size)]
79
- dist.all_gather_object(feat_gathered, feat)
80
- feat = torch.cat([x.to(self.device) for x in feat_gathered])
81
- feat = feat.cpu().numpy()
82
-
83
- if self.global_rank == 0:
84
- with h5py.File(self.args.save_root/"knowledge_db.hdf5", "a") as f:
85
- g = f.create_group(str(batch_idx))
86
- g.create_dataset("feature", data=feat, compression="gzip")
87
- g.create_dataset("text", data=text, compression="gzip")
88
-
89
- def validation_epoch_end(self, outputs):
90
- if self.global_rank == 0:
91
- knowledge_db = TextDB(self.args.save_root/"knowledge_db.hdf5")
92
- feat = knowledge_db.feature
93
-
94
- if self.args.devices == "-1":
95
- num_devices = torch.cuda.device_count()
96
- devices = list(range(num_devices))
97
- else:
98
- devices = [int(x) for x in args.devices.split(",") if x]
99
- print(f"CUDA devices: {devices}")
100
-
101
- index = build_faiss_index(feat, gpus=devices)
102
- faiss.write_index(index, str(self.args.save_root/"faiss.index"))
103
- self.trainer.strategy.barrier()
104
-
105
- if self.args.azcopy and self.global_rank == 0:
106
- with open("azcopy/sas_output", "r") as f:
107
- sas = f.readline()
108
- sas_base, sas_key = sas.split("?")
109
- sas = f"{sas_base}/knowledge_db?{sas_key}"
110
-
111
- cmd = ["azcopy/azcopy", "copy", str(self.args.save_root), sas, "--recursive=true"]
112
- print(f"start copying data with command {cmd}")
113
- ts = time.time()
114
- subprocess.run(cmd)
115
- print(f"done copying data in {time.time() - ts:.2f} secs")
116
- self.trainer.strategy.barrier()
117
-
118
-
119
- DATASETS = {
120
- "object": words.ObjsDataset,
121
- "attribute": words.AttrsDataset,
122
- "action": words.ActsDataset,
123
- "cc3m": cc.CC3MTextDataset,
124
- "cc12m": cc.CC12MTextDataset
125
- }
126
-
127
-
128
- def main(args):
129
- model, _, _ = open_clip.create_model_and_transforms(
130
- args.clip_model, pretrained=CLIP_MODELS[args.clip_model]
131
- )
132
- trans_txt = open_clip.get_tokenizer(args.clip_model)
133
-
134
- print("load dataset...")
135
- dset = DATASETS[args.dataset](Path(args.data_root), trans_txt)
136
- loader = DataLoader(
137
- dset, batch_size=args.bs, shuffle=False, num_workers=args.num_workers,
138
- drop_last=False, collate_fn=cc.collate_cc_txt
139
- )
140
-
141
- print("build model and trainer...")
142
- pl_model = TextDBBuilder(args, model)
143
- model_summary = callbacks.RichModelSummary()
144
- progress_bar = callbacks.TQDMProgressBar(args.refresh_rate)
145
- trainer_callbacks = [model_summary, progress_bar]
146
- trainer = Trainer(
147
- sync_batchnorm=True,
148
- precision=16,
149
- accelerator='gpu',
150
- devices=args.devices,
151
- strategy="ddp",
152
- default_root_dir=args.save_root,
153
- callbacks=trainer_callbacks,
154
- limit_val_batches=args.limit_val_batches
155
- )
156
-
157
- print("compute textual features...")
158
- trainer.validate(pl_model, dataloaders=loader)
159
-
160
-
161
- CLIP_MODELS = {
162
- 'ViT-B-32': 'openai',
163
- 'ViT-B-16': 'openai',
164
- 'ViT-L-14': 'openai',
165
- 'ViT-g-14': 'laion2b_s34b_b88k',
166
- 'ViT-bigG-14': 'laion2b_s39b_b160k',
167
- 'convnext_xxlarge': 'laion2b_s34b_b82k_augreg_soup',
168
- }
169
-
170
-
171
- if __name__ == "__main__":
172
- parser = argparse.ArgumentParser(description="Build knowledge database of words")
173
- parser = Trainer.add_argparse_args(parser)
174
- parser.add_argument(
175
- "--dataset", type=str, required=True, choices=["object", "attribute", "action", "cc3m", "cc12m"]
176
- )
177
- parser.add_argument("--data_root", type=str, default="data/conceptnet/conceptnet-assertions-5.7.0.csv")
178
- parser.add_argument("--clip_model", type=str, default="ViT-g-14", choices=CLIP_MODELS.keys())
179
- parser.add_argument("--bs", type=int, default=2**10)
180
- parser.add_argument("--num_workers", type=int, default=7)
181
- parser.add_argument("--seed", type=int, default=12345)
182
- parser.add_argument("--refresh_rate", type=int, default=1)
183
- parser.add_argument("--azcopy", action="store_true")
184
- args = parser.parse_args()
185
-
186
- # feature dir
187
- exp_name = f"(dataset-{args.dataset})(clip-model-{args.clip_model})"
188
- if args.azcopy:
189
- setattr(args, "save_root", Path("azcopy")/"knowledge_db"/exp_name)
190
- else:
191
- setattr(args, "save_root", Path("output")/"knowledge_db"/exp_name)
192
- shutil.rmtree(args.save_root, ignore_errors=True)
193
- args.save_root.mkdir(parents=True, exist_ok=True)
194
-
195
- print(args)
196
- seed_everything(args.seed, workers=True)
197
- main(args)
 
 
 
 
 
1
  import h5py
 
 
2
  from tqdm import tqdm
3
  import numpy as np
4
  import codecs
5
+ from knowledge.utils import file_hash
 
 
 
 
 
 
 
 
 
 
 
 
6
 
7
 
8
  class TextDB:
 
41
 
42
  return f, t
43
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
knowledge/transforms.py CHANGED
@@ -1,6 +1,5 @@
1
  import itertools
2
  from torchvision.transforms import functional as F
3
- import re
4
 
5
 
6
  def five_crop(image, ratio=0.6):
@@ -26,27 +25,3 @@ def nine_crop(image, ratio=0.4):
26
  images.append(F.crop(image, top, left, height, width))
27
 
28
  return images
29
-
30
-
31
- def pre_caption(caption, max_words=None):
32
- # Ref: https://github.com/salesforce/LAVIS/blob/main/lavis/processors/blip_processors.py#L49-L68
33
- caption = re.sub(
34
- r"([.!\"()*#:;~])",
35
- " ",
36
- caption.lower(),
37
- )
38
- caption = re.sub(
39
- r"\s{2,}",
40
- " ",
41
- caption,
42
- )
43
- caption = caption.rstrip("\n")
44
- caption = caption.strip(" ")
45
-
46
- # truncate caption
47
- caption_words = caption.split(" ")
48
- if max_words is not None and len(caption_words) > max_words:
49
- caption = " ".join(caption_words[: max_words])
50
-
51
- return caption
52
-
 
1
  import itertools
2
  from torchvision.transforms import functional as F
 
3
 
4
 
5
  def five_crop(image, ratio=0.6):
 
25
  images.append(F.crop(image, top, left, height, width))
26
 
27
  return images
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
knowledge/utils.py CHANGED
@@ -1,11 +1,5 @@
1
- from tqdm import tqdm
2
  import numpy as np
3
- import time
4
- import math
5
- import bisect
6
  import hashlib
7
- import faiss
8
- from faiss import StandardGpuResources, index_cpu_to_gpu_multiple_py
9
  import torch
10
 
11
 
@@ -21,78 +15,6 @@ def file_hash(file):
21
  return hash_fn.hexdigest()
22
 
23
 
24
- def build_faiss_index(x, gpus=None):
25
- # Ref: https://github.com/facebookresearch/faiss/wiki/Guidelines-to-choose-an-index
26
- # Ref: https://gist.github.com/mdouze/46d6bbbaabca0b9778fca37ed2bcccf6
27
-
28
- N, dim = x.shape
29
- secs = [2**i for i in range(1, 15)]
30
- d = secs[bisect.bisect_right(secs, dim) - 1] // 2
31
- m = d // 4
32
-
33
- if N <= 60000:
34
- index_factory = "Flat"
35
- elif N <= 2555904:
36
- index_factory = f"IVF{int(8*math.sqrt(N))},Flat"
37
- elif N <= 10223616:
38
- index_factory = f"OPQ{m}_{d},IVF65536_HNSW32,PQ{m}x4fsr"
39
- elif N <= 1e8:
40
- index_factory = f"OPQ{m}_{d},IVF262144_HNSW32,PQ{m}x4fsr"
41
- else:
42
- index_factory = f"OPQ{m}_{d},IVF1048576_HNSW32,PQ{m}x4fsr"
43
- print(f"train {index_factory} index on {N:,} x {dim} data")
44
-
45
- index = faiss.index_factory(dim, index_factory)
46
- if gpus is not None and N > 60000:
47
- index_ivf = faiss.extract_index_ivf(index)
48
- res = []
49
- for _ in gpus:
50
- r = StandardGpuResources()
51
- r.noTempMemory()
52
- res.append(r)
53
- clustering_index = index_cpu_to_gpu_multiple_py(
54
- res, faiss.IndexFlatL2(index_ivf.d), None, gpus
55
- )
56
- index_ivf.clustering_index = clustering_index
57
-
58
- print("train index...", end="\r")
59
- ts = time.time()
60
- # commented out for index_factory = "Flat"
61
- # assert not index.is_trained
62
- index.train(x)
63
- assert index.is_trained
64
- print(f"train index...done in {time.time() - ts:.2f} secs")
65
-
66
- index.nprobe = 64
67
- index.quantizer_efSearch = 32
68
-
69
- return index
70
-
71
-
72
- def nn_search(query, index, topk, bs=256, desc=None, disable_tqdm=True):
73
- idx, dist = [], []
74
- N = (len(query) - 1) // bs + 1
75
- for i in tqdm(range(N), dynamic_ncols=True, desc=desc, disable=disable_tqdm):
76
- D, I = index.search(query[i*bs: (i+1)*bs], topk)
77
- idx.append(I)
78
- dist.append(D)
79
- idx = np.concatenate(idx)
80
- dist = np.concatenate(dist)
81
-
82
- return idx, dist
83
-
84
-
85
- def radius_search(query, index, r, bs=256, desc=None, disable_tqdm=True):
86
- idx, dist = [], []
87
- N = (len(query) - 1) // bs + 1
88
- for i in tqdm(range(N), dynamic_ncols=True, desc=desc, disable=disable_tqdm):
89
- L, D, I = index.range_search(query[i*bs: (i+1)*bs], r)
90
- idx.extend([I[L[j]:L[j+1]] for j in range(len(L)-1)])
91
- dist.extend([D[L[j]:L[j+1]] for j in range(len(L)-1)])
92
-
93
- return idx, dist
94
-
95
-
96
  @torch.no_grad()
97
  def refine_cosine(Xa, Xq, I, device, k=None):
98
  if k is not None:
@@ -114,14 +36,3 @@ def refine_cosine(Xa, Xq, I, device, k=None):
114
  S_refined = np.stack(S_refined)
115
 
116
  return S_refined, I_refined
117
-
118
-
119
- def test_nn_search():
120
- key = np.random.random((3000000, 512)).astype(np.float32)
121
- key /= np.linalg.norm(key, keepdims=True, axis=1)
122
- index = build_faiss_index(key, -1)
123
-
124
- query = np.random.random((100000, 512)).astype(np.float32)
125
- query /= np.linalg.norm(query, keepdims=True, axis=1)
126
- idx_r = nn_search(query, index, r=0.5)
127
- idx_k = nn_search(query, index, topk=10)
 
 
1
  import numpy as np
 
 
 
2
  import hashlib
 
 
3
  import torch
4
 
5
 
 
15
  return hash_fn.hexdigest()
16
 
17
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
18
  @torch.no_grad()
19
  def refine_cosine(Xa, Xq, I, device, k=None):
20
  if k is not None:
 
36
  S_refined = np.stack(S_refined)
37
 
38
  return S_refined, I_refined
 
 
 
 
 
 
 
 
 
 
 
model/.gitattributes DELETED
@@ -1,2 +0,0 @@
1
- *.hdf5 filter=lfs diff=lfs merge=lfs -text
2
- *.pt filter=lfs diff=lfs merge=lfs -text