Spaces:
Runtime error
Runtime error
cwkuo
commited on
Commit
·
40e6e04
1
Parent(s):
bf5fb05
save memory usage
Browse files- app.py +29 -19
- knowledge/text_db.py +1 -1
- model/gptk.py +12 -3
- requirements.txt +6 -4
app.py
CHANGED
@@ -32,7 +32,7 @@ def violates_moderation(text):
|
|
32 |
"""
|
33 |
if "OPENAI_API_KEY" not in os.environ:
|
34 |
print("OPENAI_API_KEY not found, skip content moderation check...")
|
35 |
-
return
|
36 |
|
37 |
url = "https://api.openai.com/v1/moderations"
|
38 |
headers = {
|
@@ -206,7 +206,7 @@ def generate(state: Conversation, temperature, top_p, max_new_tokens, add_knwl,
|
|
206 |
# generate output
|
207 |
prompt = state.get_prompt().replace("USER: <image>\n", "")
|
208 |
prompt = prompt.split("USER:")[-1].replace("ASSISTANT:", "")
|
209 |
-
image_pt =
|
210 |
samples = {"image": image_pt, "knowledge": knwl_embd, "prompt": prompt}
|
211 |
if bool(do_beam_search):
|
212 |
new_text = gptk_model.generate(
|
@@ -358,21 +358,12 @@ def build_demo():
|
|
358 |
return demo
|
359 |
|
360 |
|
361 |
-
def
|
362 |
-
if torch.cuda.is_available():
|
363 |
-
device = torch.device("cuda")
|
364 |
-
else:
|
365 |
-
device = torch.device("cpu")
|
366 |
-
|
367 |
-
query_enc, _, query_trans = open_clip.create_model_and_transforms(
|
368 |
-
"ViT-g-14", pretrained="laion2b_s34b_b88k", precision='fp16'
|
369 |
-
)
|
370 |
-
query_enc = query_enc.to(device).eval()
|
371 |
-
|
372 |
def get_knwl(knowledge_db):
|
373 |
knwl_db = TextDB(Path(knowledge_db)/"knowledge_db.hdf5")
|
|
|
374 |
knwl_idx = faiss.read_index(str(Path(knowledge_db)/"faiss.index"))
|
375 |
-
knwl_idx.add(knwl_db.feature)
|
376 |
|
377 |
return knwl_db, knwl_idx
|
378 |
|
@@ -381,19 +372,38 @@ def build_model():
|
|
381 |
"act": get_knwl('knowledge/(dataset-action)(clip-model-ViT-g-14)(dbscan)(eps-0.15)(ms-1)'),
|
382 |
"attr": get_knwl('knowledge/(dataset-attribute)(clip-model-ViT-g-14)(dbscan)(eps-0.15)(ms-1)'),
|
383 |
}
|
384 |
-
d_knwl = knwl_db["obj"][0].feature.shape[-1]
|
385 |
|
386 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
387 |
topk = {"whole": 60, "five": 24, "nine": 16}
|
388 |
-
gptk_model = get_gptk_model(d_knwl=
|
389 |
gptk_ckpt = "model/ckpt/gptk-vicuna7b.pt"
|
390 |
gptk_ckpt = torch.load(gptk_ckpt, map_location="cpu")
|
391 |
gptk_model.load_state_dict(gptk_ckpt, strict=False)
|
392 |
gptk_model = gptk_model.to(device).eval()
|
393 |
|
394 |
-
return
|
|
|
395 |
|
|
|
|
|
|
|
|
|
396 |
|
397 |
-
|
|
|
|
|
398 |
demo = build_demo()
|
399 |
demo.queue().launch()
|
|
|
32 |
"""
|
33 |
if "OPENAI_API_KEY" not in os.environ:
|
34 |
print("OPENAI_API_KEY not found, skip content moderation check...")
|
35 |
+
return False
|
36 |
|
37 |
url = "https://api.openai.com/v1/moderations"
|
38 |
headers = {
|
|
|
206 |
# generate output
|
207 |
prompt = state.get_prompt().replace("USER: <image>\n", "")
|
208 |
prompt = prompt.split("USER:")[-1].replace("ASSISTANT:", "")
|
209 |
+
image_pt = gptk_trans(image).to(device).unsqueeze(0)
|
210 |
samples = {"image": image_pt, "knowledge": knwl_embd, "prompt": prompt}
|
211 |
if bool(do_beam_search):
|
212 |
new_text = gptk_model.generate(
|
|
|
358 |
return demo
|
359 |
|
360 |
|
361 |
+
def build_knowledge():
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
362 |
def get_knwl(knowledge_db):
|
363 |
knwl_db = TextDB(Path(knowledge_db)/"knowledge_db.hdf5")
|
364 |
+
knwl_db.feature = knwl_db.feature
|
365 |
knwl_idx = faiss.read_index(str(Path(knowledge_db)/"faiss.index"))
|
366 |
+
knwl_idx.add(knwl_db.feature.astype(np.float32))
|
367 |
|
368 |
return knwl_db, knwl_idx
|
369 |
|
|
|
372 |
"act": get_knwl('knowledge/(dataset-action)(clip-model-ViT-g-14)(dbscan)(eps-0.15)(ms-1)'),
|
373 |
"attr": get_knwl('knowledge/(dataset-attribute)(clip-model-ViT-g-14)(dbscan)(eps-0.15)(ms-1)'),
|
374 |
}
|
|
|
375 |
|
376 |
+
return knwl_db
|
377 |
+
|
378 |
+
|
379 |
+
def build_query_model():
|
380 |
+
query_enc, _, query_trans = open_clip.create_model_and_transforms(
|
381 |
+
"ViT-g-14", pretrained="laion2b_s34b_b88k", precision='fp16', device=device
|
382 |
+
)
|
383 |
+
query_enc = query_enc.eval()
|
384 |
+
|
385 |
+
return query_enc, query_trans
|
386 |
+
|
387 |
+
|
388 |
+
def build_gptk_model():
|
389 |
+
_, gptk_trans = get_gptk_image_transform()
|
390 |
topk = {"whole": 60, "five": 24, "nine": 16}
|
391 |
+
gptk_model = get_gptk_model(d_knwl=1024, topk=topk)
|
392 |
gptk_ckpt = "model/ckpt/gptk-vicuna7b.pt"
|
393 |
gptk_ckpt = torch.load(gptk_ckpt, map_location="cpu")
|
394 |
gptk_model.load_state_dict(gptk_ckpt, strict=False)
|
395 |
gptk_model = gptk_model.to(device).eval()
|
396 |
|
397 |
+
return gptk_model, gptk_trans, topk
|
398 |
+
|
399 |
|
400 |
+
if torch.cuda.is_available():
|
401 |
+
device = torch.device("cuda")
|
402 |
+
else:
|
403 |
+
device = torch.device("cpu")
|
404 |
|
405 |
+
gptk_model, gptk_trans, topk = build_gptk_model()
|
406 |
+
query_enc, query_trans = build_query_model()
|
407 |
+
knwl_db = build_knowledge()
|
408 |
demo = build_demo()
|
409 |
demo.queue().launch()
|
knowledge/text_db.py
CHANGED
@@ -18,7 +18,7 @@ class TextDB:
|
|
18 |
_, d = f[f"0/feature"].shape
|
19 |
|
20 |
with h5py.File(text_db, 'r') as f:
|
21 |
-
feature = np.zeros((db_size, d), dtype=np.
|
22 |
text = []
|
23 |
N = 0
|
24 |
for i in tqdm(range(len(f)), desc="Load text DB", dynamic_ncols=True, mininterval=1.0):
|
|
|
18 |
_, d = f[f"0/feature"].shape
|
19 |
|
20 |
with h5py.File(text_db, 'r') as f:
|
21 |
+
feature = np.zeros((db_size, d), dtype=np.float16)
|
22 |
text = []
|
23 |
N = 0
|
24 |
for i in tqdm(range(len(f)), desc="Load text DB", dynamic_ncols=True, mininterval=1.0):
|
model/gptk.py
CHANGED
@@ -7,8 +7,7 @@ import torch
|
|
7 |
from torch import nn
|
8 |
from torchvision import transforms as T
|
9 |
from torchvision.transforms.functional import InterpolationMode
|
10 |
-
from transformers import LlamaTokenizer
|
11 |
-
from transformers import BertTokenizer
|
12 |
|
13 |
import sys
|
14 |
sys.path.append("./")
|
@@ -49,7 +48,17 @@ class GPTK(nn.Module):
|
|
49 |
llm_config = LlamaConfig.from_pretrained(llm_model)
|
50 |
llm_config.gradient_checkpointing = True
|
51 |
llm_config.use_cache = True
|
52 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
53 |
self.llm_tokenizer.add_special_tokens({'pad_token': '[PAD]'})
|
54 |
self.llm_tokenizer.add_special_tokens({'bos_token': '</s>'})
|
55 |
self.llm_tokenizer.add_special_tokens({'eos_token': '</s>'})
|
|
|
7 |
from torch import nn
|
8 |
from torchvision import transforms as T
|
9 |
from torchvision.transforms.functional import InterpolationMode
|
10 |
+
from transformers import LlamaTokenizer, BertTokenizer, BitsAndBytesConfig
|
|
|
11 |
|
12 |
import sys
|
13 |
sys.path.append("./")
|
|
|
48 |
llm_config = LlamaConfig.from_pretrained(llm_model)
|
49 |
llm_config.gradient_checkpointing = True
|
50 |
llm_config.use_cache = True
|
51 |
+
quantization_config = BitsAndBytesConfig(
|
52 |
+
load_in_4bit=True,
|
53 |
+
llm_int8_threshold=6.0,
|
54 |
+
llm_int8_has_fp16_weight=False,
|
55 |
+
bnb_4bit_compute_dtype=torch.float16,
|
56 |
+
bnb_4bit_use_double_quant=True,
|
57 |
+
bnb_4bit_quant_type='nf4'
|
58 |
+
)
|
59 |
+
self.llm_model = LlamaForCausalLM.from_pretrained(
|
60 |
+
llm_model, config=llm_config, torch_dtype=torch.float16, quantization_config=quantization_config
|
61 |
+
)
|
62 |
self.llm_tokenizer.add_special_tokens({'pad_token': '[PAD]'})
|
63 |
self.llm_tokenizer.add_special_tokens({'bos_token': '</s>'})
|
64 |
self.llm_tokenizer.add_special_tokens({'eos_token': '</s>'})
|
requirements.txt
CHANGED
@@ -1,7 +1,7 @@
|
|
1 |
-
--extra-index-url https://download.pytorch.org/whl/
|
2 |
-
torch==1.
|
3 |
-
torchvision==0.
|
4 |
-
torchaudio==0.
|
5 |
|
6 |
transformers==4.30.2
|
7 |
faiss-gpu==1.7.2
|
@@ -11,3 +11,5 @@ open_clip_torch
|
|
11 |
omegaconf
|
12 |
h5py>=3.8.0
|
13 |
spacy>=3.5.0
|
|
|
|
|
|
1 |
+
--extra-index-url https://download.pytorch.org/whl/cu117
|
2 |
+
torch==1.13.1+cu117
|
3 |
+
torchvision==0.14.1+cu117
|
4 |
+
torchaudio==0.13.1
|
5 |
|
6 |
transformers==4.30.2
|
7 |
faiss-gpu==1.7.2
|
|
|
11 |
omegaconf
|
12 |
h5py>=3.8.0
|
13 |
spacy>=3.5.0
|
14 |
+
bitsandbytes
|
15 |
+
accelerate
|