cwkuo commited on
Commit
40e6e04
·
1 Parent(s): bf5fb05

save memory usage

Browse files
Files changed (4) hide show
  1. app.py +29 -19
  2. knowledge/text_db.py +1 -1
  3. model/gptk.py +12 -3
  4. requirements.txt +6 -4
app.py CHANGED
@@ -32,7 +32,7 @@ def violates_moderation(text):
32
  """
33
  if "OPENAI_API_KEY" not in os.environ:
34
  print("OPENAI_API_KEY not found, skip content moderation check...")
35
- return True
36
 
37
  url = "https://api.openai.com/v1/moderations"
38
  headers = {
@@ -206,7 +206,7 @@ def generate(state: Conversation, temperature, top_p, max_new_tokens, add_knwl,
206
  # generate output
207
  prompt = state.get_prompt().replace("USER: <image>\n", "")
208
  prompt = prompt.split("USER:")[-1].replace("ASSISTANT:", "")
209
- image_pt = image_trans(image).to(device).unsqueeze(0)
210
  samples = {"image": image_pt, "knowledge": knwl_embd, "prompt": prompt}
211
  if bool(do_beam_search):
212
  new_text = gptk_model.generate(
@@ -358,21 +358,12 @@ def build_demo():
358
  return demo
359
 
360
 
361
- def build_model():
362
- if torch.cuda.is_available():
363
- device = torch.device("cuda")
364
- else:
365
- device = torch.device("cpu")
366
-
367
- query_enc, _, query_trans = open_clip.create_model_and_transforms(
368
- "ViT-g-14", pretrained="laion2b_s34b_b88k", precision='fp16'
369
- )
370
- query_enc = query_enc.to(device).eval()
371
-
372
  def get_knwl(knowledge_db):
373
  knwl_db = TextDB(Path(knowledge_db)/"knowledge_db.hdf5")
 
374
  knwl_idx = faiss.read_index(str(Path(knowledge_db)/"faiss.index"))
375
- knwl_idx.add(knwl_db.feature)
376
 
377
  return knwl_db, knwl_idx
378
 
@@ -381,19 +372,38 @@ def build_model():
381
  "act": get_knwl('knowledge/(dataset-action)(clip-model-ViT-g-14)(dbscan)(eps-0.15)(ms-1)'),
382
  "attr": get_knwl('knowledge/(dataset-attribute)(clip-model-ViT-g-14)(dbscan)(eps-0.15)(ms-1)'),
383
  }
384
- d_knwl = knwl_db["obj"][0].feature.shape[-1]
385
 
386
- _, image_trans = get_gptk_image_transform()
 
 
 
 
 
 
 
 
 
 
 
 
 
387
  topk = {"whole": 60, "five": 24, "nine": 16}
388
- gptk_model = get_gptk_model(d_knwl=d_knwl, topk=topk)
389
  gptk_ckpt = "model/ckpt/gptk-vicuna7b.pt"
390
  gptk_ckpt = torch.load(gptk_ckpt, map_location="cpu")
391
  gptk_model.load_state_dict(gptk_ckpt, strict=False)
392
  gptk_model = gptk_model.to(device).eval()
393
 
394
- return knwl_db, query_enc, query_trans, gptk_model, image_trans, topk, device
 
395
 
 
 
 
 
396
 
397
- knwl_db, query_enc, query_trans, gptk_model, image_trans, topk, device = build_model()
 
 
398
  demo = build_demo()
399
  demo.queue().launch()
 
32
  """
33
  if "OPENAI_API_KEY" not in os.environ:
34
  print("OPENAI_API_KEY not found, skip content moderation check...")
35
+ return False
36
 
37
  url = "https://api.openai.com/v1/moderations"
38
  headers = {
 
206
  # generate output
207
  prompt = state.get_prompt().replace("USER: <image>\n", "")
208
  prompt = prompt.split("USER:")[-1].replace("ASSISTANT:", "")
209
+ image_pt = gptk_trans(image).to(device).unsqueeze(0)
210
  samples = {"image": image_pt, "knowledge": knwl_embd, "prompt": prompt}
211
  if bool(do_beam_search):
212
  new_text = gptk_model.generate(
 
358
  return demo
359
 
360
 
361
+ def build_knowledge():
 
 
 
 
 
 
 
 
 
 
362
  def get_knwl(knowledge_db):
363
  knwl_db = TextDB(Path(knowledge_db)/"knowledge_db.hdf5")
364
+ knwl_db.feature = knwl_db.feature
365
  knwl_idx = faiss.read_index(str(Path(knowledge_db)/"faiss.index"))
366
+ knwl_idx.add(knwl_db.feature.astype(np.float32))
367
 
368
  return knwl_db, knwl_idx
369
 
 
372
  "act": get_knwl('knowledge/(dataset-action)(clip-model-ViT-g-14)(dbscan)(eps-0.15)(ms-1)'),
373
  "attr": get_knwl('knowledge/(dataset-attribute)(clip-model-ViT-g-14)(dbscan)(eps-0.15)(ms-1)'),
374
  }
 
375
 
376
+ return knwl_db
377
+
378
+
379
+ def build_query_model():
380
+ query_enc, _, query_trans = open_clip.create_model_and_transforms(
381
+ "ViT-g-14", pretrained="laion2b_s34b_b88k", precision='fp16', device=device
382
+ )
383
+ query_enc = query_enc.eval()
384
+
385
+ return query_enc, query_trans
386
+
387
+
388
+ def build_gptk_model():
389
+ _, gptk_trans = get_gptk_image_transform()
390
  topk = {"whole": 60, "five": 24, "nine": 16}
391
+ gptk_model = get_gptk_model(d_knwl=1024, topk=topk)
392
  gptk_ckpt = "model/ckpt/gptk-vicuna7b.pt"
393
  gptk_ckpt = torch.load(gptk_ckpt, map_location="cpu")
394
  gptk_model.load_state_dict(gptk_ckpt, strict=False)
395
  gptk_model = gptk_model.to(device).eval()
396
 
397
+ return gptk_model, gptk_trans, topk
398
+
399
 
400
+ if torch.cuda.is_available():
401
+ device = torch.device("cuda")
402
+ else:
403
+ device = torch.device("cpu")
404
 
405
+ gptk_model, gptk_trans, topk = build_gptk_model()
406
+ query_enc, query_trans = build_query_model()
407
+ knwl_db = build_knowledge()
408
  demo = build_demo()
409
  demo.queue().launch()
knowledge/text_db.py CHANGED
@@ -18,7 +18,7 @@ class TextDB:
18
  _, d = f[f"0/feature"].shape
19
 
20
  with h5py.File(text_db, 'r') as f:
21
- feature = np.zeros((db_size, d), dtype=np.float32)
22
  text = []
23
  N = 0
24
  for i in tqdm(range(len(f)), desc="Load text DB", dynamic_ncols=True, mininterval=1.0):
 
18
  _, d = f[f"0/feature"].shape
19
 
20
  with h5py.File(text_db, 'r') as f:
21
+ feature = np.zeros((db_size, d), dtype=np.float16)
22
  text = []
23
  N = 0
24
  for i in tqdm(range(len(f)), desc="Load text DB", dynamic_ncols=True, mininterval=1.0):
model/gptk.py CHANGED
@@ -7,8 +7,7 @@ import torch
7
  from torch import nn
8
  from torchvision import transforms as T
9
  from torchvision.transforms.functional import InterpolationMode
10
- from transformers import LlamaTokenizer
11
- from transformers import BertTokenizer
12
 
13
  import sys
14
  sys.path.append("./")
@@ -49,7 +48,17 @@ class GPTK(nn.Module):
49
  llm_config = LlamaConfig.from_pretrained(llm_model)
50
  llm_config.gradient_checkpointing = True
51
  llm_config.use_cache = True
52
- self.llm_model = LlamaForCausalLM.from_pretrained(llm_model, config=llm_config, torch_dtype=torch.float16)
 
 
 
 
 
 
 
 
 
 
53
  self.llm_tokenizer.add_special_tokens({'pad_token': '[PAD]'})
54
  self.llm_tokenizer.add_special_tokens({'bos_token': '</s>'})
55
  self.llm_tokenizer.add_special_tokens({'eos_token': '</s>'})
 
7
  from torch import nn
8
  from torchvision import transforms as T
9
  from torchvision.transforms.functional import InterpolationMode
10
+ from transformers import LlamaTokenizer, BertTokenizer, BitsAndBytesConfig
 
11
 
12
  import sys
13
  sys.path.append("./")
 
48
  llm_config = LlamaConfig.from_pretrained(llm_model)
49
  llm_config.gradient_checkpointing = True
50
  llm_config.use_cache = True
51
+ quantization_config = BitsAndBytesConfig(
52
+ load_in_4bit=True,
53
+ llm_int8_threshold=6.0,
54
+ llm_int8_has_fp16_weight=False,
55
+ bnb_4bit_compute_dtype=torch.float16,
56
+ bnb_4bit_use_double_quant=True,
57
+ bnb_4bit_quant_type='nf4'
58
+ )
59
+ self.llm_model = LlamaForCausalLM.from_pretrained(
60
+ llm_model, config=llm_config, torch_dtype=torch.float16, quantization_config=quantization_config
61
+ )
62
  self.llm_tokenizer.add_special_tokens({'pad_token': '[PAD]'})
63
  self.llm_tokenizer.add_special_tokens({'bos_token': '</s>'})
64
  self.llm_tokenizer.add_special_tokens({'eos_token': '</s>'})
requirements.txt CHANGED
@@ -1,7 +1,7 @@
1
- --extra-index-url https://download.pytorch.org/whl/cu113
2
- torch==1.11.0+cu113
3
- torchvision==0.12.0+cu113
4
- torchaudio==0.11.0
5
 
6
  transformers==4.30.2
7
  faiss-gpu==1.7.2
@@ -11,3 +11,5 @@ open_clip_torch
11
  omegaconf
12
  h5py>=3.8.0
13
  spacy>=3.5.0
 
 
 
1
+ --extra-index-url https://download.pytorch.org/whl/cu117
2
+ torch==1.13.1+cu117
3
+ torchvision==0.14.1+cu117
4
+ torchaudio==0.13.1
5
 
6
  transformers==4.30.2
7
  faiss-gpu==1.7.2
 
11
  omegaconf
12
  h5py>=3.8.0
13
  spacy>=3.5.0
14
+ bitsandbytes
15
+ accelerate